[
  {
    "path": ".clang-format",
    "content": "---\nAccessModifierOffset: -1\nAlignAfterOpenBracket: AlwaysBreak\nAlignConsecutiveAssignments: false\nAlignConsecutiveDeclarations: false\nAlignEscapedNewlinesLeft: true\nAlignOperands:   false\nAlignTrailingComments: false\nAllowAllParametersOfDeclarationOnNextLine: false\nAllowShortBlocksOnASingleLine: false\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortFunctionsOnASingleLine: Empty\nAllowShortIfStatementsOnASingleLine: false\nAllowShortLoopsOnASingleLine: false\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: true\nAlwaysBreakTemplateDeclarations: true\nBinPackArguments: false\nBinPackParameters: false\nBraceWrapping:\n  AfterClass:      false\n  AfterControlStatement: false\n  AfterEnum:       false\n  AfterFunction:   false\n  AfterNamespace:  false\n  AfterObjCDeclaration: false\n  AfterStruct:     false\n  AfterUnion:      false\n  BeforeCatch:     false\n  BeforeElse:      false\n  IndentBraces:    false\nBreakBeforeBinaryOperators: None\nBreakBeforeBraces: Attach\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializersBeforeComma: false\nBreakAfterJavaFieldAnnotations: false\nBreakStringLiterals: false\nColumnLimit:     120\nCommentPragmas:  '^ IWYU pragma:'\nCompactNamespaces: false\nConstructorInitializerAllOnOneLineOrOnePerLine: true\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDerivePointerAlignment: false\nDisableFormat:   false\nForEachMacros:   [ FOR_EACH_RANGE, FOR_EACH, ]\nIncludeCategories:\n  - Regex:           '^<.*\\.h(pp)?>'\n    Priority:        1\n  - Regex:           '^<.*'\n    Priority:        2\n  - Regex:           '.*'\n    Priority:        3\nIndentCaseLabels: true\nIndentWidth:     2\nIndentWrappedFunctionNames: false\nKeepEmptyLinesAtTheStartOfBlocks: false\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBlockIndentWidth: 2\nObjCSpaceAfterProperty: false\nObjCSpaceBeforeProtocolList: false\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakString: 1000\nPenaltyExcessCharacter: 1000000\nPenaltyReturnTypeOnItsOwnLine: 2000000\nPointerAlignment: Left\nReflowComments:  true\nSortIncludes:    true\nSpaceAfterCStyleCast: false\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeParens: ControlStatements\nSpaceInEmptyParentheses: false\nSpacesBeforeTrailingComments: 1\nSpacesInAngles:  false\nSpacesInContainerLiterals: true\nSpacesInCStyleCastParentheses: false\nSpacesInParentheses: false\nSpacesInSquareBrackets: false\nStandard:        Cpp11\nTabWidth:        8\nUseTab:          Never\n...\n"
  },
  {
    "path": ".coderabbit.yaml",
    "content": "# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json\n\n# This file configures CodeRabbit with the various options described in https://docs.coderabbit.ai/configure-coderabbit.\n# CodeRabbit also has a set of commands here: https://docs.coderabbit.ai/guides/commands/\n\nlanguage: \"en-US\"\nearly_access: false\ntone_instructions: \"Be terse and to the point in all statements and commentary.\"\nreviews:\n  # chill is less verbose, assertive is more verbose with more nitpick feedback\n  profile: chill\n  high_level_summary: false\n  high_level_summary_placeholder: \"@coderabbitai summary\"\n  sequence_diagrams: false\n  auto_apply_labels: false\n  suggested_reviewers: false\n  changed_files_summary: false\n  suggested_labels: false\n  abort_on_close: true\n  poem: false\n  path_instructions:\n    - path: '**/*.md'\n      instructions: Remember that documentation must be updated with the latest information.\n    - path: '**/*.rst'\n      instructions: Remember that documentation must be updated with the latest information.\n    - path: '**/*.py'\n      instructions: >-\n        Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are\n        sensible and informative in regards to their function, though permitting simple names for loop and comprehension\n        variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and\n        nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each\n        variable, return value, and raised exception in the appropriate section of the Google-style of docstrings.\n        Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest\n        any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new\n        or modified definitions will be covered by existing or new unit tests.\n\n  auto_review:\n    # Automatic Review | Automatic code review\n    enabled: true\n    # Review draft PRs/MRs.\n    drafts: false\n    # ignore PRs with these in the title, these sorts of PRs should be drafts anyway\n    ignore_title_keywords:\n    - \"WIP\"\n    - \"DO NOT MERGE\"\n\n# opt out for now until it's clear this isn't too much info and is useful\nknowledge_base:\n  opt_out: true\n\n# chat is allowed\nchat:\n  auto_reply: true\n"
  },
  {
    "path": ".deepsource.toml",
    "content": "version = 1\n\ntest_patterns = [\"tests/**\"]\n\nexclude_patterns = [\n    \"monai/_version.py\",\n    \"versioneer.py\"\n]\n\n[[analyzers]]\nname = \"python\"\nenabled = true\n\n  [analyzers.meta]\n  runtime_version = \"3.x.x\"\n\n[[analyzers]]\nname = \"test-coverage\"\nenabled = true\n\n[[analyzers]]\nname = \"docker\"\nenabled = true\n\n[[analyzers]]\nname = \"shell\"\nenabled = true\n"
  },
  {
    "path": ".dockerignore",
    "content": "# Ignore the following files/folders during docker build\n\n__pycache__/\ndocs/\n\n.vscode\n.git\n.mypy_cache\n.ruff_cache\n.pytype\n.coverage\n.coverage.*\n.coverage/\ncoverage.xml\n.readthedocs.yml\n\n!README.md\n"
  },
  {
    "path": ".gitattributes",
    "content": "monai/_version.py export-subst\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "/monai/ @KumoLiu @ericspod @Nic-Ma\n/docs/ @KumoLiu @ericspod @Nic-Ma\n/tests/ @KumoLiu @ericspod @Nic-Ma\n/.github/ @KumoLiu\n/monai/networks/schedulers/ @virginiafdez\n/monai/inferers/inferer.py @virginiafdez\n/monai/losses/adversarial_loss.py @virginiafdez\n/monai/losses/perceptual.py @virginiafdez\n/monai/networks/blocks/spade_norm.py @virginiafdez\n/monai/networks/nets/autoencoderkl.py @virginiafdez\n/monai/networks/nets/controlnet.py @virginiafdez\n/monai/networks/nets/diffusion_model_unet.py @virginiafdez\n/monai/networks/nets/patchgan_discriminator.py @virginiafdez\n/monai/networks/nets/spade_autoencoderkl.py @virginiafdez\n/monai/networks/nets/spade_diffusion_model_unet.py @virginiafdez\n/monai/networks/nets/spade_network.py @virginiafdez\n/monai/networks/nets/vqvae.py @virginiafdez\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Install '....'\n3. Run commands '....'\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Environment**\n\nEnsuring you use the relevant python executable, please paste the output of:\n\n```\npython -c \"import monai; monai.config.print_debug_info()\"\n```\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.md",
    "content": "---\nname: Question (please use the Discussion tab)\nabout: https://github.com/Project-MONAI/MONAI/discussions\ntitle: 'Please use MONAI Discussion tab for questions'\nlabels: ''\nassignees: ''\n---\n\n**Please use MONAI's Discussions tab**\nFor questions relating to MONAI usage, please do not create an issue.\n\nInstead, use [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions). This can be found next to Issues and Pull Requests along the top of our repository.\n"
  },
  {
    "path": ".github/codecov.yml",
    "content": "coverage:\n  status:\n    project:\n      default:\n        target: 70%\n        threshold: 10\n        base: parent\n        if_no_uploads: error\n        if_not_found: success\n        if_ci_failed: error\n        only_pulls: false\n        flags: null\n        paths: null\n    patch:\n      default:\n        target: auto\n        # Allows PRs without tests, overall stats count\n        threshold: 100\n        base: auto\n        if_no_uploads: error\n        if_not_found: success\n        if_ci_failed: error\n        only_pulls: false\n        flags: null\n        paths: null\n\ncomment: # enable code coverage comment on PR\n  layout: \"diff, flags, files\"\n  behavior: default\n  require_changes: false\n  require_base: false\n  require_head: true\n  hide_project_coverage: true\n\nignore:\n  - \"versioneer.py\"\n  - \"monai/_version.py\"\n"
  },
  {
    "path": ".github/dco.yml",
    "content": "allowRemediationCommits:\n  individual: true\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# Set update schedule for GitHub Actions\n\nversion: 2\nupdates:\n\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      # Check for updates to GitHub Actions every week\n      interval: \"monthly\"\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "Fixes # .\n\n### Description\n\nA few sentences describing the changes proposed in this pull request.\n\n### Types of changes\n<!--- Put an `x` in all the boxes that apply, and remove the not applicable items -->\n- [x] Non-breaking change (fix or new feature that would not break existing functionality).\n- [ ] Breaking change (fix or new feature that would cause existing functionality to change).\n- [ ] New tests added to cover the changes.\n- [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`.\n- [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests  --disttests`.\n- [ ] In-line docstrings updated.\n- [ ] Documentation updated, tested `make html` command in the `docs/` folder.\n"
  },
  {
    "path": ".github/workflows/blossom-ci.yml",
    "content": "# A workflow to trigger ci on hybrid infra (github + self hosted runner)\nname: Blossom-CI\non:\n  issue_comment:\n    types: [created]\n  workflow_dispatch:\n      inputs:\n          platform:\n            description: 'runs-on argument'\n            required: false\n          args:\n            description: 'argument'\n            required: false\n\npermissions:\n  actions: write\n  checks: write\n  contents: write\n  issues: write\n  pull-requests: write\n  repository-projects: write\n  statuses: write\n\njobs:\n  Authorization:\n    name: Authorization\n    runs-on: blossom\n    outputs:\n      args: ${{ env.args }}\n\n    # This job only runs for pull request comments\n    if: |\n      github.event.comment.body == '/build' &&\n      (\n        github.actor == 'Nic-Ma' ||\n        github.actor == 'wyli' ||\n        github.actor == 'wendell-hom' ||\n        github.actor == 'KumoLiu'\n      )\n    steps:\n      - name: Check if comment is issued by authorized person\n        run: blossom-ci\n        env:\n          OPERATION: 'AUTH'\n          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}\n\n  Vulnerability-scan:\n    name: Vulnerability scan\n    needs: [Authorization]\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          repository: ${{ fromJson(needs.Authorization.outputs.args).repo }}\n          ref: ${{ fromJson(needs.Authorization.outputs.args).ref }}\n          lfs: 'true'\n\n      # repo specific steps\n      #- name: Setup java\n      #  uses: actions/setup-java@v1\n      #  with:\n      #    java-version: '1.8'\n\n      # add blackduck properties https://synopsys.atlassian.net/wiki/spaces/INTDOCS/pages/631308372/Methods+for+Configuring+Analysis#Using-a-configuration-file\n      #- name: Setup blackduck properties\n      #  run: |\n      #       PROJECTS=$(mvn -am dependency:tree | grep maven-dependency-plugin | awk '{ out=\"com.nvidia:\"$(NF-1);print out }' | grep rapids | xargs | sed -e 's/ /,/g')\n      #       echo detect.maven.build.command=\"-pl=$PROJECTS -am\" >> application.properties\n      #       echo detect.maven.included.scopes=compile >> application.properties\n      - name: Setup blackduck properties\n        run: |\n             echo detect.excluded.detector.types=PIP >> application.properties\n\n      - name: Run blossom action\n        uses: NVIDIA/blossom-action@main\n        env:\n          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          REPO_KEY_DATA: ${{ secrets.BLOSSOM_KEY }}\n        with:\n          args1: ${{ fromJson(needs.Authorization.outputs.args).args1 }}\n          args2: ${{ fromJson(needs.Authorization.outputs.args).args2 }}\n          args3: ${{ fromJson(needs.Authorization.outputs.args).args3 }}\n\n  Job-trigger:\n    name: Start ci job\n    needs: [Vulnerability-scan]\n    runs-on: blossom\n    steps:\n      - name: Start ci job\n        run: blossom-ci\n        env:\n          OPERATION: 'START-CI-JOB'\n          CI_SERVER: ${{ secrets.CI_SERVER }}\n          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\n  Post-processing:\n    name: Post processing\n    runs-on: blossom\n    if : github.event_name == 'workflow_dispatch'\n    steps:\n      - name: Start post processing\n        run: blossom-ci\n        env:\n          OPERATION: 'POST-PROCESSING'\n          CI_SERVER: ${{ secrets.CI_SERVER }}\n          REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/chatops.yml",
    "content": "# triggering the workflows by commenting `/black` and `/integration-test`\nname: chatops\n\n# currently dispatches /black command to project-monai/monai-code-formatter\non:\n  issue_comment:\n    types: [created, edited]\njobs:\n  dispatch_command:\n    runs-on: ubuntu-latest\n    steps:\n      - name: dispatch\n        uses: peter-evans/slash-command-dispatch@v5.0.2\n        with:\n          token: ${{ secrets.PR_MAINTAIN }}\n          reaction-token: ${{ secrets.GITHUB_TOKEN }}\n          reactions: false\n          config: >\n            [\n              {\n                \"command\": \"black\",\n                \"permission\": \"none\",\n                \"issue_type\": \"pull-request\",\n                \"allow_edits\": true,\n                \"repository\": \"project-monai/monai-code-formatter\"\n              },\n              {\n                \"command\": \"integration-test\",\n                \"permission\": \"none\",\n                \"issue_type\": \"pull-request\",\n                \"allow_edits\": true\n              }\n            ]\n"
  },
  {
    "path": ".github/workflows/codeql-analysis.yml",
    "content": "# For most projects, this workflow file will not need changing; you simply need\n# to commit it to your repository.\n#\n# You may wish to alter this file to override the set of languages analyzed,\n# or to provide custom queries or build logic.\n#\n# ******** NOTE ********\n# We have attempted to detect the languages in your repository. Please check\n# the `language` matrix defined below to confirm you have the correct set of\n# supported CodeQL languages.\n#\nname: \"CodeQL\"\n\non:\n  push:\n    branches: [ dev, main ]\n  pull_request:\n    # The branches below must be a subset of the branches above\n    branches: [ dev ]\n  schedule:\n    - cron: '18 1 * * 0'\n\njobs:\n  analyze:\n    name: Analyze\n    runs-on: ubuntu-latest\n    permissions:\n      actions: read\n      contents: read\n      security-events: write\n\n    strategy:\n      fail-fast: false\n      matrix:\n        language: [ 'cpp', 'python' ]\n        # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]\n        # Learn more about CodeQL language support at https://git.io/codeql-language-support\n\n    steps:\n    - name: Checkout repository\n      uses: actions/checkout@v6\n\n    # Initializes the CodeQL tools for scanning.\n    - name: Initialize CodeQL\n      uses: github/codeql-action/init@v4\n      with:\n        languages: ${{ matrix.language }}\n        # If you wish to specify custom queries, you can do so here or in a config file.\n        # By default, queries listed here will override any specified in a config file.\n        # Prefix the list here with \"+\" to use these queries and those in the config file.\n        # queries: ./path/to/local/query, your-org/your-repo/queries@main\n\n    # Autobuild attempts to build any compiled languages  (C/C++, C#, or Java).\n    # If this step fails, then you should remove it and run the build manually (see below)\n    # - name: Autobuild\n    #   uses: github/codeql-action/autobuild@v2\n\n    # ℹ️ Command-line programs to run using the OS shell.\n    # 📚 https://git.io/JvXDl\n\n    # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines\n    #    and modify them (or add more) to build your code if your project\n    #    uses a compiled language\n\n    - name: Build\n      run: |\n        rm -rf /opt/hostedtoolcache/{node,go,Ruby,Java*}\n        ls -al /opt/hostedtoolcache\n        rm -rf /usr/share/dotnet/\n        python -m pip install -U pip wheel\n        python -m pip install -r requirements-dev.txt\n        BUILD_MONAI=1 ./runtests.sh --build\n\n    - name: Perform CodeQL Analysis\n      uses: github/codeql-action/analyze@v4\n"
  },
  {
    "path": ".github/workflows/conda.yml",
    "content": "# daily tests for different OS with conda\nname: cron-conda\n\non:\n  schedule:\n    - cron: \"0 3 * * *\"  # at 03:00 UTC\n  # Allows you to run this workflow manually from the Actions tab\n  workflow_dispatch:\n\nconcurrency:\n  # automatically cancel the previously triggered workflows when there's a newer version\n  group: conda-tests-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  cron-conda:\n    if: github.repository == 'Project-MONAI/MONAI'\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu-latest]\n        python-version: [\"3.9\", \"3.10\"]\n    runs-on: ${{ matrix.os }}\n    timeout-minutes: 46 # equal to max + 3*std over the last 600 successful runs\n    env:\n      QUICKTEST: True\n    steps:\n    - if: runner.os == 'windows'\n      name: Config pagefile (Windows only)\n      uses: al-cheb/configure-pagefile-action@v1.5\n      with:\n        minimum-size: 8GB\n        maximum-size: 16GB\n        disk-root: \"D:\"\n    - uses: actions/checkout@v6\n    - name: Clean up disk space\n      run: |\n        find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        rm -rf /usr/share/dotnet/\n    - uses: conda-incubator/setup-miniconda@v3\n      with:\n        auto-update-conda: true\n        python-version: ${{ matrix.python-version }}\n        auto-activate-base: false\n        environment-file: environment-dev.yml\n        activate-environment: monai\n    - name: Env info (CPU ${{ runner.os }})\n      shell: bash -el {0}\n      run: |\n        conda info\n        conda list\n    - if: runner.os == 'windows'\n      name: Windows only install\n      shell: bash -el {0}\n      run: |\n        conda activate monai\n        # this `cpuonly` and -c conda-forge is needed to reduce the paging file size on a github instance\n        # force to install `cpuonly==2.0.0` is to fix the same issue as:\n        # https://github.com/pytorch/vision/issues/4240\n        conda install pytorch torchvision torchaudio cpuonly==2.0.0 -c pytorch -c conda-forge\n        conda deactivate\n    - name: Test env (CPU ${{ runner.os }})\n      shell: bash -el {0}\n      env:\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: |\n        conda activate monai\n        $(pwd)/runtests.sh --build --unittests\n        conda deactivate\n"
  },
  {
    "path": ".github/workflows/cron-ngc-bundle.yml",
    "content": "# daily tests for ngc bundles\nname: cron-ngc-bundle\n\non:\n  schedule:\n    - cron: \"0 2 * * *\"  # at 02:00 UTC\n  # Allows you to run this workflow manually from the Actions tab\n  workflow_dispatch:\n\nconcurrency:\n  # automatically cancel the previously triggered workflows when there's a newer version\n  group: bundle-tests-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  cron-load:\n    if: github.repository == 'Project-MONAI/MONAI'\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: ~/.cache/pip\n        key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install dependencies\n      run: |\n        rm -rf /github/home/.cache/torch/hub/bundle/\n        python -m pip install --upgrade pip wheel\n        python -m pip install -r requirements-dev.txt\n    - name: Loading Bundles\n      run: |\n        # clean up temporary files\n        $(pwd)/runtests.sh --build --clean\n        # run tests\n        python -m tests.ngc_bundle_download\n"
  },
  {
    "path": ".github/workflows/cron.yml",
    "content": "# nightly: Jenkinsfile.monai-pytorch-versions, monai-latest-image, monai-pip, monai-latest-docker, monai-notebooks\nname: nightly-crons\n\non:\n  # schedule:\n  #   - cron: \"0 2 * * *\"  # at 02:00 UTC\n  # Allows you to run this workflow manually from the Actions tab\n  workflow_dispatch:\n\njobs:\n  cron-gpu:\n    if: github.repository == 'Project-MONAI/MONAI'\n    strategy:\n      matrix:\n        environment:\n          - \"PT230+CUDA121\"\n          - \"PT240+CUDA126\"\n          - \"PTLATEST+CUDA126\"\n        include:\n          # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes\n          - environment: PT230+CUDA121\n            pytorch: \"pytorch==2.3.0 torchvision==0.18.0 --extra-index-url https://download.pytorch.org/whl/cu121\"\n            base: \"nvcr.io/nvidia/pytorch:23.08-py3\"  # CUDA 12.1\n          - environment: PT240+CUDA126\n            pytorch: \"pytorch==2.4.0 torchvision==0.19.0 --extra-index-url https://download.pytorch.org/whl/cu121\"\n            base: \"nvcr.io/nvidia/pytorch:24.08-py3\"  # CUDA 12.6\n          - environment: PTLATEST+CUDA126\n            pytorch: \"-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121\"\n            base: \"nvcr.io/nvidia/pytorch:24.10-py3\"  # CUDA 12.6\n    container:\n      image: ${{ matrix.base }}\n      options: \"--gpus all\"\n    runs-on: [self-hosted, linux, x64, common]\n    steps:\n    - uses: actions/checkout@v6\n    - name: apt install\n      run: |\n        apt-get update\n        apt-get install -y wget\n    - name: Install the dependencies\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n        python -m pip uninstall -y torch torchvision\n        python -m pip install ${{ matrix.pytorch }}\n        python -m pip install -r requirements-dev.txt\n        python -m pip list\n    - name: Run tests report coverage\n      env:\n          NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n          NGC_ORG: ${{ secrets.NGC_ORG }}\n          NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: |\n        export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]\n        echo \"Sleep $LAUNCH_DELAY\"\n        sleep $LAUNCH_DELAY\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        trap 'if pgrep python; then pkill python; fi;' ERR\n        python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5, 3, device=torch.device(\"cuda:0\")))'\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests  # unit tests with coverage report\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --net  # integration tests with coverage report\n        coverage xml --ignore-errors\n        if pgrep python; then pkill python; fi\n      shell: bash\n    - name: Upload coverage\n      uses: codecov/codecov-action@v5\n      with:\n        fail_ci_if_error: false\n        files: ./coverage.xml\n\n  cron-pt-image:\n    if: github.repository == 'Project-MONAI/MONAI'\n    strategy:\n      matrix:\n        container: [\"pytorch:23.08\", \"pytorch:24.08\", \"pytorch:24.10\"]\n    container:\n      image: nvcr.io/nvidia/${{ matrix.container }}-py3  # testing with the latest pytorch base image\n      options: \"--gpus all\"\n    runs-on: [self-hosted, linux, x64, integration]\n    steps:\n    - uses: actions/checkout@v6\n    - name: Install APT dependencies\n      run: |\n        apt-get update\n        DEBIAN_FRONTEND=\"noninteractive\" apt-get install -y libopenslide0\n    - name: Install Python dependencies\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n        python -m pip install -r requirements-dev.txt\n        python -m pip list\n    - name: Run tests report coverage\n      env:\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: |\n        export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]\n        echo \"Sleep $LAUNCH_DELAY\"\n        sleep $LAUNCH_DELAY\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        trap 'if pgrep python; then pkill python; fi;' ERR\n        python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5, 3, device=torch.device(\"cuda:0\")))'\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests  # unit tests with coverage report\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --net  # integration tests with coverage report\n        coverage xml --ignore-errors\n        if pgrep python; then pkill python; fi\n      shell: bash\n    - name: Upload coverage\n      uses: codecov/codecov-action@v5\n      with:\n        fail_ci_if_error: false\n        files: ./coverage.xml\n\n  cron-pip:\n    # pip install monai[all] and use it to run unit tests\n    if: github.repository == 'Project-MONAI/MONAI'\n    strategy:\n      matrix:\n        container: [\"pytorch:24.10\"]\n    container:\n      image: nvcr.io/nvidia/${{ matrix.container }}-py3  # testing with the latest pytorch base image\n      options: \"--gpus all\"\n    runs-on: [self-hosted, linux, x64, integration]\n    steps:\n      - uses: actions/checkout@v6\n        with:\n          fetch-depth: 0\n      - name: Install the dependencies\n        run: |\n          which python\n          python -m pip install --upgrade pip wheel twine\n          python -m pip list\n      - name: Run tests report coverage\n        shell: bash\n        run: |\n          pip uninstall monai\n          pip list | grep -iv monai\n          git fetch --depth=1 origin +refs/tags/*:refs/tags/*\n          root_dir=$PWD\n          echo \"$root_dir\"\n          set -e\n\n          # build tar.gz and wheel\n          bash runtests.sh --clean  # clear any existing dev temp files\n          python -m pip uninstall -y torch torchvision\n          python setup.py check -m -s\n          python setup.py sdist bdist_wheel\n          python -m twine check dist/*\n\n          # move packages to a temp dir\n          tmp_dir=$(mktemp -d)\n          cp dist/monai* \"$tmp_dir\"\n          rm -r build dist monai.egg-info\n          cd \"$tmp_dir\"\n          ls -al\n\n          # install from tar.gz\n          name=$(ls *.tar.gz | head -n1)\n          echo $name\n          python -m pip install $name[all]\n          python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv \"unknown\"\n          python -c 'import monai; print(monai.__file__)'\n\n          # run tests\n          cp $root_dir/requirements*.txt \"$tmp_dir\"\n          cp -r $root_dir/tests \"$tmp_dir\"\n          pwd\n          ls -al\n\n          export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]\n          echo \"Sleep $LAUNCH_DELAY\"\n          sleep $LAUNCH_DELAY\n          nvidia-smi\n          export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)\n          echo $CUDA_VISIBLE_DEVICES\n          trap 'if pgrep python; then pkill python; fi;' ERR\n          python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n          python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n\n          python -m pip install -r requirements-dev.txt\n          PYTHONPATH=\"$tmp_dir\":$PYTHONPATH BUILD_MONAI=1 python ./tests/runner.py -p 'test_((?!integration).)'  # unit tests\n          if pgrep python; then pkill python; fi\n\n  cron-docker:\n    if: github.repository == 'Project-MONAI/MONAI'\n    container:\n      image: docker://projectmonai/monai:latest  # this might be slow and has the pull count limitations\n      options: \"--gpus all\"\n    runs-on: [self-hosted, linux, x64, integration]\n    steps:\n    - name: Run tests report coverage\n      # The docker image process has done the compilation.\n      # BUILD_MONAI=1 is necessary for triggering the USE_COMPILED flag.\n      env:\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: |\n        cd /opt/monai\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        trap 'if pgrep python; then pkill python; fi;' ERR\n        python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5,3, device=torch.device(\"cuda:0\")))'\n        ngc --version\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --pytype --unittests --disttests  # unit tests with pytype checks, coverage report\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --net  # integration tests with coverage report\n        coverage xml --ignore-errors\n        if pgrep python; then pkill python; fi\n      shell: bash\n    - name: Upload coverage\n      uses: codecov/codecov-action@v5\n      with:\n        fail_ci_if_error: false\n        files: ./coverage.xml\n\n  cron-tutorial-notebooks:\n    if: github.repository == 'Project-MONAI/MONAI'\n    needs: cron-gpu  # so that monai itself is verified first\n    container:\n      image: nvcr.io/nvidia/pytorch:24.10-py3  # testing with the latest pytorch base image\n      options: \"--gpus all --ipc=host\"\n    runs-on: [self-hosted, linux, x64, integration]\n    steps:\n    - uses: actions/checkout@v6\n    - name: Install MONAI\n      id: monai-install\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n        python -m pip install -r requirements-dev.txt\n        BUILD_MONAI=1 python setup.py develop  # install monai\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        echo \"devices=$CUDA_VISIBLE_DEVICES\" >> $GITHUB_OUTPUT\n    - name: Checkout tutorials and install their requirements\n      run: |\n        cd /opt\n        git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/tutorials.git  # latest commit of main branch\n        cd tutorials\n        python -m pip install -r requirements.txt\n    - name: Run tutorial notebooks\n      timeout-minutes: 150\n      run: |\n        export CUDA_VISIBLE_DEVICES=${{ steps.monai-install.outputs.devices }}\n        echo $CUDA_VISIBLE_DEVICES\n        trap 'if pgrep python; then pkill python; fi;' ERR\n        python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n        cd /opt/tutorials\n        python -c 'import monai; monai.config.print_debug_info()'\n        $(pwd)/runner.sh\n        python -c 'import monai; monai.config.print_debug_info()'\n        if pgrep python; then pkill python; fi\n      shell: bash\n"
  },
  {
    "path": ".github/workflows/docker.yml",
    "content": "# this is the docker image releasing pipeline, pushing to https://hub.docker.com/r/projectmonai/monai\nname: docker\n# versioning: compute a static version file\n# local_docker: use the version file to build docker images\n# docker_test_latest: test the latest internal docker image (has flake)\n# docker_test_dockerhub: test the latest dockerhub release (no flake)\non:\n  # dev only docker deployment and quick tests\n  push:\n    branches:\n      - dev\n  # Allows you to run this workflow manually from the Actions tab\n  # This is to trigger building/testing docker image from dev only.\n  workflow_dispatch:\n\njobs:\n  versioning_dev:\n    # compute versioning file from python setup.py\n    # upload as artifact\n    # if: github.repository == 'Project-MONAI/MONAI'\n    if: ${{ false }}  # disable docker build job  project-monai/monai#7450\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6\n        # full history so that we can git describe\n        with:\n          ref: dev\n          fetch-depth: 0\n      - name: Set up Python 3.9\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.9'\n      - shell: bash\n        run: |\n          git describe\n          python -m pip install -U pip wheel setuptools\n          python setup.py build\n          cat build/lib/monai/_version.py\n      - name: Upload version\n        uses: actions/upload-artifact@v6\n        with:\n          name: _version.py\n          path: build/lib/monai/_version.py\n      - name: Clean up directory\n        shell: bash\n        run: |\n          ls -al\n          rm -rf {*,.[^.]*}\n\n  docker_build_dev:\n    # if: github.repository == 'Project-MONAI/MONAI'\n    if: ${{ false }}  # disable docker build job  project-monai/monai#7450\n    needs: versioning_dev\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v6\n      with:\n        ref: dev\n    - name: Download version\n      uses: actions/download-artifact@v6\n      with:\n        name: _version.py\n    - name: docker_build\n      shell: bash\n      run: |\n        find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        docker --version\n        # get tag info for versioning\n        cat _version.py\n        mv _version.py monai/\n\n        # build \"latest\": remove flake package as it is not needed on hub.docker.com\n        sed -i '/flake/d' requirements-dev.txt\n        docker build -t projectmonai/monai:latest -f Dockerfile .\n\n        # distribute as always w/ tag \"latest\" to hub.docker.com\n        echo \"${{ secrets.DOCKER_PW }}\" | docker login -u projectmonai --password-stdin\n\n        docker push projectmonai/monai:latest\n        docker logout\n        docker image prune -f\n\n  docker_test_dockerhub:\n    # if: github.repository == 'Project-MONAI/MONAI'\n    if: ${{ false }}  # disable self-hosted job  project-monai/monai#7039\n    needs: docker_build_dev\n    container:\n      image: docker://projectmonai/monai:latest\n      options: \"--shm-size=4g --ipc=host\"\n    runs-on: [self-hosted, linux, X64, docker]\n    steps:\n    - name: Import\n      run: |\n        export OMP_NUM_THREADS=4 MKL_NUM_THREADS=4 CUDA_VISIBLE_DEVICES=  # cpu-only\n        python -c 'import monai; monai.config.print_debug_info()'\n        cd /opt/monai\n        ls -al\n        ngc --version\n        ./runtests.sh --min\n      shell: bash\n      env:\n        QUICKTEST: True\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n"
  },
  {
    "path": ".github/workflows/integration.yml",
    "content": "# manually trigger integration with the latest pytorch\nname: integration\n\non:\n  repository_dispatch:\n    type: [integration-test-command]\n\njobs:\n  integration-auto3dseg:\n    container:\n      image: nvcr.io/nvidia/pytorch:22.04-py3  # CUDA 11.6 py38\n      options: --gpus \"device=1\" --ipc host  # shm-size 4g works fine\n    runs-on: [self-hosted, linux, x64, command]\n    steps:\n    # checkout the pull request branch\n    - uses: actions/checkout@v6\n      with:\n        token: ${{ secrets.PR_MAINTAIN }}\n        repository: ${{ github.event.client_payload.pull_request.head.repo.full_name }}\n        ref: ${{ github.event.client_payload.pull_request.head.ref }}\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: |\n          ~/.cache/pip\n          ~/.cache/torch\n        key: docker-py3-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        pwd && git log -1 && which python\n        python -m pip install --upgrade pip wheel\n        pip uninstall -y monai\n        pip uninstall -y monai\n        pip uninstall -y monai-weekly\n        pip uninstall -y monai-weekly\n        python -m pip install --upgrade torch torchvision torchaudio torchtext\n        python -m pip install -r requirements-dev.txt\n        rm -rf /github/home/.cache/torch/hub/mmars/\n    - name: Clean directory\n      run: |\n        python -m pip list\n        git config --global --add safe.directory /__w/MONAI/MONAI\n        git clean -ffdx && git reset --hard HEAD\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils -c 1 | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5,3, device=torch.device(\"cuda:0\")))'\n\n    - name: Auto3dseg tag algo\n      shell: bash\n      env:\n        BUILD_MONAI: 0\n      run: |\n        pwd && git log -1 && which python\n        ./runtests.sh -b\n        export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n        python -m tests.test_auto3dseg_bundlegen\n        python -m tests.test_auto3dseg_ensemble\n        python -m tests.test_auto3dseg_hpo\n        python -m tests.test_integration_autorunner\n        python -m tests.test_integration_gpu_customization\n    - name: Integration tests\n      shell: bash\n      env:\n        BUILD_MONAI: 1\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: ./runtests.sh --build --net\n\n    - name: Add reaction\n      uses: peter-evans/create-or-update-comment@v5\n      with:\n        token: ${{ secrets.PR_MAINTAIN }}\n        repository: ${{ github.event.client_payload.github.payload.repository.full_name }}\n        comment-id: ${{ github.event.client_payload.github.payload.comment.id }}\n        reactions: rocket\n\n\n  integration-unit:\n    container:\n      image: nvcr.io/nvidia/pytorch:22.04-py3  # CUDA 11.6 py38\n      options: --gpus \"device=2\" --ipc host  # shm-size 4g works fine\n    runs-on: [self-hosted, linux, x64, command1]\n    steps:\n    # checkout the pull request branch\n    - uses: actions/checkout@v6\n      with:\n        token: ${{ secrets.PR_MAINTAIN }}\n        repository: ${{ github.event.client_payload.pull_request.head.repo.full_name }}\n        ref: ${{ github.event.client_payload.pull_request.head.ref }}\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: |\n          ~/.cache/pip\n          ~/.cache/torch\n        key: docker-py3-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        pwd && git log -1 && which python\n        python -m pip install --upgrade pip wheel\n        pip uninstall -y monai\n        pip uninstall -y monai\n        pip uninstall -y monai-weekly\n        pip uninstall -y monai-weekly\n        python -m pip install --upgrade torch torchvision torchaudio torchtext\n        python -m pip install -r requirements-dev.txt\n        rm -rf /github/home/.cache/torch/hub/mmars/\n    - name: Clean directory\n      run: |\n        python -m pip list\n        git config --global --add safe.directory /__w/MONAI/MONAI\n        git clean -ffdx && git reset --hard HEAD\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils -c 1 | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5,3, device=torch.device(\"cuda:0\")))'\n\n    - name: Auto3dseg latest algo\n      shell: bash\n      env:\n        BUILD_MONAI: 0\n      run: |\n        pwd\n        cd ../\n        rm -rf research-contributions\n        rm -rf algorithm_templates\n        git clone --depth 1 --branch main --single-branch https://github.com/Project-MONAI/research-contributions.git\n        ls research-contributions/\n        cp -r research-contributions/auto3dseg/algorithm_templates MONAI/\n        cd research-contributions && git log -1 && cd ../MONAI\n        pwd\n        ls -ll\n        export OMP_NUM_THREADS=4\n        export MKL_NUM_THREADS=4\n        export MONAI_TESTING_ALGO_TEMPLATE=algorithm_templates\n        pwd && git log -1 && which python\n        ./runtests.sh -b\n        export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n        python -m tests.test_auto3dseg_ensemble\n        python -m tests.test_auto3dseg_hpo\n        python -m tests.test_integration_autorunner\n        python -m tests.test_integration_gpu_customization\n\n    - name: Add reaction\n      uses: peter-evans/create-or-update-comment@v5\n      with:\n        token: ${{ secrets.PR_MAINTAIN }}\n        repository: ${{ github.event.client_payload.github.payload.repository.full_name }}\n        comment-id: ${{ github.event.client_payload.github.payload.comment.id }}\n        reactions: +1\n"
  },
  {
    "path": ".github/workflows/pythonapp-gpu.yml",
    "content": "# Jenkinsfile.monai-premerge\nname: premerge-gpu\n\non:\n  # quick tests for pull requests and the releasing branches\n  push:\n    branches:\n      - main\n      - releasing/*\n  pull_request:\n    types: [opened, synchronize, closed]\n\nconcurrency:\n  # automatically cancel the previously triggered workflows when there's a newer version\n  group: build-gpu-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  GPU-quick-py3:  # GPU with full dependencies\n    # if: ${{ github.repository == 'Project-MONAI/MONAI' && github.event.pull_request.merged != true }}\n    if: ${{ false }}  # disable self-hosted job project-monai/monai#7039\n    strategy:\n      matrix:\n        environment:\n          - \"PT230+CUDA124DOCKER\"\n          - \"PT240+CUDA125DOCKER\"\n          - \"PT250+CUDA126DOCKER\"\n        include:\n          # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes\n          - environment: PT230+CUDA124DOCKER\n            # 24.04: 2.3.0a0+6ddf5cf85e\n            pytorch: \"-h\"  # we explicitly set pytorch to -h to avoid pip install error\n            base: \"nvcr.io/nvidia/pytorch:24.04-py3\"\n          - environment: PT240+CUDA125DOCKER\n            # 24.06: 2.4.0a0+f70bd71a48\n            pytorch: \"-h\"  # we explicitly set pytorch to -h to avoid pip install error\n            base: \"nvcr.io/nvidia/pytorch:24.06-py3\"\n          - environment: PT250+CUDA126DOCKER\n            # 24.08: 2.5.0a0+872d972e41\n            pytorch: \"-h\"  # we explicitly set pytorch to -h to avoid pip install error\n            base: \"nvcr.io/nvidia/pytorch:24.08-py3\"\n    container:\n      image: ${{ matrix.base }}\n      options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true  # workaround for unsatisfied condition: cuda>=11.6\n    runs-on: [self-hosted, linux, x64, common]\n    steps:\n    - uses: actions/checkout@v6\n    - name: apt install\n      if: github.event.pull_request.merged != true\n      run: |\n        apt-get update\n        apt-get install -y wget\n\n        if [ ${{ matrix.environment }} = \"PT230+CUDA124\" ]\n        then\n        PYVER=3.9 PYSFX=3 DISTUTILS=python3-distutils && \\\n        apt-get update && apt-get install -y --no-install-recommends \\\n          curl \\\n          pkg-config \\\n          python$PYVER \\\n          python$PYVER-dev \\\n          python$PYSFX-pip \\\n          $DISTUTILS \\\n          rsync \\\n          swig \\\n          unzip \\\n          zip \\\n          zlib1g-dev \\\n          libboost-locale-dev \\\n          libboost-program-options-dev \\\n          libboost-system-dev \\\n          libboost-thread-dev \\\n          libboost-test-dev \\\n          libgoogle-glog-dev \\\n          libjsoncpp-dev \\\n          cmake \\\n          git && \\\n        rm -rf /var/lib/apt/lists/* && \\\n        export PYTHONIOENCODING=utf-8 LC_ALL=C.UTF-8 && \\\n        rm -f /usr/bin/python && \\\n        rm -f /usr/bin/python`echo $PYVER | cut -c1-1` && \\\n        ln -s /usr/bin/python$PYVER /usr/bin/python && \\\n        ln -s /usr/bin/python$PYVER /usr/bin/python`echo $PYVER | cut -c1-1` &&\n        curl -O https://bootstrap.pypa.io/get-pip.py && \\\n        python get-pip.py && \\\n        rm get-pip.py;\n        fi\n    - name: Install dependencies\n      if: github.event.pull_request.merged != true\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n        # fixes preinstalled ruamel_yaml error from the docker image\n        rm -rf $(python -c \"from distutils.sysconfig import get_python_lib; print(get_python_lib())\")/ruamel*\n        rm -rf $(python -c \"from distutils.sysconfig import get_python_lib; print(get_python_lib())\")/llvmlite*  #6377\n        python -m pip install ${{ matrix.pytorch }}\n        python -m pip install -r requirements-dev.txt\n        python -m pip list\n    - name: Run quick tests (GPU)\n      if: github.event.pull_request.merged != true\n      run: |\n        git clone --depth 1 \\\n          https://github.com/Project-MONAI/MONAI-extra-test-data.git /MONAI-extra-test-data\n        export MONAI_EXTRA_TEST_DATA=\"/MONAI-extra-test-data\"\n        nvidia-smi\n        export LAUNCH_DELAY=$(python -c \"import numpy; print(numpy.random.randint(30) * 10)\")\n        echo \"Sleep $LAUNCH_DELAY\"\n        sleep $LAUNCH_DELAY\n        export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        trap 'if pgrep python; then pkill python; fi;' ERR\n        python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5, 3, device=torch.device(\"cuda:0\")))'\n        python -c \"import monai; monai.config.print_config()\"\n        # build for the current self-hosted CI Tesla V100\n        BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST=\"7.0\" ./runtests.sh --build --disttests\n        ./runtests.sh --quick --unittests\n        if [ ${{ matrix.environment }} = \"PT230+CUDA124\" ]; then\n          # test the clang-format tool downloading once\n          coverage run -m tests.clang_format_utils\n        fi\n        coverage xml --ignore-errors\n        if pgrep python; then pkill python; fi\n      shell: bash\n    - name: Upload coverage\n      if: ${{ github.head_ref != 'dev' && github.event.pull_request.merged != true }}\n      uses: codecov/codecov-action@v5\n      with:\n        files: ./coverage.xml\n"
  },
  {
    "path": ".github/workflows/pythonapp-min.yml",
    "content": "# Jenkinsfile.monai-premerge\nname: premerge-min\n\non:\n  # quick tests for pull requests and the releasing branches\n  push:\n    branches:\n      - dev\n      - main\n      - releasing/*\n  pull_request:\n    head_ref-ignore:\n      - dev\n\nconcurrency:\n  # automatically cancel the previously triggered workflows when there's a newer version\n  group: build-min-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  # caching of these jobs:\n  #   - docker-py3-pip- (shared)\n  #   - ubuntu py37 pip-\n  #   - os-latest-pip- (shared)\n  min-dep-os:  # min dependencies installed tests for different OS\n    runs-on: ${{ matrix.os }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [windows-latest, macOS-latest, ubuntu-latest]\n    timeout-minutes: 40\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n    - name: Prepare pip wheel\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n        echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      shell: bash\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: ${{ steps.pip-cache.outputs.dir }}\n        key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        # min. requirements\n        python -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n        python -m pip install -r requirements-min.txt\n        python -m pip list\n        BUILD_MONAI=0 python setup.py develop  # no compile of extensions\n      shell: bash\n    - name: Run quick tests (CPU ${{ runner.os }})\n      run: |\n        python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'\n        python -c \"import monai; monai.config.print_config()\"\n        ./runtests.sh --min\n      shell: bash\n      env:\n        QUICKTEST: True\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n\n  min-dep-py3:  # min dependencies installed tests for different python\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: ['3.9', '3.10', '3.11', '3.12']\n    timeout-minutes: 40\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v6\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Prepare pip wheel\n      run: |\n        which python\n        python -m pip install --user --upgrade pip setuptools wheel\n        python -m pip install --user more-itertools>=8.0\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n        echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      shell: bash\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: ${{ steps.pip-cache.outputs.dir }}\n        key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        # min. requirements\n        python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu\n        python -m pip install -r requirements-min.txt\n        python -m pip list\n        BUILD_MONAI=0 python setup.py develop  # no compile of extensions\n      shell: bash\n    - name: Run quick tests (CPU ${{ runner.os }})\n      run: |\n        python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'\n        python -c \"import monai; monai.config.print_config()\"\n        ./runtests.sh --min\n      env:\n        QUICKTEST: True\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n\n  min-dep-pytorch:  # min dependencies installed tests for different pytorch\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        pytorch-version: ['2.5.1', '2.6.0', '2.7.1', '2.8.0']\n    timeout-minutes: 40\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n    - name: Prepare pip wheel\n      run: |\n        which python\n        python -m pip install --user --upgrade pip setuptools wheel\n        python -m pip install --user more-itertools>=8.0\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n        echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      shell: bash\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: ${{ steps.pip-cache.outputs.dir }}\n        key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        # min. requirements\n        python -m pip install torch==${{ matrix.pytorch-version }}\n        python -m pip install -r requirements-min.txt\n        python -m pip list\n        BUILD_MONAI=0 python setup.py develop  # no compile of extensions\n      shell: bash\n    - name: Run quick tests (pytorch ${{ matrix.pytorch-version }})\n      run: |\n        python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'\n        python -c \"import monai; monai.config.print_config()\"\n        ./runtests.sh --min\n      env:\n        QUICKTEST: True\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n"
  },
  {
    "path": ".github/workflows/pythonapp.yml",
    "content": "# Jenkinsfile.monai-premerge\nname: premerge\n\non:\n  # quick tests for pull requests and the releasing branches\n  push:\n    branches:\n      - dev\n      - main\n      - releasing/*\n  pull_request:\n    head_ref-ignore:\n      - dev\n\nconcurrency:\n  # automatically cancel the previously triggered workflows when there's a newer version\n  group: build-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  # caching of these jobs:\n  #   - docker-py3-pip- (shared)\n  #   - ubuntu py37 pip-\n  #   - os-latest-pip- (shared)\n  flake8-py3:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        opt: [\"codeformat\", \"pytype\", \"mypy\"]\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n        cache: 'pip'\n    - name: Install dependencies\n      run: |\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        python -m pip install --upgrade pip wheel\n        python -m pip install --no-build-isolation -r requirements-dev.txt\n    - name: Lint and type check\n      run: |\n        # clean up temporary files\n        $(pwd)/runtests.sh --build --clean\n        # Github actions have 2 cores, so parallelize pytype\n        $(pwd)/runtests.sh --build --${{ matrix.opt }} -j 2\n\n  quick-py3:  # full dependencies installed tests for different OS\n    runs-on: ${{ matrix.os }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [windows-latest, macOS-latest, ubuntu-latest]\n    timeout-minutes: 120\n    steps:\n    - if: runner.os == 'windows'\n      name: Config pagefile (Windows only)\n      uses: al-cheb/configure-pagefile-action@v1.5\n      with:\n        minimum-size: 8GB\n        maximum-size: 16GB\n        disk-root: \"D:\"\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n        cache: 'pip'\n    - name: Prepare pip wheel\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n    - if: runner.os == 'windows'\n      name: Install torch cpu from pytorch.org (Windows only)\n      run: |\n        python -m pip install torch==2.5.1 torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu\n    - if: runner.os == 'Linux'\n      name: Install itk pre-release (Linux only)\n      run: |\n        python -m pip install --pre -U itk\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n    - name: Install the dependencies\n      run: |\n        python -m pip install --user --upgrade pip wheel\n        python -m pip install torch==2.5.1 torchvision==0.20.1\n        cat \"requirements-dev.txt\"\n        python -m pip install --no-build-isolation -r requirements-dev.txt\n        python -m pip list\n        python -m pip install -e .  # test no compile installation\n      shell: bash\n    - name: Run compiled (${{ runner.os }})\n      run: |\n        python -m pip uninstall -y monai\n        BUILD_MONAI=1 python -m pip install -e .  # compile the cpp extensions\n      shell: bash\n    - name: Run quick tests (CPU ${{ runner.os }})\n      run: |\n        python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'\n        python -c \"import monai; monai.config.print_config()\"\n        python -m unittest -v\n      env:\n        QUICKTEST: True\n        PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python  # https://github.com/Project-MONAI/MONAI/issues/4354\n\n  packaging:\n    runs-on: ubuntu-latest\n    env:\n      QUICKTEST: True\n      shell: bash\n    steps:\n    - uses: actions/checkout@v6\n      with:\n        fetch-depth: 0\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n        cache: 'pip'\n    - name: Install dependencies\n      run: |\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        python -m pip install --user --upgrade pip setuptools wheel twine packaging\n        # install the latest pytorch for testing\n        # however, \"pip install monai*.tar.gz\" will build cpp/cuda with an isolated\n        # fresh torch installation according to pyproject.toml\n        python -m pip install torch>=2.5.1 torchvision --extra-index-url https://download.pytorch.org/whl/cpu\n    - name: Check packages\n      run: |\n        pip uninstall monai\n        pip list | grep -iv monai\n        git fetch --depth=1 origin +refs/tags/*:refs/tags/*\n        set -e\n\n        # build tar.gz and wheel\n        python setup.py check -m -s\n        python setup.py sdist bdist_wheel\n        python -m twine check dist/*\n    - run: echo \"pwd=$PWD\" >> $GITHUB_OUTPUT\n      id: root\n    - run: echo \"tmp_dir=$(mktemp -d)\" >> $GITHUB_OUTPUT\n      id: mktemp\n    - name: Move packages\n      run: |\n        printf ${{ steps.root.outputs.pwd }}\n        printf ${{ steps.mktemp.outputs.tmp_dir }}\n        # move packages to a temp dir\n        cp dist/monai* \"${{ steps.mktemp.outputs.tmp_dir }}\"\n        rm -r build dist monai.egg-info\n        cd \"${{ steps.mktemp.outputs.tmp_dir }}\"\n        ls -al\n    - name: Install wheel file\n      working-directory: ${{ steps.mktemp.outputs.tmp_dir }}\n      run: |\n        # install from wheel\n        python -m pip install monai*.whl --extra-index-url https://download.pytorch.org/whl/cpu\n        python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv \"unknown\"\n        python -c 'import monai; print(monai.__file__)'\n        python -m pip uninstall -y monai\n        rm monai*.whl\n    - name: Install source archive\n      working-directory: ${{ steps.mktemp.outputs.tmp_dir }}\n      run: |\n        # install from tar.gz\n        name=$(ls *.tar.gz | head -n1)\n        echo $name\n        python -m pip install $name[all] --extra-index-url https://download.pytorch.org/whl/cpu\n        python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv \"unknown\"\n        python -c 'import monai; print(monai.__file__)'\n    - name: Quick test\n      working-directory: ${{ steps.mktemp.outputs.tmp_dir }}\n      run: |\n        # run min tests\n        cp ${{ steps.root.outputs.pwd }}/requirements*.txt .\n        cp -r ${{ steps.root.outputs.pwd }}/tests .\n        ls -al\n        python -m pip install --no-build-isolation -r requirements-dev.txt --extra-index-url https://download.pytorch.org/whl/cpu\n        python -m unittest -v\n      env:\n        PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python  # https://github.com/Project-MONAI/MONAI/issues/4354\n\n  build-docs:\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n        cache: 'pip'\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip wheel\n        python -m pip install -r docs/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu\n    - name: Make html\n      run: |\n        cd docs/\n        make clean\n        make html 2>&1 | tee tmp_log\n        if [[ $(grep -c \"ERROR:\" tmp_log) != 0 ]]; then echo \"found errors\"; grep \"ERROR:\" tmp_log; exit 1; fi\n        sed '/WARNING.*pip/d' tmp_log > tmp_log1; mv tmp_log1 tmp_log  # monai#7133\n        if [[ $(grep -c \"WARNING:\" tmp_log) != 0 ]]; then echo \"found warnings\"; grep \"WARNING:\" tmp_log; exit 1; fi\n      shell: bash\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: release\n# generating and testing package artefacts from the main branch\n\non:\n  push:\n    branches:\n      - main\n    tags:\n      - '*'\n\njobs:\n  packaging:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: ['3.9', '3.10', '3.11']\n    steps:\n    - uses: actions/checkout@v6\n      with:\n        fetch-depth: 0\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v6\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install setuptools\n      run: |\n        python -m pip install --user --upgrade setuptools wheel packaging\n    - name: Build and test source archive and wheel file\n      run: |\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        git fetch --depth=1 origin +refs/tags/*:refs/tags/*\n        root_dir=$PWD\n        echo \"$root_dir\"\n        set -e\n\n        # build tar.gz and wheel\n        python setup.py sdist bdist_wheel --build-number $(date +'%Y%m%d%H%M')\n        tmp_dir=$(mktemp -d)\n        cp dist/monai* \"$tmp_dir\"\n        cd \"$tmp_dir\"\n        ls -al\n\n        # install from tar.gz\n        python -m pip install monai*.tar.gz\n        python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv \"unknown\"\n        python -c 'import monai; print(monai.__file__)'\n        python -m pip uninstall -y monai\n        rm monai*.tar.gz\n\n        # install from wheel\n        python -m pip install monai*.whl\n        python -c 'import monai; monai.config.print_config()' 2>&1 | grep -iv \"unknown\"\n        python -c 'import monai; print(monai.__file__)'\n\n        # clean up\n        cd \"$root_dir\"\n        rm -r \"$tmp_dir\"\n        rm -rf monai/\n        ls -al .\n    - name: Quick test installed\n      run: |\n        python -m pip install -r requirements-min.txt\n        python -m tests.min_tests\n      env:\n        QUICKTEST: True\n\n    - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')\n      name: Upload artifacts\n      uses: actions/upload-artifact@v6\n      with:\n        name: dist\n        path: dist/\n\n    - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')\n      name: Check artifacts\n      run: |\n        ls -al dist/\n        rm dist/monai*.tar.gz\n        ls -al dist/\n\n    # remove publishing to Test PyPI as it is moved to blossom\n    # - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')\n    #   name: Publish to Test PyPI\n    #   uses: pypa/gh-action-pypi-publish@release/v1\n    #   with:\n    #     password: ${{ secrets.TEST_PYPI }}\n    #     repository-url: https://test.pypi.org/legacy/\n\n  versioning:\n    # compute versioning file from python setup.py\n    # upload as artifact\n    if: github.repository == 'Project-MONAI/MONAI'\n    needs: packaging\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6\n        # full history so that we can git describe\n        with:\n          fetch-depth: 0\n      - name: Set up Python 3.9\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.9'\n      - shell: bash\n        run: |\n          find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n          git describe\n          python -m pip install --user --upgrade setuptools wheel packaging\n          python setup.py build\n          cat build/lib/monai/_version.py\n      - name: Upload version\n        uses: actions/upload-artifact@v6\n        with:\n          name: _version.py\n          path: build/lib/monai/_version.py\n      - name: Clean up directory\n        shell: bash\n        run: |\n          ls -al\n          rm -rf {*,.[^.]*}\n\n  release_tag_docker:\n    # if: github.repository == 'Project-MONAI/MONAI'\n    if: ${{ false }}\n    needs: versioning\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6\n      - name: Download version\n        uses: actions/download-artifact@v6\n        with:\n          name: _version.py\n      - name: Set tag\n        id: versioning\n        run: echo \"tag=${GITHUB_REF#refs/*/}\" >> $GITHUB_OUTPUT\n      - name: Check tag\n        env:\n          RELEASE_VERSION: ${{ steps.versioning.outputs.tag }}\n        run: |\n          echo \"$RELEASE_VERSION\"\n          cat _version.py\n      - if: startsWith(github.ref, 'refs/tags/')\n        name: build with the tag\n        env:\n          RELEASE_VERSION: ${{ steps.versioning.outputs.tag }}\n        shell: bash\n        run: |\n          find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n          # get tag info for versioning\n          mv _version.py monai/\n          # version checks\n          target=\" \\\"version\\\": \\\"$RELEASE_VERSION\\\"\"\n          local=`grep \"\\\"version\\\"\" monai/_version.py`\n          echo \"$target\"\n          echo \"$local\"\n          if [[ \"$local\" == \"$target\" ]]; then\n            echo \"matched version string\"\n          else\n            echo \"unmatched version string, please check the tagging branch.\"\n            exit 1\n          fi\n          # remove flake package as it is not needed on hub.docker.com\n          sed -i '/flake/d' requirements-dev.txt\n          docker build -t projectmonai/monai:\"$RELEASE_VERSION\" -f Dockerfile .\n          # distribute with a tag to hub.docker.com\n          echo \"${{ secrets.DOCKER_PW }}\" | docker login -u projectmonai --password-stdin\n          docker push projectmonai/monai:\"$RELEASE_VERSION\"\n          docker logout\n"
  },
  {
    "path": ".github/workflows/setupapp.yml",
    "content": "# Jenkinsfile.monai-postmerge\nname: deploy\n\non:\n  # full tests for all the important branches\n  push:\n    branches:\n      - main\n      - releasing/*\n      - feature/*\n      - dev\n\nconcurrency:\n  # automatically cancel the previously triggered workflows when there's a newer version\n  group: deploy-${{ github.event.pull_request.number || github.ref }}\n  cancel-in-progress: true\n\njobs:\n  # caching of these jobs:\n  #   - docker-py3-pip- (shared)\n  #   - ubuntu 37 38 39 310-pip-\n  #   - os-latest-pip (shared)\n  coverage-py3:\n    # if: github.repository == 'Project-MONAI/MONAI'\n    if: ${{ false }}  # disable self-hosted job project-monai/monai#7039\n    container:\n      image: nvcr.io/nvidia/pytorch:22.04-py3\n      options: --gpus all\n    runs-on: [self-hosted, linux, x64, integration]\n    steps:\n    - uses: actions/checkout@v6\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      if: ${{ startsWith(github.ref, 'refs/heads/dev') }}\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: |\n          ~/.cache/pip\n          ~/.cache/torch\n        key: docker-py3-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        which python\n        python -m pip install --upgrade pip wheel\n        python -m pip install --upgrade torch torchvision\n        python -m pip install -r requirements-dev.txt\n    - name: Run unit tests report coverage\n      env:\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: |\n        python -m pip list\n        git config --global --add safe.directory /__w/MONAI/MONAI\n        git clean -ffdx\n        df -h\n        # python -m pip cache info\n        nvidia-smi\n        export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)\n        echo $CUDA_VISIBLE_DEVICES\n        trap 'if pgrep python; then pkill python; fi;' ERR\n        python -c $'import torch\\na,b=torch.zeros(1,device=\"cuda:0\"),torch.zeros(1,device=\"cuda:1\");\\nwhile True:print(a,b)' > /dev/null &\n        python -c \"import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))\"\n        python -c 'import torch; print(torch.rand(5, 3, device=torch.device(\"cuda:0\")))'\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests  # unit tests with coverage report\n        BUILD_MONAI=1 ./runtests.sh --build --coverage --net  # integration tests with coverage report\n        coverage xml --ignore-errors\n        if pgrep python; then pkill python; fi\n      shell: bash\n    - name: Upload coverage\n      uses: codecov/codecov-action@v5\n      with:\n        fail_ci_if_error: false\n        files: ./coverage.xml\n\n  test-py3x:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: ['3.9', '3.10', '3.11']\n    steps:\n    - uses: actions/checkout@v6\n      with:\n        fetch-depth: 0\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v6\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: |\n          ~/.cache/pip\n          ~/.cache/torch\n        key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the dependencies\n      run: |\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        python -m pip install --upgrade pip wheel\n        python -m pip install -r requirements-dev.txt\n    - name: Run quick tests CPU ubuntu\n      env:\n        NGC_API_KEY: ${{ secrets.NGC_API_KEY }}\n        NGC_ORG: ${{ secrets.NGC_ORG }}\n        NGC_TEAM: ${{ secrets.NGC_TEAM }}\n      run: |\n        python -m pip list\n        python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'\n        BUILD_MONAI=0 ./runtests.sh --build --quick --unittests\n        BUILD_MONAI=1 ./runtests.sh --build --quick --min\n        coverage xml --ignore-errors\n    - name: Upload coverage\n      uses: codecov/codecov-action@v5\n      with:\n        fail_ci_if_error: false\n        files: ./coverage.xml\n\n  install:  # pip install from github url, the default branch is dev\n    runs-on: ubuntu-latest\n    steps:\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: |\n          ~/.cache/pip\n          ~/.cache/torch\n        key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install the default branch no build (dev branch only)\n      if: github.ref == 'refs/heads/dev'\n      run: |\n        BUILD_MONAI=0 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI\n        python -c 'import monai; monai.config.print_config()'\n        cd $(python -c 'import monai; import os; print(os.path.dirname(monai.__file__))')\n        ls .\n        pip uninstall -y monai\n    - name: Install the default branch with build (dev branch only)\n      if: github.ref == 'refs/heads/dev'\n      run: |\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI\n        python -c 'import monai; monai.config.print_config()'\n    - name: Get the test cases (dev branch only)\n      if: github.ref == 'refs/heads/dev'\n      uses: actions/checkout@v6\n      with:\n        ref: dev\n    - name: Quick test installed (dev branch only)\n      if: github.ref == 'refs/heads/dev'\n      run: |\n        cd $GITHUB_WORKSPACE\n        rm -rf monai/\n        ls -al .\n        python -m pip install -r requirements-min.txt\n        python -m tests.min_tests\n      env:\n        QUICKTEST: True\n"
  },
  {
    "path": ".github/workflows/weekly-preview.yml",
    "content": "name: weekly-preview\n\non:\n  schedule:\n  - cron: \"0 2 * * 0\"  # 02:00 of every Sunday\n\njobs:\n  flake8-py3:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        opt: [\"codeformat\", \"pytype\", \"mypy\"]\n    steps:\n    - uses: actions/checkout@v6\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n    - name: cache weekly timestamp\n      id: pip-cache\n      run: |\n        echo \"datew=$(date '+%Y-%V')\" >> $GITHUB_OUTPUT\n    - name: cache for pip\n      uses: actions/cache@v5\n      id: cache\n      with:\n        path: ~/.cache/pip\n        key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}\n    - name: Install dependencies\n      run: |\n        find  /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \\;\n        python -m pip install --upgrade pip wheel\n        python -m pip install -r requirements-dev.txt\n    - name: Lint and type check\n      run: |\n        # clean up temporary files\n        $(pwd)/runtests.sh --build --clean\n        # Github actions have 2 cores, so parallelize pytype\n        $(pwd)/runtests.sh --build --${{ matrix.opt }} -j 2\n\n  packaging:\n    if: github.repository == 'Project-MONAI/MONAI'\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v6\n      with:\n        ref: dev\n        fetch-depth: 0\n    - name: Set up Python 3.9\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.9'\n    - name: Install setuptools\n      run: |\n        python -m pip install --user --upgrade setuptools wheel packaging\n    - name: Build distribution\n      run: |\n        export HEAD_COMMIT_ID=$(git rev-parse HEAD)\n        sed -i 's/name\\ =\\ monai$/name\\ =\\ monai-weekly/g' setup.cfg\n        echo \"__commit_id__ = \\\"$HEAD_COMMIT_ID\\\"\" >> monai/__init__.py\n        git diff setup.cfg monai/__init__.py\n        git config user.name \"CI Builder\"\n        git config user.email \"monai.contact@gmail.com\"\n        git add setup.cfg monai/__init__.py\n        git commit -m \"Weekly build at $HEAD_COMMIT_ID\"\n        export YEAR_WEEK=$(date +'%y%U')\n        echo \"Year week for tag is ${YEAR_WEEK}\"\n        if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo \"Wrong 'year week' format.  Should be 4 digits.\"; exit 1 ; fi\n        git tag \"1.6.dev${YEAR_WEEK}\"\n        git log -1\n        git tag --list\n        python setup.py sdist bdist_wheel\n\n    - name: Publish to PyPI\n      uses: pypa/gh-action-pypi-publish@release/v1\n      with:\n        user: __token__\n        password: ${{ secrets.PYPI_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.coverage/\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# temporary unittest artifacts\ntests/testing_data/temp_*\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/build/\ndocs/source/_gen\ndocs/source/*_properties.csv\n_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# pytype cache\n.pytype/\n\n# mypy\n.mypy_cache/\nexamples/scd_lvsegs.npz\ntemp/\n.idea/\n.dmypy.json\n\n*~\n\n# Remove .pyre temporary config files\n.pyre\n.pyre_configuration\n\n# temporary editor files that should not be in git\n*.orig\n*.bak\n*.swp\n.DS_Store\n\n# temporary testing data MedNIST\ntests/testing_data/MedNIST*\ntests/testing_data/*Hippocampus*\ntests/testing_data/*.tiff\ntests/testing_data/schema.json\ntests/testing_data/endo.mp4\ntests/testing_data/ultrasound.avi\ntests/testing_data/train_data_stats.yaml\ntests/testing_data/eval_data_stats.yaml\ntests/testing_data/train_data_stats_by_case.yaml\ntests/testing_data/eval_data_stats_by_case.yaml\ntests/testing_data/CT_2D_head_fixed.mha\ntests/testing_data/CT_2D_head_moving.mha\ntests/testing_data/config_executed.json\ntests/testing_data/eval\ntests/testing_data/nrrd_example.nrrd\n\n# clang format tool\n.clang-format-bin/\n\n# ctags\ntags\n\n# VSCode\n.vscode/\n*.zip\n\n# profiling results\n*.prof\nruns\n\n*.gz\n\n*.pth\n\n*zarr/*\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "default_language_version:\n  python: python3\n\nci:\n  autofix_prs: true\n  autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'\n  autoupdate_schedule: quarterly\n  # submodules: true\n\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n      - id: end-of-file-fixer\n      - id: trailing-whitespace\n      - id: check-yaml\n      - id: check-docstring-first\n      - id: check-executables-have-shebangs\n      - id: check-toml\n      - id: check-case-conflict\n      - id: check-added-large-files\n        args: ['--maxkb=1024']\n      - id: detect-private-key\n      - id: forbid-new-submodules\n      - id: pretty-format-json\n        args: ['--autofix', '--no-sort-keys', '--indent=4']\n      - id: end-of-file-fixer\n      - id: mixed-line-ending\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.14.11\n    hooks:\n    -   id: ruff-check\n        args: [\"--fix\"]\n        exclude: |\n          (?x)(\n              ^versioneer.py|\n              ^monai/_version.py\n          )\n\n  - repo: https://github.com/hadialqattan/pycln\n    rev: v2.6.0\n    hooks:\n      - id: pycln\n        args: [--config=pyproject.toml]\n"
  },
  {
    "path": ".readthedocs.yml",
    "content": "# .readthedocs.yml\n# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details\n\n# Required\nversion: 2\n\n# Build documentation in the docs/ directory with Sphinx\nsphinx:\n  configuration: docs/source/conf.py\n\n# Build documentation with MkDocs\n#mkdocs:\n#  configuration: mkdocs.yml\n\n# Optionally build your docs in additional formats such as PDF and ePub\n# formats: all\n\n# Optionally set the version of Python and requirements required to build your docs\npython:\n  version: 3\n  install:\n    - requirements: docs/requirements.txt\n#  system_packages: true\n\n\nbuild:\n  image: stable\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\nAll notable changes to MONAI are documented in this file.\n\nThe format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).\n\n## [Unreleased]\n\n## [1.5.2] - 2026-01-28\n\n## What's Changed\n### Fixed\n* Fix Zip Slip vulnerability in NGC private bundle download (#8682)\n\n## [1.5.1] - 2025-09-22\n\n## What's Changed\n### Added\n* PyTorch 2.7 and 2.8 support (#8429, #8530)\n* Create SECURITY.md (#8546)\n* Add kwargs in array and functional file (#8508)\n* Add .coderabbit.yaml File (#8513)\n* Add input validation to ImageStats class (#8501)\n* Add support for optional conditioning in PatchInferer, SliceInferer, and SlidingWindowInferer (#8400)\n* Add classifier free guidance unconditioned value (#8562)\n* Improved `DiffusionModelEncoder` to support output linear layers of different dimensions (#8578, #8580)\n\n### Fixed\n* Fix for insecure zip file extraction to address [GHSA-x6ww-pf9m-m73m](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-x6ww-pf9m-m73m) (#8568)\n* Fix for insecure use of `torch.load` and `pickle` to address [GHSA-6vm5-6jv9-rjpj](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-6vm5-6jv9-rjpj) and [GHSA-p8cm-mm2v-gwjm](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-p8cm-mm2v-gwjm) (#8566)\n* Torchvision fix for loading pretrained weights using current syntax (#8563)\n* Fix bug in MAISI vae (#8517)\n* Throw exception on invalid images in retinanet detector (#8515)\n* Fix: HistogramNormalized doc (#8543)\n* Fix build failure by pinning pyamg to versions below 5.3.0 (#8548)\n* Fix hardcoded input dim in DiffusionModelEncoder (#8514)\n* Fix for gdown downloading fails (#8576)\n\n### Changed\n* Update README badges to add research paper citations number (#8494)\n* CI: Add custom timeout to ci job in order to save resources (#8504)\n* Improve documentation on the datalist format (#8539)\n* Tests Cleanup and refactor (#8405, #8535)\n* Improve Orientation transform to use the \"space\" (LPS vs RAS) of a metatensor by default (#8473)\n* Updated supported version of Huggingface Transformers (#8574)\n\n## [1.5.0] - 2025-06-13\n\n## What's Changed\n### Added\n* Add platform-specific constraints to setup.cfg (#8260)\n* Add PythonicWorkflow (#8151)\n* Add SM architecture version check (#8199)\n* Add MedNext implementation (#8004)\n* Added a top button to CONSTRIBUTING.md (#8163)\n* Adding CODEOWNERS (#8457)\n* Restormer Implementation (#8312)\n* Add rectified flow noise scheduler for accelerated diffusion model (#8374)\n* Add prediction type for rflow scheduler (#8386)\n* Add Average Precision to metrics (#8089)\n* Implementation of a Masked Autoencoder for representation learning (#8152)\n* Implement TorchIO transforms wrapper analogous to TorchVision transfo… (#7579)\n* 8328 nnunet bundle integration (#8329)\n* Adding Support Policy + Doc Updates (#8458)\n* Classifier free guidance (#8460)\n\n### Fixed\n* Fix Ruff Numpy2 deprecation rules (#8179)\n* Fix `torch.load()` frequently warning in PersistentDataset and GDSDataset (#8177)\n* Fix the logging of a nested dictionary metric in MLflow (#8169)\n* Fix ImageFilter to allow Gaussian filter without filter_size (#8189)\n* Fix fold_constants, test_handler switched to onnx (#8211)\n* Fix TypeError in meshgrid (#8252)\n* Fix PatchMerging duplicate merging (#8285)\n* Fix test load image issue (#8297)\n* Fix bundle download error from ngc source (#8307)\n* Fix deprecated usage in zarr (#8313, #8477)\n* Fix DataFrame subsets indexing in CSVDataset() (#8351)\n* Fix `packaging` imports in version comparison logic (#8347)\n* Fix CommonKeys docstring (#8342)\n* Fix: correctly apply fftshift to real-valued data inputs (#8407)\n* Fix OptionalImportError: required package `openslide` is not installed (#8419)\n* Fix cosine noise scheduler (#8427)\n* Fix AutoencoderKL docstrings. (#8445)\n* Inverse Threading Fix (#8418)\n* Fix normalize intensity (#8286)\n* Fix path at test onnx trt export (#8361)\n* Fix broken urls (#8481, #8483)\n\n### Changed\n* [DOC] Update README.md (#8157)\n* Streamlined Rearrange in SpatialAttentionBlock (#8130)\n* Optimize VISTA3D (#8123)\n* Skip torch trt convert test with torch newer than or equal to 2.5.0 (#8165)\n* Enable redirection of all loggers by configuring a FileHandler within the bundle (#8142)\n* Apply pyupgrade fixes for Python 3.9+ syntax (#8150)\n* Update base image to 2410 (#8164)\n* TRT support for MAISI (#8153)\n* 8134 Add unit test for responsive inference (#8146)\n* SwinUNETR refactor to accept additional parameters (#8212)\n* Allow an arbitrary mask to be used in the self attention (#8235)\n* Bump codecov/codecov-action from 4 to 5 (#8245)\n* Docs: update brats classes description (#8246)\n* Change default value of `patch_norm` to False in `SwinUNETR` (#8249)\n* Modify Dice, Jaccard and Tversky losses (#8138)\n* Modify Workflow to Allow IterableDataset Inputs (#8263)\n* Enhance download_and_extract (#8216)\n* Relax gpu load check (#8282, #8275)\n* Using LocalStore in Zarr v3 (#8299)\n* Enable gpu load nifti (#8188)\n* update pydicom reader to enable gpu load (#8283)\n* Zarr compression tests only with versions before 3.0 (#8319)\n* Changing utils.py to test_utils.py (#8335)\n* Refactor testd (#8231)\n* Recursive Item Mapping for Nested Lists in Compose  (#8187)\n* Bump min torch to 1.13.1 to mitigate CVE-2022-45907 unsafe usage of eval (#8296)\n* Inferer modification - save_intermediates clashes with latent shape adjustment in latent diffusion inferers (#8343)\n* Solves path problem in test_bundle_trt_export.py (#8357)\n* Modify ControlNet inferer so that it takes in context when the diffus… (#8360)\n* Update monaihosting download method (#8364)\n* Bump torch minimum to mitigate CVE-2024-31580 & CVE-2024-31583 and enable numpy 2 compatibility (#8368)\n* Auto3DSeg algo_template hash update (#8378)\n* Enable Pytorch 2.6 (#8309)\n* Auto3DSeg algo_template hash update (#8393, #8397)\n* Update Dice Metric Docs (#8388)\n* Auto3DSeg algo_template hash update (#8406)\n* Update bundle download API (#8403)\n* Add Skip test in TestTranschex (#8416)\n* Update get latest bundle version function (#8420)\n* Temporarily Restrict setuptools Version to 79.0.1 (#8441)\n* Update default overlap value in occlusion_sensitivity to 0.6 (#8446)\n* Enable code coverage comments on PRs in codecov configuration (#8402)\n* Migrate to modern Python Logger API (#8449)\n\n### Deprecated\n### Removed\n* Remove deprecated functionality for v1.5 (#8430)\n* Remove deprecated `return_state_dict ` in bundle `load` (#8454)\n* Remove deprecated `net_name` in test file (#8461)\n* Remove unused test cases in bundle load (#8463)\n* selfattention block: Remove the fc linear layer if it is not used (#8325)\n* Removed outdated `torch` version checks from transform functions (#8359)\n\n## [1.4.0] - 2024-10-17\n## What's Changed\n### Added\n* Implemented Conjugate Gradient Solver to generate confidence maps. (#7876)\n* Added norm parameter to `ResNet` (#7752, #7805)\n* Introduced alpha parameter to `DiceFocalLoss` for improved flexibility (#7841)\n* Integrated Tailored ControlNet Implementations (#7875)\n* Integrated Tailored Auto-Encoder Model (#7861)\n* Integrated Tailored Diffusion U-Net Model (7867)\n* Added Maisi morphological functions (#7893)\n* Added support for downloading bundles from NGC private registry (#7907, #7929, #8076)\n* Integrated generative refactor into the core (#7886, #7962)\n* Made `ViT` and `UNETR` models compatible with TorchScript (#7937)\n* Implemented post-download checks for MONAI bundles and compatibility warnings (#7938)\n* Added NGC prefix argument when downloading bundles (#7974)\n* Added flash attention support in the attention block for improved performance (#7977)\n* Enhanced `MLPBlock` for compatibility with VISTA-3D (#7995)\n* Added support for Neighbor-Aware Calibration Loss (NACL) for calibrated models in segmentation tasks (#7819)\n* Added label_smoothing parameter to `DiceCELoss` for enhanced model calibration (#8000)\n* Add `include_fc` and `use_combined_linear` argument in the `SABlock` (#7996)\n* Added utilities, networks, and an inferer specific to VISTA-3D (#7999, #7987, #8047, #8059, #8021)\n* Integrated a new network, `CellSamWrapper`, for cell-based applications (#7981)\n* Introduced `WriteFileMapping` transform to map between input image paths and their corresponding output paths (#7769)\n* Added `TrtHandler` to accelerate models using TensorRT (#7990, #8064)\n* Added box and points conversion transforms for more flexible spatial manipulation (#8053)\n* Enhanced `RandSimulateLowResolutiond` transform with deterministic support (#8057)\n* Added a contiguous argument to the `Fourier` class to facilitate contiguous tensor outputs (#7969)\n* Allowed `ApplyTransformToPointsd` to receive a sequence of reference keys for more versatile point manipulation (#8063)\n* Made `MetaTensor` an optional print in `DataStats` and `DataStatsd` for more concise logging (#7814)\n#### misc.\n* Refactored Dataset to utilize Compose for handling transforms. (#7784)\n* Combined `map_classes_to_indices` and `generate_label_classes_crop_centers` into a unified function (#7712)\n* Introduced metadata schema directly into the codebase for improved structure and validation (#7409)\n* Renamed `optional_packages_version` to `required_packages_version` for clearer package dependency management. (#7253)\n* Replaced `pkg_resources` with the more modern packaging module for package handling (#7953)\n* Refactored MAISI-related networks to align with the new generative components (#7989, #7993, #8005)\n* Added a badge displaying monthly download statistics to enhance project visibility (#7891)\n### Fixed\n#### transforms\n* Ensured deterministic behavior in `MixUp`, `CutMix`, and `CutOut` transforms (#7813)\n* Applied a minor correction to `AsDiscrete` transform (#7984)\n* Fixed handling of integer weightmaps in `RandomWeightedCrop` (#8097)\n* Resolved data type bug in `ScaleIntensityRangePercentile` (#8109)\n#### data\n* Fixed negative strides issue in the `NrrdReader` (#7809)\n* Addressed wsireader issue with retrieving MPP (7921)\n* Ensured location is returned as a tuple in wsireader (#8007)\n* Corrected interpretation of space directions in nrrd reader (#8091)\n#### metrics and losses\n* Improved memory management for `NACLLoss` (#8020)\n* Fixed reduction logic in `GeneralizedDiceScore` (#7970)\n#### networks\n* Resolved issue with loading pre-trained weights in `ResNet` (#7924)\n* Fixed error where `torch.device` object had no attribute gpu_id during TensorRT export (#8019)\n* Corrected function for loading older weights in `DiffusionModelUNet` (#8031)\n* Switched to `torch_tensorrt.Device` instead of `torch.device` during TensorRT compilation (#8051)\n#### engines and handlers\n* Attempted to resolve the \"experiment already exists\" issue in `MLFlowHandler` (#7916)\n* Refactored the model export process for conversion and saving (#7934)\n#### misc.\n* Adjusted requirements to exclude Numpy version 2.0 (#7859)\n* Updated deprecated `scipy.ndimage` namespaces in optional imports (#7847, #7897)\n* Resolved `load_module()` deprecation in Python 3.12 (#7881)\n* Fixed Ruff type check issues (#7885)\n* Cleaned disk space in the conda test pipeline (#7902)\n* Replaced deprecated `pkgutil.find_loader` usage  (#7906)\n* Enhanced docstrings in various modules (#7913, #8055)\n* Test cases fixing (#7905, #7794, #7808)\n* Fix mypy issue introduced in 1.11.0 (#7941)\n* Cleaned up warnings during test collection (#7914)\n* Fix incompatible types in assignment issue (#7950)\n* Fix outdated link in the docs (#7971)\n* Addressed CI issues (#7983, #8013)\n* Fix module can not import correctly issue (#8015)\n* Fix AttributeError when using torch.min and max (#8041)\n* Ensure synchronization by adding `cuda.synchronize` (#8058)\n* Ignore warning from nptyping as workaround (#8062)\n* Suppress deprecated warning when importing monai (#8067)\n* Fix link in test bundle under MONAI-extra-test-data (#8092)\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:24.08-py3` from `nvcr.io/nvidia/pytorch:23.08-py3`\n* Change blossom-ci to ACL security format (#7843)\n* Move PyType test to weekly test (#8025)\n* Adjusted to meet Numpy 2.0 requirements (#7857)\n### Deprecated\n* Dropped support for Python 3.8 (#7909)\n* Remove deprecated arguments and class for v1.4 (#8079)\n### Removed\n* Remove use of deprecated python 3.12 strtobool (#7900)\n* Removed the pipeline for publishing to testpypi (#8086)\n* Cleaning up some very old and now obsolete infrastructure (#8113, #8118, #8121)\n\n## [1.3.2] - 2024-06-25\n### Fixed\n#### misc.\n* Updated Numpy version constraint to < 2.0 (#7859)\n\n## [1.3.1] - 2024-05-17\n### Added\n* Support for `by_measure` argument in `RemoveSmallObjects` (#7137)\n* Support for `pretrained` flag in `ResNet` (#7095)\n* Support for uploading and downloading bundles to and from the Hugging Face Hub (#6454)\n* Added weight parameter in DiceLoss to apply weight to voxels of each class (#7158)\n* Support for returning dice for each class in `DiceMetric` (#7163)\n* Introduced `ComponentStore` for storage purposes (#7159)\n* Added utilities used in MONAI Generative (#7134)\n* Enabled Python 3.11 support for `convert_to_torchscript` and `convert_to_onnx` (#7182)\n* Support for MLflow in `AutoRunner` (#7176)\n* `fname_regex` option in PydicomReader (#7181)\n* Allowed setting AutoRunner parameters from config (#7175)\n* `VoxelMorphUNet` and `VoxelMorph` (#7178)\n* Enabled `cache` option in `GridPatchDataset` (#7180)\n* Introduced `class_labels` option in `write_metrics_reports` for improved readability (#7249)\n* `DiffusionLoss` for image registration task (#7272)\n* Supported specifying `filename` in `Saveimage` (#7318)\n* Compile support in `SupervisedTrainer` and `SupervisedEvaluator` (#7375)\n* `mlflow_experiment_name` support in `Auto3DSeg` (#7442)\n* Arm support (#7500)\n* `BarlowTwinsLoss` for representation learning (#7530)\n* `SURELoss` and `ConjugateGradient` for diffusion models (#7308)\n* Support for `CutMix`, `CutOut`, and `MixUp` augmentation techniques (#7198)\n* `meta_file` and `logging_file` options to `BundleWorkflow` (#7549)\n* `properties_path` option to `BundleWorkflow` for customized properties (#7542)\n* Support for both soft and hard clipping in `ClipIntensityPercentiles` (#7535)\n* Support for not saving artifacts in `MLFlowHandler` (#7604)\n* Support for multi-channel images in `PerceptualLoss` (#7568)\n* Added ResNet backbone for `FlexibleUNet` (#7571)\n* Introduced `dim_head` option in `SABlock` to set dimensions for each head (#7664)\n* Direct links to github source code to docs (#7738, #7779)\n#### misc.\n* Refactored `list_data_collate` and `collate_meta_tensor` to utilize the latest PyTorch API (#7165)\n* Added __str__ method in `Metric` base class (#7487)\n* Made enhancements for testing files (#7662, #7670, #7663, #7671, #7672)\n* Improved documentation for bundles (#7116)\n### Fixed\n#### transforms\n* Addressed issue where lazy mode was ignored in `SpatialPadd` (#7316)\n* Tracked applied operations in `ImageFilter` (#7395)\n* Warnings are now given only if missing class is not set to 0 in `generate_label_classes_crop_centers` (#7602)\n* Input is now always converted to C-order in `distance_transform_edt` to ensure consistent behavior (#7675)\n#### data\n* Modified .npz file behavior to use keys in `NumpyReader` (#7148)\n* Handled corrupted cached files in `PersistentDataset` (#7244)\n* Corrected affine update in `NrrdReader` (#7415)\n#### metrics and losses\n* Addressed precision issue in `get_confusion_matrix` (#7187)\n* Harmonized and clarified documentation and tests for dice losses variants (#7587)\n#### networks\n* Removed hard-coded `spatial_dims` in `SwinTransformer` (#7302)\n* Fixed learnable `position_embeddings` in `PatchEmbeddingBlock` (#7564, #7605)\n* Removed `memory_pool_limit` in TRT config (#7647)\n* Propagated `kernel_size` to `ConvBlocks` within `AttentionUnet` (#7734)\n* Addressed hard-coded activation layer in `ResNet` (#7749)\n#### bundle\n* Resolved bundle download issue (#7280)\n* Updated `bundle_root` directory for `NNIGen` (#7586)\n* Checked for `num_fold` and failed early if incorrect (#7634)\n* Enhanced logging logic in `ConfigWorkflow` (#7745)\n#### misc.\n* Enabled chaining in `Auto3DSeg` CLI (#7168)\n* Addressed useless error message in `nnUNetV2Runner` (#7217)\n* Resolved typing and deprecation issues in Mypy (#7231)\n* Quoted `$PY_EXE` variable to handle Python path that contains spaces in Bash (#7268)\n* Improved documentation, code examples, and warning messages in various modules (#7234, #7213, #7271, #7326, #7569, #7584)\n* Fixed typos in various modules (#7321, #7322, #7458, #7595, #7612)\n* Enhanced docstrings in various modules (#7245, #7381, #7746)\n* Handled error when data is on CPU in `DataAnalyzer` (#7310)\n* Updated version requirements for third-party packages (#7343, #7344, #7384, #7448, #7659, #7704, #7744, #7742, #7780)\n* Addressed incorrect slice compute in `ImageStats` (#7374)\n* Avoided editing a loop's mutable iterable to address B308 (#7397)\n* Fixed issue with `CUDA_VISIBLE_DEVICES` setting being ignored (#7408, #7581)\n* Avoided changing Python version in CICD (#7424)\n* Renamed partial to callable in instantiate mode (#7413)\n* Imported AttributeError for Python 3.12 compatibility (#7482)\n* Updated `nnUNetV2Runner` to support nnunetv2 2.2 (#7483)\n* Used uint8 instead of int8 in `LabelStats` (#7489)\n* Utilized subprocess for nnUNet training (#7576)\n* Addressed deprecated warning in ruff (#7625)\n* Fixed downloading failure on FIPS machine (#7698)\n* Updated `torch_tensorrt` compile parameters to avoid warning (#7714)\n* Restrict `Auto3DSeg` fold input based on datalist (#7778)\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:24.03-py3` from `nvcr.io/nvidia/pytorch:23.08-py3`\n### Removed\n* Removed unrecommended star-arg unpacking after a keyword argument, addressed B026 (#7262)\n* Skipped old PyTorch version test for `SwinUNETR` (#7266)\n* Dropped docker build workflow and migrated to Nvidia Blossom system (#7450)\n* Dropped Python 3.8 test on quick-py3 workflow (#7719)\n\n## [1.3.0] - 2023-10-12\n### Added\n* Intensity transforms `ScaleIntensityFixedMean` and `RandScaleIntensityFixedMean` (#6542)\n* `UltrasoundConfidenceMapTransform` used for computing confidence map from an ultrasound image (#6709)\n* `channel_wise` support in `RandScaleIntensity` and `RandShiftIntensity` (#6793, #7025)\n* `RandSimulateLowResolution` and `RandSimulateLowResolutiond` (#6806)\n* `SignalFillEmptyd` (#7011)\n* Euclidean distance transform `DistanceTransformEDT` with GPU support (#6981)\n* Port loss and metrics from `monai-generative` (#6729, #6836)\n* Support `invert_image` and `retain_stats` in `AdjustContrast` and `RandAdjustContrast` (#6542)\n* New network `DAF3D` and `Quicknat` (#6306)\n* Support `sincos` position embedding (#6986)\n* `ZarrAvgMerger` used for patch inference (#6633)\n* Dataset tracking support to `MLFlowHandler` (#6616)\n* Considering spacing and subvoxel borders in `SurfaceDiceMetric` (#6681)\n* CUCIM support for surface-related metrics (#7008)\n* `loss_fn` support in `IgniteMetric` and renamed it to `IgniteMetricHandler` (#6695)\n* `CallableEventWithFilter` and `Events` options for `trigger_event` in `GarbageCollector` (#6663)\n* Support random sorting option to `GridPatch`, `RandGridPatch`, `GridPatchd` and `RandGridPatchd` (#6701)\n* Support multi-threaded batch sampling in `PatchInferer` (#6139)\n* `SoftclDiceLoss` and `SoftDiceclDiceLoss` (#6763)\n* `HausdorffDTLoss` and `LogHausdorffDTLoss` (#6994)\n* Documentation for `TensorFloat-32` (#6770)\n* Docstring format guide (#6780)\n* `GDSDataset` support for GDS (#6778)\n* PyTorch backend support for `MapLabelValue` (#6872)\n* `filter_func` in `copy_model_state` to filter the weights to be loaded  and `filter_swinunetr` (#6917)\n* `stats_sender` to `MonaiAlgo` for FL stats (#6984)\n* `freeze_layers` to help freeze specific layers (#6970)\n#### misc.\n* Refactor multi-node running command used in `Auto3DSeg` into dedicated functions (#6623)\n* Support str type annotation to `device` in `ToTensorD` (#6737)\n* Improve logging message and file name extenstion in `DataAnalyzer` for `Auto3DSeg` (#6758)\n* Set `data_range` as a property in `SSIMLoss` (#6788)\n* Unify environment variable access (#7084)\n* `end_lr` support in `WarmupCosineSchedule` (#6662)\n* Add `ClearML` as optional dependency (#6827)\n* `yandex.disk` support in `download_url` (#6667)\n* Improve config expression error message (#6977)\n### Fixed\n#### transforms\n* Make `convert_box_to_mask` throw errors when box size larger than the image (#6637)\n* Fix lazy mode in `RandAffine` (#6774)\n* Raise `ValueError` when `map_items` is bool in `Compose` (#6882)\n* Improve performance for `NormalizeIntensity` (#6887)\n* Fix mismatched shape in `Spacing` (#6912)\n* Avoid FutureWarning in `CropForeground` (#6934)\n* Fix `Lazy=True` ignored when using `Dataset` call (#6975)\n* Shape check for arbitrary types for DataStats (#7082)\n#### data\n* Fix wrong spacing checking logic in `PydicomReader` and broken link in `ITKReader` (#6660)\n* Fix boolean indexing of batched `MetaTensor` (#6781)\n* Raise warning when multiprocessing in `DataLoader` (#6830)\n* Remove `shuffle` in `DistributedWeightedRandomSampler` (#6886)\n* Fix missing `SegmentDescription` in `PydicomReader` (#6937)\n* Fix reading dicom series error in `ITKReader` (#6943)\n* Fix KeyError in `PydicomReader` (#6946)\n* Update `metatensor_to_itk_image` to accept RAS `MetaTensor` and update default 'space' in `NrrdReader` to `SpaceKeys.LPS` (#7000)\n* Collate common meta dictionary keys (#7054)\n#### metrics and losses\n* Fixed bug in `GeneralizedDiceLoss` when `batch=True` (#6775)\n* Support for `BCEWithLogitsLoss` in `DiceCELoss` (#6924)\n* Support for `weight` in Dice and related losses (#7098)\n#### networks\n* Use `np.prod` instead of `np.product` (#6639)\n* Fix dimension issue in `MBConvBlock` (#6672)\n* Fix hard-coded `up_kernel_size` in `ViTAutoEnc` (#6735)\n* Remove hard-coded `bias_downsample` in `resnet` (#6848)\n* Fix unused `kernel_size` in `ResBlock` (#6999)\n* Allow for defining reference grid on non-integer coordinates (#7032)\n* Padding option for autoencoder (#7068)\n* Lower peak memory usage for SegResNetDS (#7066)\n#### bundle\n* Set `train_dataset_data` and `dataset_data` to unrequired in BundleProperty (#6607)\n* Set `None` to properties that do not have `REF_ID` (#6607)\n* Fix `AttributeError` for default value in `get_parsed_content` for `ConfigParser` (#6756)\n* Update `monai.bundle.scripts` to support NGC hosting (#6828, #6997)\n* Add `MetaProperties` (#6835)\n* Add `create_workflow` and update `load` function (#6835)\n* Add bundle root directory to Python search directories automatically (#6910)\n* Generate properties for bundle docs automatically (#6918)\n* Move `download_large_files` from model zoo to core (#6958)\n* Bundle syntax `#` as alias of `::` (#6955)\n* Fix bundle download naming issue (#6969, #6963)\n* Simplify the usage of `ckpt_export` (#6965)\n* `update_kwargs` in `monai.bundle.script` for merging multiple configs (#7109)\n#### engines and handlers\n* Added int options for `iteration_log` and `epoch_log` in `TensorBoardStatsHandler` (#7027)\n* Support to run validator at training start (#7108)\n#### misc.\n* Fix device fallback error in `DataAnalyzer` (#6658)\n* Add int check for  `current_mode` in `convert_applied_interp_mode` (#6719)\n* Consistent type in `convert_to_contiguous` (#6849)\n* Label `argmax` in `DataAnalyzer` when retry on CPU (#6852)\n* Fix `DataAnalyzer` with `histogram_only=True` (#6874)\n* Fix `AttributeError` in `RankFilter` in single GPU environment (#6895)\n* Remove the default warning on `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` and add debug print info (#6909)\n* Hide user information in `print_config` (#6913, #6922)\n* Optionally pass coordinates to predictor during sliding window (#6795)\n* Proper ensembling when trained with a sigmoid in `AutoRunner` (#6588)\n* Fixed `test_retinanet` by increasing absolute differences (#6615)\n* Add type check to avoid comparing a np.array with a string in `_check_kwargs_are_present` (#6624)\n* Fix md5 hashing with FIPS mode (#6635)\n* Capture failures from Auto3DSeg related subprocess calls (#6596)\n* Code formatting tool for user-specified directory (#7106)\n* Various docstring fixes\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:23.08-py3` from `nvcr.io/nvidia/pytorch:23.03-py3`\n### Deprecated\n* `allow_smaller=True`; `allow_smaller=False` will be the new default in `CropForeground` and `generate_spatial_bounding_box` (#6736)\n* `dropout_prob` in `VNet` in favor of `dropout_prob_down` and `dropout_prob_up` (#6768)\n* `workflow` in `BundleWorkflow` in favor of `workflow_type`(#6768)\n* `pos_embed` in `PatchEmbeddingBlock` in favor of `proj_type`(#6986)\n* `net_name` and `net_kwargs` in `download` in favor of `model`(#7016)\n* `img_size` parameter in SwinUNETR (#7093)\n### Removed\n* `pad_val`, `stride`, `per_channel` and `upsampler` in `OcclusionSensitivity` (#6642)\n* `compute_meaniou` (#7019)\n* `AsChannelFirst`, `AddChannel`and `SplitChannel` (#7019)\n* `create_multigpu_supervised_trainer` and `create_multigpu_supervised_evaluator` (#7019)\n* `runner_id` in `run` (#7019)\n* `data_src_cfg_filename` in `AlgoEnsembleBuilder` (#7019)\n* `get_validation_stats` in `Evaluator` and `get_train_stats` in `Trainer` (#7019)\n* `epoch_interval` and `iteration_interval` in `TensorBoardStatsHandler` (#7019)\n* some self-hosted test (#7041)\n\n## [1.2.0] - 2023-06-08\n### Added\n* Various Auto3DSeg enhancements and integration tests including multi-node multi-GPU optimization, major usability improvements\n* TensorRT and ONNX support for `monai.bundle` API and the relevant models\n* nnU-Net V2 integration `monai.apps.nnunet`\n* Binary and categorical metrics and event handlers using `MetricsReloaded`\n* Python module and CLI entry point for bundle workflows in `monai.bundle.workflows` and `monai.fl.client`\n* Modular patch inference API including `PatchInferer`, `merger`, and `splitter`\n* Initial release of lazy resampling including transforms and MetaTensor implementations\n* Bridge for ITK Image object and MetaTensor `monai.data.itk_torch_bridge`\n* Sliding window inference memory efficiency optimization including `SlidingWindowInfererAdapt`\n* Generic kernel filtering transforms `ImageFiltered` and `RandImageFiltered`\n* Trainable bilateral filters and joint bilateral filters\n* ClearML stats and image handlers for experiment tracking\n#### misc.\n* Utility functions to warn API default value changes (#5738)\n* Support of dot notation to access content of `ConfigParser` (#5813)\n* Softmax version to focal loss (#6544)\n* FROC metric for N-dimensional (#6528)\n* Extend SurfaceDiceMetric for 3D images (#6549)\n* A `track_meta` option for Lambda and derived transforms (#6385)\n* CLIP pre-trained text-to-vision embedding (#6282)\n* Optional spacing to surface distances calculations (#6144)\n* `WSIReader` read by power and mpp (#6244)\n* Support GPU tensor for `GridPatch` and `GridPatchDataset` (#6246)\n* `SomeOf` transform composer (#6143)\n* GridPatch with both count and threshold filtering (#6055)\n### Fixed\n#### transforms\n* `map_classes_to_indices` efficiency issue (#6468)\n* Adaptive resampling mode based on backends (#6429)\n* Improve Compose encapsulation (#6224)\n* User-provided `FolderLayout` in `SaveImage` and `SaveImaged` transforms (#6213)\n* `SpacingD` output shape compute stability (#6126)\n* No mutate ratio /user inputs `croppad` (#6127)\n* A `warn` flag to RandCropByLabelClasses (#6121)\n* `nan` to indicate `no_channel`, split dim singleton (#6090)\n* Compatible padding mode (#6076)\n* Allow for missing `filename_or_obj` key (#5980)\n* `Spacing` pixdim in-place change (#5950)\n* Add warning in `RandHistogramShift` (#5877)\n* Exclude `cuCIM` wrappers from `get_transform_backends` (#5838)\n#### data\n* `__format__` implementation of MetaTensor (#6523)\n* `channel_dim` in `TiffFileWSIReader` and `CuCIMWSIReader` (#6514)\n* Prepend `\"meta\"` to `MetaTensor.__repr__` and `MetaTensor.__str__` for easier identification (#6214)\n* MetaTensor slicing issue (#5845)\n* Default writer flags (#6147)\n* `WSIReader` defaults and tensor conversion (#6058)\n* Remove redundant array copy for WSITiffFileReader (#6089)\n* Fix unused arg in `SlidingPatchWSIDataset` (#6047)\n* `reverse_indexing` for PILReader (#6008)\n* Use `np.linalg` for the small affine inverse (#5967)\n#### metrics and losses\n* Removing L2-norm in contrastive loss (L2-norm already present in CosSim) (#6550)\n* Fixes the SSIM metric (#6250)\n* Efficiency issues of Dice metrics (#6412)\n* Generalized Dice issue (#5929)\n* Unify output tensor devices for multiple metrics (#5924)\n#### networks\n* Make `RetinaNet` throw errors for NaN only when training (#6479)\n* Replace deprecated arg in torchvision models (#6401)\n* Improves NVFuser import check (#6399)\n* Add `device` in `HoVerNetNuclearTypePostProcessing` and `HoVerNetInstanceMapPostProcessing` (#6333)\n* Enhance hovernet load pretrained function (#6269)\n* Access to the `att_mat` in self-attention modules (#6493)\n* Optional swinunetr-v2 (#6203)\n* Add transform to handle empty box as training data for `retinanet_detector` (#6170)\n* GPU utilization of DiNTS network (#6050)\n* A pixelshuffle upsample shape mismatch problem (#5982)\n* GEGLU activation function for the MLP Block (#5856)\n* Constructors for `DenseNet` derived classes (#5846)\n* Flexible interpolation modes in `regunet` (#5807)\n#### bundle\n* Optimized the `deepcopy` logic in `ConfigParser` (#6464)\n* Improve check and error message of bundle run (#6400)\n* Warn or raise ValueError on duplicated key in json/yaml config (#6252)\n* Default metadata and logging values for bundle run (#6072)\n* `pprint` head and tail in bundle script (#5969)\n* Config parsing issue for substring reference (#5932)\n* Fix instantiate for object instantiation with attribute `path` (#5866)\n* Fix `_get_latest_bundle_version` issue on Windows (#5787)\n#### engines and handlers\n* MLflow handler run bug (#6446)\n* `monai.engine` training attribute check (#6132)\n* Update StatsHandler logging message (#6051)\n* Added callable options for `iteration_log` and `epoch_log` in TensorBoard and MLFlow (#5976)\n* `CheckpointSaver` logging error (#6026)\n* Callable options for `iteration_log` and `epoch_log` in StatsHandler (#5965)\n#### misc.\n* Avoid creating cufile.log when `import monai` (#6106)\n* `monai._extensions` module compatibility with rocm (#6161)\n* Issue of repeated UserWarning: \"TypedStorage is deprecated\" (#6105)\n* Use logging config at module level (#5960)\n* Add ITK to the list of optional dependencies (#5858)\n* `RankFilter` to skip logging when the rank is not meeting criteria (#6243)\n* Various documentation issues\n### Changed\n* Overall more precise and consistent type annotations\n* Optionally depend on PyTorch-Ignite v0.4.11 instead of v0.4.10\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:23.03-py3` from `nvcr.io/nvidia/pytorch:22.10-py3`\n### Deprecated\n* `resample=True`; `resample=False` will be the new default in `SaveImage`\n* `random_size=True`; `random_size=False` will be the new default for the random cropping transforms\n* `image_only=False`; `image_only=True` will be the new default in `LoadImage`\n* `AddChannel` and `AsChannelFirst` in favor of `EnsureChannelFirst`\n### Removed\n* Deprecated APIs since v0.9, including WSIReader from `monai.apps`, `NiftiSaver` and `PNGSaver` from `monai.data`\n* Support for PyTorch 1.8\n* Support for Python 3.7\n\n## [1.1.0] - 2022-12-19\n### Added\n* Hover-Net based digital pathology workflows including new network, loss, postprocessing, metric, training, and inference modules\n* Various enhancements for Auto3dSeg `AutoRunner` including template caching, selection, and a dry-run mode `nni_dry_run`\n* Various enhancements for Auto3dSeg algo templates including new state-of-the-art configurations, optimized GPU memory utilization\n* New bundle API and configurations to support experiment management including `MLFlowHandler`\n* New `bundle.script` API to support model zoo query and download\n* `LossMetric` metric to compute loss as cumulative metric measurement\n* Transforms and base transform APIs including `RandomizableTrait` and `MedianSmooth`\n* `runtime_cache` option for `CacheDataset` and the derived classes to allow for shared caching on the fly\n* Flexible name formatter for `SaveImage` transform\n* `pending_operations` MetaTensor property and basic APIs for lazy image resampling\n* Contrastive sensitivity for SSIM metric\n* Extensible backbones for `FlexibleUNet`\n* Generalize `SobelGradients` to 3D and any spatial axes\n* `warmup_multiplier` option for `WarmupCosineSchedule`\n* F beta score metric based on confusion matrix metric\n* Support of key overwriting in `Lambdad`\n* Basic premerge tests for Python 3.11\n* Unit and integration tests for CUDA 11.6, 11.7 and A100 GPU\n* `DataAnalyzer` handles minor image-label shape inconsistencies\n### Fixed\n* Review and enhance previously untyped APIs with additional type annotations and casts\n* `switch_endianness` in LoadImage now supports tensor input\n* Reduced memory footprint for various Auto3dSeg tests\n* Issue of `@` in `monai.bundle.ReferenceResolver`\n* Compatibility issue with ITK-Python 5.3 (converting `itkMatrixF44` for default collate)\n* Inconsistent of sform and qform when using different backends for `SaveImage`\n* `MetaTensor.shape` call now returns a `torch.Size` instead of tuple\n* Issue of channel reduction in `GeneralizedDiceLoss`\n* Issue of background handling before softmax in `DiceFocalLoss`\n* Numerical issue of `LocalNormalizedCrossCorrelationLoss`\n* Issue of incompatible view size in `ConfusionMatrixMetric`\n* `NetAdapter` compatibility with Torchscript\n* Issue of `extract_levels` in `RegUNet`\n* Optional `bias_downsample` in `ResNet`\n* `dtype` overflow for `ShiftIntensity` transform\n* Randomized transforms such as `RandCuCIM` now inherit `RandomizableTrait`\n* `fg_indices.size` compatibility issue in `generate_pos_neg_label_crop_centers`\n* Issue when inverting `ToTensor`\n* Issue of capital letters in filename suffixes check in `LoadImage`\n* Minor tensor compatibility issues in `apps.nuclick.transforms`\n* Issue of float16 in `verify_net_in_out`\n* `std` variable type issue for `RandRicianNoise`\n* `DataAnalyzer` accepts `None` as label key and checks empty labels\n* `iter_patch_position` now has a smaller memory footprint\n* `CumulativeAverage` has been refactored and enhanced to allow for simple tracking of metric running stats.\n* Multi-threading issue for `MLFlowHandler`\n### Changed\n* Printing a MetaTensor now generates a less verbose representation\n* `DistributedSampler` raises a ValueError if there are too few devices\n* OpenCV and `VideoDataset` modules are loaded lazily to avoid dependency issues\n* `device` in `monai.engines.Workflow` supports string values\n* `Activations` and `AsDiscrete` take `kwargs` as additional arguments\n* `DataAnalyzer` is now more efficient and writes summary stats before detailed all case stats\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.10-py3` from `nvcr.io/nvidia/pytorch:22.09-py3`\n* Simplified Conda environment file `environment-dev.yml`\n* Versioneer dependency upgraded to `0.23` from `0.19`\n### Deprecated\n* `NibabelReader` input argument `dtype` is deprecated, the reader will use the original dtype of the image\n### Removed\n* Support for PyTorch 1.7\n\n## [1.0.1] - 2022-10-24\n### Fixes\n* DiceCELoss for multichannel targets\n* Auto3DSeg DataAnalyzer out-of-memory error and other minor issues\n* An optional flag issue in the RetinaNet detector\n* An issue with output offset for Spacing\n* A `LoadImage` issue when `track_meta` is `False`\n* 1D data output error in `VarAutoEncoder`\n* An issue with resolution computing in `ImageStats`\n### Added\n* Flexible min/max pixdim options for Spacing\n* Upsample mode `deconvgroup` and optional kernel sizes\n* Docstrings for gradient-based saliency maps\n* Occlusion sensitivity to use sliding window inference\n* Enhanced Gaussian window and device assignments for sliding window inference\n* Multi-GPU support for MonaiAlgo\n* `ClientAlgoStats` and `MonaiAlgoStats` for federated summary statistics\n* MetaTensor support for `OneOf`\n* Add a file check for bundle logging config\n* Additional content and an authentication token option for bundle info API\n* An anti-aliasing option for `Resized`\n* `SlidingWindowInferer` adaptive device based on `cpu_thresh`\n* `SegResNetDS` with deep supervision and non-isotropic kernel support\n* Premerge tests for Python 3.10\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.09-py3` from `nvcr.io/nvidia/pytorch:22.08-py3`\n* Replace `None` type metadata content with `\"none\"` for `collate_fn` compatibility\n* HoVerNet Mode and Branch to independent StrEnum\n* Automatically infer device from the first item in random elastic deformation dict\n* Add channel dim in `ComputeHoVerMaps` and `ComputeHoVerMapsd`\n* Remove batch dim in `SobelGradients` and `SobelGradientsd`\n### Deprecated\n* Deprecating `compute_meandice`, `compute_meaniou` in `monai.metrics`, in favor of\n`compute_dice` and `compute_iou` respectively\n\n## [1.0.0] - 2022-09-16\n### Added\n* `monai.auto3dseg` base APIs and `monai.apps.auto3dseg` components for automated machine learning (AutoML) workflow\n* `monai.fl` module with base APIs and `MonaiAlgo` for federated learning client workflow\n* An initial backwards compatibility [guide](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md#backwards-compatibility)\n* Initial release of accelerated MRI reconstruction components, including `CoilSensitivityModel`\n* Support of `MetaTensor` and new metadata attributes for various digital pathology components\n* Various `monai.bundle` enhancements for MONAI model-zoo usability, including config debug mode and `get_all_bundles_list`\n* new `monai.transforms` components including `SignalContinuousWavelet` for 1D signal, `ComputeHoVerMaps` for digital pathology, and `SobelGradients` for spatial gradients\n* `VarianceMetric` and `LabelQualityScore` metrics for active learning\n* Dataset API for real-time stream and videos\n* Several networks and building blocks including `FlexibleUNet` and `HoVerNet`\n* `MeanIoUHandler` and `LogfileHandler` workflow event handlers\n* `WSIReader` with the TiffFile backend\n* Multi-threading in `WSIReader` with cuCIM backend\n* `get_stats` API in `monai.engines.Workflow`\n* `prune_meta_pattern` in `monai.transforms.LoadImage`\n* `max_interactions` for deepedit interaction workflow\n* Various profiling utilities in `monai.utils.profiling`\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.08-py3` from `nvcr.io/nvidia/pytorch:22.06-py3`\n* Optionally depend on PyTorch-Ignite v0.4.10 instead of v0.4.9\n* The cache-based dataset now matches the transform information when read/write the cache\n* `monai.losses.ContrastiveLoss` now infers `batch_size` during `forward()`\n* Rearrange the spatial axes in `RandSmoothDeform` transforms following PyTorch's convention\n* Unified several environment flags into `monai.utils.misc.MONAIEnvVars`\n* Simplified `__str__` implementation of `MetaTensor` instead of relying on the `__repr__` implementation\n### Fixed\n* Improved error messages when both `monai` and `monai-weekly` are pip-installed\n* Inconsistent pseudo number sequences for different `num_workers` in `DataLoader`\n* Issue of repeated sequences for `monai.data.ShuffleBuffer`\n* Issue of not preserving the physical extent in `monai.transforms.Spacing`\n* Issue of using `inception_v3` as the backbone of `monai.networks.nets.TorchVisionFCModel`\n* Index device issue for `monai.transforms.Crop`\n* Efficiency issue when converting the array dtype and contiguous memory\n### Deprecated\n* `Addchannel` and `AsChannelFirst` transforms in favor of `EnsureChannelFirst`\n* `monai.apps.pathology.data` components in favor of the corresponding components from `monai.data`\n* `monai.apps.pathology.handlers` in favor of the corresponding components from `monai.handlers`\n### Removed\n* `Status` section in the pull request template in favor of the pull request draft mode\n* `monai.engines.BaseWorkflow`\n* `ndim` and `dimensions` arguments in favor of `spatial_dims`\n* `n_classes`, `num_classes` arguments in `AsDiscrete` in favor of `to_onehot`\n* `logit_thresh`, `threshold_values` arguments in `AsDiscrete` in favor of `threshold`\n* `torch.testing.assert_allclose` in favor of `tests.utils.assert_allclose`\n\n## [0.9.1] - 2022-07-22\n### Added\n* Support of `monai.data.MetaTensor` as core data structure across the modules\n* Support of `inverse` in array-based transforms\n* `monai.apps.TciaDataset` APIs for The Cancer Imaging Archive (TCIA) datasets, including a pydicom-backend reader\n* Initial release of components for MRI reconstruction in `monai.apps.reconstruction`, including various FFT utilities\n* New metrics and losses, including mean IoU and structural similarity index\n* `monai.utils.StrEnum` class to simplify Enum-based type annotations\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.06-py3` from `nvcr.io/nvidia/pytorch:22.04-py3`\n* Optionally depend on PyTorch-Ignite v0.4.9 instead of v0.4.8\n### Fixed\n* Fixed issue of not skipping post activations in `Convolution` when input arguments are None\n* Fixed issue of ignoring dropout arguments in `DynUNet`\n* Fixed issue of hard-coded non-linear function in ViT classification head\n* Fixed issue of in-memory config overriding with `monai.bundle.ConfigParser.update`\n* 2D SwinUNETR incompatible shapes\n* Fixed issue with `monai.bundle.verify_metadata` not raising exceptions\n* Fixed issue with `monai.transforms.GridPatch` returns inconsistent type location when padding\n* Wrong generalized Dice score metric when denominator is 0 but prediction is non-empty\n* Docker image build error due to NGC CLI upgrade\n* Optional default value when parsing id unavailable in a ConfigParser instance\n* Immutable data input for the patch-based WSI datasets\n### Deprecated\n* `*_transforms` and `*_meta_dict` fields in dictionary-based transforms in favor of MetaTensor\n* `meta_keys`, `meta_key_postfix`, `src_affine` arguments in various transforms, in favor of MetaTensor\n* `AsChannelFirst` and `AddChannel`, in favor of `EnsureChannelFirst` transform\n\n## [0.9.0] - 2022-06-08\n### Added\n* `monai.bundle` primary module with a `ConfigParser` and command-line interfaces for configuration-based workflows\n* Initial release of MONAI bundle specification\n* Initial release of volumetric image detection modules including bounding boxes handling, RetinaNet-based architectures\n* API preview `monai.data.MetaTensor`\n* Unified `monai.data.image_writer` to support flexible IO backends including an ITK writer\n* Various new network blocks and architectures including `SwinUNETR`\n* DeepEdit interactive training/validation workflow\n* NuClick interactive segmentation transforms\n* Patch-based readers and datasets for whole-slide imaging\n* New losses and metrics including `SurfaceDiceMetric`, `GeneralizedDiceFocalLoss`\n* New pre-processing transforms including `RandIntensityRemap`, `SpatialResample`\n* Multi-output and slice-based inference for `SlidingWindowInferer`\n* `NrrdReader` for NRRD file support\n* Torchscript utilities to save models with meta information\n* Gradient-based visualization module `SmoothGrad`\n* Automatic regular source code scanning for common vulnerabilities and coding errors\n\n### Changed\n* Simplified `TestTimeAugmentation` using de-collate and invertible transforms APIs\n* Refactoring `monai.apps.pathology` modules into `monai.handlers` and `monai.transforms`\n* Flexible activation and normalization layers for `TopologySearch` and `DiNTS`\n* Anisotropic first layers for 3D resnet\n* Flexible ordering of activation, normalization in `UNet`\n* Enhanced performance of connected-components analysis using Cupy\n* `INSTANCE_NVFUSER` for enhanced performance in 3D instance norm\n* Support of string representation of dtype in `convert_data_type`\n* Added new options `iteration_log`, `iteration_log` to the logging handlers\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.04-py3` from `nvcr.io/nvidia/pytorch:21.10-py3`\n* `collate_fn` generates more data-related debugging info with `dev_collate`\n\n### Fixed\n* Unified the spellings of \"meta data\", \"metadata\", \"meta-data\" to \"metadata\"\n* Various inaccurate error messages when input data are in invalid shapes\n* Issue of computing symmetric distances in `compute_average_surface_distance`\n* Unnecessary layer  `self.conv3` in `UnetResBlock`\n* Issue of torchscript compatibility for `ViT` and self-attention blocks\n* Issue of hidden layers in `UNETR`\n* `allow_smaller` in spatial cropping transforms\n* Antialiasing in `Resize`\n* Issue of bending energy loss value at different resolutions\n* `kwargs_read_csv` in `CSVDataset`\n* In-place modification in `Metric` reduction\n* `wrap_array` for `ensure_tuple`\n* Contribution guide for introducing new third-party dependencies\n\n### Removed\n* Deprecated `nifti_writer`, `png_writer` in favor of `monai.data.image_writer`\n* Support for PyTorch 1.6\n\n## [0.8.1] - 2022-02-16\n### Added\n* Support of `matshow3d` with given `channel_dim`\n* Support of spatial 2D for `ViTAutoEnc`\n* Support of `dataframe` object input in `CSVDataset`\n* Support of tensor backend for `Orientation`\n* Support of configurable delimiter for CSV writers\n* A base workflow API\n* `DataFunc` API for dataset-level preprocessing\n* `write_scalar` API for logging with additional `engine` parameter in `TensorBoardHandler`\n* Enhancements for NVTX Range transform logging\n* Enhancements for `set_determinism`\n* Performance enhancements in the cache-based datasets\n* Configurable metadata keys for `monai.data.DatasetSummary`\n* Flexible `kwargs` for `WSIReader`\n* Logging for the learning rate schedule handler\n* `GridPatchDataset` as subclass of `monai.data.IterableDataset`\n* `is_onehot` option in `KeepLargestConnectedComponent`\n* `channel_dim` in the image readers and support of stacking images with channels\n* Skipping workflow `run` if epoch length is 0\n* Enhanced `CacheDataset` to avoid duplicated cache items\n* `save_state` utility function\n\n### Changed\n* Optionally depend on PyTorch-Ignite v0.4.8 instead of v0.4.6\n* `monai.apps.mmars.load_from_mmar` defaults to the latest version\n\n### Fixed\n* Issue when caching large items with `pickle`\n* Issue of hard-coded activation functions in `ResBlock`\n* Issue of `create_file_name` assuming local disk file creation\n* Issue of `WSIReader` when the backend is `TiffFile`\n* Issue of `deprecated_args` when the function signature contains kwargs\n* Issue of `channel_wise` computations for the intensity-based transforms\n* Issue of inverting `OneOf`\n* Issue of removing temporary caching file for the persistent dataset\n* Error messages when reader backend is not available\n* Output type casting issue in `ScaleIntensityRangePercentiles`\n* Various docstring typos and broken URLs\n* `mode` in the evaluator engine\n* Ordering of `Orientation` and `Spacing` in `monai.apps.deepgrow.dataset`\n\n### Removed\n* Additional deep supervision modules in `DynUnet`\n* Deprecated `reduction` argument for `ContrastiveLoss`\n* Decollate warning in `Workflow`\n* Unique label exception in `ROCAUCMetric`\n* Logger configuration logic in the event handlers\n\n## [0.8.0] - 2021-11-25\n### Added\n* Overview of [new features in v0.8](docs/source/whatsnew_0_8.md)\n* Network modules for differentiable neural network topology search (DiNTS)\n* Multiple Instance Learning transforms and models for digital pathology WSI analysis\n* Vision transformers for self-supervised representation learning\n* Contrastive loss for self-supervised learning\n* Finalized major improvements of 200+ components in `monai.transforms` to support input and backend in PyTorch and NumPy\n* Initial registration module benchmarking with `GlobalMutualInformationLoss` as an example\n* `monai.transforms` documentation with visual examples and the utility functions\n* Event handler for `MLfLow` integration\n* Enhanced data visualization functions including `blend_images` and `matshow3d`\n* `RandGridDistortion` and `SmoothField` in `monai.transforms`\n* Support of randomized shuffle buffer in iterable datasets\n* Performance review and enhancements for data type casting\n* Cumulative averaging API with distributed environment support\n* Module utility functions including `require_pkg` and `pytorch_after`\n* Various usability enhancements such as `allow_smaller` when sampling ROI and `wrap_sequence` when casting object types\n* `tifffile` support in `WSIReader`\n* Regression tests for the fast training workflows\n* Various tutorials and demos including educational contents at [MONAI Bootcamp 2021](https://github.com/Project-MONAI/MONAIBootcamp2021)\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.10-py3` from `nvcr.io/nvidia/pytorch:21.08-py3`\n* Decoupled `TraceKeys` and `TraceableTransform` APIs from `InvertibleTransform`\n* Skipping affine-based resampling when `resample=False` in `NiftiSaver`\n* Deprecated `threshold_values: bool` and `num_classes: int` in `AsDiscrete`\n* Enhanced `apply_filter` for spatially 1D, 2D and 3D inputs with non-separable kernels\n* Logging with `logging` in downloading and model archives in `monai.apps`\n* API documentation site now defaults to `stable` instead of `latest`\n* `skip-magic-trailing-comma` in coding style enforcements\n* Pre-merge CI pipelines now include unit tests with Nvidia Ampere architecture\n### Removed\n* Support for PyTorch 1.5\n* The deprecated `DynUnetV1` and the related network blocks\n* GitHub self-hosted CI/CD pipelines for package releases\n### Fixed\n* Support of path-like objects as file path inputs in most modules\n* Issue of `decollate_batch` for dictionary of empty lists\n* Typos in documentation and code examples in various modules\n* Issue of no available keys when `allow_missing_keys=True` for the `MapTransform`\n* Issue of redundant computation when normalization factors are 0.0 and 1.0 in `ScaleIntensity`\n* Incorrect reports of registered readers in `ImageReader`\n* Wrong numbering of iterations in `StatsHandler`\n* Naming conflicts in network modules and aliases\n* Incorrect output shape when `reduction=\"none\"` in `FocalLoss`\n* Various usability issues reported by users\n\n## [0.7.0] - 2021-09-24\n### Added\n* Overview of [new features in v0.7](docs/source/whatsnew_0_7.md)\n* Initial phase of major usability improvements in `monai.transforms` to support input and backend in PyTorch and NumPy\n* Performance enhancements, with [profiling and tuning guides](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md) for typical use cases\n* Reproducing [training modules and workflows](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of state-of-the-art Kaggle competition solutions\n* 24 new transforms, including\n  * `OneOf` meta transform\n  * DeepEdit guidance signal transforms for interactive segmentation\n  * Transforms for self-supervised pre-training\n  * Integration of [NVIDIA Tools Extension](https://developer.nvidia.com/blog/nvidia-tools-extension-api-nvtx-annotation-tool-for-profiling-code-in-python-and-c-c/) (NVTX)\n  * Integration of [cuCIM](https://github.com/rapidsai/cucim)\n  * Stain normalization and contextual grid for digital pathology\n* `Transchex` network for vision-language transformers for chest X-ray analysis\n* `DatasetSummary` utility in `monai.data`\n* `WarmupCosineSchedule`\n* Deprecation warnings and documentation support for better backwards compatibility\n* Padding with additional `kwargs` and different backend API\n* Additional options such as `dropout` and `norm` in various networks and their submodules\n\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.08-py3` from `nvcr.io/nvidia/pytorch:21.06-py3`\n* Deprecated input argument `n_classes`, in favor of `num_classes`\n* Deprecated input argument `dimensions` and `ndims`, in favor of `spatial_dims`\n* Updated the Sphinx-based documentation theme for better readability\n* `NdarrayTensor` type is replaced by `NdarrayOrTensor` for simpler annotations\n* Self-attention-based network blocks now support both 2D and 3D inputs\n\n### Removed\n* The deprecated `TransformInverter`, in favor of `monai.transforms.InvertD`\n* GitHub self-hosted CI/CD pipelines for nightly and post-merge tests\n* `monai.handlers.utils.evenly_divisible_all_gather`\n* `monai.handlers.utils.string_list_all_gather`\n\n### Fixed\n* A Multi-thread cache writing issue in `LMDBDataset`\n* Output shape convention inconsistencies of the image readers\n* Output directory and file name flexibility issue for `NiftiSaver`, `PNGSaver`\n* Requirement of the `label` field in test-time augmentation\n* Input argument flexibility issues for  `ThreadDataLoader`\n* Decoupled `Dice` and `CrossEntropy` intermediate results in `DiceCELoss`\n* Improved documentation, code examples, and warning messages in various modules\n* Various usability issues reported by users\n\n## [0.6.0] - 2021-07-08\n### Added\n* 10 new transforms, a masked loss wrapper, and a `NetAdapter` for transfer learning\n* APIs to load networks and pre-trained weights from Clara Train [Medical Model ARchives (MMARs)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html)\n* Base metric and cumulative metric APIs, 4 new regression metrics\n* Initial CSV dataset support\n* Decollating mini-batch as the default first postprocessing step, [Migrating your v0.5 code to v0.6](https://github.com/Project-MONAI/MONAI/wiki/v0.5-to-v0.6-migration-guide) wiki shows how to adapt to the breaking changes\n* Initial backward compatibility support via `monai.utils.deprecated`\n* Attention-based vision modules and `UNETR` for segmentation\n* Generic module loaders and Gaussian mixture models using the PyTorch JIT compilation\n* Inverse of image patch sampling transforms\n* Network block utilities `get_[norm, act, dropout, pool]_layer`\n* `unpack_items` mode for `apply_transform` and `Compose`\n* New event `INNER_ITERATION_STARTED` in the deepgrow interactive workflow\n* `set_data` API for cache-based datasets to dynamically update the dataset content\n* Fully compatible with PyTorch 1.9\n* `--disttests` and `--min` options for `runtests.sh`\n* Initial support of pre-merge tests with Nvidia Blossom system\n\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.06-py3` from\n  `nvcr.io/nvidia/pytorch:21.04-py3`\n* Optionally depend on PyTorch-Ignite v0.4.5 instead of v0.4.4\n* Unified the demo, tutorial, testing data to the project shared drive, and\n  [`Project-MONAI/MONAI-extra-test-data`](https://github.com/Project-MONAI/MONAI-extra-test-data)\n* Unified the terms: `post_transform` is renamed to `postprocessing`, `pre_transform` is renamed to `preprocessing`\n* Unified the postprocessing transforms and event handlers to accept the \"channel-first\" data format\n* `evenly_divisible_all_gather` and `string_list_all_gather` moved to `monai.utils.dist`\n\n### Removed\n* Support of 'batched' input for postprocessing transforms and event handlers\n* `TorchVisionFullyConvModel`\n* `set_visible_devices` utility function\n* `SegmentationSaver` and `TransformsInverter` handlers\n\n### Fixed\n* Issue of handling big-endian image headers\n* Multi-thread issue for non-random transforms in the cache-based datasets\n* Persistent dataset issue when multiple processes sharing a non-exist cache location\n* Typing issue with Numpy 1.21.0\n* Loading checkpoint with both `model` and `optmizier` using `CheckpointLoader` when `strict_shape=False`\n* `SplitChannel` has different behaviour depending on numpy/torch inputs\n* Transform pickling issue caused by the Lambda functions\n* Issue of filtering by name in `generate_param_groups`\n* Inconsistencies in the return value types of `class_activation_maps`\n* Various docstring typos\n* Various usability enhancements in `monai.transforms`\n\n## [0.5.3] - 2021-05-28\n### Changed\n* Project default branch renamed to `dev` from `master`\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.04-py3` from `nvcr.io/nvidia/pytorch:21.02-py3`\n* Enhanced type checks for the `iteration_metric` handler\n* Enhanced `PersistentDataset` to use `tempfile` during caching computation\n* Enhanced various info/error messages\n* Enhanced performance of `RandAffine`\n* Enhanced performance of `SmartCacheDataset`\n* Optionally requires `cucim` when the platform is `Linux`\n* Default `device` of `TestTimeAugmentation` changed to `cpu`\n\n### Fixed\n* Download utilities now provide better default parameters\n* Duplicated `key_transforms` in the patch-based transforms\n* A multi-GPU issue in `ClassificationSaver`\n* A default `meta_data` issue in `SpacingD`\n* Dataset caching issue with the persistent data loader workers\n* A memory issue in `permutohedral_cuda`\n* Dictionary key issue in `CopyItemsd`\n* `box_start` and `box_end` parameters for deepgrow `SpatialCropForegroundd`\n* Tissue mask array transpose issue in `MaskedInferenceWSIDataset`\n* Various type hint errors\n* Various docstring typos\n\n### Added\n* Support of `to_tensor` and `device` arguments for `TransformInverter`\n* Slicing options with SpatialCrop\n* Class name alias for the networks for backward compatibility\n* `k_divisible` option for CropForeground\n* `map_items` option for `Compose`\n* Warnings of `inf` and `nan` for surface distance computation\n* A `print_log` flag to the image savers\n* Basic testing pipelines for Python 3.9\n\n## [0.5.0] - 2021-04-09\n### Added\n* Overview document for [feature highlights in v0.5.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md)\n* Invertible spatial transforms\n  * `InvertibleTransform` base APIs\n  * Batch inverse and decollating APIs\n  * Inverse of `Compose`\n  * Batch inverse event handling\n  * Test-time augmentation as an application\n* Initial support of learning-based image registration:\n  * Bending energy, LNCC, and global mutual information loss\n  * Fully convolutional architectures\n  * Dense displacement field, dense velocity field computation\n  * Warping with high-order interpolation with C++/CUDA implementations\n* Deepgrow modules for interactive segmentation:\n  * Workflows with simulations of clicks\n  * Distance-based transforms for guidance signals\n* Digital pathology support:\n  * Efficient whole slide imaging IO and sampling with Nvidia cuCIM and SmartCache\n  * FROC measurements for lesion\n  * Probabilistic post-processing for lesion detection\n  * TorchVision classification model adaptor for fully convolutional analysis\n* 12 new transforms, grid patch dataset, `ThreadDataLoader`, EfficientNets B0-B7\n* 4 iteration events for the engine for finer control of workflows\n* New C++/CUDA extensions:\n  * Conditional random field\n  * Fast bilateral filtering using the permutohedral lattice\n* Metrics summary reporting and saving APIs\n* DiceCELoss, DiceFocalLoss, a multi-scale wrapper for segmentation loss computation\n* Data loading utilities：\n  * `decollate_batch`\n  * `PadListDataCollate` with inverse support\n* Support of slicing syntax for `Dataset`\n* Initial Torchscript support for the loss modules\n* Learning rate finder\n* Allow for missing keys in the dictionary-based transforms\n* Support of checkpoint loading for transfer learning\n* Various summary and plotting utilities for Jupyter notebooks\n* Contributor Covenant Code of Conduct\n* Major CI/CD enhancements covering the tutorial repository\n* Fully compatible with PyTorch 1.8\n* Initial nightly CI/CD pipelines using Nvidia Blossom Infrastructure\n\n### Changed\n* Enhanced `list_data_collate` error handling\n* Unified iteration metric APIs\n* `densenet*` extensions are renamed to `DenseNet*`\n* `se_res*` network extensions are renamed to `SERes*`\n* Transform base APIs are rearranged into `compose`, `inverse`, and `transform`\n* `_do_transform` flag for the random augmentations is unified via `RandomizableTransform`\n* Decoupled post-processing steps, e.g. `softmax`, `to_onehot_y`, from the metrics computations\n* Moved the distributed samplers to `monai.data.samplers` from `monai.data.utils`\n* Engine's data loaders now accept generic iterables as input\n* Workflows now accept additional custom events and state properties\n* Various type hints according to Numpy 1.20\n* Refactored testing utility `runtests.sh` to have `--unittest` and `--net` (integration tests) options\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:21.02-py3` from `nvcr.io/nvidia/pytorch:20.10-py3`\n* Docker images are now built with self-hosted environments\n* Primary contact email updated to `monai.contact@gmail.com`\n* Now using GitHub Discussions as the primary communication forum\n\n### Removed\n* Compatibility tests for PyTorch 1.5.x\n* Format specific loaders, e.g. `LoadNifti`, `NiftiDataset`\n* Assert statements from non-test files\n* `from module import *` statements, addressed flake8 F403\n\n### Fixed\n* Uses American English spelling for code, as per PyTorch\n* Code coverage now takes multiprocessing runs into account\n* SmartCache with initial shuffling\n* `ConvertToMultiChannelBasedOnBratsClasses` now supports channel-first inputs\n* Checkpoint handler to save with non-root permissions\n* Fixed an issue for exiting the distributed unit tests\n* Unified `DynUNet` to have single tensor output w/o deep supervision\n* `SegmentationSaver` now supports user-specified data types and a `squeeze_end_dims` flag\n* Fixed `*Saver` event handlers output filenames with a `data_root_dir` option\n* Load image functions now ensure little-endian\n* Fixed the test runner to support regex-based test case matching\n* Usability issues in the event handlers\n\n## [0.4.0] - 2020-12-15\n### Added\n* Overview document for [feature highlights in v0.4.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md)\n* Torchscript support for the net modules\n* New networks and layers:\n  * Discrete Gaussian kernels\n  * Hilbert transform and envelope detection\n  * Swish and mish activation\n  * Acti-norm-dropout block\n  * Upsampling layer\n  * Autoencoder, Variational autoencoder\n  * FCNet\n* Support of initialisation from pretrained weights for densenet, senet, multichannel AHNet\n* Layer-wise learning rate API\n* New model metrics and event handlers based on occlusion sensitivity, confusion matrix, surface distance\n* CAM/GradCAM/GradCAM++\n* File format-agnostic image loader APIs with Nibabel, ITK readers\n* Enhancements for dataset partition, cross-validation APIs\n* New data APIs:\n  * LMDB-based caching dataset\n  * Cache-N-transforms dataset\n  * Iterable dataset\n  * Patch dataset\n* Weekly PyPI release\n* Fully compatible with PyTorch 1.7\n* CI/CD enhancements:\n  * Skipping, speed up, fail fast, timed, quick tests\n  * Distributed training tests\n  * Performance profiling utilities\n* New tutorials and demos:\n  * Autoencoder, VAE tutorial\n  * Cross-validation demo\n  * Model interpretability tutorial\n  * COVID-19 Lung CT segmentation challenge open-source baseline\n  * Threadbuffer demo\n  * Dataset partitioning tutorial\n  * Layer-wise learning rate demo\n  * [MONAI Bootcamp 2020](https://github.com/Project-MONAI/MONAIBootcamp2020)\n\n### Changed\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.10-py3` from `nvcr.io/nvidia/pytorch:20.08-py3`\n\n#### Backwards Incompatible Changes\n* `monai.apps.CVDecathlonDataset` is extended to a generic `monai.apps.CrossValidation` with an `dataset_cls` option\n* Cache dataset now requires a `monai.transforms.Compose` instance as the transform argument\n* Model checkpoint file name extensions changed from `.pth` to `.pt`\n* Readers' `get_spatial_shape` returns a numpy array instead of list\n* Decoupled postprocessing steps such as `sigmoid`, `to_onehot_y`, `mutually_exclusive`, `logit_thresh` from metrics and event handlers,\nthe postprocessing steps should be used before calling the metrics methods\n* `ConfusionMatrixMetric` and `DiceMetric` computation now returns an additional `not_nans` flag to indicate valid results\n* `UpSample` optional `mode` now supports `\"deconv\"`, `\"nontrainable\"`, `\"pixelshuffle\"`; `interp_mode` is only used when `mode` is `\"nontrainable\"`\n* `SegResNet` optional `upsample_mode` now supports `\"deconv\"`, `\"nontrainable\"`, `\"pixelshuffle\"`\n* `monai.transforms.Compose` class inherits `monai.transforms.Transform`\n* In `Rotate`, `Rotated`, `RandRotate`, `RandRotated`  transforms, the `angle` related parameters are interpreted as angles in radians instead of degrees.\n* `SplitChannel` and `SplitChanneld` moved from `transforms.post` to `transforms.utility`\n\n### Removed\n* Support of PyTorch 1.4\n\n### Fixed\n* Enhanced loss functions for stability and flexibility\n* Sliding window inference memory and device issues\n* Revised transforms:\n  * Normalize intensity datatype and normalizer types\n  * Padding modes for zoom\n  * Crop returns coordinates\n  * Select items transform\n  * Weighted patch sampling\n  * Option to keep aspect ratio for zoom\n* Various CI/CD issues\n\n## [0.3.0] - 2020-10-02\n### Added\n* Overview document for [feature highlights in v0.3.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md)\n* Automatic mixed precision support\n* Multi-node, multi-GPU data parallel model training support\n* 3 new evaluation metric functions\n* 11 new network layers and blocks\n* 6 new network architectures\n* 14 new transforms, including an I/O adaptor\n* Cross validation module for `DecathlonDataset`\n* Smart Cache module in dataset\n* `monai.optimizers` module\n* `monai.csrc` module\n* Experimental feature of ImageReader using ITK, Nibabel, Numpy, Pillow (PIL Fork)\n* Experimental feature of differentiable image resampling in C++/CUDA\n* Ensemble evaluator module\n* GAN trainer module\n* Initial cross-platform CI environment for C++/CUDA code\n* Code style enforcement now includes isort and clang-format\n* Progress bar with tqdm\n\n### Changed\n* Now fully compatible with PyTorch 1.6\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.08-py3` from `nvcr.io/nvidia/pytorch:20.03-py3`\n* Code contributions now require signing off on the [Developer Certificate of Origin (DCO)](https://developercertificate.org/)\n* Major work in type hinting finished\n* Remote datasets migrated to [Open Data on AWS](https://registry.opendata.aws/)\n* Optionally depend on PyTorch-Ignite v0.4.2 instead of v0.3.0\n* Optionally depend on torchvision, ITK\n* Enhanced CI tests with 8 new testing environments\n\n### Removed\n* `MONAI/examples` folder (relocated into [`Project-MONAI/tutorials`](https://github.com/Project-MONAI/tutorials))\n* `MONAI/research` folder (relocated to [`Project-MONAI/research-contributions`](https://github.com/Project-MONAI/research-contributions))\n\n### Fixed\n* `dense_patch_slices` incorrect indexing\n* Data type issue in `GeneralizedWassersteinDiceLoss`\n* `ZipDataset` return value inconsistencies\n* `sliding_window_inference` indexing and `device` issues\n* importing monai modules may cause namespace pollution\n* Random data splits issue in `DecathlonDataset`\n* Issue of randomising a `Compose` transform\n* Various issues in function type hints\n* Typos in docstring and documentation\n* `PersistentDataset` issue with existing file folder\n* Filename issue in the output writers\n\n## [0.2.0] - 2020-07-02\n### Added\n* Overview document for [feature highlights in v0.2.0](https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md)\n* Type hints and static type analysis support\n* `MONAI/research` folder\n* `monai.engine.workflow` APIs for supervised training\n* `monai.inferers` APIs for validation and inference\n* 7 new tutorials and examples\n* 3 new loss functions\n* 4 new event handlers\n* 8 new layers, blocks, and networks\n* 12 new transforms, including post-processing transforms\n* `monai.apps.datasets` APIs, including `MedNISTDataset` and `DecathlonDataset`\n* Persistent caching, `ZipDataset`, and `ArrayDataset` in `monai.data`\n* Cross-platform CI tests supporting multiple Python versions\n* Optional import mechanism\n* Experimental features for third-party transforms integration\n\n### Changed\n> For more details please visit [the project wiki](https://github.com/Project-MONAI/MONAI/wiki/Notable-changes-between-0.1.0-and-0.2.0)\n* Core modules now require numpy >= 1.17\n* Categorized `monai.transforms` modules into crop and pad, intensity, IO, post-processing, spatial, and utility.\n* Most transforms are now implemented with PyTorch native APIs\n* Code style enforcement and automated formatting workflows now use autopep8 and black\n* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.03-py3` from `nvcr.io/nvidia/pytorch:19.10-py3`\n* Enhanced local testing tools\n* Documentation website domain changed to https://docs.monai.io\n\n### Removed\n* Support of Python < 3.6\n* Automatic installation of optional dependencies including pytorch-ignite, nibabel, tensorboard, pillow, scipy, scikit-image\n\n### Fixed\n* Various issues in type and argument names consistency\n* Various issues in docstring and documentation site\n* Various issues in unit and integration tests\n* Various issues in examples and notebooks\n\n## [0.1.0] - 2020-04-17\n### Added\n* Public alpha source code release under the Apache 2.0 license ([highlights](https://github.com/Project-MONAI/MONAI/blob/0.1.0/docs/source/highlights.md))\n* Various tutorials and examples\n  - Medical image classification and segmentation workflows\n  - Spacing/orientation-aware preprocessing with CPU/GPU and caching\n  - Flexible workflows with PyTorch Ignite and Lightning\n* Various GitHub Actions\n  - CI/CD pipelines via self-hosted runners\n  - Documentation publishing via readthedocs.org\n  - PyPI package publishing\n* Contributing guidelines\n* A project logo and badges\n\n[highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md\n\n[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.5.2...HEAD\n[1.5.2]: https://github.com/Project-MONAI/MONAI/compare/1.5.1...1.5.2\n[1.5.1]: https://github.com/Project-MONAI/MONAI/compare/1.5.0...1.5.1\n[1.5.0]: https://github.com/Project-MONAI/MONAI/compare/1.4.0...1.5.0\n[1.4.0]: https://github.com/Project-MONAI/MONAI/compare/1.3.2...1.4.0\n[1.3.2]: https://github.com/Project-MONAI/MONAI/compare/1.3.1...1.3.2\n[1.3.1]: https://github.com/Project-MONAI/MONAI/compare/1.3.0...1.3.1\n[1.3.0]: https://github.com/Project-MONAI/MONAI/compare/1.2.0...1.3.0\n[1.2.0]: https://github.com/Project-MONAI/MONAI/compare/1.1.0...1.2.0\n[1.1.0]: https://github.com/Project-MONAI/MONAI/compare/1.0.1...1.1.0\n[1.0.1]: https://github.com/Project-MONAI/MONAI/compare/1.0.0...1.0.1\n[1.0.0]: https://github.com/Project-MONAI/MONAI/compare/0.9.1...1.0.0\n[0.9.1]: https://github.com/Project-MONAI/MONAI/compare/0.9.0...0.9.1\n[0.9.0]: https://github.com/Project-MONAI/MONAI/compare/0.8.1...0.9.0\n[0.8.1]: https://github.com/Project-MONAI/MONAI/compare/0.8.0...0.8.1\n[0.8.0]: https://github.com/Project-MONAI/MONAI/compare/0.7.0...0.8.0\n[0.7.0]: https://github.com/Project-MONAI/MONAI/compare/0.6.0...0.7.0\n[0.6.0]: https://github.com/Project-MONAI/MONAI/compare/0.5.3...0.6.0\n[0.5.3]: https://github.com/Project-MONAI/MONAI/compare/0.5.0...0.5.3\n[0.5.0]: https://github.com/Project-MONAI/MONAI/compare/0.4.0...0.5.0\n[0.4.0]: https://github.com/Project-MONAI/MONAI/compare/0.3.0...0.4.0\n[0.3.0]: https://github.com/Project-MONAI/MONAI/compare/0.2.0...0.3.0\n[0.2.0]: https://github.com/Project-MONAI/MONAI/compare/0.1.0...0.2.0\n[0.1.0]: https://github.com/Project-MONAI/MONAI/commits/0.1.0\n"
  },
  {
    "path": "CITATION.cff",
    "content": "# YAML 1.2\n# Metadata for citation of this software according to the CFF format (https://citation-file-format.github.io/)\n#\n---\ntitle: \"MONAI: Medical Open Network for AI\"\nabstract: \"AI Toolkit for Healthcare Imaging\"\nauthors:\n  - name: \"MONAI Consortium\"\ndate-released: 2026-01-29\nversion: \"1.5.2\"\nidentifiers:\n  - description: \"This DOI represents all versions of MONAI, and will always resolve to the latest one.\"\n    type: doi\n    value: \"10.5281/zenodo.4323058\"\nlicense: \"Apache-2.0\"\nrepository-code: \"https://github.com/Project-MONAI/MONAI\"\nurl: \"https://project-monai.github.io/\"\ncff-version: \"1.2.0\"\nmessage: \"If you use this software, please cite it using these metadata.\"\npreferred-citation:\n  type: article\n  authors:\n  - given-names: \"M. Jorge\"\n    family-names: \"Cardoso\"\n  - given-names: \"Wenqi\"\n    family-names: \"Li\"\n  - given-names: \"Richard\"\n    family-names: \"Brown\"\n  - given-names: \"Nic\"\n    family-names: \"Ma\"\n  - given-names: \"Eric\"\n    family-names: \"Kerfoot\"\n  - given-names: \"Yiheng\"\n    family-names: \"Wang\"\n  - given-names: \"Benjamin\"\n    family-names: \"Murray\"\n  - given-names: \"Andriy\"\n    family-names: \"Myronenko\"\n  - given-names: \"Can\"\n    family-names: \"Zhao\"\n  - given-names: \"Dong\"\n    family-names: \"Yang\"\n  - given-names: \"Vishwesh\"\n    family-names: \"Nath\"\n  - given-names: \"Yufan\"\n    family-names: \"He\"\n  - given-names: \"Ziyue\"\n    family-names: \"Xu\"\n  - given-names: \"Ali\"\n    family-names: \"Hatamizadeh\"\n  - given-names: \"Wentao\"\n    family-names: \"Zhu\"\n  - given-names: \"Yun\"\n    family-names: \"Liu\"\n  - given-names: \"Mingxin\"\n    family-names: \"Zheng\"\n  - given-names: \"Yucheng\"\n    family-names: \"Tang\"\n  - given-names: \"Isaac\"\n    family-names: \"Yang\"\n  - given-names: \"Michael\"\n    family-names: \"Zephyr\"\n  - given-names: \"Behrooz\"\n    family-names: \"Hashemian\"\n  - given-names: \"Sachidanand\"\n    family-names: \"Alle\"\n  - given-names: \"Mohammad\"\n    family-names: \"Zalbagi Darestani\"\n  - given-names: \"Charlie\"\n    family-names: \"Budd\"\n  - given-names: \"Marc\"\n    family-names: \"Modat\"\n  - given-names: \"Tom\"\n    family-names: \"Vercauteren\"\n  - given-names: \"Guotai\"\n    family-names: \"Wang\"\n  - given-names: \"Yiwen\"\n    family-names: \"Li\"\n  - given-names: \"Yipeng\"\n    family-names: \"Hu\"\n  - given-names: \"Yunguan\"\n    family-names: \"Fu\"\n  - given-names: \"Benjamin\"\n    family-names: \"Gorman\"\n  - given-names: \"Hans\"\n    family-names: \"Johnson\"\n  - given-names: \"Brad\"\n    family-names: \"Genereaux\"\n  - given-names: \"Barbaros S.\"\n    family-names: \"Erdal\"\n  - given-names: \"Vikash\"\n    family-names: \"Gupta\"\n  - given-names: \"Andres\"\n    family-names: \"Diaz-Pinto\"\n  - given-names: \"Andre\"\n    family-names: \"Dourson\"\n  - given-names: \"Lena\"\n    family-names: \"Maier-Hein\"\n  - given-names: \"Paul F.\"\n    family-names: \"Jaeger\"\n  - given-names: \"Michael\"\n    family-names: \"Baumgartner\"\n  - given-names: \"Jayashree\"\n    family-names: \"Kalpathy-Cramer\"\n  - given-names: \"Mona\"\n    family-names: \"Flores\"\n  - given-names: \"Justin\"\n    family-names: \"Kirby\"\n  - given-names: \"Lee A.D.\"\n    family-names: \"Cooper\"\n  - given-names: \"Holger R.\"\n    family-names: \"Roth\"\n  - given-names: \"Daguang\"\n    family-names: \"Xu\"\n  - given-names: \"David\"\n    family-names: \"Bericat\"\n  - given-names: \"Ralf\"\n    family-names: \"Floca\"\n  - given-names: \"S. Kevin\"\n    family-names: \"Zhou\"\n  - given-names: \"Haris\"\n    family-names: \"Shuaib\"\n  - given-names: \"Keyvan\"\n    family-names: \"Farahani\"\n  - given-names: \"Klaus H.\"\n    family-names: \"Maier-Hein\"\n  - given-names: \"Stephen\"\n    family-names: \"Aylward\"\n  - given-names: \"Prerna\"\n    family-names: \"Dogra\"\n  - given-names: \"Sebastien\"\n    family-names: \"Ourselin\"\n  - given-names: \"Andrew\"\n    family-names: \"Feng\"\n  doi: \"https://doi.org/10.48550/arXiv.2211.02701\"\n  month: 11\n  year: 2022\n  title: \"MONAI: An open-source framework for deep learning in healthcare\"\n...\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to making participation in our project and\nour community a harassment-free experience for everyone, regardless of age, body\nsize, disability, ethnicity, sex characteristics, gender identity and expression,\nlevel of experience, education, socio-economic status, nationality, personal\nappearance, race, religion, or sexual identity and orientation.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy towards other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or\n advances\n* Trolling, insulting/derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or electronic\n address, without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n professional setting\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned to this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies both within project spaces and in public spaces\nwhen an individual is representing the project or its community. Examples of\nrepresenting a project or community include using an official project e-mail\naddress, posting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event. Representation of a project may be\nfurther defined and clarified by project maintainers.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the project team at monai.contact@gmail.com. All\ncomplaints will be reviewed and investigated and will result in a response that\nis deemed necessary and appropriate to the circumstances. The project team is\nobligated to maintain confidentiality with regard to the reporter of an incident.\nFurther details of specific enforcement policies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,\navailable at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see\nhttps://www.contributor-covenant.org/faq\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "- [Introduction](#introduction)\n- [The contribution process](#the-contribution-process)\n  - [Preparing pull requests](#preparing-pull-requests)\n    1. [Checking the coding style](#checking-the-coding-style)\n    1. [Unit testing](#unit-testing)\n    1. [Building the documentation](#building-the-documentation)\n    1. [Automatic code formatting](#automatic-code-formatting)\n    1. [Adding new optional dependencies](#adding-new-optional-dependencies)\n    1. [Signing your work](#signing-your-work)\n    1. [Utility functions](#utility-functions)\n    1. [Backwards compatibility](#backwards-compatibility)\n  - [Submitting pull requests](#submitting-pull-requests)\n- [The code reviewing process (for the maintainers)](#the-code-reviewing-process)\n  - [Reviewing pull requests](#reviewing-pull-requests)\n- [Admin tasks (for the maintainers)](#admin-tasks)\n  - [Releasing a new version](#release-a-new-version)\n\n## Introduction\n\nWelcome to Project MONAI! We're excited you're here and want to contribute. This documentation is intended for individuals and institutions interested in contributing to MONAI. MONAI is an open-source project and, as such, its success relies on its community of contributors willing to keep improving it. Your contribution will be a valued addition to the code base; we simply ask that you read this page and understand our contribution process, whether you are a seasoned open-source contributor or whether you are a first-time contributor.\n\n### Communicate with us\n\nWe are happy to talk with you about your needs for MONAI and your ideas for contributing to the project. One way to do this is to create an issue discussing your thoughts. It might be that a very similar feature is under development or already exists, so an issue is a great starting point. If you are looking for an issue to resolve that will help Project MONAI, see the [*good first issue*](https://github.com/Project-MONAI/MONAI/labels/good%20first%20issue) and [*Contribution wanted*](https://github.com/Project-MONAI/MONAI/labels/Contribution%20wanted) labels.\n\n### Does it belong in PyTorch instead of MONAI?\n\nMONAI is part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/), and mainly based on the PyTorch and Numpy libraries. These libraries implement what we consider to be best practice for general scientific computing and deep learning functionality. MONAI builds on these with a strong focus on medical applications. As such, it is a good idea to consider whether your functionality is medical-application specific or not. General deep learning functionality may be better off in PyTorch; you can find their contribution guidelines [here](https://pytorch.org/docs/stable/community/contribution_guide.html).\n\n## The contribution process\n\n*Pull request early*\n\nWe encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. [Create a draft pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request) until it is ready for formal review.\n\nPlease note that, as per PyTorch, MONAI uses American English spelling. This means classes and variables should be: normali**z**e, visuali**z**e, colo~~u~~r, etc.\n\n### Preparing pull requests\n\nTo ensure the code quality, MONAI relies on several linting tools ([black](https://github.com/psf/black), [isort](https://github.com/timothycrosley/isort), [ruff](https://github.com/astral-sh/ruff)),\nstatic type analysis tools ([mypy](https://github.com/python/mypy), [pytype](https://github.com/google/pytype)), as well as a set of unit/integration tests.\n\nThis section highlights all the necessary preparation steps required before sending a pull request.\nTo collaborate efficiently, please read through this section and follow them.\n\n- [Checking the coding style](#checking-the-coding-style)\n- [Licensing information](#licensing-information)\n- [Unit testing](#unit-testing)\n- [Building documentation](#building-the-documentation)\n- [Signing your work](#signing-your-work)\n\n#### Checking the coding style\n\nCoding style is checked and enforced by black, isort, and ruff.\nBefore submitting a pull request, we recommend that all linting should pass, by running the following command locally:\n\n```bash\n# optionally update the dependencies and dev tools\npython -m pip install -U pip\npython -m pip install -U -r requirements-dev.txt\n\n# run the linting and type checking tools\n./runtests.sh --codeformat\n\n# try to fix the coding style errors automatically\n./runtests.sh --autofix\n```\n\nFull linting and type checking may take some time. If you need a quick check, run\n\n```bash\n# run ruff only\n./runtests.sh --ruff\n```\n\n#### Licensing information\n\nAll source code files should start with this paragraph:\n\n```\n# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n```\n\n##### Exporting modules\n\nIf you intend for any variables/functions/classes to be available outside of the file with the edited functionality, then:\n\n- Create or append to the `__all__` variable (in the file in which functionality has been added), and\n- Add to the `__init__.py` file.\n\n#### Unit testing\n\nMONAI tests are located under `tests/`.\n\n- The unit test's file name currently follows `test_[module_name].py` or `test_[module_name]_dist.py`.\n- The `test_[module_name]_dist.py` subset of unit tests requires a distributed environment to verify the module with distributed GPU-based computation.\n- The integration test's file name follows `test_integration_[workflow_name].py`.\n\nA bash script (`runtests.sh`) is provided to run all tests locally.\nPlease run ``./runtests.sh -h`` to see all options.\n\nTo run a particular test, for example `tests/losses/test_dice_loss.py`:\n\n```\npython -m tests.losses.test_dice_loss\n```\n\nBefore submitting a pull request, we recommend that all linting and unit tests\nshould pass, by running the following command locally:\n\n```bash\n./runtests.sh -f -u --net --coverage\n```\n\nor (for new features that would not break existing functionality):\n\n```bash\n./runtests.sh --quick --unittests\n```\n\nIt is recommended that the new test `test_[module_name].py` is constructed by using only\npython 3.9+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages.\nIf it requires any other external packages, please make sure:\n\n- the packages are listed in [`requirements-dev.txt`](requirements-dev.txt)\n- the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that\nthe minimal CI runner will not execute it.\n\n##### Testing data\n\nTesting data such as images and binary files should not be placed in the source code repository.\nPlease deploy them to a reliable file sharing location (the current preferred one is [https://github.com/Project-MONAI/MONAI-extra-test-data/releases](https://github.com/Project-MONAI/MONAI-extra-test-data/releases)).\nAt test time, the URLs within `tests/testing_data/data_config.json` are accessible\nvia the APIs provided in `tests.utils`: `tests.utils.testing_data_config` and `tests.utils.download_url_or_skip_test`.\n\n*If it's not tested, it's broken*\n\nAll new functionality should be accompanied by an appropriate set of tests.\nMONAI functionality has plenty of unit tests from which you can draw inspiration,\nand you can reach out to us if you are unsure of how to proceed with testing.\n\nMONAI's code coverage report is available at [CodeCov](https://codecov.io/gh/Project-MONAI/MONAI).\n\n#### Building the documentation\n\nMONAI's documentation is located at `docs/`.\n\n```bash\n# install the doc-related dependencies\npip install --upgrade pip\npip install -r docs/requirements.txt\n\n# build the docs\ncd docs/\nmake html\n```\n\nThe above commands build html documentation, they are used to automatically generate [https://docs.monai.io](https://docs.monai.io).\n\nThe Python code docstring are written in\n[reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) and\nthe documentation pages can be in either [reStructuredText](https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html) or [Markdown](https://en.wikipedia.org/wiki/Markdown).  In general the Python docstrings follow the [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings).\n\nBefore submitting a pull request, it is recommended to:\n\n- edit the relevant `.rst` files in [`docs/source`](./docs/source) accordingly.\n- build html documentation locally\n- check the auto-generated documentation (by browsing `./docs/build/html/index.html` with a web browser)\n- type `make clean` in `docs/` folder to remove the current build files.\n\nPlease type `make help` in `docs/` folder for all supported format options.\n\n#### Automatic code formatting\n\nMONAI provides support of automatic Python code formatting via [a customised GitHub action](https://github.com/Project-MONAI/monai-code-formatter).\nThis makes the project's Python coding style consistent and reduces maintenance burdens.\nCommenting a pull request with `/black` triggers the formatting action based on [`psf/Black`](https://github.com/psf/black) (this is implemented with [`slash command dispatch`](https://github.com/marketplace/actions/slash-command-dispatch)).\n\nSteps for the formatting process:\n\n- After submitting a pull request or push to an existing pull request,\nmake a comment to the pull request to trigger the formatting action.\nThe first line of the comment must be `/black` so that it will be interpreted by [the comment parser](https://github.com/marketplace/actions/slash-command-dispatch#how-are-comments-parsed-for-slash-commands).\n- [Auto] The GitHub action tries to format all Python files (using [`psf/Black`](https://github.com/psf/black)) in the branch and makes a commit under the name \"MONAI bot\" if there's code change. The actual formatting action is deployed at [project-monai/monai-code-formatter](https://github.com/Project-MONAI/monai-code-formatter).\n- [Auto] After the formatting commit, the GitHub action adds an emoji to the comment that triggered the process.\n- Repeat the above steps if necessary.\n\n#### Adding new optional dependencies\n\nIn addition to the minimal requirements of PyTorch and Numpy, MONAI's core modules are built optionally based on 3rd-party packages.\nThe current set of dependencies is listed in [installing dependencies](https://monai.readthedocs.io/en/stable/installation.html#installing-the-recommended-dependencies).\n\nTo allow for flexible integration of MONAI with other systems and environments,\nthe optional dependency APIs are always invoked lazily. For example,\n\n```py\nfrom monai.utils import optional_import\nitk, _ = optional_import(\"itk\", ...)\n\nclass ITKReader(ImageReader):\n    ...\n    def read(self, ...):\n        return itk.imread(...)\n```\n\nThe availability of the external `itk.imread` API is not required unless `monai.data.ITKReader.read` is called by the user.\nIntegration tests with minimal requirements are deployed to ensure this strategy.\n\nTo add new optional dependencies, please communicate with the core team during pull request reviews,\nand add the necessary information (at least) to the following files:\n\n- [setup.cfg](https://github.com/Project-MONAI/MONAI/blob/dev/setup.cfg)  (for package's `[options.extras_require]` config)\n- [requirements-dev.txt](https://github.com/Project-MONAI/MONAI/blob/dev/requirements-dev.txt) (pip requirements file)\n- [docs/requirements.txt](https://github.com/Project-MONAI/MONAI/blob/dev/docs/requirements.txt) (docs pip requirements file)\n- [environment-dev.yml](https://github.com/Project-MONAI/MONAI/blob/dev/environment-dev.yml) (conda environment file)\n- [installation.md](https://github.com/Project-MONAI/MONAI/blob/dev/docs/source/installation.md) (documentation)\n\nWhen writing unit tests that use 3rd-party packages, it is a good practice to always consider\nan appropriate fallback default behaviour when the packages are not installed in\nthe testing environment. For example:\n\n```py\nfrom monai.utils import optional_import\nplt, has_matplotlib = optional_import(\"matplotlib.pyplot\")\n\n@skipUnless(has_matplotlib, \"Matplotlib required\")\nclass TestBlendImages(unittest.TestCase):\n```\n\nIt skips the test cases when `matplotlib.pyplot` APIs are not available.\n\nAlternatively, add the test file name to the ``exclude_cases`` in `tests/min_tests.py` to completely skip the test\ncases when running in a minimal setup.\n\n#### Signing your work\n\nMONAI enforces the [Developer Certificate of Origin](https://developercertificate.org/) (DCO) on all pull requests.\nAll commit messages should contain the `Signed-off-by` line with an email address. The [GitHub DCO app](https://github.com/apps/dco) is deployed on MONAI. The pull request's status will be `failed` if commits do not contain a valid `Signed-off-by` line.\n\nGit has a `-s` (or `--signoff`) command-line option to append this automatically to your commit message:\n\n```bash\ngit commit -s -m 'a new commit'\n```\n\nThe commit message will be:\n\n```\n    a new commit\n\n    Signed-off-by: Your Name <yourname@example.org>\n```\n\nFull text of the DCO:\n\n```\nDeveloper Certificate of Origin\nVersion 1.1\n\nCopyright (C) 2004, 2006 The Linux Foundation and its contributors.\n1 Letterman Drive\nSuite D4700\nSan Francisco, CA, 94129\n\nEveryone is permitted to copy and distribute verbatim copies of this\nlicense document, but changing it is not allowed.\n\n\nDeveloper's Certificate of Origin 1.1\n\nBy making a contribution to this project, I certify that:\n\n(a) The contribution was created in whole or in part by me and I\n    have the right to submit it under the open source license\n    indicated in the file; or\n\n(b) The contribution is based upon previous work that, to the best\n    of my knowledge, is covered under an appropriate open source\n    license and I have the right under that license to submit that\n    work with modifications, whether created in whole or in part\n    by me, under the same open source license (unless I am\n    permitted to submit under a different license), as indicated\n    in the file; or\n\n(c) The contribution was provided directly to me by some other\n    person who certified (a), (b) or (c) and I have not modified\n    it.\n\n(d) I understand and agree that this project and the contribution\n    are public and that a record of the contribution (including all\n    personal information I submit with it, including my sign-off) is\n    maintained indefinitely and may be redistributed consistent with\n    this project or the open source license(s) involved.\n```\n\n#### Utility functions\n\nMONAI provides a set of generic utility functions and frequently used routines.\nThese are located in [``monai/utils``](./monai/utils/) and in the module folders such as [``networks/utils.py``](./monai/networks/).\nUsers are encouraged to use these common routines to improve code readability and reduce the code maintenance burdens.\n\nNotably,\n\n- ``monai.module.export`` decorator can make the module name shorter when importing,\nfor example, ``import monai.transforms.Spacing`` is the equivalent of ``monai.transforms.spatial.array.Spacing`` if\n``class Spacing`` defined in file `monai/transforms/spatial/array.py` is decorated with ``@export(\"monai.transforms\")``.\n\nFor string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print`. So please try to use `f-string` if you need to define any string object.\n\n#### Backwards compatibility\n\nMONAI in general follows [PyTorch's policy for backward compatibility](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy).\nUtility functions are provided in `monai.utils.deprecated` to help migrate from the deprecated to new APIs. The use of these utilities is encouraged.\nThe pull request [template contains checkboxes](https://github.com/Project-MONAI/MONAI/blame/dev/.github/pull_request_template.md#L11-L12) that\nthe contributor should use accordingly to clearly indicate breaking changes.\n\nThe process of releasing backwards incompatible API changes is as follows:\n\n1. discuss the breaking changes during pull requests or in dev meetings with a feature proposal if needed.\n1. add a warning message in the upcoming release (version `X.Y`), the warning message should include a forecast of removing the deprecated API in:\n   1. `X+1.0` -- major version `X+1` and minor version `0` the next major version if it's a significant change,\n   1. `X.Y+2` -- major version `X` and minor version `Y+2` (the minor version after the next one), if it's a minor API change.\n   1. Note that the versioning policy is similar to PyTorch's approach which does not precisely follow [the semantic versioning](https://semver.org/) definition.\n      Major version numbers are instead used to represent major product version (which is currently not planned to be greater than 1),\n      minor version for both compatible and incompatible, and patch version for bug fixes.\n   1. when recommending new API to use in place of a deprecated API, the recommended version should\n      provide exact feature-like behaviour otherwise users will have a harder time migrating.\n1. add new test cases by extending the existing unit tests to cover both the deprecated and updated APIs.\n1. collect feedback from the users during the subsequent few releases, and reconsider step 1 if needed.\n1. before each release, review the deprecating APIs and relevant tests, and clean up the removed APIs described in step 2.\n\n### Submitting pull requests\n\nAll code changes to the dev branch must be done via [pull requests](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests).\n\n1. Create a new ticket or take a known ticket from [the issue list][monai issue list].\n1. Check if there's already a branch dedicated to the task.\n1. If the task has not been taken, [create a new branch in your fork](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request-from-a-fork)\nof the codebase named `[ticket_id]-[task_name]`.\nFor example, branch name `19-ci-pipeline-setup` corresponds to [issue #19](https://github.com/Project-MONAI/MONAI/issues/19).\nIdeally, the new branch should be based on the latest `dev` branch.\n1. Make changes to the branch ([use detailed commit messages if possible](https://chris.beams.io/posts/git-commit/)).\n1. Make sure that new tests cover the changes and the changed codebase [passes all tests locally](#unit-testing).\n1. [Create a new pull request](https://help.github.com/en/desktop/contributing-to-projects/creating-a-pull-request) from the task branch to the dev branch, with detailed descriptions of the purpose of this pull request.\n1. Check [the CI/CD status of the pull request][github ci], make sure all CI/CD tests passed.\n1. Wait for reviews; if there are reviews, make point-to-point responses, make further code changes if needed.\n1. If there are conflicts between the pull request branch and the dev branch, pull the changes from the dev and resolve the conflicts locally.\n1. Reviewer and contributor may have discussions back and forth until all comments addressed.\n1. Wait for the pull request to be merged.\n\n## The code reviewing process\n\n### Reviewing pull requests\n\nAll code review comments should be specific, constructive, and actionable.\n\n1. Check [the CI/CD status of the pull request][github ci], make sure all CI/CD tests passed before reviewing (contact the branch owner if needed).\n1. Read carefully the descriptions of the pull request and the files changed, write comments if needed.\n1. Make in-line comments to specific code segments, [request for changes](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-request-reviews) if needed.\n1. Review any further code changes until all comments addressed by the contributors.\n1. Comment to trigger `/black` and/or `/integration-test` for optional auto code formatting and [integration tests](.github/workflows/integration.yml).\n1. [Maintainers] Review the changes and comment `/build` to trigger internal full tests.\n1. Merge the pull request to the dev branch.\n1. Close the corresponding task ticket on [the issue list][monai issue list].\n\n[github ci]: https://github.com/Project-MONAI/MONAI/actions\n[monai issue list]: https://github.com/Project-MONAI/MONAI/issues\n\n## Admin tasks\n\n### Release a new version\n\nThe `dev` branch's `HEAD` always corresponds to MONAI Docker image's latest tag: `projectmonai/monai:latest`. (No\nrelease is currently done for the slim MONAI image, this is built locally by users.)\nThe `main` branch's `HEAD` always corresponds to the latest MONAI milestone release.\n\nWhen major features are ready for a milestone, to prepare for a new release:\n\n- Prepare [a release note](https://github.com/Project-MONAI/MONAI/releases) and release checklist.\n- Check out or cherry-pick a new branch `releasing/[version number]` locally from the `dev` branch and push to the codebase.\n- Create a release candidate tag, for example, `git tag -a 0.1.0rc1 -m \"release candidate 1 of version 0.1.0\"`.\n- Push the tag to the codebase, for example, `git push origin 0.1.0rc1`.\n  This step will trigger package building and testing.\n  The resultant packages are automatically uploaded to\n  [TestPyPI](https://test.pypi.org/project/monai/).  The packages are also available for downloading as\n  repository's artifacts (e.g. the file at <https://github.com/Project-MONAI/MONAI/actions/runs/66570977>).\n- Check the release test at [TestPyPI](https://test.pypi.org/project/monai/), download the artifacts when the CI finishes.\n- Optionally run [the cron testing jobs](https://github.com/Project-MONAI/MONAI/blob/dev/.github/workflows/cron.yml) on `releasing/[version number]`.\n- Rebase `releasing/[version number]` to `main`, make sure all the test pipelines succeed.\n- Once the release candidate is verified, tag and push a milestone, for example, `git push origin 0.1.0`.\n  The tag must be with the latest commit of `releasing/[version number]`.\n- Upload the packages to [PyPI](https://pypi.org/project/monai/).\n  This could be done manually by ``twine upload dist/*``, given the artifacts are unzipped to the folder ``dist/``.\n- Merge `releasing/[version number]` to `dev`, this step must make sure that the tagging commit unchanged on `dev`.\n- Publish the release note.\n\nNote that the release should be tagged with a [PEP440](https://www.python.org/dev/peps/pep-0440/) compliant version number.\n\nIf any error occurs during the release process, first check out a new hotfix branch from the `releasing/[version number]`,\nthen make PRs to the `releasing/[version number]` to fix the bugs via the regular contribution procedure.\n\nIf any error occurs after the release process, first check out a new hotfix branch from the `main` branch,\nmake a patch version release following the semantic versioning, for example, `releasing/0.1.1`.\nMake sure the `releasing/0.1.1` is merged back into both `dev` and `main` and all the test pipelines succeed.\n\n<p align=\"right\">\n  <a href=\"#introduction\">⬆️ Back to Top</a>\n</p>\n"
  },
  {
    "path": "Dockerfile",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# To build with a different base image\n# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.\nARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:25.12-py3\nFROM ${PYTORCH_IMAGE}\n\nLABEL maintainer=\"monai.contact@gmail.com\"\n\n# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)\nRUN if [[ $(uname -m) =~ \"aarch64\" ]]; then \\\n      export CFLAGS=\"-O3\" && \\\n      export DISABLE_NUMCODECS_SSE2=true && \\\n      export DISABLE_NUMCODECS_AVX2=true && \\\n      pip install numcodecs; \\\n    fi\n\nWORKDIR /opt/monai\n\n# install full deps\nCOPY requirements.txt requirements-min.txt requirements-dev.txt /tmp/\nRUN cp /tmp/requirements.txt /tmp/req.bak \\\n  && awk '!/torch/' /tmp/requirements.txt > /tmp/tmp && mv /tmp/tmp /tmp/requirements.txt \\\n  && python -m pip install --upgrade --no-cache-dir pip \\\n  && python -m pip install --no-cache-dir -r /tmp/requirements-dev.txt\n\n# compile ext and remove temp files\n# TODO: remark for issue [revise the dockerfile #1276](https://github.com/Project-MONAI/MONAI/issues/1276)\n# please specify exact files and folders to be copied -- else, basically always, the Docker build process cannot cache\n# this or anything below it and always will build from at most here; one file change leads to no caching from here on...\n\nCOPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md versioneer.py setup.py setup.cfg runtests.sh MANIFEST.in ./\nCOPY tests ./tests\nCOPY monai ./monai\n\nRUN BUILD_MONAI=1 FORCE_CUDA=1 python setup.py develop \\\n  && rm -rf build __pycache__\n\n# NGC Client\nWORKDIR /opt/tools\nARG NGC_CLI_URI=\"https://ngc.nvidia.com/downloads/ngccli_linux.zip\"\nRUN wget -q ${NGC_CLI_URI} && unzip ngccli_linux.zip && chmod u+x ngc-cli/ngc && \\\n    find ngc-cli/ -type f -exec md5sum {} + | LC_ALL=C sort | md5sum -c ngc-cli.md5 && \\\n    rm -rf ngccli_linux.zip ngc-cli.md5\nENV PATH=${PATH}:/opt/tools:/opt/tools/ngc-cli\nRUN apt-get update \\\n  && DEBIAN_FRONTEND=\"noninteractive\" apt-get install -y libopenslide0  \\\n  && rm -rf /var/lib/apt/lists/*\n# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations\nENV PATH=${PATH}:/opt/tools\nENV POLYGRAPHY_AUTOINSTALL_DEPS=1\n\n\nWORKDIR /opt/monai\n"
  },
  {
    "path": "Dockerfile.slim",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# This is a slimmed down version of the MONAI Docker image using a smaller base image and multi-stage building. Not all\n# NVIDIA tools will be present but all libraries and compiled code are included. This image isn't provided through\n# Dockerhub so users must build locally: `docker build -t monai_slim -f Dockerfile.slim .`\n# Containers may require more shared memory, eg.: `docker run -ti --rm --gpus all --shm-size=10gb monai_slim /bin/bash`\n\nARG IMAGE=debian:12-slim\n\nFROM ${IMAGE} AS build\n\nARG TORCH_CUDA_ARCH_LIST=\"7.5 8.0 8.6 8.9 9.0+PTX\"\n\nENV DEBIAN_FRONTEND=noninteractive\nENV APT_INSTALL=\"apt install -y --no-install-recommends\"\n\nRUN apt update && apt upgrade -y && \\\n    ${APT_INSTALL} ca-certificates python3-pip python-is-python3 git wget libopenslide0 unzip python3-dev && \\\n    wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb && \\\n    dpkg -i cuda-keyring_1.1-1_all.deb && \\\n    apt update && \\\n    ${APT_INSTALL} cuda-toolkit-12 && \\\n    rm -rf /usr/lib/python*/EXTERNALLY-MANAGED /var/lib/apt/lists/* && \\\n    python -m pip install --upgrade --no-cache-dir --no-build-isolation pip\n\n# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)\nRUN if [[ $(uname -m) =~ \"aarch64\" ]]; then \\\n      CFLAGS=\"-O3\" DISABLE_NUMCODECS_SSE2=true DISABLE_NUMCODECS_AVX2=true python -m pip install numcodecs; \\\n    fi\n\n# NGC Client\nWORKDIR /opt/tools\nARG NGC_CLI_URI=\"https://ngc.nvidia.com/downloads/ngccli_linux.zip\"\nRUN wget -q ${NGC_CLI_URI} && unzip ngccli_linux.zip && chmod u+x ngc-cli/ngc && \\\n    find ngc-cli/ -type f -exec md5sum {} + | LC_ALL=C sort | md5sum -c ngc-cli.md5 && \\\n    rm -rf ngccli_linux.zip ngc-cli.md5\n\nWORKDIR /opt/monai\n\n# copy relevant parts of repo\nCOPY requirements.txt requirements-min.txt requirements-dev.txt versioneer.py setup.py setup.cfg pyproject.toml ./\nCOPY LICENSE CHANGELOG.md CODE_OF_CONDUCT.md CONTRIBUTING.md README.md MANIFEST.in runtests.sh ./\nCOPY tests ./tests\nCOPY monai ./monai\n\n# install full deps\nRUN python -m pip install --no-cache-dir --no-build-isolation -r requirements-dev.txt\n\n# compile ext\nRUN CUDA_HOME=/usr/local/cuda FORCE_CUDA=1 USE_COMPILED=1 BUILD_MONAI=1 python setup.py develop\n\n# recreate the image without the installed CUDA packages then copy the installed MONAI and Python directories\nFROM ${IMAGE} AS build2\n\nENV DEBIAN_FRONTEND=noninteractive\nENV APT_INSTALL=\"apt install -y --no-install-recommends\"\n\nRUN apt update && apt upgrade -y && \\\n    ${APT_INSTALL} ca-certificates python3-pip python-is-python3 git libopenslide0 && \\\n    apt clean && \\\n    rm -rf /usr/lib/python*/EXTERNALLY-MANAGED /var/lib/apt/lists/* && \\\n    python -m pip install --upgrade --no-cache-dir --no-build-isolation pip\n\nCOPY --from=build /opt/monai /opt/monai\nCOPY --from=build /opt/tools /opt/tools\nARG PYTHON_VERSION=3.11\nCOPY --from=build /usr/local/lib/python${PYTHON_VERSION}/dist-packages /usr/local/lib/python${PYTHON_VERSION}/dist-packages\nCOPY --from=build /usr/local/bin /usr/local/bin\n\nRUN rm -rf /opt/monai/build /opt/monai/monai.egg-info && \\\n    find /opt /usr/local/lib -type d -name __pycache__ -exec rm -rf {} +\n\n# flatten all layers down to one\nFROM ${IMAGE}\nLABEL maintainer=\"monai.contact@gmail.com\"\n\nCOPY --from=build2 / /\n\nWORKDIR /opt/monai\n\nENV PATH=${PATH}:/opt/tools:/opt/tools/ngc-cli\nENV POLYGRAPHY_AUTOINSTALL_DEPS=1\nENV CUDA_HOME=/usr/local/cuda\nENV BUILD_MONAI=1\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include versioneer.py\ninclude monai/_version.py\n\ninclude README.md\ninclude LICENSE\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n<img src=\"https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/MONAI-logo-color.png\" width=\"50%\" alt='project-monai'>\n</p>\n\n**M**edical **O**pen **N**etwork for **AI**\n\n![Supported Python versions](https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/docs/images/python.svg)\n[![License](https://img.shields.io/badge/license-Apache%202.0-green.svg)](https://opensource.org/licenses/Apache-2.0)\n[![auto-commit-msg](https://img.shields.io/badge/dynamic/json?label=citations&query=%24.citationCount&url=https%3A%2F%2Fapi.semanticscholar.org%2Fgraph%2Fv1%2Fpaper%2FDOI%3A10.48550%2FarXiv.2211.02701%3Ffields%3DcitationCount)](https://arxiv.org/abs/2211.02701)\n[![PyPI version](https://badge.fury.io/py/monai.svg)](https://badge.fury.io/py/monai)\n[![docker](https://img.shields.io/badge/docker-pull-green.svg?logo=docker&logoColor=white)](https://hub.docker.com/r/projectmonai/monai)\n[![conda](https://img.shields.io/conda/vn/conda-forge/monai?color=green)](https://anaconda.org/conda-forge/monai)\n\n[![premerge](https://github.com/Project-MONAI/MONAI/actions/workflows/pythonapp.yml/badge.svg?branch=dev)](https://github.com/Project-MONAI/MONAI/actions/workflows/pythonapp.yml)\n[![postmerge](https://img.shields.io/github/checks-status/project-monai/monai/dev?label=postmerge)](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev)\n[![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://monai.readthedocs.io/en/latest/)\n[![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg?token=6FTC7U1JJ4)](https://codecov.io/gh/Project-MONAI/MONAI)\n[![monai Downloads Last Month](https://assets.piptrends.com/get-last-month-downloads-badge/monai.svg 'monai Downloads Last Month by pip Trends')](https://piptrends.com/package/monai)\n\nMONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).\nIts ambitions are as follows:\n\n- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;\n- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;\n- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.\n\n## Features\n\n> _Please see [the technical highlights](https://monai.readthedocs.io/en/latest/highlights.html) and [What's New](https://monai.readthedocs.io/en/latest/whatsnew.html) of the milestone releases._\n\n- flexible pre-processing for multi-dimensional medical imaging data;\n- compositional & portable APIs for ease of integration in existing workflows;\n- domain-specific implementations for networks, losses, evaluation metrics and more;\n- customizable design for varying user expertise;\n- multi-GPU multi-node data parallelism support.\n\n## Requirements\n\nMONAI works with the [currently supported versions of Python](https://devguide.python.org/versions), and depends directly on NumPy and PyTorch with many optional dependencies.\n\n* Major releases of MONAI will have dependency versions stated for them. The current state of the `dev` branch in this repository is the unreleased development version of MONAI which typically will support current versions of dependencies and include updates and bug fixes to do so.\n* PyTorch support covers [the current version](https://github.com/pytorch/pytorch/releases) plus three previous minor versions. If compatibility issues with a PyTorch version and other dependencies arise, support for a version may be delayed until a major release.\n* Our support policy for other dependencies adheres for the most part to [SPEC0](https://scientific-python.org/specs/spec-0000), where dependency versions are supported where possible for up to two years. Discovered vulnerabilities or defects may require certain versions to be explicitly not supported.\n* See the `requirements*.txt` files for dependency version information.\n\n## Installation\n\nTo install [the current release](https://pypi.org/project/monai/), you can simply run:\n\n```bash\npip install monai\n```\n\nPlease refer to [the installation guide](https://monai.readthedocs.io/en/latest/installation.html) for other installation options.\n\n## Getting Started\n\n[MedNIST demo](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb) and [MONAI for PyTorch Users](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/developer_guide.ipynb) are available on Colab.\n\nExamples and notebook tutorials are located at [Project-MONAI/tutorials](https://github.com/Project-MONAI/tutorials).\n\nTechnical documentation is available at [docs.monai.io](https://docs.monai.io).\n\n## Docker\n\nThe MONAI Docker image is available from [Dockerhub](https://hub.docker.com/r/projectmonai/monai),\ntagged as `latest` for the latest state of `dev` or with a release version. A slimmed down image can also be built\nlocally using `Dockerfile.slim`, see that file for instructions.\n\nTo get started with the latest MONAI, use `docker run -ti --rm --gpus all projectmonai/monai:latest /bin/bash`.\n\n## Citation\n\nIf you have used MONAI in your research, please cite us! The citation can be exported from: <https://arxiv.org/abs/2211.02701>.\n\n## Model Zoo\n\n[The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community.\nUtilizing [the MONAI Bundle format](https://monai.readthedocs.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI.\n\n## Contributing\n\nFor guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md).\n\n## Community\n\nJoin the conversation on Twitter/X [@ProjectMONAI](https://twitter.com/ProjectMONAI), [LinkedIn](https://www.linkedin.com/company/projectmonai), or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9).\n\nAsk and answer questions over on [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions).\n\n## Links\n\n- Website: <https://project-monai.github.io/>\n- API documentation (milestone): <https://monai.readthedocs.io/>\n- API documentation (latest dev): <https://monai.readthedocs.io/en/latest/>\n- Code: <https://github.com/Project-MONAI/MONAI>\n- Project tracker: <https://github.com/Project-MONAI/MONAI/projects>\n- Issue tracker: <https://github.com/Project-MONAI/MONAI/issues>\n- Wiki: <https://github.com/Project-MONAI/MONAI/wiki>\n- Test status: <https://github.com/Project-MONAI/MONAI/actions>\n- PyPI package: <https://pypi.org/project/monai/>\n- conda-forge: <https://anaconda.org/conda-forge/monai>\n- Weekly previews: <https://pypi.org/project/monai-weekly/>\n- Docker Hub: <https://hub.docker.com/r/projectmonai/monai>\n"
  },
  {
    "path": "SECURITY.md",
    "content": "# Security Policy\n\n## Reporting a Vulnerability\nMONAI takes security seriously and appreciate your efforts to responsibly disclose vulnerabilities. If you discover a security issue, please report it as soon as possible.\n\nTo report a security issue:\n* please use the GitHub Security Advisories tab to \"[Open a draft security advisory](https://github.com/Project-MONAI/MONAI/security/advisories/new)\".\n* Include a detailed description of the issue, steps to reproduce, potential impact, and any possible mitigations.\n* If applicable, please also attach proof-of-concept code or screenshots.\n* We aim to acknowledge your report within 72 hours and provide a status update as we investigate.\n* Please do not create public issues for security-related reports.\n\n## Disclosure Policy\n* We follow a coordinated disclosure approach.\n* We will not publicly disclose vulnerabilities until a fix has been developed and released.\n* Credit will be given to researchers who responsibly disclose vulnerabilities, if requested.\n## Acknowledgements\nWe greatly appreciate contributions from the security community and strive to recognize all researchers who help keep MONAI safe.\n"
  },
  {
    "path": "docs/.readthedocs.yaml",
    "content": "# Read the Docs configuration file\n# See https://docs.readthedocs.io/en/stable/config-file for details\n\nversion: 2\n\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.9\"\nsphinx:\n   configuration: docs/source/conf.py\npython:\n   install:\n   - requirements: docs/requirements.txt\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# https://github.com/Project-MONAI/MONAI/issues/4354\nexport PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION := python\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\tPIP_ROOT_USER_ACTION=ignore pip install -r requirements.txt\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\nclean:\n\trm -rf build/\n\trm -rf source/_gen\n\trm -rf source/*_properties.csv\n"
  },
  {
    "path": "docs/_static/custom.css",
    "content": "@import url('https://fonts.googleapis.com/css?family=Lekton:700|Roboto&display=swap');\nbody{font-family:'Roboto',sans-serif;}.wy-menu-vertical p.caption{color:#7cccc7;}\n*{font-variant-ligatures: none;}.autoclasstoc td {padding:0.2rem;line-height:normal;}\ndl.field-list>dt{word-break: normal}\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "-f https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp39-cp39-linux_x86_64.whl\ntorch>=2.4.1\npytorch-ignite==0.4.11\nnumpy>=1.20\nitk>=5.2\nnibabel\nparameterized\nscikit-image>=0.19.0\nscipy>=1.12.0; python_version >= '3.9'\ntensorboard\ncommonmark==0.9.1\nrecommonmark==0.6.0\nSphinx\npydata-sphinx-theme\nsphinxcontrib-applehelp\nsphinxcontrib-devhelp\nsphinxcontrib-htmlhelp\nsphinxcontrib-jsmath\nsphinxcontrib-qthelp\nsphinxcontrib-serializinghtml\nsphinx-autodoc-typehints==1.11.1\npandas\neinops\ntransformers>=4.53.0\nmlflow>=2.12.2\nclearml>=1.10.0rc0\ntensorboardX\nimagecodecs; platform_system == \"Linux\" or platform_system == \"Darwin\"\ntifffile; platform_system == \"Linux\" or platform_system == \"Darwin\"\npyyaml\nfire\njsonschema\npynrrd\npydicom\nh5py\nnni; platform_system == \"Linux\"\noptuna\nopencv-python-headless\nonnx>=1.13.0\nonnxruntime; python_version <= '3.10'\nzarr\nhuggingface_hub\npyamg>=5.0.0, <5.3.0\npackaging\npolygraphy\n"
  },
  {
    "path": "docs/source/api.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\nAPI Reference\n=============\n\n.. toctree::\n   :maxdepth: 1\n\n   apps\n   auto3dseg\n   fl\n   bundle\n   transforms\n   losses\n   networks\n   metrics\n   optimizers\n   data\n   engines\n   inferers\n   handlers\n   visualize\n   utils\n"
  },
  {
    "path": "docs/source/apidocs/modules.rst",
    "content": ":orphan:\n\nmonai\n=====\n\n.. toctree::\n   :maxdepth: 4\n\n   monai\n"
  },
  {
    "path": "docs/source/apidocs/monai.rst",
    "content": "monai package\n=============\n\nModule contents\n---------------\n\n.. automodule:: monai\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/applications.md",
    "content": "# Research and Application Highlights\n\n### COPLE-Net for COVID-19 Pneumonia Lesion Segmentation\n[A reimplementation](https://project-monai.github.io/research/coplenet-pneumonia-lesion-segmentation.html) of the COPLE-Net originally proposed by:\n\nG. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. (2020) \"A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.\" IEEE Transactions on Medical Imaging. 2020. [DOI: 10.1109/TMI.2020.3000314](https://doi.org/10.1109/TMI.2020.3000314)\n![coplenet](../images/coplenet.png)\n\n### LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation\n[A reimplementation](https://project-monai.github.io/research/lamp-automated-model-parallelism.html) of the LAMP system originally proposed by:\n\nWentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) \"LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation.\" MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575)\n\n![LAMP UNet](../images/unet-pipe.png)\n\n### DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation\nMONAI integrated the `DiNTS` module to support more flexible topologies and joint two-level search. It provides a topology guaranteed discretization algorithm and a discretization aware topology loss for the search stage to minimize the discretization gap, and a cost usage aware search method which can search 3D networks with different GPU memory requirements. For more details, please check the [DiNTS tutorial](https://project-monai.github.io/research/dints.html).\n\n![DiNTS](../images/dints-overview.png)\n\n### Accounting for Dependencies in Deep Learning Based Multiple Instance Learning for Whole Slide Imaging\nFor [classification of digital pathology whole slide images (WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and network modules for multiple instance learning. These include self-attention transformer blocks for explicitly accounting of the dependencies between instances (image patches) during training. For more details, please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). ![multi-instance](../images/mil-patches.jpg)\n\n### Self-supervised representation learning\nMONAI starts to explore self-supervised representation learning in this milestone release. The Vision Transformer has been extended to learn from self-supervised reconstruction tasks with various data augmentation and a regularized contrastive loss. The weights of the pre-trained backbone could be used to enhance the performance of the novel downstream deep learning tasks.\n\nThe [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) shows how to generate a good set of pre-trained weights using unlabeled data with self-supervised tasks, then use the pre-trained weights to perform fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`.\n\n![self-supervised](../images/ssl_overview.png)\n\n### Swin UNETR model for the task of multi-organ segmentation\nFor [Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images](https://arxiv.org/abs/2201.01266), MONAI introduces new network modules for multi-organ segmentation task using the BTCV challenge dataset. The architecture of Swin UNETR:\n\n![swin-unetr](../images/swin_unetr.png)\n\nThe [tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb) shows a typical pipeline of multi-organ segmentation based on Swin UNETR model, DiceCE loss function, Mean Dice, etc. And we used weights from self-supervised pre-training of Swin UNETR encoder (3D Swin Transformer) on a cohort of 5050 CT scans from publicly available datasets.\n\n### DeepGrow modules for interactive segmentation\n[A reimplementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/deepgrow) of the DeepGrow components, which is deep learning based semi-automated segmentation approach that aims to be a \"smart\" interactive tool for region of interest delineation in medical images, originally proposed by:\n\nSakinis, Tomas, et al. \"Interactive segmentation of medical images through fully convolutional neural networks.\" arXiv preprint arXiv:1903.08205 (2019).\n\n![deepgrow scheme](../images/deepgrow.png)\n\n### DeepEdit workflow for interactive segmentation\nDeepEdit is a method that combines an automatic and a semi-automatic approach for 3D medical images into a single deep learning-based model. The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/deepedit) of the DeepEdit modules provides essential components for interactive segmentation. More details are available in the training and inference [tutorial](https://github.com/Project-MONAI/tutorials/tree/main/deepedit/ignite).\n\nThe following figure shows the typical workflow of interactive segmentation:\n\n![deepedit workflow](../images/deepedit.png)\n\n### NuClick modules for interactive nuclei segmentation\nNuClick is a CNN-based approach to speed up collecting annotations for microscopic objects requiring minimum interaction from the annotator. The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/nuclick) contains essential components for the training and inference workflows of NuClick interactive nuclei segmentation.\n\nThe following figure is example outputs of NuClick (annotator click inside the nucleus and the mask will be generated by CNN):\n\n![nuclick output](../images/nuclick.png)\n\n### Lesion detection in digital pathology\n[Implementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/pathology) of the pathology detection components, which includes efficient whole slide imaging IO and several patch sampling methods with NVIDIA cuCIM library and SmartCache mechanism, FROC measurements for lesion and probabilistic post-processing for lesion detection.\n\n![digital pathology](../images/pathology.png)\n\n### Learning-based image registration\nStarting from v0.5.0, MONAI provides experimental features for building learning-based 2D/3D registration workflows.  These include image similarity measures as loss functions, bending energy as model regularization, network architectures, warping modules. The components can be used to build the major unsupervised and weakly-supervised algorithms.\n\nThe following figure shows the registration of CT images acquired at different time points for a single patient using MONAI:\n\n![3d registration](../images/3d_paired.png)\n\n### 2D and 3D detection workflow\nThe [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/detection) contains 2D and 3D bounding box detection components of `RetinaNet`, which includes：bounding box operations, hard negative sampler, and RetinaNet detectors.\n\nThe following figure shows the detection training and inference workflows:\n\n![detection workflow](../images/detection.png)\n\n### Reproducing the state-of-the-art Kaggle competition solutions\n[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/main/competitions/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification\n"
  },
  {
    "path": "docs/source/apps.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _apps:\n\nApplications\n============\n.. currentmodule:: monai.apps\n\n`Datasets`\n----------\n\n.. autoclass:: MedNISTDataset\n    :members:\n\n.. autoclass:: DecathlonDataset\n    :members:\n\n.. autoclass:: TciaDataset\n    :members:\n\n.. autoclass:: CrossValidation\n    :members:\n\n\n`Clara MMARs`\n-------------\n.. autofunction:: download_mmar\n\n.. autofunction:: load_from_mmar\n\n.. autodata:: monai.apps.MODEL_DESC\n    :annotation:\n\n\n`Utilities`\n-----------\n\n.. autofunction:: check_hash\n\n.. autofunction:: download_url\n\n.. autofunction:: extractall\n\n.. autofunction:: download_and_extract\n\n`Deepgrow`\n----------\n\n.. automodule:: monai.apps.deepgrow.dataset\n.. autofunction:: create_dataset\n\n.. automodule:: monai.apps.deepgrow.interaction\n.. autoclass:: Interaction\n    :members:\n\n.. automodule:: monai.apps.deepgrow.transforms\n.. autoclass:: AddInitialSeedPointd\n    :members:\n.. autoclass:: AddGuidanceSignald\n    :members:\n.. autoclass:: AddRandomGuidanced\n    :members:\n.. autoclass:: AddGuidanceFromPointsd\n    :members:\n.. autoclass:: SpatialCropForegroundd\n    :members:\n.. autoclass:: SpatialCropGuidanced\n    :members:\n.. autoclass:: RestoreLabeld\n    :members:\n.. autoclass:: ResizeGuidanced\n    :members:\n.. autoclass:: FindDiscrepancyRegionsd\n    :members:\n.. autoclass:: FindAllValidSlicesd\n    :members:\n.. autoclass:: Fetch2DSliced\n    :members:\n\n`Pathology`\n-----------\n\n.. automodule:: monai.apps.pathology.inferers\n.. autoclass:: SlidingWindowHoVerNetInferer\n    :members:\n\n.. automodule:: monai.apps.pathology.losses.hovernet_loss\n.. autoclass:: HoVerNetLoss\n    :members:\n\n.. automodule:: monai.apps.pathology.metrics\n.. autoclass:: LesionFROC\n    :members:\n\n.. automodule:: monai.apps.pathology.utils\n.. autofunction:: compute_multi_instance_mask\n.. autofunction:: compute_isolated_tumor_cells\n.. autoclass:: PathologyProbNMS\n    :members:\n\n.. automodule:: monai.apps.pathology.transforms.stain.array\n.. autoclass:: ExtractHEStains\n    :members:\n.. autoclass:: NormalizeHEStains\n    :members:\n\n.. automodule:: monai.apps.pathology.transforms.stain.dictionary\n.. autoclass:: ExtractHEStainsd\n    :members:\n.. autoclass:: NormalizeHEStainsd\n    :members:\n\n.. automodule:: monai.apps.pathology.transforms.post.array\n.. autoclass:: GenerateSuccinctContour\n    :members:\n.. autoclass:: GenerateInstanceContour\n    :members:\n.. autoclass:: GenerateInstanceCentroid\n    :members:\n.. autoclass:: GenerateInstanceType\n    :members:\n.. autoclass:: Watershed\n    :members:\n.. autoclass:: GenerateWatershedMask\n    :members:\n.. autoclass:: GenerateInstanceBorder\n    :members:\n.. autoclass:: GenerateDistanceMap\n    :members:\n.. autoclass:: GenerateWatershedMarkers\n    :members:\n.. autoclass:: HoVerNetNuclearTypePostProcessing\n    :members:\n.. autoclass:: HoVerNetInstanceMapPostProcessing\n    :members:\n\n.. automodule:: monai.apps.pathology.transforms.post.dictionary\n.. autoclass:: GenerateSuccinctContourd\n    :members:\n.. autoclass:: GenerateInstanceContourd\n    :members:\n.. autoclass:: GenerateInstanceCentroidd\n    :members:\n.. autoclass:: GenerateInstanceTyped\n    :members:\n.. autoclass:: Watershedd\n    :members:\n.. autoclass:: GenerateWatershedMaskd\n    :members:\n.. autoclass:: GenerateInstanceBorderd\n    :members:\n.. autoclass:: GenerateDistanceMapd\n    :members:\n.. autoclass:: GenerateWatershedMarkersd\n    :members:\n.. autoclass:: HoVerNetInstanceMapPostProcessingd\n    :members:\n.. autoclass:: HoVerNetNuclearTypePostProcessingd\n    :members:\n\n`Detection`\n-----------\n\n`Hard Negative Sampler`\n~~~~~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.utils.hard_negative_sampler\n    :members:\n\n`RetinaNet Network`\n~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.networks.retinanet_network\n    :members:\n\n`RetinaNet Detector`\n~~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.networks.retinanet_detector\n    :members:\n\n`Transforms`\n~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.transforms.box_ops\n    :members:\n.. automodule:: monai.apps.detection.transforms.array\n    :members:\n.. automodule:: monai.apps.detection.transforms.dictionary\n    :members:\n\n`Anchor`\n~~~~~~~~\n.. automodule:: monai.apps.detection.utils.anchor_utils\n    :members:\n\n`Matcher`\n~~~~~~~~~\n.. automodule:: monai.apps.detection.utils.ATSS_matcher\n    :members:\n\n`Box coder`\n~~~~~~~~~~~\n.. automodule:: monai.apps.detection.utils.box_coder\n    :members:\n\n`Detection Utilities`\n~~~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.utils.detector_utils\n    :members:\n\n.. automodule:: monai.apps.detection.utils.predict_utils\n    :members:\n\n`Inference box selector`\n~~~~~~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.utils.box_selector\n    :members:\n\n`Detection metrics`\n~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.apps.detection.metrics.coco\n    :members:\n.. automodule:: monai.apps.detection.metrics.matching\n    :members:\n\n`Reconstruction`\n----------------\n\nFastMRIReader\n~~~~~~~~~~~~~\n.. autoclass:: monai.apps.reconstruction.fastmri_reader.FastMRIReader\n  :members:\n\n`ConvertToTensorComplex`\n~~~~~~~~~~~~~~~~~~~~~~~~\n.. autofunction:: monai.apps.reconstruction.complex_utils.convert_to_tensor_complex\n\n`ComplexAbs`\n~~~~~~~~~~~~\n.. autofunction:: monai.apps.reconstruction.complex_utils.complex_abs\n\n`RootSumOfSquares`\n~~~~~~~~~~~~~~~~~~\n.. autofunction:: monai.apps.reconstruction.mri_utils.root_sum_of_squares\n\n`ComplexMul`\n~~~~~~~~~~~~\n.. autofunction:: monai.apps.reconstruction.complex_utils.complex_mul\n\n`ComplexConj`\n~~~~~~~~~~~~~\n.. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj\n\n`Vista3d`\n---------\n.. automodule:: monai.apps.vista3d.inferer\n.. autofunction:: point_based_window_inferer\n\n.. automodule:: monai.apps.vista3d.transforms\n.. autoclass:: VistaPreTransformd\n    :members:\n.. autoclass:: VistaPostTransformd\n    :members:\n.. autoclass:: Relabeld\n    :members:\n\n.. automodule:: monai.apps.vista3d.sampler\n.. autofunction:: sample_prompt_pairs\n\n`Auto3DSeg`\n-----------\n.. automodule:: monai.apps.auto3dseg\n  :members:\n  :special-members: __call__\n  :imported-members:\n\n`nnUNet`\n--------\n.. automodule:: monai.apps.nnunet.__main__\n\n.. autoclass:: monai.apps.nnunet.nnUNetV2Runner\n  :members:\n\n`nnUNet Bundle`\n---------------\n.. autoclass:: monai.apps.nnunet.ModelnnUNetWrapper\n    :members:\n    :special-members:\n\n.. autofunction:: monai.apps.nnunet.get_nnunet_trainer\n.. autofunction:: monai.apps.nnunet.get_nnunet_monai_predictor\n.. autofunction:: monai.apps.nnunet.convert_nnunet_to_monai_bundle\n.. autofunction:: monai.apps.nnunet.convert_monai_bundle_to_nnunet\n.. autofunction:: monai.apps.nnunet.get_network_from_nnunet_plans\n"
  },
  {
    "path": "docs/source/auto3dseg.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _auto3dseg:\n\nAuto3dseg\n=========\n\n.. automodule:: monai.auto3dseg\n  :members:\n  :imported-members:\n"
  },
  {
    "path": "docs/source/bundle.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _bundle:\n\nModel Bundle\n============\n.. currentmodule:: monai.bundle\n\n`Config Item`\n-------------\n.. autoclass:: Instantiable\n    :members:\n\n.. autoclass:: ComponentLocator\n    :members:\n\n.. autoclass:: ConfigComponent\n    :members:\n\n.. autoclass:: ConfigExpression\n    :members:\n\n.. autoclass:: ConfigItem\n    :members:\n\n`Reference Resolver`\n--------------------\n.. autoclass:: ReferenceResolver\n    :members:\n\n`Config Parser`\n---------------\n.. autoclass:: ConfigParser\n    :members:\n    :special-members:\n\n\n`Scripts`\n---------\n.. autofunction:: ckpt_export\n.. autofunction:: trt_export\n.. autofunction:: onnx_export\n.. autofunction:: download\n.. autofunction:: load\n.. autofunction:: get_all_bundles_list\n.. autofunction:: get_bundle_info\n.. autofunction:: get_bundle_versions\n.. autofunction:: run\n.. autofunction:: verify_metadata\n.. autofunction:: verify_net_in_out\n.. autofunction:: init_bundle\n.. autofunction:: push_to_hf_hub\n.. autofunction:: update_kwargs\n"
  },
  {
    "path": "docs/source/bundle_intro.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\nBundle\n======\n\nMONAI Bundles are a specification and file structure based way of distributing trained MONAI models with associated\nmetadata, code, documentation, and other resources. These are meant to make it easier for you to distribute your model\nin a format that explains what the model is for, how to use it, how to reproduce the science you've done with it, and\nuse it in other applications such as Label and Deploy.\n\n.. toctree::\n   :maxdepth: 1\n\n   mb_specification\n   config_syntax.md\n   mb_properties\n\nDetailed bundle examples and get started tutorial: https://github.com/Project-MONAI/tutorials/tree/main/bundle\n\nA collection of medical imaging models in the MONAI Bundle format: https://github.com/Project-MONAI/model-zoo\n\n\nBundle vs. MAPs\n---------------\n\nBundles differ from MONAI Application Packages (MAPs) in that they focus on description, code definition, application,\nusage. MAPs focus on deployment, containerisation, integration into existing clinical systems, and other application\nareas relating to putting models into use.\n\n.. image:: ../images/MONAI_clouds.png\n    :alt: Bundle and MAP Concepts\n    :align: center\n\nAs a user, bundles are networks and \"programs\" you would use directly for training, inference, reproducing results,\nand other tasks. Bundles can be integrated into MONAI Label apps to perform segmentation tasks through user interfaces,\nor into MAPs for deployment. They can be integrated into other container environments but this isn't their focus. A\nbundle in general is a more lightweight concept with less infrastructure\n\nFor all applications relating to containerisation, portability, and deployment, MAPs are what you're looking for. A MAP\nis the contained environment for running an inference application directly or within an orchestration system. A bundle\nalone doesn't have the structure suitable for this use, a MAP must be provided which uses a bundle as the inference object.\nMAPs are also meant for inference only unlike bundles which should include training scripts. DICOM access is emphasised in\nMAPs since they are meant for clinical deployment and so must interface with clinical databases.\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport subprocess\nimport sys\nimport importlib\nimport inspect\n\nsys.path.insert(0, os.path.abspath(\"..\"))\nsys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\", \"..\")))\nprint(sys.path)\n\nimport monai  # noqa: E402\n\n# -- Project information -----------------------------------------------------\nproject = \"MONAI\"\ncopyright = \"MONAI Consortium\"\nauthor = \"MONAI Contributors\"\n\n# The full version, including alpha/beta/rc tags\nshort_version = monai.__version__.split(\"+\")[0]\nrelease = short_version\nversion = short_version\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\n    \"transforms\",\n    \"networks\",\n    \"metrics\",\n    \"engines\",\n    \"data\",\n    \"apps\",\n    \"fl\",\n    \"bundle\",\n    \"config\",\n    \"handlers\",\n    \"losses\",\n    \"visualize\",\n    \"utils\",\n    \"inferers\",\n    \"optimizers\",\n    \"auto3dseg\",\n]\n\n\ndef generate_apidocs(*args):\n    \"\"\"Generate API docs automatically by trawling the available modules\"\"\"\n\n    import pandas as pd\n    from monai.bundle.properties import TrainProperties, InferProperties, MetaProperties\n\n    csv_file = os.path.join(os.path.dirname(__file__), \"train_properties.csv\")  # used in mb_properties.rst\n    pd.DataFrame.from_dict(TrainProperties, orient=\"index\").iloc[:, :3].to_csv(csv_file)\n    csv_file = os.path.join(os.path.dirname(__file__), \"infer_properties.csv\")\n    pd.DataFrame.from_dict(InferProperties, orient=\"index\").iloc[:, :3].to_csv(csv_file)\n    csv_file = os.path.join(os.path.dirname(__file__), \"meta_properties.csv\")\n    pd.DataFrame.from_dict(MetaProperties, orient=\"index\").iloc[:, :3].to_csv(csv_file)\n\n    module_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, \"monai\"))\n    output_path = os.path.abspath(os.path.join(os.path.dirname(__file__), \"apidocs\"))\n    apidoc_command_path = \"sphinx-apidoc\"\n    if hasattr(sys, \"real_prefix\"):  # called from a virtualenv\n        apidoc_command_path = os.path.join(sys.prefix, \"bin\", \"sphinx-apidoc\")\n        apidoc_command_path = os.path.abspath(apidoc_command_path)\n    print(f\"output_path {output_path}\")\n    print(f\"module_path {module_path}\")\n    subprocess.check_call(\n        [apidoc_command_path, \"-e\"]\n        + [\"-o\", output_path]\n        + [module_path]\n        + [os.path.join(module_path, p) for p in exclude_patterns]\n    )\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nsource_suffix = {\".rst\": \"restructuredtext\", \".txt\": \"restructuredtext\", \".md\": \"markdown\"}\n\nextensions = [\n    \"recommonmark\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.mathjax\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.linkcode\",\n    \"sphinx.ext.autosectionlabel\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx_autodoc_typehints\",\n]\n\nautoclass_content = \"class\"\nadd_module_names = True\nsource_encoding = \"utf-8\"\nautosectionlabel_prefix_document = True\nnapoleon_use_param = True\nnapoleon_include_init_with_doc = True\nset_type_checking_flag = True\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"pydata_sphinx_theme\"\n# html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]\nhtml_theme_options = {\n    \"external_links\": [{\"url\": \"https://github.com/Project-MONAI/tutorials\", \"name\": \"Tutorials\"}],\n    \"icon_links\": [\n        {\"name\": \"GitHub\", \"url\": \"https://github.com/project-monai/monai\", \"icon\": \"fab fa-github-square\"},\n        {\"name\": \"Twitter\", \"url\": \"https://twitter.com/projectmonai\", \"icon\": \"fab fa-twitter-square\"},\n    ],\n    \"collapse_navigation\": True,\n    \"navigation_with_keys\": True,\n    \"navigation_depth\": 1,\n    \"show_toc_level\": 1,\n    \"footer_start\": [\"copyright\"],\n    \"navbar_align\": \"content\",\n    \"logo\": {\"image_light\": \"MONAI-logo-color.png\", \"image_dark\": \"MONAI-logo-color.png\"},\n}\nhtml_context = {\n    \"github_user\": \"Project-MONAI\",\n    \"github_repo\": \"MONAI\",\n    \"github_version\": \"dev\",\n    \"doc_path\": \"docs/source\",\n    \"conf_py_path\": \"/docs/source\",\n    \"VERSION\": version,\n}\nhtml_scaled_image_link = False\nhtml_show_sourcelink = True\nhtml_favicon = \"../images/favicon.ico\"\nhtml_logo = \"../images/MONAI-logo-color.png\"\nhtml_sidebars = {\"**\": [\"search-field\", \"sidebar-nav-bs\"]}\npygments_style = \"sphinx\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"../_static\"]\nhtml_css_files = [\"custom.css\"]\nhtml_title = f\"{project} {version} Documentation\"\n\n# -- Auto-convert markdown pages to demo --------------------------------------\n\n\ndef setup(app):\n    # Hook to allow for automatic generation of API docs\n    # before doc deployment begins.\n    app.connect(\"builder-inited\", generate_apidocs)\n\n\n# -- Linkcode configuration --------------------------------------------------\nDEFAULT_REF = \"dev\"\nread_the_docs_ref = os.environ.get(\"READTHEDOCS_GIT_IDENTIFIER\", None)\nif read_the_docs_ref:\n    # When building on ReadTheDocs, link to the specific commit\n    # https://docs.readthedocs.io/en/stable/reference/environment-variables.html#envvar-READTHEDOCS_GIT_IDENTIFIER\n    git_ref = read_the_docs_ref\nelif os.environ.get(\"GITHUB_REF_TYPE\", \"branch\") == \"tag\":\n    # When building a tag, link to the tag itself\n    git_ref = os.environ.get(\"GITHUB_REF\", DEFAULT_REF)\nelse:\n    git_ref = os.environ.get(\"GITHUB_SHA\", DEFAULT_REF)\n\nDEFAULT_REPOSITORY = \"Project-MONAI/MONAI\"\nrepository = os.environ.get(\"GITHUB_REPOSITORY\", DEFAULT_REPOSITORY)\n\nbase_code_url = f\"https://github.com/{repository}/blob/{git_ref}\"\nMODULE_ROOT_FOLDER = \"monai\"\nrepo_root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\", \"..\"))\n\n\n# Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py\ndef linkcode_resolve(domain, info):\n    if domain != \"py\":\n        raise ValueError(\n            f\"expected domain to be 'py', got {domain}.\"\n            \"Please adjust linkcode_resolve to either handle this domain or ignore it.\"\n        )\n\n    mod = importlib.import_module(info[\"module\"])\n    if \".\" in info[\"fullname\"]:\n        objname, attrname = info[\"fullname\"].split(\".\")\n        obj = getattr(mod, objname)\n        try:\n            # object is a method of a class\n            obj = getattr(obj, attrname)\n        except AttributeError:\n            # object is an attribute of a class\n            return None\n    else:\n        obj = getattr(mod, info[\"fullname\"])\n\n    try:\n        file = inspect.getsourcefile(obj)\n        source, lineno = inspect.getsourcelines(obj)\n    except TypeError:\n        # e.g. object is a typing.Union\n        return None\n    file = os.path.relpath(file, repo_root_path)\n    if not file.startswith(MODULE_ROOT_FOLDER):\n        # e.g. object is a typing.NewType\n        return None\n    start, end = lineno, lineno + len(source) - 1\n    url = f\"{base_code_url}/{file}#L{start}-L{end}\"\n    return url\n"
  },
  {
    "path": "docs/source/config_syntax.md",
    "content": "# MONAI Bundle Configuration\n\nThe `monai.bundle` module supports building Python-based workflows via structured configurations.\n\nThe main benefits are threefold:\n\n- it provides good readability and usability by separating system parameter settings from the Python code.\n- it describes workflow at a relatively high level and allows for different low-level implementations.\n- learning paradigms at a higher level such as federated learning and AutoML can be decoupled from the component details.\n\nContent:\n\n- [A basic example](#a-basic-example)\n- [Syntax examples explained](#syntax-examples-explained)\n  - [`@` to reference Python objects in configurations](#to-reference-python-objects-in-configurations)\n  - [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions)\n  - [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements)\n  - [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object)\n  - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files)\n- [The command line interface](#the-command-line-interface)\n- [Recommendations](#recommendations)\n\n## A basic example\n\nComponents as part of a workflow can be specified using `JSON` or `YAML` syntax, for example, a network architecture\ndefinition could be stored in a `demo_config.json` file with the following content:\n\n```json\n{\n  \"demo_net\": {\n    \"_target_\": \"monai.networks.nets.BasicUNet\",\n    \"spatial_dims\": 3,\n    \"in_channels\": 1,\n    \"out_channels\": 2,\n    \"features\": [16, 16, 32, 32, 64, 64]\n  }\n}\n```\n\nor alternatively, in `YAML` format (`demo_config.yaml`):\n\n```yaml\ndemo_net:\n  _target_: monai.networks.nets.BasicUNet\n  spatial_dims: 3\n  in_channels: 1\n  out_channels: 2\n  features: [16, 16, 32, 32, 64, 64]\n```\n\nThe configuration parser can instantiate the component as a Python object:\n\n```py\n>>> from monai.bundle import ConfigParser\n>>> config = ConfigParser()\n>>> config.read_config(\"demo_config.json\")\n>>> net = config.get_parsed_content(\"demo_net\", instantiate=True)\nBasicUNet features: (16, 16, 32, 32, 64, 64).\n>>> print(type(net))\n<class 'monai.networks.nets.basic_unet.BasicUNet'>\n```\n\nor additionally, tune the input parameters then instantiate the component:\n\n```py\n>>> config[\"demo_net\"][\"features\"] = [32, 32, 32, 64, 64, 64]\n>>> net = config.get_parsed_content(\"demo_net\", instantiate=True)\nBasicUNet features: (32, 32, 32, 64, 64, 64).\n```\n\nFor more details on the `ConfigParser` API, please see [`monai.bundle.ConfigParser`](https://monai.readthedocs.io/en/latest/bundle.html#config-parser).\n\n## Syntax examples explained\n\nA few characters and keywords are interpreted beyond the plain texts, here are examples of the syntax:\n\n### To reference Python objects in configurations\n\n```json\n\"@preprocessing::transforms::keys\"\n```\n\n_Description:_ `@` character indicates a reference to another configuration value defined at `preprocessing::transforms::keys`.\nwhere `::` indicates a sub-structure of this configuration file. (`#` is a synonym for `::`, `preprocessing#transforms#keys`\nrefers to the same object.)\n\n```json\n\"@preprocessing::1\"\n```\n\n_Description:_ `1` is referencing as an integer, which is used to index (zero-based indexing) the `preprocessing` sub-structure.\n\nRelative reference is supported by starting the reference with `#`. For example, `@#A` is to use `A` at the\nsame config structure level, and `@##A` refers to `A` at one level above.\n\n### To evaluate as Python expressions\n\n```json\n\"$print(42)\"\n```\n\n_Description:_ `$` is a special character to indicate evaluating `print(42)` at runtime.\n\n```json\n\"$[i for i in @datalist]\"\n```\n\n_Description:_ Create a list at runtime using the values in `datalist` as input.\n\n```json\n\"$from torchvision.models import resnet18\"\n```\n\n_Description:_ `$` followed by an import statement is handled slightly differently from the\nPython expressions. The imported module `resnet18` will be available as a global variable\nto the other configuration sections. This is to simplify the use of external modules in the configuration.\n\nThe config expressions may use `@` to reference other config items. For example, in `$lambda x: x + @a + @b`,\n`@a` and `@b` are references to other Python objects and are made available to the anonymous function\nas 'globals'.\nIt's therefore possible to modify the Python objects within an expression, for example,\n`$lambda x: @my_list.pop() + x` will pop the last element from `@my_list` and add it to `x`.\n\n### To textually replace configuration elements\n\n```json\n\"%demo_config.json::demo_net::in_channels\"\n```\n\n_Description:_ `%` character indicates a macro to replace the current configuration element with the texts at `demo_net::in_channels` in the\n`demo_config.json` file. The replacement is done before instantiating or evaluating the components.\n\n### Instantiate a Python object\n\n```json\n{\n  \"demo_name\":{\n    \"_target_\": \"my.python.module.Class\",\n    \"args1\": \"string\",\n    \"args2\": 42}\n}\n```\n\n_Description:_ This dictionary defines an object with a reference name `demo_name`, with an instantiable type\nspecified at `_target_` and with input arguments `args1` and `args2`.\nThis dictionary will be instantiated as a Pytorch object at runtime.\n\n`_target_` is a required key by monai bundle syntax for the Python object name.\n`args1` and `args2` should be compatible with the Python object to instantiate.\n\n```json\n{\n  \"component_name\": {\n    \"_target_\": \"my.module.Class\",\n    \"_desc_\": \"this is a customized class which also triggers 'cudnn_opt' reference\",\n    \"_requires_\": \"@cudnn_opt\",\n    \"_disabled_\": \"true\",\n    \"_mode_\":  \"default\"}\n}\n```\n\n_Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional keys.\n- `_requires_` specifies references (string starts with `@`) or\n  Python expression that will be evaluated/instantiated before `_target_` object is instantiated.\n  It is useful when the component does not explicitly depend on the other ConfigItems via\n  its arguments, but requires the dependencies to be instantiated/evaluated beforehand.\n- `_disabled_` specifies a flag to indicate whether to skip the instantiation.\n- `_desc_` can be used for providing free text descriptions.\n- `_mode_` specifies the operating mode when the component is instantiated or the callable is called.\n  it currently supports the following values:\n  - `\"default\"` (default) -- return the return value of ``_target_(**kwargs)``\n  - `\"callable\"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a\n    partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function\n    that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call\n    it with additional arguments later.\n  - `\"debug\"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,\n    see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).\n\n## Multiple config files\n\n_Description:_ Multiple config files may be specified on the command line.\nThe content of those config files is being merged. When same keys are specifiled in more than one config file,\nthe value associated with the key is being overridden, in the order config files are specified.\nIf the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`.\nThe value types for the merged contents must match and be both of `dict` or both of `list` type.\n`dict` values will be merged via update(), `list` values - concatenated via extend().\nHere's an example. In this case, \"amp\" value will be overridden by extra_config.json.\n`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `\"+imports\"` is not `list`:\n\nconfig.json:\n```json\n{\n    \"amp\": \"$True\"\n    \"imports\": [\n\t\"$import torch\"\n    ],\n    \"preprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n\t  \"$@t1\",\n\t  \"$@t2\"\n        ]\n    },\n}\n```\n\nextra_config.json:\n```json\n{\n    \"amp\": \"$False\"\n    \"+imports\": [\n\t\"$from monai.networks import trt_compile\"\n    ],\n    \"+preprocessing#transforms\": [\n        \"$@t3\"\n    ]\n}\n```\n\n## The command line interface\n\nIn addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.\nThe primary usage is:\n```bash\npython -m monai.bundle COMMANDS\n```\n\nwhere `COMMANDS` is one of the following: `run`, `verify_metadata`, `ckpt_export`, ...\n(please see `python -m monai.bundle --help` for a list of available options).\n\nThe CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file.\nTo display a usage page for a command, for example `run`:\n```bash\npython -m monai.bundle run -- --help\n```\n\nThe support is provided by [Python Fire](https://github.com/google/python-fire), please\nmake sure the optional dependency is installed, for example,\nusing `pip install monai[fire]` or `pip install fire`.\nDetails on the CLI argument parsing is provided in the\n[Python Fire Guide](https://github.com/google/python-fire/blob/master/docs/guide.md#argument-parsing).\n\n## Recommendations\n- Both `YAML` and `JSON` are supported, but the advanced features of these formats are not supported.\n- Using meaningful names for the configuration elements can improve the readability.\n- While it is possible to build complex configurations with the bundle syntax,\n  simple structures with sparse uses of expressions or references are preferred.\n- For `$import <module>` in the configuration, please make sure there are instructions for the users to install\n  the `<module>` if it is not a (optional) dependency of MONAI.\n- As `#`, `::`, and `$` might be interpreted differently by the `shell` or `CLI` tools, may need to add escape characters\n  or quotes for them in the command line, like: `\"\\$torch.device('cuda:1')\"`, `\"'train_part#trainer'\"`.\n- For more details and examples, please see [the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/bundle).\n"
  },
  {
    "path": "docs/source/contrib.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\nDevelopment\n===========\n\nFor guidance on making a contribution to MONAI, see the `contributing guidelines\n<https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md>`_.\n"
  },
  {
    "path": "docs/source/data.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _data:\n\nData\n====\n\nGeneric Interfaces\n------------------\n.. currentmodule:: monai.data\n\n`Dataset`\n~~~~~~~~~\n.. autoclass:: Dataset\n  :members:\n  :special-members: __getitem__\n\n`IterableDataset`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: IterableDataset\n  :members:\n  :special-members: __next__\n\n`DatasetFunc`\n~~~~~~~~~~~~~\n.. autoclass:: DatasetFunc\n  :members:\n  :special-members: __next__\n\n`ShuffleBuffer`\n~~~~~~~~~~~~~~~\n.. autoclass:: ShuffleBuffer\n  :members:\n  :special-members: __next__\n\n`CSVIterableDataset`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: CSVIterableDataset\n  :members:\n  :special-members: __next__\n\n`PersistentDataset`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: PersistentDataset\n  :members:\n  :special-members: __getitem__\n\n`GDSDataset`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: GDSDataset\n  :members:\n  :special-members: __getitem__\n\n\n`CacheNTransDataset`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: CacheNTransDataset\n  :members:\n  :special-members: __getitem__\n\n`LMDBDataset`\n~~~~~~~~~~~~~\n.. autoclass:: LMDBDataset\n  :members:\n  :special-members: __getitem__\n\n`CacheDataset`\n~~~~~~~~~~~~~~\n.. autoclass:: CacheDataset\n  :members:\n  :special-members: __getitem__\n\n`SmartCacheDataset`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SmartCacheDataset\n  :members:\n  :special-members: __getitem__\n\n`ZipDataset`\n~~~~~~~~~~~~\n.. autoclass:: ZipDataset\n  :members:\n  :special-members: __getitem__\n\n`ArrayDataset`\n~~~~~~~~~~~~~~\n.. autoclass:: ArrayDataset\n  :members:\n  :special-members: __getitem__\n\n`ImageDataset`\n~~~~~~~~~~~~~~\n.. autoclass:: ImageDataset\n  :members:\n  :special-members: __getitem__\n\n`NPZDictItemDataset`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: NPZDictItemDataset\n  :members:\n  :special-members: __getitem__\n\n`CSVDataset`\n~~~~~~~~~~~~\n.. autoclass:: CSVDataset\n  :members:\n  :special-members: __getitem__\n\nPatch-based dataset\n-------------------\n\n`GridPatchDataset`\n~~~~~~~~~~~~~~~~~~\n.. autoclass:: GridPatchDataset\n  :members:\n\n`PatchDataset`\n~~~~~~~~~~~~~~\n.. autoclass:: PatchDataset\n  :members:\n\n`PatchIter`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: PatchIter\n    :members:\n    :special-members: __call__\n\n`PatchIterd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: PatchIterd\n    :members:\n    :special-members: __call__\n\nImage reader\n------------\n\nImageReader\n~~~~~~~~~~~\n.. autoclass:: ImageReader\n  :members:\n\nITKReader\n~~~~~~~~~\n.. autoclass:: ITKReader\n  :members:\n\nNibabelReader\n~~~~~~~~~~~~~\n.. autoclass:: NibabelReader\n  :members:\n\nNumpyReader\n~~~~~~~~~~~\n.. autoclass:: NumpyReader\n  :members:\n\nPILReader\n~~~~~~~~~\n.. autoclass:: PILReader\n  :members:\n\nNrrdReader\n~~~~~~~~~~\n.. autoclass:: NrrdReader\n  :members:\n\nImage writer\n------------\n\nresolve_writer\n~~~~~~~~~~~~~~\n.. autofunction:: resolve_writer\n\nregister_writer\n~~~~~~~~~~~~~~~\n.. autofunction:: register_writer\n\nImageWriter\n~~~~~~~~~~~\n.. autoclass:: ImageWriter\n  :members:\n\nITKWriter\n~~~~~~~~~\n.. autoclass:: ITKWriter\n  :members:\n\nNibabelWriter\n~~~~~~~~~~~~~\n.. autoclass:: NibabelWriter\n  :members:\n\nPILWriter\n~~~~~~~~~\n.. autoclass:: PILWriter\n  :members:\n\nSynthetic\n---------\n.. automodule:: monai.data.synthetic\n  :members:\n\n\nOuput folder layout\n-------------------\n.. automodule:: monai.data.folder_layout\n  :members:\n\n\nUtilities\n---------\n.. automodule:: monai.data.utils\n  :members:\n\nPartition Dataset\n~~~~~~~~~~~~~~~~~\n.. autofunction:: monai.data.partition_dataset\n\nPartition Dataset based on classes\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autofunction:: monai.data.partition_dataset_classes\n\nDistributedSampler\n~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.DistributedSampler\n\nDistributedWeightedRandomSampler\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.DistributedWeightedRandomSampler\n\nDatasetSummary\n~~~~~~~~~~~~~~\n.. autoclass:: monai.data.DatasetSummary\n\nDecathlon Datalist\n~~~~~~~~~~~~~~~~~~\n.. autofunction:: monai.data.load_decathlon_datalist\n.. autofunction:: monai.data.load_decathlon_properties\n.. autofunction:: monai.data.check_missing_files\n.. autofunction:: monai.data.create_cross_validation_datalist\n\n\nDataLoader\n~~~~~~~~~~\n.. autoclass:: monai.data.DataLoader\n\n\nThreadBuffer\n~~~~~~~~~~~~\n.. autoclass:: monai.data.ThreadBuffer\n\nThreadDataLoader\n~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.ThreadDataLoader\n\nTestTimeAugmentation\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.TestTimeAugmentation\n\nN-Dim Fourier Transform\n~~~~~~~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.data.fft_utils\n.. autofunction:: monai.data.fft_utils.fftn_centered\n.. autofunction:: monai.data.fft_utils.ifftn_centered\n\nITK Torch Bridge\n~~~~~~~~~~~~~~~~\n.. automodule:: monai.data.itk_torch_bridge\n  :members:\n\n\nMeta Object\n-----------\n.. automodule:: monai.data.meta_obj\n  :members:\n\nMetaTensor\n----------\n.. autoclass:: monai.data.MetaTensor\n  :members:\n  :show-inheritance:\n  :inherited-members: MetaObj\n\n\n\nWhole slide image reader\n------------------------\n\nBaseWSIReader\n~~~~~~~~~~~~~\n.. autoclass:: monai.data.BaseWSIReader\n  :members:\n\nWSIReader\n~~~~~~~~~\n.. autoclass:: monai.data.WSIReader\n  :members:\n\nCuCIMWSIReader\n~~~~~~~~~~~~~~\n.. autoclass:: monai.data.CuCIMWSIReader\n  :members:\n\nOpenSlideWSIReader\n~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.OpenSlideWSIReader\n  :members:\n\nTiffFileWSIReader\n~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.TiffFileWSIReader\n  :members:\n\n\nWhole slide image datasets\n--------------------------\n\nPatchWSIDataset\n~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.PatchWSIDataset\n    :members:\n\nMaskedPatchWSIDataset\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.MaskedPatchWSIDataset\n    :members:\n\nSlidingPatchWSIDataset\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.SlidingPatchWSIDataset\n    :members:\n\nBounding box\n------------\n.. automodule:: monai.data.box_utils\n    :members:\n\nVideo datasets\n--------------\n\nVideoDataset\n~~~~~~~~~~~~\n.. autoclass:: monai.data.video_dataset.VideoDataset\n\nVideoFileDataset\n~~~~~~~~~~~~~~~~\n.. autoclass:: monai.data.video_dataset.VideoFileDataset\n\nCameraDataset\n~~~~~~~~~~~~~\n.. autoclass:: monai.data.video_dataset.CameraDataset\n"
  },
  {
    "path": "docs/source/engines.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _engines:\n\nEngines\n=======\n\nWorkflows\n---------\n\n.. currentmodule:: monai.engines\n\n`Workflow`\n~~~~~~~~~~\n.. autoclass:: Workflow\n    :members:\n\n`Trainer`\n~~~~~~~~~\n.. autoclass:: Trainer\n    :members:\n\n`SupervisedTrainer`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SupervisedTrainer\n    :members:\n\n`GanTrainer`\n~~~~~~~~~~~~\n.. autoclass:: GanTrainer\n    :members:\n\n`AdversarialTrainer`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: AdversarialTrainer\n    :members:\n\n`Evaluator`\n~~~~~~~~~~~\n.. autoclass:: Evaluator\n    :members:\n\n`SupervisedEvaluator`\n~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SupervisedEvaluator\n    :members:\n\n`EnsembleEvaluator`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: EnsembleEvaluator\n    :members:\n\nUtilities\n---------\n.. automodule:: monai.engines.utils\n  :members:\n"
  },
  {
    "path": "docs/source/fl.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _fl:\n\nFederated Learning\n==================\n.. currentmodule:: monai.fl.client\n\n`Client Base Classes`\n---------------------\n\n.. autoclass:: BaseClient\n    :members:\n\n.. autoclass:: ClientAlgo\n    :members:\n\n.. autoclass:: ClientAlgoStats\n    :members:\n\n`MONAI Bundle Reference Implementations`\n----------------------------------------\n\n.. autoclass:: MonaiAlgo\n    :members:\n\n.. autoclass:: MonaiAlgoStats\n    :members:\n"
  },
  {
    "path": "docs/source/handlers.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _handlers:\n\nEvent handlers\n==============\n.. currentmodule:: monai.handlers\n\nModel checkpoint loader\n-----------------------\n.. autoclass:: CheckpointLoader\n  :members:\n\nModel checkpoint saver\n----------------------\n.. autoclass:: CheckpointSaver\n  :members:\n\n\nMetrics saver\n-------------\n.. autoclass:: MetricsSaver\n    :members:\n\n\nCSV saver\n---------\n.. autoclass:: ClassificationSaver\n    :members:\n\n\nIgnite Metric Handler\n---------------------\n.. autoclass:: IgniteMetricHandler\n    :members:\n\n\nMean Dice metrics handler\n-------------------------\n.. autoclass:: MeanDice\n    :members:\n\n\nMean IoU metric handler\n-----------------------\n.. autoclass:: MeanIoUHandler\n    :members:\n\n\nROC AUC metrics handler\n-----------------------\n.. autoclass:: ROCAUC\n    :members:\n\n\nAverage Precision metric handler\n--------------------------------\n.. autoclass:: AveragePrecision\n    :members:\n\n\nConfusion matrix metrics handler\n--------------------------------\n.. autoclass:: ConfusionMatrix\n    :members:\n\n\nHausdorff distance metrics handler\n----------------------------------\n.. autoclass:: HausdorffDistance\n    :members:\n\n\nSurface distance metrics handler\n--------------------------------\n.. autoclass:: SurfaceDistance\n    :members:\n\n\nPanoptic Quality metrics handler\n--------------------------------\n.. autoclass:: PanopticQuality\n    :members:\n\n\nCalibration Error metrics handler\n---------------------------------\n.. autoclass:: CalibrationError\n    :members:\n\n\nMean squared error metrics handler\n----------------------------------\n.. autoclass:: MeanSquaredError\n    :members:\n\n\nMean absolute error metrics handler\n-----------------------------------\n.. autoclass:: MeanAbsoluteError\n    :members:\n\n\nRoot mean squared error metrics handler\n---------------------------------------\n.. autoclass:: RootMeanSquaredError\n    :members:\n\n\nPeak signal to noise ratio metrics handler\n------------------------------------------\n.. autoclass:: PeakSignalToNoiseRatio\n    :members:\n\n\nMetrics reloaded binary handler\n-------------------------------\n.. autoclass:: MetricsReloadedBinaryHandler\n    :members:\n\n\nMetrics reloaded categorical handler\n------------------------------------\n.. autoclass:: MetricsReloadedCategoricalHandler\n    :members:\n\n\nMetric logger\n-------------\n.. autoclass:: MetricLogger\n    :members:\n\n\nLogfile handler\n---------------\n.. autoclass:: LogfileHandler\n    :members:\n\n\nTraining stats handler\n----------------------\n.. autoclass:: StatsHandler\n    :members:\n\n\nTensorboard handlers\n--------------------\n.. autoclass:: TensorBoardHandler\n    :members:\n\n.. autoclass:: TensorBoardStatsHandler\n    :members:\n\n.. autoclass:: TensorBoardImageHandler\n    :members:\n\n\nLR Schedule handler\n-------------------\n.. autoclass:: LrScheduleHandler\n    :members:\n\n\nValidation handler\n------------------\n.. autoclass:: ValidationHandler\n    :members:\n\nSmartCache handler\n------------------\n.. autoclass:: SmartCacheHandler\n    :members:\n\nParameter Scheduler handler\n---------------------------\n.. autoclass:: ParamSchedulerHandler\n    :members:\n\nEarlyStop handler\n-----------------\n.. autoclass:: EarlyStopHandler\n    :members:\n\nGarbageCollector handler\n------------------------\n.. autoclass:: GarbageCollector\n    :members:\n\nPost processing\n---------------\n.. autoclass:: PostProcessing\n    :members:\n\nDecollate batch\n---------------\n.. autoclass:: DecollateBatch\n    :members:\n\nMLFlow handler\n--------------\n.. autoclass:: MLFlowHandler\n    :members:\n\nClearML handlers\n----------------\n.. autoclass:: ClearMLHandler\n    :members:\n\n.. autoclass:: ClearMLStatsHandler\n    :members:\n\n.. autoclass:: ClearMLImageHandler\n    :members:\n\nNVTX Handlers\n-------------\n.. automodule:: monai.handlers.nvtx_handlers\n  :members:\n\nUtilities\n---------\n.. automodule:: monai.handlers.utils\n  :members:\n\nProbability Map Handlers\n------------------------\n.. automodule:: monai.handlers.probability_maps\n  :members:\n"
  },
  {
    "path": "docs/source/highlights.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\nHighlights\n==========\n\n.. toctree::\n   :maxdepth: 1\n\n   modules.md\n   applications.md\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. MONAI documentation main file, created by\n   sphinx-quickstart on Wed Feb  5 09:40:29 2020.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nProject MONAI\n=============\n\n\n*Medical Open Network for AI*\n\nMONAI is a `PyTorch <https://pytorch.org/>`_-based, `open-source <https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE>`_ framework\nfor deep learning in healthcare imaging, part of the `PyTorch Ecosystem <https://pytorch.org/ecosystem/>`_.\n\nIts ambitions are:\n\n- developing a community of academic, industrial and clinical researchers collaborating on a common foundation;\n- creating state-of-the-art, end-to-end training workflows for healthcare imaging;\n- providing researchers with an optimized and standardized way to create and evaluate deep learning models.\n\n.. image:: ../images/MONAI_arch.png\n    :alt: MONAI Architecture\n    :align: center\n\nFeatures\n--------\n\n- flexible pre-processing for multi-dimensional medical imaging data;\n- compositional & portable APIs for ease of integration in existing workflows;\n- domain-specific implementations for networks, losses, evaluation metrics and more;\n- customizable design for varying user expertise;\n- multi-GPU multi-node data parallelism support.\n\n\nGetting started\n---------------\n\n`MedNIST demo <https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb>`_ and `MONAI for PyTorch Users <https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/developer_guide.ipynb>`_ are available on Colab.\n\nExamples and notebook tutorials are located at `Project-MONAI/tutorials <https://github.com/Project-MONAI/tutorials>`_.\n\nTechnical documentation is available at `docs.monai.io <https://docs.monai.io>`_.\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Feature highlights\n\n   whatsnew\n   highlights.md\n\n.. toctree::\n   :maxdepth: 1\n   :caption: API Reference\n\n   api\n\n.. toctree::\n  :maxdepth: 1\n  :caption: Installation\n\n  installation\n\n.. toctree::\n  :maxdepth: 1\n  :caption: Precision and Accelerating\n\n  precision_accelerating\n\n.. toctree::\n  :maxdepth: 1\n  :caption: Contributing\n\n  contrib\n\n.. toctree::\n  :maxdepth: 1\n  :caption: Specifications\n\n  bundle_intro\n  lazy_resampling\n\nModel Zoo\n---------\n\n`The MONAI Model Zoo <https://github.com/Project-MONAI/model-zoo>`_ is a place for researchers and data scientists to share the latest and great models from the community.\nUtilizing `the MONAI Bundle format <https://monai.readthedocs.io/en/latest/bundle_intro.html>`_ makes it easy to `get started <https://github.com/Project-MONAI/tutorials/tree/main/model_zoo>`_ building workflows with MONAI.\n\n\nLinks\n-----\n\n- Website: https://project-monai.github.io/\n- API documentation (milestone): https://monai.readthedocs.io/\n- API documentation (latest dev): https://monai.readthedocs.io/en/latest/\n- Code: https://github.com/Project-MONAI/MONAI\n- Project tracker: https://github.com/Project-MONAI/MONAI/projects\n- Issue tracker: https://github.com/Project-MONAI/MONAI/issues\n- Changelog: https://github.com/Project-MONAI/MONAI/blob/dev/CHANGELOG.md\n- Wiki: https://github.com/Project-MONAI/MONAI/wiki\n- FAQ: https://github.com/Project-MONAI/MONAI/wiki/Frequently-asked-questions-and-answers\n- Test status: https://github.com/Project-MONAI/MONAI/actions\n- PyPI package: https://pypi.org/project/monai/\n- conda-forge: https://anaconda.org/conda-forge/monai\n- Weekly previews: https://pypi.org/project/monai-weekly/\n- Docker Hub: https://hub.docker.com/r/projectmonai/monai\n\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n"
  },
  {
    "path": "docs/source/inferers.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _inferers:\n\nInference methods\n=================\n\nInferers\n--------\n\n.. currentmodule:: monai.inferers\n.. autoclass:: Inferer\n    :members:\n    :special-members: __call__\n\n`PatchInferer`\n~~~~~~~~~~~~~~\n.. autoclass:: PatchInferer\n    :members:\n    :special-members: __call__\n\n`SimpleInferer`\n~~~~~~~~~~~~~~~\n.. autoclass:: SimpleInferer\n    :members:\n    :special-members: __call__\n\n`SlidingWindowInferer`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SlidingWindowInferer\n    :members:\n    :special-members: __call__\n\n`SlidingWindowInfererAdapt`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SlidingWindowInfererAdapt\n    :members:\n    :special-members: __call__\n\n`SaliencyInferer`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: SaliencyInferer\n    :members:\n    :special-members: __call__\n\n`SliceInferer`\n~~~~~~~~~~~~~~\n.. autoclass:: SliceInferer\n    :members:\n    :special-members: __call__\n\n`DiffusionInferer`\n~~~~~~~~~~~~~~~~~~\n.. autoclass:: DiffusionInferer\n    :members:\n    :special-members: __call__\n\n`LatentDiffusionInferer`\n~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: LatentDiffusionInferer\n    :members:\n    :special-members: __call__\n\n`ControlNetDiffusionInferer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: ControlNetDiffusionInferer\n    :members:\n    :special-members: __call__\n\n`ControlNetLatentDiffusionInferer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: ControlNetLatentDiffusionInferer\n    :members:\n    :special-members: __call__\n\nSplitters\n---------\n.. currentmodule:: monai.inferers\n.. autoclass:: Splitter\n    :members:\n    :special-members: __call__\n\n`SlidingWindowSplitter`\n~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SlidingWindowSplitter\n    :members:\n    :special-members: __call__\n\n`WSISlidingWindowSplitter`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: WSISlidingWindowSplitter\n    :members:\n    :special-members: __call__\n\n\nMergers\n-------\n.. currentmodule:: monai.inferers\n.. autoclass:: Merger\n    :members:\n    :special-members: __call__\n\n`AvgMerger`\n~~~~~~~~~~~\n.. autoclass:: AvgMerger\n    :members:\n    :special-members: __call__\n\n`ZarrAvgMerger`\n~~~~~~~~~~~~~~~\n.. autoclass:: ZarrAvgMerger\n    :members:\n    :special-members: __call__\n\n\nSliding Window Inference Function\n---------------------------------\n\n.. autofunction:: monai.inferers.sliding_window_inference\n"
  },
  {
    "path": "docs/source/installation.md",
    "content": "# Installation Guide\n\n## Table of Contents\n\n- [Installation Guide](#installation-guide)\n\t- [Table of Contents](#table-of-contents)\n\t- [From PyPI](#from-pypi)\n\t\t- [Milestone release](#milestone-release)\n\t\t- [Weekly preview release](#weekly-preview-release)\n\t\t- [Uninstall the packages](#uninstall-the-packages)\n\t- [From conda-forge](#from-conda-forge)\n\t- [From GitHub](#from-github)\n\t\t- [Option 1 (as a part of your system-wide module):](#option-1-as-a-part-of-your-system-wide-module)\n\t\t- [Option 2 (editable installation):](#option-2-editable-installation)\n\t- [Validating the install](#validating-the-install)\n\t- [MONAI version string](#monai-version-string)\n\t- [From DockerHub](#from-dockerhub)\n\t- [Installing the recommended dependencies](#installing-the-recommended-dependencies)\n\n---\n\nMONAI's core functionality is written in Python 3 (>= 3.9) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/).\n\nThe package is currently distributed via Github as the primary source code repository,\nand the Python package index (PyPI). The pre-built Docker images are made available on DockerHub.\n\nTo install optional features such as handling the NIfTI files using\n[Nibabel](https://nipy.org/nibabel/), or building workflows using [Pytorch\nIgnite](https://pytorch.org/ignite/), please follow the instructions:\n\n- [Installing the recommended dependencies](#installing-the-recommended-dependencies)\n\nThe installation commands below usually end up installing CPU variant of PyTorch. To install GPU-enabled PyTorch:\n\n1. Install the latest NVIDIA driver.\n1. Check [PyTorch Official Guide](https://pytorch.org/get-started/locally/) for the recommended CUDA versions. For Pip package, the user needs to download the CUDA manually, install it on the system, and ensure CUDA_PATH is set properly.\n1. Continue to follow the guide and install PyTorch.\n1. Install MONAI using one the ways described below.\n\n---\n\n## From PyPI\n\n### Milestone release\n\nTo install the [current milestone release](https://pypi.org/project/monai/):\n\n```bash\npip install monai\n```\n\n### Weekly preview release\n\nTo install the [weekly preview release](https://pypi.org/project/monai-weekly/):\n\n```bash\npip install monai-weekly\n```\n\nThe weekly build is released to PyPI every Sunday with a pre-release build number `dev[%y%U]`.\nTo report any issues on the weekly preview, please include the version information:\n\n```bash\npython -c \"import monai; print(monai.__version__)\"\n```\n\nCoexistence of package `monai` and `monai-weekly` in a system may cause namespace conflicts\nand `ImportError`.\nThis is usually a result of running both `pip install monai` and `pip install monai-weekly`\nwithout uninstalling the existing one first.\nTo address this issue, please uninstall both packages, and retry the installation.\n\n### Uninstall the packages\n\nThe packages installed using `pip install` could be removed by:\n\n```bash\npip uninstall -y monai\npip uninstall -y monai-weekly\n```\n\n## From conda-forge\n\nTo install the [current milestone release](https://anaconda.org/conda-forge/monai):\n\n```bash\nconda install -c conda-forge monai\n```\n\n## From GitHub\n\n(_If you have installed the\nPyPI release version using `pip install monai`, please run `pip uninstall\nmonai` before using the commands from this section. Because `pip` by\ndefault prefers the milestone release_.)\n\nThe milestone versions are currently planned and released every few months. As the\ncodebase is under active development, you may want to install MONAI from GitHub\nfor the latest features:\n\n### Option 1 (as a part of your system-wide module):\n\n```bash\npip install git+https://github.com/Project-MONAI/MONAI\n```\n\nor, to build with MONAI C++/CUDA extensions:\n\n```bash\nBUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI\n```\n\nTo build the extensions, if the system environment already has a version of Pytorch installed,\n`--no-build-isolation` might be preferred:\n\n```bash\nBUILD_MONAI=1 pip install --no-build-isolation git+https://github.com/Project-MONAI/MONAI\n```\n\nthis command will download and install the current `dev` branch of [MONAI from\nGitHub](https://github.com/Project-MONAI/MONAI).\n\nThis documentation website by default shows the information for the latest version.\n\n### Option 2 (editable installation):\n\nTo install an editable version of MONAI, it is recommended to clone the codebase directly:\n\n```bash\ngit clone https://github.com/Project-MONAI/MONAI.git\n```\n\nThis command will create a `MONAI/` folder in your current directory.\nYou can install it by running:\n\n```bash\ncd MONAI/\npip install -e .\n```\n\nor, to build with MONAI C++/CUDA extensions and install:\n\n```bash\ncd MONAI/\nBUILD_MONAI=1 pip install -e .\n# for MacOS\nBUILD_MONAI=1 CC=clang CXX=clang++ pip install -e .\n```\n\nTo uninstall the package please run:\n\n```bash\ncd MONAI/\npip uninstall -y monai\n\n# to further clean up the MONAI/ folder (Bash script)\n./runtests.sh --clean\n```\n\nAlternatively, simply adding the root directory of the cloned source code (e.g., `/workspace/Documents/MONAI`) to your `$PYTHONPATH`\nand the codebase is ready to use (without the additional features of MONAI C++/CUDA extensions).\n\n> The C++/CUDA extension features are currently experimental, a pre-compiled version is made available via\n> [the recent docker image releases](https://hub.docker.com/r/projectmonai/monai).\n> Building the extensions from source may require [Ninja](https://ninja-build.org/) and [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit).\n> By default, CUDA extension is built if `torch.cuda.is_available()`. It's possible to force building by\n> setting `FORCE_CUDA=1` environment variable.\n\n## Validating the install\n\nYou can verify the installation by:\n\n```bash\npython -c \"import monai; monai.config.print_config()\"\n```\n\nIf the installation is successful, this command will print out the MONAI version information, and this confirms the core\nmodules of MONAI are ready-to-use.\n\n## MONAI version string\n\nThe MONAI version string shows the current status of your local installation. For example:\n\n```\nMONAI version: 0.1.0+144.g52c763d.dirty\n```\n\n- `0.1.0` indicates that your installation is based on the `0.1.0` milestone release.\n- `+144` indicates that your installation is 144 git commits ahead of the milestone release.\n- `g52c763d` indicates that your installation corresponds to the git commit hash `52c763d`.\n- `dirty` indicates that you have modified the codebase locally, and the codebase is inconsistent with `52c763d`.\n\n## From DockerHub\n\nMake sure you have installed the NVIDIA driver and Docker 19.03+ for your Linux distribution.\nNote that you do not need to install the CUDA toolkit on the host, but the driver needs to be installed.\nPlease find out more information on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker).\n\nAssuming that you have the Nvidia driver and Docker 19.03+ installed, running the following command will\ndownload and start a container with the latest version of MONAI. The latest `dev` branch of MONAI from GitHub\nis included in the image.\n\n```bash\ndocker run --gpus all --rm -ti --ipc=host projectmonai/monai:latest\n```\n\nYou can also run a milestone release docker image by specifying the image tag, for example:\n\n```\ndocker run --gpus all --rm -ti --ipc=host projectmonai/monai:0.1.0\n```\n\n## Installing the recommended dependencies\n\nBy default, the installation steps will only download and install the minimal requirements of MONAI.\nOptional dependencies can be installed using [the extras syntax](https://packaging.python.org/tutorials/installing-packages/#installing-setuptools-extras) to support additional features.\n\nFor example, to install MONAI with Nibabel and Scikit-image support:\n\n```bash\ngit clone https://github.com/Project-MONAI/MONAI.git\ncd MONAI/\npip install -e '.[nibabel,skimage]'\n```\n\nAlternatively, to install all optional dependencies:\n\n```bash\ngit clone https://github.com/Project-MONAI/MONAI.git\ncd MONAI/\npip install -e \".[all]\"\n```\n\nTo install all optional dependencies with `pip` based on MONAI development environment settings:\n\n```bash\ngit clone https://github.com/Project-MONAI/MONAI.git\ncd MONAI/\npip install -r requirements-dev.txt\n```\n\nTo install all optional dependencies with `conda` based on MONAI development environment settings (`environment-dev.yml`;\nthis will install PyTorch as well as `pytorch-cuda`, please follow https://pytorch.org/get-started/locally/#start-locally for more details about installing PyTorch):\n\n```bash\ngit clone https://github.com/Project-MONAI/MONAI.git\ncd MONAI/\nconda create -n <name> python=<ver>  # eg 3.9\nconda env update -n <name> -f environment-dev.yml\n```\n\nSince MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is available via PyPI.\n\n- The options are\n\n```\n[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]\n```\n\nwhich correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,\n`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.\n\n- `pip install 'monai[all]'` installs all the optional dependencies.\n"
  },
  {
    "path": "docs/source/lazy_resampling.rst",
    "content": ".. _lazy_resampling:\n\n:github_url: https://github.com/Project-MONAI/MONAI\n\nLazy Resampling\n===============\n\n.. toctree::\n   :maxdepth: 2\n\nIntroduction\n^^^^^^^^^^^^\n\nLazy Resampling is a new feature introduced in MONAI 1.2. This feature is still experimental at this time and it is\npossible that behaviour and APIs will change in upcoming releases.\n\nLazy resampling reworks the way that preprocessing is performed. It improves upon standard preprocessing pipelines and\ncan provide significant benefits over traditional preprocessing. It can improve:\n* pipeline execution time\n* pipeline memory usage in CPU or GPU\n* image and segmentation quality by reducing incidental noise and artifacts caused by resampling\n\nThe way it does this is by adopting the methods used in computer graphics pipelines, in which transformations to objects\nin a scene are modified by composing together a sequence of \"homogeneous matrices\".\n\nRather than each transform being executed in isolation, potentially requiring the data to be resampled to make a new\ntensor, transforms whose operations can be described in terms of homogeneous transforms do not execute their transforms\nimmediately. Instead, they create a \"pending operation\", which is added to a list of operations that will be fused\ntogether and carried out at the point that they are required.\n\n\nHow Lazy Resampling changes preprocessing\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn order to understand the difference between traditional pipelines and lazy pipelines, it is best to look at an example\npipeline and the differences between their execution strategies:\n\n\nTraditional execution\n+++++++++++++++++++++\n\nWith traditional resampling, found both in MONAI and many other preprocessing libraries, you typically define a sequence\nof transforms and pass them to a ``Compose`` object, such as :class:`monai.transforms.compose.Compose`.\n\nExample::\n\n    transforms = [\n        Spacingd(keys=[\"img\", \"seg\"], ...),\n        Orientationd(keys=[\"img\", \"seg\"], ...),\n        RandSpatialCropd(keys=[\"img\", \"seg\"], ...),\n        RandRotate90d(keys=[\"img\", \"seg\"], ...),\n        RandRotated(keys=[\"img\", \"seg\"], ...),\n        RandZoomd(keys=[\"img\", \"seg\"], ...),\n        RandGaussianNoised(keys=\"img\", ...),\n    ]\n    pipeline = Compose(transforms)\n\n    # elsewhere this will be called many times (such as in a Dataset instance)\n    outputs = pipeline(inputs)\n\n\nThe following will then happen when we call ``pipeline(inputs)``:\n\n1. ``Spacingd`` is called and interpolates the data samples\n2. ``Orientationd`` permutes the data samples so that their spatial dimensions are reorganised\n3. ``RandSpatialCropd`` crops a random patch of the data samples, throwing away the rest of the data in the process\n4. ``RandRotate90d`` has a chance of performing a tensor-based rotation of the data samples\n5. ``RandRotated`` has a chance of performing a full resample of the data samples\n6. ``RandZoomd`` has a chance of performing a interpolation of the data samples\n7. ``RandGaussianNoised`` has a chance of adding noise to ``img``\n\n.. figure:: ../images/lazy_resampling_trad_example_1.svg\n\n   Figure showing traditional pipeline execution. Tensors (the boxes in the main body of the image) are passed through\n   the pipeline, and the state of their `applied_operations` property is shown at each step. Tensors with a thick red\n   border have undergone some kind of resample operation at that stage.\n\nOverall, there are up to three occasions where the data is either interpolated or resampled through spatial transforms\n(``Spacingd``, ``RandRotated`` and ``RandZoomd``). Furthermore, the crop that occurs means that the output data\nsamples might contain pixels for which there is data but that show padding values, because the data was thrown away by\n``RandSpatialCrop``.\n\nEach of these operations takes time and memory, but, as we can see in the example above, also creates resampling\nartifacts and can even destroy data in the resulting data samples.\n\nLazy execution\n++++++++++++++\n\nLazy resampling works very differently. When you execute the same pipeline with `lazy=True`, the following happens:\n\n#. ``Spacingd`` is executed lazily. It puts a description of the operation that it wants to perform onto a list of\n   pending operations\n#. ``Orientationd`` is executed lazily. It adds a description of its own operation to the pending operation list so\n   now there are 2 pending operations\n#. ``RandSpatialCropd`` is executed lazily. It adds a description of its own operation to the pending\n   operation list so now there are 3 pending operations\n#. ``RandRotate90d`` is executed lazily. It adds a description of its own operation to the pending operation\n   list so now there are 4 pending operations\n#. ``RandRotated`` is executed lazily. It adds a description of its own operation to the pending operation\n   list so now there are 5 pending operations\n#. ``RandZoomd`` is executed lazily. It adds a description of its own operation to the pending operation\n   list so now there are 6 pending operations\n\n   #. [Spacingd, Orientationd, RandSpatialCropd, RandRotate90d, RandRotated, RandZoomd] are all on the pending\n      operations list but have yet to be carried out on the data\n#. ``RandGaussianNoised`` is not a lazy transform. It is now time for the pending operations to be evaluated. Their\n   descriptions are mathematically composited together, to determine the operation that results from all of them being\n   carried out. This is then applied in a single resample operation. Once that is done, RandGaussianNoised operates on\n   the resulting data\n\n.. figure:: ../images/lazy_resampling_lazy_example_1.svg\n\n   Figure showing lazy pipeline execution. We show the state of the `pending_operations` and `applied_operations`\n   properties of the tensor as it is processed by the pipeline. Thick red borders indicate some kind of resampling\n   operation has taken place at that step. Lazy resampling performs far fewer of these operations.\n\nThe single resampling operation has less noise induced by resampling, as it only occurs once in this pipeline rather\nthan three times in the traditional pipeline. More importantly, although the crop describes an operation to keep only a\nsubset of the data sample, the crop is not performed until after the spatial transforms are completed, which means that\nall of the data sample that is within bounds is preserved and is part of the resulting output.\n\n\nComposing homogeneous matrices\n++++++++++++++++++++++++++++++\n\n.. image:: ../images/lazy_resampling_homogeneous_matrices.svg\n\n\nAlthough a full treatment of homogeneous matrices is outside the scope of this document, a brief overview of them is\nuseful to understand the mechanics of lazy resampling. Homogeneous matrices are used in computer graphics to describe\noperations in cartesian space in a unified (homogeneous) fashion. Rotation, scaling, translation, and skewing are\namongst the operations that can be performed. Homogeneous matrices have the interesting property that they can be\ncomposited together, thus describing the result of a sequence of operations. Note that ordering is important;\n`scale -> rotate -> translation` gives a very different result to `translation -> rotate -> scale`.\n\nThe ability to composite homogeneous matrices together allows a sequence of operations to be carried out as a single\noperation, which is the key mechanism by which lazy resampling functions.\n\n\nAPI changes\n^^^^^^^^^^^\n\nA number of new arguments have been added to existing properties, which we'll go over in detail here. In particular,\nwe'll focus on :class:`Compose<monai.transforms.compose.Compose`> and\n:class:`LazyTrait<monai.transforms.traits.LazyTrait>`/ :class:`LazyTransform<monai.transforms.transform.LazyTransform>`\nand the way that they interact with each other.\n\n\nCompose\n+++++++\n\n:class:`Compose<monai.transforms.compose.Compose>` gains a number of new arguments that can be used to control\nresampling behaviour. Each of them is covered in its own section:\n\n\nlazy\n\"\"\"\"\n\n``lazy`` controls whether execution is carried out in a lazy manner or not. It has three values that it can take:\n\n* `lazy=False` forces the pipeline to be executed in the standard way with every transform applied immediately\n* `lazy=True` forces the pipeline to be executed lazily. Every transform that implements\n  :class:`LazyTrait<monai.transforms.traits.LazyTrait>` (or inherits\n  :class:`LazyTransform<monai.transforms.transform.LazyTransform>`) will be executed lazily\n* `lazy=None` means that the pipeline can execute lazily, but only on transforms that have their own `lazy` property\n  set to True.\n\n\noverrides\n\"\"\"\"\"\"\"\"\"\n\n``overrides`` allows the user to specify certain parameters that transforms can be overridden with when they are\nexecuted lazily. This parameter is primarily provided to allow you to run a pipeline without having to modify fields\nlike ``mode`` and ``padding_mode``.\nWhen executing dictionary-based transforms, you provide a dictionary containing overrides for each key, as follows. You\ncan omit keys that don't require overrides:\n\n.. code-block::\n\n    {\n        \"image\": {\"mode\": \"bilinear\"},\n        \"label\": {\"padding_mode\": \"zeros\"}\n    }\n\n\nlog_stats\n\"\"\"\"\"\"\"\"\"\n\nLogging of transform execution is provided if you wish to understand exactly how your pipelines execute. It can take a\n``bool`` or ``str`` value, and is False by default, which disables logging. Otherwise, you can enable it by passing it\nthe name of a logger that you wish to use (note, you don't have to construct the logger beforehand).\n\n\nLazyTrait / LazyTransform\n+++++++++++++++++++++++++\n\nMany transforms now implement either `LazyTrait<monai.transforms.traits.LazyTrait>` or\n`LazyTransform<monai.transforms.transform.Transform>`. Doing so marks the transform for lazy execution. Lazy\ntransforms have the following in common:\n\n\n``__init__`` has a ``lazy`` argument\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n``lazy`` is a ``bool`` value that can be passed to the initialiser when a lazy transform is instantiated. This\nindicates to the transform that it should execute lazily or not lazily. Note that this value can be overridden by\npassing ``lazy`` to ``__init__``. ``lazy`` is ``False`` by default\n\n\n``__call__`` has a ``lazy`` argument\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n``lazy`` is an optional ``bool`` value that can be passed at call time to override the behaviour defined during\ninitialisation. It has a default value of ``None``. If it is not ``None``, then this value is used instead of\n``self.lazy``. This allows the calling :class:`Compose<monai.transforms.compose.Compose>` instance to override\ndefault values rather than having to set it on every lazy transform (unless the user sets\n:class:`Compose.lazy<monai.transforms.compose.Compose>` to ``None``).\n\n\nlazy property\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nThe lazy property allows you to get or set the lazy status of a lazy transform after constructing it.\n\n\nrequires_current_data property (get only)\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\nThe ``requires_current_data`` property indicates that a transform makes use of the data in one or more of the tensors\nthat it is passed during its execution. Such transforms require that the tensors must therefore be up to date, even if\nthe transform itself is executing lazily. This is required for transforms such as ``CropForeground[d]``,\n``RandCropByPosNegLabel[d]``, and ``RandCropByLabelClasses[d]``. This property is implemented to return ``False`` on\n``LazyTransform`` and must be overridden to return ``True`` by transforms that check data values when executing.\n\n\nControlling laziness\n^^^^^^^^^^^^^^^^^^^^\n\nThere are two ways that a user can provide more fine-grained control over laziness. One is to make use of lazy=None\nwhen initialising or calling ``Compose`` instances. The other is to use the ``ApplyPending[d]`` transforms. These\ntechniques can be freely mixed and matched.\n\n\nUsing ``lazy=None``\n+++++++++++++++++++\n\n``Lazy=None`` tells ``Compose`` to honor the lazy flags set on each lazy transform. These are set to False by default\nso the user must set lazy=True on the transforms that they still wish to execute lazily.\n\n\n``lazy=None`` example:\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n.. figure:: ../images/lazy_resampling_none_example.svg\n\n    Figure shwoing the effect of using ``lazy=False`` when ``Compose`` is being executed with ``lazy=None``. Note that\n    the additional resamples that occur due to ``RandRotate90d`` being executed in a non-lazy fashion.\n\n\nUsing ``ApplyPending[d]``\n+++++++++++++++++++++++++\n\n``ApplyPending[d]`` causes all pending transforms to be executed before the following transform, regardless of whether\nthe following transform is a lazy transform, or is configured to execute lazily.\n\n\n``ApplyPending`` Example:\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n.. figure:: ../images/lazy_resampling_apply_pending_example.svg\n\n    Figure showing the use of :class:`ApplyPendingd<monai.transforms.lazy.dictionary.ApplyPendingd>` to cause\n    resampling to occur in the midele of a chain of lazy transforms.\n"
  },
  {
    "path": "docs/source/losses.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _losses:\n\nLoss functions\n==============\n\nSegmentation Losses\n-------------------\n\n.. automodule:: monai.losses\n.. currentmodule:: monai.losses\n\n`DiceLoss`\n~~~~~~~~~~\n.. autoclass:: DiceLoss\n    :members:\n\n.. autoclass:: Dice\n    :members:\n\n.. autoclass:: dice\n    :members:\n\n`MaskedDiceLoss`\n~~~~~~~~~~~~~~~~\n.. autoclass:: MaskedDiceLoss\n    :members:\n\n`GeneralizedDiceLoss`\n~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: GeneralizedDiceLoss\n    :members:\n\n.. autoclass:: generalized_dice\n    :members:\n\n`GeneralizedWassersteinDiceLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: GeneralizedWassersteinDiceLoss\n    :members:\n\n.. autoclass:: generalized_wasserstein_dice\n    :members:\n\n`DiceCELoss`\n~~~~~~~~~~~~\n.. autoclass:: DiceCELoss\n    :members:\n\n`DiceFocalLoss`\n~~~~~~~~~~~~~~~\n.. autoclass:: DiceFocalLoss\n    :members:\n\n`GeneralizedDiceFocalLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: GeneralizedDiceFocalLoss\n    :members:\n\n`FocalLoss`\n~~~~~~~~~~~\n.. autoclass:: FocalLoss\n    :members:\n\n`TverskyLoss`\n~~~~~~~~~~~~~\n.. autoclass:: TverskyLoss\n    :members:\n\n`ContrastiveLoss`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: ContrastiveLoss\n    :members:\n\n`BarlowTwinsLoss`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: BarlowTwinsLoss\n    :members:\n\n`HausdorffDTLoss`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: HausdorffDTLoss\n    :members:\n\n`SoftclDiceLoss`\n~~~~~~~~~~~~~~~~\n.. autoclass:: SoftclDiceLoss\n    :members:\n\n`SoftDiceclDiceLoss`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SoftDiceclDiceLoss\n    :members:\n\n`NACLLoss`\n~~~~~~~~~~\n.. autoclass:: NACLLoss\n    :members:\n\nRegistration Losses\n-------------------\n\n`BendingEnergyLoss`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: BendingEnergyLoss\n    :members:\n\n`DiffusionLoss`\n~~~~~~~~~~~~~~~\n.. autoclass:: DiffusionLoss\n    :members:\n\n`LocalNormalizedCrossCorrelationLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: LocalNormalizedCrossCorrelationLoss\n    :members:\n\n`GlobalMutualInformationLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: GlobalMutualInformationLoss\n    :members:\n\nReconstruction Losses\n---------------------\n\n`SSIMLoss`\n~~~~~~~~~~\n.. autoclass:: monai.losses.ssim_loss.SSIMLoss\n    :members:\n\n`PatchAdversarialLoss`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: PatchAdversarialLoss\n    :members:\n\n`PerceptualLoss`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: PerceptualLoss\n    :members:\n\n`JukeboxLoss`\n~~~~~~~~~~~~~~\n.. autoclass:: JukeboxLoss\n    :members:\n\n`SURELoss`\n~~~~~~~~~~\n.. autoclass:: SURELoss\n    :members:\n\n\nLoss Wrappers\n-------------\n\n`MultiScaleLoss`\n~~~~~~~~~~~~~~~~\n.. autoclass:: MultiScaleLoss\n    :members:\n\n`MaskedLoss`\n~~~~~~~~~~~~\n.. autoclass:: MaskedLoss\n    :members:\n\n`DeepSupervisionLoss`\n~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: DeepSupervisionLoss\n    :members:\n"
  },
  {
    "path": "docs/source/mb_properties.rst",
    "content": "MONAI Bundle Properties\n=======================\n\n\nTrain properties\n----------------\n\n.. csv-table::\n   :header-rows: 1\n   :file: train_properties.csv\n   :class: longtable\n   :widths: 10, 55, 5, 30\n\nInfer properties\n----------------\n\n.. csv-table::\n   :header-rows: 1\n   :file: infer_properties.csv\n   :class: longtable\n   :widths: 10, 55, 5, 30\n\nMeta properties\n---------------\n\n.. csv-table::\n   :header-rows: 1\n   :file: meta_properties.csv\n   :class: longtable\n   :widths: 10, 55, 5, 30\n"
  },
  {
    "path": "docs/source/mb_specification.rst",
    "content": "\n==========================\nMONAI Bundle Specification\n==========================\n\nOverview\n========\n\nThis is the specification for the MONAI Bundle (MB) format of portable described deep learning models. The objective of a MB is to define a packaged network or model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a single network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include.\n\nThis specification defines the directory structure a bundle must have and the necessary files it must contain. Additional files may be included and the directory packaged into a zip file or included as extra files directly in a Torchscript file.\n\nDirectory Structure\n===================\n\nA MONAI Bundle is defined primarily as a directory with a set of specifically named subdirectories containing the model, metadata files and license. The root directory should be named for the model, given as \"ModelName\" in this example, and should contain the following structure:\n\n::\n\n  ModelName\n  ┣━ LICENSE\n  ┣━ configs\n  ┃  ┗━ metadata.json\n  ┣━ models\n  ┃  ┣━ model.pt\n  ┃  ┣━ *model.ts\n  ┃  ┗━ *model.onnx\n  ┗━ docs\n     ┣━ *README.md\n     ┗━ *license.txt\n\n\nThe following files are **required** to be present with the given filenames for the directory to define a valid bundle:\n\n* **LICENSE**: a license for the software itself comprising the configuration files and model weights.\n* **metadata.json**: metadata information in JSON format relating to the type of model, definition of input and output tensors, versions of the model and used software, and other information described below.\n* **model.pt**: the state dictionary of a saved model, the information to instantiate the model must be found in the metadata file.\n\nThe following files are optional but must have these names in the directory given above:\n\n* **model.ts**: the Torchscript saved model if the model is compatible with being saved correctly in this format.\n* **model.onnx**: the ONNX model if the model is compatible with being saved correctly in this format.\n* **README.md**: plain-language information on the model, how to use it, author information, etc. in Markdown format.\n* **license.txt**: software license attached to the data, can be left blank if no license needed.\n\nOther files can be included in any of the above directories. For example, `configs` can contain further configuration JSON or YAML files to define scripts for training or inference, overriding configuration values, environment definitions such as network instantiations, and so forth. One common file to include is `inference.json` which is used to define a basic inference script which uses input files with the stored network to produce prediction output files.\n\nArchive Format\n==============\n\nThe bundle directory and its contents can be compressed into a zip file to constitute a single file package. When unzipped into a directory this file will reproduce the above directory structure, and should itself also be named after the model it contains. For example, `ModelName.zip` would contain at least `ModelName/configs/metadata.json` and `ModelName/models/model.pt`, thus when unzipped would place files into the directory `ModelName` rather than into the current working directory.\n\nThe Torchscript file format is also just a zip file with a specific structure. When creating such an archive with `save_net_with_metadata` a MB-compliant Torchscript file can be created by including the contents of `metadata.json` as the `meta_values` argument of the function, and other files included as `more_extra_files` entries. These will be stored in a `extras` directory in the zip file and can be retrieved with `load_net_with_metadata` or with any other library/tool that can read zip data. In this format the `model.*` files are obviously not needed but `README.md` and `license.txt` as well as any others provided can be added as more extra files.\n\nThe `bundle` submodule of MONAI contains a number of command line programs. To produce a Torchscript bundle use `ckpt_export` with a set of specified components such as the saved weights file and metadata file. Config files can be provided as JSON or YAML dictionaries defining Python constructs used by the `ConfigParser`, however regardless of format the produced bundle Torchscript object will store the files as JSON.\n\nmetadata.json File\n==================\n\nThis file contains the metadata information relating to the model, including what the shape and format of inputs and outputs are, what the meaning of the outputs are, what type of model is present, and other information. The JSON structure is a dictionary containing a defined set of keys with additional user-specified keys. The mandatory keys are as follows:\n\n* **version**: version of the stored model, this allows multiple versions of the same model to be differentiated. Versions should follow semantic versioning and contain only characters valid in filenames as we may include the version to construct bundle file name.\n* **monai_version**: version of MONAI the bundle was generated on, later versions expected to work.\n* **pytorch_version**: version of Pytorch the bundle was generated on, later versions expected to work.\n* **numpy_version**: version of Numpy the bundle was generated on, later versions expected to work.\n* **required_packages_version**: dictionary relating required package names to their versions. These are packages in addition to the base requirements of MONAI which this bundle absolutely needs. For example, if the bundle must load Nifti files the Nibabel package will be required.\n* **task**: plain-language description of what the model is meant to do.\n* **description**: longer form plain-language description of what the model is, what it does, etc.\n* **authors**: state author(s) of the model.\n* **copyright**: state model copyright.\n* **network_data_format**: defines the format, shape, and meaning of inputs and outputs to the (primary) model, contains keys \"inputs\" and \"outputs\" relating named inputs/outputs to their format specifiers (defined below). There is also an optional \"post_processed_outputs\" key stating the format of \"outputs\" after postprocessing transforms are applied, this is used to describe the final output from the bundle if it varies from the raw network output. These keys can also relate to primitive values (number, string, boolean), instead of the tensor format specified below.\n\nTensor format specifiers are used to define input and output tensors and their meanings, and must be a dictionary containing at least these keys:\n\n* **type**: what sort of data the tensor represents: \"image\" for any spatial regular data whether an actual image or just data with that sort of shape, \"series\" for (time-) sequences of values such as signals, \"tuples\" for a series of items defined by a known number of values such as N-sized points in ND space, \"probabilities\" for a set of probabilities such as classifier output, this useful for interpreting what the dimensions and shape of the data represent and allow users to guess how to plot the data\n* **format**: what format of information is stored, see below for list of known formats\n* **modality**: describes the modality, protocol type, sort of capturing technology, or other property of the data not described by either it's type or format, known modalities are \"MR\", \"CT\", \"US\", \"EKG\", but can include any custom types or protocol types (eg. \"T1\"), default value is \"n/a\"\n* **num_channels**: number of channels the tensor has, assumed channel dimension first\n* **spatial_shape**: shape of the spatial dimensions of the form \"[H]\", \"[H, W]\", or \"[H, W, D]\", see below for possible values of H, W, and D\n* **dtype**: data type of tensor, eg. \"float32\", \"int32\"\n* **value_range**: minimum and maximum values the input data is expected to have of the form \"[MIN, MAX]\" or \"[]\" if not known\n* **is_patch_data**: \"true\" if the data is a patch of an input/output tensor or the entirely of the tensor, \"false\" otherwise\n* **channel_def**: dictionary relating channel indices to plain-language description of what the channel contains\n\nOptional keys:\n\n* **changelog**: dictionary relating previous version names to strings describing the version.\n* **intended_use**: what the model is to be used for, ie. what task it accomplishes.\n* **data_source**: description of where training/validation can be sourced.\n* **data_type**: type of source data used for training/validation.\n* **references**: list of published referenced relating to the model.\n* **supported_apps**: list of supported applications which use bundles, eg. 'monai-label' would be present if the bundle is compatible with MONAI Label applications.\n* **\\*_data_format**: defines the format, shape, and meaning of inputs and outputs to additional models which are secondary to the main model. This contains the same sort of information as **network_data_format** which describes networks providing secondary functionality, eg. a localisation network used to identify ROI in an image for cropping before data is sent to the primary network of this bundle.\n\nThe format for tensors used as inputs and outputs can be used to specify semantic meaning of these values, and later is used by software handling bundles to determine how to process and interpret this data. There are various types of image data that MONAI is uses, and other data types such as point clouds, dictionary sequences, time signals, and others. The following list is provided as a set of supported definitions of what a tensor \"format\" is but is not exhaustive and users can provide their own which would be left up to the model users to interpret:\n\n* **magnitude**: ND field of continuous magnitude values with one or more channels, eg. MR T1 image having 1 channel or natural RGB image with 3 channels\n* **hounsfield**: ND field of semi-categorical values given in Hounsfield, eg. CT image\n* **kspace**: 2D/3D fourier transform image associated with MR imaging\n* **raw**: ND field of values considered unprocessed from an image acquisition device, eg. directly from a MR scanner without reconstruction or other processing\n* **labels**: ND categorical image with N one-hot channels for N-class segmentation/labels, the \"channel_def\" states in plain language what the interpretation of each channel is, for each pixel/voxel the predicted label is the index of the largest channel value\n* **classes**: ND categorical image with  N channels for N-class classes, the \"channel_def\" states in plain language what the interpretation of each channel is, this permits multi-class labeling as the channels need not be one-hot encoded\n* **segmentation**: ND categorical image with one channel assigning each pixel/voxel to a label described in \"channel_def\"\n* **points**: list of points/nodes/coordinates/vertices/vectors in ND space, so having a shape of [I, N] for I points with N dimensions\n* **normals**: list of vectors (possible of unit length) in ND space, so having a shape of [I, N] for I vectors with N dimensions\n* **indices**: list of indices into a vertices array and/or other array representing a set of shapes, so having a shape of [I, N] for I shapes defined by N values\n* **sequence**: time-related sequence of values having one or more channels, such as a signal or dictionary lookup sentence, so having a shape of [C, N] for C channels of data at N time points.\n* **latent**: ND tensor of data from the latent space from some layer of a network\n* **gradient**: ND tensor of gradients from some layer of a network\n\nSpatial shape definition can be complex for models accepting inputs of varying shapes, especially if there are specific conditions on what those shapes can be. Shapes are specified as lists of either positive integers for fixed sizes or strings containing expressions defining the condition a size depends on. This can be \"*\" to mean any size, or use an expression with Python mathematical operators and one character variables to represent dependence on an unknown quantity. For example, \"2**p\" represents a size which must be a power of 2, \"2**p*n\" must be a multiple of a power of 2. Variables are shared between dimension expressions, a spatial shape example: `[\"*\", \"16*n\", \"2**p*n\"]`.\n\nThe download link of a JSON schema to verify this file can be found within it with key \"schema\".\n\nAn example JSON metadata file:\n\n::\n\n  {\n      \"schema\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json\",\n      \"version\": \"0.1.0\",\n      \"changelog\": {\n          \"0.1.0\": \"complete the model package\",\n          \"0.0.1\": \"initialize the model package structure\"\n      },\n      \"monai_version\": \"0.9.0\",\n      \"pytorch_version\": \"1.10.0\",\n      \"numpy_version\": \"1.21.2\",\n      \"required_packages_version\": {\"nibabel\": \"3.2.1\"},\n      \"task\": \"Decathlon spleen segmentation\",\n      \"description\": \"A pre-trained model for volumetric (3D) segmentation of the spleen from CT image\",\n      \"authors\": \"MONAI team\",\n      \"copyright\": \"Copyright (c) MONAI Consortium\",\n      \"data_source\": \"Task09_Spleen.tar from http://medicaldecathlon.com/\",\n      \"data_type\": \"dicom\",\n      \"image_classes\": \"single channel data, intensity scaled to [0, 1]\",\n      \"label_classes\": \"single channel data, 1 is spleen, 0 is everything else\",\n      \"pred_classes\": \"2 channels OneHot data, channel 1 is spleen, channel 0 is background\",\n      \"eval_metrics\": {\n          \"mean_dice\": 0.96\n      },\n      \"intended_use\": \"This is an example, not to be used for diagnostic purposes\",\n      \"references\": [\n          \"Xia, Yingda, et al. '3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training.' arXiv preprint arXiv:1811.12506 (2018). https://arxiv.org/abs/1811.12506.\",\n          \"Kerfoot E., Clough J., Oksuz I., Lee J., King A.P., Schnabel J.A. (2019) Left-Ventricle Quantification Using Residual U-Net. In: Pop M. et al. (eds) Statistical Atlases and Computational Models of the Heart. Atrial Segmentation and LV Quantification Challenges. STACOM 2018. Lecture Notes in Computer Science, vol 11395. Springer, Cham. https://doi.org/10.1007/978-3-030-12029-0_40\"\n      ],\n      \"network_data_format\":{\n          \"inputs\": {\n              \"image\": {\n                  \"type\": \"image\",\n                  \"format\": \"magnitude\",\n                  \"modality\": \"MR\",\n                  \"num_channels\": 1,\n                  \"spatial_shape\": [160, 160, 160],\n                  \"dtype\": \"float32\",\n                  \"value_range\": [0, 1],\n                  \"is_patch_data\": false,\n                  \"channel_def\": {\"0\": \"image\"}\n              }\n          },\n          \"outputs\":{\n              \"pred\": {\n                  \"type\": \"image\",\n                  \"format\": \"labels\",\n                  \"num_channels\": 2,\n                  \"spatial_shape\": [160, 160, 160],\n                  \"dtype\": \"float32\",\n                  \"value_range\": [],\n                  \"is_patch_data\": false,\n                  \"channel_def\": {\"0\": \"background\", \"1\": \"spleen\"}\n              }\n          }\n      }\n  }\n"
  },
  {
    "path": "docs/source/metrics.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _metrics:\n\nMetrics\n=======\n.. currentmodule:: monai.metrics\n\n`FROC`\n------\n.. autofunction:: compute_fp_tp_probs\n.. autofunction:: compute_froc_curve_data\n.. autofunction:: compute_froc_score\n\n`Metric`\n--------\n.. autoclass:: Metric\n    :members:\n\n`Variance`\n--------------\n.. autofunction:: compute_variance\n\n.. autoclass:: VarianceMetric\n    :members:\n\n`LabelQualityScore`\n--------------------\n.. autofunction:: label_quality_score\n\n.. autoclass:: LabelQualityScore\n    :members:\n\n`IterationMetric`\n-----------------\n.. autoclass:: IterationMetric\n    :members:\n\n`Cumulative`\n------------\n.. autoclass:: Cumulative\n    :members:\n\n`CumulativeIterationMetric`\n---------------------------\n.. autoclass:: CumulativeIterationMetric\n    :members:\n\n`LossMetric`\n------------\n.. autoclass:: LossMetric\n    :members:\n\n`Mean Dice`\n-----------\n.. autoclass:: DiceMetric\n    :members:\n\n.. autoclass:: DiceHelper\n    :members:\n\n`Mean IoU`\n----------\n.. autofunction:: compute_iou\n\n.. autoclass:: MeanIoU\n    :members:\n\n`Generalized Dice Score`\n------------------------\n.. autofunction:: compute_generalized_dice\n\n.. autoclass:: GeneralizedDiceScore\n    :members:\n\n`Area under the ROC curve`\n--------------------------\n.. autofunction:: compute_roc_auc\n\n.. autoclass:: ROCAUCMetric\n    :members:\n\n`Average Precision`\n-------------------\n.. autofunction:: compute_average_precision\n\n.. autoclass:: AveragePrecisionMetric\n    :members:\n\n`Confusion matrix`\n------------------\n.. autofunction:: get_confusion_matrix\n.. autofunction:: compute_confusion_matrix_metric\n\n.. autoclass:: ConfusionMatrixMetric\n    :members:\n\n`Hausdorff distance`\n--------------------\n.. autofunction:: compute_hausdorff_distance\n\n.. autoclass:: HausdorffDistanceMetric\n    :members:\n\n`Average surface distance`\n--------------------------\n.. autofunction:: compute_average_surface_distance\n\n.. autoclass:: SurfaceDistanceMetric\n    :members:\n\n`Surface dice`\n--------------\n.. autofunction:: compute_surface_dice\n\n.. autoclass:: SurfaceDiceMetric\n    :members:\n\n`PanopticQualityMetric`\n-----------------------\n.. autofunction:: compute_panoptic_quality\n\n.. autoclass:: PanopticQualityMetric\n    :members:\n\n`Mean squared error`\n--------------------\n.. autoclass:: MSEMetric\n    :members:\n\n`Mean absolute error`\n---------------------\n.. autoclass:: MAEMetric\n    :members:\n\n`Root mean squared error`\n-------------------------\n.. autoclass:: RMSEMetric\n    :members:\n\n`Peak signal to noise ratio`\n----------------------------\n.. autoclass:: PSNRMetric\n    :members:\n\n`Mean absolute percentage error`\n---------------------------------\n.. autoclass:: MAPEMetric\n    :members:\n\n`Structural similarity index measure`\n-------------------------------------\n.. autoclass:: monai.metrics.regression.SSIMMetric\n\n`Multi-scale structural similarity index measure`\n-------------------------------------------------\n.. autoclass:: MultiScaleSSIMMetric\n\n`Fréchet Inception Distance`\n------------------------------\n.. autofunction:: compute_frechet_distance\n\n.. autoclass:: FIDMetric\n    :members:\n\n`Maximum Mean Discrepancy`\n------------------------------\n.. autofunction:: compute_mmd\n\n.. autoclass:: MMDMetric\n    :members:\n\n`Cumulative average`\n--------------------\n.. autoclass:: CumulativeAverage\n    :members:\n\n`Metrics reloaded binary`\n-------------------------\n.. autoclass:: MetricsReloadedBinary\n    :members:\n\n`Metrics reloaded categorical`\n------------------------------\n.. autoclass:: MetricsReloadedCategorical\n    :members:\n\n`Calibration Error`\n-------------------\n.. autofunction:: calibration_binning\n\n.. autoclass:: CalibrationReduction\n    :members:\n\n.. autoclass:: CalibrationErrorMetric\n    :members:\n\n\nUtilities\n---------\n.. automodule:: monai.metrics.utils\n  :members:\n"
  },
  {
    "path": "docs/source/modules.md",
    "content": "# Modules\n\nMONAI aims at facilitating deep learning in medical image analysis at multiple granularities. This document provides an\noverview of the modules and highlights the key capabilities.\n\nThe core codebase is designed as a library of lightweight, flexible, and comprehensive APIs for users with varying expertise.\nThe building blocks are made easy to understand and use, they are carefully decoupled and can be readily integrated\ninto existing PyTorch programs and larger systems. By leveraging the workflow and bundle APIs, users can also quickly\nset up efficient and robust model training or evaluation pipelines for various domain-specific applications.\n\nThe overall architecture and modules are shown in the following figure:\n\n![architecture overview](../images/arch_modules.png)\n\n* [I/O, processing and augmentation](#i-o-processing-and-augmentation)\n* [Datasets and Data Loading](#datasets-and-data-loading)\n* [Differentiable components, networks, losses and optimizers](#differentiable-components-networks-losses-and-optimizers)\n* [Evaluation](#evaluation)\n* [Visualization](#visualization)\n* [Workflows](#workflows)\n* [Bundle](#bundle)\n* [Federated Learning](#federated-learning)\n* [Auto3dseg](#auto3dseg)\n* [GPU acceleration, performance profiling and optimization](#gpu-acceleration-performance-profiling-and-optimization)\n\n## I/O, processing and augmentation\nMedical images require specialized methods for I/O, preprocessing and augmentation. They often follow specific formats,\nare handled with specific protocols, and the data arrays are often high-dimensional.\n[`monai.transforms`](https://github.com/Project-MONAI/MONAI/tree/dev/monai/transforms) and\n[`monai.data`](https://github.com/Project-MONAI/MONAI/tree/dev/monai/data) modules include a set of domain-specific APIs\nfor various deep learning applications:\n\n### Transforms with data in array and dictionary styles\n\n![3d transform examples](../images/affine.png)\n\nThis enables basic image transformations, as well as more complex preprocessing pipelines such as synchronized operations\nacross different modalities and model supervision inputs. [[array and dict examples]](https://github.com/Project-MONAI/tutorials/tree/main/3d_segmentation/torch)\n\n### Various image patch-based sampling mechanisms\n\n![2d transform examples](../images/medical_transforms.png)\n\nAdvanced patch sampling methods are implemented for selective preprocessing, such as weighted, class-balanced sampling\nfrom user-specified sampling weight maps.\nThe output can be in a sequence or iterator pattern which allows for different types of shuffling strategies.\n\n### Image IO with third-party library integrations\n\nSeveral backends are built-in and can support various formats. It is easily extensible for customized format readers.\n\n### monai.data.MetaTensor\n\nCore data structure combines PyTorch native Tensor APIs with metadata handling,\nso that the deep learning models and pipelines can readily incorporate the meta information. [[MetaTensor]](https://colab.research.google.com/drive/1T4iAys-cC2qL80oJkIbAXAPlWNPwp4H7)\n\n### GPU-based accelerations\n\nImplementations are provided to ensure optimal usage of the underlying hardware resources. [[fast training guide]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration)\n\n### Determinism and reproducibility\n\nThey can be achieved with fine-level of local controls via the `Randomizable` API as well as globally\nusing `set_determinism`.\n\n### Decollating and invertible transforms\n\n![invert transform](../images/invert_transforms.png)\nThe mini-batch data output from a model can be decollated, post-processed independently, including inverting\nthe outputs to an earlier step of the preprocessing according to the tracked metadata and applied operations.\n[[inverse transform demo]](https://github.com/Project-MONAI/tutorials/blob/main/modules/inverse_transforms_and_test_time_augmentations.ipynb)\n\n### Enhanced usability\n\nAdditionally, utilities such as `DataStats` transform, `dev_collate`, and [visualization\nmethods](https://github.com/Project-MONAI/tutorials/blob/main/modules/transform_visualization.ipynb) are provided as\nextensions to PyTorch for improved overall debugability.\n\n## Datasets and Data Loading\nFollowing PyTorch's design pattern, MONAI extends the `Dataset` and `DataLoader` APIs as major enhancements in terms of\ndomain-specific usability and pipeline performance.\n\n### Cache IO and transforms data to accelerate training\n\nData-driven methods require many (potentially thousands of) epochs of training data reading and preprocessing. MONAI\nprovides multi-threaded cache-based datasets to accelerate the process [[Datasets experiment]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/dataset_type_performance.ipynb). The\ncache can be persistent and dynamic (`SmartCacheDataset`) and reused across different experiments [[SmartCache example]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/unet_training_smartcache.py).\nThe following figure illustrates the training speedup compared with a regular PyTorch program.\n\n![cachedataset speed](../images/datasets_speed.png)\n\n### `ThreadDataLoader` vs. `DataLoader`\n\nIf the transforms are light-weighted, especially when we cache all the data in RAM, the multiprocessing of PyTorch\n`DataLoader` may cause unnecessary IPC time and decrease GPU utilization. MONAI provides `ThreadDataLoader` which\nexecutes the transforms in a separate thread:\n\n![threaddataloader](../images/threaddataloader.png)\n\na `ThreadDataLoader` example is within the [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb).\n\n### Public datasets\n\nTo quickly get started with popular training data, MONAI provides several ready-to-integrate Dataset classes\n(such as `MedNISTDataset`, `DecathlonDataset`, [`TciaDataset`](https://github.com/Project-MONAI/tutorials/blob/main/modules/tcia_dataset.ipynb)), which include data downloading, and support training/evaluation splits generation with transforms.\n[[Public datasets tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/modules/public_datasets.ipynb)\nThe common workflow of predefined datasets:\n\n![pre-defined dataset](../images/dataset_progress.png)\n\n### Dataset type extensions\n\nOther extensions of the `Dataset` API include: `ZipDataset` for associating multiple data sources, `PatchDataset` for\nhandling both image- and patch-level preprocessing, `CSVDataset` for multi-modal inputs, and `partition_dataset` for\ncross-validation data preparations.\n\n## Differentiable components, networks, losses and optimizers\n\nSome deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks.\nMONAI implements reference networks with the aim of both flexibility and code readability.\n\n### Predefined layers and blocks\n\nNetwork layers and blocks are in general implemented to be compatible with spatial 1D, 2D and 3D inputs.\nUsers can easily integrate the layers, blocks and networks as part of their customised pipelines.\nVarious utilities are provided to leverage the existing model weights, e.g. [from a bundle in MONAI model-zoo](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo).\n\n### C++/CUDA optimized modules\n\nTo further accelerate the domain-specific routines, MONAI C++/CUDA implementation is introduced as extensions of the PyTorch native implementations.\nMONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):\n- via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.\n- via just-in-time (JIT) compilation, for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.\nThe following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation:\n![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)\n\n\n### Losses and optimizers\n\nCommonly used loss functions for various applications are (re-)implemented from the literature, such as `DiceLoss`, `GeneralizedDiceLoss`, `TverskyLoss`, `DiceFocalLoss`.\nThe numerical optimizations and relevant utilities include `Novograd` and `LearningRateFinder`.\nThe following figure shows a learning rate search process.\n\n![learning rate finder plot](../images/lr_finder.png)\n\n## Evaluation\nTo run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant\nwidely-used approaches. Currently, several popular evaluation metrics and inference patterns are included:\n\n### Sliding window inference\n\nFor model inferences on large volumes, the sliding window approach is a popular choice to achieve high performance while\nhaving flexible memory requirements (_alternatively, please check out the latest research on [model parallel\ntraining](https://github.com/Project-MONAI/research-contributions/tree/main/lamp-automated-model-parallelism). It also supports\n`overlap` and `blending_mode` configurations to handle the overlapped windows for better performances.\n\n![sliding window scheme](../images/sliding_window.png)\n\n### Metrics for medical tasks\n\nVarious useful evaluation metrics have been implemented to measure the quality of medical image specific models.\nThese include `Mean Dice`, `ROCAUC`, `Confusion Matrices`, `Hausdorff\nDistance`, `Surface Distance`, `Occlusion Sensitivity`.\nThe APIs also support [multi-processing computation](https://github.com/Project-MONAI/tutorials/blob/main/modules/compute_metric.py).\n\n### Report generation\n`MetricsSaver` is provided to write the final metric summary report: `mean`, `median`, `max`, `min`, `<int>percentile`, `std`:\n\n![metrics report example](../images/metrics_report.png)\n\n## Visualization\nBeyond the simple point and curve plotting, intuitive interfaces are provided to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unet_segmentation_3d_ignite.ipynb).\n\nTo easily visualize a 3D image as frames of 2D images, MONAI provides the utility `matshow3d` based on `matplotlib` library. It can plot frames of image for the specified dimension, showing a spleen 3D image as example:\n`matshow3d(volume=image, figsize=(100, 100), every_n=10, frame_dim=-1 show=True, cmap=\"gray\")`\n\n![matshow3d example](../images/matshow3d.png)\n\nMONAI also provides the `blend_images` utility to blend the `image` and `label` to an RGB color image to better visualize the segmentation regions with the specified `cmap` mode and weights, etc. Showing a spleen segmentation `image` and the corresponding `label` as example:\n\n![blend example](../images/blend.png)\n\nFor more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualization tutorial](https://github.com/Project-MONAI/tutorials/blob/main/modules/transform_visualization.ipynb).\n\nAnd to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models:\n\n![CAM visualization example](../images/cam.png)\n\nThe above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/main/modules/interpretability).\n\n## Workflows\n\nMONAI engines and workflows enable quick start of training and evaluation experiments.\n\nThese features decouple the domain-specific components and the generic machine learning processes.\nThey also provide a set of unify APIs for higher level applications (such as AutOML, Federated Learning).\nThe trainers and evaluators of the workflows are compatible with pytorch-ignite `Engine` and `Event-Handler` mechanism.\n\n### General workflows pipeline\n\nThe workflow and some of MONAI event handlers are shown as below [[Workflow examples]](https://github.com/Project-MONAI/tutorials/tree/main/modules/engines):\n\n![workflow pipeline](../images/workflows.png)\n\n\n### EnsembleEvaluator\n\nA typical ensemble procoess is implemented as a ready-to-use workflow [[Cross validation and model ensemble tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/modules/cross_validation_models_ensemble.ipynb):\n1. Split all the training dataset into K folds.\n2. Train K models with every K-1 folds data.\n3. Execute inference on the test data with all the K models.\n4. Compute the average values with weights or vote the most common value as the final result.\n\n![model ensemble](../images/models_ensemble.png)\n\n\n### Decollate batch data for flexible post-processings\n\n`decollate batch` is introduced since MONAI v0.6, which simplifies the post-processing transforms and provides flexible following operations on a batch of data with various data shapes. It can decollate batched data (e.g. model predictions) into a list of tensors, for the benefits such as:\n1. enabling postprocessing transforms for each item independently -- randomised transforms could be applied differently for each predicted item in a batch.\n2. simplifying the transform APIs and reducing the input validation burdens because both the preprocessing and postprocessing transforms now only need to support the \"channel-first\" input format.\n3. enabling the `Invertd` transform for the predictions and the inverted data with different shapes, as the data items are in a list, not stacked in a single tensor.\n4. allowing for both batch-first tensor and list of channel-first tensors in a flexible metric computation. [[decollate batch tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/modules/decollate_batch.ipynb)\n\nA typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):\n\n![decollate_batch](../images/decollate_batch.png)\n\n### Easy to integrate into popular workflows\n\nExcept for the pytorch-ignite based `monai.engines`, most of the MONAI modules could be used independently or combined\nwith other software packages. For example, MONAI can be easily integrated into popular frameworks such as\n[PyTorch-Lightning](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d_lightning.ipynb)\nand [MLflow](https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/spleen_segmentation_mlflow.ipynb).\n\n## Bundle\n\nThe objective of a MONAI bundle is to define a packaged model which includes the critical information necessary to allow\nusers and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a\nsingle network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON\nfiles are included to store metadata about the model, information for constructing training, inference, and\npost-processing transform sequences, plain-text description, legal information, and other data the model creator wishes\nto include. More details are available at [bundle specification](https://monai.readthedocs.io/en/latest/mb_specification.html).\n\nThe key benefits of bundle are to define the model package and support building Python-based workflows via structured configurations:\n- Self-contained model package include all the necessary information.\n- Structured config can be used to easily reconstruct or prototype deep learning workflows.\n- Config files can provide good readability and usability by separating parameter settings from the Python code.\n- Config files can describe flexible workflow and components, allows for different low-level Python implementations\n- Learning paradigms at a higher level such as federated learning and AutoML can be decoupled from the component details.\n\nA typical bundle example can include:\n```\n  ModelName\n  ┣━ configs\n  ┃  ┗━ metadata.json\n  ┣━ models\n  ┃  ┣━ model.pt\n  ┃  ┣━ *model.ts\n  ┃  ┗━ *model.onnx\n  ┗━ docs\n     ┣━ *README.md\n     ┗━ *license.txt\n```\nDetails about the bundle config definition and syntax & examples are at [config syntax](https://monai.readthedocs.io/en/latest/config_syntax.html).\nA step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/main/bundle/README.md) tutorial notebook can help users quickly set up a bundle. [[bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/bundle), [model-zoo](https://github.com/Project-MONAI/model-zoo)]\n\n## Federated Learning\n\n![federated-learning](../images/federated.svg)\n\nUsing the MONAI bundle configurations, we can use MONAI's [`MonaiAlgo`](https://monai.readthedocs.io/en/latest/fl.html#monai.fl.client.MonaiAlgo)\nclass, an implementation of the abstract [`ClientAlgo`](https://monai.readthedocs.io/en/latest/fl.html#clientalgo) class for federated learning (FL),\nto execute bundles from the [MONAI model zoo](https://github.com/Project-MONAI/model-zoo).\nNote that [`ClientAlgo`](https://monai.readthedocs.io/en/latest/fl.html#clientalgo) is provided as an abstract base class for\ndefining an algorithm to be run on any federated learning platform.\n[`MonaiAlgo`](https://monai.readthedocs.io/en/latest/fl.html#monai.fl.client.MonaiAlgo) implements the main functionalities needed\nto run federated learning experiments, namely `train()`, `get_weights()`, and `evaluate()`, that can be run using single- or multi-GPU training.\nOn top, it provides implementations for life-cycle management of the component such as `initialize()`, `abort()`, and `finalize()`.\nThe MONAI FL client also allows computing summary data statistics (e.g., intensity histograms) on the datasets defined in the bundle configs\nusing the [`MonaiAlgoStats`](https://monai.readthedocs.io/en/latest/fl.html#monai.fl.client.MonaiAlgoStats) class.\nThese statistics can be shared and visualized on the FL server.\n[NVIDIA FLARE](https://github.com/NVIDIA/NVFlare), the federated learning platform developed by NVIDIA, has already built [the integration piece](https://github.com/NVIDIA/NVFlare/tree/2.2/integration/monai)\nwith [`ClientAlgo`](https://monai.readthedocs.io/en/latest/fl.html#clientalgo) to allow easy experimentation with MONAI bundles within their federated environment.\nOur [[federated learning tutorials]](https://github.com/Project-MONAI/tutorials/tree/main/federated_learning/nvflare) shows\nexamples of single- & multi-GPU training and federated statistics workflows.\n\n## Auto3dseg\n\n![auto3dseg](../images/auto3dseg.png)\n\n[Auto3DSeg](https://project-monai.github.io/apps/auto3dseg.html) is a comprehensive solution for large-scale 3D medical image segmentation.\nIt leverages the latest advances in MONAI\nand GPUs to efficiently develop and deploy algorithms with state-of-the-art performance.\nIt first analyzes the global information such as intensity, dimensionality, and resolution of the dataset,\nthen generates algorithms in MONAI bundle format based on data statistics and [algorithm templates](https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg).\nNext, all algorithms initiate model training to obtain checkpoints with the best validation performance.\nFinally, the ensemble module selects the algorithms via ranking trained checkpoints and creates ensemble predictions.\n\nThe solution offers different levels of user experience for beginners and advanced researchers.\nIt has been tested on large-scale 3D medical imaging datasets in different modalities.\n\n## GPU acceleration, performance profiling and optimization\n\nMONAI provides state-of-the-art performance optimization methods including:\n\n### Auto mixed precision (AMP)\n\nSimply set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP\nExample benchmark results are as follows [[AMP training tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/automatic_mixed_precision.ipynb):\n\ntraining with AMP ON/OFF on a NVIDIA V100 GPU with CUDA 11 and PyTorch 1.6:\n\n![amp v100 results](../images/amp_training_v100.png)\n\ntraining with AMP ON/OFF on a NVIDIA A100 GPU with CUDA 11 and PyTorch 1.6:\n\n![amp a100 results](../images/amp_training_a100.png)\n\nSeveral tools including `DLProf`, `Nsight`, `NVTX` and `NVML` can be used with MONAI to identify the performance bottleneck. [[profiling tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/performance_profiling/radiology/profiling_train_base_nvtx.md)\n\n### Distributed training\n\nThe distributed data-parallel APIs of MONAI are compatible with the native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform.\n[[distributed training tutorial]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/distributed_training/brats_training_ddp.py)\n\n![distributed training results](../images/brats_distributed.png)\n\nThe [fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb)\ncombines `AMP` with `CacheDataset`, `GPU cache`, `GPU transforms`, `ThreadDataLoader`, tuning of networks and optimizers, can achieve substantial speedup compared\nwith a PyTorch regular implementation.\n"
  },
  {
    "path": "docs/source/networks.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _networks:\n\nNetwork architectures\n=====================\n\nBlocks\n------\n.. automodule:: monai.networks.blocks\n.. currentmodule:: monai.networks.blocks\n\n`ADN`\n~~~~~\n.. autoclass:: ADN\n    :members:\n\n`Convolution`\n~~~~~~~~~~~~~\n.. autoclass:: Convolution\n    :members:\n\n`CRF`\n~~~~~\n.. autoclass:: CRF\n    :members:\n\n`ResidualUnit`\n~~~~~~~~~~~~~~\n.. autoclass:: ResidualUnit\n    :members:\n\n`Swish`\n~~~~~~~\n.. autoclass:: Swish\n    :members:\n\n`MemoryEfficientSwish`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: MemoryEfficientSwish\n    :members:\n\n`FPN`\n~~~~~\n.. autoclass:: ExtraFPNBlock\n    :members:\n.. autoclass:: FeaturePyramidNetwork\n    :members:\n.. autoclass:: LastLevelMaxPool\n    :members:\n.. autoclass:: LastLevelP6P7\n    :members:\n.. autoclass:: BackboneWithFPN\n    :members:\n\n`Mish`\n~~~~~~\n.. autoclass:: Mish\n    :members:\n\n`GEGLU`\n~~~~~~~\n.. autoclass:: GEGLU\n    :members:\n\n`GCN Module`\n~~~~~~~~~~~~\n.. autoclass:: GCN\n    :members:\n\n`Refinement Module`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: Refine\n    :members:\n\n`FCN Module`\n~~~~~~~~~~~~\n.. autoclass:: FCN\n    :members:\n\n`Multi-Channel FCN Module`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: MCFCN\n    :members:\n\n`Dynamic-Unet Block`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: UnetResBlock\n    :members:\n.. autoclass:: UnetBasicBlock\n    :members:\n.. autoclass:: UnetUpBlock\n    :members:\n.. autoclass:: UnetOutBlock\n    :members:\n\n`DenseBlock`\n~~~~~~~~~~~~~\n.. autoclass:: DenseBlock\n   :members:\n\n`SegResnet Block`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: ResBlock\n    :members:\n\n`SABlock Block`\n~~~~~~~~~~~~~~~\n.. autoclass:: SABlock\n    :members:\n\n`CABlock Block`\n~~~~~~~~~~~~~~~\n.. autoclass:: CABlock\n    :members:\n\n`FeedForward Block`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: FeedForward\n    :members:\n\n`Squeeze-and-Excitation`\n~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: ChannelSELayer\n    :members:\n\n`Transformer Block`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TransformerBlock\n    :members:\n\n`UNETR Block`\n~~~~~~~~~~~~~\n.. autoclass:: UnetrBasicBlock\n    :members:\n.. autoclass:: UnetrUpBlock\n    :members:\n.. autoclass:: UnetrPrUpBlock\n    :members:\n\n`Residual Squeeze-and-Excitation`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: ResidualSELayer\n    :members:\n\n`Squeeze-and-Excitation Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SEBlock\n    :members:\n\n`Squeeze-and-Excitation Bottleneck`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SEBottleneck\n    :members:\n\n`Squeeze-and-Excitation Resnet Bottleneck`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SEResNetBottleneck\n    :members:\n\n`Squeeze-and-Excitation ResNeXt Bottleneck`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SEResNeXtBottleneck\n    :members:\n\n`Simple ASPP`\n~~~~~~~~~~~~~\n.. autoclass:: SimpleASPP\n    :members:\n\n`MaxAvgPooling`\n~~~~~~~~~~~~~~~\n.. autoclass:: MaxAvgPool\n    :members:\n\n`Upsampling`\n~~~~~~~~~~~~\n.. autoclass:: UpSample\n    :members:\n.. autoclass:: Upsample\n.. autoclass:: SubpixelUpsample\n    :members:\n.. autoclass:: Subpixelupsample\n.. autoclass:: SubpixelUpSample\n\n`Downsampling`\n~~~~~~~~~~~~~~\n.. autoclass:: DownSample\n    :members:\n.. autoclass:: Downsample\n.. autoclass:: SubpixelDownsample\n    :members:\n.. autoclass:: Subpixeldownsample\n.. autoclass:: SubpixelDownSample\n\n`Registration Residual Conv Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: RegistrationResidualConvBlock\n    :members:\n\n`Registration Down Sample Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: RegistrationDownSampleBlock\n    :members:\n\n`Registration Extraction Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: RegistrationExtractionBlock\n    :members:\n\n`LocalNet DownSample Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: LocalNetDownSampleBlock\n    :members:\n\n`LocalNet UpSample Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: LocalNetUpSampleBlock\n    :members:\n\n`LocalNet Feature Extractor Block`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: LocalNetFeatureExtractorBlock\n    :members:\n\n`MLP Block`\n~~~~~~~~~~~\n.. autoclass:: MLPBlock\n    :members:\n\n`Patch Embedding Block`\n~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: PatchEmbeddingBlock\n    :members:\n\n`FactorizedIncreaseBlock`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: FactorizedIncreaseBlock\n    :members:\n\n`FactorizedReduceBlock`\n~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: FactorizedReduceBlock\n    :members:\n\n`P3DActiConvNormBlock`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: P3DActiConvNormBlock\n    :members:\n\n`ActiConvNormBlock`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: ActiConvNormBlock\n    :members:\n\n`Warp`\n~~~~~~\n.. autoclass:: Warp\n    :members:\n\n`DVF2DDF`\n~~~~~~~~~\n.. autoclass:: DVF2DDF\n    :members:\n\n`VarNetBlock`\n~~~~~~~~~~~~~\n.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock\n   :members:\n\nN-Dim Fourier Transform\n~~~~~~~~~~~~~~~~~~~~~~~~\n.. automodule:: monai.networks.blocks.fft_utils_t\n.. autofunction:: monai.networks.blocks.fft_utils_t.fftn_centered_t\n.. autofunction:: monai.networks.blocks.fft_utils_t.ifftn_centered_t\n.. autofunction:: monai.networks.blocks.fft_utils_t.roll\n.. autofunction:: monai.networks.blocks.fft_utils_t.roll_1d\n.. autofunction:: monai.networks.blocks.fft_utils_t.fftshift\n.. autofunction:: monai.networks.blocks.fft_utils_t.ifftshift\n\nLayers\n------\n\n`Factories`\n~~~~~~~~~~~\n.. automodule:: monai.networks.layers.factories\n\n.. autoclass:: monai.networks.layers.LayerFactory\n  :members:\n\n.. currentmodule:: monai.networks.layers\n\n`split_args`\n~~~~~~~~~~~~\n.. autofunction:: monai.networks.layers.split_args\n\n`Dropout`\n~~~~~~~~~\n.. automodule:: monai.networks.layers.Dropout\n  :members:\n\n`Act`\n~~~~~\n.. automodule:: monai.networks.layers.Act\n  :members:\n\n`Norm`\n~~~~~~\n.. automodule:: monai.networks.layers.Norm\n  :members:\n\n`Conv`\n~~~~~~\n.. automodule:: monai.networks.layers.Conv\n  :members:\n\n`Pad`\n~~~~~\n.. automodule:: monai.networks.layers.Pad\n  :members:\n\n`Pool`\n~~~~~~\n.. automodule:: monai.networks.layers.Pool\n  :members:\n\n.. currentmodule:: monai.networks.layers\n\n`ChannelPad`\n~~~~~~~~~~~~\n.. autoclass:: ChannelPad\n    :members:\n\n`SkipConnection`\n~~~~~~~~~~~~~~~~\n.. autoclass:: SkipConnection\n    :members:\n\n`Flatten`\n~~~~~~~~~\n.. autoclass:: Flatten\n    :members:\n\n`Reshape`\n~~~~~~~~~\n.. autoclass:: Reshape\n    :members:\n\n`separable_filtering`\n~~~~~~~~~~~~~~~~~~~~~\n.. autofunction:: separable_filtering\n\n`apply_filter`\n~~~~~~~~~~~~~~\n.. autofunction:: apply_filter\n\n`GaussianFilter`\n~~~~~~~~~~~~~~~~\n.. autoclass:: GaussianFilter\n    :members:\n\n`MedianFilter`\n~~~~~~~~~~~~~~\n.. autoclass:: MedianFilter\n    :members:\n\n`median_filter`\n~~~~~~~~~~~~~~~\n.. autoclass:: median_filter\n    :members:\n\n`BilateralFilter`\n~~~~~~~~~~~~~~~~~\n.. autoclass:: BilateralFilter\n    :members:\n\n`TrainableBilateralFilter`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TrainableBilateralFilter\n    :members:\n\n`TrainableJointBilateralFilter`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TrainableJointBilateralFilter\n    :members:\n\n`PHLFilter`\n~~~~~~~~~~~\n.. autoclass:: PHLFilter\n\n`GaussianMixtureModel`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: GaussianMixtureModel\n\n`SavitzkyGolayFilter`\n~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: SavitzkyGolayFilter\n    :members:\n\n`HilbertTransform`\n~~~~~~~~~~~~~~~~~~\n.. autoclass:: HilbertTransform\n    :members:\n\n`Affine Transform`\n~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.networks.layers.AffineTransform\n    :members:\n\n`grid_pull`\n~~~~~~~~~~~\n.. autofunction:: monai.networks.layers.grid_pull\n\n`grid_push`\n~~~~~~~~~~~\n.. autofunction:: monai.networks.layers.grid_push\n\n`grid_count`\n~~~~~~~~~~~~\n.. autofunction:: monai.networks.layers.grid_count\n\n`grid_grad`\n~~~~~~~~~~~\n.. autofunction:: monai.networks.layers.grid_grad\n\n`LLTM`\n~~~~~~\n.. autoclass:: LLTM\n    :members:\n\n`ConjugateGradient`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: ConjugateGradient\n    :members:\n\n`Utilities`\n~~~~~~~~~~~\n.. automodule:: monai.networks.layers.convutils\n    :members:\n.. automodule:: monai.networks.layers.utils\n    :members:\n\n\nNets\n----\n.. currentmodule:: monai.networks.nets\n\n`AHNet`\n~~~~~~~\n.. autoclass:: AHNet\n  :members:\n\n`DenseNet`\n~~~~~~~~~~\n.. autoclass:: DenseNet\n  :members:\n\n`DenseNet121`\n~~~~~~~~~~~~~\n.. autoclass:: DenseNet121\n\n`DenseNet169`\n~~~~~~~~~~~~~\n.. autoclass:: DenseNet169\n\n`DenseNet201`\n~~~~~~~~~~~~~\n.. autoclass:: DenseNet201\n\n`DenseNet264`\n~~~~~~~~~~~~~\n.. autoclass:: DenseNet264\n\n`EfficientNet`\n~~~~~~~~~~~~~~\n.. autoclass:: EfficientNet\n  :members:\n\n`BlockArgs`\n~~~~~~~~~~~\n.. autoclass:: BlockArgs\n  :members:\n\n`EfficientNetBN`\n~~~~~~~~~~~~~~~~\n.. autoclass:: EfficientNetBN\n  :members:\n\n`EfficientNetBNFeatures`\n~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: EfficientNetBNFeatures\n  :members:\n\n`SegResNet`\n~~~~~~~~~~~\n.. autoclass:: SegResNet\n  :members:\n\n`SegResNetDS`\n~~~~~~~~~~~~~\n.. autoclass:: SegResNetDS\n  :members:\n\n`SegResNetDS2`\n~~~~~~~~~~~~~~\n.. autoclass:: SegResNetDS2\n  :members:\n\n`SegResNetVAE`\n~~~~~~~~~~~~~~\n.. autoclass:: SegResNetVAE\n  :members:\n\n`ResNet`\n~~~~~~~~\n.. autoclass:: ResNet\n  :members:\n\n`ResNetFeatures`\n~~~~~~~~~~~~~~~~\n.. autoclass:: ResNetFeatures\n  :members:\n\n`SENet`\n~~~~~~~\n.. autoclass:: SENet\n  :members:\n\n`SENet154`\n~~~~~~~~~~\n.. autoclass:: SENet154\n\n`SEResNet50`\n~~~~~~~~~~~~\n.. autoclass:: SEResNet50\n\n`SEResNet101`\n~~~~~~~~~~~~~\n.. autoclass:: SEResNet101\n\n`SEResNet152`\n~~~~~~~~~~~~~\n.. autoclass:: SEResNet152\n\n`SEResNext50`\n~~~~~~~~~~~~~\n.. autoclass:: SEResNext50\n\n`SEResNext101`\n~~~~~~~~~~~~~~\n.. autoclass:: SEResNext101\n\n`HighResNet`\n~~~~~~~~~~~~\n.. autoclass:: HighResNet\n  :members:\n.. autoclass:: HighResBlock\n  :members:\n\n`DynUNet`\n~~~~~~~~~\n.. autoclass:: DynUNet\n  :members:\n.. autoclass:: DynUnet\n.. autoclass:: Dynunet\n\n`UNet`\n~~~~~~\n.. autoclass:: UNet\n  :members:\n.. autoclass:: Unet\n.. autoclass:: unet\n\n`AttentionUnet`\n~~~~~~~~~~~~~~~\n.. autoclass:: AttentionUnet\n  :members:\n\n`UNETR`\n~~~~~~~\n.. autoclass:: UNETR\n    :members:\n\n`VISTA3D`\n~~~~~~~~~\n.. autoclass:: VISTA3D\n    :members:\n\n`SwinUNETR`\n~~~~~~~~~~~\n.. autoclass:: SwinUNETR\n    :members:\n\n`BasicUNet`\n~~~~~~~~~~~\n.. autoclass:: BasicUNet\n  :members:\n.. autoclass:: BasicUnet\n.. autoclass:: Basicunet\n\n`BasicUNetPlusPlus`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: BasicUNetPlusPlus\n  :members:\n.. autoclass:: BasicUnetPlusPlus\n.. autoclass:: BasicunetPlusPlus\n\n`FlexibleUNet`\n~~~~~~~~~~~~~~\n.. autoclass:: FlexibleUNet\n  :members:\n\n`VNet`\n~~~~~~\n.. autoclass:: VNet\n  :members:\n\n`RegUNet`\n~~~~~~~~~\n.. autoclass:: RegUNet\n  :members:\n\n`GlobalNet`\n~~~~~~~~~~~~\n.. autoclass:: GlobalNet\n  :members:\n\n`LocalNet`\n~~~~~~~~~~~\n.. autoclass:: LocalNet\n  :members:\n\n`AutoEncoder`\n~~~~~~~~~~~~~\n.. autoclass:: AutoEncoder\n  :members:\n\n`VarAutoEncoder`\n~~~~~~~~~~~~~~~~\n.. autoclass:: VarAutoEncoder\n  :members:\n\n`ViT`\n~~~~~\n.. autoclass:: ViT\n  :members:\n\n`Restormer`\n~~~~~~~~~~~\n.. autoclass:: restormer\n  :members:\n\n`ViTAutoEnc`\n~~~~~~~~~~~~\n.. autoclass:: ViTAutoEnc\n  :members:\n\n`MaskedAutoEncoderViT`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: MaskedAutoEncoderViT\n  :members:\n\n`FullyConnectedNet`\n~~~~~~~~~~~~~~~~~~~\n.. autoclass:: FullyConnectedNet\n  :members:\n\n`VarFullyConnectedNet`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: VarFullyConnectedNet\n  :members:\n\n`Generator`\n~~~~~~~~~~~\n.. autoclass:: Generator\n  :members:\n\n`Regressor`\n~~~~~~~~~~~\n.. autoclass:: Regressor\n  :members:\n\n`Classifier`\n~~~~~~~~~~~~\n.. autoclass:: Classifier\n  :members:\n\n`Discriminator`\n~~~~~~~~~~~~~~~\n.. autoclass:: Discriminator\n  :members:\n\n`Critic`\n~~~~~~~~\n.. autoclass:: Critic\n  :members:\n\n`Transchex`\n~~~~~~~~~~~~~~~~\n.. autoclass:: Transchex\n  :members:\n\n`NetAdapter`\n~~~~~~~~~~~~\n.. autoclass:: NetAdapter\n  :members:\n\n`TorchVisionFCModel`\n~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TorchVisionFCModel\n  :members:\n\n`MILModel`\n~~~~~~~~~~\n.. autoclass:: MILModel\n  :members:\n\n`DiNTS`\n~~~~~~~\n.. autoclass:: DiNTS\n  :members:\n\n`TopologyConstruction for DiNTS`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TopologyConstruction\n  :members:\n\n`TopologyInstance for DiNTS`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TopologyInstance\n  :members:\n\n`TopologySearch for DiNTS`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: TopologySearch\n  :members:\n\n`ComplexUnet`\n~~~~~~~~~~~~~\n.. autoclass:: monai.apps.reconstruction.networks.nets.complex_unet.ComplexUnet\n   :members:\n\n`CoilSensitivityModel`\n~~~~~~~~~~~~~~~~~~~~~~\n.. autoclass:: monai.apps.reconstruction.networks.nets.coil_sensitivity_model.CoilSensitivityModel\n   :members:\n\n`e2e-VarNet`\n~~~~~~~~~~~~\n.. autoclass:: monai.apps.reconstruction.networks.nets.varnet.VariationalNetworkModel\n   :members:\n\n`DAF3D`\n~~~~~~~~~~~~\n.. autoclass:: DAF3D\n   :members:\n\n`Quicknat`\n~~~~~~~~~~~~\n.. autoclass:: Quicknat\n   :members:\n\n`VoxelMorph`\n~~~~~~~~~~~~\n.. autoclass:: VoxelMorphUNet\n   :members:\n\n.. autoclass:: VoxelMorph\n   :members:\n\nUtilities\n---------\n.. automodule:: monai.networks.utils\n  :members:\n\n.. automodule:: monai.apps.reconstruction.networks.nets.utils\n  :members:\n\nNoise Schedulers\n----------------\n.. automodule:: monai.networks.schedulers\n.. currentmodule:: monai.networks.schedulers\n\n`Scheduler`\n~~~~~~~~~~~\n.. autoclass:: Scheduler\n  :members:\n\n`NoiseSchedules`\n~~~~~~~~~~~~~~~~\n.. autoclass:: NoiseSchedules\n  :members:\n\n`DDPMScheduler`\n~~~~~~~~~~~~~~~\n.. autoclass:: DDPMScheduler\n  :members:\n\n`DDIMScheduler`\n~~~~~~~~~~~~~~~\n.. autoclass:: DDIMScheduler\n  :members:\n\n`PNDMScheduler`\n~~~~~~~~~~~~~~~\n.. autoclass:: PNDMScheduler\n  :members:\n\n`RFlowScheduler`\n~~~~~~~~~~~~~~~~\n.. autoclass:: RFlowScheduler\n  :members:\n"
  },
  {
    "path": "docs/source/optimizers.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _optimizers:\n\nOptimizers\n==========\n.. currentmodule:: monai.optimizers\n\n`LearningRateFinder`\n--------------------\n.. autoclass:: LearningRateFinder\n    :members:\n\n`Novograd`\n----------\n.. autoclass:: Novograd\n    :members:\n\n`Generate parameter groups`\n---------------------------\n.. autofunction:: generate_param_groups\n\n`ExponentialLR`\n---------------\n.. autoclass:: ExponentialLR\n    :members:\n\n`LinearLR`\n----------\n.. autoclass:: LinearLR\n    :members:\n\n`WarmupCosineSchedule`\n----------------------\n.. autoclass:: WarmupCosineSchedule\n    :members:\n"
  },
  {
    "path": "docs/source/precision_accelerating.md",
    "content": "# Precision and Accelerating\n\nModern GPU architectures usually can use reduced precision tensor data or computational operations to save memory and increase throughput. However, in some cases, the reduced precision will cause numerical stability issues, and further cause reproducibility issues. Therefore, please ensure that you are using appropriate precision.\n\n<!-- Maybe adding Automatic Mixed Precision, Float16 or BFloat16 in the future-->\n\n## TensorFloat-32 (TF32)\n\n### Introduction\n\nNVIDIA introduced a new math mode TensorFloat-32 (TF32) for NVIDIA Ampere GPUs and above, see [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/), [TRAINING NEURAL NETWORKS\nWITH TENSOR CORES](https://nvlabs.github.io/eccv2020-mixed-precision-tutorial/files/dusan_stosic-training-neural-networks-with-tensor-cores.pdf), [CUDA 11](https://developer.nvidia.com/blog/cuda-11-features-revealed/) and [Ampere architecture](https://developer.nvidia.com/blog/nvidia-ampere-architecture-in-depth/).\n\nTF32 adopts 8 exponent bits, 10 bits of mantissa, and one sign bit.\n\n![Precision options used for AI training.](../images/precision_options.png)\n\n### Potential Impact\n\nAlthough NVIDIA has shown that TF32 mode can reach the same accuracy and convergence as float32 for most AI workloads, some users still find some significant effect on their applications, see [PyTorch and TensorFloat32](https://dev-discuss.pytorch.org/t/pytorch-and-tensorfloat32/504). Users who need high-precision matrix operation, such as traditional computer graphics operation and kernel method, may be affected by TF32 precision.\n\nNote that all operations that use `cuda.matmul` may be affected\nby TF32 mode so the impact is very wide.\n\n### Settings\n\n[PyTorch TF32](https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) default value:\n```python\ntorch.backends.cuda.matmul.allow_tf32 = False # in PyTorch 1.12 and later.\ntorch.backends.cudnn.allow_tf32 = True\n```\nPlease note that there are environment variables that can override the flags above. For example, the environment variable `NVIDIA_TF32_OVERRIDE` mentioned in [Accelerating AI Training with NVIDIA TF32 Tensor Cores](https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/) and `TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` used by PyTorch. Thus, in some cases, the flags may be accidentally changed or overridden.\n\nIf you are using an [NGC PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch), the container includes a layer `ENV TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=1`.\nThe default value `torch.backends.cuda.matmul.allow_tf32` will be overridden to `True`.\nTo restore the upstream default value, please run `unset TORCH_ALLOW_TF32_CUBLAS_OVERRIDE` in the container,\nand use the Pytorch API `torch.set_float32_matmul_precision`, `torch.backends.cudnn.allow_tf32=False` accordingly.\n\n\nWe recommend that users print out these two flags for confirmation when unsure.\n\nIf you can confirm through experiments that your model has no accuracy or convergence issues in TF32 mode and you have NVIDIA Ampere GPUs or above, you can set the two flags above to `True` to speed up your model.\n"
  },
  {
    "path": "docs/source/transforms.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _transform_api:\n\nTransforms\n==========\n\nGeneric Interfaces\n------------------\n.. automodule:: monai.transforms\n.. currentmodule:: monai.transforms\n\n`Transform`\n^^^^^^^^^^^\n.. autoclass:: Transform\n    :members:\n    :special-members: __call__\n\n`MapTransform`\n^^^^^^^^^^^^^^\n.. autoclass:: MapTransform\n    :members:\n    :special-members: __call__\n\n`RandomizableTrait`\n^^^^^^^^^^^^^^^^^^^\n.. autoclass:: RandomizableTrait\n    :members:\n\n`LazyTrait`\n^^^^^^^^^^^\n.. autoclass:: LazyTrait\n    :members:\n\n`MultiSampleTrait`\n^^^^^^^^^^^^^^^^^^\n.. autoclass:: MultiSampleTrait\n    :members:\n\n`ReduceTrait`\n^^^^^^^^^^^^^^^^^^\n.. autoclass:: ReduceTrait\n    :members:\n\n`Randomizable`\n^^^^^^^^^^^^^^\n.. autoclass:: Randomizable\n    :members:\n\n`LazyTransform`\n^^^^^^^^^^^^^^^\n.. autoclass:: LazyTransform\n    :members:\n\n`RandomizableTransform`\n^^^^^^^^^^^^^^^^^^^^^^^\n.. autoclass:: RandomizableTransform\n    :members:\n\n`Compose`\n^^^^^^^^^\n.. autoclass:: Compose\n    :members:\n    :special-members: __call__\n\n`InvertibleTransform`\n^^^^^^^^^^^^^^^^^^^^^\n.. autoclass:: InvertibleTransform\n    :members:\n\n`TraceableTransform`\n^^^^^^^^^^^^^^^^^^^^\n.. autoclass:: TraceableTransform\n    :members:\n\n`BatchInverseTransform`\n^^^^^^^^^^^^^^^^^^^^^^^\n.. autoclass:: BatchInverseTransform\n    :members:\n\n`Decollated`\n^^^^^^^^^^^^\n.. autoclass:: Decollated\n    :members:\n\n`OneOf`\n^^^^^^^\n.. autoclass:: OneOf\n    :members:\n\n`RandomOrder`\n^^^^^^^^^^^^^\n.. autoclass:: RandomOrder\n    :members:\n\n`SomeOf`\n^^^^^^^^^^^^^\n.. autoclass:: SomeOf\n    :members:\n\nFunctionals\n-----------\n\nCrop and Pad (functional)\n^^^^^^^^^^^^^^^^^^^^^^^^^\n.. automodule:: monai.transforms.croppad.functional\n    :members:\n\nSpatial (functional)\n^^^^^^^^^^^^^^^^^^^^\n.. automodule:: monai.transforms.spatial.functional\n    :members:\n\n.. currentmodule:: monai.transforms\n\nVanilla Transforms\n------------------\n\nCrop and Pad\n^^^^^^^^^^^^\n\n`PadListDataCollate`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: PadListDataCollate\n    :members:\n    :special-members: __call__\n\n`Pad`\n\"\"\"\"\"\n.. autoclass:: Pad\n    :members:\n    :special-members: __call__\n\n`SpatialPad`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/SpatialPad.png\n    :alt: example of SpatialPad\n.. autoclass:: SpatialPad\n    :members:\n    :special-members: __call__\n\n`BorderPad`\n\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/BorderPad.png\n    :alt: example of BorderPad\n.. autoclass:: BorderPad\n    :members:\n    :special-members: __call__\n\n`DivisiblePad`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/DivisiblePad.png\n    :alt: example of DivisiblePad\n.. autoclass:: DivisiblePad\n    :members:\n    :special-members: __call__\n\n`Crop`\n\"\"\"\"\"\"\n.. autoclass:: Crop\n    :members:\n    :special-members: __call__\n\n`SpatialCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/SpatialCrop.png\n    :alt: example of SpatialCrop\n.. autoclass:: SpatialCrop\n    :members:\n    :special-members: __call__\n\n`CenterSpatialCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/CenterSpatialCrop.png\n    :alt: example of CenterSpatialCrop\n.. autoclass:: CenterSpatialCrop\n    :members:\n    :special-members: __call__\n\n`RandSpatialCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSpatialCrop.png\n    :alt: example of RandSpatialCrop\n.. autoclass:: RandSpatialCrop\n    :members:\n    :special-members: __call__\n\n`RandSpatialCropSamples`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSpatialCropSamples.png\n    :alt: example of RandSpatialCropSamples\n.. autoclass:: RandSpatialCropSamples\n    :members:\n    :special-members: __call__\n\n`CropForeground`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/CropForeground.png\n    :alt: example of CropForeground\n.. autoclass:: CropForeground\n    :members:\n    :special-members: __call__\n\n`RandWeightedCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandWeightedCrop.png\n    :alt: example of RandWeightedCrop\n.. autoclass:: RandWeightedCrop\n    :members:\n    :special-members: __call__\n\n`RandCropByPosNegLabel`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCropByPosNegLabel.png\n    :alt: example of RandCropByPosNegLabel\n.. autoclass:: RandCropByPosNegLabel\n    :members:\n    :special-members: __call__\n\n`RandCropByLabelClasses`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCropByLabelClasses.png\n    :alt: example of RandCropByLabelClasses\n.. autoclass:: RandCropByLabelClasses\n    :members:\n    :special-members: __call__\n\n`ResizeWithPadOrCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ResizeWithPadOrCrop.png\n    :alt: example of ResizeWithPadOrCrop\n.. autoclass:: ResizeWithPadOrCrop\n    :members:\n    :special-members: __call__\n\n`BoundingRect`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: BoundingRect\n    :members:\n    :special-members: __call__\n\n`RandScaleCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleCrop.png\n    :alt: example of RandScaleCrop\n.. autoclass:: RandScaleCrop\n    :members:\n    :special-members: __call__\n\n`CenterScaleCrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/CenterScaleCrop.png\n    :alt: example of CenterScaleCrop\n.. autoclass:: CenterScaleCrop\n    :members:\n    :special-members: __call__\n\nIntensity\n^^^^^^^^^\n\n`RandGaussianNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGaussianNoise.png\n    :alt: example of RandGaussianNoise\n.. autoclass:: RandGaussianNoise\n    :members:\n    :special-members: __call__\n\n`ShiftIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ShiftIntensity.png\n    :alt: example of ShiftIntensity\n.. autoclass:: ShiftIntensity\n    :members:\n    :special-members: __call__\n\n`RandShiftIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandShiftIntensity.png\n    :alt: example of RandShiftIntensity\n.. autoclass:: RandShiftIntensity\n    :members:\n    :special-members: __call__\n\n`StdShiftIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/StdShiftIntensity.png\n    :alt: example of StdShiftIntensity\n.. autoclass:: StdShiftIntensity\n    :members:\n    :special-members: __call__\n\n`RandStdShiftIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandStdShiftIntensity.png\n    :alt: example of RandStdShiftIntensity\n.. autoclass:: RandStdShiftIntensity\n    :members:\n    :special-members: __call__\n\n`RandBiasField`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandBiasField.png\n    :alt: example of RandBiasField\n.. autoclass:: RandBiasField\n    :members:\n    :special-members: __call__\n\n`ScaleIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ScaleIntensity.png\n    :alt: example of ScaleIntensity\n.. autoclass:: ScaleIntensity\n    :members:\n    :special-members: __call__\n\n`ClipIntensityPercentiles`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ClipIntensityPercentiles\n    :members:\n    :special-members: __call__\n\n`RandScaleIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensity.png\n    :alt: example of RandScaleIntensity\n.. autoclass:: RandScaleIntensity\n    :members:\n    :special-members: __call__\n\n`ScaleIntensityFixedMean`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ScaleIntensityFixedMean\n    :members:\n    :special-members: __call__\n\n`RandScaleIntensityFixedMean`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandScaleIntensityFixedMean\n    :members:\n    :special-members: __call__\n\n`NormalizeIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/NormalizeIntensity.png\n    :alt: example of NormalizeIntensity\n.. autoclass:: NormalizeIntensity\n    :members:\n    :special-members: __call__\n\n`ThresholdIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ThresholdIntensity.png\n    :alt: example of ThresholdIntensity\n.. autoclass:: ThresholdIntensity\n    :members:\n    :special-members: __call__\n\n`ScaleIntensityRange`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ScaleIntensityRange.png\n    :alt: example of ScaleIntensityRange\n.. autoclass:: ScaleIntensityRange\n    :members:\n    :special-members: __call__\n\n`ScaleIntensityRangePercentiles`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ScaleIntensityRangePercentiles.png\n    :alt: example of ScaleIntensityRangePercentiles\n.. autoclass:: ScaleIntensityRangePercentiles\n    :members:\n    :special-members: __call__\n\n`AdjustContrast`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/AdjustContrast.png\n    :alt: example of AdjustContrast\n.. autoclass:: AdjustContrast\n    :members:\n    :special-members: __call__\n\n`RandAdjustContrast`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandAdjustContrast.png\n    :alt: example of RandAdjustContrast\n.. autoclass:: RandAdjustContrast\n    :members:\n    :special-members: __call__\n\n`MaskIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/MaskIntensity.png\n    :alt: example of MaskIntensity\n.. autoclass:: MaskIntensity\n    :members:\n    :special-members: __call__\n\n`SavitzkyGolaySmooth`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/SavitzkyGolaySmooth.png\n    :alt: example of SavitzkyGolaySmooth\n.. autoclass:: SavitzkyGolaySmooth\n    :members:\n    :special-members: __call__\n\n`MedianSmooth`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/MedianSmooth.png\n    :alt: example of MedianSmooth\n.. autoclass:: MedianSmooth\n    :members:\n    :special-members: __call__\n\n`GaussianSmooth`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GaussianSmooth.png\n    :alt: example of GaussianSmooth\n.. autoclass:: GaussianSmooth\n    :members:\n    :special-members: __call__\n\n`RandGaussianSmooth`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGaussianSmooth.png\n    :alt: example of RandGaussianSmooth\n.. autoclass:: RandGaussianSmooth\n    :members:\n    :special-members: __call__\n\n`GaussianSharpen`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GaussianSharpen.png\n    :alt: example of GaussianSharpen\n.. autoclass:: GaussianSharpen\n    :members:\n    :special-members: __call__\n\n`RandGaussianSharpen`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGaussianSharpen.png\n    :alt: example of RandGaussianSharpen\n.. autoclass:: RandGaussianSharpen\n    :members:\n    :special-members: __call__\n\n`RandHistogramShift`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandHistogramShift.png\n    :alt: example of RandHistogramShift\n.. autoclass:: RandHistogramShift\n    :members:\n    :special-members: __call__\n\n`DetectEnvelope`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: DetectEnvelope\n    :members:\n    :special-members: __call__\n\n`GibbsNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GibbsNoise.png\n    :alt: example of GibbsNoise\n.. autoclass:: GibbsNoise\n    :members:\n    :special-members: __call__\n\n`RandGibbsNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGibbsNoise.png\n    :alt: example of RandGibbsNoise\n.. autoclass:: RandGibbsNoise\n    :members:\n    :special-members: __call__\n\n`KSpaceSpikeNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/KSpaceSpikeNoise.png\n    :alt: example of KSpaceSpikeNoise\n.. autoclass:: KSpaceSpikeNoise\n    :members:\n    :special-members: __call__\n\n`RandKSpaceSpikeNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandKSpaceSpikeNoise.png\n    :alt: example of RandKSpaceSpikeNoise\n.. autoclass:: RandKSpaceSpikeNoise\n    :members:\n    :special-members: __call__\n\n`RandRicianNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandRicianNoise.png\n    :alt: example of RandRicianNoise\n.. autoclass:: RandRicianNoise\n    :members:\n    :special-members: __call__\n\n`RandCoarseTransform`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandCoarseTransform\n    :members:\n    :special-members: __call__\n\n`RandCoarseDropout`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCoarseDropout.png\n    :alt: example of RandCoarseDropout\n.. autoclass:: RandCoarseDropout\n    :members:\n    :special-members: __call__\n\n`RandCoarseShuffle`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCoarseShuffle.png\n    :alt: example of RandCoarseShuffle\n.. autoclass:: RandCoarseShuffle\n    :members:\n    :special-members: __call__\n\n`HistogramNormalize`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/HistogramNormalize.png\n    :alt: example of HistogramNormalize\n.. autoclass:: HistogramNormalize\n    :members:\n    :special-members: __call__\n\n\n`ForegroundMask`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ForegroundMask.png\n    :alt: example of ForegroundMask\n.. autoclass:: ForegroundMask\n    :members:\n    :special-members: __call__\n\n`ComputeHoVerMaps`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ComputeHoVerMaps\n    :members:\n    :special-members: __call__\n\n\nIO\n^^\n\n`LoadImage`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: LoadImage\n    :members:\n    :special-members: __call__\n\n`SaveImage`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SaveImage\n    :members:\n    :special-members: __call__\n\n`WriteFileMapping`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: WriteFileMapping\n    :members:\n    :special-members: __call__\n\n\nNVIDIA Tool Extension (NVTX)\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n`RangePush`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RangePush\n\n`RandRangePush`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandRangePush\n\n`RangePop`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RangePop\n\n`RandRangePop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandRangePop\n\n`Mark`\n\"\"\"\"\"\"\n.. autoclass:: Mark\n\n`RandMark`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandMark\n\n\nPost-processing\n^^^^^^^^^^^^^^^\n\n`Activations`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Activations\n    :members:\n    :special-members: __call__\n\n`AsDiscrete`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/AsDiscrete.png\n    :alt: example of AsDiscrete\n.. autoclass:: AsDiscrete\n    :members:\n    :special-members: __call__\n\n`KeepLargestConnectedComponent`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/KeepLargestConnectedComponent.png\n    :alt: example of KeepLargestConnectedComponent\n.. autoclass:: KeepLargestConnectedComponent\n    :members:\n    :special-members: __call__\n\n`DistanceTransformEDT`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: DistanceTransformEDT\n    :members:\n    :special-members: __call__\n\n`RemoveSmallObjects`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjects.png\n    :alt: example of RemoveSmallObjects\n.. autoclass:: RemoveSmallObjects\n    :members:\n    :special-members: __call__\n\n`LabelFilter`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/LabelFilter.png\n    :alt: example of LabelFilter\n.. autoclass:: LabelFilter\n    :members:\n    :special-members: __call__\n\n`FillHoles`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FillHoles\n    :members:\n    :special-members: __call__\n\n`LabelToContour`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/LabelToContour.png\n    :alt: example of LabelToContour\n.. autoclass:: LabelToContour\n    :members:\n    :special-members: __call__\n\n`MeanEnsemble`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: MeanEnsemble\n    :members:\n    :special-members: __call__\n\n`ProbNMS`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: ProbNMS\n  :members:\n\n`SobelGradients`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SobelGradients\n  :members:\n  :special-members: __call__\n\n`VoteEnsemble`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: VoteEnsemble\n    :members:\n    :special-members: __call__\n\n`Invert`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: Invert\n    :members:\n    :special-members: __call__\n\nRegularization\n^^^^^^^^^^^^^^\n\n`CutMix`\n\"\"\"\"\"\"\"\"\n.. autoclass:: CutMix\n    :members:\n    :special-members: __call__\n\n`CutOut`\n\"\"\"\"\"\"\"\"\n.. autoclass:: CutOut\n    :members:\n    :special-members: __call__\n\n`MixUp`\n\"\"\"\"\"\"\"\n.. autoclass:: MixUp\n    :members:\n    :special-members: __call__\n\nSignal\n^^^^^^^\n\n`SignalRandDrop`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandDrop\n    :members:\n    :special-members: __call__\n\n`SignalRandScale`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandScale\n    :members:\n    :special-members: __call__\n\n`SignalRandShift`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandShift\n    :members:\n    :special-members: __call__\n\n`SignalRandAddSine`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandAddSine\n    :members:\n    :special-members: __call__\n\n`SignalRandAddSquarePulse`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandAddSquarePulse\n    :members:\n    :special-members: __call__\n\n`SignalRandAddGaussianNoise`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandAddGaussianNoise\n    :members:\n    :special-members: __call__\n\n`SignalRandAddSinePartial`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandAddSinePartial\n    :members:\n    :special-members: __call__\n\n`SignalRandAddSquarePulsePartial`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRandAddSquarePulsePartial\n    :members:\n    :special-members: __call__\n\n`SignalFillEmpty`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalFillEmpty\n    :members:\n    :special-members: __call__\n\n`SignalRemoveFrequency`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalRemoveFrequency\n    :members:\n    :special-members: __call__\n\n`SignalContinuousWavelet`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalContinuousWavelet\n    :members:\n    :special-members: __call__\n\nSpatial\n^^^^^^^\n\n`SpatialResample`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SpatialResample\n    :members:\n    :special-members: __call__\n\n`ResampleToMatch`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ResampleToMatch\n    :members:\n    :special-members: __call__\n\n`Spacing`\n\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Spacing.png\n    :alt: example of Spacing\n.. autoclass:: Spacing\n    :members:\n    :special-members: __call__\n\n`Orientation`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Orientation.png\n    :alt: example of Orientation\n.. autoclass:: Orientation\n    :members:\n    :special-members: __call__\n\n`RandRotate`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandRotate.png\n    :alt: example of RandRotate\n.. autoclass:: RandRotate\n    :members:\n    :special-members: __call__\n\n`RandFlip`\n\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandFlip.png\n    :alt: example of RandFlip\n.. autoclass:: RandFlip\n    :members:\n    :special-members: __call__\n\n`RandAxisFlip`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandAxisFlip.png\n    :alt: example of RandAxisFlip\n.. autoclass:: RandAxisFlip\n    :members:\n    :special-members: __call__\n\n`RandZoom`\n\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandZoom.png\n    :alt: example of RandZoom\n.. autoclass:: RandZoom\n    :members:\n    :special-members: __call__\n\n`Affine`\n\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Affine.png\n    :alt: example of Affine\n.. autoclass:: Affine\n    :members:\n    :special-members: __call__\n\n`Resample`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Resample\n    :members:\n    :special-members: __call__\n\n`RandAffine`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandAffine.png\n    :alt: example of RandAffine\n.. autoclass:: RandAffine\n    :members:\n    :special-members: __call__\n\n`RandDeformGrid`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandDeformGrid\n    :members:\n    :special-members: __call__\n\n`AffineGrid`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AffineGrid\n    :members:\n    :special-members: __call__\n\n`RandAffineGrid`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandAffineGrid\n    :members:\n    :special-members: __call__\n\n`GridDistortion`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GridDistortion.png\n    :alt: example of GridDistortion\n.. autoclass:: GridDistortion\n    :members:\n    :special-members: __call__\n\n`RandGridDistortion`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGridDistortion.png\n    :alt: example of RandGridDistortion\n.. autoclass:: RandGridDistortion\n    :members:\n    :special-members: __call__\n\n`Rand2DElastic`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rand2DElastic.png\n    :alt: example of Rand2DElastic\n.. autoclass:: Rand2DElastic\n    :members:\n    :special-members: __call__\n\n`Rand3DElastic`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rand3DElastic.png\n    :alt: example of Rand3DElastic\n.. autoclass:: Rand3DElastic\n    :members:\n    :special-members: __call__\n\n`Rotate90`\n\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rotate90.png\n    :alt: example of Rotate90\n.. autoclass:: Rotate90\n    :members:\n    :special-members: __call__\n\n`RandRotate90`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandRotate90.png\n    :alt: example of RandRotate90\n.. autoclass:: RandRotate90\n    :members:\n    :special-members: __call__\n\n`Flip`\n\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Flip.png\n    :alt: example of Flip\n.. autoclass:: Flip\n    :members:\n    :special-members: __call__\n\n`Resize`\n\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Resize.png\n    :alt: example of Resize\n.. autoclass:: Resize\n    :members:\n    :special-members: __call__\n\n`Rotate`\n\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rotate.png\n    :alt: example of Rotate\n.. autoclass:: Rotate\n    :members:\n    :special-members: __call__\n\n`Zoom`\n\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Zoom.png\n    :alt: example of Zoom\n.. autoclass:: Zoom\n    :members:\n    :special-members: __call__\n\n`GridPatch`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: GridPatch\n    :members:\n    :special-members: __call__\n\n`RandGridPatch`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandGridPatch\n    :members:\n    :special-members: __call__\n\n`GridSplit`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: GridSplit\n    :members:\n    :special-members: __call__\n\n`RandSimulateLowResolution`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandSimulateLowResolution\n    :members:\n    :special-members: __call__\n\n`ConvertBoxToPoints`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConvertBoxToPoints\n    :members:\n    :special-members: __call__\n\n`ConvertPointsToBoxes`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConvertPointsToBoxes\n    :members:\n    :special-members: __call__\n\n\nSmooth Field\n^^^^^^^^^^^^\n\n`RandSmoothFieldAdjustContrast`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSmoothFieldAdjustContrast.png\n    :alt: example of RandSmoothFieldAdjustContrast\n.. autoclass:: RandSmoothFieldAdjustContrast\n    :members:\n    :special-members: __call__\n\n`RandSmoothFieldAdjustIntensity`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSmoothFieldAdjustIntensity.png\n    :alt: example of RandSmoothFieldAdjustIntensity\n.. autoclass:: RandSmoothFieldAdjustIntensity\n    :members:\n    :special-members: __call__\n\n`RandSmoothDeform`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSmoothDeform.png\n    :alt: example of RandSmoothDeform\n.. autoclass:: RandSmoothDeform\n    :members:\n    :special-members: __call__\n\n\nMRI Transforms\n^^^^^^^^^^^^^^\n\n`Kspace under-sampling`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: monai.apps.reconstruction.transforms.array.KspaceMask\n    :members:\n    :special-members: __call__\n\n.. autoclass:: monai.apps.reconstruction.transforms.array.RandomKspaceMask\n    :special-members: __call__\n\n.. autoclass:: monai.apps.reconstruction.transforms.array.EquispacedKspaceMask\n    :special-members: __call__\n\n\nLazy\n^^^^\n\n`ApplyPending`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n.. autoclass:: ApplyPending\n    :members:\n    :special-members: __call__\n\n\nUtility\n^^^^^^^\n\n`Identity`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Identity\n    :members:\n    :special-members: __call__\n\n`AsChannelLast`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AsChannelLast\n    :members:\n    :special-members: __call__\n\n`EnsureChannelFirst`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: EnsureChannelFirst\n    :members:\n    :special-members: __call__\n\n`RepeatChannel`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RepeatChannel\n    :members:\n    :special-members: __call__\n\n`SplitDim`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SplitDim\n    :members:\n    :special-members: __call__\n\n`CastToType`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: CastToType\n    :members:\n    :special-members: __call__\n\n`ToTensor`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToTensor\n    :members:\n    :special-members: __call__\n\n`ToNumpy`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToNumpy\n    :members:\n    :special-members: __call__\n\n`ToCupy`\n\"\"\"\"\"\"\"\"\n.. autoclass:: ToCupy\n    :members:\n    :special-members: __call__\n\n`Transpose`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Transpose\n    :members:\n    :special-members: __call__\n\n`SqueezeDim`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SqueezeDim\n    :members:\n    :special-members: __call__\n\n`DataStats`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: DataStats\n    :members:\n    :special-members: __call__\n\n`SimulateDelay`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SimulateDelay\n    :members:\n    :special-members: __call__\n\n\n`Lambda`\n\"\"\"\"\"\"\"\"\n.. autoclass:: Lambda\n    :members:\n    :special-members: __call__\n\n`RandLambda`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandLambda\n    :members:\n    :special-members: __call__\n\n`RemoveRepeatedChannel`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RemoveRepeatedChannel\n    :members:\n    :special-members: __call__\n\n`LabelToMask`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: LabelToMask\n    :members:\n    :special-members: __call__\n\n`FgBgToIndices`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FgBgToIndices\n    :members:\n    :special-members: __call__\n\n`ClassesToIndices`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ClassesToIndices\n    :members:\n    :special-members: __call__\n\n`ConvertToMultiChannelBasedOnBratsClasses`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConvertToMultiChannelBasedOnBratsClasses\n    :members:\n    :special-members: __call__\n\n`AddExtremePointsChannel`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AddExtremePointsChannel\n    :members:\n    :special-members: __call__\n\n`TorchVision`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: TorchVision\n    :members:\n    :special-members: __call__\n\n`TorchIO`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: TorchIO\n    :members:\n    :special-members: __call__\n\n`RandTorchIO`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandTorchIO\n    :members:\n    :special-members: __call__\n\n`MapLabelValue`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: MapLabelValue\n    :members:\n    :special-members: __call__\n\n`EnsureType`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: EnsureType\n    :members:\n    :special-members: __call__\n\n`IntensityStats`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: IntensityStats\n    :members:\n    :special-members: __call__\n\n`ToDevice`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToDevice\n     :members:\n     :special-members: __call__\n\n`CuCIM`\n\"\"\"\"\"\"\"\n.. autoclass:: CuCIM\n    :members:\n    :special-members: __call__\n\n`RandCuCIM`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandCuCIM\n    :members:\n    :special-members: __call__\n\n`AddCoordinateChannels`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AddCoordinateChannels\n    :members:\n    :special-members: __call__\n\n`ImageFilter`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ImageFilter\n    :members:\n    :special-members: __call__\n\n`RandImageFilter`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandImageFilter\n    :members:\n    :special-members: __call__\n\n`ApplyTransformToPoints`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ApplyTransformToPoints\n    :members:\n    :special-members: __call__\n\n`FlattenSequence`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FlattenSequence\n    :members:\n    :special-members: __call__\n\nDictionary Transforms\n---------------------\n\nCrop and Pad (Dict)\n^^^^^^^^^^^^^^^^^^^\n\n`Padd`\n\"\"\"\"\"\"\n.. autoclass:: Padd\n    :members:\n    :special-members: __call__\n\n`SpatialPadd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/SpatialPadd.png\n    :alt: example of SpatialPadd\n.. autoclass:: SpatialPadd\n    :members:\n    :special-members: __call__\n\n`BorderPadd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/BorderPadd.png\n    :alt: example of BorderPadd\n.. autoclass:: BorderPadd\n    :members:\n    :special-members: __call__\n\n`DivisiblePadd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/DivisiblePadd.png\n    :alt: example of DivisiblePadd\n.. autoclass:: DivisiblePadd\n    :members:\n    :special-members: __call__\n\n`Cropd`\n\"\"\"\"\"\"\"\n.. autoclass:: Cropd\n    :members:\n    :special-members: __call__\n\n`RandCropd`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandCropd\n    :members:\n    :special-members: __call__\n\n`SpatialCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/SpatialCropd.png\n    :alt: example of SpatialCropd\n.. autoclass:: SpatialCropd\n    :members:\n    :special-members: __call__\n\n`CenterSpatialCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/CenterSpatialCropd.png\n    :alt: example of CenterSpatialCropd\n.. autoclass:: CenterSpatialCropd\n    :members:\n    :special-members: __call__\n\n`RandSpatialCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSpatialCropd.png\n    :alt: example of RandSpatialCropd\n.. autoclass:: RandSpatialCropd\n    :members:\n    :special-members: __call__\n\n`RandSpatialCropSamplesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSpatialCropSamplesd.png\n    :alt: example of RandSpatialCropSamplesd\n.. autoclass:: RandSpatialCropSamplesd\n    :members:\n    :special-members: __call__\n\n`CropForegroundd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/CropForegroundd.png\n    :alt: example of CropForegroundd\n.. autoclass:: CropForegroundd\n    :members:\n    :special-members: __call__\n\n`RandWeightedCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandWeightedCropd.png\n    :alt: example of RandWeightedCropd\n.. autoclass:: RandWeightedCropd\n    :members:\n    :special-members: __call__\n\n`RandCropByPosNegLabeld`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCropByPosNegLabeld.png\n    :alt: example of RandCropByPosNegLabeld\n.. autoclass:: RandCropByPosNegLabeld\n    :members:\n    :special-members: __call__\n\n`RandCropByLabelClassesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCropByLabelClassesd.png\n    :alt: example of RandCropByLabelClassesd\n.. autoclass:: RandCropByLabelClassesd\n    :members:\n    :special-members: __call__\n\n`ResizeWithPadOrCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ResizeWithPadOrCropd.png\n    :alt: example of ResizeWithPadOrCropd\n.. autoclass:: ResizeWithPadOrCropd\n    :members:\n    :special-members: __call__\n\n`BoundingRectd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: BoundingRectd\n    :members:\n    :special-members: __call__\n\n`RandScaleCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleCropd.png\n    :alt: example of RandScaleCropd\n.. autoclass:: RandScaleCropd\n    :members:\n    :special-members: __call__\n\n`CenterScaleCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/CenterScaleCropd.png\n    :alt: example of CenterScaleCropd\n.. autoclass:: CenterScaleCropd\n    :members:\n    :special-members: __call__\n\nIntensity (Dict)\n^^^^^^^^^^^^^^^^\n\n`RandGaussianNoised`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGaussianNoised.png\n    :alt: example of RandGaussianNoised\n.. autoclass:: RandGaussianNoised\n    :members:\n    :special-members: __call__\n\n`ShiftIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ShiftIntensityd.png\n    :alt: example of ShiftIntensityd\n.. autoclass:: ShiftIntensityd\n    :members:\n    :special-members: __call__\n\n`RandShiftIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandShiftIntensityd.png\n    :alt: example of RandShiftIntensityd\n.. autoclass:: RandShiftIntensityd\n    :members:\n    :special-members: __call__\n\n`StdShiftIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/StdShiftIntensityd.png\n    :alt: example of StdShiftIntensityd\n.. autoclass:: StdShiftIntensityd\n    :members:\n    :special-members: __call__\n\n`RandStdShiftIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandStdShiftIntensityd.png\n    :alt: example of RandStdShiftIntensityd\n.. autoclass:: RandStdShiftIntensityd\n    :members:\n    :special-members: __call__\n\n`RandBiasFieldd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandBiasFieldd.png\n    :alt: example of RandBiasFieldd\n.. autoclass:: RandBiasFieldd\n    :members:\n    :special-members: __call__\n\n`ScaleIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ScaleIntensityd.png\n    :alt: example of ScaleIntensityd\n.. autoclass:: ScaleIntensityd\n    :members:\n    :special-members: __call__\n\n`ClipIntensityPercentilesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ClipIntensityPercentilesd\n    :members:\n    :special-members: __call__\n\n`RandScaleIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensityd.png\n    :alt: example of RandScaleIntensityd\n.. autoclass:: RandScaleIntensityd\n    :members:\n    :special-members: __call__\n\n`RandScaleIntensityFixedMeand`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandScaleIntensityFixedMeand\n    :members:\n    :special-members: __call__\n\n`NormalizeIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/NormalizeIntensityd.png\n    :alt: example of NormalizeIntensityd\n.. autoclass:: NormalizeIntensityd\n    :members:\n    :special-members: __call__\n\n`ThresholdIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ThresholdIntensityd.png\n    :alt: example of ThresholdIntensityd\n.. autoclass:: ThresholdIntensityd\n    :members:\n    :special-members: __call__\n\n`ScaleIntensityRanged`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ScaleIntensityRanged.png\n    :alt: example of ScaleIntensityRanged\n.. autoclass:: ScaleIntensityRanged\n    :members:\n    :special-members: __call__\n\n`GibbsNoised`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GibbsNoised.png\n    :alt: example of GibbsNoised\n.. autoclass:: GibbsNoised\n    :members:\n    :special-members: __call__\n\n`RandGibbsNoised`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGibbsNoised.png\n    :alt: example of RandGibbsNoised\n.. autoclass:: RandGibbsNoised\n    :members:\n    :special-members: __call__\n\n`KSpaceSpikeNoised`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/KSpaceSpikeNoised.png\n    :alt: example of KSpaceSpikeNoised\n.. autoclass:: KSpaceSpikeNoised\n    :members:\n    :special-members: __call__\n\n`RandKSpaceSpikeNoised`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandKSpaceSpikeNoised.png\n    :alt: example of RandKSpaceSpikeNoised\n.. autoclass:: RandKSpaceSpikeNoised\n    :members:\n    :special-members: __call__\n\n`RandRicianNoised`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandRicianNoised.png\n    :alt: example of RandRicianNoised\n.. autoclass:: RandRicianNoised\n    :members:\n    :special-members: __call__\n\n`ScaleIntensityRangePercentilesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ScaleIntensityRangePercentilesd.png\n    :alt: example of ScaleIntensityRangePercentilesd\n.. autoclass:: ScaleIntensityRangePercentilesd\n    :members:\n    :special-members: __call__\n\n`AdjustContrastd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/AdjustContrastd.png\n    :alt: example of AdjustContrastd\n.. autoclass:: AdjustContrastd\n    :members:\n    :special-members: __call__\n\n`RandAdjustContrastd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandAdjustContrastd.png\n    :alt: example of RandAdjustContrastd\n.. autoclass:: RandAdjustContrastd\n    :members:\n    :special-members: __call__\n\n`MaskIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/MaskIntensityd.png\n    :alt: example of MaskIntensityd\n.. autoclass:: MaskIntensityd\n    :members:\n    :special-members: __call__\n\n`SavitzkyGolaySmoothd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/SavitzkyGolaySmoothd.png\n    :alt: example of SavitzkyGolaySmoothd\n.. autoclass:: SavitzkyGolaySmoothd\n    :members:\n    :special-members: __call__\n\n`MedianSmoothd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/MedianSmoothd.png\n    :alt: example of MedianSmoothd\n.. autoclass:: MedianSmoothd\n    :members:\n    :special-members: __call__\n\n`GaussianSmoothd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GaussianSmoothd.png\n    :alt: example of GaussianSmoothd\n.. autoclass:: GaussianSmoothd\n    :members:\n    :special-members: __call__\n\n`RandGaussianSmoothd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGaussianSmoothd.png\n    :alt: example of RandGaussianSmoothd\n.. autoclass:: RandGaussianSmoothd\n    :members:\n    :special-members: __call__\n\n`GaussianSharpend`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GaussianSharpend.png\n    :alt: example of GaussianSharpend\n.. autoclass:: GaussianSharpend\n    :members:\n    :special-members: __call__\n\n`RandGaussianSharpend`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGaussianSharpend.png\n    :alt: example of RandGaussianSharpend\n.. autoclass:: RandGaussianSharpend\n    :members:\n    :special-members: __call__\n\n`RandHistogramShiftd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandHistogramShiftd.png\n    :alt: example of RandHistogramShiftd\n.. autoclass:: RandHistogramShiftd\n    :members:\n    :special-members: __call__\n\n`RandCoarseDropoutd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCoarseDropoutd.png\n    :alt: example of RandCoarseDropoutd\n.. autoclass:: RandCoarseDropoutd\n    :members:\n    :special-members: __call__\n\n`RandCoarseShuffled`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandCoarseShuffled.png\n    :alt: example of RandCoarseShuffled\n.. autoclass:: RandCoarseShuffled\n    :members:\n    :special-members: __call__\n\n`HistogramNormalized`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/HistogramNormalized.png\n    :alt: example of HistogramNormalized\n.. autoclass:: HistogramNormalized\n    :members:\n    :special-members: __call__\n\n`ForegroundMaskd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/ForegroundMaskd.png\n    :alt: example of ForegroundMaskd\n.. autoclass:: ForegroundMaskd\n    :members:\n    :special-members: __call__\n\n`ComputeHoVerMapsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ComputeHoVerMapsd\n    :members:\n    :special-members: __call__\n\nIO (Dict)\n^^^^^^^^^\n\n`LoadImaged`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: LoadImaged\n    :members:\n    :special-members: __call__\n\n`SaveImaged`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SaveImaged\n    :members:\n    :special-members: __call__\n\n`WriteFileMappingd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: WriteFileMappingd\n    :members:\n    :special-members: __call__\n\nPost-processing (Dict)\n^^^^^^^^^^^^^^^^^^^^^^\n\n`Activationsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Activationsd\n    :members:\n    :special-members: __call__\n\n`AsDiscreted`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/AsDiscreted.png\n    :alt: example of AsDiscreted\n.. autoclass:: AsDiscreted\n    :members:\n    :special-members: __call__\n\n`KeepLargestConnectedComponentd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/KeepLargestConnectedComponentd.png\n    :alt: example of KeepLargestConnectedComponentd\n.. autoclass:: KeepLargestConnectedComponentd\n    :members:\n    :special-members: __call__\n\n`DistanceTransformEDTd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: DistanceTransformEDTd\n    :members:\n    :special-members: __call__\n\n`RemoveSmallObjectsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjectsd.png\n    :alt: example of RemoveSmallObjectsd\n.. autoclass:: RemoveSmallObjectsd\n    :members:\n    :special-members: __call__\n\n`LabelFilterd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/LabelFilterd.png\n    :alt: example of LabelFilterd\n.. autoclass:: LabelFilterd\n    :members:\n    :special-members: __call__\n\n`FillHolesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FillHolesd\n    :members:\n    :special-members: __call__\n\n`LabelToContourd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/LabelToContourd.png\n    :alt: example of LabelToContourd\n.. autoclass:: LabelToContourd\n    :members:\n    :special-members: __call__\n\n`Ensembled`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Ensembled\n    :members:\n    :special-members: __call__\n\n`MeanEnsembled`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: MeanEnsembled\n    :members:\n    :special-members: __call__\n\n`VoteEnsembled`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: VoteEnsembled\n    :members:\n    :special-members: __call__\n\n`Invertd`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: Invertd\n    :members:\n    :special-members: __call__\n\n`SaveClassificationd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SaveClassificationd\n    :members:\n    :special-members: __call__\n\n`ProbNMSd`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ProbNMSd\n  :members:\n  :special-members: __call__\n\n\n`SobelGradientsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SobelGradientsd\n  :members:\n  :special-members: __call__\n\nRegularization (Dict)\n^^^^^^^^^^^^^^^^^^^^^\n\n`CutMixd`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: CutMixd\n    :members:\n    :special-members: __call__\n\n`CutOutd`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: CutOutd\n    :members:\n    :special-members: __call__\n\n`MixUpd`\n\"\"\"\"\"\"\"\"\n.. autoclass:: MixUpd\n    :members:\n    :special-members: __call__\n\nSignal (Dict)\n^^^^^^^^^^^^^\n\n`SignalFillEmptyd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SignalFillEmptyd\n    :members:\n    :special-members: __call__\n\n\nSpatial (Dict)\n^^^^^^^^^^^^^^\n\n`SpatialResampled`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SpatialResampled\n    :members:\n    :special-members: __call__\n\n`ResampleToMatchd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ResampleToMatchd\n    :members:\n    :special-members: __call__\n\n`Spacingd`\n\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Spacingd.png\n    :alt: example of Spacingd\n.. autoclass:: Spacingd\n    :members:\n    :special-members: __call__\n\n`Orientationd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Orientationd.png\n    :alt: example of Orientationd\n.. autoclass:: Orientationd\n    :members:\n    :special-members: __call__\n\n`Flipd`\n\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Flipd.png\n    :alt: example of Flipd\n.. autoclass:: Flipd\n    :members:\n    :special-members: __call__\n\n`RandFlipd`\n\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandFlipd.png\n    :alt: example of RandFlipd\n.. autoclass:: RandFlipd\n    :members:\n    :special-members: __call__\n\n`RandAxisFlipd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandAxisFlipd.png\n    :alt: example of RandAxisFlipd\n.. autoclass:: RandAxisFlipd\n    :members:\n    :special-members: __call__\n\n`Rotated`\n\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rotated.png\n    :alt: example of Rotated\n.. autoclass:: Rotated\n    :members:\n    :special-members: __call__\n\n`RandRotated`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandRotated.png\n    :alt: example of RandRotated\n.. autoclass:: RandRotated\n    :members:\n    :special-members: __call__\n\n`Zoomd`\n\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Zoomd.png\n    :alt: example of Zoomd\n.. autoclass:: Zoomd\n    :members:\n    :special-members: __call__\n\n`RandZoomd`\n\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandZoomd.png\n    :alt: example of RandZoomd\n.. autoclass:: RandZoomd\n    :members:\n    :special-members: __call__\n\n`GridPatchd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: GridPatchd\n    :members:\n    :special-members: __call__\n\n`RandGridPatchd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandGridPatchd\n    :members:\n    :special-members: __call__\n\n`GridSplitd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: GridSplitd\n    :members:\n    :special-members: __call__\n\n\n`RandRotate90d`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandRotate90d.png\n    :alt: example of RandRotate90d\n.. autoclass:: RandRotate90d\n    :members:\n    :special-members: __call__\n\n`Rotate90d`\n\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rotate90d.png\n    :alt: example of Rotate90d\n.. autoclass:: Rotate90d\n    :members:\n    :special-members: __call__\n\n`Resized`\n\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Resized.png\n    :alt: example of Resized\n.. autoclass:: Resized\n    :members:\n    :special-members: __call__\n\n`Affined`\n\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Affined.png\n    :alt: example of Affined\n.. autoclass:: Affined\n    :members:\n    :special-members: __call__\n\n`RandAffined`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandAffined.png\n    :alt: example of RandAffined\n.. autoclass:: RandAffined\n    :members:\n    :special-members: __call__\n\n`Rand2DElasticd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rand2DElasticd.png\n    :alt: example of Rand2DElasticd\n.. autoclass:: Rand2DElasticd\n    :members:\n    :special-members: __call__\n\n`Rand3DElasticd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/Rand3DElasticd.png\n    :alt: example of Rand3DElasticd\n.. autoclass:: Rand3DElasticd\n    :members:\n    :special-members: __call__\n\n`GridDistortiond`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/GridDistortiond.png\n    :alt: example of GridDistortiond\n.. autoclass:: GridDistortiond\n    :members:\n    :special-members: __call__\n\n`RandGridDistortiond`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandGridDistortiond.png\n    :alt: example of RandGridDistortiond\n.. autoclass:: RandGridDistortiond\n    :members:\n    :special-members: __call__\n\n`RandSimulateLowResolutiond`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandSimulateLowResolutiond\n    :members:\n    :special-members: __call__\n\n`ConvertBoxToPointsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConvertBoxToPointsd\n    :members:\n    :special-members: __call__\n\n`ConvertPointsToBoxesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConvertPointsToBoxesd\n    :members:\n    :special-members: __call__\n\n\nSmooth Field (Dict)\n^^^^^^^^^^^^^^^^^^^\n\n`RandSmoothFieldAdjustContrastd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSmoothFieldAdjustContrastd.png\n    :alt: example of RandSmoothFieldAdjustContrastd\n.. autoclass:: RandSmoothFieldAdjustContrastd\n    :members:\n    :special-members: __call__\n\n`RandSmoothFieldAdjustIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSmoothFieldAdjustIntensityd.png\n    :alt: example of RandSmoothFieldAdjustIntensityd\n.. autoclass:: RandSmoothFieldAdjustIntensityd\n    :members:\n    :special-members: __call__\n\n`RandSmoothDeformd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandSmoothDeformd.png\n    :alt: example of RandSmoothDeformd\n.. autoclass:: RandSmoothDeformd\n    :members:\n    :special-members: __call__\n\n\n`MRI transforms (Dict)`\n^^^^^^^^^^^^^^^^^^^^^^^\n\n`Kspace under-sampling (Dict)`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: monai.apps.reconstruction.transforms.dictionary.RandomKspaceMaskd\n    :special-members: __call__\n\n.. autoclass:: monai.apps.reconstruction.transforms.dictionary.EquispacedKspaceMaskd\n    :special-members: __call__\n\n`ExtractDataKeyFromMetaKeyd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: monai.apps.reconstruction.transforms.dictionary.ExtractDataKeyFromMetaKeyd\n    :special-members: __call__\n\n`ReferenceBasedSpatialCropd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: monai.apps.reconstruction.transforms.dictionary.ReferenceBasedSpatialCropd\n    :special-members: __call__\n\n`ReferenceBasedNormalizeIntensityd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: monai.apps.reconstruction.transforms.dictionary.ReferenceBasedNormalizeIntensityd\n    :special-members: __call__\n\n\nLazy (Dict)\n^^^^^^^^^^^\n\n`ApplyPendingd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n\n.. autoclass:: ApplyPendingd\n    :members:\n    :special-members: __call__\n\n\nUtility (Dict)\n^^^^^^^^^^^^^^\n\n`Identityd`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Identityd\n    :members:\n    :special-members: __call__\n\n`AsChannelLastd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AsChannelLastd\n    :members:\n    :special-members: __call__\n\n`EnsureChannelFirstd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: EnsureChannelFirstd\n    :members:\n    :special-members: __call__\n\n`RepeatChanneld`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RepeatChanneld\n    :members:\n    :special-members: __call__\n\n`SplitDimd`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SplitDimd\n    :members:\n    :special-members: __call__\n\n`CastToTyped`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: CastToTyped\n    :members:\n    :special-members: __call__\n\n`ToTensord`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToTensord\n    :members:\n    :special-members: __call__\n\n`ToNumpyd`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToNumpyd\n    :members:\n    :special-members: __call__\n\n`ToPIL`\n\"\"\"\"\"\"\"\n.. autoclass:: ToPIL\n    :members:\n    :special-members: __call__\n\n`ToCupyd`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToCupyd\n    :members:\n    :special-members: __call__\n\n`ToPILd`\n\"\"\"\"\"\"\"\"\n.. autoclass:: ToPILd\n    :members:\n    :special-members: __call__\n\n`DeleteItemsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: DeleteItemsd\n    :members:\n    :special-members: __call__\n\n`SelectItemsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SelectItemsd\n    :members:\n    :special-members: __call__\n\n`FlattenSubKeysd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FlattenSubKeysd\n    :members:\n    :special-members: __call__\n\n`Transposed`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: Transposed\n    :members:\n    :special-members: __call__\n\n`SqueezeDimd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SqueezeDimd\n    :members:\n    :special-members: __call__\n\n`DataStatsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: DataStatsd\n    :members:\n    :special-members: __call__\n\n`SimulateDelayd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: SimulateDelayd\n    :members:\n    :special-members: __call__\n\n`CopyItemsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: CopyItemsd\n    :members:\n    :special-members: __call__\n\n`ConcatItemsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConcatItemsd\n    :members:\n    :special-members: __call__\n\n`Lambdad`\n\"\"\"\"\"\"\"\"\"\n.. autoclass:: Lambdad\n    :members:\n    :special-members: __call__\n\n`RandLambdad`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandLambdad\n    :members:\n    :special-members: __call__\n\n`RemoveRepeatedChanneld`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RemoveRepeatedChanneld\n    :members:\n    :special-members: __call__\n\n`LabelToMaskd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: LabelToMaskd\n    :members:\n    :special-members: __call__\n\n`FgBgToIndicesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FgBgToIndicesd\n    :members:\n    :special-members: __call__\n\n`ClassesToIndicesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ClassesToIndicesd\n    :members:\n    :special-members: __call__\n\n`ConvertToMultiChannelBasedOnBratsClassesd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ConvertToMultiChannelBasedOnBratsClassesd\n    :members:\n    :special-members: __call__\n\n`AddExtremePointsChanneld`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AddExtremePointsChanneld\n    :members:\n    :special-members: __call__\n\n`TorchVisiond`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: TorchVisiond\n    :members:\n    :special-members: __call__\n\n`RandTorchVisiond`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandTorchVisiond\n    :members:\n    :special-members: __call__\n\n`TorchIOd`\n\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: TorchIOd\n    :members:\n    :special-members: __call__\n\n`RandTorchIOd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandTorchIOd\n    :members:\n    :special-members: __call__\n\n`MapLabelValued`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: MapLabelValued\n    :members:\n    :special-members: __call__\n\n`EnsureTyped`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: EnsureTyped\n    :members:\n    :special-members: __call__\n\n`IntensityStatsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: IntensityStatsd\n    :members:\n    :special-members: __call__\n\n`ToDeviced`\n\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToDeviced\n    :members:\n    :special-members: __call__\n\n`CuCIMd`\n\"\"\"\"\"\"\"\"\n.. autoclass:: CuCIMd\n    :members:\n    :special-members: __call__\n\n`RandCuCIMd`\n\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandCuCIMd\n    :members:\n    :special-members: __call__\n\n`AddCoordinateChannelsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: AddCoordinateChannelsd\n    :members:\n    :special-members: __call__\n\n`ImageFilterd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ImageFilterd\n    :members:\n    :special-members: __call__\n\n`RandImageFilterd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: RandImageFilterd\n    :members:\n    :special-members: __call__\n\n`ApplyTransformToPointsd`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ApplyTransformToPointsd\n    :members:\n    :special-members: __call__\n\n`FlattenSequenced`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FlattenSequenced\n    :members:\n    :special-members: __call__\n\n\nMetaTensor\n^^^^^^^^^^\n\n`ToMetaTensord`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: ToMetaTensord\n    :members:\n    :special-members: __call__\n\n`FromMetaTensord`\n\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\"\n.. autoclass:: FromMetaTensord\n    :members:\n    :special-members: __call__\n\nTransform Adaptors\n------------------\n.. automodule:: monai.transforms.adaptors\n\n`FunctionSignature`\n^^^^^^^^^^^^^^^^^^^\n.. autoclass:: FunctionSignature\n    :members:\n\n`adaptor`\n^^^^^^^^^\n.. autofunction:: monai.transforms.adaptors.adaptor\n\n`apply_alias`\n^^^^^^^^^^^^^\n.. autofunction:: monai.transforms.adaptors.apply_alias\n\n`to_kwargs`\n^^^^^^^^^^^\n.. autofunction:: monai.transforms.adaptors.to_kwargs\n\nUtilities\n---------\n.. automodule:: monai.transforms.utils\n    :members:\n\n.. automodule:: monai.transforms.utils_pytorch_numpy_unification\n    :members:\n\n.. automodule:: monai.transforms.utils_morphological_ops\n    :members:\n\nBy Categories\n-------------\n.. toctree::\n   :maxdepth: 1\n\n   transforms_idx\n"
  },
  {
    "path": "docs/source/transforms_idx.rst",
    "content": ".. _transforms_idx:\n\n.. currentmodule:: monai.transforms\n\nCrop and pad\n^^^^^^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   croppad.array\n   croppad.dictionary\n   croppad.batch\n\nSpatial\n^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   spatial.array\n   spatial.dictionary\n\n\nIntensity\n^^^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   intensity.array\n   intensity.dictionary\n\nIO\n^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   io.array\n   io.dictionary\n\nLazy\n^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   lazy.array\n   lazy.dictionary\n   lazy.utils\n\nMetaTensor utilities\n^^^^^^^^^^^^^^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   meta_utility.dictionary\n\nPost-processing\n^^^^^^^^^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   post.array\n   post.dictionary\n\nRegularization\n^^^^^^^^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   regularization.array\n   regularization.dictionary\n\nSignal\n^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   signal.array\n\nSmooth field\n^^^^^^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   smooth_field.array\n   smooth_field.dictionary\n\nUtility\n^^^^^^^\n\n.. autosummary::\n   :toctree: _gen\n   :nosignatures:\n\n   utility.array\n   utility.dictionary\n"
  },
  {
    "path": "docs/source/utils.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _utils:\n\nUtilities\n=========\n\nConfigurations\n--------------\n.. automodule:: monai.config.deviceconfig\n  :members:\n\n\nModule utils\n------------\n.. automodule:: monai.utils.module\n  :members:\n\n\nMisc\n----\n.. automodule:: monai.utils.misc\n  :members:\n\n\nNVTX Annotations\n----------------\n.. automodule:: monai.utils.nvtx\n  :members:\n\n\nProfiling\n---------\n.. automodule:: monai.utils.profiling\n  :members:\n\n\nDeprecated\n----------\n.. automodule:: monai.utils.deprecate_utils\n  :members:\n\n\nType conversion\n---------------\n.. automodule:: monai.utils.type_conversion\n  :members:\n\nDecorators\n----------\n.. automodule:: monai.utils.decorators\n  :members:\n\nDistributed Data Parallel\n-------------------------\n.. automodule:: monai.utils.dist\n  :members:\n\nEnums\n-----\n.. automodule:: monai.utils.enums\n  :members:\n\nJupyter Utilities\n-----------------\n.. automodule:: monai.utils.jupyter_utils\n  :members:\n\nState Cacher\n------------\n.. automodule:: monai.utils.state_cacher\n  :members:\n\nComponent store\n---------------\n.. autoclass:: monai.utils.component_store.ComponentStore\n  :members:\n\nOrdering\n--------\n.. automodule:: monai.utils.ordering\n  :members:\n"
  },
  {
    "path": "docs/source/visualize.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\n.. _visualize:\n\nVisualizations\n==============\n\n.. currentmodule:: monai.visualize\n\nTensorboard visuals\n-------------------\n\n.. automodule:: monai.visualize.img2tensorboard\n  :members:\n\nClass activation map\n--------------------\n\n.. automodule:: monai.visualize.class_activation_maps\n  :members:\n\nOcclusion sensitivity\n---------------------\n\n.. automodule:: monai.visualize.occlusion_sensitivity\n  :members:\n\nGradient-based saliency maps\n----------------------------\n\n.. automodule:: monai.visualize.gradient_based\n  :members:\n\n\nUtilities\n---------\n\n.. automodule:: monai.visualize.utils\n  :members:\n"
  },
  {
    "path": "docs/source/whatsnew.rst",
    "content": ":github_url: https://github.com/Project-MONAI/MONAI\n\nWhat's New\n==========\n\n.. toctree::\n   :maxdepth: 1\n\n   whatsnew_1_5_2.md\n   whatsnew_1_5_1.md\n   whatsnew_1_5.md\n   whatsnew_1_4.md\n   whatsnew_1_3.md\n   whatsnew_1_2.md\n   whatsnew_1_1.md\n   whatsnew_1_0.md\n   whatsnew_0_9.md\n   whatsnew_0_8.md\n   whatsnew_0_7.md\n   whatsnew_0_6.md\n   whatsnew_0_5.md\n"
  },
  {
    "path": "docs/source/whatsnew_0_5.md",
    "content": "# What's new in 0.5\n\n- Invert spatial transforms and test-time augmentations\n- Lesion detection in digital pathology\n- DeepGrow modules for interactive segmentation\n- Various usability improvements\n\n## Invert spatial transforms and test-time augmentations\nIt is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) with the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space.  We enhance almost all the spatial transforms with an `inverse` operation and release this experimental feature in v0.5. Users can easily invert all the spatial transforms for one transformed data item or a batch of data items. It also can be achieved within the workflows by using the `TransformInverter` handler.\n\nIf the pipeline includes random transformations, users may want to observe the effect that these transformations have on the output. The typical approach is that we pass the same input through the transforms multiple times with different random realizations. Then use the inverse transforms to move all the results to a common space, and calculate the metrics. MONAI provided `TestTimeAugmentation` for this feature, which by default will calculate the `mode`, `mean`, `standard deviation` and `volume variation coefficient`.\n\n[Invert transforms and TTA tutorials](https://github.com/Project-MONAI/tutorials/blob/master/modules/inverse_transforms_and_test_time_augmentations.ipynb) introduce details about the API with examples.\n\n(1) The last column is the inverted data of model output:\n![invert transform](../images/invert_transforms.png)\n\n(2) The TTA results of `mode`, `mean` and `standard deviation`:\n![test time augmentation](../images/tta.png)\n\n## Lesion detection in digital pathology\nMONAI starts to support digital pathology deep learning tasks. The initial [implementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/pathology) of the pathology detection components includes:\n- Efficient whole slide imaging IO with NVIDIA cuCIM library\n- Patch-based sampling and training strategies with the SmartCache mechanism\n- FROC measurements for lesion detection\n- Probabilistic post-processing for lesion ROIs.\n\n![digital pathology](../images/pathology.png)\n\n## DeepGrow modules for interactive segmentation\nTowards an interactive workflow with manual input during training and inference,\n[a reimplementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/deepgrow) of the DeepGrow components is included in this release.\nDeepGrow is a deep learning based semi-automated segmentation approach that aims to be a \"smart\" interactive tool for regions of interest delineation in medical images.\n\n![deepgrow scheme](../images/deepgrow_scheme.png)\n\nAn end-to-end example is presented at [`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/tree/master/deepgrow/ignite).\n![deepgrow end-to-end](../images/deepgrow.png)\n\n## Learning-based image registration\nStarting from v0.5, MONAI provides experimental features for building learning-based 2D/3D registration workflows. These include image similarity measures as loss functions, bending energy as model regularization, network architectures, warping modules. The components can be used to build the major unsupervised and weakly-supervised algorithms.\n\nThe following figure shows the registration of CT images acquired at different time points for a single patient using MONAI:\n\n![3d registration](../images/3d_paired.png)\n\n## Various usability improvements\n### IO factory for medical image formats\nMany popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in the below priority order:\n- User-specified reader at runtime when call this loader.\n- Registered readers from the latest to the first in list.\n- Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (others -> ITKReader).\n\nThe `ImageReader` API is quite straight-forward, users can easily extend for their own customized image readers.\n\nWith these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc.\n\n### Save transform data into NIfTI or PNG files\nTo convert images into files or debug the transform chain, MONAI provides `SaveImage` transform. Users can inject this transform into the transform chain to save the results.\n\n### Automatically ensure `channel-first` data shape\nMedical images have different shape formats. They can be `channel-last`, `channel-first` or even `no-channel`. We may, for example, want to load several `no-channel` images and stack them as `channel-first` data. To improve the user experience, MONAI provided an `EnsureChannelFirst` transform to automatically detect data shape according to the meta information and convert it to the `channel-first` format consistently.\n\n### Network architectures\nVarious ready-to-use architectures with pretrained model weights from `torch.hub`.\n\n### Result writing\nCurrently MONAI supports writing the model outputs as NIfTI files or PNG files for segmentation tasks, and as CSV files for classification tasks. And the writers can restore the data spacing, orientation or shape according to the `original_shape` or `original_affine` information from the input image.\n\nA rich set of formats will be supported soon, along with relevant statistics and evaluation metrics automatically computed from the outputs.\n\n### Transfer learning for different input / output classes\n`Transfer-learning` is a very common and efficient training approach, especially in the medical-specific domain where obtaining large datasets for training can be difficult. So transfer-learning from a pre-trained checkpoint can significantly improve the model metrics and shorten training time.\n\nMONAI provided `CheckpointLoader` to load a checkpoint for the workflow before training, and it allows some `layer names` of current network don't match the checkpoint, or some `layer shapes` don't match the checkpoint, which can be useful if the current task has different input image classes or output classes.\n\n### C++/CUDA optimized modules\nTo accelerate some heavy computation progress, C++/CUDA implementation can be an impressive method, which usually brings even hundreds of times faster performance. MONAI contains some C++/CUDA optimized modules, like `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`, and fully support C++/CUDA programs in CI/CD and building package.\n"
  },
  {
    "path": "docs/source/whatsnew_0_6.md",
    "content": "# What's new in 0.6\n\n- Decollating mini-batches as an essential post-processing step\n- Pythonic APIs to load the pretrained models from Clara Train MMARs\n- UNETR: Transformers for Medical Image Segmentation\n- Enhancements of the base metric interfaces\n- C++/CUDA extension modules via PyTorch JIT compilation\n- Backward compatibility and enhanced continuous integration/continuous delivery\n- Collaboration with Project-MONAI/MONAILabel for smooth integration\n\n\n## Decollating mini-batches as an essential post-processing step\n`decollate batch` is introduced in MONAI v0.6, to simplify the post-processing transforms and enable flexible operations on a batch of model outputs.\nIt can decollate batched data (e.g. model inference results) into a list of tensors -- as an 'inverse' operation of `collate_fn` of the PyTorch data loader. It has the benefits such as:\n- enabling postprocessing transforms for each item independently, for example, randomised transforms could be applied differently for each predicted item in a batch.\n- simplifying the transform APIs and reducing the input validation burdens, because both the preprocessing and postprocessing transforms now only support the \"channel-first\" input format.\n- enabling the transform inverse operation for data items in different original shapes, as the inverted items are in a list, instead of being stacked in a single tensor.\n- allowing for both a \"batch-first\" tensor and a list of \"channel-first\" tensors for flexible metric computation.\n\nA typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):\n![decollate_batch](../images/decollate_batch.png)\n\n[decollate batch tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb) shows a detailed usage example based on a PyTorch native workflow.\n\n[Migrating your v0.5 code to v0.6](https://github.com/Project-MONAI/MONAI/wiki/v0.5-to-v0.6-migration-guide) wiki shows how to migrate an existing program from v0.5 to v0.6 to adapt to the `decollate batch` logic.\n\n## UNETR: Transformers for Medical Image Segmentation\n[UNETR](https://arxiv.org/abs/2103.10504) is a transformer-based model for volumetric (3D) medical image segmentation and is currently the state-of-the-art on [BTCV dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217752) test server for the task of multi-organ semantic segmentation. UNETR is introduced in MONAI v0.6 and its flexible implementation supports various segmentation tasks.\n![UNETR](../images/UNETR.png)\n\nA tutorial for the task of 3D multi-organ semantic segmentation using UNETR is provided within\n[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unetr_btcv_segmentation_3d.ipynb).\nAnd it contains the following features:\n- Transforms for dictionary format data,\n- Defining a new transform according to MONAI transform API,\n- Loading Nifti image with metadata, loading a list of images and stacking them,\n- Randomly adjusting the intensity for data augmentation,\n- Optimized cache IO and transforms to accelerate training and validation,\n- 3D UNETR model, DiceCE loss function and Mean Dice metric for multi-organ segmentation task,\n\nThe following illustrates target body organs that are segmentation in this tutorial:\n![BTCV_organs](../images/BTCV_organs.png)\n\nPlease visit UNETR repository for more details:\nhttps://project-monai.github.io/research/unetr-btcv-multi-organ-segmentation\n\n## Pythonic APIs to load the pretrained models from Clara Train MMARs\n[The MMAR (Medical Model ARchive)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html)\ndefines a data structure for organizing all artifacts produced during the model development life cycle.\nNVIDIA Clara provides [various MMARs of medical domain-specific models](https://ngc.nvidia.com/catalog/models?orderBy=scoreDESC&pageNumber=0&query=clara_pt&quickFilter=&filters=).\nThese MMARs include all the information about the model including configurations and scripts to provide a workspace to perform model development tasks. To better leverage the trained MMARs released on Nvidia GPU cloud, MONAI provides pythonic APIs to access them.\n\nTo demonstrate this new feature, a medical image segmentation tutorial is created within\n[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb).\nIt mainly produces the following figure to compare the loss curves and validation scores for\n- training from scratch (the green line),\n- applying pretrained MMAR weights without training (the magenta line),\n- training from the MMAR model weights (the blue line),\n\naccording to the number of training epochs:\n\n![transfer_mmar](../images/transfer_mmar.png)\n\nThe tutorial shows the capability of encapsulating the details of MMAR parsing, as well as the potential of using pretrained MMARs for transfer learning.\nThese APIs are also being integrated into AI-assisted interactive workflows to accelerate the manual annotating processes (e.g. via [project-MONAI/MONAILabel](https://github.com/Project-MONAI/MONAILabel)).\n\n## Enhancements of the base metric interfaces\nThe base API for metrics is now enhanced to support the essential computation logic for both iteration and epoch-based metrics.\nWith this update, the MONAI metrics module becomes more extensible, and thus a good starting point for customised metrics.\nThe APIs also by default support data parallel computation and consider the computation efficiency:  with a `Cumulative` base class, intermediate metric outcomes can be automatically buffered, cumulated, synced across distributed processes, and aggregated for the final results. The [multi-processing computation example](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py) shows how to compute metrics based on saved predictions and labels in multi-processing environment.\n\n## C++/CUDA extension modules via PyTorch JIT compilation\nTo further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA modules are introduced as extensions of the PyTorch native implementation.\nIt now provides modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):\n- via `setuptools` (since MONAI v0.5), for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.\n- via just-in-time (JIT) compilation (since MONAI v0.6), for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.\nThe following figure shows results of MONAI's Gaussian mixture models applied to a tissue and surgical tools segmentation task:\n![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)\n\n## Backward compatibility and enhanced continuous integration/continuous delivery\nStarting from this version, we experiment with basic policies of backward compatibility.\nNew utilities are introduced on top of the existing semantic versioning modules, and the git branching model.\n\nAt the same time, we actively analyze efficient, scalable, and secure CI/CD solutions to accommodate fast and collaborative codebase development.\n\nAlthough a complete mechanism is still under development, these provide another essential step towards API-stable versions of MONAI, sustainable release cycles, and efficient open-source collaborations.\n\n## Collaboration with [`Project-MONAI/MONAILabel`](https://github.com/Project-MONAI/MONAILabel) for smooth integration\nSince MONAI v0.6, we welcome [`MONAILabel`](https://github.com/Project-MONAI/MONAILabel) under [`Project-MONAI`](https://github.com/Project-MONAI).\n\nMONAI Label is an intelligent open source image labeling and learning tool that enables users to create annotated datasets and build AI annotation models for clinical evaluation.\nMONAI Label enables application developers to build labeling apps in a serverless way,\nwhere custom labeling apps are exposed as a service through the MONAI Label Server.\n\nPlease visit MONAILabel documentation website for details:\nhttps://monai.readthedocs.io/projects/label/en/latest/\n"
  },
  {
    "path": "docs/source/whatsnew_0_7.md",
    "content": "# What's new in 0.7\n\n- Performance enhancements with profiling and tuning guides\n- Major usability improvements in `monai.transforms`\n- Reimplementing state-of-the-art Kaggle solutions\n- Vision-language multimodal transformer architectures\n\n## Performance enhancements with profiling and tuning guides\n\nModel training is often a time-consuming step during deep learning development,\nespecially for medical imaging applications. Even with powerful hardware (e.g.\nCPU/GPU with large RAM), the workflows often require careful profiling and\ntuning to achieve high performance. MONAI has been focusing on performance\nenhancements, and in this version, a fast model training guide is provided\nto help build highly performant workflows, with a comprehensive overview of\nthe profiling tools and practical strategies:\nhttps://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.\n\nThe following figure shows the use of [Nvidia Nsight™ Systems](https://developer.nvidia.com/nsight-systems) for system-wide\nperformance analysis during a performance enhancement study.\n![nsight_vis](../images/nsight_comparison.png)\n\nWith the performance profiling and enhancements, several typical use cases were studied to\nimprove the training efficiency.  The following figure shows that fast\ntraining using MONAI can be `200` times faster than a regular baseline ([learn\nmore](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb)), and it's `20` times faster than the MONAI v0.6 fast training solution.\n![fast_training](../images/fast_training.png)\n\n## Major usability improvements in `monai.transforms` for NumPy/PyTorch inputs and backends\n\n MONAI starts to roll out major usability enhancements for the\n `monai.transforms` module. Many transforms are now supporting both NumPy and\n PyTorch, as input types and computational backends. To get the supported backends of every transform, please execute: `python monai/transforms/utils.py`.\n\nOne benefit of these enhancements is that the users can now better leverage the\nGPUs for preprocessing. By transferring the input data onto GPU using\n`ToTensor` or `EnsureType`, and applying the GPU-based transforms to the data,\n[the tutorial of spleen\nsegmentation](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb)\nshows the great potential of using the flexible modules for fast and efficient\ntraining.\n\n## Reimplementing state-of-the-art Kaggle solutions\n\nWith this release, we actively evaluate and enhance the quality and flexibility\nof the MONAI core modules, using the public Kaggle challenge as a testbed. [A\nreimplementation](https://github.com/Project-MONAI/tutorials/tree/main/competitions/kaggle/RANZCR/4th_place_solution)\nof a state-of-the-art solution at [Kaggle RANZCR CLiP - Catheter and Line\nPosition\nChallenge](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification)\nis made available in this version.\n\n##  Vision-language multimodal transformers\n\nIn this release, MONAI adds support for training multimodal (vision + language)\ntransformers that can handle both image and textual data. MONAI introduces the\n`TransCheX` model which consists of vision, language, and mixed-modality\ntransformer layers for processing chest X-ray and their corresponding\nradiological reports within a unified framework. In addition to `TransCheX`,\nusers have the flexibility to alter the architecture by varying the number of\nvision, language and mixed-modality layers and customizing the classification\nhead. In addition, the model can be initialized from pre-trained BERT language\nmodels for fine-tuning.\n"
  },
  {
    "path": "docs/source/whatsnew_0_8.md",
    "content": "# What's new in 0.8\n\n- Differentiable neural network topology search\n- Multiple instance learning for digital pathology WSI analysis\n- Self-supervised representation learning\n- Major usability improvements in `monai.transforms`\n\n## Differentiable neural network topology search\nMONAI integrates `DiNTS`: [Differentiable Neural Network Topology Search for 3D\nMedical Image Segmentation](https://arxiv.org/abs/2103.15954). The neural\narchitecture search module supports flexible multi-path topology search with\nhigh search efficiency and budgeted memory usage.\n\nIt provides a topology guaranteed discretization algorithm and a\ndiscretization-aware topology loss for the search stage to minimize the\ndiscretization gap. The module is memory usage aware and is able to search 3D\nnetworks with different GPU memory requirements. For more details, please check out the\n[DiNTS tutorial](https://project-monai.github.io/research/dints.html).\n\n![DiNTS](../images/dints-overview.png)\n\n## Multiple instance learning for digital pathology WSI analysis\nFor [classification of digital pathology whole slide images\n(WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and\nnetwork modules for multiple instance learning. These include self-attention\ntransformer blocks for explicitly accounting of the dependencies between instances\n(image patches) during training. For more details,\nplease check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning).\n\n![multi-instance](../images/mil-patches.jpg)\n\n## Self-supervised representation learning\nMONAI starts to explore self-supervised representation learning in this\nmilestone release. The Vision Transformer has been extended to learn from self-supervised\nreconstruction tasks with various data augmentation and a regularized\ncontrastive loss. The weights of the pre-trained backbone could be used to\nenhance the performance of the novel downstream deep learning tasks.\n\nThe [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining)\nshows how to generate a good set of pre-trained weights using unlabeled data\nwith self-supervised tasks, then use the pre-trained weights to perform\nfine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`.\n\n![self-supervised](../images/ssl_overview.png)\n\n## Major usability improvements in `monai.transforms`\n`monai.transforms` are now more flexible and easy to use in version 0.8.\n- Input type handling and backend APIs are improved to support both\n  NumPy and PyTorch where possible.\n- Visual examples are added to the documentation to illustrate the effects of\n  various image processing.\n- New visualization utilities are provided and enhanced for quick qualitative\n  assessments of the model by visualizing, for example, the volumetric image\n  inputs, segmentation maps, and intermediate feature maps.\n  The visualization tutorial is available for\n  [TensorBoard utility, `matshow3d` and `blend_images`](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb).\n"
  },
  {
    "path": "docs/source/whatsnew_0_9.md",
    "content": "# What's new in 0.9\n\n- MONAI Bundle\n- Object detection in medical images\n- Swin Transformers for 3D medical image analysis\n- New interactive segmentation components\n- MetaTensor API preview\n\n## MONAI Bundle\nMONAI Bundle format defines portable described of deep learning models ([docs](https://monai.readthedocs.io/en/latest/bundle_intro.html)).\nA bundle includes the critical information necessary during a model development life cycle,\nand allows users and programs to understand the purpose and usage of the models.\nThe key benefits of Bundle and the `monai.bundle` APIs are:\n- Standardized packaging format for storing and sharing models,\n- Structured configuration files for fast prototyping of deep learning workflows,\n- Easy to program APIs to separate deep learning hyperparameter settings from the Python code,\n- Flexible config components to allow for different low-level Python implementations,\n- Help to decouple the component details from higher level learning paradigms such as federated learning and AutoML.\n\nMore details are [in the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/bundle).\n\n## Object detection in medical images\nThis release includes essential components for object localization and categorization workflows.\nThe initial developments include 2D and 3D bounding box handling, network blocks and architectures of RetinaNet,\nand common utility modules such as coordinate-based preprocessing, hard negative sampler.\n\nThe application specific modules are made available at\n[monai.apps.detection](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/detection).\n\n![detection workflow](../images/detection.png)\n\n\n## Swin Transformers for 3D medical image analysis\nThe Swin UNETR model is now implemented in MONAI.\n[The tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb)\nshows examples of multi-organ segmentation using this state-of-the-art model,\nwith weights from self-supervised pre-training of\nSwin UNETR encoder (3D Swin Transformer) on a cohort of 5050 CT scans from publicly available datasets.\n[The research-contribution entry](https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR)\nincludes further technical details.\n\n![swin-unetr](../images/swin_unetr.png)\n\n## New interactive segmentation components\nNew components from deep learning interactive segmentation workflows\nsuch as [DeepEdit](https://github.com/Project-MONAI/tutorials/tree/main/deepedit/ignite)\nand NuClick are integrated into the core codebase. They serve as basic building blocks for\n[the latest features in MONAILabel](https://github.com/Project-MONAI/MONAILabel).\n\n![deepedit](../images/deepedit.png)\n\n![nuclick](../images/nuclick.png)\n\n## MetaTensor API preview\nThe metadata associated with the primary imaging modalities is important in many biomedical applications,\nespecially for the data-driven approaches that MONAI has been focusing.\nStarting from this release, we roll out a major refactoring for data representation in MONAI. For the first\nstep, [the core data structures](https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/meta_tensor.py)\n`MetaTensor` and `MetaObj` are implemented as a feature preview.\nFurther developments [on the feature branch](https://github.com/Project-MONAI/MONAI/pull/4539)\nwill be made available in future milestone releases.\n"
  },
  {
    "path": "docs/source/whatsnew_1_0.md",
    "content": "# What's new in 1.0\n\n- Model Zoo\n- Auto3DSeg\n- Federated Learning Client\n- MetaTensor Support for Digital Pathology Workflows\n- Accelerated MRI Reconstruction\n\n\n## Model Zoo\nThe MONAI Model Zoo is a place for researchers and data scientists to use and share the latest and great models from the community.\nUtilizing [the MONAI Bundle format](https://github.com/Project-MONAI/tutorials/tree/main/bundle) makes it easy to quickly get started using any model with any MONAI Framework (Core, Label, or Deploy).\nOr, if you're interested in [contributing your models](https://github.com/project-monai/model-zoo), take a look at our contributing guidelines,\nwhich walks you through the process and requirements for submitting your model.\nFor more details about how to use the models, please see [the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo).\n\n## Auto3DSeg\n![auto3dseg](../images/auto3dseg.png)\n\n[Auto3DSeg](https://project-monai.github.io/apps/auto3dseg.html) is a comprehensive solution for large-scale 3D medical image segmentation.\nIt leverages the latest advances in MONAI\nand GPUs to efficiently develop and deploy algorithms with state-of-the-art performance.\nIt first analyzes the global information such as intensity, dimensionality, and resolution of the dataset,\nthen generates algorithms in MONAI bundle format based on data statistics and [algorithm templates](https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg).\nNext, all algorithms initiate model training to obtain checkpoints with the best validation performance.\nFinally, the ensemble module selects the algorithms via ranking trained checkpoints and creates ensemble predictions.\n\nThe solution offers different levels of user experience for beginners and advanced researchers.\nIt has been tested on large-scale 3D medical imaging datasets in different modalities.\n\n## Federated Learning Client\n![federated-learning](../images/federated.svg)\n\nMONAI now includes the federated learning (FL) client algorithm APIs that are exposed as an abstract base class\nfor defining an algorithm to be run on any federated learning platform.\n[NVIDIA FLARE](https://github.com/NVIDIA/NVFlare), the federated learning platform developed by [NVIDIA](https://www.nvidia.com/en-us/),\nhas already built [the integration piece](https://github.com/NVIDIA/NVFlare/tree/dev/integration/monai) with these new APIs.\nWith [the new federated learning APIs](https://monai.readthedocs.io/en/latest/fl.html), MONAI bundles can seamlessly be extended to a federated paradigm\nand executed using single- or multi-GPU training.\nThe MONAI FL client also allows computing summary data statistics (e.g., intensity histograms) on the datasets defined in the bundle configs.\nThese can be shared and visualized on the FL server, for example, using NVIDIA FLARE's federated statistics operators,\nsee [here](https://github.com/NVIDIA/NVFlare/tree/dev/integration/monai/examples) for an example.\n\nWe welcome other federated learning toolkits to integrate with MONAI FL APIs, building a common foundation for\ncollaborative learning in medical imaging.\n\n## MetaTensor Support for Digital Pathology Workflows\n![pathology](../images/pathology-meta.png)\n\nIn this release, we support MetaTensor in all digital pathology components, and\nmake sure that the future development can benefit from them. With the help of\nMONAI Pathology Working Group, we have standardized a set of metadata\nattributes for patches of images extracted from WSI to ensure reproducibility\nand enhance functionality via relying on a standard set of attributes. The\nfigure above shows all the pathology metadata attributes and their relation to\nMetaTensors. Please see [the tutorials and\nexamples](https://github.com/Project-MONAI/tutorials/tree/main/pathology).\n\n## Accelerated MRI Reconstruction\n![MRI-reconstruction](../images/mri_recon.png)\n\nThis release includes initial components for various popular accelerated MRI reconstruction workflows.\nMany of them are general-purpose tools, for example the [`SSIMLoss`](https://monai.readthedocs.io/en/latest/losses.html?highlight=ssimloss#ssimloss) function.\nSome new functionalities are task-specific, for example [`FastMRIReader`](https://monai.readthedocs.io/en/latest/data.html?highlight=fastmri#monai.apps.reconstruction.fastmri_reader.FastMRIReader).\n\nFor more details, please see [this tutorial](https://github.com/Project-MONAI/tutorials/tree/main/reconstruction/MRI_reconstruction/unet_demo) for using a baseline model for this task,\nand [this tutorial](https://github.com/Project-MONAI/tutorials/tree/main/reconstruction/MRI_reconstruction/varnet_demo) for using a state-of-the-art model.\n"
  },
  {
    "path": "docs/source/whatsnew_1_1.md",
    "content": "# What's new in 1.1\n\n- Digital pathology workflows\n- Experiment management for MONAI bundle\n- Auto3dSeg enhancements\n- New models in MONAI Model Zoo\n- State-of-the-art SurgToolLoc solution\n\n## Digital pathology workflows\n\n![hovernet](../images/hovernet_diagram.png)\n\nHover-Net is a model for simultaneous segmentation and classification of nuclei in multi-tissue histology images (Graham et al. Medical Image Analysis, 2019).\nWe have added support for this model in MONAI by implementing several new components, enhancing existing ones and providing pipelines and examples for training, validation and inference.\n\nAlong with the modules release, new digital pathology analysis tutorials are made available:\n\n- [HoVerNet pipelines](https://github.com/Project-MONAI/tutorials/tree/main/pathology/hovernet) based on MONAI workflows for training, validation and inference\n- [HoVerNet tutorial](https://github.com/Project-MONAI/tutorials/blob/main/pathology/hovernet/hovernet_torch.ipynb) for training, validation and inference\n- NuClick (Interactive Annotation for Pathology) tutorials for [training](https://github.com/Project-MONAI/tutorials/blob/main/pathology/nuclick/nuclick_training_notebook.ipynb)\nand [inference](https://github.com/Project-MONAI/tutorials/blob/main/pathology/nuclick/nuclick_infer.ipynb)\n- Nuclei classification tutorials for [training](https://github.com/Project-MONAI/tutorials/blob/main/pathology/nuclick/nuclei_classification_training_notebook.ipynb)\nand [inference](https://github.com/Project-MONAI/tutorials/blob/main/pathology/nuclick/nuclei_classification_infer.ipynb)\n\n## Experiment management for MONAI bundle\n\n![exp_mgmt](../images/exp_mgmt.png)\n\nIn this release, experiment management features are integrated with MONAI bundle.\nIt provides essential APIs for managing the end-to-end model bundle lifecycle.\nUsers can start tracking experiments by, for example, appending `--tracking \"mlflow\"` to the training or inference commands to enable the MLFlow-based management.\nBy default, MLFlow will track the executed bundle config, model quality measurements, and source code versioning.\nFor more details, please refer to the [tutorial](https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb).\n\n## Auto3dSeg enhancements\n\nMultiple improvements have been added in `Auto3DSeg` both in terms of\nusability and performance.\n- Multi-modality support is added and applied for\nautomated segmentation of the HECKTOR22 challenge dataset, which includes input 3D\nCT and PET images of various resolutions and sizes. A tutorial example of\nrunning Auto3DSeg on the HECKTOR22 challenge dataset is available in MONAI\nTutorials. The tutorial is based on [the HECKTOR22 challenge 1st place solution](https://arxiv.org/abs/2209.10809).\n- A new improved version of `Segresnet` Algo is now available in `AutoRunner`.\nIn this version, data caching is more efficient and the preprocessing transforms are more flexible.\nThe workflow progresses including the timings of steps are written to console output as well as a YAML file.\n- Automatic customization and optimization of the model training configuration\ncan be achieved according to the GPU devices used. The feature\nfocuses on determining parameters including batch size of model\ntraining and sliding-window inference, allocated devices for\ndata in sliding-window inference. For more details about how to enable it, please see [the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/auto3dseg).\n\n## New models in MONAI Model Zoo\n\nNew pretrained models are being created and released [in the Model Zoo](https://project-monai.github.io/model-zoo.html).\nNotably,\n\n- The `mednist_reg` model demonstrates how to build image registration workflows in MONAI bundle\nformat. The model uses a ResNet and spatial transformer for hand X-ray image registration based on\n[the registration_mednist tutorial](https://github.com/Project-MONAI/tutorials/blob/main/2d_registration/registration_mednist.ipynb),\n- `pathology_nuclei_segmentation_and_classification`,\n  `pathology_nuclick_annotation`, and `pathology_nuclei_classification` bundles\n  are built for [digital pathology image\n  analysis](https://github.com/Project-MONAI/model-zoo/tree/dev/models/pathology_nuclei_segmentation_classification).\n\nFor more details about how to use the models, please see [the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo).\n\n## State-of-the-art SurgToolLoc solution\n\n[SurgToolLoc](https://surgtoolloc.grand-challenge.org/Home/) is a part of the\n[EndoVis](https://endovis.grand-challenge.org/) challenge at [MICCAI 2022](https://conferences.miccai.org/2022/en/).\nThe challenge focuses on endoscopic video analysis and is divided into (1) fully supervised tool classification\nand (2) weakly supervised tool classification/localization.\nTeam NVIDIA won prizes by finishing [third](https://surgtoolloc.grand-challenge.org/results/) in both categories.\nThe core components of the solutions [are released in MONAI](https://github.com/Project-MONAI/tutorials/tree/main/competitions/MICCAI/surgtoolloc).\n"
  },
  {
    "path": "docs/source/whatsnew_1_2.md",
    "content": "# What's new in 1.2\n\n- Auto3DSeg enhancements and benchmarks\n- nnUNet integration\n- TensorRT-optimized networks\n- MetricsReloaded integration\n- Bundle workflow APIs\n- Modular patch inference\n- Lazy resampling for preprocessing\n\n## Auto3DSeg enhancements and benchmarks\nAuto3DSeg is an innovative solution for 3D medical image segmentation, leveraging the advancements in MONAI and GPUs for algorithm development and deployment.\nKey improvements in this release include:\n- Several new modules to the training pipelines, such as automated GPU-based hyperparameter scaling, early stopping mechanisms, and dynamic validation frequency.\n- Multi-GPU parallelism has been activated for all GPU-related components including data analysis, model training, and model ensemble, to augment overall performance and capabilities.\n- The algorithms were benchmarked for computational efficiency on the TotalSegmentator dataset, containing over 1,000 CT images.\n- Multi-node training is implemented, reducing model training time significantly.\n\n\n## nnUNet integration\nThe integration introduces a new class, `nnUNetV2Runner`, which leverages Python APIs to facilitate model training, validation,\nand ensemble, thereby simplifying the data conversion process for users.\nBenchmarking results from various public datasets confirm that nnUNetV2Runner performs as expected.\nUsers are required to prepare a data list and create an `input.yaml` file to install and use the system.\nThe framework also allows automatic execution of the entire nnU-Net pipeline, from model training to ensemble,\nwith options to specify the number of epochs. Users can access APIs for training, dataset conversion, data preprocessing, and other components.\nPlease check out [the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/nnunet) for more details.\n\n## TensorRT-optimized networks\n[NVIDIA TensorRT](https://developer.nvidia.com/tensorrt) is an SDK for high-performance deep learning inference,\nincludes a deep learning inference optimizer and runtime that delivers low latency and high throughput for inference applications.\nIt can accelerate the deep learning model forward computation on the NVIDIA GPU.\nIn this release, the `trt_export` API to export the TensorRT engine-based TorchScript model has been integrated into the MONAI bundle.\nUsers can try to export bundles with it. A few bundles in the MONAI model zoo,\nlike the [spleen_ct_segmentation](https://github.com/Project-MONAI/model-zoo/tree/dev/models/spleen_ct_segmentation)\nand [endoscopic_tool_segmentation](https://github.com/Project-MONAI/model-zoo/tree/dev/models/endoscopic_tool_segmentation) bundles,\nhave already been exported and benchmarked. For more details about how to export and benchmark a model,\nplease go to this [tutorial](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/TensorRT_inference_acceleration.ipynb).\n\n\n## MetricsReloaded integration\nMetricsReloaded - a new recommendation framework for biomedical image analysis validation - is released publicly\nvia https://github.com/Project-MONAI/MetricsReloaded. Binary and categorical metrics computing modules are included in this release,\nusing MetricsReloaded as the backend. [Example scripts](https://github.com/Project-MONAI/tutorials/tree/main/modules/metrics_reloaded) are made available to demonstrate the usage.\n\n\n## Bundle workflow APIs\n`BundleWorkflow` abstracts the typical workflows (such as training, evaluation, and inference) of a bundle with three main interfaces:\n`initialize`, `run`, and `finalize`, applications use these APIs to execute a bundle.\nIt unifies the required properties and optional properties for the workflows, downstream applications\ncan invoke the components instead of parsing configs with keys.\nIn this release, `ConfigWorkflow` class is also created for JSON and YAML config-based bundle workflows for improved Pythonic usability.\n\n\n## Modular patch inference\nIn patch inference, patches are extracted from the image, the inference is run on those patches, and outputs are merged\nto construct the result image corresponding to the input image. Although depending on the task, model, and computational/memory resources,\nthe exact implementations of a patch inference may vary, the overall process of splitting, running inference, and merging the results remains the same.\nIn this release, we have created a modular design for patch inference, which defines the overall process while abstracting away the specific\nbehavior of how to split the image into patches, how to pre and post process each patch, and how to merge the output patches.\n\n## Lazy Resampling for preprocessing\nLazy Resampling is a new, experimental feature for preprocessing. It works under\nthe hood along with MONAI transforms to combine adjacent spatial and\ncropping transforms into a single operation. This allows MONAI to reduce the number of data resamples\n a pipeline undergoes. Depending on the preprocessing pipeline, it can potentially:\n\n* reduce processing time\n* reduce processing memory\n* reduce incidental artifacts added by resampling\n* preserve data that would otherwise be cropped and replaced with padding\n\nLazy Resampling pipelines can use a mixture of MONAI and non-MONAI transforms, so\nshould work with almost all existing pipelines simply by setting `lazy=True`\non MONAI `Compose` instances.  See the\n[Lazy Resampling topic](https://monai.readthedocs.io/en/stable/lazy_resampling.html)\nin the documentation for more details.\n"
  },
  {
    "path": "docs/source/whatsnew_1_3.md",
    "content": "# What's new in 1.3\n\n- Bundle usability enhancements\n- Integrating MONAI Generative into MONAI core\n\n\n## Bundle usability enhancements\n\nBased on the experience of building MONAI model zoo and the feedback from the community,\nMONAI 1.3 provides major enhancements in MONAI Bundle usability. These include:\n- Pythonic APIs for Bundle trying to strike a balance between code readability and workflow standardization;\n- Streamlined Bundle building processes with step-by-step guides to the concepts;\n- Various utility functions for fetching and fine-tuning models from [MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo);\n- Various fixes for Bundle syntax and documentation, improved test coverage across the Bundle module and Model Zoo.\n\nFor more details please visit [the Bundle tutorials](https://github.com/Project-MONAI/tutorials/tree/main/bundle) and\n[the Model Zoo demos](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo).\n\n## Integrating MONAI Generative into MONAI Core\n\nMain modules developed at [MONAI GenerativeModels](https://github.com/Project-MONAI/GenerativeModels)\nare being ported into the core codebase, allowing for consistent maintenance and release of the key components for generative AI.\nAs a starting point, loss functions and metrics are integrated into this version.\n"
  },
  {
    "path": "docs/source/whatsnew_1_4.md",
    "content": "# What's new in 1.4\n\n- MAISI: state-of-the-art 3D Latent Diffusion Model\n- VISTA-3D: interactive foundation model for segmenting and anotating human anatomies\n- VISTA-2D: cell segmentation pipeline\n- Integrating MONAI Generative into MONAI core\n- Lazy TensorRT export via `trt_compile`\n- Geometric Data Support\n\n\n## MAISI: state-of-the-art 3D Latent Diffusion Model\n\n![maisi](../images/maisi_train.png)\n\nMAISI (Medical AI for Synthetic Imaging) is a state-of-the-art three-dimensional (3D) Latent Diffusion Model designed for generating high-quality synthetic CT images with or without anatomical annotations. This AI model excels in data augmentation and creating realistic medical imaging data to supplement limited datasets due to privacy concerns or rare conditions. It can also significantly enhance the performance of other medical imaging AI models by generating diverse and realistic training data.\n\nA tutorial for generating large CT images accompanied by corresponding segmentation masks using MAISI is provided within\n[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/generation/maisi).\nIt contains the following features:\n- A foundation Variational Auto-Encoder (VAE) model for latent feature compression that works for both CT and MRI with flexible volume size and voxel size\n- A foundation Diffusion model that can generate large CT volumes up to 512 × 512 × 768 size, with flexible volume size and voxel size\n- A ControlNet to generate image/mask pairs that can improve downstream tasks, with controllable organ/tumor size\n\n## VISTA-3D: state-of-the-art 3D Latent Diffusion Model\n\n![vista-3d](../images/vista3d.png)\n\nVISTA-3D is a specialized interactive foundation model for 3D medical imaging. It excels in providing accurate and adaptable segmentation analysis across anatomies and modalities. Utilizing a multi-head architecture, VISTA-3D adapts to varying conditions and anatomical areas, helping guide users' annotation workflow.\n\nA tutorial showing how to finetune VISTA-3D on spleen dataset is provided within\n[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/vista_3d).\nIt supports three core workflows:\n- Segment everything: Enables whole body exploration, crucial for understanding complex diseases affecting multiple organs and for holistic treatment planning.\n- Segment using class: Provides detailed sectional views based on specific classes, essential for targeted disease analysis or organ mapping, such as tumor identification in critical organs.\n- Segment point prompts: Enhances segmentation precision through user-directed, click-based selection. This interactive approach accelerates the creation of accurate ground-truth data, essential in medical imaging analysis.\n\n## VISTA-2D: cell segmentation pipeline\n\n![vista-2d](../images/vista2d.png)\n\nVISTA-2D is a comprehensive training and inference pipeline for cell segmentation in imaging applications. For more information, refer to this [Blog](https://developer.nvidia.com/blog/advancing-cell-segmentation-and-morphology-analysis-with-nvidia-ai-foundation-model-vista-2d/)\n\nKey features of the model include:\n- A robust deep learning algorithm utilizing transformers\n- Foundational model as compared to specialist models\n- Supports a wide variety of datasets and file formats\n- Capable of handling multiple imaging modalities\n- Multi-GPU and multinode training support\n\nA tutorial demonstrating how to train a cell segmentation model using the MONAI framework on the Cellpose dataset can be found in [`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/vista_2d).\n\n## Integrating MONAI Generative into MONAI Core\n\nKey modules originally developed in the [MONAI GenerativeModels](https://github.com/Project-MONAI/GenerativeModels) repository have been integrated into the core MONAI codebase. This integration ensures consistent maintenance and streamlined release of essential components for generative AI. In this version, all utilities, networks, diffusion schedulers, inferers, and engines have been migrated into the core codebase. Special care has been taken to ensure saved weights from models trained using GenerativeModels can be loaded into those now integrated into core.\n\nAdditionally, several tutorials have been ported and are available within [`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/generation)\n\n## Lazy TensorRT export via `trt_compile`\nThis release expands TensorRT optimization options for MONAI bundles with `trt_compile` API.\nThe existing `trt_export` API requires the user to run a separate export script to prepare a TensorRT engine-based TorchScript model.\n`trt_compile` builds and saves a TensorRT engine the first time a bundle is run and provides limited dependency support.\nIt also allows partial TensorRT export where only a certain submodule is being optimized, which improves usability.\nA few bundles in the MONAI model zoo, like the new [VISTA-3D](https://github.com/Project-MONAI/model-zoo/tree/dev/models/vista3d)\nand [VISTA-2D](https://github.com/Project-MONAI/model-zoo/tree/dev/models/vista2d) bundles, already come with `trt_inference.json` config files which use `trt_compile`.\n\n## Geometric Data Support\n\nMONAI introduces support for geometric data transformations as a key feature. As a starting point, ApplyTransformToPoints transform is added to facilitate matrix operations on points, enabling flexible and efficient handling of geometric transformations. Alongside this, the framework now supports conversions between boxes and points, providing seamless interoperability within detection pipelines. These updates have been integrated into existing pipelines, such as the [detection tutorial](https://github.com/Project-MONAI/tutorials/blob/main/detection) and the [3D registration workflow](https://github.com/Project-MONAI/tutorials/blob/main/3d_registration/learn2reg_nlst_paired_lung_ct.ipynb), leveraging the latest APIs for improved functionality.\n"
  },
  {
    "path": "docs/source/whatsnew_1_5.md",
    "content": "\n# What's new in 1.5\n\n- Support numpy 2.x and Pytorch 2.6\n- MAISI inference accelerate\n- Bundles storage changed to huggingface and correspoinding api updated in core\n- Ported remaining generative tutorials and bundles\n- New tutorials:\n  - [2d_regression/image_restoration.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/2d_regression/image_restoration.ipynb)\n  - [generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb)\n  - [generation/3d_ddpm/3d_ddpm_tutorial.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/generation/3d_ddpm/3d_ddpm_tutorial.ipynb)\n  - [generation/classifier_free_guidance/2d_ddpm_classifier_free_guidance_tutorial.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/generation/classifier_free_guidance/2d_ddpm_classifier_free_guidance_tutorial.ipynb)\n  - [hugging_face/finetune_vista3d_for_hugging_face_pipeline.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/hugging_face/finetune_vista3d_for_hugging_face_pipeline.ipynb)\n  - [hugging_face/hugging_face_pipeline_for_monai.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/hugging_face/hugging_face_pipeline_for_monai.ipynb)\n  - [modules/omniverse/omniverse_integration.ipynb](https://github.com/Project-MONAI/tutorials/blob/main/modules/omniverse/omniverse_integration.ipynb)\n- New Bundles:\n  - [models/cxr_image_synthesis_latent_diffusion_model](https://github.com/Project-MONAI/model-zoo/blob/dev/models/cxr_image_synthesis_latent_diffusion_model)\n  - [models/mednist_ddpm](https://github.com/Project-MONAI/model-zoo/blob/dev/models/mednist_ddpm)\n  - [models/brain_image_synthesis_latent_diffusion_model](https://github.com/Project-MONAI/model-zoo/blob/dev/models/mednist_ddpm)\n  - [hf_models/exaonepath-crc-msi-predictor](https://github.com/Project-MONAI/model-zoo/blob/dev/hf_models/exaonepath-crc-msi-predictor)\n  - All existing bundles are also now [hosted on Huggingface](https://huggingface.co/MONAI)!\n\n## Supported Dependency Versions\n\nThis release adds support for NumPy 2.0 and PyTorch 2.6. We plan to add support for PyTorch 2.7 in an upcoming version once some compatibility issues have been addressed.\n\nAs stated in the updated [README.md](https://github.com/Project-MONAI/MONAI/blob/main/README.md) file, MONAI's policy for the support of dependency versions has been updated for clarity.\n\nMONAI will continue to support [currently supported versions of Python](https://devguide.python.org/versions), and for other dependencies the following apply:\n\n* Major releases of MONAI will have dependency versions stated for them. The current state of the `dev` branch in this repository is the unreleased development version of MONAI which typically will support current versions of dependencies and include updates and bug fixes to do so.\n* PyTorch support covers [the current version](https://github.com/pytorch/pytorch/releases) plus three previous minor versions. If compatibility issues with a PyTorch version and other dependencies arise, support for a version may be delayed until a major release.\n* Our support policy for other dependencies adheres for the most part to [SPEC0](https://scientific-python.org/specs/spec-0000), where dependency versions are supported where possible for up to two years. Discovered vulnerabilities or defects may require certain versions to be explicitly not supported.\n* See the `requirements*.txt` files for dependency version information.\n\n## MAISI Update: Introducing MAISI Version maisi3d-rflow\n\n![maisi](../images/maisi_infer.png)\n\nWe are excited to announce the release of MAISI Version _maisi3d-rflow_. This update brings significant improvements over the previous version, _maisi3d-ddpm_, with a remarkable 33x acceleration in latent diffusion model inference speed. The MAISI VAE remains unchanged. Here are the key differences:\n  1. Scheduler Update:\n\n     * _maisi3d-ddpm_: Uses the basic DDPM noise scheduler.\n\n     * _maisi3d-rflow_: Introduces the Rectified Flow scheduler, allowing diffusion model inference to be 33 times faster.\n  2. Training Data Preparation:\n\n     * _maisi3d-ddpm_: Requires training images to be labeled with body regions (specifically “top_region_index” and “bottom_region_index”).\n\n     * _maisi3d-rflow_: No such labeling is required, making it easier to prepare the training data.\n  3. Image Quality:\n\n     * For the released model weights, _maisi3d-rflow_ generates better-quality images for head regions and smaller output volumes compared to _maisi3d-ddpm_. For other regions, the image quality is comparable.\n  4. Modality Input:\n\n     * _maisi3d-rflow_ adds a new modality input to the diffusion model, offering flexibility for future extensions to other modalities. Currently, this input is set to always equal 1, as this version supports CT generation exclusively.\n"
  },
  {
    "path": "docs/source/whatsnew_1_5_1.md",
    "content": "\n# What's new in 1.5.1\n\nThis is a minor update for MONAI to address security concerns and improve compatibility with the newest PyTorch release.\n\nWith the upgrade support for PyTorch 2.8, MONAI now directly support NVIDIA GeForce RTX 50 series GPUs and other Blackwell-based GPUs!\n\n- Support up to PyTorch 2.8.\n- Security fixes to address advisories [GHSA-x6ww-pf9m-m73m](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-x6ww-pf9m-m73m), [GHSA-6vm5-6jv9-rjpj](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-6vm5-6jv9-rjpj), and [GHSA-p8cm-mm2v-gwjm](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-p8cm-mm2v-gwjm),\n- Updated version of supported Huggingface Transformers library to address security advisories raised for it.\n- Updated Torchvision pretrained network loading to use current arguments.\n- Many minor fixes to identified issues, see release notes for details on merged PRs.\n"
  },
  {
    "path": "docs/source/whatsnew_1_5_2.md",
    "content": "\n# What's new in 1.5.2 🎉🎉\n\nThis is a minor update for MONAI to address a security concern.\n\n- Security fix to address advisory [GHSA-9rg3-9pvr-6p27](https://github.com/Project-MONAI/MONAI/security/advisories/GHSA-9rg3-9pvr-6p27).\n"
  },
  {
    "path": "environment-dev.yml",
    "content": "name: monai\nchannels:\n  - pytorch\n  - defaults\n  - nvidia\n  - conda-forge\ndependencies:\n  - numpy>=1.24,<3.0\n  - pytorch>=2.3.0\n  - torchio\n  - torchvision\n  - pytorch-cuda>=11.6\n  - pip\n  - pip:\n    - -r requirements-dev.txt\n"
  },
  {
    "path": "monai/README.md",
    "content": "# MONAI\n\n* **apps**: high level medical domain specific deep learning applications.\n\n* **auto3dseg**: automated machine learning (AutoML) components for volumetric image analysis.\n\n* **bundle**: components to build the portable self-descriptive model bundle.\n\n* **config**: for system configuration and diagnostic output.\n\n* **csrc**: for C++/CUDA extensions.\n\n* **data**: for the datasets, readers/writers, and synthetic data.\n\n* **engines**: engine-derived classes for extending Ignite behaviour.\n\n* **fl**: federated learning components to allow pipeline integration with any federated learning framework.\n\n* **handlers**: defines handlers for implementing functionality at various stages in the training process.\n\n* **inferers**: defines model inference methods.\n\n* **losses**: classes defining loss functions, which follow the pattern of `torch.nn.modules.loss`.\n\n* **metrics**: defines metric tracking types.\n\n* **networks**: contains network definitions, component definitions, and Pytorch specific utilities.\n\n* **optimizers**: classes defining optimizers, which follow the pattern of `torch.optim`.\n\n* **transforms**: defines data transforms for preprocessing and postprocessing.\n\n* **utils**: generic utilities intended to be implemented in pure Python or using Numpy,\nand not with Pytorch, such as namespace aliasing, auto module loading.\n\n* **visualize**: utilities for data visualization.\n\n* **_extensions**: C++/CUDA extensions to be loaded in a just-in-time manner using `torch.utils.cpp_extension.load`.\n"
  },
  {
    "path": "monai/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport sys\nimport warnings\n\nfrom ._version import get_versions\n\nold_showwarning = warnings.showwarning\n\n\ndef custom_warning_handler(message, category, filename, lineno, file=None, line=None):\n    ignore_files = [\"ignite/handlers/checkpoint\", \"modelopt/torch/quantization/tensor_quant\"]\n    if any(ignore in filename for ignore in ignore_files):\n        return\n    old_showwarning(message, category, filename, lineno, file, line)\n\n\nclass DeprecatedTypesWarningFilter(logging.Filter):\n    def filter(self, record):\n        message_bodies_to_ignore = [\n            \"np.bool8\",\n            \"np.object0\",\n            \"np.int0\",\n            \"np.uint0\",\n            \"np.void0\",\n            \"np.str0\",\n            \"np.bytes0\",\n            \"@validator\",\n            \"@root_validator\",\n            \"class-based `config`\",\n            \"pkg_resources\",\n            \"Implicitly cleaning up\",\n        ]\n        for message in message_bodies_to_ignore:\n            if message in record.getMessage():\n                return False\n        return True\n\n\n# workaround for https://github.com/Project-MONAI/MONAI/issues/8060\n# TODO: remove this workaround after upstream fixed the warning\n# Set the custom warning handler to filter warning\nwarnings.showwarning = custom_warning_handler\n# Get the logger for warnings and add the filter to the logger\nlogging.getLogger(\"py.warnings\").addFilter(DeprecatedTypesWarningFilter())\n\n\nPY_REQUIRED_MAJOR = 3\nPY_REQUIRED_MINOR = 9\n\nversion_dict = get_versions()\n__version__: str = version_dict.get(\"version\", \"0+unknown\")\n__revision_id__: str = version_dict.get(\"full-revisionid\")\ndel get_versions, version_dict\n\n__copyright__ = \"(c) MONAI Consortium\"\n\n__basedir__ = os.path.dirname(__file__)\n\nif sys.version_info.major != PY_REQUIRED_MAJOR or sys.version_info.minor < PY_REQUIRED_MINOR:\n    import warnings\n\n    warnings.warn(\n        f\"MONAI requires Python {PY_REQUIRED_MAJOR}.{PY_REQUIRED_MINOR} or higher. \"\n        f\"But the current Python is: {sys.version}\",\n        category=RuntimeWarning,\n    )\n\n\nfrom .utils.module import load_submodules  # noqa: E402\n\n# handlers_* have some external decorators the users may not have installed\n# *.so files and folder \"_C\" may not exist when the cpp extensions are not compiled\nexcludes = \"|\".join(\n    [\n        \"(^(monai.handlers))\",\n        \"(^(monai.bundle))\",\n        \"(^(monai.fl))\",\n        \"((\\\\.so)$)\",\n        \"(^(monai._C))\",\n        \"(.*(__main__)$)\",\n        \"(.*(video_dataset)$)\",\n        \"(.*(nnunet).*$)\",\n    ]\n)\n\n# load directory modules only, skip loading individual files\nload_submodules(sys.modules[__name__], False, exclude_pattern=excludes)\n\n# load all modules, this will trigger all export decorations\nload_submodules(sys.modules[__name__], True, exclude_pattern=excludes)\n\n__all__ = [\n    \"apps\",\n    \"auto3dseg\",\n    \"bundle\",\n    \"config\",\n    \"data\",\n    \"engines\",\n    \"fl\",\n    \"handlers\",\n    \"inferers\",\n    \"losses\",\n    \"metrics\",\n    \"networks\",\n    \"optimizers\",\n    \"transforms\",\n    \"utils\",\n    \"visualize\",\n]\n\ntry:\n    from .utils.tf32 import detect_default_tf32\n\n    detect_default_tf32()\n    import torch\n\n    # workaround related to https://github.com/Project-MONAI/MONAI/issues/7575\n    if hasattr(torch.cuda.device_count, \"cache_clear\"):\n        torch.cuda.device_count.cache_clear()\nexcept BaseException:\n    from .utils.misc import MONAIEnvVars\n\n    if MONAIEnvVars.debug():\n        raise\n"
  },
  {
    "path": "monai/_extensions/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .loader import load_module\n"
  },
  {
    "path": "monai/_extensions/gmm/gmm.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n\n#include \"gmm.h\"\n\npy::tuple init() {\n  torch::Tensor gmm_tensor =\n      torch::zeros({GMM_COUNT, GMM_COMPONENT_COUNT}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n  torch::Tensor scratch_tensor = torch::empty({1}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n  return py::make_tuple(gmm_tensor, scratch_tensor);\n}\n\nvoid learn(\n    torch::Tensor gmm_tensor,\n    torch::Tensor scratch_tensor,\n    torch::Tensor input_tensor,\n    torch::Tensor label_tensor) {\n  c10::DeviceType device_type = input_tensor.device().type();\n\n  unsigned int batch_count = input_tensor.size(0);\n  unsigned int element_count = input_tensor.stride(1);\n\n  unsigned int scratch_size =\n      batch_count * (element_count + GMM_COMPONENT_COUNT * GMM_COUNT * (element_count / (32 * 32)));\n\n  if (scratch_tensor.size(0) < scratch_size) {\n    scratch_tensor.resize_({scratch_size});\n  }\n\n  float* gmm = gmm_tensor.data_ptr<float>();\n  float* scratch = scratch_tensor.data_ptr<float>();\n  float* input = input_tensor.data_ptr<float>();\n  int* labels = label_tensor.data_ptr<int>();\n\n  if (device_type == torch::kCUDA) {\n    learn_cuda(input, labels, gmm, scratch, batch_count, element_count);\n  } else {\n    learn_cpu(input, labels, gmm, scratch, batch_count, element_count);\n  }\n}\n\ntorch::Tensor apply(torch::Tensor gmm_tensor, torch::Tensor input_tensor) {\n  c10::DeviceType device_type = input_tensor.device().type();\n\n  unsigned int dim = input_tensor.dim();\n  unsigned int batch_count = input_tensor.size(0);\n  unsigned int element_count = input_tensor.stride(1);\n\n  auto output_size = input_tensor.sizes().vec();\n  output_size[1] = MIXTURE_COUNT;\n  torch::Tensor output_tensor =\n      torch::empty(c10::IntArrayRef(output_size), torch::dtype(torch::kFloat32).device(device_type));\n\n  const float* gmm = gmm_tensor.data_ptr<float>();\n  const float* input = input_tensor.data_ptr<float>();\n  float* output = output_tensor.data_ptr<float>();\n\n  if (device_type == torch::kCUDA) {\n    apply_cuda(gmm, input, output, batch_count, element_count);\n  } else {\n    apply_cpu(gmm, input, output, batch_count, element_count);\n  }\n\n  return output_tensor;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"init\", torch::wrap_pybind_function(init));\n  m.def(\"learn\", torch::wrap_pybind_function(learn));\n  m.def(\"apply\", torch::wrap_pybind_function(apply));\n}\n"
  },
  {
    "path": "monai/_extensions/gmm/gmm.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#if !defined(CHANNEL_COUNT) || !defined(MIXTURE_COUNT) || !defined(MIXTURE_SIZE)\n#error Definition of CHANNEL_COUNT, MIXTURE_COUNT, and MIXTURE_SIZE required\n#endif\n\n#if CHANNEL_COUNT < 1 || MIXTURE_COUNT < 1 || MIXTURE_SIZE < 1\n#error CHANNEL_COUNT, MIXTURE_COUNT, and MIXTURE_SIZE must be positive\n#endif\n\n#define MATRIX_COMPONENT_COUNT ((CHANNEL_COUNT + 1) * (CHANNEL_COUNT + 2) / 2)\n#define SUB_MATRIX_COMPONENT_COUNT (CHANNEL_COUNT * (CHANNEL_COUNT + 1) / 2)\n#define GMM_COMPONENT_COUNT (MATRIX_COMPONENT_COUNT + 1)\n#define GMM_COUNT (MIXTURE_COUNT * MIXTURE_SIZE)\n\nvoid learn_cpu(\n    const float* input,\n    const int* labels,\n    float* gmm,\n    float* scratch_memory,\n    unsigned int batch_count,\n    unsigned int element_count);\nvoid apply_cpu(\n    const float* gmm,\n    const float* input,\n    float* output,\n    unsigned int batch_count,\n    unsigned int element_count);\n\nvoid learn_cuda(\n    const float* input,\n    const int* labels,\n    float* gmm,\n    float* scratch_memory,\n    unsigned int batch_count,\n    unsigned int element_count);\nvoid apply_cuda(\n    const float* gmm,\n    const float* input,\n    float* output,\n    unsigned int batch_count,\n    unsigned int element_count);\n"
  },
  {
    "path": "monai/_extensions/gmm/gmm_cpu.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <stdexcept>\n\n#include \"gmm.h\"\n\nvoid learn_cpu(\n    const float* input,\n    const int* labels,\n    float* gmm,\n    float* scratch_memory,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  throw std::invalid_argument(\"GMM received a cpu tensor but is not yet implemented for the cpu\");\n}\n\nvoid apply_cpu(\n    const float* gmm,\n    const float* input,\n    float* output,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  throw std::invalid_argument(\"GMM received a cpu tensor but is not yet implemented for the cpu\");\n}\n"
  },
  {
    "path": "monai/_extensions/gmm/gmm_cuda.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include \"gmm.h\"\n\n#include \"gmm_cuda_linalg.cuh\"\n\n#define EPSILON 1e-5\n#define BLOCK_SIZE 32\n#define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1)\n#ifdef __HIP_PLATFORM_AMD__\n#define __SHFL_DOWN(a, b) __shfl_down(a, b)\n#define __SHFL_XOR(a, b) __shfl_xor(a, b)\n#else\n#define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b)\n#define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b)\n#endif\n\ntemplate <int warp_count, int load_count>\n__global__ void CovarianceReductionKernel(\n    int gaussian_index,\n    const float* g_image,\n    const int* g_alpha,\n    float* g_matrices,\n    int element_count) {\n  constexpr int block_size = warp_count * 32;\n\n  __shared__ float s_matrix_component[warp_count];\n\n  int batch_index = blockIdx.z;\n\n  const float* g_batch_image = g_image + batch_index * element_count * CHANNEL_COUNT;\n  const int* g_batch_alpha = g_alpha + batch_index * element_count;\n  float* g_batch_matrices = g_matrices + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT * gridDim.x;\n\n  int local_index = threadIdx.x;\n  int block_index = blockIdx.x;\n  int warp_index = local_index >> 5;\n  int lane_index = local_index & 31;\n  int global_index = local_index + block_index * block_size * load_count;\n  int matrix_offset = (gaussian_index * gridDim.x + block_index) * GMM_COMPONENT_COUNT;\n\n  float matrix[MATRIX_COMPONENT_COUNT];\n\n  for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) {\n    matrix[i] = 0;\n  }\n\n  for (int load = 0; load < load_count; load++) {\n    global_index += load * block_size;\n\n    if (global_index < element_count) {\n      int my_alpha = g_batch_alpha[global_index];\n\n      if (my_alpha != -1) {\n        if (gaussian_index == (my_alpha & 15) + (my_alpha >> 4) * MIXTURE_COUNT) {\n          float feature[CHANNEL_COUNT + 1];\n\n          feature[0] = 1;\n\n          for (int i = 0; i < CHANNEL_COUNT; i++) {\n            feature[i + 1] = g_batch_image[global_index + i * element_count];\n          }\n\n          for (int index = 0, i = 0; i < CHANNEL_COUNT + 1; i++) {\n            for (int j = i; j < CHANNEL_COUNT + 1; j++, index++) {\n              matrix[index] += feature[i] * feature[j];\n            }\n          }\n        }\n      }\n    }\n  }\n\n  __syncthreads();\n\n  for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) {\n    float matrix_component = matrix[i];\n    matrix_component += __SHFL_DOWN(matrix_component, 16);\n    matrix_component += __SHFL_DOWN(matrix_component, 8);\n    matrix_component += __SHFL_DOWN(matrix_component, 4);\n    matrix_component += __SHFL_DOWN(matrix_component, 2);\n    matrix_component += __SHFL_DOWN(matrix_component, 1);\n    if (lane_index == 0) {\n      s_matrix_component[warp_index] = matrix_component;\n    }\n\n    __syncthreads();\n\n    if (warp_index == 0) {\n      matrix_component = s_matrix_component[lane_index];\n      if (warp_count >= 32) {\n        matrix_component += __SHFL_DOWN(matrix_component, 16);\n      }\n      if (warp_count >= 16) {\n        matrix_component += __SHFL_DOWN(matrix_component, 8);\n      }\n      if (warp_count >= 8) {\n        matrix_component += __SHFL_DOWN(matrix_component, 4);\n      }\n      if (warp_count >= 4) {\n        matrix_component += __SHFL_DOWN(matrix_component, 2);\n      }\n      if (warp_count >= 2) {\n        matrix_component += __SHFL_DOWN(matrix_component, 1);\n      }\n      if (lane_index == 0) {\n        g_batch_matrices[matrix_offset + i] = matrix_component;\n      }\n    }\n\n    __syncthreads();\n  }\n}\n\ntemplate <int warp_count, bool invert_matrix>\n__global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_gmm, int matrix_count) {\n  constexpr int block_size = warp_count * 32;\n\n  __shared__ float s_matrix_component[warp_count];\n  __shared__ float s_gmm[GMM_COMPONENT_COUNT];\n\n  int batch_index = blockIdx.z;\n\n  const float* g_batch_matrices = g_matrices + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT * matrix_count;\n  float* g_batch_gmm = g_gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;\n\n  int local_index = threadIdx.x;\n  int warp_index = local_index >> 5;\n  int lane_index = local_index & 31;\n  int gmm_index = blockIdx.x;\n  int matrix_offset = gmm_index * matrix_count;\n\n  int load_count = TILE(matrix_count, block_size);\n\n  float norm_factor = 1.0f;\n\n  for (int index = 0, i = 0; i < CHANNEL_COUNT + 1; i++) {\n    for (int j = i; j < CHANNEL_COUNT + 1; j++, index++) {\n      float matrix_component = 0.0f;\n\n      for (int load = 0; load < load_count; load++) {\n        int matrix_index = local_index + load * block_size;\n\n        if (matrix_index < matrix_count) {\n          matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];\n        }\n      }\n      matrix_component += __SHFL_DOWN(matrix_component, 16);\n      matrix_component += __SHFL_DOWN(matrix_component, 8);\n      matrix_component += __SHFL_DOWN(matrix_component, 4);\n      matrix_component += __SHFL_DOWN(matrix_component, 2);\n      matrix_component += __SHFL_DOWN(matrix_component, 1);\n      if (lane_index == 0) {\n        s_matrix_component[warp_index] = matrix_component;\n      }\n\n      __syncthreads();\n\n      if (warp_index == 0) {\n        matrix_component = s_matrix_component[lane_index];\n        if (warp_count >= 32) {\n          matrix_component += __SHFL_DOWN(matrix_component, 16);\n        }\n        if (warp_count >= 16) {\n          matrix_component += __SHFL_DOWN(matrix_component, 8);\n        }\n        if (warp_count >= 8) {\n          matrix_component += __SHFL_DOWN(matrix_component, 4);\n        }\n        if (warp_count >= 4) {\n          matrix_component += __SHFL_DOWN(matrix_component, 2);\n        }\n        if (warp_count >= 2) {\n          matrix_component += __SHFL_DOWN(matrix_component, 1);\n        }\n        if (lane_index == 0) {\n          float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j];\n\n          if (i != 0 && i == j) {\n            constant -= EPSILON;\n          }\n\n          s_gmm[index] = norm_factor * matrix_component - constant;\n\n          if (index == 0 && matrix_component > 0) {\n            norm_factor = 1.0f / matrix_component;\n          }\n        }\n      }\n\n      __syncthreads();\n    }\n  }\n\n  float* matrix = s_gmm + (CHANNEL_COUNT + 1);\n  float* det_ptr = s_gmm + MATRIX_COMPONENT_COUNT;\n\n  if (local_index == 0) {\n    float square_mat[CHANNEL_COUNT][CHANNEL_COUNT];\n    float cholesky_mat[CHANNEL_COUNT][CHANNEL_COUNT];\n\n    for (int i = 0; i < CHANNEL_COUNT; i++) {\n      for (int j = 0; j < CHANNEL_COUNT; j++) {\n        square_mat[i][j] = 0.0f;\n        cholesky_mat[i][j] = 0.0f;\n      }\n    }\n\n    to_square(matrix, square_mat);\n    cholesky(square_mat, cholesky_mat);\n\n    *det_ptr = chol_det(cholesky_mat);\n\n    if (invert_matrix) {\n      chol_inv(cholesky_mat, square_mat);\n      to_triangle(square_mat, matrix);\n    }\n  }\n\n  if (local_index < GMM_COMPONENT_COUNT) {\n    g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + local_index] = s_gmm[local_index];\n  }\n}\n\nstruct GMMSplit_t {\n  int idx;\n  float threshold;\n  float eigenvector[CHANNEL_COUNT];\n};\n\n// 1 Block, 32xMIXTURE_COUNT\n__global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {\n  int batch_index = blockIdx.z;\n\n  float* g_batch_gmm = gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;\n  GMMSplit_t* g_batch_gmmSplit = gmmSplit + batch_index * MIXTURE_COUNT;\n\n  int gmm_idx = threadIdx.x * MIXTURE_COUNT + threadIdx.y;\n\n  float eigenvalue = 0;\n  float eigenvector[CHANNEL_COUNT];\n\n  if (threadIdx.x < gmmK) {\n    float* matrix = g_batch_gmm + gmm_idx * GMM_COMPONENT_COUNT + (CHANNEL_COUNT + 1);\n    largest_eigenpair(matrix, eigenvector, &eigenvalue);\n  }\n\n  float max_value = eigenvalue;\n  max_value = max(max_value, __SHFL_XOR(max_value, 16));\n  max_value = max(max_value, __SHFL_XOR(max_value, 8));\n  max_value = max(max_value, __SHFL_XOR(max_value, 4));\n  max_value = max(max_value, __SHFL_XOR(max_value, 2));\n  max_value = max(max_value, __SHFL_XOR(max_value, 1));\n  if (max_value == eigenvalue) {\n    GMMSplit_t split;\n\n    float* average_feature = gmm + gmm_idx * GMM_COMPONENT_COUNT + 1;\n\n    split.idx = threadIdx.x;\n    split.threshold = scalar_prod(average_feature, eigenvector);\n\n    for (int i = 0; i < CHANNEL_COUNT; i++) {\n      split.eigenvector[i] = eigenvector[i];\n    }\n\n    g_batch_gmmSplit[threadIdx.y] = split;\n  }\n}\n\n#define DO_SPLIT_DEGENERACY 4\n\n__global__ void GMMDoSplit(const GMMSplit_t* gmmSplit, int k, const float* image, int* alpha, int element_count) {\n  __shared__ GMMSplit_t s_gmmSplit[MIXTURE_COUNT];\n\n  int batch_index = blockIdx.z;\n\n  const GMMSplit_t* g_batch_gmmSplit = gmmSplit + batch_index * MIXTURE_COUNT;\n  const float* g_batch_image = image + batch_index * element_count * CHANNEL_COUNT;\n  int* g_batch_alpha = alpha + batch_index * element_count;\n\n  int* s_linear = (int*)s_gmmSplit;\n  int* g_linear = (int*)g_batch_gmmSplit;\n\n  if (threadIdx.x < MIXTURE_COUNT * sizeof(GMMSplit_t)) {\n    s_linear[threadIdx.x] = g_linear[threadIdx.x];\n  }\n\n  __syncthreads();\n\n  int index = threadIdx.x + blockIdx.x * BLOCK_SIZE * DO_SPLIT_DEGENERACY;\n\n  for (int i = 0; i < DO_SPLIT_DEGENERACY; i++) {\n    index += BLOCK_SIZE;\n\n    if (index < element_count) {\n      int my_alpha = g_batch_alpha[index];\n\n      if (my_alpha != -1) {\n        int select = my_alpha & 15;\n        int gmm_idx = my_alpha >> 4;\n\n        if (gmm_idx == s_gmmSplit[select].idx) {\n          // in the split cluster now\n          float feature[CHANNEL_COUNT];\n\n          for (int i = 0; i < CHANNEL_COUNT; i++) {\n            feature[i] = g_batch_image[index + i * element_count];\n          }\n\n          float value = scalar_prod(s_gmmSplit[select].eigenvector, feature);\n\n          if (value > s_gmmSplit[select].threshold) {\n            // assign pixel to new cluster\n            g_batch_alpha[index] = k + select;\n          }\n        }\n      }\n    }\n  }\n}\n\n// Single block, 32xMIXTURE_COUNT\n__global__ void GMMcommonTerm(float* g_gmm) {\n  int batch_index = blockIdx.z;\n\n  float* g_batch_gmm = g_gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;\n\n  int gmm_index = (threadIdx.x * MIXTURE_COUNT) + threadIdx.y;\n\n  float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f;\n\n  float sum = gmm_n;\n  sum += __SHFL_XOR(sum, 1);\n  sum += __SHFL_XOR(sum, 2);\n  sum += __SHFL_XOR(sum, 4);\n  sum += __SHFL_XOR(sum, 8);\n  sum += __SHFL_XOR(sum, 16);\n\n  if (threadIdx.x < MIXTURE_SIZE) {\n    float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;\n    float commonTerm = det > 0.0f ? gmm_n / (sqrtf(det) * sum) : gmm_n / sum;\n\n    g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] = commonTerm;\n  }\n}\n\n__device__ float GMMTerm(float* feature, const float* gmm) {\n  const float* average_feature = gmm + 1;\n  const float* matrix = gmm + CHANNEL_COUNT + 1;\n\n  float diff[CHANNEL_COUNT];\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    diff[i] = feature[i] - average_feature[i];\n  }\n\n  float value = 0.0f;\n\n  for (int index = 0, i = 0; i < CHANNEL_COUNT; i++) {\n    for (int j = i; j < CHANNEL_COUNT; j++, index++) {\n      float term = diff[i] * diff[j] * matrix[index];\n\n      value += i == j ? term : 2 * term;\n    }\n  }\n\n  return gmm[MATRIX_COMPONENT_COUNT] * expf(-0.5 * value);\n}\n\n__global__ void GMMDataTermKernel(const float* image, const float* gmm, float* output, int element_count) {\n  int batch_index = blockIdx.z;\n\n  const float* g_batch_image = image + batch_index * element_count * CHANNEL_COUNT;\n  const float* g_batch_gmm = gmm + batch_index * GMM_COUNT * GMM_COMPONENT_COUNT;\n  float* g_batch_output = output + batch_index * element_count * MIXTURE_COUNT;\n\n  int index = blockIdx.x * blockDim.x + threadIdx.x;\n\n  if (index >= element_count)\n    return;\n\n  float feature[CHANNEL_COUNT];\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    feature[i] = g_batch_image[index + i * element_count];\n  }\n\n  float weights[MIXTURE_COUNT];\n  float weight_total = 0.0f;\n\n  for (int i = 0; i < MIXTURE_COUNT; i++) {\n    float mixture_weight = 0.0f;\n\n    for (int j = 0; j < MIXTURE_SIZE; j++) {\n      mixture_weight += GMMTerm(feature, &g_batch_gmm[(MIXTURE_COUNT * j + i) * GMM_COMPONENT_COUNT]);\n    }\n\n    weights[i] = mixture_weight;\n    weight_total += mixture_weight;\n  }\n\n  for (int i = 0; i < MIXTURE_COUNT; i++) {\n    // protecting against pixels with 0 in all mixtures\n    float final_weight = weight_total > 0.0f ? weights[i] / weight_total : 0.0f;\n    g_batch_output[index + i * element_count] = final_weight;\n  }\n}\n\n#define THREADS 512\n#define WARPS 16\n#define BLOCK (WARPS << 5)\n#define LOAD 4\n\nvoid GMMInitialize(\n    const float* image,\n    int* alpha,\n    float* gmm,\n    float* scratch_mem,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  unsigned int block_count = TILE(element_count, BLOCK * LOAD);\n\n  float* block_gmm_scratch = scratch_mem;\n  GMMSplit_t* gmm_split_scratch = (GMMSplit_t*)scratch_mem;\n\n  int gmm_N = MIXTURE_COUNT * MIXTURE_SIZE;\n\n  for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {\n    for (unsigned int i = 0; i < k; ++i) {\n      CovarianceReductionKernel<WARPS, LOAD>\n          <<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);\n    }\n\n    CovarianceFinalizationKernel<WARPS, false><<<dim3(k, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);\n\n    GMMFindSplit<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(\n        gmm_split_scratch, k / MIXTURE_COUNT, gmm);\n    GMMDoSplit<<<dim3(TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count), BLOCK_SIZE>>>(\n        gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count);\n  }\n}\n\nvoid GMMUpdate(\n    const float* image,\n    int* alpha,\n    float* gmm,\n    float* scratch_mem,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  unsigned int block_count = TILE(element_count, BLOCK * LOAD);\n\n  float* block_gmm_scratch = scratch_mem;\n\n  unsigned int gmm_N = MIXTURE_COUNT * MIXTURE_SIZE;\n\n  for (unsigned int i = 0; i < gmm_N; ++i) {\n    CovarianceReductionKernel<WARPS, LOAD>\n        <<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);\n  }\n\n  CovarianceFinalizationKernel<WARPS, true>\n      <<<dim3(gmm_N, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);\n\n  GMMcommonTerm<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);\n}\n\nvoid GMMDataTerm(\n    const float* image,\n    const float* gmm,\n    float* output,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  dim3 block(BLOCK_SIZE, 1);\n  dim3 grid(TILE(element_count, BLOCK_SIZE), 1, batch_count);\n\n  GMMDataTermKernel<<<grid, block>>>(image, gmm, output, element_count);\n}\n\nvoid learn_cuda(\n    const float* input,\n    const int* labels,\n    float* gmm,\n    float* scratch_memory,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  int* alpha = (int*)scratch_memory;\n  float* scratch_mem = scratch_memory + batch_count * element_count;\n\n  cudaMemcpyAsync(alpha, labels, batch_count * element_count * sizeof(int), cudaMemcpyDeviceToDevice);\n\n  GMMInitialize(input, alpha, gmm, scratch_mem, batch_count, element_count);\n  GMMUpdate(input, alpha, gmm, scratch_mem, batch_count, element_count);\n}\n\nvoid apply_cuda(\n    const float* gmm,\n    const float* input,\n    float* output,\n    unsigned int batch_count,\n    unsigned int element_count) {\n  GMMDataTerm(input, gmm, output, batch_count, element_count);\n}\n"
  },
  {
    "path": "monai/_extensions/gmm/gmm_cuda_linalg.cuh",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n__device__ void to_square(float in[SUB_MATRIX_COMPONENT_COUNT], float out[CHANNEL_COUNT][CHANNEL_COUNT]) {\n  for (int index = 0, i = 0; i < CHANNEL_COUNT; i++) {\n    for (int j = i; j < CHANNEL_COUNT; j++, index++) {\n      out[i][j] = in[index];\n      out[j][i] = in[index];\n    }\n  }\n}\n\n__device__ void to_triangle(float in[CHANNEL_COUNT][CHANNEL_COUNT], float out[SUB_MATRIX_COMPONENT_COUNT]) {\n  for (int index = 0, i = 0; i < CHANNEL_COUNT; i++) {\n    for (int j = i; j < CHANNEL_COUNT; j++, index++) {\n      out[index] = in[j][i];\n    }\n  }\n}\n\n__device__ void cholesky(float in[CHANNEL_COUNT][CHANNEL_COUNT], float out[CHANNEL_COUNT][CHANNEL_COUNT]) {\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    for (int j = 0; j < i + 1; j++) {\n      float sum = 0.0f;\n\n      for (int k = 0; k < j; k++) {\n        sum += out[i][k] * out[j][k];\n      }\n\n      if (i == j) {\n        out[i][j] = sqrtf(in[i][i] - sum);\n      } else {\n        out[i][j] = (in[i][j] - sum) / out[j][j];\n      }\n    }\n  }\n}\n\n__device__ float chol_det(float in[CHANNEL_COUNT][CHANNEL_COUNT]) {\n  float det = 1.0f;\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    det *= in[i][i];\n  }\n\n  return det * det;\n}\n\n__device__ void chol_inv(float in[CHANNEL_COUNT][CHANNEL_COUNT], float out[CHANNEL_COUNT][CHANNEL_COUNT]) {\n  // Invert cholesky matrix\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    in[i][i] = 1.0f / (in[i][i] + 0.0001f);\n\n    for (int j = 0; j < i; j++) {\n      float sum = 0.0f;\n\n      for (int k = j; k < i; k++) {\n        sum += in[i][k] * in[k][j];\n      }\n\n      in[i][j] = -in[i][i] * sum;\n    }\n  }\n\n  // Dot with transpose of self\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    for (int j = 0; j < CHANNEL_COUNT; j++) {\n      out[i][j] = 0.0f;\n\n      for (int k = max(i, j); k < CHANNEL_COUNT; k++) {\n        out[i][j] += in[k][i] * in[k][j];\n      }\n    }\n  }\n}\n\n__device__ void normalize(float* v) {\n  float norm = 0.0f;\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    norm += v[i] * v[i];\n  }\n\n  norm = 1.0f / sqrtf(norm);\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    v[i] *= norm;\n  }\n}\n\n__device__ float scalar_prod(float* a, float* b) {\n  float product = 0.0f;\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    product += a[i] * b[i];\n  }\n\n  return product;\n}\n\n__device__ void largest_eigenpair(const float* M, float* evec, float* eval) {\n  float scratch[CHANNEL_COUNT];\n\n  for (int i = 0; i < CHANNEL_COUNT; i++) {\n    scratch[i] = i + 1;\n  }\n\n  for (int itr = 0; itr < 10; itr++) {\n    *eval = 0.0f;\n\n    for (int i = 0; i < CHANNEL_COUNT; i++) {\n      int index = i;\n\n      evec[i] = 0.0f;\n\n      for (int j = 0; j < CHANNEL_COUNT; j++) {\n        evec[i] += M[index] * scratch[j];\n\n        if (j < i) {\n          index += CHANNEL_COUNT - (j + 1);\n        } else {\n          index += 1;\n        }\n      }\n\n      *eval = max(*eval, evec[i]);\n    }\n\n    for (int i = 0; i < CHANNEL_COUNT; i++) {\n      evec[i] /= *eval;\n      scratch[i] = evec[i];\n    }\n  }\n}\n"
  },
  {
    "path": "monai/_extensions/loader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport platform\nfrom _thread import interrupt_main\nfrom contextlib import contextmanager\nfrom glob import glob\nfrom os import path\nfrom threading import Timer\nfrom types import ModuleType\n\nimport torch\n\nfrom monai.utils.module import get_torch_version_tuple, optional_import\n\ndir_path = path.dirname(path.realpath(__file__))\n\n\n@contextmanager\ndef timeout(time, message):\n    timer = None\n    try:\n        timer = Timer(time, interrupt_main)\n        timer.daemon = True\n        timer.start()\n        yield\n    except KeyboardInterrupt as e:\n        if timer is not None and timer.is_alive():\n            raise e  # interrupt from user?\n        raise TimeoutError(message) from e\n    finally:\n        if timer is not None:\n            try:\n                timer.cancel()\n            finally:\n                pass\n\n\ndef load_module(\n    module_name: str, defines: dict | None = None, verbose_build: bool = False, build_timeout: int = 300\n) -> ModuleType:\n    \"\"\"\n    Handles the loading of c++ extension modules.\n\n    Args:\n        module_name: Name of the module to load.\n            Must match the name of the relevant source directory in the `_extensions` directory.\n        defines: Dictionary containing names and values of compilation defines.\n        verbose_build: Set to true to enable build logging.\n        build_timeout: Time in seconds before the build will throw an exception to prevent hanging.\n    \"\"\"\n\n    # Ensuring named module exists in _extensions directory.\n    module_dir = path.join(dir_path, module_name)\n    if not path.exists(module_dir):\n        raise ValueError(f\"No extension module named {module_name}\")\n\n    platform_str = f\"_{platform.system()}_{platform.python_version()}_\"\n    platform_str += \"\".join(f\"{v}\" for v in get_torch_version_tuple()[:2])\n    # Adding configuration to module name.\n    if defines is not None:\n        module_name = \"_\".join([module_name] + [f\"{v}\" for v in defines.values()])\n\n    # Gathering source files.\n    source = glob(path.join(module_dir, \"**\", \"*.cpp\"), recursive=True)\n    if torch.cuda.is_available():\n        source += glob(path.join(module_dir, \"**\", \"*.cu\"), recursive=True)\n        platform_str += f\"_{torch.version.cuda}\"\n\n    # Constructing compilation argument list.\n    define_args = [] if not defines else [f\"-D {key}={defines[key]}\" for key in defines]\n\n    # Ninja may be blocked by something out of our control.\n    # This will error if the build takes longer than expected.\n    with timeout(build_timeout, \"Build appears to be blocked. Is there a stopped process building the same extension?\"):\n        load, _ = optional_import(\"torch.utils.cpp_extension\", name=\"load\")  # main trigger some JIT config in pytorch\n        # This will either run the build or return the existing .so object.\n        name = module_name + platform_str.replace(\".\", \"_\")\n        module = load(\n            name=name, sources=source, extra_cflags=define_args, extra_cuda_cflags=define_args, verbose=verbose_build\n        )\n\n    return module  # type: ignore[no-any-return]\n"
  },
  {
    "path": "monai/_version.py",
    "content": "\n# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provided by githubs download-from-tag\n# feature). Distribution tarballs (built by setup.py sdist) and build\n# directories (produced by setup.py build) will contain a much shorter file\n# that just contains the computed version number.\n\n# This file is released into the public domain. Generated by\n# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer)\n\n\"\"\"Git implementation of _version.py.\"\"\"\n\nimport errno\nimport os\nimport re\nimport subprocess\nimport sys\nfrom typing import Callable, Dict\nimport functools\n\n\ndef get_keywords():\n    \"\"\"Get the keywords needed to look up the version information.\"\"\"\n    # these strings will be replaced by git during git-archive.\n    # setup.py/versioneer.py will grep for the variable names, so they must\n    # each be defined on a line of their own. _version.py will just call\n    # get_keywords().\n    git_refnames = \"$Format:%d$\"\n    git_full = \"$Format:%H$\"\n    git_date = \"$Format:%ci$\"\n    keywords = {\"refnames\": git_refnames, \"full\": git_full, \"date\": git_date}\n    return keywords\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n\ndef get_config():\n    \"\"\"Create, populate and return the VersioneerConfig() object.\"\"\"\n    # these strings are filled in when 'setup.py versioneer' creates\n    # _version.py\n    cfg = VersioneerConfig()\n    cfg.VCS = \"git\"\n    cfg.style = \"pep440\"\n    cfg.tag_prefix = \"\"\n    cfg.parentdir_prefix = \"\"\n    cfg.versionfile_source = \"monai/_version.py\"\n    cfg.verbose = False\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\nLONG_VERSION_PY: Dict[str, str] = {}\nHANDLERS: Dict[str, Dict[str, Callable]] = {}\n\n\ndef register_vcs_handler(vcs, method):  # decorator\n    \"\"\"Create decorator to mark a method as the handler of a VCS.\"\"\"\n    def decorate(f):\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,\n                env=None):\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    process = None\n\n    popen_kwargs = {}\n    if sys.platform == \"win32\":\n        # This hides the console window if pythonw.exe is used\n        startupinfo = subprocess.STARTUPINFO()\n        startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW\n        popen_kwargs[\"startupinfo\"] = startupinfo\n\n    for command in commands:\n        try:\n            dispcmd = str([command] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            process = subprocess.Popen([command] + args, cwd=cwd, env=env,\n                                       stdout=subprocess.PIPE,\n                                       stderr=(subprocess.PIPE if hide_stderr\n                                               else None), **popen_kwargs)\n            break\n        except OSError:\n            e = sys.exc_info()[1]\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %s\" % dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %s\" % (commands,))\n        return None, None\n    stdout = process.communicate()[0].strip().decode()\n    if process.returncode != 0:\n        if verbose:\n            print(\"unable to run %s (error)\" % dispcmd)\n            print(\"stdout was %s\" % stdout)\n        return None, process.returncode\n    return stdout, process.returncode\n\n\ndef versions_from_parentdir(parentdir_prefix, root, verbose):\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for _ in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        rootdirs.append(root)\n        root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %s but none started with prefix %s\" %\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs):\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords = {}\n    try:\n        with open(versionfile_abs, \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(\"git_refnames =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"refnames\"] = mo.group(1)\n                if line.strip().startswith(\"git_full =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"full\"] = mo.group(1)\n                if line.strip().startswith(\"git_date =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"date\"] = mo.group(1)\n    except OSError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(keywords, tag_prefix, verbose):\n    \"\"\"Get version information from git keywords.\"\"\"\n    if \"refnames\" not in keywords:\n        raise NotThisMethod(\"Short version file found\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # Use only the last line.  Previous lines may contain GPG signature\n        # information.\n        date = date.splitlines()[-1]\n\n        # git-2.2.0 added \"%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = {r.strip() for r in refnames.strip(\"()\").split(\",\")}\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = {r for r in refs if re.search(r'\\d', r)}\n        if verbose:\n            print(\"discarding '%s', no digits\" % \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %s\" % \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            # Filter out refs that exactly match prefix or that don't start\n            # with a number once the prefix is stripped (mostly a concern\n            # when prefix is '')\n            if not re.match(r'\\d', r):\n                continue\n            if verbose:\n                print(\"picking %s\" % r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    # GIT_DIR can interfere with correct operation of Versioneer.\n    # It may be intended to be passed to the Versioneer-versioned project,\n    # but that should not change where we get our version from.\n    env = os.environ.copy()\n    env.pop(\"GIT_DIR\", None)\n    runner = functools.partial(runner, env=env)\n\n    _, rc = runner(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                   hide_stderr=True)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %s not under git control\" % root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = runner(GITS, [\n        \"describe\", \"--tags\", \"--dirty\", \"--always\", \"--long\",\n        \"--match\", f\"{tag_prefix}[[:digit:]]*\"\n        ], cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = runner(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    branch_name, rc = runner(GITS, [\"rev-parse\", \"--abbrev-ref\", \"HEAD\"],\n                             cwd=root)\n    # --abbrev-ref was added in git-1.6.3\n    if rc != 0 or branch_name is None:\n        raise NotThisMethod(\"'git rev-parse --abbrev-ref' returned error\")\n    branch_name = branch_name.strip()\n\n    if branch_name == \"HEAD\":\n        # If we aren't exactly on a branch, pick a branch which represents\n        # the current commit. If all else fails, we are on a branchless\n        # commit.\n        branches, rc = runner(GITS, [\"branch\", \"--contains\"], cwd=root)\n        # --contains was added in git-1.5.4\n        if rc != 0 or branches is None:\n            raise NotThisMethod(\"'git branch --contains' returned error\")\n        branches = branches.split(\"\\n\")\n\n        # Remove the first line if we're running detached\n        if \"(\" in branches[0]:\n            branches.pop(0)\n\n        # Strip off the leading \"* \" from the list of branches.\n        branches = [branch[2:] for branch in branches]\n        if \"master\" in branches:\n            branch_name = \"master\"\n        elif not branches:\n            branch_name = None\n        else:\n            # Pick the first branch that is returned. Good or bad.\n            branch_name = branches[0]\n\n    pieces[\"branch\"] = branch_name\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparsable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%s'\"\n                               % describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%s' doesn't start with prefix '%s'\"\n                print(fmt % (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%s' doesn't start with prefix '%s'\"\n                               % (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        out, rc = runner(GITS, [\"rev-list\", \"HEAD\", \"--left-right\"], cwd=root)\n        pieces[\"distance\"] = len(out.split())  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = runner(GITS, [\"show\", \"-s\", \"--format=%ci\", \"HEAD\"], cwd=root)[0].strip()\n    # Use only the last line.  Previous lines may contain GPG signature\n    # information.\n    date = date.splitlines()[-1]\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef plus_or_dot(pieces):\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces):\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_branch(pieces):\n    \"\"\"TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch. Note that .dev0 sorts backwards\n    (a feature branch will appear \"older\" than the master branch).\n\n    Exceptions:\n    1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0\"\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef pep440_split_post(ver):\n    \"\"\"Split pep440 version string at the post-release segment.\n\n    Returns the release segments before the post-release and the\n    post-release version number (or -1 if no post-release segment is present).\n    \"\"\"\n    vc = str.split(ver, \".post\")\n    return vc[0], int(vc[1] or 0) if len(vc) == 2 else None\n\n\ndef render_pep440_pre(pieces):\n    \"\"\"TAG[.postN.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post0.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        if pieces[\"distance\"]:\n            # update the post release segment\n            tag_version, post_version = pep440_split_post(pieces[\"closest-tag\"])\n            rendered = tag_version\n            if post_version is not None:\n                rendered += \".post%d.dev%d\" % (post_version + 1, pieces[\"distance\"])\n            else:\n                rendered += \".post0.dev%d\" % (pieces[\"distance\"])\n        else:\n            # no commits, use the tag as the version\n            rendered = pieces[\"closest-tag\"]\n    else:\n        # exception #1\n        rendered = \"0.post0.dev%d\" % pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_post_branch(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_old(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces):\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces):\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces, style):\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-branch\":\n        rendered = render_pep440_branch(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-post-branch\":\n        rendered = render_pep440_post_branch(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%s'\" % style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\ndef get_versions():\n    \"\"\"Get version information or return default if unable to do so.\"\"\"\n    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have\n    # __file__, we can work backwards from there to the root. Some\n    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which\n    # case we can only use expanded keywords.\n\n    cfg = get_config()\n    verbose = cfg.verbose\n\n    try:\n        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,\n                                          verbose)\n    except NotThisMethod:\n        pass\n\n    try:\n        root = os.path.realpath(__file__)\n        # versionfile_source is the relative path from the top of the source\n        # tree (where the .git directory might live) to this file. Invert\n        # this to find the root from __file__.\n        for _ in cfg.versionfile_source.split('/'):\n            root = os.path.dirname(root)\n    except NameError:\n        return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n                \"dirty\": None,\n                \"error\": \"unable to find root of source tree\",\n                \"date\": None}\n\n    try:\n        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)\n        return render(pieces, cfg.style)\n    except NotThisMethod:\n        pass\n\n    try:\n        if cfg.parentdir_prefix:\n            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n    except NotThisMethod:\n        pass\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None,\n            \"error\": \"unable to compute version\", \"date\": None}\n"
  },
  {
    "path": "monai/apps/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .datasets import CrossValidation, DecathlonDataset, MedNISTDataset, TciaDataset\nfrom .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar\nfrom .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger\n"
  },
  {
    "path": "monai/apps/auto3dseg/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .auto_runner import AutoRunner\nfrom .bundle_gen import BundleAlgo, BundleGen\nfrom .data_analyzer import DataAnalyzer\nfrom .ensemble_builder import (\n    AlgoEnsemble,\n    AlgoEnsembleBestByFold,\n    AlgoEnsembleBestN,\n    AlgoEnsembleBuilder,\n    EnsembleRunner,\n)\nfrom .hpo_gen import NNIGen, OptunaGen\nfrom .utils import export_bundle_algo_history, get_name_from_algo_id, import_bundle_algo_history\n"
  },
  {
    "path": "monai/apps/auto3dseg/__main__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.apps.auto3dseg.auto_runner import AutoRunner\nfrom monai.apps.auto3dseg.bundle_gen import BundleAlgo, BundleGen\nfrom monai.apps.auto3dseg.data_analyzer import DataAnalyzer\nfrom monai.apps.auto3dseg.ensemble_builder import AlgoEnsembleBuilder, EnsembleRunner\nfrom monai.apps.auto3dseg.hpo_gen import NNIGen, OptunaGen\n\nif __name__ == \"__main__\":\n    from monai.utils import optional_import\n\n    fire, _ = optional_import(\"fire\")\n    fire.Fire(\n        {\n            \"DataAnalyzer\": DataAnalyzer,\n            \"BundleGen\": BundleGen,\n            \"BundleAlgo\": BundleAlgo,\n            \"AlgoEnsembleBuilder\": AlgoEnsembleBuilder,\n            \"EnsembleRunner\": EnsembleRunner,\n            \"AutoRunner\": AutoRunner,\n            \"NNIGen\": NNIGen,\n            \"OptunaGen\": OptunaGen,\n        }\n    )\n"
  },
  {
    "path": "monai/apps/auto3dseg/auto_runner.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport warnings\nfrom copy import deepcopy\nfrom time import sleep\nfrom typing import Any, cast\n\nimport torch\n\nfrom monai.apps.auto3dseg.bundle_gen import BundleGen\nfrom monai.apps.auto3dseg.data_analyzer import DataAnalyzer\nfrom monai.apps.auto3dseg.ensemble_builder import EnsembleRunner\nfrom monai.apps.auto3dseg.hpo_gen import NNIGen\nfrom monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg.utils import algo_to_pickle\nfrom monai.bundle import ConfigParser\nfrom monai.transforms import SaveImage\nfrom monai.utils import AlgoKeys, has_option, look_up_option, optional_import\nfrom monai.utils.misc import check_kwargs_exist_in_class_init, run_cmd\n\nlogger = get_logger(module_name=__name__)\n\nnni, has_nni = optional_import(\"nni\")\n\n\nclass AutoRunner:\n    \"\"\"\n    An interface for handling Auto3Dseg with minimal inputs and understanding of the internal states in Auto3Dseg.\n    The users can run the Auto3Dseg with default settings in one line of code. They can also customize the advanced\n    features Auto3Dseg in a few additional lines. Examples of customization include\n\n        - change cross-validation folds\n        - change training/prediction parameters\n        - change ensemble methods\n        - automatic hyperparameter optimization.\n\n    The output of the interface is a directory that contains\n\n        - data statistics analysis report\n        - algorithm definition files (scripts, configs, pickle objects) and training results (checkpoints, accuracies)\n        - the predictions on the testing datasets from the final algorithm ensemble\n        - a copy of the input arguments in form of YAML\n        - cached intermediate results\n\n    Args:\n        work_dir: working directory to save the intermediate and final results.\n        input: the configuration dictionary or the file path to the configuration in form of YAML.\n            The configuration should contain datalist, dataroot, modality, multigpu, and class_names info.\n        algos: optionally specify algorithms to use.  If a dictionary, must be in the form\n            {\"algname\": dict(_target_=\"algname.scripts.algo.AlgnameAlgo\", template_path=\"algname\"), ...}\n            If a list or a string, defines a subset of names of the algorithms to use, e.g. 'segresnet' or\n            ['segresnet', 'dints'] out of the full set of algorithm templates provided by templates_path_or_url.\n            Defaults to None, to use all available algorithms.\n        analyze: on/off switch to run DataAnalyzer and generate a datastats report. Defaults to None, to automatically\n            decide based on cache, and run data analysis only if we have not completed this step yet.\n        algo_gen: on/off switch to run AlgoGen and generate templated BundleAlgos. Defaults to None, to automatically\n            decide based on cache, and run algorithm folders generation only if we have not completed this step yet.\n        train: on/off switch to run training and generate algorithm checkpoints. Defaults to None, to automatically\n            decide based on cache, and run training only if we have not completed this step yet.\n        hpo: use hyperparameter optimization (HPO) in the training phase. Users can provide a list of\n            hyper-parameter and a search will be performed to investigate the algorithm performances.\n        hpo_backend: a string that indicates the backend of the HPO. Currently, only NNI Grid-search mode\n            is supported\n        ensemble: on/off switch to run model ensemble and use the ensemble to predict outputs in testing\n            datasets.\n        not_use_cache: if the value is True, it will ignore all cached results in data analysis,\n            algorithm generation, or training, and start the pipeline from scratch.\n        templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template\n            zip url will be downloaded and extracted into the work_dir.\n        allow_skip: a switch passed to BundleGen process which determines if some Algo in the default templates\n            can be skipped based on the analysis on the dataset from Auto3DSeg DataAnalyzer.\n        mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of the remote\n            tracking Server; MLflow runs will be recorded locally in algorithms' model folder if the value is None.\n        mlflow_experiment_name: the name of the experiment in MLflow server.\n        kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage\n            transform. For more information, check https://monai.readthedocs.io/en/stable/transforms.html#saveimage.\n\n\n    Examples:\n        - User can use the one-liner to start the Auto3Dseg workflow\n\n        .. code-block:: bash\n\n            python -m monai.apps.auto3dseg AutoRunner run --input \\\n            '{\"modality\": \"ct\", \"datalist\": \"dl.json\", \"dataroot\": \"/dr\", \"multigpu\": true, \"class_names\": [\"A\", \"B\"]}'\n\n        - User can also save the input dictionary as a input YAML file and use the following one-liner\n\n        .. code-block:: bash\n\n            python -m monai.apps.auto3dseg AutoRunner run --input=./input.yaml\n\n        - User can specify work_dir and data source config input and run AutoRunner:\n\n        .. code-block:: python\n\n            work_dir = \"./work_dir\"\n            input = \"path/to/input_yaml\"\n            runner = AutoRunner(work_dir=work_dir, input=input)\n            runner.run()\n\n        - User can specify a subset of algorithms to use and run AutoRunner:\n\n        .. code-block:: python\n\n            work_dir = \"./work_dir\"\n            input = \"path/to/input_yaml\"\n            algos = [\"segresnet\", \"dints\"]\n            runner = AutoRunner(work_dir=work_dir, input=input, algos=algos)\n            runner.run()\n\n        - User can specify a local folder with algorithms templates and run AutoRunner:\n\n        .. code-block:: python\n\n            work_dir = \"./work_dir\"\n            input = \"path/to/input_yaml\"\n            algos = \"segresnet\"\n            templates_path_or_url = \"./local_path_to/algorithm_templates\"\n            runner = AutoRunner(work_dir=work_dir, input=input, algos=algos, templates_path_or_url=templates_path_or_url)\n            runner.run()\n\n        - User can specify training parameters by:\n\n        .. code-block:: python\n\n            input = \"path/to/input_yaml\"\n            runner = AutoRunner(input=input)\n            train_param = {\n                \"num_epochs_per_validation\": 1,\n                \"num_images_per_batch\": 2,\n                \"num_epochs\": 2,\n            }\n            runner.set_training_params(params=train_param)  # 2 epochs\n            runner.run()\n\n        - User can specify the fold number of cross validation\n\n        .. code-block:: python\n\n            input = \"path/to/input_yaml\"\n            runner = AutoRunner(input=input)\n            runner.set_num_fold(n_fold = 2)\n            runner.run()\n\n        - User can specify the prediction parameters during algo ensemble inference:\n\n        .. code-block:: python\n\n            input = \"path/to/input_yaml\"\n            pred_params = {\n                'files_slices': slice(0,2),\n                'mode': \"vote\",\n                'sigmoid': True,\n            }\n            runner = AutoRunner(input=input)\n            runner.set_prediction_params(params=pred_params)\n            runner.run()\n\n        - User can define a grid search space and use the HPO during training.\n\n        .. code-block:: python\n\n            input = \"path/to/input_yaml\"\n            runner = AutoRunner(input=input, hpo=True)\n            runner.set_nni_search_space({\"learning_rate\": {\"_type\": \"choice\", \"_value\": [0.0001, 0.001, 0.01, 0.1]}})\n            runner.run()\n\n    Notes:\n        Expected results in the work_dir as below::\n\n            work_dir/\n            ├── algorithm_templates # bundle algo templates (scripts/configs)\n            ├── cache.yaml          # Autorunner will automatically cache results to save time\n            ├── datastats.yaml      # datastats of the dataset\n            ├── dints_0             # network scripts/configs/checkpoints and pickle object of the algo\n            ├── ensemble_output     # the prediction of testing datasets from the ensemble of the algos\n            ├── input.yaml          # copy of the input data source configs\n            ├── segresnet_0         # network scripts/configs/checkpoints and pickle object of the algo\n            ├── segresnet2d_0       # network scripts/configs/checkpoints and pickle object of the algo\n            └── swinunetr_0         # network scripts/configs/checkpoints and pickle object of the algo\n\n\n        The input config requires at least the following keys:\n            - ``modality``: the modality of the data, e.g. \"ct\", \"mri\", etc.\n            - ``datalist``: the path to the datalist file in JSON format.\n            - ``dataroot``: the root directory of the data files.\n\n        For the datalist file format, see the description under :py:func:`monai.data.load_decathlon_datalist`.\n        Note that the AutoRunner will use the \"validation\" key in the datalist file if it exists, otherwise\n        it will do cross-validation, by default with five folds (this is hardcoded).\n    \"\"\"\n\n    analyze_params: dict | None\n\n    def __init__(\n        self,\n        work_dir: str = \"./work_dir\",\n        input: dict[str, Any] | str | None = None,\n        algos: dict | list | str | None = None,\n        analyze: bool | None = None,\n        algo_gen: bool | None = None,\n        train: bool | None = None,\n        hpo: bool = False,\n        hpo_backend: str = \"nni\",\n        ensemble: bool = True,\n        not_use_cache: bool = False,\n        templates_path_or_url: str | None = None,\n        allow_skip: bool = True,\n        mlflow_tracking_uri: str | None = None,\n        mlflow_experiment_name: str | None = None,\n        **kwargs: Any,\n    ):\n        if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), \"input.yaml\")):\n            input = os.path.join(os.path.abspath(work_dir), \"input.yaml\")\n            logger.info(f\"Input config is not provided, using the default {input}\")\n\n        self.data_src_cfg = dict()\n        if isinstance(input, dict):\n            self.data_src_cfg = input\n        elif isinstance(input, str) and os.path.isfile(input):\n            self.data_src_cfg = ConfigParser.load_config_file(input)\n            logger.info(f\"Loading input config {input}\")\n        else:\n            raise ValueError(f\"{input} is not a valid file or dict\")\n\n        if \"work_dir\" in self.data_src_cfg:  # override from config\n            work_dir = self.data_src_cfg[\"work_dir\"]\n        self.work_dir = os.path.abspath(work_dir)\n\n        logger.info(f\"AutoRunner using work directory {self.work_dir}\")\n        os.makedirs(self.work_dir, exist_ok=True)\n        self.data_src_cfg_name = os.path.join(self.work_dir, \"input.yaml\")\n\n        self.algos = algos\n        self.templates_path_or_url = templates_path_or_url\n        self.allow_skip = allow_skip\n\n        # cache.yaml\n        self.not_use_cache = not_use_cache\n        self.cache_filename = os.path.join(self.work_dir, \"cache.yaml\")\n        self.cache = self.read_cache()\n        self.export_cache()\n\n        # determine if we need to analyze, algo_gen or train from cache, unless manually provided\n        self.analyze = not self.cache[\"analyze\"] if analyze is None else analyze\n        self.algo_gen = not self.cache[\"algo_gen\"] if algo_gen is None else algo_gen\n        self.train = train\n        self.ensemble = ensemble  # last step, no need to check\n        self.hpo = hpo and has_nni\n        self.hpo_backend = hpo_backend\n        self.mlflow_tracking_uri = mlflow_tracking_uri\n        self.mlflow_experiment_name = mlflow_experiment_name\n        self.kwargs = deepcopy(kwargs)\n\n        # parse input config for AutoRunner param overrides\n        for param in [\n            \"analyze\",\n            \"algo_gen\",\n            \"train\",\n            \"hpo\",\n            \"ensemble\",\n            \"not_use_cache\",\n            \"allow_skip\",\n        ]:  # override from config\n            if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool):\n                setattr(self, param, self.data_src_cfg[param])  # e.g. self.analyze = self.data_src_cfg[\"analyze\"]\n\n        for param in [\n            \"algos\",\n            \"hpo_backend\",\n            \"templates_path_or_url\",\n            \"mlflow_tracking_uri\",\n            \"mlflow_experiment_name\",\n        ]:  # override from config\n            if param in self.data_src_cfg:\n                setattr(self, param, self.data_src_cfg[param])  # e.g. self.algos = self.data_src_cfg[\"algos\"]\n\n        missing_keys = {\"dataroot\", \"datalist\", \"modality\"}.difference(self.data_src_cfg.keys())\n        if len(missing_keys) > 0:\n            raise ValueError(f\"Config keys are missing {missing_keys}\")\n\n        if not os.path.exists(self.data_src_cfg[\"datalist\"]):\n            raise ValueError(f\"Datalist file is not found {self.data_src_cfg['datalist']}\")\n\n        # copy datalist to work_dir\n        datalist_filename = os.path.join(self.work_dir, os.path.basename(self.data_src_cfg[\"datalist\"]))\n        if datalist_filename != self.data_src_cfg[\"datalist\"]:\n            try:\n                shutil.copyfile(self.data_src_cfg[\"datalist\"], datalist_filename)\n                logger.info(f\"Datalist was copied to work_dir: {datalist_filename}\")\n            except shutil.SameFileError:\n                pass\n\n        # inspect and update folds\n        self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)\n        if \"num_fold\" in self.data_src_cfg:\n            num_fold = int(self.data_src_cfg[\"num_fold\"])  # override from config\n            logger.info(f\"Setting num_fold {num_fold} based on the input config.\")\n        else:\n            num_fold = self.max_fold\n            logger.info(f\"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.\")\n\n        self.data_src_cfg[\"datalist\"] = datalist_filename  # update path to a version in work_dir and save user input\n        ConfigParser.export_config_file(\n            config=self.data_src_cfg, filepath=self.data_src_cfg_name, fmt=\"yaml\", sort_keys=False\n        )\n\n        self.dataroot = self.data_src_cfg[\"dataroot\"]\n        self.datastats_filename = os.path.join(self.work_dir, \"datastats.yaml\")\n        self.datalist_filename = datalist_filename\n\n        self.set_training_params()\n        self.set_device_info()\n        self.set_prediction_params()\n        self.set_analyze_params()\n        self.set_ensemble_method()\n        self.set_num_fold(num_fold=num_fold)\n\n        self.gpu_customization = False\n        self.gpu_customization_specs: dict[str, Any] = {}\n\n        # hpo\n        if self.hpo_backend.lower() != \"nni\":\n            raise NotImplementedError(\"HPOGen backend only supports NNI\")\n        self.hpo = self.hpo and has_nni\n        self.set_hpo_params()\n        self.search_space: dict[str, dict[str, Any]] = {}\n        self.hpo_tasks = 0\n\n        if \"sigmoid\" not in self.kwargs and \"sigmoid\" in self.data_src_cfg:\n            self.kwargs[\"sigmoid\"] = self.data_src_cfg[\"sigmoid\"]\n\n    def read_cache(self):\n        \"\"\"\n        Check if the intermediate result is cached after each step in the current working directory\n\n        Returns:\n            a dict of cache results. If not_use_cache is set to True, or there is no cache file in the\n            working directory, the result will be ``empty_cache`` in which all ``has_cache`` keys are\n            set to False.\n        \"\"\"\n\n        empty_cache = {\"analyze\": False, \"datastats\": None, \"algo_gen\": False, \"train\": False}\n\n        if self.not_use_cache or not os.path.isfile(self.cache_filename):\n            return empty_cache\n\n        cache = ConfigParser.load_config_file(self.cache_filename)\n\n        for k, v in empty_cache.items():\n            cache.setdefault(k, v)\n\n        if cache[\"analyze\"]:\n            if not (isinstance(cache[\"datastats\"], str) and os.path.isfile(cache[\"datastats\"])):\n                cache[\"analyze\"] = False\n                cache[\"datastats\"] = None\n\n        if cache[\"algo_gen\"]:\n            history = import_bundle_algo_history(self.work_dir, only_trained=False)\n            if len(history) == 0:  # no saved algo_objects\n                cache[\"algo_gen\"] = False\n\n        if cache[\"train\"]:\n            trained_history = import_bundle_algo_history(self.work_dir, only_trained=True)\n            if len(trained_history) == 0:\n                cache[\"train\"] = False\n\n        return cache\n\n    def export_cache(self, **kwargs):\n        \"\"\"\n        Save the cache state as ``cache.yaml`` in the working directory\n        \"\"\"\n        self.cache.update(kwargs)\n        ConfigParser.export_config_file(\n            self.cache, self.cache_filename, fmt=\"yaml\", default_flow_style=None, sort_keys=False\n        )\n\n    def inspect_datalist_folds(self, datalist_filename: str) -> int:\n        \"\"\"\n        Returns number of folds in the datalist file, and assigns fold numbers if not provided.\n\n        Args:\n            datalist_filename: path to the datalist file.\n\n        Notes:\n            If the fold key is not provided, it auto generates 5 folds assignments in the training key list.\n            If validation key list is available, then it assumes a single fold validation.\n        \"\"\"\n\n        datalist = ConfigParser.load_config_file(datalist_filename)\n        if \"training\" not in datalist:\n            raise ValueError(\"Datalist files has no training key:\" + str(datalist_filename))\n\n        fold_list = [int(d[\"fold\"]) for d in datalist[\"training\"] if \"fold\" in d]\n\n        if len(fold_list) > 0:\n            num_fold = max(fold_list) + 1\n            logger.info(f\"Found num_fold {num_fold} based on the input datalist {datalist_filename}.\")\n            # check if every fold is present\n            if len(set(fold_list)) != num_fold:\n                raise ValueError(f\"Fold numbers are not continuous from 0 to {num_fold - 1}\")\n        elif \"validation\" in datalist and len(datalist[\"validation\"]) > 0:\n            logger.info(\"No fold numbers provided, attempting to use a single fold based on the validation key\")\n            # update the datalist file\n            for d in datalist[\"training\"]:\n                d[\"fold\"] = 1\n            for d in datalist[\"validation\"]:\n                d[\"fold\"] = 0\n\n            val_labels = {d[\"label\"]: d for d in datalist[\"validation\"] if \"label\" in d}\n            logger.info(\n                f\"Found {len(val_labels)} items in the validation key, saving updated datalist to\", datalist_filename\n            )\n\n            # check for duplicates\n            for d in datalist[\"training\"]:\n                if d[\"label\"] in val_labels:\n                    d[\"fold\"] = 0\n                    del val_labels[d[\"label\"]]\n\n            datalist[\"training\"] = datalist[\"training\"] + list(val_labels.values())\n\n            ConfigParser.export_config_file(datalist, datalist_filename, fmt=\"json\", indent=4)\n            num_fold = 1\n\n        else:\n            num_fold = 5\n\n            warnings.warn(\n                f\"Datalist has no folds specified {datalist_filename}...\"\n                f\"Generating {num_fold} folds randomly.\"\n                f\"Please consider presaving fold numbers beforehand for repeated experiments.\"\n            )\n\n            from sklearn.model_selection import KFold\n\n            kf = KFold(n_splits=num_fold, shuffle=True, random_state=0)\n            for i, (_, valid_idx) in enumerate(kf.split(datalist[\"training\"])):\n                for vi in valid_idx:\n                    datalist[\"training\"][vi][\"fold\"] = i\n\n            ConfigParser.export_config_file(datalist, datalist_filename, fmt=\"json\", indent=4)\n\n        return num_fold\n\n    def set_gpu_customization(\n        self, gpu_customization: bool = False, gpu_customization_specs: dict[str, Any] | None = None\n    ) -> AutoRunner:\n        \"\"\"\n        Set options for GPU-based parameter customization/optimization.\n\n        Args:\n            gpu_customization: the switch to determine automatically customize/optimize bundle script/config\n                parameters for each bundleAlgo based on gpus. Custom parameters are obtained through dummy\n                training to simulate the actual model training process and hyperparameter optimization (HPO)\n                experiments.\n            gpu_customization_specs (optional): the dictionary to enable users overwrite the HPO settings. user can\n                overwrite part of variables as follows or all of them. The structure is as follows.\n\n                .. code-block:: python\n\n                    gpu_customization_specs = {\n                        'ALGO': {\n                            'num_trials': 6,\n                            'range_num_images_per_batch': [1, 20],\n                            'range_num_sw_batch_size': [1, 20]\n                        }\n                    }\n\n            ALGO: the name of algorithm. It could be one of algorithm names (e.g., 'dints') or 'universal' which\n                would apply changes to all algorithms. Possible options are\n\n                - {``\"universal\"``, ``\"dints\"``, ``\"segresnet\"``, ``\"segresnet2d\"``, ``\"swinunetr\"``}.\n\n            num_trials: the number of HPO trials/experiments to run.\n            range_num_images_per_batch: the range of number of images per mini-batch.\n            range_num_sw_batch_size: the range of batch size in sliding-window inferer.\n        \"\"\"\n        self.gpu_customization = gpu_customization\n        if gpu_customization_specs is not None:\n            self.gpu_customization_specs = gpu_customization_specs\n\n        return self\n\n    def set_num_fold(self, num_fold: int = 5) -> AutoRunner:\n        \"\"\"\n        Set the number of cross validation folds for all algos.\n\n        Args:\n            num_fold: a positive integer to define the number of folds.\n        \"\"\"\n\n        if num_fold <= 0:\n            raise ValueError(f\"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}\")\n        if num_fold > self.max_fold:\n            # Auto3DSeg must contain validation set, so the maximum fold number is max_fold.\n            raise ValueError(\n                f\"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}.\"\n            )\n        self.num_fold = num_fold\n\n        return self\n\n    def set_training_params(self, params: dict[str, Any] | None = None) -> AutoRunner:\n        \"\"\"\n        Set the training params for all algos.\n\n        Args:\n            params: a dict that defines the overriding key-value pairs during training. The overriding method\n                is defined by the algo class.\n\n        Examples:\n            For BundleAlgo objects, the training parameter to shorten the training time to a few epochs can be\n                {\"num_epochs\": 2, \"num_epochs_per_validation\": 1}\n\n        \"\"\"\n        self.train_params = deepcopy(params) if params is not None else {}\n        if \"CUDA_VISIBLE_DEVICES\" in self.train_params:\n            warnings.warn(\n                \"CUDA_VISIBLE_DEVICES is deprecated from 'set_training_params'. Use 'set_device_info' instead.\",\n                DeprecationWarning,\n            )\n\n        return self\n\n    def set_device_info(\n        self,\n        cuda_visible_devices: list[int] | str | None = None,\n        num_nodes: int | None = None,\n        mn_start_method: str | None = None,\n        cmd_prefix: str | None = None,\n    ) -> AutoRunner:\n        \"\"\"\n        Set the device related info\n\n        Args:\n            cuda_visible_devices: define GPU ids for data analyzer, training, and ensembling.\n                List of GPU ids [0,1,2,3] or a string \"0,1,2,3\".\n                Default using env \"CUDA_VISIBLE_DEVICES\" or all devices available.\n            num_nodes: number of nodes for training and ensembling.\n                Default using env \"NUM_NODES\" or 1 if \"NUM_NODES\" is unset.\n            mn_start_method: multi-node start method. Autorunner will use the method to start multi-node processes.\n                Default using env \"MN_START_METHOD\" or 'bcprun' if \"MN_START_METHOD\" is unset.\n            cmd_prefix: command line prefix for subprocess running in BundleAlgo and EnsembleRunner.\n                Default using env \"CMD_PREFIX\" or None, examples are:\n\n                    - single GPU/CPU or multinode bcprun: \"python \" or \"/opt/conda/bin/python3.9 \",\n                    - single node multi-GPU running \"torchrun --nnodes=1 --nproc_per_node=2 \"\n\n                If user define this prefix, please make sure --nproc_per_node matches cuda_visible_device or\n                os.env['CUDA_VISIBLE_DEVICES']. Also always set --nnodes=1. Set num_nodes for multi-node.\n        \"\"\"\n        self.device_setting: dict[str, Any] = {}\n        if cuda_visible_devices is None:\n            cuda_visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n        if cuda_visible_devices is None:  # still None after reading the environ\n            self.device_setting[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(x) for x in range(torch.cuda.device_count())])\n            self.device_setting[\"n_devices\"] = torch.cuda.device_count()\n        elif isinstance(cuda_visible_devices, str):\n            self.device_setting[\"CUDA_VISIBLE_DEVICES\"] = cuda_visible_devices\n            self.device_setting[\"n_devices\"] = len(cuda_visible_devices.split(\",\"))\n        elif isinstance(cuda_visible_devices, (list, tuple)):\n            self.device_setting[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(x) for x in cuda_visible_devices])\n            self.device_setting[\"n_devices\"] = len(cuda_visible_devices)\n        else:\n            logger.warning(f\"Wrong format of cuda_visible_devices {cuda_visible_devices}, devices not set\")\n\n        if num_nodes is None:\n            num_nodes = int(os.environ.get(\"NUM_NODES\", 1))\n        self.device_setting[\"NUM_NODES\"] = num_nodes\n\n        if mn_start_method is None:\n            mn_start_method = os.environ.get(\"MN_START_METHOD\", \"bcprun\")\n        self.device_setting[\"MN_START_METHOD\"] = mn_start_method\n\n        if cmd_prefix is None:\n            cmd_prefix = os.environ.get(\"CMD_PREFIX\", \"\")\n        self.device_setting[\"CMD_PREFIX\"] = cmd_prefix\n\n        if cmd_prefix is not None:\n            logger.info(f\"Using user defined command running prefix {cmd_prefix}, will override other settings\")\n\n        return self\n\n    def set_ensemble_method(self, ensemble_method_name: str = \"AlgoEnsembleBestByFold\", **kwargs: Any) -> AutoRunner:\n        \"\"\"\n        Set the bundle ensemble method name and parameters for save image transform parameters.\n\n        Args:\n            ensemble_method_name: the name of the ensemble method. Only two methods are supported \"AlgoEnsembleBestN\"\n                and \"AlgoEnsembleBestByFold\".\n            kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for\n                ``AlgoEnsembleBestN`` is supported.\n        \"\"\"\n        self.ensemble_method_name = look_up_option(\n            ensemble_method_name, supported=[\"AlgoEnsembleBestN\", \"AlgoEnsembleBestByFold\"]\n        )\n        self.kwargs.update(kwargs)\n\n        return self\n\n    def set_image_save_transform(self, **kwargs: Any) -> AutoRunner:\n        \"\"\"\n        Set the ensemble output transform.\n\n        Args:\n            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage\n                transform. For more information, check https://monai.readthedocs.io/en/stable/transforms.html#saveimage.\n\n        \"\"\"\n\n        are_all_args_present, extra_args = check_kwargs_exist_in_class_init(SaveImage, kwargs)\n        if are_all_args_present:\n            self.kwargs.update(kwargs)\n        else:\n            raise ValueError(\n                f\"{extra_args} are not supported in monai.transforms.SaveImage,\"\n                \"Check https://monai.readthedocs.io/en/stable/transforms.html#saveimage for more information.\"\n            )\n\n        return self\n\n    def set_prediction_params(self, params: dict[str, Any] | None = None) -> AutoRunner:\n        \"\"\"\n        Set the prediction params for all algos.\n\n        Args:\n            params: a dict that defines the overriding key-value pairs during prediction. The overriding method\n                is defined by the algo class.\n\n        Examples:\n\n            For BundleAlgo objects, this set of param will specify the algo ensemble to only inference the first\n                two files in the testing datalist {\"file_slices\": slice(0, 2)}\n\n        \"\"\"\n        self.pred_params = deepcopy(params) if params is not None else {}\n\n        return self\n\n    def set_analyze_params(self, params: dict[str, Any] | None = None) -> AutoRunner:\n        \"\"\"\n        Set the data analysis extra params.\n\n        Args:\n            params: a dict that defines the overriding key-value pairs during training. The overriding method\n                is defined by the algo class.\n\n        \"\"\"\n        if params is None:\n            self.analyze_params = {\"do_ccp\": False, \"device\": \"cuda\"}\n        else:\n            self.analyze_params = deepcopy(params)\n\n        return self\n\n    def set_hpo_params(self, params: dict[str, Any] | None = None) -> AutoRunner:\n        \"\"\"\n        Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle\n        templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the\n        key is found to be one of:\n\n            - \"trialCodeDirectory\"\n            - \"trialGpuNumber\"\n            - \"trialConcurrency\"\n            - \"maxTrialNumber\"\n            - \"maxExperimentDuration\"\n            - \"tuner\"\n            - \"trainingService\"\n\n        and (3) enable the dry-run mode if the user would generate the NNI configs without starting the NNI service.\n\n        Args:\n            params: a dict that defines the overriding key-value pairs during instantiation of the algo. For\n                BundleAlgo, it will override the template config filling.\n\n        Notes:\n            Users can set ``nni_dry_run`` to ``True`` in the ``params`` to enable the dry-run mode for the NNI backend.\n\n        \"\"\"\n        self.hpo_params = self.train_params if params is None else params\n\n        return self\n\n    def set_nni_search_space(self, search_space: dict[str, Any]) -> AutoRunner:\n        \"\"\"\n        Set the search space for NNI parameter search.\n\n        Args:\n            search_space: hyper parameter search space in the form of dict. For more information, please check\n                NNI documentation: https://nni.readthedocs.io/en/v2.2/Tutorial/SearchSpaceSpec.html .\n        \"\"\"\n        value_combinations = 1\n        for k, v in search_space.items():\n            if \"_value\" not in v:\n                raise ValueError(f\"{search_space} key {k} value {v} has not _value\")\n            value_combinations *= len(v[\"_value\"])\n\n        self.search_space = search_space\n        self.hpo_tasks = value_combinations\n\n        return self\n\n    def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:\n        \"\"\"\n        Train the Algos in a sequential scheme. The order of training is randomized.\n\n        Args:\n            history: the history of generated Algos. It is a list of dicts. Each element has the task name\n                (e.g. \"dints_0\" for dints network in fold 0) as the key and the algo object as the value.\n                After the training, the algo object with the ``best_metric`` will be saved as a pickle file.\n\n        Note:\n            The final results of the model training will be written to all the generated algorithm's output\n            folders under the working directory. The results include the model checkpoints, a\n            progress.yaml, accuracies in CSV and a pickle file of the Algo object.\n        \"\"\"\n        for algo_dict in history:\n            algo = algo_dict[AlgoKeys.ALGO]\n            if has_option(algo.train, \"device_setting\"):\n                algo.train(self.train_params, self.device_setting)\n            else:\n                algo.train(self.train_params)\n            acc = algo.get_score()\n\n            algo_meta_data = {str(AlgoKeys.SCORE): acc}\n            algo_to_pickle(algo, template_path=algo.template_path, **algo_meta_data)\n\n    def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:\n        \"\"\"\n        Train the Algos using HPO.\n\n        Args:\n            history: the history of generated Algos. It is a list of dicts. Each element has the task name\n                (e.g. \"dints_0\" for dints network in fold 0) as the key and the algo object as the value.\n                After the training, the algo object with the ``best_metric`` will be saved as a pickle file.\n\n        Note:\n            The final results of the model training will not be written to all the previously generated\n            algorithm's output folders. Instead, HPO will generate a new algo during the searching, and\n            the new algo will be saved under the working directory with a different format of the name.\n            For example, if the searching space has \"learning_rate\", the result of HPO will be written to\n            a folder name with original task name and the param (e.g. \"dints_0_learning_rate_0.001\").\n            The results include the model checkpoints, a progress.yaml, accuracies in CSV and a pickle\n            file of the Algo object.\n\n        \"\"\"\n        default_nni_config = {\n            \"trialCodeDirectory\": \".\",\n            \"trialGpuNumber\": torch.cuda.device_count(),\n            \"trialConcurrency\": 1,\n            \"maxTrialNumber\": 10,\n            \"maxExperimentDuration\": \"1h\",\n            \"tuner\": {\"name\": \"GridSearch\"},\n            \"trainingService\": {\"platform\": \"local\", \"useActiveGpu\": True},\n        }\n\n        last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))\n        mode_dry_run = self.hpo_params.pop(\"nni_dry_run\", False)\n        for algo_dict in history:\n            name = algo_dict[AlgoKeys.ID]\n            algo = algo_dict[AlgoKeys.ALGO]\n            nni_gen = NNIGen(algo=algo, params=self.hpo_params)\n            obj_filename = nni_gen.get_obj_filename()\n            nni_config = deepcopy(default_nni_config)\n            # override the default nni config with the same key in hpo_params\n            for key in self.hpo_params:\n                if key in nni_config:\n                    nni_config[key] = self.hpo_params[key]\n            nni_config.update({\"experimentName\": name})\n            nni_config.update({\"search_space\": self.search_space})\n            trial_cmd = \"python -m monai.apps.auto3dseg NNIGen run_algo \" + obj_filename + \" \" + self.work_dir\n            nni_config.update({\"trialCommand\": trial_cmd})\n            nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f\"{name}_nni_config.yaml\"))\n            ConfigParser.export_config_file(nni_config, nni_config_filename, fmt=\"yaml\", default_flow_style=None)\n\n            max_trial = min(self.hpo_tasks, cast(int, default_nni_config[\"maxTrialNumber\"]))\n            cmd = \"nnictl create --config \" + nni_config_filename + \" --port 8088\"\n\n            if mode_dry_run:\n                logger.info(f\"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}\")\n                continue\n\n            run_cmd(cmd.split(), check=True)\n\n            n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))\n            while n_trainings - last_total_tasks < max_trial:\n                sleep(1)\n                n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))\n\n            cmd = \"nnictl stop --all\"\n            run_cmd(cmd.split(), check=True)\n            logger.info(f\"NNI completes HPO on {name}\")\n            last_total_tasks = n_trainings\n\n    def run(self):\n        \"\"\"\n        Run the AutoRunner pipeline\n        \"\"\"\n        # step 1: data analysis\n        if self.analyze and self.analyze_params is not None:\n            logger.info(\"Running data analysis...\")\n            da = DataAnalyzer(\n                self.datalist_filename, self.dataroot, output_path=self.datastats_filename, **self.analyze_params\n            )\n            da.get_all_case_stats()\n\n            da = None  # type: ignore\n            torch.cuda.empty_cache()\n\n            self.export_cache(analyze=True, datastats=self.datastats_filename)\n        else:\n            logger.info(\"Skipping data analysis...\")\n\n        # step 2: algorithm generation\n        if self.algo_gen:\n            if not os.path.isfile(self.datastats_filename):\n                raise ValueError(\n                    f\"Could not find the datastats file {self.datastats_filename}. \"\n                    \"Possibly the required data analysis step was not completed.\"\n                )\n\n            bundle_generator = BundleGen(\n                algos=self.algos,\n                algo_path=self.work_dir,\n                templates_path_or_url=self.templates_path_or_url,\n                data_stats_filename=self.datastats_filename,\n                data_src_cfg_name=self.data_src_cfg_name,\n                mlflow_tracking_uri=self.mlflow_tracking_uri,\n                mlflow_experiment_name=self.mlflow_experiment_name,\n            )\n\n            if self.gpu_customization:\n                bundle_generator.generate(\n                    self.work_dir,\n                    num_fold=self.num_fold,\n                    gpu_customization=self.gpu_customization,\n                    gpu_customization_specs=self.gpu_customization_specs,\n                    allow_skip=self.allow_skip,\n                )\n            else:\n                bundle_generator.generate(self.work_dir, num_fold=self.num_fold, allow_skip=self.allow_skip)\n            history = bundle_generator.get_history()\n            export_bundle_algo_history(history)\n            self.export_cache(algo_gen=True)\n        else:\n            logger.info(\"Skipping algorithm generation...\")\n\n        # step 3: algo training\n        auto_train_choice = self.train is None\n        if self.train or (auto_train_choice and not self.cache[\"train\"]):\n            history = import_bundle_algo_history(self.work_dir, only_trained=False)\n\n            if len(history) == 0:\n                raise ValueError(\n                    f\"Could not find training scripts in {self.work_dir}. \"\n                    \"Possibly the required algorithms generation step was not completed.\"\n                )\n\n            if auto_train_choice:\n                skip_algos = [h[AlgoKeys.ID] for h in history if h[AlgoKeys.IS_TRAINED]]\n                if skip_algos:\n                    logger.info(\n                        f\"Skipping already trained algos {skip_algos}.\"\n                        \"Set option train=True to always retrain all algos.\"\n                    )\n                    history = [h for h in history if not h[AlgoKeys.IS_TRAINED]]\n\n            if len(history) > 0:\n                if not self.hpo:\n                    self._train_algo_in_sequence(history)\n                else:\n                    self._train_algo_in_nni(history)\n\n            self.export_cache(train=True)\n        else:\n            logger.info(\"Skipping algorithm training...\")\n\n        # step 4: model ensemble and write the prediction to disks.\n        if self.ensemble:\n            ensemble_runner = EnsembleRunner(\n                data_src_cfg_name=self.data_src_cfg_name,\n                work_dir=self.work_dir,\n                num_fold=self.num_fold,\n                ensemble_method_name=self.ensemble_method_name,\n                mgpu=int(self.device_setting[\"n_devices\"]) > 1,\n                **self.kwargs,  # for set_image_save_transform\n                **self.pred_params,\n            )  # for inference\n            ensemble_runner.run(self.device_setting)\n        logger.info(\"Auto3Dseg pipeline is completed successfully.\")\n"
  },
  {
    "path": "monai/apps/auto3dseg/bundle_gen.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport importlib\nimport os\nimport re\nimport shutil\nimport subprocess\nimport sys\nimport time\nimport warnings\nfrom copy import deepcopy\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\nfrom typing import Any\nfrom urllib.parse import urlparse\n\nimport torch\n\nfrom monai.apps import download_and_extract\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg.algo_gen import Algo, AlgoGen\nfrom monai.auto3dseg.utils import (\n    _prepare_cmd_bcprun,\n    _prepare_cmd_default,\n    _prepare_cmd_torchrun,\n    _run_cmd_bcprun,\n    _run_cmd_torchrun,\n    algo_to_pickle,\n)\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.config import PathLike\nfrom monai.utils import ensure_tuple, look_up_option, run_cmd\nfrom monai.utils.enums import AlgoKeys\nfrom monai.utils.misc import MONAIEnvVars\n\nlogger = get_logger(module_name=__name__)\nALGO_HASH = MONAIEnvVars.algo_hash()\n\n__all__ = [\"BundleAlgo\", \"BundleGen\"]\n\n\nclass BundleAlgo(Algo):\n    \"\"\"\n    An algorithm represented by a set of bundle configurations and scripts.\n\n    ``BundleAlgo.cfg`` is a ``monai.bundle.ConfigParser`` instance.\n\n    .. code-block:: python\n\n        from monai.apps.auto3dseg import BundleAlgo\n\n        data_stats_yaml = \"../datastats.yaml\"\n        algo = BundleAlgo(template_path=\"../algorithm_templates\")\n        algo.set_data_stats(data_stats_yaml)\n        # algo.set_data_src(\"../data_src.json\")\n        algo.export_to_disk(\".\", algo_name=\"segresnet2d_1\")\n\n    This class creates MONAI bundles from a directory of 'bundle template'. Different from the regular MONAI bundle\n    format, the bundle template may contain placeholders that must be filled using ``fill_template_config`` during\n    ``export_to_disk``. Then created bundle keeps the same file structure as the template.\n\n    \"\"\"\n\n    def __init__(self, template_path: PathLike):\n        \"\"\"\n        Create an Algo instance based on the predefined Algo template.\n\n        Args:\n            template_path: path to a folder that contains the algorithm templates.\n                Please check https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates\n\n        \"\"\"\n\n        self.template_path = template_path\n        self.data_stats_files = \"\"\n        self.data_list_file = \"\"\n        self.mlflow_tracking_uri: str | None = None\n        self.mlflow_experiment_name: str | None = None\n        self.output_path = \"\"\n        self.name = \"\"\n        self.best_metric = None\n        # track records when filling template config: {\"<config name>\": {\"<placeholder key>\": value, ...}, ...}\n        self.fill_records: dict = {}\n        # device_setting set default value and sanity check, in case device_setting not from autorunner\n        self.device_setting: dict[str, int | str] = {\n            \"CUDA_VISIBLE_DEVICES\": \",\".join([str(x) for x in range(torch.cuda.device_count())]),\n            \"n_devices\": int(torch.cuda.device_count()),\n            \"NUM_NODES\": int(os.environ.get(\"NUM_NODES\", 1)),\n            \"MN_START_METHOD\": os.environ.get(\"MN_START_METHOD\", \"bcprun\"),\n            \"CMD_PREFIX\": os.environ.get(\"CMD_PREFIX\", \"\"),\n        }\n\n    def pre_check_skip_algo(self, skip_bundlegen: bool = False, skip_info: str = \"\") -> tuple[bool, str]:\n        \"\"\"\n        Analyse the data analysis report and check if the algorithm needs to be skipped.\n        This function is overriden within algo.\n        Args:\n            skip_bundlegen: skip generating bundles for this algo if true.\n            skip_info: info to print when skipped.\n        \"\"\"\n        return skip_bundlegen, skip_info\n\n    def set_data_stats(self, data_stats_files: str) -> None:\n        \"\"\"\n        Set the data analysis report (generated by DataAnalyzer).\n\n        Args:\n            data_stats_files: path to the datastats yaml file\n        \"\"\"\n        self.data_stats_files = data_stats_files\n\n    def set_data_source(self, data_src_cfg: str) -> None:\n        \"\"\"\n        Set the data source configuration file\n\n        Args:\n            data_src_cfg: path to a configuration file (yaml) that contains datalist, dataroot, and other params.\n                The config will be in a form of {\"modality\": \"ct\", \"datalist\": \"path_to_json_datalist\", \"dataroot\":\n                \"path_dir_data\"}\n        \"\"\"\n        self.data_list_file = data_src_cfg\n\n    def set_mlflow_tracking_uri(self, mlflow_tracking_uri: str | None) -> None:\n        \"\"\"\n        Set the tracking URI for MLflow server\n\n        Args:\n            mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of\n                the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if\n                the value is None.\n        \"\"\"\n        self.mlflow_tracking_uri = mlflow_tracking_uri\n\n    def set_mlflow_experiment_name(self, mlflow_experiment_name: str | None) -> None:\n        \"\"\"\n        Set the experiment name for MLflow server\n\n        Args:\n            mlflow_experiment_name: a string to specify the experiment name for MLflow server.\n        \"\"\"\n        self.mlflow_experiment_name = mlflow_experiment_name\n\n    def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs: Any) -> dict:\n        \"\"\"\n        The configuration files defined when constructing this Algo instance might not have a complete training\n        and validation pipelines. Some configuration components and hyperparameters of the pipelines depend on the\n        training data and other factors. This API is provided to allow the creation of fully functioning config files.\n        Return the records of filling template config: {\"<config name>\": {\"<placeholder key>\": value, ...}, ...}.\n\n        Args:\n            data_stats_filename: filename of the data stats report (generated by DataAnalyzer)\n\n        Notes:\n            Template filling is optional. The user can construct a set of pre-filled configs without replacing values\n            by using the data analysis results. It is also intended to be re-implemented in subclasses of BundleAlgo\n            if the user wants their own way of auto-configured template filling.\n        \"\"\"\n        return {}\n\n    def export_to_disk(self, output_path: str, algo_name: str, **kwargs: Any) -> None:\n        \"\"\"\n        Fill the configuration templates, write the bundle (configs + scripts) to folder `output_path/algo_name`.\n\n        Args:\n            output_path: Path to export the 'scripts' and 'configs' directories.\n            algo_name: the identifier of the algorithm (usually contains the name and extra info like fold ID).\n            kwargs: other parameters, including: \"copy_dirs=True/False\" means whether to copy the template as output\n                instead of inplace operation, \"fill_template=True/False\" means whether to fill the placeholders\n                in the template. other parameters are for `fill_template_config` function.\n\n        \"\"\"\n        if kwargs.pop(\"copy_dirs\", True):\n            self.output_path = os.path.join(output_path, algo_name)\n            os.makedirs(self.output_path, exist_ok=True)\n            if os.path.isdir(self.output_path):\n                shutil.rmtree(self.output_path)\n            # copy algorithm_templates/<Algo> to the working directory output_path\n            shutil.copytree(os.path.join(str(self.template_path), self.name), self.output_path)\n        else:\n            self.output_path = str(self.template_path)\n        if kwargs.pop(\"fill_template\", True):\n            self.fill_records = self.fill_template_config(self.data_stats_files, self.output_path, **kwargs)\n        logger.info(f\"Generated:{self.output_path}\")\n\n    def _create_cmd(self, train_params: None | dict = None) -> tuple[str, str]:\n        \"\"\"\n        Create the command to execute training.\n\n        \"\"\"\n        if train_params is None:\n            train_params = {}\n        params = deepcopy(train_params)\n\n        train_py = os.path.join(self.output_path, \"scripts\", \"train.py\")\n        config_dir = os.path.join(self.output_path, \"configs\")\n\n        config_files = []\n        if os.path.isdir(config_dir):\n            for file in sorted(os.listdir(config_dir)):\n                if file.endswith((\"yaml\", \"json\")):\n                    # Python Fire may be confused by single-quoted WindowsPath\n                    config_files.append(Path(os.path.join(config_dir, file)).as_posix())\n\n        if int(self.device_setting[\"NUM_NODES\"]) > 1:\n            # multi-node command\n            # only bcprun is supported for now\n            try:\n                look_up_option(self.device_setting[\"MN_START_METHOD\"], [\"bcprun\"])\n            except ValueError as err:\n                raise NotImplementedError(\n                    f\"{self.device_setting['MN_START_METHOD']} is not supported yet.\"\n                    \"Try modify BundleAlgo._create_cmd for your cluster.\"\n                ) from err\n\n            return (\n                _prepare_cmd_bcprun(\n                    f\"{train_py} run\",\n                    cmd_prefix=f\"{self.device_setting['CMD_PREFIX']}\",\n                    config_file=config_files,\n                    **params,\n                ),\n                \"\",\n            )\n        elif int(self.device_setting[\"n_devices\"]) > 1:\n            return _prepare_cmd_torchrun(f\"{train_py} run\", config_file=config_files, **params), \"\"\n        else:\n            return (\n                _prepare_cmd_default(\n                    f\"{train_py} run\",\n                    cmd_prefix=f\"{self.device_setting['CMD_PREFIX']}\",\n                    config_file=config_files,\n                    **params,\n                ),\n                \"\",\n            )\n\n    def _run_cmd(self, cmd: str, devices_info: str = \"\") -> subprocess.CompletedProcess:\n        \"\"\"\n        Execute the training command with target devices information.\n\n        \"\"\"\n        if devices_info:\n            warnings.warn(f\"input devices_info {devices_info} is deprecated and ignored.\")\n\n        ps_environ = os.environ.copy()\n        ps_environ[\"CUDA_VISIBLE_DEVICES\"] = str(self.device_setting[\"CUDA_VISIBLE_DEVICES\"])\n\n        # delete pattern \"VAR=VALUE\" at the beginning of the string, with optional leading/trailing whitespaces\n        cmd = re.sub(r\"^\\s*\\w+=.*?\\s+\", \"\", cmd)\n\n        if int(self.device_setting[\"NUM_NODES\"]) > 1:\n            try:\n                look_up_option(self.device_setting[\"MN_START_METHOD\"], [\"bcprun\"])\n            except ValueError as err:\n                raise NotImplementedError(\n                    f\"{self.device_setting['MN_START_METHOD']} is not supported yet.\"\n                    \"Try modify BundleAlgo._run_cmd for your cluster.\"\n                ) from err\n\n            return _run_cmd_bcprun(cmd, n=self.device_setting[\"NUM_NODES\"], p=self.device_setting[\"n_devices\"])\n        elif int(self.device_setting[\"n_devices\"]) > 1:\n            return _run_cmd_torchrun(\n                cmd, nnodes=1, nproc_per_node=self.device_setting[\"n_devices\"], env=ps_environ, check=True\n            )\n        else:\n            return run_cmd(cmd.split(), run_cmd_verbose=True, env=ps_environ, check=True)\n\n    def train(\n        self, train_params: None | dict = None, device_setting: None | dict = None\n    ) -> subprocess.CompletedProcess:\n        \"\"\"\n        Load the run function in the training script of each model. Training parameter is predefined by the\n        algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.\n\n        Args:\n            train_params:  training parameters\n            device_setting: device related settings, should follow the device_setting in auto_runner.set_device_info.\n                'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'\n        \"\"\"\n        if device_setting is not None:\n            self.device_setting.update(device_setting)\n            self.device_setting[\"n_devices\"] = len(str(self.device_setting[\"CUDA_VISIBLE_DEVICES\"]).split(\",\"))\n\n        if train_params is not None and \"CUDA_VISIBLE_DEVICES\" in train_params:\n            warnings.warn(\"CUDA_VISIBLE_DEVICES is deprecated from train_params!\")\n            train_params.pop(\"CUDA_VISIBLE_DEVICES\")\n\n        cmd, _unused_return = self._create_cmd(train_params)\n        return self._run_cmd(cmd)\n\n    def get_score(self, *args, **kwargs):\n        \"\"\"\n        Returns validation scores of the model trained by the current Algo.\n        \"\"\"\n        config_yaml = os.path.join(self.output_path, \"configs\", \"hyper_parameters.yaml\")\n        parser = ConfigParser()\n        parser.read_config(config_yaml)\n        ckpt_path = parser.get_parsed_content(\"ckpt_path\", default=self.output_path)\n\n        dict_file = ConfigParser.load_config_file(os.path.join(ckpt_path, \"progress.yaml\"))\n        # dict_file: a list of scores saved in the form of dict in progress.yaml\n        return dict_file[-1][\"best_avg_dice_score\"]  # the last one is the best one\n\n    def get_inferer(self, *args, **kwargs):\n        \"\"\"\n        Load the InferClass from the infer.py. The InferClass should be defined in the template under the path of\n        `\"scripts/infer.py\"`. It is required to define the \"InferClass\" (name is fixed) with two functions at least\n        (``__init__`` and ``infer``). The init class has an override kwargs that can be used to override parameters in\n        the run-time optionally.\n\n        Examples:\n\n        .. code-block:: python\n\n            class InferClass\n                def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **override):\n                    # read configs from config_file (sequence)\n                    # set up transforms\n                    # set up model\n                    # set up other hyper parameters\n                    return\n\n                @torch.no_grad()\n                def infer(self, image_file):\n                    # infer the model and save the results to output\n                    return output\n\n        \"\"\"\n        infer_py = os.path.join(self.output_path, \"scripts\", \"infer.py\")\n        if not os.path.isfile(infer_py):\n            raise ValueError(f\"{infer_py} is not found, please check the path.\")\n\n        config_dir = os.path.join(self.output_path, \"configs\")\n        configs_path = [os.path.join(config_dir, f) for f in os.listdir(config_dir)]\n\n        spec = importlib.util.spec_from_file_location(\"InferClass\", infer_py)\n        infer_class = importlib.util.module_from_spec(spec)  # type: ignore\n        sys.modules[\"InferClass\"] = infer_class\n        spec.loader.exec_module(infer_class)  # type: ignore\n        return infer_class.InferClass(configs_path, *args, **kwargs)\n\n    def predict(self, predict_files: list, predict_params: dict | None = None) -> list:\n        \"\"\"\n        Use the trained model to predict the outputs with a given input image.\n\n        Args:\n            predict_files: a list of paths to files to run inference on [\"path_to_image_1\", \"path_to_image_2\"]\n            predict_params: a dict to override the parameters in the bundle config (including the files to predict).\n\n        \"\"\"\n        params = {} if predict_params is None else deepcopy(predict_params)\n        inferer = self.get_inferer(**params)\n        return [inferer.infer(f) for f in ensure_tuple(predict_files)]\n\n    def get_output_path(self):\n        \"\"\"Returns the algo output paths to find the algo scripts and configs.\"\"\"\n        return self.output_path\n\n\n# path to download the algo_templates\ndefault_algo_zip = (\n    f\"https://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/{ALGO_HASH}.tar.gz\"\n)\n\n# default algorithms\ndefault_algos = {\n    \"segresnet2d\": dict(_target_=\"segresnet2d.scripts.algo.Segresnet2dAlgo\"),\n    \"dints\": dict(_target_=\"dints.scripts.algo.DintsAlgo\"),\n    \"swinunetr\": dict(_target_=\"swinunetr.scripts.algo.SwinunetrAlgo\"),\n    \"segresnet\": dict(_target_=\"segresnet.scripts.algo.SegresnetAlgo\"),\n}\n\n\ndef _download_algos_url(url: str, at_path: str) -> dict[str, dict[str, str]]:\n    \"\"\"\n    Downloads the algorithm templates release archive, and extracts it into a parent directory of the at_path folder.\n    Returns a dictionary of the algorithm templates.\n    \"\"\"\n    at_path = os.path.abspath(at_path)\n    zip_download_dir = TemporaryDirectory()\n    algo_compressed_file = os.path.join(zip_download_dir.name, \"algo_templates.tar.gz\")\n\n    download_attempts = 3\n    for i in range(download_attempts):\n        try:\n            download_and_extract(url=url, filepath=algo_compressed_file, output_dir=os.path.dirname(at_path))\n        except Exception as e:\n            msg = f\"Download and extract of {url} failed, attempt {i + 1}/{download_attempts}.\"\n            if i < download_attempts - 1:\n                warnings.warn(msg)\n                time.sleep(i)\n            else:\n                zip_download_dir.cleanup()\n                raise ValueError(msg) from e\n        else:\n            break\n\n    zip_download_dir.cleanup()\n\n    algos_all = deepcopy(default_algos)\n    for name in algos_all:\n        algos_all[name][\"template_path\"] = at_path\n\n    return algos_all\n\n\ndef _copy_algos_folder(folder, at_path):\n    \"\"\"\n    Copies the algorithm templates folder to at_path.\n    Returns a dictionary of algorithm templates.\n    \"\"\"\n    folder = os.path.abspath(folder)\n    at_path = os.path.abspath(at_path)\n\n    if folder != at_path:\n        if os.path.exists(at_path):\n            shutil.rmtree(at_path)\n        shutil.copytree(folder, at_path)\n\n    algos_all = {}\n    for name in os.listdir(at_path):\n        if os.path.exists(os.path.join(folder, name, \"scripts\", \"algo.py\")):\n            algos_all[name] = dict(_target_=f\"{name}.scripts.algo.{name.capitalize()}Algo\", template_path=at_path)\n            logger.info(f\"Copying template: {name} -- {algos_all[name]}\")\n    if not algos_all:\n        raise ValueError(f\"Unable to find any algos in {folder}\")\n\n    return algos_all\n\n\nclass BundleGen(AlgoGen):\n    \"\"\"\n    This class generates a set of bundles according to the cross-validation folds, each of them can run independently.\n\n    Args:\n        algo_path: the directory path to save the algorithm templates. Default is the current working dir.\n        algos: If dictionary, it outlines the algorithm to use. If a list or a string, defines a subset of names of\n            the algorithms to use, e.g. ('segresnet', 'dints') out of the full set of algorithm templates provided\n            by templates_path_or_url. Defaults to None - to use all available algorithms.\n        templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template\n            zip url will be downloaded and extracted into the algo_path. The current default options are released at:\n            https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg.\n        data_stats_filename: the path to the data stats file (generated by DataAnalyzer).\n        data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of\n                           {\"modality\": \"ct\", \"datalist\": \"path_to_json_datalist\", \"dataroot\": \"path_dir_data\"}.\n        mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of\n            the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if\n            the value is None.\n        mlfow_experiment_name: a string to specify the experiment name for MLflow server.\n    .. code-block:: bash\n\n        python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename=\"../algorithms/datastats.yaml\"\n    \"\"\"\n\n    def __init__(\n        self,\n        algo_path: str = \".\",\n        algos: dict | list | str | None = None,\n        templates_path_or_url: str | None = None,\n        data_stats_filename: str | None = None,\n        data_src_cfg_name: str | None = None,\n        mlflow_tracking_uri: str | None = None,\n        mlflow_experiment_name: str | None = None,\n    ):\n        if algos is None or isinstance(algos, (list, tuple, str)):\n            if templates_path_or_url is None:\n                templates_path_or_url = default_algo_zip\n\n            at_path = os.path.join(os.path.abspath(algo_path), \"algorithm_templates\")\n\n            if os.path.isdir(templates_path_or_url):\n                # if a local folder, copy if necessary\n                logger.info(f\"BundleGen from directory {templates_path_or_url}\")\n                algos_all = _copy_algos_folder(folder=templates_path_or_url, at_path=at_path)\n            elif urlparse(templates_path_or_url).scheme in (\"http\", \"https\"):\n                # if url, trigger the download and extract process\n                logger.info(f\"BundleGen from {templates_path_or_url}\")\n                algos_all = _download_algos_url(url=templates_path_or_url, at_path=at_path)\n            else:\n                raise ValueError(f\"{self.__class__} received invalid templates_path_or_url: {templates_path_or_url}\")\n\n            if algos is not None:\n                algos = {k: v for k, v in algos_all.items() if k in ensure_tuple(algos)}  # keep only provided\n                if len(algos) == 0:\n                    raise ValueError(f\"Unable to find provided algos in {algos_all}\")\n            else:\n                algos = algos_all\n\n        self.algos: Any = []\n        if isinstance(algos, dict):\n            for algo_name, algo_params in sorted(algos.items()):\n                template_path = algo_params.get(\"template_path\", \".\")\n                if len(template_path) > 0 and template_path not in sys.path:\n                    sys.path.append(template_path)\n\n                try:\n                    onealgo = ConfigParser(algo_params).get_parsed_content()\n                    onealgo.name = algo_name\n                    self.algos.append(onealgo)\n                except RuntimeError as e:\n                    msg = \"\"\"Please make sure the folder structure of an Algo Template follows\n                        [algo_name]\n                        ├── configs\n                        │   ├── hyper_parameters.yaml  # automatically generated yaml from a set of ``template_configs``\n                        └── scripts\n                            ├── test.py\n                            ├── __init__.py\n                            └── validate.py\n                    \"\"\"\n                    raise RuntimeError(msg) from e\n        else:\n            raise ValueError(\"Unexpected error algos is not a dict\")\n\n        self.data_stats_filename = data_stats_filename\n        self.data_src_cfg_name = data_src_cfg_name\n        self.mlflow_tracking_uri = mlflow_tracking_uri\n        self.mlflow_experiment_name = mlflow_experiment_name\n        self.history: list[dict] = []\n\n    def set_data_stats(self, data_stats_filename: str) -> None:\n        \"\"\"\n        Set the data stats filename\n\n        Args:\n            data_stats_filename: filename of datastats\n        \"\"\"\n        self.data_stats_filename = data_stats_filename\n\n    def get_data_stats(self):\n        \"\"\"Get the filename of the data stats\"\"\"\n        return self.data_stats_filename\n\n    def set_data_src(self, data_src_cfg_name):\n        \"\"\"\n        Set the data source filename\n\n        Args:\n            data_src_cfg_name: filename of data_source file\n        \"\"\"\n        self.data_src_cfg_name = data_src_cfg_name\n\n    def get_data_src(self):\n        \"\"\"Get the data source filename\"\"\"\n        return self.data_src_cfg_name\n\n    def set_mlflow_tracking_uri(self, mlflow_tracking_uri):\n        \"\"\"\n        Set the tracking URI for MLflow server\n\n        Args:\n            mlflow_tracking_uri: a tracking URI for MLflow server which could be local directory or address of\n                the remote tracking Server; MLflow runs will be recorded locally in algorithms' model folder if\n                the value is None.\n        \"\"\"\n        self.mlflow_tracking_uri = mlflow_tracking_uri\n\n    def set_mlflow_experiment_name(self, mlflow_experiment_name):\n        \"\"\"\n        Set the experiment name for MLflow server\n\n        Args:\n            mlflow_experiment_name: a string to specify the experiment name for MLflow server.\n        \"\"\"\n        self.mlflow_experiment_name = mlflow_experiment_name\n\n    def get_mlflow_tracking_uri(self):\n        \"\"\"Get the tracking URI for MLflow server\"\"\"\n        return self.mlflow_tracking_uri\n\n    def get_mlflow_experiment_name(self):\n        \"\"\"Get the experiment name for MLflow server\"\"\"\n        return self.mlflow_experiment_name\n\n    def get_history(self) -> list:\n        \"\"\"Get the history of the bundleAlgo object with their names/identifiers\"\"\"\n        return self.history\n\n    def generate(\n        self,\n        output_folder: str = \".\",\n        num_fold: int = 5,\n        gpu_customization: bool = False,\n        gpu_customization_specs: dict[str, Any] | None = None,\n        allow_skip: bool = True,\n    ) -> None:\n        \"\"\"\n        Generate the bundle scripts/configs for each bundleAlgo\n\n        Args:\n            output_folder: the output folder to save each algorithm.\n            num_fold: the number of cross validation fold.\n            gpu_customization: the switch to determine automatically customize/optimize bundle script/config\n                parameters for each bundleAlgo based on gpus. Custom parameters are obtained through dummy\n                training to simulate the actual model training process and hyperparameter optimization (HPO)\n                experiments.\n            gpu_customization_specs: the dictionary to enable users overwrite the HPO settings. user can\n                overwrite part of variables as follows or all of them. The structure is as follows.\n            allow_skip: a switch to determine if some Algo in the default templates can be skipped based on the\n                analysis on the dataset from Auto3DSeg DataAnalyzer.\n\n                .. code-block:: python\n\n                    gpu_customization_specs = {\n                        'ALGO': {\n                            'num_trials': 6,\n                            'range_num_images_per_batch': [1, 20],\n                            'range_num_sw_batch_size': [1, 20]\n                        }\n                    }\n\n            ALGO: the name of algorithm. It could be one of algorithm names (e.g., 'dints') or 'universal' which\n                would apply changes to all algorithms. Possible options are\n\n                - {``\"universal\"``, ``\"dints\"``, ``\"segresnet\"``, ``\"segresnet2d\"``, ``\"swinunetr\"``}.\n\n            num_trials: the number of HPO trials/experiments to run.\n            range_num_images_per_batch: the range of number of images per mini-batch.\n            range_num_sw_batch_size: the range of batch size in sliding-window inferer.\n        \"\"\"\n        fold_idx = list(range(num_fold))\n        for algo in self.algos:\n            for f_id in ensure_tuple(fold_idx):\n                data_stats = self.get_data_stats()\n                data_src_cfg = self.get_data_src()\n                mlflow_tracking_uri = self.get_mlflow_tracking_uri()\n                mlflow_experiment_name = self.get_mlflow_experiment_name()\n                gen_algo = deepcopy(algo)\n                gen_algo.set_data_stats(data_stats)\n                gen_algo.set_data_source(data_src_cfg)\n                gen_algo.set_mlflow_tracking_uri(mlflow_tracking_uri)\n                gen_algo.set_mlflow_experiment_name(mlflow_experiment_name)\n                name = f\"{gen_algo.name}_{f_id}\"\n\n                if allow_skip:\n                    skip_bundlegen, skip_info = gen_algo.pre_check_skip_algo()\n                    if skip_bundlegen:\n                        logger.info(f\"{name} is skipped! {skip_info}\")\n                        continue\n\n                if gpu_customization:\n                    gen_algo.export_to_disk(\n                        output_folder,\n                        name,\n                        fold=f_id,\n                        gpu_customization=True,\n                        gpu_customization_specs=gpu_customization_specs,\n                    )\n                else:\n                    gen_algo.export_to_disk(output_folder, name, fold=f_id)\n\n                algo_to_pickle(gen_algo, template_path=algo.template_path)\n                self.history.append(\n                    {AlgoKeys.ID: name, AlgoKeys.ALGO: gen_algo}\n                )  # track the previous, may create a persistent history\n"
  },
  {
    "path": "monai/apps/auto3dseg/data_analyzer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom os import path\nfrom typing import Any, cast\n\nimport numpy as np\nimport torch\nfrom torch.multiprocessing import get_context\n\nfrom monai.apps.auto3dseg.transforms import EnsureSameShaped\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg import SegSummarizer\nfrom monai.auto3dseg.utils import datafold_read\nfrom monai.bundle import config_parser\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import DataLoader, Dataset, partition_dataset\nfrom monai.data.utils import no_collation\nfrom monai.transforms import Compose, EnsureTyped, LoadImaged, Orientationd\nfrom monai.utils import ImageMetaKey, StrEnum, min_version, optional_import\nfrom monai.utils.enums import DataStatsKeys, ImageStatsKeys\n\n\ndef strenum_representer(dumper, data):\n    return dumper.represent_scalar(\"tag:yaml.org,2002:str\", data.value)\n\n\nif optional_import(\"yaml\")[1]:\n    config_parser.yaml.SafeDumper.add_multi_representer(StrEnum, strenum_representer)\n\ntqdm, has_tqdm = optional_import(\"tqdm\", \"4.47.0\", min_version, \"tqdm\")\nlogger = get_logger(module_name=__name__)\n\n__all__ = [\"DataAnalyzer\"]\n\n\nclass DataAnalyzer:\n    \"\"\"\n    The DataAnalyzer automatically analyzes given medical image dataset and reports the statistics.\n    The module expects file paths to the image data and utilizes the LoadImaged transform to read the\n    files, which supports nii, nii.gz, png, jpg, bmp, npz, npy, and dcm formats. Currently, only\n    segmentation task is supported, so the user needs to provide paths to the image and label files\n    (if have). Also, label data format is preferred to be (1,H,W,D), with the label index in the\n    first dimension. If it is in onehot format, it will be converted to the preferred format.\n\n    Args:\n        datalist: a Python dictionary storing group, fold, and other information of the medical\n            image dataset, or a string to the JSON file storing the dictionary.\n        dataroot: user's local directory containing the datasets.\n        output_path: path to save the analysis result.\n        average: whether to average the statistical value across different image modalities.\n        do_ccp: apply the connected component algorithm to process the labels/images\n        device: a string specifying hardware (CUDA/CPU) utilized for the operations.\n        worker: number of workers to use for loading datasets in each GPU/CPU sub-process.\n        image_key: a string that user specify for the image. The DataAnalyzer will look it up in the\n            datalist to locate the image files of the dataset.\n        label_key: a string that user specify for the label. The DataAnalyzer will look it up in the\n            datalist to locate the label files of the dataset. If label_key is NoneType or \"None\",\n            the DataAnalyzer will skip looking for labels and all label-related operations.\n        hist_bins: bins to compute histogram for each image channel.\n        hist_range: ranges to compute histogram for each image channel.\n        fmt: format used to save the analysis results. Currently support ``\"json\"`` and ``\"yaml\"``, defaults to \"yaml\".\n        histogram_only: whether to only compute histograms. Defaults to False.\n        extra_params: other optional arguments. Currently supported arguments are :\n            'allowed_shape_difference' (default 5) can be used to change the default tolerance of\n            the allowed shape differences between the image and label items. In case of shape mismatch below\n            the tolerance, the label image will be resized to match the image using nearest interpolation.\n\n\n    Examples:\n        .. code-block:: python\n\n            from monai.apps.auto3dseg.data_analyzer import DataAnalyzer\n\n            datalist = {\n                \"testing\": [{\"image\": \"image_003.nii.gz\"}],\n                \"training\": [\n                    {\"fold\": 0, \"image\": \"image_001.nii.gz\", \"label\": \"label_001.nii.gz\"},\n                    {\"fold\": 0, \"image\": \"image_002.nii.gz\", \"label\": \"label_002.nii.gz\"},\n                    {\"fold\": 1, \"image\": \"image_001.nii.gz\", \"label\": \"label_001.nii.gz\"},\n                    {\"fold\": 1, \"image\": \"image_004.nii.gz\", \"label\": \"label_004.nii.gz\"},\n                ],\n            }\n\n            dataroot = '/datasets' # the directory where you have the image files (nii.gz)\n            DataAnalyzer(datalist, dataroot)\n\n    Notes:\n        The module can also be called from the command line interface (CLI).\n\n    For example:\n\n    .. code-block:: bash\n\n        python -m monai.apps.auto3dseg \\\\\n            DataAnalyzer \\\\\n            get_all_case_stats \\\\\n            --datalist=\"my_datalist.json\" \\\\\n            --dataroot=\"my_dataroot_dir\"\n\n    \"\"\"\n\n    def __init__(\n        self,\n        datalist: str | dict,\n        dataroot: str = \"\",\n        output_path: str = \"./datastats.yaml\",\n        average: bool = True,\n        do_ccp: bool = False,\n        device: str | torch.device = \"cuda\",\n        worker: int = 4,\n        image_key: str = \"image\",\n        label_key: str | None = \"label\",\n        hist_bins: list | int | None = 0,\n        hist_range: list | None = None,\n        fmt: str = \"yaml\",\n        histogram_only: bool = False,\n        **extra_params: Any,\n    ):\n        if path.isfile(output_path):\n            warnings.warn(f\"File {output_path} already exists and will be overwritten.\")\n            logger.debug(f\"{output_path} will be overwritten by a new datastat.\")\n\n        self.datalist = datalist\n        self.dataroot = dataroot\n        self.output_path = output_path\n        self.average = average\n        self.do_ccp = do_ccp\n        self.device = torch.device(device)\n        self.worker = worker\n        self.image_key = image_key\n        self.label_key = None if label_key == \"None\" else label_key\n        self.hist_bins = hist_bins\n        self.hist_range: list = [-500, 500] if hist_range is None else hist_range\n        self.fmt = fmt\n        self.histogram_only = histogram_only\n        self.extra_params = extra_params\n\n    @staticmethod\n    def _check_data_uniformity(keys: list[str], result: dict) -> bool:\n        \"\"\"\n        Check data uniformity since DataAnalyzer provides no support to multi-modal images with different\n        affine matrices/spacings due to monai transforms.\n\n        Args:\n            keys: a list of string-type keys under image_stats dictionary.\n\n        Returns:\n            False if one of the selected key values is not constant across the dataset images.\n\n        \"\"\"\n\n        if DataStatsKeys.SUMMARY not in result or DataStatsKeys.IMAGE_STATS not in result[DataStatsKeys.SUMMARY]:\n            return True\n        constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys]\n        for prop in constant_props:\n            if \"stdev\" in prop and np.any(prop[\"stdev\"]):\n                logger.debug(f\"summary image_stats {prop} has non-zero stdev {prop['stdev']}.\")\n                return False\n\n        return True\n\n    def get_all_case_stats(self, key=\"training\", transform_list=None):\n        \"\"\"\n        Get all case stats. Caller of the DataAnalyser class. The function initiates multiple GPU or CPU processes of the internal\n        _get_all_case_stats functions, which iterates datalist and call SegSummarizer to generate stats for each case.\n        After all case stats are generated, SegSummarizer is called to combine results.\n\n        Args:\n            key: dataset key\n            transform_list: option list of transforms before SegSummarizer\n\n        Returns:\n            A data statistics dictionary containing\n                \"stats_summary\" (summary statistics of the entire datasets). Within stats_summary\n                there are \"image_stats\"  (summarizing info of shape, channel, spacing, and etc\n                using operations_summary), \"image_foreground_stats\" (info of the intensity for the\n                non-zero labeled voxels), and \"label_stats\" (info of the labels, pixel percentage,\n                image_intensity, and each individual label in a list)\n                \"stats_by_cases\" (List type value. Each element of the list is statistics of\n                an image-label info. Within each element, there are: \"image\" (value is the\n                path to an image), \"label\" (value is the path to the corresponding label), \"image_stats\"\n                (summarizing info of shape, channel, spacing, and etc using operations),\n                \"image_foreground_stats\" (similar to the previous one but one foreground image), and\n                \"label_stats\" (stats of the individual labels )\n\n        Notes:\n            Since the backend of the statistics computation are torch/numpy, nan/inf value\n            may be generated and carried over in the computation. In such cases, the output\n            dictionary will include .nan/.inf in the statistics.\n\n        \"\"\"\n        result: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}\n        result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}\n        if self.device.type == \"cpu\":\n            nprocs = 1\n            logger.info(\"Using CPU for data analyzing!\")\n        else:\n            nprocs = torch.cuda.device_count()\n            logger.info(f\"Found {nprocs} GPUs for data analyzing!\")\n        if nprocs > 1:\n            tmp_ctx: Any = get_context(\"forkserver\")\n            with tmp_ctx.Manager() as manager:\n                manager_list = manager.list()\n                processes = []\n                for rank in range(nprocs):\n                    p = tmp_ctx.Process(\n                        target=self._get_all_case_stats, args=(rank, nprocs, manager_list, key, transform_list)\n                    )\n                    processes.append(p)\n                for p in processes:\n                    p.start()\n                for p in processes:\n                    p.join()\n                # merge DataStatsKeys.BY_CASE\n                for _ in manager_list:\n                    result_bycase[DataStatsKeys.BY_CASE].extend(_[DataStatsKeys.BY_CASE])\n        else:\n            result_bycase = self._get_all_case_stats(0, 1, None, key, transform_list)\n\n        summarizer = SegSummarizer(\n            self.image_key,\n            self.label_key,\n            average=self.average,\n            do_ccp=self.do_ccp,\n            hist_bins=self.hist_bins,\n            hist_range=self.hist_range,\n            histogram_only=self.histogram_only,\n        )\n        n_cases = len(result_bycase[DataStatsKeys.BY_CASE])\n        result[DataStatsKeys.SUMMARY] = summarizer.summarize(cast(list, result_bycase[DataStatsKeys.BY_CASE]))\n        result[DataStatsKeys.SUMMARY][\"n_cases\"] = n_cases\n        result_bycase[DataStatsKeys.SUMMARY] = result[DataStatsKeys.SUMMARY]\n        if not self._check_data_uniformity([ImageStatsKeys.SPACING], result):\n            logger.info(\"Data spacing is not completely uniform. MONAI transforms may provide unexpected result\")\n        if self.output_path:\n            logger.info(f\"Writing data stats to {self.output_path}.\")\n            ConfigParser.export_config_file(\n                result, self.output_path, fmt=self.fmt, default_flow_style=None, sort_keys=False\n            )\n            by_case_path = self.output_path.replace(f\".{self.fmt}\", f\"_by_case.{self.fmt}\")\n            if by_case_path == self.output_path:  # self.output_path not ended with self.fmt?\n                by_case_path += f\".by_case.{self.fmt}\"\n            logger.info(f\"Writing by-case data stats to {by_case_path}, this may take a while.\")\n            ConfigParser.export_config_file(\n                result_bycase, by_case_path, fmt=self.fmt, default_flow_style=None, sort_keys=False\n            )\n        # release memory\n        if self.device.type == \"cuda\":\n            # release unreferenced tensors to mitigate OOM\n            # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237\n            torch.cuda.empty_cache()\n        result[DataStatsKeys.BY_CASE] = result_bycase[DataStatsKeys.BY_CASE]\n        return result\n\n    def _get_all_case_stats(\n        self,\n        rank: int = 0,\n        world_size: int = 1,\n        manager_list: list | None = None,\n        key: str = \"training\",\n        transform_list: list | None = None,\n    ) -> Any:\n        \"\"\"\n        Get all case stats from a partitioned datalist. The function can only be called internally by get_all_case_stats.\n        Args:\n            rank: GPU process rank, 0 for CPU process\n            world_size: total number of GPUs, 1 for CPU process\n            manager_list: multiprocessing manager list object, if using multi-GPU.\n            key: dataset key\n            transform_list: option list of transforms before SegSummarizer\n        \"\"\"\n        summarizer = SegSummarizer(\n            self.image_key,\n            self.label_key,\n            average=self.average,\n            do_ccp=self.do_ccp,\n            hist_bins=self.hist_bins,\n            hist_range=self.hist_range,\n            histogram_only=self.histogram_only,\n        )\n        keys = list(filter(None, [self.image_key, self.label_key]))\n        if transform_list is None:\n            transform_list = [\n                LoadImaged(keys=keys, ensure_channel_first=True, image_only=True),\n                EnsureTyped(keys=keys, data_type=\"tensor\", dtype=torch.float),\n                Orientationd(keys=keys, axcodes=\"RAS\"),\n            ]\n            if self.label_key is not None:\n                allowed_shape_difference = self.extra_params.pop(\"allowed_shape_difference\", 5)\n                transform_list.append(\n                    EnsureSameShaped(\n                        keys=self.label_key,\n                        source_key=self.image_key,\n                        allowed_shape_difference=allowed_shape_difference,\n                    )\n                )\n\n        transform = Compose(transform_list)\n        files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key)\n        if world_size <= len(files):\n            files = partition_dataset(data=files, num_partitions=world_size)[rank]\n        else:\n            files = partition_dataset(data=files, num_partitions=len(files))[rank] if rank < len(files) else []\n        dataset = Dataset(data=files, transform=transform)\n        dataloader = DataLoader(\n            dataset,\n            batch_size=1,\n            shuffle=False,\n            num_workers=self.worker,\n            collate_fn=no_collation,\n            pin_memory=self.device.type == \"cuda\",\n        )\n        result_bycase: dict[DataStatsKeys, Any] = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}\n        device = self.device if self.device.type == \"cpu\" else torch.device(\"cuda\", rank)\n        if device.type == \"cuda\" and not (torch.cuda.is_available() and torch.cuda.device_count() > 0):\n            logger.info(f\"device={device} but CUDA device is not available, using CPU instead.\")\n            device = torch.device(\"cpu\")\n        if not has_tqdm:\n            warnings.warn(\"tqdm is not installed. not displaying the caching progress.\")\n\n        for batch_data in tqdm(dataloader) if (has_tqdm and rank == 0) else dataloader:\n            batch_data = batch_data[0]\n            try:\n                batch_data[self.image_key] = batch_data[self.image_key].to(device)\n                _label_argmax = False\n                if self.label_key is not None:\n                    label = batch_data[self.label_key]\n                    label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]\n                    _label_argmax = True  # track if label is argmaxed\n                    batch_data[self.label_key] = label.to(device)\n                d = summarizer(batch_data)\n            except BaseException as err:\n                if \"image_meta_dict\" in batch_data.keys():\n                    filename = batch_data[\"image_meta_dict\"][ImageMetaKey.FILENAME_OR_OBJ]\n                else:\n                    filename = batch_data[self.image_key].meta[ImageMetaKey.FILENAME_OR_OBJ]\n                logger.info(f\"Unable to process data {filename} on {device}. {err}\")\n                if self.device.type == \"cuda\":\n                    logger.info(\"DataAnalyzer `device` set to GPU execution hit an exception. Falling back to `cpu`.\")\n                    try:\n                        batch_data[self.image_key] = batch_data[self.image_key].to(\"cpu\")\n                        if self.label_key is not None:\n                            label = batch_data[self.label_key]\n                            if not _label_argmax:\n                                label = torch.argmax(label, dim=0) if label.shape[0] > 1 else label[0]\n                            batch_data[self.label_key] = label.to(\"cpu\")\n                        d = summarizer(batch_data)\n                    except BaseException as err:\n                        logger.info(f\"Unable to process data {filename} on {device}. {err}\")\n                        continue\n                else:\n                    continue\n\n            stats_by_cases = {\n                DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],\n                DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],\n            }\n            if not self.histogram_only:\n                stats_by_cases[DataStatsKeys.IMAGE_STATS] = d[DataStatsKeys.IMAGE_STATS]\n            if self.hist_bins != 0:\n                stats_by_cases[DataStatsKeys.IMAGE_HISTOGRAM] = d[DataStatsKeys.IMAGE_HISTOGRAM]\n\n            if self.label_key is not None:\n                stats_by_cases.update(\n                    {\n                        DataStatsKeys.FG_IMAGE_STATS: d[DataStatsKeys.FG_IMAGE_STATS],\n                        DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS],\n                    }\n                )\n            result_bycase[DataStatsKeys.BY_CASE].append(stats_by_cases)\n        if manager_list is None:\n            return result_bycase\n        else:\n            manager_list.append(result_bycase)\n"
  },
  {
    "path": "monai/apps/auto3dseg/ensemble_builder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Mapping, Sequence\nfrom copy import deepcopy\nfrom typing import Any, cast\nfrom warnings import warn\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom monai.apps.auto3dseg.bundle_gen import BundleAlgo\nfrom monai.apps.auto3dseg.utils import get_name_from_algo_id, import_bundle_algo_history\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg import concat_val_to_np\nfrom monai.auto3dseg.utils import (\n    _prepare_cmd_bcprun,\n    _prepare_cmd_torchrun,\n    _run_cmd_bcprun,\n    _run_cmd_torchrun,\n    datafold_read,\n)\nfrom monai.bundle import ConfigParser\nfrom monai.data import partition_dataset\nfrom monai.transforms import MeanEnsemble, SaveImage, VoteEnsemble\nfrom monai.utils import RankFilter\nfrom monai.utils.enums import AlgoKeys\nfrom monai.utils.misc import check_kwargs_exist_in_class_init, prob2class\nfrom monai.utils.module import look_up_option, optional_import\n\ntqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\n\nlogger = get_logger(module_name=__name__)\n\n\nclass AlgoEnsemble(ABC):\n    \"\"\"\n    The base class of Ensemble methods\n    \"\"\"\n\n    def __init__(self):\n        self.algos = []\n        self.mode = \"mean\"\n        self.infer_files = []\n        self.algo_ensemble = []\n\n    def set_algos(self, infer_algos):\n        \"\"\"\n        Register model in the ensemble\n        \"\"\"\n        self.algos = deepcopy(infer_algos)\n\n    def get_algo(self, identifier):\n        \"\"\"\n        Get a model by identifier.\n\n        Args:\n            identifier: the name of the bundleAlgo\n        \"\"\"\n        for algo in self.algos:\n            if identifier == algo[AlgoKeys.ID]:\n                return algo\n\n    def get_algo_ensemble(self):\n        \"\"\"\n        Get the algo ensemble after ranking or a empty list if ranking was not started.\n\n        Returns:\n            A list of Algo\n        \"\"\"\n        return self.algo_ensemble\n\n    def set_infer_files(self, dataroot: str, data_list_or_path: str | list, data_key: str = \"testing\") -> None:\n        \"\"\"\n        Set the files to perform model inference.\n\n        Args:\n            dataroot: the path of the files\n            data_list_or_path: the data source file path\n        \"\"\"\n\n        self.infer_files = []\n\n        if isinstance(data_list_or_path, list):\n            self.infer_files = data_list_or_path\n        elif isinstance(data_list_or_path, str):\n            datalist = ConfigParser.load_config_file(data_list_or_path)\n            if data_key in datalist:\n                self.infer_files, _ = datafold_read(datalist=datalist, basedir=dataroot, fold=-1, key=data_key)\n            elif not hasattr(self, \"rank\") or self.rank == 0:\n                logger.info(f\"Datalist file has no testing key - {data_key}. No data for inference is specified\")\n\n        else:\n            raise ValueError(\"Unsupported parameter type\")\n\n    def ensemble_pred(self, preds, sigmoid=False):\n        \"\"\"\n        ensemble the results using either \"mean\" or \"vote\" method\n\n        Args:\n            preds: a list of probability prediction in Tensor-Like format.\n            sigmoid: use the sigmoid function to threshold probability one-hot map,\n                otherwise argmax is used. Defaults to False\n\n        Returns:\n            a tensor which is the ensembled prediction.\n        \"\"\"\n\n        if any(not p.is_cuda for p in preds):\n            preds = [p.cpu() for p in preds]  # ensure CPU if at least one is on CPU\n\n        if self.mode == \"mean\":\n            prob = MeanEnsemble()(preds)\n            return prob2class(cast(torch.Tensor, prob), dim=0, keepdim=True, sigmoid=sigmoid)\n        elif self.mode == \"vote\":\n            classes = [prob2class(p, dim=0, keepdim=True, sigmoid=sigmoid) for p in preds]\n            if sigmoid:\n                return VoteEnsemble()(classes)  # do not specify num_classes for one-hot encoding\n            else:\n                return VoteEnsemble(num_classes=preds[0].shape[0])(classes)\n\n    def _apply_algo_specific_param(self, algo_spec_param: dict, param: dict, algo_name: str) -> dict:\n        \"\"\"\n        Apply the model-specific params to the prediction params based on the name of the Algo.\n\n        Args:\n            algo_spec_param: a dict that has structure of {\"<name of algo>\": \"<pred_params for that algo>\"}.\n            param: the prediction params to override.\n            algo_name: name of the Algo\n\n        Returns:\n            param after being updated with the model-specific param\n        \"\"\"\n        _param_to_override = deepcopy(algo_spec_param)\n        _param = deepcopy(param)\n        for k, v in _param_to_override.items():\n            if k.lower() == algo_name.lower():\n                _param.update(v)\n        return _param\n\n    def __call__(self, pred_param: dict | None = None) -> list:\n        \"\"\"\n        Use the ensembled model to predict result.\n\n        Args:\n            pred_param: prediction parameter dictionary. The key has two groups: the first one will be consumed\n                in this function, and the second group will be passed to the `InferClass` to override the\n                parameters of the class functions.\n                The first group contains:\n\n                    - ``\"infer_files\"``: file paths to the images to read in a list.\n                    - ``\"files_slices\"``: a value type of `slice`. The files_slices will slice the ``\"infer_files\"`` and\n                      only make prediction on the infer_files[file_slices].\n                    - ``\"mode\"``: ensemble mode. Currently \"mean\" and \"vote\" (majority voting) schemes are supported.\n                    - ``\"image_save_func\"``: a dictionary used to instantiate the ``SaveImage`` transform. When specified,\n                      the ensemble prediction will save the prediction files, instead of keeping the files in the memory.\n                      Example: `{\"_target_\": \"SaveImage\", \"output_dir\": \"./\"}`\n                    - ``\"sigmoid\"``: use the sigmoid function (e.g. x > 0.5) to convert the prediction probability map\n                      to the label class prediction, otherwise argmax(x) is used.\n                    - ``\"algo_spec_params\"``: a dictionary to add pred_params that are specific to a model.\n                      The dict has a format of {\"<name of algo>\": \"<pred_params for that algo>\"}.\n\n                The parameters in the second group is defined in the ``config`` of each Algo templates. Please check:\n                https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates\n\n        Returns:\n            A list of tensors or file paths, depending on whether ``\"image_save_func\"`` is set.\n        \"\"\"\n        param = {} if pred_param is None else deepcopy(pred_param)\n        files = self.infer_files\n\n        if \"infer_files\" in param:\n            files = param.pop(\"infer_files\")\n\n        if \"files_slices\" in param:\n            slices = param.pop(\"files_slices\")\n            files = files[slices]\n\n        if \"mode\" in param:\n            mode = param.pop(\"mode\")\n            self.mode = look_up_option(mode, supported=[\"mean\", \"vote\"])\n\n        sigmoid = param.pop(\"sigmoid\", False)\n\n        if \"image_save_func\" in param:\n            img_saver = ConfigParser(param[\"image_save_func\"]).get_parsed_content()\n\n        algo_spec_params = param.pop(\"algo_spec_params\", {})\n\n        outputs = []\n        for _, file in (\n            enumerate(tqdm(files, desc=\"Ensembling (rank 0)...\"))\n            if has_tqdm and pred_param and pred_param.get(\"rank\", 0) == 0\n            else enumerate(files)\n        ):\n            preds = []\n            for algo in self.algo_ensemble:\n                infer_algo_name = get_name_from_algo_id(algo[AlgoKeys.ID])\n                infer_instance = algo[AlgoKeys.ALGO]\n                _param = self._apply_algo_specific_param(algo_spec_params, param, infer_algo_name)\n                pred = infer_instance.predict(predict_files=[file], predict_params=_param)\n                preds.append(pred[0])\n            if \"image_save_func\" in param:\n                try:\n                    ensemble_preds = self.ensemble_pred(preds, sigmoid=sigmoid)\n                except BaseException:\n                    ensemble_preds = self.ensemble_pred([_.to(\"cpu\") for _ in preds], sigmoid=sigmoid)\n                res = img_saver(ensemble_preds)\n                # res is the path to the saved results\n                if hasattr(res, \"meta\") and \"saved_to\" in res.meta.keys():\n                    res = res.meta[\"saved_to\"]\n                else:\n                    warn(\"Image save path not returned.\")\n                    res = None\n            else:\n                warn(\"Prediction returned in list instead of disk, provide image_save_func to avoid out of memory.\")\n                res = self.ensemble_pred(preds, sigmoid=sigmoid)\n            outputs.append(res)\n        return outputs\n\n    @abstractmethod\n    def collect_algos(self, *args, **kwargs):\n        raise NotImplementedError\n\n\nclass AlgoEnsembleBestN(AlgoEnsemble):\n    \"\"\"\n    Ensemble method that select N model out of all using the models' best_metric scores\n\n    Args:\n        n_best: number of models to pick for ensemble (N).\n    \"\"\"\n\n    def __init__(self, n_best: int = 5):\n        super().__init__()\n        self.n_best = n_best\n\n    def sort_score(self):\n        \"\"\"\n        Sort the best_metrics\n        \"\"\"\n        scores = concat_val_to_np(self.algos, [AlgoKeys.SCORE])\n        return np.argsort(scores).tolist()\n\n    def collect_algos(self, n_best: int = -1) -> None:\n        \"\"\"\n        Rank the algos by finding the top N (n_best) validation scores.\n        \"\"\"\n\n        if n_best <= 0:\n            n_best = self.n_best\n\n        ranks = self.sort_score()\n        if len(ranks) < n_best:\n            warn(f\"Found {len(ranks)} available algos (pre-defined n_best={n_best}). All {len(ranks)} will be used.\")\n            n_best = len(ranks)\n\n        # get the ranks for which the indices are lower than N-n_best\n        indices = [r for (i, r) in enumerate(ranks) if i < (len(ranks) - n_best)]\n\n        # remove the found indices\n        indices = sorted(indices, reverse=True)\n\n        self.algo_ensemble = deepcopy(self.algos)\n        for idx in indices:\n            if idx < len(self.algo_ensemble):\n                self.algo_ensemble.pop(idx)\n\n\nclass AlgoEnsembleBestByFold(AlgoEnsemble):\n    \"\"\"\n    Ensemble method that select the best models that are the tops in each fold.\n\n    Args:\n        n_fold: number of cross-validation folds used in training\n    \"\"\"\n\n    def __init__(self, n_fold: int = 5):\n        super().__init__()\n        self.n_fold = n_fold\n\n    def collect_algos(self) -> None:\n        \"\"\"\n        Rank the algos by finding the best model in each cross-validation fold\n        \"\"\"\n\n        self.algo_ensemble = []\n        for f_idx in range(self.n_fold):\n            best_score = -1.0\n            best_model: BundleAlgo | None = None\n            for algo in self.algos:\n                # algorithm folder: {net}_{fold_index}_{other}\n                identifier = algo[AlgoKeys.ID].split(\"_\")[1]\n                try:\n                    algo_id = int(identifier)\n                except ValueError as err:\n                    raise ValueError(f\"model identifier {identifier} is not number.\") from err\n                if algo_id == f_idx and algo[AlgoKeys.SCORE] > best_score:\n                    best_model = algo\n                    best_score = algo[AlgoKeys.SCORE]\n            self.algo_ensemble.append(best_model)\n\n\nclass AlgoEnsembleBuilder:\n    \"\"\"\n    Build ensemble workflow from configs and arguments.\n\n    Args:\n        history: a collection of trained bundleAlgo algorithms.\n        data_src_cfg_name: filename of the data source.\n\n    Examples:\n\n        .. code-block:: python\n\n            builder = AlgoEnsembleBuilder(history, data_src_cfg)\n            builder.set_ensemble_method(BundleAlgoEnsembleBestN(3))\n            ensemble = builder.get_ensemble()\n\n    \"\"\"\n\n    def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_name: str | None = None):\n        self.infer_algos: list[dict[AlgoKeys, Any]] = []\n        self.ensemble: AlgoEnsemble\n        self.data_src_cfg = ConfigParser(globals=False)\n\n        if data_src_cfg_name is not None and os.path.exists(str(data_src_cfg_name)):\n            self.data_src_cfg.read_config(data_src_cfg_name)\n\n        for algo_dict in history:\n            # load inference_config_paths\n\n            name = algo_dict[AlgoKeys.ID]\n            gen_algo = algo_dict[AlgoKeys.ALGO]\n\n            best_metric = gen_algo.get_score()\n            algo_path = gen_algo.output_path\n            infer_path = os.path.join(algo_path, \"scripts\", \"infer.py\")\n\n            if not os.path.isdir(algo_path):\n                warn(f\"{gen_algo.output_path} is not a directory. Please check the path.\")\n\n            if not os.path.isfile(infer_path):\n                warn(f\"{infer_path} is not found. Please check the path.\")\n\n            self.add_inferer(name, gen_algo, best_metric)\n\n    def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: float | None = None) -> None:\n        \"\"\"\n        Add model inferer to the builder.\n\n        Args:\n            identifier: name of the bundleAlgo.\n            gen_algo: a trained BundleAlgo model object.\n            best_metric: the best metric in validation of the trained model.\n        \"\"\"\n\n        if best_metric is None:\n            raise ValueError(\"Feature to re-validate is to be implemented\")\n\n        algo = {AlgoKeys.ID: identifier, AlgoKeys.ALGO: gen_algo, AlgoKeys.SCORE: best_metric}\n        self.infer_algos.append(algo)\n\n    def set_ensemble_method(self, ensemble: AlgoEnsemble, *args: Any, **kwargs: Any) -> None:\n        \"\"\"\n        Set the ensemble method.\n\n        Args:\n            ensemble: the AlgoEnsemble to build.\n        \"\"\"\n\n        ensemble.set_algos(self.infer_algos)\n        ensemble.collect_algos(*args, **kwargs)\n        ensemble.set_infer_files(self.data_src_cfg[\"dataroot\"], self.data_src_cfg[\"datalist\"])\n\n        self.ensemble = ensemble\n\n    def get_ensemble(self):\n        \"\"\"Get the ensemble\"\"\"\n\n        return self.ensemble\n\n\nclass EnsembleRunner:\n    \"\"\"\n    The Runner for ensembler. It ensembles predictions and saves them to the disk with a support of using multi-GPU.\n\n    Args:\n        data_src_cfg_name: filename of the data source.\n        work_dir: working directory to save the intermediate and final results. Default is `./work_dir`.\n        num_fold: number of fold. Default is 5.\n        ensemble_method_name: method to ensemble predictions from different model. Default is AlgoEnsembleBestByFold.\n                              Supported methods: [\"AlgoEnsembleBestN\", \"AlgoEnsembleBestByFold\"].\n        mgpu: if using multi-gpu. Default is True.\n        kwargs: additional image writing, ensembling parameters and prediction parameters for the ensemble inference.\n              - for image saving, please check the supported parameters in SaveImage transform.\n              - for prediction parameters, please check the supported parameters in the ``AlgoEnsemble`` callables.\n              - for ensemble parameters, please check the documentation of the selected AlgoEnsemble callable.\n\n    Example:\n\n        .. code-block:: python\n\n            ensemble_runner = EnsembleRunner(data_src_cfg_name,\n                                             work_dir,\n                                             ensemble_method_name,\n                                             mgpu=device_setting['n_devices']>1,\n                                             **kwargs,\n                                             **pred_params)\n            ensemble_runner.run(device_setting)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data_src_cfg_name: str,\n        work_dir: str = \"./work_dir\",\n        num_fold: int = 5,\n        ensemble_method_name: str = \"AlgoEnsembleBestByFold\",\n        mgpu: bool = True,\n        **kwargs: Any,\n    ) -> None:\n        self.data_src_cfg_name = data_src_cfg_name\n        self.work_dir = work_dir\n        self.num_fold = num_fold\n        self.ensemble_method_name = ensemble_method_name\n        self.mgpu = mgpu\n        self.kwargs = deepcopy(kwargs)\n        self.rank = 0\n        self.world_size = 1\n        self.device_setting: dict[str, int | str] = {\n            \"CUDA_VISIBLE_DEVICES\": \",\".join([str(x) for x in range(torch.cuda.device_count())]),\n            \"n_devices\": torch.cuda.device_count(),\n            \"NUM_NODES\": int(os.environ.get(\"NUM_NODES\", 1)),\n            \"MN_START_METHOD\": os.environ.get(\"MN_START_METHOD\", \"bcprun\"),\n            \"CMD_PREFIX\": os.environ.get(\"CMD_PREFIX\", \"\"),\n        }\n\n    def set_ensemble_method(self, ensemble_method_name: str = \"AlgoEnsembleBestByFold\", **kwargs: Any) -> None:\n        \"\"\"\n        Set the bundle ensemble method\n\n        Args:\n            ensemble_method_name: the name of the ensemble method. Only two methods are supported \"AlgoEnsembleBestN\"\n                and \"AlgoEnsembleBestByFold\".\n            kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for\n                ``AlgoEnsembleBestN`` is supported.\n\n        \"\"\"\n        self.ensemble_method_name = look_up_option(\n            ensemble_method_name, supported=[\"AlgoEnsembleBestN\", \"AlgoEnsembleBestByFold\"]\n        )\n        if self.ensemble_method_name == \"AlgoEnsembleBestN\":\n            n_best = kwargs.pop(\"n_best\", 2)\n            self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)\n        elif self.ensemble_method_name == \"AlgoEnsembleBestByFold\":\n            self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold)  # type: ignore\n        else:\n            raise NotImplementedError(f\"Ensemble method {self.ensemble_method_name} is not implemented.\")\n\n    def _pop_kwargs_to_get_image_save_transform(self, **kwargs):\n        \"\"\"\n        Pop the kwargs used to define ImageSave class for the ensemble output.\n\n        Args:\n            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage\n                transform. For more information, check https://monai.readthedocs.io/en/stable/transforms.html#saveimage .\n\n        Returns:\n            save_image: a dictionary that can be used to instantiate a SaveImage class in ConfigParser.\n        \"\"\"\n\n        output_dir = kwargs.pop(\"output_dir\", None)\n\n        if output_dir is None:\n            output_dir = os.path.join(self.work_dir, \"ensemble_output\")\n            logger.info(f\"The output_dir is not specified. {output_dir} will be used to save ensemble predictions.\")\n\n        if not os.path.isdir(output_dir):\n            os.makedirs(output_dir, exist_ok=True)\n            logger.info(f\"Directory {output_dir} is created to save ensemble predictions\")\n\n        input_yaml = ConfigParser.load_config_file(self.data_src_cfg_name)\n        data_root_dir = input_yaml.get(\"dataroot\", \"\")\n\n        save_image = {\n            \"_target_\": \"SaveImage\",\n            \"output_dir\": output_dir,\n            \"output_postfix\": kwargs.pop(\"output_postfix\", \"ensemble\"),\n            \"output_dtype\": kwargs.pop(\"output_dtype\", \"$np.uint8\"),\n            \"resample\": kwargs.pop(\"resample\", False),\n            \"print_log\": False,\n            \"savepath_in_metadict\": True,\n            \"data_root_dir\": kwargs.pop(\"data_root_dir\", data_root_dir),\n            \"separate_folder\": kwargs.pop(\"separate_folder\", False),\n        }\n\n        are_all_args_save_image, extra_args = check_kwargs_exist_in_class_init(SaveImage, kwargs)\n        if are_all_args_save_image:\n            save_image.update(kwargs)\n        else:\n            # kwargs has extra values for other purposes, for example, pred_params\n            for args in list(kwargs):\n                if args not in extra_args:\n                    save_image.update({args: kwargs.pop(args)})\n\n        return save_image\n\n    def set_image_save_transform(self, **kwargs: Any) -> None:\n        \"\"\"\n        Set the ensemble output transform.\n\n        Args:\n            kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage\n                transform. For more information, check https://monai.readthedocs.io/en/stable/transforms.html#saveimage .\n\n        \"\"\"\n        are_all_args_present, extra_args = check_kwargs_exist_in_class_init(SaveImage, kwargs)\n        if are_all_args_present:\n            self.kwargs.update(kwargs)\n        else:\n            raise ValueError(\n                f\"{extra_args} are not supported in monai.transforms.SaveImage,\"\n                \"Check https://monai.readthedocs.io/en/stable/transforms.html#saveimage for more information.\"\n            )\n\n    def set_num_fold(self, num_fold: int = 5) -> None:\n        \"\"\"\n        Set the number of cross validation folds for all algos.\n\n        Args:\n            num_fold: a positive integer to define the number of folds.\n        \"\"\"\n\n        if num_fold <= 0:\n            raise ValueError(f\"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}\")\n        self.num_fold = num_fold\n\n    def ensemble(self):\n        if self.mgpu:  # torch.cuda.device_count() is not used because env is not set by autorunner\n            # init multiprocessing and update infer_files\n            dist.init_process_group(backend=\"nccl\", init_method=\"env://\")\n            self.world_size = dist.get_world_size()\n            self.rank = dist.get_rank()\n            logger.addFilter(RankFilter())\n        # set params after init_process_group to know the rank\n        self.set_num_fold(num_fold=self.num_fold)\n        self.set_ensemble_method(self.ensemble_method_name, **self.kwargs)\n        # self.kwargs needs to pop out args for set_image_save_transform\n        save_image = self._pop_kwargs_to_get_image_save_transform(**self.kwargs)\n\n        history = import_bundle_algo_history(self.work_dir, only_trained=False)\n        history_untrained = [h for h in history if not h[AlgoKeys.IS_TRAINED]]\n        if history_untrained:\n            logger.warning(\n                f\"Ensembling step will skip {[h[AlgoKeys.ID] for h in history_untrained]} untrained algos.\"\n                \"Generally it means these algos did not complete training.\"\n            )\n            history = [h for h in history if h[AlgoKeys.IS_TRAINED]]\n        if len(history) == 0:\n            raise ValueError(\n                f\"Could not find the trained results in {self.work_dir}. \"\n                \"Possibly the required training step was not completed.\"\n            )\n\n        builder = AlgoEnsembleBuilder(history, self.data_src_cfg_name)\n        builder.set_ensemble_method(self.ensemble_method)\n        self.ensembler = builder.get_ensemble()\n        infer_files = self.ensembler.infer_files\n        if len(infer_files) < self.world_size:\n            if len(infer_files) == 0:\n                logger.info(\"No testing files for inference is provided. Ensembler ending.\")\n                return\n            infer_files = [infer_files[self.rank]] if self.rank < len(infer_files) else []\n        else:\n            infer_files = partition_dataset(\n                data=infer_files, shuffle=False, num_partitions=self.world_size, even_divisible=False\n            )[self.rank]\n\n        # TO DO: Add some function in ensembler for infer_files update?\n        self.ensembler.infer_files = infer_files\n        # add rank to pred_params\n        self.kwargs[\"rank\"] = self.rank\n        self.kwargs[\"image_save_func\"] = save_image\n        logger.info(\"Auto3Dseg picked the following networks to ensemble:\")\n        for algo in self.ensembler.get_algo_ensemble():\n            logger.info(algo[AlgoKeys.ID])\n        output_dir = save_image[\"output_dir\"]\n        logger.info(f\"Auto3Dseg ensemble prediction outputs will be saved in {output_dir}.\")\n        self.ensembler(pred_param=self.kwargs)\n\n        if self.mgpu:\n            dist.destroy_process_group()\n\n    def run(self, device_setting: dict | None = None) -> None:\n        \"\"\"\n        Load the run function in the training script of each model. Training parameter is predefined by the\n        algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.\n\n        Args:\n            device_setting: device related settings, should follow the device_setting in auto_runner.set_device_info.\n                'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'\n        \"\"\"\n        # device_setting set default value and sanity check, in case device_setting not from autorunner\n        if device_setting is not None:\n            self.device_setting.update(device_setting)\n            self.device_setting[\"n_devices\"] = len(str(self.device_setting[\"CUDA_VISIBLE_DEVICES\"]).split(\",\"))\n        self._create_cmd()\n\n    def _create_cmd(self) -> None:\n        if int(self.device_setting[\"NUM_NODES\"]) <= 1 and int(self.device_setting[\"n_devices\"]) <= 1:\n            # if single GPU\n            logger.info(\"Ensembling using single GPU!\")\n            self.ensemble()\n            return\n\n        # define base cmd for subprocess\n        base_cmd = f\"monai.apps.auto3dseg EnsembleRunner ensemble \\\n                --data_src_cfg_name {self.data_src_cfg_name} \\\n                --work_dir {self.work_dir} \\\n                --num_fold {self.num_fold} \\\n                --ensemble_method_name {self.ensemble_method_name} \\\n                --mgpu True\"\n\n        if self.kwargs and isinstance(self.kwargs, Mapping):\n            for k, v in self.kwargs.items():\n                base_cmd += f\" --{k}={v}\"\n        # define env for subprocess\n        ps_environ = os.environ.copy()\n        ps_environ[\"CUDA_VISIBLE_DEVICES\"] = str(self.device_setting[\"CUDA_VISIBLE_DEVICES\"])\n        if int(self.device_setting[\"NUM_NODES\"]) > 1:\n            if self.device_setting[\"MN_START_METHOD\"] != \"bcprun\":\n                raise NotImplementedError(\n                    f\"{self.device_setting['MN_START_METHOD']} is not supported yet. \"\n                    \"Try modify EnsembleRunner._create_cmd for your cluster.\"\n                )\n            logger.info(f\"Ensembling on {self.device_setting['NUM_NODES']} nodes!\")\n            cmd = _prepare_cmd_bcprun(\"-m \" + base_cmd, cmd_prefix=f\"{self.device_setting['CMD_PREFIX']}\")\n            _run_cmd_bcprun(cmd, n=self.device_setting[\"NUM_NODES\"], p=self.device_setting[\"n_devices\"])\n\n        else:\n            logger.info(f\"Ensembling using {self.device_setting['n_devices']} GPU!\")\n            cmd = _prepare_cmd_torchrun(\"-m \" + base_cmd)\n            _run_cmd_torchrun(\n                cmd, nnodes=1, nproc_per_node=self.device_setting[\"n_devices\"], env=ps_environ, check=True\n            )\n        return\n"
  },
  {
    "path": "monai/apps/auto3dseg/hpo_gen.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nfrom abc import abstractmethod\nfrom copy import deepcopy\nfrom typing import Any, cast\nfrom warnings import warn\n\nfrom monai.apps.auto3dseg.bundle_gen import BundleAlgo\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg import Algo, AlgoGen, algo_from_pickle, algo_to_pickle\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.config import PathLike\nfrom monai.utils import optional_import\nfrom monai.utils.enums import AlgoKeys\n\nnni, has_nni = optional_import(\"nni\")\noptuna, has_optuna = optional_import(\"optuna\")\nlogger = get_logger(module_name=__name__)\n\n__all__ = [\"HPOGen\", \"NNIGen\", \"OptunaGen\"]\n\n\nclass HPOGen(AlgoGen):\n    \"\"\"\n    The base class for hyperparameter optimization (HPO) interfaces to generate algos in the Auto3Dseg pipeline.\n    The auto-generated algos are saved at their ``output_path`` on the disk. The files in the ``output_path``\n    may contain scripts that define the algo, configuration files, and pickle files that save the internal states\n    of the algo before/after the training. Compared to the BundleGen class, HPOGen generates Algo on-the-fly, so\n    training and algo generation may be executed alternatively and take a long time to finish the generation process.\n\n    \"\"\"\n\n    @abstractmethod\n    def get_hyperparameters(self):\n        \"\"\"Get the hyperparameter from HPO.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def update_params(self, *args, **kwargs):\n        \"\"\"Update Algo parameters according to the hyperparameters to be evaluated.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def set_score(self, *args, **kwargs):\n        \"\"\"Report the evaluated results to HPO.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def run_algo(self, *args, **kwargs):\n        \"\"\"Interface for launch the training given the fetched hyperparameters.\"\"\"\n        raise NotImplementedError\n\n\nclass NNIGen(HPOGen):\n    \"\"\"\n    Generate algorithms for the NNI to automate hyperparameter tuning. The module has two major interfaces:\n    ``__init__`` which prints out how to set up the NNI, and a trialCommand function ``run_algo`` for the NNI library to\n    start the trial of the algo. More about trialCommand function can be found in ``trail code`` section in NNI webpage\n    https://nni.readthedocs.io/en/latest/tutorials/hpo_quickstart_pytorch/main.html .\n\n    Args:\n        algo: an Algo object (e.g. BundleAlgo) with defined methods: ``get_output_path`` and train\n            and supports saving to and loading from pickle files via ``algo_from_pickle`` and ``algo_to_pickle``.\n        params: a set of parameter to override the algo if override is supported by Algo subclass.\n\n    Examples::\n\n        The experiment will keep generating new folders to save the model checkpoints, scripts, and configs if available.\n        ├── algorithm_templates\n        │   └── unet\n        ├── unet_0\n        │   ├── algo_object.pkl\n        │   ├── configs\n        │   └── scripts\n        ├── unet_0_learning_rate_0.01\n        │   ├── algo_object.pkl\n        │   ├── configs\n        │   ├── model_fold0\n        │   └── scripts\n        └── unet_0_learning_rate_0.1\n            ├── algo_object.pkl\n            ├── configs\n            ├── model_fold0\n            └── scripts\n\n        .. code-block:: python\n            # Bundle Algorithms are already generated by BundleGen in work_dir\n            import_bundle_algo_history(work_dir, only_trained=False)\n            algo_dict = self.history[0]  # pick the first algorithm\n            algo_name = algo_dict[AlgoKeys.ID]\n            onealgo = algo_dict[AlgoKeys.ALGO]\n            nni_gen = NNIGen(algo=onealgo)\n            nni_gen.print_bundle_algo_instruction()\n\n    Notes:\n        The NNIGen will prepare the algorithms in a folder and suggest a command to replace trialCommand in the experiment\n        config. However, NNIGen will not trigger NNI. User needs to write their NNI experiment configs, and then run the\n        NNI command manually.\n    \"\"\"\n\n    def __init__(self, algo: Algo | None = None, params: dict | None = None):\n        self.algo: Algo\n        self.hint = \"\"\n        self.obj_filename = \"\"\n\n        if algo is not None:\n            if isinstance(algo, BundleAlgo):\n                if params is None:\n                    self.algo = algo\n                else:\n                    self.algo = deepcopy(algo)\n                    name = os.path.basename(algo.get_output_path()) + \"_override\"\n                    output_folder = os.path.dirname(algo.get_output_path())\n\n                    params.update({\"fill_with_datastats\": False})  # just copy, not using datastats to fill\n                    self.algo.export_to_disk(output_folder, name, **params)\n            else:\n                self.algo = algo\n\n            self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)\n\n    def get_obj_filename(self):\n        \"\"\"Return the filename of the dumped pickle algo object.\"\"\"\n        return self.obj_filename\n\n    def print_bundle_algo_instruction(self):\n        \"\"\"\n        Print how to write the trial commands for Bundle Algo.\n        \"\"\"\n        hint = \"python -m monai.apps.auto3dseg NNIGen run_algo \"\n        logger.info(\"=\" * 140)\n        logger.info(\"If NNI will run in your local env: \")\n        logger.info(\"1. Add the following line to the trialCommand in your NNI config: \")\n        logger.info(f\"{hint} {self.obj_filename} {{result_dir}}\")\n        logger.info(\"-\" * 140)\n        logger.info(\"If NNI will run in a remote env: \")\n        logger.info(\n            f\"1. Copy the algorithm_templates folder {cast(BundleAlgo, self.algo).template_path} \"\n            f\"to remote {{remote_algorithm_templates_dir}}\"\n        )\n        logger.info(f\"2. Copy the older {self.algo.get_output_path()} to the remote machine {{remote_algo_dir}}\")\n        logger.info(\"Then add the following line to the trialCommand in your NNI config: \")\n        logger.info(f\"{hint} {{remote_algo_dir}} {{result_dir}} {{remote_algorithm_templates_dir}}\")\n        logger.info(\"=\" * 140)\n\n    def get_hyperparameters(self):\n        \"\"\"\n        Get parameter for next round of training from NNI server.\n        \"\"\"\n        if has_nni:\n            return nni.get_next_parameter()\n        warn(\"NNI is not detected. The code will continue to run without NNI.\")\n        return {}\n\n    def update_params(self, params: dict) -> None:\n        \"\"\"\n        Translate the parameter from monai bundle to meet NNI requirements.\n\n        Args:\n            params: a dict of parameters.\n        \"\"\"\n        self.params = params\n\n    def get_task_id(self):\n        \"\"\"\n        Get the identifier of the current experiment. In the format of listing the searching parameter name and values\n        connected by underscore in the file name.\n        \"\"\"\n        return \"\".join(f\"_{k}_{v}\" for k, v in self.params.items()) or \"_None\"\n\n    def generate(self, output_folder: str = \".\") -> None:\n        \"\"\"\n        Generate the record for each Algo. If it is a BundleAlgo, it will generate the config files.\n\n        Args:\n            output_folder: the directory nni will save the results to.\n        \"\"\"\n        task_id = self.get_task_id()\n        task_prefix = os.path.basename(self.algo.get_output_path())\n        write_path = os.path.join(output_folder, task_prefix + task_id)\n        self.obj_filename = os.path.join(write_path, \"algo_object.pkl\")\n\n        if isinstance(self.algo, BundleAlgo):\n            self.algo.export_to_disk(\n                output_folder, task_prefix + task_id, bundle_root=write_path, fill_with_datastats=False\n            )\n        else:\n            ConfigParser.export_config_file(self.params, write_path)\n            logger.info(write_path)\n\n    def set_score(self, acc):\n        \"\"\"\n        Report the acc to NNI server.\n        \"\"\"\n        if has_nni:\n            nni.report_final_result(acc)\n        else:\n            warn(\"NNI is not detected. The code will continue to run without NNI.\")\n\n    def run_algo(self, obj_filename: str, output_folder: str = \".\", template_path: PathLike | None = None) -> None:\n        \"\"\"\n        The python interface for NNI to run.\n\n        Args:\n            obj_filename: the pickle-exported Algo object.\n            output_folder: the root path of the algorithms templates.\n            template_path: the algorithm_template. It must contain algo.py in the follow path:\n                ``{algorithm_templates_dir}/{network}/scripts/algo.py``\n        \"\"\"\n        if not os.path.isfile(obj_filename):\n            raise ValueError(f\"{obj_filename} is not found\")\n\n        self.algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)\n\n        # step 1 sample hyperparams\n        params = self.get_hyperparameters()\n        # step 2 set the update params for the algo to run in the next trial\n        self.update_params(params)\n        # step 3 generate the folder to save checkpoints and train\n        self.generate(output_folder)\n        self.algo.train(self.params)\n        # step 4 report validation acc to controller\n        acc = self.algo.get_score()\n        algo_meta_data = {str(AlgoKeys.SCORE): acc}\n\n        algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)\n        self.set_score(acc)\n\n\nclass OptunaGen(HPOGen):\n    \"\"\"\n    Generate algorithms for the Optuna to automate hyperparameter tuning. Please refer to NNI and Optuna\n    (https://optuna.readthedocs.io/en/stable/) for more information. Optuna has different running scheme\n    compared to NNI. The hyperparameter samples come from a trial object (trial.suggest...) created by Optuna,\n    so OptunaGen needs to accept this trial object as input. Meanwhile, Optuna calls OptunaGen,\n    thus OptunaGen.__call__() should return the accuracy. Use functools.partial to wrap OptunaGen\n    for addition input arguments.\n\n    Args:\n        algo: an Algo object (e.g. BundleAlgo). The object must at least define two methods: get_output_path and train\n            and supports saving to and loading from pickle files via ``algo_from_pickle`` and ``algo_to_pickle``.\n        params: a set of parameter to override the algo if override is supported by Algo subclass.\n\n    Examples::\n\n        The experiment will keep generating new folders to save the model checkpoints, scripts, and configs if available.\n        ├── algorithm_templates\n        │   └── unet\n        ├── unet_0\n        │   ├── algo_object.pkl\n        │   ├── configs\n        │   └── scripts\n        ├── unet_0_learning_rate_0.01\n        │   ├── algo_object.pkl\n        │   ├── configs\n        │   ├── model_fold0\n        │   └── scripts\n        └── unet_0_learning_rate_0.1\n            ├── algo_object.pkl\n            ├── configs\n            ├── model_fold0\n            └── scripts\n\n    Notes:\n        Different from NNI and NNIGen, OptunaGen and Optuna can be ran within the Python process.\n\n    \"\"\"\n\n    def __init__(self, algo: Algo | None = None, params: dict | None = None) -> None:\n        self.algo: Algo\n        self.obj_filename = \"\"\n\n        if algo is not None:\n            if isinstance(algo, BundleAlgo):\n                if params is None:\n                    self.algo = algo\n                else:\n                    self.algo = deepcopy(algo)\n                    name = os.path.basename(algo.get_output_path()) + \"_override\"\n                    output_folder = os.path.dirname(algo.get_output_path())\n\n                    params.update({\"fill_with_datastats\": False})  # just copy, not using datastats to fill\n                    self.algo.export_to_disk(output_folder, name, **params)\n            else:\n                self.algo = algo\n\n            self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)\n\n    def get_obj_filename(self):\n        \"\"\"Return the dumped pickle object of algo.\"\"\"\n        return self.obj_filename\n\n    def get_hyperparameters(self):\n        \"\"\"\n        Get parameter for next round of training from optuna trial object.\n        This function requires user rewrite during usage for different search space.\n        \"\"\"\n        if has_optuna:\n            logger.info(\"Please rewrite this code by creating a child class\")\n            return {\"learning_rate\": self.trial.suggest_float(\"learning_rate\", 0.0001, 0.1)}\n        else:\n            warn(\"Optuna is not detected. The code will continue to run without Optuna.\")\n            return {}\n\n    def set_score(self, acc):\n        \"\"\"Set the accuracy score\"\"\"\n        self.acc = acc\n\n    def set_trial(self, trial):\n        \"\"\"Set the Optuna trial\"\"\"\n        self.trial = trial\n\n    def __call__(\n        self, trial: Any, obj_filename: str, output_folder: str = \".\", template_path: PathLike | None = None\n    ) -> Any:\n        \"\"\"\n        Callable that Optuna will use to optimize the hyper-parameters\n\n        Args:\n            obj_filename: the pickle-exported Algo object.\n            output_folder: the root path of the algorithms templates.\n            template_path: the algorithm_template. It must contain algo.py in the follow path:\n                ``{algorithm_templates_dir}/{network}/scripts/algo.py``\n        \"\"\"\n        self.set_trial(trial)\n        self.run_algo(obj_filename, output_folder, template_path)\n        return self.acc\n\n    def update_params(self, params: dict) -> None:\n        \"\"\"\n        Translate the parameter from monai bundle.\n\n        Args:\n            params: a dict of parameters.\n        \"\"\"\n        self.params = params\n\n    def get_task_id(self):\n        \"\"\"\n        Get the identifier of the current experiment. In the format of listing the searching parameter name and values\n        connected by underscore in the file name.\n        \"\"\"\n        return \"\".join(f\"_{k}_{v}\" for k, v in self.params.items()) or \"_None\"\n\n    def generate(self, output_folder: str = \".\") -> None:\n        \"\"\"\n        Generate the record for each Algo. If it is a BundleAlgo, it will generate the config files.\n\n        Args:\n            output_folder: the directory nni will save the results to.\n        \"\"\"\n        task_id = self.get_task_id()\n        task_prefix = os.path.basename(self.algo.get_output_path())\n        write_path = os.path.join(output_folder, task_prefix + task_id)\n        self.obj_filename = os.path.join(write_path, \"algo_object.pkl\")\n\n        if isinstance(self.algo, BundleAlgo):\n            self.algo.export_to_disk(output_folder, task_prefix + task_id, fill_with_datastats=False)\n        else:\n            ConfigParser.export_config_file(self.params, write_path)\n            logger.info(write_path)\n\n    def run_algo(self, obj_filename: str, output_folder: str = \".\", template_path: PathLike | None = None) -> None:\n        \"\"\"\n        The python interface for NNI to run.\n\n        Args:\n            obj_filename: the pickle-exported Algo object.\n            output_folder: the root path of the algorithms templates.\n            template_path: the algorithm_template. It must contain algo.py in the follow path:\n                ``{algorithm_templates_dir}/{network}/scripts/algo.py``\n        \"\"\"\n        if not os.path.isfile(obj_filename):\n            raise ValueError(f\"{obj_filename} is not found\")\n\n        self.algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)\n\n        # step 1 sample hyperparams\n        params = self.get_hyperparameters()\n        # step 2 set the update params for the algo to run in the next trial\n        self.update_params(params)\n        # step 3 generate the folder to save checkpoints and train\n        self.generate(output_folder)\n        self.algo.train(self.params)\n        # step 4 report validation acc to controller\n        acc = self.algo.get_score()\n        algo_meta_data = {str(AlgoKeys.SCORE): acc}\n        algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)\n        self.set_score(acc)\n"
  },
  {
    "path": "monai/apps/auto3dseg/transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Hashable, Mapping\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection\nfrom monai.transforms import MapTransform\nfrom monai.utils.misc import ImageMetaKey\n\n\nclass EnsureSameShaped(MapTransform):\n    \"\"\"\n    Checks if segmentation label images (in keys) have the same spatial shape as the main image (in source_key),\n    and raise an error if the shapes are significantly different.\n    If the shapes are only slightly different (within an allowed_shape_difference in each dim), then resize the label using\n    nearest interpolation. This transform is designed to correct datasets with slight label shape mismatches.\n    Generally image and segmentation label must have the same spatial shape, however some public datasets are having slight\n    shape mismatches, which will cause potential crashes when calculating loss or metric functions.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection = \"label\",\n        allow_missing_keys: bool = False,\n        source_key: str = \"image\",\n        allowed_shape_difference: int = 5,\n        warn: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be compared to the source_key item shape.\n            allow_missing_keys: do not raise exception if key is missing.\n            source_key: key of the item with the reference shape.\n            allowed_shape_difference: raises error if shapes are different more than this value in any dimension,\n                otherwise corrects for the shape mismatch using nearest interpolation.\n            warn: if `True` prints a warning if the label image is resized\n\n\n        \"\"\"\n        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)\n        self.source_key = source_key\n        self.allowed_shape_difference = allowed_shape_difference\n        self.warn = warn\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        image_shape = d[self.source_key].shape[1:]\n        for key in self.key_iterator(d):\n            label_shape = d[key].shape[1:]\n            if label_shape != image_shape:\n                filename = \"\"\n                if hasattr(d[key], \"meta\") and isinstance(d[key].meta, Mapping):  # type: ignore[attr-defined]\n                    filename = d[key].meta.get(ImageMetaKey.FILENAME_OR_OBJ)  # type: ignore[attr-defined]\n\n                if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference):\n                    if self.warn:\n                        warnings.warn(\n                            f\"The {key} with shape {label_shape} was resized to match the source shape {image_shape}\"\n                            f\", the metadata was not updated {filename}.\"\n                        )\n                    d[key] = torch.nn.functional.interpolate(\n                        input=d[key].unsqueeze(0), size=image_shape, mode=\"nearest-exact\"\n                    ).squeeze(0)\n                else:\n                    raise ValueError(\n                        f\"The {key} shape {label_shape} is different from the source shape {image_shape} {filename}.\"\n                    )\n        return d\n"
  },
  {
    "path": "monai/apps/auto3dseg/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\n\nfrom monai.apps.auto3dseg.bundle_gen import BundleAlgo\nfrom monai.auto3dseg import algo_from_pickle, algo_to_pickle\nfrom monai.utils.enums import AlgoKeys\n\n__all__ = [\"import_bundle_algo_history\", \"export_bundle_algo_history\", \"get_name_from_algo_id\"]\n\n\ndef import_bundle_algo_history(\n    output_folder: str = \".\", template_path: str | None = None, only_trained: bool = True\n) -> list:\n    \"\"\"\n    import the history of the bundleAlgo objects as a list of algo dicts.\n    each algo_dict has keys name (folder name), algo (bundleAlgo), is_trained (bool),\n\n    Args:\n        output_folder: the root path of the algorithms templates.\n        template_path: the algorithm_template. It must contain algo.py in the follow path:\n            ``{algorithm_templates_dir}/{network}/scripts/algo.py``.\n        only_trained: only read the algo history if the algo is trained.\n    \"\"\"\n\n    history = []\n\n    for name in sorted(os.listdir(output_folder)):\n        write_path = os.path.join(output_folder, name)\n\n        if not os.path.isdir(write_path):\n            continue\n\n        obj_filename = os.path.join(write_path, \"algo_object.pkl\")\n        if not os.path.isfile(obj_filename):  # saved mode pkl\n            continue\n\n        algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)\n\n        best_metric = algo_meta_data.get(AlgoKeys.SCORE, None)\n        if best_metric is None:\n            try:\n                best_metric = algo.get_score()\n            except BaseException:\n                pass\n\n        is_trained = best_metric is not None\n\n        if (only_trained and is_trained) or not only_trained:\n            history.append(\n                {AlgoKeys.ID: name, AlgoKeys.ALGO: algo, AlgoKeys.SCORE: best_metric, AlgoKeys.IS_TRAINED: is_trained}\n            )\n\n    return history\n\n\ndef export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:\n    \"\"\"\n    Save all the BundleAlgo in the history to algo_object.pkl in each individual folder\n\n    Args:\n        history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method\n    \"\"\"\n    for algo_dict in history:\n        algo = algo_dict[AlgoKeys.ALGO]\n        algo_to_pickle(algo, template_path=algo.template_path)\n\n\ndef get_name_from_algo_id(id: str) -> str:\n    \"\"\"\n    Get the name of Algo from the identifier of the Algo.\n\n    Args:\n        id: identifier which follows a convention of \"name_fold_other\".\n\n    Returns:\n        name of the Algo.\n    \"\"\"\n    return id.split(\"_\")[0]\n"
  },
  {
    "path": "monai/apps/datasets.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport sys\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\n\nfrom monai.apps.tcia import (\n    DCM_FILENAME_REGEX,\n    download_tcia_series_instance,\n    get_tcia_metadata,\n    get_tcia_ref_uid,\n    match_tcia_ref_uid_in_study,\n)\nfrom monai.apps.utils import download_and_extract\nfrom monai.config.type_definitions import PathLike\nfrom monai.data import (\n    CacheDataset,\n    PydicomReader,\n    load_decathlon_datalist,\n    load_decathlon_properties,\n    partition_dataset,\n    select_cross_validation_folds,\n)\nfrom monai.transforms import LoadImaged, Randomizable\nfrom monai.utils import ensure_tuple\n\n__all__ = [\"MedNISTDataset\", \"DecathlonDataset\", \"CrossValidation\", \"TciaDataset\"]\n\n\nclass MedNISTDataset(Randomizable, CacheDataset):\n    \"\"\"\n    The Dataset to automatically download MedNIST data and generate items for training, validation or test.\n    It's based on `CacheDataset` to accelerate the training process.\n\n    Args:\n        root_dir: target directory to download and load MedNIST dataset.\n        section: expected data section, can be: `training`, `validation` or `test`.\n        transform: transforms to execute operations on input data.\n        download: whether to download and extract the MedNIST from resource link, default is False.\n            if expected file already exists, skip downloading even set it to True.\n            user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory.\n        seed: random seed to randomly split training, validation and test datasets, default is 0.\n        val_frac: percentage of validation fraction in the whole dataset, default is 0.1.\n        test_frac: percentage of test fraction in the whole dataset, default is 0.1.\n        cache_num: number of items to be cached. Default is `sys.maxsize`.\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        cache_rate: percentage of cached data in total, default is 1.0 (cache all).\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        num_workers: the number of worker threads if computing cache in the initialization.\n            If num_workers is None then the number returned by os.cpu_count() is used.\n            If a value less than 1 is specified, 1 will be used instead.\n        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.\n        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,\n            default to `True`. if the random transforms don't modify the cached content\n            (for example, randomly crop from the cached image and deepcopy the crop region)\n            or if every cache item is only used once in a `multi-processing` environment,\n            may set `copy=False` for better performance.\n        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n            it may help improve the performance of following logic.\n        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare\n            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.\n\n    Raises:\n        ValueError: When ``root_dir`` is not a directory.\n        RuntimeError: When ``dataset_dir`` doesn't exist and downloading is not selected (``download=False``).\n\n    \"\"\"\n\n    resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\"\n    md5 = \"0bc7306e7427e00ad1c5526a6677552d\"\n    compressed_file_name = \"MedNIST.tar.gz\"\n    dataset_folder_name = \"MedNIST\"\n\n    def __init__(\n        self,\n        root_dir: PathLike,\n        section: str,\n        transform: Sequence[Callable] | Callable = (),\n        download: bool = False,\n        seed: int = 0,\n        val_frac: float = 0.1,\n        test_frac: float = 0.1,\n        cache_num: int = sys.maxsize,\n        cache_rate: float = 1.0,\n        num_workers: int | None = 1,\n        progress: bool = True,\n        copy_cache: bool = True,\n        as_contiguous: bool = True,\n        runtime_cache: bool = False,\n    ) -> None:\n        root_dir = Path(root_dir)\n        if not root_dir.is_dir():\n            raise ValueError(\"Root directory root_dir must be a directory.\")\n        self.section = section\n        self.val_frac = val_frac\n        self.test_frac = test_frac\n        self.set_random_state(seed=seed)\n        tarfile_name = root_dir / self.compressed_file_name\n        dataset_dir = root_dir / self.dataset_folder_name\n        self.num_class = 0\n        if download:\n            download_and_extract(\n                url=self.resource,\n                filepath=tarfile_name,\n                output_dir=root_dir,\n                hash_val=self.md5,\n                hash_type=\"md5\",\n                progress=progress,\n            )\n\n        if not dataset_dir.is_dir():\n            raise RuntimeError(\n                f\"Cannot find dataset directory: {dataset_dir}, please use download=True to download it.\"\n            )\n        data = self._generate_data_list(dataset_dir)\n        if transform == ():\n            transform = LoadImaged(\"image\")\n        CacheDataset.__init__(\n            self,\n            data=data,\n            transform=transform,\n            cache_num=cache_num,\n            cache_rate=cache_rate,\n            num_workers=num_workers,\n            progress=progress,\n            copy_cache=copy_cache,\n            as_contiguous=as_contiguous,\n            runtime_cache=runtime_cache,\n        )\n\n    def randomize(self, data: np.ndarray) -> None:\n        self.R.shuffle(data)\n\n    def get_num_classes(self) -> int:\n        \"\"\"Get number of classes.\"\"\"\n        return self.num_class\n\n    def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]:\n        \"\"\"\n        Raises:\n            ValueError: When ``section`` is not one of [\"training\", \"validation\", \"test\"].\n\n        \"\"\"\n        dataset_dir = Path(dataset_dir)\n        class_names = sorted(f\"{x.name}\" for x in dataset_dir.iterdir() if x.is_dir())  # folder name as the class name\n        self.num_class = len(class_names)\n        image_files = [[f\"{x}\" for x in (dataset_dir / class_names[i]).iterdir()] for i in range(self.num_class)]\n        num_each = [len(image_files[i]) for i in range(self.num_class)]\n        image_files_list = []\n        image_class = []\n        class_name = []\n        for i in range(self.num_class):\n            image_files_list.extend(image_files[i])\n            image_class.extend([i] * num_each[i])\n            class_name.extend([class_names[i]] * num_each[i])\n\n        length = len(image_files_list)\n        indices = np.arange(length)\n        self.randomize(indices)\n\n        test_length = int(length * self.test_frac)\n        val_length = int(length * self.val_frac)\n        if self.section == \"test\":\n            section_indices = indices[:test_length]\n        elif self.section == \"validation\":\n            section_indices = indices[test_length : test_length + val_length]\n        elif self.section == \"training\":\n            section_indices = indices[test_length + val_length :]\n        else:\n            raise ValueError(\n                f'Unsupported section: {self.section}, available options are [\"training\", \"validation\", \"test\"].'\n            )\n        # the types of label and class name should be compatible with the pytorch dataloader\n        return [\n            {\"image\": image_files_list[i], \"label\": image_class[i], \"class_name\": class_name[i]}\n            for i in section_indices\n        ]\n\n\nclass DecathlonDataset(Randomizable, CacheDataset):\n    \"\"\"\n    The Dataset to automatically download the data of Medical Segmentation Decathlon challenge\n    (http://medicaldecathlon.com/) and generate items for training, validation or test.\n    It will also load these properties from the JSON config file of dataset. user can call `get_properties()`\n    to get specified properties or all the properties loaded.\n    It's based on :py:class:`monai.data.CacheDataset` to accelerate the training process.\n\n    Args:\n        root_dir: user's local directory for caching and loading the MSD datasets.\n        task: which task to download and execute: one of list (\"Task01_BrainTumour\", \"Task02_Heart\",\n            \"Task03_Liver\", \"Task04_Hippocampus\", \"Task05_Prostate\", \"Task06_Lung\", \"Task07_Pancreas\",\n            \"Task08_HepaticVessel\", \"Task09_Spleen\", \"Task10_Colon\").\n        section: expected data section, can be: `training`, `validation` or `test`.\n        transform: transforms to execute operations on input data.\n            for further usage, use `EnsureChannelFirstd` to convert the shape to [C, H, W, D].\n        download: whether to download and extract the Decathlon from resource link, default is False.\n            if expected file already exists, skip downloading even set it to True.\n            user can manually copy tar file or dataset folder to the root directory.\n        val_frac: percentage of validation fraction in the whole dataset, default is 0.2.\n        seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.\n            note to set same seed for `training` and `validation` sections.\n        cache_num: number of items to be cached. Default is `sys.maxsize`.\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        cache_rate: percentage of cached data in total, default is 1.0 (cache all).\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        num_workers: the number of worker threads if computing cache in the initialization.\n            If num_workers is None then the number returned by os.cpu_count() is used.\n            If a value less than 1 is specified, 1 will be used instead.\n        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.\n        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,\n            default to `True`. if the random transforms don't modify the cached content\n            (for example, randomly crop from the cached image and deepcopy the crop region)\n            or if every cache item is only used once in a `multi-processing` environment,\n            may set `copy=False` for better performance.\n        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n            it may help improve the performance of following logic.\n        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare\n            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.\n\n    Raises:\n        ValueError: When ``root_dir`` is not a directory.\n        ValueError: When ``task`` is not one of [\"Task01_BrainTumour\", \"Task02_Heart\",\n            \"Task03_Liver\", \"Task04_Hippocampus\", \"Task05_Prostate\", \"Task06_Lung\", \"Task07_Pancreas\",\n            \"Task08_HepaticVessel\", \"Task09_Spleen\", \"Task10_Colon\"].\n        RuntimeError: When ``dataset_dir`` doesn't exist and downloading is not selected (``download=False``).\n\n    Example::\n\n        transform = Compose(\n            [\n                LoadImaged(keys=[\"image\", \"label\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n                ScaleIntensityd(keys=\"image\"),\n                ToTensord(keys=[\"image\", \"label\"]),\n            ]\n        )\n\n        val_data = DecathlonDataset(\n            root_dir=\"./\", task=\"Task09_Spleen\", transform=transform, section=\"validation\", seed=12345, download=True\n        )\n\n        print(val_data[0][\"image\"], val_data[0][\"label\"])\n\n    \"\"\"\n\n    resource = {\n        \"Task01_BrainTumour\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar\",\n        \"Task02_Heart\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar\",\n        \"Task03_Liver\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar\",\n        \"Task04_Hippocampus\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar\",\n        \"Task05_Prostate\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar\",\n        \"Task06_Lung\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar\",\n        \"Task07_Pancreas\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar\",\n        \"Task08_HepaticVessel\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar\",\n        \"Task09_Spleen\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar\",\n        \"Task10_Colon\": \"https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar\",\n    }\n    md5 = {\n        \"Task01_BrainTumour\": \"240a19d752f0d9e9101544901065d872\",\n        \"Task02_Heart\": \"06ee59366e1e5124267b774dbd654057\",\n        \"Task03_Liver\": \"a90ec6c4aa7f6a3d087205e23d4e6397\",\n        \"Task04_Hippocampus\": \"9d24dba78a72977dbd1d2e110310f31b\",\n        \"Task05_Prostate\": \"35138f08b1efaef89d7424d2bcc928db\",\n        \"Task06_Lung\": \"8afd997733c7fc0432f71255ba4e52dc\",\n        \"Task07_Pancreas\": \"4f7080cfca169fa8066d17ce6eb061e4\",\n        \"Task08_HepaticVessel\": \"641d79e80ec66453921d997fbf12a29c\",\n        \"Task09_Spleen\": \"410d4a301da4e5b2f6f86ec3ddba524e\",\n        \"Task10_Colon\": \"bad7a188931dc2f6acf72b08eb6202d0\",\n    }\n\n    def __init__(\n        self,\n        root_dir: PathLike,\n        task: str,\n        section: str,\n        transform: Sequence[Callable] | Callable = (),\n        download: bool = False,\n        seed: int = 0,\n        val_frac: float = 0.2,\n        cache_num: int = sys.maxsize,\n        cache_rate: float = 1.0,\n        num_workers: int = 1,\n        progress: bool = True,\n        copy_cache: bool = True,\n        as_contiguous: bool = True,\n        runtime_cache: bool = False,\n    ) -> None:\n        root_dir = Path(root_dir)\n        if not root_dir.is_dir():\n            raise ValueError(\"Root directory root_dir must be a directory.\")\n        self.section = section\n        self.val_frac = val_frac\n        self.set_random_state(seed=seed)\n        if task not in self.resource:\n            raise ValueError(f\"Unsupported task: {task}, available options are: {list(self.resource.keys())}.\")\n        dataset_dir = root_dir / task\n        tarfile_name = f\"{dataset_dir}.tar\"\n        if download:\n            download_and_extract(\n                url=self.resource[task],\n                filepath=tarfile_name,\n                output_dir=root_dir,\n                hash_val=self.md5[task],\n                hash_type=\"md5\",\n                progress=progress,\n            )\n\n        if not dataset_dir.exists():\n            raise RuntimeError(\n                f\"Cannot find dataset directory: {dataset_dir}, please use download=True to download it.\"\n            )\n        self.indices: np.ndarray = np.array([])\n        data = self._generate_data_list(dataset_dir)\n        # as `release` key has typo in Task04 config file, ignore it.\n        property_keys = [\n            \"name\",\n            \"description\",\n            \"reference\",\n            \"licence\",\n            \"tensorImageSize\",\n            \"modality\",\n            \"labels\",\n            \"numTraining\",\n            \"numTest\",\n        ]\n        self._properties = load_decathlon_properties(dataset_dir / \"dataset.json\", property_keys)\n        if transform == ():\n            transform = LoadImaged([\"image\", \"label\"])\n        CacheDataset.__init__(\n            self,\n            data=data,\n            transform=transform,\n            cache_num=cache_num,\n            cache_rate=cache_rate,\n            num_workers=num_workers,\n            progress=progress,\n            copy_cache=copy_cache,\n            as_contiguous=as_contiguous,\n            runtime_cache=runtime_cache,\n        )\n\n    def get_indices(self) -> np.ndarray:\n        \"\"\"\n        Get the indices of datalist used in this dataset.\n\n        \"\"\"\n        return self.indices\n\n    def randomize(self, data: np.ndarray) -> None:\n        self.R.shuffle(data)\n\n    def get_properties(self, keys: Sequence[str] | str | None = None) -> dict:\n        \"\"\"\n        Get the loaded properties of dataset with specified keys.\n        If no keys specified, return all the loaded properties.\n\n        \"\"\"\n        if keys is None:\n            return self._properties\n        if self._properties is not None:\n            return {key: self._properties[key] for key in ensure_tuple(keys)}\n        return {}\n\n    def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]:\n        # the types of the item in data list should be compatible with the dataloader\n        dataset_dir = Path(dataset_dir)\n        section = \"training\" if self.section in [\"training\", \"validation\"] else \"test\"\n        datalist = load_decathlon_datalist(dataset_dir / \"dataset.json\", True, section)\n        return self._split_datalist(datalist)\n\n    def _split_datalist(self, datalist: list[dict]) -> list[dict]:\n        if self.section == \"test\":\n            return datalist\n        length = len(datalist)\n        indices = np.arange(length)\n        self.randomize(indices)\n\n        val_length = int(length * self.val_frac)\n        if self.section == \"training\":\n            self.indices = indices[val_length:]\n        else:\n            self.indices = indices[:val_length]\n\n        return [datalist[i] for i in self.indices]\n\n\nclass TciaDataset(Randomizable, CacheDataset):\n    \"\"\"\n    The Dataset to automatically download the data from a public The Cancer Imaging Archive (TCIA) dataset\n    and generate items for training, validation or test.\n\n    The Highdicom library is used to load dicom data with modality \"SEG\", but only a part of collections are\n    supported, such as: \"C4KC-KiTS\", \"NSCLC-Radiomics\", \"NSCLC-Radiomics-Interobserver1\", \" QIN-PROSTATE-Repeatability\"\n    and \"PROSTATEx\". Therefore, if \"seg\" is included in `keys` of the `LoadImaged` transform and loading some\n    other collections, errors may be raised. For supported collections, the original \"SEG\" information may not\n    always be consistent for each dicom file. Therefore, to avoid creating different format of labels, please use\n    the `label_dict` argument of `PydicomReader` when calling the `LoadImaged` transform. The prepared label dicts\n    of collections that are mentioned above is also saved in: `monai.apps.tcia.TCIA_LABEL_DICT`. You can also refer\n    to the second example bellow.\n\n\n    This class is based on :py:class:`monai.data.CacheDataset` to accelerate the training process.\n\n    Args:\n        root_dir: user's local directory for caching and loading the TCIA dataset.\n        collection: name of a TCIA collection.\n            a TCIA dataset is defined as a collection. Please check the following list to browse\n            the collection list (only public collections can be downloaded):\n            https://www.cancerimagingarchive.net/collections/\n        section: expected data section, can be: `training`, `validation` or `test`.\n        transform: transforms to execute operations on input data.\n            for further usage, use `EnsureChannelFirstd` to convert the shape to [C, H, W, D].\n            If not specified, `LoadImaged(reader=\"PydicomReader\", keys=[\"image\"])` will be used as the default\n            transform. In addition, we suggest to set the argument `labels` for `PydicomReader` if segmentations\n            are needed to be loaded. The original labels for each dicom series may be different, using this argument\n            is able to unify the format of labels.\n        download: whether to download and extract the dataset, default is False.\n            if expected file already exists, skip downloading even set it to True.\n            user can manually copy tar file or dataset folder to the root directory.\n        download_len: number of series that will be downloaded, the value should be larger than 0 or -1, where -1 means\n            all series will be downloaded. Default is -1.\n        seg_type: modality type of segmentation that is used to do the first step download. Default is \"SEG\".\n        modality_tag: tag of modality. Default is (0x0008, 0x0060).\n        ref_series_uid_tag: tag of referenced Series Instance UID. Default is (0x0020, 0x000e).\n        ref_sop_uid_tag: tag of referenced SOP Instance UID. Default is (0x0008, 0x1155).\n        specific_tags: tags that will be loaded for \"SEG\" series. This argument will be used in\n            `monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010),\n            (0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)].\n        fname_regex: a regular expression to match the file names when the input is a folder.\n            If provided, only the matched files will be included. For example, to include the file name\n            \"image_0001.dcm\", the regular expression could be `\".*image_(\\\\d+).dcm\"`.\n            Default to `\"^(?!.*LICENSE).*\"`, ignoring any file name containing `\"LICENSE\"`.\n        val_frac: percentage of validation fraction in the whole dataset, default is 0.2.\n        seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.\n            note to set same seed for `training` and `validation` sections.\n        cache_num: number of items to be cached. Default is `sys.maxsize`.\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        cache_rate: percentage of cached data in total, default is 0.0 (no cache).\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        num_workers: the number of worker threads if computing cache in the initialization.\n            If num_workers is None then the number returned by os.cpu_count() is used.\n            If a value less than 1 is specified, 1 will be used instead.\n        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.\n        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,\n            default to `True`. if the random transforms don't modify the cached content\n            (for example, randomly crop from the cached image and deepcopy the crop region)\n            or if every cache item is only used once in a `multi-processing` environment,\n            may set `copy=False` for better performance.\n        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n            it may help improve the performance of following logic.\n        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare\n            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.\n\n    Example::\n\n        # collection is \"Pancreatic-CT-CBCT-SEG\", seg_type is \"RTSTRUCT\"\n        data = TciaDataset(\n            root_dir=\"./\", collection=\"Pancreatic-CT-CBCT-SEG\", seg_type=\"RTSTRUCT\", download=True\n        )\n\n        # collection is \"C4KC-KiTS\", seg_type is \"SEG\", and load both images and segmentations\n        from monai.apps.tcia import TCIA_LABEL_DICT\n        transform = Compose(\n            [\n                LoadImaged(reader=\"PydicomReader\", keys=[\"image\", \"seg\"], label_dict=TCIA_LABEL_DICT[\"C4KC-KiTS\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"seg\"]),\n                ResampleToMatchd(keys=\"image\", key_dst=\"seg\"),\n            ]\n        )\n        data = TciaDataset(\n            root_dir=\"./\", collection=\"C4KC-KiTS\", section=\"validation\", seed=12345, download=True\n        )\n\n        print(data[0][\"seg\"].shape)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        root_dir: PathLike,\n        collection: str,\n        section: str,\n        transform: Sequence[Callable] | Callable = (),\n        download: bool = False,\n        download_len: int = -1,\n        seg_type: str = \"SEG\",\n        modality_tag: tuple = (0x0008, 0x0060),\n        ref_series_uid_tag: tuple = (0x0020, 0x000E),\n        ref_sop_uid_tag: tuple = (0x0008, 0x1155),\n        specific_tags: tuple = (\n            (0x0008, 0x1115),  # Referenced Series Sequence\n            (0x0008, 0x1140),  # Referenced Image Sequence\n            (0x3006, 0x0010),  # Referenced Frame of Reference Sequence\n            (0x0020, 0x000D),  # Study Instance UID\n            (0x0010, 0x0010),  # Patient's Name\n            (0x0010, 0x0020),  # Patient ID\n            (0x0020, 0x0011),  # Series Number\n            (0x0020, 0x0012),  # Acquisition Number\n        ),\n        fname_regex: str = DCM_FILENAME_REGEX,\n        seed: int = 0,\n        val_frac: float = 0.2,\n        cache_num: int = sys.maxsize,\n        cache_rate: float = 0.0,\n        num_workers: int = 1,\n        progress: bool = True,\n        copy_cache: bool = True,\n        as_contiguous: bool = True,\n        runtime_cache: bool = False,\n    ) -> None:\n        root_dir = Path(root_dir)\n        if not root_dir.is_dir():\n            raise ValueError(\"Root directory root_dir must be a directory.\")\n\n        self.section = section\n        self.val_frac = val_frac\n        self.seg_type = seg_type\n        self.modality_tag = modality_tag\n        self.ref_series_uid_tag = ref_series_uid_tag\n        self.ref_sop_uid_tag = ref_sop_uid_tag\n\n        self.set_random_state(seed=seed)\n        download_dir = os.path.join(root_dir, collection)\n        load_tags = list(specific_tags)\n        load_tags += [modality_tag]\n        self.load_tags = load_tags\n        if download:\n            seg_series_list = get_tcia_metadata(\n                query=f\"getSeries?Collection={collection}&Modality={seg_type}\", attribute=\"SeriesInstanceUID\"\n            )\n            if download_len > 0:\n                seg_series_list = seg_series_list[:download_len]\n            if len(seg_series_list) == 0:\n                raise ValueError(f\"Cannot find data with collection: {collection} seg_type: {seg_type}\")\n            for series_uid in seg_series_list:\n                self._download_series_reference_data(series_uid, download_dir)\n\n        if not os.path.exists(download_dir):\n            raise RuntimeError(f\"Cannot find dataset directory: {download_dir}.\")\n        self.fname_regex = fname_regex\n\n        self.indices: np.ndarray = np.array([])\n        self.datalist = self._generate_data_list(download_dir)\n\n        if transform == ():\n            transform = LoadImaged(keys=[\"image\"], reader=\"PydicomReader\", fname_regex=self.fname_regex)\n        CacheDataset.__init__(\n            self,\n            data=self.datalist,\n            transform=transform,\n            cache_num=cache_num,\n            cache_rate=cache_rate,\n            num_workers=num_workers,\n            progress=progress,\n            copy_cache=copy_cache,\n            as_contiguous=as_contiguous,\n            runtime_cache=runtime_cache,\n        )\n\n    def get_indices(self) -> np.ndarray:\n        \"\"\"\n        Get the indices of datalist used in this dataset.\n\n        \"\"\"\n        return self.indices\n\n    def randomize(self, data: np.ndarray) -> None:\n        self.R.shuffle(data)\n\n    def _download_series_reference_data(self, series_uid: str, download_dir: str) -> None:\n        \"\"\"\n        First of all, download a series from TCIA according to `series_uid`.\n        Then find all referenced series and download.\n        \"\"\"\n        seg_first_dir = os.path.join(download_dir, \"raw\", series_uid)\n        download_tcia_series_instance(\n            series_uid=series_uid, download_dir=download_dir, output_dir=seg_first_dir, check_md5=False\n        )\n        dicom_files = [f for f in sorted(os.listdir(seg_first_dir)) if f.endswith(\".dcm\")]\n        # achieve series number and patient id from the first dicom file\n        dcm_path = os.path.join(seg_first_dir, dicom_files[0])\n        ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path)\n        # (0x0010,0x0020) and (0x0010,0x0010), better to be contained in `specific_tags`\n        patient_id = ds.PatientID if ds.PatientID else ds.PatientName\n        if not patient_id:\n            warnings.warn(f\"unable to find patient name of dicom file: {dcm_path}, use 'patient' instead.\")\n            patient_id = \"patient\"\n        # (0x0020,0x0011) and (0x0020,0x0012), better to be contained in `specific_tags`\n        series_num = ds.SeriesNumber if ds.SeriesNumber else ds.AcquisitionNumber\n        if not series_num:\n            warnings.warn(f\"unable to find series number of dicom file: {dcm_path}, use '0' instead.\")\n            series_num = 0\n\n        series_num = str(series_num)\n        seg_dir = os.path.join(download_dir, patient_id, series_num, self.seg_type.lower())\n        dcm_dir = os.path.join(download_dir, patient_id, series_num, \"image\")\n\n        # get ref uuid\n        ref_uid_list = []\n        for dcm_file in dicom_files:\n            dcm_path = os.path.join(seg_first_dir, dcm_file)\n            ds = PydicomReader(stop_before_pixels=True, specific_tags=self.load_tags).read(dcm_path)\n            if ds[self.modality_tag].value == self.seg_type:\n                ref_uid = get_tcia_ref_uid(\n                    ds, find_sop=False, ref_series_uid_tag=self.ref_series_uid_tag, ref_sop_uid_tag=self.ref_sop_uid_tag\n                )\n                if ref_uid == \"\":\n                    ref_sop_uid = get_tcia_ref_uid(\n                        ds,\n                        find_sop=True,\n                        ref_series_uid_tag=self.ref_series_uid_tag,\n                        ref_sop_uid_tag=self.ref_sop_uid_tag,\n                    )\n                    ref_uid = match_tcia_ref_uid_in_study(ds.StudyInstanceUID, ref_sop_uid)\n                if ref_uid != \"\":\n                    ref_uid_list.append(ref_uid)\n        if not ref_uid_list:\n            warnings.warn(f\"Cannot find the referenced Series Instance UID from series: {series_uid}.\")\n        else:\n            download_tcia_series_instance(\n                series_uid=ref_uid_list[0], download_dir=download_dir, output_dir=dcm_dir, check_md5=False\n            )\n        if not os.path.exists(seg_dir):\n            shutil.copytree(seg_first_dir, seg_dir)\n\n    def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]:\n        # the types of the item in data list should be compatible with the dataloader\n        dataset_dir = Path(dataset_dir)\n        datalist = []\n        patient_list = [f.name for f in os.scandir(dataset_dir) if f.is_dir() and f.name != \"raw\"]\n        for patient_id in patient_list:\n            series_list = [f.name for f in os.scandir(os.path.join(dataset_dir, patient_id)) if f.is_dir()]\n            for series_num in series_list:\n                seg_key = self.seg_type.lower()\n                image_path = os.path.join(dataset_dir, patient_id, series_num, \"image\")\n                mask_path = os.path.join(dataset_dir, patient_id, series_num, seg_key)\n\n                if os.path.exists(image_path):\n                    datalist.append({\"image\": image_path, seg_key: mask_path})\n                else:\n                    datalist.append({seg_key: mask_path})\n\n        return self._split_datalist(datalist)\n\n    def _split_datalist(self, datalist: list[dict]) -> list[dict]:\n        if self.section == \"test\":\n            return datalist\n        length = len(datalist)\n        indices = np.arange(length)\n        self.randomize(indices)\n\n        val_length = int(length * self.val_frac)\n        if self.section == \"training\":\n            self.indices = indices[val_length:]\n        else:\n            self.indices = indices[:val_length]\n\n        return [datalist[i] for i in self.indices]\n\n\nclass CrossValidation:\n    \"\"\"\n    Cross validation dataset based on the general dataset which must have `_split_datalist` API.\n\n    Args:\n        dataset_cls: dataset class to be used to create the cross validation partitions.\n            It must have `_split_datalist` API.\n        nfolds: number of folds to split the data for cross validation.\n        seed: random seed to randomly shuffle the datalist before splitting into N folds, default is 0.\n        dataset_params: other additional parameters for the dataset_cls base class.\n\n    Example of 5 folds cross validation training::\n\n        cvdataset = CrossValidation(\n            dataset_cls=DecathlonDataset,\n            nfolds=5,\n            seed=12345,\n            root_dir=\"./\",\n            task=\"Task09_Spleen\",\n            section=\"training\",\n            transform=train_transform,\n            download=True,\n        )\n        dataset_fold0_train = cvdataset.get_dataset(folds=[1, 2, 3, 4])\n        dataset_fold0_val = cvdataset.get_dataset(folds=0, transform=val_transform, download=False)\n        # execute training for fold 0 ...\n\n        dataset_fold1_train = cvdataset.get_dataset(folds=[0, 2, 3, 4])\n        dataset_fold1_val = cvdataset.get_dataset(folds=1, transform=val_transform, download=False)\n        # execute training for fold 1 ...\n\n        ...\n\n        dataset_fold4_train = ...\n        # execute training for fold 4 ...\n\n    \"\"\"\n\n    def __init__(self, dataset_cls: object, nfolds: int = 5, seed: int = 0, **dataset_params: Any) -> None:\n        if not hasattr(dataset_cls, \"_split_datalist\"):\n            raise ValueError(\"dataset class must have _split_datalist API.\")\n        self.dataset_cls = dataset_cls\n        self.nfolds = nfolds\n        self.seed = seed\n        self.dataset_params = dataset_params\n\n    def get_dataset(self, folds: Sequence[int] | int, **dataset_params: Any) -> object:\n        \"\"\"\n        Generate dataset based on the specified fold indices in the cross validation group.\n\n        Args:\n            folds: index of folds for training or validation, if a list of values, concatenate the data.\n            dataset_params: other additional parameters for the dataset_cls base class, will override\n                the same parameters in `self.dataset_params`.\n\n        \"\"\"\n        nfolds = self.nfolds\n        seed = self.seed\n        dataset_params_ = dict(self.dataset_params)\n        dataset_params_.update(dataset_params)\n\n        class _NsplitsDataset(self.dataset_cls):  # type: ignore\n\n            def _split_datalist(self, datalist: list[dict]) -> list[dict]:\n                data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=True, seed=seed)\n                return select_cross_validation_folds(partitions=data, folds=folds)\n\n        return _NsplitsDataset(**dataset_params_)\n"
  },
  {
    "path": "monai/apps/deepedit/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/deepedit/interaction.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sequence\n\nimport numpy as np\nimport torch\n\nfrom monai.data import decollate_batch, list_data_collate\nfrom monai.engines import SupervisedEvaluator, SupervisedTrainer\nfrom monai.engines.utils import IterationEvents\nfrom monai.transforms import Compose\nfrom monai.utils.enums import CommonKeys\n\n\nclass Interaction:\n    \"\"\"\n    Ignite process_function used to introduce interactions (simulation of clicks) for DeepEdit Training/Evaluation.\n\n    More details about this can be found at:\n\n        Diaz-Pinto et al., MONAI Label: A framework for AI-assisted Interactive\n        Labeling of 3D Medical Images. (2022) https://arxiv.org/abs/2203.12362\n\n    Args:\n        deepgrow_probability: probability of simulating clicks in an iteration\n        transforms: execute additional transformation during every iteration (before train).\n            Typically, several Tensor based transforms composed by `Compose`.\n        train: True for training mode or False for evaluation mode\n        click_probability_key: key to click/interaction probability\n        label_names: Dict of label names\n        max_interactions: maximum number of interactions per iteration\n    \"\"\"\n\n    def __init__(\n        self,\n        deepgrow_probability: float,\n        transforms: Sequence[Callable] | Callable,\n        train: bool,\n        label_names: None | dict[str, int] = None,\n        click_probability_key: str = \"probability\",\n        max_interactions: int = 1,\n    ) -> None:\n        self.deepgrow_probability = deepgrow_probability\n        self.transforms = Compose(transforms) if not isinstance(transforms, Compose) else transforms\n        self.train = train\n        self.label_names = label_names\n        self.click_probability_key = click_probability_key\n        self.max_interactions = max_interactions\n\n    def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:\n        if batchdata is None:\n            raise ValueError(\"Must provide batch data for current iteration.\")\n\n        if np.random.choice([True, False], p=[self.deepgrow_probability, 1 - self.deepgrow_probability]):\n            for j in range(self.max_interactions):\n                inputs, _ = engine.prepare_batch(batchdata)\n                inputs = inputs.to(engine.state.device)\n\n                engine.fire_event(IterationEvents.INNER_ITERATION_STARTED)\n                engine.network.eval()\n\n                with torch.no_grad():\n                    if engine.amp:\n                        with torch.autocast(\"cuda\"):\n                            predictions = engine.inferer(inputs, engine.network)\n                    else:\n                        predictions = engine.inferer(inputs, engine.network)\n                batchdata.update({CommonKeys.PRED: predictions})\n\n                # decollate/collate batchdata to execute click transforms\n                batchdata_list = decollate_batch(batchdata, detach=True)\n                for i in range(len(batchdata_list)):\n                    batchdata_list[i][self.click_probability_key] = (\n                        (1.0 - ((1.0 / self.max_interactions) * j)) if self.train else 1.0\n                    )\n                    batchdata_list[i] = self.transforms(batchdata_list[i])\n\n                batchdata = list_data_collate(batchdata_list)\n                engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)\n        else:\n            # zero out input guidance channels\n            batchdata_list = decollate_batch(batchdata, detach=True)\n            for i in range(1, len(batchdata_list[0][CommonKeys.IMAGE])):\n                batchdata_list[0][CommonKeys.IMAGE][i] *= 0\n            batchdata = list_data_collate(batchdata_list)\n\n        # first item in batch only\n        engine.state.batch = batchdata\n        return engine._iteration(engine, batchdata)  # type: ignore[arg-type]\n"
  },
  {
    "path": "monai/apps/deepedit/transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport random\nimport warnings\nfrom collections.abc import Hashable, Mapping, Sequence, Sized\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection\nfrom monai.data import MetaTensor\nfrom monai.networks.layers import GaussianFilter\nfrom monai.transforms.transform import MapTransform, Randomizable, Transform\nfrom monai.utils import deprecated, min_version, optional_import\n\nmeasure, _ = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\n\nlogger = logging.getLogger(__name__)\n\ndistance_transform_cdt, _ = optional_import(\"scipy.ndimage\", name=\"distance_transform_cdt\")\n\n\nclass DiscardAddGuidanced(MapTransform):\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        number_intensity_ch: int = 1,\n        probability: float = 1.0,\n        label_names: Sized | None = None,\n        allow_missing_keys: bool = False,\n    ):\n        \"\"\"\n        Discard positive and negative points according to discard probability\n\n        Args:\n            keys: The ``keys`` parameter will be used to get and set the actual data item to transform\n            number_intensity_ch: number of intensity channels\n            probability: probability of discarding clicks\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n\n        self.number_intensity_ch = number_intensity_ch\n        self.discard_probability = probability\n        self.label_names = label_names or []\n\n    def _apply(self, image):\n        if self.discard_probability >= 1.0 or np.random.choice(\n            [True, False], p=[self.discard_probability, 1 - self.discard_probability]\n        ):\n            signal = np.zeros(\n                (len(self.label_names), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32\n            )\n            if image.shape[0] == self.number_intensity_ch + len(self.label_names):\n                image[self.number_intensity_ch :, ...] = signal\n            else:\n                image = np.concatenate([image, signal], axis=0)\n        return image\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"image\":\n                tmp_image = self._apply(d[key])\n                if isinstance(d[key], MetaTensor):\n                    d[key].array = tmp_image\n                else:\n                    d[key] = tmp_image\n            else:\n                print(\"This transform only applies to the image\")\n        return d\n\n\nclass RemapLabelsToSequentiald(MapTransform):\n    \"\"\"\n    Remap label values from a dataset-specific schema to sequential indices (0, 1, 2, 3, ...).\n\n    This transform takes labels with arbitrary values defined in a label dictionary and remaps them\n    to a sequential range starting from 1 (with background always set to 0). This is useful for\n    standardizing labels across different datasets or ensuring labels are in a contiguous range.\n\n    The output label indices are assigned in alphabetical order by label name to ensure\n    deterministic behavior regardless of input dictionary ordering.\n\n    Args:\n        keys: The ``keys`` parameter will be used to get and set the actual data item to transform\n        label_names: Dictionary mapping label names to their current values in the dataset.\n            For example: {\"spleen\": 1, \"liver\": 6, \"background\": 0}\n            Will be remapped to: {\"background\": 0, \"liver\": 1, \"spleen\": 2}\n            (alphabetically sorted, excluding background)\n        allow_missing_keys: If True, missing keys in the data dictionary will not raise an error\n\n    Example:\n        >>> transform = RemapLabelsToSequentiald(\n        ...     keys=\"label\",\n        ...     label_names={\"liver\": 6, \"spleen\": 1, \"background\": 0}\n        ... )\n        >>> # Input label has values [0, 1, 6]\n        >>> # Output label will have values [0, 1, 2] (background=0, liver=1, spleen=2)\n        >>> # And updates d[\"label_names\"] to {\"background\": 0, \"liver\": 1, \"spleen\": 2}\n\n    Note:\n        - Background label (if present) is always mapped to 0\n        - Non-background labels are mapped to sequential indices 1, 2, 3, ... in alphabetical order\n        - Undefined labels (not in label_names) will be set to 0 (background)\n        - The transform updates the data dictionary with a new \"label_names\" key containing the remapped values\n    \"\"\"\n\n    def __init__(\n        self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False\n    ):\n        super().__init__(keys, allow_missing_keys)\n\n        self.label_names = label_names or {}\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            # Dictionary containing new label numbers\n            new_label_names = {}\n            label = np.zeros(d[key].shape)\n\n            # Sort label names to ensure deterministic ordering (exclude background)\n            sorted_labels = sorted([(k, v) for k, v in self.label_names.items() if k != \"background\"])\n\n            # Always set background to 0 first\n            if \"background\" in self.label_names:\n                new_label_names[\"background\"] = 0\n\n            # Assign sequential indices to sorted non-background labels\n            for idx, (key_label, val_label) in enumerate(sorted_labels, start=1):\n                new_label_names[key_label] = idx\n                label[d[key] == val_label] = idx\n\n            d[\"label_names\"] = new_label_names\n            if isinstance(d[key], MetaTensor):\n                d[key].array = label\n            else:\n                d[key] = label\n        return d\n\n\n@deprecated(since=\"1.6\", removed=\"1.8\", msg_suffix=\"Use `RemapLabelsToSequentiald` instead.\")\nclass NormalizeLabelsInDatasetd(RemapLabelsToSequentiald):\n    \"\"\"\n    .. deprecated:: 1.6.0\n        `NormalizeLabelsInDatasetd` is deprecated and will be removed in version 1.8.0.\n        Use :class:`RemapLabelsToSequentiald` instead.\n\n    This class is maintained for backward compatibility. Please use RemapLabelsToSequentiald\n    which better describes the transform's functionality.\n    \"\"\"\n\n\nclass SingleLabelSelectiond(MapTransform):\n\n    def __init__(\n        self, keys: KeysCollection, label_names: Sequence[str] | None = None, allow_missing_keys: bool = False\n    ):\n        \"\"\"\n        Selects one label at a time to train the DeepEdit\n\n        Args:\n            keys: The ``keys`` parameter will be used to get and set the actual data item to transform\n            label_names: all label names\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n\n        self.label_names: Sequence[str] = label_names or []\n        self.all_label_values = {\n            \"spleen\": 1,\n            \"right kidney\": 2,\n            \"left kidney\": 3,\n            \"gallbladder\": 4,\n            \"esophagus\": 5,\n            \"liver\": 6,\n            \"stomach\": 7,\n            \"aorta\": 8,\n            \"inferior vena cava\": 9,\n            \"portal_vein\": 10,\n            \"splenic_vein\": 11,\n            \"pancreas\": 12,\n            \"right adrenal gland\": 13,\n            \"left adrenal gland\": 14,\n        }\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"label\":\n                # Taking one label at a time\n                t_label = np.random.choice(self.label_names)\n                d[\"current_label\"] = t_label\n                d[key][d[key] != self.all_label_values[t_label]] = 0.0\n                # Convert label to index values following label_names argument\n                max_label_val = self.label_names.index(t_label) + 1\n                d[key][d[key] > 0] = max_label_val\n                print(f\"Using label {t_label} with number: {d[key].max()}\")\n            else:\n                warnings.warn(\"This transform only applies to the label\")\n        return d\n\n\nclass AddGuidanceSignalDeepEditd(MapTransform):\n    \"\"\"\n    Add Guidance signal for input image. Multilabel DeepEdit\n\n    Based on the \"guidance\" points, apply Gaussian to them and add them as new channel for input image.\n\n    Args:\n        guidance: key to store guidance.\n        sigma: standard deviation for Gaussian kernel.\n        number_intensity_ch: channel index.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        guidance: str = \"guidance\",\n        sigma: int = 3,\n        number_intensity_ch: int = 1,\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.guidance = guidance\n        self.sigma = sigma\n        self.number_intensity_ch = number_intensity_ch\n\n    def _get_signal(self, image, guidance):\n        dimensions = 3 if len(image.shape) > 3 else 2\n        guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance\n        guidance = json.loads(guidance) if isinstance(guidance, str) else guidance\n\n        # In inference the user may not provide clicks for some channels/labels\n        if len(guidance):\n            if dimensions == 3:\n                # Assume channel is first and depth is last CHWD\n                signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)\n            else:\n                signal = np.zeros((1, image.shape[-2], image.shape[-1]), dtype=np.float32)\n\n            sshape = signal.shape\n            for point in guidance:  # TO DO: make the guidance a list only - it is currently a list of list\n                if np.any(np.asarray(point) < 0):\n                    continue\n\n                if dimensions == 3:\n                    # Making sure points fall inside the image dimension\n                    p1 = max(0, min(int(point[-3]), sshape[-3] - 1))\n                    p2 = max(0, min(int(point[-2]), sshape[-2] - 1))\n                    p3 = max(0, min(int(point[-1]), sshape[-1] - 1))\n                    signal[:, p1, p2, p3] = 1.0\n                else:\n                    p1 = max(0, min(int(point[-2]), sshape[-2] - 1))\n                    p2 = max(0, min(int(point[-1]), sshape[-1] - 1))\n                    signal[:, p1, p2] = 1.0\n\n            # Apply a Gaussian filter to the signal\n            if np.max(signal[0]) > 0:\n                signal_tensor = torch.tensor(signal[0])\n                pt_gaussian = GaussianFilter(len(signal_tensor.shape), sigma=self.sigma)\n                signal_tensor = pt_gaussian(signal_tensor.unsqueeze(0).unsqueeze(0))\n                signal_tensor = signal_tensor.squeeze(0).squeeze(0)\n                signal[0] = signal_tensor.detach().cpu().numpy()\n                signal[0] = (signal[0] - np.min(signal[0])) / (np.max(signal[0]) - np.min(signal[0]))\n            return signal\n        else:\n            if dimensions == 3:\n                signal = np.zeros((1, image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)\n            else:\n                signal = np.zeros((1, image.shape[-2], image.shape[-1]), dtype=np.float32)\n            return signal\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"image\":\n                image = d[key]\n                tmp_image = image[0 : 0 + self.number_intensity_ch, ...]\n                guidance = d[self.guidance]\n                for key_label in guidance.keys():\n                    # Getting signal based on guidance\n                    signal = self._get_signal(image, guidance[key_label])\n                    tmp_image = np.concatenate([tmp_image, signal], axis=0)\n                    if isinstance(d[key], MetaTensor):\n                        d[key].array = tmp_image\n                    else:\n                        d[key] = tmp_image\n                return d\n            else:\n                print(\"This transform only applies to image key\")\n        return d\n\n\nclass FindAllValidSlicesDeepEditd(MapTransform):\n    \"\"\"\n    Find/List all valid slices in the labels.\n    Label is assumed to be a 4D Volume with shape CHWD, where C=1.\n\n    Args:\n        sids: key to store slices indices having valid label map.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, sids: Hashable = \"sids\", allow_missing_keys: bool = False):\n        super().__init__(keys, allow_missing_keys)\n        self.sids = sids\n\n    def _apply(self, label, d):\n        sids = {}\n        for key_label in d[\"label_names\"].keys():\n            l_ids = []\n            for sid in range(label.shape[-1]):  # Assume channel is first and depth is last CHWD\n                if d[\"label_names\"][key_label] in label[0][..., sid]:\n                    l_ids.append(sid)\n            sids[key_label] = l_ids\n        return sids\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"label\":\n                label = d[key]\n                if label.shape[0] != 1:\n                    raise ValueError(\"Only supports single channel labels!\")\n\n                if len(label.shape) != 4:  # only for 3D\n                    raise ValueError(\"Only supports label with shape CHWD!\")\n\n                sids = self._apply(label, d)\n                if sids is not None and len(sids.keys()):\n                    d[self.sids] = sids\n                return d\n            else:\n                print(\"This transform only applies to label key\")\n        return d\n\n\nclass AddInitialSeedPointDeepEditd(Randomizable, MapTransform):\n    \"\"\"\n    Add random guidance as initial seed point for a given label.\n\n    Note that the label is of size (C, D, H, W) or (C, H, W)\n\n    The guidance is of size (2, N, # of dims) where N is number of guidance added.\n    # of dims = 4 when C, D, H, W; # of dims = 3 when (C, H, W)\n\n    Args:\n        guidance: key to store guidance.\n        sids: key that represents lists of valid slice indices for the given label.\n        sid: key that represents the slice to add initial seed point.  If not present, random sid will be chosen.\n        connected_regions: maximum connected regions to use for adding initial points.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        guidance: str = \"guidance\",\n        sids: str = \"sids\",\n        sid: str = \"sid\",\n        connected_regions: int = 5,\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.sids_key = sids\n        self.sid_key = sid\n        self.sid: dict[str, int] = dict()\n        self.guidance = guidance\n        self.connected_regions = connected_regions\n\n    def _apply(self, label, sid, key_label):\n        dimensions = 3 if len(label.shape) > 3 else 2\n        self.default_guidance = [-1] * (dimensions + 1)\n\n        dims = dimensions\n        if sid is not None and dimensions == 3:\n            dims = 2\n            label = label[0][..., sid][np.newaxis]  # Assume channel is first and depth is last CHWD\n\n        # THERE MAY BE MULTIPLE BLOBS FOR SINGLE LABEL IN THE SELECTED SLICE\n        label = (label > 0.5).astype(np.float32)\n        # measure.label: Label connected regions of an integer array - Two pixels are connected\n        # when they are neighbors and have the same value\n        blobs_labels = measure.label(label.astype(int), background=0) if dims == 2 else label\n        if np.max(blobs_labels) <= 0:\n            raise AssertionError(f\"SLICES NOT FOUND FOR LABEL: {key_label}\")\n\n        pos_guidance = []\n        for ridx in range(1, 2 if dims == 3 else self.connected_regions + 1):\n            if dims == 2:\n                label = (blobs_labels == ridx).astype(np.float32)\n                if np.sum(label) == 0:\n                    pos_guidance.append(self.default_guidance)\n                    continue\n\n            # The distance transform provides a metric or measure of the separation of points in the image.\n            # This function calculates the distance between each pixel that is set to off (0) and\n            # the nearest nonzero pixel for binary images - http://matlab.izmiran.ru/help/toolbox/images/morph14.html\n            distance = distance_transform_cdt(label).flatten()\n            probability = np.exp(distance) - 1.0\n\n            idx = np.where(label.flatten() > 0)[0]\n            seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))\n            dst = distance[seed]\n\n            g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0]\n            g[0] = dst[0]  # for debug\n            if dimensions == 2 or dims == 3:\n                pos_guidance.append(g)\n            else:\n                # Clicks are created using this convention Channel Height Width Depth (CHWD)\n                pos_guidance.append([g[0], g[-2], g[-1], sid])  # Assume channel is first and depth is last CHWD\n\n        return np.asarray([pos_guidance])\n\n    def _randomize(self, d, key_label):\n        sids = d.get(self.sids_key).get(key_label) if d.get(self.sids_key) is not None else None\n        sid = d.get(self.sid_key).get(key_label) if d.get(self.sid_key) is not None else None\n        if sids is not None and sids:\n            if sid is None or sid not in sids:\n                sid = self.R.choice(sids, replace=False)\n        else:\n            logger.info(f\"Not slice IDs for label: {key_label}\")\n            sid = None\n        self.sid[key_label] = sid\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"label\":\n                label_guidances = {}\n                for key_label in d[\"sids\"].keys():\n                    # Randomize: Select a random slice\n                    self._randomize(d, key_label)\n                    # Generate guidance base on selected slice\n                    tmp_label = np.copy(d[key])\n                    # Taking one label to create the guidance\n                    if key_label != \"background\":\n                        tmp_label[tmp_label != float(d[\"label_names\"][key_label])] = 0\n                    else:\n                        tmp_label[tmp_label != float(d[\"label_names\"][key_label])] = 1\n                        tmp_label = 1 - tmp_label\n                    label_guidances[key_label] = json.dumps(\n                        self._apply(tmp_label, self.sid.get(key_label), key_label).astype(int).tolist()\n                    )\n                d[self.guidance] = label_guidances\n                return d\n            else:\n                print(\"This transform only applies to label key\")\n        return d\n\n\nclass FindDiscrepancyRegionsDeepEditd(MapTransform):\n    \"\"\"\n    Find discrepancy between prediction and actual during click interactions during training.\n\n    Args:\n        pred: key to prediction source.\n        discrepancy: key to store discrepancies found between label and prediction.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        pred: str = \"pred\",\n        discrepancy: str = \"discrepancy\",\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.pred = pred\n        self.discrepancy = discrepancy\n\n    @staticmethod\n    def disparity(label, pred):\n        disparity = label - pred\n        # Negative ONES mean predicted label is not part of the ground truth\n        # Positive ONES mean predicted label missed that region of the ground truth\n        pos_disparity = (disparity > 0).astype(np.float32)\n        neg_disparity = (disparity < 0).astype(np.float32)\n        return [pos_disparity, neg_disparity]\n\n    def _apply(self, label, pred):\n        return self.disparity(label, pred)\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"label\":\n                all_discrepancies = {}\n                for _, (key_label, val_label) in enumerate(d[\"label_names\"].items()):\n                    if key_label != \"background\":\n                        # Taking single label\n                        label = np.copy(d[key])\n                        label[label != val_label] = 0\n                        # Label should be represented in 1\n                        label = (label > 0.5).astype(np.float32)\n                        # Taking single prediction\n                        pred = np.copy(d[self.pred])\n                        pred[pred != val_label] = 0\n                        # Prediction should be represented in one\n                        pred = (pred > 0.5).astype(np.float32)\n                    else:\n                        # Taking single label\n                        label = np.copy(d[key])\n                        label[label != val_label] = 1\n                        label = 1 - label\n                        # Label should be represented in 1\n                        label = (label > 0.5).astype(np.float32)\n                        # Taking single prediction\n                        pred = np.copy(d[self.pred])\n                        pred[pred != val_label] = 1\n                        pred = 1 - pred\n                        # Prediction should be represented in one\n                        pred = (pred > 0.5).astype(np.float32)\n                    all_discrepancies[key_label] = self._apply(label, pred)\n                d[self.discrepancy] = all_discrepancies\n                return d\n            else:\n                print(\"This transform only applies to 'label' key\")\n        return d\n\n\nclass AddRandomGuidanceDeepEditd(Randomizable, MapTransform):\n    \"\"\"\n    Add random guidance based on discrepancies that were found between label and prediction.\n\n    Args:\n        guidance: key to guidance source, shape (2, N, # of dim)\n        discrepancy: key to discrepancy map between label and prediction shape (2, C, H, W, D) or (2, C, H, W)\n        probability: key to click/interaction probability, shape (1)\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        guidance: str = \"guidance\",\n        discrepancy: str = \"discrepancy\",\n        probability: str = \"probability\",\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.guidance_key = guidance\n        self.discrepancy = discrepancy\n        self.probability = probability\n        self._will_interact = None\n        self.is_pos: bool | None = None\n        self.is_other: bool | None = None\n        self.default_guidance = None\n        self.guidance: dict[str, list[list[int]]] = {}\n\n    def randomize(self, data=None):\n        probability = data[self.probability]\n        self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])\n\n    def find_guidance(self, discrepancy):\n        distance = distance_transform_cdt(discrepancy).flatten()\n        probability = np.exp(distance.flatten()) - 1.0\n        idx = np.where(discrepancy.flatten() > 0)[0]\n\n        if np.sum(discrepancy > 0) > 0:\n            seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))\n            dst = distance[seed]\n\n            g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0]\n            g[0] = dst[0]\n            return g\n        return None\n\n    def add_guidance(self, guidance, discrepancy, label_names, labels):\n        # Positive clicks of the segment in the iteration\n        pos_discr = discrepancy[0]  # idx 0 is positive discrepancy and idx 1 is negative discrepancy\n\n        # Check the areas that belong to other segments\n        other_discrepancy_areas = {}\n        for _, (key_label, val_label) in enumerate(label_names.items()):\n            if key_label != \"background\":\n                tmp_label = np.copy(labels)\n                tmp_label[tmp_label != val_label] = 0\n                tmp_label = (tmp_label > 0.5).astype(np.float32)\n                other_discrepancy_areas[key_label] = np.sum(discrepancy[1] * tmp_label)\n            else:\n                tmp_label = np.copy(labels)\n                tmp_label[tmp_label != val_label] = 1\n                tmp_label = 1 - tmp_label\n                other_discrepancy_areas[key_label] = np.sum(discrepancy[1] * tmp_label)\n\n        # Add guidance to the current key label\n        if np.sum(pos_discr) > 0:\n            guidance.append(self.find_guidance(pos_discr))\n            self.is_pos = True\n\n        # Add guidance to the other areas\n        for key_label in label_names.keys():\n            # Areas that cover more than 50 voxels\n            if other_discrepancy_areas[key_label] > 50:\n                self.is_other = True\n                if key_label != \"background\":\n                    tmp_label = np.copy(labels)\n                    tmp_label[tmp_label != label_names[key_label]] = 0\n                    tmp_label = (tmp_label > 0.5).astype(np.float32)\n                    self.guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label))\n                else:\n                    tmp_label = np.copy(labels)\n                    tmp_label[tmp_label != label_names[key_label]] = 1\n                    tmp_label = 1 - tmp_label\n                    self.guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label))\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        guidance = d[self.guidance_key]\n        discrepancy = d[self.discrepancy]\n        self.randomize(data)\n        if self._will_interact:\n            # Convert all guidance to lists so new guidance can be easily appended\n            for key_label in d[\"label_names\"].keys():\n                tmp_gui = guidance[key_label]\n                tmp_gui = tmp_gui.tolist() if isinstance(tmp_gui, np.ndarray) else tmp_gui\n                tmp_gui = json.loads(tmp_gui) if isinstance(tmp_gui, str) else tmp_gui\n                self.guidance[key_label] = [j for j in tmp_gui if -1 not in j]\n\n            # Add guidance according to discrepancy\n            for key_label in d[\"label_names\"].keys():\n                # Add guidance based on discrepancy\n                self.add_guidance(self.guidance[key_label], discrepancy[key_label], d[\"label_names\"], d[\"label\"])\n\n            # Checking the number of clicks\n            num_clicks = random.randint(1, 10)\n            counter = 0\n            keep_guidance = []\n            while True:\n                aux_label = random.choice(list(d[\"label_names\"].keys()))\n                if aux_label in keep_guidance:\n                    pass\n                else:\n                    keep_guidance.append(aux_label)\n                    counter = counter + len(self.guidance[aux_label])\n                    # If collected clicks is bigger than max clicks, discard the others\n                    if counter >= num_clicks:\n                        for key_label in d[\"label_names\"].keys():\n                            if key_label not in keep_guidance:\n                                self.guidance[key_label] = []\n                        logger.info(f\"Number of simulated clicks: {counter}\")\n                        break\n\n                # Breaking once all labels are covered\n                if len(keep_guidance) == len(d[\"label_names\"].keys()):\n                    logger.info(f\"Number of simulated clicks: {counter}\")\n                    break\n        d[self.guidance_key] = self.guidance  # Update the guidance\n        return d\n\n\nclass AddGuidanceFromPointsDeepEditd(Transform):\n    \"\"\"\n    Add guidance based on user clicks. ONLY WORKS FOR 3D\n\n    We assume the input is loaded by LoadImaged and has the shape of (H, W, D) originally.\n    Clicks always specify the coordinates in (H, W, D)\n\n    Args:\n        ref_image: key to reference image to fetch current and original image details.\n        guidance: output key to store guidance.\n        meta_keys: explicitly indicate the key of the metadata dictionary of `ref_image`.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.\n        meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        ref_image: str,\n        guidance: str = \"guidance\",\n        label_names: dict | None = None,\n        meta_keys: str | None = None,\n        meta_key_postfix: str = \"meta_dict\",\n    ):\n        self.ref_image = ref_image\n        self.guidance = guidance\n        self.label_names = label_names or {}\n        self.meta_keys = meta_keys\n        self.meta_key_postfix = meta_key_postfix\n\n    @staticmethod\n    def _apply(clicks, factor):\n        if len(clicks):\n            guidance = np.multiply(clicks, factor).astype(int).tolist()\n            return guidance\n        else:\n            return []\n\n    def __call__(self, data):\n        d = dict(data)\n        meta_dict_key = self.meta_keys or f\"{self.ref_image}_{self.meta_key_postfix}\"\n        # extract affine matrix from metadata\n        if isinstance(d[self.ref_image], MetaTensor):\n            meta_dict = d[self.ref_image].meta\n        elif meta_dict_key in d:\n            meta_dict = d[meta_dict_key]\n        else:\n            raise ValueError(\n                f\"{meta_dict_key} is not found. Please check whether it is the correct the image meta key.\"\n            )\n\n        if \"spatial_shape\" not in meta_dict:\n            raise RuntimeError('Missing \"spatial_shape\" in meta_dict!')\n\n        # Assume channel is first and depth is last CHWD\n        original_shape = meta_dict[\"spatial_shape\"]\n        current_shape = list(d[self.ref_image].shape)[1:]\n\n        # in here we assume the depth dimension is in the last dimension of \"original_shape\" and \"current_shape\"\n        factor = np.array(current_shape) / original_shape\n\n        # Creating guidance for all clicks\n        all_guidances = {}\n        for key_label in self.label_names.keys():\n            clicks = d.get(key_label, [])\n            clicks = list(np.array(clicks).astype(int))\n            all_guidances[key_label] = self._apply(clicks, factor)\n        d[self.guidance] = all_guidances\n        return d\n\n\nclass ResizeGuidanceMultipleLabelDeepEditd(Transform):\n    \"\"\"\n    Resize the guidance based on cropped vs resized image.\n\n    \"\"\"\n\n    def __init__(self, guidance: str, ref_image: str) -> None:\n        self.guidance = guidance\n        self.ref_image = ref_image\n\n    def __call__(self, data):\n        d = dict(data)\n        # Assume channel is first and depth is last CHWD\n        current_shape = d[self.ref_image].shape[1:]\n\n        meta_dict_key = \"image_meta_dict\"\n        # extract affine matrix from metadata\n        if isinstance(d[self.ref_image], MetaTensor):\n            meta_dict = d[self.ref_image].meta\n        elif meta_dict_key in d:\n            meta_dict = d[meta_dict_key]\n        else:\n            raise ValueError(\n                f\"{meta_dict_key} is not found. Please check whether it is the correct the image meta key.\"\n            )\n\n        original_shape = meta_dict[\"spatial_shape\"]\n\n        factor = np.divide(current_shape, original_shape)\n        all_guidances = {}\n        for key_label in d[self.guidance].keys():\n            guidance = (\n                np.multiply(d[self.guidance][key_label], factor).astype(int).tolist()\n                if len(d[self.guidance][key_label])\n                else []\n            )\n            all_guidances[key_label] = guidance\n\n        d[self.guidance] = all_guidances\n        return d\n\n\nclass SplitPredsLabeld(MapTransform):\n    \"\"\"\n    Split preds and labels for individual evaluation\n\n    \"\"\"\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"pred\":\n                for idx, (key_label, _) in enumerate(d[\"label_names\"].items()):\n                    if key_label != \"background\":\n                        d[f\"pred_{key_label}\"] = d[key][idx + 1, ...][None]\n                        d[f\"label_{key_label}\"] = d[\"label\"][idx + 1, ...][None]\n            elif key != \"pred\":\n                logger.info(\"This is only for pred key\")\n        return d\n\n\nclass AddInitialSeedPointMissingLabelsd(Randomizable, MapTransform):\n    \"\"\"\n    Add random guidance as initial seed point for a given label.\n    Note that the label is of size (C, D, H, W) or (C, H, W)\n    The guidance is of size (2, N, # of dims) where N is number of guidance added.\n    # of dims = 4 when C, D, H, W; # of dims = 3 when (C, H, W)\n    Args:\n        guidance: key to store guidance.\n        sids: key that represents lists of valid slice indices for the given label.\n        sid: key that represents the slice to add initial seed point.  If not present, random sid will be chosen.\n        connected_regions: maximum connected regions to use for adding initial points.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        guidance: str = \"guidance\",\n        sids: str = \"sids\",\n        sid: str = \"sid\",\n        connected_regions: int = 5,\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.sids_key = sids\n        self.sid_key = sid\n        self.sid: dict[str, int] = dict()\n        self.guidance = guidance\n        self.connected_regions = connected_regions\n\n    def _apply(self, label, sid):\n        dimensions = 3 if len(label.shape) > 3 else 2\n        self.default_guidance = [-1] * (dimensions + 1)\n\n        dims = dimensions\n        if sid is not None and dimensions == 3:\n            dims = 2\n            label = label[0][..., sid][np.newaxis]  # Assume channel is first and depth is last CHWD\n\n        # THERE MAY BE MULTIPLE BLOBS FOR SINGLE LABEL IN THE SELECTED SLICE\n        label = (label > 0.5).astype(np.float32)\n        # measure.label: Label connected regions of an integer array - Two pixels are connected\n        # when they are neighbors and have the same value\n        blobs_labels = measure.label(label.astype(int), background=0) if dims == 2 else label\n\n        label_guidance = []\n        # If there are is presence of that label in this slice\n        if np.max(blobs_labels) <= 0:\n            label_guidance.append(self.default_guidance)\n        else:\n            for ridx in range(1, 2 if dims == 3 else self.connected_regions + 1):\n                if dims == 2:\n                    label = (blobs_labels == ridx).astype(np.float32)\n                    if np.sum(label) == 0:\n                        label_guidance.append(self.default_guidance)\n                        continue\n\n                # The distance transform provides a metric or measure of the separation of points in the image.\n                # This function calculates the distance between each pixel that is set to off (0) and\n                # the nearest nonzero pixel for binary images\n                # http://matlab.izmiran.ru/help/toolbox/images/morph14.html\n                distance = distance_transform_cdt(label).flatten()\n                probability = np.exp(distance) - 1.0\n\n                idx = np.where(label.flatten() > 0)[0]\n                seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))\n                dst = distance[seed]\n\n                g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0]\n                g[0] = dst[0]  # for debug\n                if dimensions == 2 or dims == 3:\n                    label_guidance.append(g)\n                else:\n                    # Clicks are created using this convention Channel Height Width Depth (CHWD)\n                    label_guidance.append([g[0], g[-2], g[-1], sid])  # Assume channel is first and depth is last CHWD\n\n        return np.asarray(label_guidance)\n\n    def _randomize(self, d, key_label):\n        sids = d.get(self.sids_key).get(key_label) if d.get(self.sids_key) is not None else None\n        sid = d.get(self.sid_key).get(key_label) if d.get(self.sid_key) is not None else None\n        if sids is not None and sids:\n            if sid is None or sid not in sids:\n                sid = self.R.choice(sids, replace=False)\n        else:\n            logger.info(f\"Not slice IDs for label: {key_label}\")\n            sid = None\n        self.sid[key_label] = sid\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"label\":\n                label_guidances = {}\n                for key_label in d[\"sids\"].keys():\n                    # Randomize: Select a random slice\n                    self._randomize(d, key_label)\n                    # Generate guidance base on selected slice\n                    tmp_label = np.copy(d[key])\n                    # Taking one label to create the guidance\n                    if key_label != \"background\":\n                        tmp_label[tmp_label != float(d[\"label_names\"][key_label])] = 0\n                    else:\n                        tmp_label[tmp_label != float(d[\"label_names\"][key_label])] = 1\n                        tmp_label = 1 - tmp_label\n                    label_guidances[key_label] = json.dumps(\n                        self._apply(tmp_label, self.sid.get(key_label)).astype(int).tolist()\n                    )\n                d[self.guidance] = label_guidances\n                return d\n            else:\n                print(\"This transform only applies to label key\")\n        return d\n\n\nclass FindAllValidSlicesMissingLabelsd(MapTransform):\n    \"\"\"\n    Find/List all valid slices in the labels.\n    Label is assumed to be a 4D Volume with shape CHWD, where C=1.\n    Args:\n        sids: key to store slices indices having valid label map.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, sids: Hashable = \"sids\", allow_missing_keys: bool = False):\n        super().__init__(keys, allow_missing_keys)\n        self.sids = sids\n\n    def _apply(self, label, d):\n        sids = {}\n        for key_label in d[\"label_names\"].keys():\n            l_ids = []\n            for sid in range(label.shape[-1]):  # Assume channel is first and depth is last CHWD\n                if d[\"label_names\"][key_label] in label[0][..., sid]:\n                    l_ids.append(sid)\n            # If there are not slices with the label\n            if l_ids == []:\n                l_ids = [-1] * 10\n            sids[key_label] = l_ids\n        return sids\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d: dict = dict(data)\n        for key in self.key_iterator(d):\n            if key == \"label\":\n                label = d[key]\n                if label.shape[0] != 1:\n                    raise ValueError(\"Only supports single channel labels!\")\n\n                if len(label.shape) != 4:  # only for 3D\n                    raise ValueError(\"Only supports label with shape CHWD!\")\n\n                sids = self._apply(label, d)\n                if sids is not None and len(sids.keys()):\n                    d[self.sids] = sids\n                return d\n            else:\n                print(\"This transform only applies to label key\")\n        return d\n"
  },
  {
    "path": "monai/apps/deepgrow/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/deepgrow/dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nfrom collections.abc import Sequence\n\nimport numpy as np\n\nfrom monai.config import PathLike\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, Spacingd, SqueezeDimd, Transform\nfrom monai.utils import GridSampleMode\n\n\ndef create_dataset(\n    datalist: list[dict],\n    output_dir: str,\n    dimension: int,\n    pixdim: Sequence[float] | float,\n    image_key: str = \"image\",\n    label_key: str = \"label\",\n    base_dir: PathLike | None = None,\n    limit: int = 0,\n    relative_path: bool = False,\n    transforms: Transform | None = None,\n) -> list[dict]:\n    \"\"\"\n    Utility to pre-process and create dataset list for Deepgrow training over on existing one.\n    The input data list is normally a list of images and labels (3D volume) that needs pre-processing\n    for Deepgrow training pipeline.\n\n    Args:\n        datalist: A list of data dictionary. Each entry should at least contain 'image_key': <image filename>.\n            For example, typical input data can be a list of dictionaries::\n\n                [{'image': <image filename>, 'label': <label filename>}]\n\n        output_dir: target directory to store the training data for Deepgrow Training\n        pixdim: output voxel spacing.\n        dimension: dimension for Deepgrow training.  It can be 2 or 3.\n        image_key: image key in input datalist. Defaults to 'image'.\n        label_key: label key in input datalist. Defaults to 'label'.\n        base_dir: base directory in case related path is used for the keys in datalist.  Defaults to None.\n        limit: limit number of inputs for pre-processing.  Defaults to 0 (no limit).\n        relative_path: output keys values should be based on relative path.  Defaults to False.\n        transforms: explicit transforms to execute operations on input data.\n\n    Raises:\n        ValueError: When ``dimension`` is not one of [2, 3]\n        ValueError: When ``datalist`` is Empty\n\n    Returns:\n        A new datalist that contains path to the images/labels after pre-processing.\n\n    Example::\n\n        datalist = create_dataset(\n            datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}],\n            base_dir=None,\n            output_dir=output_2d,\n            dimension=2,\n            image_key='image',\n            label_key='label',\n            pixdim=(1.0, 1.0),\n            limit=0,\n            relative_path=True\n        )\n\n        print(datalist[0][\"image\"], datalist[0][\"label\"])\n    \"\"\"\n\n    if dimension not in [2, 3]:\n        raise ValueError(\"Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training\")\n\n    if not len(datalist):\n        raise ValueError(\"Input datalist is empty\")\n\n    transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms\n    new_datalist = []\n    for idx, item in enumerate(datalist):\n        if limit and idx >= limit:\n            break\n\n        image = item[image_key]\n        label = item.get(label_key, None)\n        if base_dir:\n            image = os.path.join(base_dir, image)\n            label = os.path.join(base_dir, label) if label else None\n\n        image = os.path.abspath(image)\n        label = os.path.abspath(label) if label else None\n\n        logging.info(f\"Image: {image}; Label: {label if label else None}\")\n        data = transforms({image_key: image, label_key: label})\n\n        vol_image = data[image_key]\n        vol_label = data.get(label_key)\n        logging.info(f\"Image (transform): {vol_image.shape}; Label: {None if vol_label is None else vol_label.shape}\")\n\n        vol_image = np.moveaxis(vol_image, -1, 0)\n        if vol_label is not None:\n            vol_label = np.moveaxis(vol_label, -1, 0)\n        logging.info(f\"Image (final): {vol_image.shape}; Label: {None if vol_label is None else vol_label.shape}\")\n\n        if dimension == 2:\n            data = _save_data_2d(\n                vol_idx=idx,\n                vol_image=vol_image,\n                vol_label=vol_label,\n                dataset_dir=output_dir,\n                relative_path=relative_path,\n            )\n        else:\n            data = _save_data_3d(\n                vol_idx=idx,\n                vol_image=vol_image,\n                vol_label=vol_label,\n                dataset_dir=output_dir,\n                relative_path=relative_path,\n            )\n        new_datalist.extend(data)\n    return new_datalist\n\n\ndef _default_transforms(image_key, label_key, pixdim):\n    keys = [image_key] if label_key is None else [image_key, label_key]\n    mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR]\n    return Compose(\n        [\n            LoadImaged(keys=keys),\n            EnsureChannelFirstd(keys=keys),\n            Orientationd(keys=keys, axcodes=\"RAS\"),\n            Spacingd(keys=keys, pixdim=pixdim, mode=mode),\n            SqueezeDimd(keys=keys),\n        ]\n    )\n\n\ndef _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):\n    data_list: list[dict[str, str | int]] = []\n\n    image_count = 0\n    label_count = 0\n    unique_labels_count = 0\n    for sid in range(vol_image.shape[0]):\n        image = vol_image[sid, ...]\n        label = vol_label[sid, ...] if vol_label is not None else None\n\n        if vol_label is not None and np.sum(label) == 0:\n            continue\n\n        image_file_prefix = f\"vol_idx_{vol_idx:0>4d}_slice_{sid:0>3d}\"\n        image_file = os.path.join(dataset_dir, \"images\", image_file_prefix)\n        image_file += \".npy\"\n\n        os.makedirs(os.path.join(dataset_dir, \"images\"), exist_ok=True)\n        np.save(image_file, image)\n        image_count += 1\n\n        # Test Data\n        if vol_label is None:\n            data_list.append(\n                {\"image\": image_file.replace(dataset_dir + os.pathsep, \"\") if relative_path else image_file}\n            )\n            continue\n\n        # For all Labels\n        unique_labels = np.unique(label.flatten())\n        unique_labels = unique_labels[unique_labels != 0]\n        unique_labels_count = max(unique_labels_count, len(unique_labels))\n\n        for idx in unique_labels:\n            label_file_prefix = f\"{image_file_prefix}_region_{int(idx):0>2d}\"\n            label_file = os.path.join(dataset_dir, \"labels\", label_file_prefix)\n            label_file += \".npy\"\n\n            os.makedirs(os.path.join(dataset_dir, \"labels\"), exist_ok=True)\n            curr_label = (label == idx).astype(np.float32)\n            np.save(label_file, curr_label)\n\n            label_count += 1\n            data_list.append(\n                {\n                    \"image\": image_file.replace(dataset_dir + os.pathsep, \"\") if relative_path else image_file,\n                    \"label\": label_file.replace(dataset_dir + os.pathsep, \"\") if relative_path else label_file,\n                    \"region\": int(idx),\n                }\n            )\n\n    if unique_labels_count >= 20:\n        logging.warning(f\"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.\")\n\n    logging.info(\n        f\"{vol_idx} => Image Shape: {vol_image.shape} => {image_count};\"\n        f\" Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count};\"\n        f\" Unique Labels: {unique_labels_count}\"\n    )\n    return data_list\n\n\ndef _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):\n    data_list: list[dict[str, str | int]] = []\n\n    image_count = 0\n    label_count = 0\n    unique_labels_count = 0\n\n    image_file_prefix = f\"vol_idx_{vol_idx:0>4d}\"\n    image_file = os.path.join(dataset_dir, \"images\", image_file_prefix)\n    image_file += \".npy\"\n\n    os.makedirs(os.path.join(dataset_dir, \"images\"), exist_ok=True)\n    np.save(image_file, vol_image)\n    image_count += 1\n\n    # Test Data\n    if vol_label is None:\n        data_list.append({\"image\": image_file.replace(dataset_dir + os.pathsep, \"\") if relative_path else image_file})\n    else:\n        # For all Labels\n        unique_labels = np.unique(vol_label.flatten())\n        unique_labels = unique_labels[unique_labels != 0]\n        unique_labels_count = max(unique_labels_count, len(unique_labels))\n\n        for idx in unique_labels:\n            label_file_prefix = f\"{image_file_prefix}_region_{int(idx):0>2d}\"\n            label_file = os.path.join(dataset_dir, \"labels\", label_file_prefix)\n            label_file += \".npy\"\n\n            curr_label = (vol_label == idx).astype(np.float32)\n            os.makedirs(os.path.join(dataset_dir, \"labels\"), exist_ok=True)\n            np.save(label_file, curr_label)\n\n            label_count += 1\n            data_list.append(\n                {\n                    \"image\": image_file.replace(dataset_dir + os.pathsep, \"\") if relative_path else image_file,\n                    \"label\": label_file.replace(dataset_dir + os.pathsep, \"\") if relative_path else label_file,\n                    \"region\": int(idx),\n                }\n            )\n\n    if unique_labels_count >= 20:\n        logging.warning(f\"Unique labels {unique_labels_count} exceeds 20. Please check if this is correct.\")\n\n    logging.info(\n        f\"{vol_idx} => Image Shape: {vol_image.shape} => {image_count};\"\n        f\" Label Shape: {vol_label.shape if vol_label is not None else None} => {label_count};\"\n        f\" Unique Labels: {unique_labels_count}\"\n    )\n    return data_list\n"
  },
  {
    "path": "monai/apps/deepgrow/interaction.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sequence\n\nimport torch\n\nfrom monai.data import decollate_batch, list_data_collate\nfrom monai.engines import SupervisedEvaluator, SupervisedTrainer\nfrom monai.engines.utils import IterationEvents\nfrom monai.transforms import Compose\nfrom monai.utils.enums import CommonKeys\n\n\nclass Interaction:\n    \"\"\"\n    Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.\n    For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n    This implementation is based on:\n\n        Sakinis et al., Interactive segmentation of medical images through\n        fully convolutional neural networks. (2019) https://arxiv.org/abs/1903.08205\n\n    Args:\n        transforms: execute additional transformation during every iteration (before train).\n            Typically, several Tensor based transforms composed by `Compose`.\n        max_interactions: maximum number of interactions per iteration\n        train: training or evaluation\n        key_probability: field name to fill probability for every interaction\n    \"\"\"\n\n    def __init__(\n        self,\n        transforms: Sequence[Callable] | Callable,\n        max_interactions: int,\n        train: bool,\n        key_probability: str = \"probability\",\n    ) -> None:\n        if not isinstance(transforms, Compose):\n            transforms = Compose(transforms)\n\n        self.transforms: Compose = transforms\n        self.max_interactions = max_interactions\n        self.train = train\n        self.key_probability = key_probability\n\n    def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:\n        if batchdata is None:\n            raise ValueError(\"Must provide batch data for current iteration.\")\n\n        for j in range(self.max_interactions):\n            inputs, _ = engine.prepare_batch(batchdata)\n            inputs = inputs.to(engine.state.device)\n\n            engine.fire_event(IterationEvents.INNER_ITERATION_STARTED)\n\n            engine.network.eval()\n            with torch.no_grad():\n                if engine.amp:\n                    with torch.autocast(\"cuda\"):\n                        predictions = engine.inferer(inputs, engine.network)\n                else:\n                    predictions = engine.inferer(inputs, engine.network)\n\n            engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)\n\n            batchdata.update({CommonKeys.PRED: predictions})\n\n            # decollate batch data to execute click transforms\n            batchdata_list = decollate_batch(batchdata, detach=True)\n            for i in range(len(batchdata_list)):\n                batchdata_list[i][self.key_probability] = (\n                    (1.0 - ((1.0 / self.max_interactions) * j)) if self.train else 1.0\n                )\n                batchdata_list[i] = self.transforms(batchdata_list[i])\n\n            # collate list into a batch for next round interaction\n            batchdata = list_data_collate(batchdata_list)\n\n        return engine._iteration(engine, batchdata)  # type: ignore[arg-type]\n"
  },
  {
    "path": "monai/apps/deepgrow/transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nfrom collections.abc import Callable, Hashable, Iterable, Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import IndexSelection, KeysCollection, NdarrayOrTensor\nfrom monai.networks.layers import GaussianFilter\nfrom monai.transforms import Resize, SpatialCrop\nfrom monai.transforms.transform import MapTransform, Randomizable, Transform\nfrom monai.transforms.utils import generate_spatial_bounding_box, is_positive\nfrom monai.utils import InterpolateMode, ensure_tuple, ensure_tuple_rep, min_version, optional_import\nfrom monai.utils.enums import PostFix\n\nmeasure, _ = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\ndistance_transform_cdt, _ = optional_import(\"scipy.ndimage\", name=\"distance_transform_cdt\")\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\n# Transforms to support Training for Deepgrow models\nclass FindAllValidSlicesd(Transform):\n    \"\"\"\n    Find/List all valid slices in the label.\n    Label is assumed to be a 4D Volume with shape CDHW, where C=1.\n\n    Args:\n        label: key to the label source.\n        sids: key to store slices indices having valid label map.\n    \"\"\"\n\n    def __init__(self, label: str = \"label\", sids: str = \"sids\"):\n        self.label = label\n        self.sids = sids\n\n    def _apply(self, label):\n        sids = []\n        for sid in range(label.shape[1]):  # Assume channel is first\n            if np.sum(label[0][sid]) != 0:\n                sids.append(sid)\n        return np.asarray(sids)\n\n    def __call__(self, data: Any) -> dict:\n        d: dict = dict(data)\n        label = d[self.label].numpy() if isinstance(data[self.label], torch.Tensor) else data[self.label]\n        if label.shape[0] != 1:\n            raise ValueError(f\"Only supports single channel labels, got label shape {label.shape}!\")\n\n        if len(label.shape) != 4:  # only for 3D\n            raise ValueError(f\"Only supports label with shape CDHW, got label shape {label.shape}!\")\n\n        sids = self._apply(label)\n        if sids is not None and len(sids):\n            d[self.sids] = sids\n        return d\n\n\nclass AddInitialSeedPointd(Randomizable, Transform):\n    \"\"\"\n    Add random guidance as initial seed point for a given label.\n\n    Note that the label is of size (C, D, H, W) or (C, H, W)\n\n    The guidance is of size (2, N, # of dims) where N is number of guidance added.\n    # of dims = 4 when C, D, H, W; # of dims = 3 when (C, H, W)\n\n    Args:\n        label: label source.\n        guidance: key to store guidance.\n        sids: key that represents list of valid slice indices for the given label.\n        sid: key that represents the slice to add initial seed point.  If not present, random sid will be chosen.\n        connected_regions: maximum connected regions to use for adding initial points.\n    \"\"\"\n\n    def __init__(\n        self,\n        label: str = \"label\",\n        guidance: str = \"guidance\",\n        sids: str = \"sids\",\n        sid: str = \"sid\",\n        connected_regions: int = 5,\n    ):\n        self.label = label\n        self.sids_key = sids\n        self.sid_key = sid\n        self.sid = None\n        self.guidance = guidance\n        self.connected_regions = connected_regions\n\n    def randomize(self, data):\n        sid = data.get(self.sid_key, None)\n        sids = data.get(self.sids_key, None)\n        if sids is not None:\n            if sid is None or sid not in sids:\n                sid = self.R.choice(sids, replace=False)\n        else:\n            sid = None\n        self.sid = sid\n\n    def _apply(self, label, sid):\n        dimensions = 3 if len(label.shape) > 3 else 2\n        default_guidance = [-1] * (dimensions + 1)\n\n        dims = dimensions\n        if sid is not None and dimensions == 3:\n            dims = 2\n            label = label[0][sid][np.newaxis]  # Assume channel is first\n\n        label = (label > 0.5).astype(np.float32)\n        blobs_labels = measure.label(label.astype(int), background=0) if dims == 2 else label\n        if np.max(blobs_labels) <= 0:\n            raise AssertionError(\"Not a valid Label\")\n\n        pos_guidance = []\n        for ridx in range(1, 2 if dims == 3 else self.connected_regions + 1):\n            if dims == 2:\n                label = (blobs_labels == ridx).astype(np.float32)\n                if np.sum(label) == 0:\n                    pos_guidance.append(default_guidance)\n                    continue\n\n            distance = distance_transform_cdt(label).flatten()\n            probability = np.exp(distance) - 1.0\n\n            idx = np.where(label.flatten() > 0)[0]\n            seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))\n            dst = distance[seed]\n\n            g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0]\n            g[0] = dst[0]  # for debug\n            if dimensions == 2 or dims == 3:\n                pos_guidance.append(g)\n            else:\n                pos_guidance.append([g[0], sid, g[-2], g[-1]])\n\n        return np.asarray([pos_guidance, [default_guidance] * len(pos_guidance)])\n\n    def __call__(self, data):\n        d = dict(data)\n        self.randomize(data)\n        d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int, copy=False).tolist())\n        return d\n\n\nclass AddGuidanceSignald(Transform):\n    \"\"\"\n    Add Guidance signal for input image.\n\n    Based on the \"guidance\" points, apply gaussian to them and add them as new channel for input image.\n\n    Args:\n        image: key to the image source.\n        guidance: key to store guidance.\n        sigma: standard deviation for Gaussian kernel.\n        number_intensity_ch: channel index.\n\n    \"\"\"\n\n    def __init__(self, image: str = \"image\", guidance: str = \"guidance\", sigma: int = 2, number_intensity_ch: int = 1):\n        self.image = image\n        self.guidance = guidance\n        self.sigma = sigma\n        self.number_intensity_ch = number_intensity_ch\n\n    def _get_signal(self, image, guidance):\n        dimensions = 3 if len(image.shape) > 3 else 2\n        guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance\n        guidance = json.loads(guidance) if isinstance(guidance, str) else guidance\n        if dimensions == 3:\n            signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)\n        else:\n            signal = np.zeros((len(guidance), image.shape[-2], image.shape[-1]), dtype=np.float32)\n\n        sshape = signal.shape\n        for i, g_i in enumerate(guidance):\n            for point in g_i:\n                if np.any(np.asarray(point) < 0):\n                    continue\n\n                if dimensions == 3:\n                    p1 = max(0, min(int(point[-3]), sshape[-3] - 1))\n                    p2 = max(0, min(int(point[-2]), sshape[-2] - 1))\n                    p3 = max(0, min(int(point[-1]), sshape[-1] - 1))\n                    signal[i, p1, p2, p3] = 1.0\n                else:\n                    p1 = max(0, min(int(point[-2]), sshape[-2] - 1))\n                    p2 = max(0, min(int(point[-1]), sshape[-1] - 1))\n                    signal[i, p1, p2] = 1.0\n\n            if np.max(signal[i]) > 0:\n                signal_tensor = torch.tensor(signal[i])\n                pt_gaussian = GaussianFilter(len(signal_tensor.shape), sigma=self.sigma)\n                signal_tensor = pt_gaussian(signal_tensor.unsqueeze(0).unsqueeze(0))\n                signal_tensor = signal_tensor.squeeze(0).squeeze(0)\n                signal[i] = signal_tensor.detach().cpu().numpy()\n                signal[i] = (signal[i] - np.min(signal[i])) / (np.max(signal[i]) - np.min(signal[i]))\n        return signal\n\n    def _apply(self, image, guidance):\n        signal = self._get_signal(image, guidance)\n\n        if isinstance(image, torch.Tensor):\n            image = image.detach().cpu().numpy()\n\n        image = image[0 : 0 + self.number_intensity_ch, ...]\n        return np.concatenate([image, signal], axis=0)\n\n    def __call__(self, data):\n        d = dict(data)\n        image = d[self.image]\n        guidance = d[self.guidance]\n\n        d[self.image] = self._apply(image, guidance)\n        return d\n\n\nclass FindDiscrepancyRegionsd(Transform):\n    \"\"\"\n    Find discrepancy between prediction and actual during click interactions during training.\n\n    Args:\n        label: key to label source.\n        pred: key to prediction source.\n        discrepancy: key to store discrepancies found between label and prediction.\n\n    \"\"\"\n\n    def __init__(self, label: str = \"label\", pred: str = \"pred\", discrepancy: str = \"discrepancy\"):\n        self.label = label\n        self.pred = pred\n        self.discrepancy = discrepancy\n\n    @staticmethod\n    def disparity(label, pred):\n        label = (label > 0.5).astype(np.float32)\n        pred = (pred > 0.5).astype(np.float32)\n        disparity = label - pred\n\n        pos_disparity = (disparity > 0).astype(np.float32)\n        neg_disparity = (disparity < 0).astype(np.float32)\n        return [pos_disparity, neg_disparity]\n\n    def _apply(self, label, pred):\n        return self.disparity(label, pred)\n\n    def __call__(self, data):\n        d = dict(data)\n        label = d[self.label]\n        pred = d[self.pred]\n\n        d[self.discrepancy] = self._apply(label, pred)\n        return d\n\n\nclass AddRandomGuidanced(Randomizable, Transform):\n    \"\"\"\n    Add random guidance based on discrepancies that were found between label and prediction.\n    input shape is as below:\n    Guidance is of shape (2, N, # of dim)\n    Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W)\n    Probability is of shape (1)\n\n    Args:\n        guidance: key to guidance source.\n        discrepancy: key that represents discrepancies found between label and prediction.\n        probability: key that represents click/interaction probability.\n\n    \"\"\"\n\n    def __init__(self, guidance: str = \"guidance\", discrepancy: str = \"discrepancy\", probability: str = \"probability\"):\n        self.guidance = guidance\n        self.discrepancy = discrepancy\n        self.probability = probability\n        self._will_interact = None\n\n    def randomize(self, data=None):\n        probability = data[self.probability]\n        self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])\n\n    def find_guidance(self, discrepancy):\n        distance = distance_transform_cdt(discrepancy).flatten()\n        probability = np.exp(distance) - 1.0\n        idx = np.where(discrepancy.flatten() > 0)[0]\n\n        if np.sum(discrepancy > 0) > 0:\n            seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))\n            dst = distance[seed]\n\n            g = np.asarray(np.unravel_index(seed, discrepancy.shape)).transpose().tolist()[0]\n            g[0] = dst[0]\n            return g\n        return None\n\n    def add_guidance(self, discrepancy, will_interact):\n        if not will_interact:\n            return None, None\n\n        pos_discr = discrepancy[0]\n        neg_discr = discrepancy[1]\n\n        can_be_positive = np.sum(pos_discr) > 0\n        can_be_negative = np.sum(neg_discr) > 0\n        correct_pos = np.sum(pos_discr) >= np.sum(neg_discr)\n\n        if correct_pos and can_be_positive:\n            return self.find_guidance(pos_discr), None\n\n        if not correct_pos and can_be_negative:\n            return None, self.find_guidance(neg_discr)\n        return None, None\n\n    def _apply(self, guidance, discrepancy):\n        guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance\n        guidance = json.loads(guidance) if isinstance(guidance, str) else guidance\n        pos, neg = self.add_guidance(discrepancy, self._will_interact)\n        if pos:\n            guidance[0].append(pos)\n            guidance[1].append([-1] * len(pos))\n        if neg:\n            guidance[0].append([-1] * len(neg))\n            guidance[1].append(neg)\n\n        return json.dumps(np.asarray(guidance, dtype=int).tolist())\n\n    def __call__(self, data):\n        d = dict(data)\n        guidance = d[self.guidance]\n        discrepancy = d[self.discrepancy]\n\n        self.randomize(data)\n        d[self.guidance] = self._apply(guidance, discrepancy)\n        return d\n\n\nclass SpatialCropForegroundd(MapTransform):\n    \"\"\"\n    Crop only the foreground object of the expected images.\n\n    Difference VS :py:class:`monai.transforms.CropForegroundd`:\n\n      1. If the bounding box is smaller than spatial size in all dimensions then this transform will crop the\n         object using box's center and spatial_size.\n\n      2. This transform will set \"start_coord_key\", \"end_coord_key\", \"original_shape_key\" and \"cropped_shape_key\"\n         in data[{key}_{meta_key_postfix}]\n\n    The typical usage is to help training and evaluation if the valid part is small in the whole medical image.\n    The valid part can be determined by any field in the data with `source_key`, for example:\n\n    - Select values > 0 in image field as the foreground and crop on all fields specified by `keys`.\n    - Select label = 3 in label field as the foreground to crop on all fields specified by `keys`.\n    - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields.\n\n    Users can define arbitrary function to select expected foreground from the whole source image or specified\n    channels. And it can also add margin to every dim of the bounding box of foreground object.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.MapTransform`\n        source_key: data source to generate the bounding box of foreground, can be image or label, etc.\n        spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in.\n        select_fn: function to select expected foreground, default is to select values > 0.\n        channel_indices: if defined, select foreground only on the specified channels\n            of image. if None, select foreground on the whole image.\n        margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.\n        allow_smaller: when computing box size with `margin`, whether allow the image size to be smaller\n            than box size, default to `True`. if the margined size is bigger than image size, will pad with\n            specified `mode`.\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            it can be a sequence of string, map to the `keys`.\n            if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n        meta_key_postfix: if meta_keys is None, use `{key}_{meta_key_postfix}` to fetch/store the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        start_coord_key: key to record the start coordinate of spatial bounding box for foreground.\n        end_coord_key: key to record the end coordinate of spatial bounding box for foreground.\n        original_shape_key: key to record original shape for foreground.\n        cropped_shape_key: key to record cropped shape for foreground.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        source_key: str,\n        spatial_size: Sequence[int] | np.ndarray,\n        select_fn: Callable = is_positive,\n        channel_indices: IndexSelection | None = None,\n        margin: int = 0,\n        allow_smaller: bool = True,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        start_coord_key: str = \"foreground_start_coord\",\n        end_coord_key: str = \"foreground_end_coord\",\n        original_shape_key: str = \"foreground_original_shape\",\n        cropped_shape_key: str = \"foreground_cropped_shape\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n\n        self.source_key = source_key\n        self.spatial_size = list(spatial_size)\n        self.select_fn = select_fn\n        self.channel_indices = channel_indices\n        self.margin = margin\n        self.allow_smaller = allow_smaller\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.start_coord_key = start_coord_key\n        self.end_coord_key = end_coord_key\n        self.original_shape_key = original_shape_key\n        self.cropped_shape_key = cropped_shape_key\n\n    def __call__(self, data):\n        d = dict(data)\n        box_start, box_end = generate_spatial_bounding_box(\n            d[self.source_key], self.select_fn, self.channel_indices, self.margin, self.allow_smaller\n        )\n\n        center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False))\n        current_size = list(np.subtract(box_end, box_start).astype(int, copy=False))\n\n        if np.all(np.less(current_size, self.spatial_size)):\n            cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)\n            box_start = [s.start for s in cropper.slices]\n            box_end = [s.stop for s in cropper.slices]\n        else:\n            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)\n\n        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):\n            meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n            d[meta_key][self.start_coord_key] = box_start\n            d[meta_key][self.end_coord_key] = box_end\n            d[meta_key][self.original_shape_key] = d[key].shape\n\n            image = cropper(d[key])\n            d[meta_key][self.cropped_shape_key] = image.shape\n            d[key] = image\n        return d\n\n\n# Transforms to support Inference for Deepgrow models\nclass AddGuidanceFromPointsd(Transform):\n    \"\"\"\n    Add guidance based on user clicks.\n\n    We assume the input is loaded by LoadImaged and has the shape of (H, W, D) originally.\n    Clicks always specify the coordinates in (H, W, D)\n\n    If depth_first is True:\n\n        Input is now of shape (D, H, W), will return guidance that specifies the coordinates in (D, H, W)\n\n    else:\n\n        Input is now of shape (H, W, D), will return guidance that specifies the coordinates in (H, W, D)\n\n    Args:\n        ref_image: key to reference image to fetch current and original image details.\n        guidance: output key to store guidance.\n        foreground: key that represents user foreground (+ve) clicks.\n        background: key that represents user background (-ve) clicks.\n        axis: axis that represents slices in 3D volume. (axis to Depth)\n        depth_first: if depth (slices) is positioned at first dimension.\n        spatial_dims: dimensions based on model used for deepgrow (2D vs 3D).\n        slice_key: key that represents applicable slice to add guidance.\n        meta_keys: explicitly indicate the key of the metadata dictionary of `ref_image`.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.\n        meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        ref_image: str,\n        guidance: str = \"guidance\",\n        foreground: str = \"foreground\",\n        background: str = \"background\",\n        axis: int = 0,\n        depth_first: bool = True,\n        spatial_dims: int = 2,\n        slice_key: str = \"slice\",\n        meta_keys: str | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n    ):\n        self.ref_image = ref_image\n        self.guidance = guidance\n        self.foreground = foreground\n        self.background = background\n        self.axis = axis\n        self.depth_first = depth_first\n        self.dimensions = spatial_dims\n        self.slice = slice_key\n        self.meta_keys = meta_keys\n        self.meta_key_postfix = meta_key_postfix\n\n    def _apply(self, pos_clicks, neg_clicks, factor, slice_num):\n        pos = neg = []\n\n        if self.dimensions == 2:\n            points: list = list(pos_clicks)\n            points.extend(neg_clicks)\n\n            slices = list(np.unique(np.array(points)[:, self.axis]))\n            slice_idx = slices[0] if slice_num is None else next(x for x in slices if x == slice_num)\n\n            if len(pos_clicks):\n                pos_clicks = np.array(pos_clicks)\n                pos = (pos_clicks[np.where(pos_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist()\n            if len(neg_clicks):\n                neg_clicks = np.array(neg_clicks)\n                neg = (neg_clicks[np.where(neg_clicks[:, self.axis] == slice_idx)] * factor)[:, 1:].astype(int).tolist()\n\n            guidance = [pos, neg, slice_idx]\n        else:\n            if len(pos_clicks):\n                pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist()\n            if len(neg_clicks):\n                neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist()\n            guidance = [pos, neg]\n        return guidance\n\n    def __call__(self, data):\n        d = dict(data)\n        meta_dict_key = self.meta_keys or f\"{self.ref_image}_{self.meta_key_postfix}\"\n        if meta_dict_key not in d:\n            raise RuntimeError(f\"Missing meta_dict {meta_dict_key} in data!\")\n        if \"spatial_shape\" not in d[meta_dict_key]:\n            raise RuntimeError('Missing \"spatial_shape\" in meta_dict!')\n        original_shape = d[meta_dict_key][\"spatial_shape\"]\n        current_shape = list(d[self.ref_image].shape)\n\n        if self.depth_first:\n            if self.axis != 0:\n                raise RuntimeError(\"Depth first means the depth axis should be 0.\")\n            # in here we assume the depth dimension was in the last dimension of \"original_shape\"\n            original_shape = np.roll(original_shape, 1)\n\n        factor = np.array(current_shape) / original_shape\n\n        fg_bg_clicks = []\n        for key in [self.foreground, self.background]:\n            clicks = d[key]\n            clicks = list(np.array(clicks, dtype=int))\n            if self.depth_first:\n                for i in range(len(clicks)):\n                    clicks[i] = list(np.roll(clicks[i], 1))\n            fg_bg_clicks.append(clicks)\n        d[self.guidance] = self._apply(fg_bg_clicks[0], fg_bg_clicks[1], factor, d.get(self.slice))\n        return d\n\n\nclass SpatialCropGuidanced(MapTransform):\n    \"\"\"\n    Crop image based on guidance with minimal spatial size.\n\n    - If the bounding box is smaller than spatial size in all dimensions then this transform will crop the\n      object using box's center and spatial_size.\n\n    - This transform will set \"start_coord_key\", \"end_coord_key\", \"original_shape_key\" and \"cropped_shape_key\"\n      in data[{key}_{meta_key_postfix}]\n\n    Input data is of shape (C, spatial_1, [spatial_2, ...])\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        guidance: key to the guidance. It is used to generate the bounding box of foreground\n        spatial_size: minimal spatial size of the image patch e.g. [128, 128, 128] to fit in.\n        margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            it can be a sequence of string, map to the `keys`.\n            if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n        meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        start_coord_key: key to record the start coordinate of spatial bounding box for foreground.\n        end_coord_key: key to record the end coordinate of spatial bounding box for foreground.\n        original_shape_key: key to record original shape for foreground.\n        cropped_shape_key: key to record cropped shape for foreground.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        guidance: str,\n        spatial_size: Iterable[int],\n        margin: int = 20,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        start_coord_key: str = \"foreground_start_coord\",\n        end_coord_key: str = \"foreground_end_coord\",\n        original_shape_key: str = \"foreground_original_shape\",\n        cropped_shape_key: str = \"foreground_cropped_shape\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n\n        self.guidance = guidance\n        self.spatial_size = list(spatial_size)\n        self.margin = margin\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.start_coord_key = start_coord_key\n        self.end_coord_key = end_coord_key\n        self.original_shape_key = original_shape_key\n        self.cropped_shape_key = cropped_shape_key\n\n    def bounding_box(self, points, img_shape):\n        ndim = len(img_shape)\n        margin = ensure_tuple_rep(self.margin, ndim)\n        for m in margin:\n            if m < 0:\n                raise ValueError(\"margin value should not be negative number.\")\n\n        box_start = [0] * ndim\n        box_end = [0] * ndim\n\n        for di in range(ndim):\n            dt = points[..., di]\n            min_d = max(min(dt - margin[di]), 0)\n            max_d = min(img_shape[di], max(dt + margin[di] + 1))\n            box_start[di], box_end[di] = min_d, max_d\n        return box_start, box_end\n\n    def __call__(self, data: Any) -> dict:\n        d: dict = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            return d\n\n        guidance = d[self.guidance]\n        original_spatial_shape = d[first_key].shape[1:]\n        box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape)\n        center = list(np.mean([box_start, box_end], axis=0).astype(int, copy=False))\n        spatial_size = self.spatial_size\n\n        box_size = list(np.subtract(box_end, box_start).astype(int, copy=False))\n        spatial_size = spatial_size[-len(box_size) :]\n\n        if len(spatial_size) < len(box_size):\n            # If the data is in 3D and spatial_size is specified as 2D [256,256]\n            # Then we will get all slices in such case\n            diff = len(box_size) - len(spatial_size)\n            spatial_size = list(original_spatial_shape[1 : (1 + diff)]) + spatial_size\n\n        if np.all(np.less(box_size, spatial_size)):\n            if len(center) == 3:\n                # 3D Deepgrow: set center to be middle of the depth dimension (D)\n                center[0] = spatial_size[0] // 2\n            cropper = SpatialCrop(roi_center=center, roi_size=spatial_size)\n        else:\n            cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)\n\n        # update bounding box in case it was corrected by the SpatialCrop constructor\n        box_start = np.array([s.start for s in cropper.slices])\n        box_end = np.array([s.stop for s in cropper.slices])\n        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):\n            if not np.array_equal(d[key].shape[1:], original_spatial_shape):\n                raise RuntimeError(\"All the image specified in keys should have same spatial shape\")\n            meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n            d[meta_key][self.start_coord_key] = box_start\n            d[meta_key][self.end_coord_key] = box_end\n            d[meta_key][self.original_shape_key] = d[key].shape\n\n            image = cropper(d[key])\n            d[meta_key][self.cropped_shape_key] = image.shape\n            d[key] = image\n\n        pos_clicks, neg_clicks = guidance[0], guidance[1]\n        pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else []\n        neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else []\n\n        d[self.guidance] = [pos, neg]\n        return d\n\n\nclass ResizeGuidanced(Transform):\n    \"\"\"\n    Resize the guidance based on cropped vs resized image.\n\n    This transform assumes that the images have been cropped and resized. And the shape after cropped is store inside\n    the meta dict of ref image.\n\n    Args:\n        guidance: key to guidance\n        ref_image: key to reference image to fetch current and original image details\n        meta_keys: explicitly indicate the key of the metadata dictionary of `ref_image`.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.\n        meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        cropped_shape_key: key that records cropped shape for foreground.\n    \"\"\"\n\n    def __init__(\n        self,\n        guidance: str,\n        ref_image: str,\n        meta_keys: str | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        cropped_shape_key: str = \"foreground_cropped_shape\",\n    ) -> None:\n        self.guidance = guidance\n        self.ref_image = ref_image\n        self.meta_keys = meta_keys\n        self.meta_key_postfix = meta_key_postfix\n        self.cropped_shape_key = cropped_shape_key\n\n    def __call__(self, data: Any) -> dict:\n        d = dict(data)\n        guidance = d[self.guidance]\n        meta_dict: dict = d[self.meta_keys or f\"{self.ref_image}_{self.meta_key_postfix}\"]\n        current_shape = d[self.ref_image].shape[1:]\n        cropped_shape = meta_dict[self.cropped_shape_key][1:]\n        factor = np.divide(current_shape, cropped_shape)\n\n        pos_clicks, neg_clicks = guidance[0], guidance[1]\n        pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else []\n        neg = np.multiply(neg_clicks, factor).astype(int, copy=False).tolist() if len(neg_clicks) else []\n\n        d[self.guidance] = [pos, neg]\n        return d\n\n\nclass RestoreLabeld(MapTransform):\n    \"\"\"\n    Restores label based on the ref image.\n\n    The ref_image is assumed that it went through the following transforms:\n\n        1. Fetch2DSliced (If 2D)\n        2. Spacingd\n        3. SpatialCropGuidanced\n        4. Resized\n\n    And its shape is assumed to be (C, D, H, W)\n\n    This transform tries to undo these operation so that the result label can be overlapped with original volume.\n    It does the following operation:\n\n        1. Undo Resized\n        2. Undo SpatialCropGuidanced\n        3. Undo Spacingd\n        4. Undo Fetch2DSliced\n\n    The resulting label is of shape (D, H, W)\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        ref_image: reference image to fetch current and original image details\n        slice_only: apply only to an applicable slice, in case of 2D model/prediction\n        mode: {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``, ``\"mean\"``,\n            ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            One of the listed string values or a user supplied function for padding. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n        align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            It also can be a sequence of bool, each element corresponds to a key in ``keys``.\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            it can be a sequence of string, map to the `keys`.\n            if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n        meta_key_postfix: if meta_key is None, use `key_{meta_key_postfix} to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        start_coord_key: key that records the start coordinate of spatial bounding box for foreground.\n        end_coord_key: key that records the end coordinate of spatial bounding box for foreground.\n        original_shape_key: key that records original shape for foreground.\n        cropped_shape_key: key that records cropped shape for foreground.\n        allow_missing_keys: don't raise exception if key is missing.\n        restore_resizing: used to enable or disable resizing restoration, default is True.\n            If True, the transform will resize the items back to its original shape.\n        restore_cropping: used to enable or disable cropping restoration, default is True.\n            If True, the transform will restore the items to its uncropped size.\n        restore_spacing: used to enable or disable spacing restoration, default is True.\n            If True, the transform will resample the items back to the spacing it had before being altered.\n        restore_slicing: used to enable or disable slicing restoration, default is True.\n            If True, the transform will reassemble the full volume by restoring the slices to their original positions.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        ref_image: str,\n        slice_only: bool = False,\n        mode: Sequence[InterpolateMode | str] | InterpolateMode | str = InterpolateMode.NEAREST,\n        align_corners: Sequence[bool | None] | bool | None = None,\n        meta_keys: str | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        start_coord_key: str = \"foreground_start_coord\",\n        end_coord_key: str = \"foreground_end_coord\",\n        original_shape_key: str = \"foreground_original_shape\",\n        cropped_shape_key: str = \"foreground_cropped_shape\",\n        allow_missing_keys: bool = False,\n        restore_resizing: bool = True,\n        restore_cropping: bool = True,\n        restore_spacing: bool = True,\n        restore_slicing: bool = True,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.ref_image = ref_image\n        self.slice_only = slice_only\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = meta_key_postfix\n        self.start_coord_key = start_coord_key\n        self.end_coord_key = end_coord_key\n        self.original_shape_key = original_shape_key\n        self.cropped_shape_key = cropped_shape_key\n        self.restore_resizing = restore_resizing\n        self.restore_cropping = restore_cropping\n        self.restore_spacing = restore_spacing\n        self.restore_slicing = restore_slicing\n\n    def __call__(self, data: Any) -> dict:\n        d = dict(data)\n        meta_dict: dict = d[f\"{self.ref_image}_{self.meta_key_postfix}\"]\n\n        for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys):\n            image = d[key]\n\n            # Undo Resize\n            if self.restore_resizing:\n                current_shape = image.shape\n                cropped_shape = meta_dict[self.cropped_shape_key]\n                if np.any(np.not_equal(current_shape, cropped_shape)):\n                    resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)\n                    image = resizer(image, mode=mode, align_corners=align_corners)\n\n            # Undo Crop\n            if self.restore_cropping:\n                original_shape = meta_dict[self.original_shape_key]\n                result = np.zeros(original_shape, dtype=np.float32)\n                box_start = meta_dict[self.start_coord_key]\n                box_end = meta_dict[self.end_coord_key]\n\n                spatial_dims = min(len(box_start), len(image.shape[1:]))\n                slices = tuple(\n                    [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]\n                )\n                result[slices] = image\n            else:\n                result = image\n\n            # Undo Spacing\n            if self.restore_spacing:\n                current_size = result.shape[1:]\n                # change spatial_shape from HWD to DHW\n                spatial_shape = list(np.roll(meta_dict[\"spatial_shape\"], 1))\n                spatial_size = spatial_shape[-len(current_size) :]\n\n                if np.any(np.not_equal(current_size, spatial_size)):\n                    resizer = Resize(spatial_size=spatial_size, mode=mode)\n                    result = resizer(result, mode=mode, align_corners=align_corners)  # type: ignore\n\n            # Undo Slicing\n            slice_idx = meta_dict.get(\"slice_idx\")\n            final_result: NdarrayOrTensor\n            if not self.restore_slicing:  # do nothing if restore slicing isn't requested\n                final_result = result\n            elif slice_idx is None or self.slice_only:\n                final_result = result if len(result.shape) <= 3 else result[0]\n            else:\n                slice_idx = meta_dict[\"slice_idx\"][0]\n                final_result = np.zeros(tuple(spatial_shape))\n                final_result[slice_idx] = result\n            d[key] = final_result\n\n            meta_key = meta_key or f\"{key}_{self.meta_key_postfix}\"\n            meta = d.get(meta_key)\n            if meta is None:\n                meta = dict()\n                d[meta_key] = meta\n            meta[\"slice_idx\"] = slice_idx\n            meta[\"affine\"] = meta_dict[\"original_affine\"]\n        return d\n\n\nclass Fetch2DSliced(MapTransform):\n    \"\"\"\n    Fetch one slice in case of a 3D volume.\n\n    The volume only contains spatial coordinates.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        guidance: key that represents guidance.\n        axis: axis that represents slice in 3D volume.\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            it can be a sequence of string, map to the `keys`.\n            if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n        meta_key_postfix: use `key_{meta_key_postfix}` to fetch the metadata according to the key data,\n            default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        guidance: str = \"guidance\",\n        axis: int = 0,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.guidance = guidance\n        self.axis = axis\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n\n    def _apply(self, image, guidance):\n        slice_idx = guidance[2]  # (pos, neg, slice_idx)\n        idx = []\n        for i, size_i in enumerate(image.shape):\n            idx.append(slice_idx) if i == self.axis else idx.append(slice(0, size_i))\n\n        return image[tuple(idx)], tuple(idx)\n\n    def __call__(self, data):\n        d = dict(data)\n        guidance = d[self.guidance]\n        if len(guidance) < 3:\n            raise RuntimeError(\"Guidance does not container slice_idx!\")\n        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):\n            img_slice, idx = self._apply(d[key], guidance)\n            d[key] = img_slice\n            d[meta_key or f\"{key}_{meta_key_postfix}\"][\"slice_idx\"] = idx\n        return d\n"
  },
  {
    "path": "monai/apps/detection/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/detection/metrics/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/detection/metrics/coco.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/coco.py\n# which has the following license...\n# https://github.com/MIC-DKFZ/nnDetection/blob/main/LICENSE\n#\n# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#    http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/cocodataset/cocoapi\n# which has the following license...\n# https://github.com/cocodataset/cocoapi/blob/master/license.txt\n\n# Copyright (c) 2014, Piotr Dollar and Tsung-Yi Lin\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n# The views and conclusions contained in the software and documentation are those\n# of the authors and should not be interpreted as representing official policies,\n# either expressed or implied, of the FreeBSD Project.\n\"\"\"\nThis script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/coco.py\nThe changes include 1) code reformatting, 2) docstrings.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging as logger\nimport time\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\n\n\nclass COCOMetric:\n\n    def __init__(\n        self,\n        classes: Sequence[str],\n        iou_list: Sequence[float] = (0.1, 0.5, 0.75),\n        iou_range: Sequence[float] = (0.1, 0.5, 0.05),\n        max_detection: Sequence[int] = (1, 5, 100),\n        per_class: bool = True,\n        verbose: bool = True,\n    ):\n        \"\"\"\n        Class to compute COCO metrics\n        Metrics computed includes,\n\n        - mAP over the IoU range specified by `iou_range` at last value of `max_detection`\n        - AP values at IoU thresholds specified by `iou_list` at last value of `max_detection`\n        - AR over max detections thresholds defined by `max_detection` (over iou range)\n\n        Args:\n            classes (Sequence[str]): name of each class (index needs to correspond to predicted class indices!)\n            iou_list (Sequence[float]): specific thresholds where ap is evaluated and saved\n            iou_range (Sequence[float]): (start, stop, step) for mAP iou thresholds\n            max_detection (Sequence[int]): maximum number of detections per image\n            verbose (bool): log time needed for evaluation\n\n        Example:\n\n            .. code-block:: python\n\n                from monai.data.box_utils import box_iou\n                from monai.apps.detection.metrics.coco import COCOMetric\n                from monai.apps.detection.metrics.matching import matching_batch\n                # 3D example outputs of one image from detector\n                val_outputs_all = [\n                        {\"boxes\": torch.tensor([[1,1,1,3,4,5]],dtype=torch.float16),\n                        \"labels\": torch.randint(3,(1,)),\n                        \"scores\": torch.randn((1,)).absolute()},\n                ]\n                val_targets_all = [\n                        {\"boxes\": torch.tensor([[1,1,1,2,6,4]],dtype=torch.float16),\n                        \"labels\": torch.randint(3,(1,))},\n                ]\n\n                coco_metric = COCOMetric(\n                    classes=['c0','c1','c2'], iou_list=[0.1], max_detection=[10]\n                )\n                results_metric = matching_batch(\n                    iou_fn=box_iou,\n                    iou_thresholds=coco_metric.iou_thresholds,\n                    pred_boxes=[val_data_i[\"boxes\"].numpy() for val_data_i in val_outputs_all],\n                    pred_classes=[val_data_i[\"labels\"].numpy() for val_data_i in val_outputs_all],\n                    pred_scores=[val_data_i[\"scores\"].numpy() for val_data_i in val_outputs_all],\n                    gt_boxes=[val_data_i[\"boxes\"].numpy() for val_data_i in val_targets_all],\n                    gt_classes=[val_data_i[\"labels\"].numpy() for val_data_i in val_targets_all],\n                )\n                val_metric_dict = coco_metric(results_metric)\n                print(val_metric_dict)\n        \"\"\"\n        self.verbose = verbose\n        self.classes = classes\n        self.per_class = per_class\n\n        iou_list_np = np.array(iou_list)\n        _iou_range = np.linspace(\n            iou_range[0], iou_range[1], int(np.round((iou_range[1] - iou_range[0]) / iou_range[2])) + 1, endpoint=True\n        )\n        self.iou_thresholds = np.union1d(iou_list_np, _iou_range)\n        self.iou_range = iou_range\n\n        # get indices of iou values of ious range and ious list for later evaluation\n        self.iou_list_idx = np.nonzero(iou_list_np[:, np.newaxis] == self.iou_thresholds[np.newaxis])[1]\n        self.iou_range_idx = np.nonzero(_iou_range[:, np.newaxis] == self.iou_thresholds[np.newaxis])[1]\n\n        if (\n            not (self.iou_thresholds[self.iou_list_idx] == iou_list_np).all()\n            or not (self.iou_thresholds[self.iou_range_idx] == _iou_range).all()\n        ):\n            raise ValueError(\n                \"Require self.iou_thresholds[self.iou_list_idx] == iou_list_np and \"\n                \"self.iou_thresholds[self.iou_range_idx] == _iou_range.\"\n            )\n\n        self.recall_thresholds = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)\n        self.max_detections = max_detection\n\n    def __call__(self, *args: Any, **kwargs: Any) -> tuple[dict[str, float], dict[str, np.ndarray] | None]:\n        \"\"\"\n        Compute metric. See :func:`compute` for more information.\n\n        Args:\n            *args: positional arguments passed to :func:`compute`\n            **kwargs: keyword arguments passed to :func:`compute`\n\n        Returns:\n            dict[str, float]: dictionary with scalar values for evaluation\n            dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs\n        \"\"\"\n        return self.compute(*args, **kwargs)\n\n    def check_number_of_iou(self, *args: np.ndarray) -> None:\n        \"\"\"\n        Check if shape of input in first dimension is consistent with expected IoU values\n        (assumes IoU dimension is the first dimension)\n\n        Args:\n            args: array like inputs with shape function\n        \"\"\"\n        num_ious = len(self.get_iou_thresholds())\n        for arg in args:\n            if arg.shape[0] != num_ious:\n                raise ValueError(\n                    f\"Require arg.shape[0] == len(self.get_iou_thresholds()). Got arg.shape[0]={arg.shape[0]}, \"\n                    f\"self.get_iou_thresholds()={self.get_iou_thresholds()}.\"\n                )\n\n    def get_iou_thresholds(self) -> Sequence[float]:\n        \"\"\"\n        Return IoU thresholds needed for this metric in an numpy array\n\n        Returns:\n            Sequence[float]: IoU thresholds [M], M is the number of thresholds\n        \"\"\"\n        return list(self.iou_thresholds)\n\n    def compute(self, results_list: list[dict[int, dict[str, np.ndarray]]]) -> tuple[dict[str, float], None]:\n        \"\"\"\n        Compute COCO metrics\n\n        Args:\n            results_list (list[dict[int, dict[str, np.ndarray]]]): list with results per image (in list)\n                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.\n\n                - `dtMatches`: matched detections [T, D], where T = number of\n                  thresholds, D = number of detections\n                - `gtMatches`: matched ground truth boxes [T, G], where T = number\n                  of thresholds, G = number of ground truth\n                - `dtScores`: prediction scores [D] detection scores\n                - `gtIgnore`: ground truth boxes which should be ignored\n                  [G] indicate whether ground truth should be ignored\n                - `dtIgnore`: detections which should be ignored [T, D],\n                  indicate which detections should be ignored\n\n        Returns:\n            dict[str, float], dictionary with coco metrics\n        \"\"\"\n        if self.verbose:\n            logger.info(\"Start COCO metric computation...\")\n            tic = time.time()\n\n        dataset_statistics = self._compute_statistics(results_list=results_list)  # dict[str, Union[np.ndarray, list]]\n\n        if self.verbose:\n            toc = time.time()\n            logger.info(f\"Statistics for COCO metrics finished (t={(toc - tic):0.2f}s).\")\n\n        results = {}\n        results.update(self._compute_ap(dataset_statistics))\n        results.update(self._compute_ar(dataset_statistics))\n\n        if self.verbose:\n            toc = time.time()\n            logger.info(f\"COCO metrics computed in t={(toc - tic):0.2f}s.\")\n        return results, None\n\n    def _compute_ap(self, dataset_statistics: dict[str, np.ndarray | list]) -> dict[str, float]:\n        \"\"\"\n        Compute AP metrics\n\n        Args:\n            dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)\n                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.\n\n                - `dtMatches`: matched detections [T, D], where T = number of\n                  thresholds, D = number of detections\n                - `gtMatches`: matched ground truth boxes [T, G], where T = number\n                  of thresholds, G = number of ground truth\n                - `dtScores`: prediction scores [D] detection scores\n                - `gtIgnore`: ground truth boxes which should be ignored\n                  [G] indicate whether ground truth should be ignored\n                - `dtIgnore`: detections which should be ignored [T, D],\n                  indicate which detections should be ignored\n        \"\"\"\n        results = {}\n        if self.iou_range:  # mAP\n            key = (\n                f\"mAP_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_\"\n                f\"MaxDet_{self.max_detections[-1]}\"\n            )\n            results[key] = self._select_ap(dataset_statistics, iou_idx=self.iou_range_idx, max_det_idx=-1)\n\n            if self.per_class:\n                for cls_idx, cls_str in enumerate(self.classes):  # per class results\n                    key = (\n                        f\"{cls_str}_\"\n                        f\"mAP_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_\"\n                        f\"MaxDet_{self.max_detections[-1]}\"\n                    )\n                    results[key] = self._select_ap(\n                        dataset_statistics, iou_idx=self.iou_range_idx, cls_idx=cls_idx, max_det_idx=-1\n                    )\n\n        for idx in self.iou_list_idx:  # AP@IoU\n            key = f\"AP_IoU_{self.iou_thresholds[idx]:.2f}_MaxDet_{self.max_detections[-1]}\"\n            results[key] = self._select_ap(dataset_statistics, iou_idx=[idx], max_det_idx=-1)\n\n            if self.per_class:\n                for cls_idx, cls_str in enumerate(self.classes):  # per class results\n                    key = f\"{cls_str}_\" f\"AP_IoU_{self.iou_thresholds[idx]:.2f}_\" f\"MaxDet_{self.max_detections[-1]}\"\n                    results[key] = self._select_ap(dataset_statistics, iou_idx=[idx], cls_idx=cls_idx, max_det_idx=-1)\n        return results\n\n    def _compute_ar(self, dataset_statistics: dict[str, np.ndarray | list]) -> dict[str, float]:\n        \"\"\"\n        Compute AR metrics\n\n        Args:\n            dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)\n                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.\n\n                - `dtMatches`: matched detections [T, D], where T = number of\n                  thresholds, D = number of detections\n                - `gtMatches`: matched ground truth boxes [T, G], where T = number\n                  of thresholds, G = number of ground truth\n                - `dtScores`: prediction scores [D] detection scores\n                - `gtIgnore`: ground truth boxes which should be ignored\n                  [G] indicate whether ground truth should be ignored\n                - `dtIgnore`: detections which should be ignored [T, D],\n                  indicate which detections should be ignored\n        \"\"\"\n        results = {}\n        for max_det_idx, max_det in enumerate(self.max_detections):  # mAR\n            key = f\"mAR_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_MaxDet_{max_det}\"\n            results[key] = self._select_ar(dataset_statistics, max_det_idx=max_det_idx)\n\n            if self.per_class:\n                for cls_idx, cls_str in enumerate(self.classes):  # per class results\n                    key = (\n                        f\"{cls_str}_\"\n                        f\"mAR_IoU_{self.iou_range[0]:.2f}_{self.iou_range[1]:.2f}_{self.iou_range[2]:.2f}_\"\n                        f\"MaxDet_{max_det}\"\n                    )\n                    results[key] = self._select_ar(dataset_statistics, cls_idx=cls_idx, max_det_idx=max_det_idx)\n\n        for idx in self.iou_list_idx:  # AR@IoU\n            key = f\"AR_IoU_{self.iou_thresholds[idx]:.2f}_MaxDet_{self.max_detections[-1]}\"\n            results[key] = self._select_ar(dataset_statistics, iou_idx=idx, max_det_idx=-1)\n\n            if self.per_class:\n                for cls_idx, cls_str in enumerate(self.classes):  # per class results\n                    key = f\"{cls_str}_\" f\"AR_IoU_{self.iou_thresholds[idx]:.2f}_\" f\"MaxDet_{self.max_detections[-1]}\"\n                    results[key] = self._select_ar(dataset_statistics, iou_idx=idx, cls_idx=cls_idx, max_det_idx=-1)\n        return results\n\n    @staticmethod\n    def _select_ap(\n        dataset_statistics: dict,\n        iou_idx: int | list[int] | np.ndarray | None = None,\n        cls_idx: int | Sequence[int] | None = None,\n        max_det_idx: int = -1,\n    ) -> float:\n        \"\"\"\n        Compute average precision\n\n        Args:\n            dataset_statistics (dict): computed statistics over dataset\n\n                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max\n                  detection thresholds\n                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]\n                - `precision`: Precision values at specified recall thresholds\n                  [num_iou_th, num_recall_th, num_classes, num_max_detections]\n                - `scores`: Scores corresponding to specified recall thresholds\n                  [num_iou_th, num_recall_th, num_classes, num_max_detections]\n            iou_idx: index of IoU values to select for evaluation(if None, all values are used)\n            cls_idx: class indices to select, if None all classes will be selected\n            max_det_idx (int): index to select max detection threshold from data\n\n        Returns:\n            np.ndarray: AP value\n        \"\"\"\n        prec = dataset_statistics[\"precision\"]\n        if iou_idx is not None:\n            prec = prec[iou_idx]\n        if cls_idx is not None:\n            prec = prec[..., cls_idx, :]\n        prec = prec[..., max_det_idx]\n        return float(np.mean(prec))\n\n    @staticmethod\n    def _select_ar(\n        dataset_statistics: dict,\n        iou_idx: int | Sequence[int] | None = None,\n        cls_idx: int | Sequence[int] | None = None,\n        max_det_idx: int = -1,\n    ) -> float:\n        \"\"\"\n        Compute average recall\n\n        Args:\n            dataset_statistics (dict): computed statistics over dataset\n\n                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max\n                  detection thresholds\n                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]\n                - `precision`: Precision values at specified recall thresholds\n                  [num_iou_th, num_recall_th, num_classes, num_max_detections]\n                - `scores`: Scores corresponding to specified recall thresholds\n                  [num_iou_th, num_recall_th, num_classes, num_max_detections]\n            iou_idx: index of IoU values to select for evaluation(if None, all values are used)\n            cls_idx: class indices to select, if None all classes will be selected\n            max_det_idx (int): index to select max detection threshold from data\n\n        Returns:\n            np.ndarray: recall value\n        \"\"\"\n        rec = dataset_statistics[\"recall\"]\n        if iou_idx is not None:\n            rec = rec[iou_idx]\n        if cls_idx is not None:\n            rec = rec[..., cls_idx, :]\n        rec = rec[..., max_det_idx]\n\n        if len(rec[rec > -1]) == 0:\n            return -1.0\n\n        return float(np.mean(rec[rec > -1]))\n\n    def _compute_statistics(self, results_list: list[dict[int, dict[str, np.ndarray]]]) -> dict[str, np.ndarray | list]:\n        \"\"\"\n        Compute statistics needed for COCO metrics (mAP, AP of individual classes, mAP@IoU_Thresholds, AR)\n        Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py\n\n        Args:\n            results_list (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)\n                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.\n\n                - `dtMatches`: matched detections [T, D], where T = number of\n                  thresholds, D = number of detections\n                - `gtMatches`: matched ground truth boxes [T, G], where T = number\n                  of thresholds, G = number of ground truth\n                - `dtScores`: prediction scores [D] detection scores\n                - `gtIgnore`: ground truth boxes which should be ignored\n                  [G] indicate whether ground truth should be ignored\n                - `dtIgnore`: detections which should be ignored [T, D],\n                  indicate which detections should be ignored\n\n        Returns:\n            dict: computed statistics over dataset\n                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max\n                  detection thresholds\n                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]\n                - `precision`: Precision values at specified recall thresholds\n                  [num_iou_th, num_recall_th, num_classes, num_max_detections]\n                - `scores`: Scores corresponding to specified recall thresholds\n                  [num_iou_th, num_recall_th, num_classes, num_max_detections]\n        \"\"\"\n        num_iou_th = len(self.iou_thresholds)\n        num_recall_th = len(self.recall_thresholds)\n        num_classes = len(self.classes)\n        num_max_detections = len(self.max_detections)\n\n        # -1 for the precision of absent categories\n        precision = -np.ones((num_iou_th, num_recall_th, num_classes, num_max_detections))\n        recall = -np.ones((num_iou_th, num_classes, num_max_detections))\n        scores = -np.ones((num_iou_th, num_recall_th, num_classes, num_max_detections))\n\n        for cls_idx, cls_i in enumerate(self.classes):  # for each class\n            for max_det_idx, max_det in enumerate(self.max_detections):  # for each maximum number of detections\n                results = [r[cls_idx] for r in results_list if cls_idx in r]  # len is num_images\n\n                if len(results) == 0:\n                    logger.warning(f\"WARNING, no results found for coco metric for class {cls_i}\")\n                    continue\n\n                dt_scores = np.concatenate([r[\"dtScores\"][0:max_det] for r in results])\n                # different sorting method generates slightly different results.\n                # mergesort is used to be consistent as Matlab implementation.\n                inds = np.argsort(-dt_scores, kind=\"mergesort\")\n                dt_scores_sorted = dt_scores[inds]\n\n                # r['dtMatches'] [T, R], where R = sum(all detections)\n                dt_matches = np.concatenate([r[\"dtMatches\"][:, 0:max_det] for r in results], axis=1)[:, inds]\n                dt_ignores = np.concatenate([r[\"dtIgnore\"][:, 0:max_det] for r in results], axis=1)[:, inds]\n                self.check_number_of_iou(dt_matches, dt_ignores)\n                gt_ignore = np.concatenate([r[\"gtIgnore\"] for r in results])\n                num_gt = int(np.count_nonzero(gt_ignore == 0))  # number of ground truth boxes (non ignored)\n                if num_gt == 0:\n                    logger.warning(f\"WARNING, no gt found for coco metric for class {cls_i}\")\n                    continue\n\n                # ignore cases need to be handled differently for tp and fp\n                tps = np.logical_and(dt_matches, np.logical_not(dt_ignores))\n                fps = np.logical_and(np.logical_not(dt_matches), np.logical_not(dt_ignores))\n\n                tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float32)\n                fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float32)\n\n                for th_ind, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):  # for each threshold th_ind\n                    tp, fp = np.array(tp), np.array(fp)\n                    r, p, s = _compute_stats_single_threshold(tp, fp, dt_scores_sorted, self.recall_thresholds, num_gt)\n                    recall[th_ind, cls_idx, max_det_idx] = r\n                    precision[th_ind, :, cls_idx, max_det_idx] = p\n                    # corresponding score thresholds for recall steps\n                    scores[th_ind, :, cls_idx, max_det_idx] = s\n\n        return {\n            \"counts\": [num_iou_th, num_recall_th, num_classes, num_max_detections],  # [4]\n            \"recall\": recall,  # [num_iou_th, num_classes, num_max_detections]\n            \"precision\": precision,  # [num_iou_th, num_recall_th, num_classes, num_max_detections]\n            \"scores\": scores,  # [num_iou_th, num_recall_th, num_classes, num_max_detections]\n        }\n\n\ndef _compute_stats_single_threshold(\n    tp: np.ndarray,\n    fp: np.ndarray,\n    dt_scores_sorted: np.ndarray,\n    recall_thresholds: np.ndarray | Sequence[float],\n    num_gt: int,\n) -> tuple[float, np.ndarray, np.ndarray]:\n    \"\"\"\n    Compute recall value, precision curve and scores thresholds\n    Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py\n\n    Args:\n        tp (np.ndarray): cumsum over true positives [R], R is the number of detections\n        fp (np.ndarray): cumsum over false positives [R], R is the number of detections\n        dt_scores_sorted (np.ndarray): sorted (descending) scores [R], R is the number of detections\n        recall_thresholds (Sequence[float]): recall thresholds which should be evaluated\n        num_gt (int): number of ground truth bounding boxes (excluding boxes which are ignored)\n\n    Returns:\n        - float, overall recall for given IoU value\n        - np.ndarray, precision values at defined recall values\n          [RTH], where RTH is the number of recall thresholds\n        - np.ndarray, prediction scores corresponding to recall values\n          [RTH], where RTH is the number of recall thresholds\n    \"\"\"\n    num_recall_th = len(recall_thresholds)\n\n    rc = tp / num_gt\n    # np.spacing(1) is the smallest representable epsilon with float\n    pr = tp / (fp + tp + np.spacing(1))\n\n    if len(tp):\n        recall = rc[-1]\n    else:\n        # no prediction\n        recall = 0\n\n    # array where precision values nearest to given recall th are saved\n    precision = [0.0] * num_recall_th\n    # save scores for corresponding recall value in here\n    th_scores = np.zeros((num_recall_th,))\n    # numpy is slow without cython optimization for accessing elements\n    # use python array gets significant speed improvement\n    pr = pr.tolist()\n\n    # smooth precision curve (create box shape)\n    for i in range(len(tp) - 1, 0, -1):\n        if pr[i] > pr[i - 1]:\n            pr[i - 1] = pr[i]\n\n    # get indices to nearest given recall threshold (nn interpolation!)\n    inds = np.searchsorted(rc, recall_thresholds, side=\"left\")\n    try:\n        for save_idx, array_index in enumerate(inds):\n            precision[save_idx] = pr[array_index]\n            th_scores[save_idx] = dt_scores_sorted[array_index]\n    except BaseException:\n        pass\n\n    return recall, np.array(precision), np.array(th_scores)\n"
  },
  {
    "path": "monai/apps/detection/metrics/matching.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/matching.py\n# which has the following license...\n# https://github.com/MIC-DKFZ/nnDetection/blob/main/LICENSE\n#\n# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#    http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/cocodataset/cocoapi\n# which has the following license...\n# https://github.com/cocodataset/cocoapi/blob/master/license.txt\n\n# Copyright (c) 2014, Piotr Dollar and Tsung-Yi Lin\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# 1. Redistributions of source code must retain the above copyright notice, this\n#    list of conditions and the following disclaimer.\n# 2. Redistributions in binary form must reproduce the above copyright notice,\n#    this list of conditions and the following disclaimer in the documentation\n#    and/or other materials provided with the distribution.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n# The views and conclusions contained in the software and documentation are those\n# of the authors and should not be interpreted as representing official policies,\n# either expressed or implied, of the FreeBSD Project.\n\"\"\"\nThis script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/matching.py\nThe changes include 1) code reformatting, 2) docstrings,\n3) allow input args gt_ignore to be optional. (If so, no GT boxes will be ignored.)\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sequence\n\nimport numpy as np\n\n__all__ = [\"matching_batch\"]\n\n\ndef matching_batch(\n    iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],\n    iou_thresholds: Sequence[float],\n    pred_boxes: Sequence[np.ndarray],\n    pred_classes: Sequence[np.ndarray],\n    pred_scores: Sequence[np.ndarray],\n    gt_boxes: Sequence[np.ndarray],\n    gt_classes: Sequence[np.ndarray],\n    gt_ignore: Sequence[Sequence[bool]] | Sequence[np.ndarray] | None = None,\n    max_detections: int = 100,\n) -> list[dict[int, dict[str, np.ndarray]]]:\n    \"\"\"\n    Match boxes of a batch to corresponding ground truth for each category\n    independently.\n\n    Args:\n        iou_fn: compute overlap for each pair\n        iou_thresholds: defined which IoU thresholds should be evaluated\n        pred_boxes: predicted boxes from single batch; List[[D, dim * 2]],\n            D number of predictions\n        pred_classes: predicted classes from a single batch; List[[D]],\n            D number of predictions\n        pred_scores: predicted score for each bounding box; List[[D]],\n            D number of predictions\n        gt_boxes: ground truth boxes; List[[G, dim * 2]], G number of ground\n            truth\n        gt_classes: ground truth classes; List[[G]], G number of ground truth\n        gt_ignore: specified if which ground truth boxes are not counted as\n            true positives. If not given, when use all the gt_boxes.\n            (detections which match theses boxes are not counted as false\n            positives either); List[[G]], G number of ground truth\n        max_detections: maximum number of detections which should be evaluated\n\n    Returns:\n        List[Dict[int, Dict[str, np.ndarray]]], each Dict[str, np.ndarray] corresponds to an image.\n        Dict has the following keys.\n\n        - `dtMatches`: matched detections [T, D], where T = number of\n          thresholds, D = number of detections\n        - `gtMatches`: matched ground truth boxes [T, G], where T = number\n          of thresholds, G = number of ground truth\n        - `dtScores`: prediction scores [D] detection scores\n        - `gtIgnore`: ground truth boxes which should be ignored\n          [G] indicate whether ground truth should be ignored\n        - `dtIgnore`: detections which should be ignored [T, D],\n          indicate which detections should be ignored\n\n    Example:\n\n        .. code-block:: python\n\n            from monai.data.box_utils import box_iou\n            from monai.apps.detection.metrics.coco import COCOMetric\n            from monai.apps.detection.metrics.matching import matching_batch\n            # 3D example outputs of one image from detector\n            val_outputs_all = [\n                    {\"boxes\": torch.tensor([[1,1,1,3,4,5]],dtype=torch.float16),\n                    \"labels\": torch.randint(3,(1,)),\n                    \"scores\": torch.randn((1,)).absolute()},\n            ]\n            val_targets_all = [\n                    {\"boxes\": torch.tensor([[1,1,1,2,6,4]],dtype=torch.float16),\n                    \"labels\": torch.randint(3,(1,))},\n            ]\n\n            coco_metric = COCOMetric(\n                classes=['c0','c1','c2'], iou_list=[0.1], max_detection=[10]\n            )\n            results_metric = matching_batch(\n                iou_fn=box_iou,\n                iou_thresholds=coco_metric.iou_thresholds,\n                pred_boxes=[val_data_i[\"boxes\"].numpy() for val_data_i in val_outputs_all],\n                pred_classes=[val_data_i[\"labels\"].numpy() for val_data_i in val_outputs_all],\n                pred_scores=[val_data_i[\"scores\"].numpy() for val_data_i in val_outputs_all],\n                gt_boxes=[val_data_i[\"boxes\"].numpy() for val_data_i in val_targets_all],\n                gt_classes=[val_data_i[\"labels\"].numpy() for val_data_i in val_targets_all],\n            )\n            val_metric_dict = coco_metric(results_metric)\n            print(val_metric_dict)\n    \"\"\"\n    results = []\n    if gt_ignore is None:\n        gt_ignore = [np.full_like(gt_c, False) for gt_c in gt_classes]\n    # iterate over images/batches\n    for pboxes, pclasses, pscores, gboxes, gclasses, gignore in zip(\n        pred_boxes, pred_classes, pred_scores, gt_boxes, gt_classes, gt_ignore\n    ):\n        # for each image\n        img_classes = np.union1d(pclasses, gclasses)  # possible class labels\n        result = {}  # dict contains results for each class in one image\n        for c in img_classes:\n            pred_mask = pclasses == c  # bool mask predictions with current class\n            gt_mask = gclasses == c  # bool mask ground truth with current class\n\n            if not np.any(gt_mask):  # no ground truth\n                result[c] = _matching_no_gt(\n                    iou_thresholds=iou_thresholds, pred_scores=pscores[pred_mask], max_detections=max_detections\n                )\n            elif not np.any(pred_mask):  # no predictions\n                result[c] = _matching_no_pred(iou_thresholds=iou_thresholds, gt_ignore=gignore[gt_mask])\n            else:  # at least one prediction and one ground truth\n                result[c] = _matching_single_image_single_class(\n                    iou_fn=iou_fn,\n                    pred_boxes=pboxes[pred_mask],\n                    pred_scores=pscores[pred_mask],\n                    gt_boxes=gboxes[gt_mask],\n                    gt_ignore=gignore[gt_mask],\n                    max_detections=max_detections,\n                    iou_thresholds=iou_thresholds,\n                )\n        results.append(result)\n    return results\n\n\ndef _matching_no_gt(\n    iou_thresholds: Sequence[float], pred_scores: np.ndarray, max_detections: int\n) -> dict[str, np.ndarray]:\n    \"\"\"\n    Matching result with not ground truth in image\n\n    Args:\n        iou_thresholds: defined which IoU thresholds should be evaluated\n        dt_scores: predicted scores\n        max_detections: maximum number of allowed detections per image.\n            This functions uses this parameter to stay consistent with\n            the actual matching function which needs this limit.\n\n    Returns:\n        computed matching, a Dict[str, np.ndarray]\n\n        - `dtMatches`: matched detections [T, D], where T = number of\n          thresholds, D = number of detections\n        - `gtMatches`: matched ground truth boxes [T, G], where T = number\n          of thresholds, G = number of ground truth\n        - `dtScores`: prediction scores [D] detection scores\n        - `gtIgnore`: ground truth boxes which should be ignored\n          [G] indicate whether ground truth should be ignored\n        - `dtIgnore`: detections which should be ignored [T, D],\n          indicate which detections should be ignored\n    \"\"\"\n    dt_ind = np.argsort(-pred_scores, kind=\"mergesort\")\n    dt_ind = dt_ind[:max_detections]\n    dt_scores = pred_scores[dt_ind]\n\n    num_preds = len(dt_scores)\n\n    gt_match: np.ndarray = np.array([[]] * len(iou_thresholds))\n    dt_match: np.ndarray = np.zeros((len(iou_thresholds), num_preds))\n    dt_ignore: np.ndarray = np.zeros((len(iou_thresholds), num_preds))\n\n    return {\n        \"dtMatches\": dt_match,  # [T, D], where T = number of thresholds, D = number of detections\n        \"gtMatches\": gt_match,  # [T, G], where T = number of thresholds, G = number of ground truth\n        \"dtScores\": dt_scores,  # [D] detection scores\n        \"gtIgnore\": np.array([]).reshape(-1),  # [G] indicate whether ground truth should be ignored\n        \"dtIgnore\": dt_ignore,  # [T, D], indicate which detections should be ignored\n    }\n\n\ndef _matching_no_pred(iou_thresholds: Sequence[float], gt_ignore: np.ndarray) -> dict[str, np.ndarray]:\n    \"\"\"\n    Matching result with no predictions\n\n    Args:\n        iou_thresholds: defined which IoU thresholds should be evaluated\n        gt_ignore: specified if which ground truth boxes are not counted as\n            true positives (detections which match theses boxes are not\n            counted as false positives either); [G], G number of ground truth\n\n    Returns:\n        dict: computed matching\n\n        - `dtMatches`: matched detections [T, D], where T = number of\n          thresholds, D = number of detections\n        - `gtMatches`: matched ground truth boxes [T, G], where T = number\n          of thresholds, G = number of ground truth\n        - `dtScores`: prediction scores [D] detection scores\n        - `gtIgnore`: ground truth boxes which should be ignored\n          [G] indicate whether ground truth should be ignored\n        - `dtIgnore`: detections which should be ignored [T, D],\n          indicate which detections should be ignored\n    \"\"\"\n    dt_scores: np.ndarray = np.array([])\n    dt_match: np.ndarray = np.array([[]] * len(iou_thresholds))\n    dt_ignore: np.ndarray = np.array([[]] * len(iou_thresholds))\n\n    n_gt = 0 if gt_ignore.size == 0 else gt_ignore.shape[0]\n    gt_match = np.zeros((len(iou_thresholds), n_gt))\n\n    return {\n        \"dtMatches\": dt_match,  # [T, D], where T = number of thresholds, D = number of detections\n        \"gtMatches\": gt_match,  # [T, G], where T = number of thresholds, G = number of ground truth\n        \"dtScores\": dt_scores,  # [D] detection scores\n        \"gtIgnore\": gt_ignore.reshape(-1),  # [G] indicate whether ground truth should be ignored\n        \"dtIgnore\": dt_ignore,  # [T, D], indicate which detections should be ignored\n    }\n\n\ndef _matching_single_image_single_class(\n    iou_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],\n    pred_boxes: np.ndarray,\n    pred_scores: np.ndarray,\n    gt_boxes: np.ndarray,\n    gt_ignore: np.ndarray,\n    max_detections: int,\n    iou_thresholds: Sequence[float],\n) -> dict[str, np.ndarray]:\n    \"\"\"\n    Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py\n\n    Args:\n        iou_fn: compute overlap for each pair\n        iou_thresholds: defined which IoU thresholds should be evaluated\n        pred_boxes: predicted boxes from single batch; [D, dim * 2], D number\n            of predictions\n        pred_scores: predicted score for each bounding box; [D], D number of\n            predictions\n        gt_boxes: ground truth boxes; [G, dim * 2], G number of ground truth\n        gt_ignore: specified if which ground truth boxes are not counted as\n            true positives (detections which match theses boxes are not\n            counted as false positives either); [G], G number of ground truth\n        max_detections: maximum number of detections which should be evaluated\n\n    Returns:\n        dict: computed matching\n\n        - `dtMatches`: matched detections [T, D], where T = number of\n          thresholds, D = number of detections\n        - `gtMatches`: matched ground truth boxes [T, G], where T = number\n          of thresholds, G = number of ground truth\n        - `dtScores`: prediction scores [D] detection scores\n        - `gtIgnore`: ground truth boxes which should be ignored\n          [G] indicate whether ground truth should be ignored\n        - `dtIgnore`: detections which should be ignored [T, D],\n          indicate which detections should be ignored\n    \"\"\"\n    # filter for max_detections highest scoring predictions to speed up computation\n    dt_ind = np.argsort(-pred_scores, kind=\"mergesort\")\n    dt_ind = dt_ind[:max_detections]\n\n    pred_boxes = pred_boxes[dt_ind]\n    pred_scores = pred_scores[dt_ind]\n\n    # sort ignored ground truth to last positions\n    gt_ind = np.argsort(gt_ignore, kind=\"mergesort\")\n    gt_boxes = gt_boxes[gt_ind]\n    gt_ignore = gt_ignore[gt_ind]\n\n    # ious between sorted(!) predictions and ground truth\n    ious = iou_fn(pred_boxes, gt_boxes)  # array sized (num_preds, num_gts)\n\n    num_preds, num_gts = ious.shape[0], ious.shape[1]\n    gt_match = np.zeros((len(iou_thresholds), num_gts))\n    dt_match = np.zeros((len(iou_thresholds), num_preds))\n    dt_ignore = np.zeros((len(iou_thresholds), num_preds))\n\n    for tind, t in enumerate(iou_thresholds):\n        for dind, _d in enumerate(pred_boxes):  # iterate detections starting from highest scoring one\n            # information about best match so far (m=-1 -> unmatched)\n            iou = min([t, 1 - 1e-10])\n            m = -1\n\n            for gind, _g in enumerate(gt_boxes):  # iterate ground truth\n                # if this gt already matched, continue\n                if gt_match[tind, gind] > 0:\n                    continue\n\n                # if dt matched to reg gt, and on ignore gt, stop\n                if m > -1 and gt_ignore[m] == 0 and gt_ignore[gind] == 1:\n                    break\n\n                # continue to next gt unless better match made\n                if ious[dind, gind] < iou:\n                    continue\n\n                # if match successful and best so far, store appropriately\n                iou = ious[dind, gind]\n                m = gind\n\n            # if match made, store id of match for both dt and gt\n            if m == -1:\n                continue\n            else:\n                dt_ignore[tind, dind] = int(gt_ignore[m])\n                dt_match[tind, dind] = 1\n                gt_match[tind, m] = 1\n\n    # store results for given image and category\n    return {\n        \"dtMatches\": dt_match,  # [T, D], where T = number of thresholds, D = number of detections\n        \"gtMatches\": gt_match,  # [T, G], where T = number of thresholds, G = number of ground truth\n        \"dtScores\": pred_scores,  # [D] detection scores\n        \"gtIgnore\": gt_ignore.reshape(-1),  # [G] indicate whether ground truth should be ignored\n        \"dtIgnore\": dt_ignore,  # [T, D], indicate which detections should be ignored\n    }\n"
  },
  {
    "path": "monai/apps/detection/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/detection/networks/retinanet_detector.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\"\"\"\nPart of this script is adapted from\nhttps://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom monai.apps.detection.networks.retinanet_network import RetinaNet, resnet_fpn_feature_extractor\nfrom monai.apps.detection.utils.anchor_utils import AnchorGenerator\nfrom monai.apps.detection.utils.ATSS_matcher import ATSSMatcher\nfrom monai.apps.detection.utils.box_coder import BoxCoder\nfrom monai.apps.detection.utils.box_selector import BoxSelector\nfrom monai.apps.detection.utils.detector_utils import check_training_targets, preprocess_images\nfrom monai.apps.detection.utils.hard_negative_sampler import HardNegativeSampler\nfrom monai.apps.detection.utils.predict_utils import ensure_dict_value_to_list_, predict_with_inferer\nfrom monai.data.box_utils import box_iou\nfrom monai.inferers import SlidingWindowInferer\nfrom monai.networks.nets import resnet\nfrom monai.utils import BlendMode, PytorchPadMode, ensure_tuple_rep, optional_import\n\nBalancedPositiveNegativeSampler, _ = optional_import(\n    \"torchvision.models.detection._utils\", name=\"BalancedPositiveNegativeSampler\"\n)\nMatcher, _ = optional_import(\"torchvision.models.detection._utils\", name=\"Matcher\")\n\n\nclass RetinaNetDetector(nn.Module):\n    \"\"\"\n    Retinanet detector, expandable to other one stage anchor based box detectors in the future.\n    An example of construction can found in the source code of\n    :func:`~monai.apps.detection.networks.retinanet_detector.retinanet_resnet50_fpn_detector` .\n\n    The input to the model is expected to be a list of tensors, each of shape (C, H, W) or  (C, H, W, D),\n    one for each image, and should be in 0-1 range. Different images can have different sizes.\n    Or it can also be a Tensor sized (B, C, H, W) or  (B, C, H, W, D). In this case, all images have same size.\n\n    The behavior of the model changes depending if it is in training or evaluation mode.\n\n    During training, the model expects both the input tensors, as well as a targets (list of dictionary),\n    containing:\n\n    - boxes (``FloatTensor[N, 4]`` or ``FloatTensor[N, 6]``): the ground-truth boxes in ``StandardMode``, i.e.,\n      ``[xmin, ymin, xmax, ymax]`` or ``[xmin, ymin, zmin, xmax, ymax, zmax]`` format,\n      with ``0 <= xmin < xmax <= H``, ``0 <= ymin < ymax <= W``, ``0 <= zmin < zmax <= D``.\n    - labels: the class label for each ground-truth box\n\n    The model returns a Dict[str, Tensor] during training, containing the classification and regression\n    losses.\n    When saving the model, only self.network contains trainable parameters and needs to be saved.\n\n    During inference, the model requires only the input tensors, and returns the post-processed\n    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as\n    follows:\n\n    - boxes (``FloatTensor[N, 4]`` or ``FloatTensor[N, 6]``): the predicted boxes in ``StandardMode``, i.e.,\n      ``[xmin, ymin, xmax, ymax]`` or ``[xmin, ymin, zmin, xmax, ymax, zmax]`` format,\n      with ``0 <= xmin < xmax <= H``, ``0 <= ymin < ymax <= W``, ``0 <= zmin < zmax <= D``.\n    - labels (Int64Tensor[N]): the predicted labels for each image\n    - labels_scores (Tensor[N]): the scores for each prediction\n\n    Args:\n        network: a network that takes an image Tensor sized (B, C, H, W) or (B, C, H, W, D) as input\n            and outputs a dictionary Dict[str, List[Tensor]] or Dict[str, Tensor].\n        anchor_generator: anchor generator.\n        box_overlap_metric: func that compute overlap between two sets of boxes, default is Intersection over Union (IoU).\n        debug: whether to print out internal parameters, used for debugging and parameter tuning.\n\n    Notes:\n\n        Input argument ``network`` can be a monai.apps.detection.networks.retinanet_network.RetinaNet(*) object,\n        but any network that meets the following rules is a valid input ``network``.\n\n        1. It should have attributes including spatial_dims, num_classes, cls_key, box_reg_key, num_anchors, size_divisible.\n\n            - spatial_dims (int) is the spatial dimension of the network, we support both 2D and 3D.\n            - num_classes (int) is the number of classes, excluding the background.\n            - size_divisible (int or Sequence[int]) is the expectation on the input image shape.\n              The network needs the input spatial_size to be divisible by size_divisible, length should be 2 or 3.\n            - cls_key (str) is the key to represent classification in the output dict.\n            - box_reg_key (str) is the key to represent box regression in the output dict.\n            - num_anchors (int) is the number of anchor shapes at each location. it should equal to\n              ``self.anchor_generator.num_anchors_per_location()[0]``.\n\n            If network does not have these attributes, user needs to provide them for the detector.\n\n        2. Its input should be an image Tensor sized (B, C, H, W) or (B, C, H, W, D).\n\n        3. About its output ``head_outputs``, it should be either a list of tensors or a dictionary of str: List[Tensor]:\n\n            - If it is a dictionary, it needs to have at least two keys:\n              ``network.cls_key`` and ``network.box_reg_key``, representing predicted classification maps and box regression maps.\n              ``head_outputs[network.cls_key]`` should be List[Tensor] or Tensor. Each Tensor represents\n              classification logits map at one resolution level,\n              sized (B, num_classes*num_anchors, H_i, W_i) or (B, num_classes*num_anchors, H_i, W_i, D_i).\n              ``head_outputs[network.box_reg_key]`` should be List[Tensor] or Tensor. Each Tensor represents\n              box regression map at one resolution level,\n              sized (B, 2*spatial_dims*num_anchors, H_i, W_i)or (B, 2*spatial_dims*num_anchors, H_i, W_i, D_i).\n              ``len(head_outputs[network.cls_key]) == len(head_outputs[network.box_reg_key])``.\n            - If it is a list of 2N tensors, the first N tensors should be the predicted classification maps,\n              and the second N tensors should be the predicted box regression maps.\n\n    Example:\n\n        .. code-block:: python\n\n            # define a naive network\n            import torch\n            class NaiveNet(torch.nn.Module):\n                def __init__(self, spatial_dims: int, num_classes: int):\n                    super().__init__()\n                    self.spatial_dims = spatial_dims\n                    self.num_classes = num_classes\n                    self.size_divisible = 2\n                    self.cls_key = \"cls\"\n                    self.box_reg_key = \"box_reg\"\n                    self.num_anchors = 1\n                def forward(self, images: torch.Tensor):\n                    spatial_size = images.shape[-self.spatial_dims:]\n                    out_spatial_size = tuple(s//self.size_divisible for s in spatial_size)  # half size of input\n                    out_cls_shape = (images.shape[0],self.num_classes*self.num_anchors) + out_spatial_size\n                    out_box_reg_shape = (images.shape[0],2*self.spatial_dims*self.num_anchors) + out_spatial_size\n                    return {self.cls_key: [torch.randn(out_cls_shape)], self.box_reg_key: [torch.randn(out_box_reg_shape)]}\n\n            # create a RetinaNetDetector detector\n            spatial_dims = 3\n            num_classes = 5\n            anchor_generator = monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape(\n                feature_map_scales=(1, ), base_anchor_shapes=((8,) * spatial_dims)\n            )\n            net = NaiveNet(spatial_dims, num_classes)\n            detector = RetinaNetDetector(net, anchor_generator)\n\n            # only detector.network may contain trainable parameters.\n            optimizer = torch.optim.SGD(\n                detector.network.parameters(),\n                1e-3,\n                momentum=0.9,\n                weight_decay=3e-5,\n                nesterov=True,\n            )\n            torch.save(detector.network.state_dict(), 'model.pt')  # save model\n            detector.network.load_state_dict(torch.load('model.pt', weights_only=True))  # load model\n    \"\"\"\n\n    def __init__(\n        self,\n        network: nn.Module,\n        anchor_generator: AnchorGenerator,\n        box_overlap_metric: Callable = box_iou,\n        spatial_dims: int | None = None,  # used only when network.spatial_dims does not exist\n        num_classes: int | None = None,  # used only when network.num_classes does not exist\n        size_divisible: Sequence[int] | int = 1,  # used only when network.size_divisible does not exist\n        cls_key: str = \"classification\",  # used only when network.cls_key does not exist\n        box_reg_key: str = \"box_regression\",  # used only when network.box_reg_key does not exist\n        debug: bool = False,\n    ):\n        super().__init__()\n\n        self.network = network\n        # network attribute\n        self.spatial_dims = self.get_attribute_from_network(\"spatial_dims\", default_value=spatial_dims)\n        self.num_classes = self.get_attribute_from_network(\"num_classes\", default_value=num_classes)\n\n        self.size_divisible = self.get_attribute_from_network(\"size_divisible\", default_value=size_divisible)\n        self.size_divisible = ensure_tuple_rep(self.size_divisible, self.spatial_dims)\n        # keys for the network output\n        self.cls_key = self.get_attribute_from_network(\"cls_key\", default_value=cls_key)\n        self.box_reg_key = self.get_attribute_from_network(\"box_reg_key\", default_value=box_reg_key)\n\n        # check if anchor_generator matches with network\n        self.anchor_generator = anchor_generator\n        self.num_anchors_per_loc = self.anchor_generator.num_anchors_per_location()[0]\n        network_num_anchors = self.get_attribute_from_network(\"num_anchors\", default_value=self.num_anchors_per_loc)\n        if self.num_anchors_per_loc != network_num_anchors:\n            raise ValueError(\n                f\"Number of feature map channels ({network_num_anchors}) \"\n                f\"should match with number of anchors at each location ({self.num_anchors_per_loc}).\"\n            )\n        # if new coming input images has same shape with\n        # self.previous_image_shape, there is no need to generate new anchors.\n        self.anchors: list[Tensor] | None = None\n        self.previous_image_shape: Any | None = None\n\n        self.box_overlap_metric = box_overlap_metric\n        self.debug = debug\n\n        # default setting for training\n        self.fg_bg_sampler: Any | None = None\n        self.set_cls_loss(torch.nn.BCEWithLogitsLoss(reduction=\"mean\"))  # classification loss\n        self.set_box_regression_loss(\n            torch.nn.SmoothL1Loss(beta=1.0 / 9, reduction=\"mean\"), encode_gt=True, decode_pred=False\n        )  # box regression loss\n\n        # default setting for both training and inference\n        # can be updated by self.set_box_coder_weights(*)\n        self.box_coder = BoxCoder(weights=(1.0,) * 2 * self.spatial_dims)\n\n        # default keys in the ground truth targets and predicted boxes,\n        # can be updated by self.set_target_keys(*)\n        self.target_box_key = \"boxes\"\n        self.target_label_key = \"labels\"\n        self.pred_score_key = self.target_label_key + \"_scores\"  # score key for the detected boxes\n\n        # default setting for inference,\n        # can be updated by self.set_sliding_window_inferer(*)\n        self.inferer: SlidingWindowInferer | None = None\n        # can be updated by self.set_box_selector_parameters(*),\n        self.box_selector = BoxSelector(\n            box_overlap_metric=self.box_overlap_metric,\n            score_thresh=0.05,\n            topk_candidates_per_level=1000,\n            nms_thresh=0.5,\n            detections_per_img=300,\n            apply_sigmoid=True,\n        )\n\n    def get_attribute_from_network(self, attr_name, default_value=None):\n        if hasattr(self.network, attr_name):\n            return getattr(self.network, attr_name)\n        elif default_value is not None:\n            return default_value\n        else:\n            raise ValueError(f\"network does not have attribute {attr_name}, please provide it in the detector.\")\n\n    def set_box_coder_weights(self, weights: tuple[float]) -> None:\n        \"\"\"\n        Set the weights for box coder.\n\n        Args:\n            weights: a list/tuple with length of 2*self.spatial_dims\n\n        \"\"\"\n        if len(weights) != 2 * self.spatial_dims:\n            raise ValueError(f\"len(weights) should be {2 * self.spatial_dims}, got weights={weights}.\")\n        self.box_coder = BoxCoder(weights=weights)\n\n    def set_target_keys(self, box_key: str, label_key: str) -> None:\n        \"\"\"\n        Set keys for the training targets and inference outputs.\n        During training, both box_key and label_key should be keys in the targets\n        when performing ``self.forward(input_images, targets)``.\n        During inference, they will be the keys in the output dict of `self.forward(input_images)``.\n        \"\"\"\n        self.target_box_key = box_key\n        self.target_label_key = label_key\n        self.pred_score_key = label_key + \"_scores\"\n\n    def set_cls_loss(self, cls_loss: nn.Module) -> None:\n        \"\"\"\n        Using for training. Set loss for classification that takes logits as inputs, make sure sigmoid/softmax is built in.\n\n        Args:\n            cls_loss: loss module for classification\n\n        Example:\n            .. code-block:: python\n\n                detector.set_cls_loss(torch.nn.BCEWithLogitsLoss(reduction=\"mean\"))\n                detector.set_cls_loss(FocalLoss(reduction=\"mean\", gamma=2.0))\n        \"\"\"\n        self.cls_loss_func = cls_loss\n\n    def set_box_regression_loss(self, box_loss: nn.Module, encode_gt: bool, decode_pred: bool) -> None:\n        \"\"\"\n        Using for training. Set loss for box regression.\n\n        Args:\n            box_loss: loss module for box regression\n            encode_gt: if True, will encode ground truth boxes to target box regression\n                before computing the losses. Should be True for L1 loss and False for GIoU loss.\n            decode_pred: if True, will decode predicted box regression into predicted boxes\n                before computing losses. Should be False for L1 loss and True for GIoU loss.\n\n        Example:\n            .. code-block:: python\n\n                detector.set_box_regression_loss(\n                    torch.nn.SmoothL1Loss(beta=1.0 / 9, reduction=\"mean\"),\n                    encode_gt = True, decode_pred = False\n                )\n                detector.set_box_regression_loss(\n                    monai.losses.giou_loss.BoxGIoULoss(reduction=\"mean\"),\n                    encode_gt = False, decode_pred = True\n                )\n        \"\"\"\n        self.box_loss_func = box_loss\n        self.encode_gt = encode_gt\n        self.decode_pred = decode_pred\n\n    def set_regular_matcher(\n        self, fg_iou_thresh: float, bg_iou_thresh: float, allow_low_quality_matches: bool = True\n    ) -> None:\n        \"\"\"\n        Using for training. Set torchvision matcher that matches anchors with ground truth boxes.\n\n        Args:\n            fg_iou_thresh: foreground IoU threshold for Matcher, considered as matched if IoU > fg_iou_thresh\n            bg_iou_thresh: background IoU threshold for Matcher, considered as not matched if IoU < bg_iou_thresh\n            allow_low_quality_matches: if True, produce additional matches\n                for predictions that have only low-quality match candidates.\n        \"\"\"\n        if fg_iou_thresh < bg_iou_thresh:\n            raise ValueError(\n                \"Require fg_iou_thresh >= bg_iou_thresh. \"\n                f\"Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}.\"\n            )\n        self.proposal_matcher = Matcher(\n            fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches\n        )\n\n    def set_atss_matcher(self, num_candidates: int = 4, center_in_gt: bool = False) -> None:\n        \"\"\"\n        Using for training. Set ATSS matcher that matches anchors with ground truth boxes\n\n        Args:\n            num_candidates: number of positions to select candidates from.\n                Smaller value will result in a higher matcher threshold and less matched candidates.\n            center_in_gt: If False (default), matched anchor center points do not need\n                to lie within the ground truth box. Recommend False for small objects.\n                If True, will result in a strict matcher and less matched candidates.\n        \"\"\"\n        self.proposal_matcher = ATSSMatcher(num_candidates, self.box_overlap_metric, center_in_gt, debug=self.debug)\n\n    def set_hard_negative_sampler(\n        self, batch_size_per_image: int, positive_fraction: float, min_neg: int = 1, pool_size: float = 10\n    ) -> None:\n        \"\"\"\n        Using for training. Set hard negative sampler that samples part of the anchors for training.\n\n        HardNegativeSampler is used to suppress false positive rate in classification tasks.\n        During training, it select negative samples with high prediction scores.\n\n        Args:\n            batch_size_per_image: number of elements to be selected per image\n            positive_fraction: percentage of positive elements in the selected samples\n            min_neg: minimum number of negative samples to select if possible.\n            pool_size: when we need ``num_neg`` hard negative samples, they will be randomly selected from\n                ``num_neg * pool_size`` negative samples with the highest prediction scores.\n                Larger ``pool_size`` gives more randomness, yet selects negative samples that are less 'hard',\n                i.e., negative samples with lower prediction scores.\n        \"\"\"\n        self.fg_bg_sampler = HardNegativeSampler(\n            batch_size_per_image=batch_size_per_image,\n            positive_fraction=positive_fraction,\n            min_neg=min_neg,\n            pool_size=pool_size,\n        )\n\n    def set_balanced_sampler(self, batch_size_per_image: int, positive_fraction: float) -> None:\n        \"\"\"\n        Using for training. Set torchvision balanced sampler that samples part of the anchors for training.\n\n        Args:\n            batch_size_per_image: number of elements to be selected per image\n            positive_fraction: percentage of positive elements per batch\n\n        \"\"\"\n        self.fg_bg_sampler = BalancedPositiveNegativeSampler(\n            batch_size_per_image=batch_size_per_image, positive_fraction=positive_fraction\n        )\n\n    def set_sliding_window_inferer(\n        self,\n        roi_size: Sequence[int] | int,\n        sw_batch_size: int = 1,\n        overlap: float = 0.5,\n        mode: BlendMode | str = BlendMode.CONSTANT,\n        sigma_scale: Sequence[float] | float = 0.125,\n        padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,\n        cval: float = 0.0,\n        sw_device: torch.device | str | None = None,\n        device: torch.device | str | None = None,\n        progress: bool = False,\n        cache_roi_weight_map: bool = False,\n    ) -> None:\n        \"\"\"\n        Define sliding window inferer and store it to self.inferer.\n        \"\"\"\n        self.inferer = SlidingWindowInferer(\n            roi_size,\n            sw_batch_size,\n            overlap,\n            mode,\n            sigma_scale,\n            padding_mode,\n            cval,\n            sw_device,\n            device,\n            progress,\n            cache_roi_weight_map,\n        )\n\n    def set_box_selector_parameters(\n        self,\n        score_thresh: float = 0.05,\n        topk_candidates_per_level: int = 1000,\n        nms_thresh: float = 0.5,\n        detections_per_img: int = 300,\n        apply_sigmoid: bool = True,\n    ) -> None:\n        \"\"\"\n        Using for inference. Set the parameters that are used for box selection during inference.\n        The box selection is performed with the following steps:\n\n        #. For each level, discard boxes with scores less than self.score_thresh.\n        #. For each level, keep boxes with top self.topk_candidates_per_level scores.\n        #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.\n        #. For the whole image, keep boxes with top self.detections_per_img scores.\n\n        Args:\n            score_thresh: no box with scores less than score_thresh will be kept\n            topk_candidates_per_level: max number of boxes to keep for each level\n            nms_thresh: box overlapping threshold for NMS\n            detections_per_img: max number of boxes to keep for each image\n        \"\"\"\n\n        self.box_selector = BoxSelector(\n            box_overlap_metric=self.box_overlap_metric,\n            apply_sigmoid=apply_sigmoid,\n            score_thresh=score_thresh,\n            topk_candidates_per_level=topk_candidates_per_level,\n            nms_thresh=nms_thresh,\n            detections_per_img=detections_per_img,\n        )\n\n    def forward(\n        self,\n        input_images: list[Tensor] | Tensor,\n        targets: list[dict[str, Tensor]] | None = None,\n        use_inferer: bool = False,\n    ) -> dict[str, Tensor] | list[dict[str, Tensor]]:\n        \"\"\"\n        Returns a dict of losses during training, or a list predicted dict of boxes and labels during inference.\n\n        Args:\n            input_images: The input to the model is expected to be a list of tensors, each of shape (C, H, W) or  (C, H, W, D),\n                one for each image, and should be in 0-1 range. Different images can have different sizes.\n                Or it can also be a Tensor sized (B, C, H, W) or  (B, C, H, W, D). In this case, all images have same size.\n            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image (optional).\n            use_inferer: whether to use self.inferer, a sliding window inferer, to do the inference.\n                If False, will simply forward the network.\n                If True, will use self.inferer, and requires\n                ``self.set_sliding_window_inferer(*args)`` to have been called before.\n\n        Return:\n            If training mode, will return a dict with at least two keys,\n            including self.cls_key and self.box_reg_key, representing classification loss and box regression loss.\n\n            If evaluation mode, will return a list of detection results.\n            Each element corresponds to an images in ``input_images``, is a dict with at least three keys,\n            including self.target_box_key, self.target_label_key, self.pred_score_key,\n            representing predicted boxes, classification labels, and classification scores.\n\n        \"\"\"\n        # 1. Check if input arguments are valid\n        if self.training:\n            targets = check_training_targets(\n                input_images, targets, self.spatial_dims, self.target_label_key, self.target_box_key\n            )\n            self._check_detector_training_components()\n\n        # 2. Pad list of images to a single Tensor `images` with spatial size divisible by self.size_divisible.\n        # image_sizes stores the original spatial_size of each image before padding.\n        images, image_sizes = preprocess_images(input_images, self.spatial_dims, self.size_divisible)\n\n        # 3. Generate network outputs. Use inferer only in evaluation mode.\n        if self.training or (not use_inferer):\n            head_outputs = self.network(images)\n            if isinstance(head_outputs, (tuple, list)):\n                tmp_dict = {}\n                tmp_dict[self.cls_key] = head_outputs[: len(head_outputs) // 2]\n                tmp_dict[self.box_reg_key] = head_outputs[len(head_outputs) // 2 :]\n                head_outputs = tmp_dict\n            else:\n                # ensure head_outputs is Dict[str, List[Tensor]]\n                ensure_dict_value_to_list_(head_outputs)\n        else:\n            if self.inferer is None:\n                raise ValueError(\n                    \"`self.inferer` is not defined.\" \"Please refer to function self.set_sliding_window_inferer(*).\"\n                )\n            head_outputs = predict_with_inferer(\n                images, self.network, keys=[self.cls_key, self.box_reg_key], inferer=self.inferer\n            )\n\n        # 4. Generate anchors and store it in self.anchors: List[Tensor]\n        self.generate_anchors(images, head_outputs)\n        # num_anchor_locs_per_level: List[int], list of HW or HWD for each level\n        num_anchor_locs_per_level = [x.shape[2:].numel() for x in head_outputs[self.cls_key]]\n\n        # 5. Reshape and concatenate head_outputs values from List[Tensor] to Tensor\n        # head_outputs, originally being Dict[str, List[Tensor]], will be reshaped to Dict[str, Tensor]\n        for key in [self.cls_key, self.box_reg_key]:\n            # reshape to Tensor sized(B, sum(HWA), self.num_classes) for self.cls_key\n            # or (B, sum(HWA), 2* self.spatial_dims) for self.box_reg_key\n            # A = self.num_anchors_per_loc\n            head_outputs[key] = self._reshape_maps(head_outputs[key])\n\n        # 6(1). If during training, return losses\n        if self.training:\n            losses = self.compute_loss(head_outputs, targets, self.anchors, num_anchor_locs_per_level)  # type: ignore\n            return losses\n\n        # 6(2). If during inference, return detection results\n        detections = self.postprocess_detections(\n            head_outputs, self.anchors, image_sizes, num_anchor_locs_per_level  # type: ignore\n        )\n        return detections\n\n    def _check_detector_training_components(self):\n        \"\"\"\n        Check if self.proposal_matcher and self.fg_bg_sampler have been set for training.\n        \"\"\"\n        if not hasattr(self, \"proposal_matcher\"):\n            raise AttributeError(\n                \"Matcher is not set. Please refer to self.set_regular_matcher(*) or self.set_atss_matcher(*).\"\n            )\n        if self.fg_bg_sampler is None and self.debug:\n            warnings.warn(\n                \"No balanced sampler is used. Negative samples are likely to \"\n                \"be much more than positive samples. Please set balanced samplers with self.set_balanced_sampler(*) \"\n                \"or self.set_hard_negative_sampler(*), \"\n                \"or set classification loss function as Focal loss with self.set_cls_loss(*)\"\n            )\n\n    def generate_anchors(self, images: Tensor, head_outputs: dict[str, list[Tensor]]) -> None:\n        \"\"\"\n        Generate anchors and store it in self.anchors: List[Tensor].\n        We generate anchors only when there is no stored anchors,\n        or the new coming images has different shape with self.previous_image_shape\n\n        Args:\n            images: input images, a (B, C, H, W) or (B, C, H, W, D) Tensor.\n            head_outputs: head_outputs. ``head_output_reshape[self.cls_key]`` is a Tensor\n              sized (B, sum(HW(D)A), self.num_classes). ``head_output_reshape[self.box_reg_key]`` is a Tensor\n              sized (B, sum(HW(D)A), 2*self.spatial_dims)\n        \"\"\"\n        if (self.anchors is None) or (self.previous_image_shape != images.shape):\n            self.anchors = self.anchor_generator(images, head_outputs[self.cls_key])  # List[Tensor], len = batchsize\n            self.previous_image_shape = images.shape\n\n    def _reshape_maps(self, result_maps: list[Tensor]) -> Tensor:\n        \"\"\"\n        Concat network output map list to a single Tensor.\n        This function is used in both training and inference.\n\n        Args:\n            result_maps: a list of Tensor, each Tensor is a (B, num_channel*A, H, W) or (B, num_channel*A, H, W, D) map.\n                A = self.num_anchors_per_loc\n\n        Return:\n            reshaped and concatenated result, sized (B, sum(HWA), num_channel) or (B, sum(HWDA), num_channel)\n        \"\"\"\n        all_reshaped_result_map = []\n\n        for result_map in result_maps:\n            batch_size = result_map.shape[0]\n            num_channel = result_map.shape[1] // self.num_anchors_per_loc\n            spatial_size = result_map.shape[-self.spatial_dims :]\n\n            # reshaped_result_map will become (B, A, num_channel, H, W) or (B, A, num_channel, H, W, D)\n            # A = self.num_anchors_per_loc\n            view_shape = (batch_size, -1, num_channel) + spatial_size\n            reshaped_result_map = result_map.view(view_shape)\n\n            # permute output to (B, H, W, A, num_channel) or (B, H, W, D, A, num_channel)\n            if self.spatial_dims == 2:\n                reshaped_result_map = reshaped_result_map.permute(0, 3, 4, 1, 2)\n            elif self.spatial_dims == 3:\n                reshaped_result_map = reshaped_result_map.permute(0, 3, 4, 5, 1, 2)\n            else:\n                raise ValueError(\"Images can only be 2D or 3D.\")\n\n            # reshaped_result_map will become (B, HWA, num_channel) or (B, HWDA, num_channel)\n            reshaped_result_map = reshaped_result_map.reshape(batch_size, -1, num_channel)\n\n            if torch.isnan(reshaped_result_map).any() or torch.isinf(reshaped_result_map).any():\n                if torch.is_grad_enabled():\n                    raise ValueError(\"Concatenated result is NaN or Inf.\")\n                else:\n                    warnings.warn(\"Concatenated result is NaN or Inf.\")\n\n            all_reshaped_result_map.append(reshaped_result_map)\n\n        return torch.cat(all_reshaped_result_map, dim=1)\n\n    def postprocess_detections(\n        self,\n        head_outputs_reshape: dict[str, Tensor],\n        anchors: list[Tensor],\n        image_sizes: list[list[int]],\n        num_anchor_locs_per_level: Sequence[int],\n        need_sigmoid: bool = True,\n    ) -> list[dict[str, Tensor]]:\n        \"\"\"\n        Postprocessing to generate detection result from classification logits and box regression.\n        Use self.box_selector to select the final output boxes for each image.\n\n        Args:\n            head_outputs_reshape: reshaped head_outputs. ``head_output_reshape[self.cls_key]`` is a Tensor\n              sized (B, sum(HW(D)A), self.num_classes). ``head_output_reshape[self.box_reg_key]`` is a Tensor\n              sized (B, sum(HW(D)A), 2*self.spatial_dims)\n            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            anchors: a list of Tensor. Each Tensor represents anchors for each image,\n                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).\n                A = self.num_anchors_per_loc.\n\n        Return:\n            a list of dict, each dict corresponds to detection result on image.\n        \"\"\"\n\n        # recover level sizes, HWA or HWDA for each level\n        num_anchors_per_level = [\n            num_anchor_locs * self.num_anchors_per_loc for num_anchor_locs in num_anchor_locs_per_level\n        ]\n\n        # split outputs per level\n        split_head_outputs: dict[str, list[Tensor]] = {}\n        for k in head_outputs_reshape:\n            split_head_outputs[k] = list(head_outputs_reshape[k].split(num_anchors_per_level, dim=1))\n        split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]  # List[List[Tensor]]\n\n        class_logits = split_head_outputs[self.cls_key]  # List[Tensor], each sized (B, HWA, self.num_classes)\n        box_regression = split_head_outputs[self.box_reg_key]  # List[Tensor], each sized (B, HWA, 2*spatial_dims)\n        compute_dtype = class_logits[0].dtype\n\n        num_images = len(image_sizes)  # B\n\n        detections: list[dict[str, Tensor]] = []\n\n        for index in range(num_images):\n            box_regression_per_image = [\n                br[index] for br in box_regression\n            ]  # List[Tensor], each sized (HWA, 2*spatial_dims)\n            logits_per_image = [cl[index] for cl in class_logits]  # List[Tensor], each sized (HWA, self.num_classes)\n            anchors_per_image, img_spatial_size = split_anchors[index], image_sizes[index]\n            # decode box regression into boxes\n            boxes_per_image = [\n                self.box_coder.decode_single(b.to(torch.float32), a).to(compute_dtype)\n                for b, a in zip(box_regression_per_image, anchors_per_image)\n            ]  # List[Tensor], each sized (HWA, 2*spatial_dims)\n\n            selected_boxes, selected_scores, selected_labels = self.box_selector.select_boxes_per_image(\n                boxes_per_image, logits_per_image, img_spatial_size\n            )\n\n            detections.append(\n                {\n                    self.target_box_key: selected_boxes,  # Tensor, sized (N, 2*spatial_dims)\n                    self.pred_score_key: selected_scores,  # Tensor, sized (N, )\n                    self.target_label_key: selected_labels,  # Tensor, sized (N, )\n                }\n            )\n\n        return detections\n\n    def compute_loss(\n        self,\n        head_outputs_reshape: dict[str, Tensor],\n        targets: list[dict[str, Tensor]],\n        anchors: list[Tensor],\n        num_anchor_locs_per_level: Sequence[int],\n    ) -> dict[str, Tensor]:\n        \"\"\"\n        Compute losses.\n\n        Args:\n            head_outputs_reshape: reshaped head_outputs. ``head_output_reshape[self.cls_key]`` is a Tensor\n              sized (B, sum(HW(D)A), self.num_classes). ``head_output_reshape[self.box_reg_key]`` is a Tensor\n              sized (B, sum(HW(D)A), 2*self.spatial_dims)\n            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            anchors: a list of Tensor. Each Tensor represents anchors for each image,\n                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).\n                A = self.num_anchors_per_loc.\n\n        Return:\n            a dict of several kinds of losses.\n        \"\"\"\n        matched_idxs = self.compute_anchor_matched_idxs(anchors, targets, num_anchor_locs_per_level)\n        losses_cls = self.compute_cls_loss(head_outputs_reshape[self.cls_key], targets, matched_idxs)\n        losses_box_regression = self.compute_box_loss(\n            head_outputs_reshape[self.box_reg_key], targets, anchors, matched_idxs\n        )\n        return {self.cls_key: losses_cls, self.box_reg_key: losses_box_regression}\n\n    def compute_anchor_matched_idxs(\n        self, anchors: list[Tensor], targets: list[dict[str, Tensor]], num_anchor_locs_per_level: Sequence[int]\n    ) -> list[Tensor]:\n        \"\"\"\n        Compute the matched indices between anchors and ground truth (gt) boxes in targets.\n        output[k][i] represents the matched gt index for anchor[i] in image k.\n        Suppose there are M gt boxes for image k. The range of it output[k][i] value is [-2, -1, 0, ..., M-1].\n        [0, M - 1] indicates this anchor is matched with a gt box,\n        while a negative value indicating that it is not matched.\n\n        Args:\n            anchors: a list of Tensor. Each Tensor represents anchors for each image,\n                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).\n                A = self.num_anchors_per_loc.\n            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            num_anchor_locs_per_level: each element represents HW or HWD at this level.\n\n\n        Return:\n            a list of matched index `matched_idxs_per_image` (Tensor[int64]), Tensor sized (sum(HWA),) or (sum(HWDA),).\n            Suppose there are M gt boxes. `matched_idxs_per_image[i]` is a matched gt index in [0, M - 1]\n            or a negative value indicating that anchor i could not be matched.\n            BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2\n        \"\"\"\n        matched_idxs = []\n        for anchors_per_image, targets_per_image in zip(anchors, targets):\n            # anchors_per_image: Tensor, targets_per_image: Dice[str, Tensor]\n            if targets_per_image[self.target_box_key].numel() == 0:\n                # if no GT boxes\n                matched_idxs.append(\n                    torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)\n                )\n                continue\n\n            # matched_idxs_per_image (Tensor[int64]): Tensor sized (sum(HWA),) or (sum(HWDA),)\n            # Suppose there are M gt boxes. matched_idxs_per_image[i] is a matched gt index in [0, M - 1]\n            # or a negative value indicating that anchor i could not be matched.\n            # BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2\n            if isinstance(self.proposal_matcher, Matcher):\n                # if torchvision matcher\n                match_quality_matrix = self.box_overlap_metric(\n                    targets_per_image[self.target_box_key].to(anchors_per_image.device), anchors_per_image\n                )\n                matched_idxs_per_image = self.proposal_matcher(match_quality_matrix)\n            elif isinstance(self.proposal_matcher, ATSSMatcher):\n                # if monai ATSS matcher\n                match_quality_matrix, matched_idxs_per_image = self.proposal_matcher(\n                    targets_per_image[self.target_box_key].to(anchors_per_image.device),\n                    anchors_per_image,\n                    num_anchor_locs_per_level,\n                    self.num_anchors_per_loc,\n                )\n            else:\n                raise NotImplementedError(\n                    \"Currently support torchvision Matcher and monai ATSS matcher. Other types of matcher not supported. \"\n                    \"Please override self.compute_anchor_matched_idxs(*) for your own matcher.\"\n                )\n\n            if self.debug:\n                print(f\"Max box overlap between anchors and gt boxes: {torch.max(match_quality_matrix, dim=1)[0]}.\")\n\n            if torch.max(matched_idxs_per_image) < 0:\n                warnings.warn(\n                    f\"No anchor is matched with GT boxes. Please adjust matcher setting, anchor setting,\"\n                    \" or the network setting to change zoom scale between network output and input images.\"\n                    f\"GT boxes are {targets_per_image[self.target_box_key]}.\"\n                )\n\n            matched_idxs.append(matched_idxs_per_image)\n        return matched_idxs\n\n    def compute_cls_loss(\n        self, cls_logits: Tensor, targets: list[dict[str, Tensor]], matched_idxs: list[Tensor]\n    ) -> Tensor:\n        \"\"\"\n        Compute classification losses.\n\n        Args:\n            cls_logits: classification logits, sized (B, sum(HW(D)A), self.num_classes)\n            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            matched_idxs: a list of matched index. each element is sized (sum(HWA),) or  (sum(HWDA),)\n\n        Return:\n            classification losses.\n        \"\"\"\n        total_cls_logits_list = []\n        total_gt_classes_target_list = []\n        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):\n            # for each image, get training samples\n            sampled_cls_logits_per_image, sampled_gt_classes_target = self.get_cls_train_sample_per_image(\n                cls_logits_per_image, targets_per_image, matched_idxs_per_image\n            )\n            total_cls_logits_list.append(sampled_cls_logits_per_image)\n            total_gt_classes_target_list.append(sampled_gt_classes_target)\n\n        total_cls_logits = torch.cat(total_cls_logits_list, dim=0)\n        total_gt_classes_target = torch.cat(total_gt_classes_target_list, dim=0)\n        losses: Tensor = self.cls_loss_func(total_cls_logits, total_gt_classes_target).to(total_cls_logits.dtype)\n        return losses\n\n    def compute_box_loss(\n        self,\n        box_regression: Tensor,\n        targets: list[dict[str, Tensor]],\n        anchors: list[Tensor],\n        matched_idxs: list[Tensor],\n    ) -> Tensor:\n        \"\"\"\n        Compute box regression losses.\n\n        Args:\n            box_regression: box regression results, sized (B, sum(HWA), 2*self.spatial_dims)\n            targets: a list of dict. Each dict with two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            anchors: a list of Tensor. Each Tensor represents anchors for each image,\n                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).\n                A = self.num_anchors_per_loc.\n            matched_idxs: a list of matched index. each element is sized (sum(HWA),) or  (sum(HWDA),)\n\n        Return:\n            box regression losses.\n        \"\"\"\n        total_box_regression_list = []\n        total_target_regression_list = []\n\n        for targets_per_image, box_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(\n            targets, box_regression, anchors, matched_idxs\n        ):\n            # for each image, get training samples\n            decode_box_regression_per_image, matched_gt_boxes_per_image = self.get_box_train_sample_per_image(\n                box_regression_per_image, targets_per_image, anchors_per_image, matched_idxs_per_image\n            )\n            total_box_regression_list.append(decode_box_regression_per_image)\n            total_target_regression_list.append(matched_gt_boxes_per_image)\n\n        total_box_regression = torch.cat(total_box_regression_list, dim=0)\n        total_target_regression = torch.cat(total_target_regression_list, dim=0)\n\n        if total_box_regression.shape[0] == 0:\n            # if there is no training sample.\n            losses = torch.tensor(0.0)\n            return losses\n\n        losses = self.box_loss_func(total_box_regression, total_target_regression).to(total_box_regression.dtype)\n\n        return losses\n\n    def get_cls_train_sample_per_image(\n        self, cls_logits_per_image: Tensor, targets_per_image: dict[str, Tensor], matched_idxs_per_image: Tensor\n    ) -> tuple[Tensor, Tensor]:\n        \"\"\"\n        Get samples from one image for classification losses computation.\n\n        Args:\n            cls_logits_per_image: classification logits for one image, (sum(HWA), self.num_classes)\n            targets_per_image: a dict with at least two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            matched_idxs_per_image: matched index, Tensor sized (sum(HWA),) or (sum(HWDA),)\n                Suppose there are M gt boxes. matched_idxs_per_image[i] is a matched gt index in [0, M - 1]\n                or a negative value indicating that anchor i could not be matched.\n                BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2\n\n        Return:\n            paired predicted and GT samples from one image for classification losses computation\n        \"\"\"\n\n        if torch.isnan(cls_logits_per_image).any() or torch.isinf(cls_logits_per_image).any():\n            if torch.is_grad_enabled():\n                raise ValueError(\"NaN or Inf in predicted classification logits.\")\n            else:\n                warnings.warn(\"NaN or Inf in predicted classification logits.\")\n\n        foreground_idxs_per_image = matched_idxs_per_image >= 0\n\n        num_foreground = int(foreground_idxs_per_image.sum())\n        num_gt_box = targets_per_image[self.target_box_key].shape[0]\n\n        if self.debug:\n            print(f\"Number of positive (matched) anchors: {num_foreground}; Number of GT box: {num_gt_box}.\")\n            if num_gt_box > 0 and num_foreground < 2 * num_gt_box:\n                print(\n                    f\"Only {num_foreground} anchors are matched with {num_gt_box} GT boxes. \"\n                    \"Please consider adjusting matcher setting, anchor setting,\"\n                    \" or the network setting to change zoom scale between network output and input images.\"\n                )\n\n        # create the target classification with one-hot encoding\n        gt_classes_target = torch.zeros_like(cls_logits_per_image)  # (sum(HW(D)A), self.num_classes)\n        gt_classes_target[\n            foreground_idxs_per_image,  # fg anchor idx in\n            targets_per_image[self.target_label_key][\n                matched_idxs_per_image[foreground_idxs_per_image]\n            ],  # fg class label\n        ] = 1.0\n\n        if self.fg_bg_sampler is None:\n            # if no balanced sampling\n            valid_idxs_per_image = matched_idxs_per_image != self.proposal_matcher.BETWEEN_THRESHOLDS\n        else:\n            # The input of fg_bg_sampler: list of tensors containing -1, 0 or positive values.\n            # Each tensor corresponds to a specific image.\n            # -1 values are ignored, 0 are considered as negatives and > 0 as positives.\n\n            # matched_idxs_per_image (Tensor[int64]): an N tensor where N[i] is a matched gt in\n            # [0, M - 1] or a negative value indicating that prediction i could not\n            # be matched. BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2\n            if isinstance(self.fg_bg_sampler, HardNegativeSampler):\n                max_cls_logits_per_image = torch.max(cls_logits_per_image.to(torch.float32), dim=1)[0]\n                sampled_pos_inds_list, sampled_neg_inds_list = self.fg_bg_sampler(\n                    [matched_idxs_per_image + 1], max_cls_logits_per_image\n                )\n            elif isinstance(self.fg_bg_sampler, BalancedPositiveNegativeSampler):\n                sampled_pos_inds_list, sampled_neg_inds_list = self.fg_bg_sampler([matched_idxs_per_image + 1])\n            else:\n                raise NotImplementedError(\n                    \"Currently support torchvision BalancedPositiveNegativeSampler and monai HardNegativeSampler matcher. \"\n                    \"Other types of sampler not supported. \"\n                    \"Please override self.get_cls_train_sample_per_image(*) for your own sampler.\"\n                )\n\n            sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds_list, dim=0))[0]\n            sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds_list, dim=0))[0]\n            valid_idxs_per_image = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)\n\n        return cls_logits_per_image[valid_idxs_per_image, :], gt_classes_target[valid_idxs_per_image, :]\n\n    def get_box_train_sample_per_image(\n        self,\n        box_regression_per_image: Tensor,\n        targets_per_image: dict[str, Tensor],\n        anchors_per_image: Tensor,\n        matched_idxs_per_image: Tensor,\n    ) -> tuple[Tensor, Tensor]:\n        \"\"\"\n        Get samples from one image for box regression losses computation.\n\n        Args:\n            box_regression_per_image: box regression result for one image, (sum(HWA), 2*self.spatial_dims)\n            targets_per_image: a dict with at least two keys: self.target_box_key and self.target_label_key,\n                ground-truth boxes present in the image.\n            anchors_per_image: anchors of one image,\n                sized (sum(HWA), 2*spatial_dims) or (sum(HWDA), 2*spatial_dims).\n                A = self.num_anchors_per_loc.\n            matched_idxs_per_image: matched index, sized (sum(HWA),) or  (sum(HWDA),)\n\n        Return:\n            paired predicted and GT samples from one image for box regression losses computation\n        \"\"\"\n\n        if torch.isnan(box_regression_per_image).any() or torch.isinf(box_regression_per_image).any():\n            if torch.is_grad_enabled():\n                raise ValueError(\"NaN or Inf in predicted box regression.\")\n            else:\n                warnings.warn(\"NaN or Inf in predicted box regression.\")\n\n        foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]\n        num_gt_box = targets_per_image[self.target_box_key].shape[0]\n\n        # if no GT box, return empty arrays\n        if num_gt_box == 0:\n            return box_regression_per_image[0:0, :], box_regression_per_image[0:0, :]\n\n        # select only the foreground boxes\n        # matched GT boxes for foreground anchors\n        matched_gt_boxes_per_image = targets_per_image[self.target_box_key][\n            matched_idxs_per_image[foreground_idxs_per_image]\n        ].to(box_regression_per_image.device)\n        # predicted box regression for foreground anchors\n        box_regression_per_image = box_regression_per_image[foreground_idxs_per_image, :]\n        # foreground anchors\n        anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]\n\n        # encode GT boxes or decode predicted box regression before computing losses\n        matched_gt_boxes_per_image_ = matched_gt_boxes_per_image\n        box_regression_per_image_ = box_regression_per_image\n        if self.encode_gt:\n            matched_gt_boxes_per_image_ = self.box_coder.encode_single(matched_gt_boxes_per_image_, anchors_per_image)\n        if self.decode_pred:\n            box_regression_per_image_ = self.box_coder.decode_single(box_regression_per_image_, anchors_per_image)\n\n        return box_regression_per_image_, matched_gt_boxes_per_image_\n\n\ndef retinanet_resnet50_fpn_detector(\n    num_classes: int,\n    anchor_generator: AnchorGenerator,\n    returned_layers: Sequence[int] = (1, 2, 3),\n    pretrained: bool = False,\n    progress: bool = True,\n    **kwargs: Any,\n) -> RetinaNetDetector:\n    \"\"\"\n    Returns a RetinaNet detector using a ResNet-50 as backbone, which can be pretrained\n    from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`\n    _.\n\n    Args:\n        num_classes: number of output classes of the model (excluding the background).\n        anchor_generator: AnchorGenerator,\n        returned_layers: returned layers to extract feature maps. Each returned layer should be in the range [1,4].\n            len(returned_layers)+1 will be the number of extracted feature maps.\n            There is an extra maxpooling layer LastLevelMaxPool() appended.\n        pretrained: If True, returns a backbone pre-trained on 23 medical datasets\n        progress: If True, displays a progress bar of the download to stderr\n\n    Return:\n        A RetinaNetDetector object with resnet50 as backbone\n\n    Example:\n\n        .. code-block:: python\n\n            # define a naive network\n            resnet_param = {\n                \"pretrained\": False,\n                \"spatial_dims\": 3,\n                \"n_input_channels\": 2,\n                \"num_classes\": 3,\n                \"conv1_t_size\": 7,\n                \"conv1_t_stride\": (2, 2, 2)\n            }\n            returned_layers = [1]\n            anchor_generator = monai.apps.detection.utils.anchor_utils.AnchorGeneratorWithAnchorShape(\n                feature_map_scales=(1, 2), base_anchor_shapes=((8,) * resnet_param[\"spatial_dims\"])\n            )\n            detector = retinanet_resnet50_fpn_detector(\n                **resnet_param, anchor_generator=anchor_generator, returned_layers=returned_layers\n            )\n    \"\"\"\n\n    backbone = resnet.resnet50(pretrained, progress, **kwargs)\n    spatial_dims = len(backbone.conv1.stride)\n    # number of output feature maps is len(returned_layers)+1\n    feature_extractor = resnet_fpn_feature_extractor(\n        backbone=backbone,\n        spatial_dims=spatial_dims,\n        pretrained_backbone=pretrained,\n        trainable_backbone_layers=None,\n        returned_layers=returned_layers,\n    )\n    num_anchors = anchor_generator.num_anchors_per_location()[0]\n    size_divisible = [s * 2 * 2 ** max(returned_layers) for s in feature_extractor.body.conv1.stride]\n    network = RetinaNet(\n        spatial_dims=spatial_dims,\n        num_classes=num_classes,\n        num_anchors=num_anchors,\n        feature_extractor=feature_extractor,\n        size_divisible=size_divisible,\n    )\n    return RetinaNetDetector(network, anchor_generator)\n"
  },
  {
    "path": "monai/apps/detection/networks/retinanet_network.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\"\"\"\nPart of this script is adapted from\nhttps://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport math\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom monai.networks.blocks.backbone_fpn_utils import BackboneWithFPN, _resnet_fpn_extractor\nfrom monai.networks.layers.factories import Conv\nfrom monai.networks.nets import resnet\nfrom monai.utils import ensure_tuple_rep, look_up_option, optional_import\n\n_validate_trainable_layers, _ = optional_import(\n    \"torchvision.models.detection.backbone_utils\", name=\"_validate_trainable_layers\"\n)\n\n\nclass RetinaNetClassificationHead(nn.Module):\n    \"\"\"\n    A classification head for use in RetinaNet.\n\n    This head takes a list of feature maps as inputs, and outputs a list of classification maps.\n    Each output map has same spatial size with the corresponding input feature map,\n    and the number of output channel is num_anchors * num_classes.\n\n    Args:\n        in_channels: number of channels of the input feature\n        num_anchors: number of anchors to be predicted\n        num_classes: number of classes to be predicted\n        spatial_dims: spatial dimension of the network, should be 2 or 3.\n        prior_probability: prior probability to initialize classification convolutional layers.\n    \"\"\"\n\n    def __init__(\n        self, in_channels: int, num_anchors: int, num_classes: int, spatial_dims: int, prior_probability: float = 0.01\n    ):\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        conv = []\n        for _ in range(4):\n            conv.append(conv_type(in_channels, in_channels, kernel_size=3, stride=1, padding=1))\n            conv.append(nn.GroupNorm(num_groups=8, num_channels=in_channels))\n            conv.append(nn.ReLU())\n        self.conv = nn.Sequential(*conv)\n\n        for layer in self.conv.children():\n            if isinstance(layer, conv_type):  # type: ignore\n                torch.nn.init.normal_(layer.weight, std=0.01)  # type: ignore[arg-type]\n                torch.nn.init.constant_(layer.bias, 0)  # type: ignore[arg-type]\n\n        self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)\n        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)\n        torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))\n\n        self.num_classes = num_classes\n        self.num_anchors = num_anchors\n\n    def forward(self, x: list[Tensor]) -> list[Tensor]:\n        \"\"\"\n        It takes a list of feature maps as inputs, and outputs a list of classification maps.\n        Each output classification map has same spatial size with the corresponding input feature map,\n        and the number of output channel is num_anchors * num_classes.\n\n        Args:\n            x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.\n\n        Return:\n            cls_logits_maps, list of classification map. cls_logits_maps[i] is a\n            (B, num_anchors * num_classes, H_i, W_i) or (B, num_anchors * num_classes, H_i, W_i, D_i) Tensor.\n\n        \"\"\"\n        cls_logits_maps = []\n\n        if isinstance(x, Tensor):\n            feature_maps = [x]\n        else:\n            feature_maps = x\n\n        for features in feature_maps:\n            cls_logits = self.conv(features)\n            cls_logits = self.cls_logits(cls_logits)\n\n            cls_logits_maps.append(cls_logits)\n\n            if not torch.compiler.is_compiling():\n                if torch.isnan(cls_logits).any() or torch.isinf(cls_logits).any():\n                    if torch.is_grad_enabled():\n                        raise ValueError(\"cls_logits is NaN or Inf.\")\n                    else:\n                        warnings.warn(\"cls_logits is NaN or Inf.\")\n\n        return cls_logits_maps\n\n\nclass RetinaNetRegressionHead(nn.Module):\n    \"\"\"\n    A regression head for use in RetinaNet.\n\n    This head takes a list of feature maps as inputs, and outputs a list of box regression maps.\n    Each output box regression map has same spatial size with the corresponding input feature map,\n    and the number of output channel is num_anchors * 2 * spatial_dims.\n\n    Args:\n        in_channels: number of channels of the input feature\n        num_anchors: number of anchors to be predicted\n        spatial_dims: spatial dimension of the network, should be 2 or 3.\n    \"\"\"\n\n    def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n\n        conv = []\n        for _ in range(4):\n            conv.append(conv_type(in_channels, in_channels, kernel_size=3, stride=1, padding=1))\n            conv.append(nn.GroupNorm(num_groups=8, num_channels=in_channels))\n            conv.append(nn.ReLU())\n\n        self.conv = nn.Sequential(*conv)\n\n        self.bbox_reg = conv_type(in_channels, num_anchors * 2 * spatial_dims, kernel_size=3, stride=1, padding=1)\n        torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)\n        torch.nn.init.zeros_(self.bbox_reg.bias)\n\n        for layer in self.conv.children():\n            if isinstance(layer, conv_type):  # type: ignore\n                torch.nn.init.normal_(layer.weight, std=0.01)  # type: ignore[arg-type]\n                torch.nn.init.zeros_(layer.bias)  # type: ignore[arg-type]\n\n    def forward(self, x: list[Tensor]) -> list[Tensor]:\n        \"\"\"\n        It takes a list of feature maps as inputs, and outputs a list of box regression maps.\n        Each output box regression map has same spatial size with the corresponding input feature map,\n        and the number of output channel is num_anchors * 2 * spatial_dims.\n\n        Args:\n            x: list of feature map, x[i] is a (B, in_channels, H_i, W_i) or (B, in_channels, H_i, W_i, D_i) Tensor.\n\n        Return:\n            box_regression_maps, list of box regression map. cls_logits_maps[i] is a\n            (B, num_anchors * 2 * spatial_dims, H_i, W_i) or (B, num_anchors * 2 * spatial_dims, H_i, W_i, D_i) Tensor.\n\n        \"\"\"\n        box_regression_maps = []\n\n        if isinstance(x, Tensor):\n            feature_maps = [x]\n        else:\n            feature_maps = x\n\n        for features in feature_maps:\n            box_regression = self.conv(features)\n            box_regression = self.bbox_reg(box_regression)\n\n            box_regression_maps.append(box_regression)\n\n            if not torch.compiler.is_compiling():\n                if torch.isnan(box_regression).any() or torch.isinf(box_regression).any():\n                    if torch.is_grad_enabled():\n                        raise ValueError(\"box_regression is NaN or Inf.\")\n                    else:\n                        warnings.warn(\"box_regression is NaN or Inf.\")\n\n        return box_regression_maps\n\n\nclass RetinaNet(nn.Module):\n    \"\"\"\n    The network used in RetinaNet.\n\n    It takes an image tensor as inputs, and outputs either 1) a dictionary ``head_outputs``.\n    ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.\n    ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.\n    or 2) a list of 2N tensors ``head_outputs``, with first N tensors being the predicted\n    classification maps and second N tensors being the predicted box regression maps.\n\n    Args:\n        spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.\n        num_classes: number of output classes of the model (excluding the background).\n        num_anchors: number of anchors at each location.\n        feature_extractor: a network that outputs feature maps from the input images,\n            each feature map corresponds to a different resolution.\n            Its output can have a format of Tensor, Dict[Any, Tensor], or Sequence[Tensor].\n            It can be the output of ``resnet_fpn_feature_extractor(*args, **kwargs)``.\n        size_divisible: the spatial size of the network input should be divisible by size_divisible,\n            decided by the feature_extractor.\n        use_list_output: default False. If False, the network outputs a dictionary ``head_outputs``,\n            ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.\n            ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.\n            If True, the network outputs a list of 2N tensors ``head_outputs``, with first N tensors being\n            the predicted classification maps and second N tensors being the predicted box regression maps.\n\n    Example:\n\n        .. code-block:: python\n\n            from monai.networks.nets import resnet\n            spatial_dims = 3  # 3D network\n            conv1_t_stride = (2,2,1)  # stride of first convolutional layer in backbone\n            backbone = resnet.ResNet(\n                spatial_dims = spatial_dims,\n                block = resnet.ResNetBottleneck,\n                layers = [3, 4, 6, 3],\n                block_inplanes = resnet.get_inplanes(),\n                n_input_channels= 1,\n                conv1_t_stride = conv1_t_stride,\n                conv1_t_size = (7,7,7),\n            )\n            # This feature_extractor outputs 4-level feature maps.\n            # number of output feature maps is len(returned_layers)+1\n            returned_layers = [1,2,3]  # returned layer from feature pyramid network\n            feature_extractor = resnet_fpn_feature_extractor(\n                backbone = backbone,\n                spatial_dims = spatial_dims,\n                pretrained_backbone = False,\n                trainable_backbone_layers = None,\n                returned_layers = returned_layers,\n            )\n            # This feature_extractor requires input image spatial size\n            # to be divisible by (32, 32, 16).\n            size_divisible = tuple(2*s*2**max(returned_layers) for s in conv1_t_stride)\n            model = RetinaNet(\n                spatial_dims = spatial_dims,\n                num_classes = 5,\n                num_anchors = 6,\n                feature_extractor=feature_extractor,\n                size_divisible = size_divisible,\n            ).to(device)\n            result = model(torch.rand(2, 1, 128,128,128))\n            cls_logits_maps = result[\"classification\"]  # a list of len(returned_layers)+1 Tensor\n            box_regression_maps = result[\"box_regression\"]  # a list of len(returned_layers)+1 Tensor\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        num_classes: int,\n        num_anchors: int,\n        feature_extractor: nn.Module,\n        size_divisible: Sequence[int] | int = 1,\n        use_list_output: bool = False,\n    ):\n        super().__init__()\n\n        self.spatial_dims = look_up_option(spatial_dims, supported=[1, 2, 3])\n        self.num_classes = num_classes\n        self.size_divisible = ensure_tuple_rep(size_divisible, self.spatial_dims)\n        self.use_list_output = use_list_output\n\n        if not hasattr(feature_extractor, \"out_channels\"):\n            raise ValueError(\n                \"feature_extractor should contain an attribute out_channels \"\n                \"specifying the number of output channels (assumed to be the \"\n                \"same for all the levels)\"\n            )\n        self.feature_extractor = feature_extractor\n\n        self.feature_map_channels: int = self.feature_extractor.out_channels  # type: ignore[assignment]\n        self.num_anchors = num_anchors\n        self.classification_head = RetinaNetClassificationHead(\n            self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims\n        )\n        self.regression_head = RetinaNetRegressionHead(\n            self.feature_map_channels, self.num_anchors, spatial_dims=self.spatial_dims\n        )\n\n        self.cls_key: str = \"classification\"\n        self.box_reg_key: str = \"box_regression\"\n\n    def forward(self, images: Tensor) -> Any:\n        \"\"\"\n        It takes an image tensor as inputs, and outputs predicted classification maps\n        and predicted box regression maps in ``head_outputs``.\n\n        Args:\n            images: input images, sized (B, img_channels, H, W) or (B, img_channels, H, W, D).\n\n        Return:\n            1) If self.use_list_output is False, output a dictionary ``head_outputs`` with\n            keys including self.cls_key and self.box_reg_key.\n            ``head_outputs[self.cls_key]`` is the predicted classification maps, a list of Tensor.\n            ``head_outputs[self.box_reg_key]`` is the predicted box regression maps, a list of Tensor.\n            2) if self.use_list_output is True, outputs a list of 2N tensors ``head_outputs``, with first N tensors being\n            the predicted classification maps and second N tensors being the predicted box regression maps.\n\n        \"\"\"\n        # compute features maps list from the input images.\n        features = self.feature_extractor(images)\n        if isinstance(features, Tensor):\n            feature_maps = [features]\n        elif torch.jit.isinstance(features, dict[str, Tensor]):\n            feature_maps = list(features.values())\n        else:\n            feature_maps = list(features)\n\n        if not isinstance(feature_maps[0], Tensor):\n            raise ValueError(\"feature_extractor output format must be Tensor, Dict[str, Tensor], or Sequence[Tensor].\")\n\n        # compute classification and box regression maps from the feature maps\n        # expandable for mask prediction in the future\n\n        if not self.use_list_output:\n            # output dict\n            head_outputs = {self.cls_key: self.classification_head(feature_maps)}\n            head_outputs[self.box_reg_key] = self.regression_head(feature_maps)\n            return head_outputs\n        else:\n            # output list of tensor, first half is classification, second half is box regression\n            head_outputs_sequence = self.classification_head(feature_maps) + self.regression_head(feature_maps)\n            return head_outputs_sequence\n\n\ndef resnet_fpn_feature_extractor(\n    backbone: resnet.ResNet,\n    spatial_dims: int,\n    pretrained_backbone: bool = False,\n    returned_layers: Sequence[int] = (1, 2, 3),\n    trainable_backbone_layers: int | None = None,\n) -> BackboneWithFPN:\n    \"\"\"\n    Constructs a feature extractor network with a ResNet-FPN backbone, used as feature_extractor in RetinaNet.\n\n    Reference: `\"Focal Loss for Dense Object Detection\" <https://arxiv.org/abs/1708.02002>`_.\n\n    The returned feature_extractor network takes an image tensor as inputs,\n    and outputs a dictionary that maps string to the extracted feature maps (Tensor).\n\n    The input to the returned feature_extractor is expected to be a list of tensors,\n    each of shape ``[C, H, W]`` or ``[C, H, W, D]``,\n    one for each image. Different images can have different sizes.\n\n\n    Args:\n        backbone: a ResNet model, used as backbone.\n        spatial_dims: number of spatial dimensions of the images. We support both 2D and 3D images.\n        pretrained_backbone: whether the backbone has been pre-trained.\n        returned_layers: returned layers to extract feature maps. Each returned layer should be in the range [1,4].\n            len(returned_layers)+1 will be the number of extracted feature maps.\n            There is an extra maxpooling layer LastLevelMaxPool() appended.\n        trainable_backbone_layers: number of trainable (not frozen) resnet layers starting from final block.\n            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.\n            When pretrained_backbone is False, this value is set to be 5.\n            When pretrained_backbone is True, if ``None`` is passed (the default) this value is set to 3.\n\n    Example:\n\n        .. code-block:: python\n\n            from monai.networks.nets import resnet\n            spatial_dims = 3 # 3D network\n            backbone = resnet.ResNet(\n                spatial_dims = spatial_dims,\n                block = resnet.ResNetBottleneck,\n                layers = [3, 4, 6, 3],\n                block_inplanes = resnet.get_inplanes(),\n                n_input_channels= 1,\n                conv1_t_stride = (2,2,1),\n                conv1_t_size = (7,7,7),\n            )\n            # This feature_extractor outputs 4-level feature maps.\n            # number of output feature maps is len(returned_layers)+1\n            feature_extractor = resnet_fpn_feature_extractor(\n                backbone = backbone,\n                spatial_dims = spatial_dims,\n                pretrained_backbone = False,\n                trainable_backbone_layers = None,\n                returned_layers = [1,2,3],\n            )\n            model = RetinaNet(\n                spatial_dims = spatial_dims,\n                num_classes = 5,\n                num_anchors = 6,\n                feature_extractor=feature_extractor,\n                size_divisible = 32,\n            ).to(device)\n    \"\"\"\n    # If pretrained_backbone is False, valid_trainable_backbone_layers = 5.\n    # If pretrained_backbone is True, valid_trainable_backbone_layers = trainable_backbone_layers or 3 if None.\n    valid_trainable_backbone_layers: int = _validate_trainable_layers(\n        pretrained_backbone, trainable_backbone_layers, max_value=5, default_value=3\n    )\n\n    feature_extractor = _resnet_fpn_extractor(\n        backbone,\n        spatial_dims,\n        valid_trainable_backbone_layers,\n        returned_layers=list(returned_layers),\n        extra_blocks=None,\n    )\n    return feature_extractor\n"
  },
  {
    "path": "monai/apps/detection/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/detection/transforms/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for box operations\nhttps://github.com/Project-MONAI/MONAI/wiki/MONAI_Design\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import DtypeLike, NdarrayOrTensor, NdarrayTensor\nfrom monai.data.box_utils import (\n    BoxMode,\n    clip_boxes_to_image,\n    convert_box_mode,\n    convert_box_to_standard_mode,\n    get_spatial_dims,\n    spatial_crop_boxes,\n    standardize_empty_box,\n)\nfrom monai.transforms import Rotate90, SpatialCrop\nfrom monai.transforms.transform import Transform\nfrom monai.utils import ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option\nfrom monai.utils.enums import TransformBackends\n\nfrom .box_ops import (\n    apply_affine_to_boxes,\n    convert_box_to_mask,\n    convert_mask_to_box,\n    flip_boxes,\n    resize_boxes,\n    rot90_boxes,\n    select_labels,\n    zoom_boxes,\n)\n\n__all__ = [\n    \"StandardizeEmptyBox\",\n    \"ConvertBoxToStandardMode\",\n    \"ConvertBoxMode\",\n    \"AffineBox\",\n    \"ZoomBox\",\n    \"ResizeBox\",\n    \"FlipBox\",\n    \"ClipBoxToImage\",\n    \"BoxToMask\",\n    \"MaskToBox\",\n    \"SpatialCropBox\",\n    \"RotateBox90\",\n]\n\n\nclass StandardizeEmptyBox(Transform):\n    \"\"\"\n    When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).\n\n    Args:\n        spatial_dims: number of spatial dimensions of the bounding boxes.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, spatial_dims: int) -> None:\n        self.spatial_dims = spatial_dims\n\n    def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            boxes: source bounding boxes, Nx4 or Nx6 or 0xM torch tensor or ndarray.\n        \"\"\"\n        return standardize_empty_box(boxes, spatial_dims=self.spatial_dims)\n\n\nclass ConvertBoxMode(Transform):\n    \"\"\"\n    This transform converts the boxes in src_mode to the dst_mode.\n\n    Args:\n        src_mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.\n        dst_mode: target box mode. If it is not given, this func will assume it is ``StandardMode()``.\n\n    Note:\n        ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,\n        also represented as \"xyxy\" for 2D and \"xyzxyz\" for 3D.\n\n        src_mode and dst_mode can be:\n            #. str: choose from :class:`~monai.utils.enums.BoxModeName`, for example,\n                - \"xyxy\": boxes has format [xmin, ymin, xmax, ymax]\n                - \"xyzxyz\": boxes has format [xmin, ymin, zmin, xmax, ymax, zmax]\n                - \"xxyy\": boxes has format [xmin, xmax, ymin, ymax]\n                - \"xxyyzz\": boxes has format [xmin, xmax, ymin, ymax, zmin, zmax]\n                - \"xyxyzz\": boxes has format [xmin, ymin, xmax, ymax, zmin, zmax]\n                - \"xywh\": boxes has format [xmin, ymin, xsize, ysize]\n                - \"xyzwhd\": boxes has format [xmin, ymin, zmin, xsize, ysize, zsize]\n                - \"ccwh\": boxes has format [xcenter, ycenter, xsize, ysize]\n                - \"cccwhd\": boxes has format [xcenter, ycenter, zcenter, xsize, ysize, zsize]\n            #. BoxMode class: choose from the subclasses of :class:`~monai.data.box_utils.BoxMode`, for example,\n                - CornerCornerModeTypeA: equivalent to \"xyxy\" or \"xyzxyz\"\n                - CornerCornerModeTypeB: equivalent to \"xxyy\" or \"xxyyzz\"\n                - CornerCornerModeTypeC: equivalent to \"xyxy\" or \"xyxyzz\"\n                - CornerSizeMode: equivalent to \"xywh\" or \"xyzwhd\"\n                - CenterSizeMode: equivalent to \"ccwh\" or \"cccwhd\"\n            #. BoxMode object: choose from the subclasses of :class:`~monai.data.box_utils.BoxMode`, for example,\n                - CornerCornerModeTypeA(): equivalent to \"xyxy\" or \"xyzxyz\"\n                - CornerCornerModeTypeB(): equivalent to \"xxyy\" or \"xxyyzz\"\n                - CornerCornerModeTypeC(): equivalent to \"xyxy\" or \"xyxyzz\"\n                - CornerSizeMode(): equivalent to \"xywh\" or \"xyzwhd\"\n                - CenterSizeMode(): equivalent to \"ccwh\" or \"cccwhd\"\n            #. None: will assume mode is ``StandardMode()``\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,4)\n            # convert boxes with format [xmin, ymin, xmax, ymax] to [xcenter, ycenter, xsize, ysize].\n            box_converter = ConvertBoxMode(src_mode=\"xyxy\", dst_mode=\"ccwh\")\n            box_converter(boxes)\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        src_mode: str | BoxMode | type[BoxMode] | None = None,\n        dst_mode: str | BoxMode | type[BoxMode] | None = None,\n    ) -> None:\n        self.src_mode = src_mode\n        self.dst_mode = dst_mode\n\n    def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Converts the boxes in src_mode to the dst_mode.\n\n        Args:\n            boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n        Returns:\n            bounding boxes with target mode, with same data type as ``boxes``, does not share memory with ``boxes``\n        \"\"\"\n        return convert_box_mode(boxes, src_mode=self.src_mode, dst_mode=self.dst_mode)\n\n\nclass ConvertBoxToStandardMode(Transform):\n    \"\"\"\n    Convert given boxes to standard mode.\n    Standard mode is \"xyxy\" or \"xyzxyz\",\n    representing box format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].\n\n    Args:\n        mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.\n            It follows the same format with ``src_mode`` in :class:`~monai.apps.detection.transforms.array.ConvertBoxMode` .\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,6)\n            # convert boxes with format [xmin, xmax, ymin, ymax, zmin, zmax] to [xmin, ymin, zmin, xmax, ymax, zmax]\n            box_converter = ConvertBoxToStandardMode(mode=\"xxyyzz\")\n            box_converter(boxes)\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:\n        self.mode = mode\n\n    def __call__(self, boxes: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Convert given boxes to standard mode.\n        Standard mode is \"xyxy\" or \"xyzxyz\",\n        representing box format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].\n\n        Args:\n            boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n        Returns:\n            bounding boxes with standard mode, with same data type as ``boxes``, does not share memory with ``boxes``\n        \"\"\"\n        return convert_box_to_standard_mode(boxes, mode=self.mode)\n\n\nclass AffineBox(Transform):\n    \"\"\"\n    Applies affine matrix to the boxes\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __call__(self, boxes: NdarrayOrTensor, affine: NdarrayOrTensor | None) -> NdarrayOrTensor:  # type: ignore\n        \"\"\"\n        Args:\n            boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            affine: affine matrix to be applied to the box coordinate\n        \"\"\"\n        if affine is None:\n            return boxes\n\n        return apply_affine_to_boxes(boxes, affine=affine)\n\n\nclass ZoomBox(Transform):\n    \"\"\"\n    Zooms an ND Box with same padding or slicing setting with Zoom().\n\n    Args:\n        zoom: The zoom factor along the spatial axes.\n            If a float, zoom is the same for each spatial axis.\n            If a sequence, zoom should contain one value for each spatial axis.\n        keep_size: Should keep original size (padding/slicing if needed), default is True.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, zoom: Sequence[float] | float, keep_size: bool = False, **kwargs: Any) -> None:\n        self.zoom = zoom\n        self.keep_size = keep_size\n        self.kwargs = kwargs\n\n    def __call__(self, boxes: NdarrayTensor, src_spatial_size: Sequence[int] | int | None = None) -> NdarrayTensor:\n        \"\"\"\n        Args:\n            boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            src_spatial_size: original image spatial size before zooming, used only when keep_size=True.\n        \"\"\"\n        spatial_dims: int = get_spatial_dims(boxes=boxes)\n        self._zoom = ensure_tuple_rep(self.zoom, spatial_dims)  # match the spatial image dim\n\n        if not self.keep_size:\n            return zoom_boxes(boxes, self._zoom)\n\n        if src_spatial_size is None:\n            raise ValueError(\"keep_size=True, src_spatial_size must be provided.\")\n\n        src_spatial_size = ensure_tuple_rep(src_spatial_size, spatial_dims)\n        dst_spatial_size = [int(round(z * ss)) for z, ss in zip(self._zoom, src_spatial_size)]\n        self._zoom = tuple(ds / float(ss) for ss, ds in zip(src_spatial_size, dst_spatial_size))\n        zoomed_boxes = zoom_boxes(boxes, self._zoom)\n\n        # See also keep_size in monai.transforms.spatial.array.Zoom()\n        if not np.allclose(np.array(src_spatial_size), np.array(dst_spatial_size)):\n            for axis, (od, zd) in enumerate(zip(src_spatial_size, dst_spatial_size)):\n                diff = od - zd\n                half = abs(diff) // 2\n                if diff > 0:  # need padding (half, diff - half)\n                    zoomed_boxes[:, axis] = zoomed_boxes[:, axis] + half\n                    zoomed_boxes[:, axis + spatial_dims] = zoomed_boxes[:, axis + spatial_dims] + half\n                elif diff < 0:  # need slicing (half, half + od)\n                    zoomed_boxes[:, axis] = zoomed_boxes[:, axis] - half\n                    zoomed_boxes[:, axis + spatial_dims] = zoomed_boxes[:, axis + spatial_dims] - half\n        return zoomed_boxes\n\n\nclass ResizeBox(Transform):\n    \"\"\"\n    Resize the input boxes when the corresponding image is\n    resized to given spatial size (with scaling, not cropping/padding).\n\n    Args:\n        spatial_size: expected shape of spatial dimensions after resize operation.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        size_mode: should be \"all\" or \"longest\", if \"all\", will use `spatial_size` for all the spatial dims,\n            if \"longest\", rescale the image so that only the longest side is equal to specified `spatial_size`,\n            which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:\n            https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/\n            #albumentations.augmentations.geometric.resize.LongestMaxSize.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, spatial_size: Sequence[int] | int, size_mode: str = \"all\", **kwargs: Any) -> None:\n        self.size_mode = look_up_option(size_mode, [\"all\", \"longest\"])\n        self.spatial_size = spatial_size\n\n    def __call__(self, boxes: NdarrayOrTensor, src_spatial_size: Sequence[int] | int) -> NdarrayOrTensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            src_spatial_size: original image spatial size before resizing.\n\n        Raises:\n            ValueError: When ``self.spatial_size`` length is less than ``boxes`` spatial dimensions.\n        \"\"\"\n        input_ndim = get_spatial_dims(boxes=boxes)  # spatial ndim\n        src_spatial_size_ = ensure_tuple_rep(src_spatial_size, input_ndim)\n\n        if self.size_mode == \"all\":\n            # spatial_size must be a Sequence if size_mode is 'all'\n            output_ndim = len(ensure_tuple(self.spatial_size))\n            if output_ndim != input_ndim:\n                raise ValueError(\n                    \"len(spatial_size) must be greater or equal to img spatial dimensions, \"\n                    f\"got spatial_size={output_ndim} img={input_ndim}.\"\n                )\n            spatial_size_ = fall_back_tuple(self.spatial_size, src_spatial_size_)\n        else:  # for the \"longest\" mode\n            if not isinstance(self.spatial_size, int):\n                raise ValueError(\"spatial_size must be an int number if size_mode is 'longest'.\")\n            scale = self.spatial_size / max(src_spatial_size_)\n            spatial_size_ = tuple(int(round(s * scale)) for s in src_spatial_size_)\n\n        return resize_boxes(boxes, src_spatial_size_, spatial_size_)\n\n\nclass FlipBox(Transform):\n    \"\"\"\n    Reverses the box coordinates along the given spatial axis. Preserves shape.\n\n    Args:\n        spatial_axis: spatial axes along which to flip over. Default is None.\n            The default `axis=None` will flip over all of the axes of the input array.\n            If axis is negative it counts from the last to the first axis.\n            If axis is a tuple of ints, flipping is performed on all of the axes\n            specified in the tuple.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, spatial_axis: Sequence[int] | int | None = None) -> None:\n        self.spatial_axis = spatial_axis\n\n    def __call__(self, boxes: NdarrayOrTensor, spatial_size: Sequence[int] | int):  # type: ignore\n        \"\"\"\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            spatial_size: image spatial size.\n        \"\"\"\n\n        return flip_boxes(boxes, spatial_size=spatial_size, flip_axes=self.spatial_axis)\n\n\nclass ClipBoxToImage(Transform):\n    \"\"\"\n    Clip the bounding boxes and the associated labels/scores to make sure they are within the image.\n    There might be multiple arrays of labels/scores associated with one array of boxes.\n\n    Args:\n        remove_empty: whether to remove the boxes and corresponding labels that are actually empty\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, remove_empty: bool = False) -> None:\n        self.remove_empty = remove_empty\n\n    def __call__(  # type: ignore\n        self,\n        boxes: NdarrayOrTensor,\n        labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor,\n        spatial_size: Sequence[int] | int,\n    ) -> tuple[NdarrayOrTensor, tuple | NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            labels: Sequence of array. Each element represents classification labels or scores\n                corresponding to ``boxes``, sized (N,).\n            spatial_size: The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3].\n\n        Returns:\n            - clipped boxes, does not share memory with original boxes\n            - clipped labels, does not share memory with original labels\n\n        Example:\n            .. code-block:: python\n\n                box_clipper = ClipBoxToImage(remove_empty=True)\n                boxes = torch.ones(2, 6)\n                class_labels = torch.Tensor([0, 1])\n                pred_scores = torch.Tensor([[0.4,0.3,0.3], [0.5,0.1,0.4]])\n                labels = (class_labels, pred_scores)\n                spatial_size = [32, 32, 32]\n                boxes_clip, labels_clip_tuple = box_clipper(boxes, labels, spatial_size)\n        \"\"\"\n        spatial_dims: int = get_spatial_dims(boxes=boxes)\n        spatial_size = ensure_tuple_rep(spatial_size, spatial_dims)  # match the spatial image dim\n\n        boxes_clip, keep = clip_boxes_to_image(boxes, spatial_size, self.remove_empty)\n        return boxes_clip, select_labels(labels, keep)\n\n\nclass BoxToMask(Transform):\n    \"\"\"\n    Convert box to int16 mask image, which has the same size with the input image.\n\n    Args:\n        bg_label: background labels for the output mask image, make sure it is smaller than any foreground(fg) labels.\n        ellipse_mask: bool.\n\n            - If True, it assumes the object shape is close to ellipse or ellipsoid.\n            - If False, it assumes the object shape is close to rectangle or cube and well occupies the bounding box.\n            - If the users are going to apply random rotation as data augmentation, we suggest setting ellipse_mask=True\n              See also Kalra et al. \"Towards Rotation Invariance in Object Detection\", ICCV 2021.\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, bg_label: int = -1, ellipse_mask: bool = False) -> None:\n        self.bg_label = bg_label\n        self.ellipse_mask = ellipse_mask\n\n    def __call__(  # type: ignore\n        self, boxes: NdarrayOrTensor, labels: NdarrayOrTensor, spatial_size: Sequence[int] | int\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.\n            labels: classification foreground(fg) labels corresponding to `boxes`, dtype should be int, sized (N,).\n            spatial_size: image spatial size.\n\n        Return:\n            - int16 array, sized (num_box, H, W). Each channel represents a box.\n                The foreground region in channel c has intensity of labels[c].\n                The background intensity is bg_label.\n        \"\"\"\n        return convert_box_to_mask(boxes, labels, spatial_size, self.bg_label, self.ellipse_mask)\n\n\nclass MaskToBox(Transform):\n    \"\"\"\n    Convert int16 mask image to box, which has the same size with the input image.\n    Pairs with :py:class:`monai.apps.detection.transforms.array.BoxToMask`.\n    Please make sure the same ``min_fg_label`` is used when using the two transforms in pairs.\n\n    Args:\n        bg_label: background labels for the output mask image, make sure it is smaller than any foreground(fg) labels.\n        box_dtype: output dtype for boxes\n        label_dtype: output dtype for labels\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        bg_label: int = -1,\n        box_dtype: DtypeLike | torch.dtype = torch.float32,\n        label_dtype: DtypeLike | torch.dtype = torch.long,\n    ) -> None:\n        self.bg_label = bg_label\n        self.box_dtype = box_dtype\n        self.label_dtype = label_dtype\n\n    def __call__(self, boxes_mask: NdarrayOrTensor) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box.\n                The foreground region in channel c has intensity of labels[c].\n                The background intensity is bg_label.\n\n        Return:\n            - bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.\n            - classification foreground(fg) labels, dtype should be int, sized (N,).\n        \"\"\"\n        return convert_mask_to_box(boxes_mask, self.bg_label, self.box_dtype, self.label_dtype)\n\n\nclass SpatialCropBox(SpatialCrop):\n    \"\"\"\n    General purpose box cropper when the corresponding image is cropped by SpatialCrop(*) with the same ROI.\n    The difference is that we do not support negative indexing for roi_slices.\n\n    If a dimension of the expected ROI size is bigger than the input image size, will not crop that dimension.\n    So the cropped result may be smaller than the expected ROI, and the cropped results of several images may\n    not have exactly the same shape.\n    It can support to crop ND spatial boxes.\n\n    The cropped region can be parameterised in various ways:\n        - a list of slices for each spatial dimension (do not allow for use of negative indexing)\n        - a spatial center and size\n        - the start and end coordinates of the ROI\n\n    Args:\n        roi_center: voxel coordinates for center of the crop ROI.\n        roi_size: size of the crop ROI, if a dimension of ROI size is bigger than image size,\n            will not crop that dimension of the image.\n        roi_start: voxel coordinates for start of the crop ROI.\n        roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,\n            use the end coordinate of image.\n        roi_slices: list of slices for each of the spatial dimensions.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        roi_center: Sequence[int] | NdarrayOrTensor | None = None,\n        roi_size: Sequence[int] | NdarrayOrTensor | None = None,\n        roi_start: Sequence[int] | NdarrayOrTensor | None = None,\n        roi_end: Sequence[int] | NdarrayOrTensor | None = None,\n        roi_slices: Sequence[slice] | None = None,\n    ) -> None:\n        super().__init__(roi_center, roi_size, roi_start, roi_end, roi_slices)\n        for s in self.slices:\n            if s.start < 0 or s.stop < 0 or (s.step is not None and s.step < 0):\n                raise ValueError(\"Currently negative indexing is not supported for SpatialCropBox.\")\n\n    def __call__(  # type: ignore[override]\n        self, boxes: NdarrayTensor, labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor\n    ) -> tuple[NdarrayTensor, tuple | NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            labels: Sequence of array. Each element represents classification labels or scores\n\n        Returns:\n            - cropped boxes, does not share memory with original boxes\n            - cropped labels, does not share memory with original labels\n\n        Example:\n            .. code-block:: python\n\n                box_cropper = SpatialCropPadBox(roi_start=[0, 1, 4], roi_end=[21, 15, 8])\n                boxes = torch.ones(2, 6)\n                class_labels = torch.Tensor([0, 1])\n                pred_scores = torch.Tensor([[0.4,0.3,0.3], [0.5,0.1,0.4]])\n                labels = (class_labels, pred_scores)\n                boxes_crop, labels_crop_tuple = box_cropper(boxes, labels)\n        \"\"\"\n        spatial_dims = min(len(self.slices), get_spatial_dims(boxes=boxes))  # spatial dims\n        boxes_crop, keep = spatial_crop_boxes(\n            boxes,\n            [self.slices[axis].start for axis in range(spatial_dims)],\n            [self.slices[axis].stop for axis in range(spatial_dims)],\n        )\n        return boxes_crop, select_labels(labels, keep)\n\n\nclass RotateBox90(Rotate90):\n    \"\"\"\n    Rotate a boxes by 90 degrees in the plane specified by `axes`.\n    See box_ops.rot90_boxes for additional details\n\n    Args:\n        k: number of times to rotate by 90 degrees.\n        spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n            Default: (0, 1), this is the first two axis in spatial dimensions.\n            If axis is negative it counts from the last to the first axis.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __call__(self, boxes: NdarrayTensor, spatial_size: Sequence[int] | int) -> NdarrayTensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ]),\n        \"\"\"\n        return rot90_boxes(boxes, spatial_size, self.k, self.spatial_axes)\n"
  },
  {
    "path": "monai/apps/detection/transforms/box_ops.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import DtypeLike, NdarrayOrTensor, NdarrayTensor\nfrom monai.data.box_utils import COMPUTE_DTYPE, TO_REMOVE, get_spatial_dims\nfrom monai.transforms import Resize\nfrom monai.transforms.utils import create_scale\nfrom monai.utils import look_up_option\nfrom monai.utils.misc import ensure_tuple, ensure_tuple_rep\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type\n\n\ndef _apply_affine_to_points(points: torch.Tensor, affine: torch.Tensor, include_shift: bool = True) -> torch.Tensor:\n    \"\"\"\n    This internal function applies affine matrices to the point coordinate\n\n    Args:\n        points: point coordinates, Nx2 or Nx3 torch tensor or ndarray, representing [x, y] or [x, y, z]\n        affine: affine matrix to be applied to the point coordinates, sized (spatial_dims+1,spatial_dims+1)\n        include_shift: default True, whether the function apply translation (shift) in the affine transform\n\n    Returns:\n        transformed point coordinates, with same data type as ``points``, does not share memory with ``points``\n    \"\"\"\n\n    spatial_dims = get_spatial_dims(points=points)\n\n    # compute new points\n    if include_shift:\n        # append 1 to form Nx(spatial_dims+1) vector, then transpose\n        points_affine = torch.cat(\n            [points, torch.ones(points.shape[0], 1, device=points.device, dtype=points.dtype)], dim=1\n        ).transpose(0, 1)\n        # apply affine\n        points_affine = torch.matmul(affine, points_affine)\n        # remove appended 1 and transpose back\n        points_affine = points_affine[:spatial_dims, :].transpose(0, 1)\n    else:\n        points_affine = points.transpose(0, 1)\n        points_affine = torch.matmul(affine[:spatial_dims, :spatial_dims], points_affine)\n        points_affine = points_affine.transpose(0, 1)\n\n    return points_affine\n\n\ndef apply_affine_to_boxes(boxes: NdarrayTensor, affine: NdarrayOrTensor) -> NdarrayTensor:\n    \"\"\"\n    This function applies affine matrices to the boxes\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode\n        affine: affine matrix to be applied to the box coordinates, sized (spatial_dims+1,spatial_dims+1)\n\n    Returns:\n        returned affine transformed boxes, with same data type as ``boxes``, does not share memory with ``boxes``\n    \"\"\"\n\n    # convert numpy to tensor if needed\n    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)\n\n    # some operation does not support torch.float16\n    # convert to float32\n\n    boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE)\n    affine_t, *_ = convert_to_dst_type(src=affine, dst=boxes_t)\n\n    spatial_dims = get_spatial_dims(boxes=boxes_t)\n\n    # affine transform left top and bottom right points\n    # might flipped, thus lt may not be left top any more\n    lt: torch.Tensor = _apply_affine_to_points(boxes_t[:, :spatial_dims], affine_t, include_shift=True)\n    rb: torch.Tensor = _apply_affine_to_points(boxes_t[:, spatial_dims:], affine_t, include_shift=True)\n\n    # make sure lt_new is left top, and rb_new is bottom right\n    lt_new, _ = torch.min(torch.stack([lt, rb], dim=2), dim=2)\n    rb_new, _ = torch.max(torch.stack([lt, rb], dim=2), dim=2)\n\n    boxes_t_affine = torch.cat([lt_new, rb_new], dim=1)\n\n    # convert tensor back to numpy if needed\n    boxes_affine: NdarrayOrTensor\n    boxes_affine, *_ = convert_to_dst_type(src=boxes_t_affine, dst=boxes)\n    return boxes_affine  # type: ignore[return-value]\n\n\ndef zoom_boxes(boxes: NdarrayTensor, zoom: Sequence[float] | float) -> NdarrayTensor:\n    \"\"\"\n    Zoom boxes\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be StandardMode\n        zoom: The zoom factor along the spatial axes.\n            If a float, zoom is the same for each spatial axis.\n            If a sequence, zoom should contain one value for each spatial axis.\n\n    Returns:\n        zoomed boxes, with same data type as ``boxes``, does not share memory with ``boxes``\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(1,4)\n            zoom_boxes(boxes, zoom=[0.5,2.2]) #  will return tensor([[0.5, 2.2, 0.5, 2.2]])\n    \"\"\"\n    spatial_dims = get_spatial_dims(boxes=boxes)\n\n    # generate affine transform corresponding to ``zoom``\n    affine = create_scale(spatial_dims=spatial_dims, scaling_factor=zoom)\n\n    return apply_affine_to_boxes(boxes=boxes, affine=affine)\n\n\ndef resize_boxes(\n    boxes: NdarrayOrTensor, src_spatial_size: Sequence[int] | int, dst_spatial_size: Sequence[int] | int\n) -> NdarrayOrTensor:\n    \"\"\"\n    Resize boxes when the corresponding image is resized\n\n    Args:\n        boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        src_spatial_size: source image spatial size.\n        dst_spatial_size: target image spatial size.\n\n    Returns:\n        resized boxes, with same data type as ``boxes``, does not share memory with ``boxes``\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(1,4)\n            src_spatial_size = [100, 100]\n            dst_spatial_size = [128, 256]\n            resize_boxes(boxes, src_spatial_size, dst_spatial_size) #  will return tensor([[1.28, 2.56, 1.28, 2.56]])\n    \"\"\"\n    spatial_dims: int = get_spatial_dims(boxes=boxes)\n\n    src_spatial_size = ensure_tuple_rep(src_spatial_size, spatial_dims)\n    dst_spatial_size = ensure_tuple_rep(dst_spatial_size, spatial_dims)\n\n    zoom = [dst_spatial_size[axis] / float(src_spatial_size[axis]) for axis in range(spatial_dims)]\n\n    return zoom_boxes(boxes=boxes, zoom=zoom)\n\n\ndef flip_boxes(\n    boxes: NdarrayTensor, spatial_size: Sequence[int] | int, flip_axes: Sequence[int] | int | None = None\n) -> NdarrayTensor:\n    \"\"\"\n    Flip boxes when the corresponding image is flipped\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        spatial_size: image spatial size.\n        flip_axes: spatial axes along which to flip over. Default is None.\n            The default `axis=None` will flip over all of the axes of the input array.\n            If axis is negative it counts from the last to the first axis.\n            If axis is a tuple of ints, flipping is performed on all of the axes\n            specified in the tuple.\n\n    Returns:\n        flipped boxes, with same data type as ``boxes``, does not share memory with ``boxes``\n    \"\"\"\n    spatial_dims: int = get_spatial_dims(boxes=boxes)\n    spatial_size = ensure_tuple_rep(spatial_size, spatial_dims)\n    if flip_axes is None:\n        flip_axes = tuple(range(spatial_dims))\n    flip_axes = ensure_tuple(flip_axes)\n\n    # flip box\n    _flip_boxes: NdarrayTensor = boxes.clone() if isinstance(boxes, torch.Tensor) else deepcopy(boxes)  # type: ignore[assignment]\n\n    for axis in flip_axes:\n        _flip_boxes[:, axis + spatial_dims] = spatial_size[axis] - boxes[:, axis] - TO_REMOVE\n        _flip_boxes[:, axis] = spatial_size[axis] - boxes[:, axis + spatial_dims] - TO_REMOVE\n\n    return _flip_boxes\n\n\ndef convert_box_to_mask(\n    boxes: NdarrayOrTensor,\n    labels: NdarrayOrTensor,\n    spatial_size: Sequence[int] | int,\n    bg_label: int = -1,\n    ellipse_mask: bool = False,\n) -> NdarrayOrTensor:\n    \"\"\"\n    Convert box to int16 mask image, which has the same size with the input image.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.\n        labels: classification foreground(fg) labels corresponding to `boxes`, dtype should be int, sized (N,).\n        spatial_size: image spatial size.\n        bg_label: background labels for the output mask image, make sure it is smaller than any fg labels.\n        ellipse_mask: bool.\n\n            - If True, it assumes the object shape is close to ellipse or ellipsoid.\n            - If False, it assumes the object shape is close to rectangle or cube and well occupies the bounding box.\n            - If the users are going to apply random rotation as data augmentation, we suggest setting ellipse_mask=True\n              See also Kalra et al. \"Towards Rotation Invariance in Object Detection\", ICCV 2021.\n\n    Return:\n        - int16 array, sized (num_box, H, W). Each channel represents a box.\n            The foreground region in channel c has intensity of labels[c].\n            The background intensity is bg_label.\n    \"\"\"\n    spatial_dims: int = get_spatial_dims(boxes=boxes)\n    spatial_size = ensure_tuple_rep(spatial_size, spatial_dims)\n\n    # if no box, return empty mask\n    if labels.shape[0] == 0:\n        boxes_mask_np = np.ones((1,) + spatial_size, dtype=np.int16) * np.int16(bg_label)\n        boxes_mask, *_ = convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)\n        return boxes_mask\n\n    # bg_label should be smaller than labels\n    if bg_label >= min(labels):\n        raise ValueError(\n            f\"bg_label should be smaller than any foreground box labels.\\n\"\n            f\"min(labels)={min(labels)}, while bg_label={bg_label}\"\n        )\n\n    if labels.shape[0] != boxes.shape[0]:\n        raise ValueError(\"Number of labels should equal to number of boxes.\")\n\n    # allocate memory for boxes_mask_np\n    boxes_mask_np = np.ones((labels.shape[0],) + spatial_size, dtype=np.int16) * np.int16(bg_label)\n\n    boxes_np: np.ndarray = convert_data_type(boxes, np.ndarray, dtype=np.int32)[0]\n    if np.any(boxes_np[:, spatial_dims:] > np.array(spatial_size)):\n        raise ValueError(\"Some boxes are larger than the image.\")\n\n    labels_np, *_ = convert_to_dst_type(src=labels, dst=boxes_np)\n    for b in range(boxes_np.shape[0]):\n        # generate a foreground mask\n        box_size = [boxes_np[b, axis + spatial_dims] - boxes_np[b, axis] for axis in range(spatial_dims)]\n        if ellipse_mask:\n            # initialize a square/cube mask\n            max_box_size = max(box_size)  # max of box w/h/d\n            radius = max_box_size / 2.0\n            center = (max_box_size - 1) / 2.0\n            boxes_only_mask = np.ones([max_box_size] * spatial_dims, dtype=np.int16) * np.int16(bg_label)\n            # apply label intensity to generate circle/ball foreground\n            ranges = tuple(slice(0, max_box_size) for _ in range(spatial_dims))\n            dist_from_center = sum((grid - center) ** 2 for grid in np.ogrid[ranges])\n            boxes_only_mask[dist_from_center <= radius**2] = np.int16(labels_np[b])\n            # squeeze it to a ellipse/ellipsoid mask\n            resizer = Resize(spatial_size=box_size, mode=\"nearest\", anti_aliasing=False)\n            boxes_only_mask = resizer(boxes_only_mask[None])[0]  # type: ignore\n        else:\n            # generate a rect mask\n            boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])\n        # apply to global mask\n        slicing = [b]\n        slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims))  # type:ignore\n        boxes_mask_np[tuple(slicing)] = boxes_only_mask\n    return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]\n\n\ndef convert_mask_to_box(\n    boxes_mask: NdarrayOrTensor,\n    bg_label: int = -1,\n    box_dtype: DtypeLike | torch.dtype = torch.float32,\n    label_dtype: DtypeLike | torch.dtype = torch.long,\n) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:\n    \"\"\"\n    Convert int16 mask image to box, which has the same size with the input image\n\n    Args:\n        boxes_mask: int16 array, sized (num_box, H, W). Each channel represents a box.\n            The foreground region in channel c has intensity of labels[c].\n            The background intensity is bg_label.\n        bg_label: background labels for the boxes_mask\n        box_dtype: output dtype for boxes\n        label_dtype: output dtype for labels\n\n    Return:\n        - bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.\n        - classification foreground(fg) labels, dtype should be int, sized (N,).\n    \"\"\"\n    look_up_option(len(boxes_mask.shape), [3, 4])\n    spatial_size = list(boxes_mask.shape[1:])\n    spatial_dims = get_spatial_dims(spatial_size=spatial_size)\n\n    boxes_mask_np, *_ = convert_data_type(boxes_mask, np.ndarray)\n\n    boxes_list = []\n    labels_list = []\n    for b in range(boxes_mask_np.shape[0]):\n        fg_indices = np.nonzero(boxes_mask_np[b, ...] - bg_label)\n        if fg_indices[0].shape[0] == 0:\n            continue\n        boxes_b = []\n        for fd_i in fg_indices:\n            boxes_b.append(min(fd_i))  # top left corner\n        for fd_i in fg_indices:\n            boxes_b.append(max(fd_i) + 1 - TO_REMOVE)  # bottom right corner\n        boxes_list.append(boxes_b)\n        if spatial_dims == 2:\n            labels_list.append(boxes_mask_np[b, fg_indices[0][0], fg_indices[1][0]])\n        if spatial_dims == 3:\n            labels_list.append(boxes_mask_np[b, fg_indices[0][0], fg_indices[1][0], fg_indices[2][0]])\n\n    if len(boxes_list) == 0:\n        boxes_np, labels_np = np.zeros([0, 2 * spatial_dims]), np.zeros([0])\n    else:\n        boxes_np, labels_np = np.asarray(boxes_list), np.asarray(labels_list)\n    boxes, *_ = convert_to_dst_type(src=boxes_np, dst=boxes_mask, dtype=box_dtype)\n    labels, *_ = convert_to_dst_type(src=labels_np, dst=boxes_mask, dtype=label_dtype)\n    return boxes, labels\n\n\ndef select_labels(\n    labels: Sequence[NdarrayOrTensor] | NdarrayOrTensor, keep: NdarrayOrTensor\n) -> tuple | NdarrayOrTensor:\n    \"\"\"\n    For element in labels, select indices keep from it.\n\n    Args:\n        labels: Sequence of array. Each element represents classification labels or scores\n            corresponding to ``boxes``, sized (N,).\n        keep: the indices to keep, same length with each element in labels.\n\n    Return:\n        selected labels, does not share memory with original labels.\n    \"\"\"\n    labels_tuple = ensure_tuple(labels, True)\n\n    labels_select_list = []\n    keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0]\n    for item in labels_tuple:\n        labels_t: torch.Tensor = convert_data_type(item, torch.Tensor)[0]\n        labels_t = labels_t[keep_t, ...]\n        labels_select_list.append(convert_to_dst_type(src=labels_t, dst=item)[0])\n\n    if isinstance(labels, (torch.Tensor, np.ndarray)):\n        return labels_select_list[0]  # type: ignore\n\n    return tuple(labels_select_list)\n\n\ndef swapaxes_boxes(boxes: NdarrayTensor, axis1: int, axis2: int) -> NdarrayTensor:\n    \"\"\"\n    Interchange two axes of boxes.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        axis1: First axis.\n        axis2: Second axis.\n\n    Returns:\n        boxes with two axes interchanged.\n\n    \"\"\"\n    spatial_dims: int = get_spatial_dims(boxes=boxes)\n\n    if isinstance(boxes, torch.Tensor):\n        boxes_swap = boxes.clone()\n    else:\n        boxes_swap = deepcopy(boxes)  # type: ignore\n    boxes_swap[:, [axis1, axis2]] = boxes_swap[:, [axis2, axis1]]\n\n    boxes_swap[:, [spatial_dims + axis1, spatial_dims + axis2]] = boxes_swap[\n        :, [spatial_dims + axis2, spatial_dims + axis1]\n    ]\n    return boxes_swap  # type: ignore[return-value]\n\n\ndef rot90_boxes(\n    boxes: NdarrayTensor, spatial_size: Sequence[int] | int, k: int = 1, axes: tuple[int, int] = (0, 1)\n) -> NdarrayTensor:\n    \"\"\"\n    Rotate boxes by 90 degrees in the plane specified by axes.\n    Rotation direction is from the first towards the second axis.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        spatial_size: image spatial size.\n        k : number of times the array is rotated by 90 degrees.\n        axes: (2,) array_like\n            The array is rotated in the plane defined by the axes. Axes must be different.\n\n    Returns:\n        A rotated view of `boxes`.\n\n    Notes:\n        ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))``  is the reverse of\n        ``rot90_boxes(boxes, spatial_size, k=1, axes=(0,1))``\n        ``rot90_boxes(boxes, spatial_size, k=1, axes=(1,0))`` is equivalent to\n        ``rot90_boxes(boxes, spatial_size, k=-1, axes=(0,1))``\n    \"\"\"\n    spatial_dims: int = get_spatial_dims(boxes=boxes)\n    spatial_size_ = list(ensure_tuple_rep(spatial_size, spatial_dims))\n\n    axes = ensure_tuple(axes)\n\n    if len(axes) != 2:\n        raise ValueError(\"len(axes) must be 2.\")\n\n    if axes[0] == axes[1] or abs(axes[0] - axes[1]) == spatial_dims:\n        raise ValueError(\"Axes must be different.\")\n\n    if axes[0] >= spatial_dims or axes[0] < -spatial_dims or axes[1] >= spatial_dims or axes[1] < -spatial_dims:\n        raise ValueError(f\"Axes={axes} out of range for array of ndim={spatial_dims}.\")\n\n    k %= 4\n\n    if k == 0:\n        return boxes\n    if k == 2:\n        return flip_boxes(flip_boxes(boxes, spatial_size_, axes[0]), spatial_size_, axes[1])\n\n    if k == 1:\n        boxes_ = flip_boxes(boxes, spatial_size_, axes[1])\n        return swapaxes_boxes(boxes_, axes[0], axes[1])\n    else:\n        # k == 3\n        boxes_ = swapaxes_boxes(boxes, axes[0], axes[1])\n        spatial_size_[axes[0]], spatial_size_[axes[1]] = spatial_size_[axes[1]], spatial_size_[axes[0]]\n        return flip_boxes(boxes_, spatial_size_, axes[1])\n"
  },
  {
    "path": "monai/apps/detection/transforms/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for box operations\ndefined in :py:class:`monai.apps.detection.transforms.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping, Sequence\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.apps.detection.transforms.array import (\n    AffineBox,\n    BoxToMask,\n    ClipBoxToImage,\n    ConvertBoxMode,\n    ConvertBoxToStandardMode,\n    FlipBox,\n    MaskToBox,\n    RotateBox90,\n    SpatialCropBox,\n    StandardizeEmptyBox,\n    ZoomBox,\n)\nfrom monai.apps.detection.transforms.box_ops import convert_box_to_mask\nfrom monai.config import KeysCollection, SequenceStr\nfrom monai.config.type_definitions import DtypeLike, NdarrayOrTensor\nfrom monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image\nfrom monai.data.meta_tensor import MetaTensor, get_track_meta\nfrom monai.data.utils import orientation_ras_lps\nfrom monai.transforms import Flip, RandFlip, RandZoom, Rotate90, SpatialCrop, Zoom\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform\nfrom monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices\nfrom monai.utils import InterpolateMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple\nfrom monai.utils.enums import PostFix, TraceKeys\nfrom monai.utils.type_conversion import convert_data_type, convert_to_tensor\n\n__all__ = [\n    \"StandardizeEmptyBoxd\",\n    \"StandardizeEmptyBoxD\",\n    \"StandardizeEmptyBoxDict\",\n    \"ConvertBoxModed\",\n    \"ConvertBoxModeD\",\n    \"ConvertBoxModeDict\",\n    \"ConvertBoxToStandardModed\",\n    \"ConvertBoxToStandardModeD\",\n    \"ConvertBoxToStandardModeDict\",\n    \"AffineBoxToImageCoordinated\",\n    \"AffineBoxToImageCoordinateD\",\n    \"AffineBoxToImageCoordinateDict\",\n    \"ZoomBoxd\",\n    \"ZoomBoxD\",\n    \"ZoomBoxDict\",\n    \"RandZoomBoxd\",\n    \"RandZoomBoxD\",\n    \"RandZoomBoxDict\",\n    \"FlipBoxd\",\n    \"FlipBoxD\",\n    \"FlipBoxDict\",\n    \"RandFlipBoxd\",\n    \"RandFlipBoxD\",\n    \"RandFlipBoxDict\",\n    \"ClipBoxToImaged\",\n    \"ClipBoxToImageD\",\n    \"ClipBoxToImageDict\",\n    \"BoxToMaskd\",\n    \"BoxToMaskD\",\n    \"BoxToMaskDict\",\n    \"MaskToBoxd\",\n    \"MaskToBoxD\",\n    \"MaskToBoxDict\",\n    \"RandCropBoxByPosNegLabeld\",\n    \"RandCropBoxByPosNegLabelD\",\n    \"RandCropBoxByPosNegLabelDict\",\n    \"RotateBox90d\",\n    \"RotateBox90D\",\n    \"RotateBox90Dict\",\n    \"RandRotateBox90d\",\n    \"RandRotateBox90D\",\n    \"RandRotateBox90Dict\",\n]\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\nclass StandardizeEmptyBoxd(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.StandardizeEmptyBox`.\n\n    When boxes are empty, this transform standardize it to shape of (0,4) or (0,6).\n\n    Example:\n        .. code-block:: python\n\n            data = {\"boxes\": torch.ones(0,), \"image\": torch.ones(1, 128, 128, 128)}\n            box_converter = StandardizeEmptyBoxd(box_keys=[\"boxes\"], box_ref_image_keys=\"image\")\n            box_converter(data)\n    \"\"\"\n\n    def __init__(self, box_keys: KeysCollection, box_ref_image_keys: str, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            box_keys: Keys to pick data for transformation.\n            box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode`\n        \"\"\"\n        super().__init__(box_keys, allow_missing_keys)\n        box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)\n        if len(box_ref_image_keys_tuple) > 1:\n            raise ValueError(\n                \"Please provide a single key for box_ref_image_keys.\\\n                All boxes of box_keys are attached to box_ref_image_keys.\"\n            )\n        self.box_ref_image_keys = box_ref_image_keys\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        spatial_dims = len(d[self.box_ref_image_keys].shape) - 1\n        self.converter = StandardizeEmptyBox(spatial_dims=spatial_dims)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        return dict(data)\n\n\nclass ConvertBoxModed(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxMode`.\n\n    This transform converts the boxes in src_mode to the dst_mode.\n\n    Example:\n        .. code-block:: python\n\n            data = {\"boxes\": torch.ones(10,4)}\n            # convert boxes with format [xmin, ymin, xmax, ymax] to [xcenter, ycenter, xsize, ysize].\n            box_converter = ConvertBoxModed(box_keys=[\"boxes\"], src_mode=\"xyxy\", dst_mode=\"ccwh\")\n            box_converter(data)\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        src_mode: str | BoxMode | type[BoxMode] | None = None,\n        dst_mode: str | BoxMode | type[BoxMode] | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            box_keys: Keys to pick data for transformation.\n            src_mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.\n                It follows the same format with ``src_mode`` in :class:`~monai.apps.detection.transforms.array.ConvertBoxMode` .\n            dst_mode: target box mode. If it is not given, this func will assume it is ``StandardMode()``.\n                It follows the same format with ``src_mode`` in :class:`~monai.apps.detection.transforms.array.ConvertBoxMode` .\n            allow_missing_keys: don't raise exception if key is missing.\n\n        See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxMode`\n        \"\"\"\n        super().__init__(box_keys, allow_missing_keys)\n        self.converter = ConvertBoxMode(src_mode=src_mode, dst_mode=dst_mode)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n            self.push_transform(d, key, extra_info={\"src\": self.converter.src_mode, \"dst\": self.converter.dst_mode})\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            tr = self.get_most_recent_transform(d, key)\n            src_mode, dst_mode = tr[TraceKeys.EXTRA_INFO][\"src\"], tr[TraceKeys.EXTRA_INFO][\"dst\"]\n            inverse_converter = ConvertBoxMode(src_mode=dst_mode, dst_mode=src_mode)\n            # Inverse is same as forward\n            d[key] = inverse_converter(d[key])\n            # Remove the applied transform\n            self.pop_transform(d, key)\n        return d\n\n\nclass ConvertBoxToStandardModed(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ConvertBoxToStandardMode`.\n\n    Convert given boxes to standard mode.\n    Standard mode is \"xyxy\" or \"xyzxyz\",\n    representing box format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].\n\n    Example:\n        .. code-block:: python\n\n            data = {\"boxes\": torch.ones(10,6)}\n            # convert boxes with format [xmin, xmax, ymin, ymax, zmin, zmax] to [xmin, ymin, zmin, xmax, ymax, zmax]\n            box_converter = ConvertBoxToStandardModed(box_keys=[\"boxes\"], mode=\"xxyyzz\")\n            box_converter(data)\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        mode: str | BoxMode | type[BoxMode] | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            box_keys: Keys to pick data for transformation.\n            mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.\n                It follows the same format with ``src_mode`` in :class:`~monai.apps.detection.transforms.array.ConvertBoxMode` .\n            allow_missing_keys: don't raise exception if key is missing.\n\n        See also :py:class:`monai.apps.detection,transforms.array.ConvertBoxToStandardMode`\n        \"\"\"\n        super().__init__(box_keys, allow_missing_keys)\n        self.converter = ConvertBoxToStandardMode(mode=mode)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n            self.push_transform(d, key, extra_info={\"mode\": self.converter.mode})\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            tr = self.get_most_recent_transform(d, key)\n            original_mode = tr[TraceKeys.EXTRA_INFO][\"mode\"]\n            inverse_converter = ConvertBoxMode(src_mode=None, dst_mode=original_mode)\n            # Inverse is same as forward\n            d[key] = inverse_converter(d[key])\n            # Remove the applied transform\n            self.pop_transform(d, key)\n        return d\n\n\nclass AffineBoxToImageCoordinated(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform that converts box in world coordinate to image coordinate.\n\n    Args:\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.\n        remove_empty: whether to remove the boxes that are actually empty\n        allow_missing_keys: don't raise exception if key is missing.\n        image_meta_key: explicitly indicate the key of the corresponding metadata dictionary.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, affine, original_shape, etc.\n            it is a string, map to the `box_ref_image_key`.\n            if None, will try to construct meta_keys by `box_ref_image_key_{meta_key_postfix}`.\n        image_meta_key_postfix: if image_meta_keys=None, use `box_ref_image_key_{postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        affine_lps_to_ras: default ``False``. Yet if 1) the image is read by ITKReader,\n            and 2) the ITKReader has affine_lps_to_ras=True, and 3) the box is in world coordinate,\n            then set ``affine_lps_to_ras=True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        box_ref_image_keys: str,\n        allow_missing_keys: bool = False,\n        image_meta_key: str | None = None,\n        image_meta_key_postfix: str | None = DEFAULT_POST_FIX,\n        affine_lps_to_ras: bool = False,\n    ) -> None:\n        super().__init__(box_keys, allow_missing_keys)\n        box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)\n        if len(box_ref_image_keys_tuple) > 1:\n            raise ValueError(\n                \"Please provide a single key for box_ref_image_keys.\\\n                All boxes of box_keys are attached to box_ref_image_keys.\"\n            )\n        self.box_ref_image_keys = box_ref_image_keys\n        self.image_meta_key = image_meta_key or f\"{box_ref_image_keys}_{image_meta_key_postfix}\"\n        self.converter_to_image_coordinate = AffineBox()\n        self.affine_lps_to_ras = affine_lps_to_ras\n\n    def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> tuple[NdarrayOrTensor, torch.Tensor]:\n        d = dict(data)\n\n        meta_key = self.image_meta_key\n        # extract affine matrix from metadata\n        if isinstance(d[self.box_ref_image_keys], MetaTensor):\n            meta_dict = d[self.box_ref_image_keys].meta  # type: ignore\n        elif meta_key in d:\n            meta_dict = d[meta_key]\n        else:\n            raise ValueError(f\"{meta_key} is not found. Please check whether it is the correct the image meta key.\")\n        if \"affine\" not in meta_dict:\n            raise ValueError(\n                f\"'affine' is not found in {meta_key}. \\\n                Please check whether it is the correct the image meta key.\"\n            )\n        affine: NdarrayOrTensor = meta_dict[\"affine\"]\n\n        if self.affine_lps_to_ras:  # RAS affine\n            affine = orientation_ras_lps(affine)\n\n        # when convert boxes from world coordinate to image coordinate,\n        # we apply inverse affine transform\n        affine_t, *_ = convert_data_type(affine, torch.Tensor)\n        # torch.inverse should not run in half precision\n        inv_affine_t = torch.inverse(affine_t.to(COMPUTE_DTYPE))\n        return affine, inv_affine_t\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n\n        affine, inv_affine_t = self.extract_affine(data)  # type: ignore\n\n        for key in self.key_iterator(d):\n            d[key] = self.converter_to_image_coordinate(d[key], affine=inv_affine_t)\n            self.push_transform(d, key, extra_info={\"affine\": affine})\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key)\n            affine = transform[\"extra_info\"][\"affine\"]\n            d[key] = AffineBox()(d[key], affine=affine)\n            self.pop_transform(d, key)\n        return d\n\n\nclass AffineBoxToWorldCoordinated(AffineBoxToImageCoordinated):\n    \"\"\"\n    Dictionary-based transform that converts box in image coordinate to world coordinate.\n\n    Args:\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: The single key that represents the reference image to which ``box_keys`` are attached.\n        remove_empty: whether to remove the boxes that are actually empty\n        allow_missing_keys: don't raise exception if key is missing.\n        image_meta_key: explicitly indicate the key of the corresponding metadata dictionary.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, affine, original_shape, etc.\n            it is a string, map to the `box_ref_image_key`.\n            if None, will try to construct meta_keys by `box_ref_image_key_{meta_key_postfix}`.\n        image_meta_key_postfix: if image_meta_keys=None, use `box_ref_image_key_{postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n        affine_lps_to_ras: default ``False``. Yet if 1) the image is read by ITKReader,\n            and 2) the ITKReader has affine_lps_to_ras=True, and 3) the box is in world coordinate,\n            then set ``affine_lps_to_ras=True``.\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        box_ref_image_keys: str,\n        allow_missing_keys: bool = False,\n        image_meta_key: str | None = None,\n        image_meta_key_postfix: str | None = DEFAULT_POST_FIX,\n        affine_lps_to_ras: bool = False,\n    ) -> None:\n        super().__init__(\n            box_keys, box_ref_image_keys, allow_missing_keys, image_meta_key, image_meta_key_postfix, affine_lps_to_ras\n        )\n        self.converter_to_world_coordinate = AffineBox()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n\n        affine, inv_affine_t = self.extract_affine(data)  # type: ignore\n\n        for key in self.key_iterator(d):\n            d[key] = self.converter_to_world_coordinate(d[key], affine=affine)\n            self.push_transform(d, key, extra_info={\"affine\": inv_affine_t})\n        return d\n\n\nclass ZoomBoxd(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform that zooms input boxes and images with the given zoom scale.\n\n    Args:\n        image_keys: Keys to pick image data for transformation.\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        zoom: The zoom factor along the spatial axes.\n            If a float, zoom is the same for each spatial axis.\n            If a sequence, zoom should contain one value for each spatial axis.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            The mode to pad data after zooming.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of bool or None, each element corresponds to a key in ``keys``.\n        keep_size: Should keep original size (pad if needed), default is True.\n        allow_missing_keys: don't raise exception if key is missing.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        zoom: Sequence[float] | float,\n        mode: SequenceStr = InterpolateMode.AREA,\n        padding_mode: SequenceStr = NumpyPadMode.EDGE,\n        align_corners: Sequence[bool | None] | bool | None = None,\n        keep_size: bool = True,\n        allow_missing_keys: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        super().__init__(self.image_keys + self.box_keys, allow_missing_keys)\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n\n        self.mode = ensure_tuple_rep(mode, len(self.image_keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.image_keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.image_keys))\n        self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs)\n        self.keep_size = keep_size\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d: dict[Hashable, torch.Tensor] = dict(data)\n\n        # zoom box\n        for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):\n            src_spatial_size = d[box_ref_image_key].shape[1:]\n            dst_spatial_size = [int(round(z * ss)) for z, ss in zip(self.zoomer.zoom, src_spatial_size)]  # type: ignore\n            self.zoomer.zoom = [ds / float(ss) for ss, ds in zip(src_spatial_size, dst_spatial_size)]\n            d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)(\n                d[box_key], src_spatial_size=src_spatial_size\n            )\n            self.push_transform(\n                d,\n                box_key,\n                extra_info={\"zoom\": self.zoomer.zoom, \"src_spatial_size\": src_spatial_size, \"type\": \"box_key\"},\n            )\n\n        # zoom image\n        for key, mode, padding_mode, align_corners in zip(\n            self.image_keys, self.mode, self.padding_mode, self.align_corners\n        ):\n            d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners)\n\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d: dict[Hashable, torch.Tensor] = dict(data)\n\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key, check=False)\n            key_type = transform[TraceKeys.EXTRA_INFO].get(\"type\", \"image_key\")\n            # zoom image, copied from monai.transforms.spatial.dictionary.Zoomd\n            if key_type == \"image_key\":\n                d[key] = self.zoomer.inverse(d[key])\n\n            # zoom boxes\n            if key_type == \"box_key\":\n                zoom = np.array(transform[TraceKeys.EXTRA_INFO][\"zoom\"])\n                src_spatial_size = transform[TraceKeys.EXTRA_INFO][\"src_spatial_size\"]\n                box_inverse_transform = ZoomBox(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size)\n                d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size)\n                # Remove the applied transform\n                self.pop_transform(d, key)\n\n        return d\n\n\nclass RandZoomBoxd(RandomizableTransform, MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform that randomly zooms input boxes and images with given probability within given zoom range.\n\n    Args:\n        image_keys: Keys to pick image data for transformation.\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        prob: Probability of zooming.\n        min_zoom: Min zoom factor. Can be float or sequence same size as image.\n            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims\n            to keep the original spatial shape ratio.\n            If a sequence, min_zoom should contain one value for each spatial axis.\n            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.\n        max_zoom: Max zoom factor. Can be float or sequence same size as image.\n            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims\n            to keep the original spatial shape ratio.\n            If a sequence, max_zoom should contain one value for each spatial axis.\n            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            The mode to pad data after zooming.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of bool or None, each element corresponds to a key in ``keys``.\n        keep_size: Should keep original size (pad if needed), default is True.\n        allow_missing_keys: don't raise exception if key is missing.\n        kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.\n            more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n    \"\"\"\n\n    backend = RandZoom.backend\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        prob: float = 0.1,\n        min_zoom: Sequence[float] | float = 0.9,\n        max_zoom: Sequence[float] | float = 1.1,\n        mode: SequenceStr = InterpolateMode.AREA,\n        padding_mode: SequenceStr = NumpyPadMode.EDGE,\n        align_corners: Sequence[bool | None] | bool | None = None,\n        keep_size: bool = True,\n        allow_missing_keys: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        MapTransform.__init__(self, self.image_keys + self.box_keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n\n        self.rand_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, **kwargs)\n        self.mode = ensure_tuple_rep(mode, len(self.image_keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.image_keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.image_keys))\n        self.keep_size = keep_size\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomBoxd:\n        super().set_random_state(seed, state)\n        self.rand_zoom.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            return d\n\n        self.randomize(None)\n\n        # all the keys share the same random zoom factor\n        self.rand_zoom.randomize(d[first_key])\n\n        # zoom box\n        for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):\n            if self._do_transform:\n                src_spatial_size = d[box_ref_image_key].shape[1:]\n                dst_spatial_size = [int(round(z * ss)) for z, ss in zip(self.rand_zoom._zoom, src_spatial_size)]\n                self.rand_zoom._zoom = [ds / float(ss) for ss, ds in zip(src_spatial_size, dst_spatial_size)]\n\n                d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)(\n                    d[box_key], src_spatial_size=src_spatial_size\n                )\n                self.push_transform(\n                    d,\n                    box_key,\n                    extra_info={\"zoom\": self.rand_zoom._zoom, \"src_spatial_size\": src_spatial_size, \"type\": \"box_key\"},\n                )\n\n        # zoom image, copied from monai.transforms.spatial.dictionary.RandZoomd\n        for key, mode, padding_mode, align_corners in zip(\n            self.image_keys, self.mode, self.padding_mode, self.align_corners\n        ):\n            if self._do_transform:\n                d[key] = self.rand_zoom(\n                    d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False\n                )\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            if get_track_meta():\n                xform = self.pop_transform(d[key], check=False) if self._do_transform else {}\n                self.push_transform(d[key], extra_info=xform)\n\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key, check=False)\n            key_type = transform[TraceKeys.EXTRA_INFO].get(\"type\", \"image_key\")\n            # Check if random transform was actually performed (based on `prob`)\n            if transform[TraceKeys.DO_TRANSFORM]:\n                # zoom image, copied from monai.transforms.spatial.dictionary.Zoomd\n                if key_type == \"image_key\":\n                    xform = self.pop_transform(d[key])\n                    d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO])  # type: ignore\n                    d[key] = self.rand_zoom.inverse(d[key])\n\n                # zoom boxes\n                if key_type == \"box_key\":\n                    # Create inverse transform\n                    zoom = np.array(transform[TraceKeys.EXTRA_INFO][\"zoom\"])\n                    src_spatial_size = transform[TraceKeys.EXTRA_INFO][\"src_spatial_size\"]\n                    box_inverse_transform = ZoomBox(zoom=(1.0 / zoom).tolist(), keep_size=self.rand_zoom.keep_size)\n                    d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size)\n                    # Remove the applied transform\n                    self.pop_transform(d, key)\n        return d\n\n\nclass FlipBoxd(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform that flip boxes and images.\n\n    Args:\n        image_keys: Keys to pick image data for transformation.\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        spatial_axis: Spatial axes along which to flip over. Default is None.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = Flip.backend\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        spatial_axis: Sequence[int] | int | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        super().__init__(self.image_keys + self.box_keys, allow_missing_keys)\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n\n        self.flipper = Flip(spatial_axis=spatial_axis)\n        self.box_flipper = FlipBox(spatial_axis=self.flipper.spatial_axis)\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n\n        for key in self.image_keys:\n            d[key] = self.flipper(d[key])\n\n        for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):\n            spatial_size = d[box_ref_image_key].shape[1:]\n            d[box_key] = self.box_flipper(d[box_key], spatial_size)\n            self.push_transform(d, box_key, extra_info={\"spatial_size\": spatial_size, \"type\": \"box_key\"})\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key, check=False)\n            key_type = transform.get(TraceKeys.EXTRA_INFO, {}).get(\"type\", \"image_key\")\n\n            # flip image, copied from monai.transforms.spatial.dictionary.Flipd\n            if key_type == \"image_key\":\n                d[key] = self.flipper.inverse(d[key])\n\n            # flip boxes\n            if key_type == \"box_key\":\n                spatial_size = transform[TraceKeys.EXTRA_INFO][\"spatial_size\"]\n                d[key] = self.box_flipper(d[key], spatial_size)\n                # Remove the applied transform\n                self.pop_transform(d, key)\n        return d\n\n\nclass RandFlipBoxd(RandomizableTransform, MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform that randomly flip boxes and images with the given probabilities.\n\n    Args:\n        image_keys: Keys to pick image data for transformation.\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        prob: Probability of flipping.\n        spatial_axis: Spatial axes along which to flip over. Default is None.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RandFlip.backend\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        prob: float = 0.1,\n        spatial_axis: Sequence[int] | int | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        MapTransform.__init__(self, self.image_keys + self.box_keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n\n        self.flipper = Flip(spatial_axis=spatial_axis)\n        self.box_flipper = FlipBox(spatial_axis=spatial_axis)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipBoxd:\n        super().set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        self.randomize(None)\n\n        for key in self.image_keys:\n            if self._do_transform:\n                d[key] = self.flipper(d[key])\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            if get_track_meta():\n                xform_info = self.pop_transform(d[key], check=False) if self._do_transform else {}\n                self.push_transform(d[key], extra_info=xform_info)\n\n        for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):\n            spatial_size = d[box_ref_image_key].shape[1:]\n            if self._do_transform:\n                d[box_key] = self.box_flipper(d[box_key], spatial_size)\n            self.push_transform(d, box_key, extra_info={\"spatial_size\": spatial_size, \"type\": \"box_key\"})\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key, check=False)\n            key_type = transform[TraceKeys.EXTRA_INFO].get(\"type\", \"image_key\")\n            # Check if random transform was actually performed (based on `prob`)\n            if transform[TraceKeys.DO_TRANSFORM]:\n                # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd\n                if key_type == \"image_key\":\n                    with self.flipper.trace_transform(False):\n                        d[key] = self.flipper(d[key])\n\n                # flip boxes\n                if key_type == \"box_key\":\n                    spatial_size = transform[TraceKeys.EXTRA_INFO][\"spatial_size\"]\n                    d[key] = self.box_flipper(d[key], spatial_size)\n\n            # Remove the applied transform\n            self.pop_transform(d, key, check=False)\n        return d\n\n\nclass ClipBoxToImaged(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.ClipBoxToImage`.\n\n    Clip the bounding boxes and the associated labels/scores to makes sure they are within the image.\n    There might be multiple keys of labels/scores associated with one key of boxes.\n\n    Args:\n        box_keys: The single key to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        label_keys: Keys that represent the labels corresponding to the ``box_keys``. Multiple keys are allowed.\n        box_ref_image_keys: The single key that represents the reference image\n            to which ``box_keys`` and ``label_keys`` are attached.\n        remove_empty: whether to remove the boxes that are actually empty\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Example:\n        .. code-block:: python\n\n            ClipBoxToImaged(\n                box_keys=\"boxes\", box_ref_image_keys=\"image\", label_keys=[\"labels\", \"scores\"], remove_empty=True\n            )\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        label_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        remove_empty: bool = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        box_keys_tuple = ensure_tuple(box_keys)\n        if len(box_keys_tuple) != 1:\n            raise ValueError(\n                \"Please provide a single key for box_keys.\\\n                All label_keys are attached to this box_keys.\"\n            )\n        box_ref_image_keys_tuple = ensure_tuple(box_ref_image_keys)\n        if len(box_ref_image_keys_tuple) != 1:\n            raise ValueError(\n                \"Please provide a single key for box_ref_image_keys.\\\n                All box_keys and label_keys are attached to this box_ref_image_keys.\"\n            )\n        self.label_keys = ensure_tuple(label_keys)\n        super().__init__(box_keys_tuple, allow_missing_keys)\n\n        self.box_keys = box_keys_tuple[0]\n        self.box_ref_image_keys = box_ref_image_keys_tuple[0]\n        self.clipper = ClipBoxToImage(remove_empty=remove_empty)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        spatial_size = d[self.box_ref_image_keys].shape[1:]\n        labels = [d[label_key] for label_key in self.label_keys]  # could be multiple arrays\n        d[self.box_keys], clipped_labels = self.clipper(d[self.box_keys], labels, spatial_size)\n\n        for label_key, clipped_labels_i in zip(self.label_keys, clipped_labels):\n            d[label_key] = clipped_labels_i\n        return d\n\n\nclass BoxToMaskd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.BoxToMask`.\n    Pairs with :py:class:`monai.apps.detection.transforms.dictionary.MaskToBoxd` .\n    Please make sure the same ``min_fg_label`` is used when using the two transforms in pairs.\n    The output ``d[box_mask_key]`` will have background intensity 0, since the following operations\n    may pad 0 on the border.\n\n    This is the general solution for transforms that need to be applied on images and boxes simultaneously.\n    It is performed with the following steps.\n\n        1) use ``BoxToMaskd`` to covert boxes and labels to box_masks;\n        2) do transforms, e.g., rotation or cropping, on images and box_masks together;\n        3) use ``MaskToBoxd`` to convert box_masks back to boxes and labels.\n\n    Args:\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_mask_keys: Keys to store output box mask results for transformation. Same length with ``box_keys``.\n        label_keys: Keys that represent the labels corresponding to the ``box_keys``. Same length with ``box_keys``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        min_fg_label: min foreground box label.\n        ellipse_mask: bool.\n\n            - If True, it assumes the object shape is close to ellipse or ellipsoid.\n            - If False, it assumes the object shape is close to rectangle or cube and well occupies the bounding box.\n            - If the users are going to apply random rotation as data augmentation, we suggest setting ellipse_mask=True\n              See also Kalra et al. \"Towards Rotation Invariance in Object Detection\", ICCV 2021.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Example:\n        .. code-block:: python\n\n            # This code snippet creates transforms (random rotation and cropping) on boxes, labels, and image together.\n            import numpy as np\n            from monai.transforms import Compose, RandRotated, RandSpatialCropd, DeleteItemsd\n            transforms = Compose(\n                [\n                    BoxToMaskd(\n                        box_keys=\"boxes\", label_keys=\"labels\",\n                        box_mask_keys=\"box_mask\", box_ref_image_keys=\"image\",\n                        min_fg_label=0, ellipse_mask=True\n                    ),\n                    RandRotated(keys=[\"image\",\"box_mask\"],mode=[\"nearest\",\"nearest\"],\n                        prob=0.2,range_x=np.pi/6,range_y=np.pi/6,range_z=np.pi/6,\n                        keep_size=True,padding_mode=\"zeros\"\n                    ),\n                    RandSpatialCropd(keys=[\"image\",\"box_mask\"],roi_size=128, random_size=False),\n                    MaskToBoxd(\n                        box_mask_keys=\"box_mask\", box_keys=\"boxes\",\n                        label_keys=\"labels\", min_fg_label=0\n                    )\n                    DeleteItemsd(keys=[\"box_mask\"]),\n                ]\n            )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        box_mask_keys: KeysCollection,\n        label_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        min_fg_label: int,\n        ellipse_mask: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(box_keys, allow_missing_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        self.label_keys = ensure_tuple(label_keys)\n        self.box_mask_keys = ensure_tuple(box_mask_keys)\n        if not len(self.label_keys) == len(self.box_keys) == len(self.box_mask_keys):\n            raise ValueError(\"Please make sure len(label_keys)==len(box_keys)==len(box_mask_keys)!\")\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n        self.bg_label = min_fg_label - 1  # make sure background label is always smaller than fg labels.\n        self.converter = BoxToMask(bg_label=self.bg_label, ellipse_mask=ellipse_mask)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n\n        for box_key, label_key, box_mask_key, box_ref_image_key in zip(\n            self.box_keys, self.label_keys, self.box_mask_keys, self.box_ref_image_keys\n        ):\n            spatial_size = d[box_ref_image_key].shape[1:]\n            d[box_mask_key] = self.converter(d[box_key], d[label_key], spatial_size)\n            # make box mask background intensity to be 0, since the following operations may pad 0 on the border.\n            d[box_mask_key] -= self.bg_label\n        return d\n\n\nclass MaskToBoxd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.detection.transforms.array.MaskToBox`.\n    Pairs with :py:class:`monai.apps.detection.transforms.dictionary.BoxToMaskd` .\n    Please make sure the same ``min_fg_label`` is used when using the two transforms in pairs.\n\n    This is the general solution for transforms that need to be applied on images and boxes simultaneously.\n    It is performed with the following steps.\n\n        1) use ``BoxToMaskd`` to covert boxes and labels to box_masks;\n        2) do transforms, e.g., rotation or cropping, on images and box_masks together;\n        3) use ``MaskToBoxd`` to convert box_masks back to boxes and labels.\n\n    Args:\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_mask_keys: Keys to store output box mask results for transformation. Same length with ``box_keys``.\n        label_keys: Keys that represent the labels corresponding to the ``box_keys``. Same length with ``box_keys``.\n        min_fg_label: min foreground box label.\n        box_dtype: output dtype for box_keys\n        label_dtype: output dtype for label_keys\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Example:\n        .. code-block:: python\n\n            # This code snippet creates transforms (random rotation and cropping) on boxes, labels, and images together.\n            import numpy as np\n            from monai.transforms import Compose, RandRotated, RandSpatialCropd, DeleteItemsd\n            transforms = Compose(\n                [\n                    BoxToMaskd(\n                        box_keys=\"boxes\", label_keys=\"labels\",\n                        box_mask_keys=\"box_mask\", box_ref_image_keys=\"image\",\n                        min_fg_label=0, ellipse_mask=True\n                    ),\n                    RandRotated(keys=[\"image\",\"box_mask\"],mode=[\"nearest\",\"nearest\"],\n                        prob=0.2,range_x=np.pi/6,range_y=np.pi/6,range_z=np.pi/6,\n                        keep_size=True,padding_mode=\"zeros\"\n                    ),\n                    RandSpatialCropd(keys=[\"image\",\"box_mask\"],roi_size=128, random_size=False),\n                    MaskToBoxd(\n                        box_mask_keys=\"box_mask\", box_keys=\"boxes\",\n                        label_keys=\"labels\", min_fg_label=0\n                    )\n                    DeleteItemsd(keys=[\"box_mask\"]),\n                ]\n            )\n    \"\"\"\n\n    def __init__(\n        self,\n        box_keys: KeysCollection,\n        box_mask_keys: KeysCollection,\n        label_keys: KeysCollection,\n        min_fg_label: int,\n        box_dtype: DtypeLike | torch.dtype = torch.float32,\n        label_dtype: DtypeLike | torch.dtype = torch.long,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(box_keys, allow_missing_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        self.label_keys = ensure_tuple(label_keys)\n        self.box_mask_keys = ensure_tuple(box_mask_keys)\n        if not len(self.label_keys) == len(self.box_keys) == len(self.box_mask_keys):\n            raise ValueError(\"Please make sure len(label_keys)==len(box_keys)==len(box_mask_keys)!\")\n        self.bg_label = min_fg_label - 1  # make sure background label is always smaller than fg labels.\n        self.converter = MaskToBox(bg_label=self.bg_label, box_dtype=box_dtype, label_dtype=label_dtype)\n        self.box_dtype = box_dtype\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n\n        for box_key, label_key, box_mask_key in zip(self.box_keys, self.label_keys, self.box_mask_keys):\n            d[box_mask_key] += self.bg_label  # pairs with the operation in BoxToMaskd\n            d[box_key], d[label_key] = self.converter(d[box_mask_key])\n        return d\n\n\nclass RandCropBoxByPosNegLabeld(Randomizable, MapTransform):\n    \"\"\"\n    Crop random fixed sized regions that contains foreground boxes.\n    Suppose all the expected fields specified by `image_keys` have same shape,\n    and add `patch_index` to the corresponding meta data.\n    And will return a list of dictionaries for all the cropped images.\n    If a dimension of the expected spatial size is bigger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than the expected size,\n    and the cropped results of several images may not have exactly the same shape.\n\n    Args:\n        image_keys: Keys to pick image data for transformation. They need to have the same spatial size.\n        box_keys: The single key to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        label_keys: Keys that represent the labels corresponding to the ``box_keys``. Multiple keys are allowed.\n        spatial_size: the spatial size of the crop region e.g. [224, 224, 128].\n            if a dimension of ROI size is bigger than image size, will not crop that dimension of the image.\n            if its components have non-positive values, the corresponding size of `data[label_key]` will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for the probability\n            to pick a foreground voxel as a center rather than a background voxel.\n        neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for the probability\n            to pick a foreground voxel as a center rather than a background voxel.\n        num_samples: number of samples (crop regions) to take in each list.\n        whole_box: Bool, default True, whether we prefer to contain at least one whole box in the cropped foreground patch.\n            Even if True, it is still possible to get partial box if there are multiple boxes in the image.\n        thresh_image_key: if thresh_image_key is not None, use ``label == 0 & thresh_image > image_threshold`` to select\n            the negative sample(background) center. so the crop center will only exist on valid image area.\n        image_threshold: if enabled thresh_image_key, use ``thresh_image > image_threshold`` to determine\n            the valid image content area.\n        fg_indices_key: if provided pre-computed foreground indices of `label`, will ignore above `image_key` and\n            `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key`\n            and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening.\n            a typical usage is to call `FgBgToIndicesd` transform first and cache the results.\n        bg_indices_key: if provided pre-computed background indices of `label`, will ignore above `image_key` and\n            `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key`\n            and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening.\n            a typical usage is to call `FgBgToIndicesd` transform first and cache the results.\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            used to add `patch_index` to the meta dict.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            it can be a sequence of string, map to the `keys`.\n            if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n        meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            used to add `patch_index` to the meta dict.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: str,\n        label_keys: KeysCollection,\n        spatial_size: Sequence[int] | int,\n        pos: float = 1.0,\n        neg: float = 1.0,\n        num_samples: int = 1,\n        whole_box: bool = True,\n        thresh_image_key: str | None = None,\n        image_threshold: float = 0.0,\n        fg_indices_key: str | None = None,\n        bg_indices_key: str | None = None,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        allow_smaller: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        if len(self.image_keys) < 1:\n            raise ValueError(\"At least one image_keys should be provided.\")\n\n        MapTransform.__init__(self, self.image_keys, allow_missing_keys)\n\n        box_keys_tuple = ensure_tuple(box_keys)\n        if len(box_keys_tuple) != 1:\n            raise ValueError(\n                \"Please provide a single key for box_keys.\\\n                All label_keys are attached to this box_keys.\"\n            )\n        self.box_keys = box_keys_tuple[0]\n        self.label_keys = ensure_tuple(label_keys)\n\n        self.spatial_size_: tuple[int, ...] | Sequence[int] | int = spatial_size\n\n        if pos < 0 or neg < 0:\n            raise ValueError(f\"pos and neg must be nonnegative, got pos={pos} neg={neg}.\")\n        if pos + neg == 0:\n            raise ValueError(\"Incompatible values: pos=0 and neg=0.\")\n        self.pos_ratio = pos / (pos + neg)\n        if num_samples < 1:\n            raise ValueError(f\"num_samples needs to be positive int, got num_samples={num_samples}.\")\n        self.num_samples = num_samples\n        self.whole_box = whole_box\n\n        self.thresh_image_key = thresh_image_key\n        self.image_threshold = image_threshold\n        self.fg_indices_key = fg_indices_key\n        self.bg_indices_key = bg_indices_key\n\n        self.meta_keys = ensure_tuple_rep(None, len(self.image_keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.image_keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.image_keys))\n        self.centers: tuple[tuple] | None = None\n        self.allow_smaller = allow_smaller\n\n    def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray:\n        # We don't require crop center to be within the boxes.\n        # As along as the cropped patch contains a box, it is considered as a foreground patch.\n        # Positions within extended_boxes are crop centers for foreground patches\n        spatial_dims = len(image_size)\n        boxes_np, *_ = convert_data_type(boxes, np.ndarray)\n\n        extended_boxes = np.zeros_like(boxes_np, dtype=int)\n        boxes_start = np.ceil(boxes_np[:, :spatial_dims]).astype(int)\n        boxes_stop = np.floor(boxes_np[:, spatial_dims:]).astype(int)\n        for axis in range(spatial_dims):\n            if not self.whole_box:\n                extended_boxes[:, axis] = boxes_start[:, axis] - self.spatial_size[axis] // 2 + 1\n                extended_boxes[:, axis + spatial_dims] = boxes_stop[:, axis] + self.spatial_size[axis] // 2 - 1\n            else:\n                # the cropper will extend an additional pixel to the left side when the size is even\n                radius_left = self.spatial_size[axis] // 2\n                radius_right = self.spatial_size[axis] - radius_left - 1  # we subtract 1 for the center voxel\n                # extended box start\n                extended_boxes[:, axis] = boxes_stop[:, axis] - radius_right\n                extended_boxes[:, axis] = np.minimum(extended_boxes[:, axis], boxes_start[:, axis])\n                # extended box stop\n                extended_boxes[:, axis + spatial_dims] = boxes_start[:, axis] + radius_left\n                extended_boxes[:, axis + spatial_dims] = np.maximum(\n                    extended_boxes[:, axis + spatial_dims], boxes_stop[:, axis]\n                )\n        extended_boxes, _ = clip_boxes_to_image(extended_boxes, image_size, remove_empty=True)  # type: ignore\n        return extended_boxes\n\n    def randomize(  # type: ignore\n        self,\n        boxes: NdarrayOrTensor,\n        image_size: Sequence[int],\n        fg_indices: NdarrayOrTensor | None = None,\n        bg_indices: NdarrayOrTensor | None = None,\n        thresh_image: NdarrayOrTensor | None = None,\n    ) -> None:\n        if fg_indices is None or bg_indices is None:\n            # We don't require crop center to be within the boxes.\n            # As along as the cropped patch contains a box, it is considered as a foreground patch.\n            # Positions within extended_boxes are crop centers for foreground patches\n            extended_boxes_np = self.generate_fg_center_boxes_np(boxes, image_size)\n            mask_img = convert_box_to_mask(\n                extended_boxes_np, np.ones(extended_boxes_np.shape[0]), image_size, bg_label=0, ellipse_mask=False\n            )\n            mask_img = np.amax(mask_img, axis=0, keepdims=True)[0:1, ...]\n            fg_indices_, bg_indices_ = map_binary_to_indices(mask_img, thresh_image, self.image_threshold)\n        else:\n            fg_indices_ = fg_indices\n            bg_indices_ = bg_indices\n\n        self.centers = generate_pos_neg_label_crop_centers(\n            self.spatial_size,\n            self.num_samples,\n            self.pos_ratio,\n            image_size,\n            fg_indices_,\n            bg_indices_,\n            self.R,\n            self.allow_smaller,\n        )\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]:\n        d = dict(data)\n        image_size = d[self.image_keys[0]].shape[1:]\n        self.spatial_size = fall_back_tuple(self.spatial_size_, image_size)\n\n        # randomly sample crop centers\n        boxes = d[self.box_keys]\n        labels = [d[label_key] for label_key in self.label_keys]  # could be multiple arrays\n        fg_indices = d.pop(self.fg_indices_key, None) if self.fg_indices_key is not None else None\n        bg_indices = d.pop(self.bg_indices_key, None) if self.bg_indices_key is not None else None\n        thresh_image = d[self.thresh_image_key] if self.thresh_image_key else None\n        self.randomize(boxes, image_size, fg_indices, bg_indices, thresh_image)\n\n        if self.centers is None:\n            raise ValueError(\"no available ROI centers to crop.\")\n\n        # initialize returned list with shallow copy to preserve key ordering\n        results: list[dict[Hashable, torch.Tensor]] = [dict(d) for _ in range(self.num_samples)]\n\n        # crop images and boxes for each center.\n        for i, center in enumerate(self.centers):\n            results[i] = deepcopy(d)\n            # compute crop start and end, always crop, no padding\n            cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)\n            crop_start = [max(s.start, 0) for s in cropper.slices]\n            crop_end = [min(s.stop, image_size_a) for s, image_size_a in zip(cropper.slices, image_size)]\n            crop_slices = [slice(int(s), int(e)) for s, e in zip(crop_start, crop_end)]\n\n            # crop images\n            cropper = SpatialCrop(roi_slices=crop_slices)\n            for image_key in self.image_keys:\n                results[i][image_key] = cropper(d[image_key])\n\n            # crop boxes and labels\n            boxcropper = SpatialCropBox(roi_slices=crop_slices)\n            results[i][self.box_keys], cropped_labels = boxcropper(boxes, labels)\n            for label_key, cropped_labels_i in zip(self.label_keys, cropped_labels):\n                results[i][label_key] = cropped_labels_i\n\n        return results\n\n\nclass RotateBox90d(MapTransform, InvertibleTransform):\n    \"\"\"\n    Input boxes and images are rotated by 90 degrees\n    in the plane specified by ``spatial_axes`` for ``k`` times\n\n    Args:\n        image_keys: Keys to pick image data for transformation.\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        k: number of times to rotate by 90 degrees.\n        spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n            Default (0, 1), this is the first two axis in spatial dimensions.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RotateBox90.backend\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        k: int = 1,\n        spatial_axes: tuple[int, int] = (0, 1),\n        allow_missing_keys: bool = False,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        self.box_keys = ensure_tuple(box_keys)\n        super().__init__(self.image_keys + self.box_keys, allow_missing_keys)\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n        self.img_rotator = Rotate90(k, spatial_axes)\n        self.box_rotator = RotateBox90(k, spatial_axes)\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):\n            spatial_size = list(d[box_ref_image_key].shape[1:])\n            d[key] = self.box_rotator(d[key], spatial_size)\n            if self.img_rotator.k % 2 == 1:\n                # if k = 1 or 3, spatial_size will be transposed\n                spatial_size[self.img_rotator.spatial_axes[0]], spatial_size[self.img_rotator.spatial_axes[1]] = (\n                    spatial_size[self.img_rotator.spatial_axes[1]],\n                    spatial_size[self.img_rotator.spatial_axes[0]],\n                )\n            self.push_transform(d, key, extra_info={\"spatial_size\": spatial_size, \"type\": \"box_key\"})\n\n        for key in self.image_keys:\n            d[key] = self.img_rotator(d[key])\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key, check=False)\n            key_type = transform[TraceKeys.EXTRA_INFO].get(\"type\", \"image_key\")\n            num_times_to_rotate = 4 - self.img_rotator.k\n\n            if key_type == \"image_key\":\n                d[key] = self.img_rotator.inverse(d[key])\n            if key_type == \"box_key\":\n                spatial_size = transform[TraceKeys.EXTRA_INFO][\"spatial_size\"]\n                inverse_transform = RotateBox90(num_times_to_rotate, self.box_rotator.spatial_axes)\n                d[key] = inverse_transform(d[key], spatial_size)\n                self.pop_transform(d, key)\n        return d\n\n\nclass RandRotateBox90d(RandomizableTransform, MapTransform, InvertibleTransform):\n    \"\"\"\n    With probability `prob`, input boxes and images are rotated by 90 degrees\n    in the plane specified by `spatial_axes`.\n\n    Args:\n        image_keys: Keys to pick image data for transformation.\n        box_keys: Keys to pick box data for transformation. The box mode is assumed to be ``StandardMode``.\n        box_ref_image_keys: Keys that represent the reference images to which ``box_keys`` are attached.\n        prob: probability of rotating.\n            (Default 0.1, with 10% probability it returns a rotated array.)\n        max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.\n            (Default 3)\n        spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n            Default: (0, 1), this is the first two axis in spatial dimensions.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RotateBox90.backend\n\n    def __init__(\n        self,\n        image_keys: KeysCollection,\n        box_keys: KeysCollection,\n        box_ref_image_keys: KeysCollection,\n        prob: float = 0.1,\n        max_k: int = 3,\n        spatial_axes: tuple[int, int] = (0, 1),\n        allow_missing_keys: bool = False,\n    ) -> None:\n        self.image_keys = ensure_tuple(image_keys)\n        self.box_keys = ensure_tuple(box_keys)\n\n        MapTransform.__init__(self, self.image_keys + self.box_keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n\n        self.max_k = max_k\n        self.spatial_axes = spatial_axes\n        self._rand_k = 0\n        self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]:\n        self.randomize()\n        d = dict(data)\n\n        if self._rand_k % 4 == 0:\n            return d\n\n        # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need\n        # to be compatible with the random status of some previous integration tests\n        box_rotator = RotateBox90(self._rand_k, self.spatial_axes)\n        img_rotator = Rotate90(self._rand_k, self.spatial_axes)\n\n        for key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):\n            if self._do_transform:\n                spatial_size = list(d[box_ref_image_key].shape[1:])\n                d[key] = box_rotator(d[key], spatial_size)\n                if self._rand_k % 2 == 1:\n                    # if k = 1 or 3, spatial_size will be transposed\n                    spatial_size[self.spatial_axes[0]], spatial_size[self.spatial_axes[1]] = (\n                        spatial_size[self.spatial_axes[1]],\n                        spatial_size[self.spatial_axes[0]],\n                    )\n                self.push_transform(\n                    d, key, extra_info={\"rand_k\": self._rand_k, \"spatial_size\": spatial_size, \"type\": \"box_key\"}\n                )\n\n        for key in self.image_keys:\n            if self._do_transform:\n                d[key] = (\n                    img_rotator(d[key])\n                    if self._do_transform\n                    else convert_to_tensor(d[key], track_meta=get_track_meta())\n                )\n                if get_track_meta():\n                    xform = self.pop_transform(d[key], check=False) if self._do_transform else {}\n                    self.push_transform(d[key], extra_info=xform)\n        return d\n\n    def randomize(self, data: Any | None = None) -> None:\n        self._rand_k = self.R.randint(self.max_k) + 1\n        super().randomize(None)\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        if self._rand_k % 4 == 0:\n            return d\n\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key, check=False)\n            key_type = transform[TraceKeys.EXTRA_INFO].get(\"type\", \"image_key\")\n            # Check if random transform was actually performed (based on `prob`)\n            if transform[TraceKeys.DO_TRANSFORM]:\n                # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd\n                if key_type == \"image_key\":\n                    xform = self.pop_transform(d, key, check=False)\n                    d[key] = Rotate90().inverse_transform(d[key], xform[TraceKeys.EXTRA_INFO])\n                if key_type == \"box_key\":\n                    num_times_rotated = transform[TraceKeys.EXTRA_INFO][\"rand_k\"]\n                    num_times_to_rotate = 4 - num_times_rotated\n                    spatial_size = transform[TraceKeys.EXTRA_INFO][\"spatial_size\"]\n                    inverse_transform = RotateBox90(num_times_to_rotate, self.spatial_axes)\n                    d[key] = inverse_transform(d[key], spatial_size)\n                    self.pop_transform(d, key)\n        return d\n\n\nConvertBoxModeD = ConvertBoxModeDict = ConvertBoxModed\nConvertBoxToStandardModeD = ConvertBoxToStandardModeDict = ConvertBoxToStandardModed\nZoomBoxD = ZoomBoxDict = ZoomBoxd\nRandZoomBoxD = RandZoomBoxDict = RandZoomBoxd\nAffineBoxToImageCoordinateD = AffineBoxToImageCoordinateDict = AffineBoxToImageCoordinated\nFlipBoxD = FlipBoxDict = FlipBoxd\nRandFlipBoxD = RandFlipBoxDict = RandFlipBoxd\nClipBoxToImageD = ClipBoxToImageDict = ClipBoxToImaged\nBoxToMaskD = BoxToMaskDict = BoxToMaskd\nMaskToBoxD = MaskToBoxDict = MaskToBoxd\nRandCropBoxByPosNegLabelD = RandCropBoxByPosNegLabelDict = RandCropBoxByPosNegLabeld\nRotateBox90D = RotateBox90Dict = RotateBox90d\nRandRotateBox90D = RandRotateBox90Dict = RandRotateBox90d\nStandardizeEmptyBoxD = StandardizeEmptyBoxDict = StandardizeEmptyBoxd\n"
  },
  {
    "path": "monai/apps/detection/utils/ATSS_matcher.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py\n# which has the following license...\n# https://github.com/MIC-DKFZ/nnDetection/blob/main/LICENSE\n#\n# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#    http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n#\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\"\"\"\nThe functions in this script are adapted from nnDetection,\nhttps://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py\nwhich is adapted from torchvision.\n\nThese are the changes compared with nndetection:\n1) comments and docstrings;\n2) reformat;\n3) add a debug option to ATSSMatcher to help the users to tune parameters;\n4) add a corner case return in ATSSMatcher.compute_matches;\n5) add support for float16 cpu\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Sequence\nfrom typing import TypeVar\n\nimport torch\nfrom torch import Tensor\n\nfrom monai.data.box_utils import COMPUTE_DTYPE, box_iou, boxes_center_distance, centers_in_boxes\nfrom monai.utils.type_conversion import convert_to_tensor\n\n# -INF should be smaller than the lower bound of similarity_fn output.\nINF = float(\"inf\")\n\n\nclass Matcher(ABC):\n    \"\"\"\n    Base class of Matcher, which matches boxes and anchors to each other\n\n    Args:\n        similarity_fn: function for similarity computation between\n            boxes and anchors\n    \"\"\"\n\n    BELOW_LOW_THRESHOLD: int = -1\n    BETWEEN_THRESHOLDS: int = -2\n\n    def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):  # type: ignore\n        self.similarity_fn = similarity_fn\n\n    def __call__(\n        self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Compute matches for a single image\n\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.\n            num_anchors_per_level: number of anchors per feature pyramid level\n            num_anchors_per_loc: number of anchors per position\n\n        Returns:\n            - matrix which contains the similarity from each boxes to each anchor [N, M]\n            - vector which contains the matched box index for all\n                anchors (if background `BELOW_LOW_THRESHOLD` is used\n                and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]\n\n        Note:\n            ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,\n            also represented as \"xyxy\" ([xmin, ymin, xmax, ymax]) for 2D\n            and \"xyzxyz\" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D.\n        \"\"\"\n        if boxes.numel() == 0:\n            # no ground truth\n            num_anchors = anchors.shape[0]\n            match_quality_matrix = torch.tensor([]).to(anchors)\n            matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD)\n            return match_quality_matrix, matches\n        # at least one ground truth\n        return self.compute_matches(\n            boxes=boxes,\n            anchors=anchors,\n            num_anchors_per_level=num_anchors_per_level,\n            num_anchors_per_loc=num_anchors_per_loc,\n        )\n\n    @abstractmethod\n    def compute_matches(\n        self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Compute matches\n\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.\n            num_anchors_per_level: number of anchors per feature pyramid level\n            num_anchors_per_loc: number of anchors per position\n\n        Returns:\n            - matrix which contains the similarity from each boxes to each anchor [N, M]\n            - vector which contains the matched box index for all\n              anchors (if background `BELOW_LOW_THRESHOLD` is used\n              and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]\n        \"\"\"\n        raise NotImplementedError\n\n\nclass ATSSMatcher(Matcher):\n\n    def __init__(\n        self,\n        num_candidates: int = 4,\n        similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,  # type: ignore\n        center_in_gt: bool = True,\n        debug: bool = False,\n    ):\n        \"\"\"\n        Compute matching based on ATSS https://arxiv.org/abs/1912.02424\n        `Bridging the Gap Between Anchor-based and Anchor-free Detection\n        via Adaptive Training Sample Selection`\n\n        Args:\n            num_candidates: number of positions to select candidates from.\n                Smaller value will result in a higher matcher threshold and less matched candidates.\n            similarity_fn: function for similarity computation between boxes and anchors\n            center_in_gt: If False (default), matched anchor center points do not need\n                to lie withing the ground truth box. Recommend False for small objects.\n                If True, will result in a strict matcher and less matched candidates.\n            debug: if True, will print the matcher threshold in order to\n                tune ``num_candidates`` and ``center_in_gt``.\n        \"\"\"\n        super().__init__(similarity_fn=similarity_fn)\n        self.num_candidates = num_candidates\n        self.min_dist = 0.01\n        self.center_in_gt = center_in_gt\n        self.debug = debug\n        logging.info(\n            f\"Running ATSS Matching with num_candidates={self.num_candidates} and center_in_gt {self.center_in_gt}.\"\n        )\n\n    def compute_matches(\n        self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Compute matches according to ATTS for a single image\n        Adapted from\n        (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss/loss.py#L180-L184)\n\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n            anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``.\n            num_anchors_per_level: number of anchors per feature pyramid level\n            num_anchors_per_loc: number of anchors per position\n\n        Returns:\n            - matrix which contains the similarity from each boxes to each anchor [N, M]\n            - vector which contains the matched box index for all\n              anchors (if background `BELOW_LOW_THRESHOLD` is used\n              and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M]\n\n        Note:\n            ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,\n            also represented as \"xyxy\" ([xmin, ymin, xmax, ymax]) for 2D\n            and \"xyzxyz\" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D.\n        \"\"\"\n        num_gt = boxes.shape[0]\n        num_anchors = anchors.shape[0]\n\n        distances_, _, anchors_center = boxes_center_distance(boxes, anchors)  # num_boxes x anchors\n        distances = convert_to_tensor(distances_)\n\n        # select candidates based on center distance\n        candidate_idx_list = []\n        start_idx = 0\n        for _, apl in enumerate(num_anchors_per_level):\n            end_idx = start_idx + apl * num_anchors_per_loc\n\n            # topk: total number of candidates per position\n            topk = min(self.num_candidates * num_anchors_per_loc, apl)\n            # torch.topk() does not support float16 cpu, need conversion to float32 or float64\n            _, idx = distances[:, start_idx:end_idx].to(COMPUTE_DTYPE).topk(topk, dim=1, largest=False)\n            # idx: shape [num_boxes x topk]\n            candidate_idx_list.append(idx + start_idx)\n\n            start_idx = end_idx\n        # [num_boxes x num_candidates] (index of candidate anchors)\n        candidate_idx = torch.cat(candidate_idx_list, dim=1)\n\n        match_quality_matrix = self.similarity_fn(boxes, anchors)  # [num_boxes x anchors]\n        candidate_ious = match_quality_matrix.gather(1, candidate_idx)  # [num_boxes, n_candidates]\n\n        # corner case, n_candidates<=1 will make iou_std_per_gt NaN\n        if candidate_idx.shape[1] <= 1:\n            matches = -1 * torch.ones((num_anchors,), dtype=torch.long, device=boxes.device)\n            matches[candidate_idx] = 0\n            return match_quality_matrix, matches\n\n        # compute adaptive iou threshold\n        iou_mean_per_gt = candidate_ious.mean(dim=1)  # [num_boxes]\n        iou_std_per_gt = candidate_ious.std(dim=1)  # [num_boxes]\n        iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt  # [num_boxes]\n        is_pos = candidate_ious >= iou_thresh_per_gt[:, None]  # [num_boxes x n_candidates]\n        if self.debug:\n            print(f\"Anchor matcher threshold: {iou_thresh_per_gt}\")\n\n        if self.center_in_gt:  # can discard all candidates in case of very small objects :/\n            # center point of selected anchors needs to lie within the ground truth\n            boxes_idx = (\n                torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None]\n                .expand_as(candidate_idx)\n                .contiguous()\n            )  # [num_boxes x n_candidates]\n            is_in_gt_ = centers_in_boxes(\n                anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist\n            )\n            is_in_gt = convert_to_tensor(is_in_gt_)\n            is_pos = is_pos & is_in_gt.view_as(is_pos)  # [num_boxes x n_candidates]\n\n        # in case on anchor is assigned to multiple boxes, use box with highest IoU\n        # TODO: think about a better way to do this\n        for ng in range(num_gt):\n            candidate_idx[ng, :] += ng * num_anchors\n        ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1)\n        index = candidate_idx.view(-1)[is_pos.view(-1)]\n        ious_inf[index] = match_quality_matrix.view(-1)[index]\n        ious_inf = ious_inf.view_as(match_quality_matrix)\n\n        matched_vals, matches = ious_inf.to(COMPUTE_DTYPE).max(dim=0)\n        matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD\n        return match_quality_matrix, matches\n\n\nMatcherType = TypeVar(\"MatcherType\", bound=Matcher)\n"
  },
  {
    "path": "monai/apps/detection/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/detection/utils/anchor_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\"\"\"\nThis script is adapted from\nhttps://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom monai.utils import ensure_tuple\nfrom monai.utils.misc import issequenceiterable\nfrom monai.utils.module import look_up_option\n\n\nclass AnchorGenerator(nn.Module):\n    \"\"\"\n    This module is modified from torchvision to support both 2D and 3D images.\n\n    Module that generates anchors for a set of feature maps and\n    image sizes.\n\n    The module support computing anchors at multiple sizes and aspect ratios\n    per feature map.\n\n    sizes and aspect_ratios should have the same number of elements, and it should\n    correspond to the number of feature maps.\n\n    sizes[i] and aspect_ratios[i] can have an arbitrary number of elements.\n    For 2D images, anchor width and height w:h = 1:aspect_ratios[i,j]\n    For 3D images, anchor width, height, and depth w:h:d = 1:aspect_ratios[i,j,0]:aspect_ratios[i,j,1]\n\n    AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors\n    per spatial location for feature map i.\n\n    Args:\n        sizes: base size of each anchor.\n            len(sizes) is the number of feature maps, i.e., the number of output levels for\n            the feature pyramid network (FPN).\n            Each element of ``sizes`` is a Sequence which represents several anchor sizes for each feature map.\n        aspect_ratios: the aspect ratios of anchors. ``len(aspect_ratios) = len(sizes)``.\n            For 2D images, each element of ``aspect_ratios[i]`` is a Sequence of float.\n            For 3D images, each element of ``aspect_ratios[i]`` is a Sequence of 2 value Sequence.\n        indexing: choose from {``'ij'``, ``'xy'``}, optional,\n            Matrix (``'ij'``, default and recommended) or Cartesian (``'xy'``) indexing of output.\n\n            - Matrix (``'ij'``, default and recommended) indexing keeps the original axis not changed.\n            - To use other monai detection components, please set ``indexing = 'ij'``.\n            - Cartesian (``'xy'``) indexing swaps axis 0 and 1.\n            - For 2D cases, monai ``AnchorGenerator(sizes, aspect_ratios, indexing='xy')`` and\n              ``torchvision.models.detection.anchor_utils.AnchorGenerator(sizes, aspect_ratios)`` are equivalent.\n\n\n    Reference:.\n        https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py\n\n    Example:\n        .. code-block:: python\n\n            # 2D example inputs for a 2-level feature maps\n            sizes = ((10,12,14,16), (20,24,28,32))\n            base_aspect_ratios = (1., 0.5,  2.)\n            aspect_ratios = (base_aspect_ratios, base_aspect_ratios)\n            anchor_generator = AnchorGenerator(sizes, aspect_ratios)\n\n            # 3D example inputs for a 2-level feature maps\n            sizes = ((10,12,14,16), (20,24,28,32))\n            base_aspect_ratios = ((1., 1.), (1., 0.5), (0.5, 1.), (2., 2.))\n            aspect_ratios = (base_aspect_ratios, base_aspect_ratios)\n            anchor_generator = AnchorGenerator(sizes, aspect_ratios)\n    \"\"\"\n\n    __annotations__ = {\"cell_anchors\": list[torch.Tensor]}\n\n    def __init__(\n        self,\n        sizes: Sequence[Sequence[int]] = ((20, 30, 40),),\n        aspect_ratios: Sequence = (((0.5, 1), (1, 0.5)),),\n        indexing: str = \"ij\",\n    ) -> None:\n        super().__init__()\n\n        if not issequenceiterable(sizes[0]):\n            self.sizes = tuple((s,) for s in sizes)\n        else:\n            self.sizes = ensure_tuple(sizes)\n        if not issequenceiterable(aspect_ratios[0]):\n            aspect_ratios = (aspect_ratios,) * len(self.sizes)\n\n        if len(self.sizes) != len(aspect_ratios):\n            raise ValueError(\n                \"len(sizes) and len(aspect_ratios) should be equal. \\\n                It represents the number of feature maps.\"\n            )\n\n        spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1\n        spatial_dims = look_up_option(spatial_dims, [2, 3])\n        self.spatial_dims = spatial_dims\n\n        self.indexing = look_up_option(indexing, [\"ij\", \"xy\"])\n\n        self.aspect_ratios = aspect_ratios\n        self.cell_anchors = [\n            self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)\n        ]\n\n    # This comment comes from torchvision.\n    # TODO: https://github.com/pytorch/pytorch/issues/26792\n    # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.\n    # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)\n    # This method assumes aspect ratio = height / width for an anchor.\n    def generate_anchors(\n        self,\n        scales: Sequence,\n        aspect_ratios: Sequence,\n        dtype: torch.dtype = torch.float32,\n        device: torch.device | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map.\n\n        Args:\n            scales: a sequence which represents several anchor sizes for the current feature map.\n            aspect_ratios: a sequence which represents several aspect_ratios for the current feature map.\n                For 2D images, it is a Sequence of float aspect_ratios[j],\n                anchor width and height w:h = 1:aspect_ratios[j].\n                For 3D images, it is a Sequence of 2 value Sequence aspect_ratios[j,0] and aspect_ratios[j,1],\n                anchor width, height, and depth w:h:d = 1:aspect_ratios[j,0]:aspect_ratios[j,1]\n            dtype: target data type of the output Tensor.\n            device: target device to put the output Tensor data.\n\n            Returns:\n                For each s in scales, returns [s, s*aspect_ratios[j]] for 2D images,\n                and [s, s*aspect_ratios[j,0],s*aspect_ratios[j,1]] for 3D images.\n        \"\"\"\n        scales_t = torch.as_tensor(scales, dtype=dtype, device=device)  # sized (N,)\n        aspect_ratios_t = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)  # sized (M,) or (M,2)\n        if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2):\n            raise ValueError(\n                f\"In {self.spatial_dims}-D image, aspect_ratios for each level should be \\\n                {len(aspect_ratios_t.shape) - 1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}.\"\n            )\n\n        if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1):\n            raise ValueError(\n                f\"In {self.spatial_dims}-D image, aspect_ratios for each level should has \\\n                shape (_,{self.spatial_dims - 1}). But got aspect_ratios with shape {aspect_ratios_t.shape}.\"\n            )\n\n        # if 2d, w:h = 1:aspect_ratios\n        if self.spatial_dims == 2:\n            area_scale = torch.sqrt(aspect_ratios_t)\n            w_ratios = 1 / area_scale\n            h_ratios = area_scale\n        # if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]\n        else:\n            area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)\n            w_ratios = 1 / area_scale\n            h_ratios = aspect_ratios_t[:, 0] / area_scale\n            d_ratios = aspect_ratios_t[:, 1] / area_scale\n\n        ws = (w_ratios[:, None] * scales_t[None, :]).view(-1)\n        hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)\n        if self.spatial_dims == 2:\n            base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0\n        else:  # elif self.spatial_dims == 3:\n            ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)\n            base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0\n\n        return base_anchors.round()\n\n    def set_cell_anchors(self, dtype: torch.dtype, device: torch.device) -> None:\n        \"\"\"\n        Convert each element in self.cell_anchors to ``dtype`` and send to ``device``.\n        \"\"\"\n        self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]\n\n    def num_anchors_per_location(self):\n        \"\"\"\n        Return number of anchor shapes for each feature map.\n        \"\"\"\n        return [c.shape[0] for c in self.cell_anchors]\n\n    def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) -> list[Tensor]:\n        \"\"\"\n        Every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:spatial_dims)\n        corresponds to a feature map.\n        It outputs g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.\n\n        Args:\n            grid_sizes: spatial size of the feature maps\n            strides: strides of the feature maps regarding to the original image\n\n        Example:\n            .. code-block:: python\n\n                grid_sizes = [[100,100],[50,50]]\n                strides = [[torch.tensor(2),torch.tensor(2)], [torch.tensor(4),torch.tensor(4)]]\n        \"\"\"\n        anchors = []\n        cell_anchors = self.cell_anchors\n        if cell_anchors is None:\n            raise AssertionError\n\n        if not (len(grid_sizes) == len(strides) == len(cell_anchors)):\n            raise ValueError(\n                \"Anchors should be Tuple[Tuple[int]] because each feature \"\n                \"map could potentially have different sizes and aspect ratios. \"\n                \"There needs to be a match between the number of \"\n                \"feature maps passed and the number of sizes / aspect ratios specified.\"\n            )\n\n        for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):\n            # for each feature map\n            device = base_anchors.device\n\n            # compute anchor centers regarding to the image.\n            # shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]\n            shifts_centers = [\n                torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]\n                for axis in range(self.spatial_dims)\n            ]\n\n            # to support torchscript, cannot directly use torch.meshgrid(shifts_centers).\n            shifts_centers = list(torch.meshgrid(shifts_centers[: self.spatial_dims], indexing=\"ij\"))\n\n            for axis in range(self.spatial_dims):\n                # each element of shifts_centers is sized (HW,) or (HWD,)\n                shifts_centers[axis] = shifts_centers[axis].reshape(-1)\n\n            # Expand to [x_center, y_center, x_center, y_center],\n            # or [x_center, y_center, z_center, x_center, y_center, z_center]\n            if self.indexing == \"xy\":\n                # Cartesian ('xy') indexing swaps axis 0 and 1.\n                shifts_centers[1], shifts_centers[0] = shifts_centers[0], shifts_centers[1]\n            shifts = torch.stack(shifts_centers * 2, dim=1)  # sized (HW,4) or (HWD,6)\n\n            # For every (base anchor, output anchor) pair,\n            # offset each zero-centered base anchor by the center of the output anchor.\n            anchors.append(\n                (shifts.view(-1, 1, self.spatial_dims * 2) + base_anchors.view(1, -1, self.spatial_dims * 2)).reshape(\n                    -1, self.spatial_dims * 2\n                )  # each element sized (AHWD,4) or (AHWD,6)\n            )\n\n        return anchors\n\n    def forward(self, images: Tensor, feature_maps: list[Tensor]) -> list[Tensor]:\n        \"\"\"\n        Generate anchor boxes for each image.\n\n        Args:\n            images: sized (B, C, W, H) or (B, C, W, H, D)\n            feature_maps: for FPN level i, feature_maps[i] is sized (B, C_i, W_i, H_i) or (B, C_i, W_i, H_i, D_i).\n                This input argument does not have to be the actual feature maps.\n                Any list variable with the same (C_i, W_i, H_i) or (C_i, W_i, H_i, D_i) as feature maps works.\n\n        Return:\n            A list with length of B. Each element represents the anchors for this image.\n            The B elements are identical.\n\n        Example:\n            .. code-block:: python\n\n                images = torch.zeros((3,1,128,128,128))\n                feature_maps = [torch.zeros((3,6,64,64,32)), torch.zeros((3,6,32,32,16))]\n                anchor_generator(images, feature_maps)\n        \"\"\"\n        grid_sizes = [list(feature_map.shape[-self.spatial_dims :]) for feature_map in feature_maps]\n        image_size = images.shape[-self.spatial_dims :]\n        batchsize = images.shape[0]\n        dtype, device = feature_maps[0].dtype, feature_maps[0].device\n        strides = [\n            [\n                torch.tensor(image_size[axis] // g[axis], dtype=torch.int64, device=device)\n                for axis in range(self.spatial_dims)\n            ]\n            for g in grid_sizes\n        ]\n\n        self.set_cell_anchors(dtype, device)\n        anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)\n\n        anchors_per_image = torch.cat(list(anchors_over_all_feature_maps))\n        return [anchors_per_image] * batchsize\n\n\nclass AnchorGeneratorWithAnchorShape(AnchorGenerator):\n    \"\"\"\n    Module that generates anchors for a set of feature maps and\n    image sizes, inherited from :py:class:`~monai.apps.detection.networks.utils.anchor_utils.AnchorGenerator`\n\n    The module support computing anchors at multiple base anchor shapes\n    per feature map.\n\n    ``feature_map_scales`` should have the same number of elements with the number of feature maps.\n\n    base_anchor_shapes can have an arbitrary number of elements.\n    For 2D images, each element represents anchor width and height [w,h].\n    For 2D images, each element represents anchor width, height, and depth [w,h,d].\n\n    AnchorGenerator will output a set of ``len(base_anchor_shapes)`` anchors\n    per spatial location for feature map ``i``.\n\n    Args:\n        feature_map_scales: scale of anchors for each feature map, i.e., each output level of\n            the feature pyramid network (FPN). ``len(feature_map_scales)`` is the number of feature maps.\n            ``scale[i]*base_anchor_shapes`` represents the anchor shapes for feature map ``i``.\n        base_anchor_shapes: a sequence which represents several anchor shapes for one feature map.\n            For N-D images, it is a Sequence of N value Sequence.\n        indexing: choose from {'xy', 'ij'}, optional\n            Cartesian ('xy') or matrix ('ij', default) indexing of output.\n            Cartesian ('xy') indexing swaps axis 0 and 1, which is the setting inside torchvision.\n            matrix ('ij', default) indexing keeps the original axis not changed.\n            See also indexing in https://pytorch.org/docs/stable/generated/torch.meshgrid.html\n\n    Example:\n        .. code-block:: python\n\n            # 2D example inputs for a 2-level feature maps\n            feature_map_scales = (1, 2)\n            base_anchor_shapes = ((10, 10), (6, 12), (12, 6))\n            anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes)\n\n            # 3D example inputs for a 2-level feature maps\n            feature_map_scales = (1, 2)\n            base_anchor_shapes = ((10, 10, 10), (12, 12, 8), (10, 10, 6), (16, 16, 10))\n            anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes)\n    \"\"\"\n\n    __annotations__ = {\"cell_anchors\": list[torch.Tensor]}\n\n    def __init__(\n        self,\n        feature_map_scales: Sequence[int] | Sequence[float] = (1, 2, 4, 8),\n        base_anchor_shapes: Sequence[Sequence[int]] | Sequence[Sequence[float]] = (\n            (32, 32, 32),\n            (48, 20, 20),\n            (20, 48, 20),\n            (20, 20, 48),\n        ),\n        indexing: str = \"ij\",\n    ) -> None:\n        nn.Module.__init__(self)\n\n        spatial_dims = len(base_anchor_shapes[0])\n        spatial_dims = look_up_option(spatial_dims, [2, 3])\n        self.spatial_dims = spatial_dims\n\n        self.indexing = look_up_option(indexing, [\"ij\", \"xy\"])\n\n        base_anchor_shapes_t = torch.Tensor(base_anchor_shapes)\n        self.cell_anchors = [self.generate_anchors_using_shape(s * base_anchor_shapes_t) for s in feature_map_scales]\n\n    @staticmethod\n    def generate_anchors_using_shape(\n        anchor_shapes: torch.Tensor, dtype: torch.dtype = torch.float32, device: torch.device | None = None\n    ) -> torch.Tensor:\n        \"\"\"\n        Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map.\n\n        Args:\n            anchor_shapes: [w, h] or [w, h, d], sized (N, spatial_dims),\n                represents N anchor shapes for the current feature map.\n            dtype: target data type of the output Tensor.\n            device: target device to put the output Tensor data.\n\n        Returns:\n            For 2D images, returns [-w/2, -h/2, w/2, h/2];\n            For 3D images, returns [-w/2, -h/2, -d/2, w/2, h/2, d/2]\n        \"\"\"\n        half_anchor_shapes = anchor_shapes / 2.0\n        base_anchors = torch.cat([-half_anchor_shapes, half_anchor_shapes], dim=1)\n        return base_anchors.round().to(dtype=dtype, device=device)\n"
  },
  {
    "path": "monai/apps/detection/utils/box_coder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n#\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\"\"\"\nThis script is modified from torchvision to support N-D images,\n\nhttps://github.com/pytorch/vision/blob/main/torchvision/models/detection/_utils.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import Tensor\n\nfrom monai.data.box_utils import COMPUTE_DTYPE, CenterSizeMode, StandardMode, convert_box_mode, is_valid_box_values\nfrom monai.utils.module import look_up_option\n\n\ndef encode_boxes(gt_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:\n    \"\"\"\n    Encode a set of proposals with respect to some reference ground truth (gt) boxes.\n\n    Args:\n        gt_boxes: gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n        proposals: boxes to be encoded, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n        weights: the weights for ``(cx, cy, w, h) or (cx,cy,cz, w,h,d)``\n\n    Return:\n        encoded gt, target of box regression that is used to convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.\n    \"\"\"\n\n    if gt_boxes.shape[0] != proposals.shape[0]:\n        raise ValueError(\"gt_boxes.shape[0] should be equal to proposals.shape[0].\")\n    spatial_dims = look_up_option(len(weights), [4, 6]) // 2\n\n    if not is_valid_box_values(gt_boxes):\n        raise ValueError(\"gt_boxes is not valid. Please check if it contains empty boxes.\")\n    if not is_valid_box_values(proposals):\n        raise ValueError(\"proposals is not valid. Please check if it contains empty boxes.\")\n\n    # implementation starts here\n    ex_cccwhd: Tensor = convert_box_mode(proposals, src_mode=StandardMode, dst_mode=CenterSizeMode)  # type: ignore\n    gt_cccwhd: Tensor = convert_box_mode(gt_boxes, src_mode=StandardMode, dst_mode=CenterSizeMode)  # type: ignore\n    targets_dxyz = (\n        weights[:spatial_dims].unsqueeze(0)\n        * (gt_cccwhd[:, :spatial_dims] - ex_cccwhd[:, :spatial_dims])\n        / ex_cccwhd[:, spatial_dims:]\n    )\n    targets_dwhd = weights[spatial_dims:].unsqueeze(0) * torch.log(\n        gt_cccwhd[:, spatial_dims:] / ex_cccwhd[:, spatial_dims:]\n    )\n\n    targets = torch.cat((targets_dxyz, targets_dwhd), dim=1)\n    # torch.log may cause NaN or Inf\n    if torch.isnan(targets).any() or torch.isinf(targets).any():\n        raise ValueError(\"targets is NaN or Inf.\")\n    return targets\n\n\nclass BoxCoder:\n    \"\"\"\n    This class encodes and decodes a set of bounding boxes into\n    the representation used for training the regressors.\n\n    Args:\n        weights: 4-element tuple or 6-element tuple\n        boxes_xform_clip: high threshold to prevent sending too large values into torch.exp()\n\n    Example:\n        .. code-block:: python\n\n            box_coder = BoxCoder(weights=[1., 1., 1., 1., 1., 1.])\n            gt_boxes = torch.tensor([[1,2,1,4,5,6],[1,3,2,7,8,9]])\n            proposals = gt_boxes + torch.rand(gt_boxes.shape)\n            rel_gt_boxes = box_coder.encode_single(gt_boxes, proposals)\n            gt_back = box_coder.decode_single(rel_gt_boxes, proposals)\n            # We expect gt_back to be equal to gt_boxes\n    \"\"\"\n\n    def __init__(self, weights: Sequence[float], boxes_xform_clip: float | None = None) -> None:\n        if boxes_xform_clip is None:\n            boxes_xform_clip = math.log(1000.0 / 16)\n        self.spatial_dims = look_up_option(len(weights), [4, 6]) // 2\n        self.weights = weights\n        self.boxes_xform_clip = boxes_xform_clip\n\n    def encode(self, gt_boxes: Sequence[Tensor], proposals: Sequence[Tensor]) -> tuple[Tensor]:\n        \"\"\"\n        Encode a set of proposals with respect to some ground truth (gt) boxes.\n\n        Args:\n            gt_boxes: list of gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n            proposals: list of boxes to be encoded, each element is Mx4 or Mx6 torch tensor.\n                The box mode is assumed to be ``StandardMode``\n\n        Return:\n            A tuple of encoded gt, target of box regression that is used to\n                convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.\n        \"\"\"\n        boxes_per_image = [len(b) for b in gt_boxes]\n        # concat the lists to do computation\n        concat_gt_boxes = torch.cat(tuple(gt_boxes), dim=0)\n        concat_proposals = torch.cat(tuple(proposals), dim=0)\n        concat_targets = self.encode_single(concat_gt_boxes, concat_proposals)\n        # split to tuple\n        targets: tuple[Tensor] = concat_targets.split(boxes_per_image, 0)\n        return targets\n\n    def encode_single(self, gt_boxes: Tensor, proposals: Tensor) -> Tensor:\n        \"\"\"\n        Encode proposals with respect to ground truth (gt) boxes.\n\n        Args:\n            gt_boxes: gt boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n            proposals: boxes to be encoded, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n\n        Return:\n            encoded gt, target of box regression that is used to convert proposals into gt_boxes, Nx4 or Nx6 torch tensor.\n        \"\"\"\n        dtype = gt_boxes.dtype\n        device = gt_boxes.device\n        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)\n        targets = encode_boxes(gt_boxes, proposals, weights)\n        return targets\n\n    def decode(self, rel_codes: Tensor, reference_boxes: Sequence[Tensor]) -> Tensor:\n        \"\"\"\n        From a set of original reference_boxes and encoded relative box offsets,\n\n        Args:\n            rel_codes: encoded boxes, Nx4 or Nx6 torch tensor.\n            reference_boxes: a list of reference boxes, each element is Mx4 or Mx6 torch tensor.\n                The box mode is assumed to be ``StandardMode``\n\n        Return:\n            decoded boxes, Nx1x4 or Nx1x6 torch tensor. The box mode will be ``StandardMode``\n        \"\"\"\n        if not isinstance(reference_boxes, Sequence) or (not isinstance(rel_codes, torch.Tensor)):\n            raise ValueError(\"Input arguments wrong type.\")\n        boxes_per_image = [b.size(0) for b in reference_boxes]\n        # concat the lists to do computation\n        concat_boxes = torch.cat(tuple(reference_boxes), dim=0)\n        box_sum = 0\n        for val in boxes_per_image:\n            box_sum += val\n        if box_sum > 0:\n            rel_codes = rel_codes.reshape(box_sum, -1)\n        pred_boxes = self.decode_single(rel_codes, concat_boxes)\n        if box_sum > 0:\n            pred_boxes = pred_boxes.reshape(box_sum, -1, 2 * self.spatial_dims)\n        return pred_boxes\n\n    def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:\n        \"\"\"\n        From a set of original boxes and encoded relative box offsets,\n\n        Args:\n            rel_codes: encoded boxes, Nx(4*num_box_reg) or Nx(6*num_box_reg) torch tensor.\n            reference_boxes: reference boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n\n        Return:\n            decoded boxes, Nx(4*num_box_reg) or Nx(6*num_box_reg) torch tensor. The box mode will to be ``StandardMode``\n        \"\"\"\n        reference_boxes = reference_boxes.to(rel_codes.dtype)\n        offset = reference_boxes.shape[-1]\n\n        pred_boxes = []\n        boxes_cccwhd: torch.Tensor = convert_box_mode(\n            reference_boxes, src_mode=StandardMode, dst_mode=CenterSizeMode\n        )  # type: ignore[assignment]\n\n        for axis in range(self.spatial_dims):\n            whd_axis = boxes_cccwhd[:, axis + self.spatial_dims]\n            ctr_xyz_axis = boxes_cccwhd[:, axis]\n            dxyz_axis = rel_codes[:, axis::offset] / self.weights[axis]\n            dwhd_axis = rel_codes[:, self.spatial_dims + axis :: offset] / self.weights[axis + self.spatial_dims]\n            # Prevent sending too large values into torch.exp()\n            dwhd_axis = torch.clamp(dwhd_axis.to(COMPUTE_DTYPE), max=self.boxes_xform_clip)\n\n            pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]\n            pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]\n            pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)  # type: ignore[union-attr]\n\n            # When convert float32 to float16, Inf or Nan may occur\n            if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():\n                raise ValueError(\"pred_whd_axis is NaN or Inf.\")\n\n            # Distance from center to box's corner.\n            c_to_c_whd_axis = (\n                torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis  # type: ignore[arg-type]\n            )\n\n            pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)\n            pred_boxes.append(pred_ctr_xyx_axis + c_to_c_whd_axis)\n\n        pred_boxes = pred_boxes[::2] + pred_boxes[1::2]\n        pred_boxes_final = torch.stack(pred_boxes, dim=2).flatten(1)\n        return pred_boxes_final\n"
  },
  {
    "path": "monai/apps/detection/utils/box_selector.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\"\"\"\nPart of this script is adapted from\nhttps://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nimport torch\nfrom torch import Tensor\n\nfrom monai.data.box_utils import batched_nms, box_iou, clip_boxes_to_image\nfrom monai.transforms.utils_pytorch_numpy_unification import floor_divide\n\n\nclass BoxSelector:\n    \"\"\"\n    Box selector which selects the predicted boxes.\n    The box selection is performed with the following steps:\n\n    #. For each level, discard boxes with scores less than self.score_thresh.\n    #. For each level, keep boxes with top self.topk_candidates_per_level scores.\n    #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.\n    #. For the whole image, keep boxes with top self.detections_per_img scores.\n\n    Args:\n        apply_sigmoid: whether to apply sigmoid to get scores from classification logits\n        score_thresh: no box with scores less than score_thresh will be kept\n        topk_candidates_per_level: max number of boxes to keep for each level\n        nms_thresh: box overlapping threshold for NMS\n        detections_per_img: max number of boxes to keep for each image\n\n    Example:\n\n        .. code-block:: python\n\n            input_param = {\n                \"apply_sigmoid\": True,\n                \"score_thresh\": 0.1,\n                \"topk_candidates_per_level\": 2,\n                \"nms_thresh\": 0.1,\n                \"detections_per_img\": 5,\n            }\n            box_selector = BoxSelector(**input_param)\n            boxes = [torch.randn([3,6]), torch.randn([7,6])]\n            logits = [torch.randn([3,3]), torch.randn([7,3])]\n            spatial_size = (8,8,8)\n            selected_boxes, selected_scores, selected_labels = box_selector.select_boxes_per_image(\n                boxes, logits, spatial_size\n            )\n    \"\"\"\n\n    def __init__(\n        self,\n        box_overlap_metric: Callable = box_iou,\n        apply_sigmoid: bool = True,\n        score_thresh: float = 0.05,\n        topk_candidates_per_level: int = 1000,\n        nms_thresh: float = 0.5,\n        detections_per_img: int = 300,\n    ):\n        self.box_overlap_metric = box_overlap_metric\n\n        self.apply_sigmoid = apply_sigmoid\n        self.score_thresh = score_thresh\n        self.topk_candidates_per_level = topk_candidates_per_level\n        self.nms_thresh = nms_thresh\n        self.detections_per_img = detections_per_img\n\n    def select_top_score_idx_per_level(self, logits: Tensor) -> tuple[Tensor, Tensor, Tensor]:\n        \"\"\"\n        Select indices with highest scores.\n\n        The indices selection is performed with the following steps:\n\n        #. If self.apply_sigmoid, get scores by applying sigmoid to logits. Otherwise, use logits as scores.\n        #. Discard indices with scores less than self.score_thresh\n        #. Keep indices with top self.topk_candidates_per_level scores\n\n        Args:\n            logits: predicted classification logits, Tensor sized (N, num_classes)\n\n        Return:\n            - topk_idxs: selected M indices, Tensor sized (M, )\n            - selected_scores: selected M scores, Tensor sized (M, )\n            - selected_labels: selected M labels, Tensor sized (M, )\n        \"\"\"\n        num_classes = logits.shape[-1]\n\n        # apply sigmoid to classification logits if asked\n        if self.apply_sigmoid:\n            scores = torch.sigmoid(logits.to(torch.float32)).flatten()\n        else:\n            scores = logits.flatten()\n\n        # remove low scoring boxes\n        keep_idxs = scores > self.score_thresh\n        scores = scores[keep_idxs]\n        flatten_topk_idxs = torch.where(keep_idxs)[0]\n\n        # keep only topk scoring predictions\n        num_topk = min(self.topk_candidates_per_level, flatten_topk_idxs.size(0))\n        selected_scores, idxs = scores.to(torch.float32).topk(\n            num_topk\n        )  # half precision not implemented for cpu float16\n        flatten_topk_idxs = flatten_topk_idxs[idxs]\n\n        selected_labels = flatten_topk_idxs % num_classes\n\n        topk_idxs = floor_divide(flatten_topk_idxs, num_classes)\n        return topk_idxs, selected_scores, selected_labels  # type: ignore\n\n    def select_boxes_per_image(\n        self, boxes_list: list[Tensor], logits_list: list[Tensor], spatial_size: list[int] | tuple[int]\n    ) -> tuple[Tensor, Tensor, Tensor]:\n        \"\"\"\n        Postprocessing to generate detection result from classification logits and boxes.\n\n        The box selection is performed with the following steps:\n\n        #. For each level, discard boxes with scores less than self.score_thresh.\n        #. For each level, keep boxes with top self.topk_candidates_per_level scores.\n        #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.\n        #. For the whole image, keep boxes with top self.detections_per_img scores.\n\n        Args:\n            boxes_list: list of predicted boxes from a single image,\n                each element i is a Tensor sized (N_i, 2*spatial_dims)\n            logits_list: list of predicted classification logits from a single image,\n                each element i is a Tensor sized (N_i, num_classes)\n            spatial_size: spatial size of the image\n\n        Return:\n            - selected boxes, Tensor sized (P, 2*spatial_dims)\n            - selected_scores, Tensor sized (P, )\n            - selected_labels, Tensor sized (P, )\n        \"\"\"\n\n        if len(boxes_list) != len(logits_list):\n            raise ValueError(\n                \"len(boxes_list) should equal to len(logits_list). \"\n                f\"Got len(boxes_list)={len(boxes_list)}, len(logits_list)={len(logits_list)}\"\n            )\n\n        image_boxes = []\n        image_scores = []\n        image_labels = []\n\n        boxes_dtype = boxes_list[0].dtype\n        logits_dtype = logits_list[0].dtype\n\n        for boxes_per_level, logits_per_level in zip(boxes_list, logits_list):\n            # select topk boxes for each level\n            topk_idxs: Tensor\n            topk_idxs, scores_per_level, labels_per_level = self.select_top_score_idx_per_level(logits_per_level)\n            boxes_per_level = boxes_per_level[topk_idxs]\n\n            keep: Tensor\n            boxes_per_level, keep = clip_boxes_to_image(  # type: ignore\n                boxes_per_level, spatial_size, remove_empty=True\n            )\n            image_boxes.append(boxes_per_level)\n            image_scores.append(scores_per_level[keep])\n            image_labels.append(labels_per_level[keep])\n\n        image_boxes_t: Tensor = torch.cat(image_boxes, dim=0)\n        image_scores_t: Tensor = torch.cat(image_scores, dim=0)\n        image_labels_t: Tensor = torch.cat(image_labels, dim=0)\n\n        # non-maximum suppression on detected boxes from all levels\n        keep_t: Tensor = batched_nms(  # type: ignore\n            image_boxes_t,\n            image_scores_t,\n            image_labels_t,\n            self.nms_thresh,\n            box_overlap_metric=self.box_overlap_metric,\n            max_proposals=self.detections_per_img,\n        )\n\n        selected_boxes = image_boxes_t[keep_t].to(boxes_dtype)\n        selected_scores = image_scores_t[keep_t].to(logits_dtype)\n        selected_labels = image_labels_t[keep_t]\n\n        return selected_boxes, selected_scores, selected_labels\n"
  },
  {
    "path": "monai/apps/detection/utils/detector_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom monai.data.box_utils import standardize_empty_box\nfrom monai.transforms.croppad.array import SpatialPad\nfrom monai.transforms.utils import compute_divisible_spatial_size, convert_pad_mode\nfrom monai.utils import PytorchPadMode, ensure_tuple_rep\n\n\ndef check_input_images(input_images: list[Tensor] | Tensor, spatial_dims: int) -> None:\n    \"\"\"\n    Validate the input dimensionality (raise a `ValueError` if invalid).\n\n    Args:\n        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),\n            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).\n        spatial_dims: number of spatial dimensions of the images, 2 or 3.\n    \"\"\"\n    if isinstance(input_images, Tensor):\n        if len(input_images.shape) != spatial_dims + 2:\n            raise ValueError(\n                \"When input_images is a Tensor, its need to be (spatial_dims + 2)-D.\"\n                f\"In this case, it should be a {(spatial_dims + 2)}-D Tensor, got Tensor shape {input_images.shape}.\"\n            )\n    elif isinstance(input_images, list):\n        for img in input_images:\n            if len(img.shape) != spatial_dims + 1:\n                raise ValueError(\n                    \"When input_images is a List[Tensor], each element should have be (spatial_dims + 1)-D.\"\n                    f\"In this case, it should be a {(spatial_dims + 1)}-D Tensor, got Tensor shape {img.shape}.\"\n                )\n    else:\n        raise ValueError(\"input_images needs to be a List[Tensor] or Tensor.\")\n    return\n\n\ndef check_training_targets(\n    input_images: list[Tensor] | Tensor,\n    targets: list[dict[str, Tensor]] | None,\n    spatial_dims: int,\n    target_label_key: str,\n    target_box_key: str,\n) -> list[dict[str, Tensor]]:\n    \"\"\"\n    Validate the input images/targets during training (raise a `ValueError` if invalid).\n\n    Args:\n        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),\n            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).\n        targets: a list of dict. Each dict with two keys: target_box_key and target_label_key,\n            ground-truth boxes present in the image.\n        spatial_dims: number of spatial dimensions of the images, 2 or 3.\n        target_label_key: the expected key of target labels.\n        target_box_key: the expected key of target boxes.\n    \"\"\"\n    if targets is None:\n        raise ValueError(\"Please provide ground truth targets during training.\")\n\n    if len(input_images) != len(targets):\n        raise ValueError(f\"len(input_images) should equal to len(targets), got {len(input_images)}, {len(targets)}.\")\n\n    for i in range(len(targets)):\n        target = targets[i]\n        if (target_label_key not in target.keys()) or (target_box_key not in target.keys()):\n            raise ValueError(\n                f\"{target_label_key} and {target_box_key} are expected keys in targets. Got {target.keys()}.\"\n            )\n\n        boxes = target[target_box_key]\n        if not isinstance(boxes, torch.Tensor):\n            raise ValueError(f\"Expected target boxes to be of type Tensor, got {type(boxes)}.\")\n        if len(boxes.shape) != 2 or boxes.shape[-1] != 2 * spatial_dims:\n            if boxes.numel() == 0:\n                warnings.warn(\n                    f\"Warning: Given target boxes has shape of {boxes.shape}. \"\n                    f\"The detector reshaped it with boxes = torch.reshape(boxes, [0, {2 * spatial_dims}]).\"\n                )\n            else:\n                raise ValueError(\n                    f\"Expected target boxes to be a tensor of shape [N, {2 * spatial_dims}], got {boxes.shape}.).\"\n                )\n        if not torch.is_floating_point(boxes):\n            raise ValueError(f\"Expected target boxes to be a float tensor, got {boxes.dtype}.\")\n        targets[i][target_box_key] = standardize_empty_box(boxes, spatial_dims=spatial_dims)  # type: ignore\n\n        labels = target[target_label_key]\n        if torch.is_floating_point(labels):\n            warnings.warn(f\"Warning: Given target labels is {labels.dtype}. The detector converted it to torch.long.\")\n            targets[i][target_label_key] = labels.long()\n    return targets\n\n\ndef pad_images(\n    input_images: list[Tensor] | Tensor,\n    spatial_dims: int,\n    size_divisible: int | Sequence[int],\n    mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,\n    **kwargs: Any,\n) -> tuple[Tensor, list[list[int]]]:\n    \"\"\"\n    Pad the input images, so that the output spatial sizes are divisible by `size_divisible`.\n    It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.\n    Padded size (H, W) or (H, W, D) is divisible by size_divisible.\n    Default padding uses constant padding with value 0.0\n\n    Args:\n        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),\n            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).\n        spatial_dims: number of spatial dimensions of the images, 2D or 3D.\n        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.\n            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.\n        mode: available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        kwargs: other arguments for `torch.pad` function.\n\n    Return:\n        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor\n        - image_sizes, the original spatial size of each image\n    \"\"\"\n    size_divisible = ensure_tuple_rep(size_divisible, spatial_dims)\n\n    # If input_images: Tensor\n    if isinstance(input_images, Tensor):\n        orig_size = list(input_images.shape[-spatial_dims:])\n        new_size = compute_divisible_spatial_size(spatial_shape=orig_size, k=size_divisible)\n        all_pad_width = [(0, max(sp_i - orig_size[i], 0)) for i, sp_i in enumerate(new_size)]\n        pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1]\n        if max(pt_pad_width) == 0:\n            # if there is no need to pad\n            return input_images, [orig_size] * input_images.shape[0]\n        mode_: str = convert_pad_mode(dst=input_images, mode=mode)\n        return F.pad(input_images, pt_pad_width, mode=mode_, **kwargs), [orig_size] * input_images.shape[0]\n\n    # If input_images: List[Tensor])\n    image_sizes = [img.shape[-spatial_dims:] for img in input_images]\n    in_channels = input_images[0].shape[0]\n    dtype = input_images[0].dtype\n    device = input_images[0].device\n\n    # compute max_spatial_size\n    image_sizes_t = torch.tensor(image_sizes)\n    max_spatial_size_t, _ = torch.max(image_sizes_t, dim=0)\n\n    if len(max_spatial_size_t) != spatial_dims or len(size_divisible) != spatial_dims:\n        raise ValueError(\" Require len(max_spatial_size_t) == spatial_dims ==len(size_divisible).\")\n\n    max_spatial_size = compute_divisible_spatial_size(spatial_shape=list(max_spatial_size_t), k=size_divisible)\n\n    # allocate memory for the padded images\n    images = torch.zeros([len(image_sizes), in_channels] + list(max_spatial_size), dtype=dtype, device=device)\n\n    # Use `SpatialPad` to match sizes, padding in the end will not affect boxes\n    padder = SpatialPad(spatial_size=max_spatial_size, method=\"end\", mode=mode, **kwargs)\n    for idx, img in enumerate(input_images):\n        images[idx, ...] = padder(img)\n\n    return images, [list(ss) for ss in image_sizes]\n\n\ndef preprocess_images(\n    input_images: list[Tensor] | Tensor,\n    spatial_dims: int,\n    size_divisible: int | Sequence[int],\n    mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,\n    **kwargs: Any,\n) -> tuple[Tensor, list[list[int]]]:\n    \"\"\"\n    Preprocess the input images, including\n\n    - validate of the inputs\n    - pad the inputs so that the output spatial sizes are divisible by `size_divisible`.\n      It pads them at the end to create a (B, C, H, W) or (B, C, H, W, D) Tensor.\n      Padded size (H, W) or (H, W, D) is divisible by size_divisible.\n      Default padding uses constant padding with value 0.0\n\n    Args:\n        input_images: It can be 1) a tensor sized (B, C, H, W) or  (B, C, H, W, D),\n            or 2) a list of image tensors, each image i may have different size (C, H_i, W_i) or  (C, H_i, W_i, D_i).\n        spatial_dims: number of spatial dimensions of the images, 2 or 3.\n        size_divisible: int or Sequence[int], is the expected pattern on the input image shape.\n            If an int, the same `size_divisible` will be applied to all the input spatial dimensions.\n        mode: available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        kwargs: other arguments for `torch.pad` function.\n\n    Return:\n        - images, a (B, C, H, W) or (B, C, H, W, D) Tensor\n        - image_sizes, the original spatial size of each image\n    \"\"\"\n    check_input_images(input_images, spatial_dims)\n    size_divisible = ensure_tuple_rep(size_divisible, spatial_dims)\n\n    return pad_images(input_images, spatial_dims, size_divisible, mode, **kwargs)\n"
  },
  {
    "path": "monai/apps/detection/utils/hard_negative_sampler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/sampler.py\n# which has the following license...\n# https://github.com/MIC-DKFZ/nnDetection/blob/main/LICENSE\n#\n# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#    http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe functions in this script are adapted from nnDetection,\nhttps://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/sampler.py\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\n\nimport torch\nfrom torch import Tensor\n\n\nclass HardNegativeSamplerBase:\n    \"\"\"\n    Base class of hard negative sampler.\n\n    Hard negative sampler is used to suppress false positive rate in classification tasks.\n    During training, it select negative samples with high prediction scores.\n\n    The training workflow is described as the follows:\n    1) forward network and get prediction scores (classification prob/logits) for all the samples;\n    2) use hard negative sampler to choose negative samples with high prediction scores and some positive samples;\n    3) compute classification loss for the selected samples;\n    4) do back propagation.\n\n    Args:\n        pool_size: when we need ``num_neg`` hard negative samples, they will be randomly selected from\n            ``num_neg * pool_size`` negative samples with the highest prediction scores.\n            Larger ``pool_size`` gives more randomness, yet selects negative samples that are less 'hard',\n            i.e., negative samples with lower prediction scores.\n    \"\"\"\n\n    def __init__(self, pool_size: float = 10) -> None:\n        self.pool_size = pool_size\n\n    def select_negatives(self, negative: Tensor, num_neg: int, fg_probs: Tensor) -> Tensor:\n        \"\"\"\n        Select hard negative samples.\n\n        Args:\n            negative: indices of all the negative samples, sized (P,),\n                where P is the number of negative samples\n            num_neg: number of negative samples to sample\n            fg_probs: maximum foreground prediction scores (probability) across all the classes\n                for each sample, sized (A,), where A is the number of samples.\n\n        Returns:\n            binary mask of negative samples to choose, sized (A,),\n                where A is the number of samples in one image\n        \"\"\"\n        if negative.numel() > fg_probs.numel():\n            raise ValueError(\"The number of negative samples should not be larger than the number of all samples.\")\n\n        # sample pool size is ``num_neg * self.pool_size``\n        pool = int(num_neg * self.pool_size)\n        pool = min(negative.numel(), pool)  # protect against not enough negatives\n\n        # create a sample pool of highest scoring negative samples\n        _, negative_idx_pool = fg_probs[negative].to(torch.float32).topk(pool, dim=0, sorted=True)\n        hard_negative = negative[negative_idx_pool]\n\n        # select negatives from pool\n        perm2 = torch.randperm(hard_negative.numel(), device=hard_negative.device)[:num_neg]\n        selected_neg_idx = hard_negative[perm2]\n\n        # output a binary mask with same size of fg_probs that indicates selected negative samples.\n        neg_mask = torch.zeros_like(fg_probs, dtype=torch.uint8)\n        neg_mask[selected_neg_idx] = 1\n        return neg_mask\n\n\nclass HardNegativeSampler(HardNegativeSamplerBase):\n    \"\"\"\n    HardNegativeSampler is used to suppress false positive rate in classification tasks.\n    During training, it selects negative samples with high prediction scores.\n\n    The training workflow is described as the follows:\n    1) forward network and get prediction scores (classification prob/logits) for all the samples;\n    2) use hard negative sampler to choose negative samples with high prediction scores and some positive samples;\n    3) compute classification loss for the selected samples;\n    4) do back propagation.\n\n    Args:\n        batch_size_per_image: number of training samples to be randomly selected per image\n        positive_fraction: percentage of positive elements in the selected samples\n        min_neg: minimum number of negative samples to select if possible.\n        pool_size: when we need ``num_neg`` hard negative samples, they will be randomly selected from\n            ``num_neg * pool_size`` negative samples with the highest prediction scores.\n            Larger ``pool_size`` gives more randomness, yet selects negative samples that are less 'hard',\n            i.e., negative samples with lower prediction scores.\n    \"\"\"\n\n    def __init__(\n        self, batch_size_per_image: int, positive_fraction: float, min_neg: int = 1, pool_size: float = 10\n    ) -> None:\n        super().__init__(pool_size=pool_size)\n        self.min_neg = min_neg\n        self.batch_size_per_image = batch_size_per_image\n        self.positive_fraction = positive_fraction\n        logging.info(\"Sampling hard negatives on a per batch basis\")\n\n    def __call__(self, target_labels: list[Tensor], concat_fg_probs: Tensor) -> tuple[list[Tensor], list[Tensor]]:\n        \"\"\"\n        Select positives and hard negatives from list samples per image.\n        Hard negative sampler will be applied to each image independently.\n\n        Args:\n            target_labels: list of labels per image.\n                For image i in the batch, target_labels[i] is a Tensor sized (A_i,),\n                where A_i is the number of samples in image i.\n                Positive samples have positive labels, negative samples have label 0.\n            concat_fg_probs: concatenated maximum foreground probability for all the images, sized (R,),\n                where R is the sum of all samples inside one batch, i.e., R = A_0 + A_1 + ...\n\n        Returns:\n            - list of binary mask for positive samples\n            - list of binary mask for negative samples\n\n        Example:\n            .. code-block:: python\n\n                sampler = HardNegativeSampler(\n                    batch_size_per_image=6, positive_fraction=0.5, min_neg=1, pool_size=2\n                )\n                # two images with different number of samples\n                target_labels = [ torch.tensor([0,1]), torch.tensor([1,0,2,1])]\n                concat_fg_probs = torch.rand(6)\n                pos_idx_list, neg_idx_list = sampler(target_labels, concat_fg_probs)\n        \"\"\"\n        samples_per_image = [samples_in_image.shape[0] for samples_in_image in target_labels]\n        fg_probs = concat_fg_probs.split(samples_per_image, 0)\n        return self.select_samples_img_list(target_labels, fg_probs)\n\n    def select_samples_img_list(\n        self, target_labels: list[Tensor], fg_probs: list[Tensor]\n    ) -> tuple[list[Tensor], list[Tensor]]:\n        \"\"\"\n        Select positives and hard negatives from list samples per image.\n        Hard negative sampler will be applied to each image independently.\n\n        Args:\n            target_labels: list of labels per image.\n                For image i in the batch, target_labels[i] is a Tensor sized (A_i,),\n                where A_i is the number of samples in image i.\n                Positive samples have positive labels, negative samples have label 0.\n            fg_probs: list of maximum foreground probability per images,\n                For image i in the batch, target_labels[i] is a Tensor sized (A_i,),\n                where A_i is the number of samples in image i.\n\n        Returns:\n            - list of binary mask for positive samples\n            - list binary mask for negative samples\n\n        Example:\n            .. code-block:: python\n\n                sampler = HardNegativeSampler(\n                    batch_size_per_image=6, positive_fraction=0.5, min_neg=1, pool_size=2\n                )\n                # two images with different number of samples\n                target_labels = [ torch.tensor([0,1]), torch.tensor([1,0,2,1])]\n                fg_probs = [ torch.rand(2), torch.rand(4)]\n                pos_idx_list, neg_idx_list = sampler.select_samples_img_list(target_labels, fg_probs)\n        \"\"\"\n        pos_idx = []\n        neg_idx = []\n\n        if len(target_labels) != len(fg_probs):\n            raise ValueError(\n                \"Require len(target_labels) == len(fg_probs). \"\n                f\"Got len(target_labels)={len(target_labels)}, len(fg_probs)={len(fg_probs)}.\"\n            )\n        for labels_per_img, fg_probs_per_img in zip(target_labels, fg_probs):\n            pos_idx_per_image_mask, neg_idx_per_image_mask = self.select_samples_per_img(\n                labels_per_img, fg_probs_per_img\n            )\n            pos_idx.append(pos_idx_per_image_mask)\n            neg_idx.append(neg_idx_per_image_mask)\n\n        return pos_idx, neg_idx\n\n    def select_samples_per_img(self, labels_per_img: Tensor, fg_probs_per_img: Tensor) -> tuple[Tensor, Tensor]:\n        \"\"\"\n        Select positives and hard negatives from samples.\n\n        Args:\n            labels_per_img: labels, sized (A,).\n                Positive samples have positive labels, negative samples have label 0.\n            fg_probs_per_img: maximum foreground probability, sized (A,)\n\n        Returns:\n            - binary mask for positive samples, sized (A,)\n            - binary mask for negative samples, sized (A,)\n\n        Example:\n            .. code-block:: python\n\n                sampler = HardNegativeSampler(\n                    batch_size_per_image=6, positive_fraction=0.5, min_neg=1, pool_size=2\n                )\n                # two images with different number of samples\n                target_labels = torch.tensor([1,0,2,1])\n                fg_probs = torch.rand(4)\n                pos_idx, neg_idx = sampler.select_samples_per_img(target_labels, fg_probs)\n        \"\"\"\n        # for each image, find positive sample indices and negative sample indices\n        if labels_per_img.numel() != fg_probs_per_img.numel():\n            raise ValueError(\"labels_per_img and fg_probs_per_img should have same number of elements.\")\n\n        positive = torch.where(labels_per_img >= 1)[0]\n        negative = torch.where(labels_per_img == 0)[0]\n\n        num_pos = self.get_num_pos(positive)\n        pos_idx_per_image_mask = self.select_positives(positive, num_pos, labels_per_img)\n\n        num_neg = self.get_num_neg(negative, num_pos)\n        neg_idx_per_image_mask = self.select_negatives(negative, num_neg, fg_probs_per_img)\n\n        return pos_idx_per_image_mask, neg_idx_per_image_mask\n\n    def get_num_pos(self, positive: torch.Tensor) -> int:\n        \"\"\"\n        Number of positive samples to draw\n\n        Args:\n            positive: indices of positive samples\n\n        Returns:\n            number of positive sample\n        \"\"\"\n        # positive sample sampling\n        num_pos = int(self.batch_size_per_image * self.positive_fraction)\n        # protect against not enough positive examples\n        num_pos = min(positive.numel(), num_pos)\n        return num_pos\n\n    def get_num_neg(self, negative: torch.Tensor, num_pos: int) -> int:\n        \"\"\"\n        Sample enough negatives to fill up ``self.batch_size_per_image``\n\n        Args:\n            negative: indices of positive samples\n            num_pos: number of positive samples to draw\n\n        Returns:\n            number of negative samples\n        \"\"\"\n        # always assume at least one pos sample was sampled\n        num_neg = int(max(1, num_pos) * abs(1 - 1.0 / float(self.positive_fraction)))\n        # protect against not enough negative examples and sample at least self.min_neg if possible\n        num_neg = min(negative.numel(), max(num_neg, self.min_neg))\n        return num_neg\n\n    def select_positives(self, positive: Tensor, num_pos: int, labels: Tensor) -> Tensor:\n        \"\"\"\n        Select positive samples\n\n        Args:\n            positive: indices of positive samples, sized (P,),\n                where P is the number of positive samples\n            num_pos: number of positive samples to sample\n            labels: labels for all samples, sized (A,),\n                where A is the number of samples.\n\n        Returns:\n            binary mask of positive samples to choose, sized (A,),\n                where A is the number of samples in one image\n        \"\"\"\n        if positive.numel() > labels.numel():\n            raise ValueError(\"The number of positive samples should not be larger than the number of all samples.\")\n\n        perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]\n        pos_idx_per_image = positive[perm1]\n\n        # output a binary mask with same size of labels that indicates selected positive samples.\n        pos_idx_per_image_mask = torch.zeros_like(labels, dtype=torch.uint8)\n        pos_idx_per_image_mask[pos_idx_per_image] = 1\n        return pos_idx_per_image_mask\n"
  },
  {
    "path": "monai/apps/detection/utils/predict_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom monai.inferers import SlidingWindowInferer\n\n\ndef ensure_dict_value_to_list_(head_outputs: dict[str, list[Tensor]], keys: list[str] | None = None) -> None:\n    \"\"\"\n    An in-place function. We expect ``head_outputs`` to be Dict[str, List[Tensor]].\n    Yet if it is Dict[str, Tensor], this func converts it to Dict[str, List[Tensor]].\n    It will be modified in-place.\n\n    Args:\n        head_outputs: a Dict[str, List[Tensor]] or Dict[str, Tensor], will be modifier in-place\n        keys: the keys in head_output that need to have value type List[Tensor]. If not provided, will use head_outputs.keys().\n    \"\"\"\n    if keys is None:\n        keys = list(head_outputs.keys())\n\n    for k in keys:\n        value_k = head_outputs[k]  # Tensor or List[Tensor]\n        # convert value_k to List[Tensor]\n        if isinstance(value_k, Tensor):\n            head_outputs[k] = [value_k]\n        elif isinstance(value_k[0], Tensor):\n            head_outputs[k] = list(value_k)\n        else:\n            raise ValueError(\"The output of network should be Dict[str, List[Tensor]] or Dict[str, Tensor].\")\n\n\ndef check_dict_values_same_length(head_outputs: dict[str, list[Tensor]], keys: list[str] | None = None) -> None:\n    \"\"\"\n    We expect the values in ``head_outputs``: Dict[str, List[Tensor]] to have the same length.\n    Will raise ValueError if not.\n\n    Args:\n        head_outputs: a Dict[str, List[Tensor]] or Dict[str, Tensor]\n        keys: the keys in head_output that need to have values (List) with same length.\n            If not provided, will use head_outputs.keys().\n    \"\"\"\n    if keys is None:\n        keys = list(head_outputs.keys())\n\n    num_output_levels_list: list[int] = [len(head_outputs[k]) for k in keys]\n    num_output_levels = torch.unique(torch.tensor(num_output_levels_list))\n    if len(num_output_levels) != 1:\n        raise ValueError(f\"The values in the input dict should have the same length, Got {num_output_levels_list}.\")\n\n\ndef _network_sequence_output(images: Tensor, network: nn.Module, keys: list[str] | None = None) -> list[Tensor]:\n    \"\"\"\n    Decompose the output of network (a dict) into a list.\n\n    Args:\n        images: input of the network\n        keys: the keys in the network output whose values will be output in this func.\n            If not provided, will use all keys.\n\n    Return:\n        network output values concat to a single List[Tensor]\n    \"\"\"\n    head_outputs = network(images)\n\n    # if head_outputs is already a sequence of tensors, directly output it\n    if isinstance(head_outputs, (tuple, list)):\n        return list(head_outputs)\n\n    # if head_outputs is a dict\n    ensure_dict_value_to_list_(head_outputs, keys)\n    if keys is None:\n        keys = list(head_outputs.keys())\n    check_dict_values_same_length(head_outputs, keys)\n    head_outputs_sequence = []\n    for k in keys:\n        head_outputs_sequence += list(head_outputs[k])\n    return head_outputs_sequence\n\n\ndef predict_with_inferer(\n    images: Tensor, network: nn.Module, keys: list[str], inferer: SlidingWindowInferer | None = None\n) -> dict[str, list[Tensor]]:\n    \"\"\"\n    Predict network dict output with an inferer. Compared with directly output network(images),\n    it enables a sliding window inferer that can be used to handle large inputs.\n\n    Args:\n        images: input of the network, Tensor sized (B, C, H, W) or  (B, C, H, W, D)\n        network: a network that takes an image Tensor sized (B, C, H, W) or (B, C, H, W, D) as input\n            and outputs a dictionary Dict[str, List[Tensor]] or Dict[str, Tensor].\n        keys: the keys in the output dict, should be network output keys or a subset of them.\n        inferer: a SlidingWindowInferer to handle large inputs.\n\n    Return:\n        The predicted head_output from network, a Dict[str, List[Tensor]]\n\n    Example:\n        .. code-block:: python\n\n            # define a naive network\n            import torch\n            import monai\n            class NaiveNet(torch.nn.Module):\n                def __init__(self, ):\n                    super().__init__()\n\n                def forward(self, images: torch.Tensor):\n                    return {\"cls\": torch.randn(images.shape), \"box_reg\": [torch.randn(images.shape)]}\n\n            # create a predictor\n            network = NaiveNet()\n            inferer = monai.inferers.SlidingWindowInferer(\n                roi_size = (128, 128, 128),\n                overlap = 0.25,\n                cache_roi_weight_map = True,\n            )\n            network_output_keys=[\"cls\", \"box_reg\"]\n            images = torch.randn((2, 3, 512, 512, 512))  # a large input\n            head_outputs = predict_with_inferer(images, network, network_output_keys, inferer)\n\n    \"\"\"\n    if inferer is None:\n        raise ValueError(\"Please set inferer as a monai.inferers.inferer.SlidingWindowInferer(*)\")\n    head_outputs_sequence = inferer(images, _network_sequence_output, network, keys=keys)\n    num_output_levels: int = len(head_outputs_sequence) // len(keys)\n    head_outputs = {}\n    for i, k in enumerate(keys):\n        head_outputs[k] = list(head_outputs_sequence[num_output_levels * i : num_output_levels * (i + 1)])\n    return head_outputs\n"
  },
  {
    "path": "monai/apps/generation/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/generation/maisi/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/generation/maisi/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/generation/maisi/networks/autoencoderkl_maisi.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport gc\nimport logging\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.blocks.spatialattention import SpatialAttentionBlock\nfrom monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL\nfrom monai.utils.type_conversion import convert_to_tensor\n\n# Set up logging configuration\nlogger = logging.getLogger(__name__)\n\n\ndef _empty_cuda_cache(save_mem: bool) -> None:\n    if torch.cuda.is_available() and save_mem:\n        torch.cuda.empty_cache()\n    return\n\n\nclass MaisiGroupNorm3D(nn.GroupNorm):\n    \"\"\"\n    Custom 3D Group Normalization with optional print_info output.\n\n    Args:\n        num_groups: Number of groups for the group norm.\n        num_channels: Number of channels for the group norm.\n        eps: Epsilon value for numerical stability.\n        affine: Whether to use learnable affine parameters, default to `True`.\n        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.\n        print_info: Whether to print information, default to `False`.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_groups: int,\n        num_channels: int,\n        eps: float = 1e-5,\n        affine: bool = True,\n        norm_float16: bool = False,\n        print_info: bool = False,\n        save_mem: bool = True,\n    ):\n        super().__init__(num_groups, num_channels, eps, affine)\n        self.norm_float16 = norm_float16\n        self.print_info = print_info\n        self.save_mem = save_mem\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if self.print_info:\n            logger.info(f\"MaisiGroupNorm3D with input size: {input.size()}\")\n\n        if len(input.shape) != 5:\n            raise ValueError(\"Expected a 5D tensor\")\n\n        param_n, param_c, param_d, param_h, param_w = input.shape\n        input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)\n\n        inputs = []\n        for i in range(input.size(1)):\n            array = input[:, i : i + 1, ...].to(dtype=torch.float32)\n            mean = array.mean([2, 3, 4, 5], keepdim=True)\n            std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()\n            if self.norm_float16:\n                inputs.append(((array - mean) / std).to(dtype=torch.float16))\n            else:\n                inputs.append((array - mean) / std)\n\n        del input\n        _empty_cuda_cache(self.save_mem)\n\n        input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs)\n\n        input = input.view(param_n, param_c, param_d, param_h, param_w)\n        if self.affine:\n            input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1))\n\n        if self.print_info:\n            logger.info(f\"MaisiGroupNorm3D with output size: {input.size()}\")\n\n        return input\n\n    def _cat_inputs(self, inputs):\n        input_type = inputs[0].device.type\n        input = inputs[0].clone().to(\"cpu\", non_blocking=True) if input_type == \"cuda\" else inputs[0].clone()\n        inputs[0] = 0\n        _empty_cuda_cache(self.save_mem)\n\n        for k in range(len(inputs) - 1):\n            input = torch.cat((input, inputs[k + 1].cpu()), dim=1)\n            inputs[k + 1] = 0\n            _empty_cuda_cache(self.save_mem)\n            gc.collect()\n\n            if self.print_info:\n                logger.info(f\"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.\")\n\n        return input.to(\"cuda\", non_blocking=True) if input_type == \"cuda\" else input\n\n\nclass MaisiConvolution(nn.Module):\n    \"\"\"\n    Convolutional layer with optional print_info output and custom splitting mechanism.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        in_channels: Number of input channels.\n        out_channels: Number of output channels.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        print_info: Whether to print information.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n        Additional arguments for the convolution operation.\n        https://monai.readthedocs.io/en/stable/networks.html#convolution\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        num_splits: int,\n        dim_split: int,\n        print_info: bool,\n        save_mem: bool = True,\n        strides: Sequence[int] | int = 1,\n        kernel_size: Sequence[int] | int = 3,\n        adn_ordering: str = \"NDA\",\n        act: tuple | str | None = \"PRELU\",\n        norm: tuple | str | None = \"INSTANCE\",\n        dropout: tuple | str | float | None = None,\n        dropout_dim: int = 1,\n        dilation: Sequence[int] | int = 1,\n        groups: int = 1,\n        bias: bool = True,\n        conv_only: bool = False,\n        is_transposed: bool = False,\n        padding: Sequence[int] | int | None = None,\n        output_padding: Sequence[int] | int | None = None,\n    ) -> None:\n        super().__init__()\n        self.conv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            strides=strides,\n            kernel_size=kernel_size,\n            adn_ordering=adn_ordering,\n            act=act,\n            norm=norm,\n            dropout=dropout,\n            dropout_dim=dropout_dim,\n            dilation=dilation,\n            groups=groups,\n            bias=bias,\n            conv_only=conv_only,\n            is_transposed=is_transposed,\n            padding=padding,\n            output_padding=output_padding,\n        )\n\n        self.dim_split = dim_split\n        self.stride = strides[self.dim_split] if isinstance(strides, list) else strides\n        self.num_splits = num_splits\n        self.print_info = print_info\n        self.save_mem = save_mem\n\n    def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]:\n        overlaps = [0] + [padding] * (self.num_splits - 1)\n        last_padding = x.size(self.dim_split + 2) % split_size\n\n        slices = [slice(None)] * 5\n        splits: list[torch.Tensor] = []\n        for i in range(self.num_splits):\n            slices[self.dim_split + 2] = slice(\n                i * split_size - overlaps[i],\n                (i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding),\n            )\n            splits.append(x[tuple(slices)])\n\n        if self.print_info:\n            for j in range(len(splits)):\n                logger.info(f\"Split {j + 1}/{len(splits)} size: {splits[j].size()}\")\n\n        return splits\n\n    def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor:\n        slices = [slice(None)] * 5\n        for i in range(self.num_splits):\n            slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size)\n            outputs[i] = outputs[i][tuple(slices)]\n\n        if self.print_info:\n            for i in range(self.num_splits):\n                logger.info(f\"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}\")\n\n        if max(outputs[0].size()) < 500:\n            x = torch.cat(outputs, dim=self.dim_split + 2)\n        else:\n            target_device = outputs[0].device\n\n            x = outputs[0].clone().to(\"cpu\", non_blocking=True)\n            outputs[0] = torch.Tensor(0)\n            _empty_cuda_cache(self.save_mem)\n            for k in range(len(outputs) - 1):\n                x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2)\n                outputs[k + 1] = torch.Tensor(0)\n                _empty_cuda_cache(self.save_mem)\n                gc.collect()\n                if self.print_info:\n                    logger.info(f\"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.\")\n\n            if target_device.type != \"cpu\":\n                x = x.to(target_device, non_blocking=True)\n\n        return x\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.print_info:\n            logger.info(f\"Number of splits: {self.num_splits}\")\n\n        if self.num_splits <= 1:\n            x = self.conv(x)\n            return x\n\n        # compute size of splits\n        l = x.size(self.dim_split + 2)\n        split_size = l // self.num_splits\n\n        # update padding length if necessary\n        padding = 3\n        if padding % self.stride > 0:\n            padding = (padding // self.stride + 1) * self.stride\n        if self.print_info:\n            logger.info(f\"Padding size: {padding}\")\n\n        # split tensor into a list of tensors\n        splits = self._split_tensor(x, split_size, padding)\n\n        del x\n        _empty_cuda_cache(self.save_mem)\n\n        # convolution\n        outputs = [self.conv(split) for split in splits]\n        if self.print_info:\n            for j in range(len(outputs)):\n                logger.info(f\"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}\")\n\n        # update size of splits and padding length for output\n        split_size_out = split_size\n        padding_s = padding\n        non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0\n        if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2:\n            split_size_out *= 2\n            padding_s *= 2\n        elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2:\n            split_size_out //= 2\n            padding_s //= 2\n\n        # concatenate list of tensors\n        x = self._concatenate_tensors(outputs, split_size_out, padding_s)\n\n        del outputs\n        _empty_cuda_cache(self.save_mem)\n\n        return x\n\n\nclass MaisiUpsample(nn.Module):\n    \"\"\"\n    Convolution-based upsampling layer.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        in_channels: Number of input channels to the layer.\n        use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        print_info: Whether to print information.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        use_convtranspose: bool,\n        num_splits: int,\n        dim_split: int,\n        print_info: bool,\n        save_mem: bool = True,\n    ) -> None:\n        super().__init__()\n        self.conv = MaisiConvolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            strides=2 if use_convtranspose else 1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n            is_transposed=use_convtranspose,\n            num_splits=num_splits,\n            dim_split=dim_split,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n        self.use_convtranspose = use_convtranspose\n        self.save_mem = save_mem\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.use_convtranspose:\n            x = self.conv(x)\n            x_tensor: torch.Tensor = convert_to_tensor(x)\n            return x_tensor\n\n        x = F.interpolate(x, scale_factor=2.0, mode=\"trilinear\")\n        _empty_cuda_cache(self.save_mem)\n        x = self.conv(x)\n        _empty_cuda_cache(self.save_mem)\n\n        out_tensor: torch.Tensor = convert_to_tensor(x)\n        return out_tensor\n\n\nclass MaisiDownsample(nn.Module):\n    \"\"\"\n    Convolution-based downsampling layer.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        in_channels: Number of input channels.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        print_info: Whether to print information.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_splits: int,\n        dim_split: int,\n        print_info: bool,\n        save_mem: bool = True,\n    ) -> None:\n        super().__init__()\n        self.pad = (0, 1) * spatial_dims\n        self.conv = MaisiConvolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            strides=2,\n            kernel_size=3,\n            padding=0,\n            conv_only=True,\n            num_splits=num_splits,\n            dim_split=dim_split,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = F.pad(x, self.pad, mode=\"constant\", value=0.0)\n        x = self.conv(x)\n        return x\n\n\nclass MaisiResBlock(nn.Module):\n    \"\"\"\n    Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a\n    residual connection between input and output.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        in_channels: Input channels to the layer.\n        norm_num_groups: Number of groups for the group norm layer.\n        norm_eps: Epsilon for the normalization.\n        out_channels: Number of output channels.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.\n        print_info: Whether to print information, default to `False`.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        norm_num_groups: int,\n        norm_eps: float,\n        out_channels: int,\n        num_splits: int,\n        dim_split: int,\n        norm_float16: bool = False,\n        print_info: bool = False,\n        save_mem: bool = True,\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels if out_channels is None else out_channels\n        self.save_mem = save_mem\n\n        self.norm1 = MaisiGroupNorm3D(\n            num_groups=norm_num_groups,\n            num_channels=in_channels,\n            eps=norm_eps,\n            affine=True,\n            norm_float16=norm_float16,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n        self.conv1 = MaisiConvolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.in_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n            num_splits=num_splits,\n            dim_split=dim_split,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n        self.norm2 = MaisiGroupNorm3D(\n            num_groups=norm_num_groups,\n            num_channels=out_channels,\n            eps=norm_eps,\n            affine=True,\n            norm_float16=norm_float16,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n        self.conv2 = MaisiConvolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.out_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n            num_splits=num_splits,\n            dim_split=dim_split,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n\n        self.nin_shortcut = (\n            MaisiConvolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.in_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=1,\n                padding=0,\n                conv_only=True,\n                num_splits=num_splits,\n                dim_split=dim_split,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n            if self.in_channels != self.out_channels\n            else nn.Identity()\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        h = self.norm1(x)\n        _empty_cuda_cache(self.save_mem)\n\n        h = F.silu(h)\n        _empty_cuda_cache(self.save_mem)\n        h = self.conv1(h)\n        _empty_cuda_cache(self.save_mem)\n\n        h = self.norm2(h)\n        _empty_cuda_cache(self.save_mem)\n\n        h = F.silu(h)\n        _empty_cuda_cache(self.save_mem)\n        h = self.conv2(h)\n        _empty_cuda_cache(self.save_mem)\n\n        if self.in_channels != self.out_channels:\n            x = self.nin_shortcut(x)\n            _empty_cuda_cache(self.save_mem)\n\n        out = x + h\n        out_tensor: torch.Tensor = convert_to_tensor(out)\n        return out_tensor\n\n\nclass MaisiEncoder(nn.Module):\n    \"\"\"\n    Convolutional cascade that downsamples the image into a spatial latent space.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        in_channels: Number of input channels.\n        num_channels: Sequence of block output channels.\n        out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.\n        num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.\n        norm_num_groups: Number of groups for the group norm layers.\n        norm_eps: Epsilon for the normalization.\n        attention_levels: Indicate which level from num_channels contain an attention block.\n        with_nonlocal_attn: If True, use non-local attention block.\n        include_fc: whether to include the final linear layer in the attention block. Default to False.\n        use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.\n        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.\n        print_info: Whether to print information, default to `False`.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_channels: Sequence[int],\n        out_channels: int,\n        num_res_blocks: Sequence[int],\n        norm_num_groups: int,\n        norm_eps: float,\n        attention_levels: Sequence[bool],\n        num_splits: int,\n        dim_split: int,\n        norm_float16: bool = False,\n        print_info: bool = False,\n        save_mem: bool = True,\n        with_nonlocal_attn: bool = True,\n        include_fc: bool = False,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n\n        # Check if attention_levels and num_channels have the same size\n        if len(attention_levels) != len(num_channels):\n            raise ValueError(\"attention_levels and num_channels must have the same size\")\n\n        # Check if num_res_blocks and num_channels have the same size\n        if len(num_res_blocks) != len(num_channels):\n            raise ValueError(\"num_res_blocks and num_channels must have the same size\")\n\n        self.save_mem = save_mem\n\n        blocks: list[nn.Module] = []\n\n        blocks.append(\n            MaisiConvolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=num_channels[0],\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n                num_splits=num_splits,\n                dim_split=dim_split,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n        )\n\n        output_channel = num_channels[0]\n        for i in range(len(num_channels)):\n            input_channel = output_channel\n            output_channel = num_channels[i]\n            is_final_block = i == len(num_channels) - 1\n\n            for _ in range(num_res_blocks[i]):\n                blocks.append(\n                    MaisiResBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=input_channel,\n                        norm_num_groups=norm_num_groups,\n                        norm_eps=norm_eps,\n                        out_channels=output_channel,\n                        num_splits=num_splits,\n                        dim_split=dim_split,\n                        norm_float16=norm_float16,\n                        print_info=print_info,\n                        save_mem=save_mem,\n                    )\n                )\n                input_channel = output_channel\n                if attention_levels[i]:\n                    blocks.append(\n                        SpatialAttentionBlock(\n                            spatial_dims=spatial_dims,\n                            num_channels=input_channel,\n                            norm_num_groups=norm_num_groups,\n                            norm_eps=norm_eps,\n                            include_fc=include_fc,\n                            use_combined_linear=use_combined_linear,\n                            use_flash_attention=use_flash_attention,\n                        )\n                    )\n\n            if not is_final_block:\n                blocks.append(\n                    MaisiDownsample(\n                        spatial_dims=spatial_dims,\n                        in_channels=input_channel,\n                        num_splits=num_splits,\n                        dim_split=dim_split,\n                        print_info=print_info,\n                        save_mem=save_mem,\n                    )\n                )\n\n        if with_nonlocal_attn:\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=num_channels[-1],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=num_channels[-1],\n                )\n            )\n\n            blocks.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=num_channels[-1],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=num_channels[-1],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=num_channels[-1],\n                )\n            )\n\n        blocks.append(\n            MaisiGroupNorm3D(\n                num_groups=norm_num_groups,\n                num_channels=num_channels[-1],\n                eps=norm_eps,\n                affine=True,\n                norm_float16=norm_float16,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n        )\n        blocks.append(\n            MaisiConvolution(\n                spatial_dims=spatial_dims,\n                in_channels=num_channels[-1],\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n                num_splits=num_splits,\n                dim_split=dim_split,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n        )\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            x = block(x)\n            _empty_cuda_cache(self.save_mem)\n        return x\n\n\nclass MaisiDecoder(nn.Module):\n    \"\"\"\n    Convolutional cascade upsampling from a spatial latent space into an image space.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        num_channels: Sequence of block output channels.\n        in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.\n        out_channels: Number of output channels.\n        num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.\n        norm_num_groups: Number of groups for the group norm layers.\n        norm_eps: Epsilon for the normalization.\n        attention_levels: Indicate which level from num_channels contain an attention block.\n        with_nonlocal_attn: If True, use non-local attention block.\n        include_fc: whether to include the final linear layer in the attention block. Default to False.\n        use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.\n        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.\n        use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.\n        print_info: Whether to print information, default to `False`.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        num_channels: Sequence[int],\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: Sequence[int],\n        norm_num_groups: int,\n        norm_eps: float,\n        attention_levels: Sequence[bool],\n        num_splits: int,\n        dim_split: int,\n        norm_float16: bool = False,\n        print_info: bool = False,\n        save_mem: bool = True,\n        with_nonlocal_attn: bool = True,\n        include_fc: bool = False,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n        use_convtranspose: bool = False,\n    ) -> None:\n        super().__init__()\n        self.print_info = print_info\n        self.save_mem = save_mem\n\n        reversed_block_out_channels = list(reversed(num_channels))\n\n        blocks: list[nn.Module] = []\n\n        blocks.append(\n            MaisiConvolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=reversed_block_out_channels[0],\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n                num_splits=num_splits,\n                dim_split=dim_split,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n        )\n\n        if with_nonlocal_attn:\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=reversed_block_out_channels[0],\n                )\n            )\n            blocks.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=reversed_block_out_channels[0],\n                )\n            )\n\n        reversed_attention_levels = list(reversed(attention_levels))\n        reversed_num_res_blocks = list(reversed(num_res_blocks))\n        block_out_ch = reversed_block_out_channels[0]\n        for i in range(len(reversed_block_out_channels)):\n            block_in_ch = block_out_ch\n            block_out_ch = reversed_block_out_channels[i]\n            is_final_block = i == len(num_channels) - 1\n\n            for _ in range(reversed_num_res_blocks[i]):\n                blocks.append(\n                    MaisiResBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=block_in_ch,\n                        norm_num_groups=norm_num_groups,\n                        norm_eps=norm_eps,\n                        out_channels=block_out_ch,\n                        num_splits=num_splits,\n                        dim_split=dim_split,\n                        norm_float16=norm_float16,\n                        print_info=print_info,\n                        save_mem=save_mem,\n                    )\n                )\n                block_in_ch = block_out_ch\n\n                if reversed_attention_levels[i]:\n                    blocks.append(\n                        SpatialAttentionBlock(\n                            spatial_dims=spatial_dims,\n                            num_channels=block_in_ch,\n                            norm_num_groups=norm_num_groups,\n                            norm_eps=norm_eps,\n                            include_fc=include_fc,\n                            use_combined_linear=use_combined_linear,\n                            use_flash_attention=use_flash_attention,\n                        )\n                    )\n\n            if not is_final_block:\n                blocks.append(\n                    MaisiUpsample(\n                        spatial_dims=spatial_dims,\n                        in_channels=block_in_ch,\n                        use_convtranspose=use_convtranspose,\n                        num_splits=num_splits,\n                        dim_split=dim_split,\n                        print_info=print_info,\n                        save_mem=save_mem,\n                    )\n                )\n\n        blocks.append(\n            MaisiGroupNorm3D(\n                num_groups=norm_num_groups,\n                num_channels=block_in_ch,\n                eps=norm_eps,\n                affine=True,\n                norm_float16=norm_float16,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n        )\n        blocks.append(\n            MaisiConvolution(\n                spatial_dims=spatial_dims,\n                in_channels=block_in_ch,\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n                num_splits=num_splits,\n                dim_split=dim_split,\n                print_info=print_info,\n                save_mem=save_mem,\n            )\n        )\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            x = block(x)\n            _empty_cuda_cache(self.save_mem)\n        return x\n\n\nclass AutoencoderKlMaisi(AutoencoderKL):\n    \"\"\"\n    AutoencoderKL with custom MaisiEncoder and MaisiDecoder.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).\n        in_channels: Number of input channels.\n        out_channels: Number of output channels.\n        num_res_blocks: Number of residual blocks per level.\n        num_channels: Sequence of block output channels.\n        attention_levels: Indicate which level from num_channels contain an attention block.\n        latent_channels: Number of channels in the latent space.\n        norm_num_groups: Number of groups for the group norm layers.\n        norm_eps: Epsilon for the normalization.\n        with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.\n        with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.\n        include_fc: whether to include the final linear layer. Default to False.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.\n        use_checkpointing: If True, use activation checkpointing.\n        use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.\n        num_splits: Number of splits for the input tensor.\n        dim_split: Dimension of splitting for the input tensor.\n        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.\n        print_info: Whether to print information, default to `False`.\n        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: Sequence[int],\n        num_channels: Sequence[int],\n        attention_levels: Sequence[bool],\n        latent_channels: int = 3,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        with_encoder_nonlocal_attn: bool = False,\n        with_decoder_nonlocal_attn: bool = False,\n        include_fc: bool = False,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n        use_checkpointing: bool = False,\n        use_convtranspose: bool = False,\n        num_splits: int = 16,\n        dim_split: int = 0,\n        norm_float16: bool = False,\n        print_info: bool = False,\n        save_mem: bool = True,\n    ) -> None:\n        super().__init__(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            num_res_blocks,\n            num_channels,\n            attention_levels,\n            latent_channels,\n            norm_num_groups,\n            norm_eps,\n            with_encoder_nonlocal_attn,\n            with_decoder_nonlocal_attn,\n            use_checkpointing,\n            use_convtranspose,\n            include_fc,\n            use_combined_linear,\n            use_flash_attention,\n        )\n\n        self.encoder: nn.Module = MaisiEncoder(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            num_channels=num_channels,\n            out_channels=latent_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            attention_levels=attention_levels,\n            with_nonlocal_attn=with_encoder_nonlocal_attn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n            num_splits=num_splits,\n            dim_split=dim_split,\n            norm_float16=norm_float16,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n\n        self.decoder: nn.Module = MaisiDecoder(\n            spatial_dims=spatial_dims,\n            num_channels=num_channels,\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            attention_levels=attention_levels,\n            with_nonlocal_attn=with_decoder_nonlocal_attn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n            use_convtranspose=use_convtranspose,\n            num_splits=num_splits,\n            dim_split=dim_split,\n            norm_float16=norm_float16,\n            print_info=print_info,\n            save_mem=save_mem,\n        )\n"
  },
  {
    "path": "monai/apps/generation/maisi/networks/controlnet_maisi.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\n\nfrom monai.networks.nets.controlnet import ControlNet\nfrom monai.networks.nets.diffusion_model_unet import get_timestep_embedding\n\n\nclass ControlNetMaisi(ControlNet):\n    \"\"\"\n    Control network for diffusion models based on Zhang and Agrawala \"Adding Conditional Control to Text-to-Image\n    Diffusion Models\" (https://arxiv.org/abs/2302.05543)\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        num_res_blocks: number of residual blocks (see ResnetBlock) per level.\n        num_channels: tuple of block output channels.\n        attention_levels: list of levels to add attention.\n        norm_num_groups: number of groups for the normalization.\n        norm_eps: epsilon for the normalization.\n        resblock_updown: if True use residual blocks for up/downsampling.\n        num_head_channels: number of channels in each attention head.\n        with_conditioning: if True add spatial transformers to perform conditioning.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`\n            classes.\n        upcast_attention: if True, upcast attention operations to full precision.\n        conditioning_embedding_in_channels: number of input channels for the conditioning embedding.\n        conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.\n        use_checkpointing: if True, use activation checkpointing to save memory.\n        include_fc: whether to include the final linear layer. Default to False.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        num_channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        resblock_updown: bool = False,\n        num_head_channels: int | Sequence[int] = 8,\n        with_conditioning: bool = False,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        num_class_embeds: int | None = None,\n        upcast_attention: bool = False,\n        conditioning_embedding_in_channels: int = 1,\n        conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),\n        use_checkpointing: bool = True,\n        include_fc: bool = False,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__(\n            spatial_dims,\n            in_channels,\n            num_res_blocks,\n            num_channels,\n            attention_levels,\n            norm_num_groups,\n            norm_eps,\n            resblock_updown,\n            num_head_channels,\n            with_conditioning,\n            transformer_num_layers,\n            cross_attention_dim,\n            num_class_embeds,\n            upcast_attention,\n            conditioning_embedding_in_channels,\n            conditioning_embedding_num_channels,\n            include_fc,\n            use_combined_linear,\n            use_flash_attention,\n        )\n        self.use_checkpointing = use_checkpointing\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        controlnet_cond: torch.Tensor,\n        conditioning_scale: float = 1.0,\n        context: torch.Tensor | None = None,\n        class_labels: torch.Tensor | None = None,\n    ) -> tuple[list[torch.Tensor], torch.Tensor]:\n        emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)\n        h = self._apply_initial_convolution(x)\n        if self.use_checkpointing:\n            controlnet_cond = torch.utils.checkpoint.checkpoint(\n                self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False\n            )\n        else:\n            controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)\n        h += controlnet_cond\n        down_block_res_samples, h = self._apply_down_blocks(emb, context, h)\n        h = self._apply_mid_block(emb, context, h)\n        down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples)\n        # scaling\n        down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]\n        mid_block_res_sample *= conditioning_scale\n\n        return down_block_res_samples, mid_block_res_sample\n\n    def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):\n        # 1. time\n        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=x.dtype)\n        emb = self.time_embed(t_emb)\n\n        # 2. class\n        if self.num_class_embeds is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n            class_emb = self.class_embedding(class_labels)\n            class_emb = class_emb.to(dtype=x.dtype)\n            emb = emb + class_emb\n\n        return emb\n\n    def _apply_initial_convolution(self, x):\n        # 3. initial convolution\n        h = self.conv_in(x)\n        return h\n\n    def _apply_down_blocks(self, emb, context, h):\n        # 4. down\n        if context is not None and self.with_conditioning is False:\n            raise ValueError(\"model should have with_conditioning = True if context is provided\")\n        down_block_res_samples: list[torch.Tensor] = [h]\n        for downsample_block in self.down_blocks:\n            h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)\n            for residual in res_samples:\n                down_block_res_samples.append(residual)\n\n        return down_block_res_samples, h\n\n    def _apply_mid_block(self, emb, context, h):\n        # 5. mid\n        h = self.middle_block(hidden_states=h, temb=emb, context=context)\n        return h\n\n    def _apply_controlnet_blocks(self, h, down_block_res_samples):\n        # 6. Control net blocks\n        controlnet_down_block_res_samples = []\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples.append(down_block_res_sample)\n\n        mid_block_res_sample = self.controlnet_mid_block(h)\n\n        return controlnet_down_block_res_samples, mid_block_res_sample\n"
  },
  {
    "path": "monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.nets.diffusion_model_unet import (\n    get_down_block,\n    get_mid_block,\n    get_timestep_embedding,\n    get_up_block,\n    zero_module,\n)\nfrom monai.utils import ensure_tuple_rep\nfrom monai.utils.type_conversion import convert_to_tensor\n\n__all__ = [\"DiffusionModelUNetMaisi\"]\n\n\nclass DiffusionModelUNetMaisi(nn.Module):\n    \"\"\"\n    U-Net network with timestep embedding and attention mechanisms for conditioning based on\n    Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n    and Pinaya et al. \"Brain Imaging Generation with Latent Diffusion Models\" https://arxiv.org/abs/2209.07162\n\n    Args:\n        spatial_dims: Number of spatial dimensions.\n        in_channels: Number of input channels.\n        out_channels: Number of output channels.\n        num_res_blocks: Number of residual blocks (see ResnetBlock) per level. Can be a single integer or a sequence of integers.\n        num_channels: Tuple of block output channels.\n        attention_levels: List of levels to add attention.\n        norm_num_groups: Number of groups for the normalization.\n        norm_eps: Epsilon for the normalization.\n        resblock_updown: If True, use residual blocks for up/downsampling.\n        num_head_channels: Number of channels in each attention head. Can be a single integer or a sequence of integers.\n        with_conditioning: If True, add spatial transformers to perform conditioning.\n        transformer_num_layers: Number of layers of Transformer blocks to use.\n        cross_attention_dim: Number of context dimensions to use.\n        num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.\n        upcast_attention: If True, upcast attention operations to full precision.\n        include_fc: whether to include the final linear layer. Default to False.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.\n        dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.\n        include_top_region_index_input: If True, use top region index input.\n        include_bottom_region_index_input: If True, use bottom region index input.\n        include_spacing_input: If True, use spacing input.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        num_channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        resblock_updown: bool = False,\n        num_head_channels: int | Sequence[int] = 8,\n        with_conditioning: bool = False,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        num_class_embeds: int | None = None,\n        upcast_attention: bool = False,\n        include_fc: bool = False,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n        dropout_cattn: float = 0.0,\n        include_top_region_index_input: bool = False,\n        include_bottom_region_index_input: bool = False,\n        include_spacing_input: bool = False,\n    ) -> None:\n        super().__init__()\n        if with_conditioning is True and cross_attention_dim is None:\n            raise ValueError(\n                \"DiffusionModelUNetMaisi expects dimension of the cross-attention conditioning (cross_attention_dim) \"\n                \"when using with_conditioning.\"\n            )\n        if cross_attention_dim is not None and with_conditioning is False:\n            raise ValueError(\n                \"DiffusionModelUNetMaisi expects with_conditioning=True when specifying the cross_attention_dim.\"\n            )\n        if dropout_cattn > 1.0 or dropout_cattn < 0.0:\n            raise ValueError(\"Dropout cannot be negative or >1.0!\")\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):\n            raise ValueError(\n                f\"DiffusionModelUNetMaisi expects all num_channels being multiple of norm_num_groups, \"\n                f\"but get num_channels: {num_channels} and norm_num_groups: {norm_num_groups}\"\n            )\n\n        if len(num_channels) != len(attention_levels):\n            raise ValueError(\n                f\"DiffusionModelUNetMaisi expects num_channels being same size of attention_levels, \"\n                f\"but get num_channels: {len(num_channels)} and attention_levels: {len(attention_levels)}\"\n            )\n\n        if isinstance(num_head_channels, int):\n            num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))\n\n        if len(num_head_channels) != len(attention_levels):\n            raise ValueError(\n                \"num_head_channels should have the same length as attention_levels. For the i levels without attention,\"\n                \" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.\"\n            )\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))\n\n        if len(num_res_blocks) != len(num_channels):\n            raise ValueError(\n                \"`num_res_blocks` should be a single integer or a tuple of integers with the same length as \"\n                \"`num_channels`.\"\n            )\n\n        if use_flash_attention is True and not torch.cuda.is_available():\n            raise ValueError(\n                \"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU.\"\n            )\n\n        self.in_channels = in_channels\n        self.block_out_channels = num_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_levels = attention_levels\n        self.num_head_channels = num_head_channels\n        self.with_conditioning = with_conditioning\n\n        # input\n        self.conv_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=num_channels[0],\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        # time\n        time_embed_dim = num_channels[0] * 4\n        self.time_embed = self._create_embedding_module(num_channels[0], time_embed_dim)\n\n        # class embedding\n        self.num_class_embeds = num_class_embeds\n        if num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n\n        self.include_top_region_index_input = include_top_region_index_input\n        self.include_bottom_region_index_input = include_bottom_region_index_input\n        self.include_spacing_input = include_spacing_input\n\n        new_time_embed_dim = time_embed_dim\n        if self.include_top_region_index_input:\n            self.top_region_index_layer = self._create_embedding_module(4, time_embed_dim)\n            new_time_embed_dim += time_embed_dim\n        if self.include_bottom_region_index_input:\n            self.bottom_region_index_layer = self._create_embedding_module(4, time_embed_dim)\n            new_time_embed_dim += time_embed_dim\n        if self.include_spacing_input:\n            self.spacing_layer = self._create_embedding_module(3, time_embed_dim)\n            new_time_embed_dim += time_embed_dim\n\n        # down\n        self.down_blocks = nn.ModuleList([])\n        output_channel = num_channels[0]\n        for i in range(len(num_channels)):\n            input_channel = output_channel\n            output_channel = num_channels[i]\n            is_final_block = i == len(num_channels) - 1\n            down_block = get_down_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=new_time_embed_dim,\n                num_res_blocks=num_res_blocks[i],\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_downsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(attention_levels[i] and not with_conditioning),\n                with_cross_attn=(attention_levels[i] and with_conditioning),\n                num_head_channels=num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n                dropout_cattn=dropout_cattn,\n            )\n\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.middle_block = get_mid_block(\n            spatial_dims=spatial_dims,\n            in_channels=num_channels[-1],\n            temb_channels=new_time_embed_dim,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            with_conditioning=with_conditioning,\n            num_head_channels=num_head_channels[-1],\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n            dropout_cattn=dropout_cattn,\n        )\n\n        # up\n        self.up_blocks = nn.ModuleList([])\n        reversed_block_out_channels = list(reversed(num_channels))\n        reversed_num_res_blocks = list(reversed(num_res_blocks))\n        reversed_attention_levels = list(reversed(attention_levels))\n        reversed_num_head_channels = list(reversed(num_head_channels))\n        output_channel = reversed_block_out_channels[0]\n        for i in range(len(reversed_block_out_channels)):\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)]\n\n            is_final_block = i == len(num_channels) - 1\n\n            up_block = get_up_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                prev_output_channel=prev_output_channel,\n                out_channels=output_channel,\n                temb_channels=new_time_embed_dim,\n                num_res_blocks=reversed_num_res_blocks[i] + 1,\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_upsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(reversed_attention_levels[i] and not with_conditioning),\n                with_cross_attn=(reversed_attention_levels[i] and with_conditioning),\n                num_head_channels=reversed_num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n                dropout_cattn=dropout_cattn,\n            )\n\n            self.up_blocks.append(up_block)\n\n        # out\n        self.out = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True),\n            nn.SiLU(),\n            zero_module(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=num_channels[0],\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n            ),\n        )\n\n    def _create_embedding_module(self, input_dim, embed_dim):\n        model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim))\n        return model\n\n    def _get_time_and_class_embedding(self, x, timesteps, class_labels):\n        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=x.dtype)\n        emb = self.time_embed(t_emb)\n\n        if self.num_class_embeds is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n            class_emb = self.class_embedding(class_labels)\n            class_emb = class_emb.to(dtype=x.dtype)\n            emb += class_emb\n        return emb\n\n    def _get_input_embeddings(self, emb, top_index, bottom_index, spacing):\n        if self.include_top_region_index_input:\n            _emb = self.top_region_index_layer(top_index)\n            emb = torch.cat((emb, _emb), dim=1)\n        if self.include_bottom_region_index_input:\n            _emb = self.bottom_region_index_layer(bottom_index)\n            emb = torch.cat((emb, _emb), dim=1)\n        if self.include_spacing_input:\n            _emb = self.spacing_layer(spacing)\n            emb = torch.cat((emb, _emb), dim=1)\n        return emb\n\n    def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals):\n        if context is not None and self.with_conditioning is False:\n            raise ValueError(\"model should have with_conditioning = True if context is provided\")\n        down_block_res_samples: list[torch.Tensor] = [h]\n        for downsample_block in self.down_blocks:\n            h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)\n            down_block_res_samples.extend(res_samples)\n\n        # Additional residual conections for Controlnets\n        if down_block_additional_residuals is not None:\n            new_down_block_res_samples: list[torch.Tensor] = []\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample += down_block_additional_residual\n                new_down_block_res_samples.append(down_block_res_sample)\n\n            down_block_res_samples = new_down_block_res_samples\n        return h, down_block_res_samples\n\n    def _apply_up_blocks(self, h, emb, context, down_block_res_samples):\n        for upsample_block in self.up_blocks:\n            idx: int = -len(upsample_block.resnets)  # type: ignore\n            res_samples = down_block_res_samples[idx:]\n            down_block_res_samples = down_block_res_samples[:idx]\n            h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)\n\n        return h\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        context: torch.Tensor | None = None,\n        class_labels: torch.Tensor | None = None,\n        down_block_additional_residuals: tuple[torch.Tensor] | None = None,\n        mid_block_additional_residual: torch.Tensor | None = None,\n        top_region_index_tensor: torch.Tensor | None = None,\n        bottom_region_index_tensor: torch.Tensor | None = None,\n        spacing_tensor: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Forward pass through the UNet model.\n\n        Args:\n            x: Input tensor of shape (N, C, SpatialDims).\n            timesteps: Timestep tensor of shape (N,).\n            context: Context tensor of shape (N, 1, ContextDim).\n            class_labels: Class labels tensor of shape (N,).\n            down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims).\n            mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims).\n            top_region_index_tensor: Tensor representing top region index of shape (N, 4).\n            bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4).\n            spacing_tensor: Tensor representing spacing of shape (N, 3).\n\n        Returns:\n            A tensor representing the output of the UNet model.\n        \"\"\"\n\n        emb = self._get_time_and_class_embedding(x, timesteps, class_labels)\n        emb = self._get_input_embeddings(emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor)\n        h = self.conv_in(x)\n        h, _updated_down_block_res_samples = self._apply_down_blocks(h, emb, context, down_block_additional_residuals)\n        h = self.middle_block(h, emb, context)\n\n        # Additional residual conections for Controlnets\n        if mid_block_additional_residual is not None:\n            h += mid_block_additional_residual\n\n        h = self._apply_up_blocks(h, emb, context, _updated_down_block_res_samples)\n        h = self.out(h)\n        h_tensor: torch.Tensor = convert_to_tensor(h)\n        return h_tensor\n"
  },
  {
    "path": "monai/apps/mmars/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .mmars import download_mmar, get_model_spec, load_from_mmar\nfrom .model_desc import MODEL_DESC, RemoteMMARKeys\n"
  },
  {
    "path": "monai/apps/mmars/mmars.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities for accessing Nvidia MMARs\n\nSee Also:\n    - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport warnings\nfrom collections.abc import Mapping\nfrom pathlib import Path\nfrom typing import Any\n\nimport torch\n\nimport monai.networks.nets as monai_nets\nfrom monai.apps.utils import download_and_extract, logger\nfrom monai.config.type_definitions import PathLike\nfrom monai.networks.utils import copy_model_state\nfrom monai.utils.module import optional_import\n\nfrom .model_desc import MODEL_DESC\nfrom .model_desc import RemoteMMARKeys as Keys\n\n__all__ = [\"get_model_spec\", \"download_mmar\", \"load_from_mmar\"]\n\n\ndef get_model_spec(idx: int | str) -> dict | Any:\n    \"\"\"get model specification by `idx`. `idx` could be index of the constant tuple of dict or the actual model ID.\"\"\"\n    if isinstance(idx, int):\n        return MODEL_DESC[idx]\n    if isinstance(idx, str):\n        key = idx.strip().lower()\n        for cand in MODEL_DESC:\n            if str(cand.get(Keys.ID)).strip().lower() == key:\n                return cand\n    return idx\n\n\ndef _get_all_ngc_models(pattern, page_index=0, page_size=50):\n    url = \"https://api.ngc.nvidia.com/v2/search/catalog/resources/MODEL\"\n    query_dict = {\n        \"query\": \"\",\n        \"orderBy\": [{\"field\": \"score\", \"value\": \"DESC\"}],\n        \"queryFields\": [\"all\", \"description\", \"displayName\", \"name\", \"resourceId\"],\n        \"fields\": [\n            \"isPublic\",\n            \"attributes\",\n            \"guestAccess\",\n            \"name\",\n            \"orgName\",\n            \"teamName\",\n            \"displayName\",\n            \"dateModified\",\n            \"labels\",\n            \"description\",\n        ],\n        \"page\": 0,\n    }\n\n    filter = [dict(field=\"name\", value=f\"*{pattern}*\")]\n    query_dict[\"page\"] = page_index\n    query_dict[\"pageSize\"] = page_size\n    query_dict[\"filters\"] = filter\n    query_str = json.dumps(query_dict)\n    full_url = f\"{url}?q={query_str}\"\n    requests_get, has_requests = optional_import(\"requests\", name=\"get\")\n    if has_requests:\n        resp = requests_get(full_url)\n        resp.raise_for_status()\n    else:\n        raise ValueError(\"NGC API requires requests package.  Please install it.\")\n    model_list = json.loads(resp.text)\n    model_dict = {}\n    for result in model_list[\"results\"]:\n        for model in result[\"resources\"]:\n            current_res_id = model[\"resourceId\"]\n            model_dict[current_res_id] = {\"name\": model[\"name\"]}\n            for attribute in model[\"attributes\"]:\n                if attribute[\"key\"] == \"latestVersionIdStr\":\n                    model_dict[current_res_id][\"latest\"] = attribute[\"value\"]\n    return model_dict\n\n\ndef _get_ngc_url(model_name: str, version: str, model_prefix: str = \"\") -> str:\n    return f\"https://api.ngc.nvidia.com/v2/models/{model_prefix}{model_name}/versions/{version}/zip\"\n\n\ndef _get_ngc_doc_url(model_name: str, model_prefix: str = \"\") -> str:\n    return f\"https://ngc.nvidia.com/catalog/models/{model_prefix}{model_name}\"\n\n\ndef download_mmar(\n    item: str | Mapping, mmar_dir: PathLike | None = None, progress: bool = True, api: bool = True, version: int = -1\n) -> Path:\n    \"\"\"\n    Download and extract Medical Model Archive (MMAR) from Nvidia Clara Train.\n\n    See Also:\n        - https://docs.nvidia.com/clara/\n        - Nvidia NGC Registry CLI\n        - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html\n\n    Args:\n        item: the corresponding model item from `MODEL_DESC`.\n          Or when api is True, the substring to query NGC's model name field.\n        mmar_dir: target directory to store the MMAR, default is `mmars` subfolder under `torch.hub get_dir()`.\n        progress: whether to display a progress bar.\n        api: whether to query NGC and download via api\n        version: which version of MMAR to download.  -1 means the latest from ngc.\n\n    Examples::\n        >>> from monai.apps import download_mmar\n        >>> download_mmar(\"clara_pt_prostate_mri_segmentation_1\", mmar_dir=\".\")\n        >>> download_mmar(\"prostate_mri_segmentation\", mmar_dir=\".\", api=True)\n\n\n    Returns:\n        The local directory of the downloaded model.\n        If api is True, a list of local directories of downloaded models.\n    \"\"\"\n    if not mmar_dir:\n        get_dir, has_home = optional_import(\"torch.hub\", name=\"get_dir\")\n        if has_home:\n            mmar_dir = Path(get_dir()) / \"mmars\"\n        else:\n            raise ValueError(\"mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?\")\n    _mmar_dir = Path(mmar_dir)\n    model_dir: Path\n    if api:\n        model_dict = _get_all_ngc_models(item.get(Keys.NAME, f\"{item}\") if isinstance(item, Mapping) else f\"{item}\")\n        if len(model_dict) == 0:\n            raise ValueError(f\"api query returns no item for pattern {item}.  Please change or shorten it.\")\n        model_dir_list: list[Path] = []\n        for k, v in model_dict.items():\n            ver = v[\"latest\"] if version == -1 else str(version)\n            download_url = _get_ngc_url(k, ver)\n            model_dir = _mmar_dir / v[\"name\"]\n            download_and_extract(\n                url=download_url,\n                filepath=_mmar_dir / f'{v[\"name\"]}_{ver}.zip',\n                output_dir=model_dir,\n                hash_val=None,\n                hash_type=\"md5\",\n                file_type=\"zip\",\n                has_base=False,\n                progress=progress,\n            )\n            model_dir_list.append(model_dir)\n        if not model_dir_list:\n            raise ValueError(f\"api query download no item for pattern {item}.  Please change or shorten it.\")\n        return model_dir_list[0]\n\n    if not isinstance(item, Mapping):\n        item = get_model_spec(item)\n    ver = item.get(Keys.VERSION, 1)\n    if version > 0:\n        ver = str(version)\n    model_fullname = f\"{item[Keys.NAME]}_{ver}\"\n    model_dir = _mmar_dir / model_fullname\n    model_url = item.get(Keys.URL) or _get_ngc_url(item[Keys.NAME], version=ver, model_prefix=\"nvidia/med/\")\n    download_and_extract(\n        url=model_url,\n        filepath=_mmar_dir / f\"{model_fullname}.{item[Keys.FILE_TYPE]}\",\n        output_dir=model_dir,\n        hash_val=item[Keys.HASH_VAL],\n        hash_type=item[Keys.HASH_TYPE],\n        file_type=item[Keys.FILE_TYPE],\n        has_base=False,\n        progress=progress,\n    )\n    return model_dir\n\n\ndef load_from_mmar(\n    item: Mapping | str | int,\n    mmar_dir: PathLike | None = None,\n    progress: bool = True,\n    version: int = -1,\n    map_location: Any | None = None,\n    pretrained: bool = True,\n    weights_only: bool = False,\n    model_key: str = \"model\",\n    api: bool = True,\n    model_file: PathLike | None = None,\n) -> Any:\n    \"\"\"\n    Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train.\n\n    Args:\n        item: the corresponding model item from `MODEL_DESC`.\n        mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`.\n        progress: whether to display a progress bar when downloading the content.\n        version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`.\n        map_location: pytorch API parameter for `torch.load` or `torch.jit.load`.\n        pretrained: whether to load the pretrained weights after initializing a network module.\n        weights_only: whether to load only the weights instead of initializing the network module and assign weights.\n        model_key: a key to search in the model file or config file for the model dictionary.\n            Currently this function assumes that the model dictionary has\n            `{\"[name|path]\": \"test.module\", \"args\": {'kw': 'test'}}`.\n        api: whether to query NGC API to get model infomation.\n        model_file: the relative path to the model file within an MMAR.\n\n    Examples::\n        >>> from monai.apps import load_from_mmar\n        >>> unet_model = load_from_mmar(\"clara_pt_prostate_mri_segmentation_1\", mmar_dir=\".\", map_location=\"cpu\")\n        >>> print(unet_model)\n\n    See Also:\n        https://docs.nvidia.com/clara/\n    \"\"\"\n    if api:\n        item = {Keys.NAME: get_model_spec(item)[Keys.NAME] if isinstance(item, int) else f\"{item}\"}\n    if not isinstance(item, Mapping):\n        item = get_model_spec(item)\n    model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version, api=api)\n    if model_file is None:\n        model_file = os.path.join(\"models\", \"model.pt\")\n    _model_file = model_dir / item.get(Keys.MODEL_FILE, model_file)\n    logger.info(f'\\n*** \"{item.get(Keys.NAME)}\" available at {model_dir}.')\n\n    # loading with `torch.jit.load`\n    if _model_file.name.endswith(\".ts\"):\n        if not pretrained:\n            warnings.warn(\"Loading a ScriptModule, 'pretrained' option ignored.\")\n        if weights_only:\n            warnings.warn(\"Loading a ScriptModule, 'weights_only' option ignored.\")\n        return torch.jit.load(_model_file, map_location=map_location)\n\n    # loading with `torch.load`\n    model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)\n    if weights_only:\n        return model_dict.get(model_key, model_dict)  # model_dict[model_key] or model_dict directly\n\n    # 1. search `model_dict['train_config]` for model config spec.\n    model_config = _get_val(dict(model_dict).get(\"train_conf\", {}), key=model_key, default={})\n    if not model_config or not isinstance(model_config, Mapping):\n        # 2. search json CONFIG_FILE for model config spec.\n        json_path = model_dir / item.get(Keys.CONFIG_FILE, os.path.join(\"config\", \"config_train.json\"))\n        with open(json_path) as f:\n            conf_dict = json.load(f)\n        conf_dict = dict(conf_dict)\n        model_config = _get_val(conf_dict, key=model_key, default={})\n    if not model_config:\n        # 3. search `model_dict` for model config spec.\n        model_config = _get_val(dict(model_dict), key=model_key, default={})\n\n    if not (model_config and isinstance(model_config, Mapping)):\n        raise ValueError(\n            f\"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, \"\n            f\"or from model file: {item.get(Keys.MODEL_FILE)}.\"\n        )\n\n    # parse `model_config` for model class and model parameters\n    if model_config.get(\"name\"):  # model config section is a \"name\"\n        model_name = model_config[\"name\"]\n        model_cls = monai_nets.__dict__[model_name]\n    elif model_config.get(\"path\"):  # model config section is a \"path\"\n        # https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html\n        model_module, model_name = model_config.get(\"path\", \".\").rsplit(\".\", 1)\n        model_cls, has_cls = optional_import(module=model_module, name=model_name)\n        if not has_cls:\n            raise ValueError(\n                f\"Could not load MMAR model config {model_config.get('path', '')}, \"\n                f\"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH.\"\n                \"See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html\"\n            )\n    else:\n        raise ValueError(f\"Could not load model config {model_config}.\")\n\n    logger.info(f\"*** Model: {model_cls}\")\n    model_kwargs = model_config.get(\"args\", None)\n    if model_kwargs:\n        model_inst = model_cls(**model_kwargs)\n        logger.info(f\"*** Model params: {model_kwargs}\")\n    else:\n        model_inst = model_cls()\n    if pretrained:\n        _, changed, unchanged = copy_model_state(model_inst, model_dict.get(model_key, model_dict), inplace=True)\n        if not (changed and not unchanged):  # not all model_inst variables are changed\n            logger.warning(f\"*** Loading model state -- unchanged: {len(unchanged)}, changed: {len(changed)}.\")\n    logger.info(\"\\n---\")\n    doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix=\"nvidia:med:\")\n    logger.info(f\"For more information, please visit {doc_url}\\n\")\n    return model_inst\n\n\ndef _get_val(input_dict: Mapping, key: str = \"model\", default: Any | None = None) -> Any | None:\n    \"\"\"\n    Search for the item with `key` in `config_dict`.\n    Returns: the first occurrence of `key` in a breadth first search.\n    \"\"\"\n    if key in input_dict:\n        return input_dict[key]\n    for sub_dict in input_dict:\n        val = input_dict[sub_dict]\n        if isinstance(val, Mapping):\n            found_val = _get_val(val, key=key, default=None)\n            if found_val is not None:\n                return found_val\n    return default\n"
  },
  {
    "path": "monai/apps/mmars/model_desc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nCollection of the remote MMAR descriptors\n\nSee Also:\n    - https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Any\n\n__all__ = [\"MODEL_DESC\", \"RemoteMMARKeys\"]\n\n\nclass RemoteMMARKeys:\n    \"\"\"\n    Data keys used for loading MMAR.\n    ID must uniquely define an MMAR.\n    \"\"\"\n\n    ID = \"id\"  # unique MMAR\n    NAME = \"name\"  # MMAR name for readability\n    URL = \"url\"  # remote location of the MMAR, see also: `monai.apps.mmars.mmars._get_ngc_url`\n    DOC = \"doc\"  # documentation page of the remote model, see also: `monai.apps.mmars.mmars._get_ngc_doc_url`\n    FILE_TYPE = \"file_type\"  # type of the compressed MMAR\n    HASH_TYPE = \"hash_type\"  # hashing method for the compressed MMAR\n    HASH_VAL = \"hash_val\"  # hashing value for the compressed MMAR\n    MODEL_FILE = \"model_file\"  # within an MMAR folder, the relative path to the model file\n    CONFIG_FILE = \"config_file\"  # within an MMAR folder, the relative path to the config file (for model config)\n    VERSION = \"version\"  # version of the MMAR\n\n\nMODEL_DESC: tuple[dict[Any, Any], ...] = (\n    {\n        RemoteMMARKeys.ID: \"clara_pt_spleen_ct_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_spleen_ct_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_prostate_mri_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_prostate_mri_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_covid19_ct_lesion_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_covid19_ct_lesion_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_covid19_3d_ct_classification_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_covid19_3d_ct_classification\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_covid19_ct_lung_annotation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_covid19_ct_lung_annotation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_fed_learning_brain_tumor_mri_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_fed_learning_brain_tumor_mri_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"server\", \"best_FL_global_model.pt\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_pathology_metastasis_detection_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_pathology_metastasis_detection\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_brain_mri_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_brain_mri_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_brain_mri_segmentation_t1c_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_brain_mri_segmentation_t1c\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_liver_and_tumor_ct_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_liver_and_tumor_ct_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_pancreas_and_tumor_ct_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_pancreas_and_tumor_ct_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_brain_mri_annotation_t1c_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_brain_mri_annotation_t1c\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_spleen_ct_annotation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_spleen_ct_annotation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_deepgrow_3d_annotation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_deepgrow_3d_annotation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_deepgrow_2d_annotation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_deepgrow_2d_annotation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_covid19_ct_lung_segmentation_1\",\n        RemoteMMARKeys.NAME: \"clara_pt_covid19_ct_lung_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_unetr_ct_btcv_segmentation\",\n        RemoteMMARKeys.NAME: \"clara_pt_unetr_ct_btcv_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 4.1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_chest_xray_classification\",\n        RemoteMMARKeys.NAME: \"clara_pt_chest_xray_classification\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models\", \"model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 4.1,\n    },\n    {\n        RemoteMMARKeys.ID: \"clara_pt_self_supervised_learning_segmentation\",\n        RemoteMMARKeys.NAME: \"clara_pt_self_supervised_learning_segmentation\",\n        RemoteMMARKeys.FILE_TYPE: \"zip\",\n        RemoteMMARKeys.HASH_TYPE: \"md5\",\n        RemoteMMARKeys.HASH_VAL: None,\n        RemoteMMARKeys.MODEL_FILE: os.path.join(\"models_2gpu\", \"best_metric_model.pt\"),\n        RemoteMMARKeys.CONFIG_FILE: os.path.join(\"config\", \"config_train.json\"),\n        RemoteMMARKeys.VERSION: 4.1,\n    },\n)\n"
  },
  {
    "path": "monai/apps/nnunet/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .nnunet_bundle import (\n    ModelnnUNetWrapper,\n    convert_monai_bundle_to_nnunet,\n    convert_nnunet_to_monai_bundle,\n    get_network_from_nnunet_plans,\n    get_nnunet_monai_predictor,\n    get_nnunet_trainer,\n)\nfrom .nnunetv2_runner import nnUNetV2Runner\nfrom .utils import NNUNETMode, analyze_data, create_new_data_copy, create_new_dataset_json\n"
  },
  {
    "path": "monai/apps/nnunet/__main__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.apps.nnunet.nnunetv2_runner import nnUNetV2Runner\n\nif __name__ == \"__main__\":\n    from monai.utils import optional_import\n\n    fire, _ = optional_import(\"fire\")\n    fire.Fire({\"nnUNetV2Runner\": nnUNetV2Runner})\n"
  },
  {
    "path": "monai/apps/nnunet/nnunet_bundle.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom torch.backends import cudnn\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.utils import optional_import\n\njoin, _ = optional_import(\"batchgenerators.utilities.file_and_folder_operations\", name=\"join\")\nload_json, _ = optional_import(\"batchgenerators.utilities.file_and_folder_operations\", name=\"load_json\")\n\n__all__ = [\n    \"get_nnunet_trainer\",\n    \"get_nnunet_monai_predictor\",\n    \"get_network_from_nnunet_plans\",\n    \"convert_nnunet_to_monai_bundle\",\n    \"convert_monai_bundle_to_nnunet\",\n    \"ModelnnUNetWrapper\",\n]\n\n\ndef get_nnunet_trainer(\n    dataset_name_or_id: str | int,\n    configuration: str,\n    fold: int | str,\n    trainer_class_name: str = \"nnUNetTrainer\",\n    plans_identifier: str = \"nnUNetPlans\",\n    use_compressed_data: bool = False,\n    continue_training: bool = False,\n    only_run_validation: bool = False,\n    disable_checkpointing: bool = False,\n    device: str = \"cuda\",\n    pretrained_model: str | None = None,\n) -> Any:  # type: ignore\n    \"\"\"\n    Get the nnUNet trainer instance based on the provided configuration.\n    The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,\n    optimizer, loss function, DataLoader, etc.\n\n    Example::\n\n        from monai.apps import SupervisedTrainer\n        from monai.bundle.nnunet import get_nnunet_trainer\n\n        dataset_name_or_id = 'Task009_Spleen'\n        fold = 0\n        configuration = '3d_fullres'\n        nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)\n\n        trainer = SupervisedTrainer(\n            device=nnunet_trainer.device,\n            max_epochs=nnunet_trainer.num_epochs,\n            train_data_loader=nnunet_trainer.dataloader_train,\n            network=nnunet_trainer.network,\n            optimizer=nnunet_trainer.optimizer,\n            loss_function=nnunet_trainer.loss_function,\n            epoch_length=nnunet_trainer.num_iterations_per_epoch,\n        )\n\n    Parameters\n    ----------\n    dataset_name_or_id : Union[str, int]\n        The name or ID of the dataset to be used.\n    configuration : str\n        The configuration name for the training.\n    fold : Union[int, str]\n        The fold number or 'all' for cross-validation.\n    trainer_class_name : str, optional\n        The class name of the trainer to be used. Default is 'nnUNetTrainer'.\n        For a complete list of supported trainers, check:\n        https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants\n    plans_identifier : str, optional\n        Identifier for the plans to be used. Default is 'nnUNetPlans'.\n    use_compressed_data : bool, optional\n        Whether to use compressed data. Default is False.\n    continue_training : bool, optional\n        Whether to continue training from a checkpoint. Default is False.\n    only_run_validation : bool, optional\n        Whether to only run validation. Default is False.\n    disable_checkpointing : bool, optional\n        Whether to disable checkpointing. Default is False.\n    device : str, optional\n        The device to be used for training. Default is 'cuda'.\n    pretrained_model : Optional[str], optional\n        Path to the pretrained model file.\n\n    Returns\n    -------\n    nnunet_trainer : object\n        The nnUNet trainer instance.\n    \"\"\"\n    # From nnUNet/nnunetv2/run/run_training.py#run_training\n    if isinstance(fold, str):\n        if fold != \"all\":\n            try:\n                fold = int(fold)\n            except ValueError as e:\n                print(\n                    f'Unable to convert given value for fold to int: {fold}. fold must bei either \"all\" or an integer!'\n                )\n                raise e\n\n    from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint\n\n    nnunet_trainer = get_trainer_from_args(\n        str(dataset_name_or_id), configuration, fold, trainer_class_name, plans_identifier, device=torch.device(device)\n    )\n    if disable_checkpointing:\n        nnunet_trainer.disable_checkpointing = disable_checkpointing\n\n    assert not (continue_training and only_run_validation), \"Cannot set --c and --val flag at the same time. Dummy.\"\n\n    maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation)\n    nnunet_trainer.on_train_start()  # Added to Initialize Trainer\n    if torch.cuda.is_available():\n        cudnn.deterministic = False\n        cudnn.benchmark = True\n\n    if pretrained_model is not None:\n        state_dict = torch.load(pretrained_model, weights_only=True)\n        if \"network_weights\" in state_dict:\n            nnunet_trainer.network._orig_mod.load_state_dict(state_dict[\"network_weights\"])\n    return nnunet_trainer\n\n\nclass ModelnnUNetWrapper(torch.nn.Module):\n    \"\"\"\n    A wrapper class for nnUNet model integration with MONAI framework.\n    The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.\n\n    Parameters\n    ----------\n    predictor : nnUNetPredictor\n        The nnUNet predictor object used for inference.\n    model_folder : Union[str, Path]\n        The folder path where the model and related files are stored.\n    model_name : str, optional\n        The name of the model file, by default \"model.pt\".\n\n    Attributes\n    ----------\n    predictor : nnUNetPredictor\n        The nnUNet predictor object used for inference.\n    network_weights : torch.nn.Module\n        The network weights of the model.\n\n    Notes\n    -----\n    This class integrates nnUNet model with MONAI framework by loading necessary configurations,\n    restoring network architecture, and setting up the predictor for inference.\n    \"\"\"\n\n    def __init__(self, predictor: object, model_folder: str | Path, model_name: str = \"model.pt\"):  # type: ignore\n        super().__init__()\n        self.predictor = predictor\n\n        model_training_output_dir = model_folder\n\n        from nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n        # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor\n        dataset_json = load_json(join(Path(model_training_output_dir).parent, \"dataset.json\"))\n        plans = load_json(join(Path(model_training_output_dir).parent, \"plans.json\"))\n        plans_manager = PlansManager(plans)\n\n        parameters = []\n\n        checkpoint = torch.load(\n            join(Path(model_training_output_dir).parent, \"nnunet_checkpoint.pth\"),\n            map_location=torch.device(\"cpu\"),\n            weights_only=True,\n        )\n        trainer_name = checkpoint[\"trainer_name\"]\n        configuration_name = checkpoint[\"init_args\"][\"configuration\"]\n        inference_allowed_mirroring_axes = (\n            checkpoint[\"inference_allowed_mirroring_axes\"]\n            if \"inference_allowed_mirroring_axes\" in checkpoint.keys()\n            else None\n        )\n        if Path(model_training_output_dir).joinpath(model_name).is_file():\n            monai_checkpoint = torch.load(\n                join(model_training_output_dir, model_name), map_location=torch.device(\"cpu\"), weights_only=True\n            )\n            if \"network_weights\" in monai_checkpoint.keys():\n                parameters.append(monai_checkpoint[\"network_weights\"])\n            else:\n                parameters.append(monai_checkpoint)\n\n        configuration_manager = plans_manager.get_configuration(configuration_name)\n        import nnunetv2\n        from nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n        from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\n\n        num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n        trainer_class = recursive_find_python_class(\n            join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"), trainer_name, \"nnunetv2.training.nnUNetTrainer\"\n        )\n        if trainer_class is None:\n            raise RuntimeError(\n                f\"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. \"\n                f\"Please place it there (in any .py file)!\"\n            )\n        network = trainer_class.build_network_architecture(\n            configuration_manager.network_arch_class_name,\n            configuration_manager.network_arch_init_kwargs,\n            configuration_manager.network_arch_init_kwargs_req_import,\n            num_input_channels,\n            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n            enable_deep_supervision=False,\n        )\n\n        predictor.plans_manager = plans_manager  # type: ignore\n        predictor.configuration_manager = configuration_manager  # type: ignore\n        predictor.list_of_parameters = parameters  # type: ignore\n        predictor.network = network  # type: ignore\n        predictor.dataset_json = dataset_json  # type: ignore\n        predictor.trainer_name = trainer_name  # type: ignore\n        predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes  # type: ignore\n        predictor.label_manager = plans_manager.get_label_manager(dataset_json)  # type: ignore\n\n        self.network_weights = self.predictor.network  # type: ignore\n\n    def forward(self, x: MetaTensor) -> MetaTensor:\n        \"\"\"\n        Forward pass for the nnUNet model.\n\n        Args:\n            x (MetaTensor): Input tensor. If the input is a tuple,\n                it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.\n\n        Returns:\n            MetaTensor: The output tensor with the same metadata as the input.\n\n        Raises:\n            TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.\n\n        Notes:\n            - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.\n            - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.\n            - The filenames are used to generate predictions using the nnUNet predictor.\n            - The predictions are converted to torch tensors, with added batch and channel dimensions.\n            - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.\n        \"\"\"\n        if isinstance(x, MetaTensor):\n            if \"pixdim\" in x.meta:\n                properties_or_list_of_properties = {\"spacing\": x.meta[\"pixdim\"][0][1:4].numpy().tolist()}\n            elif \"affine\" in x.meta:\n                spacing = [\n                    abs(x.meta[\"affine\"][0][0].item()),\n                    abs(x.meta[\"affine\"][1][1].item()),\n                    abs(x.meta[\"affine\"][2][2].item()),\n                ]\n                properties_or_list_of_properties = {\"spacing\": spacing}\n            else:\n                properties_or_list_of_properties = {\"spacing\": [1.0, 1.0, 1.0]}\n        else:\n            raise TypeError(\"Input must be a MetaTensor or a tuple of MetaTensors.\")\n\n        image_or_list_of_images = x.cpu().numpy()[0, :]\n\n        # input_files should be a list of file paths, one per modality\n        prediction_output = self.predictor.predict_from_list_of_npy_arrays(  # type: ignore\n            image_or_list_of_images,\n            None,\n            properties_or_list_of_properties,\n            truncated_ofname=None,\n            save_probabilities=False,\n            num_processes=2,\n            num_processes_segmentation_export=2,\n        )\n        # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax\n\n        out_tensors = []\n        for out in prediction_output:  # Add batch and channel dimensions\n            out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0)))\n        out_tensor = torch.cat(out_tensors, 0)  # Concatenate along batch dimension\n\n        return MetaTensor(out_tensor, meta=x.meta)\n\n\ndef get_nnunet_monai_predictor(model_folder: str | Path, model_name: str = \"model.pt\") -> ModelnnUNetWrapper:\n    \"\"\"\n    Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.\n    The model folder should contain the following files, created during training:\n\n        - dataset.json: from the nnUNet results folder\n        - plans.json: from the nnUNet results folder\n        - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration\n        - model.pt: The checkpoint file containing the model weights.\n\n    The returned wrapper object can be used for inference with MONAI framework:\n\n    Example::\n\n        from monai.bundle.nnunet import get_nnunet_monai_predictor\n\n        model_folder = 'path/to/monai_bundle/model'\n        model_name = 'model.pt'\n        wrapper = get_nnunet_monai_predictor(model_folder, model_name)\n\n        # Perform inference\n        input_data = ...\n        output = wrapper(input_data)\n\n\n    Parameters\n    ----------\n    model_folder : Union[str, Path]\n        The folder where the model is stored.\n    model_name : str, optional\n        The name of the model file, by default \"model.pt\".\n\n    Returns\n    -------\n    ModelnnUNetWrapper\n        A wrapper object that contains the nnUNetPredictor and the loaded model.\n    \"\"\"\n\n    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n\n    predictor = nnUNetPredictor(\n        tile_step_size=0.5,\n        use_gaussian=True,\n        use_mirroring=False,\n        device=torch.device(\"cuda\", 0),\n        verbose=False,\n        verbose_preprocessing=False,\n        allow_tqdm=True,\n    )\n    # initializes the network architecture, loads the checkpoint\n    wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name)\n    return wrapper\n\n\ndef convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:\n    \"\"\"\n    Convert nnUNet model checkpoints and configuration to MONAI bundle format.\n\n    Parameters\n    ----------\n    nnunet_config : dict\n        Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration',\n        'nnunet_trainer', and 'nnunet_plans'.\n    bundle_root_folder : str\n        Root folder where the MONAI bundle will be saved.\n    fold : int, optional\n        Fold number of the nnUNet model to be converted, by default 0.\n\n    Returns\n    -------\n    None\n    \"\"\"\n\n    nnunet_trainer = \"nnUNetTrainer\"\n    nnunet_plans = \"nnUNetPlans\"\n    nnunet_configuration = \"3d_fullres\"\n\n    if \"nnunet_trainer\" in nnunet_config:\n        nnunet_trainer = nnunet_config[\"nnunet_trainer\"]\n\n    if \"nnunet_plans\" in nnunet_config:\n        nnunet_plans = nnunet_config[\"nnunet_plans\"]\n\n    if \"nnunet_configuration\" in nnunet_config:\n        nnunet_configuration = nnunet_config[\"nnunet_configuration\"]\n\n    from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\n    dataset_name = maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"])\n    nnunet_model_folder = Path(os.environ[\"nnUNet_results\"]).joinpath(\n        dataset_name, f\"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}\"\n    )\n\n    nnunet_checkpoint_final = torch.load(\n        Path(nnunet_model_folder).joinpath(f\"fold_{fold}\", \"checkpoint_final.pth\"), weights_only=True\n    )\n    nnunet_checkpoint_best = torch.load(\n        Path(nnunet_model_folder).joinpath(f\"fold_{fold}\", \"checkpoint_best.pth\"), weights_only=True\n    )\n\n    nnunet_checkpoint = {}\n    nnunet_checkpoint[\"inference_allowed_mirroring_axes\"] = nnunet_checkpoint_final[\"inference_allowed_mirroring_axes\"]\n    nnunet_checkpoint[\"init_args\"] = nnunet_checkpoint_final[\"init_args\"]\n    nnunet_checkpoint[\"trainer_name\"] = nnunet_checkpoint_final[\"trainer_name\"]\n\n    torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath(\"models\", \"nnunet_checkpoint.pth\"))\n\n    Path(bundle_root_folder).joinpath(\"models\", f\"fold_{fold}\").mkdir(parents=True, exist_ok=True)\n    monai_last_checkpoint = {}\n    monai_last_checkpoint[\"network_weights\"] = nnunet_checkpoint_final[\"network_weights\"]\n    torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath(\"models\", f\"fold_{fold}\", \"model.pt\"))\n\n    monai_best_checkpoint = {}\n    monai_best_checkpoint[\"network_weights\"] = nnunet_checkpoint_best[\"network_weights\"]\n    torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath(\"models\", f\"fold_{fold}\", \"best_model.pt\"))\n\n    if not os.path.exists(os.path.join(bundle_root_folder, \"models\", \"plans.json\")):\n        shutil.copy(\n            Path(nnunet_model_folder).joinpath(\"plans.json\"), Path(bundle_root_folder).joinpath(\"models\", \"plans.json\")\n        )\n\n    if not os.path.exists(os.path.join(bundle_root_folder, \"models\", \"dataset.json\")):\n        shutil.copy(\n            Path(nnunet_model_folder).joinpath(\"dataset.json\"),\n            Path(bundle_root_folder).joinpath(\"models\", \"dataset.json\"),\n        )\n\n\ndef get_network_from_nnunet_plans(\n    plans_file: str,\n    dataset_file: str,\n    configuration: str,\n    model_ckpt: str | None = None,\n    model_key_in_ckpt: str = \"model\",\n) -> torch.nn.Module | Any:\n    \"\"\"\n    Load and initialize a nnUNet network based on nnUNet plans and configuration.\n\n    Parameters\n    ----------\n    plans_file : str\n        Path to the JSON file containing the nnUNet plans.\n    dataset_file : str\n        Path to the JSON file containing the dataset information.\n    configuration : str\n        The configuration name to be used from the plans.\n    model_ckpt : Optional[str], optional\n        Path to the model checkpoint file. If None, the network is returned without loading weights (default is None).\n    model_key_in_ckpt : str, optional\n        The key in the checkpoint file that contains the model state dictionary (default is \"model\").\n\n    Returns\n    -------\n    network : torch.nn.Module\n        The initialized neural network, with weights loaded if `model_ckpt` is provided.\n    \"\"\"\n    from batchgenerators.utilities.file_and_folder_operations import load_json\n    from nnunetv2.utilities.get_network_from_plans import get_network_from_plans\n    from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\n    from nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n\n    plans = load_json(plans_file)\n    dataset_json = load_json(dataset_file)\n\n    plans_manager = PlansManager(plans)\n    configuration_manager = plans_manager.get_configuration(configuration)\n    num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n    label_manager = plans_manager.get_label_manager(dataset_json)\n\n    enable_deep_supervision = True\n\n    network = get_network_from_plans(\n        configuration_manager.network_arch_class_name,\n        configuration_manager.network_arch_init_kwargs,\n        configuration_manager.network_arch_init_kwargs_req_import,\n        num_input_channels,\n        label_manager.num_segmentation_heads,\n        allow_init=True,\n        deep_supervision=enable_deep_supervision,\n    )\n\n    if model_ckpt is None:\n        return network\n    else:\n        state_dict = torch.load(model_ckpt, weights_only=True)\n        network.load_state_dict(state_dict[model_key_in_ckpt])\n        return network\n\n\ndef convert_monai_bundle_to_nnunet(nnunet_config: dict, bundle_root_folder: str, fold: int = 0) -> None:\n    \"\"\"\n    Convert a MONAI bundle to nnU-Net format.\n\n    Parameters\n    ----------\n    nnunet_config : dict\n        Configuration dictionary for nnU-Net. Expected keys are:\n        - \"dataset_name_or_id\": str, name or ID of the dataset.\n        - \"nnunet_trainer\": str, optional, name of the nnU-Net trainer (default is \"nnUNetTrainer\").\n        - \"nnunet_plans\": str, optional, name of the nnU-Net plans (default is \"nnUNetPlans\").\n    bundle_root_folder : str\n        Path to the root folder of the MONAI bundle.\n    fold : int, optional\n        Fold number for cross-validation (default is 0).\n\n    Returns\n    -------\n    None\n    \"\"\"\n    from odict import odict\n\n    nnunet_trainer: str = \"nnUNetTrainer\"\n    nnunet_plans: str = \"nnUNetPlans\"\n\n    if \"nnunet_trainer\" in nnunet_config:\n        nnunet_trainer = nnunet_config[\"nnunet_trainer\"]\n\n    if \"nnunet_plans\" in nnunet_config:\n        nnunet_plans = nnunet_config[\"nnunet_plans\"]\n\n    from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n    from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\n    def subfiles(\n        folder: str | Path, prefix: str | None = None, suffix: str | None = None, sort: bool = True\n    ) -> list[str]:\n        res = [\n            i.name\n            for i in Path(folder).iterdir()\n            if i.is_file()\n            and (prefix is None or i.name.startswith(prefix))\n            and (suffix is None or i.name.endswith(suffix))\n        ]\n        if sort:\n            res.sort()\n        return res\n\n    nnunet_model_folder: Path = Path(os.environ[\"nnUNet_results\"]).joinpath(\n        maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"]),\n        f\"{nnunet_trainer}__{nnunet_plans}__3d_fullres\",\n    )\n\n    nnunet_preprocess_model_folder: Path = Path(os.environ[\"nnUNet_preprocessed\"]).joinpath(\n        maybe_convert_to_dataset_name(nnunet_config[\"dataset_name_or_id\"])\n    )\n\n    Path(nnunet_model_folder).joinpath(f\"fold_{fold}\").mkdir(parents=True, exist_ok=True)\n\n    nnunet_checkpoint: dict = torch.load(f\"{bundle_root_folder}/models/nnunet_checkpoint.pth\", weights_only=True)\n    latest_checkpoints: list[str] = subfiles(\n        Path(bundle_root_folder).joinpath(\"models\", f\"fold_{fold}\"), prefix=\"checkpoint_epoch\", sort=True\n    )\n    epochs: list[int] = []\n    for latest_checkpoint in latest_checkpoints:\n        epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\") : -len(\".pt\")]))\n\n    epochs.sort()\n    final_epoch: int = epochs[-1]\n    monai_last_checkpoint: dict = torch.load(\n        f\"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt\", weights_only=True\n    )\n\n    best_checkpoints: list[str] = subfiles(\n        Path(bundle_root_folder).joinpath(\"models\", f\"fold_{fold}\"), prefix=\"checkpoint_key_metric\", sort=True\n    )\n    key_metrics: list[str] = []\n    for best_checkpoint in best_checkpoints:\n        key_metrics.append(str(best_checkpoint[len(\"checkpoint_key_metric=\") : -len(\".pt\")]))\n\n    key_metrics.sort()\n    best_key_metric: str = key_metrics[-1]\n    monai_best_checkpoint: dict = torch.load(\n        f\"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt\", weights_only=True\n    )\n\n    nnunet_checkpoint[\"optimizer_state\"] = monai_last_checkpoint[\"optimizer_state\"]\n\n    nnunet_checkpoint[\"network_weights\"] = odict()\n\n    for key in monai_last_checkpoint[\"network_weights\"]:\n        nnunet_checkpoint[\"network_weights\"][key] = monai_last_checkpoint[\"network_weights\"][key]\n\n    nnunet_checkpoint[\"current_epoch\"] = final_epoch\n    nnunet_checkpoint[\"logging\"] = nnUNetLogger().get_checkpoint()\n    nnunet_checkpoint[\"_best_ema\"] = 0\n    nnunet_checkpoint[\"grad_scaler_state\"] = None\n\n    torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f\"fold_{fold}\", \"checkpoint_final.pth\"))\n\n    nnunet_checkpoint[\"network_weights\"] = odict()\n\n    nnunet_checkpoint[\"optimizer_state\"] = monai_best_checkpoint[\"optimizer_state\"]\n\n    for key in monai_best_checkpoint[\"network_weights\"]:\n        nnunet_checkpoint[\"network_weights\"][key] = monai_best_checkpoint[\"network_weights\"][key]\n\n    torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(f\"fold_{fold}\", \"checkpoint_best.pth\"))\n\n    if not os.path.exists(os.path.join(nnunet_model_folder, \"dataset.json\")):\n        shutil.copy(f\"{bundle_root_folder}/models/dataset.json\", nnunet_model_folder)\n    if not os.path.exists(os.path.join(nnunet_model_folder, \"plans.json\")):\n        shutil.copy(f\"{bundle_root_folder}/models/plans.json\", nnunet_model_folder)\n    if not os.path.exists(os.path.join(nnunet_model_folder, \"dataset_fingerprint.json\")):\n        shutil.copy(f\"{nnunet_preprocess_model_folder}/dataset_fingerprint.json\", nnunet_model_folder)\n    if not os.path.exists(os.path.join(nnunet_model_folder, \"nnunet_checkpoint.pth\")):\n        shutil.copy(f\"{bundle_root_folder}/models/nnunet_checkpoint.pth\", nnunet_model_folder)\n"
  },
  {
    "path": "monai/apps/nnunet/nnunetv2_runner.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: disable=import-error\nfrom __future__ import annotations\n\nimport glob\nimport os\nimport shlex\nimport subprocess\nfrom typing import Any\n\nimport monai\nfrom monai.apps.nnunet.utils import NNUNETMode as M\nfrom monai.apps.nnunet.utils import analyze_data, create_new_data_copy, create_new_dataset_json\nfrom monai.bundle import ConfigParser\nfrom monai.utils import ensure_tuple, optional_import\nfrom monai.utils.misc import run_cmd\n\nload_pickle, _ = optional_import(\"batchgenerators.utilities.file_and_folder_operations\", name=\"load_pickle\")\njoin, _ = optional_import(\"batchgenerators.utilities.file_and_folder_operations\", name=\"join\")\ntqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\nnib, _ = optional_import(\"nibabel\")\n\nlogger = monai.apps.utils.get_logger(__name__)\n\n__all__ = [\"nnUNetV2Runner\"]\n\n\nclass nnUNetV2Runner:  # noqa: N801\n    \"\"\"\n    ``nnUNetV2Runner`` provides an interface in MONAI to use `nnU-Net` V2 library to analyze, train, and evaluate\n    neural networks for medical image segmentation tasks.\n    A version of nnunetv2 higher than 2.2 is needed for this class.\n\n    ``nnUNetV2Runner`` can be used in two ways:\n\n    #. with one line of code to execute the complete pipeline.\n    #. with a series of commands to run each modules in the pipeline.\n\n    The output of the interface is a directory that contains:\n\n    #. converted dataset met the requirement of nnU-Net V2\n    #. data analysis results\n    #. checkpoints from the trained U-Net models\n    #. validation accuracy in each fold of cross-validation\n    #. the predictions on the testing datasets from the final algorithm ensemble and potential post-processing\n\n    Args:\n        input_config: the configuration dictionary or the file path to the configuration in the form of YAML.\n            The keys required in the configuration are:\n            - ``\"datalist\"``: File path to the datalist for the train/testing splits\n            - ``\"dataroot\"``: File path to the dataset\n            - ``\"modality\"``: Imaging modality, e.g. \"CT\", [\"T2\", \"ADC\"]\n            Currently, the configuration supports these optional keys:\n            - ``\"nnunet_raw\"``: File path that will be written to env variable for nnU-Net\n            - ``\"nnunet_preprocessed\"``: File path that will be written to env variable for nnU-Net\n            - ``\"nnunet_results\"``: File path that will be written to env variable for nnU-Net\n            - ``\"nnUNet_trained_models\"``\n            - ``\"dataset_name_or_id\"``: Name or Integer ID of the dataset\n            If an optional key is not specified, then the pipeline will use the default values.\n        trainer_class_name: the trainer class names offered by nnUNetV2 exhibit variations in training duration.\n            Default: \"nnUNetTrainer\". Other options: \"nnUNetTrainer_Xepoch\". X could be one of 1,5,10,20,50,100,\n            250,2000,4000,8000.\n        export_validation_probabilities: True to save softmax predictions from final validation as npz\n            files (in addition to predicted segmentations). Needed for finding the best ensemble.\n            Default: True.\n        work_dir: working directory to save the intermediate and final results.\n\n    Examples:\n        - Use the one-liner to start the nnU-Net workflow\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner run --input_config ./input.yaml\n\n        - Use `convert_dataset` to prepare the data to meet nnU-Net requirements, generate dataset JSON file,\n            and copy the dataset to a location specified by ``nnunet_raw`` in the input config file\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner convert_dataset --input_config=\"./input.yaml\"\n\n        - `convert_msd_dataset` is an alternative option to prepare the data if the dataset is MSD.\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner convert_msd_dataset \\\\\n                --input_config \"./input.yaml\" --data_dir \"/path/to/Task09_Spleen\"\n\n        - experiment planning and data pre-processing\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner plan_and_process --input_config \"./input.yaml\"\n\n        - training all 20 models using all GPUs available.\n            \"CUDA_VISIBLE_DEVICES\" environment variable is not supported.\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner train --input_config \"./input.yaml\"\n\n        - training a single model on a single GPU for 5 epochs. Here ``config`` is used to specify the configuration.\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner train_single_model --input_config \"./input.yaml\" \\\\\n                --config \"3d_fullres\" \\\\\n                --fold 0 \\\\\n                --gpu_id 0 \\\\\n                --trainer_class_name \"nnUNetTrainer_5epochs\" \\\\\n                --export_validation_probabilities True\n\n        - training for all 20 models (4 configurations by 5 folds) on 2 GPUs\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner train --input_config \"./input.yaml\" --gpu_id_for_all \"0,1\"\n\n        - 5-fold training for a single model on 2 GPUs. Here ``configs`` is used to specify the configurations.\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner train --input_config \"./input.yaml\" \\\\\n                --configs \"3d_fullres\" \\\\\n                --trainer_class_name \"nnUNetTrainer_5epochs\" \\\\\n                --export_validation_probabilities True \\\\\n                --gpu_id_for_all \"0,1\"\n\n        - find the best configuration\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner find_best_configuration --input_config \"./input.yaml\"\n\n        - predict, ensemble, and post-process\n\n        .. code-block:: bash\n\n            python -m monai.apps.nnunet nnUNetV2Runner predict_ensemble_postprocessing --input_config \"./input.yaml\"\n\n    \"\"\"\n\n    def __init__(\n        self,\n        input_config: Any,\n        trainer_class_name: str = \"nnUNetTrainer\",\n        work_dir: str = \"work_dir\",\n        export_validation_probabilities: bool = True,\n    ) -> None:\n        self.input_info: dict = {}\n        self.input_config_or_dict = input_config\n        self.trainer_class_name = trainer_class_name\n        self.export_validation_probabilities = export_validation_probabilities\n        self.work_dir = work_dir\n\n        if isinstance(self.input_config_or_dict, dict):\n            self.input_info = self.input_config_or_dict\n        elif isinstance(self.input_config_or_dict, str) and os.path.isfile(self.input_config_or_dict):\n            self.input_info = ConfigParser.load_config_file(self.input_config_or_dict)\n        else:\n            raise ValueError(f\"{input_config} is not a valid file or dict\")\n\n        self.nnunet_raw = self.input_info.pop(\"nnunet_raw\", os.path.join(\".\", self.work_dir, \"nnUNet_raw_data_base\"))\n        self.nnunet_preprocessed = self.input_info.pop(\n            \"nnunet_preprocessed\", os.path.join(\".\", self.work_dir, \"nnUNet_preprocessed\")\n        )\n        self.nnunet_results = self.input_info.pop(\n            \"nnunet_results\", os.path.join(\".\", self.work_dir, \"nnUNet_trained_models\")\n        )\n\n        if not os.path.exists(self.nnunet_raw):\n            os.makedirs(self.nnunet_raw)\n\n        if not os.path.exists(self.nnunet_preprocessed):\n            os.makedirs(self.nnunet_preprocessed)\n\n        if not os.path.exists(self.nnunet_results):\n            os.makedirs(self.nnunet_results)\n\n        # claim environment variable\n        os.environ[\"nnUNet_raw\"] = self.nnunet_raw\n        os.environ[\"nnUNet_preprocessed\"] = self.nnunet_preprocessed\n        os.environ[\"nnUNet_results\"] = self.nnunet_results\n        os.environ[\"OMP_NUM_THREADS\"] = str(1)\n\n        # dataset_name_or_id has to be a string\n        self.dataset_name_or_id = str(self.input_info.pop(\"dataset_name_or_id\", 1))\n\n        try:\n            from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\n            self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id))\n        except BaseException:\n            logger.warning(\n                f\"Dataset with name/ID: {self.dataset_name_or_id} cannot be found in the record. \"\n                \"Please ignore the message above if you are running the pipeline from a fresh start. \"\n                \"But if the dataset is expected to be found, please check your input_config.\"\n            )\n\n        from nnunetv2.configuration import default_num_processes\n\n        self.default_num_processes = default_num_processes\n\n        self.num_folds = 5\n        self.best_configuration: dict = {}\n\n    def convert_dataset(self):\n        \"\"\"Convert and make a copy the dataset to meet the requirements of nnU-Net workflow.\"\"\"\n        try:\n            raw_data_foldername_prefix = str(int(self.dataset_name_or_id) + 1000)\n            raw_data_foldername_prefix = \"Dataset\" + raw_data_foldername_prefix[-3:]\n\n            # check if the dataset is created\n            subdirs = glob.glob(f\"{self.nnunet_raw}/*\")\n            dataset_ids = [_item.split(os.sep)[-1] for _item in subdirs]\n            dataset_ids = [_item.split(\"_\")[0] for _item in dataset_ids]\n            if raw_data_foldername_prefix in dataset_ids:\n                logger.warning(\"Dataset with the same ID exists!\")\n                return\n\n            data_dir = self.input_info.pop(\"dataroot\")\n            if data_dir[-1] == os.sep:\n                data_dir = data_dir[:-1]\n\n            raw_data_foldername = raw_data_foldername_prefix + \"_\" + data_dir.split(os.sep)[-1]\n            raw_data_foldername = os.path.join(self.nnunet_raw, raw_data_foldername)\n            if not os.path.exists(raw_data_foldername):\n                os.makedirs(raw_data_foldername)\n\n            from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name\n\n            self.dataset_name = maybe_convert_to_dataset_name(int(self.dataset_name_or_id))\n\n            datalist_json = ConfigParser.load_config_file(self.input_info.pop(\"datalist\"))\n\n            if \"training\" in datalist_json:\n                os.makedirs(os.path.join(raw_data_foldername, \"imagesTr\"))\n                os.makedirs(os.path.join(raw_data_foldername, \"labelsTr\"))\n            else:\n                logger.error(\"The datalist file has incorrect format: the `training` key is not found.\")\n                return\n\n            test_key = None\n            if \"test\" in datalist_json or \"testing\" in datalist_json:\n                os.makedirs(os.path.join(raw_data_foldername, \"imagesTs\"))\n                test_key = \"test\" if \"test\" in datalist_json else \"testing\"\n                if isinstance(datalist_json[test_key][0], dict) and \"label\" in datalist_json[test_key][0]:\n                    os.makedirs(os.path.join(raw_data_foldername, \"labelsTs\"))\n\n            num_input_channels, num_foreground_classes = analyze_data(datalist_json=datalist_json, data_dir=data_dir)\n\n            modality = self.input_info.pop(\"modality\")\n            if not isinstance(modality, list):\n                modality = [modality]\n\n            create_new_dataset_json(\n                modality=modality,\n                num_foreground_classes=num_foreground_classes,\n                num_input_channels=num_input_channels,\n                num_training_data=len(datalist_json[\"training\"]),\n                output_filepath=os.path.join(raw_data_foldername, \"dataset.json\"),\n            )\n\n            create_new_data_copy(\n                test_key=test_key,  # type: ignore\n                datalist_json=datalist_json,\n                data_dir=data_dir,\n                num_input_channels=num_input_channels,\n                output_datafolder=raw_data_foldername,\n            )\n        except BaseException as err:\n            logger.warning(f\"Input config may be incorrect. Detail info: error/exception message is:\\n {err}\")\n            return\n\n    def convert_msd_dataset(self, data_dir: str, overwrite_id: str | None = None, n_proc: int = -1) -> None:\n        \"\"\"\n        Convert and make a copy the MSD dataset to meet requirements of nnU-Net workflow.\n\n        Args:\n            data_dir: downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset!\n                Example: \"/workspace/downloads/Task05_Prostate\".\n            overwrite_id: Overwrite the dataset id. If not set then use the id of the MSD task (inferred from\n                the folder name). Only use this if you already have an equivalently numbered dataset!\n            n_proc: Number of processes used.\n        \"\"\"\n        from nnunetv2.dataset_conversion.convert_MSD_dataset import convert_msd_dataset\n\n        num_processes = None if n_proc < 0 else self.default_num_processes\n        convert_msd_dataset(data_dir, overwrite_id, num_processes)\n\n    def extract_fingerprints(\n        self,\n        fpe: str = \"DatasetFingerprintExtractor\",\n        npfp: int = -1,\n        verify_dataset_integrity: bool = False,\n        clean: bool = False,\n        verbose: bool = False,\n    ) -> None:\n        \"\"\"\n        Extracts the dataset fingerprint used for experiment planning.\n\n        Args:\n            fpe: [OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is\n                \"DatasetFingerprintExtractor\".\n            npfp: [OPTIONAL] Number of processes used for fingerprint extraction.\n            verify_dataset_integrity: [RECOMMENDED] set this flag to check the dataset integrity. This is\n                useful and should be done once for each dataset!\n            clean: [OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a\n                fingerprint already exists, the fingerprint extractor will not run.\n            verbose: set this to print a lot of stuff. Useful for debugging. Will disable progress bar!\n                Recommended for cluster environments.\n        \"\"\"\n        from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints\n\n        npfp = self.default_num_processes if npfp < 0 else npfp\n\n        logger.info(\"Fingerprint extraction...\")\n        extract_fingerprints([int(self.dataset_name_or_id)], fpe, npfp, verify_dataset_integrity, clean, verbose)\n\n    def plan_experiments(\n        self,\n        pl: str = \"ExperimentPlanner\",\n        gpu_memory_target: float = 8,\n        preprocessor_name: str = \"DefaultPreprocessor\",\n        overwrite_target_spacing: Any = None,\n        overwrite_plans_name: str = \"nnUNetPlans\",\n    ) -> None:\n        \"\"\"\n        Generate a configuration file that specifies the details of the experiment.\n\n        Args:\n            pl: [OPTIONAL] Name of the Experiment Planner class that should be used. Default is \"ExperimentPlanner\".\n                Note: There is no longer a distinction between 2d and 3d planner. It's an all-in-one solution now.\n            gpu_memory_target: [OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB].\n                Changing this will affect patch and batch size and will definitely affect your models' performance!\n                Only use this if you really know what you are doing and NEVER use this without running the\n                default nnU-Net first (as a baseline).\n            preprocessor_name: [OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in\n                nnunetv2.preprocessing. Default: \"DefaultPreprocessor\". Changing this may affect your models'\n                performance! Only use this if you really know what you are doing and NEVER use this without running the\n                default nnU-Net first (as a baseline).\n            overwrite_target_spacing: [OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres\n                and 3d_cascade_fullres configurations. Default: None [no changes]. Changing this will affect\n                image size and potentially patch and batch size. This will definitely affect your models' performance!\n                Only use this if you really know what you are doing and NEVER use this without running the\n                default nnU-Net first (as a baseline). Changing the target spacing for the other configurations\n                is currently not implemented. New target spacing must be a list of three numbers!\n            overwrite_plans_name: [OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or\n                -overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate\n                a differently named plans file such that the nnunet default plans are not overwritten.\n                You will then need to specify your custom plan.\n        \"\"\"\n        from nnunetv2.experiment_planning.plan_and_preprocess_api import plan_experiments\n\n        logger.info(\"Experiment planning...\")\n        plan_experiments(\n            [int(self.dataset_name_or_id)],\n            pl,\n            gpu_memory_target,\n            preprocessor_name,\n            overwrite_target_spacing,\n            overwrite_plans_name,\n        )\n\n    def preprocess(\n        self,\n        c: tuple = (M.N_2D, M.N_3D_FULLRES, M.N_3D_LOWRES),\n        n_proc: tuple = (8, 8, 8),\n        overwrite_plans_name: str = \"nnUNetPlans\",\n        verbose: bool = False,\n    ) -> None:\n        \"\"\"\n        Apply a set of preprocessing operations to the input data before the training.\n\n        Args:\n            overwrite_plans_name: [OPTIONAL] You can use this to specify a custom plans file that you may have\n                generated.\n            c: [OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3f_fullres\n                3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data\n                from 3f_fullres. Configurations that do not exist for some datasets will be skipped).\n            n_proc: [OPTIONAL] Use this to define how many processes are to be used. If this is just one number then\n                this number of processes is used for all configurations specified with -c. If it's a\n                list of numbers this list must have as many elements as there are configurations. We\n                then iterate over zip(configs, num_processes) to determine the number of processes\n                used for each configuration. More processes are always faster (up to the number of\n                threads your PC can support, so 8 for a 4-core CPU with hyperthreading. If you don't\n                know what that is then don't touch it, or at least don't increase it!). DANGER: More\n                often than not the number of processes that can be used is limited by the amount of\n                RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND\n                DECREASE -n_proc IF YOUR RAM FILLS UP TOO MUCH! Default: 8 4 8 (=8 processes for 2d, 4\n                for 3d_fullres and 8 for 3d_lowres if -c is at its default).\n            verbose: Set this to print a lot of stuff. Useful for debugging. Will disable the progress bar!\n                Recommended for cluster environments.\n        \"\"\"\n        from nnunetv2.experiment_planning.plan_and_preprocess_api import preprocess\n\n        logger.info(\"Preprocessing...\")\n        preprocess(\n            [int(self.dataset_name_or_id)],\n            overwrite_plans_name,\n            configurations=c,\n            num_processes=n_proc,\n            verbose=verbose,\n        )\n\n    def plan_and_process(\n        self,\n        fpe: str = \"DatasetFingerprintExtractor\",\n        npfp: int = 8,\n        verify_dataset_integrity: bool = False,\n        no_pp: bool = False,\n        clean: bool = False,\n        pl: str = \"ExperimentPlanner\",\n        gpu_memory_target: int = 8,\n        preprocessor_name: str = \"DefaultPreprocessor\",\n        overwrite_target_spacing: Any = None,\n        overwrite_plans_name: str = \"nnUNetPlans\",\n        c: tuple = (M.N_2D, M.N_3D_FULLRES, M.N_3D_LOWRES),\n        n_proc: tuple = (8, 8, 8),\n        verbose: bool = False,\n    ) -> None:\n        \"\"\"\n        Performs experiment planning and preprocessing before the training.\n\n        Args:\n            fpe: [OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is\n                \"DatasetFingerprintExtractor\".\n            npfp: [OPTIONAL] Number of processes used for fingerprint extraction. Default: 8.\n            verify_dataset_integrity: [RECOMMENDED] set this flag to check the dataset integrity.\n                This is useful and should be done once for each dataset!\n            no_pp: [OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no\n                preprocessing). Useful for debugging.\n            clean:[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a\n                fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU\n                CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!\n            pl: [OPTIONAL] Name of the Experiment Planner class that should be used. Default is \"ExperimentPlanner\".\n                Note: There is no longer a distinction between 2d and 3d planner. It's an all-in-one solution now.\n            gpu_memory_target: [OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB].\n                Changing this will affect patch and batch size and will\n                definitely affect your models' performance! Only use this if you really know what you\n                are doing and NEVER use this without running the default nnU-Net first (as a baseline).\n            preprocessor_name: [OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in\n                nnunetv2.preprocessing. Default: \"DefaultPreprocessor\". Changing this may affect your\n                models' performance! Only use this if you really know what you\n                are doing and NEVER use this without running the default nnU-Net first (as a baseline).\n            overwrite_target_spacing: [OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and\n                3d_cascade_fullres configurations. Default: None [no changes]. Changing this will affect image size and\n                potentially patch and batch size. This will definitely affect your models performance!\n                Only use this if you really know what you are doing and NEVER use this without running the\n                default nnU-Net first (as a baseline). Changing the target spacing for the other\n                configurations is currently not implemented. New target spacing must be a list of three numbers!\n            overwrite_plans_name: [OPTIONAL] USE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target,\n                -preprocessor_name or -overwrite_target_spacing it is best practice to use -overwrite_plans_name to\n                generate a differently named plans file such that the nnunet default plans are not\n                overwritten. You will then need to specify your custom plans file with -p whenever\n                running other nnunet commands (training, inference, etc)\n            c: [OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3f_fullres\n                3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data\n                from 3f_fullres. Configurations that do not exist for some datasets will be skipped.\n            n_proc: [OPTIONAL] Use this to define how many processes are to be used. If this is just one number then\n                this number of processes is used for all configurations specified with -c. If it's a\n                list of numbers this list must have as many elements as there are configurations. We\n                then iterate over zip(configs, num_processes) to determine the number of processes\n                used for each configuration. More processes are always faster (up to the number of\n                threads your PC can support, so 8 for a 4-core CPU with hyperthreading. If you don't\n                know what that is then don't touch it, or at least don't increase it!). DANGER: More\n                often than not the number of processes that can be used is limited by the amount of\n                RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND\n                DECREASE -n_proc IF YOUR RAM FILLS UP TOO MUCH! Default: 8 4 8 (=8 processes for 2d, 4\n                for 3d_fullres and 8 for 3d_lowres if -c is at its default).\n            verbose: Set this to print a lot of stuff. Useful for debugging. Will disable progress bar!\n                (Recommended for cluster environments).\n        \"\"\"\n        self.extract_fingerprints(fpe, npfp, verify_dataset_integrity, clean, verbose)\n        self.plan_experiments(pl, gpu_memory_target, preprocessor_name, overwrite_target_spacing, overwrite_plans_name)\n\n        if not no_pp:\n            self.preprocess(c, n_proc, overwrite_plans_name, verbose)\n\n    def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int | str = 0, **kwargs: Any) -> None:\n        \"\"\"\n        Run the training on a single GPU with one specified configuration provided.\n        Note: if CUDA_VISIBLE_DEVICES is already set and gpu_id resolves to 0, the existing value is preserved;\n        otherwise it is set to gpu_id.\n\n        Args:\n            config: configuration that should be trained. Examples: \"2d\", \"3d_fullres\", \"3d_lowres\".\n            fold: fold of the 5-fold cross-validation. Should be an int between 0 and 4.\n            gpu_id: an int, MIG UUID (str), or tuple/list of GPU indices for multi-GPU training (e.g., (0,1)). Default: 0.\n            kwargs: this optional parameter allows you to specify additional arguments in\n                ``nnunetv2.run.run_training.run_training_entry``.\n\n                Currently supported args are:\n\n                - p: custom plans identifier. Default: \"nnUNetPlans\".\n                - pretrained_weights: path to nnU-Net checkpoint file to be used as pretrained model. Will only be\n                    used when actually training. Beta. Use with caution. Default: False.\n                - use_compressed: True to use compressed data for training. Reading compressed data is much\n                    more CPU and (potentially) RAM intensive and should only be used if you know what you are\n                    doing. Default: False.\n                - c: continue training from latest checkpoint. Default: False.\n                - val: True to run the validation only. Requires training to have finished.\n                    Default: False.\n                - disable_checkpointing: True to disable checkpointing. Ideal for testing things out and you\n                    don't want to flood your hard drive with checkpoints. Default: False.\n        \"\"\"\n        if \"num_gpus\" in kwargs:\n            kwargs.pop(\"num_gpus\")\n            logger.warning(\"please use gpu_id to set the GPUs to use\")\n\n        if \"tr\" in kwargs:\n            kwargs.pop(\"tr\")\n            logger.warning(\"please specify the `trainer_class_name` in the __init__ of `nnUNetV2Runner`.\")\n\n        if \"npz\" in kwargs:\n            kwargs.pop(\"npz\")\n            logger.warning(\"please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.\")\n\n        cmd, env = self.train_single_model_command(config, fold, gpu_id, kwargs)\n        run_cmd(cmd, env=env)\n\n    def train_single_model_command(\n        self, config: str, fold: int, gpu_id: int | str | tuple | list, kwargs: dict[str, Any]\n    ) -> tuple[list[str], dict[str, str]]:\n        \"\"\"\n        Build the shell command string for training a single nnU-Net model.\n\n        Args:\n            config: Configuration name (e.g., \"3d_fullres\").\n            fold: Cross-validation fold index (0-4).\n            gpu_id: Device selector—int, str (MIG UUID), or tuple/list for multi-GPU.\n            kwargs: Additional CLI arguments forwarded to nnUNetv2_train.\n\n        Returns:\n            Tuple of (cmd, env) where cmd is a list[str] of argv entries and env is a dict[str, str]\n            passed to the subprocess.\n\n        Raises:\n            ValueError: If gpu_id is an empty tuple or list.\n        \"\"\"\n        env = os.environ.copy()\n        device_setting: str = \"0\"\n        num_gpus = 1\n        if isinstance(gpu_id, str):\n            device_setting = gpu_id\n            num_gpus = 1\n        elif isinstance(gpu_id, (tuple, list)):\n            if len(gpu_id) == 0:\n                raise ValueError(\"gpu_id tuple/list cannot be empty\")\n            if len(gpu_id) > 1:\n                device_setting = \",\".join(str(x) for x in gpu_id)\n                num_gpus = len(gpu_id)\n            elif len(gpu_id) == 1:\n                device_setting = str(gpu_id[0])\n                num_gpus = 1\n        else:\n            device_setting = str(gpu_id)\n            num_gpus = 1\n        env_cuda = env.get(\"CUDA_VISIBLE_DEVICES\")\n        if env_cuda is not None and device_setting == \"0\":\n            logger.info(f\"Using existing environment variable CUDA_VISIBLE_DEVICES='{env_cuda}'\")\n        else:\n            env[\"CUDA_VISIBLE_DEVICES\"] = device_setting\n\n        cmd = [\n            \"nnUNetv2_train\",\n            f\"{self.dataset_name_or_id}\",\n            f\"{config}\",\n            f\"{fold}\",\n            \"-tr\",\n            f\"{self.trainer_class_name}\",\n            \"-num_gpus\",\n            f\"{num_gpus}\",\n        ]\n        if self.export_validation_probabilities:\n            cmd.append(\"--npz\")\n        for _key, _value in kwargs.items():\n            if _key == \"p\" or _key == \"pretrained_weights\":\n                cmd.extend([f\"-{_key}\", f\"{_value}\"])\n            else:\n                cmd.extend([f\"--{_key}\", f\"{_value}\"])\n        return cmd, env\n\n    def train(\n        self,\n        configs: tuple | str = (M.N_3D_FULLRES, M.N_2D, M.N_3D_LOWRES, M.N_3D_CASCADE_FULLRES),\n        gpu_id_for_all: tuple | list | int | None = None,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Run the training for all the models specified by the configurations.\n        Note: to set the number of GPUs to use, use ``gpu_id_for_all`` instead of the `CUDA_VISIBLE_DEVICES`\n        environment variable.\n\n        Args:\n            configs: configurations that should be trained.\n                Default: (\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\").\n            gpu_id_for_all: a tuple/list/integer of GPU device ID(s) to use for the training. Default:\n                None (all available GPUs).\n            kwargs: this optional parameter allows you to specify additional arguments defined in the\n                ``train_single_model`` method.\n        \"\"\"\n        if gpu_id_for_all is None:\n            result = subprocess.run([\"nvidia-smi\", \"--list-gpus\"], stdout=subprocess.PIPE)\n            output = result.stdout.decode(\"utf-8\")\n            num_gpus = len(output.strip().split(\"\\n\"))\n            gpu_id_for_all = tuple(range(num_gpus))\n        elif isinstance(gpu_id_for_all, int):\n            gpu_id_for_all = ensure_tuple(gpu_id_for_all)\n        logger.info(f\"number of GPUs is {len(gpu_id_for_all)}, device ids are {gpu_id_for_all}\")\n        if len(gpu_id_for_all) > 1:\n            self.train_parallel(configs=ensure_tuple(configs), gpu_id_for_all=gpu_id_for_all, **kwargs)\n        else:\n            for cfg in ensure_tuple(configs):\n                for _fold in range(self.num_folds):\n                    self.train_single_model(config=cfg, fold=_fold, gpu_id=gpu_id_for_all, **kwargs)\n\n    def train_parallel_cmd(\n        self,\n        configs: tuple | str = (M.N_3D_FULLRES, M.N_2D, M.N_3D_LOWRES, M.N_3D_CASCADE_FULLRES),\n        gpu_id_for_all: tuple | list | int | None = None,\n        **kwargs: Any,\n    ) -> list:\n        \"\"\"\n        Create the line command for subprocess call for parallel training.\n\n        Args:\n            configs: configurations that should be trained.\n                Default: (\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\").\n            gpu_id_for_all: a tuple/list/integer of GPU device ID(s) to use for the training. Default:\n                None (all available GPUs).\n            kwargs: this optional parameter allows you to specify additional arguments defined in the\n                ``train_single_model`` method.\n        \"\"\"\n        # unpack compressed files\n        folder_names = []\n        for root, _, files in os.walk(os.path.join(self.nnunet_preprocessed, self.dataset_name)):\n            if any(file.endswith(\".npz\") for file in files):\n                folder_names.append(root)\n\n        from nnunetv2.training.dataloading.utils import unpack_dataset\n\n        for folder_name in folder_names:\n            logger.info(f\"unpacking '{folder_name}'...\")\n            unpack_dataset(\n                folder=folder_name,\n                unpack_segmentation=True,\n                overwrite_existing=False,\n                num_processes=self.default_num_processes,\n            )\n\n        # model training\n        kwargs = kwargs or {}\n        devices = ensure_tuple(gpu_id_for_all)\n        n_devices = len(devices)\n        _configs = [[M.N_3D_FULLRES, M.N_2D, M.N_3D_LOWRES], [M.N_3D_CASCADE_FULLRES]]\n        all_cmds: list = []\n        for _stage in range(len(_configs)):\n            all_cmds.append({_j: [] for _j in devices})\n            _index = 0\n\n            for _config in _configs[_stage]:\n                if _config in ensure_tuple(configs):\n                    for _i in range(self.num_folds):\n                        the_device = gpu_id_for_all[_index % n_devices]  # type: ignore\n                        cmd, env = self.train_single_model_command(_config, _i, the_device, kwargs)\n                        all_cmds[-1][the_device].append((cmd, env))\n                        _index += 1\n        return all_cmds\n\n    def train_parallel(\n        self,\n        configs: tuple | str = (M.N_3D_FULLRES, M.N_2D, M.N_3D_LOWRES, M.N_3D_CASCADE_FULLRES),\n        gpu_id_for_all: tuple | list | int | None = None,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Create the line command for subprocess call for parallel training.\n        Note: to set the number of GPUs to use, use ``gpu_id_for_all`` instead of the `CUDA_VISIBLE_DEVICES`\n        environment variable.\n\n        Args:\n            configs: configurations that should be trained.\n                default: (\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\").\n            gpu_id_for_all: a tuple/list/integer of GPU device ID(s) to use for the training. Default:\n                None (all available GPUs).\n            kwargs: this optional parameter allows you to specify additional arguments defined in the\n                ``train_single_model`` method.\n        \"\"\"\n        all_cmds = self.train_parallel_cmd(configs=configs, gpu_id_for_all=gpu_id_for_all, **kwargs)\n        for s, cmds in enumerate(all_cmds):\n            for gpu_id, gpu_cmd in cmds.items():\n                if not gpu_cmd:\n                    continue\n                cmds_for_log = [shlex.join(cmd) for cmd, _ in gpu_cmd]\n                logger.info(\n                    f\"training - stage {s + 1}:\\n\"\n                    f\"for gpu {gpu_id}, commands: {cmds_for_log}\\n\"\n                    f\"log '.txt' inside '{os.path.join(self.nnunet_results, self.dataset_name)}'\"\n                )\n        for stage in all_cmds:\n            processes = []\n            for device_id in stage:\n                if not stage[device_id]:\n                    continue\n                cmd_str = \"; \".join(shlex.join(cmd) for cmd, _ in stage[device_id])\n                env = stage[device_id][0][1]\n                logger.info(f\"Current running command on GPU device {device_id}:\\n{cmd_str}\\n\")\n                processes.append(subprocess.Popen(cmd_str, shell=True, env=env, stdout=subprocess.DEVNULL))\n            # finish this stage first\n            for p in processes:\n                p.wait()\n\n    def validate_single_model(self, config: str, fold: int, **kwargs: Any) -> None:\n        \"\"\"\n        Perform validation on single model.\n\n        Args:\n            config: configuration that should be trained.\n            fold: fold of the 5-fold cross-validation. Should be an int between 0 and 4.\n            kwargs: this optional parameter allows you to specify additional arguments defined in the\n                ``train_single_model`` method.\n        \"\"\"\n        self.train_single_model(config=config, fold=fold, only_run_validation=True, **kwargs)\n\n    def validate(\n        self, configs: tuple = (M.N_3D_FULLRES, M.N_2D, M.N_3D_LOWRES, M.N_3D_CASCADE_FULLRES), **kwargs: Any\n    ) -> None:\n        \"\"\"\n        Perform validation in all models defined by the configurations over 5 folds.\n\n        Args:\n            configs: configurations that should be trained.\n                default: (\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\").\n            kwargs: this optional parameter allows you to specify additional arguments defined in the\n                ``train_single_model`` method.\n        \"\"\"\n        for cfg in ensure_tuple(configs):\n            for _fold in range(self.num_folds):\n                self.validate_single_model(config=cfg, fold=_fold, **kwargs)\n\n    def find_best_configuration(\n        self,\n        plans: tuple | str = \"nnUNetPlans\",\n        configs: tuple | str = (M.N_2D, M.N_3D_FULLRES, M.N_3D_LOWRES, M.N_3D_CASCADE_FULLRES),\n        trainers: tuple | str | None = None,\n        allow_ensembling: bool = True,\n        num_processes: int = -1,\n        overwrite: bool = True,\n        folds: list[int] | tuple[int, ...] = (0, 1, 2, 3, 4),\n        strict: bool = False,\n    ) -> None:\n        \"\"\"\n        Find the best model configurations.\n\n        Args:\n            plans: list of plan identifiers. Default: nnUNetPlans.\n            configs: list of configurations. Default: [\"2d\", \"3d_fullres\", \"3d_lowres\", \"3d_cascade_fullres\"].\n            trainers: list of trainers. Default: nnUNetTrainer.\n            allow_ensembling: set this flag to enable ensembling.\n            num_processes: number of processes to use for ensembling, postprocessing, etc.\n            overwrite: if set we will overwrite already ensembled files etc. May speed up consecutive\n                runs of this command (not recommended) at the risk of not updating outdated results.\n            folds: folds to use. Default: (0, 1, 2, 3, 4).\n            strict: a switch that triggers RunTimeError if the logging folder cannot be found. Default: False.\n        \"\"\"\n        from nnunetv2.evaluation.find_best_configuration import (\n            dumb_trainer_config_plans_to_trained_models_dict,\n            find_best_configuration,\n        )\n\n        configs = ensure_tuple(configs)\n        plans = ensure_tuple(plans)\n\n        if trainers is None:\n            trainers = self.trainer_class_name\n        trainers = ensure_tuple(trainers)\n\n        models = dumb_trainer_config_plans_to_trained_models_dict(trainers, configs, plans)\n        num_processes = self.default_num_processes if num_processes < 0 else num_processes\n        find_best_configuration(\n            int(self.dataset_name_or_id),\n            models,\n            allow_ensembling=allow_ensembling,\n            num_processes=num_processes,\n            overwrite=overwrite,\n            folds=folds,\n            strict=strict,\n        )\n\n    def predict(\n        self,\n        list_of_lists_or_source_folder: str | list[list[str]],\n        output_folder: str | None | list[str],\n        model_training_output_dir: str,\n        use_folds: tuple[int, ...] | str | None = None,\n        tile_step_size: float = 0.5,\n        use_gaussian: bool = True,\n        use_mirroring: bool = True,\n        perform_everything_on_gpu: bool = True,\n        verbose: bool = True,\n        save_probabilities: bool = False,\n        overwrite: bool = True,\n        checkpoint_name: str = \"checkpoint_final.pth\",\n        folder_with_segs_from_prev_stage: str | None = None,\n        num_parts: int = 1,\n        part_id: int = 0,\n        num_processes_preprocessing: int = -1,\n        num_processes_segmentation_export: int = -1,\n        gpu_id: int | str = 0,\n    ) -> None:\n        \"\"\"\n        Use this to run inference with nnU-Net. This function is used when you want to manually specify a folder containing\n            a trained nnU-Net model. This is useful when the nnunet environment variables (nnUNet_results) are not set.\n\n        Args:\n            list_of_lists_or_source_folder: input folder. Remember to use the correct channel numberings for\n                your files (_0000 etc). File endings must be the same as the training dataset!\n            output_folder: Output folder. If it does not exist it will be created. Predicted segmentations will\n                have the same name as their source images.\n            model_training_output_dir: folder in which the trained model is. Must have subfolders fold_X for the\n                different folds you trained.\n            use_folds: specify the folds of the trained model that should be used for prediction\n                Default: (0, 1, 2, 3, 4).\n            tile_step_size: step size for sliding window prediction. The larger it is the faster but less accurate\n                the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.\n            use_gaussian: use Gaussian smoothing as test-time augmentation.\n            use_mirroring: use mirroring/flipping as test-time augmentation.\n            verbose: set this if you like being talked to. You will have to be a good listener/reader.\n            save_probabilities: set this to export predicted class \"probabilities\". Required if you want to ensemble\n                multiple configurations.\n            overwrite: overwrite an existing previous prediction (will not overwrite existing files)\n            checkpoint_name: name of the checkpoint you want to use. Default: checkpoint_final.pth.\n            folder_with_segs_from_prev_stage: folder containing the predictions of the previous stage.\n                Required for cascaded models.\n            num_parts: number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one\n                call predicts everything).\n            part_id: if multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with\n                num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts\n                5 and use -part_id 0, 1, 2, 3 and 4.\n            num_processes_preprocessing: out-of-RAM issues.\n            num_processes_segmentation_export: Number of processes used for segmentation export.\n                More is not always better. Beware of out-of-RAM issues.\n            gpu_id: GPU device index (int) or MIG UUID (str) for prediction.\n                If CUDA_VISIBLE_DEVICES is already set and gpu_id is 0, the existing\n                environment variable is preserved.\n        \"\"\"\n        if \"CUDA_VISIBLE_DEVICES\" in os.environ and gpu_id in {0, \"0\"}:\n            logger.info(f\"Predict: Using existing CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}\")\n        else:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = f\"{gpu_id}\"\n\n        from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n\n        n_processes_preprocessing = (\n            self.default_num_processes if num_processes_preprocessing < 0 else num_processes_preprocessing\n        )\n        n_processes_segmentation_export = (\n            self.default_num_processes if num_processes_segmentation_export < 0 else num_processes_segmentation_export\n        )\n        predictor = nnUNetPredictor(\n            tile_step_size=tile_step_size,\n            use_gaussian=use_gaussian,\n            use_mirroring=use_mirroring,\n            perform_everything_on_device=perform_everything_on_gpu,\n            verbose=verbose,\n        )\n        predictor.initialize_from_trained_model_folder(\n            model_training_output_dir=model_training_output_dir, use_folds=use_folds, checkpoint_name=checkpoint_name\n        )\n        predictor.predict_from_files(\n            list_of_lists_or_source_folder=list_of_lists_or_source_folder,\n            output_folder_or_list_of_truncated_output_files=output_folder,\n            save_probabilities=save_probabilities,\n            overwrite=overwrite,\n            num_processes_preprocessing=n_processes_preprocessing,\n            num_processes_segmentation_export=n_processes_segmentation_export,\n            folder_with_segs_from_prev_stage=folder_with_segs_from_prev_stage,\n            num_parts=num_parts,\n            part_id=part_id,\n        )\n\n    def predict_ensemble_postprocessing(\n        self,\n        folds: tuple = (0, 1, 2, 3, 4),\n        run_ensemble: bool = True,\n        run_predict: bool = True,\n        run_postprocessing: bool = True,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Run prediction, ensemble, and/or postprocessing optionally.\n\n        Args:\n            folds: which folds to use\n            run_ensemble: whether to run ensembling.\n            run_predict: whether to predict using trained checkpoints\n            run_postprocessing: whether to conduct post-processing\n            kwargs: this optional parameter allows you to specify additional arguments defined in the\n                ``predict`` method.\n        \"\"\"\n        from nnunetv2.ensembling.ensemble import ensemble_folders\n        from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder\n        from nnunetv2.utilities.file_path_utilities import get_output_folder\n\n        source_dir = join(self.nnunet_raw, self.dataset_name, \"imagesTs\")\n        target_dir_base = join(self.nnunet_results, self.dataset_name)\n\n        self.best_configuration = ConfigParser.load_config_file(\n            os.path.join(self.nnunet_results, self.dataset_name, \"inference_information.json\")\n        )\n\n        run_ensemble = (\n            run_ensemble and len(self.best_configuration[\"best_model_or_ensemble\"][\"selected_model_or_models\"]) > 1\n        )\n\n        used_folds = folds\n        output_folders = []\n        for im in self.best_configuration[\"best_model_or_ensemble\"][\"selected_model_or_models\"]:\n            output_dir = join(target_dir_base, f\"pred_{im['configuration']}\")\n            output_folders.append(output_dir)\n\n            if run_predict:\n                model_folder = get_output_folder(\n                    int(self.dataset_name_or_id), im[\"trainer\"], im[\"plans_identifier\"], im[\"configuration\"]\n                )\n                self.predict(\n                    list_of_lists_or_source_folder=source_dir,\n                    output_folder=output_dir,\n                    model_training_output_dir=model_folder,\n                    use_folds=used_folds,\n                    save_probabilities=run_ensemble,\n                    verbose=False,\n                    overwrite=True,\n                    **kwargs,\n                )\n\n        # if we have an ensemble, we need to ensemble the results\n        if run_ensemble:\n            ensemble_folders(\n                output_folders, join(target_dir_base, \"ensemble_predictions\"), save_merged_probabilities=False\n            )\n            if run_postprocessing:\n                folder_for_pp = join(target_dir_base, \"ensemble_predictions\")\n        elif run_postprocessing:\n            folder_for_pp = output_folders[0]\n\n        # apply postprocessing\n        if run_postprocessing:\n            pp_fns, pp_fn_kwargs = load_pickle(self.best_configuration[\"best_model_or_ensemble\"][\"postprocessing_file\"])\n            apply_postprocessing_to_folder(\n                folder_for_pp,\n                join(target_dir_base, \"ensemble_predictions_postprocessed\"),\n                pp_fns,\n                pp_fn_kwargs,\n                plans_file_or_dict=self.best_configuration[\"best_model_or_ensemble\"][\"some_plans_file\"],\n            )\n\n    def run(\n        self,\n        run_convert_dataset: bool = True,\n        run_plan_and_process: bool = True,\n        run_train: bool = True,\n        run_find_best_configuration: bool = True,\n        run_predict_ensemble_postprocessing: bool = True,\n    ) -> None:\n        \"\"\"\n        Run the nnU-Net pipeline.\n\n        Args:\n            run_convert_dataset: whether to convert datasets, defaults to True.\n            run_plan_and_process: whether to preprocess and analyze the dataset, defaults to True.\n            run_train: whether to train models, defaults to True.\n            run_find_best_configuration: whether to find the best model (ensemble) configurations, defaults to True.\n            run_predict_ensemble_postprocessing: whether to make predictions on test datasets, defaults to True.\n        \"\"\"\n        if run_convert_dataset:\n            self.convert_dataset()\n\n        if run_plan_and_process:\n            self.plan_and_process()\n\n        if run_train:\n            self.train()\n\n        if run_find_best_configuration:\n            self.find_best_configuration()\n\n        if run_predict_ensemble_postprocessing:\n            self.predict_ensemble_postprocessing()\n\n        return\n"
  },
  {
    "path": "monai/apps/nnunet/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nimport os\n\nimport numpy as np\n\nimport monai\nfrom monai.bundle import ConfigParser\nfrom monai.utils import StrEnum, ensure_tuple, optional_import\n\ntqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\nnib, _ = optional_import(\"nibabel\")\n\nlogger = monai.apps.utils.get_logger(__name__)\n\n__all__ = [\"analyze_data\", \"create_new_data_copy\", \"create_new_dataset_json\", \"NNUNETMode\"]\n\n\nclass NNUNETMode(StrEnum):\n    N_2D = \"2d\"\n    N_3D_FULLRES = \"3d_fullres\"\n    N_3D_LOWRES = \"3d_lowres\"\n    N_3D_CASCADE_FULLRES = \"3d_cascade_fullres\"\n\n\ndef analyze_data(datalist_json: dict, data_dir: str) -> tuple[int, int]:\n    \"\"\"\n    Analyze (training) data\n\n    Args:\n        datalist_json: original data list .json (required by most monai tutorials).\n        data_dir: raw data directory.\n    \"\"\"\n    img = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)(\n        os.path.join(data_dir, datalist_json[\"training\"][0][\"image\"])\n    )\n    num_input_channels = img.size()[0] if img.dim() == 4 else 1\n    logger.info(f\"num_input_channels: {num_input_channels}\")\n\n    num_foreground_classes = 0\n    for _i in range(len(datalist_json[\"training\"])):\n        seg = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)(\n            os.path.join(data_dir, datalist_json[\"training\"][_i][\"label\"])\n        )\n        num_foreground_classes = max(num_foreground_classes, int(seg.max()))\n    logger.info(f\"num_foreground_classes: {num_foreground_classes}\")\n\n    return num_input_channels, num_foreground_classes\n\n\ndef create_new_data_copy(\n    test_key: str, datalist_json: dict, data_dir: str, num_input_channels: int, output_datafolder: str\n) -> None:\n    \"\"\"\n    Create and organize a new copy of data to meet the requirements of nnU-Net V2\n\n    Args:\n        test_key: key for test data in the data list .json.\n        datalist_json: original data list .json (required by most monai tutorials).\n        data_dir: raw data directory.\n        num_input_channels: number of input (image) channels.\n        output_datafolder: output folder.\n    \"\"\"\n    _index = 0\n    new_datalist_json: dict = {\"training\": [], test_key: []}\n\n    for _key, _folder, _label_folder in list(\n        zip([\"training\", test_key], [\"imagesTr\", \"imagesTs\"], [\"labelsTr\", \"labelsTs\"])\n    ):\n        if _key is None:\n            continue\n\n        logger.info(f\"converting data section: {_key}...\")\n        for _k in tqdm(range(len(datalist_json[_key]))) if has_tqdm else range(len(datalist_json[_key])):\n            orig_img_name = (\n                datalist_json[_key][_k][\"image\"]\n                if isinstance(datalist_json[_key][_k], dict)\n                else datalist_json[_key][_k]\n            )\n            img_name = f\"case_{_index}\"\n            _index += 1\n\n            # copy image\n            nda = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)(\n                os.path.join(data_dir, orig_img_name)\n            )\n            affine = nda.meta[\"original_affine\"]\n            nda = nda.numpy()\n            for _l in range(num_input_channels):\n                outimg = nib.Nifti1Image(nda[_l, ...], affine)\n                index = \"_\" + str(_l + 10000)[-4:]\n                nib.save(outimg, os.path.join(output_datafolder, _folder, img_name + index + \".nii.gz\"))\n\n            # copy label\n            if isinstance(datalist_json[_key][_k], dict) and \"label\" in datalist_json[_key][_k]:\n                nda = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)(\n                    os.path.join(data_dir, datalist_json[_key][_k][\"label\"])\n                )\n                affine = nda.meta[\"original_affine\"]\n                nda = nda.numpy().astype(np.uint8)\n                nda = nda[0, ...] if nda.ndim == 4 and nda.shape[0] == 1 else nda\n                nib.save(\n                    nib.Nifti1Image(nda, affine), os.path.join(output_datafolder, _label_folder, img_name + \".nii.gz\")\n                )\n\n            if isinstance(datalist_json[_key][_k], dict):\n                _val = copy.deepcopy(datalist_json[_key][_k])\n                _val[\"new_name\"] = img_name\n                new_datalist_json[_key].append(_val)\n            else:\n                new_datalist_json[_key].append({\"image\": datalist_json[_key][_k], \"new_name\": img_name})\n\n            ConfigParser.export_config_file(\n                config=new_datalist_json,\n                filepath=os.path.join(output_datafolder, \"datalist.json\"),\n                fmt=\"json\",\n                sort_keys=True,\n                indent=4,\n                ensure_ascii=False,\n            )\n\n    return\n\n\ndef create_new_dataset_json(\n    modality: str, num_foreground_classes: int, num_input_channels: int, num_training_data: int, output_filepath: str\n) -> None:\n    \"\"\"\n    Create a new copy of dataset .json to meet the requirements of nnU-Net V2\n\n    Args:\n        modality: image modality, could a string or a list of strings.\n        num_foreground_classes: number of foreground classes.\n        num_input_channels: number of input (image) channels.\n        num_training_data: number of training data.\n        output_filepath: output file path/name.\n    \"\"\"\n    new_json_data: dict = {}\n\n    # modality = self.input_info.pop(\"modality\")\n    modality = ensure_tuple(modality)  # type: ignore\n\n    new_json_data[\"channel_names\"] = {}\n    for _j in range(num_input_channels):\n        new_json_data[\"channel_names\"][str(_j)] = modality[_j]\n\n    new_json_data[\"labels\"] = {}\n    new_json_data[\"labels\"][\"background\"] = 0\n    for _j in range(num_foreground_classes):\n        new_json_data[\"labels\"][f\"class{_j + 1}\"] = _j + 1\n\n    # new_json_data[\"numTraining\"] = len(datalist_json[\"training\"])\n    new_json_data[\"numTraining\"] = num_training_data\n    new_json_data[\"file_ending\"] = \".nii.gz\"\n\n    ConfigParser.export_config_file(\n        config=new_json_data,\n        # filepath=os.path.join(raw_data_foldername, \"dataset.json\"),\n        filepath=output_filepath,\n        fmt=\"json\",\n        sort_keys=True,\n        indent=4,\n        ensure_ascii=False,\n    )\n\n    return\n"
  },
  {
    "path": "monai/apps/nuclick/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/nuclick/transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection, NdarrayOrTensor\nfrom monai.networks.layers import GaussianFilter\nfrom monai.transforms import MapTransform, Randomizable, SpatialPad\nfrom monai.utils import StrEnum, convert_to_numpy, optional_import\n\nmeasure, _ = optional_import(\"skimage.measure\")\nmorphology, _ = optional_import(\"skimage.morphology\")\ndistance_transform_cdt, _ = optional_import(\"scipy.ndimage\", name=\"distance_transform_cdt\")\n\n\nclass NuclickKeys(StrEnum):\n    \"\"\"\n    Keys for nuclick transforms.\n    \"\"\"\n\n    IMAGE = \"image\"\n    LABEL = \"label\"\n    OTHERS = \"others\"  # key of other labels from the binary mask which are not being used for training\n    FOREGROUND = \"foreground\"\n\n    CENTROID = \"centroid\"  # key where the centroid values are stored\n    MASK_VALUE = \"mask_value\"\n    LOCATION = \"location\"\n\n    NUC_POINTS = \"nuc_points\"\n    BOUNDING_BOXES = \"bounding_boxes\"\n    IMG_HEIGHT = \"img_height\"\n    IMG_WIDTH = \"img_width\"\n    PRED_CLASSES = \"pred_classes\"\n\n\nclass FlattenLabeld(MapTransform):\n    \"\"\"\n    FlattenLabeld creates labels per closed object contour (defined by a connectivity). For e.g if there are\n    12 small regions of 1's it will delineate them into 12 different label classes\n\n    Args:\n        connectivity: Max no. of orthogonal hops to consider a pixel/voxel as a neighbor. Refer skimage.measure.label\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, connectivity: int = 1, allow_missing_keys: bool = False):\n        super().__init__(keys, allow_missing_keys)\n        self.connectivity = connectivity\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.keys:\n            img = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key]\n            d[key] = measure.label(img, connectivity=self.connectivity).astype(np.uint8)\n        return d\n\n\nclass ExtractPatchd(MapTransform):\n    \"\"\"\n    Extracts a patch from the given image and label, however it is based on the centroid location.\n    The centroid location is a 2D coordinate (H, W). The extracted patch is extracted around the centroid,\n    if the centroid is towards the edge, the centroid will not be the center of the image as the patch will be\n    extracted from the edges onwards\n\n    Args:\n        keys: image, label\n        centroid_key: key where the centroid values are stored, defaults to ``\"centroid\"``\n        patch_size: size of the extracted patch\n        allow_missing_keys: don't raise exception if key is missing.\n        pad_kwargs: other arguments for the SpatialPad transform\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        centroid_key: str = NuclickKeys.CENTROID,\n        patch_size: tuple[int, int] | int = 128,\n        allow_missing_keys: bool = False,\n        **kwargs: Any,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.centroid_key = centroid_key\n        self.patch_size = patch_size\n        self.kwargs = kwargs\n\n    def __call__(self, data):\n        d = dict(data)\n\n        centroid = d[self.centroid_key]  # create mask based on centroid (select nuclei based on centroid)\n        roi_size = (self.patch_size, self.patch_size)\n\n        for key in self.keys:\n            img = d[key]\n            x_start, x_end, y_start, y_end = self.bbox(self.patch_size, centroid, img.shape[-2:])\n            cropped = img[:, x_start:x_end, y_start:y_end]\n            d[key] = SpatialPad(spatial_size=roi_size, **self.kwargs)(cropped)\n        return d\n\n    def bbox(self, patch_size, centroid, size):\n        x, y = centroid\n        m, n = size\n\n        x_start = int(max(x - patch_size / 2, 0))\n        y_start = int(max(y - patch_size / 2, 0))\n        x_end = x_start + patch_size\n        y_end = y_start + patch_size\n        if x_end > m:\n            x_end = m\n            x_start = m - patch_size\n        if y_end > n:\n            y_end = n\n            y_start = n - patch_size\n        return x_start, x_end, y_start, y_end\n\n\nclass SplitLabeld(MapTransform):\n    \"\"\"\n    Extracts a single label from all the given classes, the single label is defined by mask_value, the remaining\n    labels are kept in others\n\n    Args:\n        label: key of the label source\n        others: other labels storage key, defaults to ``\"others\"``\n        mask_value: the mask_value that will be kept for binarization of the label, defaults to ``\"mask_value\"``\n        min_area: The smallest allowable object size.\n        others_value: Value/class for other nuclei;  Use this to separate core nuclei vs others.\n        to_binary_mask: Convert mask to binary;  Set it false to restore original class values\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        others: str = NuclickKeys.OTHERS,\n        mask_value: str | None = NuclickKeys.MASK_VALUE,\n        min_area: int = 5,\n        others_value: int = 0,\n        to_binary_mask: bool = True,\n    ):\n        super().__init__(keys, allow_missing_keys=False)\n        self.others = others\n        self.mask_value = mask_value\n        self.min_area = min_area\n        self.others_value = others_value\n        self.to_binary_mask = to_binary_mask\n\n    def __call__(self, data):\n        d = dict(data)\n\n        if len(self.keys) > 1:\n            print(\"Only 'label' key is supported, more than 1 key was found\")\n            return None\n\n        for key in self.keys:\n            label = d[key] if isinstance(d[key], torch.Tensor) else torch.from_numpy(d[key])\n\n            mask = torch.clone(label)\n            if self.mask_value:\n                mask_value = d[self.mask_value]\n                mask[label != mask_value] = 0\n            else:\n                mask[label >= self.others_value] = 0\n                mask_value = int(torch.max(mask))\n\n            if self.to_binary_mask:\n                mask[mask > 0] = 1\n\n            others = torch.clone(label)\n            others[label == mask_value] = 0\n            others[others > 0] = 1\n            if torch.count_nonzero(others):\n                others = measure.label(convert_to_numpy(others)[0], connectivity=1)\n                others = torch.from_numpy(others)[None]\n\n            label = mask.type(torch.uint8) if isinstance(mask, torch.Tensor) else mask\n            others = others.type(torch.uint8) if isinstance(others, torch.Tensor) else others\n\n            d[key] = label if isinstance(d[key], torch.Tensor) else convert_to_numpy(label)\n            d[self.others] = others if isinstance(d[key], torch.Tensor) else convert_to_numpy(others)\n\n        return d\n\n\nclass FilterImaged(MapTransform):\n    \"\"\"\n    Filters Green and Gray channel of the image using an allowable object size, this pre-processing transform\n    is specific towards NuClick training process. More details can be referred in this paper Koohbanani,\n    Navid Alemi, et al. \"NuClick: a deep learning framework for interactive segmentation of microscopic images.\"\n    Medical Image Analysis 65 (2020): 101771.\n\n    Args:\n        min_size: The smallest allowable object size\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, min_size: int = 500, allow_missing_keys: bool = False):\n        super().__init__(keys, allow_missing_keys)\n        self.min_size = min_size\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.keys:\n            img = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key]\n            d[key] = self.filter(img)\n        return d\n\n    def filter(self, rgb):\n        mask_not_green = self.filter_green_channel(rgb)\n        mask_not_gray = self.filter_grays(rgb)\n        mask_gray_green = mask_not_gray & mask_not_green\n        mask = (\n            self.filter_remove_small_objects(mask_gray_green, min_size=self.min_size)\n            if self.min_size\n            else mask_gray_green\n        )\n\n        return rgb * np.dstack([mask, mask, mask])\n\n    def filter_green_channel(\n        self, img_np, green_thresh=200, avoid_overmask=True, overmask_thresh=90, output_type=\"bool\"\n    ):\n        g = img_np[:, :, 1]\n        gr_ch_mask = (g < green_thresh) & (g > 0)\n        mask_percentage = self.mask_percent(gr_ch_mask)\n        if (mask_percentage >= overmask_thresh) and (green_thresh < 255) and (avoid_overmask is True):\n            new_green_thresh = math.ceil((255 - green_thresh) / 2 + green_thresh)\n            gr_ch_mask = self.filter_green_channel(\n                img_np, new_green_thresh, avoid_overmask, overmask_thresh, output_type\n            )\n        return gr_ch_mask\n\n    def filter_grays(self, rgb, tolerance=15):\n        rg_diff = abs(rgb[:, :, 0] - rgb[:, :, 1]) <= tolerance\n        rb_diff = abs(rgb[:, :, 0] - rgb[:, :, 2]) <= tolerance\n        gb_diff = abs(rgb[:, :, 1] - rgb[:, :, 2]) <= tolerance\n        return ~(rg_diff & rb_diff & gb_diff)\n\n    def mask_percent(self, img_np):\n        if (len(img_np.shape) == 3) and (img_np.shape[2] == 3):\n            np_sum = img_np[:, :, 0] + img_np[:, :, 1] + img_np[:, :, 2]\n            mask_percentage = 100 - np.count_nonzero(np_sum) / np_sum.size * 100\n        else:\n            mask_percentage = 100 - np.count_nonzero(img_np) / img_np.size * 100\n        return mask_percentage\n\n    def filter_remove_small_objects(self, img_np, min_size=3000, avoid_overmask=True, overmask_thresh=95):\n        rem_sm = morphology.remove_small_objects(img_np.astype(bool), min_size=min_size)\n        mask_percentage = self.mask_percent(rem_sm)\n        if (mask_percentage >= overmask_thresh) and (min_size >= 1) and (avoid_overmask is True):\n            new_min_size = round(min_size / 2)\n            rem_sm = self.filter_remove_small_objects(img_np, new_min_size, avoid_overmask, overmask_thresh)\n        return rem_sm\n\n\nclass AddPointGuidanceSignald(Randomizable, MapTransform):\n    \"\"\"\n    Adds Guidance Signal to the input image\n\n    Args:\n        image: key of source image, defaults to ``\"image\"``\n        label: key of source label, defaults to ``\"label\"``\n        others: source others (other labels from the binary mask which are not being used for training)\n            defaults to ``\"others\"``\n        drop_rate: probability of dropping the signal, defaults to ``0.5``\n        jitter_range: noise added to the points in the point mask for exclusion mask, defaults to ``3``\n        gaussian: add gaussian\n        sigma: sigma value for gaussian\n        truncated: spreads how many stds for gaussian\n        add_exclusion_map: add exclusion map/signal\n    \"\"\"\n\n    def __init__(\n        self,\n        image: str = NuclickKeys.IMAGE,\n        label: str = NuclickKeys.LABEL,\n        others: str = NuclickKeys.OTHERS,\n        drop_rate: float = 0.5,\n        jitter_range: int = 0,\n        gaussian: bool = False,\n        sigma: float = 1.0,\n        truncated: float = 2.0,\n        add_exclusion_map: bool = True,\n        use_distance: bool = False,\n    ):\n        MapTransform.__init__(self, image)\n\n        self.image = image\n        self.label = label\n        self.others = others\n        self.drop_rate = drop_rate\n        self.jitter_range = jitter_range\n        self.gaussian = gaussian\n        self.sigma = sigma\n        self.truncated = truncated\n        self.add_exclusion_map = add_exclusion_map\n        self.use_distance = use_distance\n\n    def __call__(self, data):\n        d = dict(data)\n\n        image = d[self.image] if isinstance(d[self.image], torch.Tensor) else torch.from_numpy(d[self.image])\n        mask = d[self.label] if isinstance(d[self.label], torch.Tensor) else torch.from_numpy(d[self.label])\n\n        inc_sig = self.inclusion_map(mask[0], dtype=image.dtype)\n        inc_sig = self._apply_gaussian(inc_sig)\n        if self.add_exclusion_map:\n            others = d[self.others] if isinstance(d[self.others], torch.Tensor) else torch.from_numpy(d[self.others])\n            exc_sig = self.exclusion_map(\n                others[0], dtype=image.dtype, drop_rate=self.drop_rate, jitter_range=self.jitter_range\n            )\n            exc_sig = self._apply_gaussian(exc_sig)\n            image = torch.cat((image, inc_sig[None], exc_sig[None]), dim=0)\n        else:\n            image = torch.cat((image, inc_sig[None]), dim=0)\n\n        d[self.image] = image if isinstance(d[self.image], torch.Tensor) else convert_to_numpy(image)\n        return d\n\n    def _apply_gaussian(self, t):\n        if not self.gaussian or torch.count_nonzero(t) == 0:\n            return t\n        x = GaussianFilter(spatial_dims=2, truncated=self.truncated, sigma=self.sigma)(t.unsqueeze(0).unsqueeze(0))\n        return x.squeeze(0).squeeze(0)\n\n    def _seed_point(self, label):\n        if distance_transform_cdt is None or not self.use_distance:\n            indices: NdarrayOrTensor\n            if hasattr(torch, \"argwhere\"):\n                indices = torch.argwhere(label > 0)\n            else:\n                indices = np.argwhere(convert_to_numpy(label) > 0)\n\n            if len(indices) > 0:\n                index = self.R.randint(0, len(indices))\n                return indices[index, 0], indices[index, 1]\n            return None\n\n        distance = distance_transform_cdt(label).flatten()\n        probability = np.exp(distance) - 1.0\n\n        idx = np.where(label.flatten() > 0)[0]\n        seed = self.R.choice(idx, size=1, p=probability[idx] / np.sum(probability[idx]))\n        g = np.asarray(np.unravel_index(seed, label.shape)).transpose().tolist()[0]\n        return g[-2], g[-1]\n\n    def inclusion_map(self, mask, dtype):\n        point_mask = torch.zeros_like(mask, dtype=dtype)\n        pt = self._seed_point(mask)\n        if pt is not None:\n            point_mask[pt[0], pt[1]] = 1\n\n        return point_mask\n\n    def exclusion_map(self, others, dtype, jitter_range, drop_rate):\n        point_mask = torch.zeros_like(others, dtype=dtype)\n        if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):\n            return point_mask\n\n        max_x = point_mask.shape[0] - 1\n        max_y = point_mask.shape[1] - 1\n        stats = measure.regionprops(convert_to_numpy(others))\n        for stat in stats:\n            if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):\n                continue\n\n            # random jitter\n            x, y = stat.centroid\n            x = int(math.floor(x))\n            y = int(math.floor(y))\n            if jitter_range:\n                x = x + self.R.randint(low=-jitter_range, high=jitter_range)\n                y = y + self.R.randint(low=-jitter_range, high=jitter_range)\n                x = min(max(0, x), max_x)\n                y = min(max(0, y), max_y)\n            point_mask[x, y] = 1\n\n        return point_mask\n\n\nclass AddClickSignalsd(MapTransform):\n    \"\"\"\n    Adds Click Signal to the input image\n\n    Args:\n        image: source image, defaults to ``\"image\"``\n        foreground: 2D click indices as list, defaults to ``\"foreground\"``\n        bb_size: single integer size, defines a bounding box like (bb_size, bb_size)\n        gaussian: add gaussian\n        sigma: sigma value for gaussian\n        truncated: spreads how many stds for gaussian\n        add_exclusion_map: add exclusion map/signal\n    \"\"\"\n\n    def __init__(\n        self,\n        image: str = NuclickKeys.IMAGE,\n        foreground: str = NuclickKeys.FOREGROUND,\n        bb_size: int = 128,\n        gaussian: bool = False,\n        sigma: float = 1.0,\n        truncated: float = 2.0,\n        add_exclusion_map: bool = True,\n    ):\n        self.image = image\n        self.foreground = foreground\n        self.bb_size = bb_size\n        self.gaussian = gaussian\n        self.sigma = sigma\n        self.truncated = truncated\n        self.add_exclusion_map = add_exclusion_map\n\n    def __call__(self, data):\n        d = dict(data)\n\n        img = d[self.image] if isinstance(d[self.image], torch.Tensor) else torch.from_numpy(d[self.image])\n        x = img.shape[-2]\n        y = img.shape[-1]\n\n        location = d.get(NuclickKeys.LOCATION.value, (0, 0))\n        tx, ty = location[0], location[1]\n        pos = d.get(self.foreground)\n        pos = (np.array(pos) - (tx, ty)).astype(int).tolist() if pos else []\n\n        cx = [xy[0] for xy in pos]\n        cy = [xy[1] for xy in pos]\n\n        click_map, bounding_boxes = self.get_clickmap_boundingbox(img, cx=cx, cy=cy, x=x, y=y, bb=self.bb_size)\n        if not bounding_boxes:\n            raise ValueError(\"Failed to create patches from given click points\")\n\n        patches = self.get_patches_and_signals(\n            img=img, click_map=click_map, bounding_boxes=bounding_boxes, cx=cx, cy=cy, x=x, y=y\n        )\n\n        d[NuclickKeys.BOUNDING_BOXES.value] = bounding_boxes\n        d[NuclickKeys.IMG_WIDTH.value] = x\n        d[NuclickKeys.IMG_HEIGHT.value] = y\n\n        d[self.image] = patches if isinstance(d[self.image], torch.Tensor) else convert_to_numpy(patches)\n        return d\n\n    def get_clickmap_boundingbox(self, img, cx, cy, x, y, bb=128):\n        click_map = torch.zeros_like(img[0])\n\n        x_del_indices = {i for i in range(len(cx)) if cx[i] >= x or cx[i] < 0}\n        y_del_indices = {i for i in range(len(cy)) if cy[i] >= y or cy[i] < 0}\n        del_indices = list(x_del_indices.union(y_del_indices))\n        cx = np.delete(cx, del_indices)\n        cy = np.delete(cy, del_indices)\n\n        click_map[cx, cy] = 1\n        bounding_boxes = []\n        for i in range(len(cx)):\n            x_start = max(0, cx[i] - bb // 2)\n            y_start = max(0, cy[i] - bb // 2)\n            x_end = min(x_start + bb, x)\n            y_end = min(y_start + bb, y)\n\n            if x_end - x_start != bb:\n                x_start = x_end - bb\n            if y_end - y_start != bb:\n                y_start = y_end - bb\n            if x_end - x_start == bb and y_end - y_start == bb:\n                bounding_boxes.append([x_start, y_start, x_end, y_end])\n            else:\n                print(f\"Ignore smaller sized bbox ({x_start}, {y_start}, {x_end}, {y_end}) (Min size: {bb}x{bb})\")\n        return click_map, bounding_boxes\n\n    def get_patches_and_signals(self, img, click_map, bounding_boxes, cx, cy, x, y):\n        patches = []\n\n        x_del_indices = {i for i in range(len(cx)) if cx[i] >= x or cx[i] < 0}\n        y_del_indices = {i for i in range(len(cy)) if cy[i] >= y or cy[i] < 0}\n        del_indices = list(x_del_indices.union(y_del_indices))\n        cx = np.delete(cx, del_indices)\n        cy = np.delete(cy, del_indices)\n\n        for i, bounding_box in enumerate(bounding_boxes):\n            x_start = bounding_box[0]\n            y_start = bounding_box[1]\n            x_end = bounding_box[2]\n            y_end = bounding_box[3]\n\n            patch = img[:, x_start:x_end, y_start:y_end]\n\n            this_click_map = torch.zeros_like(img[0])\n            this_click_map[cx[i], cy[i]] = 1\n\n            nuc_points = this_click_map[x_start:x_end, y_start:y_end]\n            nuc_points = self._apply_gaussian(nuc_points)\n\n            if self.add_exclusion_map:\n                others_click_map = ((click_map - this_click_map) > 0).type(img.dtype)\n                other_points = others_click_map[x_start:x_end, y_start:y_end]\n                other_points = self._apply_gaussian(other_points)\n                patches.append(torch.cat([patch, nuc_points[None], other_points[None]]))\n            else:\n                patches.append(torch.cat([patch, nuc_points[None]]))\n\n        return torch.stack(patches)\n\n    def _apply_gaussian(self, t):\n        if not self.gaussian or torch.count_nonzero(t) == 0:\n            return t\n        x = GaussianFilter(spatial_dims=2, truncated=self.truncated, sigma=self.sigma)(t.unsqueeze(0).unsqueeze(0))\n        return x.squeeze(0).squeeze(0)\n\n\nclass PostFilterLabeld(MapTransform):\n    \"\"\"\n    Performs Filtering of Labels on the predicted probability map\n\n    Args:\n        thresh: probability threshold for classifying a pixel as a mask\n        min_size: min_size objects that will be removed from the image, refer skimage remove_small_objects\n        min_hole: min_hole that will be removed from the image, refer skimage remove_small_holes\n        do_reconstruction: Boolean Flag, Perform a morphological reconstruction of an image, refer skimage\n        allow_missing_keys: don't raise exception if key is missing.\n        pred_classes: List of Predicted class for each instance\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        nuc_points: str = NuclickKeys.NUC_POINTS,\n        bounding_boxes: str = NuclickKeys.BOUNDING_BOXES,\n        img_height: str = NuclickKeys.IMG_HEIGHT,\n        img_width: str = NuclickKeys.IMG_WIDTH,\n        thresh: float = 0.33,\n        min_size: int = 10,\n        min_hole: int = 30,\n        do_reconstruction: bool = False,\n        allow_missing_keys: bool = False,\n        pred_classes: str = NuclickKeys.PRED_CLASSES,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.nuc_points = nuc_points\n        self.bounding_boxes = bounding_boxes\n        self.img_height = img_height\n        self.img_width = img_width\n\n        self.thresh = thresh\n        self.min_size = min_size\n        self.min_hole = min_hole\n        self.do_reconstruction = do_reconstruction\n        self.pred_classes = pred_classes\n\n    def __call__(self, data):\n        d = dict(data)\n\n        pred_classes = d.get(self.pred_classes)\n        bounding_boxes = d[self.bounding_boxes]\n        x = d[self.img_width]\n        y = d[self.img_height]\n\n        for key in self.keys:\n            label = d[key].astype(np.uint8)\n            masks = self.post_processing(label, self.thresh, self.min_size, self.min_hole)\n            d[key] = self.gen_instance_map(masks, bounding_boxes, x, y, pred_classes=pred_classes).astype(np.uint8)\n        return d\n\n    def post_processing(self, preds, thresh=0.33, min_size=10, min_hole=30):\n        masks = preds > thresh\n        for i in range(preds.shape[0]):\n            masks[i] = morphology.remove_small_objects(masks[i], min_size=min_size)\n            masks[i] = morphology.remove_small_holes(masks[i], area_threshold=min_hole)\n        return masks\n\n    def gen_instance_map(self, masks, bounding_boxes, x, y, flatten=True, pred_classes=None):\n        instance_map = np.zeros((x, y), dtype=np.uint16)\n        for i, mask in enumerate(masks):\n            bb = bounding_boxes[i]\n            c = pred_classes[i] if pred_classes and i < len(pred_classes) else 1\n            c = c if flatten else i + 1\n\n            this_map = instance_map[bb[0] : bb[2], bb[1] : bb[3]]\n            this_map = np.where(mask > 0, c, this_map)\n            instance_map[bb[0] : bb[2], bb[1] : bb[3]] = this_map\n\n        return instance_map\n\n\nclass AddLabelAsGuidanced(MapTransform):\n    \"\"\"\n    Add Label as new guidance channel\n\n    Args:\n        source: label/source key which gets added as additional guidance channel\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, source: str = \"label\") -> None:\n        super().__init__(keys, allow_missing_keys=False)\n        self.source = source\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.keys:\n            image = d[key] if isinstance(d[key], torch.Tensor) else torch.from_numpy(d[key])\n            label = d[self.source] if isinstance(d[self.source], torch.Tensor) else torch.from_numpy(d[self.source])\n\n            label = label > 0\n            if len(label.shape) < len(image.shape):\n                label = label[None]\n            image = torch.cat([image, label.type(image.dtype)], dim=len(label.shape) - 3)\n            d[key] = image if isinstance(d[key], torch.Tensor) else convert_to_numpy(image)\n        return d\n\n\nclass SetLabelClassd(MapTransform):\n    \"\"\"\n    Assign class value from the labelmap.  This converts multi-dimension tensor to single scalar tensor.\n\n    Args:\n        offset: offset value to be added to the mask value to determine the final class\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, offset: int = -1) -> None:\n        super().__init__(keys, allow_missing_keys=False)\n        self.offset = offset\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.keys:\n            label = d[key] if isinstance(d[key], torch.Tensor) else torch.from_numpy(d[key])\n            mask_value = int(torch.max(label))\n            d[key] = mask_value + self.offset\n        return d\n"
  },
  {
    "path": "monai/apps/pathology/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .losses import HoVerNetLoss\nfrom .metrics import LesionFROC\nfrom .transforms.stain.array import ExtractHEStains, NormalizeHEStains\nfrom .transforms.stain.dictionary import (\n    ExtractHEStainsd,\n    ExtractHEStainsD,\n    ExtractHEStainsDict,\n    NormalizeHEStainsd,\n    NormalizeHEStainsD,\n    NormalizeHEStainsDict,\n)\nfrom .utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask\n"
  },
  {
    "path": "monai/apps/pathology/engines/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .utils import PrepareBatchHoVerNet\n"
  },
  {
    "path": "monai/apps/pathology/engines/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\n\nfrom monai.engines import PrepareBatch, PrepareBatchExtraInput\nfrom monai.utils import ensure_tuple\nfrom monai.utils.enums import HoVerNetBranch\n\n__all__ = [\"PrepareBatchHoVerNet\"]\n\n\nclass PrepareBatchHoVerNet(PrepareBatch):\n    \"\"\"\n    Customized prepare batch callable for trainers or evaluators which support label to be a dictionary.\n    Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).\n    This assumes label is a dictionary.\n\n    Args:\n        extra_keys: If a sequence of strings is provided, values from the input dictionary are extracted from\n            those keys and passed to the network as extra positional arguments.\n    \"\"\"\n\n    def __init__(self, extra_keys: Sequence[str]) -> None:\n        if len(ensure_tuple(extra_keys)) != 2:\n            raise ValueError(f\"length of `extra_keys` should be 2, get {len(ensure_tuple(extra_keys))}\")\n        self.prepare_batch = PrepareBatchExtraInput(extra_keys)\n\n    def __call__(\n        self,\n        batchdata: dict[str, torch.Tensor],\n        device: str | torch.device | None = None,\n        non_blocking: bool = False,\n        **kwargs: Any,\n    ) -> tuple[torch.Tensor, dict[HoVerNetBranch, torch.Tensor]]:\n        \"\"\"\n        Args `batchdata`, `device`, `non_blocking` refer to the ignite API:\n        https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.\n        `kwargs` supports other args for `Tensor.to()` API.\n        \"\"\"\n        image, _label, extra_label, _ = self.prepare_batch(batchdata, device, non_blocking, **kwargs)\n        label = {HoVerNetBranch.NP: _label, HoVerNetBranch.NC: extra_label[0], HoVerNetBranch.HV: extra_label[1]}\n\n        return image, label\n"
  },
  {
    "path": "monai/apps/pathology/handlers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n"
  },
  {
    "path": "monai/apps/pathology/handlers/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Hashable\nfrom typing import Any\n\nfrom monai.config import KeysCollection\nfrom monai.utils import ensure_tuple\n\n\ndef from_engine_hovernet(keys: KeysCollection, nested_key: str) -> Callable[[Any], Any]:\n    \"\"\"\n    Since the output of HoVerNet is a dictionary, this function is to extend `monai.handlers.from_engine`\n    to work with HoVerNet.\n\n    If data is a list of nested dictionaries after decollating, extract nested value with expected keys and\n    construct lists respectively, for example,\n    if data is `[{\"A\": {\"C\": 1, \"D\": 2}, \"B\": {\"C\": 2, \"D\": 2}}, {\"A\":  {\"C\": 3, \"D\": 2}, \"B\":  {\"C\": 4, \"D\": 2}}]`,\n    from_engine_hovernet([\"A\", \"B\"], \"C\"): `([1, 3], [2, 4])`.\n\n    Here is a simple example::\n\n        from monai.handlers import MeanDice, from_engine_hovernet\n\n        metric = MeanDice(\n            include_background=False,\n            output_transform=from_engine_hovernet(keys=[\"pred\", \"label\"], nested_key=HoVerNetBranch.NP.value)\n        )\n\n    Args:\n        keys: specified keys to extract data from dictionary or decollated list of dictionaries.\n        nested_key: specified key to extract nested data from dictionary or decollated list of dictionaries.\n\n    \"\"\"\n    _keys: tuple[Hashable, ...] = ensure_tuple(keys)\n\n    def _wrapper(data):\n        if isinstance(data, dict):\n            return tuple(data[k][nested_key] for k in _keys)\n        if isinstance(data, list) and isinstance(data[0], dict):\n            # if data is a list of dictionaries, extract expected keys and construct lists,\n            ret = [[i[k][nested_key] for i in data] for k in _keys]\n            return tuple(ret) if len(ret) > 1 else ret[0]\n\n    return _wrapper\n"
  },
  {
    "path": "monai/apps/pathology/inferers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .inferer import SlidingWindowHoVerNetInferer\n"
  },
  {
    "path": "monai/apps/pathology/inferers/inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any, Callable\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom monai.inferers import SlidingWindowInferer\nfrom monai.inferers.utils import sliding_window_inference\nfrom monai.utils import BlendMode, PytorchPadMode, look_up_option\n\n__all__ = [\"SlidingWindowHoVerNetInferer\"]\n\n\nclass SlidingWindowHoVerNetInferer(SlidingWindowInferer):\n    \"\"\"\n    Sliding window method for HoVerNet model inference,\n    with `sw_batch_size` windows for every model.forward().\n    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.\n\n    Args:\n        roi_size: the window size to execute SlidingWindow evaluation.\n            If it has non-positive components, the corresponding `inputs` size will be used.\n            if the components of the `roi_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        sw_batch_size: the batch size to run window slices.\n        overlap: Amount of overlap between scans.\n        mode: {``\"constant\"``, ``\"gaussian\"``}\n            How to blend output of overlapping windows. Defaults to ``\"constant\"``.\n\n            - ``\"constant``\": gives equal weight to all predictions.\n            - ``\"gaussian``\": gives less weight to predictions on edges of windows.\n\n        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``\"gaussian\"``.\n            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.\n            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding\n            spatial dimensions.\n        padding_mode: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}\n            Padding mode when ``roi_size`` is larger than inputs. Defaults to ``\"constant\"``\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        cval: fill value for 'constant' padding mode. Default: 0\n        sw_device: device for the window data.\n            By default the device (and accordingly the memory) of the `inputs` is used.\n            Normally `sw_device` should be consistent with the device where `predictor` is defined.\n        device: device for the stitched output prediction.\n            By default the device (and accordingly the memory) of the `inputs` is used. If for example\n            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the\n            `inputs` and `roi_size`. Output is on the `device`.\n        progress: whether to print a tqdm progress bar.\n        cache_roi_weight_map: whether to pre-compute the ROI weight map.\n        cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)\n            when input image volume is larger than this threshold (in pixels/voxels).\n            Otherwise use ``\"device\"``. Thus, the output may end-up on either cpu or gpu.\n        extra_input_padding: the amount of padding for the input image, which is a tuple of even number of pads.\n            Refer to to the `pad` argument of `torch.nn.functional.pad` for more details.\n\n    Note:\n        ``sw_batch_size`` denotes the max number of windows per network inference iteration,\n        not the batch size of inputs.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        roi_size: Sequence[int] | int,\n        sw_batch_size: int = 1,\n        overlap: float = 0.25,\n        mode: BlendMode | str = BlendMode.CONSTANT,\n        sigma_scale: Sequence[float] | float = 0.125,\n        padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,\n        cval: float = 0.0,\n        sw_device: torch.device | str | None = None,\n        device: torch.device | str | None = None,\n        progress: bool = False,\n        cache_roi_weight_map: bool = False,\n        cpu_thresh: int | None = None,\n        extra_input_padding: tuple[int] | None = None,\n    ) -> None:\n        super().__init__(\n            roi_size=roi_size,\n            sw_batch_size=sw_batch_size,\n            overlap=overlap,\n            mode=mode,\n            sigma_scale=sigma_scale,\n            padding_mode=padding_mode,\n            cval=cval,\n            sw_device=sw_device,\n            device=device,\n            progress=progress,\n            cache_roi_weight_map=cache_roi_weight_map,\n            cpu_thresh=cpu_thresh,\n        )\n        self.extra_input_padding = extra_input_padding\n\n    def process_output(self, seg_prob_tuple, window_data, importance_map_):\n        window_shape = window_data.shape[2:]\n        seg_shape = seg_prob_tuple[0].shape[2:]\n\n        window_pad_size = []\n        window_pad_slices = []\n        for window_s, output_s in zip(window_shape, seg_shape):\n            pad_width = max(window_s - output_s, 0)\n            pad_half_1 = pad_width // 2\n            pad_half_2 = pad_width - pad_half_1\n            window_pad_size.extend([pad_half_1, pad_half_2])\n            window_pad_slices.append(slice(pad_half_1, window_s - pad_half_2))\n\n        # Make the padding area of the importance map zero\n        importance_map = torch.zeros(window_shape, dtype=importance_map_.dtype, device=importance_map_.device)\n        importance_map[window_pad_slices] = importance_map_[window_pad_slices]\n\n        seg_prob_tuple = tuple(\n            F.pad(seg_prob, pad=tuple(window_pad_size), mode=self.padding_mode, value=self.cval)\n            for seg_prob in seg_prob_tuple\n        )\n\n        return seg_prob_tuple, importance_map\n\n    def __call__(\n        self,\n        inputs: torch.Tensor,\n        network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n        *args: Any,\n        **kwargs: Any,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:\n        \"\"\"\n\n        Args:\n            inputs: model input data for inference.\n            network: target model to execute inference.\n                supports callables such as ``lambda x: my_torch_model(x, additional_config)``\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n\n        \"\"\"\n\n        device = self.device\n        if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:\n            device = \"cpu\"  # stitch in cpu memory if image is too large\n\n        if self.extra_input_padding:\n            image_size_original = inputs.shape[2:]\n            num_spatial_dims = len(image_size_original)\n            inputs = F.pad(\n                inputs,\n                pad=tuple(self.extra_input_padding),\n                mode=look_up_option(self.padding_mode, PytorchPadMode),\n                value=self.cval,\n            )\n\n        results = sliding_window_inference(\n            inputs,\n            self.roi_size,\n            self.sw_batch_size,\n            network,\n            self.overlap,\n            self.mode,\n            self.sigma_scale,\n            self.padding_mode,\n            self.cval,\n            self.sw_device,\n            device,\n            self.progress,\n            self.roi_weight_map,\n            self.process_output,\n            self.buffer_steps,\n            self.buffer_dim,\n            False,\n            *args,\n            **kwargs,\n        )\n\n        if self.extra_input_padding:\n            extra_slicing: list[slice] = []\n            num_padded_dims = len(self.extra_input_padding) // 2\n            for sp in range(num_padded_dims):\n                slice_dim = slice(\n                    self.extra_input_padding[sp * 2],\n                    image_size_original[num_spatial_dims - sp - 1] + self.extra_input_padding[sp * 2],\n                )\n                extra_slicing.insert(0, slice_dim)\n            for _ in range(len(inputs.shape) - num_padded_dims):\n                extra_slicing.insert(0, slice(None))\n\n            if isinstance(results, dict):\n                for k, v in results.items():\n                    results[k] = v[extra_slicing]\n            elif isinstance(results, (list, tuple)):\n                results = type(results)([res[extra_slicing] for res in results])\n            elif isinstance(results, (torch.Tensor, np.ndarray)):\n                results = results[extra_slicing]\n            else:\n                raise ValueError(\n                    f\"The output [{type(results)}] should be either dict, list, tuple, torch.Tensor, or numpy array.\"\n                )\n\n        return results\n"
  },
  {
    "path": "monai/apps/pathology/losses/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .hovernet_loss import HoVerNetLoss\n"
  },
  {
    "path": "monai/apps/pathology/losses/hovernet_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn import CrossEntropyLoss\nfrom torch.nn import functional as F\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.losses import DiceLoss\nfrom monai.transforms import SobelGradients\nfrom monai.utils.enums import HoVerNetBranch\n\n\nclass HoVerNetLoss(_Loss):\n    \"\"\"\n    Loss function for HoVerNet pipeline, which is combination of losses across the three branches.\n    The NP (nucleus prediction) branch uses Dice + CrossEntropy.\n    The HV (Horizontal and Vertical) distance from centroid branch uses MSE + MSE of the gradient.\n    The NC (Nuclear Class prediction) branch uses Dice + CrossEntropy\n    The result is a weighted sum of these losses.\n\n    Args:\n        lambda_hv_mse: Weight factor to apply to the HV regression MSE part of the overall loss\n        lambda_hv_mse_grad: Weight factor to apply to the MSE of the HV gradient part of the overall loss\n        lambda_np_ce: Weight factor to apply to the nuclei prediction CrossEntropyLoss part\n            of the overall loss\n        lambda_np_dice: Weight factor to apply to the nuclei prediction DiceLoss part of overall loss\n        lambda_nc_ce: Weight factor to apply to the nuclei class prediction CrossEntropyLoss part\n            of the overall loss\n        lambda_nc_dice: Weight factor to apply to the nuclei class prediction DiceLoss part of the\n            overall loss\n\n    \"\"\"\n\n    def __init__(\n        self,\n        lambda_hv_mse: float = 2.0,\n        lambda_hv_mse_grad: float = 1.0,\n        lambda_np_ce: float = 1.0,\n        lambda_np_dice: float = 1.0,\n        lambda_nc_ce: float = 1.0,\n        lambda_nc_dice: float = 1.0,\n    ) -> None:\n        self.lambda_hv_mse = lambda_hv_mse\n        self.lambda_hv_mse_grad = lambda_hv_mse_grad\n        self.lambda_np_ce = lambda_np_ce\n        self.lambda_np_dice = lambda_np_dice\n        self.lambda_nc_ce = lambda_nc_ce\n        self.lambda_nc_dice = lambda_nc_dice\n        super().__init__()\n\n        self.dice = DiceLoss(softmax=True, smooth_dr=1e-03, smooth_nr=1e-03, reduction=\"sum\", batch=True)\n        self.ce = CrossEntropyLoss(reduction=\"mean\")\n        self.sobel_v = SobelGradients(kernel_size=5, spatial_axes=0)\n        self.sobel_h = SobelGradients(kernel_size=5, spatial_axes=1)\n\n    def _compute_sobel(self, image: torch.Tensor) -> torch.Tensor:\n        \"\"\"Compute the Sobel gradients of the horizontal vertical map (HoVerMap).\n        More specifically, it will compute horizontal gradient of the input horizontal gradient map (channel=0) and\n        vertical gradient of the input vertical gradient map (channel=1).\n\n        Args:\n            image: a tensor with the shape of BxCxHxW representing HoVerMap\n\n        \"\"\"\n        result_h = self.sobel_h(image[:, 0])\n        result_v = self.sobel_v(image[:, 1])\n        return torch.stack([result_h, result_v], dim=1)\n\n    def _mse_gradient_loss(self, prediction: torch.Tensor, target: torch.Tensor, focus: torch.Tensor) -> torch.Tensor:\n        \"\"\"Compute the MSE loss of the gradients of the horizontal and vertical centroid distance maps\"\"\"\n\n        pred_grad = self._compute_sobel(prediction)\n        true_grad = self._compute_sobel(target)\n\n        loss = pred_grad - true_grad\n\n        # The focus constrains the loss computation to the detected nuclear regions\n        # (i.e. background is excluded)\n        focus = focus[:, None, ...]\n        focus = torch.cat((focus, focus), 1)\n\n        loss = focus * (loss * loss)\n        loss = loss.sum() / (focus.sum() + 1.0e-8)\n\n        return loss\n\n    def forward(self, prediction: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor:\n        \"\"\"\n        Args:\n            prediction: dictionary of predicted outputs for three branches,\n                each of which should have the shape of BNHW.\n            target: dictionary of ground truths for three branches,\n                each of which should have the shape of BNHW.\n        \"\"\"\n\n        if not (HoVerNetBranch.NP.value in prediction and HoVerNetBranch.HV.value in prediction):\n            raise ValueError(\n                \"nucleus prediction (NP) and horizontal_vertical (HV) branches must be \"\n                \"present for prediction and target parameters\"\n            )\n        if not (HoVerNetBranch.NP.value in target and HoVerNetBranch.HV.value in target):\n            raise ValueError(\n                \"nucleus prediction (NP) and horizontal_vertical (HV) branches must be \"\n                \"present for prediction and target parameters\"\n            )\n        if HoVerNetBranch.NC.value not in target and HoVerNetBranch.NC.value in target:\n            raise ValueError(\n                \"type_prediction (NC) must be present in both or neither of the prediction and target parameters\"\n            )\n        if HoVerNetBranch.NC.value in target and HoVerNetBranch.NC.value not in target:\n            raise ValueError(\n                \"type_prediction (NC) must be present in both or neither of the prediction and target parameters\"\n            )\n\n        # Compute the NP branch loss\n        dice_loss_np = (\n            self.dice(prediction[HoVerNetBranch.NP.value], target[HoVerNetBranch.NP.value]) * self.lambda_np_dice\n        )\n        # convert to target class indices\n        argmax_target = target[HoVerNetBranch.NP.value].argmax(dim=1)\n        ce_loss_np = self.ce(prediction[HoVerNetBranch.NP.value], argmax_target) * self.lambda_np_ce\n        loss_np = dice_loss_np + ce_loss_np\n\n        # Compute the HV branch loss\n        loss_hv_mse = (\n            F.mse_loss(prediction[HoVerNetBranch.HV.value], target[HoVerNetBranch.HV.value]) * self.lambda_hv_mse\n        )\n\n        # Use the nuclei class, one hot encoded, as the mask\n        loss_hv_mse_grad = (\n            self._mse_gradient_loss(\n                prediction[HoVerNetBranch.HV.value],\n                target[HoVerNetBranch.HV.value],\n                target[HoVerNetBranch.NP.value][:, 1],\n            )\n            * self.lambda_hv_mse_grad\n        )\n        loss_hv = loss_hv_mse_grad + loss_hv_mse\n\n        # Compute the NC branch loss\n        loss_nc = 0\n        if HoVerNetBranch.NC.value in prediction:\n            dice_loss_nc = (\n                self.dice(prediction[HoVerNetBranch.NC.value], target[HoVerNetBranch.NC.value]) * self.lambda_nc_dice\n            )\n            # Convert to target class indices\n            argmax_target = target[HoVerNetBranch.NC.value].argmax(dim=1)\n            ce_loss_nc = self.ce(prediction[HoVerNetBranch.NC.value], argmax_target) * self.lambda_nc_ce\n            loss_nc = dice_loss_nc + ce_loss_nc\n\n        # Sum the losses from each branch\n        loss: torch.Tensor = loss_hv + loss_np + loss_nc\n\n        return loss\n"
  },
  {
    "path": "monai/apps/pathology/metrics/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .lesion_froc import LesionFROC\n"
  },
  {
    "path": "monai/apps/pathology/metrics/lesion_froc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Iterable\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\n\nfrom monai.apps.pathology.utils import PathologyProbNMS, compute_isolated_tumor_cells, compute_multi_instance_mask\nfrom monai.config import NdarrayOrTensor\nfrom monai.data.wsi_reader import WSIReader\nfrom monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score\nfrom monai.utils import min_version, optional_import\n\nif TYPE_CHECKING:\n    from tqdm import tqdm\n\n    has_tqdm = True\nelse:\n    tqdm, has_tqdm = optional_import(\"tqdm\", \"4.47.0\", min_version, \"tqdm\")\n\nif not has_tqdm:\n\n    def tqdm(x):\n        return x\n\n\nclass LesionFROC:\n    \"\"\"\n    Evaluate with Free Response Operating Characteristic (FROC) score.\n\n    Args:\n        data: either the list of dictionaries containing probability maps (inference result) and\n            tumor mask (ground truth), as below, or the path to a json file containing such list.\n            `{\n            \"prob_map\": \"path/to/prob_map_1.npy\",\n            \"tumor_mask\": \"path/to/ground_truth_1.tiff\",\n            \"level\": 6,\n            \"pixel_spacing\": 0.243\n            }`\n        grow_distance: Euclidean distance (in micrometer) by which to grow the label the ground truth's tumors.\n            Defaults to 75, which is the equivalent size of 5 tumor cells.\n        itc_diameter: the maximum diameter of a region (in micrometer) to be considered as an isolated tumor cell.\n            Defaults to 200.\n        eval_thresholds: the false positive rates for calculating the average sensitivity.\n            Defaults to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.\n        nms_sigma: the standard deviation for gaussian filter of non-maximal suppression. Defaults to 0.0.\n        nms_prob_threshold: the probability threshold of non-maximal suppression. Defaults to 0.5.\n        nms_box_size: the box size (in pixel) to be removed around the pixel for non-maximal suppression.\n        image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.\n            Defaults to CuCIM.\n\n    Note:\n        For more info on `nms_*` parameters look at monai.utils.prob_nms.ProbNMS`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: list[dict],\n        grow_distance: int = 75,\n        itc_diameter: int = 200,\n        eval_thresholds: tuple = (0.25, 0.5, 1, 2, 4, 8),\n        nms_sigma: float = 0.0,\n        nms_prob_threshold: float = 0.5,\n        nms_box_size: int = 48,\n        image_reader_name: str = \"cuCIM\",\n    ) -> None:\n        self.data = data\n        self.grow_distance = grow_distance\n        self.itc_diameter = itc_diameter\n        self.eval_thresholds = eval_thresholds\n        self.image_reader = WSIReader(image_reader_name)\n        self.nms = PathologyProbNMS(sigma=nms_sigma, prob_threshold=nms_prob_threshold, box_size=nms_box_size)\n\n    def prepare_inference_result(self, sample: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n        \"\"\"\n        Prepare the probability map for detection evaluation.\n\n        \"\"\"\n        # load the probability map (the result of model inference)\n        prob_map = np.load(sample[\"prob_map\"])\n\n        # apply non-maximal suppression\n        nms_outputs = self.nms(probs_map=prob_map, resolution_level=sample[\"level\"])\n\n        # separate nms outputs\n        probs: Iterable[Any]\n        x_coord: Iterable[Any]\n        y_coord: Iterable[Any]\n        if nms_outputs:\n            probs, x_coord, y_coord = zip(*nms_outputs)\n        else:\n            probs, x_coord, y_coord = [], [], []\n\n        return np.array(probs), np.array(x_coord), np.array(y_coord)\n\n    def prepare_ground_truth(self, sample):\n        \"\"\"\n        Prepare the ground truth for evaluation based on the binary tumor mask\n\n        \"\"\"\n        # load binary tumor masks\n        img_obj = self.image_reader.read(sample[\"tumor_mask\"])\n        tumor_mask = self.image_reader.get_data(img_obj, level=sample[\"level\"])[0][0]\n\n        # calculate pixel spacing at the mask level\n        mask_pixel_spacing = sample[\"pixel_spacing\"] * pow(2, sample[\"level\"])\n\n        # compute multi-instance mask from a binary mask\n        grow_pixel_threshold = self.grow_distance / (mask_pixel_spacing * 2)\n        tumor_mask = compute_multi_instance_mask(mask=tumor_mask, threshold=grow_pixel_threshold)\n\n        # identify isolated tumor cells\n        itc_threshold = (self.itc_diameter + self.grow_distance) / mask_pixel_spacing\n        itc_labels = compute_isolated_tumor_cells(tumor_mask=tumor_mask, threshold=itc_threshold)\n\n        return tumor_mask, itc_labels\n\n    def compute_fp_tp(self):\n        \"\"\"\n        Compute false positive and true positive probabilities for tumor detection,\n        by comparing the model outputs with the prepared ground truths for all samples\n\n        \"\"\"\n        total_fp_probs: list[NdarrayOrTensor] = []\n        total_tp_probs: list[NdarrayOrTensor] = []\n        total_num_targets = 0\n        num_images = len(self.data)\n\n        for sample in tqdm(self.data):\n            probs, y_coord, x_coord = self.prepare_inference_result(sample)\n            ground_truth, itc_labels = self.prepare_ground_truth(sample)\n            # compute FP and TP probabilities for a pair of an image and an ground truth mask\n            fp_probs, tp_probs, num_targets = compute_fp_tp_probs(\n                probs=probs,\n                y_coord=y_coord,\n                x_coord=x_coord,\n                evaluation_mask=ground_truth,\n                labels_to_exclude=itc_labels,\n                resolution_level=sample[\"level\"],\n            )\n            total_fp_probs.extend(fp_probs)\n            total_tp_probs.extend(tp_probs)\n            total_num_targets += num_targets\n\n        return np.array(total_fp_probs), np.array(total_tp_probs), total_num_targets, num_images\n\n    def evaluate(self):\n        \"\"\"\n        Evaluate the detection performance of a model based on the model probability map output,\n        the ground truth tumor mask, and their associated metadata (e.g., pixel_spacing, level)\n        \"\"\"\n        # compute false positive (FP) and true positive (TP) probabilities for all images\n        fp_probs, tp_probs, num_targets, num_images = self.compute_fp_tp()\n\n        # compute FROC curve given the evaluation of all images\n        fps_per_image, total_sensitivity = compute_froc_curve_data(\n            fp_probs=fp_probs, tp_probs=tp_probs, num_targets=num_targets, num_images=num_images\n        )\n\n        # compute FROC score give specific evaluation threshold\n        froc_score = compute_froc_score(\n            fps_per_image=fps_per_image, total_sensitivity=total_sensitivity, eval_thresholds=self.eval_thresholds\n        )\n\n        return froc_score\n"
  },
  {
    "path": "monai/apps/pathology/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .post.array import (\n    GenerateDistanceMap,\n    GenerateInstanceBorder,\n    GenerateInstanceCentroid,\n    GenerateInstanceContour,\n    GenerateInstanceType,\n    GenerateSuccinctContour,\n    GenerateWatershedMarkers,\n    GenerateWatershedMask,\n    HoVerNetInstanceMapPostProcessing,\n    HoVerNetNuclearTypePostProcessing,\n    Watershed,\n)\nfrom .post.dictionary import (\n    GenerateDistanceMapD,\n    GenerateDistanceMapd,\n    GenerateDistanceMapDict,\n    GenerateInstanceBorderD,\n    GenerateInstanceBorderd,\n    GenerateInstanceBorderDict,\n    GenerateInstanceCentroidD,\n    GenerateInstanceCentroidd,\n    GenerateInstanceCentroidDict,\n    GenerateInstanceContourD,\n    GenerateInstanceContourd,\n    GenerateInstanceContourDict,\n    GenerateInstanceTypeD,\n    GenerateInstanceTyped,\n    GenerateInstanceTypeDict,\n    GenerateSuccinctContourD,\n    GenerateSuccinctContourd,\n    GenerateSuccinctContourDict,\n    GenerateWatershedMarkersD,\n    GenerateWatershedMarkersd,\n    GenerateWatershedMarkersDict,\n    GenerateWatershedMaskD,\n    GenerateWatershedMaskd,\n    GenerateWatershedMaskDict,\n    HoVerNetInstanceMapPostProcessingD,\n    HoVerNetInstanceMapPostProcessingd,\n    HoVerNetInstanceMapPostProcessingDict,\n    HoVerNetNuclearTypePostProcessingD,\n    HoVerNetNuclearTypePostProcessingd,\n    HoVerNetNuclearTypePostProcessingDict,\n    WatershedD,\n    Watershedd,\n    WatershedDict,\n)\nfrom .stain.array import ExtractHEStains, NormalizeHEStains\nfrom .stain.dictionary import (\n    ExtractHEStainsd,\n    ExtractHEStainsD,\n    ExtractHEStainsDict,\n    NormalizeHEStainsd,\n    NormalizeHEStainsD,\n    NormalizeHEStainsDict,\n)\n"
  },
  {
    "path": "monai/apps/pathology/transforms/post/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .array import (\n    GenerateDistanceMap,\n    GenerateInstanceBorder,\n    GenerateInstanceCentroid,\n    GenerateInstanceContour,\n    GenerateInstanceType,\n    GenerateSuccinctContour,\n    GenerateWatershedMarkers,\n    GenerateWatershedMask,\n    HoVerNetInstanceMapPostProcessing,\n    HoVerNetNuclearTypePostProcessing,\n    Watershed,\n)\nfrom .dictionary import (\n    GenerateDistanceMapD,\n    GenerateDistanceMapd,\n    GenerateDistanceMapDict,\n    GenerateInstanceBorderD,\n    GenerateInstanceBorderd,\n    GenerateInstanceBorderDict,\n    GenerateInstanceCentroidD,\n    GenerateInstanceCentroidd,\n    GenerateInstanceCentroidDict,\n    GenerateInstanceContourD,\n    GenerateInstanceContourd,\n    GenerateInstanceContourDict,\n    GenerateInstanceTypeD,\n    GenerateInstanceTyped,\n    GenerateInstanceTypeDict,\n    GenerateSuccinctContourD,\n    GenerateSuccinctContourd,\n    GenerateSuccinctContourDict,\n    GenerateWatershedMarkersD,\n    GenerateWatershedMarkersd,\n    GenerateWatershedMarkersDict,\n    GenerateWatershedMaskD,\n    GenerateWatershedMaskd,\n    GenerateWatershedMaskDict,\n    HoVerNetInstanceMapPostProcessingD,\n    HoVerNetInstanceMapPostProcessingd,\n    HoVerNetInstanceMapPostProcessingDict,\n    HoVerNetNuclearTypePostProcessingD,\n    HoVerNetNuclearTypePostProcessingd,\n    HoVerNetNuclearTypePostProcessingDict,\n    WatershedD,\n    Watershedd,\n    WatershedDict,\n)\n"
  },
  {
    "path": "monai/apps/pathology/transforms/post/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\nfrom typing import Callable\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import DtypeLike, NdarrayOrTensor\nfrom monai.transforms import (\n    Activations,\n    AsDiscrete,\n    BoundingRect,\n    FillHoles,\n    GaussianSmooth,\n    RemoveSmallObjects,\n    SobelGradients,\n)\nfrom monai.transforms.transform import Transform\nfrom monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique, where\nfrom monai.utils import TransformBackends, convert_to_numpy, optional_import\nfrom monai.utils.misc import ensure_tuple_rep\nfrom monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor\n\nlabel, _ = optional_import(\"scipy.ndimage\", name=\"label\")\ndisk, _ = optional_import(\"skimage.morphology\", name=\"disk\")\nopening, _ = optional_import(\"skimage.morphology\", name=\"opening\")\nwatershed, _ = optional_import(\"skimage.segmentation\", name=\"watershed\")\nfind_contours, _ = optional_import(\"skimage.measure\", name=\"find_contours\")\ncentroid, _ = optional_import(\"skimage.measure\", name=\"centroid\")\n\n__all__ = [\n    \"Watershed\",\n    \"GenerateWatershedMask\",\n    \"GenerateInstanceBorder\",\n    \"GenerateDistanceMap\",\n    \"GenerateWatershedMarkers\",\n    \"GenerateSuccinctContour\",\n    \"GenerateInstanceContour\",\n    \"GenerateInstanceCentroid\",\n    \"GenerateInstanceType\",\n    \"HoVerNetInstanceMapPostProcessing\",\n    \"HoVerNetNuclearTypePostProcessing\",\n]\n\n\nclass Watershed(Transform):\n    \"\"\"\n    Use `skimage.segmentation.watershed` to get instance segmentation results from images.\n    See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed.\n\n    Args:\n        connectivity: an array with the same number of dimensions as image whose non-zero elements indicate\n            neighbors for connection. Following the scipy convention, default is a one-connected array of\n            the dimension of the image.\n        dtype: target data content type to convert, default is np.int64.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, connectivity: int | None = 1, dtype: DtypeLike = np.int64) -> None:\n        self.connectivity = connectivity\n        self.dtype = dtype\n\n    def __call__(\n        self, image: NdarrayOrTensor, mask: NdarrayOrTensor | None = None, markers: NdarrayOrTensor | None = None\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            image: image where the lowest value points are labeled first. Shape must be [1, H, W, [D]].\n            mask: optional, the same shape as image. Only points at which mask == True will be labeled.\n                If None (no mask given), it is a volume of all 1s.\n            markers: optional, the same shape as image. The desired number of markers, or an array marking\n                the basins with the values to be assigned in the label matrix. Zero means not a marker.\n                If None (no markers given), the local minima of the image are used as markers.\n        \"\"\"\n\n        image = convert_to_numpy(image)\n        markers = convert_to_numpy(markers)\n        mask = convert_to_numpy(mask)\n\n        instance_seg = watershed(image, markers=markers, mask=mask, connectivity=self.connectivity)\n\n        return convert_to_dst_type(instance_seg, image, dtype=self.dtype)[0]\n\n\nclass GenerateWatershedMask(Transform):\n    \"\"\"\n    generate mask used in `watershed`. Only points at which mask == True will be labeled.\n\n    Args:\n        activation: the activation layer to be applied on the input probability map.\n            It can be \"softmax\" or \"sigmoid\" string, or any callable. Defaults to \"softmax\".\n        threshold: an optional float value to threshold to binarize probability map.\n            If not provided, defaults to 0.5 when activation is not \"softmax\", otherwise None.\n        min_object_size: objects smaller than this size (in pixel) are removed. Defaults to 10.\n        dtype: target data content type to convert, default is np.uint8.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        activation: str | Callable = \"softmax\",\n        threshold: float | None = None,\n        min_object_size: int = 10,\n        dtype: DtypeLike = np.uint8,\n    ) -> None:\n        self.dtype = dtype\n\n        # set activation layer\n        use_softmax = False\n        use_sigmoid = False\n        activation_fn = None\n        if isinstance(activation, str):\n            if activation.lower() == \"softmax\":\n                use_softmax = True\n            elif activation.lower() == \"sigmoid\":\n                use_sigmoid = True\n            else:\n                raise ValueError(\n                    f\"The activation should be 'softmax' or 'sigmoid' string, or any callable. '{activation}' was given.\"\n                )\n        elif callable(activation):\n            activation_fn = activation\n        else:\n            raise ValueError(f\"The activation type should be either str or callable. '{type(activation)}' was given.\")\n        self.activation = Activations(softmax=use_softmax, sigmoid=use_sigmoid, other=activation_fn)\n\n        # set discretization transform\n        if not use_softmax and threshold is None:\n            threshold = 0.5\n        self.as_discrete = AsDiscrete(threshold=threshold, argmax=use_softmax)\n\n        # set small object removal transform\n        self.remove_small_objects = RemoveSmallObjects(min_size=min_object_size) if min_object_size > 0 else None\n\n    def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            prob_map: probability map of segmentation, shape must be [C, H, W, [D]]\n        \"\"\"\n\n        pred = self.activation(prob_map)\n        pred = self.as_discrete(pred)\n\n        pred = convert_to_numpy(pred)\n\n        pred = label(pred)[0]\n        if self.remove_small_objects is not None:\n            pred = self.remove_small_objects(pred)\n        pred_indices = np.where(pred > 0)\n        pred[pred_indices] = 1\n\n        return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]\n\n\nclass GenerateInstanceBorder(Transform):\n    \"\"\"\n    Generate instance border by hover map. The more parts of the image that cannot be identified as foreground areas,\n    the larger the grey scale value. The grey value of the instance's border will be larger.\n\n    Args:\n        kernel_size: the size of the Sobel kernel. Defaults to 5.\n        dtype: target data type to convert to. Defaults to np.float32.\n\n\n    Raises:\n        ValueError: when the `mask` shape is not [1, H, W].\n        ValueError: when the `hover_map` shape is not [2, H, W].\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, kernel_size: int = 5, dtype: DtypeLike = np.float32) -> None:\n        self.dtype = dtype\n        self.sobel_gradient = SobelGradients(kernel_size=kernel_size)\n\n    def __call__(self, mask: NdarrayOrTensor, hover_map: NdarrayOrTensor) -> NdarrayOrTensor:  # type: ignore\n        \"\"\"\n        Args:\n            mask: binary segmentation map, the output of :py:class:`GenerateWatershedMask`.\n                Shape must be [1, H, W] or [H, W].\n            hover_map:  horizontal and vertical distances of nuclear pixels to their centres of mass. Shape must be [2, H, W].\n                The first and second channel represent the horizontal and vertical maps respectively. For more details refer\n                to papers: https://arxiv.org/abs/1812.06499.\n        \"\"\"\n        if len(hover_map.shape) != 3:\n            raise ValueError(f\"The hover map should have the shape of [C, H, W], but got {hover_map.shape}.\")\n        if len(mask.shape) == 3:\n            if mask.shape[0] != 1:\n                raise ValueError(f\"The mask should have only one channel, but got {mask.shape[0]}.\")\n        elif len(mask.shape) == 2:\n            mask = mask[None]\n        else:\n            raise ValueError(f\"The mask should have the shape of [1, H, W] or [H, W], but got {mask.shape}.\")\n        if hover_map.shape[0] != 2:\n            raise ValueError(f\"Suppose the hover map only has two channels, but got {hover_map.shape[0]}\")\n\n        hover_h = hover_map[0:1, ...]\n        hover_v = hover_map[1:2, ...]\n\n        hover_h_min, hover_h_max = min(hover_h), max(hover_h)\n        hover_v_min, hover_v_max = min(hover_v), max(hover_v)\n        if (hover_h_max - hover_h_min) == 0 or (hover_v_max - hover_v_min) == 0:\n            raise ValueError(\"Not a valid hover map, please check your input\")\n        hover_h = (hover_h - hover_h_min) / (hover_h_max - hover_h_min)\n        hover_v = (hover_v - hover_v_min) / (hover_v_max - hover_v_min)\n        sobelh = self.sobel_gradient(hover_h)[1, ...]\n        sobelv = self.sobel_gradient(hover_v)[0, ...]\n        sobelh_min, sobelh_max = min(sobelh), max(sobelh)\n        sobelv_min, sobelv_max = min(sobelv), max(sobelv)\n        if (sobelh_max - sobelh_min) == 0 or (sobelv_max - sobelv_min) == 0:\n            raise ValueError(\"Not a valid sobel gradient map\")\n        sobelh = 1 - (sobelh - sobelh_min) / (sobelh_max - sobelh_min)\n        sobelv = 1 - (sobelv - sobelv_min) / (sobelv_max - sobelv_min)\n\n        # combine the h & v values using max\n        overall = maximum(sobelh, sobelv)\n        overall = overall - (1 - mask)\n        overall[overall < 0] = 0\n\n        return convert_to_dst_type(overall, mask, dtype=self.dtype)[0]\n\n\nclass GenerateDistanceMap(Transform):\n    \"\"\"\n    Generate distance map.\n    In general, the instance map is calculated from the distance to the background.\n    Here, we use 1 - \"instance border map\" to generate the distance map.\n    Nuclei values form mountains so invert them to get basins.\n\n    Args:\n        smooth_fn: smoothing function for distance map, which can be any callable object.\n            If not provided :py:class:`monai.transforms.GaussianSmooth()` is used.\n        dtype: target data type to convert to. Defaults to np.float32.\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, smooth_fn: Callable | None = None, dtype: DtypeLike = np.float32) -> None:\n        self.smooth_fn = smooth_fn if smooth_fn is not None else GaussianSmooth()\n        self.dtype = dtype\n\n    def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> NdarrayOrTensor:  # type: ignore\n        \"\"\"\n        Args:\n            mask: binary segmentation map, the output of :py:class:`GenerateWatershedMask`.\n                Shape must be [1, H, W] or [H, W].\n            instance_border: instance border map, the output of :py:class:`GenerateInstanceBorder`.\n                Shape must be [1, H, W].\n        \"\"\"\n        if len(mask.shape) == 3:\n            if mask.shape[0] != 1:\n                raise ValueError(f\"The mask should have only one channel, but got {mask.shape[0]}.\")\n        elif len(mask.shape) == 2:\n            mask = mask[None]\n        else:\n            raise ValueError(f\"The mask should have the shape of [1, H, W] or [H, W], but got {mask.shape}.\")\n        if instance_border.shape[0] != 1 or instance_border.ndim != 3:\n            raise ValueError(f\"Input instance_border should be with size of [1, H, W], but got {instance_border.shape}\")\n\n        distance_map = (1.0 - instance_border) * mask\n        distance_map = self.smooth_fn(distance_map)  # type: ignore\n\n        return convert_to_dst_type(-distance_map, mask, dtype=self.dtype)[0]\n\n\nclass GenerateWatershedMarkers(Transform):\n    \"\"\"\n    Generate markers to be used in `watershed`. The watershed algorithm treats pixels values as a local topography\n    (elevation). The algorithm floods basins from the markers until basins attributed to different markers meet on\n    watershed lines. Generally, markers are chosen as local minima of the image, from which basins are flooded.\n    Here is the implementation from HoVerNet paper.\n    For more details refer to papers: https://arxiv.org/abs/1812.06499.\n\n    Args:\n        threshold: a float value to threshold to binarize instance border map.\n            It turns uncertain area to 1 and other area to 0. Defaults to 0.4.\n        radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2.\n        min_object_size: objects smaller than this size (in pixel) are removed. Defaults to 10.\n        postprocess_fn: additional post-process function on the markers.\n            If not provided, :py:class:`monai.transforms.post.FillHoles()` will be used.\n        dtype: target data type to convert to. Defaults to np.int64.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        threshold: float = 0.4,\n        radius: int = 2,\n        min_object_size: int = 10,\n        postprocess_fn: Callable | None = None,\n        dtype: DtypeLike = np.int64,\n    ) -> None:\n        self.threshold = threshold\n        self.radius = radius\n        self.dtype = dtype\n        if postprocess_fn is None:\n            postprocess_fn = FillHoles()\n\n        self.postprocess_fn = postprocess_fn\n        self.remove_small_objects = RemoveSmallObjects(min_size=min_object_size) if min_object_size > 0 else None\n\n    def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> NdarrayOrTensor:  # type: ignore\n        \"\"\"\n        Args:\n            mask: binary segmentation map, the output of :py:class:`GenerateWatershedMask`.\n                Shape must be [1, H, W] or [H, W].\n            instance_border: instance border map, the output of :py:class:`GenerateInstanceBorder`.\n                Shape must be [1, H, W].\n        \"\"\"\n        if len(mask.shape) == 3:\n            if mask.shape[0] != 1:\n                raise ValueError(f\"The mask should have only one channel, but got {mask.shape[0]}.\")\n        elif len(mask.shape) == 2:\n            mask = mask[None]\n        else:\n            raise ValueError(f\"The mask should have the shape of [1, H, W] or [H, W], but got {mask.shape}.\")\n        if instance_border.shape[0] != 1 or instance_border.ndim != 3:\n            raise ValueError(f\"Input instance_border should be with size of [1, H, W], but got {instance_border.shape}\")\n\n        instance_border = instance_border >= self.threshold  # uncertain area\n\n        marker = mask - convert_to_dst_type(instance_border, mask)[0]  # certain foreground\n        marker_indices = where(marker < 0)\n        marker[marker_indices] = 0  # type: ignore[index]\n        marker = self.postprocess_fn(marker)\n        marker = convert_to_numpy(marker)\n\n        marker = opening(marker.squeeze(), disk(self.radius))\n        marker = label(marker)[0][None]\n        if self.remove_small_objects is not None:\n            marker = self.remove_small_objects(marker)\n\n        return convert_to_dst_type(marker, mask, dtype=self.dtype)[0]\n\n\nclass GenerateSuccinctContour(Transform):\n    \"\"\"\n    Converts SciPy-style contours (generated by skimage.measure.find_contours) to a more succinct version which only includes\n    the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line).\n\n    Args:\n        height: height of bounding box, used to detect direction of line segment.\n        width: width of bounding box, used to detect direction of line segment.\n\n    Returns:\n        the pixels that need to be joined by straight lines to describe the outmost pixels of the foreground similar to\n            OpenCV's cv.CHAIN_APPROX_SIMPLE (counterclockwise)\n    \"\"\"\n\n    def __init__(self, height: int, width: int) -> None:\n        self.height = height\n        self.width = width\n\n    def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> tuple[int, int]:\n        \"\"\"\n        Generate contour coordinates. Given the previous and current coordinates of border positions,\n        returns the int pixel that marks the extremity of the segmented pixels.\n\n        Args:\n            current: coordinates of the current border position.\n            previous: coordinates of the previous border position.\n        \"\"\"\n\n        p_delta = (current[0] - previous[0], current[1] - previous[1])\n        row, col = -1, -1\n\n        if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)):\n            row = int(current[0] + 0.5)\n            col = int(current[1])\n        elif p_delta in ((0.0, -1.0), (0.5, -0.5)):\n            row = int(current[0])\n            col = int(current[1])\n        elif p_delta in ((-1, 0.0), (-0.5, -0.5)):\n            row = int(current[0])\n            col = int(current[1] + 0.5)\n        elif p_delta == (-0.5, 0.5):\n            row = int(current[0] + 0.5)\n            col = int(current[1] + 0.5)\n\n        return row, col\n\n    def _calculate_distance_from_top_left(self, sequence: Sequence[tuple[int, int]]) -> int:\n        \"\"\"\n        Each sequence of coordinates describes a boundary between foreground and background starting and ending at two sides\n        of the bounding box. To order the sequences correctly, we compute the distance from the top-left of the bounding box\n        around the perimeter in a clockwise direction.\n\n        Args:\n            sequence: list of border points coordinates.\n\n        Returns:\n            the distance round the perimeter of the bounding box from the top-left origin\n        \"\"\"\n        distance: int\n        first_coord = sequence[0]\n        if first_coord[0] == 0:\n            distance = first_coord[1]\n        elif first_coord[1] == self.width - 1:\n            distance = self.width + first_coord[0]\n        elif first_coord[0] == self.height - 1:\n            distance = 2 * self.width + self.height - first_coord[1]\n        else:\n            distance = 2 * (self.width + self.height) - first_coord[0]\n\n        return distance\n\n    def __call__(self, contours: list[np.ndarray]) -> np.ndarray:\n        \"\"\"\n        Args:\n            contours: list of (n, 2)-ndarrays, scipy-style clockwise line segments, with lines separating foreground/background.\n                Each contour is an ndarray of shape (n, 2), consisting of n (row, column) coordinates along the contour.\n        \"\"\"\n        pixels: list[tuple[int, int]] = []\n        sequences = []\n        corners = [False, False, False, False]\n\n        for group in contours:\n            sequence: list[tuple[int, int]] = []\n            last_added = None\n            prev = None\n            corner = -1\n\n            for i, coord in enumerate(group):\n                if i == 0:\n                    # originating from the top, so must be heading south east\n                    if coord[0] == 0.0:\n                        corner = 1\n                        pixel = (0, int(coord[1] - 0.5))\n                        if pixel[1] == self.width - 1:\n                            corners[1] = True\n                        elif pixel[1] == 0.0:\n                            corners[0] = True\n                    # originating from the left, so must be heading north east\n                    elif coord[1] == 0.0:\n                        corner = 0\n                        pixel = (int(coord[0] + 0.5), 0)\n                    # originating from the bottom, so must be heading north west\n                    elif coord[0] == self.height - 1:\n                        corner = 3\n                        pixel = (int(coord[0]), int(coord[1] + 0.5))\n                        if pixel[1] == self.width - 1:\n                            corners[2] = True\n                    # originating from the right, so must be heading south west\n                    elif coord[1] == self.width - 1:\n                        corner = 2\n                        pixel = (int(coord[0] - 0.5), int(coord[1]))\n                    else:\n                        warnings.warn(f\"Invalid contour coord {coord} is generated, skip this instance.\")\n                        return None  # type: ignore\n                    sequence.append(pixel)\n                    last_added = pixel\n                elif i == len(group) - 1:\n                    # add this point\n                    pixel = self._generate_contour_coord(coord, prev)  # type: ignore\n                    if pixel != last_added:\n                        sequence.append(pixel)\n                        last_added = pixel\n                elif np.any(coord - prev != group[i + 1] - coord):\n                    pixel = self._generate_contour_coord(coord, prev)  # type: ignore\n                    if pixel != last_added:\n                        sequence.append(pixel)\n                        last_added = pixel\n\n                # flag whether each corner has been crossed\n                if i == len(group) - 1:\n                    if corner == 0:\n                        if coord[0] == 0:\n                            corners[corner] = True\n                    elif corner == 1:\n                        if coord[1] == self.width - 1:\n                            corners[corner] = True\n                    elif corner == 2:\n                        if coord[0] == self.height - 1:\n                            corners[corner] = True\n                    elif corner == 3:\n                        if coord[1] == 0.0:\n                            corners[corner] = True\n\n                prev = coord\n            dist = self._calculate_distance_from_top_left(sequence)\n\n            sequences.append({\"distance\": dist, \"sequence\": sequence})\n\n        # check whether we need to insert any missing corners\n        if corners[0] is False:\n            sequences.append({\"distance\": 0, \"sequence\": [(0, 0)]})\n        if corners[1] is False:\n            sequences.append({\"distance\": self.width, \"sequence\": [(0, self.width - 1)]})\n        if corners[2] is False:\n            sequences.append({\"distance\": self.width + self.height, \"sequence\": [(self.height - 1, self.width - 1)]})\n        if corners[3] is False:\n            sequences.append({\"distance\": 2 * self.width + self.height, \"sequence\": [(self.height - 1, 0)]})\n\n        # join the sequences into a single contour\n        # starting at top left and rotating clockwise\n        sequences.sort(key=lambda x: x.get(\"distance\"))  # type: ignore\n\n        last = (-1, -1)\n        for _sequence in sequences:\n            if _sequence[\"sequence\"][0] == last:  # type: ignore\n                pixels.pop()\n            if pixels:\n                pixels = [*pixels, *_sequence[\"sequence\"]]  # type: ignore\n            else:\n                pixels = _sequence[\"sequence\"]  # type: ignore\n            last = pixels[-1]\n\n        if pixels[0] == last:\n            pixels.pop(0)\n\n        if pixels[0] == (0, 0):\n            pixels.append(pixels.pop(0))\n\n        return np.flip(convert_to_numpy(pixels, dtype=np.int32))  # type: ignore\n\n\nclass GenerateInstanceContour(Transform):\n    \"\"\"\n    Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include\n    the pixels to which lines need to be drawn\n\n    Args:\n        min_num_points: assumed that the created contour does not form a contour if it does not contain more points\n            than the specified value. Defaults to 3.\n        contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.\n            If not provided, the level is set to `(max(image) + min(image)) / 2`.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, min_num_points: int = 3, contour_level: float | None = None) -> None:\n        self.contour_level = contour_level\n        self.min_num_points = min_num_points\n\n    def __call__(self, inst_mask: NdarrayOrTensor, offset: Sequence[int] | None = (0, 0)) -> np.ndarray | None:\n        \"\"\"\n        Args:\n            inst_mask: segmentation mask for a single instance. Shape should be [1, H, W, [D]]\n            offset: optional offset of starting position of the instance mask in the original array. Default to 0 for each dim.\n        \"\"\"\n        inst_mask = inst_mask.squeeze()  # squeeze channel dim\n        inst_mask = convert_to_numpy(inst_mask)\n        inst_contour_cv = find_contours(inst_mask, level=self.contour_level)\n        generate_contour = GenerateSuccinctContour(inst_mask.shape[0], inst_mask.shape[1])\n        inst_contour = generate_contour(inst_contour_cv)\n        if inst_contour is None:\n            return None\n        # less than `self.min_num_points` points don't make a contour, so skip.\n        # They are likely to be artifacts as the contours obtained via approximation.\n        if inst_contour.shape[0] < self.min_num_points:\n            print(f\"< {self.min_num_points} points don't make a contour, so skipped!\")\n            return None\n        # check for tricky shape\n        elif len(inst_contour.shape) != 2:\n            print(f\"{len(inst_contour.shape)} != 2, check for tricky shapes!\")\n            return None\n        else:\n            inst_contour[:, 0] += offset[0]  # type: ignore\n            inst_contour[:, 1] += offset[1]  # type: ignore\n            return inst_contour\n\n\nclass GenerateInstanceCentroid(Transform):\n    \"\"\"\n    Generate instance centroid using `skimage.measure.centroid`.\n\n    Args:\n        dtype: the data type of output centroid.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, dtype: DtypeLike | None = int) -> None:\n        self.dtype = dtype\n\n    def __call__(self, inst_mask: NdarrayOrTensor, offset: Sequence[int] | int = 0) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            inst_mask: segmentation mask for a single instance. Shape should be [1, H, W, [D]]\n            offset: optional offset of starting position of the instance mask in the original array. Default to 0 for each dim.\n\n        \"\"\"\n        inst_mask = convert_to_numpy(inst_mask)\n        inst_mask = inst_mask.squeeze(0)  # squeeze channel dim\n        ndim = len(inst_mask.shape)\n        offset = ensure_tuple_rep(offset, ndim)\n\n        inst_centroid = centroid(inst_mask)\n        for i in range(ndim):\n            inst_centroid[i] += offset[i]\n\n        return convert_to_dst_type(inst_centroid, inst_mask, dtype=self.dtype)[0]\n\n\nclass GenerateInstanceType(Transform):\n    \"\"\"\n    Generate instance type and probability for each instance.\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __call__(  # type: ignore\n        self, type_pred: NdarrayOrTensor, seg_pred: NdarrayOrTensor, bbox: np.ndarray, instance_id: int\n    ) -> tuple[int, float]:\n        \"\"\"\n        Args:\n            type_pred: pixel-level type prediction map after activation function.\n            seg_pred: pixel-level segmentation prediction map after activation function.\n            bbox: bounding box coordinates of the instance, shape is [channel, 2 * spatial dims].\n            instance_id: get instance type from specified instance id.\n        \"\"\"\n\n        rmin, rmax, cmin, cmax = bbox.flatten()\n        seg_map_crop = seg_pred[0, rmin:rmax, cmin:cmax]\n        type_map_crop = type_pred[0, rmin:rmax, cmin:cmax]\n\n        seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]\n\n        inst_type = type_map_crop[seg_map_crop]  # type: ignore[index]\n        type_list, type_pixels = unique(inst_type, return_counts=True)\n        type_list = list(zip(type_list, type_pixels))\n        type_list = sorted(type_list, key=lambda x: x[1], reverse=True)\n        inst_type = type_list[0][0]\n        if inst_type == 0:  # ! pick the 2nd most dominant if exist\n            if len(type_list) > 1:\n                inst_type = type_list[1][0]\n        type_dict = {v[0]: v[1] for v in type_list}\n        type_prob = type_dict[inst_type] / (sum(seg_map_crop) + 1.0e-6)\n\n        return (int(inst_type), float(type_prob))\n\n\nclass HoVerNetInstanceMapPostProcessing(Transform):\n    \"\"\"\n    The post-processing transform for HoVerNet model to generate instance segmentation map.\n    It generates an instance segmentation map as well as a dictionary containing centroids, bounding boxes, and contours\n    for each instance.\n\n    Args:\n        activation: the activation layer to be applied on the input probability map.\n            It can be \"softmax\" or \"sigmoid\" string, or any callable. Defaults to \"softmax\".\n        mask_threshold: a float value to threshold to binarize probability map to generate mask.\n        min_object_size: objects smaller than this size (in pixel) are removed. Defaults to 10.\n        sobel_kernel_size: the size of the Sobel kernel used in :py:class:`GenerateInstanceBorder`. Defaults to 5.\n        distance_smooth_fn: smoothing function for distance map.\n            If not provided, :py:class:`monai.transforms.intensity.GaussianSmooth()` will be used.\n        marker_threshold: a float value to threshold to binarize instance border map for markers.\n            It turns uncertain area to 1 and other area to 0. Defaults to 0.4.\n        marker_radius: the radius of the disk-shaped footprint used in `opening` of markers. Defaults to 2.\n        marker_postprocess_fn: post-process function for watershed markers.\n            If not provided, :py:class:`monai.transforms.post.FillHoles()` will be used.\n        watershed_connectivity: `connectivity` argument of `skimage.segmentation.watershed`.\n        min_num_points: minimum number of points to be considered as a contour. Defaults to 3.\n        contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.\n            If not provided, the level is set to `(max(image) + min(image)) / 2`.\n        device: target device to put the output Tensor data.\n    \"\"\"\n\n    def __init__(\n        self,\n        activation: str | Callable = \"softmax\",\n        mask_threshold: float | None = None,\n        min_object_size: int = 10,\n        sobel_kernel_size: int = 5,\n        distance_smooth_fn: Callable | None = None,\n        marker_threshold: float = 0.4,\n        marker_radius: int = 2,\n        marker_postprocess_fn: Callable | None = None,\n        watershed_connectivity: int | None = 1,\n        min_num_points: int = 3,\n        contour_level: float | None = None,\n        device: str | torch.device | None = None,\n    ) -> None:\n        super().__init__()\n        self.device = device\n        self.generate_watershed_mask = GenerateWatershedMask(\n            activation=activation, threshold=mask_threshold, min_object_size=min_object_size\n        )\n        self.generate_instance_border = GenerateInstanceBorder(kernel_size=sobel_kernel_size)\n        self.generate_distance_map = GenerateDistanceMap(smooth_fn=distance_smooth_fn)\n        self.generate_watershed_markers = GenerateWatershedMarkers(\n            threshold=marker_threshold,\n            radius=marker_radius,\n            postprocess_fn=marker_postprocess_fn,\n            min_object_size=min_object_size,\n        )\n        self.watershed = Watershed(connectivity=watershed_connectivity)\n        self.generate_instance_contour = GenerateInstanceContour(\n            min_num_points=min_num_points, contour_level=contour_level\n        )\n        self.generate_instance_centroid = GenerateInstanceCentroid()\n\n    def __call__(  # type: ignore\n        self, nuclear_prediction: NdarrayOrTensor, hover_map: NdarrayOrTensor\n    ) -> tuple[dict, NdarrayOrTensor]:\n        \"\"\"post-process instance segmentation branches (NP and HV) to generate instance segmentation map.\n\n        Args:\n            nuclear_prediction: the output of NP (nuclear prediction) branch of HoVerNet model\n            hover_map: the output of HV (hover map) branch of HoVerNet model\n        \"\"\"\n\n        # Process NP and HV branch using watershed algorithm\n        watershed_mask = self.generate_watershed_mask(nuclear_prediction)\n        instance_borders = self.generate_instance_border(watershed_mask, hover_map)\n        distance_map = self.generate_distance_map(watershed_mask, instance_borders)\n        watershed_markers = self.generate_watershed_markers(watershed_mask, instance_borders)\n        instance_map = self.watershed(distance_map, watershed_mask, watershed_markers)\n\n        # Create bounding boxes, contours and centroids\n        instance_ids = set(np.unique(instance_map)) - {0}  # exclude background\n        instance_info = {}\n        for inst_id in instance_ids:\n            instance_mask = instance_map == inst_id\n            instance_bbox = BoundingRect()(instance_mask)\n\n            instance_mask = instance_mask[\n                :, instance_bbox[0][0] : instance_bbox[0][1], instance_bbox[0][2] : instance_bbox[0][3]\n            ]\n            offset = [instance_bbox[0][2], instance_bbox[0][0]]\n            instance_contour = self.generate_instance_contour(FillHoles()(instance_mask), offset)\n            if instance_contour is not None:\n                instance_centroid = self.generate_instance_centroid(instance_mask, offset)\n                instance_info[inst_id] = {\n                    \"bounding_box\": instance_bbox,\n                    \"centroid\": instance_centroid,\n                    \"contour\": instance_contour,\n                }\n        instance_map = convert_to_tensor(instance_map, device=self.device)\n        return instance_info, instance_map\n\n\nclass HoVerNetNuclearTypePostProcessing(Transform):\n    \"\"\"\n    The post-processing transform for HoVerNet model to generate nuclear type information.\n    It updates the input instance info dictionary with information about types of the nuclei (value and probability).\n    Also if requested (`return_type_map=True`), it generates a pixel-level type map.\n\n    Args:\n        activation: the activation layer to be applied on nuclear type branch. It can be \"softmax\" or \"sigmoid\" string,\n            or any callable. Defaults to \"softmax\".\n        threshold: an optional float value to threshold to binarize probability map.\n            If not provided, defaults to 0.5 when activation is not \"softmax\", otherwise None.\n        return_type_map: whether to calculate and return pixel-level type map.\n        device: target device to put the output Tensor data.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        activation: str | Callable = \"softmax\",\n        threshold: float | None = None,\n        return_type_map: bool = True,\n        device: str | torch.device | None = None,\n    ) -> None:\n        super().__init__()\n        self.device = device\n        self.return_type_map = return_type_map\n        self.generate_instance_type = GenerateInstanceType()\n\n        # set activation layer\n        use_softmax = False\n        use_sigmoid = False\n        activation_fn = None\n        if isinstance(activation, str):\n            if activation.lower() == \"softmax\":\n                use_softmax = True\n            elif activation.lower() == \"sigmoid\":\n                use_sigmoid = True\n            else:\n                raise ValueError(\n                    f\"The activation should be 'softmax' or 'sigmoid' string, or any callable. '{activation}' was given.\"\n                )\n        elif callable(activation):\n            activation_fn = activation\n        else:\n            raise ValueError(f\"The activation type should be either str or callable. '{type(activation)}' was given.\")\n        self.activation = Activations(softmax=use_softmax, sigmoid=use_sigmoid, other=activation_fn)\n\n        # set discretization transform\n        if not use_softmax and threshold is None:\n            threshold = 0.5\n        self.as_discrete = AsDiscrete(threshold=threshold, argmax=use_softmax)\n\n    def __call__(  # type: ignore\n        self, type_prediction: NdarrayOrTensor, instance_info: dict[int, dict], instance_map: NdarrayOrTensor\n    ) -> tuple[dict, NdarrayOrTensor | None]:\n        \"\"\"Process NC (type prediction) branch and combine it with instance segmentation\n        It updates the instance_info with instance type and associated probability, and generate instance type map.\n\n        Args:\n            instance_info: instance information dictionary, the output of :py:class:`HoVerNetInstanceMapPostProcessing`\n            instance_map: instance segmentation map, the output of :py:class:`HoVerNetInstanceMapPostProcessing`\n            type_prediction: the output of NC (type prediction) branch of HoVerNet model\n        \"\"\"\n        type_prediction = self.activation(type_prediction)\n        type_prediction = self.as_discrete(type_prediction)\n\n        type_map = None\n        if self.return_type_map:\n            type_map = convert_to_dst_type(torch.zeros(instance_map.shape), instance_map)[0]\n\n        for inst_id in instance_info:\n            instance_type, instance_type_prob = self.generate_instance_type(\n                type_pred=type_prediction,\n                seg_pred=instance_map,\n                bbox=instance_info[inst_id][\"bounding_box\"],\n                instance_id=inst_id,\n            )\n            # update instance info dict with type data\n            instance_info[inst_id][\"type_prob\"] = instance_type_prob\n            instance_info[inst_id][\"type\"] = instance_type\n\n            # update instance type map\n            if type_map is not None:\n                type_map[instance_map == inst_id] = instance_type\n                type_map = convert_to_tensor(type_map, device=self.device)\n\n        return instance_info, type_map\n"
  },
  {
    "path": "monai/apps/pathology/transforms/post/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Hashable, Mapping\n\nimport numpy as np\nimport torch\n\nfrom monai.apps.pathology.transforms.post.array import (\n    GenerateDistanceMap,\n    GenerateInstanceBorder,\n    GenerateInstanceCentroid,\n    GenerateInstanceContour,\n    GenerateInstanceType,\n    GenerateSuccinctContour,\n    GenerateWatershedMarkers,\n    GenerateWatershedMask,\n    HoVerNetInstanceMapPostProcessing,\n    HoVerNetNuclearTypePostProcessing,\n    Watershed,\n)\nfrom monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor\nfrom monai.transforms.transform import MapTransform, Transform\nfrom monai.utils import optional_import\nfrom monai.utils.enums import HoVerNetBranch\n\nfind_contours, _ = optional_import(\"skimage.measure\", name=\"find_contours\")\n\n__all__ = [\n    \"WatershedD\",\n    \"WatershedDict\",\n    \"Watershedd\",\n    \"GenerateWatershedMaskD\",\n    \"GenerateWatershedMaskDict\",\n    \"GenerateWatershedMaskd\",\n    \"GenerateInstanceBorderD\",\n    \"GenerateInstanceBorderDict\",\n    \"GenerateInstanceBorderd\",\n    \"GenerateDistanceMapD\",\n    \"GenerateDistanceMapDict\",\n    \"GenerateDistanceMapd\",\n    \"GenerateWatershedMarkersD\",\n    \"GenerateWatershedMarkersDict\",\n    \"GenerateWatershedMarkersd\",\n    \"GenerateSuccinctContourDict\",\n    \"GenerateSuccinctContourD\",\n    \"GenerateSuccinctContourd\",\n    \"GenerateInstanceContourDict\",\n    \"GenerateInstanceContourD\",\n    \"GenerateInstanceContourd\",\n    \"GenerateInstanceCentroidDict\",\n    \"GenerateInstanceCentroidD\",\n    \"GenerateInstanceCentroidd\",\n    \"GenerateInstanceTypeDict\",\n    \"GenerateInstanceTypeD\",\n    \"GenerateInstanceTyped\",\n    \"HoVerNetInstanceMapPostProcessingDict\",\n    \"HoVerNetInstanceMapPostProcessingD\",\n    \"HoVerNetInstanceMapPostProcessingd\",\n    \"HoVerNetNuclearTypePostProcessingDict\",\n    \"HoVerNetNuclearTypePostProcessingD\",\n    \"HoVerNetNuclearTypePostProcessingd\",\n]\n\n\nclass Watershedd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.Watershed`.\n    Use `skimage.segmentation.watershed` to get instance segmentation results from images.\n    See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        mask_key: keys of mask used in watershed. Only points at which mask == True will be labeled.\n        markers_key: keys of markers used in watershed. If None (no markers given), the local minima of the image are\n            used as markers.\n        connectivity: An array with the same number of dimensions as image whose non-zero elements indicate neighbors\n            for connection. Following the scipy convention, default is a one-connected array of the dimension of the\n            image.\n        dtype: target data content type to convert. Defaults to np.uint8.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Raises:\n        ValueError: when the `image` shape is not [1, H, W].\n        ValueError: when the `mask` shape is not [1, H, W].\n\n    \"\"\"\n\n    backend = Watershed.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        mask_key: str | None = \"mask\",\n        markers_key: str | None = None,\n        connectivity: int | None = 1,\n        dtype: DtypeLike = np.uint8,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.mask_key = mask_key\n        self.markers_key = markers_key\n        self.transform = Watershed(connectivity=connectivity, dtype=dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        markers = d[self.markers_key] if self.markers_key else None\n        mask = d[self.mask_key] if self.mask_key else None\n\n        for key in self.key_iterator(d):\n            d[key] = self.transform(d[key], mask, markers)\n\n        return d\n\n\nclass GenerateWatershedMaskd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateWatershedMask`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        mask_key: the mask will be written to the value of `{mask_key}`.\n        activation: the activation layer to be applied on nuclear type branch. It can be \"softmax\" or \"sigmoid\" string,\n            or any callable. Defaults to \"softmax\".\n        threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold.\n        min_object_size: objects smaller than this size are removed. Defaults to 10.\n        dtype: target data content type to convert, default is np.uint8.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GenerateWatershedMask.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        mask_key: str = \"mask\",\n        activation: str | Callable = \"softmax\",\n        threshold: float | None = None,\n        min_object_size: int = 10,\n        dtype: DtypeLike = np.uint8,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.mask_key = mask_key\n        self.transform = GenerateWatershedMask(\n            activation=activation, threshold=threshold, min_object_size=min_object_size, dtype=dtype\n        )\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            mask = self.transform(d[key])\n            if self.mask_key in d:\n                raise KeyError(f\"Mask with key {self.mask_key} already exists.\")\n            d[self.mask_key] = mask\n        return d\n\n\nclass GenerateInstanceBorderd(Transform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateInstanceBorder`.\n\n    Args:\n        mask_key: the input key where the watershed mask is stored. Defaults to `\"mask\"`.\n        hover_map_key: the input key where hover map is stored. Defaults to `\"hover_map\"`.\n        border_key: the output key where instance border map is written. Defaults to `\"border\"`.\n        kernel_size: the size of the Sobel kernel. Defaults to 21.\n        dtype: target data content type to convert, default is np.float32.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Raises:\n        ValueError: when the `hover_map` has only one value.\n        ValueError: when the `sobel gradient map` has only one value.\n\n    \"\"\"\n\n    backend = GenerateInstanceBorder.backend\n\n    def __init__(\n        self,\n        mask_key: str = \"mask\",\n        hover_map_key: str = \"hover_map\",\n        border_key: str = \"border\",\n        kernel_size: int = 21,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        self.mask_key = mask_key\n        self.hover_map_key = hover_map_key\n        self.border_key = border_key\n        self.transform = GenerateInstanceBorder(kernel_size=kernel_size, dtype=dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        if self.border_key in d:\n            raise KeyError(f\"The key '{self.border_key}' for instance border map already exists.\")\n        d[self.border_key] = self.transform(d[self.mask_key], d[self.hover_map_key])\n        return d\n\n\nclass GenerateDistanceMapd(Transform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateDistanceMap`.\n\n    Args:\n        mask_key: the input key where the watershed mask is stored. Defaults to `\"mask\"`.\n        border_key: the input key where instance border map is stored. Defaults to `\"border\"`.\n        dist_map_key: the output key where distance map is written. Defaults to `\"dist_map\"`.\n        smooth_fn: smoothing function for distance map, which can be any callable object.\n            If not provided :py:class:`monai.transforms.GaussianSmooth()` is used.\n        dtype: target data content type to convert, default is np.float32.\n    \"\"\"\n\n    backend = GenerateDistanceMap.backend\n\n    def __init__(\n        self,\n        mask_key: str = \"mask\",\n        border_key: str = \"border\",\n        dist_map_key: str = \"dist_map\",\n        smooth_fn: Callable | None = None,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        self.mask_key = mask_key\n        self.border_key = border_key\n        self.dist_map_key = dist_map_key\n        self.transform = GenerateDistanceMap(smooth_fn=smooth_fn, dtype=dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        if self.dist_map_key in d:\n            raise KeyError(f\"The key '{self.dist_map_key}' for distance map already exists.\")\n        d[self.dist_map_key] = self.transform(d[self.mask_key], d[self.border_key])\n        return d\n\n\nclass GenerateWatershedMarkersd(Transform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateWatershedMarkers`.\n\n    Args:\n        mask_key: the input key where the watershed mask is stored. Defaults to `\"mask\"`.\n        border_key: the input key where instance border map is stored. Defaults to `\"border\"`.\n        markers_key: the output key where markers is written. Defaults to `\"markers\"`.\n        threshold: threshold the float values of instance border map to int 0 or 1 with specified threshold.\n            It turns uncertain area to 1 and other area to 0. Defaults to 0.4.\n        radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2.\n        min_object_size: objects smaller than this size are removed. Defaults to 10.\n        postprocess_fn: execute additional post transformation on marker. Defaults to None.\n        dtype: target data content type to convert, default is np.uint8.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = GenerateWatershedMarkers.backend\n\n    def __init__(\n        self,\n        mask_key: str = \"mask\",\n        border_key: str = \"border\",\n        markers_key: str = \"markers\",\n        threshold: float = 0.4,\n        radius: int = 2,\n        min_object_size: int = 10,\n        postprocess_fn: Callable | None = None,\n        dtype: DtypeLike = np.uint8,\n    ) -> None:\n        self.mask_key = mask_key\n        self.border_key = border_key\n        self.markers_key = markers_key\n        self.transform = GenerateWatershedMarkers(\n            threshold=threshold,\n            radius=radius,\n            min_object_size=min_object_size,\n            postprocess_fn=postprocess_fn,\n            dtype=dtype,\n        )\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        if self.markers_key in d:\n            raise KeyError(f\"The key '{self.markers_key}' for markers already exists.\")\n        d[self.markers_key] = self.transform(d[self.mask_key], d[self.border_key])\n        return d\n\n\nclass GenerateSuccinctContourd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateSuccinctContour`.\n    Converts SciPy-style contours (generated by skimage.measure.find_contours) to a more succinct version which\n    only includes the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line).\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        height: height of bounding box, used to detect direction of line segment.\n        width: width of bounding box, used to detect direction of line segment.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GenerateSuccinctContour.backend\n\n    def __init__(self, keys: KeysCollection, height: int, width: int, allow_missing_keys: bool = False) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = GenerateSuccinctContour(height=height, width=width)\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n\n        return d\n\n\nclass GenerateInstanceContourd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceContour`.\n    Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include the pixels\n    to which lines need to be drawn\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        contour_key_postfix: the output contour coordinates will be written to the value of\n            `{key}_{contour_key_postfix}`.\n        offset_key: keys of offset used in `GenerateInstanceContour`.\n        min_num_points: assumed that the created contour does not form a contour if it does not contain more points\n            than the specified value. Defaults to 3.\n        level: optional. Value along which to find contours in the array. By default, the level is set\n            to (max(image) + min(image)) / 2.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GenerateInstanceContour.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        contour_key_postfix: str = \"contour\",\n        offset_key: str | None = None,\n        min_num_points: int = 3,\n        level: float | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = GenerateInstanceContour(min_num_points=min_num_points, contour_level=level)\n        self.contour_key_postfix = contour_key_postfix\n        self.offset_key = offset_key\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            offset = d[self.offset_key] if self.offset_key else None\n            contour = self.converter(d[key], offset)\n            key_to_add = f\"{key}_{self.contour_key_postfix}\"\n            if key_to_add in d:\n                raise KeyError(f\"Contour with key {key_to_add} already exists.\")\n            d[key_to_add] = contour\n        return d\n\n\nclass GenerateInstanceCentroidd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceCentroid`.\n    Generate instance centroid using `skimage.measure.centroid`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        centroid_key_postfix: the output centroid coordinates will be written to the value of\n            `{key}_{centroid_key_postfix}`.\n        offset_key: keys of offset used in `GenerateInstanceCentroid`.\n        dtype: the data type of output centroid.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GenerateInstanceCentroid.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        centroid_key_postfix: str = \"centroid\",\n        offset_key: str | None = None,\n        dtype: DtypeLike | None = int,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = GenerateInstanceCentroid(dtype=dtype)\n        self.centroid_key_postfix = centroid_key_postfix\n        self.offset_key = offset_key\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            offset = d[self.offset_key] if self.offset_key else None\n            centroid = self.converter(d[key], offset)\n            key_to_add = f\"{key}_{self.centroid_key_postfix}\"\n            if key_to_add in d:\n                raise KeyError(f\"Centroid with key {key_to_add} already exists.\")\n            d[key_to_add] = centroid\n        return d\n\n\nclass GenerateInstanceTyped(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceType`.\n    Generate instance type and probability for each instance.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        type_info_key: the output instance type and probability will be written to the value of\n            `{type_info_key}`.\n        bbox_key: keys of bounding box.\n        seg_pred_key: keys of segmentation prediction map.\n        instance_id_key: keys of instance id.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GenerateInstanceType.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        type_info_key: str = \"type_info\",\n        bbox_key: str = \"bbox\",\n        seg_pred_key: str = \"seg\",\n        instance_id_key: str = \"id\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = GenerateInstanceType()\n        self.type_info_key = type_info_key\n        self.bbox_key = bbox_key\n        self.seg_pred_key = seg_pred_key\n        self.instance_id_key = instance_id_key\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            seg = d[self.seg_pred_key]\n            bbox = d[self.bbox_key]\n            id = d[self.instance_id_key]\n            instance_type, type_prob = self.converter(d[key], seg, bbox, id)\n            key_to_add = f\"{self.type_info_key}\"\n            if key_to_add in d:\n                raise KeyError(f\"Type information with key {key_to_add} already exists.\")\n            d[key_to_add] = {\"inst_type\": instance_type, \"type_prob\": type_prob}\n        return d\n\n\nclass HoVerNetInstanceMapPostProcessingd(Transform):\n    \"\"\"\n    Dictionary-based wrapper for :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetInstanceMapPostProcessing`.\n    The post-processing transform for HoVerNet model to generate instance segmentation map.\n    It generates an instance segmentation map as well as a dictionary containing centroids, bounding boxes, and contours\n    for each instance.\n\n    Args:\n        nuclear_prediction_key: the key for HoVerNet NP (nuclear prediction) branch. Defaults to `HoVerNetBranch.NP`.\n        hover_map_key: the key for HoVerNet NC (nuclear prediction) branch. Defaults to `HoVerNetBranch.HV`.\n        instance_info_key: the output key where instance information (contour, bounding boxes, and centroids)\n            is written. Defaults to `\"instance_info\"`.\n        instance_map_key: the output key where instance map is written. Defaults to `\"instance_map\"`.\n        activation: the activation layer to be applied on the input probability map.\n            It can be \"softmax\" or \"sigmoid\" string, or any callable. Defaults to \"softmax\".\n        mask_threshold: a float value to threshold to binarize probability map to generate mask.\n        min_object_size: objects smaller than this size are removed. Defaults to 10.\n        sobel_kernel_size: the size of the Sobel kernel used in :py:class:`GenerateInstanceBorder`. Defaults to 5.\n        distance_smooth_fn: smoothing function for distance map.\n            If not provided, :py:class:`monai.transforms.intensity.GaussianSmooth()` will be used.\n        marker_threshold: a float value to threshold to binarize instance border map for markers.\n            It turns uncertain area to 1 and other area to 0. Defaults to 0.4.\n        marker_radius: the radius of the disk-shaped footprint used in `opening` of markers. Defaults to 2.\n        marker_postprocess_fn: post-process function for watershed markers.\n            If not provided, :py:class:`monai.transforms.post.FillHoles()` will be used.\n        watershed_connectivity: `connectivity` argument of `skimage.segmentation.watershed`.\n        min_num_points: minimum number of points to be considered as a contour. Defaults to 3.\n        contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.\n            If not provided, the level is set to `(max(image) + min(image)) / 2`.\n        device: target device to put the output Tensor data.\n    \"\"\"\n\n    def __init__(\n        self,\n        nuclear_prediction_key: str = HoVerNetBranch.NP.value,\n        hover_map_key: str = HoVerNetBranch.HV.value,\n        instance_info_key: str = \"instance_info\",\n        instance_map_key: str = \"instance_map\",\n        activation: str | Callable = \"softmax\",\n        mask_threshold: float | None = None,\n        min_object_size: int = 10,\n        sobel_kernel_size: int = 5,\n        distance_smooth_fn: Callable | None = None,\n        marker_threshold: float = 0.4,\n        marker_radius: int = 2,\n        marker_postprocess_fn: Callable | None = None,\n        watershed_connectivity: int | None = 1,\n        min_num_points: int = 3,\n        contour_level: float | None = None,\n        device: str | torch.device | None = None,\n    ) -> None:\n        super().__init__()\n        self.instance_map_post_process = HoVerNetInstanceMapPostProcessing(\n            activation=activation,\n            mask_threshold=mask_threshold,\n            min_object_size=min_object_size,\n            sobel_kernel_size=sobel_kernel_size,\n            distance_smooth_fn=distance_smooth_fn,\n            marker_threshold=marker_threshold,\n            marker_radius=marker_radius,\n            marker_postprocess_fn=marker_postprocess_fn,\n            watershed_connectivity=watershed_connectivity,\n            min_num_points=min_num_points,\n            contour_level=contour_level,\n            device=device,\n        )\n        self.nuclear_prediction_key = nuclear_prediction_key\n        self.hover_map_key = hover_map_key\n        self.instance_info_key = instance_info_key\n        self.instance_map_key = instance_map_key\n\n    def __call__(self, data):\n        d = dict(data)\n\n        for k in [self.instance_info_key, self.instance_map_key]:\n            if k in d:\n                raise ValueError(\"The output key ['{k}'] already exists in the input dictionary!\")\n\n        d[self.instance_info_key], d[self.instance_map_key] = self.instance_map_post_process(\n            d[self.nuclear_prediction_key], d[self.hover_map_key]\n        )\n\n        return d\n\n\nclass HoVerNetNuclearTypePostProcessingd(Transform):\n    \"\"\"\n    Dictionary-based wrapper for :py:class:`monai.apps.pathology.transforms.post.array.HoVerNetNuclearTypePostProcessing`.\n    It updates the input instance info dictionary with information about types of the nuclei (value and probability).\n    Also if requested (`return_type_map=True`), it generates a pixel-level type map.\n\n    Args:\n        type_prediction_key: the key for HoVerNet NC (type prediction) branch. Defaults to `HoVerNetBranch.NC`.\n        instance_info_key: the key where instance information (contour, bounding boxes, and centroids) is stored.\n            Defaults to `\"instance_info\"`.\n        instance_map_key: the key where instance map is stored. Defaults to `\"instance_map\"`.\n        type_map_key: the output key where type map is written. Defaults to `\"type_map\"`.\n        device: target device to put the output Tensor data.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        type_prediction_key: str = HoVerNetBranch.NC.value,\n        instance_info_key: str = \"instance_info\",\n        instance_map_key: str = \"instance_map\",\n        type_map_key: str = \"type_map\",\n        activation: str | Callable = \"softmax\",\n        threshold: float | None = None,\n        return_type_map: bool = True,\n        device: str | torch.device | None = None,\n    ) -> None:\n        super().__init__()\n        self.type_post_process = HoVerNetNuclearTypePostProcessing(\n            activation=activation, threshold=threshold, return_type_map=return_type_map, device=device\n        )\n        self.type_prediction_key = type_prediction_key\n        self.instance_info_key = instance_info_key\n        self.instance_map_key = instance_map_key\n        self.type_map_key = type_map_key\n        self.return_type_map = return_type_map\n\n    def __call__(self, data):\n        d = dict(data)\n\n        d[self.instance_info_key], type_map = self.type_post_process(\n            d[self.type_prediction_key], d[self.instance_info_key], d[self.instance_map_key]\n        )\n        if self.return_type_map:\n            if self.type_map_key in d:\n                raise ValueError(\"The output key ['{self.type_map_key}'] already exists in the input dictionary!\")\n            d[self.type_map_key] = type_map\n\n        return d\n\n\nWatershedD = WatershedDict = Watershedd\nGenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd\nGenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd\nGenerateDistanceMapD = GenerateDistanceMapDict = GenerateDistanceMapd\nGenerateWatershedMarkersD = GenerateWatershedMarkersDict = GenerateWatershedMarkersd\nGenerateSuccinctContourDict = GenerateSuccinctContourD = GenerateSuccinctContourd\nGenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd\nGenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd\nGenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped\nHoVerNetInstanceMapPostProcessingDict = HoVerNetInstanceMapPostProcessingD = HoVerNetInstanceMapPostProcessingd\nHoVerNetNuclearTypePostProcessingDict = HoVerNetNuclearTypePostProcessingD = HoVerNetNuclearTypePostProcessingd\n"
  },
  {
    "path": "monai/apps/pathology/transforms/stain/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .array import ExtractHEStains, NormalizeHEStains\nfrom .dictionary import (\n    ExtractHEStainsd,\n    ExtractHEStainsD,\n    ExtractHEStainsDict,\n    NormalizeHEStainsd,\n    NormalizeHEStainsD,\n    NormalizeHEStainsDict,\n)\n"
  },
  {
    "path": "monai/apps/pathology/transforms/stain/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport numpy as np\n\nfrom monai.transforms.transform import Transform\n\n\nclass ExtractHEStains(Transform):\n    \"\"\"Class to extract a target stain from an image, using stain deconvolution (see Note).\n\n    Args:\n        tli: transmitted light intensity. Defaults to 240.\n        alpha: tolerance in percentile for the pseudo-min (alpha percentile)\n            and pseudo-max (100 - alpha percentile). Defaults to 1.\n        beta: absorbance threshold for transparent pixels. Defaults to 0.15\n        max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).\n            Defaults to (1.9705, 1.0308).\n\n    Note:\n        For more information refer to:\n        - the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf\n        - the previous implementations:\n\n          - MATLAB: https://github.com/mitkovetta/staining-normalization\n          - Python: https://github.com/schaugf/HEnorm_python\n    \"\"\"\n\n    def __init__(\n        self, tli: float = 240, alpha: float = 1, beta: float = 0.15, max_cref: tuple | np.ndarray = (1.9705, 1.0308)\n    ) -> None:\n        self.tli = tli\n        self.alpha = alpha\n        self.beta = beta\n        self.max_cref = np.array(max_cref)\n\n    def _deconvolution_extract_stain(self, image: np.ndarray) -> np.ndarray:\n        \"\"\"Perform Stain Deconvolution and return stain matrix for the image.\n\n        Args:\n            image: uint8 RGB image to perform stain deconvolution on\n\n        Return:\n            he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values)\n        \"\"\"\n        # check image type and values\n        if not isinstance(image, np.ndarray):\n            raise TypeError(\"Image must be of type numpy.ndarray.\")\n        if image.min() < 0:\n            raise ValueError(\"Image should not have negative values.\")\n        if image.max() > 255:\n            raise ValueError(\"Image should not have values greater than 255.\")\n\n        # reshape image and calculate absorbance\n        image = image.reshape((-1, 3))\n        image = image.astype(np.float32, copy=False) + 1.0\n        absorbance = -np.log(image.clip(max=self.tli) / self.tli)\n\n        # remove transparent pixels\n        absorbance_hat = absorbance[np.all(absorbance > self.beta, axis=1)]\n        if len(absorbance_hat) == 0:\n            raise ValueError(\"All pixels of the input image are below the absorbance threshold.\")\n\n        # compute eigenvectors\n        _, eigvecs = np.linalg.eigh(np.cov(absorbance_hat.T).astype(np.float32, copy=False))\n\n        # project on the plane spanned by the eigenvectors corresponding to the two largest eigenvalues\n        t_hat = absorbance_hat.dot(eigvecs[:, 1:3])\n\n        # find the min and max vectors and project back to absorbance space\n        phi = np.arctan2(t_hat[:, 1], t_hat[:, 0])\n        min_phi = np.percentile(phi, self.alpha)\n        max_phi = np.percentile(phi, 100 - self.alpha)\n        v_min = eigvecs[:, 1:3].dot(np.array([(np.cos(min_phi), np.sin(min_phi))], dtype=np.float32).T)\n        v_max = eigvecs[:, 1:3].dot(np.array([(np.cos(max_phi), np.sin(max_phi))], dtype=np.float32).T)\n\n        # a heuristic to make the vector corresponding to hematoxylin first and the one corresponding to eosin second\n        # Hematoxylin: high blue, lower red (low R/B ratio)\n        # Eosin: high red, lower blue (high R/B ratio)\n        eps = np.finfo(np.float32).eps\n        v_min_rb_ratio = v_min[0] / (v_min[2] + eps)\n        v_max_rb_ratio = v_max[0] / (v_max[2] + eps)\n        if v_min_rb_ratio < v_max_rb_ratio:\n            he = np.array((v_min[:, 0], v_max[:, 0]), dtype=np.float32).T\n        else:\n            he = np.array((v_max[:, 0], v_min[:, 0]), dtype=np.float32).T\n\n        return he\n\n    def __call__(self, image: np.ndarray) -> np.ndarray:\n        \"\"\"Perform stain extraction.\n\n        Args:\n            image: uint8 RGB image to extract stain from\n\n        return:\n            target_he: H&E absorbance matrix for the image (first column is H, second column is E, rows are RGB values)\n        \"\"\"\n        if not isinstance(image, np.ndarray):\n            raise TypeError(\"Image must be of type numpy.ndarray.\")\n\n        target_he = self._deconvolution_extract_stain(image)\n        return target_he\n\n\nclass NormalizeHEStains(Transform):\n    \"\"\"Class to normalize patches/images to a reference or target image stain (see Note).\n\n    Performs stain deconvolution of the source image using the ExtractHEStains\n    class, to obtain the stain matrix and calculate the stain concentration matrix\n    for the image. Then, performs the inverse Beer-Lambert transform to recreate the\n    patch using the target H&E stain matrix provided. If no target stain provided, a default\n    reference stain is used. Similarly, if no maximum stain concentrations are provided, a\n    reference maximum stain concentrations matrix is used.\n\n    Args:\n        tli: transmitted light intensity. Defaults to 240.\n        alpha: tolerance in percentile for the pseudo-min (alpha percentile) and\n            pseudo-max (100 - alpha percentile). Defaults to 1.\n        beta: absorbance threshold for transparent pixels. Defaults to 0.15.\n        target_he: target stain matrix. Defaults to ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)).\n        max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).\n            Defaults to [1.9705, 1.0308].\n\n    Note:\n        For more information refer to:\n        - the original paper: Macenko et al., 2009 http://wwwx.cs.unc.edu/~mn/sites/default/files/macenko2009.pdf\n        - the previous implementations:\n\n            - MATLAB: https://github.com/mitkovetta/staining-normalization\n            - Python: https://github.com/schaugf/HEnorm_python\n    \"\"\"\n\n    def __init__(\n        self,\n        tli: float = 240,\n        alpha: float = 1,\n        beta: float = 0.15,\n        target_he: tuple | np.ndarray = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)),\n        max_cref: tuple | np.ndarray = (1.9705, 1.0308),\n    ) -> None:\n        self.tli = tli\n        self.target_he = np.array(target_he)\n        self.max_cref = np.array(max_cref)\n        self.stain_extractor = ExtractHEStains(tli=self.tli, alpha=alpha, beta=beta, max_cref=self.max_cref)\n\n    def __call__(self, image: np.ndarray) -> np.ndarray:\n        \"\"\"Perform stain normalization.\n\n        Args:\n            image: uint8 RGB image/patch to be stain normalized, pixel values between 0 and 255\n\n        Return:\n            image_norm: stain normalized image/patch\n        \"\"\"\n        # check image type and values\n        if not isinstance(image, np.ndarray):\n            raise TypeError(\"Image must be of type numpy.ndarray.\")\n        if image.min() < 0:\n            raise ValueError(\"Image should not have negative values.\")\n        if image.max() > 255:\n            raise ValueError(\"Image should not have values greater than 255.\")\n\n        # extract stain of the image\n        he = self.stain_extractor(image)\n\n        # reshape image and calculate absorbance\n        h, w, _ = image.shape\n        image = image.reshape((-1, 3))\n        image = image.astype(np.float32) + 1.0\n        absorbance = -np.log(image.clip(max=self.tli) / self.tli)\n\n        # rows correspond to channels (RGB), columns to absorbance values\n        y = np.reshape(absorbance, (-1, 3)).T\n\n        # determine concentrations of the individual stains\n        conc = np.linalg.lstsq(he, y, rcond=None)[0]\n\n        # normalize stain concentrations\n        max_conc = np.asarray([np.percentile(conc[0, :], 99), np.percentile(conc[1, :], 99)], dtype=np.float32)\n        tmp = np.divide(max_conc, self.max_cref, dtype=np.float32)\n        image_c = np.divide(conc, tmp[:, np.newaxis], dtype=np.float32)\n\n        image_norm: np.ndarray = np.multiply(self.tli, np.exp(-self.target_he.dot(image_c)), dtype=np.float32)\n        image_norm[image_norm > 255] = 254\n        image_norm = np.reshape(image_norm.T, (h, w, 3)).astype(np.uint8)\n        return image_norm\n"
  },
  {
    "path": "monai/apps/pathology/transforms/stain/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the pathology transforms\ndefined in :py:class:`monai.apps.pathology.transforms.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping\n\nimport numpy as np\n\nfrom monai.config import KeysCollection\nfrom monai.transforms.transform import MapTransform\n\nfrom .array import ExtractHEStains, NormalizeHEStains\n\n\nclass ExtractHEStainsd(MapTransform):\n    \"\"\"Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.ExtractHEStains`.\n    Class to extract a target stain from an image, using stain deconvolution.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        tli: transmitted light intensity. Defaults to 240.\n        alpha: tolerance in percentile for the pseudo-min (alpha percentile)\n            and pseudo-max (100 - alpha percentile). Defaults to 1.\n        beta: absorbance threshold for transparent pixels. Defaults to 0.15\n        max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).\n            Defaults to (1.9705, 1.0308).\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        tli: float = 240,\n        alpha: float = 1,\n        beta: float = 0.15,\n        max_cref: tuple | np.ndarray = (1.9705, 1.0308),\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.extractor = ExtractHEStains(tli=tli, alpha=alpha, beta=beta, max_cref=max_cref)\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.extractor(d[key])\n        return d\n\n\nclass NormalizeHEStainsd(MapTransform):\n    \"\"\"Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.NormalizeHEStains`.\n\n    Class to normalize patches/images to a reference or target image stain.\n\n    Performs stain deconvolution of the source image using the ExtractHEStains\n    class, to obtain the stain matrix and calculate the stain concentration matrix\n    for the image. Then, performs the inverse Beer-Lambert transform to recreate the\n    patch using the target H&E stain matrix provided. If no target stain provided, a default\n    reference stain is used. Similarly, if no maximum stain concentrations are provided, a\n    reference maximum stain concentrations matrix is used.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        tli: transmitted light intensity. Defaults to 240.\n        alpha: tolerance in percentile for the pseudo-min (alpha percentile) and\n            pseudo-max (100 - alpha percentile). Defaults to 1.\n        beta: absorbance threshold for transparent pixels. Defaults to 0.15.\n        target_he: target stain matrix. Defaults to None.\n        max_cref: reference maximum stain concentrations for Hematoxylin & Eosin (H&E).\n            Defaults to None.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        tli: float = 240,\n        alpha: float = 1,\n        beta: float = 0.15,\n        target_he: tuple | np.ndarray = ((0.5626, 0.2159), (0.7201, 0.8012), (0.4062, 0.5581)),\n        max_cref: tuple | np.ndarray = (1.9705, 1.0308),\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.normalizer = NormalizeHEStains(tli=tli, alpha=alpha, beta=beta, target_he=target_he, max_cref=max_cref)\n\n    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.normalizer(d[key])\n        return d\n\n\nExtractHEStainsDict = ExtractHEStainsD = ExtractHEStainsd\nNormalizeHEStainsDict = NormalizeHEStainsD = NormalizeHEStainsd\n"
  },
  {
    "path": "monai/apps/pathology/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms.post.array import ProbNMS\nfrom monai.utils import optional_import\n\nmeasure, _ = optional_import(\"skimage.measure\")\nndimage, _ = optional_import(\"scipy.ndimage\")\n\n\ndef compute_multi_instance_mask(mask: np.ndarray, threshold: float) -> Any:\n    \"\"\"\n    This method computes the segmentation mask according to the binary tumor mask.\n\n    Args:\n        mask: the binary mask array\n        threshold: the threshold to fill holes\n    \"\"\"\n\n    neg = 255 - mask * 255\n    distance = ndimage.distance_transform_edt(neg)\n    binary = distance < threshold\n\n    filled_image = ndimage.binary_fill_holes(binary)\n    multi_instance_mask = measure.label(filled_image, connectivity=2)\n\n    return multi_instance_mask\n\n\ndef compute_isolated_tumor_cells(tumor_mask: np.ndarray, threshold: float) -> list[int]:\n    \"\"\"\n    This method computes identifies Isolated Tumor Cells (ITC) and return their labels.\n\n    Args:\n        tumor_mask: the tumor mask.\n        threshold: the threshold (at the mask level) to define an isolated tumor cell (ITC).\n            A region with the longest diameter less than this threshold is considered as an ITC.\n    \"\"\"\n    max_label = np.amax(tumor_mask)\n    properties = measure.regionprops(tumor_mask)\n    itc_list = [i + 1 for i in range(max_label) if properties[i].major_axis_length < threshold]\n\n    return itc_list\n\n\nclass PathologyProbNMS(ProbNMS):\n    \"\"\"\n    This class extends monai.utils.ProbNMS and add the `resolution` option for\n    Pathology.\n    \"\"\"\n\n    def __call__(self, probs_map: np.ndarray | torch.Tensor, resolution_level: int = 0) -> list[list]:\n        \"\"\"\n        probs_map: the input probabilities map, it must have shape (H[, W, ...]).\n        resolution_level: the level at which the probabilities map is made.\n        \"\"\"\n        resolution = pow(2, resolution_level)\n        org_outputs = ProbNMS.__call__(self, probs_map)\n        outputs = []\n        for org_output in org_outputs:\n            prob = org_output[0]\n            coord = np.asarray(org_output[1:])\n            coord_wsi = ((coord + 0.5) * resolution).astype(int)\n            outputs.append([prob] + list(coord_wsi))\n        return outputs\n"
  },
  {
    "path": "monai/apps/reconstruction/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/reconstruction/complex_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis script contains utility functions for complex-value PyTorch tensor.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport re\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.utils.type_conversion import convert_to_numpy, convert_to_tensor\n\n\ndef convert_to_tensor_complex(\n    data: NdarrayOrTensor | list | int | float,\n    dtype: torch.dtype | None = None,\n    device: torch.device | None = None,\n    wrap_sequence: bool = True,\n    track_meta: bool = False,\n) -> Tensor:\n    \"\"\"\n    Convert complex-valued data to a 2-channel PyTorch tensor.\n    The real and imaginary parts are stacked along the last dimension.\n    This function relies on 'monai.utils.type_conversion.convert_to_tensor'\n\n    Args:\n        data: input data can be PyTorch Tensor, numpy array, list, int, and float.\n            will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original.\n            for list, convert every item to a Tensor if applicable.\n        dtype: target data type to when converting to Tensor.\n        device: target device to put the converted Tensor data.\n        wrap_sequence: if `False`, then lists will recursively call this function.\n            E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.\n        track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.\n            default to `False`.\n\n    Returns:\n        PyTorch version of the data\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            data = np.array([ [1+1j, 1-1j], [2+2j, 2-2j] ])\n            # the following line prints (2,2)\n            print(data.shape)\n            # the following line prints torch.Size([2, 2, 2])\n            print(convert_to_tensor_complex(data).shape)\n    \"\"\"\n    # if data is not complex, just turn it into a tensor\n    if isinstance(data, Tensor):\n        if not torch.is_complex(data):\n            converted_data: Tensor = convert_to_tensor(\n                data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta\n            )\n            return converted_data\n    else:\n        if not np.iscomplexobj(data):\n            converted_data = convert_to_tensor(\n                data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta\n            )\n            return converted_data\n\n    # if data is complex, turn its stacked version into a tensor\n    if isinstance(data, torch.Tensor):\n        data = torch.stack([data.real, data.imag], dim=-1)\n\n    elif isinstance(data, np.ndarray):\n        if re.search(r\"[SaUO]\", data.dtype.str) is None:\n            # numpy array with 0 dims is also sequence iterable,\n            # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims\n            if data.ndim > 0:\n                data = np.ascontiguousarray(data)\n            data = np.stack((data.real, data.imag), axis=-1)\n\n    elif isinstance(data, (float, int)):\n        data = [[data.real, data.imag]]\n\n    elif isinstance(data, list):\n        data = convert_to_numpy(data, wrap_sequence=True)\n        data = np.stack((data.real, data.imag), axis=-1).tolist()  # type: ignore\n\n    converted_data = convert_to_tensor(\n        data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta\n    )\n    return converted_data\n\n\ndef complex_abs_t(x: Tensor) -> Tensor:\n    \"\"\"\n    Compute the absolute value of a complex tensor.\n\n    Args:\n        x: Input tensor with 2 channels in the last dimension representing real and imaginary parts.\n\n    Returns:\n        Absolute value along the last dimension\n    \"\"\"\n    if x.shape[-1] != 2:\n        raise ValueError(f\"x.shape[-1] is not 2 ({x.shape[-1]}).\")\n    return (x[..., 0] ** 2 + x[..., 1] ** 2) ** 0.5  # type: ignore\n\n\ndef complex_abs(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute the absolute value of a complex array.\n\n    Args:\n        x: Input array/tensor with 2 channels in the last dimension representing real and imaginary parts.\n\n    Returns:\n        Absolute value along the last dimension\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            x = np.array([3,4])[np.newaxis]\n            # the following line prints 5\n            print(complex_abs(x))\n    \"\"\"\n    return complex_abs_t(x)  # type: ignore\n\n\ndef complex_mul_t(x: Tensor, y: Tensor) -> Tensor:\n    \"\"\"\n    Compute complex-valued multiplication. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)\n\n    Args:\n        x: Input tensor with 2 channels in the last dimension representing real and imaginary parts.\n        y: Input tensor with 2 channels in the last dimension representing real and imaginary parts.\n\n    Returns:\n        Complex multiplication of x and y\n    \"\"\"\n    if x.shape[-1] != 2 or y.shape[-1] != 2:\n        raise ValueError(f\"last dim must be 2, but x.shape[-1] is {x.shape[-1]} and y.shape[-1] is {y.shape[-1]}.\")\n\n    real_part = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]\n    imag_part = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]\n\n    return torch.stack((real_part, imag_part), dim=-1)\n\n\ndef complex_mul(x: NdarrayOrTensor, y: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute complex-valued multiplication. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)\n\n    Args:\n        x: Input array/tensor with 2 channels in the last dimension representing real and imaginary parts.\n        y: Input array/tensor with 2 channels in the last dimension representing real and imaginary parts.\n\n    Returns:\n        Complex multiplication of x and y\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            x = np.array([[1,2],[3,4]])\n            y = np.array([[1,1],[1,1]])\n            # the following line prints array([[-1,  3], [-1,  7]])\n            print(complex_mul(x,y))\n    \"\"\"\n    if x.shape[-1] != 2 or y.shape[-1] != 2:\n        raise ValueError(f\"last dim must be 2, but x.shape[-1] is {x.shape[-1]} and y.shape[-1] is {y.shape[-1]}.\")\n\n    if isinstance(x, Tensor):\n        return complex_mul_t(x, y)  # type: ignore\n\n    else:\n        real_part = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]\n        imag_part = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]\n\n        mult: np.ndarray = np.stack((real_part, imag_part), axis=-1)\n        return mult\n\n\ndef complex_conj_t(x: Tensor) -> Tensor:\n    \"\"\"\n    Compute complex conjugate of a tensor. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)\n\n    Args:\n        x: Input tensor with 2 channels in the last dimension representing real and imaginary parts.\n\n    Returns:\n        Complex conjugate of x\n    \"\"\"\n    if x.shape[-1] != 2:\n        raise ValueError(f\"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.\")\n\n    return torch.stack((x[..., 0], -x[..., 1]), dim=-1)\n\n\ndef complex_conj(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute complex conjugate of an/a array/tensor. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)\n\n    Args:\n        x: Input array/tensor with 2 channels in the last dimension representing real and imaginary parts.\n\n    Returns:\n        Complex conjugate of x\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            x = np.array([[1,2],[3,4]])\n            # the following line prints array([[ 1, -2], [ 3, -4]])\n            print(complex_conj(x))\n    \"\"\"\n    if x.shape[-1] != 2:\n        raise ValueError(f\"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.\")\n\n    if isinstance(x, Tensor):\n        return complex_conj_t(x)\n    else:\n        np_conj: np.ndarray = np.stack((x[..., 0], -x[..., 1]), axis=-1)\n        return np_conj\n"
  },
  {
    "path": "monai/apps/reconstruction/fastmri_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nfrom collections.abc import Sequence\n\nimport numpy as np\nfrom numpy import ndarray\n\nfrom monai.config import PathLike\nfrom monai.data.image_reader import ImageReader\nfrom monai.data.utils import is_supported_format\nfrom monai.utils import FastMRIKeys, optional_import, require_pkg\n\nh5py, has_h5py = optional_import(\"h5py\")\n\n\n@require_pkg(pkg_name=\"h5py\")\nclass FastMRIReader(ImageReader):\n    \"\"\"\n    Load fastMRI files with '.h5' suffix. fastMRI files, when loaded with \"h5py\",\n    are HDF5 dictionary-like datasets. The keys are:\n\n    - kspace: contains the fully-sampled kspace\n    - reconstruction_rss: contains the root sum of squares of ifft of kspace. This\n        is the ground-truth image.\n\n    It also has several attributes with the following keys:\n\n    - acquisition (str): acquisition mode of the data (e.g., AXT2 denotes T2 brain MRI scans)\n    - max (float): dynamic range of the data\n    - norm (float): norm of the kspace\n    - patient_id (str): the patient's id whose measurements were recorded\n    \"\"\"\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n         Verify whether the specified file format is supported by h5py reader.\n\n        Args:\n             filename: file name\n        \"\"\"\n        suffixes: Sequence[str] = [\".h5\"]\n        return has_h5py and is_supported_format(filename, suffixes)\n\n    def read(self, data: Sequence[PathLike] | PathLike) -> dict:  # type: ignore\n        \"\"\"\n        Read data from specified h5 file.\n        Note that the returned object is a dictionary.\n\n        Args:\n            data: file name to read.\n        \"\"\"\n        if isinstance(data, (tuple, list)):\n            data = data[0]\n\n        with h5py.File(data, \"r\") as f:\n            # extract everything from the ht5 file\n            dat = dict(\n                [(key, f[key][()]) for key in f]\n                + [(key, f.attrs[key]) for key in f.attrs]\n                + [(FastMRIKeys.FILENAME, os.path.basename(data))]  # type: ignore\n            )\n        f.close()\n\n        return dat\n\n    def get_data(self, dat: dict) -> tuple[ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from the loaded data and return them.\n        This function returns two objects, first is numpy array of image data, second is dict of metadata.\n\n        Args:\n            dat: a dictionary loaded from an h5 file\n        \"\"\"\n        header = self._get_meta_dict(dat)\n        data: ndarray = np.array(dat[FastMRIKeys.KSPACE])\n        header[FastMRIKeys.MASK] = (\n            np.expand_dims(np.array(dat[FastMRIKeys.MASK]), 0)[None, ..., None]\n            if FastMRIKeys.MASK in dat.keys()\n            else np.zeros(data.shape)\n        )\n        return data, header\n\n    def _get_meta_dict(self, dat: dict) -> dict:\n        \"\"\"\n        Get all the metadata of the loaded dict and return the meta dict.\n\n        Args:\n            dat: a dictionary object loaded from an h5 file.\n        \"\"\"\n        return {k.value: dat[k.value] for k in FastMRIKeys if k.value in dat}\n"
  },
  {
    "path": "monai/apps/reconstruction/mri_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom torch import Tensor\n\nfrom monai.config.type_definitions import NdarrayOrTensor\n\n\ndef root_sum_of_squares_t(x: Tensor, spatial_dim: int) -> Tensor:\n    \"\"\"\n    Compute the root sum of squares (rss) of the data (typically done for multi-coil MRI samples)\n\n    Args:\n        x: Input tensor\n        spatial_dim: dimension along which rss is applied\n\n    Returns:\n        rss of x along spatial_dim\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            x = torch.ones([2,3])\n            # the following line prints Tensor([1.41421356, 1.41421356, 1.41421356])\n            print(rss(x,spatial_dim=0))\n    \"\"\"\n    rss_x: Tensor = (x**2).sum(spatial_dim) ** 0.5\n    return rss_x\n\n\ndef root_sum_of_squares(x: NdarrayOrTensor, spatial_dim: int) -> NdarrayOrTensor:\n    \"\"\"\n    Compute the root sum of squares (rss) of the data (typically done for multi-coil MRI samples)\n\n    Args:\n        x: Input array/tensor\n        spatial_dim: dimension along which rss is applied\n\n    Returns:\n        rss of x along spatial_dim\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            x = np.ones([2,3])\n            # the following line prints array([1.41421356, 1.41421356, 1.41421356])\n            print(rss(x,spatial_dim=0))\n    \"\"\"\n    rss_x: NdarrayOrTensor = root_sum_of_squares_t(x, spatial_dim)  # type: ignore\n    return rss_x\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/blocks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/blocks/varnetblock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\n\nfrom monai.apps.reconstruction.networks.nets.utils import sensitivity_map_expand, sensitivity_map_reduce\n\n\nclass VarNetBlock(nn.Module):\n    \"\"\"\n    A variational block based on Sriram et. al., \"End-to-end variational networks for accelerated MRI reconstruction\".\n    It applies data consistency and refinement to the intermediate kspace and combines those results.\n\n    Modified and adopted from: https://github.com/facebookresearch/fastMRI\n\n    Args:\n        refinement_model: the model used for refinement (typically a U-Net but can be any deep learning model\n            that performs well when the input and output are in image domain (e.g., a convolutional network).\n        spatial_dims: is 2 for 2D data and is 3 for 3D data\n    \"\"\"\n\n    def __init__(self, refinement_model: nn.Module, spatial_dims: int = 2):\n        super().__init__()\n        self.model = refinement_model\n        self.spatial_dims = spatial_dims\n        self.dc_weight = nn.Parameter(torch.ones(1))  # learned scalar as the multiplier of the DC block\n\n        buffer_shape = [1 for _ in range(spatial_dims + 3)]  # 3 denotes the batch, channel, and real/complex dimensions\n        self.register_buffer(\"zeros\", torch.zeros(buffer_shape))\n\n    def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"\n        Applies data consistency to input x. Suppose x is an intermediate estimate of the kspace and ref_kspace\n        is the reference under-sampled measurement. This function returns mask * (x - ref_kspace). View this as the\n        residual between the original under-sampled kspace and the estimate given by the network.\n\n        Args:\n            x: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the\n                coil dimension. 3D data will have the shape (B,C,H,W,D,2).\n            ref_kspace: original under-sampled kspace with the same shape as x.\n            mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.\n\n        Returns:\n            Output of DC block with the same shape as x\n        \"\"\"\n        return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight  # type: ignore\n\n    def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:\n        \"\"\"\n        Args:\n            current_kspace: Predicted kspace from the previous block. It's a 2D kspace (B,C,H,W,2)\n                with the last dimension being 2 (for real/imaginary parts) and C denoting the\n                coil dimension. 3D data will have the shape (B,C,H,W,D,2).\n            ref_kspace: reference kspace for applying data consistency (is the under-sampled kspace in MRI reconstruction).\n                Its shape is the same as current_kspace.\n            mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.\n            sens_maps: coil sensitivity maps with the same shape as current_kspace\n\n        Returns:\n            Output of VarNetBlock with the same shape as current_kspace\n        \"\"\"\n        dc_out = self.soft_dc(current_kspace, ref_kspace, mask)  # output of DC block\n        refinement_out = sensitivity_map_expand(\n            self.model(sensitivity_map_reduce(current_kspace, sens_maps, spatial_dims=self.spatial_dims)),\n            sens_maps,\n            spatial_dims=self.spatial_dims,\n        )  # output of refinement model\n        output = current_kspace - dc_out - refinement_out\n        return output\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/nets/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\n\nfrom monai.apps.reconstruction.mri_utils import root_sum_of_squares_t\nfrom monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet\nfrom monai.apps.reconstruction.networks.nets.utils import (\n    reshape_batch_channel_to_channel_dim,\n    reshape_channel_to_batch_dim,\n)\nfrom monai.networks.blocks.fft_utils_t import ifftn_centered_t\n\n\nclass CoilSensitivityModel(nn.Module):\n    \"\"\"\n    This class uses a convolutional model to learn coil sensitivity maps for multi-coil MRI reconstruction.\n    The convolutional model is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet` by default\n    but can be specified by the user as well. Learning is done on the center of the under-sampled\n    kspace (that region is fully sampled).\n\n    The data being a (complex) 2-channel tensor is a requirement for using this model.\n\n    Modified and adopted from: https://github.com/facebookresearch/fastMRI\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        features: six integers as numbers of features. denotes number of channels in each layer.\n        act: activation type and arguments. Defaults to LeakyReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        bias: whether to have a bias term in convolution blocks. Defaults to True.\n        dropout: dropout ratio. Defaults to 0.0.\n        upsample: upsampling mode, available options are\n            ``\"deconv\"``, ``\"pixelshuffle\"``, ``\"nontrainable\"``.\n        coil_dim: coil dimension in the data\n        conv_net: the learning model used to estimate the coil sensitivity maps. default\n            is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet`. The only\n            requirement on the model is to have 2 as input and output number of channels.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 2,\n        features: Sequence[int] = (32, 32, 64, 128, 256, 32),\n        act: str | tuple = (\"LeakyReLU\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        norm: str | tuple = (\"instance\", {\"affine\": True}),\n        bias: bool = True,\n        dropout: float | tuple = 0.0,\n        upsample: str = \"deconv\",\n        coil_dim: int = 1,\n        conv_net: nn.Module | None = None,\n    ):\n        super().__init__()\n        if conv_net is None:\n            self.conv_net = ComplexUnet(\n                spatial_dims=spatial_dims,\n                features=features,\n                act=act,\n                norm=norm,\n                bias=bias,\n                dropout=dropout,\n                upsample=upsample,\n            )\n        else:\n            # assume the first layer is convolutional and\n            # check whether in_channels == 2\n            params = [p.shape for p in conv_net.parameters()]\n            if params[0][1] != 2:\n                raise ValueError(f\"in_channels should be 2 but it's {params[0][1]}.\")\n            self.conv_net = conv_net  # type: ignore\n        self.spatial_dims = spatial_dims\n        self.coil_dim = coil_dim\n\n    def get_fully_sampled_region(self, mask: Tensor) -> tuple[int, int]:\n        \"\"\"\n        Extracts the size of the fully-sampled part of the kspace. Note that when a kspace\n        is under-sampled, a part of its center is fully sampled. This part is called the Auto\n        Calibration Region (ACR). ACR is used for sensitivity map computation.\n\n        Args:\n            mask: the under-sampling mask of shape (..., S, 1) where S denotes the sampling dimension\n\n        Returns:\n            A tuple containing\n                (1) left index of the region\n                (2) right index of the region\n\n        Note:\n            Suppose the mask is of shape (1,1,20,1). If this function returns 8,12 as left and right\n                indices, then it means that the fully-sampled center region has size 4 starting from 8 to 12.\n        \"\"\"\n        left = right = mask.shape[-2] // 2\n        while mask[..., right, :]:\n            right += 1\n\n        while mask[..., left, :]:\n            left -= 1\n\n        return left + 1, right\n\n    def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"\n        Args:\n            masked_kspace: the under-sampled kspace (which is the input measurement). Its shape\n                is (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.\n            mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.\n\n        Returns:\n            predicted coil sensitivity maps with shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data.\n        \"\"\"\n        left, right = self.get_fully_sampled_region(mask)\n        num_low_freqs = right - left  # size of the fully-sampled center\n\n        # take out the fully-sampled region and set the rest of the data to zero\n        x = torch.zeros_like(masked_kspace)\n        start = (mask.shape[-2] - num_low_freqs + 1) // 2  # this marks the start of center extraction\n        x[..., start : start + num_low_freqs, :] = masked_kspace[..., start : start + num_low_freqs, :]\n\n        # apply inverse fourier to the extracted fully-sampled data\n        x = ifftn_centered_t(x, spatial_dims=self.spatial_dims, is_complex=True)\n\n        x, b = reshape_channel_to_batch_dim(x)  # shape of x will be (B*C,1,...)\n        x = self.conv_net(x)\n        x = reshape_batch_channel_to_channel_dim(x, b)  # shape will be (B,C,...)\n        # normalize the maps\n        x = x / root_sum_of_squares_t(x, spatial_dim=self.coil_dim).unsqueeze(self.coil_dim)\n\n        return x\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/nets/complex_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch.nn as nn\nfrom torch import Tensor\n\nfrom monai.apps.reconstruction.networks.nets.utils import (\n    complex_normalize,\n    divisible_pad_t,\n    inverse_divisible_pad_t,\n    reshape_channel_complex_to_last_dim,\n    reshape_complex_to_channel_dim,\n)\nfrom monai.networks.nets.basic_unet import BasicUNet\n\n\nclass ComplexUnet(nn.Module):\n    \"\"\"\n    This variant of U-Net handles complex-value input/output. It can be\n    used as a model to learn sensitivity maps in multi-coil MRI data. It is\n    built based on :py:class:`monai.networks.nets.BasicUNet` by default but the user\n    can input their convolutional model as well.\n    ComplexUnet also applies default normalization to the input which makes it more stable to train.\n\n    The data being a (complex) 2-channel tensor is a requirement for using this model.\n\n    Modified and adopted from: https://github.com/facebookresearch/fastMRI\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        features: six integers as numbers of features. denotes number of channels in each layer.\n        act: activation type and arguments. Defaults to LeakyReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        bias: whether to have a bias term in convolution blocks. Defaults to True.\n        dropout: dropout ratio. Defaults to 0.0.\n        upsample: upsampling mode, available options are\n            ``\"deconv\"``, ``\"pixelshuffle\"``, ``\"nontrainable\"``.\n        pad_factor: an integer denoting the number which each padded dimension will be divisible to.\n            For example, 16 means each dimension will be divisible by 16 after padding\n        conv_net: the learning model used inside the ComplexUnet. The default\n            is :py:class:`monai.networks.nets.basic_unet`. The only requirement on the model is to\n            have 2 as input and output number of channels.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 2,\n        features: Sequence[int] = (32, 32, 64, 128, 256, 32),\n        act: str | tuple = (\"LeakyReLU\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        norm: str | tuple = (\"instance\", {\"affine\": True}),\n        bias: bool = True,\n        dropout: float | tuple = 0.0,\n        upsample: str = \"deconv\",\n        pad_factor: int = 16,\n        conv_net: nn.Module | None = None,\n    ):\n        super().__init__()\n        self.unet: nn.Module\n        if conv_net is None:\n            self.unet = BasicUNet(\n                spatial_dims=spatial_dims,\n                in_channels=2,\n                out_channels=2,\n                features=features,\n                act=act,\n                norm=norm,\n                bias=bias,\n                dropout=dropout,\n                upsample=upsample,\n            )\n        else:\n            # assume the first layer is convolutional and\n            # check whether in_channels == 2\n            params = [p.shape for p in conv_net.parameters()]\n            if params[0][1] != 2:\n                raise ValueError(f\"in_channels should be 2 but it's {params[0][1]}.\")\n            self.unet = conv_net\n\n        self.pad_factor = pad_factor\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"\n        Args:\n            x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data\n\n        Returns:\n            output of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data\n        \"\"\"\n        # suppose the input is 2D, the comment in front of each operator below shows the shape after that operator\n        x = reshape_complex_to_channel_dim(x)  # x will be of shape (B,C*2,H,W)\n        x, mean, std = complex_normalize(x)  # x will be of shape (B,C*2,H,W)\n        # pad input\n        x, padding_sizes = divisible_pad_t(\n            x, k=self.pad_factor\n        )  # x will be of shape (B,C*2,H',W') where H' and W' are for after padding\n\n        x = self.unet(x)\n        # inverse padding\n        x = inverse_divisible_pad_t(x, padding_sizes)  # x will be of shape (B,C*2,H,W)\n\n        x = x * std + mean\n        x = reshape_channel_complex_to_last_dim(x)  # x will be of shape (B,C,H,W,2)\n        return x\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/nets/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis script contains utility functions for developing new networks/blocks in PyTorch.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport math\n\nfrom torch import Tensor\nfrom torch.nn import functional as F\n\nfrom monai.apps.reconstruction.complex_utils import complex_conj_t, complex_mul_t\nfrom monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t\n\n\ndef reshape_complex_to_channel_dim(x: Tensor) -> Tensor:\n    \"\"\"\n    Swaps the complex dimension with the channel dimension so that the network treats real/imaginary\n    parts as two separate channels.\n\n    Args:\n        x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data\n\n    Returns:\n        output of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data\n    \"\"\"\n    if x.shape[-1] != 2:\n        raise ValueError(f\"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.\")\n\n    if len(x.shape) == 5:  # this is 2D\n        b, c, h, w, two = x.shape\n        return x.permute(0, 4, 1, 2, 3).contiguous().view(b, 2 * c, h, w)\n\n    elif len(x.shape) == 6:  # this is 3D\n        b, c, h, w, d, two = x.shape\n        return x.permute(0, 5, 1, 2, 3, 4).contiguous().view(b, 2 * c, h, w, d)\n\n    else:\n        raise ValueError(f\"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}\")\n\n\ndef reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor:\n    \"\"\"\n    Swaps the complex dimension with the channel dimension so that the network output has 2 as its last dimension\n\n    Args:\n        x: input of shape (B,C*2,H,W) for 2D data or (B,C*2,H,W,D) for 3D data\n\n    Returns:\n        output of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data\n    \"\"\"\n    if x.shape[1] % 2 != 0:\n        raise ValueError(f\"channel dimension should be even but ({x.shape[1]}) is odd.\")\n\n    if len(x.shape) == 4:  # this is 2D\n        b, c2, h, w = x.shape  # c2 means c*2\n        c = c2 // 2\n        return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1)\n\n    elif len(x.shape) == 5:  # this is 3D\n        b, c2, h, w, d = x.shape  # c2 means c*2\n        c = c2 // 2\n        return x.view(b, 2, c, h, w, d).permute(0, 2, 3, 4, 5, 1)\n\n    else:\n        raise ValueError(f\"only 2D (B,C*2,H,W) and 3D (B,C*2,H,W,D) data are supported but x has shape {x.shape}\")\n\n\ndef reshape_channel_to_batch_dim(x: Tensor) -> tuple[Tensor, int]:\n    \"\"\"\n    Combines batch and channel dimensions.\n\n    Args:\n        x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data\n\n    Returns:\n        A tuple containing:\n            (1) output of shape (B*C,1,...)\n            (2) batch size\n    \"\"\"\n\n    if len(x.shape) == 5:  # this is 2D\n        b, c, h, w, two = x.shape\n        return x.contiguous().view(b * c, 1, h, w, two), b\n\n    elif len(x.shape) == 6:  # this is 3D\n        b, c, h, w, d, two = x.shape\n        return x.contiguous().view(b * c, 1, h, w, d, two), b\n\n    else:\n        raise ValueError(f\"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}\")\n\n\ndef reshape_batch_channel_to_channel_dim(x: Tensor, batch_size: int) -> Tensor:\n    \"\"\"\n    Detaches batch and channel dimensions.\n\n    Args:\n        x: input of shape (B*C,1,H,W,2) for 2D data or (B*C,1,H,W,D,2) for 3D data\n        batch_size: batch size\n\n    Returns:\n        output of shape (B,C,...)\n    \"\"\"\n    if len(x.shape) == 5:  # this is 2D\n        bc, one, h, w, two = x.shape  # bc represents B*C\n        c = bc // batch_size\n        return x.view(batch_size, c, h, w, two)\n\n    elif len(x.shape) == 6:  # this is 3D\n        bc, one, h, w, d, two = x.shape  # bc represents B*C\n        c = bc // batch_size\n        return x.view(batch_size, c, h, w, d, two)\n\n    else:\n        raise ValueError(f\"only 2D (B*C,1,H,W,2) and 3D (B*C,1,H,W,D,2) data are supported but x has shape {x.shape}\")\n\n\ndef complex_normalize(x: Tensor) -> tuple[Tensor, Tensor, Tensor]:\n    \"\"\"\n    Performs layer mean-std normalization for complex data. Normalization is done for each batch member\n    along each part (part refers to real and imaginary parts), separately.\n\n    Args:\n        x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data\n\n    Returns:\n        A tuple containing\n            (1) normalized output of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data\n            (2) mean\n            (3) std\n    \"\"\"\n    if len(x.shape) == 4:  # this is 2D\n        b, c, h, w = x.shape\n        x = x.contiguous().view(b, 2, c // 2 * h * w)\n        mean = x.mean(dim=2).view(b, 2, 1, 1, 1).expand(b, 2, c // 2, 1, 1).contiguous().view(b, c, 1, 1)\n        std = x.std(dim=2, unbiased=False).view(b, 2, 1, 1, 1).expand(b, 2, c // 2, 1, 1).contiguous().view(b, c, 1, 1)\n        x = x.view(b, c, h, w)\n        return (x - mean) / std, mean, std\n\n    elif len(x.shape) == 5:  # this is 3D\n        b, c, h, w, d = x.shape\n        x = x.contiguous().view(b, 2, c // 2 * h * w * d)\n        mean = x.mean(dim=2).view(b, 2, 1, 1, 1, 1).expand(b, 2, c // 2, 1, 1, 1).contiguous().view(b, c, 1, 1, 1)\n        std = (\n            x.std(dim=2, unbiased=False)\n            .view(b, 2, 1, 1, 1, 1)\n            .expand(b, 2, c // 2, 1, 1, 1)\n            .contiguous()\n            .view(b, c, 1, 1, 1)\n        )\n        x = x.view(b, c, h, w, d)\n        return (x - mean) / std, mean, std\n\n    else:\n        raise ValueError(f\"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}\")\n\n\ndef divisible_pad_t(\n    x: Tensor, k: int = 16\n) -> tuple[Tensor, tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]]:\n    \"\"\"\n    Pad input to feed into the network (torch script compatible)\n\n    Args:\n        x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data\n        k: padding factor. each padded dimension will be divisible by k.\n\n    Returns:\n        A tuple containing\n            (1) padded input\n            (2) pad sizes (in order to reverse padding if needed)\n\n    Example:\n        .. code-block:: python\n\n            import torch\n\n            # 2D data\n            x = torch.ones([3,2,50,70])\n            x_pad,padding_sizes = divisible_pad_t(x, k=16)\n            # the following line should print (3, 2, 64, 80)\n            print(x_pad.shape)\n\n            # 3D data\n            x = torch.ones([3,2,50,70,80])\n            x_pad,padding_sizes = divisible_pad_t(x, k=16)\n            # the following line should print (3, 2, 64, 80, 80)\n            print(x_pad.shape)\n\n    \"\"\"\n    if len(x.shape) == 4:  # this is 2D\n        b, c, h, w = x.shape\n        w_mult = ((w - 1) | (k - 1)) + 1  # OR with (k-1) and then +1 makes sure padding is divisible by k\n        h_mult = ((h - 1) | (k - 1)) + 1\n        w_pad = floor_ceil((w_mult - w) / 2)\n        h_pad = floor_ceil((h_mult - h) / 2)\n        x = F.pad(x, w_pad + h_pad)\n        # dummy values for the 3rd spatial dimension\n        d_mult = -1\n        d_pad = (-1, -1)\n        pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult)\n\n    elif len(x.shape) == 5:  # this is 3D\n        b, c, h, w, d = x.shape\n        w_mult = ((w - 1) | (k - 1)) + 1\n        h_mult = ((h - 1) | (k - 1)) + 1\n        d_mult = ((d - 1) | (k - 1)) + 1\n        w_pad = floor_ceil((w_mult - w) / 2)\n        h_pad = floor_ceil((h_mult - h) / 2)\n        d_pad = floor_ceil((d_mult - d) / 2)\n        x = F.pad(x, d_pad + w_pad + h_pad)\n        pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult)\n\n    else:\n        raise ValueError(f\"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}\")\n\n    return x, pad_sizes\n\n\ndef inverse_divisible_pad_t(\n    x: Tensor, pad_sizes: tuple[tuple[int, int], tuple[int, int], tuple[int, int], int, int, int]\n) -> Tensor:\n    \"\"\"\n    De-pad network output to match its original shape\n\n    Args:\n        x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data\n        pad_sizes: padding values\n\n    Returns:\n        de-padded input\n    \"\"\"\n    h_pad, w_pad, d_pad, h_mult, w_mult, d_mult = pad_sizes\n\n    if len(x.shape) == 4:  # this is 2D\n        return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]\n\n    elif len(x.shape) == 5:  # this is 3D\n        return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1], d_pad[0] : d_mult - d_pad[1]]\n\n    else:\n        raise ValueError(f\"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}\")\n\n\ndef floor_ceil(n: float) -> tuple[int, int]:\n    \"\"\"\n    Returns floor and ceil of the input\n\n    Args:\n        n: input number\n\n    Returns:\n        A tuple containing:\n            (1) floor(n)\n            (2) ceil(n)\n    \"\"\"\n    return math.floor(n), math.ceil(n)\n\n\ndef sensitivity_map_reduce(kspace: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor:\n    \"\"\"\n    Reduces coil measurements to a corresponding image based on the given sens_maps. Let's say there\n    are C coil measurements inside kspace, then this function multiplies the conjugate of each coil sensitivity map with the\n    corresponding coil image. The result of this process will be C images. Summing those images together gives the\n    resulting \"reduced image.\"\n\n    Args:\n        kspace: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the\n            coil dimension. 3D data will have the shape (B,C,H,W,D,2).\n        sens_maps: sensitivity maps of the same shape as input x.\n        spatial_dims: is 2 for 2D data and is 3 for 3D data\n\n    Returns:\n        reduction of x to (B,1,H,W,2) for 2D data or (B,1,H,W,D,2) for 3D data.\n    \"\"\"\n    img = ifftn_centered_t(kspace, spatial_dims=spatial_dims, is_complex=True)  # inverse fourier transform\n    return complex_mul_t(img, complex_conj_t(sens_maps)).sum(dim=1, keepdim=True)\n\n\ndef sensitivity_map_expand(img: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor:\n    \"\"\"\n    Expands an image to its corresponding coil images based on the given sens_maps. Let's say there\n    are C coils. This function multiples image img with each coil sensitivity map in sens_maps and stacks\n    the resulting C coil images along the channel dimension which is reserved for coils.\n\n    Args:\n        img: 2D image (B,1,H,W,2) with the last dimension being 2 (for real/imaginary parts). 3D data will have\n            the shape (B,1,H,W,D,2).\n        sens_maps: Sensitivity maps for combining coil images. The shape is (B,C,H,W,2) for 2D data\n            or (B,C,H,W,D,2) for 3D data (C denotes the coil dimension).\n        spatial_dims: is 2 for 2D data and is 3 for 3D data\n\n    Returns:\n        Expansion of x to (B,C,H,W,2) for 2D data and (B,C,H,W,D,2) for 3D data. The output is transferred\n            to the frequency domain to yield coil measurements.\n    \"\"\"\n    return fftn_centered_t(complex_mul_t(img, sens_maps), spatial_dims=spatial_dims, is_complex=True)\n"
  },
  {
    "path": "monai/apps/reconstruction/networks/nets/varnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\n\nimport torch.nn as nn\nfrom torch import Tensor\n\nfrom monai.apps.reconstruction.complex_utils import complex_abs_t\nfrom monai.apps.reconstruction.mri_utils import root_sum_of_squares_t\nfrom monai.apps.reconstruction.networks.blocks.varnetblock import VarNetBlock\nfrom monai.networks.blocks.fft_utils_t import ifftn_centered_t\n\n\nclass VariationalNetworkModel(nn.Module):\n    \"\"\"\n    The end-to-end variational network (or simply e2e-VarNet) based on Sriram et. al., \"End-to-end variational\n    networks for accelerated MRI reconstruction\".\n    It comprises several cascades each consisting of refinement and data consistency steps. The network takes in\n    the under-sampled kspace and estimates the ground-truth reconstruction.\n\n    Modified and adopted from: https://github.com/facebookresearch/fastMRI\n\n    Args:\n        coil_sensitivity_model: A convolutional model for learning coil sensitivity maps. An example is\n            :py:class:`monai.apps.reconstruction.networks.nets.coil_sensitivity_model.CoilSensitivityModel`.\n        refinement_model: A convolutional network used in the refinement step of e2e-VarNet. An example\n            is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet.ComplexUnet`.\n        num_cascades: Number of cascades. Each cascade is a\n            :py:class:`monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock` which consists of\n            refinement and data consistency steps.\n        spatial_dims: number of spatial dimensions.\n    \"\"\"\n\n    def __init__(\n        self,\n        coil_sensitivity_model: nn.Module,\n        refinement_model: nn.Module,\n        num_cascades: int = 12,\n        spatial_dims: int = 2,\n    ):\n        super().__init__()\n        self.coil_sensitivity_model = coil_sensitivity_model\n        self.cascades = nn.ModuleList([VarNetBlock(copy.deepcopy(refinement_model)) for i in range(num_cascades)])\n        self.spatial_dims = spatial_dims\n\n    def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"\n        Args:\n            masked_kspace: The under-sampled kspace. It's a 2D kspace (B,C,H,W,2)\n                with the last dimension being 2 (for real/imaginary parts) and C denoting the\n                coil dimension. 3D data will have the shape (B,C,H,W,D,2).\n            mask: The under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.\n\n        Returns:\n            The reconstructed image which is the root sum of squares (rss) of the absolute value\n                of the inverse fourier of the predicted kspace (note that rss combines coil images into one image).\n        \"\"\"\n        sensitivity_maps = self.coil_sensitivity_model(masked_kspace, mask)  # shape is similar to masked_kspace\n        kspace_pred = masked_kspace.clone()\n\n        for cascade in self.cascades:\n            kspace_pred = cascade(kspace_pred, masked_kspace, mask, sensitivity_maps)\n\n        output_image = root_sum_of_squares_t(\n            complex_abs_t(ifftn_centered_t(kspace_pred, spatial_dims=self.spatial_dims)),\n            spatial_dim=1,  # 1 is for C which is the coil dimension\n        )  # shape is (B,H,W) for 2D and (B,H,W,D) for 3D data.\n        return output_image\n"
  },
  {
    "path": "monai/apps/reconstruction/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/reconstruction/transforms/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom abc import abstractmethod\nfrom collections.abc import Sequence\n\nimport numpy as np\nfrom torch import Tensor\n\nfrom monai.apps.reconstruction.complex_utils import complex_abs, convert_to_tensor_complex\nfrom monai.apps.reconstruction.mri_utils import root_sum_of_squares\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.fft_utils import ifftn_centered\nfrom monai.transforms.transform import RandomizableTransform\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.type_conversion import convert_to_tensor\n\n\nclass KspaceMask(RandomizableTransform):\n    \"\"\"\n    A basic class for under-sampling mask setup. It provides common\n    features for under-sampling mask generators.\n    For example, RandomMaskFunc and EquispacedMaskFunc (two mask\n    transform objects defined right after this module)\n    both inherit MaskFunc to properly setup properties like the\n    acceleration factor.\n    \"\"\"\n\n    def __init__(\n        self,\n        center_fractions: Sequence[float],\n        accelerations: Sequence[float],\n        spatial_dims: int = 2,\n        is_complex: bool = True,\n    ):\n        \"\"\"\n        Args:\n            center_fractions: Fraction of low-frequency columns to be retained.\n                If multiple values are provided, then one of these numbers\n                is chosen uniformly each time.\n            accelerations: Amount of under-sampling. This should have the\n                same length as center_fractions. If multiple values are\n                provided, then one of these is chosen uniformly each time.\n            spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data;\n                it's also 2 for pseudo-3D datasets like the fastMRI dataset).\n                The last spatial dim is selected for sampling. For the fastMRI\n                dataset, k-space has the form (...,num_slices,num_coils,H,W)\n                and sampling is done along W. For a general 3D data with the\n                shape (...,num_coils,H,W,D), sampling is done along D.\n            is_complex: if True, then the last dimension will be reserved for\n                real/imaginary parts.\n        \"\"\"\n        if len(center_fractions) != len(accelerations):\n            raise ValueError(\n                \"Number of center fractions \\\n                should match number of accelerations\"\n            )\n\n        self.center_fractions = center_fractions\n        self.accelerations = accelerations\n        self.spatial_dims = spatial_dims\n        self.is_complex = is_complex\n\n    @abstractmethod\n    def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:\n        \"\"\"\n        This is an extra instance to allow for defining new mask generators.\n        For creating other mask transforms, define a new class and simply\n        override __call__. See an example of this in\n        :py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`.\n\n        Args:\n            kspace: The input k-space data. The shape is (...,num_coils,H,W,2)\n                for complex 2D inputs and (...,num_coils,H,W,D) for real 3D\n                data.\n        \"\"\"\n        raise NotImplementedError\n\n    def randomize_choose_acceleration(self) -> Sequence[float]:\n        \"\"\"\n        If multiple values are provided for center_fractions and\n        accelerations, this function selects one value uniformly\n        for each training/test sample.\n\n        Returns:\n            A tuple containing\n                (1) center_fraction: chosen fraction of center kspace\n                lines to exclude from under-sampling\n                (2) acceleration: chosen acceleration factor\n        \"\"\"\n        choice = self.R.randint(0, len(self.accelerations))\n        center_fraction = self.center_fractions[choice]\n        acceleration = self.accelerations[choice]\n        return center_fraction, acceleration\n\n\nclass RandomKspaceMask(KspaceMask):\n    \"\"\"\n    This k-space mask transform under-samples the k-space according to a\n    random sampling pattern. Precisely, it uniformly selects a subset of\n    columns from the input k-space data. If the k-space data has N columns,\n    the mask picks out:\n\n    1. N_low_freqs = (N * center_fraction) columns in the center\n    corresponding to low-frequencies\n\n    2. The other columns are selected uniformly at random with a probability\n    equal to:\n    prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs).\n    This ensures that the expected number of columns selected is equal to\n    (N / acceleration)\n\n    It is possible to use multiple center_fractions and accelerations,\n    in which case one possible (center_fraction, acceleration) is chosen\n    uniformly at random each time the transform is called.\n\n    Example:\n        If accelerations = [4, 8] and center_fractions = [0.08, 0.04],\n        then there is a 50% probability that 4-fold acceleration with 8%\n        center fraction is selected and a 50% probability that 8-fold\n        acceleration with 4% center fraction is selected.\n\n    Modified and adopted from:\n        https://github.com/facebookresearch/fastMRI/tree/master/fastmri\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:\n        \"\"\"\n        Args:\n            kspace: The input k-space data. The shape is (...,num_coils,H,W,2)\n                for complex 2D inputs and (...,num_coils,H,W,D) for real 3D\n                data. The last spatial dim is selected for sampling. For the\n                fastMRI dataset, k-space has the form\n                (...,num_slices,num_coils,H,W) and sampling is done along W.\n                For a general 3D data with the shape (...,num_coils,H,W,D),\n                sampling is done along D.\n\n        Returns:\n            A tuple containing\n                (1) the under-sampled kspace\n                (2) absolute value of the inverse fourier of the under-sampled kspace\n        \"\"\"\n        kspace_t = convert_to_tensor_complex(kspace)\n        spatial_size = kspace_t.shape\n        num_cols = spatial_size[-1]\n        if self.is_complex:  # for complex data\n            num_cols = spatial_size[-2]\n\n        center_fraction, acceleration = self.randomize_choose_acceleration()\n\n        # Create the mask\n        num_low_freqs = int(round(num_cols * center_fraction))\n        prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs)\n        mask = self.R.uniform(size=num_cols) < prob\n        pad = (num_cols - num_low_freqs + 1) // 2\n        mask[pad : pad + num_low_freqs] = True\n\n        # Reshape the mask\n        mask_shape = [1 for _ in spatial_size]\n        if self.is_complex:\n            mask_shape[-2] = num_cols\n        else:\n            mask_shape[-1] = num_cols\n\n        mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32))\n\n        # under-sample the ksapce\n        masked = mask * kspace_t\n        masked_kspace: Tensor = convert_to_tensor(masked)\n        self.mask = mask\n\n        # compute inverse fourier of the masked kspace\n        masked_kspace_ifft: Tensor = convert_to_tensor(\n            complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex))\n        )\n        # combine coil images (it is assumed that the coil dimension is\n        # the first dimension before spatial dimensions)\n        masked_kspace_ifft_rss: Tensor = convert_to_tensor(\n            root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1)\n        )\n        return masked_kspace, masked_kspace_ifft_rss\n\n\nclass EquispacedKspaceMask(KspaceMask):\n    \"\"\"\n    This k-space mask transform under-samples the k-space according to an\n    equi-distant sampling pattern. Precisely, it selects an equi-distant\n    subset of columns from the input k-space data. If the k-space data has N\n    columns, the mask picks out:\n\n    1. N_low_freqs = (N * center_fraction) columns in the center corresponding\n    to low-frequencies\n\n    2. The other columns are selected with equal spacing at a proportion that\n    reaches the desired acceleration rate taking into consideration the number\n    of low frequencies. This ensures that the expected number of columns\n    selected is equal to (N / acceleration)\n\n    It is possible to use multiple center_fractions and accelerations, in\n    which case one possible (center_fraction, acceleration) is chosen\n    uniformly at random each time the EquispacedMaskFunc object is called.\n\n    Example:\n        If accelerations = [4, 8] and center_fractions = [0.08, 0.04],\n        then there is a 50% probability that 4-fold acceleration with 8%\n        center fraction is selected and a 50% probability that 8-fold\n        acceleration with 4% center fraction is selected.\n\n    Modified and adopted from:\n        https://github.com/facebookresearch/fastMRI/tree/master/fastmri\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:\n        \"\"\"\n        Args:\n            kspace: The input k-space data. The shape is (...,num_coils,H,W,2)\n                for complex 2D inputs and (...,num_coils,H,W,D) for real 3D\n                data. The last spatial dim is selected for sampling. For the\n                fastMRI multi-coil dataset, k-space has the form\n                (...,num_slices,num_coils,H,W) and sampling is done along W.\n                For a general 3D data with the shape (...,num_coils,H,W,D),\n                sampling is done along D.\n\n        Returns:\n            A tuple containing\n                (1) the under-sampled kspace\n                (2) absolute value of the inverse fourier of the under-sampled kspace\n        \"\"\"\n        kspace_t = convert_to_tensor_complex(kspace)\n        spatial_size = kspace_t.shape\n        num_cols = spatial_size[-1]\n        if self.is_complex:  # for complex data\n            num_cols = spatial_size[-2]\n\n        center_fraction, acceleration = self.randomize_choose_acceleration()\n        num_low_freqs = int(round(num_cols * center_fraction))\n\n        # Create the mask\n        mask = np.zeros(num_cols, dtype=np.float32)\n        pad = (num_cols - num_low_freqs + 1) // 2\n        mask[pad : pad + num_low_freqs] = True\n\n        # Determine acceleration rate by adjusting for the\n        # number of low frequencies\n        adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)\n        offset = self.R.randint(0, round(adjusted_accel))\n\n        accel_samples = np.arange(offset, num_cols - 1, adjusted_accel)\n        accel_samples = np.around(accel_samples).astype(np.uint)\n        mask[accel_samples] = True\n\n        # Reshape the mask\n        mask_shape = [1 for _ in spatial_size]\n        if self.is_complex:\n            mask_shape[-2] = num_cols\n        else:\n            mask_shape[-1] = num_cols\n\n        mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32))\n\n        # under-sample the ksapce\n        masked = mask * kspace_t\n        masked_kspace: Tensor = convert_to_tensor(masked)\n        self.mask = mask\n\n        # compute inverse fourier of the masked kspace\n        masked_kspace_ifft: Tensor = convert_to_tensor(\n            complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex))\n        )\n        # combine coil images (it is assumed that the coil dimension is\n        # the first dimension before spatial dimensions)\n        masked_kspace_ifft_rss: Tensor = convert_to_tensor(\n            root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1)\n        )\n        return masked_kspace, masked_kspace_ifft_rss\n"
  },
  {
    "path": "monai/apps/reconstruction/transforms/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping, Sequence\n\nimport numpy as np\nfrom numpy import ndarray\nfrom torch import Tensor\n\nfrom monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask\nfrom monai.config import DtypeLike, KeysCollection\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.transforms import InvertibleTransform\nfrom monai.transforms.croppad.array import SpatialCrop\nfrom monai.transforms.intensity.array import NormalizeIntensity\nfrom monai.transforms.transform import MapTransform, RandomizableTransform\nfrom monai.utils import FastMRIKeys\nfrom monai.utils.type_conversion import convert_to_tensor\n\n\nclass ExtractDataKeyFromMetaKeyd(MapTransform):\n    \"\"\"\n    Moves keys from meta to data. It is useful when a dataset of paired samples\n    is loaded and certain keys should be moved from meta to data.\n\n    Args:\n        keys: keys to be transferred from meta to data\n        meta_key: the meta key where all the meta-data is stored\n        allow_missing_keys: don't raise exception if key is missing\n\n    Example:\n        When the fastMRI dataset is loaded, \"kspace\" is stored in the data dictionary,\n        but the ground-truth image with the key \"reconstruction_rss\" is stored in the meta data.\n        In this case, ExtractDataKeyFromMetaKeyd moves \"reconstruction_rss\" to data.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.meta_key = meta_key\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Tensor]:\n        \"\"\"\n        Args:\n            data: is a dictionary containing (key,value) pairs from the\n                loaded dataset\n\n        Returns:\n            the new data dictionary\n        \"\"\"\n        d = dict(data)\n        for key in self.keys:\n            if key in d[self.meta_key]:\n                d[key] = d[self.meta_key][key]  # type: ignore\n            elif not self.allow_missing_keys:\n                raise KeyError(\n                    f\"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data\"\n                    \" and allow_missing_keys==False.\"\n                )\n        return d  # type: ignore\n\n\nclass RandomKspaceMaskd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.apps.reconstruction.transforms.array.RandomKspacemask`.\n    Other mask transforms can inherit from this class, for example:\n    :py:class:`monai.apps.reconstruction.transforms.dictionary.EquispacedKspaceMaskd`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        center_fractions: Fraction of low-frequency columns to be retained.\n            If multiple values are provided, then one of these numbers is\n            chosen uniformly each time.\n        accelerations: Amount of under-sampling. This should have the\n            same length as center_fractions. If multiple values are provided,\n            then one of these is chosen uniformly each time.\n        spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data; it's\n            also 2 for pseudo-3D datasets like the fastMRI dataset).\n            The last spatial dim is selected for sampling. For the fastMRI\n            dataset, k-space has the form (...,num_slices,num_coils,H,W)\n            and sampling is done along W. For a general 3D data with the\n            shape (...,num_coils,H,W,D), sampling is done along D.\n        is_complex: if True, then the last dimension will be reserved\n            for real/imaginary parts.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RandomKspaceMask.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        center_fractions: Sequence[float],\n        accelerations: Sequence[float],\n        spatial_dims: int = 2,\n        is_complex: bool = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.masker = RandomKspaceMask(\n            center_fractions=center_fractions,\n            accelerations=accelerations,\n            spatial_dims=spatial_dims,\n            is_complex=is_complex,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandomKspaceMaskd:\n        super().set_random_state(seed, state)\n        self.masker.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Tensor]:\n        \"\"\"\n        Args:\n            data: is a dictionary containing (key,value) pairs from the\n                loaded dataset\n\n        Returns:\n            the new data dictionary\n        \"\"\"\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key + \"_masked\"], d[key + \"_masked_ifft\"] = self.masker(d[key])\n            d[FastMRIKeys.MASK] = self.masker.mask\n\n        return d  # type: ignore\n\n\nclass EquispacedKspaceMaskd(RandomKspaceMaskd):\n    \"\"\"\n    Dictionary-based wrapper of\n    :py:class:`monai.apps.reconstruction.transforms.array.EquispacedKspaceMask`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        center_fractions: Fraction of low-frequency columns to be retained.\n            If multiple values are provided, then one of these numbers is\n            chosen uniformly each time.\n        accelerations: Amount of under-sampling. This should have the same\n            length as center_fractions. If multiple values are provided,\n            then one of these is chosen uniformly each time.\n        spatial_dims: Number of spatial dims (e.g., it's 2 for a 2D data;\n            it's also 2 for  pseudo-3D datasets like the fastMRI dataset).\n            The last spatial dim is selected for sampling. For the fastMRI\n            dataset, k-space has the form (...,num_slices,num_coils,H,W)\n            and sampling is done along W. For a general 3D data with the shape\n            (...,num_coils,H,W,D), sampling is done along D.\n        is_complex: if True, then the last dimension will be reserved\n            for real/imaginary parts.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = EquispacedKspaceMask.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        center_fractions: Sequence[float],\n        accelerations: Sequence[float],\n        spatial_dims: int = 2,\n        is_complex: bool = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.masker = EquispacedKspaceMask(  # type: ignore\n            center_fractions=center_fractions,\n            accelerations=accelerations,\n            spatial_dims=spatial_dims,\n            is_complex=is_complex,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> EquispacedKspaceMaskd:\n        super().set_random_state(seed, state)\n        self.masker.set_random_state(seed, state)\n        return self\n\n\nclass ReferenceBasedSpatialCropd(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`.\n    This is similar to :py:class:`monai.transforms.SpatialCropd` which is a\n    general purpose cropper to produce sub-volume region of interest (ROI).\n    Their difference is that this transform does cropping according to a reference image.\n\n    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        ref_key: key of the item to be used to crop items of \"keys\"\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Example:\n        In an image reconstruction task, let keys=[\"image\"] and ref_key=[\"target\"].\n        Also, let data be the data dictionary. Then, ReferenceBasedSpatialCropd\n        center-crops data[\"image\"] based on the spatial size of data[\"target\"] by\n        calling :py:class:`monai.transforms.SpatialCrop`.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, ref_key: str, allow_missing_keys: bool = False) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.ref_key = ref_key\n\n    def __call__(self, data: Mapping[Hashable, Tensor]) -> dict[Hashable, Tensor]:\n        \"\"\"\n        This transform can support to crop ND spatial (channel-first) data.\n        It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D\n        data point where C is the number of slices)\n\n        Args:\n            data: is a dictionary containing (key,value) pairs from\n                the loaded dataset\n\n        Returns:\n            the new data dictionary\n        \"\"\"\n        d = dict(data)\n\n        # compute roi_size according to self.ref_key\n        roi_size = d[self.ref_key].shape[1:]  # first dimension is not spatial (could be channel)\n\n        # crop keys\n        for key in self.key_iterator(d):\n            image = d[key]\n            roi_center = tuple(i // 2 for i in image.shape[1:])\n            cropper = SpatialCrop(roi_center=roi_center, roi_size=roi_size)\n            d[key] = convert_to_tensor(cropper(d[key]))\n        return d\n\n\nclass ReferenceBasedNormalizeIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of\n    :py:class:`monai.transforms.NormalizeIntensity`.\n    This is similar to :py:class:`monai.transforms.NormalizeIntensityd`\n    and can normalize non-zero values or the entire image. The difference\n    is that this transform does normalization according to a reference image.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        ref_key: key of the item to be used to normalize items of \"keys\"\n        subtrahend: the amount to subtract by (usually the mean)\n        divisor: the amount to divide by (usually the standard deviation)\n        nonzero: whether only normalize non-zero values.\n        channel_wise: if True, calculate on each channel separately,\n            otherwise, calculate on the entire image directly. default\n            to False.\n        dtype: output data type, if None, same as input image. defaults\n            to float32.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Example:\n        In an image reconstruction task, let keys=[\"image\", \"target\"] and ref_key=[\"image\"].\n        Also, let data be the data dictionary. Then, ReferenceBasedNormalizeIntensityd\n        normalizes data[\"target\"] and data[\"image\"] based on the mean-std of data[\"image\"] by\n        calling :py:class:`monai.transforms.NormalizeIntensity`.\n    \"\"\"\n\n    backend = NormalizeIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        ref_key: str,\n        subtrahend: NdarrayOrTensor | None = None,\n        divisor: NdarrayOrTensor | None = None,\n        nonzero: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.default_normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype)\n        self.ref_key = ref_key\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        This transform can support to normalize ND spatial (channel-first) data.\n        It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D\n        data point where C is the number of slices)\n\n        Args:\n            data: is a dictionary containing (key,value) pairs from\n                the loaded dataset\n\n        Returns:\n            the new data dictionary\n        \"\"\"\n        d = dict(data)\n\n        # prepare the normalizer based on self.ref_key\n        if self.default_normalizer.channel_wise:\n            # perform channel-wise normalization\n            # compute mean of each channel in the input for mean-std normalization\n            # subtrahend will have the same shape as image, for example (C,W,D) for a 2D data\n            if self.default_normalizer.subtrahend is None:\n                subtrahend = np.array(\n                    [val.mean() if isinstance(val, ndarray) else val.float().mean().item() for val in d[self.ref_key]]\n                )\n            # users can define default values instead of mean\n            else:\n                subtrahend = self.default_normalizer.subtrahend  # type: ignore\n\n            # compute std of each channel in the input for mean-std normalization\n            # will have the same shape as subtrahend\n            if self.default_normalizer.divisor is None:\n                divisor = np.array(\n                    [\n                        val.std() if isinstance(val, ndarray) else val.float().std(unbiased=False).item()\n                        for val in d[self.ref_key]\n                    ]\n                )\n            else:\n                # users can define default values instead of std\n                divisor = self.default_normalizer.divisor  # type: ignore\n        else:\n            # perform ordinary normalization (not channel-wise)\n            # subtrahend will be a scalar and is the mean of d[self.ref_key], unless user specifies another value\n            if self.default_normalizer.subtrahend is None:\n                if isinstance(d[self.ref_key], ndarray):\n                    subtrahend = d[self.ref_key].mean()  # type: ignore\n                else:\n                    subtrahend = d[self.ref_key].float().mean().item()  # type: ignore\n            # users can define default values instead of mean\n            else:\n                subtrahend = self.default_normalizer.subtrahend  # type: ignore\n\n            # divisor will be a scalar and is the std of d[self.ref_key], unless user specifies another value\n            if self.default_normalizer.divisor is None:\n                if isinstance(d[self.ref_key], ndarray):\n                    divisor = d[self.ref_key].std()  # type: ignore\n                else:\n                    divisor = d[self.ref_key].float().std(unbiased=False).item()  # type: ignore\n            else:\n                # users can define default values instead of std\n                divisor = self.default_normalizer.divisor  # type: ignore\n\n        # this creates a new normalizer instance based on self.ref_key\n        normalizer = NormalizeIntensity(\n            subtrahend,\n            divisor,\n            self.default_normalizer.nonzero,\n            self.default_normalizer.channel_wise,\n            self.default_normalizer.dtype,\n        )\n\n        # save mean and std\n        d[\"mean\"] = subtrahend\n        d[\"std\"] = divisor\n\n        # perform normalization\n        for key in self.key_iterator(d):\n            d[key] = normalizer(d[key])\n\n        return d\n"
  },
  {
    "path": "monai/apps/tcia/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .label_desc import TCIA_LABEL_DICT\nfrom .utils import (\n    BASE_URL,\n    DCM_FILENAME_REGEX,\n    download_tcia_series_instance,\n    get_tcia_metadata,\n    get_tcia_ref_uid,\n    match_tcia_ref_uid_in_study,\n)\n"
  },
  {
    "path": "monai/apps/tcia/label_desc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\n__all__ = [\"TCIA_LABEL_DICT\"]\n\nTCIA_LABEL_DICT: dict[str, dict[str, int]] = {\n    \"C4KC-KiTS\": {\"Kidney\": 0, \"Renal Tumor\": 1},\n    \"NSCLC-Radiomics\": {\n        \"Esophagus\": 0,\n        \"GTV-1\": 1,\n        \"Lungs-Total\": 2,\n        \"Spinal-Cord\": 3,\n        \"Lung-Left\": 4,\n        \"Lung-Right\": 5,\n        \"Heart\": 6,\n    },\n    \"NSCLC-Radiomics-Interobserver1\": {\n        \"GTV-1auto-1\": 0,\n        \"GTV-1auto-2\": 1,\n        \"GTV-1auto-3\": 2,\n        \"GTV-1auto-4\": 3,\n        \"GTV-1auto-5\": 4,\n        \"GTV-1vis-1\": 5,\n        \"GTV-1vis-2\": 6,\n        \"GTV-1vis-3\": 7,\n        \"GTV-1vis-4\": 8,\n        \"GTV-1vis-5\": 9,\n    },\n    \"QIN-PROSTATE-Repeatability\": {\"NormalROI_PZ_1\": 0, \"TumorROI_PZ_1\": 1, \"PeripheralZone\": 2, \"WholeGland\": 3},\n    \"PROSTATEx\": {\n        \"Prostate\": 0,\n        \"Peripheral zone of prostate\": 1,\n        \"Transition zone of prostate\": 2,\n        \"Distal prostatic urethra\": 3,\n        \"Anterior fibromuscular stroma of prostate\": 4,\n    },\n}\n"
  },
  {
    "path": "monai/apps/tcia/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nfrom collections.abc import Iterable\n\nimport monai\nfrom monai.config.type_definitions import PathLike\nfrom monai.utils import optional_import\n\nrequests_get, has_requests = optional_import(\"requests\", name=\"get\")\npd, has_pandas = optional_import(\"pandas\")\n\nDCM_FILENAME_REGEX = r\"^(?!.*LICENSE).*\"  # excluding the file with \"LICENSE\" in its name\nBASE_URL = \"https://services.cancerimagingarchive.net/nbia-api/services/v1/\"\n\n__all__ = [\n    \"get_tcia_metadata\",\n    \"download_tcia_series_instance\",\n    \"get_tcia_ref_uid\",\n    \"match_tcia_ref_uid_in_study\",\n    \"DCM_FILENAME_REGEX\",\n    \"BASE_URL\",\n]\n\n\ndef get_tcia_metadata(query: str, attribute: str | None = None) -> list:\n    \"\"\"\n    Achieve metadata of a public The Cancer Imaging Archive (TCIA) dataset.\n\n    This function makes use of The National Biomedical Imaging Archive (NBIA) REST APIs to access the metadata\n    of objects in the TCIA database.\n    Please refer to the following link for more details:\n    https://wiki.cancerimagingarchive.net/display/Public/NBIA+Search+REST+API+Guide\n\n    This function relies on `requests` package.\n\n    Args:\n        query: queries used to achieve the corresponding metadata. A query is consisted with query name and\n            query parameters. The format is like: <query name>?<parameter 1>&<parameter 2>.\n            For example: \"getSeries?Collection=C4KC-KiTS&Modality=SEG\"\n            Please refer to the section of Image Metadata APIs in the link mentioned\n            above for more details.\n        attribute: Achieved metadata may contain multiple attributes, if specifying an attribute name, other attributes\n            will be ignored.\n\n    \"\"\"\n\n    if not has_requests:\n        raise ValueError(\"requests package is necessary, please install it.\")\n    full_url = f\"{BASE_URL}{query}\"\n    resp = requests_get(full_url)\n    resp.raise_for_status()\n    metadata_list: list = []\n    if len(resp.text) == 0:\n        return metadata_list\n    for d in resp.json():\n        if attribute is not None and attribute in d:\n            metadata_list.append(d[attribute])\n        else:\n            metadata_list.append(d)\n\n    return metadata_list\n\n\ndef download_tcia_series_instance(\n    series_uid: str,\n    download_dir: PathLike,\n    output_dir: PathLike,\n    check_md5: bool = False,\n    hashes_filename: str = \"md5hashes.csv\",\n    progress: bool = True,\n) -> None:\n    \"\"\"\n    Download a dicom series from a public The Cancer Imaging Archive (TCIA) dataset.\n    The downloaded compressed file will be stored in `download_dir`, and the uncompressed folder will be saved\n    in `output_dir`.\n\n    Args:\n        series_uid: SeriesInstanceUID of a dicom series.\n        download_dir: the path to store the downloaded compressed file. The full path of the file is:\n            `os.path.join(download_dir, f\"{series_uid}.zip\")`.\n        output_dir: target directory to save extracted dicom series.\n        check_md5: whether to download the MD5 hash values as well. If True, will check hash values for all images in\n            the downloaded dicom series.\n        hashes_filename: file that contains hashes.\n        progress: whether to display progress bar.\n\n    \"\"\"\n    query_name = \"getImageWithMD5Hash\" if check_md5 else \"getImage\"\n    download_url = f\"{BASE_URL}{query_name}?SeriesInstanceUID={series_uid}\"\n\n    monai.apps.utils.download_and_extract(\n        url=download_url,\n        filepath=os.path.join(download_dir, f\"{series_uid}.zip\"),\n        output_dir=output_dir,\n        progress=progress,\n    )\n    if check_md5:\n        if not has_pandas:\n            raise ValueError(\"pandas package is necessary, please install it.\")\n        hashes_df = pd.read_csv(os.path.join(output_dir, hashes_filename))\n        for dcm, md5hash in hashes_df.values:\n            monai.apps.utils.check_hash(filepath=os.path.join(output_dir, dcm), val=md5hash, hash_type=\"md5\")\n\n\ndef get_tcia_ref_uid(\n    ds: Iterable,\n    find_sop: bool = False,\n    ref_series_uid_tag: tuple = (0x0020, 0x000E),\n    ref_sop_uid_tag: tuple = (0x0008, 0x1155),\n) -> str:\n    \"\"\"\n    Achieve the referenced UID from the referenced Series Sequence for the input pydicom dataset object.\n    The referenced UID could be Series Instance UID or SOP Instance UID. The UID will be detected from\n    the data element of the input object. If the data element is a sequence, each dataset within the sequence\n    will be detected iteratively. The first detected UID will be returned.\n\n    Args:\n        ds: a pydicom dataset object.\n        find_sop: whether to achieve the referenced SOP Instance UID.\n        ref_series_uid_tag: tag of the referenced Series Instance UID.\n        ref_sop_uid_tag: tag of the referenced SOP Instance UID.\n\n    \"\"\"\n    ref_uid_tag = ref_sop_uid_tag if find_sop else ref_series_uid_tag\n    output = \"\"\n\n    for elem in ds:\n        if elem.VR == \"SQ\":\n            for item in elem:\n                output = get_tcia_ref_uid(item, find_sop)\n        if elem.tag == ref_uid_tag:\n            return elem.value  # type: ignore[no-any-return]\n\n    return output\n\n\ndef match_tcia_ref_uid_in_study(study_uid, ref_sop_uid):\n    \"\"\"\n    Match the SeriesInstanceUID from all series in a study according to the input SOPInstanceUID.\n\n    Args:\n        study_uid: StudyInstanceUID.\n        ref_sop_uid: SOPInstanceUID.\n\n    \"\"\"\n    series_list = get_tcia_metadata(query=f\"getSeries?StudyInstanceUID={study_uid}\", attribute=\"SeriesInstanceUID\")\n    for series_id in series_list:\n        sop_id_list = get_tcia_metadata(\n            query=f\"getSOPInstanceUIDs?SeriesInstanceUID={series_id}\", attribute=\"SOPInstanceUID\"\n        )\n        if ref_sop_uid in sop_id_list:\n            return series_id\n    return \"\"\n"
  },
  {
    "path": "monai/apps/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport hashlib\nimport json\nimport logging\nimport os\nimport re\nimport shutil\nimport sys\nimport tarfile\nimport tempfile\nimport warnings\nimport zipfile\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any\nfrom urllib.error import ContentTooShortError, HTTPError, URLError\nfrom urllib.parse import urlparse\nfrom urllib.request import urlopen, urlretrieve\n\nfrom monai.config.type_definitions import PathLike\nfrom monai.utils import look_up_option, min_version, optional_import\n\nrequests, has_requests = optional_import(\"requests\")\ngdown, has_gdown = optional_import(\"gdown\", \"4.7.3\")\nBeautifulSoup, has_bs4 = optional_import(\"bs4\", name=\"BeautifulSoup\")\n\nif TYPE_CHECKING:\n    from tqdm import tqdm\n\n    has_tqdm = True\nelse:\n    tqdm, has_tqdm = optional_import(\"tqdm\", \"4.47.0\", min_version, \"tqdm\")\n\n__all__ = [\"check_hash\", \"download_url\", \"extractall\", \"download_and_extract\", \"get_logger\", \"SUPPORTED_HASH_TYPES\"]\n\nDEFAULT_FMT = \"%(asctime)s - %(levelname)s - %(message)s\"\nSUPPORTED_HASH_TYPES = {\"md5\": hashlib.md5, \"sha1\": hashlib.sha1, \"sha256\": hashlib.sha256, \"sha512\": hashlib.sha512}\n\n\ndef get_logger(\n    module_name: str = \"monai.apps\",\n    fmt: str = DEFAULT_FMT,\n    datefmt: str | None = None,\n    logger_handler: logging.Handler | None = None,\n) -> logging.Logger:\n    \"\"\"\n    Get a `module_name` logger with the specified format and date format.\n    By default, the logger will print to `stdout` at the INFO level.\n    If `module_name` is `None`, return the root logger.\n    `fmt` and `datafmt` are passed to a `logging.Formatter` object\n    (https://docs.python.org/3/library/logging.html#formatter-objects).\n    `logger_handler` can be used to add an additional handler.\n    \"\"\"\n    adds_stdout_handler = module_name is not None and module_name not in logging.root.manager.loggerDict\n    logger = logging.getLogger(module_name)\n    logger.propagate = False\n    logger.setLevel(logging.INFO)\n    if adds_stdout_handler:  # don't add multiple stdout or add to the root\n        handler = logging.StreamHandler(sys.stdout)\n        formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)\n        handler.setFormatter(formatter)\n        logger.addHandler(handler)\n    if logger_handler is not None:\n        logger.addHandler(logger_handler)\n    return logger\n\n\n# apps module-level default logger\nlogger = get_logger(\"monai.apps\")\n__all__.append(\"logger\")\n\n\ndef _basename(p: PathLike) -> str:\n    \"\"\"get the last part of the path (removing the trailing slash if it exists)\"\"\"\n    sep = os.path.sep + (os.path.altsep or \"\") + \"/ \"\n    return Path(f\"{p}\".rstrip(sep)).name\n\n\ndef _download_with_progress(url: str, filepath: Path, progress: bool = True) -> None:\n    \"\"\"\n    Retrieve file from `url` to `filepath`, optionally showing a progress bar.\n    \"\"\"\n    try:\n        if has_tqdm and progress:\n\n            class TqdmUpTo(tqdm):\n                \"\"\"\n                Provides `update_to(n)` which uses `tqdm.update(delta_n)`.\n                Inspired by the example in https://github.com/tqdm/tqdm.\n                \"\"\"\n\n                def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> None:\n                    \"\"\"\n                    Args:\n                        b: number of blocks transferred so far, default: 1.\n                        bsize: size of each block (in tqdm units), default: 1.\n                        tsize: total size (in tqdm units). if None, remains unchanged.\n                    \"\"\"\n                    if tsize is not None:\n                        self.total = tsize\n                    self.update(b * bsize - self.n)  # will also set self.n = b * bsize\n\n            with TqdmUpTo(unit=\"B\", unit_scale=True, unit_divisor=1024, miniters=1, desc=_basename(filepath)) as t:\n                urlretrieve(url, filepath, reporthook=t.update_to)\n        else:\n            if not has_tqdm and progress:\n                warnings.warn(\"tqdm is not installed, will not show the downloading progress bar.\")\n            urlretrieve(url, filepath)\n    except (URLError, HTTPError, ContentTooShortError, OSError) as e:\n        logger.error(f\"Download failed from {url} to {filepath}.\")\n        raise e\n\n\ndef safe_extract_member(member, extract_to):\n    \"\"\"Securely verify compressed package member paths to prevent path traversal attacks\"\"\"\n    # Get member path (handle different compression formats)\n    if hasattr(member, \"filename\"):\n        member_path = member.filename  # zipfile\n    elif hasattr(member, \"name\"):\n        member_path = member.name  # tarfile\n    else:\n        member_path = str(member)\n\n    if hasattr(member, \"issym\") and member.issym():\n        raise ValueError(f\"Symbolic link detected in archive: {member_path}\")\n    if hasattr(member, \"islnk\") and member.islnk():\n        raise ValueError(f\"Hard link detected in archive: {member_path}\")\n\n    member_path = os.path.normpath(member_path)\n\n    if os.path.isabs(member_path) or \"..\" in member_path.split(os.sep):\n        raise ValueError(f\"Unsafe path detected in archive: {member_path}\")\n\n    full_path = os.path.join(extract_to, member_path)\n    full_path = os.path.normpath(full_path)\n\n    extract_root = os.path.realpath(extract_to)\n    target_real = os.path.realpath(full_path)\n    # Ensure the resolved path stays within the extraction root\n    if os.path.commonpath([extract_root, target_real]) != extract_root:\n        raise ValueError(f\"Unsafe path: path traversal {member_path}\")\n\n    return full_path\n\n\ndef check_hash(filepath: PathLike, val: str | None = None, hash_type: str = \"md5\") -> bool:\n    \"\"\"\n    Verify hash signature of specified file.\n\n    Args:\n        filepath: path of source file to verify hash value.\n        val: expected hash value of the file.\n        hash_type: type of hash algorithm to use, default is `\"md5\"`.\n            The supported hash types are `\"md5\"`, `\"sha1\"`, `\"sha256\"`, `\"sha512\"`.\n            See also: :py:data:`monai.apps.utils.SUPPORTED_HASH_TYPES`.\n\n    \"\"\"\n    if val is None:\n        logger.info(f\"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.\")\n        return True\n    actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)\n\n    actual_hash = actual_hash_func(usedforsecurity=False)  # allows checks on FIPS enabled machines\n\n    try:\n        with open(filepath, \"rb\") as f:\n            for chunk in iter(lambda: f.read(1024 * 1024), b\"\"):\n                actual_hash.update(chunk)\n    except Exception as e:\n        logger.error(f\"Exception in check_hash: {e}\")\n        return False\n    if val != actual_hash.hexdigest():\n        logger.error(f\"check_hash failed {actual_hash.hexdigest()}.\")\n        return False\n\n    logger.info(f\"Verified '{_basename(filepath)}', {hash_type}: {val}.\")\n    return True\n\n\ndef download_url(\n    url: str,\n    filepath: PathLike = \"\",\n    hash_val: str | None = None,\n    hash_type: str = \"md5\",\n    progress: bool = True,\n    **gdown_kwargs: Any,\n) -> None:\n    \"\"\"\n    Download file from specified URL link, support process bar and hash check.\n\n    Args:\n        url: source URL link to download file.\n        filepath: target filepath to save the downloaded file (including the filename).\n            If undefined, `os.path.basename(url)` will be used.\n        hash_val: expected hash value to validate the downloaded file.\n            if None, skip hash validation.\n        hash_type: 'md5' or 'sha1', defaults to 'md5'.\n        progress: whether to display a progress bar.\n        gdown_kwargs: other args for `gdown` except for the `url`, `output` and `quiet`.\n            these args will only be used if download from google drive.\n            details of the args of it:\n            https://github.com/wkentaro/gdown/blob/main/gdown/download.py\n\n    Raises:\n        RuntimeError: When the hash validation of the ``filepath`` existing file fails.\n        RuntimeError: When a network issue or denied permission prevents the\n            file download from ``url`` to ``filepath``.\n        URLError: See urllib.request.urlretrieve.\n        HTTPError: See urllib.request.urlretrieve.\n        ContentTooShortError: See urllib.request.urlretrieve.\n        IOError: See urllib.request.urlretrieve.\n        RuntimeError: When the hash validation of the ``url`` downloaded file fails.\n\n    \"\"\"\n    if not filepath:\n        filepath = Path(\".\", _basename(url)).resolve()\n        logger.info(f\"Default downloading to '{filepath}'\")\n    filepath = Path(filepath)\n    if filepath.exists():\n        if not check_hash(filepath, hash_val, hash_type):\n            raise RuntimeError(\n                f\"{hash_type} check of existing file failed: filepath={filepath}, expected {hash_type}={hash_val}.\"\n            )\n        logger.info(f\"File exists: {filepath}, skipped downloading.\")\n        return\n    try:\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            tmp_name = Path(tmp_dir, _basename(filepath))\n            if urlparse(url).netloc == \"drive.google.com\":\n                if not has_gdown:\n                    raise RuntimeError(\"To download files from Google Drive, please install the gdown dependency.\")\n                if \"fuzzy\" not in gdown_kwargs:\n                    gdown_kwargs[\"fuzzy\"] = True  # default to true for flexible url\n                gdown.download(url, f\"{tmp_name}\", quiet=not progress, **gdown_kwargs)\n            elif urlparse(url).netloc == \"cloud-api.yandex.net\":\n                with urlopen(url) as response:\n                    code = response.getcode()\n                    if code == 200:\n                        download_url = json.load(response)[\"href\"]\n                        _download_with_progress(download_url, tmp_name, progress=progress)\n                    else:\n                        raise RuntimeError(\n                            f\"Download of file from {download_url}, received from {url} \"\n                            + f\" to {filepath} failed due to network issue or denied permission.\"\n                        )\n            else:\n                _download_with_progress(url, tmp_name, progress=progress)\n            if not tmp_name.exists():\n                raise RuntimeError(\n                    f\"Download of file from {url} to {filepath} failed due to network issue or denied permission.\"\n                )\n            file_dir = filepath.parent\n            if file_dir:\n                os.makedirs(file_dir, exist_ok=True)\n            shutil.move(f\"{tmp_name}\", f\"{filepath}\")  # copy the downloaded to a user-specified cache.\n    except (PermissionError, NotADirectoryError):  # project-monai/monai issue #3613 #3757 for windows\n        pass\n    logger.info(f\"Downloaded: {filepath}\")\n    if not check_hash(filepath, hash_val, hash_type):\n        raise RuntimeError(\n            f\"{hash_type} check of downloaded file failed: URL={url}, \"\n            f\"filepath={filepath}, expected {hash_type}={hash_val}.\"\n        )\n\n\ndef _extract_zip(filepath, output_dir):\n    with zipfile.ZipFile(filepath, \"r\") as zip_file:\n        for member in zip_file.infolist():\n            safe_path = safe_extract_member(member, output_dir)\n            if member.is_dir():\n                continue\n            os.makedirs(os.path.dirname(safe_path), exist_ok=True)\n            with zip_file.open(member) as source:\n                with open(safe_path, \"wb\") as target:\n                    shutil.copyfileobj(source, target)\n\n\ndef _extract_tar(filepath, output_dir):\n    with tarfile.open(filepath, \"r\") as tar_file:\n        for member in tar_file.getmembers():\n            safe_path = safe_extract_member(member, output_dir)\n            if not member.isfile():\n                continue\n            os.makedirs(os.path.dirname(safe_path), exist_ok=True)\n            source = tar_file.extractfile(member)\n            if source is not None:\n                with source:\n                    with open(safe_path, \"wb\") as target:\n                        shutil.copyfileobj(source, target)\n\n\ndef extractall(\n    filepath: PathLike,\n    output_dir: PathLike = \".\",\n    hash_val: str | None = None,\n    hash_type: str = \"md5\",\n    file_type: str = \"\",\n    has_base: bool = True,\n) -> None:\n    \"\"\"\n    Extract file to the output directory.\n    Expected file types are: `zip`, `tar.gz` and `tar`.\n\n    Args:\n        filepath: the file path of compressed file.\n        output_dir: target directory to save extracted files.\n        hash_val: expected hash value to validate the compressed file.\n            if None, skip hash validation.\n        hash_type: 'md5' or 'sha1', defaults to 'md5'.\n        file_type: string of file type for decompressing. Leave it empty to infer the type from the filepath basename.\n        has_base: whether the extracted files have a base folder. This flag is used when checking if the existing\n            folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped\n            to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should\n            be False.\n\n    Raises:\n        RuntimeError: When the hash validation of the ``filepath`` compressed file fails.\n        NotImplementedError: When the ``filepath`` file extension is not one of [zip\", \"tar.gz\", \"tar\"].\n\n    \"\"\"\n    if has_base:\n        # the extracted files will be in this folder\n        cache_dir = Path(output_dir, _basename(filepath).split(\".\")[0])\n    else:\n        cache_dir = Path(output_dir)\n    if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None:\n        logger.info(f\"Non-empty folder exists in {cache_dir}, skipped extracting.\")\n        return\n    filepath = Path(filepath)\n    if hash_val and not check_hash(filepath, hash_val, hash_type):\n        raise RuntimeError(\n            f\"{hash_type} check of compressed file failed: \" f\"filepath={filepath}, expected {hash_type}={hash_val}.\"\n        )\n    logger.info(f\"Writing into directory: {output_dir}.\")\n    _file_type = file_type.lower().strip()\n    if filepath.name.endswith(\"zip\") or _file_type == \"zip\":\n        _extract_zip(filepath, output_dir)\n        return\n    if filepath.name.endswith(\"tar\") or filepath.name.endswith(\"tar.gz\") or \"tar\" in _file_type:\n        _extract_tar(filepath, output_dir)\n        return\n    raise NotImplementedError(\n        f'Unsupported file type, available options are: [\"zip\", \"tar.gz\", \"tar\"]. name={filepath} type={file_type}.'\n    )\n\n\ndef get_filename_from_url(data_url: str) -> str:\n    \"\"\"\n    Get the filename from the URL link.\n    \"\"\"\n    try:\n        response = requests.head(data_url, allow_redirects=True)\n        content_disposition = response.headers.get(\"Content-Disposition\")\n        if content_disposition:\n            filename = re.findall('filename=\"?([^\";]+)\"?', content_disposition)\n            if filename:\n                return str(filename[0])\n        if \"drive.google.com\" in data_url:\n            response = requests.get(data_url)\n            if \"text/html\" in response.headers.get(\"Content-Type\", \"\"):\n                soup = BeautifulSoup(response.text, \"html.parser\")\n                filename_div = soup.find(\"span\", {\"class\": \"uc-name-size\"})\n                if filename_div:\n                    return str(filename_div.find(\"a\").text)\n        return _basename(data_url)\n    except Exception as e:\n        raise Exception(f\"Error processing URL: {e}\") from e\n\n\ndef download_and_extract(\n    url: str,\n    filepath: PathLike = \"\",\n    output_dir: PathLike = \".\",\n    hash_val: str | None = None,\n    hash_type: str = \"md5\",\n    file_type: str = \"\",\n    has_base: bool = True,\n    progress: bool = True,\n) -> None:\n    \"\"\"\n    Download file from URL and extract it to the output directory.\n\n    Args:\n        url: source URL link to download file.\n        filepath: the file path of the downloaded compressed file.\n            use this option to keep the directly downloaded compressed file, to avoid further repeated downloads.\n        output_dir: target directory to save extracted files.\n            default is the current directory.\n        hash_val: expected hash value to validate the downloaded file.\n            if None, skip hash validation.\n        hash_type: 'md5' or 'sha1', defaults to 'md5'.\n        file_type: string of file type for decompressing. Leave it empty to infer the type from url's base file name.\n        has_base: whether the extracted files have a base folder. This flag is used when checking if the existing\n            folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped\n            to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should\n            be False.\n        progress: whether to display progress bar.\n    \"\"\"\n    url_filename_ext = \"\".join(Path(get_filename_from_url(url)).suffixes)\n    filepath_ext = \"\".join(Path(_basename(filepath)).suffixes)\n    if filepath not in [\"\", \".\"]:\n        if filepath_ext == \"\":\n            new_filepath = Path(filepath).with_suffix(url_filename_ext)\n            logger.warning(\n                f\"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}\"\n            )\n            filepath = new_filepath\n    if filepath_ext and filepath_ext != url_filename_ext:\n        raise ValueError(f\"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}\")\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve()\n        download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)\n        extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)\n"
  },
  {
    "path": "monai/apps/vista3d/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/apps/vista3d/inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.utils import optional_import\n\ntqdm, _ = optional_import(\"tqdm\", name=\"tqdm\")\n\n__all__ = [\"point_based_window_inferer\"]\n\n\ndef point_based_window_inferer(\n    inputs: torch.Tensor | MetaTensor,\n    roi_size: Sequence[int],\n    predictor: torch.nn.Module,\n    point_coords: torch.Tensor,\n    point_labels: torch.Tensor,\n    class_vector: torch.Tensor | None = None,\n    prompt_class: torch.Tensor | None = None,\n    prev_mask: torch.Tensor | MetaTensor | None = None,\n    point_start: int = 0,\n    center_only: bool = True,\n    margin: int = 5,\n    **kwargs: Any,\n) -> torch.Tensor:\n    \"\"\"\n    Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.\n    The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by\n    patch inference and average output stitching, and finally returns the segmented mask.\n\n    Args:\n        inputs: [1CHWD], input image to be processed.\n        roi_size: the spatial window size for inferences.\n            When its components have None or non-positives, the corresponding inputs dimension will be used.\n            if the components of the `roi_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        sw_batch_size: the batch size to run window slices.\n        predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].\n            Add transpose=True in kwargs for vista3d.\n        point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.\n        point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.\n            2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).\n        class_vector: [B]. Used for class-head automatic segmentation. Can be None value.\n        prompt_class: [B]. The same as class_vector representing the point class and inform point head about\n            supported class or zeroshot, not used for automatic segmentation. If None, point head is default\n            to supported class segmentation.\n        prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.\n        point_start: only use points starting from this number. All points before this number is used to generate\n            prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.\n        center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.\n        margin: if center_only is false, this value is the distance between point to the patch boundary.\n    Returns:\n        stitched_output: [1, B, H, W, D]. The value is before sigmoid.\n    Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.\n    \"\"\"\n    if not point_coords.shape[0] == 1:\n        raise ValueError(\"Only supports single object point click.\")\n    if not len(inputs.shape) == 5:\n        raise ValueError(\"Input image should be 5D.\")\n    image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)\n    point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)\n    prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None\n    stitched_output = None\n    for p in point_coords[0][point_start:]:\n        lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)\n        ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)\n        lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)\n        for i in range(len(lx_)):\n            for j in range(len(ly_)):\n                for k in range(len(lz_)):\n                    lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])\n                    unravel_slice = (\n                        slice(None),\n                        slice(None),\n                        slice(int(lx), int(rx)),\n                        slice(int(ly), int(ry)),\n                        slice(int(lz), int(rz)),\n                    )\n                    batch_image = image[unravel_slice]\n                    output = predictor(\n                        batch_image,\n                        point_coords=point_coords,\n                        point_labels=point_labels,\n                        class_vector=class_vector,\n                        prompt_class=prompt_class,\n                        patch_coords=[unravel_slice],\n                        prev_mask=prev_mask,\n                        **kwargs,\n                    )\n                    if stitched_output is None:\n                        stitched_output = torch.zeros(\n                            [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device=\"cpu\"\n                        )\n                        stitched_mask = torch.zeros(\n                            [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device=\"cpu\"\n                        )\n                    stitched_output[unravel_slice] += output.to(\"cpu\")\n                    stitched_mask[unravel_slice] = 1\n    # if stitched_mask is 0, then NaN value\n    stitched_output = stitched_output / stitched_mask\n    # revert padding\n    stitched_output = stitched_output[\n        :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]\n    ]\n    stitched_mask = stitched_mask[\n        :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]\n    ]\n    if prev_mask is not None:\n        prev_mask = prev_mask[\n            :,\n            :,\n            pad[4] : image.shape[-3] - pad[5],\n            pad[2] : image.shape[-2] - pad[3],\n            pad[0] : image.shape[-1] - pad[1],\n        ]\n        prev_mask = prev_mask.to(\"cpu\")  # type: ignore\n        # for un-calculated place, use previous mask\n        stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]\n    if isinstance(inputs, torch.Tensor):\n        inputs = MetaTensor(inputs)\n    if not hasattr(stitched_output, \"meta\"):\n        stitched_output = MetaTensor(stitched_output, affine=inputs.meta[\"affine\"], meta=inputs.meta)\n    return stitched_output\n\n\ndef _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:\n    \"\"\"Helper function to get the window index.\"\"\"\n    if p - roi // 2 < 0:\n        left, right = 0, roi\n    elif p + roi // 2 > s:\n        left, right = s - roi, s\n    else:\n        left, right = int(p) - roi // 2, int(p) + roi // 2\n    return left, right\n\n\ndef _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:\n    \"\"\"Get the window index.\"\"\"\n    left, right = _get_window_idx_c(p, roi, s)\n    if center_only:\n        return [left], [right]\n    left_most = max(0, p - roi + margin)\n    right_most = min(s, p + roi - margin)\n    left_list = [left_most, right_most - roi, left]\n    right_list = [left_most + roi, right_most, right]\n    return left_list, right_list\n\n\ndef _pad_previous_mask(\n    inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0\n) -> tuple[torch.Tensor | MetaTensor, list[int]]:\n    \"\"\"Helper function to pad inputs.\"\"\"\n    pad_size = []\n    for k in range(len(inputs.shape) - 1, 1, -1):\n        diff = max(roi_size[k - 2] - inputs.shape[k], 0)\n        half = diff // 2\n        pad_size.extend([half, diff - half])\n    if any(pad_size):\n        inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode=\"constant\", value=padvalue)  # type: ignore\n    return inputs, pad_size\n"
  },
  {
    "path": "monai/apps/vista3d/sampler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nimport random\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nENABLE_SPECIAL = True\nSPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)\nMERGE_LIST = {\n    1: [25, 26],  # hepatic tumor and vessel merge into liver\n    4: [24],  # pancreatic tumor merge into pancreas\n    132: [57],  # overlap with trachea merge into airway\n}\n\n__all__ = [\"sample_prompt_pairs\"]\n\n\ndef _get_point_label(id: int) -> tuple[int, int]:\n    if id in SPECIAL_INDEX and ENABLE_SPECIAL:\n        return 2, 3\n    else:\n        return 0, 1\n\n\ndef sample_prompt_pairs(\n    labels: Tensor,\n    label_set: Sequence[int],\n    max_prompt: int | None = None,\n    max_foreprompt: int | None = None,\n    max_backprompt: int = 1,\n    max_point: int = 20,\n    include_background: bool = False,\n    drop_label_prob: float = 0.2,\n    drop_point_prob: float = 0.2,\n    point_sampler: Callable | None = None,\n    **point_sampler_kwargs: Any,\n) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:\n    \"\"\"\n    Sample training pairs for VISTA3D training.\n\n    Args:\n        labels: [1, 1, H, W, D], ground truth labels.\n        label_set: the label list for the specific dataset. Note if 0 is included in label_set,\n            it will be added into automatic branch training. Recommend removing 0 from label_set\n            for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.\n            The reason is region with 0 in one partially labeled dataset may contain foregrounds in\n            another dataset.\n        max_prompt: int, max number of total prompt, including foreground and background.\n        max_foreprompt: int, max number of prompt from foreground.\n        max_backprompt: int, max number of prompt from background.\n        max_point: maximum number of points for each object.\n        include_background: if include 0 into training prompt. If included, background 0 is treated\n            the same as foreground and points will be sampled. Can be true only if user want to segment\n            background 0 with point clicks, otherwise always be false.\n        drop_label_prob: probability to drop label prompt.\n        drop_point_prob: probability to drop point prompt.\n        point_sampler: sampler to augment masks with supervoxel.\n        point_sampler_kwargs: arguments for point_sampler.\n\n    Returns:\n        tuple:\n            - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for\n              training automatic segmentation.\n            - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points\n              for each class. Note that background label prompts require matching points as well\n              (e.g., [0, 0, 0] is used).\n            - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point\n              labels for each point (negative or positive). -1 is used for padding the background\n              label prompt and will be ignored.\n            - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt\n              for label indexing during training. If label_prompt is None, prompt_class is used to\n              identify point classes.\n\n    \"\"\"\n\n    # class label number\n    if not labels.shape[0] == 1:\n        raise ValueError(\"only support batch size 1\")\n    labels = labels[0, 0]\n    device = labels.device\n    unique_labels = labels.unique().cpu().numpy().tolist()\n    if include_background:\n        unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))\n    else:\n        unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})\n    background_labels = list(set(label_set) - set(unique_labels))\n    # during training, balance background and foreground prompts\n    if max_backprompt is not None:\n        if len(background_labels) > max_backprompt:\n            random.shuffle(background_labels)\n            background_labels = background_labels[:max_backprompt]\n\n    if max_foreprompt is not None:\n        if len(unique_labels) > max_foreprompt:\n            random.shuffle(unique_labels)\n            unique_labels = unique_labels[:max_foreprompt]\n\n    if max_prompt is not None:\n        if len(unique_labels) + len(background_labels) > max_prompt:\n            if len(unique_labels) > max_prompt:\n                unique_labels = random.sample(unique_labels, max_prompt)\n                background_labels = []\n            else:\n                background_labels = random.sample(background_labels, max_prompt - len(unique_labels))\n    _point = []\n    _point_label = []\n    # if use regular sampling\n    if point_sampler is None:\n        num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)\n        num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))\n        for id in unique_labels:\n            neg_id, pos_id = _get_point_label(id)\n            plabels = labels == int(id)\n            nlabels = ~plabels\n            plabelpoints = torch.nonzero(plabels)\n            nlabelpoints = torch.nonzero(nlabels)\n            # final sampled positive points\n            num_pa = min(len(plabelpoints), num_p)\n            # final sampled negative points\n            num_na = min(len(nlabelpoints), num_n)\n            _point.append(\n                torch.stack(\n                    random.choices(plabelpoints, k=num_pa)\n                    + random.choices(nlabelpoints, k=num_na)\n                    + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)\n                )\n            )\n            _point_label.append(\n                torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(\n                    device\n                )\n            )\n        for _ in background_labels:\n            # pad the background labels\n            _point.append(torch.zeros(num_p + num_n, 3).to(device))  # all 0\n            _point_label.append(torch.zeros(num_p + num_n).to(device) - 1)  # -1 not a point\n    else:\n        _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)\n        for _ in background_labels:\n            # pad the background labels\n            _point.append(torch.zeros(len(_point_label[0]), 3).to(device))  # all 0\n            _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1)  # -1 not a point\n    if len(unique_labels) == 0 and len(background_labels) == 0:\n        # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must\n        # be skipped. Handle this in trainer.\n        label_prompt, point, point_label, prompt_class = None, None, None, None\n    else:\n        label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()\n        point = torch.stack(_point)\n        point_label = torch.stack(_point_label)\n        prompt_class = copy.deepcopy(label_prompt)\n        if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:\n            label_prompt = None\n            # If label prompt is dropped, there is no need to pad with points with label -1.\n            pad = len(background_labels)\n            point = point[: len(point) - pad]  # type: ignore\n            point_label = point_label[: len(point_label) - pad]\n            prompt_class = prompt_class[: len(prompt_class) - pad]\n        else:\n            if random.uniform(0, 1) < drop_point_prob:\n                point = None\n                point_label = None\n    return label_prompt, point, point_label, prompt_class\n"
  },
  {
    "path": "monai/apps/vista3d/transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\n\nfrom monai.config import DtypeLike, KeysCollection\nfrom monai.transforms import MapLabelValue\nfrom monai.transforms.transform import MapTransform\nfrom monai.transforms.utils import keep_components_with_positive_points\nfrom monai.utils import look_up_option\n\n__all__ = [\"VistaPreTransformd\", \"VistaPostTransformd\", \"Relabeld\"]\n\n\ndef _get_name_to_index_mapping(labels_dict: dict | None) -> dict:\n    \"\"\"get the label name to index mapping\"\"\"\n    name_to_index_mapping = {}\n    if labels_dict is not None:\n        name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()}\n    return name_to_index_mapping\n\n\ndef _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None:\n    \"\"\"convert the label name to index\"\"\"\n    if label_prompt is not None and isinstance(label_prompt, list):\n        converted_label_prompt = []\n        # for new class, add to the mapping\n        for l in label_prompt:\n            if isinstance(l, str) and not l.isdigit():\n                if l.lower() not in name_to_index_mapping:\n                    name_to_index_mapping[l.lower()] = len(name_to_index_mapping)\n        for l in label_prompt:\n            if isinstance(l, (int, str)):\n                converted_label_prompt.append(\n                    name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l)\n                )\n            else:\n                converted_label_prompt.append(l)\n        return converted_label_prompt\n    return label_prompt\n\n\nclass VistaPreTransformd(MapTransform):\n    def __init__(\n        self,\n        keys: KeysCollection,\n        allow_missing_keys: bool = False,\n        special_index: Sequence[int] = (25, 26, 27, 28, 29, 117),\n        labels_dict: dict | None = None,\n        subclass: dict | None = None,\n    ) -> None:\n        \"\"\"\n        Pre-transform for Vista3d.\n\n        It performs two functionalities:\n\n        1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels),\n           convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive).\n\n        2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key].\n           e.g. \"lung\" label is converted to [\"left lung\", \"right lung\"].\n\n        The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B,\n        where each element is an int value of length [B, N].\n\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            special_index: the index that defines the special class.\n            subclass: a dictionary that maps a label prompt to its subclasses.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.special_index = special_index\n        self.subclass = subclass\n        self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict)\n\n    def __call__(self, data):\n        label_prompt = data.get(\"label_prompt\", None)\n        point_labels = data.get(\"point_labels\", None)\n        # convert the label name to index if needed\n        label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt)\n        try:\n            # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator.\n            if self.subclass is not None and label_prompt is not None:\n                _label_prompt = []\n                subclass_keys = list(map(int, self.subclass.keys()))\n                for i in range(len(label_prompt)):\n                    if label_prompt[i] in subclass_keys:\n                        _label_prompt.extend(self.subclass[str(label_prompt[i])])\n                    else:\n                        _label_prompt.append(label_prompt[i])\n                data[\"label_prompt\"] = _label_prompt\n            if label_prompt is not None and point_labels is not None:\n                if label_prompt[0] in self.special_index:\n                    point_labels = np.array(point_labels)\n                    point_labels[point_labels == 0] = 2\n                    point_labels[point_labels == 1] = 3\n                    point_labels = point_labels.tolist()\n                data[\"point_labels\"] = point_labels\n        except Exception:\n            # There is specific requirements for `label_prompt` and `point_labels`.\n            # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None.\n            # Those formatting errors should be captured later.\n            warnings.warn(\"VistaPreTransformd failed to transform label prompt or point labels.\")\n\n        return data\n\n\nclass VistaPostTransformd(MapTransform):\n    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Post-transform for Vista3d. It converts the model output logits into final segmentation masks.\n        If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...],\n        else the indexes will be [0, label_prompt[0], label_prompt[1], ...].\n        If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove\n        regions that does not contain positive points.\n\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            dataset_transforms: a dictionary specifies the transform for corresponding dataset:\n                key: dataset name, value: list of data transforms.\n            dataset_key: key to get the dataset name from the data dictionary, default to \"dataset_name\".\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n\n    def __call__(self, data):\n        \"\"\"data[\"label_prompt\"] should not contain 0\"\"\"\n        for keys in self.keys:\n            if keys in data:\n                pred = data[keys]\n                object_num = pred.shape[0]\n                device = pred.device\n                if data.get(\"label_prompt\", None) is None and data.get(\"points\", None) is not None:\n                    pred = keep_components_with_positive_points(\n                        pred.unsqueeze(0),\n                        point_coords=data.get(\"points\").to(device),\n                        point_labels=data.get(\"point_labels\").to(device),\n                    )[0]\n                pred[pred < 0] = 0.0\n                # if it's multichannel, perform argmax\n                if object_num > 1:\n                    # concate background channel. Make sure user did not provide 0 as prompt.\n                    is_bk = torch.all(pred <= 0, dim=0, keepdim=True)\n                    pred = pred.argmax(0).unsqueeze(0).float() + 1.0\n                    pred[is_bk] = 0.0\n                else:\n                    # AsDiscrete will remove NaN\n                    # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred)\n                    pred[pred > 0] = 1.0\n                if \"label_prompt\" in data and data[\"label_prompt\"] is not None:\n                    pred += 0.5  # inplace mapping to avoid cloning pred\n                    label_prompt = data[\"label_prompt\"].to(device)  # Ensure label_prompt is on the same device\n                    for i in range(1, object_num + 1):\n                        frac = i + 0.5\n                        pred[pred == frac] = label_prompt[i - 1].to(pred.dtype)\n                    pred[pred == 0.5] = 0.0\n                data[keys] = pred\n        return data\n\n\nclass Relabeld(MapTransform):\n    def __init__(\n        self,\n        keys: KeysCollection,\n        label_mappings: dict[str, list[tuple[int, int]]],\n        dtype: DtypeLike = np.int16,\n        dataset_key: str = \"dataset_name\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Remap the voxel labels in the input data dictionary based on the specified mapping.\n\n        This list of local -> global label mappings will be applied to each input `data[keys]`.\n        if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.\n        if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.\n\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            label_mappings: a dictionary specifies how local dataset class indices are mapped to the\n                global class indices. The dictionary keys are dataset names and the values are lists of\n                list of (local label, global label) pairs. This list of local -> global label mappings\n                will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,\n                label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,\n                no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform.\n            dtype: convert the output data to dtype, default to float32.\n            dataset_key: key to get the dataset name from the data dictionary, default to \"dataset_name\".\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.mappers = {}\n        self.dataset_key = dataset_key\n        for name, mapping in label_mappings.items():\n            self.mappers[name] = MapLabelValue(\n                orig_labels=[int(pair[0]) for pair in mapping],\n                target_labels=[int(pair[1]) for pair in mapping],\n                dtype=dtype,\n            )\n\n    def __call__(self, data):\n        d = dict(data)\n        dataset_name = d.get(self.dataset_key, \"default\")\n        _m = look_up_option(dataset_name, self.mappers, default=None)\n        if _m is None:\n            return d\n        for key in self.key_iterator(d):\n            d[key] = _m(d[key])\n        return d\n"
  },
  {
    "path": "monai/auto3dseg/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .algo_gen import Algo, AlgoGen\nfrom .analyzer import (\n    Analyzer,\n    FgImageStats,\n    FgImageStatsSumm,\n    FilenameStats,\n    ImageStats,\n    ImageStatsSumm,\n    LabelStats,\n    LabelStatsSumm,\n)\nfrom .operations import Operations, SampleOperations, SummaryOperations\nfrom .seg_summarizer import SegSummarizer\nfrom .utils import (\n    algo_from_pickle,\n    algo_to_pickle,\n    concat_multikeys_to_dict,\n    concat_val_to_np,\n    datafold_read,\n    get_foreground_image,\n    get_foreground_label,\n    get_label_ccp,\n    verify_report_format,\n)\n"
  },
  {
    "path": "monai/auto3dseg/algo_gen.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.config import PathLike\nfrom monai.transforms import Randomizable\n\n\nclass Algo:\n    \"\"\"\n    An algorithm in this context is loosely defined as a data processing pipeline consisting of multiple components\n    such as image preprocessing, followed by deep learning model training and evaluation.\n    \"\"\"\n\n    template_path: PathLike | None = None\n\n    def set_data_stats(self, *args, **kwargs):\n        \"\"\"Provide dataset (and summaries) so that the model creation can depend on the input datasets.\"\"\"\n\n    def train(self, *args, **kwargs):\n        \"\"\"Read training/validation data and output a model.\"\"\"\n\n    def predict(self, *args, **kwargs):\n        \"\"\"Read test data and output model predictions.\"\"\"\n\n    def get_score(self, *args, **kwargs):\n        \"\"\"Returns the model quality measurement based on training and validation datasets.\"\"\"\n\n    def get_output_path(self, *args, **kwargs):\n        \"\"\"Returns the algo output paths for scripts location\"\"\"\n\n\nclass AlgoGen(Randomizable):\n    \"\"\"\n    A data-driven algorithm generator. It optionally takes the following inputs:\n\n        - training dataset properties (such as data statistics from ``monai.auto3dseg.analyzer``),\n        - previous algorithm's scores measuring the model quality,\n        - computational budgets,\n\n    and generates ``Algo`` instances. The generated algos are to be trained with the training datasets::\n\n                                  scores\n                        +------------------------+\n                        |   +---------+          |\n        +-----------+   +-->|         |    +-----+----+\n        | Dataset,  |       | AlgoGen |--->|   Algo   |\n        | summaries |------>|         |    +----------+\n        +-----+-----+       +---------+          ^\n              |                                  |\n              +----------------------------------+\n\n    This class also maintains a history of previously generated Algo and their corresponding validation scores.\n    The Algo generation process may be stochastic (using ``Randomizable.R`` as the source random state).\n    \"\"\"\n\n    def set_data_stats(self, *args, **kwargs):  # type ignore\n        \"\"\"Provide dataset summaries/properties so that the generator can be conditioned on the input datasets.\"\"\"\n\n    def set_budget(self, *args, **kwargs):\n        \"\"\"Provide computational budget so that the generator outputs algorithms that requires reasonable resources.\"\"\"\n\n    def set_score(self, *args, **kwargs):\n        \"\"\"Feedback from the previously generated algo, the score can be used for new Algo generations.\"\"\"\n\n    def get_data_stats(self, *args, **kwargs):\n        \"\"\"Get current dataset summaries.\"\"\"\n\n    def get_budget(self, *args, **kwargs):\n        \"\"\"Get the current computational budget.\"\"\"\n\n    def get_history(self, *args, **kwargs):\n        \"\"\"Get the previously generated algo.\"\"\"\n\n    def generate(self):\n        \"\"\"Generate new Algo -- based on data_stats, budget, and history of previous algo generations.\"\"\"\n\n    def run_algo(self, *args, **kwargs):\n        \"\"\"\n        Launch the Algos. This is useful for light-weight Algos where there's no need to distribute the training jobs.\n\n        If the generated Algos require significant scheduling of parallel executions, a job scheduler/controller\n        implemented separately is preferred to run them. In this case the controller should also report back the\n        scores and the algo history, so that the future ``AlgoGen.generate`` can leverage the information.\n        \"\"\"\n"
  },
  {
    "path": "monai/auto3dseg/analyzer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport time\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Hashable, Mapping\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg.operations import Operations, SampleOperations, SummaryOperations\nfrom monai.auto3dseg.utils import (\n    concat_multikeys_to_dict,\n    concat_val_to_np,\n    get_foreground_image,\n    get_foreground_label,\n    get_label_ccp,\n    verify_report_format,\n)\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.bundle.utils import ID_SEP_KEY\nfrom monai.data import MetaTensor, affine_to_spacing\nfrom monai.transforms.transform import MapTransform\nfrom monai.transforms.utils_pytorch_numpy_unification import sum, unique\nfrom monai.utils import convert_to_numpy\nfrom monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys\nfrom monai.utils.misc import ImageMetaKey, label_union\n\nlogger = get_logger(module_name=__name__)\n\n__all__ = [\n    \"Analyzer\",\n    \"ImageStats\",\n    \"FgImageStats\",\n    \"LabelStats\",\n    \"ImageStatsSumm\",\n    \"FgImageStatsSumm\",\n    \"LabelStatsSumm\",\n    \"FilenameStats\",\n    \"ImageHistogram\",\n    \"ImageHistogramSumm\",\n]\n\n\nclass Analyzer(MapTransform, ABC):\n    \"\"\"\n    The Analyzer component is a base class. Other classes inherit this class will provide a callable\n    with the same class name and produces one pre-formatted dictionary for the input data. The format\n    is pre-defined by the init function of the class that inherit this base class. Function operations\n    can also be registered before the runtime of the callable.\n\n    Args:\n        report_format: a dictionary that outlines the key structures of the report format.\n\n    \"\"\"\n\n    def __init__(self, stats_name: str, report_format: dict) -> None:\n        super().__init__(None)\n        parser = ConfigParser(report_format, globals=False)  # ConfigParser.globals not picklable\n        self.report_format = parser.get(\"\")\n        self.stats_name = stats_name\n        self.ops = ConfigParser({}, globals=False)\n\n    def update_ops(self, key: str, op: Operations) -> None:\n        \"\"\"\n        Register a statistical operation to the Analyzer and update the report_format.\n\n        Args:\n            key: value key in the report.\n            op: Operation sub-class object that represents statistical operations.\n\n        \"\"\"\n        self.ops[key] = op\n        parser = ConfigParser(self.report_format)\n\n        if parser.get(key, \"None\") != \"None\":\n            parser[key] = op\n\n        self.report_format = parser.get(\"\")\n\n    def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:\n        \"\"\"\n        Update operations for nested label format. Operation value in report_format will be resolved\n        to a dict with only keys.\n\n        Args:\n            nested_key: str that has format of 'key1#0#key2'.\n            op: Operation sub-class object that represents statistical operations.\n        \"\"\"\n        keys = nested_key.split(ID_SEP_KEY)\n        if len(keys) != 3:\n            raise ValueError(\"Nested_key input format is wrong. Please ensure it is like key1#0#key2\")\n        root: str\n        child_key: str\n        (root, _, child_key) = keys\n        if root not in self.ops:\n            self.ops[root] = [{}]\n        self.ops[root][0].update({child_key: None})\n\n        self.ops[nested_key] = op\n\n        parser = ConfigParser(self.report_format)\n        if parser.get(nested_key, \"NA\") != \"NA\":\n            parser[nested_key] = op\n\n    def get_report_format(self) -> dict:\n        \"\"\"\n        Get the report format by resolving the registered operations recursively.\n\n        Returns:\n            a dictionary with {keys: None} pairs.\n\n        \"\"\"\n        self.resolve_format(self.report_format)\n        return self.report_format  # type: ignore[no-any-return]\n\n    @staticmethod\n    def unwrap_ops(func):\n        \"\"\"\n        Unwrap a function value and generates the same set keys in a dict when the function is actually\n        called in runtime\n\n        Args:\n            func: Operation sub-class object that represents statistical operations. The func object\n                should have a `data` dictionary which stores the statistical operation information.\n                For some operations (ImageStats for example), it may also contain the data_addon\n                property, which is part of the update process.\n\n        Returns:\n            a dict with a set of keys.\n\n        \"\"\"\n        ret = dict.fromkeys(list(func.data))\n        if hasattr(func, \"data_addon\"):\n            for key in func.data_addon:\n                ret.update({key: None})\n        return ret\n\n    def resolve_format(self, report: dict) -> None:\n        \"\"\"\n        Resolve the format of the pre-defined report.\n\n        Args:\n            report: the dictionary to resolve. Values will be replaced in-place.\n\n        \"\"\"\n        for k, v in report.items():\n            if isinstance(v, Operations):\n                report[k] = self.unwrap_ops(v)\n            elif isinstance(v, list) and len(v) > 0:\n                self.resolve_format(v[0])\n            else:\n                report[k] = v\n\n    @abstractmethod\n    def __call__(self, data: Any) -> dict:\n        \"\"\"Analyze the dict format dataset, return the summary report\"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass ImageStats(Analyzer):\n    \"\"\"\n    Analyzer to extract image stats properties for each case(image).\n\n    Args:\n        image_key: the key to find image data in the callable function input (data)\n\n    Examples:\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.auto3dseg import ImageStats\n        from monai.data import MetaTensor\n\n        input = {}\n        input['image'] = np.random.rand(1,30,30,30)\n        input['image'] = MetaTensor(np.random.rand(1,30,30,30))  # MetaTensor\n        analyzer = ImageStats(image_key=\"image\")\n        print(analyzer(input)[\"image_stats\"])\n\n    Notes:\n        if the image data is NumPy array, the spacing stats will be [1.0] * `ndims` of the array,\n        where the `ndims` is the lesser value between the image dimension and 3.\n\n    \"\"\"\n\n    def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS) -> None:\n        if not isinstance(image_key, str):\n            raise ValueError(\"image_key input must be str\")\n\n        self.image_key = image_key\n\n        report_format = {\n            ImageStatsKeys.SHAPE: None,\n            ImageStatsKeys.CHANNELS: None,\n            ImageStatsKeys.CROPPED_SHAPE: None,\n            ImageStatsKeys.SPACING: None,\n            ImageStatsKeys.SIZEMM: None,\n            ImageStatsKeys.INTENSITY: None,\n        }\n\n        super().__init__(stats_name, report_format)\n        self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())\n\n    def __call__(self, data):\n        # Input Validation Addition\n        if not isinstance(data, dict):\n            raise TypeError(f\"Input data must be a dict, but got {type(data).__name__}.\")\n        if self.image_key not in data:\n            raise KeyError(f\"Key '{self.image_key}' not found in input data.\")\n        image = data[self.image_key]\n        if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):\n            raise TypeError(\n                f\"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, \"\n                f\"but got {type(image).__name__}.\"\n            )\n        if image.ndim < 3:\n            raise ValueError(\n                f\"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}.\"\n            )\n            # --- End of validation ---\n        \"\"\"\n        Callable to execute the pre-defined functions\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format. The value of\n            ImageStatsKeys.INTENSITY is in a list format. Each element of the value list\n            has stats pre-defined by SampleOperations (max, min, ....).\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Note:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n\n        \"\"\"\n        d = dict(data)\n        start = time.time()\n        restore_grad_state = torch.is_grad_enabled()\n        torch.set_grad_enabled(False)\n\n        ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]\n        if \"nda_croppeds\" not in d:\n            nda_croppeds = [get_foreground_image(nda) for nda in ndas]\n\n        # perform calculation\n        report = deepcopy(self.get_report_format())\n\n        report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]\n        report[ImageStatsKeys.CHANNELS] = len(ndas)\n        report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]\n        report[ImageStatsKeys.SPACING] = (\n            affine_to_spacing(data[self.image_key].affine).tolist()\n            if isinstance(data[self.image_key], MetaTensor)\n            else [1.0] * min(3, data[self.image_key].ndim)\n        )\n\n        report[ImageStatsKeys.SIZEMM] = [\n            a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])\n        ]\n\n        report[ImageStatsKeys.INTENSITY] = [\n            self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds\n        ]\n\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        d[self.stats_name] = report\n\n        torch.set_grad_enabled(restore_grad_state)\n        logger.debug(f\"Get image stats spent {time.time() - start}\")\n        return d\n\n\nclass FgImageStats(Analyzer):\n    \"\"\"\n    Analyzer to extract foreground label properties for each case(image and label).\n\n    Args:\n        image_key: the key to find image data in the callable function input (data)\n        label_key: the key to find label data in the callable function input (data)\n\n    Examples:\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.auto3dseg import FgImageStats\n\n        input = {}\n        input['image'] = np.random.rand(1,30,30,30)\n        input['label'] = np.ones([30,30,30])\n        analyzer = FgImageStats(image_key='image', label_key='label')\n        print(analyzer(input)[\"image_foreground_stats\"])\n\n    \"\"\"\n\n    def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.FG_IMAGE_STATS):\n        self.image_key = image_key\n        self.label_key = label_key\n\n        report_format = {ImageStatsKeys.INTENSITY: None}\n\n        super().__init__(stats_name, report_format)\n        self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())\n\n    def __call__(self, data: Mapping) -> dict:\n        \"\"\"\n        Callable to execute the pre-defined functions\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n            in a list format. Each element of the value list has stats pre-defined\n            by SampleOperations (max, min, ....).\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Note:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n\n        d = dict(data)\n        start = time.time()\n        restore_grad_state = torch.is_grad_enabled()\n        torch.set_grad_enabled(False)\n\n        ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]\n        ndas_label = d[self.label_key]  # (H,W,D)\n\n        if ndas_label.shape != ndas[0].shape:\n            raise ValueError(f\"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}\")\n\n        nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]\n        nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]\n\n        # perform calculation\n        report = deepcopy(self.get_report_format())\n\n        report[ImageStatsKeys.INTENSITY] = [\n            self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds\n        ]\n\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        d[self.stats_name] = report\n\n        torch.set_grad_enabled(restore_grad_state)\n        logger.debug(f\"Get foreground image stats spent {time.time() - start}\")\n        return d\n\n\nclass LabelStats(Analyzer):\n    \"\"\"\n    Analyzer to extract label stats properties for each case(image and label).\n\n    Args:\n        image_key: the key to find image data in the callable function input (data)\n        label_key: the key to find label data in the callable function input (data)\n        do_ccp: performs connected component analysis. Default is True.\n\n    Examples:\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.auto3dseg import LabelStats\n\n        input = {}\n        input['image'] = np.random.rand(1,30,30,30)\n        input['label'] = np.ones([30,30,30])\n        analyzer = LabelStats(image_key='image', label_key='label')\n        print(analyzer(input)[\"label_stats\"])\n\n    \"\"\"\n\n    def __init__(\n        self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.LABEL_STATS, do_ccp: bool | None = True\n    ):\n        self.image_key = image_key\n        self.label_key = label_key\n        self.do_ccp = do_ccp\n\n        report_format: dict[LabelStatsKeys, Any] = {\n            LabelStatsKeys.LABEL_UID: None,\n            LabelStatsKeys.IMAGE_INTST: None,\n            LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}],\n        }\n\n        if self.do_ccp:\n            report_format[LabelStatsKeys.LABEL][0].update(\n                {LabelStatsKeys.LABEL_SHAPE: None, LabelStatsKeys.LABEL_NCOMP: None}\n            )\n\n        super().__init__(stats_name, report_format)\n        self.update_ops(LabelStatsKeys.IMAGE_INTST, SampleOperations())\n\n        id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, \"0\", LabelStatsKeys.IMAGE_INTST])\n        self.update_ops_nested_label(id_seq, SampleOperations())\n\n    def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]:\n        \"\"\"\n        Callable to execute the pre-defined functions.\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n            in a list format. Each element of the value list has stats pre-defined\n            by SampleOperations (max, min, ....).\n\n        Examples:\n            output dict contains {\n                LabelStatsKeys.LABEL_UID:[0,1,3],\n                LabelStatsKeys.IMAGE_INTST: {...},\n                LabelStatsKeys.LABEL:[\n                    {\n                        LabelStatsKeys.PIXEL_PCT: 0.8,\n                        LabelStatsKeys.IMAGE_INTST: {...},\n                        LabelStatsKeys.LABEL_SHAPE: [...],\n                        LabelStatsKeys.LABEL_NCOMP: 1\n                    }\n                    {\n                        LabelStatsKeys.PIXEL_PCT: 0.1,\n                        LabelStatsKeys.IMAGE_INTST: {...},\n                        LabelStatsKeys.LABEL_SHAPE: [...],\n                        LabelStatsKeys.LABEL_NCOMP: 1\n                    }\n                    {\n                        LabelStatsKeys.PIXEL_PCT: 0.1,\n                        LabelStatsKeys.IMAGE_INTST: {...},\n                        LabelStatsKeys.LABEL_SHAPE: [...],\n                        LabelStatsKeys.LABEL_NCOMP: 1\n                    }\n                ]\n                }\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Notes:\n            The label class_ID of the dictionary in LabelStatsKeys.LABEL IS NOT the\n            index. Instead, the class_ID is the LabelStatsKeys.LABEL_UID with the same\n            index. For instance, the last dict in LabelStatsKeys.LABEL in the Examples\n            is 3, which is the last element under LabelStatsKeys.LABEL_UID.\n\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n        d: dict[Hashable, MetaTensor] = dict(data)\n        start = time.time()\n        if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == \"cuda\":\n            using_cuda = True\n        else:\n            using_cuda = False\n        restore_grad_state = torch.is_grad_enabled()\n        torch.set_grad_enabled(False)\n\n        ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]  # type: ignore\n        ndas_label: MetaTensor = d[self.label_key].astype(torch.int16)  # (H,W,D)\n\n        if ndas_label.shape != ndas[0].shape:\n            raise ValueError(f\"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}\")\n\n        nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]\n        nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]\n\n        unique_label = unique(ndas_label)\n        if isinstance(ndas_label, (MetaTensor, torch.Tensor)):\n            unique_label = unique_label.data.cpu().numpy()  # type: ignore[assignment]\n\n        unique_label = unique_label.astype(np.int16).tolist()\n\n        label_substats = []  # each element is one label\n        pixel_sum = 0\n        pixel_arr = []\n        for index in unique_label:\n            start_label = time.time()\n            label_dict: dict[str, Any] = {}\n            mask_index = ndas_label == index\n\n            nda_masks = [nda[mask_index] for nda in ndas]\n            label_dict[LabelStatsKeys.IMAGE_INTST] = [\n                self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks\n            ]\n\n            pixel_count = sum(mask_index)\n            pixel_arr.append(pixel_count)\n            pixel_sum += pixel_count\n            if self.do_ccp:  # apply connected component\n                if using_cuda:\n                    # The back end of get_label_ccp is CuPy\n                    # which is unable to automatically release CUDA GPU memory held by PyTorch\n                    del nda_masks\n                    torch.cuda.empty_cache()\n                shape_list, ncomponents = get_label_ccp(mask_index)\n                label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list\n                label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents\n\n            label_substats.append(label_dict)\n            logger.debug(f\" label {index} stats takes {time.time() - start_label}\")\n\n        for i, _ in enumerate(unique_label):\n            label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})\n\n        report = deepcopy(self.get_report_format())\n        report[LabelStatsKeys.LABEL_UID] = unique_label\n        report[LabelStatsKeys.IMAGE_INTST] = [\n            self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds\n        ]\n        report[LabelStatsKeys.LABEL] = label_substats\n\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        d[self.stats_name] = report  # type: ignore[assignment]\n\n        torch.set_grad_enabled(restore_grad_state)\n        logger.debug(f\"Get label stats spent {time.time() - start}\")\n        return d  # type: ignore[return-value]\n\n\nclass ImageStatsSumm(Analyzer):\n    \"\"\"\n    This summary analyzer processes the values of specific key `stats_name` in a list of dict.\n    Typically, the list of dict is the output of case analyzer under the same prefix\n    (ImageStats).\n\n    Args:\n        stats_name: the key of the to-process value in the dict.\n        average: whether to average the statistical value across different image modalities.\n\n    \"\"\"\n\n    def __init__(self, stats_name: str = DataStatsKeys.IMAGE_STATS, average: bool | None = True):\n        self.summary_average = average\n        report_format = {\n            ImageStatsKeys.SHAPE: None,\n            ImageStatsKeys.CHANNELS: None,\n            ImageStatsKeys.CROPPED_SHAPE: None,\n            ImageStatsKeys.SPACING: None,\n            ImageStatsKeys.SIZEMM: None,\n            ImageStatsKeys.INTENSITY: None,\n        }\n        super().__init__(stats_name, report_format)\n\n        self.update_ops(ImageStatsKeys.SHAPE, SampleOperations())\n        self.update_ops(ImageStatsKeys.CHANNELS, SampleOperations())\n        self.update_ops(ImageStatsKeys.CROPPED_SHAPE, SampleOperations())\n        self.update_ops(ImageStatsKeys.SPACING, SampleOperations())\n        self.update_ops(ImageStatsKeys.SIZEMM, SampleOperations())\n        self.update_ops(ImageStatsKeys.INTENSITY, SummaryOperations())\n\n    def __call__(self, data: list[dict]) -> dict:\n        \"\"\"\n        Callable to execute the pre-defined functions\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n            in a list format. Each element of the value list has stats pre-defined\n            by SampleOperations (max, min, ....).\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Examples:\n            output dict contains a dictionary for all of the following keys{\n                ImageStatsKeys.SHAPE:{...}\n                ImageStatsKeys.CHANNELS: {...},\n                ImageStatsKeys.CROPPED_SHAPE: {...},\n                ImageStatsKeys.SPACING: {...},\n                ImageStatsKeys.SIZEMM: {...},\n                ImageStatsKeys.INTENSITY: {...},\n                }\n\n        Notes:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n        if not isinstance(data, list):\n            raise ValueError(f\"Callable {self.__class__} requires list inputs\")\n\n        if len(data) == 0:\n            raise ValueError(f\"Callable {self.__class__} input list is empty\")\n\n        if self.stats_name not in data[0]:\n            raise KeyError(f\"{self.stats_name} is not in input data\")\n\n        report = deepcopy(self.get_report_format())\n\n        for k in [\n            ImageStatsKeys.SHAPE,\n            ImageStatsKeys.CHANNELS,\n            ImageStatsKeys.CROPPED_SHAPE,\n            ImageStatsKeys.SPACING,\n            ImageStatsKeys.SIZEMM,\n        ]:\n            v_np = concat_val_to_np(data, [self.stats_name, k])\n            report[k] = self.ops[k].evaluate(v_np, dim=(0, 1) if v_np.ndim > 2 and self.summary_average else 0)\n\n        intst_str = ImageStatsKeys.INTENSITY\n        op_keys = report[intst_str].keys()  # template, max/min/...\n        intst_dict = concat_multikeys_to_dict(data, [self.stats_name, intst_str], op_keys)\n        report[intst_str] = self.ops[intst_str].evaluate(intst_dict, dim=None if self.summary_average else 0)\n\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        return report\n\n\nclass FgImageStatsSumm(Analyzer):\n    \"\"\"\n    This summary analyzer processes the values of specific key `stats_name` in a list of\n    dict. Typically, the list of dict is the output of case analyzer under the similar name\n    (FgImageStats).\n\n    Args:\n        stats_name: the key of the to-process value in the dict.\n        average: whether to average the statistical value across different image modalities.\n\n    \"\"\"\n\n    def __init__(self, stats_name: str = DataStatsKeys.FG_IMAGE_STATS, average: bool | None = True):\n        self.summary_average = average\n\n        report_format = {ImageStatsKeys.INTENSITY: None}\n        super().__init__(stats_name, report_format)\n        self.update_ops(ImageStatsKeys.INTENSITY, SummaryOperations())\n\n    def __call__(self, data: list[dict]) -> dict:\n        \"\"\"\n        Callable to execute the pre-defined functions.\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n            in a list format. Each element of the value list has stats pre-defined\n            by SampleOperations (max, min, ....) and SummaryOperation (max of the\n            max, mean of the mean, etc).\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Examples:\n            output dict contains a dictionary for all of the following keys{\n                ImageStatsKeys.INTENSITY: {...},\n                }\n\n        Notes:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n        if not isinstance(data, list):\n            raise ValueError(f\"Callable {self.__class__} requires list inputs\")\n\n        if len(data) == 0:\n            raise ValueError(f\"Callable {self.__class__} input list is empty\")\n\n        if self.stats_name not in data[0]:\n            raise KeyError(f\"{self.stats_name} is not in input data.\")\n\n        report = deepcopy(self.get_report_format())\n        intst_str = ImageStatsKeys.INTENSITY\n        op_keys = report[intst_str].keys()  # template, max/min/...\n        intst_dict = concat_multikeys_to_dict(data, [self.stats_name, intst_str], op_keys)\n\n        report[intst_str] = self.ops[intst_str].evaluate(intst_dict, dim=None if self.summary_average else 0)\n\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        return report\n\n\nclass LabelStatsSumm(Analyzer):\n    \"\"\"\n    This summary analyzer processes the values of specific key `stats_name` in a list of\n    dict. Typically, the list of dict is the output of case analyzer under the similar name\n    (LabelStats).\n\n    Args:\n        stats_name: the key of the to-process value in the dict.\n        average: whether to average the statistical value across different image modalities.\n\n    \"\"\"\n\n    def __init__(\n        self, stats_name: str = DataStatsKeys.LABEL_STATS, average: bool | None = True, do_ccp: bool | None = True\n    ):\n        self.summary_average = average\n        self.do_ccp = do_ccp\n\n        report_format: dict[str, Any] = {\n            LabelStatsKeys.LABEL_UID: None,\n            LabelStatsKeys.IMAGE_INTST: None,\n            LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}],\n        }\n        if self.do_ccp:\n            report_format[LabelStatsKeys.LABEL][0].update(\n                {LabelStatsKeys.LABEL_SHAPE: None, LabelStatsKeys.LABEL_NCOMP: None}\n            )\n\n        super().__init__(stats_name, report_format)\n        self.update_ops(LabelStatsKeys.IMAGE_INTST, SummaryOperations())\n\n        # label-0-'pixel percentage'\n        id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, \"0\", LabelStatsKeys.PIXEL_PCT])\n        self.update_ops_nested_label(id_seq, SampleOperations())\n        # label-0-'image intensity'\n        id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, \"0\", LabelStatsKeys.IMAGE_INTST])\n        self.update_ops_nested_label(id_seq, SummaryOperations())\n        # label-0-shape\n        id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, \"0\", LabelStatsKeys.LABEL_SHAPE])\n        self.update_ops_nested_label(id_seq, SampleOperations())\n        # label-0-ncomponents\n        id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, \"0\", LabelStatsKeys.LABEL_NCOMP])\n        self.update_ops_nested_label(id_seq, SampleOperations())\n\n    def __call__(self, data: list[dict]) -> dict:\n        \"\"\"\n        Callable to execute the pre-defined functions\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n            in a list format. Each element of the value list has stats pre-defined\n            by SampleOperations (max, min, ....) and SummaryOperation (max of the\n            max, mean of the mean, etc).\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Notes:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n        if not isinstance(data, list):\n            raise ValueError(f\"Callable {self.__class__} requires list inputs\")\n\n        if len(data) == 0:\n            raise ValueError(f\"Callable {self.__class__} input list is empty\")\n\n        if self.stats_name not in data[0]:\n            raise KeyError(f\"{self.stats_name} is not in input data\")\n\n        report = deepcopy(self.get_report_format())\n        # unique class ID\n        uid_np = concat_val_to_np(data, [self.stats_name, LabelStatsKeys.LABEL_UID], axis=None, ragged=True)\n        unique_label = label_union(uid_np)\n        report[LabelStatsKeys.LABEL_UID] = unique_label\n\n        # image intensity\n        intst_str = LabelStatsKeys.IMAGE_INTST\n        op_keys = report[intst_str].keys()  # template, max/min/...\n        intst_dict = concat_multikeys_to_dict(data, [self.stats_name, intst_str], op_keys)\n        report[intst_str] = self.ops[intst_str].evaluate(intst_dict, dim=None if self.summary_average else 0)\n\n        detailed_label_list = []\n        # iterate through each label\n        label_str = LabelStatsKeys.LABEL\n        for label_id in unique_label:\n            stats = {}\n\n            pct_str = LabelStatsKeys.PIXEL_PCT\n            pct_fixed_keys = [self.stats_name, label_str, label_id, pct_str]\n            pct_np = concat_val_to_np(data, pct_fixed_keys, allow_missing=True)\n            stats[pct_str] = self.ops[label_str][0][pct_str].evaluate(\n                pct_np, dim=(0, 1) if pct_np.ndim > 2 and self.summary_average else 0\n            )\n\n            if self.do_ccp:\n                ncomp_str = LabelStatsKeys.LABEL_NCOMP\n                ncomp_fixed_keys = [self.stats_name, LabelStatsKeys.LABEL, label_id, ncomp_str]\n                ncomp_np = concat_val_to_np(data, ncomp_fixed_keys, allow_missing=True)\n                stats[ncomp_str] = self.ops[label_str][0][ncomp_str].evaluate(\n                    ncomp_np, dim=(0, 1) if ncomp_np.ndim > 2 and self.summary_average else 0\n                )\n\n                shape_str = LabelStatsKeys.LABEL_SHAPE\n                shape_fixed_keys = [self.stats_name, label_str, label_id, LabelStatsKeys.LABEL_SHAPE]\n                shape_np = concat_val_to_np(data, shape_fixed_keys, ragged=True, allow_missing=True)\n                stats[shape_str] = self.ops[label_str][0][shape_str].evaluate(\n                    shape_np, dim=(0, 1) if shape_np.ndim > 2 and self.summary_average else 0\n                )\n                # label shape is a 3-element value, but the number of labels in each image\n                # can vary from 0 to N. So the value in a list format is \"ragged\"\n\n            intst_str = LabelStatsKeys.IMAGE_INTST\n            intst_fixed_keys = [self.stats_name, label_str, label_id, intst_str]\n            op_keys = report[label_str][0][intst_str].keys()\n            intst_dict = concat_multikeys_to_dict(data, intst_fixed_keys, op_keys, allow_missing=True)\n            stats[intst_str] = self.ops[label_str][0][intst_str].evaluate(\n                intst_dict, dim=None if self.summary_average else 0\n            )\n\n            detailed_label_list.append(stats)\n\n        report[LabelStatsKeys.LABEL] = detailed_label_list\n\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        return report\n\n\nclass FilenameStats(Analyzer):\n    \"\"\"\n    This class finds the file path for the loaded image/label and writes the info\n    into the data pipeline as a monai transforms.\n\n    Args:\n        key: the key to fetch the filename (for example, \"image\", \"label\").\n        stats_name: the key to store the filename in the output stats report.\n\n    \"\"\"\n\n    def __init__(self, key: str | None, stats_name: str) -> None:\n        self.key = key\n        super().__init__(stats_name, {})\n\n    def __call__(self, data):\n        d = dict(data)\n\n        if self.key:  # when there is no (label) file, key can be None\n            if self.key not in d:  # check whether image/label is in the data\n                raise ValueError(f\"Data with key {self.key} is missing.\")\n            if not isinstance(d[self.key], MetaTensor):\n                raise ValueError(f\"Value type of {self.key} is not MetaTensor.\")\n            if ImageMetaKey.FILENAME_OR_OBJ not in d[self.key].meta:\n                raise ValueError(f\"{ImageMetaKey.FILENAME_OR_OBJ} not found in MetaTensor {d[self.key]}.\")\n            d[self.stats_name] = d[self.key].meta[ImageMetaKey.FILENAME_OR_OBJ]\n        else:\n            d[self.stats_name] = \"None\"\n\n        return d\n\n\nclass ImageHistogram(Analyzer):\n    \"\"\"\n    Analyzer to compute intensity histogram.\n\n    Args:\n        image_key: the key to find image data in the callable function input (data)\n        hist_bins: list of positive integers (one for each channel) for setting the number of bins used to\n            compute the histogram. Defaults to [100].\n        hist_range: list of lists of two floats (one for each channel) setting the intensity range to\n            compute the histogram. Defaults to [-500, 500].\n\n    Examples:\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.auto3dseg.analyzer import ImageHistogram\n\n        input = {}\n        input['image'] = np.random.rand(1,30,30,30)\n        input['label'] = np.ones([30,30,30])\n        analyzer = ImageHistogram(image_key='image')\n        print(analyzer(input))\n\n    \"\"\"\n\n    def __init__(\n        self,\n        image_key: str,\n        stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM,\n        hist_bins: list[int] | int | None = None,\n        hist_range: list | None = None,\n    ):\n        self.image_key = image_key\n\n        # set defaults\n        self.hist_bins: list[int] = (\n            [100] if hist_bins is None else hist_bins if isinstance(hist_bins, list) else [hist_bins]\n        )\n        self.hist_range: list = [-500, 500] if hist_range is None else hist_range\n\n        report_format = {\"counts\": None, \"bin_edges\": None}\n\n        super().__init__(stats_name, report_format)\n        self.update_ops(ImageStatsKeys.HISTOGRAM, SampleOperations())\n\n        # check histogram configurations for each channel in list\n        if not all(isinstance(hr, list) for hr in self.hist_range):\n            self.hist_range = [self.hist_range]\n        if len(self.hist_bins) != len(self.hist_range):\n            raise ValueError(\n                f\"Number of histogram bins ({len(self.hist_bins)}) and \"\n                f\"histogram ranges ({len(self.hist_range)}) need to be the same!\"\n            )\n        for i, hist_params in enumerate(zip(self.hist_bins, self.hist_range)):\n            _hist_bins, _hist_range = hist_params\n            if not isinstance(_hist_bins, int) or _hist_bins < 0:\n                raise ValueError(f\"Expected {i + 1}. hist_bins value to be positive integer but got {_hist_bins}\")\n            if not isinstance(_hist_range, list) or len(_hist_range) != 2:\n                raise ValueError(\n                    f\"Expected {i + 1}. hist_range values to be list of length 2 but received {_hist_range}\"\n                )\n\n    def __call__(self, data: dict) -> dict:\n        \"\"\"\n        Callable to execute the pre-defined functions\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Note:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n\n        d = dict(data)\n\n        ndas = convert_to_numpy(d[self.image_key], wrap_sequence=True)  # (1,H,W,D) or (C,H,W,D)\n        nr_channels = np.shape(ndas)[0]\n\n        # adjust histogram params to match channels\n        if len(self.hist_bins) == 1:\n            self.hist_bins = nr_channels * self.hist_bins\n        if len(self.hist_bins) != nr_channels:\n            raise ValueError(\n                f\"There is a mismatch between the number of channels ({nr_channels}) \"\n                f\"and number histogram bins ({len(self.hist_bins)}).\"\n            )\n        if len(self.hist_range) == 1:\n            self.hist_range = nr_channels * self.hist_range\n        if len(self.hist_range) != nr_channels:\n            raise ValueError(\n                f\"There is a mismatch between the number of channels ({nr_channels}) \"\n                f\"and histogram ranges ({len(self.hist_range)}).\"\n            )\n\n        # perform calculation\n        reports = []\n        for channel in range(nr_channels):\n            counts, bin_edges = np.histogram(\n                ndas[channel, ...],\n                bins=self.hist_bins[channel],\n                range=(self.hist_range[channel][0], self.hist_range[channel][1]),\n            )\n            _report = {\"counts\": counts.tolist(), \"bin_edges\": bin_edges.tolist()}\n            if not verify_report_format(_report, self.get_report_format()):\n                raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n            reports.append(_report)\n\n        d[self.stats_name] = reports\n        return d\n\n\nclass ImageHistogramSumm(Analyzer):\n    \"\"\"\n    This summary analyzer processes the values of specific key `stats_name` in a list of dict.\n    Typically, the list of dict is the output of case analyzer under the same prefix\n    (ImageHistogram).\n\n    Args:\n        stats_name: the key of the to-process value in the dict.\n        average: whether to average the statistical value across different image modalities.\n\n    \"\"\"\n\n    def __init__(self, stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM, average: bool | None = True):\n        self.summary_average = average\n        report_format = {ImageStatsKeys.HISTOGRAM: None}\n        super().__init__(stats_name, report_format)\n\n        self.update_ops(ImageStatsKeys.HISTOGRAM, SummaryOperations())\n\n    def __call__(self, data: list[dict]) -> dict:\n        \"\"\"\n        Callable to execute the pre-defined functions\n\n        Returns:\n            A dictionary. The dict has the key in self.report_format and value\n            in a list format. Each element of the value list has stats pre-defined\n            by SampleOperations (max, min, ....).\n\n        Raises:\n            RuntimeError if the stats report generated is not consistent with the pre-\n                defined report_format.\n\n        Examples:\n            output dict contains a dictionary for all of the following keys{\n                ImageStatsKeys.SHAPE:{...}\n                ImageStatsKeys.CHANNELS: {...},\n                ImageStatsKeys.CROPPED_SHAPE: {...},\n                ImageStatsKeys.SPACING: {...},\n                ImageStatsKeys.SIZEMM: {...},\n                ImageStatsKeys.INTENSITY: {...},\n                }\n\n        Notes:\n            The stats operation uses numpy and torch to compute max, min, and other\n            functions. If the input has nan/inf, the stats results will be nan/inf.\n        \"\"\"\n        if not isinstance(data, list):\n            raise ValueError(f\"Callable {self.__class__} requires list inputs\")\n\n        if len(data) == 0:\n            raise ValueError(f\"Callable {self.__class__} input list is empty\")\n\n        if self.stats_name not in data[0]:\n            raise KeyError(f\"{self.stats_name} is not in input data\")\n\n        summ_histogram: dict = {}\n\n        for d in data:\n            if not summ_histogram:\n                summ_histogram = d[DataStatsKeys.IMAGE_HISTOGRAM]\n                # convert to numpy for computing total histogram\n                for k in range(len(summ_histogram)):\n                    summ_histogram[k][\"counts\"] = np.array(summ_histogram[k][\"counts\"])\n            else:\n                for k in range(len(summ_histogram)):\n                    summ_histogram[k][\"counts\"] += np.array(d[DataStatsKeys.IMAGE_HISTOGRAM][k][\"counts\"])\n                    if np.all(summ_histogram[k][\"bin_edges\"] != d[DataStatsKeys.IMAGE_HISTOGRAM][k][\"bin_edges\"]):\n                        raise ValueError(\n                            f\"bin edges are not consistent! {summ_histogram[k]['bin_edges']} vs. \"\n                            f\"{d[DataStatsKeys.IMAGE_HISTOGRAM][k]['bin_edges']}\"\n                        )\n\n        # convert back to list\n        for k in range(len(summ_histogram)):\n            summ_histogram[k][\"counts\"] = summ_histogram[k][\"counts\"].tolist()\n\n        report = {ImageStatsKeys.HISTOGRAM: summ_histogram}\n        if not verify_report_format(report, self.get_report_format()):\n            raise RuntimeError(f\"report generated by {self.__class__} differs from the report format.\")\n\n        return report\n"
  },
  {
    "path": "monai/auto3dseg/operations.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections import UserDict\nfrom functools import partial\nfrom typing import Any\n\nfrom monai.transforms.utils_pytorch_numpy_unification import max, mean, median, min, percentile, std\n\n__all__ = [\"Operations\", \"SampleOperations\", \"SummaryOperations\"]\n\n\nclass Operations(UserDict):\n    \"\"\"\n    Base class of operation interface\n    \"\"\"\n\n    def evaluate(self, data: Any, **kwargs: Any) -> dict:\n        \"\"\"\n        For key-value pairs in the self.data, if the value is a callable,\n        then this function will apply the callable to the input data.\n        The result will be written under the same key under the output dict.\n\n        Args:\n            data: input data.\n\n        Returns:\n            a dictionary which has same keys as the self.data if the value\n                is callable.\n        \"\"\"\n        return {k: v(data, **kwargs) for k, v in self.data.items() if callable(v)}\n\n\nclass SampleOperations(Operations):\n    \"\"\"\n    Apply statistical operation to a sample (image/ndarray/tensor).\n\n    Notes:\n        Percentile operation uses a partial function that embeds different kwargs (q).\n        In order to print the result nicely, data_addon is added to map the numbers\n        generated by percentile to different keys (\"percentile_00_5\" for example).\n        Annotation of the postfix means the percentage for percentile computation.\n        For example, _00_5 means 0.5% and _99_5 means 99.5%.\n\n    Example:\n\n        .. code-block:: python\n\n            # use the existing operations\n            import numpy as np\n            op = SampleOperations()\n            data_np = np.random.rand(10, 10).astype(np.float64)\n            print(op.evaluate(data_np))\n\n            # add a new operation\n            op.update({\"sum\": np.sum})\n            print(op.evaluate(data_np))\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.data = {\n            \"max\": max,\n            \"mean\": mean,\n            \"median\": median,\n            \"min\": min,\n            \"stdev\": std,\n            \"percentile\": partial(percentile, q=[0.5, 10, 90, 99.5]),\n        }\n        self.data_addon = {\n            \"percentile_00_5\": (\"percentile\", 0),\n            \"percentile_10_0\": (\"percentile\", 1),\n            \"percentile_90_0\": (\"percentile\", 2),\n            \"percentile_99_5\": (\"percentile\", 3),\n        }\n\n    def evaluate(self, data: Any, **kwargs: Any) -> dict:\n        \"\"\"\n        Applies the callables to the data, and convert the\n        numerics to list or Python numeric types (int/float).\n\n        Args:\n            data: input data\n        \"\"\"\n        ret = super().evaluate(data, **kwargs)\n        for k, v in self.data_addon.items():\n            cache = v[0]\n            idx = v[1]\n            if isinstance(v, tuple) and cache in ret:\n                ret.update({k: ret[cache][idx]})\n\n        for k, v in ret.items():\n            ret[k] = v.tolist()  # type: ignore\n        return ret\n\n\nclass SummaryOperations(Operations):\n    \"\"\"\n    Apply statistical operation to summarize a dict. The key-value looks like: {\"max\", \"min\"\n    ,\"mean\", ....}. The value may contain multiple values in a list format. Then this operation\n    will apply the operation to the list. Typically, the dict is generated by multiple\n    `SampleOperation` and `concat_multikeys_to_dict` functions.\n\n    Examples:\n\n        .. code-block:: python\n\n            import numpy as np\n            data = {\n                \"min\": np.random.rand(4),\n                \"max\": np.random.rand(4),\n                \"mean\": np.random.rand(4),\n                \"sum\": np.random.rand(4),\n            }\n            op = SummaryOperations()\n            print(op.evaluate(data)) # \"sum\" is not registered yet, so it won't contain \"sum\"\n\n            op.update({\"sum\", np.sum})\n            print(op.evaluate(data)) # output has \"sum\"\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.data = {\n            \"max\": max,\n            \"mean\": mean,\n            \"median\": mean,\n            \"min\": min,\n            \"stdev\": mean,\n            \"percentile_00_5\": mean,\n            \"percentile_10_0\": mean,\n            \"percentile_90_0\": mean,\n            \"percentile_99_5\": mean,\n        }\n\n    def evaluate(self, data: Any, **kwargs: Any) -> dict:\n        \"\"\"\n        Applies the callables to the data, and convert the numerics to list or Python\n        numeric types (int/float).\n\n        Args:\n            data: input data\n        \"\"\"\n        return {k: v(data[k], **kwargs).tolist() for k, v in self.data.items() if (callable(v) and k in data)}\n"
  },
  {
    "path": "monai/auto3dseg/seg_summarizer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nfrom monai.auto3dseg.analyzer import (\n    Analyzer,\n    FgImageStats,\n    FgImageStatsSumm,\n    FilenameStats,\n    ImageHistogram,\n    ImageHistogramSumm,\n    ImageStats,\n    ImageStatsSumm,\n    LabelStats,\n    LabelStatsSumm,\n)\nfrom monai.transforms import Compose\nfrom monai.utils.enums import DataStatsKeys\n\n__all__ = [\"SegSummarizer\"]\n\n\nclass SegSummarizer(Compose):\n    \"\"\"\n    SegSummarizer serializes the operations for data analysis in Auto3Dseg pipeline. It loads\n    two types of analyzer functions and execute differently. The first type of analyzer is\n    CaseAnalyzer which is similar to traditional monai transforms. It can be composed with other\n    transforms to process the data dict which has image/label keys. The second type of analyzer\n    is SummaryAnalyzer which works only on a list of dictionary. Each dictionary is the output\n    of the case analyzers on a single dataset.\n\n    Args:\n        image_key: a string that user specify for the image. The DataAnalyzer will look it up in the\n            datalist to locate the image files of the dataset.\n        label_key: a string that user specify for the label. The DataAnalyzer will look it up in the\n            datalist to locate the label files of the dataset. If label_key is None, the DataAnalyzer\n            will skip looking for labels and all label-related operations.\n        do_ccp: apply the connected component algorithm to process the labels/images.\n        hist_bins: list of positive integers (one for each channel) for setting the number of bins used to\n            compute the histogram. Defaults to [100].\n        hist_range: list of lists of two floats (one for each channel) setting the intensity range to\n            compute the histogram. Defaults to [-500, 500].\n        histogram_only: whether to only compute histograms. Defaults to False.\n\n    Examples:\n        .. code-block:: python\n\n            # imports\n\n            summarizer = SegSummarizer(\"image\", \"label\")\n            transform_list = [\n                LoadImaged(keys=keys),\n                EnsureChannelFirstd(keys=keys),  # this creates label to be (1,H,W,D)\n                ToDeviced(keys=keys, device=device, non_blocking=True),\n                Orientationd(keys=keys, axcodes=\"RAS\"),\n                EnsureTyped(keys=keys, data_type=\"tensor\"),\n                Lambdad(keys=\"label\", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),\n                SqueezeDimd(keys=[\"label\"], dim=0),\n                summarizer,\n            ]\n            ...\n            # skip some steps to set up data loader\n            dataset = data.DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n            transform = Compose(transform_list)\n            stats = []\n            for batch_data in dataset:\n                d = transform(batch_data[0])\n                stats.append(d)\n            report = summarizer.summarize(stats)\n    \"\"\"\n\n    def __init__(\n        self,\n        image_key: str,\n        label_key: str | None,\n        average: bool = True,\n        do_ccp: bool = True,\n        hist_bins: list[int] | int | None = None,\n        hist_range: list | None = None,\n        histogram_only: bool = False,\n    ) -> None:\n        self.image_key = image_key\n        self.label_key = label_key\n        # set defaults\n        self.hist_bins: list[int] | int = [100] if hist_bins is None else hist_bins\n        self.hist_range: list = [-500, 500] if hist_range is None else hist_range\n        self.histogram_only = histogram_only\n\n        self.summary_analyzers: list[Any] = []\n        super().__init__()\n\n        self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)\n        self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)\n        if not self.histogram_only:\n            self.add_analyzer(ImageStats(image_key), ImageStatsSumm(average=average))\n\n            if label_key is None:\n                return\n\n            self.add_analyzer(FgImageStats(image_key, label_key), FgImageStatsSumm(average=average))\n\n            self.add_analyzer(\n                LabelStats(image_key, label_key, do_ccp=do_ccp), LabelStatsSumm(average=average, do_ccp=do_ccp)\n            )\n\n        # compute histograms\n        if self.hist_bins != 0:\n            self.add_analyzer(\n                ImageHistogram(image_key=image_key, hist_bins=hist_bins, hist_range=hist_range), ImageHistogramSumm()\n            )\n\n    def add_analyzer(self, case_analyzer: Analyzer, summary_analyzer: Analyzer | None) -> None:\n        \"\"\"\n        Add new analyzers to the engine so that the callable and summarize functions will\n        utilize the new analyzers for stats computations.\n\n        Args:\n            case_analyzer: analyzer that works on each data.\n            summary_analyzer: analyzer that works on list of stats dict (output from case_analyzers).\n\n        Examples:\n\n            .. code-block:: python\n\n                from monai.auto3dseg import Analyzer\n                from monai.auto3dseg.utils import concat_val_to_np\n                from monai.auto3dseg.analyzer_engine import SegSummarizer\n\n                class UserAnalyzer(Analyzer):\n                    def __init__(self, image_key=\"image\", stats_name=\"user_stats\"):\n                        self.image_key = image_key\n                        report_format = {\"ndims\": None}\n                        super().__init__(stats_name, report_format)\n\n                    def __call__(self, data):\n                        d = dict(data)\n                        report = deepcopy(self.get_report_format())\n                        report[\"ndims\"] = d[self.image_key].ndim\n                        d[self.stats_name] = report\n                        return d\n\n                class UserSummaryAnalyzer(Analyzer):\n                    def __init__(stats_name=\"user_stats\"):\n                        report_format = {\"ndims\": None}\n                        super().__init__(stats_name, report_format)\n                        self.update_ops(\"ndims\", SampleOperations())\n\n                    def __call__(self, data):\n                        report = deepcopy(self.get_report_format())\n                        v_np = concat_val_to_np(data, [self.stats_name, \"ndims\"])\n                        report[\"ndims\"] = self.ops[\"ndims\"].evaluate(v_np)\n                        return report\n\n                summarizer = SegSummarizer()\n                summarizer.add_analyzer(UserAnalyzer, UserSummaryAnalyzer)\n\n        \"\"\"\n        self.transforms += (case_analyzer,)\n        if summary_analyzer is not None:\n            self.summary_analyzers.append(summary_analyzer)\n\n    def summarize(self, data: list[dict]) -> dict[str, dict]:\n        \"\"\"\n        Summarize the input list of data and generates a report ready for json/yaml export.\n\n        Args:\n            data: a list of data dicts.\n\n        Returns:\n            a dict that summarizes the stats across data samples.\n\n        Examples:\n            stats_summary:\n                image_foreground_stats:\n                    intensity: {...}\n                image_stats:\n                    channels: {...}\n                    cropped_shape: {...}\n                    ...\n                label_stats:\n                    image_intensity: {...}\n                    label:\n                    - image_intensity: {...}\n                    - image_intensity: {...}\n                    - image_intensity: {...}\n                    - image_intensity: {...}\n        \"\"\"\n        if not isinstance(data, list):\n            raise ValueError(f\"{self.__class__} summarize function needs input to be a list of dict\")\n\n        report: dict[str, dict] = {}\n        if len(data) == 0:\n            return report\n\n        if not isinstance(data[0], dict):\n            raise ValueError(f\"{self.__class__} summarize function needs a list of dict. Now we have {type(data[0])}\")\n\n        for analyzer in self.summary_analyzers:\n            if callable(analyzer):\n                report.update({analyzer.stats_name: analyzer(data)})\n\n        return report\n"
  },
  {
    "path": "monai/auto3dseg/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport pickle\nimport subprocess\nimport sys\nfrom copy import deepcopy\nfrom numbers import Number\nfrom typing import Any, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.auto3dseg import Algo\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.bundle.utils import ID_SEP_KEY\nfrom monai.config import PathLike\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import CropForeground, ToCupy\nfrom monai.utils import min_version, optional_import, run_cmd\n\n__all__ = [\n    \"get_foreground_image\",\n    \"get_foreground_label\",\n    \"get_label_ccp\",\n    \"concat_val_to_np\",\n    \"concat_multikeys_to_dict\",\n    \"datafold_read\",\n    \"verify_report_format\",\n    \"algo_to_pickle\",\n    \"algo_from_pickle\",\n]\n\nmeasure_np, has_measure = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\ncp, has_cp = optional_import(\"cupy\")\n\n\ndef get_foreground_image(image: MetaTensor) -> np.ndarray:\n    \"\"\"\n    Get a foreground image by removing all-zero rectangles on the edges of the image\n    Note for the developer: update select_fn if the foreground is defined differently.\n\n    Args:\n        image: ndarray image to segment.\n\n    Returns:\n        ndarray of foreground image by removing all-zero edges.\n\n    Notes:\n        the size of the output is smaller than the input.\n    \"\"\"\n\n    copper = CropForeground(select_fn=lambda x: x > 0, allow_smaller=True)\n    image_foreground = copper(image)\n    return cast(np.ndarray, image_foreground)\n\n\ndef get_foreground_label(image: MetaTensor, label: MetaTensor) -> MetaTensor:\n    \"\"\"\n    Get foreground image pixel values and mask out the non-labeled area.\n\n    Args\n        image: ndarray image to segment.\n        label: ndarray the image input and annotated with class IDs.\n\n    Returns:\n        1D array of foreground image with label > 0\n    \"\"\"\n\n    label_foreground = MetaTensor(image[label > 0])\n    return label_foreground\n\n\ndef get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> tuple[list[Any], int]:\n    \"\"\"\n    Find all connected components and their bounding shape. Backend can be cuPy/cuCIM or Numpy\n    depending on the hardware.\n\n    Args:\n        mask_index: a binary mask.\n        use_gpu: a switch to use GPU/CUDA or not. If GPU is unavailable, CPU will be used\n            regardless of this setting.\n\n    \"\"\"\n    skimage, has_cucim = optional_import(\"cucim.skimage\")\n    shape_list = []\n    if mask_index.device.type == \"cuda\" and has_cp and has_cucim and use_gpu:\n        mask_cupy = ToCupy()(mask_index.short())\n        labeled = skimage.measure.label(mask_cupy)\n        vals = cp.unique(labeled[cp.nonzero(labeled)])\n\n        for ncomp in vals:\n            comp_idx = cp.argwhere(labeled == ncomp)\n            comp_idx_min = cp.min(comp_idx, axis=0).tolist()\n            comp_idx_max = cp.max(comp_idx, axis=0).tolist()\n            bbox_shape = [comp_idx_max[i] - comp_idx_min[i] + 1 for i in range(len(comp_idx_max))]\n            shape_list.append(bbox_shape)\n        ncomponents = len(vals)\n\n        del mask_cupy, labeled, vals, comp_idx, ncomp\n        cp.get_default_memory_pool().free_all_blocks()\n\n    elif has_measure:\n        labeled, ncomponents = measure_np.label(mask_index.data.cpu().numpy(), background=-1, return_num=True)\n        for ncomp in range(1, ncomponents + 1):\n            comp_idx = np.argwhere(labeled == ncomp)\n            comp_idx_min = np.min(comp_idx, axis=0).tolist()\n            comp_idx_max = np.max(comp_idx, axis=0).tolist()\n            bbox_shape = [comp_idx_max[i] - comp_idx_min[i] + 1 for i in range(len(comp_idx_max))]\n            shape_list.append(bbox_shape)\n    else:\n        raise RuntimeError(\"Cannot find one of the following required dependencies: {cuPy+cuCIM} or {scikit-image}\")\n\n    return shape_list, ncomponents\n\n\ndef concat_val_to_np(\n    data_list: list[dict],\n    fixed_keys: list[str | int],\n    ragged: bool | None = False,\n    allow_missing: bool | None = False,\n    **kwargs: Any,\n) -> np.ndarray:\n    \"\"\"\n    Get the nested value in a list of dictionary that shares the same structure.\n\n    Args:\n       data_list: a list of dictionary {key1: {key2: np.ndarray}}.\n       fixed_keys: a list of keys that records to path to the value in the dict elements.\n       ragged: if True, numbers can be in list of lists or ragged format so concat mode needs change.\n       allow_missing: if True, it will return a None if the value cannot be found.\n\n    Returns:\n        nd.array of concatenated array.\n\n    \"\"\"\n\n    np_list: list[np.ndarray | None] = []\n    for data in data_list:\n        parser = ConfigParser(data)\n        for i, key in enumerate(fixed_keys):\n            fixed_keys[i] = str(key)\n\n        val: Any\n        val = parser.get(ID_SEP_KEY.join(fixed_keys))  # type: ignore\n\n        if val is None:\n            if allow_missing:\n                np_list.append(None)\n            else:\n                raise AttributeError(f\"{fixed_keys} is not nested in the dictionary\")\n        elif isinstance(val, list):\n            np_list.append(np.array(val))\n        elif isinstance(val, (torch.Tensor, MetaTensor)):\n            np_list.append(val.cpu().numpy())\n        elif isinstance(val, np.ndarray):\n            np_list.append(val)\n        elif isinstance(val, Number):\n            np_list.append(np.array(val))\n        else:\n            raise NotImplementedError(f\"{val.__class__} concat is not supported.\")\n\n    if allow_missing:\n        np_list = [x for x in np_list if x is not None]\n\n    if len(np_list) == 0:\n        return np.array([0])\n    elif ragged:\n        return np.concatenate(np_list, **kwargs)  # type: ignore\n    else:\n        return np.concatenate([np_list], **kwargs)\n\n\ndef concat_multikeys_to_dict(\n    data_list: list[dict], fixed_keys: list[str | int], keys: list[str], zero_insert: bool = True, **kwargs: Any\n) -> dict[str, np.ndarray]:\n    \"\"\"\n    Get the nested value in a list of dictionary that shares the same structure iteratively on all keys.\n    It returns a dictionary with keys with the found values in nd.ndarray.\n\n    Args:\n        data_list: a list of dictionary {key1: {key2: np.ndarray}}.\n        fixed_keys: a list of keys that records to path to the value in the dict elements.\n        keys: a list of string keys that will be iterated to generate a dict output.\n        zero_insert: insert a zero in the list so that it can find the value in element 0 before getting the keys\n        flatten: if True, numbers are flattened before concat.\n\n    Returns:\n        a dict with keys - nd.array of concatenated array pair.\n    \"\"\"\n\n    ret_dict = {}\n    for key in keys:\n        addon: list[str | int] = [0, key] if zero_insert else [key]\n        val = concat_val_to_np(data_list, fixed_keys + addon, **kwargs)\n        ret_dict.update({key: val})\n\n    return ret_dict\n\n\ndef datafold_read(datalist: str | dict, basedir: str, fold: int = 0, key: str = \"training\") -> tuple[list, list]:\n    \"\"\"\n    Read a list of data dictionary `datalist`\n\n    Args:\n        datalist: the name of a JSON file listing the data, or a dictionary.\n        basedir: directory of image files.\n        fold: which fold to use (0..1 if in training set).\n        key: usually 'training' , but can try 'validation' or 'testing' to get the list data without labels (used in challenges).\n\n    Returns:\n        A tuple of two arrays (training, validation).\n    \"\"\"\n\n    if isinstance(datalist, str):\n        json_data = ConfigParser.load_config_file(datalist)\n    else:\n        json_data = datalist\n\n    dict_data = deepcopy(json_data[key])\n\n    for d in dict_data:\n        for k, _ in d.items():\n            if isinstance(d[k], list):\n                d[k] = [os.path.join(basedir, iv) for iv in d[k]]\n            elif isinstance(d[k], str):\n                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]\n\n    tr = []\n    val = []\n    for d in dict_data:\n        if \"fold\" in d and d[\"fold\"] == fold:\n            val.append(d)\n        else:\n            tr.append(d)\n\n    return tr, val\n\n\ndef verify_report_format(report: dict, report_format: dict) -> bool:\n    \"\"\"\n    Compares the report and the report_format that has only keys.\n\n    Args:\n        report: dict that has real values.\n        report_format: dict that only has keys and list-nested value.\n    \"\"\"\n    for k_fmt, v_fmt in report_format.items():\n        if k_fmt not in report:\n            return False\n\n        v = report[k_fmt]\n\n        if isinstance(v_fmt, list) and isinstance(v, list):\n            if len(v_fmt) != 1:\n                raise UserWarning(\"list length in report_format is not 1\")\n            if len(v_fmt) > 0 and len(v) > 0:\n                return verify_report_format(v[0], v_fmt[0])\n            else:\n                return False\n\n    return True\n\n\ndef algo_to_pickle(algo: Algo, template_path: PathLike | None = None, **algo_meta_data: Any) -> str:\n    \"\"\"\n    Export the Algo object to pickle file.\n\n    Args:\n        algo: Algo-like object.\n        template_path: a str path that is needed to be added to the sys.path to instantiate the class.\n        algo_meta_data: additional keyword to save into the dictionary, for example, model training info\n            such as acc/best_metrics\n\n    Returns:\n        filename of the pickled Algo object\n    \"\"\"\n    data = {\"algo_bytes\": pickle.dumps(algo), \"template_path\": str(template_path)}\n    pkl_filename = os.path.join(algo.get_output_path(), \"algo_object.pkl\")\n    for k, v in algo_meta_data.items():\n        data.update({k: v})\n    data_bytes = pickle.dumps(data)\n    with open(pkl_filename, \"wb\") as f_pi:\n        f_pi.write(data_bytes)\n    return pkl_filename\n\n\ndef algo_from_pickle(pkl_filename: str, template_path: PathLike | None = None, **kwargs: Any) -> Any:\n    \"\"\"\n    Import the Algo object from a pickle file.\n\n    Args:\n        pkl_filename: the name of the pickle file.\n        template_path: a folder containing files to instantiate the Algo. Besides the `template_path`,\n        this function will also attempt to use the `template_path` saved in the pickle file and a directory\n        named `algorithm_templates` in the parent folder of the folder containing the pickle file.\n\n    Returns:\n        algo: the Algo object saved in the pickle file.\n        algo_meta_data: additional keyword saved in the pickle file, for example, acc/best_metrics.\n\n    Raises:\n        ValueError if the pkl_filename does not contain a dict, or the dict does not contain `algo_bytes`.\n        ModuleNotFoundError if it is unable to instantiate the Algo class.\n\n    \"\"\"\n    with open(pkl_filename, \"rb\") as f_pi:\n        data_bytes = f_pi.read()\n    data = pickle.loads(data_bytes)\n\n    if not isinstance(data, dict):\n        raise ValueError(f\"the data object is {data.__class__}. Dict is expected.\")\n\n    if \"algo_bytes\" not in data:\n        raise ValueError(f\"key [algo_bytes] not found in {data}. Unable to instantiate.\")\n\n    algo_bytes = data.pop(\"algo_bytes\")\n    algo_template_path = data.pop(\"template_path\", None)\n\n    template_paths_candidates: list[str] = []\n\n    if os.path.isdir(str(template_path)):\n        template_paths_candidates.append(os.path.abspath(str(template_path)))\n        template_paths_candidates.append(os.path.abspath(os.path.join(str(template_path), \"..\")))\n\n    if os.path.isdir(str(algo_template_path)):\n        template_paths_candidates.append(os.path.abspath(algo_template_path))\n        template_paths_candidates.append(os.path.abspath(os.path.join(algo_template_path, \"..\")))\n\n    pkl_dir = os.path.dirname(pkl_filename)\n    algo_template_path_fuzzy = os.path.join(pkl_dir, \"..\", \"algorithm_templates\")\n\n    if os.path.isdir(algo_template_path_fuzzy):\n        template_paths_candidates.append(os.path.abspath(algo_template_path_fuzzy))\n\n    if len(template_paths_candidates) == 0:\n        # no template_path provided or needed\n        algo = pickle.loads(algo_bytes)\n        algo.template_path = None\n    else:\n        for i, p in enumerate(template_paths_candidates):\n            try:\n                sys.path.append(p)\n                algo = pickle.loads(algo_bytes)\n                break\n            except ModuleNotFoundError as not_found_err:\n                logging.debug(f\"Folder {p} doesn't contain the Algo templates for Algo instantiation.\")\n                sys.path.pop()\n                if i == len(template_paths_candidates) - 1:\n                    raise ValueError(\n                        f\"Failed to instantiate {pkl_filename} with {template_paths_candidates}\"\n                    ) from not_found_err\n        algo.template_path = p\n\n    if os.path.abspath(pkl_dir) != os.path.abspath(algo.get_output_path()):\n        logging.debug(f\"{algo.get_output_path()} is changed. Now override the Algo output_path with: {pkl_dir}.\")\n        algo.output_path = pkl_dir\n\n    algo_meta_data = {}\n    for k, v in data.items():\n        algo_meta_data.update({k: v})\n\n    return algo, algo_meta_data\n\n\ndef list_to_python_fire_arg_str(args: list) -> str:\n    \"\"\"\n    Convert a list of arguments to a string that can be used in python-fire.\n\n    Args:\n        args: the list of arguments.\n\n    Returns:\n        the string that can be used in python-fire.\n    \"\"\"\n    args_str = \",\".join([str(arg) for arg in args])\n    return f\"'{args_str}'\"\n\n\ndef check_and_set_optional_args(params: dict) -> str:\n    \"\"\"convert `params` into '--key_1=value_1 --key_2=value_2 ...'\"\"\"\n    cmd_mod_opt = \"\"\n    for k, v in params.items():\n        if isinstance(v, dict):\n            raise ValueError(\"Nested dict is not supported.\")\n        elif isinstance(v, list):\n            v = list_to_python_fire_arg_str(v)\n        cmd_mod_opt += f\" --{k}={v}\"\n    return cmd_mod_opt\n\n\ndef _prepare_cmd_default(cmd: str, cmd_prefix: str | None = None, **kwargs: Any) -> str:\n    \"\"\"\n    Prepare the command for subprocess to run the script with the given arguments.\n\n    Args:\n        cmd: the command or script to run in the distributed job.\n        cmd_prefix: the command prefix to run the script, e.g., \"python\", \"python -m\", \"python3\", \"/opt/conda/bin/python3.9 \".\n        kwargs: the keyword arguments to be passed to the script.\n\n    Returns:\n        the command to run with ``subprocess``.\n\n    Examples:\n        To prepare a subprocess command\n        \"python train.py run -k --config 'a,b'\", the function can be called as\n        - _prepare_cmd_default(\"train.py run -k\", config=['a','b'])\n        - _prepare_cmd_default(\"train.py run -k --config 'a,b'\")\n\n    \"\"\"\n    params = kwargs.copy()\n\n    if not cmd_prefix or \"None\" in cmd_prefix:  # defaulting to 'python'\n        cmd_prefix = \"python\"\n\n    if not cmd_prefix.endswith(\" \"):\n        cmd_prefix += \" \"  # ensure a space after the command prefix so that the script can be appended\n\n    return cmd_prefix + cmd + check_and_set_optional_args(params)\n\n\ndef _prepare_cmd_torchrun(cmd: str, **kwargs: Any) -> str:\n    \"\"\"\n    Prepare the command for multi-gpu/multi-node job execution using torchrun.\n\n    Args:\n        cmd: the command or script to run in the distributed job.\n        kwargs: the keyword arguments to be passed to the script.\n\n    Returns:\n        the command to append to ``torchrun``\n\n    Examples:\n        For command \"torchrun --nnodes=1 --nproc_per_node=8 train.py run -k --config 'a,b'\",\n        it only prepares command after the torchrun arguments, i.e., \"train.py run -k --config 'a,b'\".\n        The function can be called as\n        - _prepare_cmd_torchrun(\"train.py run -k\", config=['a','b'])\n        - _prepare_cmd_torchrun(\"train.py run -k --config 'a,b'\")\n    \"\"\"\n    params = kwargs.copy()\n    return cmd + check_and_set_optional_args(params)\n\n\ndef _prepare_cmd_bcprun(cmd: str, cmd_prefix: str | None = None, **kwargs: Any) -> str:\n    \"\"\"\n    Prepare the command for distributed job running using bcprun.\n\n    Args:\n        script: the script to run in the distributed job.\n        cmd_prefix: the command prefix to run the script, e.g., \"python\".\n        kwargs: the keyword arguments to be passed to the script.\n\n    Returns:\n        The command to run the script in the distributed job.\n\n    Examples:\n        For command \"bcprun -n 2 -p 8 -c python train.py run -k --config 'a,b'\",\n        it only prepares command after the bcprun arguments, i.e., \"train.py run -k --config 'a,b'\".\n        the function can be called as\n        - _prepare_cmd_bcprun(\"train.py run -k\", config=['a','b'], n=2, p=8)\n        - _prepare_cmd_bcprun(\"train.py run -k --config 'a,b'\", n=2, p=8)\n    \"\"\"\n\n    return _prepare_cmd_default(cmd, cmd_prefix=cmd_prefix, **kwargs)\n\n\ndef _run_cmd_torchrun(cmd: str, **kwargs: Any) -> subprocess.CompletedProcess:\n    \"\"\"\n    Run the command with torchrun.\n\n    Args:\n        cmd: the command to run. Typically it is prepared by ``_prepare_cmd_torchrun``.\n        kwargs: the keyword arguments to be passed to the ``torchrun``.\n\n    Return:\n        the return code of the subprocess command.\n    \"\"\"\n    params = kwargs.copy()\n\n    cmd_list = cmd.split()\n\n    # append arguments to the command list\n    torchrun_list = [\"torchrun\"]\n    required_args = [\"nnodes\", \"nproc_per_node\"]\n    for arg in required_args:\n        if arg not in params:\n            raise ValueError(f\"Missing required argument {arg} for torchrun.\")\n        torchrun_list += [f\"--{arg}\", str(params.pop(arg))]\n    torchrun_list += cmd_list\n    return run_cmd(torchrun_list, run_cmd_verbose=True, **params)\n\n\ndef _run_cmd_bcprun(cmd: str, **kwargs: Any) -> subprocess.CompletedProcess:\n    \"\"\"\n    Run the command with bcprun.\n\n    Args:\n        cmd: the command to run. Typically it is prepared by ``_prepare_cmd_bcprun``.\n        kwargs: the keyword arguments to be passed to the ``bcprun``.\n\n    Returns:\n        the return code of the subprocess command.\n    \"\"\"\n    params = kwargs.copy()\n    cmd_list = [\"bcprun\"]\n    required_args = [\"n\", \"p\"]\n    for arg in required_args:\n        if arg not in params:\n            raise ValueError(f\"Missing required argument {arg} for bcprun.\")\n        cmd_list += [f\"-{arg}\", str(params.pop(arg))]\n    cmd_list.extend([\"-c\", cmd])\n    return run_cmd(cmd_list, run_cmd_verbose=True, **params)\n"
  },
  {
    "path": "monai/bundle/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable\nfrom .config_parser import ConfigParser\nfrom .properties import InferProperties, MetaProperties, TrainProperties\nfrom .reference_resolver import ReferenceResolver\nfrom .scripts import (\n    ckpt_export,\n    create_workflow,\n    download,\n    download_large_files,\n    get_all_bundles_list,\n    get_bundle_info,\n    get_bundle_versions,\n    init_bundle,\n    load,\n    onnx_export,\n    push_to_hf_hub,\n    run,\n    run_workflow,\n    trt_export,\n    update_kwargs,\n    verify_metadata,\n    verify_net_in_out,\n)\nfrom .utils import (\n    DEFAULT_EXP_MGMT_SETTINGS,\n    DEFAULT_MLFLOW_SETTINGS,\n    EXPR_KEY,\n    ID_REF_KEY,\n    ID_SEP_KEY,\n    MACRO_KEY,\n    load_bundle_config,\n)\nfrom .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow\n"
  },
  {
    "path": "monai/bundle/__main__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.bundle.scripts import (\n    ckpt_export,\n    download,\n    download_large_files,\n    init_bundle,\n    onnx_export,\n    run,\n    run_workflow,\n    trt_export,\n    verify_metadata,\n    verify_net_in_out,\n)\n\nif __name__ == \"__main__\":\n    from monai.utils import optional_import\n\n    fire, _ = optional_import(\"fire\")\n    fire.Fire()\n"
  },
  {
    "path": "monai/bundle/config_item.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport ast\nimport inspect\nimport sys\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Mapping, Sequence\nfrom importlib import import_module\nfrom pprint import pformat\nfrom typing import Any\n\nfrom monai.bundle.utils import EXPR_KEY\nfrom monai.utils import CompInitMode, ensure_tuple, first, instantiate, optional_import, run_debug, run_eval\n\n__all__ = [\"ComponentLocator\", \"ConfigItem\", \"ConfigExpression\", \"ConfigComponent\", \"Instantiable\"]\n\n\nclass Instantiable(ABC):\n    \"\"\"\n    Base class for an instantiable object.\n    \"\"\"\n\n    @abstractmethod\n    def is_disabled(self, *args: Any, **kwargs: Any) -> bool:\n        \"\"\"\n        Return a boolean flag to indicate whether the object should be instantiated.\n        \"\"\"\n        raise NotImplementedError(f\"subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def instantiate(self, *args: Any, **kwargs: Any) -> object:\n        \"\"\"\n        Instantiate the target component and return the instance.\n        \"\"\"\n        raise NotImplementedError(f\"subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass ComponentLocator:\n    \"\"\"\n    Scan all the available classes and functions in the MONAI package and map them with the module paths in a table.\n    It's used to locate the module path for provided component name.\n\n    Args:\n        excludes: if any string of the `excludes` exists in the full module name, don't import this module.\n\n    \"\"\"\n\n    MOD_START = \"monai\"\n\n    def __init__(self, excludes: Sequence[str] | str | None = None):\n        self.excludes = [] if excludes is None else ensure_tuple(excludes)\n        self._components_table: dict[str, list] | None = None\n\n    def _find_module_names(self) -> list[str]:\n        \"\"\"\n        Find all the modules start with MOD_START and don't contain any of `excludes`.\n\n        \"\"\"\n        return [m for m in sys.modules if m.startswith(self.MOD_START) and all(s not in m for s in self.excludes)]\n\n    def _find_classes_or_functions(self, modnames: Sequence[str] | str) -> dict[str, list]:\n        \"\"\"\n        Find all the classes and functions in the modules with specified `modnames`.\n\n        Args:\n            modnames: names of the target modules to find all the classes and functions.\n\n        \"\"\"\n        table: dict[str, list] = {}\n        # all the MONAI modules are already loaded by `load_submodules`\n        for modname in ensure_tuple(modnames):\n            try:\n                # scan all the classes and functions in the module\n                module = import_module(modname)\n                for name, obj in inspect.getmembers(module):\n                    if (inspect.isclass(obj) or inspect.isfunction(obj)) and obj.__module__ == modname:\n                        if name not in table:\n                            table[name] = []\n                        table[name].append(modname)\n            except ModuleNotFoundError:\n                pass\n        return table\n\n    def get_component_module_name(self, name: str) -> list[str] | str | None:\n        \"\"\"\n        Get the full module name of the class or function with specified ``name``.\n        If target component name exists in multiple packages or modules, return a list of full module names.\n\n        Args:\n            name: name of the expected class or function.\n\n        \"\"\"\n        if not isinstance(name, str):\n            raise ValueError(f\"`name` must be a valid string, but got: {name}.\")\n        if self._components_table is None:\n            # init component and module mapping table\n            self._components_table = self._find_classes_or_functions(self._find_module_names())\n\n        mods: list[str] | str | None = self._components_table.get(name)\n        if isinstance(mods, list) and len(mods) == 1:\n            mods = mods[0]\n        return mods\n\n\nclass ConfigItem:\n    \"\"\"\n    Basic data structure to represent a configuration item.\n\n    A `ConfigItem` instance can optionally have a string id, so that other items can refer to it.\n    It has a build-in `config` property to store the configuration object.\n\n    Args:\n        config: content of a config item, can be objects of any types,\n            a configuration resolver may interpret the content to generate a configuration object.\n        id: name of the current config item, defaults to empty string.\n\n    \"\"\"\n\n    def __init__(self, config: Any, id: str = \"\") -> None:\n        self.config = config\n        self.id = id\n\n    def get_id(self) -> str:\n        \"\"\"\n        Get the ID name of current config item, useful to identify config items during parsing.\n\n        \"\"\"\n        return self.id\n\n    def update_config(self, config: Any) -> None:\n        \"\"\"\n        Replace the content of `self.config` with new `config`.\n        A typical usage is to modify the initial config content at runtime.\n\n        Args:\n            config: content of a `ConfigItem`.\n\n        \"\"\"\n        self.config = config\n\n    def get_config(self):\n        \"\"\"\n        Get the config content of current config item.\n\n        \"\"\"\n        return self.config\n\n    def __repr__(self) -> str:\n        return f\"{type(self).__name__}: \\n{pformat(self.config)}\"\n\n\nclass ConfigComponent(ConfigItem, Instantiable):\n    \"\"\"\n    Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to\n    represent a component of `class` or `function` and supports instantiation.\n\n    Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:\n\n        - class or function identifier of the python module, specified by ``\"_target_\"``,\n          indicating a monai built-in Python class or function such as ``\"LoadImageDict\"``,\n          or a full module name, e.g. ``\"monai.transforms.LoadImageDict\"``, or a callable, e.g. ``\"$@model.forward\"``.\n        - ``\"_requires_\"`` (optional): specifies reference IDs (string starts with ``\"@\"``) or ``ConfigExpression``\n          of the dependencies for this ``ConfigComponent`` object. These dependencies will be\n          evaluated/instantiated before this object is instantiated.  It is useful when the\n          component doesn't explicitly depend on the other `ConfigItems` via its arguments,\n          but requires the dependencies to be instantiated/evaluated beforehand.\n        - ``\"_disabled_\"`` (optional): a flag to indicate whether to skip the instantiation.\n        - ``\"_desc_\"`` (optional): free text descriptions of the component for code readability.\n        - ``\"_mode_\"`` (optional): operating mode for invoking the callable ``component`` defined by ``\"_target_\"``:\n\n            - ``\"default\"``: returns ``component(**kwargs)``\n            - ``\"callable\"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``\n            - ``\"debug\"``: returns ``pdb.runcall(component, **kwargs)``\n\n    Other fields in the config content are input arguments to the python module.\n\n    .. code-block:: python\n\n        from monai.bundle import ComponentLocator, ConfigComponent\n\n        locator = ComponentLocator(excludes=[\"modules_to_exclude\"])\n        config = {\n            \"_target_\": \"LoadImaged\",\n            \"keys\": [\"image\", \"label\"]\n        }\n\n        configer = ConfigComponent(config, id=\"test\", locator=locator)\n        image_loader = configer.instantiate()\n        print(image_loader)  # <monai.transforms.io.dictionary.LoadImaged object at 0x7fba7ad1ee50>\n\n    Args:\n        config: content of a config item.\n        id: name of the current config item, defaults to empty string.\n        locator: a ``ComponentLocator`` to convert a module name string into the actual python module.\n            if `None`, a ``ComponentLocator(excludes=excludes)`` will be used.\n        excludes: if ``locator`` is None, create a new ``ComponentLocator`` with ``excludes``.\n            See also: :py:class:`monai.bundle.ComponentLocator`.\n\n    \"\"\"\n\n    non_arg_keys = {\"_target_\", \"_disabled_\", \"_requires_\", \"_desc_\", \"_mode_\"}\n\n    def __init__(\n        self,\n        config: Any,\n        id: str = \"\",\n        locator: ComponentLocator | None = None,\n        excludes: Sequence[str] | str | None = None,\n    ) -> None:\n        super().__init__(config=config, id=id)\n        self.locator = ComponentLocator(excludes=excludes) if locator is None else locator\n\n    @staticmethod\n    def is_instantiable(config: Any) -> bool:\n        \"\"\"\n        Check whether this config represents a `class` or `function` that is to be instantiated.\n\n        Args:\n            config: input config content to check.\n\n        \"\"\"\n        return isinstance(config, Mapping) and \"_target_\" in config\n\n    def resolve_module_name(self):\n        \"\"\"\n        Resolve the target module name from current config content.\n        The config content must have ``\"_target_\"`` key.\n\n        \"\"\"\n        config = dict(self.get_config())\n        target = config.get(\"_target_\")\n        if not isinstance(target, str):\n            return target  # for feature discussed in project-monai/monai#5852\n\n        module = self.locator.get_component_module_name(target)\n        if module is None:\n            # target is the full module name, no need to parse\n            return target\n\n        if isinstance(module, list):\n            warnings.warn(\n                f\"there are more than 1 component have name `{target}`: {module}, use the first one `{module[0]}.\"\n                f\" if want to use others, please set its full module path in `_target_` directly.\"\n            )\n            module = module[0]\n        return f\"{module}.{target}\"\n\n    def resolve_args(self):\n        \"\"\"\n        Utility function used in `instantiate()` to resolve the arguments from current config content.\n\n        \"\"\"\n        return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys}\n\n    def is_disabled(self) -> bool:\n        \"\"\"\n        Utility function used in `instantiate()` to check whether to skip the instantiation.\n\n        \"\"\"\n        _is_disabled = self.get_config().get(\"_disabled_\", False)\n        return _is_disabled.lower().strip() == \"true\" if isinstance(_is_disabled, str) else bool(_is_disabled)\n\n    def instantiate(self, **kwargs: Any) -> object:\n        \"\"\"\n        Instantiate component based on ``self.config`` content.\n        The target component must be a `class` or a `function`, otherwise, return `None`.\n\n        Args:\n            kwargs: args to override / add the config args when instantiation.\n\n        \"\"\"\n        if not self.is_instantiable(self.get_config()) or self.is_disabled():\n            # if not a class or function or marked as `disabled`, skip parsing and return `None`\n            return None\n\n        modname = self.resolve_module_name()\n        mode = self.get_config().get(\"_mode_\", CompInitMode.DEFAULT)\n        args = self.resolve_args()\n        args.update(kwargs)\n        return instantiate(modname, mode, **args)\n\n\nclass ConfigExpression(ConfigItem):\n    \"\"\"\n    Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression\n    (execute based on ``eval()``, or import the module to the `globals` if it's an import statement).\n\n    See also:\n\n        - https://docs.python.org/3/library/functions.html#eval.\n\n    For example:\n\n    .. code-block:: python\n\n        import monai\n        from monai.bundle import ConfigExpression\n\n        config = \"$monai.__version__\"\n        expression = ConfigExpression(config, id=\"test\", globals={\"monai\": monai})\n        print(expression.evaluate())\n\n    Args:\n        config: content of a config item.\n        id: name of current config item, defaults to empty string.\n        globals: additional global context to evaluate the string.\n\n    \"\"\"\n\n    prefix = EXPR_KEY\n    run_eval = run_eval\n\n    def __init__(self, config: Any, id: str = \"\", globals: dict | None = None) -> None:\n        super().__init__(config=config, id=id)\n        self.globals = globals if globals is not None else {}\n\n    def _parse_import_string(self, import_string: str) -> Any | None:\n        \"\"\"parse single import statement such as \"from monai.transforms import Resize\"\"\"\n        node = first(ast.iter_child_nodes(ast.parse(import_string)))\n        if not isinstance(node, (ast.Import, ast.ImportFrom)):\n            return None\n        if len(node.names) < 1:\n            return None\n        if len(node.names) > 1:\n            warnings.warn(f\"ignoring multiple import alias '{import_string}'.\")\n        name, asname = f\"{node.names[0].name}\", node.names[0].asname\n        asname = name if asname is None else f\"{asname}\"\n        if isinstance(node, ast.ImportFrom):\n            self.globals[asname], _ = optional_import(f\"{node.module}\", name=f\"{name}\")\n            return self.globals[asname]\n        if isinstance(node, ast.Import):\n            self.globals[asname], _ = optional_import(f\"{name}\")\n            return self.globals[asname]\n        return None\n\n    def evaluate(self, globals: dict | None = None, locals: dict | None = None) -> str | Any | None:\n        \"\"\"\n        Execute the current config content and return the result if it is expression, based on Python `eval()`.\n        For more details: https://docs.python.org/3/library/functions.html#eval.\n\n        Args:\n            globals: besides ``self.globals``, other global symbols used in the expression at runtime.\n            locals: besides ``globals``, may also have some local symbols used in the expression at runtime.\n\n        \"\"\"\n        value = self.get_config()\n        if not ConfigExpression.is_expression(value):\n            return None\n        optional_module = self._parse_import_string(value[len(self.prefix) :])\n        if optional_module is not None:\n            return optional_module\n        if not self.run_eval:\n            return f\"{value[len(self.prefix) :]}\"\n        globals_ = dict(self.globals)\n        if globals is not None:\n            for k, v in globals.items():\n                if k in globals_:\n                    warnings.warn(f\"the new global variable `{k}` conflicts with `self.globals`, override it.\")\n                globals_[k] = v\n        if not run_debug:\n            try:\n                return eval(value[len(self.prefix) :], globals_, locals)\n            except Exception as e:\n                raise RuntimeError(f\"Failed to evaluate {self}\") from e\n        warnings.warn(\n            f\"\\n\\npdb: value={value}\\n\"\n            f\"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\\n\"\n        )\n        import pdb\n\n        pdb.run(value[len(self.prefix) :], globals_, locals)\n        return None\n\n    @classmethod\n    def is_expression(cls, config: dict | list | str) -> bool:\n        \"\"\"\n        Check whether the config is an executable expression string.\n        Currently, a string starts with ``\"$\"`` character is interpreted as an expression.\n\n        Args:\n            config: input config content to check.\n\n        \"\"\"\n        return isinstance(config, str) and config.startswith(cls.prefix)\n\n    @classmethod\n    def is_import_statement(cls, config: dict | list | str) -> bool:\n        \"\"\"\n        Check whether the config is an import statement (a special case of expression).\n\n        Args:\n            config: input config content to check.\n        \"\"\"\n        if not cls.is_expression(config):\n            return False\n        if \"import\" not in config:\n            return False\n        return isinstance(\n            first(ast.iter_child_nodes(ast.parse(f\"{config[len(cls.prefix) :]}\"))), (ast.Import, ast.ImportFrom)\n        )\n"
  },
  {
    "path": "monai/bundle/config_parser.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport re\nfrom collections.abc import Sequence\nfrom copy import deepcopy\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any\n\nfrom monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem\nfrom monai.bundle.reference_resolver import ReferenceResolver\nfrom monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv\nfrom monai.config import PathLike\nfrom monai.utils import ensure_tuple, look_up_option, optional_import\nfrom monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates\n\nif TYPE_CHECKING:\n    import yaml\nelse:\n    yaml, _ = optional_import(\"yaml\")\n\n__all__ = [\"ConfigParser\"]\n\n_default_globals = {\"monai\": \"monai\", \"torch\": \"torch\", \"np\": \"numpy\", \"numpy\": \"numpy\"}\n\n\nclass ConfigParser:\n    \"\"\"\n    The primary configuration parser. It traverses a structured config (in the form of nested Python dict or list),\n    creates ``ConfigItem``, and assign unique IDs according to the structures.\n\n    This class provides convenient access to the set of ``ConfigItem`` of the config by ID.\n    A typical workflow of config parsing is as follows:\n\n        - Initialize ``ConfigParser`` with the ``config`` source.\n        - Call ``get_parsed_content()`` to get expected component with `id`.\n\n    .. code-block:: python\n\n        from monai.bundle import ConfigParser\n\n        config = {\n            \"my_dims\": 2,\n            \"dims_1\": \"$@my_dims + 1\",\n            \"my_xform\": {\"_target_\": \"LoadImage\"},\n            \"my_net\": {\"_target_\": \"BasicUNet\", \"spatial_dims\": \"@dims_1\", \"in_channels\": 1, \"out_channels\": 4},\n            \"trainer\": {\"_target_\": \"SupervisedTrainer\", \"network\": \"@my_net\", \"preprocessing\": \"@my_xform\"}\n        }\n        # in the example $@my_dims + 1 is an expression, which adds 1 to the value of @my_dims\n        parser = ConfigParser(config)\n\n        # get/set configuration content, the set method should happen before calling parse()\n        print(parser[\"my_net\"][\"in_channels\"])  # original input channels 1\n        parser[\"my_net\"][\"in_channels\"] = 4  # change input channels to 4\n        print(parser[\"my_net\"][\"in_channels\"])\n\n        # instantiate the network component\n        parser.parse(True)\n        net = parser.get_parsed_content(\"my_net\", instantiate=True)\n        print(net)\n\n        # also support to get the configuration content of parsed `ConfigItem`\n        trainer = parser.get_parsed_content(\"trainer\", instantiate=False)\n        print(trainer)\n\n    Args:\n        config: input config source to parse.\n        excludes: when importing modules to instantiate components,\n            excluding components from modules specified in ``excludes``.\n        globals: pre-import packages as global variables to ``ConfigExpression``,\n            so that expressions, for example, ``\"$monai.data.list_data_collate\"`` can use ``monai`` modules.\n            The current supported globals and alias names are\n            ``{\"monai\": \"monai\", \"torch\": \"torch\", \"np\": \"numpy\", \"numpy\": \"numpy\"}``.\n            These are MONAI's minimal dependencies. Additional packages could be included with `globals={\"itk\": \"itk\"}`.\n            Set it to ``False`` to disable `self.globals` module importing.\n\n    See also:\n\n        - :py:class:`monai.bundle.ConfigItem`\n        - :py:class:`monai.bundle.scripts.run`\n\n    \"\"\"\n\n    suffixes = (\"json\", \"yaml\", \"yml\")\n    suffix_match = rf\".*\\.({'|'.join(suffixes)})\"\n    path_match = rf\"({suffix_match}$)\"\n    # match relative id names, e.g. \"@#data\", \"@##transform#1\"\n    relative_id_prefix = re.compile(rf\"(?:{ID_REF_KEY}|{MACRO_KEY}){ID_SEP_KEY}+\")\n    meta_key = \"_meta_\"  # field key to save metadata\n\n    def __init__(\n        self,\n        config: Any = None,\n        excludes: Sequence[str] | str | None = None,\n        globals: dict[str, Any] | None | bool = None,\n    ):\n        self.config: ConfigItem | None = None\n        self.globals: dict[str, Any] = {}\n        _globals = _default_globals.copy()\n        if isinstance(_globals, dict) and globals not in (None, False):\n            _globals.update(globals)  # type: ignore\n        if _globals is not None and globals is not False:\n            for k, v in _globals.items():\n                self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v\n\n        self.locator = ComponentLocator(excludes=excludes)\n        self.ref_resolver = ReferenceResolver()\n        if config is None:\n            config = {self.meta_key: {}}\n        self.set(config=self.ref_resolver.normalize_meta_id(config))\n\n    def __repr__(self):\n        return f\"{self.config}\"\n\n    def __getattr__(self, id):\n        \"\"\"\n        Get the parsed result of ``ConfigItem`` with the specified ``id``\n        with default arguments (e.g. ``lazy=True``, ``instantiate=True`` and ``eval_expr=True``).\n\n        Args:\n            id: id of the ``ConfigItem``.\n\n        See also:\n             :py:meth:`get_parsed_content`\n        \"\"\"\n        return self.get_parsed_content(id)\n\n    def __getitem__(self, id: str | int) -> Any:\n        \"\"\"\n        Get the config by id.\n\n        Args:\n            id: id of the ``ConfigItem``, ``\"::\"`` (or ``\"#\"``) in id are interpreted as special characters to\n                go one level further into the nested structures.\n                Use digits indexing from \"0\" for list or other strings for dict.\n                For example: ``\"xform::5\"``, ``\"net::channels\"``. ``\"\"`` indicates the entire ``self.config``.\n\n        \"\"\"\n        if id == \"\":\n            return self.config\n        config = self.config\n        for k in ReferenceResolver.split_id(id):\n            if not isinstance(config, (dict, list)):\n                raise ValueError(f\"config must be dict or list for key `{k}`, but got {type(config)}: {config}.\")\n            try:\n                config = (\n                    look_up_option(k, config, print_all_options=False) if isinstance(config, dict) else config[int(k)]\n                )\n            except ValueError as e:\n                raise KeyError(f\"query key: {k}\") from e\n        return config\n\n    def __setitem__(self, id: str | int, config: Any) -> None:\n        \"\"\"\n        Set config by ``id``.  Note that this method should be used before ``parse()`` or ``get_parsed_content()``\n        to ensure the updates are included in the parsed content.\n\n        Args:\n            id: id of the ``ConfigItem``, ``\"::\"`` (or ``\"#\"``) in id are interpreted as special characters to\n                go one level further into the nested structures.\n                Use digits indexing from \"0\" for list or other strings for dict.\n                For example: ``\"xform::5\"``, ``\"net::channels\"``. ``\"\"`` indicates the entire ``self.config``.\n            config: config to set at location ``id``.\n\n        \"\"\"\n        if id == \"\":\n            self.config = config\n            self.ref_resolver.reset()\n            return\n        last_id, base_id = ReferenceResolver.split_id(id, last=True)\n        # get the last parent level config item and replace it\n        conf_ = self[last_id]\n\n        indexing = base_id if isinstance(conf_, dict) else int(base_id)\n        conf_[indexing] = config\n        self.ref_resolver.reset()\n        return\n\n    def get(self, id: str = \"\", default: Any | None = None) -> Any:\n        \"\"\"\n        Get the config by id.\n\n        Args:\n            id: id to specify the expected position. See also :py:meth:`__getitem__`.\n            default: default value to return if the specified ``id`` is invalid.\n\n        \"\"\"\n        try:\n            return self[id]\n        except (KeyError, IndexError, ValueError):  # Index error for integer indexing\n            return default\n\n    def set(self, config: Any, id: str = \"\", recursive: bool = True) -> None:\n        \"\"\"\n        Set config by ``id``.\n\n        Args:\n            config: config to set at location ``id``.\n            id: id to specify the expected position. See also :py:meth:`__setitem__`.\n            recursive: if the nested id doesn't exist, whether to recursively create the nested items in the config.\n                default to `True`. for the nested id, only support `dict` for the missing section.\n\n        \"\"\"\n        keys = ReferenceResolver.split_id(id)\n        conf_ = self.get()\n        if recursive:\n            if conf_ is None:\n                self.config = conf_ = {}  # type: ignore\n            for k in keys[:-1]:\n                if isinstance(conf_, dict) and k not in conf_:\n                    conf_[k] = {}\n                conf_ = conf_[k if isinstance(conf_, dict) else int(k)]\n        self[ReferenceResolver.normalize_id(id)] = self.ref_resolver.normalize_meta_id(config)\n\n    def update(self, pairs: dict[str, Any]) -> None:\n        \"\"\"\n        Set the ``id`` and the corresponding config content in pairs, see also :py:meth:`__setitem__`.\n        For example, ``parser.update({\"train::epoch\": 100, \"train::lr\": 0.02})``\n\n        Args:\n            pairs: dictionary of `id` and config pairs.\n\n        \"\"\"\n        for k, v in pairs.items():\n            self[k] = v\n\n    def __contains__(self, id: str | int) -> bool:\n        \"\"\"\n        Returns True if `id` is stored in this configuration.\n\n        Args:\n            id: id to specify the expected position. See also :py:meth:`__getitem__`.\n        \"\"\"\n        try:\n            _ = self[id]\n            return True\n        except (KeyError, IndexError, ValueError):  # Index error for integer indexing\n            return False\n\n    def parse(self, reset: bool = True) -> None:\n        \"\"\"\n        Recursively resolve `self.config` to replace the macro tokens with target content.\n        Then recursively parse the config source, add every item as ``ConfigItem`` to the reference resolver.\n\n        Args:\n            reset: whether to reset the ``reference_resolver`` before parsing. Defaults to `True`.\n\n        \"\"\"\n        if reset:\n            self.ref_resolver.reset()\n        self.resolve_macro_and_relative_ids()\n        self._do_parse(config=self.get())\n\n    def get_parsed_content(self, id: str = \"\", **kwargs: Any) -> Any:\n        \"\"\"\n        Get the parsed result of ``ConfigItem`` with the specified ``id``.\n\n            - If the item is ``ConfigComponent`` and ``instantiate=True``, the result is the instance.\n            - If the item is ``ConfigExpression`` and ``eval_expr=True``, the result is the evaluated output.\n            - Else, the result is the configuration content of `ConfigItem`.\n\n        Args:\n            id: id of the ``ConfigItem``, ``\"::\"`` (or ``\"#\"``) in id are interpreted as special characters to\n                go one level further into the nested structures.\n                Use digits indexing from \"0\" for list or other strings for dict.\n                For example: ``\"xform::5\"``, ``\"net::channels\"``. ``\"\"`` indicates the entire ``self.config``.\n            kwargs: additional keyword arguments to be passed to ``_resolve_one_item``.\n                Currently support ``lazy`` (whether to retain the current config cache, default to `True`),\n                ``instantiate`` (whether to instantiate the `ConfigComponent`, default to `True`) and\n                ``eval_expr`` (whether to evaluate the `ConfigExpression`, default to `True`), ``default``\n                (the default config item if the `id` is not in the config content).\n\n        \"\"\"\n        if not self.ref_resolver.is_resolved():\n            # not parsed the config source yet, parse it\n            self.parse(reset=True)\n        elif not kwargs.get(\"lazy\", True):\n            self.parse(reset=not kwargs.get(\"lazy\", True))\n        return self.ref_resolver.get_resolved_content(id=id, **kwargs)\n\n    def read_meta(self, f: PathLike | Sequence[PathLike] | dict, **kwargs: Any) -> None:\n        \"\"\"\n        Read the metadata from specified JSON or YAML file.\n        The metadata as a dictionary will be stored at ``self.config[\"_meta_\"]``.\n\n        Args:\n            f: filepath of the metadata file, the content must be a dictionary,\n                if providing a list of files, will merge the content of them.\n                if providing a dictionary directly, use it as metadata.\n            kwargs: other arguments for ``json.load`` or ``yaml.safe_load``, depends on the file format.\n\n        \"\"\"\n        self.set(self.load_config_files(f, **kwargs), self.meta_key)\n\n    def read_config(self, f: PathLike | Sequence[PathLike] | dict, **kwargs: Any) -> None:\n        \"\"\"\n        Read the config from specified JSON/YAML file or a dictionary and\n        override the config content in the `self.config` dictionary.\n\n        Args:\n            f: filepath of the config file, the content must be a dictionary,\n                if providing a list of files, wil merge the content of them.\n                if providing a dictionary directly, use it as config.\n            kwargs: other arguments for ``json.load`` or ``yaml.safe_load``, depends on the file format.\n\n        \"\"\"\n        content = {self.meta_key: self.get(self.meta_key, {})}\n        content.update(self.load_config_files(f, **kwargs))\n        self.set(config=content)\n\n    def _do_resolve(self, config: Any, id: str = \"\") -> Any:\n        \"\"\"\n        Recursively resolve `self.config` to replace the relative ids with absolute ids, for example,\n        `@##A` means `A` in the upper level. and replace the macro tokens with target content,\n        The macro tokens start with \"%\", can be from another structured file, like:\n        ``\"%default_net\"``, ``\"%/data/config.json#net\"``.\n        Note that the macro replacement doesn't support recursive macro tokens.\n\n        Args:\n            config: input config file to resolve.\n            id: id of the ``ConfigItem``, ``\"::\"`` (or ``\"#\"``) in id are interpreted as special characters to\n                go one level further into the nested structures.\n                Use digits indexing from \"0\" for list or other strings for dict.\n                For example: ``\"xform::5\"``, ``\"net::channels\"``. ``\"\"`` indicates the entire ``self.config``.\n\n        \"\"\"\n        if isinstance(config, (dict, list)):\n            for k, sub_id, v in self.ref_resolver.iter_subconfigs(id=id, config=config):\n                config[k] = self._do_resolve(v, sub_id)  # type: ignore\n        if isinstance(config, str):\n            config = self.resolve_relative_ids(id, config)\n            if config.startswith(MACRO_KEY):\n                path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :])\n                parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path))\n                # deepcopy to ensure the macro replacement is independent config content\n                return deepcopy(parser[ids])\n        return config\n\n    def resolve_macro_and_relative_ids(self):\n        \"\"\"\n        Recursively resolve `self.config` to replace the relative ids with absolute ids, for example,\n        `@##A` means `A` in the upper level. and replace the macro tokens with target content,\n        The macro tokens are marked as starting with \"%\", can be from another structured file, like:\n        ``\"%default_net\"``, ``\"%/data/config.json::net\"``.\n\n        \"\"\"\n        self.set(self._do_resolve(config=self.get()))\n\n    def _do_parse(self, config: Any, id: str = \"\") -> None:\n        \"\"\"\n        Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver.\n\n        Args:\n            config: config source to parse.\n            id: id of the ``ConfigItem``, ``\"::\"`` (or ``\"#\"``) in id are interpreted as special characters to\n                go one level further into the nested structures.\n                Use digits indexing from \"0\" for list or other strings for dict.\n                For example: ``\"xform::5\"``, ``\"net::channels\"``. ``\"\"`` indicates the entire ``self.config``.\n\n        \"\"\"\n        if isinstance(config, (dict, list)):\n            for _, sub_id, v in self.ref_resolver.iter_subconfigs(id=id, config=config):\n                self._do_parse(config=v, id=sub_id)\n\n        if ConfigComponent.is_instantiable(config):\n            self.ref_resolver.add_item(ConfigComponent(config=config, id=id, locator=self.locator))\n        elif ConfigExpression.is_expression(config):\n            self.ref_resolver.add_item(ConfigExpression(config=config, id=id, globals=self.globals))\n        else:\n            self.ref_resolver.add_item(ConfigItem(config=config, id=id))\n\n    @classmethod\n    def load_config_file(cls, filepath: PathLike, **kwargs: Any) -> dict:\n        \"\"\"\n        Load a single config file with specified file path (currently support JSON and YAML files).\n\n        Args:\n            filepath: path of target file to load, supported postfixes: `.json`, `.yml`, `.yaml`.\n            kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format.\n\n        \"\"\"\n        if not filepath:\n            return {}\n        _filepath: str = str(Path(filepath))\n        if not re.compile(cls.path_match, re.IGNORECASE).findall(_filepath):\n            raise ValueError(f'unknown file input: \"{filepath}\"')\n        with open(_filepath) as f:\n            if _filepath.lower().endswith(cls.suffixes[0]):\n                return json.load(f, object_pairs_hook=check_key_duplicates, **kwargs)  # type: ignore[no-any-return]\n            if _filepath.lower().endswith(cls.suffixes[1:]):\n                return yaml.load(f, CheckKeyDuplicatesYamlLoader, **kwargs)  # type: ignore[no-any-return]\n            raise ValueError(f\"only support JSON or YAML config file so far, got name {_filepath}.\")\n\n    @classmethod\n    def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs: Any) -> dict:\n        \"\"\"\n        Load multiple config files into a single config dict.\n        The latter config file in the list will override or add the former config file.\n        ``\"::\"`` (or ``\"#\"``) in the config keys are interpreted as special characters to go one level\n        further into the nested structures.\n\n        Args:\n            files: path of target files to load, supported postfixes: `.json`, `.yml`, `.yaml`.\n                if providing a list of files, will merge the content of them.\n                if providing a string with comma separated file paths, will merge the content of them.\n                if providing a dictionary, return it directly.\n            kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format.\n        \"\"\"\n        if isinstance(files, dict):  # already a config dict\n            return files\n        parser = ConfigParser(config={})\n        if isinstance(files, str) and not Path(files).is_file() and \",\" in files:\n            files = files.split(\",\")\n        for i in ensure_tuple(files):\n            config_dict = cls.load_config_file(i, **kwargs)\n            for k, v in config_dict.items():\n                merge_kv(parser, k, v)\n\n        return parser.get()  # type: ignore\n\n    @classmethod\n    def export_config_file(cls, config: dict, filepath: PathLike, fmt: str = \"json\", **kwargs: Any) -> None:\n        \"\"\"\n        Export the config content to the specified file path (currently support JSON and YAML files).\n\n        Args:\n            config: source config content to export.\n            filepath: target file path to save.\n            fmt: format of config content, currently support ``\"json\"`` and ``\"yaml\"``.\n            kwargs: other arguments for ``json.dump`` or ``yaml.safe_dump``, depends on the file format.\n\n        \"\"\"\n        _filepath: str = str(Path(filepath))\n        writer = look_up_option(fmt.lower(), {\"json\", \"yaml\", \"yml\"})\n        with open(_filepath, \"w\") as f:\n            if writer == \"json\":\n                json.dump(config, f, **kwargs)\n                return\n            if writer == \"yaml\" or writer == \"yml\":\n                return yaml.safe_dump(config, f, **kwargs)\n            raise ValueError(f\"only support JSON or YAML config file so far, got {writer}.\")\n\n    @classmethod\n    def split_path_id(cls, src: str) -> tuple[str, str]:\n        \"\"\"\n        Split `src` string into two parts: a config file path and component id.\n        The file path should end with `(json|yaml|yml)`. The component id should be separated by `::` if it exists.\n        If no path or no id, return \"\".\n\n        Args:\n            src: source string to split.\n\n        \"\"\"\n        src = ReferenceResolver.normalize_id(src)\n        result = re.compile(rf\"({cls.suffix_match}(?=(?:{ID_SEP_KEY}.*)|$))\", re.IGNORECASE).findall(src)\n        if not result:\n            return \"\", src  # the src is a pure id\n        path_name = result[0][0]  # at most one path_name\n        _, ids = src.rsplit(path_name, 1)\n        return path_name, ids[len(ID_SEP_KEY) :] if ids.startswith(ID_SEP_KEY) else \"\"\n\n    @classmethod\n    def resolve_relative_ids(cls, id: str, value: str) -> str:\n        \"\"\"\n        To simplify the reference or macro tokens ID in the nested config content, it's available to use\n        relative ID name which starts with the `ID_SEP_KEY`, for example, \"@#A\" means `A` in the same level,\n        `@##A` means `A` in the upper level.\n        It resolves the relative ids to absolute ids. For example, if the input data is:\n\n        .. code-block:: python\n\n            {\n                \"A\": 1,\n                \"B\": {\"key\": \"@##A\", \"value1\": 2, \"value2\": \"%#value1\", \"value3\": [3, 4, \"@#1\"]},\n            }\n\n        It will resolve `B` to `{\"key\": \"@A\", \"value1\": 2, \"value2\": \"%B#value1\", \"value3\": [3, 4, \"@B#value3#1\"]}`.\n\n        Args:\n            id: id name for current config item to compute relative id.\n            value: input value to resolve relative ids.\n\n        \"\"\"\n        # get the prefixes like: \"@####\", \"%###\", \"@#\"\n        value = ReferenceResolver.normalize_id(value)\n        prefixes = sorted(set().union(cls.relative_id_prefix.findall(value)), reverse=True)\n        current_id = id.split(ID_SEP_KEY)\n\n        for p in prefixes:\n            sym = ID_REF_KEY if ID_REF_KEY in p else MACRO_KEY\n            length = p[len(sym) :].count(ID_SEP_KEY)\n            if length > len(current_id):\n                raise ValueError(f\"the relative id in `{value}` is out of the range of config content.\")\n            if length == len(current_id):\n                new = \"\"  # root id is `\"\"`\n            else:\n                new = ID_SEP_KEY.join(current_id[:-length]) + ID_SEP_KEY\n            value = value.replace(p, sym + new)\n        return value\n"
  },
  {
    "path": "monai/bundle/properties.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThe predefined properties for a bundle workflow, other applications can leverage the properties\nto interact with the bundle workflow.\nSome properties are required and some are optional, optional properties mean: if some component of the\nbundle workflow refer to the property, the property must be defined, otherwise, the property can be None.\nEvery item in this `TrainProperties` or `InferProperties` or `MetaProperties` dictionary is a property,\nthe key is the property name and the values include:\n1. description.\n2. whether it's a required property.\n3. config item ID name (only applicable when the bundle workflow is defined in config).\n4. reference config item ID name (only applicable when the bundle workflow is defined in config).\n\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom monai.bundle.utils import ID_SEP_KEY\nfrom monai.utils import BundleProperty, BundlePropertyConfig\n\nTrainProperties = {\n    \"bundle_root\": {\n        BundleProperty.DESC: \"root path of the bundle.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"bundle_root\",\n    },\n    \"device\": {\n        BundleProperty.DESC: \"target device to execute the bundle workflow.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"device\",\n    },\n    \"dataset_dir\": {\n        BundleProperty.DESC: \"directory path of the dataset.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"dataset_dir\",\n    },\n    \"trainer\": {\n        BundleProperty.DESC: \"training workflow engine.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}trainer\",\n    },\n    \"network_def\": {\n        BundleProperty.DESC: \"network module for the training.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: \"network_def\",\n    },\n    \"max_epochs\": {\n        BundleProperty.DESC: \"max number of epochs to execute the training.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}trainer{ID_SEP_KEY}max_epochs\",\n    },\n    \"train_dataset\": {\n        BundleProperty.DESC: \"PyTorch dataset object for the training logic.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}dataset\",\n    },\n    \"train_inferer\": {\n        BundleProperty.DESC: \"MONAI Inferer object to execute the model computation in training.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}inferer\",\n    },\n    \"train_dataset_data\": {\n        BundleProperty.DESC: \"data source for the training dataset.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}dataset{ID_SEP_KEY}data\",\n        BundlePropertyConfig.REF_ID: None,  # no reference to this ID\n    },\n    \"train_handlers\": {\n        BundleProperty.DESC: \"event-handlers for the training logic.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}handlers\",\n        BundlePropertyConfig.REF_ID: f\"train{ID_SEP_KEY}trainer{ID_SEP_KEY}train_handlers\",\n    },\n    \"train_preprocessing\": {\n        BundleProperty.DESC: \"preprocessing for the training input data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}preprocessing\",\n        BundlePropertyConfig.REF_ID: f\"train{ID_SEP_KEY}dataset{ID_SEP_KEY}transform\",\n    },\n    \"train_postprocessing\": {\n        BundleProperty.DESC: \"postprocessing for the training model output data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}postprocessing\",\n        BundlePropertyConfig.REF_ID: f\"train{ID_SEP_KEY}trainer{ID_SEP_KEY}postprocessing\",\n    },\n    \"train_key_metric\": {\n        BundleProperty.DESC: \"key metric to compute on the training data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"train{ID_SEP_KEY}key_metric\",\n        BundlePropertyConfig.REF_ID: f\"train{ID_SEP_KEY}trainer{ID_SEP_KEY}key_train_metric\",\n    },\n    \"evaluator\": {\n        BundleProperty.DESC: \"validation workflow engine.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}evaluator\",\n        BundlePropertyConfig.REF_ID: \"validator\",  # this REF_ID is the arg name of `ValidationHandler`\n    },\n    \"val_interval\": {\n        BundleProperty.DESC: \"validation interval during the training.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: \"val_interval\",\n        BundlePropertyConfig.REF_ID: \"interval\",  # this REF_ID is the arg name of `ValidationHandler`\n    },\n    \"val_handlers\": {\n        BundleProperty.DESC: \"event-handlers for the validation logic.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}handlers\",\n        BundlePropertyConfig.REF_ID: f\"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}val_handlers\",\n    },\n    \"val_dataset\": {\n        BundleProperty.DESC: \"PyTorch dataset object for the validation logic.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}dataset\",\n        BundlePropertyConfig.REF_ID: f\"validate{ID_SEP_KEY}dataloader{ID_SEP_KEY}dataset\",\n    },\n    \"val_dataset_data\": {\n        BundleProperty.DESC: \"data source for the validation dataset.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}dataset{ID_SEP_KEY}data\",\n        BundlePropertyConfig.REF_ID: None,  # no reference to this ID\n    },\n    \"val_inferer\": {\n        BundleProperty.DESC: \"MONAI Inferer object to execute the model computation in validation.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}inferer\",\n        BundlePropertyConfig.REF_ID: f\"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}inferer\",\n    },\n    \"val_preprocessing\": {\n        BundleProperty.DESC: \"preprocessing for the validation input data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}preprocessing\",\n        BundlePropertyConfig.REF_ID: f\"validate{ID_SEP_KEY}dataset{ID_SEP_KEY}transform\",\n    },\n    \"val_postprocessing\": {\n        BundleProperty.DESC: \"postprocessing for the validation model output data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}postprocessing\",\n        BundlePropertyConfig.REF_ID: f\"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}postprocessing\",\n    },\n    \"val_key_metric\": {\n        BundleProperty.DESC: \"key metric to compute on the validation data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"validate{ID_SEP_KEY}key_metric\",\n        BundlePropertyConfig.REF_ID: f\"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}key_val_metric\",\n    },\n}\n\nInferProperties = {\n    \"bundle_root\": {\n        BundleProperty.DESC: \"root path of the bundle.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"bundle_root\",\n    },\n    \"device\": {\n        BundleProperty.DESC: \"target device to execute the bundle workflow.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"device\",\n    },\n    \"dataset_dir\": {\n        BundleProperty.DESC: \"directory path of the dataset.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"dataset_dir\",\n    },\n    \"dataset\": {\n        BundleProperty.DESC: \"PyTorch dataset object for the inference / evaluation logic.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"dataset\",\n    },\n    \"evaluator\": {\n        BundleProperty.DESC: \"inference / evaluation workflow engine.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"evaluator\",\n    },\n    \"network_def\": {\n        BundleProperty.DESC: \"network module for the inference.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"network_def\",\n    },\n    \"inferer\": {\n        BundleProperty.DESC: \"MONAI Inferer object to execute the model computation in inference.\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: \"inferer\",\n    },\n    \"dataset_data\": {\n        BundleProperty.DESC: \"data source for the inference / evaluation dataset.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"dataset{ID_SEP_KEY}data\",\n        BundlePropertyConfig.REF_ID: None,  # no reference to this ID\n    },\n    \"handlers\": {\n        BundleProperty.DESC: \"event-handlers for the inference / evaluation logic.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: \"handlers\",\n        BundlePropertyConfig.REF_ID: f\"evaluator{ID_SEP_KEY}val_handlers\",\n    },\n    \"preprocessing\": {\n        BundleProperty.DESC: \"preprocessing for the input data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: \"preprocessing\",\n        BundlePropertyConfig.REF_ID: f\"dataset{ID_SEP_KEY}transform\",\n    },\n    \"postprocessing\": {\n        BundleProperty.DESC: \"postprocessing for the model output data.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: \"postprocessing\",\n        BundlePropertyConfig.REF_ID: f\"evaluator{ID_SEP_KEY}postprocessing\",\n    },\n    \"key_metric\": {\n        BundleProperty.DESC: \"the key metric during evaluation.\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: \"key_metric\",\n        BundlePropertyConfig.REF_ID: f\"evaluator{ID_SEP_KEY}key_val_metric\",\n    },\n}\n\nMetaProperties = {\n    \"version\": {\n        BundleProperty.DESC: \"bundle version\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}version\",\n    },\n    \"monai_version\": {\n        BundleProperty.DESC: \"required monai version used for bundle\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}monai_version\",\n    },\n    \"pytorch_version\": {\n        BundleProperty.DESC: \"required pytorch version used for bundle\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}pytorch_version\",\n    },\n    \"numpy_version\": {\n        BundleProperty.DESC: \"required numpy version used for bundle\",\n        BundleProperty.REQUIRED: True,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}numpy_version\",\n    },\n    \"description\": {\n        BundleProperty.DESC: \"description for bundle\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}description\",\n    },\n    \"spatial_shape\": {\n        BundleProperty.DESC: \"spatial shape for the inputs\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}inputs{ID_SEP_KEY}image\"\n        f\"{ID_SEP_KEY}spatial_shape\",\n    },\n    \"channel_def\": {\n        BundleProperty.DESC: \"channel definition for the prediction\",\n        BundleProperty.REQUIRED: False,\n        BundlePropertyConfig.ID: f\"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}outputs{ID_SEP_KEY}pred{ID_SEP_KEY}channel_def\",\n    },\n}\n"
  },
  {
    "path": "monai/bundle/reference_resolver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport re\nimport warnings\nfrom collections.abc import Iterator, Sequence\nfrom typing import Any\n\nfrom monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem\nfrom monai.bundle.utils import DEPRECATED_ID_MAPPING, ID_REF_KEY, ID_SEP_KEY\nfrom monai.utils import allow_missing_reference, look_up_option\n\n__all__ = [\"ReferenceResolver\"]\n\n\nclass ReferenceResolver:\n    \"\"\"\n    Utility class to manage a set of ``ConfigItem`` and resolve the references between them.\n\n    This class maintains a set of ``ConfigItem`` objects and their associated IDs.\n    The IDs must be unique within this set. A string in ``ConfigItem``\n    starting with ``@`` will be treated as a reference to other ``ConfigItem`` objects by ID.\n    Since ``ConfigItem`` may have a nested dictionary or list structure,\n    the reference string may also contain the separator ``::`` to refer to a substructure by\n    key indexing for a dictionary or integer indexing for a list.\n\n    In this class, resolving references is essentially substitution of the reference strings with the\n    corresponding python objects. A typical workflow of resolving references is as follows:\n\n        - Add multiple ``ConfigItem`` objects to the ``ReferenceResolver`` by ``add_item()``.\n        - Call ``get_resolved_content()`` to automatically resolve the references. This is done (recursively) by:\n            - Convert the items to objects, for those do not have references to other items.\n                - If it is instantiable, instantiate it and cache the class instance in ``resolved_content``.\n                - If it is an expression, evaluate it and save the value in ``resolved_content``.\n            - Substitute the reference strings with the corresponding objects.\n\n    Args:\n        items: ``ConfigItem``s to resolve, this could be added later with ``add_item()``.\n\n    \"\"\"\n\n    _vars = \"__local_refs\"\n    sep = ID_SEP_KEY  # separator for key indexing\n    ref = ID_REF_KEY  # reference prefix\n    # match a reference string, e.g. \"@id::key\", \"@id::key::0\", \"@_target_::key\"\n    id_matcher = re.compile(rf\"{ref}(?:\\w*)(?:{sep}\\w*)*\")\n    # if `allow_missing_reference` and can't find a reference ID, will just raise a warning and don't update the config\n    allow_missing_reference = allow_missing_reference\n\n    def __init__(self, items: Sequence[ConfigItem] | None = None):\n        # save the items in a dictionary with the `ConfigItem.id` as key\n        self.items: dict[str, ConfigItem] = {} if items is None else {i.get_id(): i for i in items}\n        self.resolved_content: dict[str, ConfigExpression | str | Any | None] = {}\n\n    def reset(self):\n        \"\"\"\n        Clear all the added `ConfigItem` and all the resolved content.\n\n        \"\"\"\n        self.items = {}\n        self.resolved_content = {}\n\n    def is_resolved(self) -> bool:\n        return bool(self.resolved_content)\n\n    def add_item(self, item: ConfigItem) -> None:\n        \"\"\"\n        Add a ``ConfigItem`` to the resolver.\n\n        Args:\n            item: a ``ConfigItem``.\n\n        \"\"\"\n        id = item.get_id()\n        if id in self.items:\n            return\n        self.items[id] = item\n\n    def get_item(self, id: str, resolve: bool = False, **kwargs: Any) -> ConfigItem | None:\n        \"\"\"\n        Get the ``ConfigItem`` by id.\n\n        If ``resolve=True``, the returned item will be resolved, that is,\n        all the reference strings are substituted by the corresponding ``ConfigItem`` objects.\n\n        Args:\n            id: id of the expected config item.\n            resolve: whether to resolve the item if it is not resolved, default to False.\n            kwargs: keyword arguments to pass to ``_resolve_one_item()``.\n                Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True.\n        \"\"\"\n        id = self.normalize_id(id)\n        if resolve and id not in self.resolved_content:\n            self._resolve_one_item(id=id, **kwargs)\n        return self.items.get(id)\n\n    def _resolve_one_item(\n        self, id: str, waiting_list: set[str] | None = None, **kwargs: Any\n    ) -> ConfigExpression | str | Any | None:\n        \"\"\"\n        Resolve and return one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``.\n        If it has unresolved references, recursively resolve the referring items first.\n\n        Args:\n            id: id name of ``ConfigItem`` to be resolved.\n            waiting_list: set of ids pending to be resolved.\n                It's used to detect circular references such as:\n                `{\"name\": \"A\", \"dep\": \"@B\"}` and `{\"name\": \"B\", \"dep\": \"@A\"}`.\n            kwargs: keyword arguments to pass to ``_resolve_one_item()``.\n                Currently support ``instantiate``, ``eval_expr`` and ``default``.\n                `instantiate` and `eval_expr` are defaulting to True, `default` is the target config item\n                if the `id` is not in the config content, must be a `ConfigItem` object.\n\n        \"\"\"\n        id = self.normalize_id(id)\n        if id in self.resolved_content:\n            return self.resolved_content[id]\n        try:\n            item = look_up_option(id, self.items, print_all_options=False, default=kwargs.get(\"default\", \"no_default\"))\n        except ValueError as err:\n            raise KeyError(f\"id='{id}' is not found in the config resolver.\") from err\n        if not isinstance(item, ConfigItem):\n            return item\n        item_config = item.get_config()\n\n        if waiting_list is None:\n            waiting_list = set()\n        waiting_list.add(id)\n\n        for t, v in self.items.items():\n            if (\n                t not in self.resolved_content\n                and isinstance(v, ConfigExpression)\n                and v.is_import_statement(v.get_config())\n            ):\n                self.resolved_content[t] = v.evaluate() if kwargs.get(\"eval_expr\", True) else v\n        for d in self.find_refs_in_config(config=item_config, id=id).keys():\n            # if current item has reference already in the waiting list, that's circular references\n            if d in waiting_list:\n                raise ValueError(f\"detected circular references '{d}' for id='{id}' in the config content.\")\n            # check whether the component has any unresolved references\n            if d not in self.resolved_content:\n                # this referring item is not resolved\n                try:\n                    look_up_option(d, self.items, print_all_options=False)\n                except ValueError as err:\n                    msg = f\"the referring item `@{d}` is not defined in the config content.\"\n                    if not self.allow_missing_reference:\n                        raise ValueError(msg) from err\n                    warnings.warn(msg)\n                    continue\n                # recursively resolve the reference first\n                self._resolve_one_item(id=d, waiting_list=waiting_list, **kwargs)\n                waiting_list.discard(d)\n\n        # all references are resolved, then try to resolve current config item\n        new_config = self.update_config_with_refs(config=item_config, id=id, refs=self.resolved_content)\n        item.update_config(config=new_config)\n        # save the resolved result into `resolved_content` to recursively resolve others\n        if isinstance(item, ConfigComponent):\n            self.resolved_content[id] = item.instantiate() if kwargs.get(\"instantiate\", True) else item\n        elif isinstance(item, ConfigExpression):\n            run_eval = kwargs.get(\"eval_expr\", True)\n            self.resolved_content[id] = (\n                item.evaluate(globals={f\"{self._vars}\": self.resolved_content}) if run_eval else item\n            )\n        else:\n            self.resolved_content[id] = new_config\n        return self.resolved_content[id]\n\n    def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str | Any | None:\n        \"\"\"\n        Get the resolved ``ConfigItem`` by id.\n\n        Args:\n            id: id name of the expected item.\n            kwargs: keyword arguments to pass to ``_resolve_one_item()``.\n                Currently support ``instantiate``, ``eval_expr`` and ``default``.\n                `instantiate` and `eval_expr` are defaulting to True, `default` is the target config item\n                if the `id` is not in the config content, must be a `ConfigItem` object.\n\n        \"\"\"\n        return self._resolve_one_item(id=id, **kwargs)\n\n    def remove_resolved_content(self, id: str) -> Any | None:\n        \"\"\"\n        Remove the resolved ``ConfigItem`` by id.\n\n        Args:\n            id: id name of the expected item.\n\n        \"\"\"\n        return self.resolved_content.pop(id) if id in self.resolved_content else None\n\n    @classmethod\n    def normalize_id(cls, id: str | int) -> str:\n        \"\"\"\n        Normalize the id string to consistently use `cls.sep`.\n\n        Args:\n            id: id string to be normalized.\n        \"\"\"\n        return str(id).replace(\"#\", cls.sep)  # backward compatibility `#` is the old separator\n\n    def normalize_meta_id(self, config: Any) -> Any:\n        \"\"\"\n        Update deprecated identifiers in `config` using `DEPRECATED_ID_MAPPING`.\n        This will replace names that are marked as deprecated with their replacement.\n\n        Args:\n            config: input config to be updated.\n        \"\"\"\n        if isinstance(config, dict):\n            for _id, _new_id in DEPRECATED_ID_MAPPING.items():\n                if _id in config.keys():\n                    warnings.warn(\n                        f\"Detected deprecated name '{_id}' in configuration file, replacing with '{_new_id}'.\"\n                    )\n                    config[_new_id] = config.pop(_id)\n        return config\n\n    @classmethod\n    def split_id(cls, id: str | int, last: bool = False) -> list[str]:\n        \"\"\"\n        Split the id string into a list of strings by `cls.sep`.\n\n        Args:\n            id: id string to be split.\n            last: whether to split the rightmost part of the id. default is False (split all parts).\n        \"\"\"\n        if not last:\n            return cls.normalize_id(id).split(cls.sep)\n        res = cls.normalize_id(id).rsplit(cls.sep, 1)\n        return [\"\".join(res[:-1]), res[-1]]\n\n    @classmethod\n    def iter_subconfigs(cls, id: str, config: Any) -> Iterator[tuple[str, str, Any]]:\n        \"\"\"\n        Iterate over the sub-configs of the input config, the output `sub_id` uses `cls.sep` to denote substructure.\n\n        Args:\n            id: id string of the current input config.\n            config: input config to be iterated.\n        \"\"\"\n        for k, v in config.items() if isinstance(config, dict) else enumerate(config):\n            sub_id = f\"{id}{cls.sep}{k}\" if id != \"\" else f\"{k}\"\n            yield k, sub_id, v\n\n    @classmethod\n    def match_refs_pattern(cls, value: str) -> dict[str, int]:\n        \"\"\"\n        Match regular expression for the input string to find the references.\n        The reference string starts with ``\"@\"``, like: ``\"@XXX::YYY::ZZZ\"``.\n\n        Args:\n            value: input value to match regular expression.\n\n        \"\"\"\n        refs: dict[str, int] = {}\n        # regular expression pattern to match \"@XXX\" or \"@XXX::YYY\"\n        value = cls.normalize_id(value)\n        result = cls.id_matcher.findall(value)\n        value_is_expr = ConfigExpression.is_expression(value)\n        for item in result:\n            if value_is_expr or value == item:\n                # only check when string starts with \"$\" or the whole content is \"@XXX\"\n                id = item[len(cls.ref) :]\n                refs[id] = refs.get(id, 0) + 1\n        return refs\n\n    @classmethod\n    def update_refs_pattern(cls, value: str, refs: dict) -> str:\n        \"\"\"\n        Match regular expression for the input string to update content with the references.\n        The reference part starts with ``\"@\"``, like: ``\"@XXX::YYY::ZZZ\"``.\n        References dictionary must contain the referring IDs as keys.\n\n        Args:\n            value: input value to match regular expression.\n            refs: all the referring components with ids as keys, default to `None`.\n\n        \"\"\"\n        # regular expression pattern to match \"@XXX\" or \"@XXX::YYY\"\n        value = cls.normalize_id(value)\n        result = cls.id_matcher.findall(value)\n        # reversely sort the matched references by length\n        # and handle the longer first in case a reference item is substring of another longer item\n        result.sort(key=len, reverse=True)\n        value_is_expr = ConfigExpression.is_expression(value)\n        for item in result:\n            # only update reference when string starts with \"$\" or the whole content is \"@XXX\"\n            if value_is_expr or value == item:\n                ref_id = item[len(cls.ref) :]  # remove the ref prefix \"@\"\n                if ref_id not in refs:\n                    msg = f\"can not find expected ID '{ref_id}' in the references.\"\n                    if not cls.allow_missing_reference:\n                        raise KeyError(msg)\n                    warnings.warn(msg)\n                    continue\n                if value_is_expr:\n                    # replace with local code, `{\"__local_refs\": self.resolved_content}` will be added to\n                    # the `globals` argument of python `eval` in the `evaluate`\n                    value = value.replace(item, f\"{cls._vars}['{ref_id}']\")\n                elif value == item:\n                    # the whole content is \"@XXX\", it will avoid the case that regular string contains \"@\"\n                    value = refs[ref_id]\n        return value\n\n    @classmethod\n    def find_refs_in_config(cls, config: Any, id: str, refs: dict[str, int] | None = None) -> dict[str, int]:\n        \"\"\"\n        Recursively search all the content of input config item to get the ids of references.\n        References mean: the IDs of other config items (``\"@XXX\"`` in this config item), or the\n        sub-item in the config is `instantiable`, or the sub-item in the config is `expression`.\n        For `dict` and `list`, recursively check the sub-items.\n\n        Args:\n            config: input config content to search.\n            id: ID name for the input config item.\n            refs: dict of the ID name and count of found references, default to `None`.\n\n        \"\"\"\n        refs_: dict[str, int] = refs or {}\n        if isinstance(config, str):\n            for id, count in cls.match_refs_pattern(value=config).items():  # ref count is not currently used\n                refs_[id] = refs_.get(id, 0) + count\n        if not isinstance(config, (list, dict)):\n            return refs_\n        for _, sub_id, v in cls.iter_subconfigs(id, config):\n            if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v) and sub_id not in refs_:\n                refs_[sub_id] = 1\n            refs_ = cls.find_refs_in_config(v, sub_id, refs_)\n        return refs_\n\n    @classmethod\n    def update_config_with_refs(cls, config: Any, id: str, refs: dict | None = None) -> Any:\n        \"\"\"\n        With all the references in ``refs``, update the input config content with references\n        and return the new config.\n\n        Args:\n            config: input config content to update.\n            id: ID name for the input config.\n            refs: all the referring content with ids, default to `None`.\n\n        \"\"\"\n        refs_: dict = refs or {}\n        if isinstance(config, str):\n            return cls.update_refs_pattern(config, refs_)\n        if not isinstance(config, (list, dict)):\n            return config\n        ret = type(config)()\n        for idx, sub_id, v in cls.iter_subconfigs(id, config):\n            if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v):\n                updated = refs_[sub_id]\n                if ConfigComponent.is_instantiable(v) and updated is None:\n                    # the component is disabled\n                    continue\n            else:\n                updated = cls.update_config_with_refs(v, sub_id, refs_)\n            ret.update({idx: updated}) if isinstance(ret, dict) else ret.append(updated)\n        return ret\n"
  },
  {
    "path": "monai/bundle/scripts.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport ast\nimport json\nimport os\nimport re\nimport urllib\nimport warnings\nfrom collections.abc import Mapping, Sequence\nfrom functools import partial\nfrom pathlib import Path\nfrom pydoc import locate\nfrom shutil import copyfile\nfrom textwrap import dedent\nfrom typing import Any, Callable\n\nimport torch\nfrom torch.cuda import is_available\n\nfrom monai._version import get_versions\nfrom monai.apps.utils import _basename, _extract_zip, download_url, extractall, get_logger\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv\nfrom monai.bundle.workflows import BundleWorkflow, ConfigWorkflow\nfrom monai.config import PathLike\nfrom monai.data import load_net_with_metadata, save_net_with_metadata\nfrom monai.networks import (\n    convert_to_onnx,\n    convert_to_torchscript,\n    convert_to_trt,\n    copy_model_state,\n    get_state_dict,\n    save_state,\n)\nfrom monai.utils import (\n    IgniteInfo,\n    check_parent_dir,\n    ensure_tuple,\n    get_equivalent_dtype,\n    min_version,\n    optional_import,\n    pprint_edges,\n)\n\nvalidate, _ = optional_import(\"jsonschema\", name=\"validate\")\nValidationError, _ = optional_import(\"jsonschema.exceptions\", name=\"ValidationError\")\nCheckpoint, has_ignite = optional_import(\"ignite.handlers\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Checkpoint\")\nrequests, has_requests = optional_import(\"requests\")\nonnx, _ = optional_import(\"onnx\")\nhuggingface_hub, _ = optional_import(\"huggingface_hub\")\n\nlogger = get_logger(module_name=__name__)\n\n# set BUNDLE_DOWNLOAD_SRC=\"ngc\" to use NGC source in default for bundle download\n# set BUNDLE_DOWNLOAD_SRC=\"github\" to use github source in default for bundle download\nDEFAULT_DOWNLOAD_SOURCE = os.environ.get(\"BUNDLE_DOWNLOAD_SRC\", \"monaihosting\")\nPPRINT_CONFIG_N = 5\n\nMONAI_HOSTING_BASE_URL = \"https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting\"\nNGC_BASE_URL = \"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit\"\n\n\ndef update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:\n    \"\"\"\n    Update the `args` dictionary with the input `kwargs`.\n    For dict data, recursively update the content based on the keys.\n\n    Example::\n\n        from monai.bundle import update_kwargs\n        update_kwargs({'exist': 1}, exist=2, new_arg=3)\n        # return {'exist': 2, 'new_arg': 3}\n\n    Args:\n        args: source `args` dictionary (or a json/yaml filename to read as dictionary) to update.\n        ignore_none: whether to ignore input args with None value, default to `True`.\n        kwargs: key=value pairs to be merged into `args`.\n\n    \"\"\"\n    args_: dict = args if isinstance(args, dict) else {}\n    if isinstance(args, str):\n        # args are defined in a structured file\n        args_ = ConfigParser.load_config_file(args)\n    if isinstance(args, (tuple, list)) and all(isinstance(x, str) for x in args):\n        primary, overrides = args\n        args_ = update_kwargs(primary, ignore_none, **update_kwargs(overrides, ignore_none, **kwargs))\n    if not isinstance(args_, dict):\n        return args_\n    # recursively update the default args with new args\n    for k, v in kwargs.items():\n        if ignore_none and v is None:\n            continue\n        if isinstance(v, dict) and isinstance(args_.get(k), dict):\n            args_[k] = update_kwargs(args_[k], ignore_none, **v)\n        else:\n            merge_kv(args_, k, v)\n    return args_\n\n\n_update_args = update_kwargs  # backward compatibility\n\n\ndef _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:\n    \"\"\"\n    Pop args from the `src` dictionary based on specified keys in `args` and (key, default value) pairs in `kwargs`.\n\n    \"\"\"\n    return tuple([src.pop(i) for i in args] + [src.pop(k, v) for k, v in kwargs.items()])\n\n\ndef _log_input_summary(tag: str, args: dict) -> None:\n    logger.info(f\"--- input summary of monai.bundle.scripts.{tag} ---\")\n    for name, val in args.items():\n        logger.info(f\"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}\")\n    logger.info(\"---\\n\\n\")\n\n\ndef _get_var_names(expr: str) -> list[str]:\n    \"\"\"\n    Parse the expression and discover what variables are present in it based on ast module.\n\n    Args:\n        expr: source expression to parse.\n\n    \"\"\"\n    tree = ast.parse(expr)\n    return [m.id for m in ast.walk(tree) if isinstance(m, ast.Name)]\n\n\ndef _get_fake_spatial_shape(shape: Sequence[str | int], p: int = 1, n: int = 1, any: int = 1) -> tuple:\n    \"\"\"\n    Get spatial shape for fake data according to the specified shape pattern.\n    It supports `int` number and `string` with formats like: \"32\", \"32 * n\", \"32 ** p\", \"32 ** p *n\".\n\n    Args:\n        shape: specified pattern for the spatial shape.\n        p: power factor to generate fake data shape if dim of expected shape is \"x**p\", default to 1.\n        p: multiply factor to generate fake data shape if dim of expected shape is \"x*n\", default to 1.\n        any: specified size to generate fake data shape if dim of expected shape is \"*\", default to 1.\n\n    \"\"\"\n    ret = []\n    for i in shape:\n        if isinstance(i, int):\n            ret.append(i)\n        elif isinstance(i, str):\n            if i == \"*\":\n                ret.append(any)\n            else:\n                for c in _get_var_names(i):\n                    if c not in [\"p\", \"n\"]:\n                        raise ValueError(f\"only support variables 'p' and 'n' so far, but got: {c}.\")\n                ret.append(eval(i, {\"p\": p, \"n\": n}))\n        else:\n            raise ValueError(f\"spatial shape items must be int or string, but got: {type(i)} {i}.\")\n    return tuple(ret)\n\n\ndef _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filename: str) -> str:\n    return f\"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}\"\n\n\ndef _get_ngc_bundle_url(model_name: str, version: str) -> str:\n    return f\"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files\"\n\n\ndef _get_ngc_private_base_url(repo: str) -> str:\n    return f\"https://api.ngc.nvidia.com/v2/{repo}/models\"\n\n\ndef _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:\n    return f\"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip\"\n\n\ndef _get_monaihosting_bundle_url(model_name: str, version: str) -> str:\n    return f\"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip\"\n\n\ndef _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:\n    repo_owner, repo_name, tag_name = repo.split(\"/\")\n    if \".zip\" not in filename:\n        filename += \".zip\"\n    url = _get_git_release_url(repo_owner, repo_name, tag_name=tag_name, filename=filename)\n    filepath = download_path / f\"{filename}\"\n    download_url(url=url, filepath=filepath, hash_val=None, progress=progress)\n    extractall(filepath=filepath, output_dir=download_path, has_base=True)\n\n\ndef _download_from_monaihosting(download_path: Path, filename: str, version: str, progress: bool) -> None:\n    url = _get_monaihosting_bundle_url(model_name=filename, version=version)\n    filepath = download_path / f\"{filename}_v{version}.zip\"\n    download_url(url=url, filepath=filepath, hash_val=None, progress=progress)\n    extractall(filepath=filepath, output_dir=download_path, has_base=True)\n\n\ndef _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:\n    bundle_info = get_bundle_info(bundle_name=filename, version=version)\n    if not bundle_info:\n        raise ValueError(f\"Bundle info not found for {filename} v{version}.\")\n    url = bundle_info[\"browser_download_url\"]\n    filepath = download_path / f\"{filename}_v{version}.zip\"\n    download_url(url=url, filepath=filepath, hash_val=None, progress=progress)\n    extractall(filepath=filepath, output_dir=download_path, has_base=True)\n\n\ndef _add_ngc_prefix(name: str, prefix: str = \"monai_\") -> str:\n    if name.startswith(prefix):\n        return name\n    return f\"{prefix}{name}\"\n\n\ndef _remove_ngc_prefix(name: str, prefix: str = \"monai_\") -> str:\n    if name.startswith(prefix):\n        return name[len(prefix) :]\n    return name\n\n\ndef _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]:\n    if not has_requests:\n        raise ValueError(\"requests package is required, please install it.\")\n    headers = {} if headers is None else headers\n    response = requests.get(request_url, headers=headers)\n    response.raise_for_status()\n    model_info = json.loads(response.text)\n\n    if not isinstance(model_info, dict) or \"modelFiles\" not in model_info:\n        raise ValueError(\"The data is not a dictionary or it does not have the key 'modelFiles'.\")\n\n    model_files = model_info[\"modelFiles\"]\n    return [f[\"path\"] for f in model_files]\n\n\ndef _download_from_ngc(\n    download_path: Path,\n    filename: str,\n    version: str,\n    prefix: str = \"monai_\",\n    remove_prefix: str | None = \"monai_\",\n    progress: bool = True,\n) -> None:\n    # ensure prefix is contained\n    filename = _add_ngc_prefix(filename, prefix=prefix)\n    url = _get_ngc_bundle_url(model_name=filename, version=version)\n    if remove_prefix:\n        filename = _remove_ngc_prefix(filename, prefix=remove_prefix)\n    filepath = download_path / filename\n    filepath.mkdir(parents=True, exist_ok=True)\n    for file in _get_all_download_files(url):\n        download_url(url=f\"{url}/{file}\", filepath=f\"{filepath}/{file}\", hash_val=None, progress=progress)\n\n\ndef _download_from_ngc_private(\n    download_path: Path,\n    filename: str,\n    version: str,\n    repo: str,\n    prefix: str = \"monai_\",\n    remove_prefix: str | None = \"monai_\",\n    headers: dict | None = None,\n) -> None:\n    # ensure prefix is contained\n    filename = _add_ngc_prefix(filename, prefix=prefix)\n    request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)\n    if has_requests:\n        headers = {} if headers is None else headers\n        response = requests.get(request_url, headers=headers)\n        response.raise_for_status()\n    else:\n        raise ValueError(\"NGC API requires requests package. Please install it.\")\n\n    os.makedirs(download_path, exist_ok=True)\n    zip_path = download_path / f\"{filename}_v{version}.zip\"\n    with open(zip_path, \"wb\") as f:\n        f.write(response.content)\n    logger.info(f\"Downloading: {zip_path}.\")\n    if remove_prefix:\n        filename = _remove_ngc_prefix(filename, prefix=remove_prefix)\n    extract_path = download_path / f\"{filename}\"\n    _extract_zip(zip_path, extract_path)\n    logger.info(f\"Writing into directory: {extract_path}.\")\n\n\ndef _get_ngc_token(api_key, retry=0):\n    \"\"\"Try to connect to NGC.\"\"\"\n    url = \"https://authn.nvidia.com/token?service=ngc\"\n    headers = {\"Accept\": \"application/json\", \"Authorization\": \"ApiKey \" + api_key}\n    if has_requests:\n        response = requests.get(url, headers=headers)\n        if not response.ok:\n            # retry 3 times, if failed, raise an error.\n            if retry < 3:\n                logger.info(f\"Retrying {retry} time(s) to GET {url}.\")\n                return _get_ngc_token(url, retry + 1)\n            raise RuntimeError(\"NGC API response is not ok. Failed to get token.\")\n        else:\n            token = response.json()[\"token\"]\n        return token\n\n\ndef _examine_monai_version(monai_version: str) -> tuple[bool, str]:\n    \"\"\"Examine if the package version is compatible with the MONAI version in the metadata.\"\"\"\n    version_dict = get_versions()\n    package_version = version_dict.get(\"version\", \"0+unknown\")\n    if package_version == \"0+unknown\":\n        return False, \"Package version is not available. Skipping version check.\"\n    if monai_version == \"0+unknown\":\n        return False, \"MONAI version is not specified in the bundle. Skipping version check.\"\n    # treat rc versions as the same as the release version\n    package_version = re.sub(r\"rc\\d.*\", \"\", package_version)\n    monai_version = re.sub(r\"rc\\d.*\", \"\", monai_version)\n    if package_version < monai_version:\n        return (\n            False,\n            f\"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.\",\n        )\n    return True, \"\"\n\n\ndef _check_monai_version(bundle_dir: PathLike, name: str) -> None:\n    \"\"\"Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version\"\"\"\n    metadata_file = Path(bundle_dir) / name / \"configs\" / \"metadata.json\"\n    if not metadata_file.exists():\n        logger.warning(f\"metadata file not found in {metadata_file}.\")\n        return\n    with open(metadata_file) as f:\n        metadata = json.load(f)\n    is_compatible, msg = _examine_monai_version(metadata.get(\"monai_version\", \"0+unknown\"))\n    if not is_compatible:\n        logger.warning(msg)\n\n\ndef _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:\n    \"\"\"\n    Extract the latest versions from the data dictionary.\n\n    Args:\n        data: the data dictionary.\n        max_versions: the maximum number of versions to return.\n\n    Returns:\n        versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].\n    \"\"\"\n    # Check if the data is a dictionary and it has the key 'modelVersions'\n    if not isinstance(data, dict) or \"modelVersions\" not in data:\n        raise ValueError(\"The data is not a dictionary or it does not have the key 'modelVersions'.\")\n\n    # Extract the list of model versions\n    model_versions = data[\"modelVersions\"]\n\n    if (\n        not isinstance(model_versions, list)\n        or len(model_versions) == 0\n        or \"createdDate\" not in model_versions[0]\n        or \"versionId\" not in model_versions[0]\n    ):\n        raise ValueError(\n            \"The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'.\"\n        )\n\n    # Sort the versions by the 'createdDate' in descending order\n    sorted_versions = sorted(model_versions, key=lambda x: x[\"createdDate\"], reverse=True)\n    return [v[\"versionId\"] for v in sorted_versions[:max_versions]]\n\n\ndef _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:\n    base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL\n    version_endpoint = base_url + f\"/{name.lower()}/versions/\"\n\n    if not has_requests:\n        raise ValueError(\"requests package is required, please install it.\")\n\n    version_header = {\"Accept-Encoding\": \"gzip, deflate\"}  # Excluding 'zstd' to fit NGC requirements\n    if headers:\n        version_header.update(headers)\n    resp = requests.get(version_endpoint, headers=version_header)\n    resp.raise_for_status()\n    model_info = json.loads(resp.text)\n    latest_versions = _list_latest_versions(model_info)\n\n    for version in latest_versions:\n        file_endpoint = base_url + f\"/{name.lower()}/versions/{version}/files/configs/metadata.json\"\n        resp = requests.get(file_endpoint, headers=headers)\n        metadata = json.loads(resp.text)\n        resp.raise_for_status()\n        # if the package version is not available or the model is compatible with the package version\n        is_compatible, _ = _examine_monai_version(metadata[\"monai_version\"])\n        if is_compatible:\n            if version != latest_versions[0]:\n                logger.info(f\"Latest version is {latest_versions[0]}, but the compatible version is {version}.\")\n            return version\n\n    # if no compatible version is found, return the latest version\n    return latest_versions[0]\n\n\ndef _get_latest_bundle_version(\n    source: str, name: str, repo: str, **kwargs: Any\n) -> dict[str, list[str] | str] | Any | None:\n    if source == \"ngc\":\n        name = _add_ngc_prefix(name)\n        return _get_latest_bundle_version_ngc(name)\n    elif source == \"monaihosting\":\n        return get_bundle_versions(name, repo=\"Project-MONAI/model-zoo\", tag=\"dev\")[\"latest_version\"]\n    elif source == \"ngc_private\":\n        headers = kwargs.pop(\"headers\", {})\n        name = _add_ngc_prefix(name)\n        return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)\n    elif source == \"github\":\n        repo_owner, repo_name, tag_name = repo.split(\"/\")\n        return get_bundle_versions(name, repo=f\"{repo_owner}/{repo_name}\", tag=tag_name)[\"latest_version\"]\n    elif source == \"huggingface_hub\":\n        refs = huggingface_hub.list_repo_refs(repo_id=repo)\n        if len(refs.tags) > 0:\n            all_versions = [t.name for t in refs.tags]  # git tags, not to be confused with `tag`\n            latest_version = [\"latest_version\" if \"latest_version\" in all_versions else all_versions[-1]][0]\n        else:\n            latest_version = [b.name for b in refs.branches][0]  # use the branch that was last updated\n        return latest_version\n    else:\n        raise ValueError(\n            f\"To get the latest bundle version, source should be 'github', 'monaihosting' or 'ngc', got {source}.\"\n        )\n\n\ndef _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path:\n    if bundle_dir is None:\n        get_dir, has_home = optional_import(\"torch.hub\", name=\"get_dir\")\n        if has_home:\n            bundle_dir = Path(get_dir()) / \"bundle\"\n        else:\n            raise ValueError(\"bundle_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?\")\n    return Path(bundle_dir)\n\n\ndef download(\n    name: str | None = None,\n    version: str | None = None,\n    bundle_dir: PathLike | None = None,\n    source: str = DEFAULT_DOWNLOAD_SOURCE,\n    repo: str | None = None,\n    url: str | None = None,\n    remove_prefix: str | None = \"monai_\",\n    progress: bool = True,\n    args_file: str | None = None,\n) -> None:\n    \"\"\"\n    download bundle from the specified source or url. The bundle should be a zip file and it\n    will be extracted after downloading.\n    This function refers to:\n    https://pytorch.org/docs/stable/_modules/torch/hub.html\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        # Execute this module as a CLI entry, and download bundle from the model-zoo repo:\n        python -m monai.bundle download --name <bundle_name> --version \"0.1.0\" --bundle_dir \"./\"\n\n        # Execute this module as a CLI entry, and download bundle from specified github repo:\n        python -m monai.bundle download --name <bundle_name> --source \"github\" --repo \"repo_owner/repo_name/release_tag\"\n\n        # Execute this module as a CLI entry, and download bundle from ngc with latest version:\n        python -m monai.bundle download --name <bundle_name> --source \"ngc\" --bundle_dir \"./\"\n\n        # Execute this module as a CLI entry, and download bundle from monaihosting with latest version:\n        python -m monai.bundle download --name <bundle_name> --source \"monaihosting\" --bundle_dir \"./\"\n\n        # Execute this module as a CLI entry, and download bundle from Hugging Face Hub:\n        python -m monai.bundle download --name \"bundle_name\" --source \"huggingface_hub\" --repo \"repo_owner/repo_name\"\n\n        # Execute this module as a CLI entry, and download bundle via URL:\n        python -m monai.bundle download --name <bundle_name> --url <url>\n\n        # Execute this module as a CLI entry, and download bundle from ngc_private with latest version:\n        python -m monai.bundle download --name <bundle_name> --source \"ngc_private\" --bundle_dir \"./\" --repo \"org/org_name\"\n\n        # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.\n        # Other args still can override the default args at runtime.\n        # The content of the JSON / YAML file is a dictionary. For example:\n        # {\"name\": \"spleen\", \"bundle_dir\": \"download\", \"source\": \"\"}\n        # then do the following command for downloading:\n        python -m monai.bundle download --args_file \"args.json\" --source \"github\"\n\n    Args:\n        name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.\n            for example:\n            \"spleen_ct_segmentation\", \"prostate_mri_anatomy\" in model-zoo:\n            https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.\n            \"monai_brats_mri_segmentation\" in ngc:\n            https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.\n        version: version name of the target bundle to download, like: \"0.1.0\". If `None`, will download\n            the latest version (or the last commit to the `main` branch in the case of Hugging Face Hub).\n        bundle_dir: target directory to store the downloaded data.\n            Default is `bundle` subfolder under `torch.hub.get_dir()`.\n        source: storage location name. This argument is used when `url` is `None`.\n            In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and\n            it should be \"ngc\", \"monaihosting\", \"github\", \"ngc_private\", or \"huggingface_hub\".\n            If source is \"ngc_private\", you need specify the NGC_API_KEY in the environment variable.\n        repo: repo name. This argument is used when `url` is `None` and `source` is \"github\" or \"huggingface_hub\".\n            If `source` is \"github\", it should be in the form of \"repo_owner/repo_name/release_tag\".\n            If `source` is \"huggingface_hub\", it should be in the form of \"repo_owner/repo_name\". Please note that\n            bundles for \"monaihosting\" source are also hosted on Hugging Face Hub, but the \"repo_id\" is always in the form\n            of \"MONAI/bundle_name\", therefore, this argument is not required for \"monaihosting\" source.\n            If `source` is \"ngc_private\", it should be in the form of \"org/org_name\" or \"org/org_name/team/team_name\",\n            or you can specify the environment variable NGC_ORG and NGC_TEAM.\n        url: url to download the data. If not `None`, data will be downloaded directly\n            and `source` will not be checked.\n            If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.\n        remove_prefix: This argument is used when `source` is \"ngc\" or \"ngc_private\". Currently, all ngc bundles\n            have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to\n            maintain the consistency between these two sources, remove prefix is necessary.\n            Therefore, if specified, downloaded folder name will remove the prefix.\n        progress: whether to display a progress bar.\n        args_file: a JSON or YAML file to provide default values for all the args in this function.\n            so that the command line inputs can be simplified.\n\n    \"\"\"\n    _args = update_kwargs(\n        args=args_file,\n        name=name,\n        version=version,\n        bundle_dir=bundle_dir,\n        source=source,\n        repo=repo,\n        url=url,\n        remove_prefix=remove_prefix,\n        progress=progress,\n    )\n\n    _log_input_summary(tag=\"download\", args=_args)\n    source_, progress_, remove_prefix_, repo_, name_, version_, bundle_dir_, url_ = _pop_args(\n        _args, \"source\", \"progress\", remove_prefix=None, repo=None, name=None, version=None, bundle_dir=None, url=None\n    )\n\n    bundle_dir_ = _process_bundle_dir(bundle_dir_)\n    if repo_ is None:\n        org_ = os.getenv(\"NGC_ORG\", None)\n        team_ = os.getenv(\"NGC_TEAM\", None)\n        if org_ is not None and source_ == \"ngc_private\":\n            repo_ = f\"org/{org_}/team/{team_}\" if team_ is not None else f\"org/{org_}\"\n        else:\n            repo_ = \"Project-MONAI/model-zoo/hosting_storage_v1\"\n    if len(repo_.split(\"/\")) not in (2, 4) and source_ == \"ngc_private\":\n        raise ValueError(f\"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.\")\n    if len(repo_.split(\"/\")) != 3 and source_ == \"github\":\n        raise ValueError(f\"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.\")\n    elif len(repo_.split(\"/\")) != 2 and source_ == \"huggingface_hub\":\n        raise ValueError(f\"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.\")\n    if url_ is not None:\n        if name_ is not None:\n            filepath = bundle_dir_ / f\"{name_}.zip\"\n        else:\n            filepath = bundle_dir_ / f\"{_basename(url_)}\"\n        download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)\n        extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)\n    else:\n        headers = {}\n        if name_ is None:\n            raise ValueError(f\"To download from source: {source_}, `name` must be provided.\")\n        if source == \"ngc_private\":\n            api_key = os.getenv(\"NGC_API_KEY\", None)\n            if api_key is None:\n                raise ValueError(\"API key is required for ngc_private source.\")\n            else:\n                token = _get_ngc_token(api_key)\n                headers = {\"Authorization\": f\"Bearer {token}\"}\n\n        if version_ is None:\n            version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)\n        if source_ == \"github\":\n            name_ver = \"_v\".join([name_, version_]) if version_ is not None else name_\n            _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)\n        elif source_ == \"monaihosting\":\n            try:\n                extract_path = os.path.join(bundle_dir_, name_)\n                huggingface_hub.snapshot_download(repo_id=f\"MONAI/{name_}\", revision=version_, local_dir=extract_path)\n            except (huggingface_hub.errors.RevisionNotFoundError, huggingface_hub.errors.RepositoryNotFoundError):\n                # if bundle or version not found from huggingface, download from ngc monaihosting\n                _download_from_monaihosting(\n                    download_path=bundle_dir_, filename=name_, version=version_, progress=progress_\n                )\n            except urllib.error.HTTPError:\n                # if also cannot download from ngc monaihosting, download according to bundle_info\n                _download_from_bundle_info(\n                    download_path=bundle_dir_, filename=name_, version=version_, progress=progress_\n                )\n\n        elif source_ == \"ngc\":\n            _download_from_ngc(\n                download_path=bundle_dir_,\n                filename=name_,\n                version=version_,\n                remove_prefix=remove_prefix_,\n                progress=progress_,\n            )\n        elif source_ == \"ngc_private\":\n            _download_from_ngc_private(\n                download_path=bundle_dir_,\n                filename=name_,\n                version=version_,\n                remove_prefix=remove_prefix_,\n                repo=repo_,\n                headers=headers,\n            )\n        elif source_ == \"huggingface_hub\":\n            extract_path = os.path.join(bundle_dir_, name_)\n            huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)\n        else:\n            raise NotImplementedError(\n                \"Currently only download from `url`, source 'github', 'monaihosting', 'huggingface_hub' or 'ngc' are implemented,\"\n                f\"got source: {source_}.\"\n            )\n\n    _check_monai_version(bundle_dir_, name_)\n\n\ndef load(\n    name: str,\n    model: torch.nn.Module | None = None,\n    version: str | None = None,\n    workflow_type: str = \"train\",\n    model_file: str | None = None,\n    load_ts_module: bool = False,\n    bundle_dir: PathLike | None = None,\n    source: str = DEFAULT_DOWNLOAD_SOURCE,\n    repo: str | None = None,\n    remove_prefix: str | None = \"monai_\",\n    progress: bool = True,\n    device: str | None = None,\n    key_in_ckpt: str | None = None,\n    config_files: Sequence[str] = (),\n    workflow_name: str | BundleWorkflow | None = None,\n    args_file: str | None = None,\n    copy_model_args: dict | None = None,\n    net_override: dict | None = None,\n) -> object | tuple[torch.nn.Module, dict, dict] | Any:\n    \"\"\"\n    Load model weights or TorchScript module of a bundle.\n\n    Args:\n        name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.\n            for example:\n            \"spleen_ct_segmentation\", \"prostate_mri_anatomy\" in model-zoo:\n            https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.\n            \"monai_brats_mri_segmentation\" in ngc:\n            https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.\n            \"mednist_gan\" in monaihosting:\n            https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/mednist_gan/versions/0.2.0/files/mednist_gan_v0.2.0.zip\n        model: a pytorch module to be updated. Default to None, using the \"network_def\" in the bundle.\n        version: version name of the target bundle to download, like: \"0.1.0\". If `None`, will download\n            the latest version. If `source` is \"huggingface_hub\", this argument is a Git revision id.\n        workflow_type: specifies the workflow type: \"train\" or \"training\" for a training workflow,\n            or \"infer\", \"inference\", \"eval\", \"evaluation\" for a inference workflow,\n            other unsupported string will raise a ValueError.\n            default to `train` for training workflow.\n        model_file: the relative path of the model weights or TorchScript module within bundle.\n            If `None`, \"models/model.pt\" or \"models/model.ts\" will be used.\n        load_ts_module: a flag to specify if loading the TorchScript module.\n        bundle_dir: directory the weights/TorchScript module will be loaded from.\n            Default is `bundle` subfolder under `torch.hub.get_dir()`.\n        source: storage location name. This argument is used when `model_file` is not existing locally and need to be\n            downloaded first.\n            In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and\n            it should be \"ngc\", \"monaihosting\", \"github\", or \"huggingface_hub\".\n        repo: repo name. This argument is used when `url` is `None` and `source` is \"github\" or \"huggingface_hub\".\n            If `source` is \"github\", it should be in the form of \"repo_owner/repo_name/release_tag\".\n            If `source` is \"huggingface_hub\", it should be in the form of \"repo_owner/repo_name\".\n        remove_prefix: This argument is used when `source` is \"ngc\". Currently, all ngc bundles\n            have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to\n            maintain the consistency between these three sources, remove prefix is necessary.\n            Therefore, if specified, downloaded folder name will remove the prefix.\n        progress: whether to display a progress bar when downloading.\n        device: target device of returned weights or module, if `None`, prefer to \"cuda\" if existing.\n        key_in_ckpt: for nested checkpoint like `{\"model\": XXX, \"optimizer\": XXX, ...}`, specify the key of model\n            weights. if not nested checkpoint, no need to set.\n        config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,\n            see `_extra_files` in `torch.jit.load` for more details.\n        workflow_name: specified bundle workflow name, should be a string or class, default to \"ConfigWorkflow\".\n        args_file: a JSON or YAML file to provide default values for all the args in \"download\" function.\n        copy_model_args: other arguments for the `monai.networks.copy_model_state` function.\n        net_override: id-value pairs to override the parameters in the network of the bundle, default to `None`.\n\n    Returns:\n        1. If `load_ts_module` is `False` and `model` is `None`,\n            return model weights if can't find \"network_def\" in the bundle,\n            else return an instantiated network that loaded the weights.\n        2. If `load_ts_module` is `False` and `model` is not `None`,\n            return an instantiated network that loaded the weights.\n        3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,\n            the corresponding metadata dict, and extra files dict.\n            please check `monai.data.load_net_with_metadata` for more details.\n\n    \"\"\"\n    bundle_dir_ = _process_bundle_dir(bundle_dir)\n    net_override = {} if net_override is None else net_override\n    copy_model_args = {} if copy_model_args is None else copy_model_args\n\n    if device is None:\n        device = \"cuda:0\" if is_available() else \"cpu\"\n    if model_file is None:\n        model_file = os.path.join(\"models\", \"model.ts\" if load_ts_module is True else \"model.pt\")\n    if source == \"ngc\":\n        name = _add_ngc_prefix(name)\n        if remove_prefix:\n            name = _remove_ngc_prefix(name, prefix=remove_prefix)\n    full_path = os.path.join(bundle_dir_, name, model_file)\n    if not os.path.exists(full_path):\n        download(\n            name=name,\n            version=version,\n            bundle_dir=bundle_dir_,\n            source=source,\n            repo=repo,\n            remove_prefix=remove_prefix,\n            progress=progress,\n            args_file=args_file,\n        )\n\n    # loading with `torch.jit.load`\n    if load_ts_module is True:\n        return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)\n    # loading with `torch.load`\n    model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)\n\n    if not isinstance(model_dict, Mapping):\n        warnings.warn(f\"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.\")\n        model_dict = get_state_dict(model_dict)\n\n    _workflow = None\n    if model is None:\n        bundle_config_file = bundle_dir_ / name / \"configs\" / f\"{workflow_type}.json\"\n        if bundle_config_file.is_file():\n            _net_override = {f\"network_def#{key}\": value for key, value in net_override.items()}\n            _workflow = create_workflow(\n                workflow_name=workflow_name,\n                args_file=args_file,\n                config_file=str(bundle_config_file),\n                workflow_type=workflow_type,\n                **_net_override,\n            )\n        else:\n            warnings.warn(f\"Cannot find the config file: {bundle_config_file}, return state dict instead.\")\n            return model_dict\n        if _workflow is not None:\n            if not hasattr(_workflow, \"network_def\"):\n                warnings.warn(\"No available network definition in the bundle, return state dict instead.\")\n                return model_dict\n            else:\n                model = _workflow.network_def\n\n    model.to(device)  # type: ignore\n\n    copy_model_state(\n        dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args  # type: ignore\n    )\n\n    return model\n\n\ndef _get_all_bundles_info(\n    repo: str = \"Project-MONAI/model-zoo\", tag: str = \"dev\", auth_token: str | None = None\n) -> dict[str, dict[str, dict[str, Any]]]:\n    if has_requests:\n        if tag == \"hosting_storage_v1\":\n            request_url = f\"https://api.github.com/repos/{repo}/releases\"\n        else:\n            request_url = f\"https://raw.githubusercontent.com/{repo}/{tag}/models/model_info.json\"\n\n        if auth_token is not None:\n            headers = {\"Authorization\": f\"Bearer {auth_token}\"}\n            resp = requests.get(request_url, headers=headers)\n        else:\n            resp = requests.get(request_url)\n        resp.raise_for_status()\n    else:\n        raise ValueError(\"requests package is required, please install it.\")\n    releases_list = json.loads(resp.text)\n    bundle_name_pattern = re.compile(r\"_v\\d*.\")\n    bundles_info: dict[str, dict[str, dict[str, Any]]] = {}\n\n    if tag == \"hosting_storage_v1\":\n        for release in releases_list:\n            if release[\"tag_name\"] == tag:\n                for asset in release[\"assets\"]:\n                    asset_name = bundle_name_pattern.split(asset[\"name\"])[0]\n                    if asset_name not in bundles_info:\n                        bundles_info[asset_name] = {}\n                    asset_version = asset[\"name\"].split(f\"{asset_name}_v\")[-1].replace(\".zip\", \"\")\n                    bundles_info[asset_name][asset_version] = dict(asset)\n                return bundles_info\n    else:\n        for asset in releases_list.keys():\n            asset_name = bundle_name_pattern.split(asset)[0]\n            if asset_name not in bundles_info:\n                bundles_info[asset_name] = {}\n            asset_version = asset.split(f\"{asset_name}_v\")[-1]\n            bundles_info[asset_name][asset_version] = {\n                \"name\": asset,\n                \"browser_download_url\": releases_list[asset][\"source\"],\n            }\n    return bundles_info\n\n\ndef get_all_bundles_list(\n    repo: str = \"Project-MONAI/model-zoo\", tag: str = \"dev\", auth_token: str | None = None\n) -> list[tuple[str, str]]:\n    \"\"\"\n    Get all bundles names (and the latest versions) that are stored in the release of specified repository\n    with the provided tag. If tag is \"dev\", will get model information from\n    https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.\n    The default values of arguments correspond to the release of MONAI model zoo. In order to increase the\n    rate limits of calling Github APIs, you can input your personal access token.\n    Please check the following link for more details about rate limiting:\n    https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting\n\n    The following link shows how to create your personal access token:\n    https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token\n\n    Args:\n        repo: it should be in the form of \"repo_owner/repo_name/\".\n        tag: the tag name of the release.\n        auth_token: github personal access token.\n\n    Returns:\n        a list of tuple in the form of (bundle name, latest version).\n\n    \"\"\"\n\n    bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)\n    bundles_list = []\n    for bundle_name in bundles_info:\n        latest_version = sorted(bundles_info[bundle_name].keys())[-1]\n        bundles_list.append((bundle_name, latest_version))\n\n    return bundles_list\n\n\ndef get_bundle_versions(\n    bundle_name: str, repo: str = \"Project-MONAI/model-zoo\", tag: str = \"dev\", auth_token: str | None = None\n) -> dict[str, list[str] | str]:\n    \"\"\"\n    Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified\n    repository with the provided tag. If tag is \"dev\", will get model information from\n    https://raw.githubusercontent.com/repo_owner/repo_name/dev/models/model_info.json.\n    In order to increase the rate limits of calling Github APIs, you can input your personal access token.\n    Please check the following link for more details about rate limiting:\n    https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting\n\n    The following link shows how to create your personal access token:\n    https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token\n\n    Args:\n        bundle_name: bundle name.\n        repo: it should be in the form of \"repo_owner/repo_name/\".\n        tag: the tag name of the release.\n        auth_token: github personal access token.\n\n    Returns:\n        a dictionary that contains the latest version and all versions of a bundle.\n\n    \"\"\"\n\n    bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)\n    if bundle_name not in bundles_info:\n        raise ValueError(f\"bundle: {bundle_name} is not existing in repo: {repo}.\")\n    bundle_info = bundles_info[bundle_name]\n    all_versions = sorted(bundle_info.keys())\n\n    return {\"latest_version\": all_versions[-1], \"all_versions\": all_versions}\n\n\ndef get_bundle_info(\n    bundle_name: str,\n    version: str | None = None,\n    repo: str = \"Project-MONAI/model-zoo\",\n    tag: str = \"dev\",\n    auth_token: str | None = None,\n) -> dict[str, Any]:\n    \"\"\"\n    Get all information (include \"name\" and \"browser_download_url\") of a bundle\n    with the specified bundle name and version which is stored in the release of specified repository with the provided tag.\n    In order to increase the rate limits of calling Github APIs, you can input your personal access token.\n    Please check the following link for more details about rate limiting:\n    https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting\n\n    The following link shows how to create your personal access token:\n    https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token\n\n    Args:\n        bundle_name: bundle name.\n        version: version name of the target bundle, if None, the latest version will be used.\n        repo: it should be in the form of \"repo_owner/repo_name/\".\n        tag: the tag name of the release.\n        auth_token: github personal access token.\n\n    Returns:\n        a dictionary that contains the bundle's information.\n\n    \"\"\"\n\n    bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)\n    if bundle_name not in bundles_info:\n        raise ValueError(f\"bundle: {bundle_name} is not existing.\")\n    bundle_info = bundles_info[bundle_name]\n    if version is None:\n        version = sorted(bundle_info.keys())[-1]\n    if version not in bundle_info:\n        raise ValueError(f\"version: {version} of bundle: {bundle_name} is not existing.\")\n\n    return bundle_info[version]\n\n\ndef run(\n    run_id: str | None = None,\n    init_id: str | None = None,\n    final_id: str | None = None,\n    meta_file: str | Sequence[str] | None = None,\n    config_file: str | Sequence[str] | None = None,\n    logging_file: str | None = None,\n    tracking: str | dict | None = None,\n    args_file: str | None = None,\n    **override: Any,\n) -> None:\n    \"\"\"\n    Specify `config_file` to run monai bundle components and workflows.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        # Execute this module as a CLI entry:\n        python -m monai.bundle run --meta_file <meta path> --config_file <config path>\n\n        # Execute with specified `run_id=training`:\n        python -m monai.bundle run training --meta_file <meta path> --config_file <config path>\n\n        # Execute with all specified `run_id=runtest`, `init_id=inittest`, `final_id=finaltest`:\n        python -m monai.bundle run --run_id runtest --init_id inittest --final_id finaltest ...\n\n        # Override config values at runtime by specifying the component id and its new value:\n        python -m monai.bundle run --net#input_chns 1 ...\n\n        # Override config values with another config file `/path/to/another.json`:\n        python -m monai.bundle run --net %/path/to/another.json ...\n\n        # Override config values with part content of another config file:\n        python -m monai.bundle run --net %/data/other.json#net_arg ...\n\n        # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.\n        # Other args still can override the default args at runtime:\n        python -m monai.bundle run --args_file \"/workspace/data/args.json\" --config_file <config path>\n\n    Args:\n        run_id: ID name of the expected config expression to run, default to \"run\".\n            to run the config, the target config must contain this ID.\n        init_id: ID name of the expected config expression to initialize before running, default to \"initialize\".\n            it's optional for both configs and this `run` function.\n        final_id: ID name of the expected config expression to finalize after running, default to \"finalize\".\n            it's optional for both configs and this `run` function.\n        meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.\n            Default to None.\n        config_file: filepath of the config file, if `None`, must be provided in `args_file`.\n            if it is a list of file paths, the content of them will be merged.\n        logging_file: config file for `logging` module in the program. for more details:\n            https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.\n            Default to None.\n        tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible.\n            If \"mlflow\", will add `MLFlowHandler` to the parsed bundle with default tracking settings where a set of\n            common parameters shown below will be added and can be passed through the `override` parameter of this method.\n\n            - ``\"output_dir\"``: the path to save mlflow tracking outputs locally, default to \"<bundle root>/eval\".\n            - ``\"tracking_uri\"``: uri to save mlflow tracking outputs, default to \"/output_dir/mlruns\".\n            - ``\"experiment_name\"``: experiment name for this run, default to \"monai_experiment\".\n            - ``\"run_name\"``: the name of current run.\n            - ``\"save_execute_config\"``: whether to save the executed config files. It can be `False`, `/path/to/artifacts`\n              or `True`. If set to `True`, will save to the default path \"<bundle_root>/eval\". Default to `True`.\n\n            If other string, treat it as file path to load the tracking settings.\n            If `dict`, treat it as tracking settings.\n            Will patch the target config content with `tracking handlers` and the top-level items of `configs`.\n            for detailed usage examples, please check the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb.\n        args_file: a JSON or YAML file to provide default values for `run_id`, `meta_file`,\n            `config_file`, `logging`, and override pairs. so that the command line inputs can be simplified.\n        override: id-value pairs to override or add the corresponding config content.\n            e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``.\n\n    \"\"\"\n\n    workflow = create_workflow(\n        config_file=config_file,\n        args_file=args_file,\n        meta_file=meta_file,\n        logging_file=logging_file,\n        init_id=init_id,\n        run_id=run_id,\n        final_id=final_id,\n        tracking=tracking,\n        **override,\n    )\n    workflow.run()\n    workflow.finalize()\n\n\ndef run_workflow(\n    workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any\n) -> None:\n    \"\"\"\n    Specify `bundle workflow` to run monai bundle components and workflows.\n    The workflow should be subclass of `BundleWorkflow` and be available to import.\n    It can be MONAI existing bundle workflows or user customized workflows.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        # Execute this module as a CLI entry with default ConfigWorkflow:\n        python -m monai.bundle run_workflow --meta_file <meta path> --config_file <config path>\n\n        # Set the workflow to other customized BundleWorkflow subclass:\n        python -m monai.bundle run_workflow --workflow_name CustomizedWorkflow ...\n\n    Args:\n        workflow_name: specified bundle workflow name, should be a string or class, default to \"ConfigWorkflow\".\n        args_file: a JSON or YAML file to provide default values for this API.\n            so that the command line inputs can be simplified.\n        kwargs: arguments to instantiate the workflow class.\n\n    \"\"\"\n\n    workflow_ = create_workflow(workflow_name=workflow_name, args_file=args_file, **kwargs)\n    workflow_.run()\n    workflow_.finalize()\n\n\ndef verify_metadata(\n    meta_file: str | Sequence[str] | None = None,\n    filepath: PathLike | None = None,\n    create_dir: bool | None = None,\n    hash_val: str | None = None,\n    hash_type: str | None = None,\n    args_file: str | None = None,\n    **kwargs: Any,\n) -> None:\n    \"\"\"\n    Verify the provided `metadata` file based on the predefined `schema`.\n    `metadata` content must contain the `schema` field for the URL of schema file to download.\n    The schema standard follows: http://json-schema.org/.\n\n    Args:\n        meta_file: filepath of the metadata file to verify, if `None`, must be provided in `args_file`.\n            if it is a list of file paths, the content of them will be merged.\n        filepath: file path to store the downloaded schema.\n        create_dir: whether to create directories if not existing, default to `True`.\n        hash_val: if not None, define the hash value to verify the downloaded schema file.\n        hash_type: if not None, define the hash type to verify the downloaded schema file. Defaults to \"md5\".\n        args_file: a JSON or YAML file to provide default values for all the args in this function.\n            so that the command line inputs can be simplified.\n        kwargs: other arguments for `jsonschema.validate()`. for more details:\n            https://python-jsonschema.readthedocs.io/en/stable/validate/#jsonschema.validate.\n\n    \"\"\"\n\n    _args = update_kwargs(\n        args=args_file,\n        meta_file=meta_file,\n        filepath=filepath,\n        create_dir=create_dir,\n        hash_val=hash_val,\n        hash_type=hash_type,\n        **kwargs,\n    )\n    _log_input_summary(tag=\"verify_metadata\", args=_args)\n    filepath_, meta_file_, create_dir_, hash_val_, hash_type_ = _pop_args(\n        _args, \"filepath\", \"meta_file\", create_dir=True, hash_val=None, hash_type=\"md5\"\n    )\n\n    check_parent_dir(path=filepath_, create_dir=create_dir_)\n    metadata = ConfigParser.load_config_files(files=meta_file_)\n    url = metadata.get(\"schema\")\n    if url is None:\n        raise ValueError(\"must provide the `schema` field in the metadata for the URL of schema file.\")\n    download_url(url=url, filepath=filepath_, hash_val=hash_val_, hash_type=hash_type_, progress=True)\n    schema = ConfigParser.load_config_file(filepath=filepath_)\n\n    try:\n        # the rest key-values in the _args are for `validate` API\n        validate(instance=metadata, schema=schema, **_args)\n    except ValidationError as e:  # pylint: disable=E0712\n        # as the error message is very long, only extract the key information\n        raise ValueError(\n            re.compile(r\".*Failed validating\", re.S).findall(str(e))[0] + f\" against schema `{url}`.\"\n        ) from e\n    logger.info(\"metadata is verified with no error.\")\n\n\ndef _get_net_io_info(parser: ConfigParser | None = None, prefix: str = \"_meta_#network_data_format\") -> tuple:\n    \"\"\"\n    Get the input and output information defined in the metadata.\n\n    Args:\n        parser: a ConfigParser of the given bundle.\n        prefix: a prefix for the input and output ID, which will be combined as `prefix#inputs` and\n            `prefix#outputs` to parse the input and output information in the `metadata.json` file of\n            a bundle, default to `meta_#network_data_format`.\n\n    Returns:\n        input_channels: the channel number of the `image` input.\n        input_spatial_shape: the spatial shape of the `image` input.\n        input_dtype: the data type of the `image` input.\n        output_channels: the channel number of the output.\n        output_dtype: the data type of the output.\n    \"\"\"\n    if not isinstance(parser, ConfigParser):\n        raise AttributeError(f\"Parameter parser should be a ConfigParser, got {type(parser)}.\")\n\n    prefix_key = f\"{prefix}#inputs\"\n    key = f\"{prefix_key}#image#num_channels\"\n    input_channels = parser.get(key)\n    key = f\"{prefix_key}#image#spatial_shape\"\n    input_spatial_shape = tuple(parser.get(key))\n    key = f\"{prefix_key}#image#dtype\"\n    input_dtype = get_equivalent_dtype(parser.get(key), torch.Tensor)\n\n    prefix_key = f\"{prefix}#outputs\"\n    key = f\"{prefix_key}#pred#num_channels\"\n    output_channels = parser.get(key)\n    key = f\"{prefix_key}#pred#dtype\"\n    output_dtype = get_equivalent_dtype(parser.get(key), torch.Tensor)\n\n    return input_channels, input_spatial_shape, input_dtype, output_channels, output_dtype\n\n\ndef _get_fake_input_shape(parser: ConfigParser) -> tuple:\n    \"\"\"\n    Get a fake input shape e.g. [N, C, H, W] or [N, C, H, W, D], whose batch size is 1, from the given parser.\n\n    Args:\n        parser: a ConfigParser which contains the i/o information of a bundle.\n    \"\"\"\n    input_channels, input_spatial_shape, _, _, _ = _get_net_io_info(parser=parser)\n    spatial_shape = _get_fake_spatial_shape(input_spatial_shape)\n    input_shape = (1, input_channels, *spatial_shape)\n    return input_shape\n\n\ndef verify_net_in_out(\n    net_id: str | None = None,\n    meta_file: str | Sequence[str] | None = None,\n    config_file: str | Sequence[str] | None = None,\n    device: str | None = None,\n    p: int | None = None,\n    n: int | None = None,\n    any: int | None = None,\n    extra_forward_args: dict | None = None,\n    args_file: str | None = None,\n    **override: Any,\n) -> None:\n    \"\"\"\n    Verify the input and output data shape and data type of network defined in the metadata.\n    Will test with fake Tensor data according to the required data shape in `metadata`.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        python -m monai.bundle verify_net_in_out network --meta_file <meta path> --config_file <config path>\n\n    Args:\n        net_id: ID name of the network component to verify, it must be `torch.nn.Module`.\n        meta_file: filepath of the metadata file to get network args, if `None`, must be provided in `args_file`.\n            if it is a list of file paths, the content of them will be merged.\n        config_file: filepath of the config file to get network definition, if `None`, must be provided in `args_file`.\n            if it is a list of file paths, the content of them will be merged.\n        device: target device to run the network forward computation, if None, prefer to \"cuda\" if existing.\n        p: power factor to generate fake data shape if dim of expected shape is \"x**p\", default to 1.\n        n: multiply factor to generate fake data shape if dim of expected shape is \"x*n\", default to 1.\n        any: specified size to generate fake data shape if dim of expected shape is \"*\", default to 1.\n        extra_forward_args: a dictionary that contains other args for the forward function of the network.\n            Default to an empty dictionary.\n        args_file: a JSON or YAML file to provide default values for `net_id`, `meta_file`, `config_file`,\n            `device`, `p`, `n`, `any`, and override pairs. so that the command line inputs can be simplified.\n        override: id-value pairs to override or add the corresponding config content.\n            e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.\n\n    \"\"\"\n\n    _args = update_kwargs(\n        args=args_file,\n        net_id=net_id,\n        meta_file=meta_file,\n        config_file=config_file,\n        device=device,\n        p=p,\n        n=n,\n        any=any,\n        extra_forward_args=extra_forward_args,\n        **override,\n    )\n    _log_input_summary(tag=\"verify_net_in_out\", args=_args)\n    config_file_, meta_file_, net_id_, device_, p_, n_, any_, extra_forward_args_ = _pop_args(\n        _args,\n        \"config_file\",\n        \"meta_file\",\n        net_id=\"\",\n        device=\"cuda:0\" if is_available() else \"cpu\",\n        p=1,\n        n=1,\n        any=1,\n        extra_forward_args={},\n    )\n\n    parser = ConfigParser()\n    parser.read_config(f=config_file_)\n    parser.read_meta(f=meta_file_)\n\n    # the rest key-values in the _args are to override config content\n    for k, v in _args.items():\n        parser[k] = v\n\n    input_channels, input_spatial_shape, input_dtype, output_channels, output_dtype = _get_net_io_info(parser=parser)\n    try:\n        key: str = net_id_  # mark the full id when KeyError\n        net = parser.get_parsed_content(key).to(device_)\n    except KeyError as e:\n        raise KeyError(f\"Failed to verify due to missing expected key in the config: {key}.\") from e\n\n    net.eval()\n    with torch.no_grad():\n        spatial_shape = _get_fake_spatial_shape(input_spatial_shape, p=p_, n=n_, any=any_)\n        test_data = torch.rand(*(1, input_channels, *spatial_shape), dtype=input_dtype, device=device_)\n        if input_dtype == torch.float16:\n            # fp16 can only be executed in gpu mode\n            net.to(\"cuda\")\n\n            with torch.autocast(\"cuda\"):\n                output = net(test_data.cuda(), **extra_forward_args_)\n            net.to(device_)\n        else:\n            output = net(test_data, **extra_forward_args_)\n        if output.shape[1] != output_channels:\n            raise ValueError(f\"output channel number `{output.shape[1]}` doesn't match: `{output_channels}`.\")\n        if output.dtype != output_dtype:\n            raise ValueError(f\"dtype of output data `{output.dtype}` doesn't match: {output_dtype}.\")\n    logger.info(\"data shape of network is verified with no error.\")\n\n\ndef _export(\n    converter: Callable,\n    saver: Callable,\n    parser: ConfigParser,\n    net_id: str,\n    filepath: str,\n    ckpt_file: str,\n    config_file: str,\n    key_in_ckpt: str,\n    **kwargs: Any,\n) -> None:\n    \"\"\"\n    Export a model defined in the parser to a new one specified by the converter.\n\n    Args:\n        converter: a callable object that takes a torch.nn.module and kwargs as input and\n            converts the module to another type.\n        saver: a callable object that accepts the converted model to save, a filepath to save to, meta values\n            (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.\n        parser: a ConfigParser of the bundle to be converted.\n        net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.\n        filepath: filepath to export, if filename has no extension, it becomes `.ts`.\n        ckpt_file: filepath of the model checkpoint to load.\n        config_file: filepath of the config file to save in the converted model,the saved key in the converted\n            model is the config filename without extension, and the saved config value is always serialized in\n            JSON format no matter the original file format is JSON or YAML. it can be a single file or a list\n            of files.\n        key_in_ckpt: for nested checkpoint like `{\"model\": XXX, \"optimizer\": XXX, ...}`, specify the key of model\n            weights. if not nested checkpoint, no need to set.\n        kwargs: key arguments for the converter.\n\n    \"\"\"\n    net = parser.get_parsed_content(net_id)\n    if has_ignite:\n        # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver\n        Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)\n    else:\n        ckpt = torch.load(ckpt_file, weights_only=True)\n        copy_model_state(dst=net, src=ckpt if key_in_ckpt == \"\" else ckpt[key_in_ckpt])\n\n    # Use the given converter to convert a model and save with metadata, config content\n    net = converter(model=net, **kwargs)\n\n    extra_files: dict = {}\n    for i in ensure_tuple(config_file):\n        # split the filename and directory\n        filename = os.path.basename(i)\n        # remove extension\n        filename, _ = os.path.splitext(filename)\n        # because all files are stored as JSON their name parts without extension must be unique\n        if filename in extra_files:\n            raise ValueError(f\"Filename part '{filename}' is given multiple times in config file list.\")\n        # the file may be JSON or YAML but will get loaded and dumped out again as JSON\n        extra_files[filename] = json.dumps(ConfigParser.load_config_file(i)).encode()\n\n    # add .json extension to all extra files which are always encoded as JSON\n    extra_files = {k + \".json\": v for k, v in extra_files.items()}\n\n    meta_values = parser.get().pop(\"_meta_\", None)\n    saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)\n\n    logger.info(f\"exported to file: {filepath}.\")\n\n\ndef onnx_export(\n    net_id: str | None = None,\n    filepath: PathLike | None = None,\n    ckpt_file: str | None = None,\n    meta_file: str | Sequence[str] | None = None,\n    config_file: str | Sequence[str] | None = None,\n    key_in_ckpt: str | None = None,\n    use_trace: bool | None = None,\n    input_shape: Sequence[int] | None = None,\n    args_file: str | None = None,\n    converter_kwargs: Mapping | None = None,\n    **override: Any,\n) -> None:\n    \"\"\"\n    Export the model checkpoint to an onnx model.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        python -m monai.bundle onnx_export network --filepath <export path> --ckpt_file <checkpoint path> ...\n\n    Args:\n        net_id: ID name of the network component in the config, it must be `torch.nn.Module`.\n        filepath: filepath where the onnx model is saved to.\n        ckpt_file: filepath of the model checkpoint to load.\n        meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.\n        config_file: filepath of the config file that contains extract network information,\n        key_in_ckpt: for nested checkpoint like `{\"model\": XXX, \"optimizer\": XXX, ...}`, specify the key of model\n            weights. if not nested checkpoint, no need to set.\n        use_trace: whether using `torch.jit.trace` to convert the pytorch model to torchscript model.\n        input_shape: a shape used to generate the random input of the network, when converting the model to an\n            onnx model. Should be a list like [N, C, H, W] or [N, C, H, W, D]. If not given, will try to parse from\n            the `metadata` config.\n        args_file: a JSON or YAML file to provide default values for all the parameters of this function, so that\n            the command line inputs can be simplified.\n        converter_kwargs: extra arguments that are needed by `convert_to_onnx`, except ones that already exist in the\n            input parameters.\n        override: id-value pairs to override or add the corresponding config content.\n            e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.\n\n    \"\"\"\n    _args = update_kwargs(\n        args=args_file,\n        net_id=net_id,\n        filepath=filepath,\n        meta_file=meta_file,\n        config_file=config_file,\n        ckpt_file=ckpt_file,\n        key_in_ckpt=key_in_ckpt,\n        use_trace=use_trace,\n        input_shape=input_shape,\n        converter_kwargs=converter_kwargs,\n        **override,\n    )\n    _log_input_summary(tag=\"onnx_export\", args=_args)\n    (\n        filepath_,\n        ckpt_file_,\n        config_file_,\n        net_id_,\n        meta_file_,\n        key_in_ckpt_,\n        use_trace_,\n        input_shape_,\n        converter_kwargs_,\n    ) = _pop_args(\n        _args,\n        \"filepath\",\n        \"ckpt_file\",\n        \"config_file\",\n        net_id=\"\",\n        meta_file=None,\n        key_in_ckpt=\"\",\n        use_trace=False,\n        input_shape=None,\n        converter_kwargs={},\n    )\n\n    parser = ConfigParser()\n\n    parser.read_config(f=config_file_)\n    if meta_file_ is not None:\n        parser.read_meta(f=meta_file_)\n\n    # the rest key-values in the _args are to override config content\n    for k, v in _args.items():\n        parser[k] = v\n\n    # The convert_to_onnx must have an `inputs` as input, no matter what the `use_trace` is.\n    # If the `input_shape` is not provided, will try to parse it from the parser to generate a random `inputs`.\n    if not input_shape_:\n        input_shape_ = _get_fake_input_shape(parser=parser)\n\n    inputs_ = [torch.rand(input_shape_)]\n\n    converter_kwargs_.update({\"inputs\": inputs_, \"use_trace\": use_trace_})\n\n    def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:\n        onnx.save(onnx_obj, filename_prefix_or_stream)\n\n    _export(\n        convert_to_onnx,\n        save_onnx,\n        parser,\n        net_id=net_id_,\n        filepath=filepath_,\n        ckpt_file=ckpt_file_,\n        config_file=config_file_,\n        key_in_ckpt=key_in_ckpt_,\n        **converter_kwargs_,\n    )\n\n\ndef ckpt_export(\n    net_id: str | None = None,\n    filepath: PathLike | None = None,\n    ckpt_file: str | None = None,\n    meta_file: str | Sequence[str] | None = None,\n    config_file: str | Sequence[str] | None = None,\n    key_in_ckpt: str | None = None,\n    use_trace: bool | None = None,\n    input_shape: Sequence[int] | None = None,\n    args_file: str | None = None,\n    converter_kwargs: Mapping | None = None,\n    **override: Any,\n) -> None:\n    \"\"\"\n    Export the model checkpoint to the given filepath with metadata and config included as JSON files.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        python -m monai.bundle ckpt_export network --filepath <export path> --ckpt_file <checkpoint path> ...\n\n    Args:\n        net_id: ID name of the network component in the config, it must be `torch.nn.Module`.\n            Default to \"network_def\".\n        filepath: filepath to export, if filename has no extension it becomes `.ts`.\n            Default to \"models/model.ts\" under \"os.getcwd()\" if `bundle_root` is not specified.\n        ckpt_file: filepath of the model checkpoint to load.\n            Default to \"models/model.pt\" under \"os.getcwd()\" if `bundle_root` is not specified.\n        meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.\n            Default to \"configs/metadata.json\" under \"os.getcwd()\" if `bundle_root` is not specified.\n        config_file: filepath of the config file to save in TorchScript model and extract network information,\n            the saved key in the TorchScript model is the config filename without extension, and the saved config\n            value is always serialized in JSON format no matter the original file format is JSON or YAML.\n            it can be a single file or a list of files. if `None`, must be provided in `args_file`.\n        key_in_ckpt: for nested checkpoint like `{\"model\": XXX, \"optimizer\": XXX, ...}`, specify the key of model\n            weights. if not nested checkpoint, no need to set.\n        use_trace: whether using `torch.jit.trace` to convert the PyTorch model to TorchScript model.\n        input_shape: a shape used to generate the random input of the network, when converting the model to a\n            TorchScript model. Should be a list like [N, C, H, W] or [N, C, H, W, D]. If not given, will try to\n            parse from the `metadata` config.\n        args_file: a JSON or YAML file to provide default values for all the parameters of this function, so that\n            the command line inputs can be simplified.\n        converter_kwargs: extra arguments that are needed by `convert_to_torchscript`, except ones that already exist\n            in the input parameters.\n        override: id-value pairs to override or add the corresponding config content.\n            e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.\n\n    \"\"\"\n    _args = update_kwargs(\n        args=args_file,\n        net_id=net_id,\n        filepath=filepath,\n        meta_file=meta_file,\n        config_file=config_file,\n        ckpt_file=ckpt_file,\n        key_in_ckpt=key_in_ckpt,\n        use_trace=use_trace,\n        input_shape=input_shape,\n        converter_kwargs=converter_kwargs,\n        **override,\n    )\n    _log_input_summary(tag=\"ckpt_export\", args=_args)\n    (\n        config_file_,\n        filepath_,\n        ckpt_file_,\n        net_id_,\n        meta_file_,\n        key_in_ckpt_,\n        use_trace_,\n        input_shape_,\n        converter_kwargs_,\n    ) = _pop_args(\n        _args,\n        \"config_file\",\n        filepath=None,\n        ckpt_file=None,\n        net_id=None,\n        meta_file=None,\n        key_in_ckpt=\"\",\n        use_trace=False,\n        input_shape=None,\n        converter_kwargs={},\n    )\n    bundle_root = _args.get(\"bundle_root\", os.getcwd())\n\n    parser = ConfigParser()\n    parser.read_config(f=config_file_)\n    meta_file_ = os.path.join(bundle_root, \"configs\", \"metadata.json\") if meta_file_ is None else meta_file_\n    if os.path.exists(meta_file_):\n        parser.read_meta(f=meta_file_)\n\n    # the rest key-values in the _args are to override config content\n    for k, v in _args.items():\n        parser[k] = v\n\n    filepath_ = os.path.join(bundle_root, \"models\", \"model.ts\") if filepath_ is None else filepath_\n    ckpt_file_ = os.path.join(bundle_root, \"models\", \"model.pt\") if ckpt_file_ is None else ckpt_file_\n    if not os.path.exists(ckpt_file_):\n        raise FileNotFoundError(f'Checkpoint file \"{ckpt_file_}\" not found, please specify it in argument \"ckpt_file\".')\n\n    net_id_ = \"network_def\" if net_id_ is None else net_id_\n    try:\n        parser.get_parsed_content(net_id_)\n    except ValueError as e:\n        raise ValueError(\n            f'Network definition \"{net_id_}\" cannot be found in \"{config_file_}\", specify name with argument \"net_id\".'\n        ) from e\n\n    # When export through torch.jit.trace without providing input_shape, will try to parse one from the parser.\n    if (not input_shape_) and use_trace:\n        input_shape_ = _get_fake_input_shape(parser=parser)\n\n    inputs_: Sequence[Any] | None = [torch.rand(input_shape_)] if input_shape_ else None\n\n    converter_kwargs_.update({\"inputs\": inputs_, \"use_trace\": use_trace_})\n    # Use the given converter to convert a model and save with metadata, config content\n\n    save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)\n\n    _export(\n        convert_to_torchscript,\n        save_ts,\n        parser,\n        net_id=net_id_,\n        filepath=filepath_,\n        ckpt_file=ckpt_file_,\n        config_file=config_file_,\n        key_in_ckpt=key_in_ckpt_,\n        **converter_kwargs_,\n    )\n\n\ndef trt_export(\n    net_id: str | None = None,\n    filepath: PathLike | None = None,\n    ckpt_file: str | None = None,\n    meta_file: str | Sequence[str] | None = None,\n    config_file: str | Sequence[str] | None = None,\n    key_in_ckpt: str | None = None,\n    precision: str | None = None,\n    input_shape: Sequence[int] | None = None,\n    use_trace: bool | None = None,\n    dynamic_batchsize: Sequence[int] | None = None,\n    device: int | None = None,\n    use_onnx: bool | None = None,\n    onnx_input_names: Sequence[str] | None = None,\n    onnx_output_names: Sequence[str] | None = None,\n    args_file: str | None = None,\n    converter_kwargs: Mapping | None = None,\n    **override: Any,\n) -> None:\n    \"\"\"\n    Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.\n    Currently, this API only supports converting models whose inputs are all tensors.\n    Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.\n    Review the TensorRT Support Matrix for which GPUs are supported.\n\n    There are two ways to export a model:\n    1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.\n    2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->\n    TensorRT engine-based TorchScript.\n\n    When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT\n    may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through\n    the second way, some Python data structures like `dict` are not supported. And some TorchScript models are\n    not supported by the ONNX if exported through `torch.jit.script`.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        python -m monai.bundle trt_export --net_id <network definition> --filepath <export path> \\\n            --ckpt_file <checkpoint path> --input_shape <input shape> --dynamic_batchsize <batch range> ...\n\n    Args:\n        net_id: ID name of the network component in the config, it must be `torch.nn.Module`.\n        filepath: filepath to export, if filename has no extension, it becomes `.ts`.\n        ckpt_file: filepath of the model checkpoint to load.\n        meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.\n        config_file: filepath of the config file to save in the TensorRT based TorchScript model and extract network\n            information, the saved key in the model is the config filename without extension, and the saved config\n            value is always serialized in JSON format no matter the original file format is JSON or YAML.\n            it can be a single file or a list of files. if `None`, must be provided in `args_file`.\n        key_in_ckpt: for nested checkpoint like `{\"model\": XXX, \"optimizer\": XXX, ...}`, specify the key of model\n            weights. if not nested checkpoint, no need to set.\n        precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.\n        input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or\n            [N, C, H, W, D]. If not given, will try to parse from the `metadata` config.\n        use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to\n            a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True).\n        dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be\n            converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of\n            model input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize\n            that the TensorRT tries to fit. The `OPT_BATCH` should be the most frequently used input batchsize in\n            the application.\n        device: the target GPU index to convert and verify the model.\n        use_onnx: whether using the ONNX-TensorRT way to export the TensorRT engine-based TorchScript model.\n        onnx_input_names: optional input names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be\n            a sequence like `['input_0', 'input_1', ..., 'input_N']` where N equals to the number of the model inputs. If not\n            given, will use `['input_0']`, which supposes the model only has one input.\n        onnx_output_names: optional output names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be\n            a sequence like `['output_0', 'output_1', ..., 'output_N']` where N equals to the number of the model outputs. If\n            not given, will use `['output_0']`, which supposes the model only has one output.\n        args_file: a JSON or YAML file to provide default values for all the parameters of this function, so that\n            the command line inputs can be simplified.\n        converter_kwargs: extra arguments that are needed by `convert_to_trt`, except ones that already exist in the\n            input parameters.\n        override: id-value pairs to override or add the corresponding config content.\n            e.g. ``--_meta#network_data_format#inputs#image#num_channels 3``.\n\n    \"\"\"\n    _args = update_kwargs(\n        args=args_file,\n        net_id=net_id,\n        filepath=filepath,\n        meta_file=meta_file,\n        config_file=config_file,\n        ckpt_file=ckpt_file,\n        key_in_ckpt=key_in_ckpt,\n        precision=precision,\n        input_shape=input_shape,\n        use_trace=use_trace,\n        dynamic_batchsize=dynamic_batchsize,\n        device=device,\n        use_onnx=use_onnx,\n        onnx_input_names=onnx_input_names,\n        onnx_output_names=onnx_output_names,\n        converter_kwargs=converter_kwargs,\n        **override,\n    )\n    _log_input_summary(tag=\"trt_export\", args=_args)\n    (\n        filepath_,\n        ckpt_file_,\n        config_file_,\n        net_id_,\n        meta_file_,\n        key_in_ckpt_,\n        precision_,\n        input_shape_,\n        use_trace_,\n        dynamic_batchsize_,\n        device_,\n        use_onnx_,\n        onnx_input_names_,\n        onnx_output_names_,\n        converter_kwargs_,\n    ) = _pop_args(\n        _args,\n        \"filepath\",\n        \"ckpt_file\",\n        \"config_file\",\n        net_id=\"\",\n        meta_file=None,\n        key_in_ckpt=\"\",\n        precision=\"fp32\",\n        input_shape=[],\n        use_trace=False,\n        dynamic_batchsize=None,\n        device=None,\n        use_onnx=False,\n        onnx_input_names=[\"input_0\"],\n        onnx_output_names=[\"output_0\"],\n        converter_kwargs={},\n    )\n\n    parser = ConfigParser()\n\n    parser.read_config(f=config_file_)\n    if meta_file_ is not None:\n        parser.read_meta(f=meta_file_)\n\n    # the rest key-values in the _args are to override config content\n    for k, v in _args.items():\n        parser[k] = v\n\n    # The convert_to_trt must have an `input_shape_` as input, no matter what the `use_trace` is.\n    # If the `input_shape` is not provided, will try to parse it from the parser`.\n    if not input_shape_:\n        input_shape_ = _get_fake_input_shape(parser=parser)\n\n    trt_api_parameters = {\n        \"precision\": precision_,\n        \"input_shape\": input_shape_,\n        \"dynamic_batchsize\": dynamic_batchsize_,\n        \"use_trace\": use_trace_,\n        \"device\": device_,\n        \"use_onnx\": use_onnx_,\n        \"onnx_input_names\": onnx_input_names_,\n        \"onnx_output_names\": onnx_output_names_,\n    }\n    converter_kwargs_.update(trt_api_parameters)\n\n    save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)\n\n    _export(\n        convert_to_trt,\n        save_ts,\n        parser,\n        net_id=net_id_,\n        filepath=filepath_,\n        ckpt_file=ckpt_file_,\n        config_file=config_file_,\n        key_in_ckpt=key_in_ckpt_,\n        **converter_kwargs_,\n    )\n\n\ndef init_bundle(\n    bundle_dir: PathLike,\n    ckpt_file: PathLike | None = None,\n    network: torch.nn.Module | None = None,\n    dataset_license: bool = False,\n    metadata_str: dict | str | None = None,\n    inference_str: dict | str | None = None,\n) -> None:\n    \"\"\"\n    Initialise a new bundle directory with some default configuration files and optionally network weights.\n\n    Typical usage example:\n\n    .. code-block:: bash\n\n        python -m monai.bundle init_bundle /path/to/bundle_dir network_ckpt.pt\n\n    Args:\n        bundle_dir: directory name to create, must not exist but parent direct must exist\n        ckpt_file: optional checkpoint file to copy into bundle\n        network: if given instead of ckpt_file this network's weights will be stored in bundle\n        dataset_license: if `True`, a default license file called \"data_license.txt\" will be produced. This\n            file is required if there are any license conditions stated for data your bundle uses.\n        metadata_str: optional metadata string to write to bundle, if not given a default will be used.\n        inference_str: optional inference string to write to bundle, if not given a default will be used.\n    \"\"\"\n    if metadata_str is None:\n        metadata_str = DEFAULT_METADATA\n    if inference_str is None:\n        inference_str = DEFAULT_INFERENCE\n\n    bundle_dir = Path(bundle_dir).absolute()\n\n    if bundle_dir.exists():\n        raise ValueError(f\"Specified bundle directory '{str(bundle_dir)}' already exists\")\n\n    if not bundle_dir.parent.is_dir():\n        raise ValueError(f\"Parent directory of specified bundle directory '{str(bundle_dir)}' does not exist\")\n\n    configs_dir = bundle_dir / \"configs\"\n    models_dir = bundle_dir / \"models\"\n    docs_dir = bundle_dir / \"docs\"\n\n    bundle_dir.mkdir()\n    configs_dir.mkdir()\n    models_dir.mkdir()\n    docs_dir.mkdir()\n\n    if isinstance(metadata_str, dict):\n        metadata_str = json.dumps(metadata_str, indent=4)\n\n    if isinstance(inference_str, dict):\n        inference_str = json.dumps(inference_str, indent=4)\n\n    with open(str(configs_dir / \"metadata.json\"), \"w\") as o:\n        o.write(metadata_str)\n\n    with open(str(configs_dir / \"inference.json\"), \"w\") as o:\n        o.write(inference_str)\n\n    with open(str(docs_dir / \"README.md\"), \"w\") as o:\n        readme = \"\"\"\n        # Your Model Name\n\n        Describe your model here and how to run it, for example using `inference.json`:\n\n        ```\n        python -m monai.bundle run \\\n            --meta_file /path/to/bundle/configs/metadata.json \\\n            --config_file /path/to/bundle/configs/inference.json \\\n            --dataset_dir ./input \\\n            --bundle_root /path/to/bundle\n        ```\n        \"\"\"\n\n        o.write(dedent(readme))\n\n    with open(str(bundle_dir / \"LICENSE\"), \"w\") as o:\n        o.write(\"Select a license and place its terms here\\n\")\n\n    if dataset_license is True:\n        with open(str(docs_dir / \"data_license.txt\"), \"w\") as o:\n            o.write(\"Select a license for dataset and place its terms here\\n\")\n\n    if ckpt_file is not None:\n        copyfile(str(ckpt_file), str(models_dir / \"model.pt\"))\n    elif network is not None:\n        save_state(network, str(models_dir / \"model.pt\"))\n\n\ndef _add_model_card_metadata(new_modelcard_path):\n    # Extract license from LICENSE file\n    license_name = \"unknown\"\n    license_path = os.path.join(os.path.dirname(new_modelcard_path), \"LICENSE\")\n    if os.path.exists(license_path):\n        with open(license_path) as file:\n            content = file.read()\n        if \"Apache License\" in content and \"Version 2.0\" in content:\n            license_name = \"apache-2.0\"\n        elif \"MIT License\" in content:\n            license_name = \"mit\"\n    # Add relevant tags\n    tags = \"- monai\\n- medical\\nlibrary_name: monai\\n\"\n    # Create tag section\n    tag_content = f\"---\\ntags:\\n{tags}license: {license_name}\\n---\"\n\n    # Update model card\n    with open(new_modelcard_path) as file:\n        content = file.read()\n    new_content = tag_content + \"\\n\" + content\n    with open(new_modelcard_path, \"w\") as file:\n        file.write(new_content)\n\n\ndef push_to_hf_hub(\n    repo: str,\n    name: str,\n    bundle_dir: str,\n    token: str | None = None,\n    private: bool | None = True,\n    version: str | None = None,\n    tag_as_latest_version: bool | None = False,\n    **upload_folder_kwargs: Any,\n) -> Any:\n    \"\"\"\n    Push a MONAI bundle to the Hugging Face Hub.\n\n    Typical usage examples:\n\n    .. code-block:: bash\n\n        python -m monai.bundle push_to_hf_hub --repo <HF repository id> --name <bundle name> \\\n            --bundle_dir <bundle directory> --version <version> ...\n\n    Args:\n        repo: namespace (user or organization) and a repo name separated by a /, e.g. `hf_username/bundle_name`\n        bundle_name: name of the bundle directory to push.\n        bundle_dir: path to the bundle directory.\n        token: Hugging Face authentication token. Default is `None` (will default to the stored token).\n        private: Private visibility of the repository on Hugging Face. Default is `True`.\n        version_name: Name of the version tag to create. Default is `None` (no version tag is created).\n        tag_as_latest_version: Whether to tag the commit as `latest_version`.\n            This version will downloaded by default when using `bundle.download()`. Default is `False`.\n        upload_folder_kwargs: Keyword arguments to pass to `HfApi.upload_folder`.\n\n    Returns:\n        repo_url: URL of the Hugging Face repo\n    \"\"\"\n    # Connect to API and create repo\n    hf_api = huggingface_hub.HfApi(token=token)\n    hf_api.create_repo(repo_id=repo, private=private, exist_ok=True)\n\n    # Create model card in bundle directory\n    new_modelcard_path = os.path.join(bundle_dir, name, \"README.md\")\n    modelcard_path = os.path.join(bundle_dir, name, \"docs\", \"README.md\")\n    if os.path.exists(modelcard_path):\n        # Copy README from old path if it exists\n        copyfile(modelcard_path, new_modelcard_path)\n        _add_model_card_metadata(new_modelcard_path)\n\n    # Upload bundle folder to repo\n    repo_url = hf_api.upload_folder(repo_id=repo, folder_path=os.path.join(bundle_dir, name), **upload_folder_kwargs)\n\n    # Create version tag if specified\n    if version is not None:\n        hf_api.create_tag(repo_id=repo, tag=version, exist_ok=True)\n\n    # Optionally tag as `latest_version`\n    if tag_as_latest_version:\n        hf_api.create_tag(repo_id=repo, tag=\"latest_version\", exist_ok=True)\n\n    return repo_url\n\n\ndef create_workflow(\n    workflow_name: str | BundleWorkflow | None = None,\n    config_file: str | Sequence[str] | None = None,\n    args_file: str | None = None,\n    **kwargs: Any,\n) -> Any:\n    \"\"\"\n    Specify `bundle workflow` to create monai bundle workflows.\n    The workflow should be subclass of `BundleWorkflow` and be available to import.\n    It can be MONAI existing bundle workflows or user customized workflows.\n\n    Typical usage examples:\n\n    .. code-block:: python\n\n        # Specify config_file path to create workflow:\n        workflow = create_workflow(config_file=\"/workspace/spleen_ct_segmentation/configs/train.json\", workflow_type=\"train\")\n\n        # Set the workflow to other customized BundleWorkflow subclass to create workflow:\n        workflow = create_workflow(workflow_name=CustomizedWorkflow)\n\n    Args:\n        workflow_name: specified bundle workflow name, should be a string or class, default to \"ConfigWorkflow\".\n        config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged.\n        args_file: a JSON or YAML file to provide default values for this API.\n            so that the command line inputs can be simplified.\n        kwargs: arguments to instantiate the workflow class.\n\n    \"\"\"\n    _args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)\n    (workflow_name, config_file) = _pop_args(\n        _args, workflow_name=ConfigWorkflow, config_file=None\n    )  # the default workflow name is \"ConfigWorkflow\"\n    if isinstance(workflow_name, str):\n        workflow_class, has_built_in = optional_import(\"monai.bundle\", name=str(workflow_name))  # search built-in\n        if not has_built_in:\n            workflow_class = locate(str(workflow_name))  # search dotted path\n        if workflow_class is None:\n            raise ValueError(f\"cannot locate specified workflow class: {workflow_name}.\")\n    elif issubclass(workflow_name, BundleWorkflow):  # type: ignore\n        workflow_class = workflow_name\n    else:\n        raise ValueError(\n            \"Argument `workflow_name` must be a bundle workflow class name\"\n            f\"or subclass of BundleWorkflow, got: {workflow_name}.\"\n        )\n\n    if config_file is not None:\n        workflow_ = workflow_class(config_file=config_file, **_args)\n    else:\n        workflow_ = workflow_class(**_args)\n\n    workflow_.initialize()\n    _log_input_summary(tag=\"run\", args=_args)\n    return workflow_\n\n\ndef download_large_files(bundle_path: str | None = None, large_file_name: str | None = None) -> None:\n    \"\"\"\n    This utility allows you to download large files from a bundle. It supports file suffixes like \".yml\", \".yaml\", and \".json\".\n    If you don't specify a `large_file_name`, it will automatically search for large files among the supported suffixes.\n\n    Typical usage examples:\n    .. code-block:: bash\n\n        # Execute this module as a CLI entry to download large files from a bundle path:\n        python -m monai.bundle download_large_files --bundle_path <bundle_path>\n\n        # Execute this module as a CLI entry to download large files from the bundle path with a specified `large_file_name`:\n        python -m monai.bundle download_large_files --bundle_path <bundle_path> --large_file_name large_files.yaml\n\n    Args:\n        bundle_path: (Optional) The path to the bundle where the files are located. Default is `os.getcwd()`.\n        large_file_name: (Optional) The name of the large file to be downloaded.\n\n    \"\"\"\n    bundle_path = os.getcwd() if bundle_path is None else bundle_path\n    if large_file_name is None:\n        large_file_path = list(Path(bundle_path).glob(\"large_files*\"))\n        large_file_path = list(filter(lambda x: x.suffix in [\".yml\", \".yaml\", \".json\"], large_file_path))\n        if len(large_file_path) == 0:\n            raise FileNotFoundError(f\"Cannot find the large_files.yml/yaml/json under {bundle_path}.\")\n\n    parser = ConfigParser()\n    parser.read_config(large_file_path)\n    large_files_list = parser.get()[\"large_files\"]\n    for lf_data in large_files_list:\n        lf_data[\"fuzzy\"] = True\n        if \"hash_val\" in lf_data and lf_data.get(\"hash_val\", \"\") == \"\":\n            lf_data.pop(\"hash_val\")\n        if \"hash_type\" in lf_data and lf_data.get(\"hash_type\", \"\") == \"\":\n            lf_data.pop(\"hash_type\")\n        lf_data[\"filepath\"] = os.path.join(bundle_path, lf_data[\"path\"])\n        lf_data.pop(\"path\")\n        download_url(**lf_data)\n"
  },
  {
    "path": "monai/bundle/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport warnings\nimport zipfile\nfrom typing import Any\n\nfrom monai.config.deviceconfig import get_config_values\nfrom monai.utils import optional_import\n\nyaml, _ = optional_import(\"yaml\")\n\n__all__ = [\n    \"ID_REF_KEY\",\n    \"ID_SEP_KEY\",\n    \"EXPR_KEY\",\n    \"MACRO_KEY\",\n    \"MERGE_KEY\",\n    \"DEFAULT_MLFLOW_SETTINGS\",\n    \"DEFAULT_EXP_MGMT_SETTINGS\",\n]\n\nID_REF_KEY = \"@\"  # start of a reference to a ConfigItem\nID_SEP_KEY = \"::\"  # separator for the ID of a ConfigItem\nEXPR_KEY = \"$\"  # start of a ConfigExpression\nMACRO_KEY = \"%\"  # start of a macro of a config\nMERGE_KEY = \"+\"  # prefix indicating merge instead of override in case of multiple configs.\n\n_conf_values = get_config_values()\n\nDEFAULT_METADATA = {\n    \"version\": \"0.0.1\",\n    \"changelog\": {\"0.0.1\": \"Initial version\"},\n    \"monai_version\": _conf_values[\"MONAI\"],\n    \"pytorch_version\": str(_conf_values[\"Pytorch\"]).split(\"+\")[0].split(\"a\")[0],  # 1.9.0a0+df837d0 or 1.13.0+cu117\n    \"numpy_version\": _conf_values[\"Numpy\"],\n    \"required_packages_version\": {},\n    \"task\": \"Describe what the network predicts\",\n    \"description\": \"A longer description of what the network does, use context, inputs, outputs, etc.\",\n    \"authors\": \"Your Name Here\",\n    \"copyright\": \"Copyright (c) Your Name Here\",\n    \"network_data_format\": {\"inputs\": {}, \"outputs\": {}},\n}\n\nDEFAULT_INFERENCE = {\n    \"imports\": [\"$import glob\"],\n    \"device\": \"$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\",\n    \"ckpt_path\": \"$@bundle_root + '/models/model.pt'\",\n    \"dataset_dir\": \"/workspace/data\",\n    \"datalist\": \"$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))\",\n    \"network_def\": {\"_target_\": \"???\", \"spatial_dims\": 2},\n    \"network\": \"$@network_def.to(@device)\",\n    \"preprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n            {\"_target_\": \"LoadImaged\", \"keys\": \"image\"},\n            {\"_target_\": \"EnsureChannelFirstd\", \"keys\": \"image\"},\n            {\"_target_\": \"ScaleIntensityd\", \"keys\": \"image\"},\n            {\"_target_\": \"EnsureTyped\", \"keys\": \"image\", \"device\": \"@device\"},\n        ],\n    },\n    \"dataset\": {\"_target_\": \"Dataset\", \"data\": \"$[{'image': i} for i in @datalist]\", \"transform\": \"@preprocessing\"},\n    \"dataloader\": {\n        \"_target_\": \"DataLoader\",\n        \"dataset\": \"@dataset\",\n        \"batch_size\": 1,\n        \"shuffle\": False,\n        \"num_workers\": 0,\n    },\n    \"inferer\": {\"_target_\": \"SimpleInferer\"},\n    \"postprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n            {\"_target_\": \"Activationsd\", \"keys\": \"pred\", \"softmax\": True},\n            {\"_target_\": \"AsDiscreted\", \"keys\": \"pred\", \"argmax\": True},\n        ],\n    },\n    \"handlers\": [\n        {\n            \"_target_\": \"CheckpointLoader\",\n            \"_disabled_\": \"$not os.path.exists(@ckpt_path)\",\n            \"load_path\": \"@ckpt_path\",\n            \"load_dict\": {\"model\": \"@network\"},\n        }\n    ],\n    \"evaluator\": {\n        \"_target_\": \"SupervisedEvaluator\",\n        \"device\": \"@device\",\n        \"val_data_loader\": \"@dataloader\",\n        \"network\": \"@network\",\n        \"inferer\": \"@inferer\",\n        \"postprocessing\": \"@postprocessing\",\n        \"val_handlers\": \"@handlers\",\n    },\n    \"evaluating\": [\"$@evaluator.run()\"],\n}\n\nDEFAULT_HANDLERS_ID = {\n    \"trainer\": {\"id\": \"train#trainer\", \"handlers\": \"train#handlers\"},\n    \"validator\": {\"id\": \"validate#evaluator\", \"handlers\": \"validate#handlers\"},\n    \"evaluator\": {\"id\": \"evaluator\", \"handlers\": \"handlers\"},\n}\n\nDEFAULT_MLFLOW_SETTINGS = {\n    \"handlers_id\": DEFAULT_HANDLERS_ID,\n    \"configs\": {\n        # if no \"output_dir\" in the bundle config, default to \"<bundle root>/eval\"\n        \"output_dir\": \"$@bundle_root + '/eval'\",\n        # use URI to support linux, mac and windows os\n        \"tracking_uri\": \"$monai.utils.path_to_uri(@output_dir) + '/mlruns'\",\n        \"experiment_name\": \"monai_experiment\",\n        \"run_name\": None,\n        # may fill it at runtime\n        \"save_execute_config\": True,\n        \"is_not_rank0\": (\n            \"$torch.distributed.is_available() \\\n                and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0\"\n        ),\n        # MLFlowHandler config for the trainer\n        \"trainer\": {\n            \"_target_\": \"MLFlowHandler\",\n            \"_disabled_\": \"@is_not_rank0\",\n            \"tracking_uri\": \"@tracking_uri\",\n            \"experiment_name\": \"@experiment_name\",\n            \"run_name\": \"@run_name\",\n            \"artifacts\": \"@save_execute_config\",\n            \"iteration_log\": True,\n            \"epoch_log\": True,\n            \"tag_name\": \"train_loss\",\n            \"output_transform\": \"$monai.handlers.from_engine(['loss'], first=True)\",\n            \"close_on_complete\": True,\n        },\n        # MLFlowHandler config for the validator\n        \"validator\": {\n            \"_target_\": \"MLFlowHandler\",\n            \"_disabled_\": \"@is_not_rank0\",\n            \"tracking_uri\": \"@tracking_uri\",\n            \"experiment_name\": \"@experiment_name\",\n            \"run_name\": \"@run_name\",\n            \"iteration_log\": False,\n        },\n        # MLFlowHandler config for the evaluator\n        \"evaluator\": {\n            \"_target_\": \"MLFlowHandler\",\n            \"_disabled_\": \"@is_not_rank0\",\n            \"tracking_uri\": \"@tracking_uri\",\n            \"experiment_name\": \"@experiment_name\",\n            \"run_name\": \"@run_name\",\n            \"artifacts\": \"@save_execute_config\",\n            \"iteration_log\": False,\n            \"close_on_complete\": True,\n        },\n    },\n}\n\nDEFAULT_EXP_MGMT_SETTINGS = {\"mlflow\": DEFAULT_MLFLOW_SETTINGS}  # default experiment management settings\n\nDEPRECATED_ID_MAPPING = {\"optional_packages_version\": \"required_packages_version\"}\n\n\ndef load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any) -> Any:\n    \"\"\"\n    Load the metadata and nominated configuration files from a MONAI bundle without loading the network itself.\n\n    This function will load the information from the bundle, which can be a directory or a zip file containing a\n    directory or a Torchscript bundle, and return the parser object with the information. This saves having to load\n    the model if only the information is wanted, and can work on any sort of bundle format.\n\n    Args:\n        bundle_path: path to the bundle directory or zip file\n        config_names: names of configuration files with extensions to load, should not be full paths but just name+ext\n        load_kw_args: keyword arguments to pass to the ConfigParser object when loading\n\n    Returns:\n        ConfigParser object containing the parsed information\n    \"\"\"\n\n    from monai.bundle.config_parser import ConfigParser  # avoids circular import\n\n    parser = ConfigParser()\n\n    if not os.path.exists(bundle_path):\n        raise ValueError(f\"Cannot find bundle file/directory '{bundle_path}'\")\n\n    # bundle is a directory, read files directly\n    if os.path.isdir(bundle_path):\n        conf_data = []\n        parser.read_meta(f=os.path.join(bundle_path, \"configs\", \"metadata.json\"), **load_kw_args)\n\n        for cname in config_names:\n            cpath = os.path.join(bundle_path, \"configs\", cname)\n            if not os.path.exists(cpath):\n                raise ValueError(f\"Cannot find config file '{cpath}'\")\n\n            conf_data.append(cpath)\n\n        parser.read_config(f=conf_data, **load_kw_args)\n    else:\n        # bundle is a zip file which is either a zipped directory or a Torchscript archive\n\n        name, _ = os.path.splitext(os.path.basename(bundle_path))\n\n        archive = zipfile.ZipFile(bundle_path, \"r\")\n\n        all_files = archive.namelist()\n\n        zip_meta_name = f\"{name}/configs/metadata.json\"\n\n        if zip_meta_name in all_files:\n            prefix = f\"{name}/configs/\"  # zipped directory location for files\n        else:\n            zip_meta_name = f\"{name}/extra/metadata.json\"\n            prefix = f\"{name}/extra/\"  # Torchscript location for files\n\n        meta_json = json.loads(archive.read(zip_meta_name))\n        parser.read_meta(f=meta_json)\n\n        for cname in config_names:\n            full_cname = prefix + cname\n            if full_cname not in all_files:\n                raise ValueError(f\"Cannot find config file '{full_cname}'\")\n\n            ardata = archive.read(full_cname)\n            cdata = {}\n\n            if full_cname.lower().endswith(\"json\"):\n                cdata = json.loads(ardata, **load_kw_args)\n            elif full_cname.lower().endswith((\"yaml\", \"yml\")):\n                cdata = yaml.safe_load(ardata, **load_kw_args)\n\n            parser.read_config(f=cdata)\n\n    return parser\n\n\ndef merge_kv(args: dict | Any, k: str, v: Any) -> None:\n    \"\"\"\n    Update the `args` dict-like object with the key/value pair `k` and `v`.\n    \"\"\"\n    if k.startswith(MERGE_KEY):\n        \"\"\"\n        Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.\n        `dict` values will be merged, `list` values - concatenated.\n        \"\"\"\n        id = k[1:]\n        if id in args:\n            if isinstance(v, dict) and isinstance(args[id], dict):\n                args[id].update(v)\n            elif isinstance(v, list) and isinstance(args[id], list):\n                args[id].extend(v)\n            else:\n                raise ValueError(ValueError(f\"config must be dict or list for key `{k}`, but got {type(v)}: {v}.\"))\n        else:\n            warnings.warn(f\"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.\")\n            args[id] = v\n    else:\n        args[k] = v\n"
  },
  {
    "path": "monai/bundle/workflows.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport sys\nimport time\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Sequence\nfrom copy import copy\nfrom logging.config import fileConfig\nfrom pathlib import Path\nfrom typing import Any\n\nfrom monai.apps.utils import get_logger\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.bundle.properties import InferProperties, MetaProperties, TrainProperties\nfrom monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY\nfrom monai.config import PathLike\nfrom monai.utils import BundleProperty, BundlePropertyConfig, ensure_tuple\n\n__all__ = [\"BundleWorkflow\", \"ConfigWorkflow\"]\n\nlogger = get_logger(module_name=__name__)\n\n\nclass BundleWorkflow(ABC):\n    \"\"\"\n    Base class for the workflow specification in bundle, it can be a training, evaluation or inference workflow.\n    It defines the basic interfaces for the bundle workflow behavior: `initialize`, `run`, `finalize`, etc.\n    And also provides the interface to get / set public properties to interact with a bundle workflow.\n\n    Args:\n        workflow_type: specifies the workflow type: \"train\" or \"training\" for a training workflow,\n            or \"infer\", \"inference\", \"eval\", \"evaluation\" for a inference workflow,\n            other unsupported string will raise a ValueError.\n            default to `None` for only using meta properties.\n        properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be\n            loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,\n            properties will default to loading from \"meta\". If `properties_path` is None, default properties\n            will be sourced from \"monai/bundle/properties.py\" based on the workflow_type:\n            For a training workflow, properties load from `TrainProperties` and `MetaProperties`.\n            For a inference workflow, properties load from `InferProperties` and `MetaProperties`.\n            For workflow_type = None : only `MetaProperties` will be loaded.\n        meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.\n        logging_file: config file for `logging` module in the program. for more details:\n            https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.\n\n    \"\"\"\n\n    supported_train_type: tuple = (\"train\", \"training\")\n    supported_infer_type: tuple = (\"infer\", \"inference\", \"eval\", \"evaluation\")\n\n    def __init__(\n        self,\n        workflow_type: str | None = None,\n        properties_path: PathLike | None = None,\n        meta_file: str | Sequence[str] | None = None,\n        logging_file: str | None = None,\n    ):\n        if logging_file is not None:\n            if not os.path.isfile(logging_file):\n                raise FileNotFoundError(f\"Cannot find the logging config file: {logging_file}.\")\n            logger.info(f\"Setting logging properties based on config: {logging_file}.\")\n            fileConfig(logging_file, disable_existing_loggers=False)\n\n        if meta_file is not None:\n            if isinstance(meta_file, str) and not os.path.isfile(meta_file):\n                logger.error(\n                    f\"Cannot find the metadata config file: {meta_file}. \"\n                    \"Please see: https://monai.readthedocs.io/en/stable/mb_specification.html\"\n                )\n                meta_file = None\n            if isinstance(meta_file, list):\n                for f in meta_file:\n                    if not os.path.isfile(f):\n                        logger.error(\n                            f\"Cannot find the metadata config file: {f}. \"\n                            \"Please see: https://monai.readthedocs.io/en/stable/mb_specification.html\"\n                        )\n                        meta_file = None\n\n        if workflow_type is not None:\n            if workflow_type.lower() in self.supported_train_type:\n                workflow_type = \"train\"\n            elif workflow_type.lower() in self.supported_infer_type:\n                workflow_type = \"infer\"\n            else:\n                raise ValueError(f\"Unsupported workflow type: '{workflow_type}'.\")\n\n        if properties_path is not None:\n            properties_path = Path(properties_path)\n            if not properties_path.is_file():\n                raise ValueError(f\"Property file {properties_path} does not exist.\")\n            with open(properties_path) as json_file:\n                try:\n                    properties = json.load(json_file)\n                    self.properties: dict = {}\n                    if workflow_type is not None and workflow_type in properties:\n                        self.properties = properties[workflow_type]\n                        if \"meta\" in properties:\n                            self.properties.update(properties[\"meta\"])\n                    elif workflow_type is None:\n                        if \"meta\" in properties:\n                            self.properties = properties[\"meta\"]\n                            logger.info(\n                                \"No workflow type specified, default to load meta properties from property file.\"\n                            )\n                        else:\n                            logger.warning(\"No 'meta' key found in properties while workflow_type is None.\")\n                except KeyError as e:\n                    raise ValueError(f\"{workflow_type} not found in property file {properties_path}\") from e\n                except json.JSONDecodeError as e:\n                    raise ValueError(f\"Error decoding JSON from property file {properties_path}\") from e\n        else:\n            if workflow_type == \"train\":\n                self.properties = {**TrainProperties, **MetaProperties}\n            elif workflow_type == \"infer\":\n                self.properties = {**InferProperties, **MetaProperties}\n            elif workflow_type is None:\n                self.properties = copy(MetaProperties)\n                logger.info(\"No workflow type and property file specified, default to 'meta' properties.\")\n            else:\n                raise ValueError(f\"Unsupported workflow type: '{workflow_type}'.\")\n\n        self.workflow_type = workflow_type\n        self.meta_file = meta_file\n\n    @abstractmethod\n    def initialize(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        Initialize the bundle workflow before running.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def run(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        Run the bundle workflow, it can be a training, evaluation or inference.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def finalize(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        Finalize step after the running of bundle workflow.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _get_property(self, name: str, property: dict) -> Any:\n        \"\"\"\n        With specified property name and information, get the expected property value.\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _set_property(self, name: str, property: dict, value: Any) -> Any:\n        \"\"\"\n        With specified property name and information, set value for the expected property.\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n            value: value to set for the property.\n\n        \"\"\"\n        raise NotImplementedError()\n\n    def __getattr__(self, name):\n        if self.properties is not None and name in self.properties:\n            return self._get_property(name=name, property=self.properties[name])\n        else:\n            return self.__getattribute__(name)  # getting regular attribute\n\n    def __setattr__(self, name, value):\n        if name != \"properties\" and self.properties is not None and name in self.properties:\n            self._set_property(name=name, property=self.properties[name], value=value)\n        else:\n            super().__setattr__(name, value)  # setting regular attribute\n\n    def get_workflow_type(self):\n        \"\"\"\n        Get the workflow type, it can be `None`, \"train\", or \"infer\".\n\n        \"\"\"\n        return self.workflow_type\n\n    def get_meta_file(self):\n        \"\"\"\n        Get the meta file.\n\n        \"\"\"\n        return self.meta_file\n\n    def add_property(self, name: str, required: str, desc: str | None = None) -> None:\n        \"\"\"\n        Besides the default predefined properties, some 3rd party applications may need the bundle\n        definition to provide additional properties for the specific use cases, if the bundle can't\n        provide the property, means it can't work with the application.\n        This utility adds the property for the application requirements check and access.\n\n        Args:\n            name: the name of target property.\n            required: whether the property is \"must-have\".\n            desc: descriptions for the property.\n        \"\"\"\n        if self.properties is None:\n            self.properties = {}\n        if name in self.properties:\n            logger.warning(f\"property '{name}' already exists in the properties list, overriding it.\")\n        self.properties[name] = {BundleProperty.DESC: desc, BundleProperty.REQUIRED: required}\n\n    def check_properties(self) -> list[str] | None:\n        \"\"\"\n        Check whether the required properties are existing in the bundle workflow.\n        If no workflow type specified, return None, otherwise, return a list of required but missing properties.\n\n        \"\"\"\n        if self.properties is None:\n            return None\n        return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)]\n\n\nclass PythonicWorkflow(BundleWorkflow):\n    \"\"\"\n    Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow.\n    It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc.\n    This also provides the interface to get / set public properties to interact with a bundle workflow through\n    defined `get_<property>` accessor methods or directly defining members of the object.\n    For how to set the properties, users can define the `_set_<property>` methods or directly set the members of the object.\n    The `initialize` method is called to set up the workflow before running. This method sets up internal state\n    and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized`\n    is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is\n    properly set up with the new property values.\n\n    Args:\n        workflow_type: specifies the workflow type: \"train\" or \"training\" for a training workflow,\n            or \"infer\", \"inference\", \"eval\", \"evaluation\" for a inference workflow,\n            other unsupported string will raise a ValueError.\n            default to `None` for only using meta properties.\n        workflow: specifies the workflow type: \"train\" or \"training\" for a training workflow,\n            or \"infer\", \"inference\", \"eval\", \"evaluation\" for a inference workflow,\n            other unsupported string will raise a ValueError.\n            default to `None` for common workflow.\n        properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be\n            loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,\n            properties will default to loading from \"meta\". If `properties_path` is None, default properties\n            will be sourced from \"monai/bundle/properties.py\" based on the workflow_type:\n            For a training workflow, properties load from `TrainProperties` and `MetaProperties`.\n            For a inference workflow, properties load from `InferProperties` and `MetaProperties`.\n            For workflow_type = None : only `MetaProperties` will be loaded.\n        config_file: path to the config file, typically used to store hyperparameters.\n        meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.\n        logging_file: config file for `logging` module in the program. for more details:\n            https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.\n\n    \"\"\"\n\n    supported_train_type: tuple = (\"train\", \"training\")\n    supported_infer_type: tuple = (\"infer\", \"inference\", \"eval\", \"evaluation\")\n\n    def __init__(\n        self,\n        workflow_type: str | None = None,\n        properties_path: PathLike | None = None,\n        config_file: str | Sequence[str] | None = None,\n        meta_file: str | Sequence[str] | None = None,\n        logging_file: str | None = None,\n        **override: Any,\n    ):\n        meta_file = str(Path(os.getcwd()) / \"metadata.json\") if meta_file is None else meta_file\n        super().__init__(\n            workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file\n        )\n        self._props_vals: dict = {}\n        self._set_props_vals: dict = {}\n        self.parser = ConfigParser()\n        if config_file is not None:\n            self.parser.read_config(f=config_file)\n        if self.meta_file is not None:\n            self.parser.read_meta(f=self.meta_file)\n\n        # the rest key-values in the _args are to override config content\n        self.parser.update(pairs=override)\n        self._is_initialized: bool = False\n\n    def initialize(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        Initialize the bundle workflow before running.\n        \"\"\"\n        self._props_vals = {}\n        self._is_initialized = True\n\n    def _get_property(self, name: str, property: dict) -> Any:\n        \"\"\"\n        With specified property name and information, get the expected property value.\n        If the property is already generated, return from the bucket directly.\n        If user explicitly set the property, return it directly.\n        Otherwise, generate the expected property as a class private property with prefix \"_\".\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n        \"\"\"\n        if not self._is_initialized:\n            raise RuntimeError(\"Please execute 'initialize' before getting any properties.\")\n        value = None\n        if name in self._set_props_vals:\n            value = self._set_props_vals[name]\n        elif name in self._props_vals:\n            value = self._props_vals[name]\n        elif name in self.parser.config[self.parser.meta_key]:  # type: ignore[index]\n            id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None)\n            value = self.parser[id]\n        else:\n            try:\n                value = getattr(self, f\"get_{name}\")()\n            except AttributeError as e:\n                if property[BundleProperty.REQUIRED]:\n                    raise ValueError(\n                        f\"unsupported property '{name}' is required in the bundle properties,\"\n                        f\"need to implement a method 'get_{name}' to provide the property.\"\n                    ) from e\n            self._props_vals[name] = value\n        return value\n\n    def _set_property(self, name: str, property: dict, value: Any) -> Any:\n        \"\"\"\n        With specified property name and information, set value for the expected property.\n        Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized.\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n            value: value to set for the property.\n\n        \"\"\"\n        self._set_props_vals[name] = value\n        self._is_initialized = False\n\n\nclass ConfigWorkflow(BundleWorkflow):\n    \"\"\"\n    Specification for the config-based bundle workflow.\n    Standardized the `initialize`, `run`, `finalize` behavior in a config-based training, evaluation, or inference.\n    Before `run`, we add bundle root directory to Python search directories automatically.\n    For more information: https://monai.readthedocs.io/en/latest/mb_specification.html.\n\n    Args:\n        config_file: filepath of the config file, if this is a list of file paths, their contents will be merged in order.\n        meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.\n            If None, default to \"configs/metadata.json\", which is commonly used for bundles in MONAI model zoo.\n        logging_file: config file for `logging` module in the program. for more details:\n            https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.\n            If None, default to \"configs/logging.conf\", which is commonly used for bundles in MONAI model zoo.\n            If False, the logging logic for the bundle will not be modified.\n        init_id: ID name of the expected config expression to initialize before running, default to \"initialize\".\n            allow a config to have no `initialize` logic and the ID.\n        run_id: ID name of the expected config expression to run, default to \"run\".\n            to run the config, the target config must contain this ID.\n        final_id: ID name of the expected config expression to finalize after running, default to \"finalize\".\n            allow a config to have no `finalize` logic and the ID.\n        tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible.\n            if \"mlflow\", will add `MLFlowHandler` to the parsed bundle with default tracking settings,\n            if other string, treat it as file path to load the tracking settings.\n            if `dict`, treat it as tracking settings.\n            will patch the target config content with `tracking handlers` and the top-level items of `configs`.\n            for detailed usage examples, please check the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb.\n        workflow_type: specifies the workflow type: \"train\" or \"training\" for a training workflow,\n            or \"infer\", \"inference\", \"eval\", \"evaluation\" for a inference workflow,\n            other unsupported string will raise a ValueError.\n            default to `None` for common workflow.\n        properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be\n            loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,\n            properties will default to loading from \"train\". If `properties_path` is None, default properties\n            will be sourced from \"monai/bundle/properties.py\" based on the workflow_type:\n            For a training workflow, properties load from `TrainProperties` and `MetaProperties`.\n            For a inference workflow, properties load from `InferProperties` and `MetaProperties`.\n            For workflow_type = None : only `MetaProperties` will be loaded.\n        override: id-value pairs to override or add the corresponding config content.\n            e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``\n\n    \"\"\"\n\n    def __init__(\n        self,\n        config_file: str | Sequence[str],\n        meta_file: str | Sequence[str] | None = None,\n        logging_file: str | bool | None = None,\n        init_id: str = \"initialize\",\n        run_id: str = \"run\",\n        final_id: str = \"finalize\",\n        tracking: str | dict | None = None,\n        workflow_type: str | None = \"train\",\n        properties_path: PathLike | None = None,\n        **override: Any,\n    ) -> None:\n        if config_file is not None:\n            _config_files = ensure_tuple(config_file)\n            config_root_path = Path(_config_files[0]).parent\n            for _config_file in _config_files:\n                _config_file = Path(_config_file)\n                if _config_file.parent != config_root_path:\n                    logger.warning(\n                        f\"Not all config files are in {config_root_path}. If logging_file and meta_file are\"\n                        f\"not specified, {config_root_path} will be used as the default config root directory.\"\n                    )\n                if not _config_file.is_file():\n                    raise FileNotFoundError(f\"Cannot find the config file: {_config_file}.\")\n        else:\n            config_root_path = Path(\"configs\")\n        meta_file = str(config_root_path / \"metadata.json\") if meta_file is None else meta_file\n        super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path)\n        self.config_root_path = config_root_path\n        logging_file = str(self.config_root_path / \"logging.conf\") if logging_file is None else logging_file\n        if logging_file is False:\n            logger.warning(f\"Logging file is set to {logging_file}, skipping logging.\")\n        else:\n            if not os.path.isfile(logging_file):\n                if logging_file == str(self.config_root_path / \"logging.conf\"):\n                    logger.warning(f\"Default logging file in {logging_file} does not exist, skipping logging.\")\n                else:\n                    raise FileNotFoundError(f\"Cannot find the logging config file: {logging_file}.\")\n            else:\n                fileConfig(str(logging_file), disable_existing_loggers=False)\n                logger.info(f\"Setting logging properties based on config: {logging_file}.\")\n\n        self.parser = ConfigParser()\n        self.parser.read_config(f=config_file)\n        if self.meta_file is not None:\n            self.parser.read_meta(f=self.meta_file)\n        # the rest key-values in the _args are to override config content\n        self.parser.update(pairs=override)\n        self.init_id = init_id\n        self.run_id = run_id\n        self.final_id = final_id\n        # set tracking configs for experiment management\n        if tracking is not None:\n            if isinstance(tracking, str) and tracking in DEFAULT_EXP_MGMT_SETTINGS:\n                settings_ = DEFAULT_EXP_MGMT_SETTINGS[tracking]\n            else:\n                settings_ = ConfigParser.load_config_files(tracking)\n            self.patch_bundle_tracking(parser=self.parser, settings=settings_)\n        self._is_initialized: bool = False\n\n    def initialize(self) -> Any:\n        \"\"\"\n        Initialize the bundle workflow before running.\n\n        \"\"\"\n        # reset the \"reference_resolver\" buffer at initialization stage\n        self.parser.parse(reset=True)\n        self._is_initialized = True\n        return self._run_expr(id=self.init_id)\n\n    def run(self) -> Any:\n        \"\"\"\n        Run the bundle workflow, it can be a training, evaluation or inference.\n        Before run, we add bundle root directory to Python search directories automatically.\n\n        \"\"\"\n        _bundle_root_path = (\n            self.config_root_path.parent if self.config_root_path.name == \"configs\" else self.config_root_path\n        )\n        sys.path.insert(1, str(_bundle_root_path))\n        if self.run_id not in self.parser:\n            raise ValueError(f\"run ID '{self.run_id}' doesn't exist in the config file.\")\n        return self._run_expr(id=self.run_id)\n\n    def finalize(self) -> Any:\n        \"\"\"\n        Finalize step after the running of bundle workflow.\n\n        \"\"\"\n        return self._run_expr(id=self.final_id)\n\n    def check_properties(self) -> list[str] | None:\n        \"\"\"\n        Check whether the required properties are existing in the bundle workflow.\n        If the optional properties have reference in the config, will also check whether the properties are existing.\n        If no workflow type specified, return None, otherwise, return a list of required but missing properties.\n\n        \"\"\"\n        ret = super().check_properties()\n        if self.properties is None:\n            logger.warning(\"No available properties had been set, skipping check.\")\n            return None\n        if ret:\n            logger.warning(f\"Loaded bundle does not contain the following required properties: {ret}\")\n        # also check whether the optional properties use correct ID name if existing\n        wrong_props = []\n        for n, p in self.properties.items():\n            if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p):\n                wrong_props.append(n)\n        if wrong_props:\n            logger.warning(f\"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}\")\n        if ret is not None:\n            ret.extend(wrong_props)\n        return ret\n\n    def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:\n        \"\"\"\n        Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,\n        allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.\n        \"\"\"\n        ret = []\n        if id in self.parser:\n            # suppose all the expressions are in a list, run and reset the expressions\n            if isinstance(self.parser[id], list):\n                for i in range(len(self.parser[id])):\n                    sub_id = f\"{id}{ID_SEP_KEY}{i}\"\n                    ret.append(self.parser.get_parsed_content(sub_id, **kwargs))\n                    self.parser.ref_resolver.remove_resolved_content(sub_id)\n            else:\n                ret.append(self.parser.get_parsed_content(id, **kwargs))\n                self.parser.ref_resolver.remove_resolved_content(id)\n        return ret\n\n    def _get_prop_id(self, name: str, property: dict) -> Any:\n        prop_id = property[BundlePropertyConfig.ID]\n        if prop_id not in self.parser:\n            if not property.get(BundleProperty.REQUIRED, False):\n                return None\n            else:\n                raise KeyError(f\"Property '{name}' with config ID '{prop_id}' not in the config.\")\n        return prop_id\n\n    def _get_property(self, name: str, property: dict) -> Any:\n        \"\"\"\n        With specified property name and information, get the parsed property value from config.\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n\n        \"\"\"\n        if not self._is_initialized:\n            raise RuntimeError(\"Please execute 'initialize' before getting any parsed content.\")\n        prop_id = self._get_prop_id(name, property)\n        return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None\n\n    def _set_property(self, name: str, property: dict, value: Any) -> None:\n        \"\"\"\n        With specified property name and information, set value for the expected property.\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n            value: value to set for the property.\n\n        \"\"\"\n        prop_id = self._get_prop_id(name, property)\n        if prop_id is not None:\n            self.parser[prop_id] = value\n            # must parse the config again after changing the content\n            self._is_initialized = False\n            self.parser.ref_resolver.reset()\n\n    def add_property(  # type: ignore[override]\n        self, name: str, required: str, config_id: str, desc: str | None = None\n    ) -> None:\n        \"\"\"\n        Besides the default predefined properties, some 3rd party applications may need the bundle\n        definition to provide additional properties for the specific use cases, if the bundle can't\n        provide the property, means it can't work with the application.\n        This utility adds the property for the application requirements check and access.\n\n        Args:\n            name: the name of target property.\n            required: whether the property is \"must-have\".\n            config_id: the config ID of target property in the bundle definition.\n            desc: descriptions for the property.\n\n        \"\"\"\n        super().add_property(name=name, required=required, desc=desc)\n        self.properties[name][BundlePropertyConfig.ID] = config_id\n\n    def _check_optional_id(self, name: str, property: dict) -> bool:\n        \"\"\"\n        If an optional property has reference in the config, check whether the property is existing.\n        If `ValidationHandler` is defined for a training workflow, will check whether the optional properties\n        \"evaluator\" and \"val_interval\" are existing.\n\n        Args:\n            name: the name of target property.\n            property: other information for the target property, defined in `TrainProperties` or `InferProperties`.\n\n        \"\"\"\n        id = property.get(BundlePropertyConfig.ID, None)\n        ref_id = property.get(BundlePropertyConfig.REF_ID, None)\n        if ref_id is None:\n            # no ID of reference config item, skipping check for this optional property\n            return True\n        # check validation `validator` and `interval` properties as the handler index of ValidationHandler is unknown\n        ref: str | None = None\n        if name in (\"evaluator\", \"val_interval\"):\n            if f\"train{ID_SEP_KEY}handlers\" in self.parser:\n                for h in self.parser[f\"train{ID_SEP_KEY}handlers\"]:\n                    if h[\"_target_\"] == \"ValidationHandler\":\n                        ref = h.get(ref_id, None)\n        else:\n            ref = self.parser.get(ref_id, None)\n        # for reference IDs that not refer to a property directly but using expressions, skip the check\n        if ref is not None and not ref.startswith(EXPR_KEY) and ref != ID_REF_KEY + id:\n            return False\n        return True\n\n    @staticmethod\n    def patch_bundle_tracking(parser: ConfigParser, settings: dict) -> None:\n        \"\"\"\n        Patch the loaded bundle config with a new handler logic to enable experiment tracking features.\n\n        Args:\n            parser: loaded config content to patch the handler.\n            settings: settings for the experiment tracking, should follow the pattern of default settings.\n\n        \"\"\"\n        for k, v in settings[\"configs\"].items():\n            if k in settings[\"handlers_id\"]:\n                engine = parser.get(settings[\"handlers_id\"][k][\"id\"])\n                if engine is not None:\n                    handlers = parser.get(settings[\"handlers_id\"][k][\"handlers\"])\n                    if handlers is None:\n                        engine[\"train_handlers\" if k == \"trainer\" else \"val_handlers\"] = [v]\n                    else:\n                        handlers.append(v)\n            elif k not in parser:\n                parser[k] = v\n        # save the executed config into file\n        default_name = f\"config_{time.strftime('%Y%m%d_%H%M%S')}.json\"\n        # Users can set the `save_execute_config` to `False`, `/path/to/artifacts` or `True`.\n        # If set to False, nothing will be recorded. If set to True, the default path will be logged.\n        # If set to a file path, the given path will be logged.\n        filepath = parser.get(\"save_execute_config\", True)\n        if filepath:\n            if isinstance(filepath, bool):\n                if \"output_dir\" not in parser:\n                    # if no \"output_dir\" in the bundle config, default to \"<bundle root>/eval\"\n                    parser[\"output_dir\"] = f\"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'\"\n                # experiment management tools can refer to this config item to track the config info\n                parser[\"save_execute_config\"] = parser[\"output_dir\"] + f\" + '/{default_name}'\"\n                filepath = os.path.join(parser.get_parsed_content(\"output_dir\"), default_name)\n            Path(filepath).parent.mkdir(parents=True, exist_ok=True)\n            parser.export_config_file(parser.get(), filepath)\n        else:\n            parser[\"save_execute_config\"] = None\n"
  },
  {
    "path": "monai/config/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .deviceconfig import (\n    USE_COMPILED,\n    USE_META_DICT,\n    IgniteInfo,\n    get_config_values,\n    get_gpu_info,\n    get_optional_config_values,\n    get_system_info,\n    print_config,\n    print_debug_info,\n    print_gpu_info,\n    print_system_info,\n)\nfrom .type_definitions import (\n    DtypeLike,\n    IndexSelection,\n    KeysCollection,\n    NdarrayOrTensor,\n    NdarrayTensor,\n    PathLike,\n    SequenceStr,\n    TensorOrList,\n)\n"
  },
  {
    "path": "monai/config/deviceconfig.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport getpass\nimport os\nimport platform\nimport re\nimport sys\nfrom collections import OrderedDict\nfrom typing import TextIO\n\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.utils.deprecate_utils import deprecated\nfrom monai.utils.enums import IgniteInfo as _IgniteInfo\nfrom monai.utils.module import OptionalImportError, get_package_version, optional_import\n\ntry:\n    _, HAS_EXT = optional_import(\"monai._C\")\n    USE_COMPILED = HAS_EXT and os.getenv(\"BUILD_MONAI\", \"0\") == \"1\"\nexcept (OptionalImportError, ImportError, AttributeError):\n    HAS_EXT = USE_COMPILED = False\n\nUSE_META_DICT = os.environ.get(\"USE_META_DICT\", \"0\") == \"1\"  # set to True for compatibility, use meta dict.\n\npsutil, has_psutil = optional_import(\"psutil\")\npsutil_version = psutil.__version__ if has_psutil else \"NOT INSTALLED or UNKNOWN VERSION.\"\n\n__all__ = [\n    \"print_config\",\n    \"get_system_info\",\n    \"print_system_info\",\n    \"get_gpu_info\",\n    \"print_gpu_info\",\n    \"print_debug_info\",\n    \"USE_COMPILED\",\n    \"USE_META_DICT\",\n    \"IgniteInfo\",\n]\n\n\ndef get_config_values():\n    \"\"\"\n    Read the package versions into a dictionary.\n    \"\"\"\n    output = OrderedDict()\n\n    output[\"MONAI\"] = monai.__version__\n    output[\"Numpy\"] = np.version.full_version\n    output[\"Pytorch\"] = torch.__version__\n\n    return output\n\n\ndef get_optional_config_values():\n    \"\"\"\n    Read the optional package versions into a dictionary.\n    \"\"\"\n    output = OrderedDict()\n\n    output[\"Pytorch Ignite\"] = get_package_version(\"ignite\")\n    output[\"ITK\"] = get_package_version(\"itk\")\n    output[\"Nibabel\"] = get_package_version(\"nibabel\")\n    output[\"scikit-image\"] = get_package_version(\"skimage\")\n    output[\"scipy\"] = get_package_version(\"scipy\")\n    output[\"Pillow\"] = get_package_version(\"PIL\")\n    output[\"Tensorboard\"] = get_package_version(\"tensorboard\")\n    output[\"gdown\"] = get_package_version(\"gdown\")\n    output[\"TorchVision\"] = get_package_version(\"torchvision\")\n    output[\"tqdm\"] = get_package_version(\"tqdm\")\n    output[\"lmdb\"] = get_package_version(\"lmdb\")\n    output[\"psutil\"] = psutil_version\n    output[\"pandas\"] = get_package_version(\"pandas\")\n    output[\"einops\"] = get_package_version(\"einops\")\n    output[\"transformers\"] = get_package_version(\"transformers\")\n    output[\"mlflow\"] = get_package_version(\"mlflow\")\n    output[\"pynrrd\"] = get_package_version(\"nrrd\")\n    output[\"clearml\"] = get_package_version(\"clearml\")\n\n    return output\n\n\ndef print_config(file=sys.stdout):\n    \"\"\"\n    Print the package versions to `file`.\n\n    Args:\n        file: `print()` text stream file. Defaults to `sys.stdout`.\n    \"\"\"\n    for k, v in get_config_values().items():\n        print(f\"{k} version: {v}\", file=file, flush=True)\n    print(f\"MONAI flags: HAS_EXT = {HAS_EXT}, USE_COMPILED = {USE_COMPILED}, USE_META_DICT = {USE_META_DICT}\")\n    print(f\"MONAI rev id: {monai.__revision_id__}\")\n    username = getpass.getuser()\n    masked_file_path = re.sub(username, \"<username>\", monai.__file__)\n    print(f\"MONAI __file__: {masked_file_path}\", file=file, flush=True)\n    print(\"\\nOptional dependencies:\", file=file, flush=True)\n    for k, v in get_optional_config_values().items():\n        print(f\"{k} version: {v}\", file=file, flush=True)\n    print(\"\\nFor details about installing the optional dependencies, please visit:\", file=file, flush=True)\n    print(\n        \"    https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies\\n\",\n        file=file,\n        flush=True,\n    )\n\n\ndef _dict_append(in_dict, key, fn):\n    try:\n        in_dict[key] = fn() if callable(fn) else fn\n    except BaseException:\n        in_dict[key] = \"UNKNOWN for given OS\"\n\n\ndef get_system_info() -> OrderedDict:\n    \"\"\"\n    Get system info as an ordered dictionary.\n    \"\"\"\n    output: OrderedDict = OrderedDict()\n\n    _dict_append(output, \"System\", platform.system)\n    if output[\"System\"] == \"Windows\":\n        _dict_append(output, \"Win32 version\", platform.win32_ver)\n        if hasattr(platform, \"win32_edition\"):\n            _dict_append(output, \"Win32 edition\", platform.win32_edition)\n\n    elif output[\"System\"] == \"Darwin\":\n        _dict_append(output, \"Mac version\", lambda: platform.mac_ver()[0])\n    else:\n        with open(\"/etc/os-release\") as rel_f:\n            linux_ver = re.search(r'PRETTY_NAME=\"(.*)\"', rel_f.read())\n        if linux_ver:\n            _dict_append(output, \"Linux version\", lambda: linux_ver.group(1))\n\n    _dict_append(output, \"Platform\", platform.platform)\n    _dict_append(output, \"Processor\", platform.processor)\n    _dict_append(output, \"Machine\", platform.machine)\n    _dict_append(output, \"Python version\", platform.python_version)\n\n    if not has_psutil:\n        _dict_append(output, \"`psutil` missing\", lambda: \"run `pip install monai[psutil]`\")\n    else:\n        p = psutil.Process()\n        with p.oneshot():\n            _dict_append(output, \"Process name\", p.name)\n            _dict_append(output, \"Command\", p.cmdline)\n            _dict_append(output, \"Open files\", p.open_files)\n            _dict_append(output, \"Num physical CPUs\", lambda: psutil.cpu_count(logical=False))\n            _dict_append(output, \"Num logical CPUs\", lambda: psutil.cpu_count(logical=True))\n            _dict_append(output, \"Num usable CPUs\", lambda: len(psutil.Process().cpu_affinity()))\n            _dict_append(output, \"CPU usage (%)\", lambda: psutil.cpu_percent(percpu=True))\n            _dict_append(output, \"CPU freq. (MHz)\", lambda: round(psutil.cpu_freq(percpu=False)[0]))\n            _dict_append(\n                output,\n                \"Load avg. in last 1, 5, 15 mins (%)\",\n                lambda: [round(x / psutil.cpu_count() * 100, 1) for x in psutil.getloadavg()],\n            )\n            _dict_append(output, \"Disk usage (%)\", lambda: psutil.disk_usage(os.getcwd()).percent)\n            _dict_append(\n                output,\n                \"Avg. sensor temp. (Celsius)\",\n                lambda: np.round(\n                    np.mean([item.current for sublist in psutil.sensors_temperatures().values() for item in sublist], 1)\n                ),\n            )\n            mem = psutil.virtual_memory()\n            _dict_append(output, \"Total physical memory (GB)\", lambda: round(mem.total / 1024**3, 1))\n            _dict_append(output, \"Available memory (GB)\", lambda: round(mem.available / 1024**3, 1))\n            _dict_append(output, \"Used memory (GB)\", lambda: round(mem.used / 1024**3, 1))\n\n    return output\n\n\ndef print_system_info(file: TextIO = sys.stdout) -> None:\n    \"\"\"\n    Print system info to `file`. Requires the optional library, `psutil`.\n\n    Args:\n        file: `print()` text stream file. Defaults to `sys.stdout`.\n    \"\"\"\n    if not has_psutil:\n        print(\"`psutil` required for `print_system_info`\", file=file, flush=True)\n    else:\n        for k, v in get_system_info().items():\n            print(f\"{k}: {v}\", file=file, flush=True)\n\n\ndef get_gpu_info() -> OrderedDict:\n    output: OrderedDict = OrderedDict()\n\n    num_gpus = torch.cuda.device_count()\n    _dict_append(output, \"Num GPUs\", lambda: num_gpus)\n\n    _dict_append(output, \"Has CUDA\", lambda: bool(torch.cuda.is_available()))\n\n    if output[\"Has CUDA\"]:\n        _dict_append(output, \"CUDA version\", lambda: torch.version.cuda)\n    cudnn_ver = torch.backends.cudnn.version()\n    _dict_append(output, \"cuDNN enabled\", lambda: bool(cudnn_ver))\n    _dict_append(output, \"NVIDIA_TF32_OVERRIDE\", os.environ.get(\"NVIDIA_TF32_OVERRIDE\"))\n    _dict_append(output, \"TORCH_ALLOW_TF32_CUBLAS_OVERRIDE\", os.environ.get(\"TORCH_ALLOW_TF32_CUBLAS_OVERRIDE\"))\n\n    if cudnn_ver:\n        _dict_append(output, \"cuDNN version\", lambda: cudnn_ver)\n\n    if num_gpus > 0:\n        _dict_append(output, \"Current device\", torch.cuda.current_device)\n        _dict_append(output, \"Library compiled for CUDA architectures\", torch.cuda.get_arch_list)\n\n    for gpu in range(num_gpus):\n        gpu_info = torch.cuda.get_device_properties(gpu)\n        _dict_append(output, f\"GPU {gpu} Name\", gpu_info.name)\n        _dict_append(output, f\"GPU {gpu} Is integrated\", bool(gpu_info.is_integrated))\n        _dict_append(output, f\"GPU {gpu} Is multi GPU board\", bool(gpu_info.is_multi_gpu_board))\n        _dict_append(output, f\"GPU {gpu} Multi processor count\", gpu_info.multi_processor_count)\n        _dict_append(output, f\"GPU {gpu} Total memory (GB)\", round(gpu_info.total_memory / 1024**3, 1))\n        _dict_append(output, f\"GPU {gpu} CUDA capability (maj.min)\", f\"{gpu_info.major}.{gpu_info.minor}\")\n\n    return output\n\n\ndef print_gpu_info(file: TextIO = sys.stdout) -> None:\n    \"\"\"\n    Print GPU info to `file`.\n\n    Args:\n        file: `print()` text stream file. Defaults to `sys.stdout`.\n    \"\"\"\n    for k, v in get_gpu_info().items():\n        print(f\"{k}: {v}\", file=file, flush=True)\n\n\ndef print_debug_info(file: TextIO = sys.stdout) -> None:\n    \"\"\"\n    Print config (installed dependencies, etc.) and system info for debugging.\n\n    Args:\n        file: `print()` text stream file. Defaults to `sys.stdout`.\n    \"\"\"\n    print(\"================================\", file=file, flush=True)\n    print(\"Printing MONAI config...\", file=file, flush=True)\n    print(\"================================\", file=file, flush=True)\n    print_config(file)\n    print(\"\\n================================\", file=file, flush=True)\n    print(\"Printing system config...\")\n    print(\"================================\", file=file, flush=True)\n    print_system_info(file)\n    print(\"\\n================================\", file=file, flush=True)\n    print(\"Printing GPU config...\")\n    print(\"================================\", file=file, flush=True)\n    print_gpu_info(file)\n\n\n@deprecated(since=\"1.4.0\", removed=\"1.6.0\", msg_suffix=\"Please use `monai.utils.enums.IgniteInfo` instead.\")\nclass IgniteInfo:\n    \"\"\"Deprecated Import of IgniteInfo enum, which was moved to `monai.utils.enums.IgniteInfo`.\"\"\"\n\n    OPT_IMPORT_VERSION = _IgniteInfo.OPT_IMPORT_VERSION\n\n\nif __name__ == \"__main__\":\n    print_debug_info()\n"
  },
  {
    "path": "monai/config/type_definitions.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nfrom collections.abc import Collection, Hashable, Iterable, Sequence\nfrom typing import TypeVar, Union\n\nimport numpy as np\nimport torch\n\n# Commonly used concepts\n# This module provides naming and type specifications for commonly used concepts\n# within the MONAI package. The intent is to explicitly identify information\n# that should be used consistently throughout the entire MONAI package.\n#\n# A type would be named as type_definitions.KeysCollection\n# which includes a meaningful name for the consent in the name itself. The\n# definitions in this file map context meaningful names to the underlying\n# object properties that define the expected API.\n#\n# A conceptual type is represented by a new type name but is also one which\n# can be different depending on an environment (i.e. differences for python 3.6 vs 3.9\n# may be implemented). Consistent use of the concept and recorded documentation of\n# the rationale and convention behind it lowers the learning curve for new\n# developers. For readability, short names are preferred.\n__all__ = [\n    \"KeysCollection\",\n    \"IndexSelection\",\n    \"DtypeLike\",\n    \"NdarrayTensor\",\n    \"NdarrayOrTensor\",\n    \"TensorOrList\",\n    \"PathLike\",\n    \"SequenceStr\",\n]\n\n#: KeysCollection\n#\n# The KeyCollection type is used to for defining variables\n# that store a subset of keys to select items from a dictionary.\n# The container of keys must contain hashable elements.\n# NOTE:  `Hashable` is not a collection, but is provided as a\n#        convenience to end-users.  All supplied values will be\n#        internally converted to a tuple of `Hashable`'s before\n#        use\nKeysCollection = Union[Collection[Hashable], Hashable]\n\n#: IndexSelection\n#\n# The IndexSelection type is used to for defining variables\n# that store a subset of indices to select items from a List or Array like objects.\n# The indices must be integers, and if a container of indices is specified, the\n# container must be iterable.\nIndexSelection = Union[Iterable[int], int]\n\n#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/v1.21.4/numpy/typing/_dtype_like.py#L121\nDtypeLike = Union[np.dtype, type, str, None]\n\n#: NdarrayOrTensor: Union of numpy.ndarray and torch.Tensor to be used for typing\nNdarrayOrTensor = Union[np.ndarray, torch.Tensor]\n\n#: NdarrayTensor\n#\n# Generic type which can represent either a numpy.ndarray or a torch.Tensor\n# Unlike Union can create a dependence between parameter(s) / return(s)\nNdarrayTensor = TypeVar(\"NdarrayTensor\", bound=NdarrayOrTensor)\n\n#: TensorOrList: The TensorOrList type is used for defining `batch-first Tensor` or `list of channel-first Tensor`.\nTensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]]\n\n#: PathLike: The PathLike type is used for defining a file path.\nPathLike = Union[str, os.PathLike]\n\n#: SequenceStr\n# string or a sequence of strings for `mode` types.\nSequenceStr = Union[Sequence[str], str]\n"
  },
  {
    "path": "monai/csrc/ext.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n\n#include \"filtering/filtering.h\"\n#include \"lltm/lltm.h\"\n#include \"resample/pushpull.h\"\n#include \"utils/resample_utils.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  // filtering\n  m.def(\"bilateral_filter\", &BilateralFilter, \"Bilateral Filter\");\n  m.def(\"phl_filter\", &PermutohedralFilter, \"Permutohedral Filter\");\n  m.def(\"tbf_forward\", &TrainableBilateralFilterForward, \"Trainable Bilateral Filter Forward\");\n  m.def(\"tbf_backward\", &TrainableBilateralFilterBackward, \"Trainable Bilateral Filter Backward\");\n  m.def(\"tjbf_forward\", &TrainableJointBilateralFilterForward, \"Trainable Joint Bilateral Filter Forward\");\n  m.def(\"tjbf_backward\", &TrainableJointBilateralFilterBackward, \"Trainable Joint Bilateral Filter Backward\");\n\n  // lltm\n  m.def(\"lltm_forward\", &lltm_forward, \"LLTM forward\");\n  m.def(\"lltm_backward\", &lltm_backward, \"LLTM backward\");\n\n  // resample bound mode\n  py::enum_<monai::BoundType>(m, \"BoundType\")\n      .value(\"replicate\", monai::BoundType::Replicate, \"a a a | a b c d | d d d\")\n      .value(\"nearest\", monai::BoundType::Replicate, \"a a a | a b c d | d d d\")\n      .value(\"border\", monai::BoundType::Replicate, \"a a a | a b c d | d d d\")\n      .value(\"dct1\", monai::BoundType::DCT1, \"d c b | a b c d | c b a\")\n      .value(\"mirror\", monai::BoundType::DCT1, \"d c b | a b c d | c b a\")\n      .value(\"dct2\", monai::BoundType::DCT2, \"c b a | a b c d | d c b\")\n      .value(\"reflect\", monai::BoundType::DCT2, \"c b a | a b c d | d c b\")\n      .value(\"dst1\", monai::BoundType::DST1, \"-b -a 0 | a b c d | 0 -d -c\")\n      .value(\"antimirror\", monai::BoundType::DST1, \"-b -a 0 | a b c d | 0 -d -c\")\n      .value(\"dst2\", monai::BoundType::DST2, \"-c -b -a | a b c d | -d -c -b\")\n      .value(\"antireflect\", monai::BoundType::DST2, \"-c -b -a | a b c d | -d -c -b\")\n      .value(\"dft\", monai::BoundType::DFT, \"b c d | a b c d | a b c\")\n      .value(\"wrap\", monai::BoundType::DFT, \"b c d | a b c d | a b c\")\n      //   .value(\"sliding\", monai::BoundType::Sliding)\n      .value(\"zero\", monai::BoundType::Zero, \"0 0 0 | a b c d | 0 0 0\")\n      .value(\"zeros\", monai::BoundType::Zero, \"0 0 0 | a b c d | 0 0 0\")\n      .export_values();\n\n  // resample interpolation mode\n  py::enum_<monai::InterpolationType>(m, \"InterpolationType\")\n      .value(\"nearest\", monai::InterpolationType::Nearest)\n      .value(\"linear\", monai::InterpolationType::Linear)\n      .value(\"quadratic\", monai::InterpolationType::Quadratic)\n      .value(\"cubic\", monai::InterpolationType::Cubic)\n      .value(\"fourth\", monai::InterpolationType::FourthOrder)\n      .value(\"fifth\", monai::InterpolationType::FifthOrder)\n      .value(\"sixth\", monai::InterpolationType::SixthOrder)\n      .value(\"seventh\", monai::InterpolationType::SeventhOrder)\n      .export_values();\n\n  // resample\n  m.def(\"grid_pull\", &monai::grid_pull, \"GridPull\");\n  m.def(\"grid_pull_backward\", &monai::grid_pull_backward, \"GridPull backward\");\n  m.def(\"grid_push\", &monai::grid_push, \"GridPush\");\n  m.def(\"grid_push_backward\", &monai::grid_push_backward, \"GridPush backward\");\n  m.def(\"grid_count\", &monai::grid_count, \"GridCount\");\n  m.def(\"grid_count_backward\", &monai::grid_count_backward, \"GridCount backward\");\n  m.def(\"grid_grad\", &monai::grid_grad, \"GridGrad\");\n  m.def(\"grid_grad_backward\", &monai::grid_grad_backward, \"GridGrad backward\");\n}\n"
  },
  {
    "path": "monai/csrc/filtering/bilateral/bilateral.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n#include <stdexcept>\n#include <string>\n\n#include \"bilateral.h\"\n#include \"utils/common_utils.h\"\n\ntorch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL) {\n  torch::Tensor (*filterFunction)(torch::Tensor, float, float);\n\n#ifdef WITH_CUDA\n\n  if (torch::cuda::is_available() && input.is_cuda()) {\n    CHECK_CONTIGUOUS_CUDA(input);\n\n    if (input.size(1) > BF_CUDA_MAX_CHANNELS) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for channel count > \" + std::to_string(BF_CUDA_MAX_CHANNELS));\n    }\n\n    if (input.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for spatial dimension > \" +\n          std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));\n    }\n\n    filterFunction = usePHL ? &BilateralFilterPHLCuda : &BilateralFilterCuda;\n  } else {\n    filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;\n  }\n#else\n  filterFunction = usePHL ? &BilateralFilterPHLCpu : &BilateralFilterCpu;\n#endif\n\n  return filterFunction(input, spatial_sigma, color_sigma);\n}\n"
  },
  {
    "path": "monai/csrc/filtering/bilateral/bilateral.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n#include <torch/extension.h>\n\n#define BF_CUDA_MAX_CHANNELS 16\n#define BF_CUDA_MAX_SPATIAL_DIMENSION 3\n\ntorch::Tensor BilateralFilterCpu(torch::Tensor input, float spatial_sigma, float color_sigma);\ntorch::Tensor BilateralFilterPHLCpu(torch::Tensor input, float spatial_sigma, float color_sigma);\n\n#ifdef WITH_CUDA\ntorch::Tensor BilateralFilterCuda(torch::Tensor input, float spatial_sigma, float color_sigma);\ntorch::Tensor BilateralFilterPHLCuda(torch::Tensor input, float spatial_sigma, float color_sigma);\n#endif\n\ntorch::Tensor BilateralFilter(torch::Tensor input, float spatial_sigma, float color_sigma, bool usePHL);\n"
  },
  {
    "path": "monai/csrc/filtering/bilateral/bilateralfilter_cpu.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <math.h>\n#include <torch/extension.h>\n\n#include \"utils/tensor_description.h\"\n#include \"utils/tensor_indexing.h\"\n\ntemplate <typename scalar_t>\nvoid BilateralFilterCpu(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Raw tensor data pointers.\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n\n  // Pre-calculate common values\n  int windowSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size\n  int halfWindowSize = floor(0.5f * windowSize);\n  scalar_t spatialExpConstant = -1.0f / (2 * spatialSigma * spatialSigma);\n  scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  // Kernel sizes.\n  int* kernelSizes = new int[desc.dimensions];\n\n  for (int i = 0; i < desc.dimensions; i++) {\n    kernelSizes[i] = windowSize;\n  }\n\n  // Pre-calculate gaussian kernel in 1D.\n  scalar_t* gaussianKernel = new scalar_t[windowSize];\n\n  for (int i = 0; i < windowSize; i++) {\n    int distance = i - halfWindowSize;\n    gaussianKernel[i] = exp(distance * distance * spatialExpConstant);\n  }\n\n  // Kernel aggregates used to calculate\n  // the output value.\n  scalar_t* valueSum = new scalar_t[desc.channelCount];\n  scalar_t weightSum = 0;\n\n  // Looping over the batches\n  for (int b = 0; b < desc.batchCount; b++) {\n    int batchOffset = b * desc.batchStride;\n\n    // Looping over all dimensions for the home element\n    Indexer homeIndex = Indexer(desc.dimensions, desc.sizes);\n    do // while(homeIndex++)\n    {\n      // Calculating indexing offset for the home element\n      int homeOffset = batchOffset;\n\n      for (int i = 0; i < desc.dimensions; i++) {\n        homeOffset += homeIndex[i] * desc.strides[i];\n      }\n\n      // Zero kernel aggregates.\n      for (int i = 0; i < desc.channelCount; i++) {\n        valueSum[i] = 0;\n      }\n\n      weightSum = 0.0f;\n\n      // Looping over all dimensions for the neighbour element\n      Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes);\n      do // while(kernelIndex++)\n      {\n        // Calculating buffer offset for the neighbour element\n        // Index is clamped to the border in each dimension.\n        int neighbourOffset = batchOffset;\n\n        for (int i = 0; i < desc.dimensions; i++) {\n          int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize;\n          int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex));\n          neighbourOffset += neighbourIndexClamped * desc.strides[i];\n        }\n\n        // Euclidean color distance.\n        scalar_t colorDistanceSquared = 0;\n\n        for (int i = 0; i < desc.channelCount; i++) {\n          scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] -\n              inputTensorData[neighbourOffset + i * desc.channelStride];\n          colorDistanceSquared += diff * diff;\n        }\n\n        // Calculating and combining the spatial\n        // and color weights.\n        scalar_t spatialWeight = 1;\n\n        for (int i = 0; i < desc.dimensions; i++) {\n          spatialWeight *= gaussianKernel[kernelIndex[i]];\n        }\n\n        scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant);\n        scalar_t totalWeight = spatialWeight * colorWeight;\n\n        // Aggregating values.\n        for (int i = 0; i < desc.channelCount; i++) {\n          valueSum[i] += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight;\n        }\n\n        weightSum += totalWeight;\n      } while (kernelIndex++);\n\n      for (int i = 0; i < desc.channelCount; i++) {\n        outputTensorData[homeOffset + i * desc.channelStride] = valueSum[i] / weightSum;\n      }\n    } while (homeIndex++);\n  }\n}\n\ntorch::Tensor BilateralFilterCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) {\n  // Preparing output tensor.\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), \"BilateralFilterCpu\", ([&] {\n                                        BilateralFilterCpu<scalar_t>(\n                                            inputTensor, outputTensor, spatialSigma, colorSigma);\n                                      }));\n\n  return outputTensor;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/bilateral/bilateralfilter_cpu_phl.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n\n#include \"filtering/permutohedral/permutohedral.h\"\n#include \"utils/tensor_description.h\"\n\ntemplate <typename scalar_t>\nvoid BilateralFilterPHLCpu(\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    float spatialSigma,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  int featureChannels = desc.channelCount + desc.dimensions;\n\n  // Preparing memory\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n  scalar_t* data = new scalar_t[desc.channelStride * desc.channelCount];\n  scalar_t* features = new scalar_t[desc.channelStride * featureChannels];\n\n  // Precalculating inverse sigmas\n  float invSpatialSigma = 1.0f / spatialSigma;\n  float invColorSigma = 1.0f / colorSigma;\n\n  // Looping over batches\n  for (int b = 0; b < desc.batchCount; b++) {\n    int batchOffset = b * desc.batchStride;\n\n    // Creating features (also permuting input data to be channel last. Permutohedral\n    // implementation should be changed to channel first to avoid this)\n    for (int i = 0; i < desc.channelStride; i++) {\n      // Color features (and permutation)\n      for (int c = 0; c < desc.channelCount; c++) {\n        features[i * featureChannels + c] = invColorSigma * inputTensorData[batchOffset + i + c * desc.channelStride];\n        data[i * desc.channelCount + c] = inputTensorData[batchOffset + i + c * desc.channelStride];\n      }\n\n      // Spatial features\n      int offsetRemainder = i;\n\n      for (int d = 0; d < desc.dimensions; d++) {\n        int coord = offsetRemainder / desc.strides[d];\n        offsetRemainder -= coord * desc.strides[d];\n\n        features[i * featureChannels + desc.channelCount + d] = (scalar_t)invSpatialSigma * coord;\n      }\n    }\n\n    // Filtering data with respect to the features.\n    PermutohedralCPU<scalar_t>(data, features, desc.channelCount, featureChannels, desc.channelStride);\n\n    // Writing output tensor.\n    for (int i = 0; i < desc.channelStride; i++) {\n      for (int c = 0; c < desc.channelCount; c++) {\n        outputTensorData[batchOffset + i + c * desc.channelStride] = data[i * desc.channelCount + c];\n      }\n    }\n  }\n\n  delete[] data;\n  delete[] features;\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\ntorch::Tensor BilateralFilterPHLCpu(torch::Tensor inputTensor, float spatialSigma, float colorSigma) {\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n\n  AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), \"BilateralFilterPhlCpu\", ([&] {\n                               BilateralFilterPHLCpu<scalar_t>(inputTensor, outputTensor, spatialSigma, colorSigma);\n                             }));\n\n  return outputTensor;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/bilateral/bilateralfilter_cuda.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"bilateral.h\"\n#include \"utils/meta_macros.h\"\n#include \"utils/tensor_description.h\"\n\n__constant__ int cBatchStride;\n__constant__ int cColorStride;\n\n__constant__ int cSizes[3];\n__constant__ int cStrides[3];\n\n__constant__ int cKernelSize;\n__constant__ float cKernel[256];\n\n__constant__ float cColorExponentFactor;\n\ntemplate <typename scalar_t, int C>\n__global__ void BilateralFilterCudaKernel1D(scalar_t* input, scalar_t* output) {\n  int kernelHalfSize = cKernelSize / 2;\n\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStride;\n\n  if (homeOffset >= cColorStride)\n    return;\n\n  scalar_t weightSum = 0;\n\n  for (int kernelOffset = 0; kernelOffset < cKernelSize; kernelOffset++) {\n    int neighbourOffset = max(0, min(homeOffset + (kernelOffset - kernelHalfSize), cSizes[0] - 1));\n    scalar_t gaussian = cKernel[kernelOffset];\n\n    scalar_t distanceSquared = 0;\n\n#pragma unroll\n    for (int c = 0; c < C; c++) {\n      scalar_t a = input[batchOffset + homeOffset + c * cColorStride];\n      scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride];\n      scalar_t diff = a - b;\n      distanceSquared += diff * diff;\n    }\n\n    scalar_t spatialWeight = gaussian;\n    scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared);\n    scalar_t totalWeight = spatialWeight * colorWeight;\n\n#pragma unroll\n    for (int c = 0; c < C; c++) {\n      scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride];\n\n      output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight;\n    }\n\n    weightSum += totalWeight;\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    output[batchOffset + homeOffset + c * cColorStride] /= weightSum;\n  }\n}\n\ntemplate <typename scalar_t, int C>\n__global__ void BilateralFilterCudaKernel2D(scalar_t* input, scalar_t* output) {\n  int kernelHalfSize = cKernelSize / 2;\n\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStride;\n\n  if (homeOffset >= cColorStride)\n    return;\n\n  int homeX = homeOffset / cStrides[0];\n  int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1];\n\n  scalar_t weightSum = 0;\n\n  for (int kernelX = 0; kernelX < cKernelSize; kernelX++) {\n    int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1));\n    scalar_t gaussianX = cKernel[kernelX];\n\n    for (int kernelY = 0; kernelY < cKernelSize; kernelY++) {\n      int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1));\n      scalar_t gaussianY = cKernel[kernelY];\n\n      int neighbourOffset = neighbourX * cStrides[0] + neighbourY;\n\n      scalar_t distanceSquared = 0;\n\n#pragma unroll\n      for (int c = 0; c < C; c++) {\n        scalar_t a = input[batchOffset + homeOffset + c * cColorStride];\n        scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride];\n        scalar_t diff = a - b;\n        distanceSquared += diff * diff;\n      }\n\n      scalar_t spatialWeight = gaussianX * gaussianY;\n      scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared);\n      scalar_t totalWeight = spatialWeight * colorWeight;\n\n#pragma unroll\n      for (int c = 0; c < C; c++) {\n        scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride];\n\n        output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight;\n      }\n\n      weightSum += totalWeight;\n    }\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    output[batchOffset + homeOffset + c * cColorStride] /= weightSum;\n  }\n}\n\ntemplate <typename scalar_t, int C>\n__global__ void BilateralFilterCudaKernel3D(scalar_t* input, scalar_t* output) {\n  int kernelHalfSize = cKernelSize / 2;\n\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStride;\n\n  if (homeOffset >= cColorStride)\n    return;\n\n  int homeX = homeOffset / cStrides[0];\n  int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1];\n  int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2];\n\n  scalar_t weightSum = 0;\n\n  for (int kernelX = 0; kernelX < cKernelSize; kernelX++) {\n    int neighbourX = max(0, min(homeX + (kernelX - kernelHalfSize), cSizes[0] - 1));\n    scalar_t gaussianX = cKernel[kernelX];\n\n    for (int kernelY = 0; kernelY < cKernelSize; kernelY++) {\n      int neighbourY = max(0, min(homeY + (kernelY - kernelHalfSize), cSizes[1] - 1));\n      scalar_t gaussianY = cKernel[kernelY];\n\n      for (int kernelZ = 0; kernelZ < cKernelSize; kernelZ++) {\n        int neighbourZ = max(0, min(homeZ + (kernelZ - kernelHalfSize), cSizes[2] - 1));\n        scalar_t gaussianZ = cKernel[kernelZ];\n\n        int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ;\n\n        scalar_t distanceSquared = 0;\n\n#pragma unroll\n        for (int c = 0; c < C; c++) {\n          scalar_t a = input[batchOffset + homeOffset + c * cColorStride];\n          scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride];\n          scalar_t diff = a - b;\n          distanceSquared += diff * diff;\n        }\n\n        scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ;\n        scalar_t colorWeight = exp(cColorExponentFactor * distanceSquared);\n        scalar_t totalWeight = spatialWeight * colorWeight;\n\n#pragma unroll\n        for (int c = 0; c < C; c++) {\n          scalar_t a = input[batchOffset + neighbourOffset + c * cColorStride];\n          output[batchOffset + homeOffset + c * cColorStride] += a * totalWeight;\n        }\n\n        weightSum += totalWeight;\n      }\n    }\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    output[batchOffset + homeOffset + c * cColorStride] /= weightSum;\n  }\n}\n\ntemplate <int C, int D>\nvoid BilateralFilterCuda(torch::Tensor inputTensor, torch::Tensor outputTensor, float spatialSigma, float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Pre-calculating exponent factors.\n  float spatialExponentFactor = -1.0f / (2 * spatialSigma * spatialSigma);\n  float colorExponentFactor = -1.0f / (2 * colorSigma * colorSigma);\n\n  // Pre-calculating gaussian kernel.\n  int kernelSize = (int)ceil(5.0f * spatialSigma) | 1; // ORing last bit to ensure odd window size\n  int kernelHalfSize = floor(0.5f * kernelSize);\n  float* kernel = new float[kernelSize];\n\n  for (int i = 0; i < kernelSize; i++) {\n    int distance = i - kernelHalfSize;\n    kernel[i] = exp(distance * distance * spatialExponentFactor);\n  }\n\n  // Writing constant memory.\n  cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int));\n  cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int));\n  cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * D);\n  cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * D);\n  cudaMemcpyToSymbol(cKernelSize, &kernelSize, sizeof(int));\n  cudaMemcpyToSymbol(cKernel, kernel, sizeof(float) * kernelSize);\n  cudaMemcpyToSymbol(cColorExponentFactor, &colorExponentFactor, sizeof(float));\n\n#define BLOCK_SIZE 32\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      inputTensor.scalar_type(), \"BilateralFilterCudaKernel\", ([&] {\n        // Dispatch kernel. (Partial template function specialisation not supported at present so using this switch\n        // instead)\n        switch (D) {\n          case (1):\n            BilateralFilterCudaKernel1D<scalar_t, C>\n                <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                    inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());\n            break;\n          case (2):\n            BilateralFilterCudaKernel2D<scalar_t, C>\n                <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                    inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());\n            break;\n          case (3):\n            BilateralFilterCudaKernel3D<scalar_t, C>\n                <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                    inputTensor.data_ptr<scalar_t>(), outputTensor.data_ptr<scalar_t>());\n            break;\n        }\n      }));\n\n  delete[] kernel;\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\ntorch::Tensor BilateralFilterCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) {\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n\n#define CASE(c, d) BilateralFilterCuda<c, d>(inputTensor, outputTensor, spatialSigma, colorSigma);\n  SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);\n\n  return outputTensor;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/bilateral/bilateralfilter_cuda_phl.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include \"bilateral.h\"\n#include \"filtering/permutohedral/permutohedral.h\"\n#include \"utils/meta_macros.h\"\n#include \"utils/tensor_description.h\"\n\n__constant__ int cBatchStride;\n__constant__ int cChannelStride;\n__constant__ int cSpatialStrides[3];\n__constant__ float cInvSpatialSigma;\n__constant__ float cInvColorSigma;\n\ntemplate <typename scalar_t, int C, int D>\n__global__ void FeatureCreation(const scalar_t* inputTensor, scalar_t* outputData, scalar_t* outputFeatures) {\n  int elementIndex = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchIndex = blockIdx.y;\n\n  if (elementIndex >= cChannelStride)\n    return;\n\n  int dataBatchOffset = batchIndex * cBatchStride;\n  int featureBatchOffset = batchIndex * (D + C) * cChannelStride;\n\n#pragma unroll\n  for (int i = 0; i < C; i++) {\n    outputData[dataBatchOffset + elementIndex * C + i] =\n        inputTensor[dataBatchOffset + elementIndex + i * cChannelStride];\n    outputFeatures[featureBatchOffset + elementIndex * (C + D) + i] =\n        inputTensor[dataBatchOffset + elementIndex + i * cChannelStride] * cInvColorSigma;\n  }\n\n  int remainder = elementIndex;\n\n#pragma unroll\n  for (int i = 0; i < D; i++) {\n    int coord = remainder / cSpatialStrides[i];\n    remainder -= coord * cSpatialStrides[i];\n\n    outputFeatures[featureBatchOffset + elementIndex * (C + D) + C + i] = coord * cInvSpatialSigma;\n  }\n}\n\ntemplate <typename scalar_t, int C>\n__global__ void WriteOutput(const scalar_t* data, scalar_t* outputTensor) {\n  int elementIndex = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchIndex = blockIdx.y;\n\n  if (elementIndex >= cChannelStride)\n    return;\n\n  int batchOffset = batchIndex * cBatchStride;\n\n#pragma unroll\n  for (int i = 0; i < C; i++) {\n    outputTensor[batchOffset + elementIndex + i * cChannelStride] = data[batchOffset + elementIndex * C + i];\n  }\n}\n\ntemplate <typename scalar_t, int C, int D>\nvoid BilateralFilterPHLCuda(\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    float spatialSigma,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  int featureChannelCount = desc.channelCount + desc.dimensions;\n\n  // Pre calculating inverse sigmas.\n  float invSpatialSigma = 1.0f / spatialSigma;\n  float invColorSigma = 1.0f / colorSigma;\n\n  // Preparing global memory\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n\n  scalar_t* data;\n  scalar_t* features;\n  cudaMalloc(&data, desc.batchCount * desc.channelStride * desc.channelCount * sizeof(scalar_t));\n  cudaMalloc(&features, desc.batchCount * desc.channelStride * featureChannelCount * sizeof(scalar_t));\n\n  // Preparing constant memory\n  cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int));\n  cudaMemcpyToSymbol(cChannelStride, &desc.channelStride, sizeof(int));\n  cudaMemcpyToSymbol(cSpatialStrides, desc.strides, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cInvSpatialSigma, &invSpatialSigma, sizeof(float));\n  cudaMemcpyToSymbol(cInvColorSigma, &invColorSigma, sizeof(float));\n\n#define BLOCK_SIZE 32\n\n  // Creating features\n  FeatureCreation<scalar_t, C, D>\n      <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n          inputTensorData, data, features);\n\n  // Filtering data with respect to the features for each sample in batch\n  for (int batchIndex = 0; batchIndex < desc.batchCount; batchIndex++) {\n    scalar_t* offsetData = data + batchIndex * desc.batchStride;\n    scalar_t* offsetFeatures = features + batchIndex * featureChannelCount * desc.channelStride;\n\n    PermutohedralCuda<scalar_t, C, C + D>(offsetData, offsetFeatures, desc.channelStride, true);\n  }\n\n  // Writing output\n  WriteOutput<scalar_t, C><<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n      data, outputTensorData);\n\n  cudaFree(data);\n  cudaFree(features);\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\ntorch::Tensor BilateralFilterPHLCuda(torch::Tensor inputTensor, float spatialSigma, float colorSigma) {\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n\n#define CASE(c, d)                                                                       \\\n  AT_DISPATCH_FLOATING_TYPES(inputTensor.scalar_type(), \"BilateralFilterCudaPHL\", ([&] { \\\n                               BilateralFilterPHLCuda<scalar_t, c, d>(                   \\\n                                   inputTensor, outputTensor, spatialSigma, colorSigma); \\\n                             }));\n\n  SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);\n\n  return outputTensor;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/filtering.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n#include \"bilateral/bilateral.h\"\n#include \"permutohedral/permutohedral.h\"\n#include \"trainable_bilateral/trainable_bilateral.h\"\n#include \"trainable_joint_bilateral/trainable_joint_bilateral.h\"\n"
  },
  {
    "path": "monai/csrc/filtering/permutohedral/hash_table.cuh",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n\n//#define USE_ADDITIVE_HASH\n\n// turn this on if you want to get slightly less memory consumption and slightly longer run times.\n//#define LINEAR_D_MEMORY\n\n#define USE_CUSTOM_MODULO\n\n__device__ __constant__ signed short* table_keys;\n__device__ __constant__ int* table_entries;\n__device__ __constant__ unsigned int table_capacity;\n__device__ __constant__ signed short* table_zeros;\n__device__ __constant__ char* table_rank;\n\n/*************************************************************/\n/* Fast computation of modulo operator with constant divisor */\n/*************************************************************/\n__device__ __constant__ unsigned int __div_m;\n__device__ __constant__ unsigned int __div_l;\n__device__ __constant__ unsigned int __div_c;\n\n#ifdef USE_CUSTOM_MODULO\n__device__ inline unsigned int modHash(unsigned int n) {\n  unsigned int t1 = __umulhi(__div_m, n);\n  return n - ((t1 + ((n - t1) >> 1)) >> (__div_l - 1)) * __div_c;\n}\n\n#else\n#define modHash(n) ((n) % (2 * table_capacity));\n#endif\n\n/*************************************************************/\n/* End modulo                                                */\n/*************************************************************/\n\n__device__ __constant__ static unsigned int hOffset[64];\n\ntemplate <typename scalar_t, int kd, int vd>\nstatic scalar_t* createHashTable(int capacity) {\n  scalar_t* values;\n  cudaMalloc(&values, capacity * vd * sizeof(scalar_t));\n  cudaMemset(values, 0, capacity * vd * sizeof(scalar_t));\n\n  int* entries;\n  cudaMalloc(&entries, capacity * 2 * sizeof(int));\n  cudaMemset(entries, -1, capacity * 2 * sizeof(int));\n\n  cudaMemcpyToSymbol(table_capacity, &capacity, sizeof(int));\n\n  cudaMemcpyToSymbol(table_entries, &entries, sizeof(int*));\n\n#ifdef LINEAR_D_MEMORY\n\n  char* ranks;\n  cudaMalloc(&ranks, capacity * sizeof(char));\n\n  signed short* zeros;\n  cudaMalloc(&zeros, capacity * sizeof(signed short));\n\n  cudaMemcpyToSymbol(table_rank, &ranks, sizeof(char*));\n  cudaMemcpyToSymbol(table_zeros, &zeros, sizeof(char*));\n\n#else\n\n  signed short* keys;\n  cudaMalloc(&keys, capacity * kd * sizeof(signed short));\n  cudaMemset(keys, 0, capacity * kd * sizeof(signed short));\n\n  cudaMemcpyToSymbol(table_keys, &keys, sizeof(unsigned int*));\n\n#endif\n\n  return values;\n}\n\ntemplate <typename scalar_t>\nstatic void destroyHashTable() {\n#ifndef LINEAR_D_MEMORY\n  signed short* keys;\n  cudaMemcpyFromSymbol(&keys, table_keys, sizeof(unsigned int*));\n  cudaFree(keys);\n#endif\n\n  int* entries;\n  cudaMemcpyFromSymbol(&entries, table_entries, sizeof(int*));\n  cudaFree(entries);\n}\n\ntemplate <int kd>\n__device__ __host__ static unsigned int hash(signed short* key) {\n  unsigned int k = 0;\n  for (int i = 0; i < kd; i++) {\n    k += key[i];\n    k = k * 2531011;\n  }\n  return k;\n}\n\ntemplate <int kd>\n__device__ __host__ static unsigned int hash(int* key) {\n  unsigned int k = 0;\n  for (int i = 0; i < kd; i++) {\n    k += key[i];\n    k = k * 2531011;\n  }\n  return k;\n}\n\ntemplate <int d>\n__device__ static bool matchKey(int idx, signed short* key) {\n  bool match = true;\n  int slot = idx / (d + 1), color = idx - slot * (d + 1);\n  char* rank = table_rank + slot * (d + 1);\n  signed short* zero = table_zeros + slot * (d + 1);\n\n  for (int i = 0; i < d && match; i++) {\n    match = (key[i] == zero[i] + color - (rank[i] > d - color ? (d + 1) : 0));\n  }\n\n  return match;\n}\n\ntemplate <int d>\n__device__ static void generateKey(int idx, signed short* key) {\n  int slot = idx / (d + 1), color = idx - slot * (d + 1);\n  char* rank = table_rank + slot * (d + 1);\n  signed short* zero = table_zeros + slot * (d + 1);\n\n  for (int i = 0; i < d; i++) {\n    key[i] = zero[i] + color - (rank[i] > d - color ? (d + 1) : 0);\n  }\n}\n\ntemplate <int kd>\n__device__ static int hashTableInsert(unsigned int fh, signed short* key, unsigned int slot) {\n  int h = modHash(fh);\n  while (1) {\n    int* e = &table_entries[h];\n\n    // If the cell is empty (-1), lock it (-2)\n    int contents = atomicCAS(e, -1, -2);\n\n    if (contents == -2) {\n      // If it was locked already, move on to the next cell\n    } else if (contents == -1) {\n      // If it was empty, we successfully locked it. Write our key.\n\n#ifndef LINEAR_D_MEMORY\n      for (int i = 0; i < kd; i++) {\n        table_keys[slot * kd + i] = key[i];\n      }\n#endif\n\n      // Unlock\n      atomicExch(e, slot);\n\n      return h;\n    } else {\n// The cell is unlocked and has a key in it, check if it matches\n#ifdef LINEAR_D_MEMORY\n      if (matchKey<kd>(contents, key))\n        return h;\n#else\n      bool match = true;\n\n      for (int i = 0; i < kd && match; i++) {\n        match = (table_keys[contents * kd + i] == key[i]);\n      }\n\n      if (match)\n        return h;\n#endif\n    }\n    // increment the bucket with wraparound\n    h++;\n\n    if (h == table_capacity * 2)\n      h = 0;\n  }\n}\n\ntemplate <int kd>\n__device__ static int hashTableInsert(signed short* key, unsigned int slot) {\n  unsigned int myHash = hash<kd>(key);\n  return hashTableInsert<kd>(myHash, key, slot);\n}\n\ntemplate <int kd>\n__device__ static int hashTableRetrieveWithHash(unsigned int fh, signed short* key) {\n  int h = modHash(fh);\n  while (1) {\n    int* e = table_entries + h;\n\n    if (*e == -1)\n      return -1;\n\n#ifdef LINEAR_D_MEMORY\n    if (matchKey<kd>((*e), key))\n      return *e;\n#else\n    bool match = true;\n\n    for (int i = 0; i < kd && match; i++) {\n      match = (table_keys[(*e) * kd + i] == key[i]);\n    }\n\n    if (match)\n      return *e;\n#endif\n\n    h++;\n\n    if (h == table_capacity * 2)\n      h = 0;\n  }\n}\n\ntemplate <int kd>\n__device__ static int hashTableRetrieve(signed short* key) {\n  int h = modHash(hash<kd>(key));\n  while (1) {\n    int* e = table_entries + h;\n\n    if (*e == -1)\n      return -1;\n\n#ifdef LINEAR_D_MEMORY\n    if (matchKey<kd>((*e), key))\n      return *e;\n#else\n    bool match = true;\n\n    for (int i = 0; i < kd && match; i++) {\n      match = (table_keys[(*e) * kd + i] == key[i]);\n    }\n\n    if (match)\n      return *e;\n#endif\n\n    h++;\n\n    if (h == table_capacity * 2)\n      h = 0;\n  }\n}\n"
  },
  {
    "path": "monai/csrc/filtering/permutohedral/permutohedral.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <stdexcept>\n#include <string>\n\n#include \"utils/common_utils.h\"\n#include \"utils/meta_macros.h\"\n\n#include \"permutohedral.h\"\n\ntorch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features) {\n  input = input.contiguous();\n\n  int batchCount = input.size(0);\n  int batchStride = input.stride(0);\n  int elementCount = input.stride(1);\n  int channelCount = input.size(1);\n  int featureCount = features.size(1);\n\n// movedim not support in torch < 1.7.1\n#if MONAI_TORCH_VERSION >= 10701\n  torch::Tensor data = input.clone().movedim(1, -1).contiguous();\n  features = features.movedim(1, -1).contiguous();\n#else\n  torch::Tensor data = input.clone();\n  features = features;\n\n  for (int i = 1; i < input.dim() - 1; i++) {\n    data = data.transpose(i, i + 1);\n    features = features.transpose(i, i + 1);\n  }\n\n  data = data.contiguous();\n  features = features.contiguous();\n#endif\n\n#ifdef WITH_CUDA\n  if (torch::cuda::is_available() && data.is_cuda()) {\n    CHECK_CONTIGUOUS_CUDA(data);\n\n    if (channelCount > PHL_CUDA_MAX_CHANNELS) {\n      throw std::runtime_error(\n          \"PHL filtering not implemented for channel count > \" + std::to_string(PHL_CUDA_MAX_CHANNELS));\n    }\n\n    if (featureCount > PHL_CUDA_MAX_FEATURES) {\n      throw std::runtime_error(\n          \"PHL filtering not implemented for feature count > \" + std::to_string(PHL_CUDA_MAX_FEATURES));\n    }\n\n#define CASE(dc, fc)                                                                                                  \\\n  AT_DISPATCH_FLOATING_TYPES(data.scalar_type(), \"PermutohedralCuda\", ([&] {                                          \\\n                               for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) {                      \\\n                                 scalar_t* offsetData = data.data_ptr<scalar_t>() + batchIndex * batchStride;         \\\n                                 scalar_t* offsetFeatures =                                                           \\\n                                     features.data_ptr<scalar_t>() + batchIndex * fc * elementCount;                  \\\n                                 PermutohedralCuda<scalar_t, dc, fc>(offsetData, offsetFeatures, elementCount, true); \\\n                               }                                                                                      \\\n                             }));\n    SWITCH_AB(CASE, PHL_CUDA_MAX_CHANNELS, PHL_CUDA_MAX_FEATURES, channelCount, featureCount);\n\n  } else {\n#endif\n    AT_DISPATCH_FLOATING_TYPES(\n        data.scalar_type(), \"PermutohedralCPU\", ([&] {\n          for (int batchIndex = 0; batchIndex < batchCount; batchIndex++) {\n            scalar_t* offsetData = data.data_ptr<scalar_t>() + batchIndex * batchStride;\n            scalar_t* offsetFeatures = features.data_ptr<scalar_t>() + batchIndex * featureCount * elementCount;\n            PermutohedralCPU<scalar_t>(offsetData, offsetFeatures, channelCount, featureCount, elementCount);\n          }\n        }));\n#ifdef WITH_CUDA\n  }\n#endif\n\n// movedim not support in torch < 1.7.1\n#if MONAI_TORCH_VERSION >= 10701\n  data = data.movedim(-1, 1);\n#else\n  for (int i = input.dim() - 1; i > 1; i--) {\n    data = data.transpose(i - 1, i);\n  }\n#endif\n\n  return data;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/permutohedral/permutohedral.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n#include <torch/extension.h>\n\n#define PHL_CUDA_MAX_CHANNELS 16\n#define PHL_CUDA_MAX_FEATURES 19\n\ntemplate <typename scalar_t>\nvoid PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount);\n#ifdef WITH_CUDA\ntemplate <typename scalar_t, int dc, int fc>\nvoid PermutohedralCuda(scalar_t* data, scalar_t* features, int elementCount, bool accurate);\n#endif\n\ntorch::Tensor PermutohedralFilter(torch::Tensor input, torch::Tensor features);\n"
  },
  {
    "path": "monai/csrc/filtering/permutohedral/permutohedral_cpu.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n/*\nAdapted from https://github.com/abadams/permutohedral\nwhich has the following license...\n\nMIT License\n\nCopyright (c) 2020 Andrew Adams\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n*/\n\n#include <math.h>\n#include <string.h>\n\n#include <torch/extension.h>\n\nusing namespace std;\n\n/***************************************************************/\n/* Hash table implementation for permutohedral lattice\n *\n * The lattice points are stored sparsely using a hash table.\n * The key for each point is its spatial location in the (d+1)-\n * dimensional space.\n */\n/***************************************************************/\ntemplate <typename scalar_t>\nclass HashTablePermutohedral {\n public:\n  /* Constructor\n   *  kd_: the dimensionality of the position vectors on the hyperplane.\n   *  vd_: the dimensionality of the value vectors\n   */\n  HashTablePermutohedral(int kd_, int vd_) : kd(kd_), vd(vd_) {\n    capacity = 1 << 15;\n    filled = 0;\n    entries = new Entry[capacity];\n    keys = new short[kd * capacity / 2];\n    values = new scalar_t[vd * capacity / 2];\n    memset(values, 0, sizeof(scalar_t) * vd * capacity / 2);\n  }\n\n  // Returns the number of vectors stored.\n  int size() {\n    return filled;\n  }\n\n  // Returns a pointer to the keys array.\n  short* getKeys() {\n    return keys;\n  }\n\n  // Returns a pointer to the values array.\n  scalar_t* getValues() {\n    return values;\n  }\n\n  /* Returns the index into the hash table for a given key.\n   *     key: a pointer to the position vector.\n   *       h: hash of the position vector.\n   *  create: a flag specifying whether an entry should be created,\n   *          should an entry with the given key not found.\n   */\n  int lookupOffset(short* key, size_t h, bool create = true) {\n    // Double hash table size if necessary\n    if (filled >= (capacity / 2) - 1) {\n      grow();\n    }\n\n    // Find the entry with the given key\n    while (1) {\n      Entry e = entries[h];\n      // check if the cell is empty\n      if (e.keyIdx == -1) {\n        if (!create)\n          return -1; // Return not found.\n        // need to create an entry. Store the given key.\n        for (int i = 0; i < kd; i++)\n          keys[filled * kd + i] = key[i];\n        e.keyIdx = filled * kd;\n        e.valueIdx = filled * vd;\n        entries[h] = e;\n        filled++;\n        return e.valueIdx;\n      }\n\n      // check if the cell has a matching key\n      bool match = true;\n      for (int i = 0; i < kd && match; i++)\n        match = keys[e.keyIdx + i] == key[i];\n      if (match)\n        return e.valueIdx;\n\n      // increment the bucket with wraparound\n      h++;\n      if (h == capacity)\n        h = 0;\n    }\n  }\n\n  /* Looks up the value vector associated with a given key vector.\n   *        k : pointer to the key vector to be looked up.\n   *   create : true if a non-existing key should be created.\n   */\n  scalar_t* lookup(short* k, bool create = true) {\n    size_t h = hash(k) % capacity;\n    int offset = lookupOffset(k, h, create);\n    if (offset < 0)\n      return NULL;\n    else\n      return values + offset;\n  };\n\n  /* Hash function used in this implementation. A simple base conversion. */\n  size_t hash(const short* key) {\n    size_t k = 0;\n    for (int i = 0; i < kd; i++) {\n      k += key[i];\n      k *= 2531011;\n    }\n    return k;\n  }\n\n private:\n  /* Grows the size of the hash table */\n  void grow() {\n    size_t oldCapacity = capacity;\n    capacity *= 2;\n\n    // Migrate the value vectors.\n    scalar_t* newValues = new scalar_t[vd * capacity / 2];\n    memset(newValues, 0, sizeof(scalar_t) * vd * capacity / 2);\n    memcpy(newValues, values, sizeof(scalar_t) * vd * filled);\n    delete[] values;\n    values = newValues;\n\n    // Migrate the key vectors.\n    short* newKeys = new short[kd * capacity / 2];\n    memcpy(newKeys, keys, sizeof(short) * kd * filled);\n    delete[] keys;\n    keys = newKeys;\n\n    Entry* newEntries = new Entry[capacity];\n\n    // Migrate the table of indices.\n    for (size_t i = 0; i < oldCapacity; i++) {\n      if (entries[i].keyIdx == -1)\n        continue;\n      size_t h = hash(keys + entries[i].keyIdx) % capacity;\n      while (newEntries[h].keyIdx != -1) {\n        h++;\n        if (h == capacity)\n          h = 0;\n      }\n      newEntries[h] = entries[i];\n    }\n    delete[] entries;\n    entries = newEntries;\n  }\n\n  // Private struct for the hash table entries.\n  struct Entry {\n    Entry() : keyIdx(-1), valueIdx(-1) {}\n    int keyIdx;\n    int valueIdx;\n  };\n\n  short* keys;\n  scalar_t* values;\n  Entry* entries;\n  size_t capacity, filled;\n  int kd, vd;\n};\n\n/***************************************************************/\n/* The algorithm class that performs the filter\n *\n * PermutohedralLattice::filter(...) does all the work.\n *\n */\n/***************************************************************/\ntemplate <typename scalar_t>\nclass PermutohedralLattice {\n public:\n  /* Filters given image against a reference image.\n   *   im : image to be bilateral-filtered.\n   *  ref : reference image whose edges are to be respected.\n   */\n  static void filter(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) {\n    // Create lattice\n    PermutohedralLattice lattice(featureChannels, dataChannels + 1, elementCount);\n\n    // Splat into the lattice\n    scalar_t* col = new scalar_t[dataChannels + 1];\n    col[dataChannels] = 1; // homogeneous coordinate\n\n    for (int i = 0, e = 0; e < elementCount; e++) {\n      for (int c = 0; c < dataChannels; c++, i++) {\n        col[c] = data[i];\n      }\n\n      scalar_t* featureVec = features + e * featureChannels;\n      lattice.splat(featureVec, col);\n    }\n\n    // Blur the lattice\n    lattice.blur();\n\n    // Slice from the lattice\n    lattice.beginSlice();\n\n    for (int i = 0, e = 0; e < elementCount; e++) {\n      lattice.slice(col);\n\n      scalar_t scale = 1.0f / col[dataChannels];\n      for (int c = 0; c < dataChannels; c++, i++) {\n        data[i] = col[c] * scale;\n      }\n    }\n  }\n\n  /* Constructor\n   *     d_ : dimensionality of key vectors\n   *    vd_ : dimensionality of value vectors\n   * nData_ : number of points in the input\n   */\n  PermutohedralLattice(int d_, int vd_, int nData_) : d(d_), vd(vd_), nData(nData_), hashTable(d_, vd_) {\n    // Allocate storage for various arrays\n    elevated = new scalar_t[d + 1];\n    scaleFactor = new scalar_t[d];\n\n    greedy = new short[d + 1];\n    rank = new char[d + 1];\n    barycentric = new scalar_t[d + 2];\n    replay = new ReplayEntry[nData * (d + 1)];\n    nReplay = 0;\n    canonical = new short[(d + 1) * (d + 1)];\n    key = new short[d + 1];\n\n    // compute the coordinates of the canonical simplex, in which\n    // the difference between a contained point and the zero\n    // remainder vertex is always in ascending order. (See pg.4 of paper.)\n    for (int i = 0; i <= d; i++) {\n      for (int j = 0; j <= d - i; j++)\n        canonical[i * (d + 1) + j] = i;\n      for (int j = d - i + 1; j <= d; j++)\n        canonical[i * (d + 1) + j] = i - (d + 1);\n    }\n\n    // Compute parts of the rotation matrix E. (See pg.4-5 of paper.)\n    for (int i = 0; i < d; i++) {\n      // the diagonal entries for normalization\n      scaleFactor[i] = 1.0f / (sqrtf((scalar_t)(i + 1) * (i + 2)));\n\n      /* We presume that the user would like to do a Gaussian blur of standard deviation\n       * 1 in each dimension (or a total variance of d, summed over dimensions.)\n       * Because the total variance of the blur performed by this algorithm is not d,\n       * we must scale the space to offset this.\n       *\n       * The total variance of the algorithm is (See pg.6 and 10 of paper):\n       *  [variance of splatting] + [variance of blurring] + [variance of splatting]\n       *   = d(d+1)(d+1)/12 + d(d+1)(d+1)/2 + d(d+1)(d+1)/12\n       *   = 2d(d+1)(d+1)/3.\n       *\n       * So we need to scale the space by (d+1)sqrt(2/3).\n       */\n      scaleFactor[i] *= (d + 1) * sqrtf(2.0 / 3);\n    }\n  }\n\n  /* Performs splatting with given position and value vectors */\n  void splat(scalar_t* position, scalar_t* value) {\n    // first rotate position into the (d+1)-dimensional hyperplane\n    elevated[d] = -d * position[d - 1] * scaleFactor[d - 1];\n    for (int i = d - 1; i > 0; i--)\n      elevated[i] =\n          (elevated[i + 1] - i * position[i - 1] * scaleFactor[i - 1] + (i + 2) * position[i] * scaleFactor[i]);\n    elevated[0] = elevated[1] + 2 * position[0] * scaleFactor[0];\n\n    // prepare to find the closest lattice points\n    scalar_t scale = 1.0f / (d + 1);\n    char* myrank = rank;\n    short* mygreedy = greedy;\n\n    // greedily search for the closest zero-colored lattice point\n    int sum = 0;\n    for (int i = 0; i <= d; i++) {\n      scalar_t v = elevated[i] * scale;\n      scalar_t up = ceilf(v) * (d + 1);\n      scalar_t down = floorf(v) * (d + 1);\n\n      if (up - elevated[i] < elevated[i] - down)\n        mygreedy[i] = (short)up;\n      else\n        mygreedy[i] = (short)down;\n\n      sum += mygreedy[i];\n    }\n    sum /= d + 1;\n\n    // rank differential to find the permutation between this simplex and the canonical one.\n    // (See pg. 3-4 in paper.)\n    memset(myrank, 0, sizeof(char) * (d + 1));\n    for (int i = 0; i < d; i++)\n      for (int j = i + 1; j <= d; j++)\n        if (elevated[i] - mygreedy[i] < elevated[j] - mygreedy[j])\n          myrank[i]++;\n        else\n          myrank[j]++;\n\n    if (sum > 0) {\n      // sum too large - the point is off the hyperplane.\n      // need to bring down the ones with the smallest differential\n      for (int i = 0; i <= d; i++) {\n        if (myrank[i] >= d + 1 - sum) {\n          mygreedy[i] -= d + 1;\n          myrank[i] += sum - (d + 1);\n        } else\n          myrank[i] += sum;\n      }\n    } else if (sum < 0) {\n      // sum too small - the point is off the hyperplane\n      // need to bring up the ones with largest differential\n      for (int i = 0; i <= d; i++) {\n        if (myrank[i] < -sum) {\n          mygreedy[i] += d + 1;\n          myrank[i] += (d + 1) + sum;\n        } else\n          myrank[i] += sum;\n      }\n    }\n\n    // Compute barycentric coordinates (See pg.10 of paper.)\n    memset(barycentric, 0, sizeof(scalar_t) * (d + 2));\n    for (int i = 0; i <= d; i++) {\n      barycentric[d - myrank[i]] += (elevated[i] - mygreedy[i]) * scale;\n      barycentric[d + 1 - myrank[i]] -= (elevated[i] - mygreedy[i]) * scale;\n    }\n    barycentric[0] += 1.0f + barycentric[d + 1];\n\n    // Splat the value into each vertex of the simplex, with barycentric weights.\n    for (int remainder = 0; remainder <= d; remainder++) {\n      // Compute the location of the lattice point explicitly (all but the last coordinate - it's redundant because they\n      // sum to zero)\n      for (int i = 0; i < d; i++)\n        key[i] = mygreedy[i] + canonical[remainder * (d + 1) + myrank[i]];\n\n      // Retrieve pointer to the value at this vertex.\n      scalar_t* val = hashTable.lookup(key, true);\n\n      // Accumulate values with barycentric weight.\n      for (int i = 0; i < vd; i++)\n        val[i] += barycentric[remainder] * value[i];\n\n      // Record this interaction to use later when slicing\n      replay[nReplay].offset = val - hashTable.getValues();\n      replay[nReplay].weight = barycentric[remainder];\n      nReplay++;\n    }\n  }\n\n  // Prepare for slicing\n  void beginSlice() {\n    nReplay = 0;\n  }\n\n  /* Performs slicing out of position vectors. Note that the barycentric weights and the simplex\n   * containing each position vector were calculated and stored in the splatting step.\n   * We may reuse this to accelerate the algorithm. (See pg. 6 in paper.)\n   */\n  void slice(scalar_t* col) {\n    scalar_t* base = hashTable.getValues();\n    for (int j = 0; j < vd; j++)\n      col[j] = 0;\n    for (int i = 0; i <= d; i++) {\n      ReplayEntry r = replay[nReplay++];\n      for (int j = 0; j < vd; j++) {\n        col[j] += r.weight * base[r.offset + j];\n      }\n    }\n  }\n\n  /* Performs a Gaussian blur along each projected axis in the hyperplane. */\n  void blur() {\n    // Prepare arrays\n    short* neighbor1 = new short[d + 1];\n    short* neighbor2 = new short[d + 1];\n    scalar_t* newValue = new scalar_t[vd * hashTable.size()];\n    scalar_t* oldValue = hashTable.getValues();\n    scalar_t* hashTableBase = oldValue;\n\n    scalar_t* zero = new scalar_t[vd];\n    for (int k = 0; k < vd; k++)\n      zero[k] = 0;\n\n    // For each of d+1 axes,\n    for (int j = 0; j <= d; j++) {\n      // For each vertex in the lattice,\n      for (int i = 0; i < hashTable.size(); i++) { // blur point i in dimension j\n        short* key = hashTable.getKeys() + i * (d); // keys to current vertex\n        for (int k = 0; k < d; k++) {\n          neighbor1[k] = key[k] + 1;\n          neighbor2[k] = key[k] - 1;\n        }\n        neighbor1[j] = key[j] - d;\n        neighbor2[j] = key[j] + d; // keys to the neighbors along the given axis.\n\n        scalar_t* oldVal = oldValue + i * vd;\n        scalar_t* newVal = newValue + i * vd;\n\n        scalar_t *vm1, *vp1;\n\n        vm1 = hashTable.lookup(neighbor1, false); // look up first neighbor\n        if (vm1)\n          vm1 = vm1 - hashTableBase + oldValue;\n        else\n          vm1 = zero;\n\n        vp1 = hashTable.lookup(neighbor2, false); // look up second neighbor\n        if (vp1)\n          vp1 = vp1 - hashTableBase + oldValue;\n        else\n          vp1 = zero;\n\n        // Mix values of the three vertices\n        for (int k = 0; k < vd; k++)\n          newVal[k] = (0.25f * vm1[k] + 0.5f * oldVal[k] + 0.25f * vp1[k]);\n      }\n      scalar_t* tmp = newValue;\n      newValue = oldValue;\n      oldValue = tmp;\n      // the freshest data is now in oldValue, and newValue is ready to be written over\n    }\n\n    // depending where we ended up, we may have to copy data\n    if (oldValue != hashTableBase) {\n      memcpy(hashTableBase, oldValue, hashTable.size() * vd * sizeof(scalar_t));\n      delete[] oldValue;\n    } else {\n      delete[] newValue;\n    }\n\n    delete[] zero;\n    delete[] neighbor1;\n    delete[] neighbor2;\n  }\n\n private:\n  int d, vd, nData;\n  scalar_t *elevated, *scaleFactor, *barycentric;\n  short* canonical;\n  short* key;\n\n  // slicing is done by replaying splatting (ie storing the sparse matrix)\n  struct ReplayEntry {\n    int offset;\n    scalar_t weight;\n  } * replay;\n  int nReplay, nReplaySub;\n\n public:\n  char* rank;\n  short* greedy;\n  HashTablePermutohedral<scalar_t> hashTable;\n};\n\ntemplate <typename scalar_t>\nvoid PermutohedralCPU(scalar_t* data, scalar_t* features, int dataChannels, int featureChannels, int elementCount) {\n  PermutohedralLattice<scalar_t>::filter(data, features, dataChannels, featureChannels, elementCount);\n}\n\ntemplate void PermutohedralCPU(float* data, float* features, int dataChannels, int featureChannels, int elementCount);\ntemplate void PermutohedralCPU(double* data, double* features, int dataChannels, int featureChannels, int elementCount);\n"
  },
  {
    "path": "monai/csrc/filtering/permutohedral/permutohedral_cuda.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n/*\nAdapted from https://github.com/abadams/permutohedral\nwhich has the following license...\n\nMIT License\n\nCopyright (c) 2020 Andrew Adams\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n*/\n\n#define BLOCK_SIZE 32\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <stdio.h>\n#include <torch/extension.h>\n#include <THC/THCAtomics.cuh>\n\n#include \"hash_table.cuh\"\n#include \"permutohedral.h\"\n#include \"utils/meta_macros.h\"\n\ntemplate <typename scalar_t>\nstruct MatrixEntry {\n  int index;\n  scalar_t weight;\n};\n\ntemplate <typename scalar_t, int pd>\n__global__ static void createMatrix(\n    const int elementCount,\n    const scalar_t* positions,\n    const scalar_t* values,\n    const scalar_t* scaleFactor,\n    MatrixEntry<scalar_t>* matrix) {\n  const int threadId = threadIdx.x;\n  const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE;\n  const bool outOfBounds = idx >= elementCount;\n\n  scalar_t myElevated[pd + 1];\n  const scalar_t* myPosition = positions + idx * pd;\n\n  int myGreedy[pd + 1];\n  int myRank[pd + 1];\n\n  scalar_t myBarycentric[pd + 2];\n  __shared__ short keys[pd * BLOCK_SIZE];\n  short* myKey = keys + threadId * pd;\n\n  if (!outOfBounds) {\n    myElevated[pd] = -pd * myPosition[pd - 1] * scaleFactor[pd - 1];\n\n    for (int i = pd - 1; i > 0; i--) {\n      myElevated[i] =\n          myElevated[i + 1] - i * (myPosition[i - 1]) * scaleFactor[i - 1] + (i + 2) * myPosition[i] * scaleFactor[i];\n    }\n\n    myElevated[0] = myElevated[1] + 2 * myPosition[0] * scaleFactor[0];\n\n    // find the closest zero-colored lattice point\n\n    // greedily search for the closest zero-colored lattice point\n    signed short sum = 0;\n\n    for (int i = 0; i <= pd; i++) {\n      scalar_t v = myElevated[i] * (1.0f / (pd + 1));\n      scalar_t up = ceilf(v) * (pd + 1);\n      scalar_t down = floorf(v) * (pd + 1);\n\n      myGreedy[i] = (signed short)(up - myElevated[i] < myElevated[i] - down ? up : down);\n      sum += myGreedy[i];\n    }\n\n    sum /= pd + 1;\n\n    // sort differential to find the permutation between this simplex and the canonical one\n    for (int i = 0; i <= pd; i++) {\n      myRank[i] = 0;\n\n      for (int j = 0; j <= pd; j++) {\n        scalar_t iDiff = myElevated[i] - myGreedy[i];\n        scalar_t jDiff = myElevated[j] - myGreedy[j];\n\n        if (iDiff < jDiff || (iDiff == jDiff && i > j)) {\n          myRank[i]++;\n        }\n      }\n    }\n\n    if (sum > 0) // sum too large, need to bring down the ones with the smallest differential\n    {\n      for (int i = 0; i <= pd; i++) {\n        if (myRank[i] >= pd + 1 - sum) {\n          myGreedy[i] -= (pd + 1);\n          myRank[i] += sum - (pd + 1);\n        } else {\n          myRank[i] += sum;\n        }\n      }\n    } else if (sum < 0) // sum too small, need to bring up the ones with largest differential\n    {\n      for (int i = 0; i <= pd; i++) {\n        if (myRank[i] < -sum) {\n          myGreedy[i] += (pd + 1);\n          myRank[i] += sum + (pd + 1);\n        } else {\n          myRank[i] += sum;\n        }\n      }\n    }\n\n#ifdef LINEAR_D_MEMORY\n    for (int i = 0; i <= pd; i++) {\n      table_zeros[idx * (pd + 1) + i] = myGreedy[i];\n      table_rank[idx * (pd + 1) + i] = myRank[i];\n    }\n#endif\n\n    // turn delta into barycentric coords\n    for (int i = 0; i <= pd + 1; i++) {\n      myBarycentric[i] = 0;\n    }\n\n    for (int i = 0; i <= pd; i++) {\n      scalar_t delta = (myElevated[i] - myGreedy[i]) * (1.0f / (pd + 1));\n      myBarycentric[pd - myRank[i]] += delta;\n      myBarycentric[pd + 1 - myRank[i]] -= delta;\n    }\n\n    myBarycentric[0] += 1.0f + myBarycentric[pd + 1];\n  }\n\n#ifdef USE_ADDITIVE_HASH\n  unsigned int cumulative_hash = hash<pd>(myGreedy);\n#endif\n\n  for (int color = 0; color <= pd; color++) {\n    // Compute the location of the lattice point explicitly (all but\n    // the last coordinate - it's redundant because they sum to zero)\n    if (!outOfBounds) {\n      for (int i = 0; i < pd; i++) {\n        myKey[i] = myGreedy[i] + color;\n\n        if (myRank[i] > pd - color) {\n          myKey[i] -= (pd + 1);\n        }\n      }\n    }\n\n#ifdef USE_ADDITIVE_HASH\n    for (int i = 0; i < pd; i++) {\n      if (myRank[i] == pd - color) {\n        cumulative_hash += hOffset[i];\n      }\n    }\n#endif\n\n    if (!outOfBounds) {\n      MatrixEntry<scalar_t> r;\n\n#ifdef USE_ADDITIVE_HASH\n      r.index = hashTableInsert<pd>(cumulative_hash, myKey, idx * (pd + 1) + color);\n#else\n      r.index = hashTableInsert<pd>(myKey, idx * (pd + 1) + color);\n#endif\n\n      r.weight = myBarycentric[color];\n      matrix[idx * (pd + 1) + color] = r;\n    }\n  }\n}\n\ntemplate <typename scalar_t, int kd>\n__global__ static void cleanHashTable(const int elementCount, MatrixEntry<scalar_t>* matrix) {\n  const int idx = threadIdx.x + blockIdx.x * blockDim.x;\n\n  if (idx >= elementCount)\n    return;\n\n  // find my hash table entry\n  int* e = table_entries + idx;\n\n  // Check if I created my own key in the previous phase\n  if (*e >= 0) {\n    // Rehash my key and reset the pointer in order to merge with\n    // any other pixel that created a different entry under the\n    // same key. If the computation was serial this would never\n    // happen, but sometimes race conditions can make the same key\n    // be inserted twice. hashTableRetrieve always returns the\n    // earlier, so it's no problem as long as we rehash now.\n\n#ifdef LINEAR_D_MEMORY\n    // Get my key\n    short myKey[kd];\n    generateKey<kd>(*e, myKey);\n    *e = hashTableRetrieve<kd>(myKey);\n#else\n    *e = hashTableRetrieve<kd>(table_keys + *e * kd);\n#endif\n  }\n}\n\ntemplate <typename scalar_t, int pd, int vd>\n__global__ static void splat(\n    const int elementCount,\n    scalar_t* values,\n    MatrixEntry<scalar_t>* matrix,\n    scalar_t* table_values) {\n  const int color = threadIdx.y;\n  const int idx = threadIdx.x + blockIdx.x * blockDim.x;\n\n  const bool outOfBounds = idx >= elementCount;\n\n  if (outOfBounds) {\n    return;\n  }\n\n  scalar_t* myValue = values + idx * vd;\n\n  MatrixEntry<scalar_t> r = matrix[idx * (pd + 1) + color];\n\n  matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index];\n  scalar_t* val = table_values + r.index * (vd + 1);\n\n  for (int j = 0; j < vd; j++) {\n    gpuAtomicAdd(val + j, myValue[j] * r.weight);\n  }\n\n  gpuAtomicAdd(val + vd, r.weight);\n}\n\n// splat splits by color, so extend the y coordinate to our blocks to represent that\n// dim3 oldblocks((w-1)/8+1, (h-1)/8+1, 1);\n// dim3 oldblockSize(8, 8, 1);\n// oldblocks.y *= pd+1;\n// splatCache<pd, vd><<<oldblocks, oldblockSize>>>(w, h, values, matrix);\n\n// int blockCount = (elementCount + 1) / BLOCK_SIZE + 1;\n// int blockSize = BLOCK_SIZE;\n\n// splatCache<pd, vd><<<dim3(blockCount, 1), dim3(blockSize, pd+1)>>>(elementCount, values, matrix);\n\ntemplate <typename scalar_t, int pd, int vd>\n__global__ static void splatCache(\n    const int elementCount,\n    scalar_t* values,\n    MatrixEntry<scalar_t>* matrix,\n    scalar_t* table_values) {\n  // const int x = threadIdx.x + blockIdx.x * blockDim.x;\n  // const int y = threadIdx.y + (blockIdx.y/(pd+1)) * blockDim.y;\n\n  // const int threadId = threadIdx.y*blockDim.x + threadIdx.x;\n  // const int color = blockIdx.y % (pd+1);\n  // const int idx = y*w + x;\n\n  const int threadId = threadIdx.x;\n  const int color = threadIdx.y;\n  const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE;\n\n  const bool outOfBounds = idx >= elementCount;\n\n  __shared__ int sharedOffsets[BLOCK_SIZE];\n  __shared__ scalar_t sharedValues[BLOCK_SIZE * (vd + 1)];\n\n  int myOffset = -1;\n  scalar_t* myValue = sharedValues + threadId * (vd + 1);\n\n  if (!outOfBounds) {\n    scalar_t* value = values + idx * vd;\n\n    MatrixEntry<scalar_t> r = matrix[idx * (pd + 1) + color];\n\n    // convert the matrix entry from a pointer into the entries array to a pointer into the keys/values array\n    matrix[idx * (pd + 1) + color].index = r.index = table_entries[r.index];\n    // record the offset into the keys/values array in shared space\n    myOffset = sharedOffsets[threadId] = r.index * (vd + 1);\n\n    for (int j = 0; j < vd; j++) {\n      myValue[j] = value[j] * r.weight;\n    }\n    myValue[vd] = r.weight;\n\n  } else {\n    sharedOffsets[threadId] = -1;\n  }\n\n  __syncthreads();\n\n  // am I the first thread in this block to care about this key?\n\n  if (outOfBounds)\n    return;\n\n  for (int i = 0; i < BLOCK_SIZE; i++) {\n    if (i < threadId) {\n      if (myOffset == sharedOffsets[i]) {\n        // somebody else with higher priority cares about this key\n        return;\n      }\n    } else if (i > threadId) {\n      if (myOffset == sharedOffsets[i]) {\n        // someone else with lower priority cares about this key, accumulate it into mine\n        for (int j = 0; j <= vd; j++) {\n          sharedValues[threadId * (vd + 1) + j] += sharedValues[i * (vd + 1) + j];\n        }\n      }\n    }\n  }\n\n  // only the threads with something to write to main memory are still going\n  scalar_t* val = table_values + myOffset;\n  for (int j = 0; j <= vd; j++) {\n    gpuAtomicAdd(val + j, myValue[j]);\n  }\n}\n\ntemplate <typename scalar_t, int pd, int vd>\n__global__ static void blur(\n    int n,\n    scalar_t* newValues,\n    MatrixEntry<scalar_t>* matrix,\n    int color,\n    scalar_t* table_values) {\n  const int idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x * blockDim.y + threadIdx.x;\n\n  if (idx >= n)\n    return;\n\n  // Check if I'm valid\n  if (matrix[idx].index != idx)\n    return;\n\n  // find my key and the keys of my neighbours\n  short myKey[pd + 1];\n  short np[pd + 1];\n  short nm[pd + 1];\n\n#ifdef LINEAR_D_MEMORY\n  generateKey<pd>(idx, myKey);\n  for (int i = 0; i < pd; i++) {\n    np[i] = myKey[i] + 1;\n    nm[i] = myKey[i] - 1;\n  }\n#else\n  for (int i = 0; i < pd; i++) {\n    myKey[i] = table_keys[idx * pd + i];\n    np[i] = myKey[i] + 1;\n    nm[i] = myKey[i] - 1;\n  }\n#endif\n\n  np[color] -= pd + 1;\n  nm[color] += pd + 1;\n\n#ifdef USE_ADDITIVE_HASH\n  unsigned int hCurrent = hash<pd>(myKey);\n  int offNp = hashTableRetrieveWithHash<pd>(hCurrent + hOffset[color], np);\n  int offNm = hashTableRetrieveWithHash<pd>(hCurrent - hOffset[color], nm);\n#else\n  int offNp = hashTableRetrieve<pd>(np);\n  int offNm = hashTableRetrieve<pd>(nm);\n#endif\n\n  scalar_t* valMe = table_values + (vd + 1) * idx;\n  scalar_t* valNp = table_values + (vd + 1) * offNp;\n  scalar_t* valNm = table_values + (vd + 1) * offNm;\n  scalar_t* valOut = newValues + (vd + 1) * idx;\n\n  if (offNp >= 0 && offNm >= 0) {\n    for (int i = 0; i <= vd; i++) {\n      valOut[i] = (valNp[i] + (valMe[i] * 2) + valNm[i]) / 4;\n    }\n  } else if (offNp >= 0) {\n    for (int i = 0; i <= vd; i++) {\n      valOut[i] = (valNp[i] + (valMe[i] * 2)) / 4;\n    }\n  } else if (offNm >= 0) {\n    for (int i = 0; i <= vd; i++) {\n      valOut[i] = (valNm[i] + (valMe[i] * 2)) / 4;\n    }\n  } else {\n    for (int i = 0; i <= vd; i++) {\n      valOut[i] = valMe[i] * 2;\n    }\n  }\n}\n\ntemplate <typename scalar_t, int pd, int vd>\n__global__ static void slice(\n    const int elementCount,\n    scalar_t* values,\n    MatrixEntry<scalar_t>* matrix,\n    scalar_t* table_values) {\n  const int threadId = threadIdx.x;\n  const int idx = threadIdx.x + blockIdx.x * BLOCK_SIZE;\n  const bool outOfBounds = idx >= elementCount;\n\n  if (outOfBounds)\n    return;\n\n  __shared__ scalar_t localValue[BLOCK_SIZE * vd];\n\n  scalar_t* myValue = localValue + threadId * vd;\n  scalar_t myWeight = 0;\n\n  for (int i = 0; i < vd; i++) {\n    myValue[i] = 0;\n  }\n\n  for (int i = 0; i <= pd; i++) {\n    MatrixEntry<scalar_t> r = matrix[idx * (pd + 1) + i];\n    scalar_t* val = table_values + r.index * (vd + 1);\n\n    for (int j = 0; j < vd; j++) {\n      myValue[j] += r.weight * val[j];\n    }\n\n    myWeight += r.weight * val[vd];\n  }\n\n  myWeight = 1.0f / myWeight;\n\n  for (int j = 0; j < vd; j++) {\n    values[idx * vd + j] = myValue[j] * myWeight;\n  }\n}\n\ntemplate <typename scalar_t, int vd, int pd>\nvoid PermutohedralCuda(scalar_t* values, scalar_t* positions, int elementCount, bool accurate) {\n  scalar_t blurVariance = accurate ? 0.5 : 0;\n\n  scalar_t* scaleFactor;\n  cudaMalloc(&scaleFactor, pd * sizeof(scalar_t));\n\n  scalar_t scaleFactorHost[pd];\n  for (int i = 0; i < pd; i++) {\n    scaleFactorHost[i] = (pd + 1) * sqrtf((1.0 / 6 + blurVariance) / ((i + 1) * (i + 2)));\n  }\n\n  cudaMemcpy(scaleFactor, scaleFactorHost, pd * sizeof(scalar_t), cudaMemcpyHostToDevice);\n\n  MatrixEntry<scalar_t>* matrix;\n  cudaMalloc(&matrix, elementCount * (pd + 1) * sizeof(MatrixEntry<scalar_t>));\n\n  scalar_t* table_values = createHashTable<scalar_t, pd, vd + 1>(elementCount * (pd + 1));\n\n  // Populate constant memory for hash helpers\n  unsigned long long int __host_two32 = ((unsigned long long int)1) << 32;\n  unsigned int __host_div_c = 2 * (elementCount * (pd + 1));\n  unsigned int __host_div_l = ceilf(logf((float)__host_div_c) / logf(2.0f));\n  unsigned int __host_div_m = (__host_two32 << __host_div_l) / __host_div_c - __host_two32 + 1;\n  cudaMemcpyToSymbol(__div_c, &__host_div_c, sizeof(unsigned int));\n  cudaMemcpyToSymbol(__div_l, &__host_div_l, sizeof(unsigned int));\n  cudaMemcpyToSymbol(__div_m, &__host_div_m, sizeof(unsigned int));\n\n  // Populate constant memory with hash of offset vectors\n  unsigned int hOffset_host[pd + 1];\n  signed short offset[pd + 1];\n  for (int i = 0; i < pd; offset[i] = 1, i++)\n    ;\n  for (int i = 0; i <= pd; i++) {\n    offset[i] -= pd + 1;\n    hOffset_host[i] = hash<pd>(offset);\n    offset[i] += pd + 1;\n  }\n  cudaMemcpyToSymbol(hOffset, &hOffset_host, sizeof(unsigned int) * (pd + 1));\n\n  int blockCount = (elementCount + 1) / BLOCK_SIZE + 1;\n  int blockSize = BLOCK_SIZE;\n\n  createMatrix<scalar_t, pd><<<blockCount, blockSize>>>(elementCount, positions, values, scaleFactor, matrix);\n\n  // fix duplicate hash table entries\n  int tableSize = elementCount * 2 * (pd + 1);\n  int cleanBlockSize = 32;\n  int cleanBlocks = (tableSize - 1) / cleanBlockSize + 1;\n\n  cleanHashTable<scalar_t, pd><<<cleanBlocks, cleanBlockSize>>>(tableSize, matrix);\n\n  splat<scalar_t, pd, vd><<<dim3(blockCount, 1), dim3(blockSize, pd + 1)>>>(elementCount, values, matrix, table_values);\n\n  if (accurate) {\n    scalar_t* newValues;\n    cudaMalloc(&newValues, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t));\n    cudaMemset(newValues, 0, elementCount * (pd + 1) * (vd + 1) * sizeof(scalar_t));\n\n    for (int color = 0; color <= pd; color++) {\n      blur<scalar_t, pd, vd>\n          <<<cleanBlocks, cleanBlockSize>>>(elementCount * (pd + 1), newValues, matrix, color, table_values);\n\n      scalar_t* swap = newValues;\n      newValues = table_values;\n      table_values = swap;\n    }\n\n    cudaFree(newValues);\n  }\n\n  slice<scalar_t, pd, vd><<<blockCount, blockSize>>>(elementCount, values, matrix, table_values);\n\n  destroyHashTable<scalar_t>();\n  cudaFree(table_values);\n  cudaFree(scaleFactor);\n  cudaFree(matrix);\n}\n\n#define DECLARATION(dc, fc)                                                                                         \\\n  template void PermutohedralCuda<float, dc, fc>(float* values, float* positions, int elementCount, bool accurate); \\\n  template void PermutohedralCuda<double, dc, fc>(double* values, double* positions, int elementCount, bool accurate);\nDO_FOR_AB(DECLARATION, 16, 19)\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"trainable_bilateral.h\"\n#include \"utils/tensor_description.h\"\n#include \"utils/tensor_indexing.h\"\n\ntemplate <typename scalar_t>\nvoid BilateralFilterCpuBackward_3d(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor gradientOutputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(gradientInputTensor);\n\n  // Raw tensor data pointers.\n  scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr<scalar_t>();\n  scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr<scalar_t>();\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n  scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr<scalar_t>();\n  scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr<scalar_t>();\n\n  // Pre-calculate common values\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  // Set kernel sizes with respect to the defined spatial sigmas.\n  int* kernelSizes = new int[desc.dimensions];\n\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  // Pre-calculate gaussian kernel and distance map in 1D.\n  scalar_t* gaussianKernel_x = new scalar_t[windowSize_x];\n  scalar_t* gaussianKernel_y = new scalar_t[windowSize_y];\n  scalar_t* gaussianKernel_z = new scalar_t[windowSize_z];\n  scalar_t* xDistanceSquared = new scalar_t[windowSize_x];\n  scalar_t* yDistanceSquared = new scalar_t[windowSize_y];\n  scalar_t* zDistanceSquared = new scalar_t[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Looping over the batches\n  for (int b = 0; b < desc.batchCount; b++) {\n    int batchOffset = b * desc.batchStride;\n\n    // Looping over all dimensions for the home element\n    for (int z = 0; z < desc.sizes[2]; z++)\n#pragma omp parallel for\n      for (int y = 0; y < desc.sizes[1]; y++) {\n        for (int x = 0; x < desc.sizes[0]; x++) {\n          // Calculating indexing offset for the home element\n          int homeOffset = batchOffset;\n\n          int homeIndex[] = {x, y, z};\n          homeOffset += x * desc.strides[0];\n          homeOffset += y * desc.strides[1];\n          homeOffset += z * desc.strides[2];\n\n          // Zero kernel aggregates.\n          scalar_t filter_kernel = 0;\n          scalar_t valueSum = 0;\n\n          // Looping over all dimensions for the neighbour element\n          Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes);\n          do // while(kernelIndex++)\n          {\n            // Calculating buffer offset for the neighbour element\n            // Index is clamped to the border in each dimension.\n            int neighbourOffset = batchOffset;\n            bool flagNotClamped = true;\n\n            for (int i = 0; i < desc.dimensions; i++) {\n              int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i];\n              int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex));\n              neighbourOffset += neighbourIndexClamped * desc.strides[i];\n              if (neighbourIndex != neighbourIndexClamped) {\n                flagNotClamped = false;\n              }\n            }\n\n            // Euclidean color distance.\n            scalar_t colorDistance = 0;\n            scalar_t colorDistanceSquared = 0;\n\n            for (int i = 0; i < desc.channelCount; i++) {\n              scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] -\n                  inputTensorData[homeOffset +\n                                  i * desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q)\n              colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n              colorDistanceSquared += diff * diff;\n            }\n\n            // Calculating and combining the spatial\n            // and color weights.\n            scalar_t spatialWeight = 1;\n\n            spatialWeight =\n                gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]];\n\n            scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant);\n            scalar_t totalWeight = spatialWeight * colorWeight;\n\n            // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n            if (flagNotClamped) {\n              for (int i = 0; i < desc.channelCount; i++) {\n                // Distinguish cases for k!=i (calculation is done here)\n                // and k==i (partial derivatives are precalculated).\n                // If statement replaces center element of neighborhood/kernel.\n                if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y ||\n                    kernelIndex[2] != halfWindowSize_z) {\n                  filter_kernel = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) *\n                          outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance /\n                          (colorSigma * colorSigma) +\n                      (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight *\n                          (1 +\n                           inputTensorData[homeOffset + i * desc.channelStride] * colorDistance /\n                               (colorSigma * colorSigma)); // inputTensorData[homeOffset] !!\n                } else {\n                  filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride];\n                }\n\n                valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel;\n              }\n            }\n          } while (kernelIndex++);\n\n          // Do the filtering and calculate the values for the backward pass.\n          for (int i = 0; i < desc.channelCount; i++) {\n            // Filtering:\n            gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum;\n          }\n        }\n      }\n  }\n\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\ntorch::Tensor BilateralFilterCpuBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Preparing output tensor.\n  torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor);\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), \"BilateralFilterCpuBackward_3d\", ([&] {\n                                        BilateralFilterCpuBackward_3d<scalar_t>(\n                                            gradientInputTensor,\n                                            gradientOutputTensor,\n                                            inputTensor,\n                                            outputTensor,\n                                            outputWeightsTensor,\n                                            dO_dx_ki,\n                                            sigma_x,\n                                            sigma_y,\n                                            sigma_z,\n                                            colorSigma);\n                                      }));\n\n  return gradientOutputTensor;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_forward.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"trainable_bilateral.h\"\n#include \"utils/tensor_description.h\"\n#include \"utils/tensor_indexing.h\"\n\ntemplate <typename scalar_t>\nvoid BilateralFilterCpuForward_3d(\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    torch::Tensor dO_dsig_r,\n    torch::Tensor dO_dsig_x,\n    torch::Tensor dO_dsig_y,\n    torch::Tensor dO_dsig_z,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Raw tensor data pointers.\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n  scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr<scalar_t>();\n  scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr<scalar_t>();\n\n  // Pre-calculate common values\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  // Set kernel sizes with respect to the defined spatial sigmas.\n  int* kernelSizes = new int[desc.dimensions];\n\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  // Pre-calculate gaussian kernel and distance map in 1D.\n  scalar_t* gaussianKernel_x = new scalar_t[windowSize_x];\n  scalar_t* gaussianKernel_y = new scalar_t[windowSize_y];\n  scalar_t* gaussianKernel_z = new scalar_t[windowSize_z];\n  scalar_t* xDistanceSquared = new scalar_t[windowSize_x];\n  scalar_t* yDistanceSquared = new scalar_t[windowSize_y];\n  scalar_t* zDistanceSquared = new scalar_t[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Looping over the batches\n  for (int b = 0; b < desc.batchCount; b++) {\n    int batchOffset = b * desc.batchStride;\n\n    // Looping over all dimensions for the home element\n    for (int z = 0; z < desc.sizes[2]; z++)\n#pragma omp parallel for\n      for (int y = 0; y < desc.sizes[1]; y++) {\n        for (int x = 0; x < desc.sizes[0]; x++) {\n          // Calculating indexing offset for the home element\n          int homeOffset = batchOffset;\n\n          int homeIndex[] = {x, y, z};\n          homeOffset += x * desc.strides[0];\n          homeOffset += y * desc.strides[1];\n          homeOffset += z * desc.strides[2];\n\n          // Zero kernel aggregates.\n          scalar_t valueSum = 0;\n          scalar_t dw_dx_ki = 0;\n          scalar_t dfilter_dx_ki = 0;\n          scalar_t colorSum_w = 0;\n          scalar_t colorSum_alpha = 0;\n          scalar_t xSum_w = 0;\n          scalar_t xSum_alpha = 0;\n          scalar_t ySum_w = 0;\n          scalar_t ySum_alpha = 0;\n          scalar_t zSum_w = 0;\n          scalar_t zSum_alpha = 0;\n\n          scalar_t weightSum = 0.0f;\n\n          // Looping over all dimensions for the neighbour element\n          Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes);\n          do // while(kernelIndex++)\n          {\n            // Calculating buffer offset for the neighbour element\n            // Index is clamped to the border in each dimension.\n            int neighbourOffset = batchOffset;\n            bool flagNotClamped = true;\n\n            for (int i = 0; i < desc.dimensions; i++) {\n              int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i];\n              int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex));\n              neighbourOffset += neighbourIndexClamped * desc.strides[i];\n              if (neighbourIndex != neighbourIndexClamped) {\n                flagNotClamped = false;\n              }\n            }\n\n            // Euclidean color distance.\n            scalar_t colorDistance = 0;\n            scalar_t colorDistanceSquared = 0;\n\n            for (int i = 0; i < desc.channelCount; i++) {\n              scalar_t diff = inputTensorData[homeOffset + i * desc.channelStride] -\n                  inputTensorData[neighbourOffset + i * desc.channelStride];\n              colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n              colorDistanceSquared += diff * diff;\n            }\n\n            // Calculating and combining the spatial\n            // and color weights.\n            scalar_t spatialWeight = 1;\n\n            spatialWeight =\n                gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]];\n\n            scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant);\n            scalar_t totalWeight = spatialWeight * colorWeight;\n\n            // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n            if (flagNotClamped) {\n              for (int i = 0; i < desc.channelCount; i++) {\n                valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight;\n\n                // Derivative of weights with respect to X_i while i=k.\n                dw_dx_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma);\n                // Derivative of convolved image with respect to X_i while i=k.\n                dfilter_dx_ki += (-1) * totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    colorDistance /\n                    (colorSigma *\n                     colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData\n\n                colorSum_w += totalWeight * colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma);\n                colorSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma);\n\n                xSum_w += totalWeight * xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x);\n                xSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x);\n\n                ySum_w += totalWeight * yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y);\n                ySum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y);\n\n                zSum_w += totalWeight * zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z);\n                zSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z);\n              }\n\n              weightSum += totalWeight;\n            }\n          } while (kernelIndex++);\n\n          // Do the filtering and calculate the values for the backward pass.\n          for (int i = 0; i < desc.channelCount; i++) {\n            // Filtering:\n            outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum;\n\n            // Pre-computations for the backward pass:\n            outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum;\n            dO_dx_kiData[homeOffset + i * desc.channelStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki +\n                (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here\n            dO_dsig_rData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha;\n            dO_dsig_xData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha;\n            dO_dsig_yData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha;\n            dO_dsig_zData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha;\n          }\n        }\n      }\n  }\n\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nBilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) {\n  // Preparing output tensor.\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n  torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor);\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), \"BilateralFilterCpuForward_3d\", ([&] {\n                                        BilateralFilterCpuForward_3d<scalar_t>(\n                                            inputTensor,\n                                            outputTensor,\n                                            outputWeightsTensor,\n                                            dO_dx_ki,\n                                            dO_dsig_r,\n                                            dO_dsig_x,\n                                            dO_dsig_y,\n                                            dO_dsig_z,\n                                            sigma_x,\n                                            sigma_y,\n                                            sigma_z,\n                                            colorSigma);\n                                      }));\n\n  return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z};\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_backward.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include \"trainable_bilateral.h\"\n//#include \"../utils/cuda_error_check.h\"\n#include \"utils/meta_macros.h\"\n#include \"utils/tensor_description.h\"\n\n__constant__ int cBatchStrideBack;\n__constant__ int cColorStrideBack;\n\n__constant__ int cSizesBack[3];\n__constant__ int cStridesBack[3];\n\n__constant__ int cKernelSizesBack[3];\n__constant__ int cHalfWindowSize_arrBack[3];\n__constant__ float cGaussianKernel_xBack[256];\n__constant__ float cGaussianKernel_yBack[256];\n__constant__ float cGaussianKernel_zBack[256];\n__constant__ float cXDistanceSquaredBack[256];\n__constant__ float cYDistanceSquaredBack[256];\n__constant__ float cZDistanceSquaredBack[256];\n__constant__ float cColorExponentConstantBack;\n__constant__ float cSigma_xBack;\n__constant__ float cSigma_yBack;\n__constant__ float cSigma_zBack;\n__constant__ float cColorSigmaBack;\n\ntemplate <typename scalar_t, int C>\n__global__ void BilateralFilterCudaKernel3DBackward(\n    scalar_t* gradientInputTensor,\n    scalar_t* gradientOutputTensor,\n    scalar_t* inputTensor,\n    scalar_t* outputTensor,\n    scalar_t* outputWeightsTensor,\n    scalar_t* dO_dx_ki) {\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStrideBack;\n\n  if (homeOffset >= cColorStrideBack)\n    return;\n\n  int homeX = homeOffset / cStridesBack[0];\n  int homeY = (homeOffset - homeX * cStridesBack[0]) / cStridesBack[1];\n  int homeZ = (homeOffset - homeX * cStridesBack[0] - homeY * cStridesBack[1]) / cStridesBack[2];\n  int homeIndex[] = {homeX, homeY, homeZ};\n\n  // Zero kernel aggregates.\n  scalar_t valueSum = 0;\n\n  for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) {\n    int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1));\n    scalar_t gaussianX = cGaussianKernel_xBack[kernelX];\n\n    for (int kernelY = 0; kernelY < cKernelSizesBack[1]; kernelY++) {\n      int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arrBack[1]), cSizesBack[1] - 1));\n      scalar_t gaussianY = cGaussianKernel_yBack[kernelY];\n\n      for (int kernelZ = 0; kernelZ < cKernelSizesBack[2]; kernelZ++) {\n        int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arrBack[2]), cSizesBack[2] - 1));\n        scalar_t gaussianZ = cGaussianKernel_zBack[kernelZ];\n\n        int neighbourOffset = neighbourX * cStridesBack[0] + neighbourY * cStridesBack[1] + neighbourZ;\n\n        bool flagNotClamped = true;\n        int kernelIndex[] = {kernelX, kernelY, kernelZ};\n        int dimensions = 3; // Must equal the number of spatial dimensions.\n\n        for (int i = 0; i < dimensions; i++) {\n          int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!),\n                                                               // otherwise: cudaErrorMisalignedAddress\n          int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack;\n          int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex));\n          if (neighbourIndex != neighbourIndexClamped) {\n            flagNotClamped = false;\n          }\n        }\n\n        scalar_t colorDistance = 0;\n        scalar_t colorDistanceSquared = 0;\n\n#pragma unroll\n        for (int c = 0; c < C; c++) {\n          scalar_t a = inputTensor[batchOffset + neighbourOffset + c * cColorStrideBack];\n          scalar_t b = inputTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (X_k -\n                                                                                     // X_i) and not (X_i - X_q)\n          scalar_t diff = a - b;\n          colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n          colorDistanceSquared += diff * diff;\n        }\n\n        scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ;\n        scalar_t colorWeight = exp(cColorExponentConstantBack * colorDistanceSquared);\n        scalar_t totalWeight = spatialWeight * colorWeight;\n\n        // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n        if (flagNotClamped) {\n          scalar_t filter_kernel_back;\n\n#pragma unroll\n          for (int c = 0; c < C; c++) {\n            // Distinguish cases for k!=i (calculation is done here)\n            // and k==i (partial derivatives are precalculated).\n            // If statement replaces center element of neighborhood/kernel.\n            if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] ||\n                kernelZ != cHalfWindowSize_arrBack[2]) {\n              filter_kernel_back = -(1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) *\n                      outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * totalWeight * colorDistance /\n                      (cColorSigmaBack * cColorSigmaBack) +\n                  (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight *\n                      (1 +\n                       inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance /\n                           (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !!\n            } else {\n              filter_kernel_back = dO_dx_ki[batchOffset + homeOffset + c * cColorStrideBack];\n            }\n\n            valueSum += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_back;\n          }\n        }\n      }\n    }\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSum;\n  }\n}\n\ntemplate <int C, int D>\nvoid BilateralFilterCudaBackwardFunction(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor gradientOutputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Pre-calculating gaussian kernel.\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  int* kernelSizes = new int[desc.dimensions];\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  auto* gaussianKernel_x = new float[windowSize_x];\n  auto* gaussianKernel_y = new float[windowSize_y];\n  auto* gaussianKernel_z = new float[windowSize_z];\n  auto* xDistanceSquared = new float[windowSize_x];\n  auto* yDistanceSquared = new float[windowSize_y];\n  auto* zDistanceSquared = new float[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Writing constant memory.\n  cudaMemcpyToSymbol(cBatchStrideBack, &desc.batchStride, sizeof(int));\n  cudaMemcpyToSymbol(cColorStrideBack, &desc.channelStride, sizeof(int));\n  cudaMemcpyToSymbol(cSizesBack, desc.sizes, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cStridesBack, desc.strides, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cKernelSizesBack, kernelSizes, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cHalfWindowSize_arrBack, halfWindowSize_arr, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cGaussianKernel_xBack, gaussianKernel_x, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cGaussianKernel_yBack, gaussianKernel_y, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cGaussianKernel_zBack, gaussianKernel_z, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cXDistanceSquaredBack, xDistanceSquared, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cYDistanceSquaredBack, yDistanceSquared, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cZDistanceSquaredBack, zDistanceSquared, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cColorExponentConstantBack, &colorExpConstant, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_xBack, &sigma_x, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_yBack, &sigma_y, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float));\n  cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float));\n\n  //  cuda_error_check(\"Cuda check before kernel call.\");\n\n#define BLOCK_SIZE 32\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      inputTensor.scalar_type(), \"BilateralFilterCudaKernel3DBackward\", ([&] {\n        BilateralFilterCudaKernel3DBackward<scalar_t, C>\n            <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                gradientInputTensor.data_ptr<scalar_t>(),\n                gradientOutputTensor.data_ptr<scalar_t>(),\n                inputTensor.data_ptr<scalar_t>(),\n                outputTensor.data_ptr<scalar_t>(),\n                outputWeightsTensor.data_ptr<scalar_t>(),\n                dO_dx_ki.data_ptr<scalar_t>());\n      }));\n\n  //  cuda_error_check(\"Cuda check after kernel call.\");\n  //  delete[] kernel;\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\ntorch::Tensor BilateralFilterCudaBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor);\n  //  cuda_error_check(\"beginning\");\n\n#define CASE(c, d)                           \\\n  BilateralFilterCudaBackwardFunction<c, d>( \\\n      gradientInputTensor,                   \\\n      gradientOutputTensor,                  \\\n      inputTensor,                           \\\n      outputTensor,                          \\\n      outputWeightsTensor,                   \\\n      dO_dx_ki,                              \\\n      sigma_x,                               \\\n      sigma_y,                               \\\n      sigma_z,                               \\\n      colorSigma);\n  SWITCH_AB(\n      CASE,\n      BF_CUDA_MAX_CHANNELS,\n      BF_CUDA_MAX_SPATIAL_DIMENSION,\n      gradientInputTensor.size(1),\n      gradientInputTensor.dim() - 2);\n\n  return gradientOutputTensor;\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_bilateral/bf_layer_gpu_forward.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include \"trainable_bilateral.h\"\n//#include \"../utils/cuda_error_check.h\"\n#include \"utils/meta_macros.h\"\n#include \"utils/tensor_description.h\"\n\n__constant__ int cBatchStride;\n__constant__ int cColorStride;\n\n__constant__ int cSizes[3];\n__constant__ int cStrides[3];\n\n__constant__ int cKernelSizes[3];\n__constant__ int cHalfWindowSize_arr[3];\n__constant__ float cGaussianKernel_x[256];\n__constant__ float cGaussianKernel_y[256];\n__constant__ float cGaussianKernel_z[256];\n__constant__ float cXDistanceSquared[256];\n__constant__ float cYDistanceSquared[256];\n__constant__ float cZDistanceSquared[256];\n__constant__ float cColorExponentConstant;\n__constant__ float cSigma_x;\n__constant__ float cSigma_y;\n__constant__ float cSigma_z;\n__constant__ float cColorSigma;\n\ntemplate <typename scalar_t, int C>\n__global__ void BilateralFilterCudaKernel3DForward(\n    scalar_t* input,\n    scalar_t* output,\n    scalar_t* outputWeightsTensor,\n    scalar_t* dO_dx_ki,\n    scalar_t* dO_dsig_r,\n    scalar_t* dO_dsig_x,\n    scalar_t* dO_dsig_y,\n    scalar_t* dO_dsig_z) {\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStride;\n\n  if (homeOffset >= cColorStride)\n    return;\n\n  int homeX = homeOffset / cStrides[0];\n  int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1];\n  int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2];\n  int homeIndex[] = {homeX, homeY, homeZ};\n\n  // Zero kernel aggregates.\n  scalar_t valueSum = 0;\n  scalar_t dw_dx_ki = 0;\n  scalar_t dfilter_dx_ki = 0;\n  scalar_t colorSum_w = 0;\n  scalar_t colorSum_alpha = 0;\n  scalar_t xSum_w = 0;\n  scalar_t xSum_alpha = 0;\n  scalar_t ySum_w = 0;\n  scalar_t ySum_alpha = 0;\n  scalar_t zSum_w = 0;\n  scalar_t zSum_alpha = 0;\n  scalar_t weightSum = 0;\n\n  for (int kernelX = 0; kernelX < cKernelSizes[0]; kernelX++) {\n    int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arr[0]), cSizes[0] - 1));\n    scalar_t gaussianX = cGaussianKernel_x[kernelX];\n\n    for (int kernelY = 0; kernelY < cKernelSizes[1]; kernelY++) {\n      int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arr[1]), cSizes[1] - 1));\n      scalar_t gaussianY = cGaussianKernel_y[kernelY];\n\n      for (int kernelZ = 0; kernelZ < cKernelSizes[2]; kernelZ++) {\n        int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arr[2]), cSizes[2] - 1));\n        scalar_t gaussianZ = cGaussianKernel_z[kernelZ];\n\n        int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ;\n\n        bool flagNotClamped = true;\n        int kernelIndex[] = {kernelX, kernelY, kernelZ};\n        int dimensions = 3; // Must equal the number of spatial dimensions.\n\n        for (int i = 0; i < dimensions; i++) {\n          int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!),\n                                                           // otherwise: cudaErrorMisalignedAddress\n          int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack;\n          int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex));\n          if (neighbourIndex != neighbourIndexClamped) {\n            flagNotClamped = false;\n          }\n        }\n\n        scalar_t colorDistance = 0;\n        scalar_t colorDistanceSquared = 0;\n\n#pragma unroll\n        for (int c = 0; c < C; c++) {\n          scalar_t a = input[batchOffset + homeOffset + c * cColorStride];\n          scalar_t b = input[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward the\n                                                                                // other way around !!\n          scalar_t diff = a - b;\n          colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n          colorDistanceSquared += diff * diff;\n        }\n\n        scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ;\n        scalar_t colorWeight = exp(cColorExponentConstant * colorDistanceSquared);\n        scalar_t totalWeight = spatialWeight * colorWeight;\n\n        // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n        if (flagNotClamped) {\n#pragma unroll\n          for (int c = 0; c < C; c++) {\n            valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight;\n\n            // Derivative of weights with respect to X_i while i=k.\n            dw_dx_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma);\n            // Derivative of convolved image with respect to X_i while i=k.\n            dfilter_dx_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                colorDistance /\n                (cColorSigma *\n                 cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData\n\n            colorSum_w += totalWeight * colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma);\n            colorSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma);\n\n            xSum_w += totalWeight * cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x);\n            xSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x);\n\n            ySum_w += totalWeight * cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y);\n            ySum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y);\n\n            zSum_w += totalWeight * cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z);\n            zSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z);\n          }\n\n          weightSum += totalWeight;\n        }\n      }\n    }\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    //    output[batchOffset + homeOffset + c * cColorStride] /= weightSum;\n    output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum;\n\n    // Pre-computations for the backward pass:\n    outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum;\n    dO_dx_ki[batchOffset + homeOffset + c * cColorStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dx_ki +\n        (1 / weightSum) * (dfilter_dx_ki + 1); // +1 for dfilter_dx_ki is added here\n    dO_dsig_r[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha;\n    dO_dsig_x[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha;\n    dO_dsig_y[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha;\n    dO_dsig_z[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha;\n  }\n}\n\ntemplate <int C, int D>\nvoid BilateralFilterCudaForwardFunction(\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    torch::Tensor dO_dsig_r,\n    torch::Tensor dO_dsig_x,\n    torch::Tensor dO_dsig_y,\n    torch::Tensor dO_dsig_z,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Pre-calculating gaussian kernel.\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  int* kernelSizes = new int[desc.dimensions];\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  auto* gaussianKernel_x = new float[windowSize_x];\n  auto* gaussianKernel_y = new float[windowSize_y];\n  auto* gaussianKernel_z = new float[windowSize_z];\n  auto* xDistanceSquared = new float[windowSize_x];\n  auto* yDistanceSquared = new float[windowSize_y];\n  auto* zDistanceSquared = new float[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Writing constant memory.\n  cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int));\n  cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int));\n  cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cKernelSizes, kernelSizes, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cHalfWindowSize_arr, halfWindowSize_arr, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cGaussianKernel_x, gaussianKernel_x, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cGaussianKernel_y, gaussianKernel_y, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cGaussianKernel_z, gaussianKernel_z, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cXDistanceSquared, xDistanceSquared, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cYDistanceSquared, yDistanceSquared, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cZDistanceSquared, zDistanceSquared, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cColorExponentConstant, &colorExpConstant, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_x, &sigma_x, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_y, &sigma_y, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float));\n  cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float));\n\n  //  cuda_error_check(\"Cuda check before kernel call.\");\n\n#define BLOCK_SIZE 32\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      inputTensor.scalar_type(), \"BilateralFilterCudaKernel3DForward\", ([&] {\n        BilateralFilterCudaKernel3DForward<scalar_t, C>\n            <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                inputTensor.data_ptr<scalar_t>(),\n                outputTensor.data_ptr<scalar_t>(),\n                outputWeightsTensor.data_ptr<scalar_t>(),\n                dO_dx_ki.data_ptr<scalar_t>(),\n                dO_dsig_r.data_ptr<scalar_t>(),\n                dO_dsig_x.data_ptr<scalar_t>(),\n                dO_dsig_y.data_ptr<scalar_t>(),\n                dO_dsig_z.data_ptr<scalar_t>());\n      }));\n\n  //  cuda_error_check(\"Cuda check after kernel call.\");\n  //  delete[] kernel;\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nBilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma) {\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n  torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dx_ki = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor);\n  //  cuda_error_check(\"beginning\");\n\n#define CASE(c, d)                          \\\n  BilateralFilterCudaForwardFunction<c, d>( \\\n      inputTensor,                          \\\n      outputTensor,                         \\\n      outputWeightsTensor,                  \\\n      dO_dx_ki,                             \\\n      dO_dsig_r,                            \\\n      dO_dsig_x,                            \\\n      dO_dsig_y,                            \\\n      dO_dsig_z,                            \\\n      sigma_x,                              \\\n      sigma_y,                              \\\n      sigma_z,                              \\\n      colorSigma);\n  SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);\n\n  return {outputTensor, outputWeightsTensor, dO_dx_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z};\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_bilateral/trainable_bilateral.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n#include <stdexcept>\n#include <string>\n\n#include \"trainable_bilateral.h\"\n#include \"utils/common_utils.h\"\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nTrainableBilateralFilterForward(\n    torch::Tensor inputTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> (\n      *filterFunction)(torch::Tensor, float, float, float, float);\n\n#ifdef WITH_CUDA\n\n  if (torch::cuda::is_available() && inputTensor.is_cuda()) {\n    CHECK_CONTIGUOUS_CUDA(inputTensor);\n\n    if (inputTensor.size(1) > BF_CUDA_MAX_CHANNELS) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for channel count > \" + std::to_string(BF_CUDA_MAX_CHANNELS));\n    }\n\n    if (inputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for spatial dimension > \" +\n          std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));\n    }\n\n    filterFunction = &BilateralFilterCudaForward;\n  } else {\n    filterFunction = &BilateralFilterCpuForward;\n  }\n#else\n  filterFunction = &BilateralFilterCpuForward;\n#endif\n\n  return filterFunction(inputTensor, sigma_x, sigma_y, sigma_z, colorSigma);\n}\n\ntorch::Tensor TrainableBilateralFilterBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  torch::Tensor (*filterFunction)(\n      torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, float, float, float, float);\n\n#ifdef WITH_CUDA\n\n  if (torch::cuda::is_available() && gradientInputTensor.is_cuda()) {\n    CHECK_CONTIGUOUS_CUDA(gradientInputTensor);\n\n    if (gradientInputTensor.size(1) > BF_CUDA_MAX_CHANNELS) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for channel count > \" + std::to_string(BF_CUDA_MAX_CHANNELS));\n    }\n\n    if (gradientInputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for spatial dimension > \" +\n          std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));\n    }\n\n    filterFunction = &BilateralFilterCudaBackward;\n  } else {\n    filterFunction = &BilateralFilterCpuBackward;\n  }\n#else\n  filterFunction = &BilateralFilterCpuBackward;\n#endif\n\n  return filterFunction(\n      gradientInputTensor,\n      inputTensor,\n      outputTensor,\n      outputWeightsTensor,\n      dO_dx_ki,\n      sigma_x,\n      sigma_y,\n      sigma_z,\n      colorSigma);\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_bilateral/trainable_bilateral.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-bilateral-filter-source/blob/main/LICENSE.md\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n#include <torch/extension.h>\n#include <algorithm>\n#include <iostream>\n#include <vector>\n#include \"utils/common_utils.h\"\n//#include \"utils/tensor_description.h\"\n\n#define BF_CUDA_MAX_CHANNELS 16\n#define BF_CUDA_MAX_SPATIAL_DIMENSION 3\n\n#ifdef WITH_CUDA\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nBilateralFilterCudaForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma);\ntorch::Tensor BilateralFilterCudaBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n#endif\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nBilateralFilterCpuForward(torch::Tensor inputTensor, float sigma_x, float sigma_y, float sigma_z, float colorSigma);\n\ntorch::Tensor BilateralFilterCpuBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nTrainableBilateralFilterForward(\n    torch::Tensor inputTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n\ntorch::Tensor TrainableBilateralFilterBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_backward.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"trainable_joint_bilateral.h\"\n#include \"utils/tensor_description.h\"\n#include \"utils/tensor_indexing.h\"\n\ntemplate <typename scalar_t>\nvoid JointBilateralFilterCpuBackward_3d(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor gradientGuidanceTensor,\n    torch::Tensor gradientOutputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(gradientInputTensor);\n\n  // Raw tensor data pointers.\n  scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr<scalar_t>();\n  scalar_t* gradientGuidanceTensorData = gradientGuidanceTensor.data_ptr<scalar_t>();\n  scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr<scalar_t>();\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* guidanceTensorData = guidanceTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n  scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr<scalar_t>();\n  scalar_t* dO_dz_kiData = dO_dz_ki.data_ptr<scalar_t>();\n  //    scalar_t* dw_dx_kiData = dw_dx_ki_Tensor.data_ptr<scalar_t>();\n  //    scalar_t* dfilter_dx_kiData = dfilter_dx_ki_Tensor.data_ptr<scalar_t>();\n\n  // Pre-calculate common values\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  // Set kernel sizes with respect to the defined spatial sigmas.\n  int* kernelSizes = new int[desc.dimensions];\n\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  // Pre-calculate gaussian kernel and distance map in 1D.\n  scalar_t* gaussianKernel_x = new scalar_t[windowSize_x];\n  scalar_t* gaussianKernel_y = new scalar_t[windowSize_y];\n  scalar_t* gaussianKernel_z = new scalar_t[windowSize_z];\n  scalar_t* xDistanceSquared = new scalar_t[windowSize_x];\n  scalar_t* yDistanceSquared = new scalar_t[windowSize_y];\n  scalar_t* zDistanceSquared = new scalar_t[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Looping over the batches\n  for (int b = 0; b < desc.batchCount; b++) {\n    int batchOffset = b * desc.batchStride;\n\n    // Looping over all dimensions for the home element\n    for (int z = 0; z < desc.sizes[2]; z++)\n#pragma omp parallel for\n      for (int y = 0; y < desc.sizes[1]; y++) {\n        for (int x = 0; x < desc.sizes[0]; x++) {\n          // Calculating indexing offset for the home element\n          int homeOffset = batchOffset;\n\n          int homeIndex[] = {x, y, z};\n          homeOffset += x * desc.strides[0];\n          homeOffset += y * desc.strides[1];\n          homeOffset += z * desc.strides[2];\n\n          // Zero kernel aggregates.\n          scalar_t filter_kernel_guidance = 0;\n          scalar_t valueSumGuidance = 0;\n          scalar_t valueSumInput = 0;\n\n          // Looping over all dimensions for the neighbour element\n          Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes);\n          do // while(kernelIndex++)\n          {\n            // Calculating buffer offset for the neighbour element\n            // Index is clamped to the border in each dimension.\n            int neighbourOffset = batchOffset;\n            bool flagNotClamped = true;\n\n            for (int i = 0; i < desc.dimensions; i++) {\n              int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i];\n              int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex));\n              neighbourOffset += neighbourIndexClamped * desc.strides[i];\n              if (neighbourIndex != neighbourIndexClamped) {\n                flagNotClamped = false;\n              }\n            }\n\n            // Euclidean color distance.\n            scalar_t colorDistance = 0;\n            scalar_t colorDistanceSquared = 0;\n\n            for (int i = 0; i < desc.channelCount; i++) {\n              scalar_t diff = guidanceTensorData[neighbourOffset + i * desc.channelStride] -\n                  guidanceTensorData[homeOffset + i * desc.channelStride]; // Be careful: Here it is (Z_k - Z_i) and not\n                                                                           // (Z_i - Z_q)\n              colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n              colorDistanceSquared += diff * diff;\n            }\n\n            // Calculating and combining the spatial\n            // and color weights.\n            scalar_t spatialWeight = 1;\n\n            spatialWeight =\n                gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]];\n\n            scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant);\n            scalar_t totalWeight = spatialWeight * colorWeight;\n\n            // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n            if (flagNotClamped) {\n              for (int i = 0; i < desc.channelCount; i++) {\n                // Distinguish cases for k!=i (calculation is done here)\n                // and k==i (partial derivatives are precalculated).\n                // If statement replaces center element of neighborhood/kernel.\n                if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y ||\n                    kernelIndex[2] != halfWindowSize_z) {\n                  filter_kernel_guidance = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) *\n                          outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance /\n                          (colorSigma * colorSigma) +\n                      (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight *\n                          (inputTensorData[homeOffset + i * desc.channelStride] * colorDistance /\n                           (colorSigma * colorSigma)); // inputTensorData[homeOffset] !!, no +1!!\n                } else {\n                  filter_kernel_guidance = dO_dz_kiData[homeOffset + i * desc.channelStride];\n                }\n\n                valueSumGuidance +=\n                    gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel_guidance;\n                valueSumInput += gradientInputTensorData[neighbourOffset + i * desc.channelStride] *\n                    (1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight;\n              }\n            }\n          } while (kernelIndex++);\n\n          // Do the filtering and calculate the values for the backward pass.\n          for (int i = 0; i < desc.channelCount; i++) {\n            // Filtering:\n            gradientGuidanceTensorData[homeOffset + i * desc.channelStride] = valueSumGuidance;\n            gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSumInput;\n          }\n        }\n      }\n  }\n\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> JointBilateralFilterCpuBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Preparing output tensor.\n  torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor);\n  torch::Tensor gradientGuidanceTensor = torch::zeros_like(gradientInputTensor);\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), \"JointBilateralFilterCpuBackward_3d\", ([&] {\n                                        JointBilateralFilterCpuBackward_3d<scalar_t>(\n                                            gradientInputTensor,\n                                            gradientGuidanceTensor,\n                                            gradientOutputTensor,\n                                            inputTensor,\n                                            guidanceTensor,\n                                            outputTensor,\n                                            outputWeightsTensor,\n                                            dO_dz_ki,\n                                            sigma_x,\n                                            sigma_y,\n                                            sigma_z,\n                                            colorSigma);\n                                      }));\n\n  return {gradientOutputTensor, gradientGuidanceTensor};\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_cpu_forward.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include \"trainable_joint_bilateral.h\"\n#include \"utils/tensor_description.h\"\n#include \"utils/tensor_indexing.h\"\n\ntemplate <typename scalar_t>\nvoid JointBilateralFilterCpuForward_3d(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    torch::Tensor dO_dsig_r,\n    torch::Tensor dO_dsig_x,\n    torch::Tensor dO_dsig_y,\n    torch::Tensor dO_dsig_z,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Raw tensor data pointers.\n  scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();\n  scalar_t* guidanceTensorData = guidanceTensor.data_ptr<scalar_t>();\n  scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();\n  scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr<scalar_t>();\n  scalar_t* dO_dz_kiData = dO_dz_ki.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_rData = dO_dsig_r.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_xData = dO_dsig_x.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_yData = dO_dsig_y.data_ptr<scalar_t>();\n  scalar_t* dO_dsig_zData = dO_dsig_z.data_ptr<scalar_t>();\n\n  // Pre-calculate common values\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  // Set kernel sizes with respect to the defined spatial sigmas.\n  int* kernelSizes = new int[desc.dimensions];\n\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  // Pre-calculate gaussian kernel and distance map in 1D.\n  scalar_t* gaussianKernel_x = new scalar_t[windowSize_x];\n  scalar_t* gaussianKernel_y = new scalar_t[windowSize_y];\n  scalar_t* gaussianKernel_z = new scalar_t[windowSize_z];\n  scalar_t* xDistanceSquared = new scalar_t[windowSize_x];\n  scalar_t* yDistanceSquared = new scalar_t[windowSize_y];\n  scalar_t* zDistanceSquared = new scalar_t[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Looping over the batches\n  for (int b = 0; b < desc.batchCount; b++) {\n    int batchOffset = b * desc.batchStride;\n\n    // Looping over all dimensions for the home element\n    for (int z = 0; z < desc.sizes[2]; z++)\n#pragma omp parallel for\n      for (int y = 0; y < desc.sizes[1]; y++) {\n        for (int x = 0; x < desc.sizes[0]; x++) {\n          // Calculating indexing offset for the home element\n          int homeOffset = batchOffset;\n\n          int homeIndex[] = {x, y, z};\n          homeOffset += x * desc.strides[0];\n          homeOffset += y * desc.strides[1];\n          homeOffset += z * desc.strides[2];\n\n          // Zero kernel aggregates.\n          scalar_t valueSum = 0;\n          scalar_t dw_dz_ki = 0;\n          scalar_t dfilter_dz_ki = 0;\n          scalar_t colorSum_w = 0;\n          scalar_t colorSum_alpha = 0;\n          scalar_t xSum_w = 0;\n          scalar_t xSum_alpha = 0;\n          scalar_t ySum_w = 0;\n          scalar_t ySum_alpha = 0;\n          scalar_t zSum_w = 0;\n          scalar_t zSum_alpha = 0;\n\n          scalar_t weightSum = 0.0f;\n\n          // Looping over all dimensions for the neighbour element\n          Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes);\n          do // while(kernelIndex++)\n          {\n            // Calculating buffer offset for the neighbour element\n            // Index is clamped to the border in each dimension.\n            int neighbourOffset = batchOffset;\n            bool flagNotClamped = true;\n\n            for (int i = 0; i < desc.dimensions; i++) {\n              int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i];\n              int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex));\n              neighbourOffset += neighbourIndexClamped * desc.strides[i];\n              if (neighbourIndex != neighbourIndexClamped) {\n                flagNotClamped = false;\n              }\n            }\n\n            // Euclidean color distance.\n            scalar_t colorDistance = 0;\n            scalar_t colorDistanceSquared = 0;\n\n            for (int i = 0; i < desc.channelCount; i++) {\n              scalar_t diff = guidanceTensorData[homeOffset + i * desc.channelStride] -\n                  guidanceTensorData[neighbourOffset + i * desc.channelStride];\n              colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n              colorDistanceSquared += diff * diff;\n            }\n\n            // Calculating and combining the spatial\n            // and color weights.\n            scalar_t spatialWeight = 1;\n\n            spatialWeight =\n                gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]];\n\n            scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant);\n            scalar_t totalWeight = spatialWeight * colorWeight;\n\n            // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n            if (flagNotClamped) {\n              for (int i = 0; i < desc.channelCount; i++) {\n                valueSum += inputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight;\n\n                // Derivative of weights with respect to X_i while i=k.\n                dw_dz_ki += (-1) * totalWeight * colorDistance / (colorSigma * colorSigma);\n                // Derivative of convolved image with respect to X_i while i=k.\n                dfilter_dz_ki += (-1) * totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    colorDistance /\n                    (colorSigma *\n                     colorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData\n\n                colorSum_w += totalWeight * colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma);\n                colorSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    colorDistanceSquared / std::abs(colorSigma * colorSigma * colorSigma);\n\n                xSum_w += totalWeight * xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x);\n                xSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    xDistanceSquared[kernelIndex[0]] / std::abs(sigma_x * sigma_x * sigma_x);\n\n                ySum_w += totalWeight * yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y);\n                ySum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    yDistanceSquared[kernelIndex[1]] / std::abs(sigma_y * sigma_y * sigma_y);\n\n                zSum_w += totalWeight * zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z);\n                zSum_alpha += totalWeight * inputTensorData[neighbourOffset + i * desc.channelStride] *\n                    zDistanceSquared[kernelIndex[2]] / std::abs(sigma_z * sigma_z * sigma_z);\n              }\n\n              weightSum += totalWeight;\n            }\n          } while (kernelIndex++);\n\n          // Do the filtering and calculate the values for the backward pass.\n          for (int i = 0; i < desc.channelCount; i++) {\n            // Filtering:\n            outputTensorData[homeOffset + i * desc.channelStride] = valueSum / weightSum;\n\n            // Pre-computations for the backward pass:\n            outputWeightsTensorData[homeOffset + i * desc.channelStride] = weightSum;\n            dO_dz_kiData[homeOffset + i * desc.channelStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dz_ki +\n                (1 / weightSum) * (dfilter_dz_ki); // no +1 for dfilter_dz_ki for JBF added here!\n            dO_dsig_rData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha;\n            dO_dsig_xData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha;\n            dO_dsig_yData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha;\n            dO_dsig_zData[homeOffset + i * desc.channelStride] =\n                -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha;\n          }\n        }\n      }\n  }\n\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nJointBilateralFilterCpuForward(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Preparing output tensor.\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n  torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dz_ki = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor);\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputTensor.scalar_type(), \"JointBilateralFilterCpuForward_3d\", ([&] {\n                                        JointBilateralFilterCpuForward_3d<scalar_t>(\n                                            inputTensor,\n                                            guidanceTensor,\n                                            outputTensor,\n                                            outputWeightsTensor,\n                                            dO_dz_ki,\n                                            dO_dsig_r,\n                                            dO_dsig_x,\n                                            dO_dsig_y,\n                                            dO_dsig_z,\n                                            sigma_x,\n                                            sigma_y,\n                                            sigma_z,\n                                            colorSigma);\n                                      }));\n\n  return {outputTensor, outputWeightsTensor, dO_dz_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z};\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_backward.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include \"trainable_joint_bilateral.h\"\n//#include \"../utils/cuda_error_check.h\"\n#include \"utils/meta_macros.h\"\n#include \"utils/tensor_description.h\"\n\n__constant__ int cBatchStrideBack;\n__constant__ int cColorStrideBack;\n\n__constant__ int cSizesBack[3];\n__constant__ int cStridesBack[3];\n\n__constant__ int cKernelSizesBack[3];\n__constant__ int cHalfWindowSize_arrBack[3];\n__constant__ float cGaussianKernel_xBack[256];\n__constant__ float cGaussianKernel_yBack[256];\n__constant__ float cGaussianKernel_zBack[256];\n__constant__ float cXDistanceSquaredBack[256];\n__constant__ float cYDistanceSquaredBack[256];\n__constant__ float cZDistanceSquaredBack[256];\n__constant__ float cColorExponentConstantBack;\n__constant__ float cSigma_xBack;\n__constant__ float cSigma_yBack;\n__constant__ float cSigma_zBack;\n__constant__ float cColorSigmaBack;\n\ntemplate <typename scalar_t, int C>\n__global__ void JointBilateralFilterCudaKernel3DBackward(\n    scalar_t* gradientInputTensor,\n    scalar_t* gradientGuidanceTensor,\n    scalar_t* gradientOutputTensor,\n    scalar_t* inputTensor,\n    scalar_t* guidanceTensor,\n    scalar_t* outputTensor,\n    scalar_t* outputWeightsTensor,\n    scalar_t* dO_dz_ki) {\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStrideBack;\n\n  if (homeOffset >= cColorStrideBack)\n    return;\n\n  int homeX = homeOffset / cStridesBack[0];\n  int homeY = (homeOffset - homeX * cStridesBack[0]) / cStridesBack[1];\n  int homeZ = (homeOffset - homeX * cStridesBack[0] - homeY * cStridesBack[1]) / cStridesBack[2];\n  int homeIndex[] = {homeX, homeY, homeZ};\n\n  // Zero kernel aggregates.\n  scalar_t valueSumGuidance = 0;\n  scalar_t valueSumInput = 0;\n\n  for (int kernelX = 0; kernelX < cKernelSizesBack[0]; kernelX++) {\n    int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arrBack[0]), cSizesBack[0] - 1));\n    scalar_t gaussianX = cGaussianKernel_xBack[kernelX];\n\n    for (int kernelY = 0; kernelY < cKernelSizesBack[1]; kernelY++) {\n      int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arrBack[1]), cSizesBack[1] - 1));\n      scalar_t gaussianY = cGaussianKernel_yBack[kernelY];\n\n      for (int kernelZ = 0; kernelZ < cKernelSizesBack[2]; kernelZ++) {\n        int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arrBack[2]), cSizesBack[2] - 1));\n        scalar_t gaussianZ = cGaussianKernel_zBack[kernelZ];\n\n        int neighbourOffset = neighbourX * cStridesBack[0] + neighbourY * cStridesBack[1] + neighbourZ;\n\n        bool flagNotClamped = true;\n        int kernelIndex[] = {kernelX, kernelY, kernelZ};\n        int dimensions = 3; // Must equal the number of spatial dimensions.\n\n        for (int i = 0; i < dimensions; i++) {\n          int HalfWindowSizeBack = cHalfWindowSize_arrBack[i]; // Define constant memory as new variable here (!!),\n                                                               // otherwise: cudaErrorMisalignedAddress\n          int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack;\n          int neighbourIndexClamped = min(cSizesBack[i] - 1, max(0, neighbourIndex));\n          if (neighbourIndex != neighbourIndexClamped) {\n            flagNotClamped = false;\n          }\n        }\n\n        scalar_t colorDistance = 0;\n        scalar_t colorDistanceSquared = 0;\n\n#pragma unroll\n        for (int c = 0; c < C; c++) {\n          scalar_t a = guidanceTensor[batchOffset + neighbourOffset + c * cColorStrideBack];\n          scalar_t b = guidanceTensor[batchOffset + homeOffset + c * cColorStrideBack]; // Be careful: Here it is (Z_k -\n                                                                                        // Z_i) and not (Z_i - Z_q)\n          scalar_t diff = a - b;\n          colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n          colorDistanceSquared += diff * diff;\n        }\n\n        scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ;\n        scalar_t colorWeight = exp(cColorExponentConstantBack * colorDistanceSquared);\n        scalar_t totalWeight = spatialWeight * colorWeight;\n\n        // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n        if (flagNotClamped) {\n          scalar_t filter_kernel_guidance_back;\n\n#pragma unroll\n          for (int c = 0; c < C; c++) {\n            // Distinguish cases for k!=i (calculation is done here)\n            // and k==i (partial derivatives are precalculated).\n            // If statement replaces center element of neighborhood/kernel.\n            if (kernelX != cHalfWindowSize_arrBack[0] || kernelY != cHalfWindowSize_arrBack[1] ||\n                kernelZ != cHalfWindowSize_arrBack[2]) {\n              filter_kernel_guidance_back =\n                  -(1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) *\n                      outputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * totalWeight * colorDistance /\n                      (cColorSigmaBack * cColorSigmaBack) +\n                  (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight *\n                      (inputTensor[batchOffset + homeOffset + c * cColorStrideBack] * colorDistance /\n                       (cColorSigmaBack * cColorSigmaBack)); // inputTensorData[homeOffset] !!, no +1!!\n            } else {\n              filter_kernel_guidance_back = dO_dz_ki[batchOffset + homeOffset + c * cColorStrideBack];\n            }\n\n            valueSumGuidance +=\n                gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] * filter_kernel_guidance_back;\n            valueSumInput += gradientInputTensor[batchOffset + neighbourOffset + c * cColorStrideBack] *\n                (1 / outputWeightsTensor[batchOffset + neighbourOffset + c * cColorStrideBack]) * totalWeight;\n          }\n        }\n      }\n    }\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    gradientGuidanceTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSumGuidance;\n    gradientOutputTensor[batchOffset + homeOffset + c * cColorStrideBack] = valueSumInput;\n  }\n}\n\ntemplate <int C, int D>\nvoid JointBilateralFilterCudaBackwardFunction(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor gradientGuidanceTensor,\n    torch::Tensor gradientOutputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Pre-calculating gaussian kernel.\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  int* kernelSizes = new int[desc.dimensions];\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  auto* gaussianKernel_x = new float[windowSize_x];\n  auto* gaussianKernel_y = new float[windowSize_y];\n  auto* gaussianKernel_z = new float[windowSize_z];\n  auto* xDistanceSquared = new float[windowSize_x];\n  auto* yDistanceSquared = new float[windowSize_y];\n  auto* zDistanceSquared = new float[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Writing constant memory.\n  cudaMemcpyToSymbol(cBatchStrideBack, &desc.batchStride, sizeof(int));\n  cudaMemcpyToSymbol(cColorStrideBack, &desc.channelStride, sizeof(int));\n  cudaMemcpyToSymbol(cSizesBack, desc.sizes, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cStridesBack, desc.strides, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cKernelSizesBack, kernelSizes, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cHalfWindowSize_arrBack, halfWindowSize_arr, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cGaussianKernel_xBack, gaussianKernel_x, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cGaussianKernel_yBack, gaussianKernel_y, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cGaussianKernel_zBack, gaussianKernel_z, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cXDistanceSquaredBack, xDistanceSquared, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cYDistanceSquaredBack, yDistanceSquared, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cZDistanceSquaredBack, zDistanceSquared, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cColorExponentConstantBack, &colorExpConstant, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_xBack, &sigma_x, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_yBack, &sigma_y, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_zBack, &sigma_z, sizeof(float));\n  cudaMemcpyToSymbol(cColorSigmaBack, &colorSigma, sizeof(float));\n\n  //  cuda_error_check(\"Cuda check before kernel call.\");\n\n#define BLOCK_SIZE 32\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      inputTensor.scalar_type(), \"JointBilateralFilterCudaKernel3DBackward\", ([&] {\n        JointBilateralFilterCudaKernel3DBackward<scalar_t, C>\n            <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                gradientInputTensor.data_ptr<scalar_t>(),\n                gradientGuidanceTensor.data_ptr<scalar_t>(),\n                gradientOutputTensor.data_ptr<scalar_t>(),\n                inputTensor.data_ptr<scalar_t>(),\n                guidanceTensor.data_ptr<scalar_t>(),\n                outputTensor.data_ptr<scalar_t>(),\n                outputWeightsTensor.data_ptr<scalar_t>(),\n                dO_dz_ki.data_ptr<scalar_t>());\n      }));\n\n  //  cuda_error_check(\"Cuda check after kernel call.\");\n  //  delete[] kernel;\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\nstd::tuple<torch::Tensor, torch::Tensor> JointBilateralFilterCudaBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor);\n  torch::Tensor gradientGuidanceTensor = torch::zeros_like(gradientInputTensor);\n  //  cuda_error_check(\"beginning\");\n\n#define CASE(c, d)                                \\\n  JointBilateralFilterCudaBackwardFunction<c, d>( \\\n      gradientInputTensor,                        \\\n      gradientGuidanceTensor,                     \\\n      gradientOutputTensor,                       \\\n      inputTensor,                                \\\n      guidanceTensor,                             \\\n      outputTensor,                               \\\n      outputWeightsTensor,                        \\\n      dO_dz_ki,                                   \\\n      sigma_x,                                    \\\n      sigma_y,                                    \\\n      sigma_z,                                    \\\n      colorSigma);\n  SWITCH_AB(\n      CASE,\n      BF_CUDA_MAX_CHANNELS,\n      BF_CUDA_MAX_SPATIAL_DIMENSION,\n      gradientInputTensor.size(1),\n      gradientInputTensor.dim() - 2);\n\n  return {gradientOutputTensor, gradientGuidanceTensor};\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_joint_bilateral/jbf_layer_gpu_forward.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include \"trainable_joint_bilateral.h\"\n//#include \"../utils/cuda_error_check.h\"\n#include \"utils/meta_macros.h\"\n#include \"utils/tensor_description.h\"\n\n__constant__ int cBatchStride;\n__constant__ int cColorStride;\n\n__constant__ int cSizes[3];\n__constant__ int cStrides[3];\n\n__constant__ int cKernelSizes[3];\n__constant__ int cHalfWindowSize_arr[3];\n__constant__ float cGaussianKernel_x[256];\n__constant__ float cGaussianKernel_y[256];\n__constant__ float cGaussianKernel_z[256];\n__constant__ float cXDistanceSquared[256];\n__constant__ float cYDistanceSquared[256];\n__constant__ float cZDistanceSquared[256];\n__constant__ float cColorExponentConstant;\n__constant__ float cSigma_x;\n__constant__ float cSigma_y;\n__constant__ float cSigma_z;\n__constant__ float cColorSigma;\n\ntemplate <typename scalar_t, int C>\n__global__ void JointBilateralFilterCudaKernel3DForward(\n    scalar_t* input,\n    scalar_t* guidance,\n    scalar_t* output,\n    scalar_t* outputWeightsTensor,\n    scalar_t* dO_dz_ki,\n    scalar_t* dO_dsig_r,\n    scalar_t* dO_dsig_x,\n    scalar_t* dO_dsig_y,\n    scalar_t* dO_dsig_z) {\n  int homeOffset = blockIdx.x * blockDim.x + threadIdx.x;\n  int batchOffset = blockIdx.y * cBatchStride;\n\n  if (homeOffset >= cColorStride)\n    return;\n\n  int homeX = homeOffset / cStrides[0];\n  int homeY = (homeOffset - homeX * cStrides[0]) / cStrides[1];\n  int homeZ = (homeOffset - homeX * cStrides[0] - homeY * cStrides[1]) / cStrides[2];\n  int homeIndex[] = {homeX, homeY, homeZ};\n\n  // Zero kernel aggregates.\n  scalar_t valueSum = 0;\n  scalar_t dw_dz_ki = 0;\n  scalar_t dfilter_dz_ki = 0;\n  scalar_t colorSum_w = 0;\n  scalar_t colorSum_alpha = 0;\n  scalar_t xSum_w = 0;\n  scalar_t xSum_alpha = 0;\n  scalar_t ySum_w = 0;\n  scalar_t ySum_alpha = 0;\n  scalar_t zSum_w = 0;\n  scalar_t zSum_alpha = 0;\n  scalar_t weightSum = 0;\n\n  for (int kernelX = 0; kernelX < cKernelSizes[0]; kernelX++) {\n    int neighbourX = max(0, min(homeX + (kernelX - cHalfWindowSize_arr[0]), cSizes[0] - 1));\n    scalar_t gaussianX = cGaussianKernel_x[kernelX];\n\n    for (int kernelY = 0; kernelY < cKernelSizes[1]; kernelY++) {\n      int neighbourY = max(0, min(homeY + (kernelY - cHalfWindowSize_arr[1]), cSizes[1] - 1));\n      scalar_t gaussianY = cGaussianKernel_y[kernelY];\n\n      for (int kernelZ = 0; kernelZ < cKernelSizes[2]; kernelZ++) {\n        int neighbourZ = max(0, min(homeZ + (kernelZ - cHalfWindowSize_arr[2]), cSizes[2] - 1));\n        scalar_t gaussianZ = cGaussianKernel_z[kernelZ];\n\n        int neighbourOffset = neighbourX * cStrides[0] + neighbourY * cStrides[1] + neighbourZ;\n\n        bool flagNotClamped = true;\n        int kernelIndex[] = {kernelX, kernelY, kernelZ};\n        int dimensions = 3; // Must equal the number of spatial dimensions.\n\n        for (int i = 0; i < dimensions; i++) {\n          int HalfWindowSizeBack = cHalfWindowSize_arr[i]; // Define constant memory as new variable here (!!),\n                                                           // otherwise: cudaErrorMisalignedAddress\n          int neighbourIndex = homeIndex[i] + kernelIndex[i] - HalfWindowSizeBack;\n          int neighbourIndexClamped = min(cSizes[i] - 1, max(0, neighbourIndex));\n          if (neighbourIndex != neighbourIndexClamped) {\n            flagNotClamped = false;\n          }\n        }\n\n        scalar_t colorDistance = 0;\n        scalar_t colorDistanceSquared = 0;\n\n#pragma unroll\n        for (int c = 0; c < C; c++) {\n          scalar_t a = guidance[batchOffset + homeOffset + c * cColorStride];\n          scalar_t b = guidance[batchOffset + neighbourOffset + c * cColorStride]; // Home - neighbor (!!) in backward\n                                                                                   // the other way around !!\n          scalar_t diff = a - b;\n          colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.\n          colorDistanceSquared += diff * diff;\n        }\n\n        scalar_t spatialWeight = gaussianX * gaussianY * gaussianZ;\n        scalar_t colorWeight = exp(cColorExponentConstant * colorDistanceSquared);\n        scalar_t totalWeight = spatialWeight * colorWeight;\n\n        // Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.\n        if (flagNotClamped) {\n#pragma unroll\n          for (int c = 0; c < C; c++) {\n            valueSum += input[batchOffset + neighbourOffset + c * cColorStride] * totalWeight;\n\n            // Derivative of weights with respect to X_i while i=k.\n            dw_dz_ki += (-1) * totalWeight * colorDistance / (cColorSigma * cColorSigma);\n            // Derivative of convolved image with respect to X_i while i=k.\n            dfilter_dz_ki += (-1) * totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                colorDistance /\n                (cColorSigma *\n                 cColorSigma); // Be careful, the +1 is missing here -> Added before filling dfilter_dx_kiData\n\n            colorSum_w += totalWeight * colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma);\n            colorSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                colorDistanceSquared / std::abs(cColorSigma * cColorSigma * cColorSigma);\n\n            xSum_w += totalWeight * cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x);\n            xSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                cXDistanceSquared[kernelX] / std::abs(cSigma_x * cSigma_x * cSigma_x);\n\n            ySum_w += totalWeight * cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y);\n            ySum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                cYDistanceSquared[kernelY] / std::abs(cSigma_y * cSigma_y * cSigma_y);\n\n            zSum_w += totalWeight * cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z);\n            zSum_alpha += totalWeight * input[batchOffset + neighbourOffset + c * cColorStride] *\n                cZDistanceSquared[kernelZ] / std::abs(cSigma_z * cSigma_z * cSigma_z);\n          }\n\n          weightSum += totalWeight;\n        }\n      }\n    }\n  }\n\n#pragma unroll\n  for (int c = 0; c < C; c++) {\n    //    output[batchOffset + homeOffset + c * cColorStride] /= weightSum;\n    output[batchOffset + homeOffset + c * cColorStride] = valueSum / weightSum;\n\n    // Pre-computations for the backward pass:\n    outputWeightsTensor[batchOffset + homeOffset + c * cColorStride] = weightSum;\n    dO_dz_ki[batchOffset + homeOffset + c * cColorStride] = -(1 / weightSum) * (valueSum / weightSum) * dw_dz_ki +\n        (1 / weightSum) * (dfilter_dz_ki); // no +1 for dfilter_dz_ki for JBF added here!\n    dO_dsig_r[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * colorSum_w + (1 / weightSum) * colorSum_alpha;\n    dO_dsig_x[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * xSum_w + (1 / weightSum) * xSum_alpha;\n    dO_dsig_y[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * ySum_w + (1 / weightSum) * ySum_alpha;\n    dO_dsig_z[batchOffset + homeOffset + c * cColorStride] =\n        -(1 / weightSum) * (valueSum / weightSum) * zSum_w + (1 / weightSum) * zSum_alpha;\n  }\n}\n\ntemplate <int C, int D>\nvoid JointBilateralFilterCudaForwardFunction(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    torch::Tensor dO_dsig_r,\n    torch::Tensor dO_dsig_x,\n    torch::Tensor dO_dsig_y,\n    torch::Tensor dO_dsig_z,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  // Getting tensor description.\n  TensorDescription desc = TensorDescription(inputTensor);\n\n  // Pre-calculating gaussian kernel.\n  int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size\n  int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size\n  int halfWindowSize_x = floor(0.5f * windowSize_x);\n  int halfWindowSize_y = floor(0.5f * windowSize_y);\n  int halfWindowSize_z = floor(0.5f * windowSize_z);\n  int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};\n  float spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);\n  float spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);\n  float spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);\n  float colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);\n\n  int* kernelSizes = new int[desc.dimensions];\n  kernelSizes[0] = windowSize_x;\n  kernelSizes[1] = windowSize_y;\n  kernelSizes[2] = windowSize_z;\n\n  auto* gaussianKernel_x = new float[windowSize_x];\n  auto* gaussianKernel_y = new float[windowSize_y];\n  auto* gaussianKernel_z = new float[windowSize_z];\n  auto* xDistanceSquared = new float[windowSize_x];\n  auto* yDistanceSquared = new float[windowSize_y];\n  auto* zDistanceSquared = new float[windowSize_z];\n\n  for (int i = 0; i < windowSize_x; i++) {\n    int distance = i - halfWindowSize_x;\n    gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);\n    xDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_y; i++) {\n    int distance = i - halfWindowSize_y;\n    gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);\n    yDistanceSquared[i] = distance * distance;\n  }\n  for (int i = 0; i < windowSize_z; i++) {\n    int distance = i - halfWindowSize_z;\n    gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);\n    zDistanceSquared[i] = distance * distance;\n  }\n\n  // Writing constant memory.\n  cudaMemcpyToSymbol(cBatchStride, &desc.batchStride, sizeof(int));\n  cudaMemcpyToSymbol(cColorStride, &desc.channelStride, sizeof(int));\n  cudaMemcpyToSymbol(cSizes, desc.sizes, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cStrides, desc.strides, sizeof(int) * 3);\n  cudaMemcpyToSymbol(cKernelSizes, kernelSizes, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cHalfWindowSize_arr, halfWindowSize_arr, sizeof(int) * desc.dimensions);\n  cudaMemcpyToSymbol(cGaussianKernel_x, gaussianKernel_x, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cGaussianKernel_y, gaussianKernel_y, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cGaussianKernel_z, gaussianKernel_z, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cXDistanceSquared, xDistanceSquared, sizeof(float) * windowSize_x);\n  cudaMemcpyToSymbol(cYDistanceSquared, yDistanceSquared, sizeof(float) * windowSize_y);\n  cudaMemcpyToSymbol(cZDistanceSquared, zDistanceSquared, sizeof(float) * windowSize_z);\n  cudaMemcpyToSymbol(cColorExponentConstant, &colorExpConstant, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_x, &sigma_x, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_y, &sigma_y, sizeof(float));\n  cudaMemcpyToSymbol(cSigma_z, &sigma_z, sizeof(float));\n  cudaMemcpyToSymbol(cColorSigma, &colorSigma, sizeof(float));\n\n  //  cuda_error_check(\"Cuda check before kernel call.\");\n\n#define BLOCK_SIZE 32\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      inputTensor.scalar_type(), \"JointBilateralFilterCudaKernel3DForward\", ([&] {\n        JointBilateralFilterCudaKernel3DForward<scalar_t, C>\n            <<<dim3(int(desc.channelStride / BLOCK_SIZE) + 1, desc.batchCount), dim3(BLOCK_SIZE, 1)>>>(\n                inputTensor.data_ptr<scalar_t>(),\n                guidanceTensor.data_ptr<scalar_t>(),\n                outputTensor.data_ptr<scalar_t>(),\n                outputWeightsTensor.data_ptr<scalar_t>(),\n                dO_dz_ki.data_ptr<scalar_t>(),\n                dO_dsig_r.data_ptr<scalar_t>(),\n                dO_dsig_x.data_ptr<scalar_t>(),\n                dO_dsig_y.data_ptr<scalar_t>(),\n                dO_dsig_z.data_ptr<scalar_t>());\n      }));\n\n  //  cuda_error_check(\"Cuda check after kernel call.\");\n  //  delete[] kernel;\n  delete[] kernelSizes;\n  delete[] gaussianKernel_x;\n  delete[] gaussianKernel_y;\n  delete[] gaussianKernel_z;\n  delete[] xDistanceSquared;\n  delete[] yDistanceSquared;\n  delete[] zDistanceSquared;\n}\n\n// Function to choose template implementation based on dynamic, channels and dimensions\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nJointBilateralFilterCudaForward(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  torch::Tensor outputTensor = torch::zeros_like(inputTensor);\n  torch::Tensor outputWeightsTensor = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dz_ki = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_r = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_x = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_y = torch::zeros_like(inputTensor);\n  torch::Tensor dO_dsig_z = torch::zeros_like(inputTensor);\n  //  cuda_error_check(\"beginning\");\n\n#define CASE(c, d)                               \\\n  JointBilateralFilterCudaForwardFunction<c, d>( \\\n      inputTensor,                               \\\n      guidanceTensor,                            \\\n      outputTensor,                              \\\n      outputWeightsTensor,                       \\\n      dO_dz_ki,                                  \\\n      dO_dsig_r,                                 \\\n      dO_dsig_x,                                 \\\n      dO_dsig_y,                                 \\\n      dO_dsig_z,                                 \\\n      sigma_x,                                   \\\n      sigma_y,                                   \\\n      sigma_z,                                   \\\n      colorSigma);\n  SWITCH_AB(CASE, BF_CUDA_MAX_CHANNELS, BF_CUDA_MAX_SPATIAL_DIMENSION, inputTensor.size(1), inputTensor.dim() - 2);\n\n  return {outputTensor, outputWeightsTensor, dO_dz_ki, dO_dsig_r, dO_dsig_x, dO_dsig_y, dO_dsig_z};\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n#include <stdexcept>\n#include <string>\n\n#include \"trainable_joint_bilateral.h\"\n#include \"utils/common_utils.h\"\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nTrainableJointBilateralFilterForward(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> (\n      *filterFunction)(torch::Tensor, torch::Tensor, float, float, float, float);\n\n#ifdef WITH_CUDA\n\n  if (torch::cuda::is_available() && inputTensor.is_cuda()) {\n    CHECK_CONTIGUOUS_CUDA(inputTensor);\n\n    if (inputTensor.size(1) > BF_CUDA_MAX_CHANNELS) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for channel count > \" + std::to_string(BF_CUDA_MAX_CHANNELS));\n    }\n\n    if (inputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for spatial dimension > \" +\n          std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));\n    }\n\n    filterFunction = &JointBilateralFilterCudaForward;\n  } else {\n    filterFunction = &JointBilateralFilterCpuForward;\n  }\n#else\n  filterFunction = &JointBilateralFilterCpuForward;\n#endif\n\n  return filterFunction(inputTensor, guidanceTensor, sigma_x, sigma_y, sigma_z, colorSigma);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> TrainableJointBilateralFilterBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma) {\n  std::tuple<torch::Tensor, torch::Tensor> (*filterFunction)(\n      torch::Tensor,\n      torch::Tensor,\n      torch::Tensor,\n      torch::Tensor,\n      torch::Tensor,\n      torch::Tensor,\n      float,\n      float,\n      float,\n      float);\n\n#ifdef WITH_CUDA\n\n  if (torch::cuda::is_available() && gradientInputTensor.is_cuda()) {\n    CHECK_CONTIGUOUS_CUDA(gradientInputTensor);\n\n    if (gradientInputTensor.size(1) > BF_CUDA_MAX_CHANNELS) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for channel count > \" + std::to_string(BF_CUDA_MAX_CHANNELS));\n    }\n\n    if (gradientInputTensor.dim() - 2 > BF_CUDA_MAX_SPATIAL_DIMENSION) {\n      throw std::runtime_error(\n          \"Bilateral filtering not implemented for spatial dimension > \" +\n          std::to_string(BF_CUDA_MAX_SPATIAL_DIMENSION));\n    }\n\n    filterFunction = &JointBilateralFilterCudaBackward;\n  } else {\n    filterFunction = &JointBilateralFilterCpuBackward;\n  }\n#else\n  filterFunction = &JointBilateralFilterCpuBackward;\n#endif\n\n  return filterFunction(\n      gradientInputTensor,\n      inputTensor,\n      guidanceTensor,\n      outputTensor,\n      outputWeightsTensor,\n      dO_dx_ki,\n      sigma_x,\n      sigma_y,\n      sigma_z,\n      colorSigma);\n}\n"
  },
  {
    "path": "monai/csrc/filtering/trainable_joint_bilateral/trainable_joint_bilateral.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n=========================================================================\nAdapted from https://github.com/faebstn96/trainable-joint-bilateral-filter-source\nwhich has the following license...\nhttps://github.com/faebstn96/trainable-joint-bilateral-filter-source/blob/main/LICENSE\n\nCopyright 2022 Fabian Wagner, Pattern Recognition Lab, FAU Erlangen-Nuernberg, Erlangen, Germany\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n#include <torch/extension.h>\n#include <algorithm>\n#include <iostream>\n#include <vector>\n#include \"utils/common_utils.h\"\n//#include \"utils/tensor_description.h\"\n\n#define BF_CUDA_MAX_CHANNELS 16\n#define BF_CUDA_MAX_SPATIAL_DIMENSION 3\n\n#ifdef WITH_CUDA\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nJointBilateralFilterCudaForward(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\nstd::tuple<torch::Tensor, torch::Tensor> JointBilateralFilterCudaBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n#endif\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nJointBilateralFilterCpuForward(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n\nstd::tuple<torch::Tensor, torch::Tensor> JointBilateralFilterCpuBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dz_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>\nTrainableJointBilateralFilterForward(\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n\nstd::tuple<torch::Tensor, torch::Tensor> TrainableJointBilateralFilterBackward(\n    torch::Tensor gradientInputTensor,\n    torch::Tensor inputTensor,\n    torch::Tensor guidanceTensor,\n    torch::Tensor outputTensor,\n    torch::Tensor outputWeightsTensor,\n    torch::Tensor dO_dx_ki,\n    float sigma_x,\n    float sigma_y,\n    float sigma_z,\n    float colorSigma);\n"
  },
  {
    "path": "monai/csrc/lltm/lltm.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n#include <torch/extension.h>\n#include <vector>\n#include \"utils/common_utils.h\"\n\n#ifdef WITH_CUDA\nstd::vector<torch::Tensor> lltm_cuda_forward(\n    torch::Tensor input,\n    torch::Tensor weights,\n    torch::Tensor bias,\n    torch::Tensor old_h,\n    torch::Tensor old_cell);\n\nstd::vector<torch::Tensor> lltm_cuda_backward(\n    torch::Tensor grad_h,\n    torch::Tensor grad_cell,\n    torch::Tensor new_cell,\n    torch::Tensor input_gate,\n    torch::Tensor output_gate,\n    torch::Tensor candidate_cell,\n    torch::Tensor X,\n    torch::Tensor gate_weights,\n    torch::Tensor weights);\n#endif\n\nstd::vector<torch::Tensor> lltm_cpu_forward(\n    torch::Tensor input,\n    torch::Tensor weights,\n    torch::Tensor bias,\n    torch::Tensor old_h,\n    torch::Tensor old_cell);\n\nstd::vector<torch::Tensor> lltm_cpu_backward(\n    torch::Tensor grad_h,\n    torch::Tensor grad_cell,\n    torch::Tensor new_cell,\n    torch::Tensor input_gate,\n    torch::Tensor output_gate,\n    torch::Tensor candidate_cell,\n    torch::Tensor X,\n    torch::Tensor gate_weights,\n    torch::Tensor weights);\n\nstd::vector<torch::Tensor> lltm_forward(\n    torch::Tensor input,\n    torch::Tensor weights,\n    torch::Tensor bias,\n    torch::Tensor old_h,\n    torch::Tensor old_cell) {\n  if (input.is_cuda()) {\n#ifdef WITH_CUDA\n    CHECK_CONTIGUOUS_CUDA(input);\n    CHECK_CONTIGUOUS_CUDA(weights);\n    CHECK_CONTIGUOUS_CUDA(bias);\n    CHECK_CONTIGUOUS_CUDA(old_h);\n    CHECK_CONTIGUOUS_CUDA(old_cell);\n\n    return lltm_cuda_forward(input, weights, bias, old_h, old_cell);\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  }\n  return lltm_cpu_forward(input, weights, bias, old_h, old_cell);\n}\n\nstd::vector<torch::Tensor> lltm_backward(\n    torch::Tensor grad_h,\n    torch::Tensor grad_cell,\n    torch::Tensor new_cell,\n    torch::Tensor input_gate,\n    torch::Tensor output_gate,\n    torch::Tensor candidate_cell,\n    torch::Tensor X,\n    torch::Tensor gate_weights,\n    torch::Tensor weights) {\n  if (X.is_cuda()) {\n#ifdef WITH_CUDA\n    CHECK_CONTIGUOUS_CUDA(grad_h);\n    CHECK_CONTIGUOUS_CUDA(grad_cell);\n    CHECK_CONTIGUOUS_CUDA(new_cell);\n    CHECK_CONTIGUOUS_CUDA(input_gate);\n    CHECK_CONTIGUOUS_CUDA(output_gate);\n    CHECK_CONTIGUOUS_CUDA(candidate_cell);\n    CHECK_CONTIGUOUS_CUDA(X);\n    CHECK_CONTIGUOUS_CUDA(gate_weights);\n    CHECK_CONTIGUOUS_CUDA(weights);\n\n    return lltm_cuda_backward(\n        grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights);\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  }\n  return lltm_cpu_backward(\n      grad_h, grad_cell, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights, weights);\n}\n"
  },
  {
    "path": "monai/csrc/lltm/lltm_cpu.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n#include <vector>\n\n// s'(z) = (1 - s(z)) * s(z)\ntorch::Tensor d_sigmoid(torch::Tensor z) {\n  auto s = torch::sigmoid(z);\n  return (1 - s) * s;\n}\n\n// tanh'(z) = 1 - tanh^2(z)\ntorch::Tensor d_tanh(torch::Tensor z) {\n  return 1 - z.tanh().pow(2);\n}\n\n// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}\ntorch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {\n  auto e = z.exp();\n  auto mask = (alpha * (e - 1)) < 0;\n  return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);\n}\n\nstd::vector<torch::Tensor> lltm_cpu_forward(\n    torch::Tensor input,\n    torch::Tensor weights,\n    torch::Tensor bias,\n    torch::Tensor old_h,\n    torch::Tensor old_cell) {\n  auto X = torch::cat({old_h, input}, /*dim=*/1);\n\n  auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));\n  auto gates = gate_weights.chunk(3, /*dim=*/1);\n\n  auto input_gate = torch::sigmoid(gates[0]);\n  auto output_gate = torch::sigmoid(gates[1]);\n  auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);\n\n  auto new_cell = old_cell + candidate_cell * input_gate;\n  auto new_h = torch::tanh(new_cell) * output_gate;\n\n  return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights};\n}\n\nstd::vector<torch::Tensor> lltm_cpu_backward(\n    torch::Tensor grad_h,\n    torch::Tensor grad_cell,\n    torch::Tensor new_cell,\n    torch::Tensor input_gate,\n    torch::Tensor output_gate,\n    torch::Tensor candidate_cell,\n    torch::Tensor X,\n    torch::Tensor gate_weights,\n    torch::Tensor weights) {\n  auto d_output_gate = torch::tanh(new_cell) * grad_h;\n  auto d_tanh_new_cell = output_gate * grad_h;\n  auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;\n\n  auto d_old_cell = d_new_cell;\n  auto d_candidate_cell = input_gate * d_new_cell;\n  auto d_input_gate = candidate_cell * d_new_cell;\n\n  auto gates = gate_weights.chunk(3, /*dim=*/1);\n  d_input_gate *= d_sigmoid(gates[0]);\n  d_output_gate *= d_sigmoid(gates[1]);\n  d_candidate_cell *= d_elu(gates[2]);\n\n  auto d_gates = torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);\n\n  auto d_weights = d_gates.t().mm(X);\n  auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);\n\n  auto d_X = d_gates.mm(weights);\n  const auto state_size = grad_h.size(1);\n  auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);\n  auto d_input = d_X.slice(/*dim=*/1, state_size);\n\n  return {d_old_h, d_input, d_weights, d_bias, d_old_cell};\n}\n"
  },
  {
    "path": "monai/csrc/lltm/lltm_cuda.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include <vector>\n\nnamespace {\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {\n  return 1.0 / (1.0 + exp(-z));\n}\n\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {\n  const auto s = sigmoid(z);\n  return (1.0 - s) * s;\n}\n\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {\n  const auto t = tanh(z);\n  return 1 - (t * t);\n}\n\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {\n  return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0));\n}\n\ntemplate <typename scalar_t>\n__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {\n  const auto e = exp(z);\n  const auto d_relu = z < 0.0 ? 0.0 : 1.0;\n  return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0);\n}\n\ntemplate <typename scalar_t>\n__global__ void lltm_cuda_forward_kernel(\n    const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> gates,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> old_cell,\n    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> new_h,\n    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> new_cell,\n    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input_gate,\n    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> output_gate,\n    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> candidate_cell) {\n  // batch index\n  const int n = blockIdx.y;\n  // column index\n  const int c = blockIdx.x * blockDim.x + threadIdx.x;\n  if (c < gates.size(2)) {\n    input_gate[n][c] = sigmoid(gates[n][0][c]);\n    output_gate[n][c] = sigmoid(gates[n][1][c]);\n    candidate_cell[n][c] = elu(gates[n][2][c]);\n    new_cell[n][c] = old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];\n    new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void lltm_cuda_backward_kernel(\n    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> d_old_cell,\n    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> d_gates,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> grad_h,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> grad_cell,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> new_cell,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> input_gate,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> output_gate,\n    const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> candidate_cell,\n    const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> gate_weights) {\n  // batch index\n  const int n = blockIdx.y;\n  // column index\n  const int c = blockIdx.x * blockDim.x + threadIdx.x;\n  if (c < d_gates.size(2)) {\n    const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];\n    const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];\n    const auto d_new_cell = d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];\n\n    d_old_cell[n][c] = d_new_cell;\n    const auto d_candidate_cell = input_gate[n][c] * d_new_cell;\n    const auto d_input_gate = candidate_cell[n][c] * d_new_cell;\n\n    d_gates[n][0][c] = d_input_gate * d_sigmoid(gate_weights[n][0][c]);\n    d_gates[n][1][c] = d_output_gate * d_sigmoid(gate_weights[n][1][c]);\n    d_gates[n][2][c] = d_candidate_cell * d_elu(gate_weights[n][2][c]);\n  }\n}\n} // namespace\n\nstd::vector<torch::Tensor> lltm_cuda_forward(\n    torch::Tensor input,\n    torch::Tensor weights,\n    torch::Tensor bias,\n    torch::Tensor old_h,\n    torch::Tensor old_cell) {\n  auto X = torch::cat({old_h, input}, /*dim=*/1);\n  auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));\n\n  const auto batch_size = old_cell.size(0);\n  const auto state_size = old_cell.size(1);\n\n  auto gates = gate_weights.reshape({batch_size, 3, state_size});\n  auto new_h = torch::zeros_like(old_cell);\n  auto new_cell = torch::zeros_like(old_cell);\n  auto input_gate = torch::zeros_like(old_cell);\n  auto output_gate = torch::zeros_like(old_cell);\n  auto candidate_cell = torch::zeros_like(old_cell);\n\n  const int threads = 1024;\n  const dim3 blocks((state_size + threads - 1) / threads, batch_size);\n\n  AT_DISPATCH_FLOATING_TYPES(gates.scalar_type(), \"lltm_forward_cuda\", ([&] {\n                               lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(\n                                   gates.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),\n                                   old_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   new_h.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   new_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   input_gate.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   output_gate.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   candidate_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>());\n                             }));\n\n  return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};\n}\n\nstd::vector<torch::Tensor> lltm_cuda_backward(\n    torch::Tensor grad_h,\n    torch::Tensor grad_cell,\n    torch::Tensor new_cell,\n    torch::Tensor input_gate,\n    torch::Tensor output_gate,\n    torch::Tensor candidate_cell,\n    torch::Tensor X,\n    torch::Tensor gates,\n    torch::Tensor weights) {\n  auto d_old_cell = torch::zeros_like(new_cell);\n  auto d_gates = torch::zeros_like(gates);\n\n  const auto batch_size = new_cell.size(0);\n  const auto state_size = new_cell.size(1);\n\n  const int threads = 1024;\n  const dim3 blocks((state_size + threads - 1) / threads, batch_size);\n\n  AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), \"lltm_forward_cuda\", ([&] {\n                               lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(\n                                   d_old_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   d_gates.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),\n                                   grad_h.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   grad_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   new_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   input_gate.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   output_gate.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   candidate_cell.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                                   gates.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>());\n                             }));\n\n  auto d_gate_weights = d_gates.flatten(1, 2);\n  auto d_weights = d_gate_weights.t().mm(X);\n  auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);\n\n  auto d_X = d_gate_weights.mm(weights);\n  auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);\n  auto d_input = d_X.slice(/*dim=*/1, state_size);\n\n  return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};\n}\n"
  },
  {
    "path": "monai/csrc/resample/bounds_common.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// adapted from https://github.com/balbasty/nitorch\n\n#pragma once\n\n// This file contains static functions for handling out-of-bound indices.\n// They implement typical boundary conditions (those of standard discrete\n// transforms) + a few other cases (replicated border, zeros, sliding)\n// It also defines an enumerated types that encodes each boundary type.\n// The entry points are:\n// . monai::bound::index -> wrap out-of-bound indices\n// . monai::bound::sign  -> optional out-of-bound sign change (sine transforms)\n// . monai::BoundType    -> enumerated boundary type\n//\n// Everything in this file should have internal linkage (static) except\n// the BoundType/BoundVectorRef types.\n\n#include \"utils/resample_utils.h\"\n\nnamespace monai {\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n//                             INDEXING\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nnamespace _index {\n\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE size_t inbounds(size_t coord, size_t size) {\n  return coord;\n}\n\n// Boundary condition of a DCT-I (periodicity: (n-1)*2)\n// Indices are reflected about the centre of the border elements:\n//    -1 --> 1\n//     n --> n-2\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE size_t reflect1c(size_t coord, size_t size) {\n  if (size == 1)\n    return 0;\n  size_t size_twice = (size - 1) * 2;\n  coord = coord < 0 ? -coord : coord;\n  coord = coord % size_twice;\n  coord = coord >= size ? size_twice - coord : coord;\n  return coord;\n}\n\n// Boundary condition of a DST-I (periodicity: (n+1)*2)\n// Indices are reflected about the centre of the first out-of-bound\n// element:\n//    -1 --> undefined [0]\n//    -2 --> 0\n//     n --> undefined [n-1]\n//   n+1 --> n-1\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE size_t reflect1s(size_t coord, size_t size) {\n  if (size == 1)\n    return 0;\n  size_t size_twice = (size + 1) * 2;\n  coord = coord == -1 ? 0 : coord < 0 ? -coord - 2 : coord;\n  coord = coord % size_twice;\n  coord = coord == size ? size - 1 : coord > size ? size_twice - coord - 2 : coord;\n  return coord;\n}\n\n// Boundary condition of a DCT/DST-II (periodicity: n*2)\n// Indices are reflected about the edge of the border elements:\n//    -1 --> 0\n//     n --> n-1\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE size_t reflect2(size_t coord, size_t size) {\n  size_t size_twice = size * 2;\n  coord = coord < 0 ? size_twice - ((-coord - 1) % size_twice) - 1 : coord % size_twice;\n  coord = coord >= size ? size_twice - coord - 1 : coord;\n  return coord;\n}\n\n// Boundary condition of a DFT (periodicity: n)\n// Indices wrap about the edges:\n//    -1 --> n-1\n//     n --> 0\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE size_t circular(size_t coord, size_t size) {\n  coord = coord < 0 ? (size + coord % size) % size : coord % size;\n  return coord;\n}\n\n// Replicate edge values:\n//    -1 --> 0\n//    -2 --> 0\n//     n --> n-1\n//   n+1 --> n-1\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE size_t replicate(size_t coord, size_t size) {\n  coord = coord <= 0 ? 0 : coord >= size ? size - 1 : coord;\n  return coord;\n}\n\n} // namespace _index\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n//                          SIGN MODIFICATION\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nnamespace _sign {\n\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE int8_t inbounds(size_t coord, size_t size) {\n  return coord < 0 || coord >= size ? 0 : 1;\n}\n\n// Boundary condition of a DCT/DFT\n// No sign modification based on coordinates\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE int8_t constant(size_t coord, size_t size) {\n  return static_cast<int8_t>(1);\n}\n\n// Boundary condition of a DST-I\n// Periodic sign change based on coordinates\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE int8_t periodic1(size_t coord, size_t size) {\n  if (size == 1)\n    return 1;\n  size_t size_twice = (size + 1) * 2;\n  coord = coord < 0 ? size - coord - 1 : coord;\n  coord = coord % size_twice;\n  if (coord % (size + 1) == size)\n    return static_cast<int8_t>(0);\n  else if ((coord / (size + 1)) % 2)\n    return static_cast<int8_t>(-1);\n  else\n    return static_cast<int8_t>(1);\n}\n\n// Boundary condition of a DST-II\n// Periodic sign change based on coordinates\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE int8_t periodic2(size_t coord, size_t size) {\n  coord = (coord < 0 ? size - coord - 1 : coord);\n  return static_cast<int8_t>((coord / size) % 2 ? -1 : 1);\n}\n\n} // namespace _sign\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n//                                BOUND\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n// Check if coordinates within bounds\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE bool inbounds(size_t coord, size_t size) {\n  return coord >= 0 && coord < size;\n}\n\ntemplate <typename scalar_t, typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE bool inbounds(scalar_t coord, size_t size, scalar_t tol) {\n  return coord >= -tol && coord < (scalar_t)(size - 1) + tol;\n}\n\nnamespace bound {\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t\nget(const scalar_t* ptr, offset_t offset, int8_t sign = static_cast<int8_t>(1)) {\n  if (sign == -1)\n    return -ptr[offset];\n  else if (sign)\n    return ptr[offset];\n  else\n    return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void add(\n    scalar_t* ptr,\n    offset_t offset,\n    scalar_t val,\n    int8_t sign = static_cast<int8_t>(1)) {\n  if (sign == -1)\n    MONAI_ATOMIC_ADD(ptr, offset, -val);\n  else if (sign)\n    MONAI_ATOMIC_ADD(ptr, offset, val);\n}\n\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE int64_t index(BoundType bound_type, size_t coord, size_t size) {\n  switch (bound_type) {\n    case BoundType::Replicate:\n      return _index::replicate(coord, size);\n    case BoundType::DCT1:\n      return _index::reflect1c(coord, size);\n    case BoundType::DCT2:\n      return _index::reflect2(coord, size);\n    case BoundType::DST1:\n      return _index::reflect1s(coord, size);\n    case BoundType::DST2:\n      return _index::reflect2(coord, size);\n    case BoundType::DFT:\n      return _index::circular(coord, size);\n    case BoundType::Zero:\n      return _index::inbounds(coord, size);\n    default:\n      return _index::inbounds(coord, size);\n  }\n}\n\ntemplate <typename size_t>\nstatic MONAI_INLINE MONAI_DEVICE int8_t sign(BoundType bound_type, size_t coord, size_t size) {\n  switch (bound_type) {\n    case BoundType::Replicate:\n      return _sign::constant(coord, size);\n    case BoundType::DCT1:\n      return _sign::constant(coord, size);\n    case BoundType::DCT2:\n      return _sign::constant(coord, size);\n    case BoundType::DST1:\n      return _sign::periodic1(coord, size);\n    case BoundType::DST2:\n      return _sign::periodic2(coord, size);\n    case BoundType::DFT:\n      return _sign::constant(coord, size);\n    case BoundType::Zero:\n      return _sign::inbounds(coord, size);\n    default:\n      return _sign::inbounds(coord, size);\n  }\n}\n\n} // namespace bound\n} // namespace monai\n"
  },
  {
    "path": "monai/csrc/resample/interpolation_common.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// adapted from https://github.com/balbasty/nitorch\n\n#pragma once\n\n// This file contains static functions for handling (0-7 order)\n// interpolation weights.\n// It also defines an enumerated types that encodes each boundary type.\n// The entry points are:\n// . monai::interpolation::weight     -> node weight based on distance\n// . monai::interpolation::fastweight -> same, assuming x lies in support\n// . monai::interpolation::grad       -> weight derivative // oriented distance\n// . monai::interpolation::fastgrad   -> same, assuming x lies in support\n// . monai::interpolation::hess       -> weight 2nd derivative // oriented distance\n// . monai::interpolation::fasthess   -> same, assuming x lies in support\n// . monai::interpolation::bounds     -> min/max nodes\n\n// NOTE:\n// 1st derivatives used to be implemented with a recursive call, e.g.:\n// scalar_t grad2(scalar_t x) {\n//   if (x < 0) return -grad2(-x);\n//   ...\n// }\n// However, this prevents nvcc to statically determine the stack size\n// and leads to memory errors (because the allocated stack is too small).\n// I now use a slightly less compact implementation that gets rid of\n// recursive calls.\n\n// TODO:\n// . second order derivatives [5/6/7]\n// ? other types of basis functions (gauss, sinc)\n\n#include \"utils/resample_utils.h\"\n\nnamespace monai {\n\nnamespace _interpolation {\n\n// --- order 0 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight0(scalar_t x) {\n  x = std::fabs(x);\n  return x < 0.5 ? static_cast<scalar_t>(1) : static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight0(scalar_t x) {\n  x = std::fabs(x);\n  return static_cast<scalar_t>(1);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad0(scalar_t x) {\n  return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad0(scalar_t x) {\n  return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t hess0(scalar_t x) {\n  return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fasthess0(scalar_t x) {\n  return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds0(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::round(x));\n  upp = low;\n}\n\n// --- order 1 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight1(scalar_t x) {\n  x = std::fabs(x);\n  return x < 1 ? static_cast<scalar_t>(1) - x : static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight1(scalar_t x) {\n  return static_cast<scalar_t>(1) - std::fabs(x);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad1(scalar_t x) {\n  if (std::fabs(x) >= 1)\n    return static_cast<scalar_t>(0);\n  return fastgrad1(x);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad1(scalar_t x) {\n  return x < static_cast<scalar_t>(0) ? static_cast<scalar_t>(1) : static_cast<scalar_t>(-1);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t hess1(scalar_t x) {\n  return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fasthess1(scalar_t x) {\n  return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds1(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x));\n  upp = low + 1;\n}\n\n// --- order 2 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight2(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    return 0.75 - x * x;\n  } else if (x < 1.5) {\n    x = 1.5 - x;\n    return 0.5 * x * x;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight2(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    return 0.75 - x * x;\n  } else {\n    x = 1.5 - x;\n    return 0.5 * x * x;\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad2(scalar_t x) {\n  bool neg = x < 0;\n  if (x < 0.5) {\n    x = -2. * x;\n  } else if (x < 1.5) {\n    x = x - 1.5;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad2(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 0.5) {\n    x = -2. * x;\n  } else {\n    x = x - 1.5;\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t hess2(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    return static_cast<scalar_t>(-2.);\n  } else if (x < 1.5) {\n    return static_cast<scalar_t>(1.);\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fasthess2(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    return static_cast<scalar_t>(-2.);\n  } else {\n    return static_cast<scalar_t>(1.);\n  }\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds2(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x - .5));\n  upp = low + 2;\n}\n\n// --- order 3 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight3(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    return (x * x * (x - 2.) * 3. + 4.) / 6.;\n  } else if (x < 2.) {\n    x = 2. - x;\n    return (x * x * x) / 6.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight3(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    return (x * x * (x - 2.) * 3. + 4.) / 6.;\n  } else {\n    x = 2. - x;\n    return (x * x * x) / 6.;\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad3(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 1.) {\n    x = x * (x * 1.5 - 2.);\n  } else if (x < 2.) {\n    x = 2. - x;\n    x = -(x * x) * 0.5;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad3(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 1.) {\n    x = x * (x * 1.5 - 2.);\n  } else {\n    x = 2. - x;\n    x = -(x * x) * 0.5;\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t hess3(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    return x * 3. - 2.;\n  } else if (x < 2.) {\n    return 2. - x;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fasthess3(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    return x * 3. - 2.;\n  } else {\n    return 2. - x;\n  }\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds3(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x - 1.));\n  upp = low + 3;\n}\n\n// --- order 4 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight4(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    x *= x;\n    return x * (x * 0.25 - 0.625) + 115. / 192.;\n  } else if (x < 1.5) {\n    return x * (x * (x * (5. - x) / 6. - 1.25) + 5. / 24.) + 55. / 96.;\n  } else if (x < 2.5) {\n    x -= 2.5;\n    x *= x;\n    return (x * x) / 24.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight4(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    x *= x;\n    return x * (x * 0.25 - 0.625) + 115. / 192.;\n  } else if (x < 1.5) {\n    return x * (x * (x * (5. - x) / 6. - 1.25) + 5. / 24.) + 55. / 96.;\n  } else {\n    x -= 2.5;\n    x *= x;\n    return (x * x) / 24.;\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad4(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 0.5) {\n    x = x * (x * x - 1.25);\n  } else if (x < 1.5) {\n    x = x * (x * (x * (-2. / 3.) + 2.5) - 2.5) + 5. / 24.;\n  } else if (x < 2.5) {\n    x = x * 2. - 5.;\n    x = (x * x * x) / 48.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad4(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 0.5) {\n    x = x * (x * x - 1.25);\n  } else if (x < 1.5) {\n    x = x * (x * (x * (-2. / 3.) + 2.5) - 2.5) + 5. / 24.;\n  } else {\n    x = x * 2. - 5.;\n    x = (x * x * x) / 48.;\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t hess4(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    return (x * x) * 3. - 1.25;\n  } else if (x < 1.5) {\n    return x * (x * (-2.) + 5.) - 2.5;\n  } else if (x < 2.5) {\n    x = x * 2. - 5.;\n    return (x * x) / 8.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fasthess4(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    return (x * x) * 3. - 1.25;\n  } else if (x < 1.5) {\n    return x * (x * (-2.) + 5.) - 2.5;\n  } else {\n    x = x * 2. - 5.;\n    return (x * x) / 8.;\n  }\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds4(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x - 1.5));\n  upp = low + 4;\n}\n\n// --- order 5 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight5(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    scalar_t f = x * x;\n    return f * (f * (0.25 - x * (1. / 12.)) - 0.5) + 0.55;\n  } else if (x < 2.) {\n    return x * (x * (x * (x * (x * (1. / 24.) - 0.375) + 1.25) - 1.75) + 0.625) + 0.425;\n  } else if (x < 3.) {\n    scalar_t f = 3. - x;\n    x = f * f;\n    return f * x * x * (1. / 120.);\n  } else\n    return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight5(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    scalar_t f = x * x;\n    return f * (f * (0.25 - x * (1. / 12.)) - 0.5) + 0.55;\n  } else if (x < 2.) {\n    return x * (x * (x * (x * (x * (1. / 24.) - 0.375) + 1.25) - 1.75) + 0.625) + 0.425;\n  } else {\n    scalar_t f = 3. - x;\n    x = f * f;\n    return f * x * x * (1. / 120.);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad5(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 1.) {\n    x = x * (x * (x * (x * (-5. / 12.) + 1.)) - 1.);\n  } else if (x < 2.) {\n    x = x * (x * (x * (x * (5. / 24.) - 1.5) + 3.75) - 3.5) + 0.625;\n  } else if (x < 3.) {\n    x -= 3.;\n    x *= x;\n    x = -(x * x) / 24.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad5(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 1.) {\n    x = x * (x * (x * (x * (-5. / 12.) + 1.)) - 1.);\n  } else if (x < 2.) {\n    x = x * (x * (x * (x * (5. / 24.) - 1.5) + 3.75) - 3.5) + 0.625;\n  } else {\n    x -= 3.;\n    x *= x;\n    x = -(x * x) / 24.;\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds5(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x - 2.));\n  upp = low + 5;\n}\n\n// --- order 6 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight6(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    x *= x;\n    return x * (x * (7. / 48. - x * (1. / 36.)) - 77. / 192.) + 5887. / 11520.0;\n  } else if (x < 1.5) {\n    return x * (x * (x * (x * (x * (x * (1. / 48.) - 7. / 48.) + 0.328125) - 35. / 288.) - 91. / 256.) - 7. / 768.) +\n        7861. / 15360.0;\n  } else if (x < 2.5) {\n    return x * (x * (x * (x * (x * (7. / 60. - x * (1. / 120.)) - 0.65625) + 133. / 72.) - 2.5703125) + 1267. / 960.) +\n        1379. / 7680.0;\n  } else if (x < 3.5) {\n    x -= 3.5;\n    x *= x * x;\n    return x * x * (1. / 720.);\n  } else\n    return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight6(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 0.5) {\n    x *= x;\n    return x * (x * (7. / 48. - x * (1. / 36.)) - 77. / 192.) + 5887. / 11520.0;\n  } else if (x < 1.5) {\n    return x * (x * (x * (x * (x * (x * (1. / 48.) - 7. / 48.) + 0.328125) - 35. / 288.) - 91. / 256.) - 7. / 768.) +\n        7861. / 15360.0;\n  } else if (x < 2.5) {\n    return x * (x * (x * (x * (x * (7. / 60. - x * (1. / 120.)) - 0.65625) + 133. / 72.) - 2.5703125) + 1267. / 960.) +\n        1379. / 7680.0;\n  } else {\n    x -= 3.5;\n    x *= x * x;\n    return x * x * (1. / 720.);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad6(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < .5) {\n    scalar_t x2 = x * x;\n    x = x * (x2 * (7. / 12.) - (x2 * x2) / 6. - 77. / 96.);\n  } else if (x < 1.5) {\n    x = x * (x * (x * (x * (x * 0.125 - 35. / 48.) + 1.3125) - 35. / 96.) - 0.7109375) - 7.0 / 768.0;\n  } else if (x < 2.5) {\n    x = x * (x * (x * (x * (x * (-1. / 20.) + 7. / 12.) - 2.625) + 133. / 24.) - 5.140625) + 1267. / 960.;\n  } else if (x < 3.5) {\n    x *= 2.;\n    x -= 7.;\n    scalar_t x2 = x * x;\n    x = (x2 * x2 * x) / 3840.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad6(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < .5) {\n    scalar_t x2 = x * x;\n    x = x * (x2 * (7. / 12.) - (x2 * x2) / 6. - 77. / 96.);\n  } else if (x < 1.5) {\n    x = x * (x * (x * (x * (x * 0.125 - 35. / 48.) + 1.3125) - 35. / 96.) - 0.7109375) - 7.0 / 768.0;\n  } else if (x < 2.5) {\n    x = x * (x * (x * (x * (x * (-1. / 20.) + 7. / 12.) - 2.625) + 133. / 24.) - 5.140625) + 1267. / 960.;\n  } else {\n    x *= 2.;\n    x -= 7.;\n    scalar_t x2 = x * x;\n    x = (x2 * x2 * x) / 3840.;\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds6(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x - 2.5));\n  upp = low + 6;\n}\n\n// --- order 7 -------------------------------------------------------\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight7(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    scalar_t f = x * x;\n    return f * (f * (f * (x * (1. / 144.) - 1. / 36.) + 1. / 9.) - 1. / 3.) + 151. / 315.0;\n  } else if (x < 2.) {\n    return x * (x * (x * (x * (x * (x * (0.05 - x * (1. / 240.)) - 7. / 30.) + 0.5) - 7. / 18.) - 0.1) - 7. / 90.) +\n        103. / 210.0;\n  } else if (x < 3.) {\n    return x *\n        (x * (x * (x * (x * (x * (x * (1. / 720.) - 1. / 36.) + 7. / 30.) - 19. / 18.) + 49. / 18.) - 23. / 6.) +\n         217. / 90.) -\n        139. / 630.0;\n  } else if (x < 4.) {\n    scalar_t f = 4. - x;\n    x = f * f * f;\n    return (x * x * f) / 5040.;\n  } else\n    return static_cast<scalar_t>(0);\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight7(scalar_t x) {\n  x = std::fabs(x);\n  if (x < 1.) {\n    scalar_t f = x * x;\n    return f * (f * (f * (x * (1. / 144.) - 1. / 36.) + 1. / 9.) - 1. / 3.) + 151. / 315.0;\n  } else if (x < 2.) {\n    return x * (x * (x * (x * (x * (x * (0.05 - x * (1. / 240.)) - 7. / 30.) + 0.5) - 7. / 18.) - 0.1) - 7. / 90.) +\n        103. / 210.0;\n  } else if (x < 3.) {\n    return x *\n        (x * (x * (x * (x * (x * (x * (1. / 720.) - 1. / 36.) + 7. / 30.) - 19. / 18.) + 49. / 18.) - 23. / 6.) +\n         217. / 90.) -\n        139. / 630.0;\n  } else {\n    scalar_t f = 4. - x;\n    x = f * f * f;\n    return (x * x * f) / 5040.;\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad7(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 1.) {\n    scalar_t x2 = x * x;\n    x = x * (x2 * (x2 * (x * (7. / 144.) - 1. / 6.) + 4. / 9.) - 2. / 3.);\n  } else if (x < 2.) {\n    x = x * (x * (x * (x * (x * (x * (-7. / 240.) + 3. / 10.) - 7. / 6.) + 2.) - 7. / 6.) - 1. / 5.) - 7. / 90.;\n  } else if (x < 3.) {\n    x = x * (x * (x * (x * (x * (x * (7. / 720.) - 1. / 6.) + 7. / 6.) - 38. / 9.) + 49. / 6.) - 23. / 3.) + 217. / 90.;\n  } else if (x < 4.) {\n    x -= 4;\n    x *= x * x;\n    x *= x;\n    x = -x / 720.;\n  } else {\n    return static_cast<scalar_t>(0);\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad7(scalar_t x) {\n  bool neg = x < 0;\n  if (neg)\n    x = -x;\n  if (x < 1.) {\n    scalar_t x2 = x * x;\n    x = x * (x2 * (x2 * (x * (7. / 144.) - 1. / 6.) + 4. / 9.) - 2. / 3.);\n  } else if (x < 2.) {\n    x = x * (x * (x * (x * (x * (x * (-7. / 240.) + 3. / 10.) - 7. / 6.) + 2.) - 7. / 6.) - 1. / 5.) - 7. / 90.;\n  } else if (x < 3.) {\n    x = x * (x * (x * (x * (x * (x * (7. / 720.) - 1. / 6.) + 7. / 6.) - 38. / 9.) + 49. / 6.) - 23. / 3.) + 217. / 90.;\n  } else {\n    x -= 4;\n    x *= x * x;\n    x *= x;\n    x = -x / 720.;\n  }\n  if (neg)\n    x = -x;\n  return x;\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds7(scalar_t x, offset_t& low, offset_t& upp) {\n  low = static_cast<offset_t>(std::floor(x - 3.));\n  upp = low + 7;\n}\n\n} // namespace _interpolation\n\nnamespace interpolation {\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t weight(InterpolationType interpolation_type, scalar_t x) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::weight0(x);\n    case InterpolationType::Linear:\n      return _interpolation::weight1(x);\n    case InterpolationType::Quadratic:\n      return _interpolation::weight2(x);\n    case InterpolationType::Cubic:\n      return _interpolation::weight3(x);\n    case InterpolationType::FourthOrder:\n      return _interpolation::weight4(x);\n    case InterpolationType::FifthOrder:\n      return _interpolation::weight5(x);\n    case InterpolationType::SixthOrder:\n      return _interpolation::weight6(x);\n    case InterpolationType::SeventhOrder:\n      return _interpolation::weight7(x);\n    default:\n      return _interpolation::weight1(x);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastweight(InterpolationType interpolation_type, scalar_t x) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::fastweight0(x);\n    case InterpolationType::Linear:\n      return _interpolation::fastweight1(x);\n    case InterpolationType::Quadratic:\n      return _interpolation::fastweight2(x);\n    case InterpolationType::Cubic:\n      return _interpolation::fastweight3(x);\n    case InterpolationType::FourthOrder:\n      return _interpolation::fastweight4(x);\n    case InterpolationType::FifthOrder:\n      return _interpolation::fastweight5(x);\n    case InterpolationType::SixthOrder:\n      return _interpolation::fastweight6(x);\n    case InterpolationType::SeventhOrder:\n      return _interpolation::fastweight7(x);\n    default:\n      return _interpolation::fastweight1(x);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t grad(InterpolationType interpolation_type, scalar_t x) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::grad0(x);\n    case InterpolationType::Linear:\n      return _interpolation::grad1(x);\n    case InterpolationType::Quadratic:\n      return _interpolation::grad2(x);\n    case InterpolationType::Cubic:\n      return _interpolation::grad3(x);\n    case InterpolationType::FourthOrder:\n      return _interpolation::grad4(x);\n    case InterpolationType::FifthOrder:\n      return _interpolation::grad5(x);\n    case InterpolationType::SixthOrder:\n      return _interpolation::grad6(x);\n    case InterpolationType::SeventhOrder:\n      return _interpolation::grad7(x);\n    default:\n      return _interpolation::grad1(x);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fastgrad(InterpolationType interpolation_type, scalar_t x) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::fastgrad0(x);\n    case InterpolationType::Linear:\n      return _interpolation::fastgrad1(x);\n    case InterpolationType::Quadratic:\n      return _interpolation::fastgrad2(x);\n    case InterpolationType::Cubic:\n      return _interpolation::fastgrad3(x);\n    case InterpolationType::FourthOrder:\n      return _interpolation::fastgrad4(x);\n    case InterpolationType::FifthOrder:\n      return _interpolation::fastgrad5(x);\n    case InterpolationType::SixthOrder:\n      return _interpolation::fastgrad6(x);\n    case InterpolationType::SeventhOrder:\n      return _interpolation::fastgrad7(x);\n    default:\n      return _interpolation::fastgrad1(x);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t hess(InterpolationType interpolation_type, scalar_t x) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::hess0(x);\n    case InterpolationType::Linear:\n      return _interpolation::hess1(x);\n    case InterpolationType::Quadratic:\n      return _interpolation::hess2(x);\n    case InterpolationType::Cubic:\n      return _interpolation::hess3(x);\n    case InterpolationType::FourthOrder:\n      return _interpolation::hess4(x);\n    case InterpolationType::FifthOrder:\n      return _interpolation::hess0(x); // notimplemented\n    case InterpolationType::SixthOrder:\n      return _interpolation::hess0(x); // notimplemented\n    case InterpolationType::SeventhOrder:\n      return _interpolation::hess0(x); // notimplemented\n    default:\n      return _interpolation::grad1(x);\n  }\n}\n\ntemplate <typename scalar_t>\nstatic MONAI_INLINE MONAI_DEVICE scalar_t fasthess(InterpolationType interpolation_type, scalar_t x) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::fasthess0(x);\n    case InterpolationType::Linear:\n      return _interpolation::fasthess1(x);\n    case InterpolationType::Quadratic:\n      return _interpolation::fasthess2(x);\n    case InterpolationType::Cubic:\n      return _interpolation::fasthess3(x);\n    case InterpolationType::FourthOrder:\n      return _interpolation::fasthess4(x);\n    case InterpolationType::FifthOrder:\n      return _interpolation::fasthess0(x); // notimplemented\n    case InterpolationType::SixthOrder:\n      return _interpolation::fasthess0(x); // notimplemented\n    case InterpolationType::SeventhOrder:\n      return _interpolation::fasthess0(x); // notimplemented\n    default:\n      return _interpolation::fasthess1(x);\n  }\n}\n\ntemplate <typename scalar_t, typename offset_t>\nstatic MONAI_INLINE MONAI_DEVICE void bounds(\n    InterpolationType interpolation_type,\n    scalar_t x,\n    offset_t& low,\n    offset_t& upp) {\n  switch (interpolation_type) {\n    case InterpolationType::Nearest:\n      return _interpolation::bounds0(x, low, upp);\n    case InterpolationType::Linear:\n      return _interpolation::bounds1(x, low, upp);\n    case InterpolationType::Quadratic:\n      return _interpolation::bounds2(x, low, upp);\n    case InterpolationType::Cubic:\n      return _interpolation::bounds3(x, low, upp);\n    case InterpolationType::FourthOrder:\n      return _interpolation::bounds4(x, low, upp);\n    case InterpolationType::FifthOrder:\n      return _interpolation::bounds5(x, low, upp);\n    case InterpolationType::SixthOrder:\n      return _interpolation::bounds6(x, low, upp);\n    case InterpolationType::SeventhOrder:\n      return _interpolation::bounds7(x, low, upp);\n    default:\n      return _interpolation::bounds1(x, low, upp);\n  }\n}\n\n} // namespace interpolation\n\n} // namespace monai\n"
  },
  {
    "path": "monai/csrc/resample/pushpull.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// adapted from https://github.com/balbasty/nitorch\n\n#include <ATen/ATen.h>\n#include <deque>\n#include <tuple>\n#include <vector>\n#include \"utils/common_utils.h\"\n#include \"utils/resample_utils.h\"\n\n#define MONAI_PUSHPULL_DECLARE(space)                                            \\\n  namespace space {                                                              \\\n  template <typename BoundType, typename InterpolationType, typename SourceType> \\\n  std::deque<at::Tensor> pushpull(                                               \\\n      const SourceType& source,                                                  \\\n      const at::Tensor& grid,                                                    \\\n      BoundType bound,                                                           \\\n      InterpolationType interpolation,                                           \\\n      bool extrapolate,                                                          \\\n      bool do_pull,                                                              \\\n      bool do_push,                                                              \\\n      bool do_count,                                                             \\\n      bool do_grad,                                                              \\\n      bool do_sgrad);                                                            \\\n  template <typename BoundType, typename InterpolationType, typename SourceType> \\\n  std::deque<at::Tensor> pushpull(                                               \\\n      const SourceType& source,                                                  \\\n      const at::Tensor& grid,                                                    \\\n      const at::Tensor& target,                                                  \\\n      BoundType bound,                                                           \\\n      InterpolationType interpolation,                                           \\\n      bool extrapolate,                                                          \\\n      bool do_pull,                                                              \\\n      bool do_push,                                                              \\\n      bool do_count,                                                             \\\n      bool do_grad,                                                              \\\n      bool do_sgrad);                                                            \\\n  }\n\nnamespace monai {\n\nMONAI_PUSHPULL_DECLARE(cpu)\nMONAI_PUSHPULL_DECLARE(cuda)\n\n// PULL\nat::Tensor grid_pull(\n    const at::Tensor& input,\n    const at::Tensor& grid,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  CHECK_DEFINED(input)\n  CHECK_DEFINED(grid)\n  auto input_opt = input.options();\n  auto grid_opt = grid.options();\n  CHECK_STRIDED(input_opt)\n  CHECK_STRIDED(grid_opt)\n  CHECK_SAME_DEVICE(input_opt, grid_opt)\n  CHECK_SAME_DTYPE(input_opt, grid_opt)\n  CHECK_SPATIAL_1D_2D_OR_3D(input)\n  CHECK_SPATIAL_1D_2D_OR_3D(grid)\n  CHECK_GRID_COMPONENT(grid, grid.dim())\n  CHECK_SPATIAL_NOT_EMPTY(input)\n  CHECK_SPATIAL_NOT_EMPTY(grid)\n  CHECK_VEC_NOT_EMPTY(bound_mode);\n  CHECK_VEC_NOT_EMPTY(interpolation_mode);\n\n  if (input.is_cuda())\n#ifdef WITH_CUDA\n    return cuda::pushpull(\n               input,\n               grid,\n               BoundVectorRef(bound_mode),\n               InterpolationVectorRef(interpolation_mode),\n               extrapolate,\n               true,\n               false,\n               false,\n               false,\n               false)\n        .front();\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  else\n    return cpu::pushpull(\n               input,\n               grid,\n               BoundVectorRef(bound_mode),\n               InterpolationVectorRef(interpolation_mode),\n               extrapolate,\n               true,\n               false,\n               false,\n               false,\n               false)\n        .front();\n}\n\nstd::deque<at::Tensor> grid_pull_backward(\n    const at::Tensor& grad,\n    const at::Tensor& input,\n    const at::Tensor& grid,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  if (input.is_cuda()) {\n#ifdef WITH_CUDA\n    return cuda::pushpull(\n        input,\n        grid,\n        grad,\n        BoundVectorRef(bound_mode),\n        InterpolationVectorRef(interpolation_mode),\n        extrapolate,\n        false,\n        input.requires_grad(),\n        false,\n        grid.requires_grad(),\n        false);\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  } else {\n    return cpu::pushpull(\n        input,\n        grid,\n        grad,\n        BoundVectorRef(bound_mode),\n        InterpolationVectorRef(interpolation_mode),\n        extrapolate,\n        false,\n        input.requires_grad(),\n        false,\n        grid.requires_grad(),\n        false);\n  }\n}\n\n// PUSH\nat::Tensor grid_push(\n    const at::Tensor& input,\n    const at::Tensor& grid,\n    c10::IntArrayRef source_size,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  CHECK_DEFINED(input)\n  CHECK_DEFINED(grid)\n  auto input_opt = input.options();\n  auto grid_opt = grid.options();\n  CHECK_STRIDED(input_opt)\n  CHECK_STRIDED(grid_opt)\n  CHECK_SAME_DEVICE(input_opt, grid_opt)\n  CHECK_SAME_DTYPE(input_opt, grid_opt)\n  CHECK_SPATIAL_1D_2D_OR_3D(input)\n  CHECK_SPATIAL_1D_2D_OR_3D(grid)\n  CHECK_GRID_COMPONENT(grid, grid.dim())\n  CHECK_SPATIAL_NOT_EMPTY(input)\n  CHECK_SPATIAL_NOT_EMPTY(grid)\n  CHECK_GRID_TARGET_COMPAT(grid, input)\n  CHECK_VEC_NOT_EMPTY(bound_mode);\n  CHECK_VEC_NOT_EMPTY(interpolation_mode);\n\n  if (source_size.empty()) {\n    auto size = c10::IntArrayRef(\n        {input.dim() >= 3 ? input.size(2) : 1,\n         input.dim() >= 4 ? input.size(3) : 1,\n         input.dim() >= 5 ? input.size(4) : 1});\n    if (input.is_cuda())\n#ifdef WITH_CUDA\n      return cuda::pushpull(\n                 size,\n                 grid,\n                 input,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 true,\n                 false,\n                 false,\n                 false)\n          .front();\n#else\n      AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n    else\n      return cpu::pushpull(\n                 size,\n                 grid,\n                 input,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 true,\n                 false,\n                 false,\n                 false)\n          .front();\n  } else {\n    CHECK_SPATIAL_LENGTH(source_size, grid.dim())\n    if (input.is_cuda())\n#ifdef WITH_CUDA\n      return cuda::pushpull(\n                 source_size,\n                 grid,\n                 input,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 true,\n                 false,\n                 false,\n                 false)\n          .front();\n#else\n      AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n    else\n      return cpu::pushpull(\n                 source_size,\n                 grid,\n                 input,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 true,\n                 false,\n                 false,\n                 false)\n          .front();\n  }\n}\n\nstd::deque<at::Tensor> grid_push_backward(\n    const at::Tensor& grad,\n    const at::Tensor& input,\n    const at::Tensor& grid,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  if (input.is_cuda()) {\n#ifdef WITH_CUDA\n    return cuda::pushpull(\n        grad,\n        grid,\n        input,\n        BoundVectorRef(bound_mode),\n        InterpolationVectorRef(interpolation_mode),\n        extrapolate,\n        input.requires_grad(),\n        false,\n        false,\n        grid.requires_grad(),\n        false);\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  } else {\n    return cpu::pushpull(\n        grad,\n        grid,\n        input,\n        BoundVectorRef(bound_mode),\n        InterpolationVectorRef(interpolation_mode),\n        extrapolate,\n        input.requires_grad(),\n        false,\n        false,\n        grid.requires_grad(),\n        false);\n  }\n}\n\n// COUNT\nat::Tensor grid_count(\n    const at::Tensor& grid,\n    c10::IntArrayRef source_size,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  CHECK_DEFINED(grid)\n  auto grid_opt = grid.options();\n  CHECK_STRIDED(grid_opt)\n  CHECK_SPATIAL_1D_2D_OR_3D(grid)\n  CHECK_GRID_COMPONENT(grid, grid.dim())\n  CHECK_SPATIAL_NOT_EMPTY(grid)\n  CHECK_VEC_NOT_EMPTY(bound_mode);\n  CHECK_VEC_NOT_EMPTY(interpolation_mode);\n\n  if (source_size.empty()) {\n    auto size = c10::IntArrayRef(\n        {grid.dim() >= 3 ? grid.size(2) : 1, grid.dim() >= 4 ? grid.size(3) : 1, grid.dim() >= 5 ? grid.size(4) : 1});\n    if (grid.is_cuda())\n#ifdef WITH_CUDA\n      return cuda::pushpull(\n                 size,\n                 grid,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 false,\n                 true,\n                 false,\n                 false)\n          .front();\n#else\n      AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n    else\n      return cpu::pushpull(\n                 size,\n                 grid,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 false,\n                 true,\n                 false,\n                 false)\n          .front();\n  } else {\n    CHECK_SPATIAL_LENGTH(source_size, grid.dim())\n    if (grid.is_cuda())\n#ifdef WITH_CUDA\n      return cuda::pushpull(\n                 source_size,\n                 grid,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 false,\n                 true,\n                 false,\n                 false)\n          .front();\n#else\n      AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n    else\n      return cpu::pushpull(\n                 source_size,\n                 grid,\n                 BoundVectorRef(bound_mode),\n                 InterpolationVectorRef(interpolation_mode),\n                 extrapolate,\n                 false,\n                 false,\n                 true,\n                 false,\n                 false)\n          .front();\n  }\n}\n\nat::Tensor grid_count_backward(\n    const at::Tensor& grad,\n    const at::Tensor& grid,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  if (grid.is_cuda()) {\n#ifdef WITH_CUDA\n    return cuda::pushpull(\n               grad,\n               grid,\n               BoundVectorRef(bound_mode),\n               InterpolationVectorRef(interpolation_mode),\n               extrapolate,\n               false,\n               false,\n               false,\n               grid.requires_grad(),\n               false)\n        .front();\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  } else {\n    return cpu::pushpull(\n               grad,\n               grid,\n               BoundVectorRef(bound_mode),\n               InterpolationVectorRef(interpolation_mode),\n               extrapolate,\n               false,\n               false,\n               false,\n               grid.requires_grad(),\n               false)\n        .front();\n  }\n}\n\n// PULL GRADIENTS\nat::Tensor grid_grad(\n    const at::Tensor& input,\n    const at::Tensor& grid,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  CHECK_DEFINED(input)\n  CHECK_DEFINED(grid)\n  auto input_opt = input.options();\n  auto grid_opt = grid.options();\n  CHECK_STRIDED(input_opt)\n  CHECK_STRIDED(grid_opt)\n  CHECK_SAME_DEVICE(input_opt, grid_opt)\n  CHECK_SAME_DTYPE(input_opt, grid_opt)\n  CHECK_SPATIAL_1D_2D_OR_3D(input)\n  CHECK_SPATIAL_1D_2D_OR_3D(grid)\n  CHECK_GRID_COMPONENT(grid, grid.dim())\n  CHECK_SPATIAL_NOT_EMPTY(input)\n  CHECK_SPATIAL_NOT_EMPTY(grid)\n  CHECK_VEC_NOT_EMPTY(bound_mode);\n  CHECK_VEC_NOT_EMPTY(interpolation_mode);\n\n  if (input.is_cuda())\n#ifdef WITH_CUDA\n    return cuda::pushpull(\n               input,\n               grid,\n               BoundVectorRef(bound_mode),\n               InterpolationVectorRef(interpolation_mode),\n               extrapolate,\n               false,\n               false,\n               false,\n               false,\n               true)\n        .front();\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  else\n    return cpu::pushpull(\n               input,\n               grid,\n               BoundVectorRef(bound_mode),\n               InterpolationVectorRef(interpolation_mode),\n               extrapolate,\n               false,\n               false,\n               false,\n               false,\n               true)\n        .front();\n}\n\nstd::deque<at::Tensor> grid_grad_backward(\n    const at::Tensor& grad,\n    const at::Tensor& input,\n    const at::Tensor& grid,\n    const std::vector<BoundType>& bound_mode,\n    const std::vector<InterpolationType>& interpolation_mode,\n    bool extrapolate) {\n  if (input.is_cuda()) {\n#ifdef WITH_CUDA\n    return cuda::pushpull(\n        input,\n        grid,\n        grad,\n        BoundVectorRef(bound_mode),\n        InterpolationVectorRef(interpolation_mode),\n        extrapolate,\n        false,\n        input.requires_grad(),\n        false,\n        grid.requires_grad(),\n        false);\n#else\n    AT_ERROR(\"Not compiled with GPU support.\");\n#endif\n  } else {\n    return cpu::pushpull(\n        input,\n        grid,\n        grad,\n        BoundVectorRef(bound_mode),\n        InterpolationVectorRef(interpolation_mode),\n        extrapolate,\n        false,\n        input.requires_grad(),\n        false,\n        grid.requires_grad(),\n        false);\n  }\n}\n\n} // namespace monai\n"
  },
  {
    "path": "monai/csrc/resample/pushpull_cpu.cpp",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// adapted from https://github.com/balbasty/nitorch\n\n// This file implements spline interpolation / sampling and its adjoint\n// operations. It corresponds loosely to torch's `GridSampler`.\n// It handles boundary conditions and interpolation orders defined in\n// `utils/resample_utils.h` and `utils/resample_utils.h`.\n// These parameters can be specified per dimension.\n// Isotropic 0-th and 1-st order interpolation have their own (faster)\n// implementations. Sliding boundary conditions are also implemented\n// separately.\n\n// TODO:\n// . [DONE] generic 3d\n// . [DONE] generic 2d\n// . [DONE] generic 1d\n// . sliding nearest 3d\n// . sliding nearest 2d\n// . sliding linear 3d\n// . sliding linear 2d\n// . sliding generic 3d\n// . sliding generic 2d\n// . [DONE] spatial gradient mode (without multiplication with output gradient)\n// . [DONE] second order gradients (backward pass for spatial gradients)\n// . performance tests\n// . input bound/inter are always vectors -> clean unused constructors\n\n#include <ATen/ATen.h>\n#include <limits>\n#include <tuple>\n#include \"bounds_common.h\"\n#include \"interpolation_common.h\"\n#include \"utils/resample_utils.h\"\n//#include <cstdio>\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n// CPU-specific parameters\n#include <ATen/Parallel.h>\nnamespace {\n// This parameter specifies the minimum number of voxels that should be\n// processed on a single processor in the parallel for loop .\nint64_t GRAIN_SIZE = static_cast<int64_t>(at::internal::GRAIN_SIZE);\n} // namespace\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n// maximum number of channels\n// > not used in mode isotropic nearest/linear\n#ifndef MONAI_MAX_NUM_CHANNELS\n#define MONAI_MAX_NUM_CHANNELS 1024\n#endif\n\n// This parameter allows for a little bit of tolerance when considering\n// a coordinate as \"out-of-bound\" (if !extrapolate)\n#define TINY 5e-2\n\nusing at::Tensor;\nusing at::TensorOptions;\nusing c10::IntArrayRef;\n\nnamespace monai {\nMONAI_NAMESPACE_DEVICE { // cpu\n\n  namespace { // anonymous namespace > everything inside has internal linkage\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                        INDEXING UTILS\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  // This class reads and sets all the parameters that will later be used\n  // by the algorithm in PushPullImpl. All of this is done outside of the\n  // implementation class so that we do not depend on generic types. The\n  // point is to pre-allocate all necessary tensors so that we can check\n  // if they're all compatible with 32 bit math. If it's the case, we can\n  // dispatch to a 32b cuda implementation, which might increase\n  // performance. Else, we use 64 bit math to compute offsets.\n  // (On CPU, we always use 64 bit offsets because it doesn't make a huge\n  // difference. It would be different if we had a vectorized\n  // implementation as in PyTorch).\n  class PushPullAllocator {\n   public:\n    static constexpr int64_t max_int32 = std::numeric_limits<int32_t>::max();\n\n    // ~~~ CONSTRUCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    MONAI_HOST\n    PushPullAllocator(\n        int dim,\n        BoundVectorRef bound,\n        InterpolationVectorRef interpolation,\n        bool extrapolate,\n        bool do_pull,\n        bool do_push,\n        bool do_count,\n        bool do_grad,\n        bool do_sgrad)\n        : dim(dim),\n          bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),\n          bound1(\n              bound.size() > 1       ? bound[1]\n                  : bound.size() > 0 ? bound[0]\n                                     : BoundType::Replicate),\n          bound2(\n              bound.size() > 2       ? bound[2]\n                  : bound.size() > 1 ? bound[1]\n                  : bound.size() > 0 ? bound[0]\n                                     : BoundType::Replicate),\n          interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),\n          interpolation1(\n              interpolation.size() > 1       ? interpolation[1]\n                  : interpolation.size() > 0 ? interpolation[0]\n                                             : InterpolationType::Linear),\n          interpolation2(\n              interpolation.size() > 2       ? interpolation[2]\n                  : interpolation.size() > 1 ? interpolation[1]\n                  : interpolation.size() > 0 ? interpolation[0]\n                                             : InterpolationType::Linear),\n          extrapolate(extrapolate),\n          do_pull(do_pull),\n          do_push(do_push),\n          do_count(do_count),\n          do_grad(do_grad),\n          do_sgrad(do_sgrad) {\n      iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;\n    }\n\n    // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    // Usually used for pull:\n    // - do_pull  -> return source[grid]\n    // - do_push  -> fails\n    // - do_grad  -> return J(source)[grid]\n    // - do_sgrad -> return H(source)[grid]\n    MONAI_HOST void ioset(const Tensor& source, const Tensor& grid) {\n      init_all();\n      init_source(source);\n      init_grid(grid);\n      init_output();\n    }\n\n    // Usually used for pull_backward:\n    // - do_pull  -> return source[grid]\n    // - do_push  -> return push(target, grid, source.shape)\n    // - do_grad  -> return J(source)[grid]\n    // - do_sgrad -> return H(source)[grid]\n    MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) {\n      init_all();\n      init_source(source);\n      init_grid(grid);\n      init_target(target);\n      init_output();\n    }\n\n    // Usually used for push:\n    // - do_pull  -> fails\n    // - do_push  -> return push(target, grid, source_size)\n    // - do_grad  -> fails\n    // - do_sgrad -> fails\n    MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid, const Tensor& target) {\n      init_all();\n      init_source(source_size);\n      init_grid(grid);\n      init_target(target);\n      init_output();\n    }\n\n    // Usually used for count:\n    // - do_pull  -> fails\n    // - do_push  -> return push(ones, grid, source_size)\n    // - do_grad  -> fails\n    // - do_sgrad -> fails\n    MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid) {\n      init_all();\n      init_source(source_size);\n      init_grid(grid);\n      init_output();\n    }\n\n    // We just check that all tensors that we own are compatible with 32b math\n    bool canUse32BitIndexMath(int64_t max_elem = max_int32) const {\n      return src_32b_ok && trgt_32b_ok && grid_32b_ok && grad_32b_ok && out_32b_ok;\n    }\n\n   private:\n    // Copied from aten/src/ATen/native/IndexingUtils.cpp in PyTorch 1.6.\n    // It is used to decide to which pointer type we should dispatch to.\n    // Basically, we need to make sure that the \"furthest\" element we need\n    // to reach is less than max_elem away.\n    static bool tensorCanUse32BitIndexMath(const Tensor& t, int64_t max_elem = max_int32) {\n      int64_t elements = t.numel();\n      if (elements >= max_elem) {\n        return false;\n      }\n      if (elements == 0) {\n        return max_elem > 0;\n      }\n\n      int64_t offset = 0;\n      int64_t linearId = elements - 1;\n\n      // NOTE: Assumes all strides are positive, which is true for now\n      for (int i = t.dim() - 1; i >= 0; --i) {\n        int64_t curDimIndex = linearId % t.size(i);\n        int64_t curDimOffset = curDimIndex * t.stride(i);\n        offset += curDimOffset;\n        linearId /= t.size(i);\n      }\n\n      if (offset >= max_elem) {\n        return false;\n      }\n\n      return true;\n    }\n\n    // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    MONAI_HOST void init_all();\n    MONAI_HOST void init_source(const Tensor& source);\n    MONAI_HOST void init_source(IntArrayRef source_size);\n    MONAI_HOST void init_grid(const Tensor& grid);\n    MONAI_HOST void init_target(const Tensor& target);\n    MONAI_HOST void init_output();\n\n    // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    int dim; // dimensionality (2 or 3)\n    BoundType bound0; // boundary condition  // x|W\n    BoundType bound1; // boundary condition  // y|H\n    BoundType bound2; // boundary condition  // z|D\n    InterpolationType interpolation0; // interpolation order // x|W\n    InterpolationType interpolation1; // interpolation order // y|H\n    InterpolationType interpolation2; // interpolation order // z|D\n    bool iso; // isotropic interpolation?\n    bool extrapolate; // compute out-of-bound values\n    bool do_pull; // sample a volume\n    bool do_push; // splat a volume\n    bool do_count; // splatting weights (= jacobian determinant)\n    bool do_grad; // backprop: gradient of grid // pull\n    bool do_sgrad; // sample spatial gradients\n\n    // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    std::deque<Tensor> output;\n    TensorOptions src_opt;\n    TensorOptions grid_opt;\n    TensorOptions trgt_opt;\n    int64_t N;\n    int64_t C;\n    int64_t src_X;\n    int64_t src_Y;\n    int64_t src_Z;\n    int64_t trgt_X;\n    int64_t trgt_Y;\n    int64_t trgt_Z;\n    int64_t trgt_K;\n    int64_t src_sN;\n    int64_t src_sC;\n    int64_t src_sX;\n    int64_t src_sY;\n    int64_t src_sZ;\n    bool src_32b_ok;\n    void* src_ptr;\n    int64_t trgt_sN;\n    int64_t trgt_sC;\n    int64_t trgt_sX;\n    int64_t trgt_sY;\n    int64_t trgt_sZ;\n    int64_t trgt_sK;\n    bool trgt_32b_ok;\n    void* trgt_ptr;\n    int64_t grid_sN;\n    int64_t grid_sC;\n    int64_t grid_sX;\n    int64_t grid_sY;\n    int64_t grid_sZ;\n    bool grid_32b_ok;\n    void* grid_ptr;\n    int64_t out_sN;\n    int64_t out_sC;\n    int64_t out_sX;\n    int64_t out_sY;\n    int64_t out_sZ;\n    int64_t out_sK; // gradient dimension\n    bool out_32b_ok;\n    void* out_ptr;\n    int64_t grad_sN;\n    int64_t grad_sC;\n    int64_t grad_sX;\n    int64_t grad_sY;\n    int64_t grad_sZ;\n    bool grad_32b_ok;\n    void* grad_ptr;\n\n    // Allow PushPullImpl's constructor to access PushPullAllocator's\n    // private members.\n    template <typename scalar_t, typename offset_t>\n    friend class PushPullImpl;\n  };\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                          INITIALISATION\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  MONAI_HOST\n  void PushPullAllocator::init_all() {\n    src_opt = grid_opt = trgt_opt = TensorOptions();\n    N = C = 1L;\n    src_X = src_Y = src_Z = 1L;\n    trgt_X = trgt_Y = trgt_Z = 1L;\n    trgt_K = 0L;\n    src_sN = src_sC = src_sX = src_sY = src_sZ = 0L;\n    grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = 0L;\n    grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = 0L;\n    trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = 0L;\n    out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = 0L;\n    src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast<float*>(0);\n    src_32b_ok = trgt_32b_ok = grid_32b_ok = out_32b_ok = grad_32b_ok = true;\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_source(const Tensor& source) {\n    N = source.size(0);\n    C = source.size(1);\n    src_X = source.size(2);\n    src_Y = dim < 2 ? 1L : source.size(3);\n    src_Z = dim < 3 ? 1L : source.size(4);\n    src_sN = source.stride(0);\n    src_sC = source.stride(1);\n    src_sX = source.stride(2);\n    src_sY = dim < 2 ? 0L : source.stride(3);\n    src_sZ = dim < 3 ? 0L : source.stride(4);\n    src_ptr = source.data_ptr();\n    src_opt = source.options();\n    src_32b_ok = tensorCanUse32BitIndexMath(source);\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_source(IntArrayRef source_size) {\n    src_X = source_size[0];\n    src_Y = dim < 2 ? 1L : source_size[1];\n    src_Z = dim < 3 ? 1L : source_size[2];\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_grid(const Tensor& grid) {\n    N = grid.size(0);\n    trgt_X = grid.size(1);\n    trgt_Y = dim < 2 ? 1L : grid.size(2);\n    trgt_Z = dim < 3 ? 1L : grid.size(3);\n    grid_sN = grid.stride(0);\n    grid_sX = grid.stride(1);\n    grid_sY = dim < 2 ? 0L : grid.stride(2);\n    grid_sZ = dim < 3 ? 0L : grid.stride(3);\n    grid_sC = grid.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);\n    grid_ptr = grid.data_ptr();\n    grid_opt = grid.options();\n    grid_32b_ok = tensorCanUse32BitIndexMath(grid);\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_target(const Tensor& target) {\n    N = target.size(0);\n    C = target.size(1);\n    trgt_X = target.size(2);\n    trgt_Y = dim < 2 ? 1L : target.size(3);\n    trgt_Z = dim < 3 ? 1L : target.size(4);\n    trgt_K = target.dim() == dim + 3 ? target.size(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;\n    trgt_sN = target.stride(0);\n    trgt_sC = target.stride(1);\n    trgt_sX = target.stride(2);\n    trgt_sY = dim < 2 ? 0L : target.stride(3);\n    trgt_sZ = dim < 3 ? 0L : target.stride(4);\n    trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;\n    trgt_ptr = target.data_ptr();\n    trgt_opt = target.options();\n    trgt_32b_ok = tensorCanUse32BitIndexMath(target);\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_output() {\n    output.clear();\n    if (do_pull) {\n      if (dim == 1)\n        output.push_back(at::empty({N, C, trgt_X}, src_opt));\n      else if (dim == 2)\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt));\n      else\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt));\n      auto pull = output.back();\n      out_sN = pull.stride(0);\n      out_sC = pull.stride(1);\n      out_sX = pull.stride(2);\n      out_sY = dim < 2 ? 0L : pull.stride(3);\n      out_sZ = dim < 3 ? 0L : pull.stride(4);\n      out_sK = 0L;\n      out_ptr = pull.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(pull);\n    } else if (do_sgrad) {\n      if (dim == 1)\n        output.push_back(at::empty({N, C, trgt_X, 1}, src_opt));\n      else if (dim == 2)\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt));\n      else\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt));\n      auto sgrad = output.back();\n      out_sN = sgrad.stride(0);\n      out_sC = sgrad.stride(1);\n      out_sX = sgrad.stride(2);\n      out_sY = dim < 2 ? 0L : sgrad.stride(3);\n      out_sZ = dim < 3 ? 0L : sgrad.stride(4);\n      out_sK = sgrad.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5);\n      out_ptr = sgrad.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(sgrad);\n\n      if (iso && interpolation0 == InterpolationType::Nearest)\n        sgrad.zero_();\n      if (iso && interpolation0 == InterpolationType::Linear && dim == 1)\n        sgrad.zero_();\n    } else if (do_push) {\n      if (dim == 1)\n        output.push_back(at::zeros({N, C, src_X}, trgt_opt));\n      else if (dim == 2)\n        output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt));\n      else\n        output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt));\n      auto push = output.back();\n      out_sN = push.stride(0);\n      out_sC = push.stride(1);\n      out_sX = push.stride(2);\n      out_sY = dim < 2 ? 0L : push.stride(3);\n      out_sZ = dim < 3 ? 0L : push.stride(4);\n      out_sK = 0L;\n      out_ptr = push.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(push);\n    } else if (do_count) {\n      if (dim == 1)\n        output.push_back(at::zeros({N, 1, src_X}, grid_opt));\n      else if (dim == 2)\n        output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt));\n      else\n        output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt));\n      auto count = output.back();\n      out_sN = count.stride(0);\n      out_sC = count.stride(1);\n      out_sX = count.stride(2);\n      out_sY = dim < 2 ? 0L : count.stride(3);\n      out_sZ = dim < 3 ? 0L : count.stride(4);\n      out_sK = 0L;\n      out_ptr = count.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(count);\n    }\n    if (do_grad) {\n      if (dim == 1)\n        output.push_back(at::zeros({N, trgt_X, 1}, grid_opt));\n      else if (dim == 2)\n        output.push_back(at::zeros({N, trgt_X, trgt_Y, 2}, grid_opt));\n      else\n        output.push_back(at::zeros({N, trgt_X, trgt_Y, trgt_Z, 3}, grid_opt));\n      auto grad = output.back();\n      grad_sN = grad.stride(0);\n      grad_sX = grad.stride(1);\n      grad_sY = dim < 2 ? 0L : grad.stride(2);\n      grad_sZ = dim < 3 ? 0L : grad.stride(3);\n      grad_sC = grad.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);\n      grad_ptr = grad.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(grad);\n\n      if (iso && interpolation0 == InterpolationType::Nearest)\n        grad.zero_();\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                        GENERIC PUSHPULL CLASS\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  // This class implements the bulk of the code.\n  // /!\\ No type and shape checking is performed here.\n\n  template <typename scalar_t, typename offset_t>\n  class PushPullImpl {\n   public:\n    // ~~~ CONSTRUCTOR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    PushPullImpl(const PushPullAllocator& info)\n        : output(info.output),\n          dim(info.dim),\n          bound0(info.bound0),\n          bound1(info.bound1),\n          bound2(info.bound2),\n          interpolation0(info.interpolation0),\n          interpolation1(info.interpolation1),\n          interpolation2(info.interpolation1),\n          iso(info.iso),\n          extrapolate(info.extrapolate),\n          do_pull(info.do_pull),\n          do_push(info.do_push),\n          do_count(info.do_count),\n          do_grad(info.do_grad),\n          do_sgrad(info.do_sgrad),\n          N(static_cast<offset_t>(info.N)),\n          C(static_cast<offset_t>(info.C)),\n          src_X(static_cast<offset_t>(info.src_X)),\n          src_Y(static_cast<offset_t>(info.src_Y)),\n          src_Z(static_cast<offset_t>(info.src_Z)),\n          trgt_X(static_cast<offset_t>(info.trgt_X)),\n          trgt_Y(static_cast<offset_t>(info.trgt_Y)),\n          trgt_Z(static_cast<offset_t>(info.trgt_Z)),\n          trgt_K(static_cast<offset_t>(info.trgt_K)),\n          src_sN(static_cast<offset_t>(info.src_sN)),\n          src_sC(static_cast<offset_t>(info.src_sC)),\n          src_sX(static_cast<offset_t>(info.src_sX)),\n          src_sY(static_cast<offset_t>(info.src_sY)),\n          src_sZ(static_cast<offset_t>(info.src_sZ)),\n          src_ptr(static_cast<scalar_t*>(info.src_ptr)),\n          trgt_sN(static_cast<offset_t>(info.trgt_sN)),\n          trgt_sC(static_cast<offset_t>(info.trgt_sC)),\n          trgt_sX(static_cast<offset_t>(info.trgt_sX)),\n          trgt_sY(static_cast<offset_t>(info.trgt_sY)),\n          trgt_sZ(static_cast<offset_t>(info.trgt_sZ)),\n          trgt_sK(static_cast<offset_t>(info.trgt_sK)),\n          trgt_ptr(static_cast<scalar_t*>(info.trgt_ptr)),\n          grid_sN(static_cast<offset_t>(info.grid_sN)),\n          grid_sC(static_cast<offset_t>(info.grid_sC)),\n          grid_sX(static_cast<offset_t>(info.grid_sX)),\n          grid_sY(static_cast<offset_t>(info.grid_sY)),\n          grid_sZ(static_cast<offset_t>(info.grid_sZ)),\n          grid_ptr(static_cast<scalar_t*>(info.grid_ptr)),\n          out_sN(static_cast<offset_t>(info.out_sN)),\n          out_sC(static_cast<offset_t>(info.out_sC)),\n          out_sX(static_cast<offset_t>(info.out_sX)),\n          out_sY(static_cast<offset_t>(info.out_sY)),\n          out_sZ(static_cast<offset_t>(info.out_sZ)),\n          out_sK(static_cast<offset_t>(info.out_sK)),\n          out_ptr(static_cast<scalar_t*>(info.out_ptr)),\n          grad_sN(static_cast<offset_t>(info.grad_sN)),\n          grad_sC(static_cast<offset_t>(info.grad_sC)),\n          grad_sX(static_cast<offset_t>(info.grad_sX)),\n          grad_sY(static_cast<offset_t>(info.grad_sY)),\n          grad_sZ(static_cast<offset_t>(info.grad_sZ)),\n          grad_ptr(static_cast<scalar_t*>(info.grad_ptr)) {}\n\n    // ~~~ PUBLIC VALUE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    std::deque<Tensor> output;\n\n    // MONAI_HOST MONAI_DEVICE void printInfo() const {\n    //   printf(\"dim: %d\\n\", dim);\n    //   printf(\"do_pull:  %d\\n\", do_pull);\n    //   printf(\"do_push:  %d\\n\", do_push);\n    //   printf(\"do_count: %d\\n\", do_count);\n    //   printf(\"do_sgrad: %d\\n\", do_sgrad);\n    //   printf(\"do_grad:  %d\\n\", do_grad);\n    //   printf(\"bound:         [%d %d %d]\\n\", static_cast<int>(bound0),\n    //     static_cast<int>(bound1), static_cast<int>(bound2));\n    //   printf(\"interpolation: [%d %d %d]\\n\", static_cast<int>(interpolation0),\n    //     static_cast<int>(interpolation1), static_cast<int>(interpolation2));\n    //   printf(\"src:  [%d %d %d]\\n\", src_Z, src_Y, src_X);\n    //   printf(\"trgt: [%d %d %d (%d)]\\n\", trgt_Z, trgt_Y, trgt_X, trgt_K);\n    //   printf(\"N: %d\\n\", N);\n    //   printf(\"C: %d\\n\", C);\n    //   printf(\"src  -> %lu\\n\", reinterpret_cast<std::uintptr_t>(src_ptr));\n    //   printf(\"trgt -> %lu\\n\", reinterpret_cast<std::uintptr_t>(trgt_ptr));\n    //   printf(\"grid -> %lu\\n\", reinterpret_cast<std::uintptr_t>(grid_ptr));\n    //   printf(\"out  -> %lu\\n\", reinterpret_cast<std::uintptr_t>(out_ptr));\n    //   printf(\"grad -> %lu\\n\", reinterpret_cast<std::uintptr_t>(grad_ptr));\n    // }\n\n    // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    // Loop over all voxels\n    void loop() const;\n\n    MONAI_HOST MONAI_DEVICE int64_t voxcount() const {\n      return N * trgt_X * trgt_Y * trgt_Z;\n    }\n\n   private:\n    // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    MONAI_DEVICE void check1d(offset_t w, offset_t n) const;\n    MONAI_DEVICE void check2d(offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void check3d(offset_t w, offset_t h, offset_t d, offset_t n) const;\n    MONAI_DEVICE void interpolate1d(scalar_t x, offset_t w, offset_t n) const;\n    MONAI_DEVICE void interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const;\n    MONAI_DEVICE void interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const;\n    MONAI_DEVICE void interpolate1d_sliding(scalar_t x, offset_t w, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate1d_sliding_nearest(scalar_t x, offset_t w, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate1d_sliding_linear(scalar_t x, offset_t w, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate2d(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void interpolate2d_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void interpolate2d_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void interpolate2d_sliding(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate2d_sliding_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n)\n        const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate2d_sliding_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n)\n        const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate3d(scalar_t x, scalar_t y, scalar_t z, offset_t w, offset_t h, offset_t d, offset_t n)\n        const;\n    MONAI_DEVICE void interpolate3d_nearest(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const;\n    MONAI_DEVICE void interpolate3d_trilinear(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const;\n    MONAI_DEVICE void interpolate3d_sliding(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate3d_sliding_nearest(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate3d_sliding_trilinear(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const { /*TODO*/\n    }\n\n    // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    int dim; // dimensionality (2 or 3)\n    BoundType bound0; // boundary condition  // x|W\n    BoundType bound1; // boundary condition  // y|H\n    BoundType bound2; // boundary condition  // z|D\n    InterpolationType interpolation0; // interpolation order // x|W\n    InterpolationType interpolation1; // interpolation order // y|H\n    InterpolationType interpolation2; // interpolation order // z|D\n    bool iso; // isotropic interpolation?\n    bool extrapolate; // compute out-of-bound values\n    bool do_pull; // sample a volume\n    bool do_push; // splat a volume\n    bool do_count; // splatting weights (= jacobian determinant)\n    bool do_grad; // backprop: gradient of grid // pull\n    bool do_sgrad; // sample spatial gradients\n\n    // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    offset_t N;\n    offset_t C;\n    offset_t src_X;\n    offset_t src_Y;\n    offset_t src_Z;\n    offset_t trgt_X;\n    offset_t trgt_Y;\n    offset_t trgt_Z;\n    offset_t trgt_K;\n    offset_t src_sN;\n    offset_t src_sC;\n    offset_t src_sX;\n    offset_t src_sY;\n    offset_t src_sZ;\n    scalar_t* src_ptr;\n    offset_t trgt_sN;\n    offset_t trgt_sC;\n    offset_t trgt_sX;\n    offset_t trgt_sY;\n    offset_t trgt_sZ;\n    offset_t trgt_sK;\n    scalar_t* trgt_ptr;\n    offset_t grid_sN;\n    offset_t grid_sC;\n    offset_t grid_sX;\n    offset_t grid_sY;\n    offset_t grid_sZ;\n    scalar_t* grid_ptr;\n    offset_t out_sN;\n    offset_t out_sC;\n    offset_t out_sX;\n    offset_t out_sY;\n    offset_t out_sZ;\n    offset_t out_sK; // gradient dimension\n    scalar_t* out_ptr;\n    offset_t grad_sN;\n    offset_t grad_sC;\n    offset_t grad_sX;\n    offset_t grad_sY;\n    offset_t grad_sZ;\n    scalar_t* grad_ptr;\n  };\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                             LOOP\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  // This bit loops over all target voxels. We therefore need to\n  // convert linear indices to multivariate indices. The way I do it\n  // might not be optimal.\n  // Note that I parallelize across all voxels (whereas ATen's grid\n  // sampler is only parallelized across batches).\n  //\n  // TODO: check that the default grain size is optimal. We do quite a lot\n  // of compute per voxel, so a smaller value might be better suited.\n  template <typename scalar_t, typename offset_t>\n  MONAI_HOST void PushPullImpl<scalar_t, offset_t>::loop() const {\n#if !(AT_PARALLEL_OPENMP)\n    if (do_push) {\n      // I do not have access to atomic operations so I cannot\n      // parallelize across voxels.\n      at::parallel_for(0, N, 0, [&](offset_t start, offset_t end) {\n        for (offset_t n = start; n < end; ++n) {\n          if (dim == 1) {\n            for (offset_t w = 0; w < trgt_X; ++w)\n              check1d(w, n);\n          } else if (dim == 2) {\n            for (offset_t h = 0; h < trgt_Y; ++h)\n              for (offset_t w = 0; w < trgt_X; ++w)\n                check2d(w, h, n);\n          } else {\n            for (offset_t d = 0; d < trgt_Z; ++d)\n              for (offset_t h = 0; h < trgt_Y; ++h)\n                for (offset_t w = 0; w < trgt_X; ++w)\n                  check3d(w, h, d, n);\n          }\n        }\n      });\n      return;\n    }\n\n#endif\n    // Parallelize across voxels\n    offset_t trgt_NXYZ = trgt_Z * trgt_Y * trgt_X * N;\n    offset_t trgt_XYZ = trgt_Z * trgt_Y * trgt_X;\n    offset_t trgt_YZ = trgt_Z * trgt_Y;\n    at::parallel_for(0, trgt_NXYZ, GRAIN_SIZE, [&](offset_t start, offset_t end) {\n      offset_t n, w, h, d;\n      for (offset_t i = start; i < end; ++i) {\n        // Convert index: linear to sub\n        n = (i / trgt_XYZ);\n        w = (i / trgt_YZ) % trgt_X;\n        h = (i / trgt_Z) % trgt_Y;\n        d = i % trgt_Z;\n\n        if (dim == 1)\n          check1d(w, n);\n        else if (dim == 2)\n          check2d(w, h, n);\n        else\n          check3d(w, h, d, n);\n      }\n    });\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                        CHECK OUT-OF-BOUND\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  // Here, we:\n  // 1) read the [x,y,z] source coordinate for the current target voxel\n  // 3) check if the source coordinate is in bounds\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const {\n    // get the corresponding input x, y, z co-ordinates from grid\n    scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ;\n    scalar_t x = *grid_ptr_NXYZ;\n    scalar_t y = grid_ptr_NXYZ[grid_sC];\n    scalar_t z = grid_ptr_NXYZ[grid_sC * 2];\n\n    // Check if out-of-bound\n    if (!(extrapolate ||\n          (inbounds(x, src_X, static_cast<scalar_t>(TINY)) && inbounds(y, src_Y, static_cast<scalar_t>(TINY)) &&\n           inbounds(z, src_Z, static_cast<scalar_t>(TINY))))) {\n      if (do_pull || do_sgrad) {\n        scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) {\n          *out_ptr_NCXYZ = static_cast<scalar_t>(0);\n          if (do_sgrad) {\n            out_ptr_NCXYZ[out_sK] = static_cast<scalar_t>(0);\n            out_ptr_NCXYZ[out_sK * 2] = static_cast<scalar_t>(0);\n          }\n        }\n      }\n      if (do_grad) {\n        scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;\n        (*grad_ptr_NXYZ) = static_cast<scalar_t>(0);\n        grad_ptr_NXYZ[grad_sC] = static_cast<scalar_t>(0);\n        grad_ptr_NXYZ[grad_sC * 2] = static_cast<scalar_t>(0);\n      }\n      return;\n    }\n\n    // Next step\n    if (bound0 == BoundType::Sliding) {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate3d_sliding_nearest(x, y, z, w, h, d, n);\n          case 1:\n            return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n);\n        }\n      return interpolate3d_sliding(x, y, z, w, h, d, n);\n    } else {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate3d_nearest(x, y, z, w, h, d, n);\n          case 1:\n            return interpolate3d_trilinear(x, y, z, w, h, d, n);\n        }\n      return interpolate3d(x, y, z, w, h, d, n);\n    }\n  }\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::check2d(offset_t w, offset_t h, offset_t n) const {\n    // get the corresponding input x, y, z co-ordinates from grid\n    scalar_t* grid_ptr_NXY = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY;\n    scalar_t x = *grid_ptr_NXY;\n    scalar_t y = grid_ptr_NXY[grid_sC];\n\n    // Check if out-of-bound\n    if (!(extrapolate ||\n          (inbounds(x, src_X, static_cast<scalar_t>(TINY)) && inbounds(y, src_Y, static_cast<scalar_t>(TINY))))) {\n      if (do_pull || do_sgrad) {\n        scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) {\n          *out_ptr_NCXY = static_cast<scalar_t>(0);\n          if (do_sgrad)\n            out_ptr_NCXY[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n      if (do_grad) {\n        scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;\n        (*grad_ptr_NXY) = static_cast<scalar_t>(0);\n        grad_ptr_NXY[grad_sC] = static_cast<scalar_t>(0);\n      }\n      return;\n    }\n\n    // Next step\n    if (bound0 == BoundType::Sliding) {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate2d_sliding_nearest(x, y, w, h, n);\n          case 1:\n            return interpolate2d_sliding_bilinear(x, y, w, h, n);\n        }\n      return interpolate2d_sliding(x, y, w, h, n);\n    } else {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate2d_nearest(x, y, w, h, n);\n          case 1:\n            return interpolate2d_bilinear(x, y, w, h, n);\n        }\n      return interpolate2d(x, y, w, h, n);\n    }\n  }\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::check1d(offset_t w, offset_t n) const {\n    // get the corresponding input x, y, z co-ordinates from grid\n    scalar_t* grid_ptr_NX = grid_ptr + n * grid_sN + w * grid_sX;\n    scalar_t x = *grid_ptr_NX;\n\n    // Check if out-of-bound\n    if (!(extrapolate || inbounds(x, src_X, static_cast<scalar_t>(TINY)))) {\n      if (do_pull || do_sgrad) {\n        scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) {\n          *out_ptr_NCX = static_cast<scalar_t>(0);\n          if (do_sgrad)\n            out_ptr_NCX[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n      if (do_grad) {\n        scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;\n        (*grad_ptr_NX) = static_cast<scalar_t>(0);\n        grad_ptr_NX[grad_sC] = static_cast<scalar_t>(0);\n      }\n      return;\n    }\n\n    // Next step\n    if (bound0 == BoundType::Sliding) {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate1d_sliding_nearest(x, w, n);\n          case 1:\n            return interpolate1d_sliding_linear(x, w, n);\n        }\n      return interpolate1d_sliding(x, w, n);\n    } else {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate1d_nearest(x, w, n);\n          case 1:\n            return interpolate1d_linear(x, w, n);\n        }\n      return interpolate1d(x, w, n);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     GENERIC INTERPOLATION 3D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate3d(\n      scalar_t x,\n      scalar_t y,\n      scalar_t z,\n      offset_t w,\n      offset_t h,\n      offset_t d,\n      offset_t n) const {\n    // Get corner pixel values from (x, y, z)\n    offset_t bx0, bx1, by0, by1, bz0, bz1;\n    interpolation::bounds(interpolation0, x, bx0, bx1);\n    interpolation::bounds(interpolation1, y, by0, by1);\n    interpolation::bounds(interpolation2, z, bz0, bz1);\n    offset_t dbx = bx1 - bx0;\n    offset_t dby = by1 - by0;\n    offset_t dbz = bz1 - bz0;\n\n    // Pre-compute offsets and target value\n    scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;\n    scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;\n    scalar_t* out_ptr_NCXYZ0 = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n    scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n    scalar_t target[3 * MONAI_MAX_NUM_CHANNELS];\n    if (trgt_ptr && (do_push || do_grad))\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC) {\n        target[c] = *trgt_ptr_NCXYZ;\n        if (trgt_K > 0) {\n          target[c + C] = trgt_ptr_NCXYZ[trgt_sK];\n          target[c + C * 2] = trgt_ptr_NCXYZ[trgt_sK * 2];\n        }\n      }\n\n    // Initialize output\n    scalar_t* out_ptr_NCXYZ = out_ptr_NCXYZ0;\n    if (do_pull || do_sgrad) {\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) {\n        *out_ptr_NCXYZ = static_cast<scalar_t>(0);\n        if (do_sgrad) {\n          out_ptr_NCXYZ[out_sK] = static_cast<scalar_t>(0);\n          out_ptr_NCXYZ[out_sK * 2] = static_cast<scalar_t>(0);\n        }\n      }\n    }\n\n    // Pre-compute indices/weights/grad\n    scalar_t wx[8], wy[8], wz[8]; // B-spline weights\n    scalar_t gx[8], gy[8], gz[8]; // B-spline derivatives\n    scalar_t hx[8], hy[8], hz[8]; // B-spline 2nd derivatives\n    offset_t ix[8], iy[8], iz[8]; // Warped indices\n    uint8_t sx[8], sy[8], sz[8]; // Warped indices\n\n    {\n      scalar_t *owz = static_cast<scalar_t*>(wz), *ogz = static_cast<scalar_t*>(gz), *ohz = static_cast<scalar_t*>(hz);\n      offset_t* oiz = static_cast<offset_t*>(iz);\n      uint8_t* osz = static_cast<uint8_t*>(sz);\n      for (offset_t bz = bz0; bz <= bz1; ++bz) {\n        scalar_t dz = z - bz;\n        *(owz++) = interpolation::fastweight(interpolation2, dz);\n        if (do_grad || do_sgrad)\n          *(ogz++) = interpolation::fastgrad(interpolation2, dz);\n        if (do_grad && trgt_sK > 1)\n          *(ohz++) = interpolation::fasthess(interpolation2, dz);\n        *(osz++) = bound::sign(bound2, bz, src_Z);\n        *(oiz++) = bound::index(bound2, bz, src_Z);\n      }\n    }\n    {\n      scalar_t *owy = static_cast<scalar_t*>(wy), *ogy = static_cast<scalar_t*>(gy), *ohy = static_cast<scalar_t*>(hy);\n      offset_t* oiy = static_cast<offset_t*>(iy);\n      uint8_t* osy = static_cast<uint8_t*>(sy);\n      for (offset_t by = by0; by <= by1; ++by) {\n        scalar_t dy = y - by;\n        *(owy++) = interpolation::fastweight(interpolation1, dy);\n        if (do_grad || do_sgrad)\n          *(ogy++) = interpolation::fastgrad(interpolation1, dy);\n        if (do_grad && trgt_sK > 1)\n          *(ohy++) = interpolation::fasthess(interpolation1, dy);\n        *(osy++) = bound::sign(bound1, by, src_Y);\n        *(oiy++) = bound::index(bound1, by, src_Y);\n      }\n    }\n    {\n      scalar_t *owx = static_cast<scalar_t*>(wx), *ogx = static_cast<scalar_t*>(gx), *ohx = static_cast<scalar_t*>(hx);\n      offset_t* oix = static_cast<offset_t*>(ix);\n      uint8_t* osx = static_cast<uint8_t*>(sx);\n      for (offset_t bx = bx0; bx <= bx1; ++bx) {\n        scalar_t dx = x - bx;\n        *(owx++) = interpolation::fastweight(interpolation0, dx);\n        if (do_grad || do_sgrad)\n          *(ogx++) = interpolation::fastgrad(interpolation0, dx);\n        if (do_grad && trgt_sK > 1)\n          *(ohx++) = interpolation::fasthess(interpolation0, dx);\n        *(osx++) = bound::sign(bound0, bx, src_X);\n        *(oix++) = bound::index(bound0, bx, src_X);\n      }\n    }\n\n    // Convolve coefficients with basis functions\n    scalar_t ogx, ogy, ogz;\n    ogx = ogy = ogz = static_cast<scalar_t>(0);\n    for (offset_t k = 0; k <= dbz; ++k) {\n      offset_t ooz = iz[k] * out_sZ;\n      offset_t osz = iz[k] * src_sZ;\n      uint8_t szz = sz[k];\n      scalar_t wzz = wz[k];\n      scalar_t gzz = gz[k];\n      scalar_t hzz = hz[k];\n      for (offset_t j = 0; j <= dby; ++j) {\n        offset_t ooyz = ooz + iy[j] * out_sY;\n        offset_t osyz = osz + iy[j] * src_sY;\n        uint8_t syz = szz * sy[j];\n        scalar_t wyy = wy[j];\n        scalar_t gyy = gy[j];\n        scalar_t hyy = hy[j];\n        for (offset_t i = 0; i <= dbx; ++i) {\n          offset_t ooxyz = ooyz + ix[i] * out_sX;\n          offset_t osxyz = osyz + ix[i] * src_sX;\n          uint8_t sxyz = syz * sx[i];\n          scalar_t wxx = wx[i];\n          scalar_t gxx = gx[i];\n          scalar_t hxx = hx[i];\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          if (do_pull) {\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t* out_ptr_NCXYZ = out_ptr_NCXYZ0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC)\n              *out_ptr_NCXYZ += bound::get(src_ptr_NC, osxyz, sxyz) * (wxx * wyy * wzz);\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          else if (do_sgrad) {\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t* out_ptr_NCXYZ = out_ptr_NCXYZ0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) {\n              scalar_t src = bound::get(src_ptr_NC, osxyz, sxyz);\n              *out_ptr_NCXYZ += src * (gxx * wyy * wzz);\n              out_ptr_NCXYZ[out_sK] += src * (wxx * gyy * wzz);\n              out_ptr_NCXYZ[2 * out_sK] += src * (wxx * wyy * gzz);\n            }\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          else if (do_push) {\n            if (trgt_K == 0) {\n              // Diff w.r.t. push/pull\n              scalar_t* out_ptr_NC = out_ptr_NC0;\n              for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n                bound::add(out_ptr_NC, ooxyz, (wxx * wyy * wzz) * target[c], sxyz);\n            } else {\n              // Diff w.r.t. sgrad\n              scalar_t* out_ptr_NC = out_ptr_NC0;\n              for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {\n                scalar_t val = (gxx * wyy * wzz) * target[c] + (wxx * gyy * wzz) * target[c + C] +\n                    (wxx * wyy * gzz) * target[c + C * 2];\n                bound::add(out_ptr_NC, ooxyz, val, sxyz);\n              }\n            }\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          else if (do_count) {\n            bound::add(out_ptr_NC0, ooxyz, (wxx * wyy * wzz), sxyz);\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          if (do_grad) {\n            if (trgt_K == 0) {\n              // Diff w.r.t. pull/push\n              scalar_t* src_ptr_NC = src_ptr_NC0;\n              scalar_t dot = static_cast<scalar_t>(0);\n              for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n                scalar_t src = bound::get(src_ptr_NC, osxyz, sxyz);\n                dot += (trgt_ptr ? src * target[c] : src);\n                // trgt_ptr == 0 in the backward pass of 'count'\n              }\n              ogx += (gxx * wyy * wzz) * dot;\n              ogy += (wxx * gyy * wzz) * dot;\n              ogz += (wxx * wyy * gzz) * dot;\n            } else {\n              // Diff w.r.t. sgrad\n              scalar_t* src_ptr_NC = src_ptr_NC0;\n              scalar_t dot0, dot1, dot2;\n              dot0 = dot1 = dot2 = static_cast<scalar_t>(0);\n              for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n                scalar_t src = bound::get(src_ptr_NC, osxyz, sxyz);\n                dot0 += src * target[c];\n                dot1 += src * target[c + C];\n                dot2 += src * target[c + C * 2];\n              }\n              ogx += (hxx * wyy * wzz) * dot0 + (gxx * gyy * wzz) * dot1 + (gxx * wyy * gzz) * dot2;\n              ogy += (gxx * gyy * wzz) * dot0 + (wxx * hyy * wzz) * dot1 + (wxx * gyy * gzz) * dot2;\n              ogz += (gxx * wyy * gzz) * dot0 + (wxx * gyy * gzz) * dot1 + (wxx * wyy * hzz) * dot2;\n            }\n          }\n\n        } // x\n      } // y\n    } // z\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;\n      (*grad_ptr_NXYZ) = ogx;\n      grad_ptr_NXYZ[grad_sC] = ogy;\n      grad_ptr_NXYZ[grad_sC * 2] = ogz;\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     GENERIC INTERPOLATION 2D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate2d(\n      scalar_t x,\n      scalar_t y,\n      offset_t w,\n      offset_t h,\n      offset_t n) const {\n    // Get corner pixel values from (x, y)\n    offset_t bx0, bx1, by0, by1;\n    interpolation::bounds(interpolation0, x, bx0, bx1);\n    interpolation::bounds(interpolation1, y, by0, by1);\n    offset_t dbx = bx1 - bx0;\n    offset_t dby = by1 - by0;\n\n    // Pre-compute offsets and target value\n    scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;\n    scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;\n    scalar_t* out_ptr_NCXY0 = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n    scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n    scalar_t target[2 * MONAI_MAX_NUM_CHANNELS];\n    if (trgt_ptr && (do_push || do_grad))\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC) {\n        target[c] = *trgt_ptr_NCXY;\n        if (trgt_K > 0) {\n          target[c + C] = trgt_ptr_NCXY[trgt_sK];\n        }\n      }\n\n    // Initialize output\n    scalar_t* out_ptr_NCXY = out_ptr_NCXY0;\n    if (do_pull || do_sgrad) {\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) {\n        *out_ptr_NCXY = static_cast<scalar_t>(0);\n        if (do_sgrad) {\n          out_ptr_NCXY[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n    }\n\n    // Pre-compute indices/weights/grad\n    scalar_t wx[8], wy[8]; // B-spline weights\n    scalar_t gx[8], gy[8]; // B-spline derivatives\n    scalar_t hx[8], hy[8]; // B-spline 2nd derivatives\n    offset_t ix[8], iy[8]; // Warped indices\n    uint8_t sx[8], sy[8]; // Warped indices\n\n    {\n      scalar_t *owy = static_cast<scalar_t*>(wy), *ogy = static_cast<scalar_t*>(gy), *ohy = static_cast<scalar_t*>(hy);\n      offset_t* oiy = static_cast<offset_t*>(iy);\n      uint8_t* osy = static_cast<uint8_t*>(sy);\n      for (offset_t by = by0; by <= by1; ++by) {\n        scalar_t dy = y - by;\n        *(owy++) = interpolation::fastweight(interpolation1, dy);\n        if (do_grad || do_sgrad)\n          *(ogy++) = interpolation::fastgrad(interpolation1, dy);\n        if (do_grad && trgt_sK > 1)\n          *(ohy++) = interpolation::fasthess(interpolation1, dy);\n        *(osy++) = bound::sign(bound1, by, src_Y);\n        *(oiy++) = bound::index(bound1, by, src_Y);\n      }\n    }\n    {\n      scalar_t *owx = static_cast<scalar_t*>(wx), *ogx = static_cast<scalar_t*>(gx), *ohx = static_cast<scalar_t*>(hx);\n      offset_t* oix = static_cast<offset_t*>(ix);\n      uint8_t* osx = static_cast<uint8_t*>(sx);\n      for (offset_t bx = bx0; bx <= bx1; ++bx) {\n        scalar_t dx = x - bx;\n        *(owx++) = interpolation::fastweight(interpolation0, dx);\n        if (do_grad || do_sgrad)\n          *(ogx++) = interpolation::fastgrad(interpolation0, dx);\n        if (do_grad && trgt_sK > 1)\n          *(ohx++) = interpolation::fasthess(interpolation0, dx);\n        *(osx++) = bound::sign(bound0, bx, src_X);\n        *(oix++) = bound::index(bound0, bx, src_X);\n      }\n    }\n\n    // Convolve coefficients with basis functions\n    scalar_t ogx, ogy;\n    ogx = ogy = static_cast<scalar_t>(0);\n    for (offset_t j = 0; j <= dby; ++j) {\n      offset_t ooy = iy[j] * out_sY;\n      offset_t osy = iy[j] * src_sY;\n      uint8_t syy = sy[j];\n      scalar_t wyy = wy[j];\n      scalar_t gyy = gy[j];\n      scalar_t hyy = hy[j];\n      for (offset_t i = 0; i <= dbx; ++i) {\n        offset_t ooxy = ooy + ix[i] * out_sX;\n        offset_t osxy = osy + ix[i] * src_sX;\n        uint8_t sxy = syy * sx[i];\n        scalar_t wxx = wx[i];\n        scalar_t gxx = gx[i];\n        scalar_t hxx = hx[i];\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        if (do_pull) {\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t* out_ptr_NCXY = out_ptr_NCXY0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC)\n            *out_ptr_NCXY += bound::get(src_ptr_NC, osxy, sxy) * (wxx * wyy);\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        else if (do_sgrad) {\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t* out_ptr_NCXY = out_ptr_NCXY0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) {\n            scalar_t src = bound::get(src_ptr_NC, osxy, sxy);\n            *out_ptr_NCXY += src * (gxx * wyy);\n            out_ptr_NCXY[out_sK] += src * (wxx * gyy);\n          }\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        else if (do_push) {\n          if (trgt_K == 0) {\n            // Diff w.r.t. push/pull\n            scalar_t* out_ptr_NC = out_ptr_NC0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n              bound::add(out_ptr_NC, ooxy, (wxx * wyy) * target[c], sxy);\n          } else {\n            // Diff w.r.t. sgrad\n            scalar_t* out_ptr_NC = out_ptr_NC0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {\n              scalar_t val = (gxx * wyy) * target[c] + (wxx * gyy) * target[c + C];\n              bound::add(out_ptr_NC, ooxy, val, sxy);\n            }\n          }\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        else if (do_count) {\n          bound::add(out_ptr_NC0, ooxy, (wxx * wyy), sxy);\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        if (do_grad) {\n          if (trgt_K == 0) {\n            // Diff w.r.t. pull/push\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t dot = static_cast<scalar_t>(0);\n            for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n              scalar_t src = bound::get(src_ptr_NC, osxy, sxy);\n              dot += (trgt_ptr ? src * target[c] : src);\n              // trgt_ptr == 0 in the backward pass of 'count'\n            }\n            ogx += (gxx * wyy) * dot;\n            ogy += (wxx * gyy) * dot;\n          } else {\n            // Diff w.r.t. sgrad\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t dot0, dot1;\n            dot0 = dot1 = static_cast<scalar_t>(0);\n            for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n              scalar_t src = bound::get(src_ptr_NC, osxy, sxy);\n              dot0 += src * target[c];\n              dot1 += src * target[c + C];\n            }\n            ogx += (hxx * wyy) * dot0 + (gxx * gyy) * dot1;\n            ogy += (gxx * gyy) * dot0 + (wxx * hyy) * dot1;\n          }\n        }\n\n      } // x\n    } // y\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;\n      (*grad_ptr_NXY) = ogx;\n      grad_ptr_NXY[grad_sC] = ogy;\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     GENERIC INTERPOLATION 1D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate1d(scalar_t x, offset_t w, offset_t n) const {\n    // Get corner pixel values from (x, y)\n    offset_t bx0, bx1;\n    interpolation::bounds(interpolation0, x, bx0, bx1);\n    offset_t dbx = bx1 - bx0;\n\n    // Pre-compute offsets and target value\n    scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;\n    scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;\n    scalar_t* out_ptr_NCX0 = out_ptr + n * out_sN + w * out_sX;\n    scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n    scalar_t target[2 * MONAI_MAX_NUM_CHANNELS];\n    if (trgt_ptr && (do_push || do_grad))\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC) {\n        target[c] = *trgt_ptr_NCX;\n        if (trgt_K > 0) {\n          target[c + C] = trgt_ptr_NCX[trgt_sK];\n        }\n      }\n\n    // Initialize output\n    scalar_t* out_ptr_NCX = out_ptr_NCX0;\n    if (do_pull || do_sgrad) {\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) {\n        *out_ptr_NCX = static_cast<scalar_t>(0);\n        if (do_sgrad) {\n          out_ptr_NCX[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n    }\n\n    // Pre-compute indices/weights/grad\n    scalar_t wx[8]; // B-spline weights\n    scalar_t gx[8]; // B-spline derivatives\n    scalar_t hx[8]; // B-spline 2nd derivatives\n    offset_t ix[8]; // Warped indices\n    uint8_t sx[8]; // Warped indices\n\n    {\n      scalar_t *owx = static_cast<scalar_t*>(wx), *ogx = static_cast<scalar_t*>(gx), *ohx = static_cast<scalar_t*>(hx);\n      offset_t* oix = static_cast<offset_t*>(ix);\n      uint8_t* osx = static_cast<uint8_t*>(sx);\n      for (offset_t bx = bx0; bx <= bx1; ++bx) {\n        scalar_t dx = x - bx;\n        *(owx++) = interpolation::fastweight(interpolation0, dx);\n        if (do_grad || do_sgrad)\n          *(ogx++) = interpolation::fastgrad(interpolation0, dx);\n        if (do_grad && trgt_sK > 1)\n          *(ohx++) = interpolation::fasthess(interpolation0, dx);\n        *(osx++) = bound::sign(bound0, bx, src_X);\n        *(oix++) = bound::index(bound0, bx, src_X);\n      }\n    }\n\n    // Convolve coefficients with basis functions\n    scalar_t ogx;\n    ogx = static_cast<scalar_t>(0);\n    for (offset_t i = 0; i <= dbx; ++i) {\n      offset_t oox = ix[i] * out_sX;\n      offset_t osx = ix[i] * src_sX;\n      uint8_t sxx = sx[i];\n      scalar_t wxx = wx[i];\n      scalar_t gxx = gx[i];\n      scalar_t hxx = hx[i];\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      if (do_pull) {\n        scalar_t* src_ptr_NC = src_ptr_NC0;\n        scalar_t* out_ptr_NCX = out_ptr_NCX0;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC)\n          *out_ptr_NCX += bound::get(src_ptr_NC, osx, sxx) * wxx;\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      else if (do_sgrad) {\n        scalar_t* src_ptr_NC = src_ptr_NC0;\n        scalar_t* out_ptr_NCX = out_ptr_NCX0;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {\n          scalar_t src = bound::get(src_ptr_NC, osx, sxx);\n          *out_ptr_NCX += src * gxx;\n        }\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      else if (do_push) {\n        if (trgt_K == 0) {\n          // Diff w.r.t. push/pull\n          scalar_t* out_ptr_NC = out_ptr_NC0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n            bound::add(out_ptr_NC, oox, wxx * target[c], sxx);\n        } else {\n          // Diff w.r.t. sgrad\n          scalar_t* out_ptr_NC = out_ptr_NC0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {\n            scalar_t val = gxx * target[c];\n            bound::add(out_ptr_NC, oox, val, sxx);\n          }\n        }\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      else if (do_count) {\n        bound::add(out_ptr_NC0, oox, wxx, sxx);\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      if (do_grad) {\n        if (trgt_K == 0) {\n          // Diff w.r.t. pull/push\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t dot = static_cast<scalar_t>(0);\n          for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n            scalar_t src = bound::get(src_ptr_NC, osx, sxx);\n            dot += (trgt_ptr ? src * target[c] : src);\n            // trgt_ptr == 0 in the backward pass of 'count'\n          }\n          ogx += gxx * dot;\n        } else {\n          // Diff w.r.t. sgrad\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t dot;\n          dot = static_cast<scalar_t>(0);\n          for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n            scalar_t src = bound::get(src_ptr_NC, osx, sxx);\n            dot += src * target[c];\n          }\n          ogx += hxx * dot;\n        }\n      }\n\n    } // x\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;\n      (*grad_ptr_NX) = ogx;\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     LINEAR INTERPOLATION 3D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate3d_trilinear(\n      scalar_t x,\n      scalar_t y,\n      scalar_t z,\n      offset_t w,\n      offset_t h,\n      offset_t d,\n      offset_t n) const {\n    // Get corner pixel values from (x, y, z)\n    offset_t ix0 = static_cast<offset_t>(std::floor(x));\n    offset_t iy0 = static_cast<offset_t>(std::floor(y));\n    offset_t iz0 = static_cast<offset_t>(std::floor(z));\n\n    // Interpolation weights (inversely proportional to distance)\n    scalar_t dx1 = x - ix0;\n    scalar_t dy1 = y - iy0;\n    scalar_t dz1 = z - iz0;\n    scalar_t dx0 = 1. - dx1;\n    scalar_t dy0 = 1. - dy1;\n    scalar_t dz0 = 1. - dz1;\n    scalar_t w000 = dx0 * dy0 * dz0;\n    scalar_t w100 = dx1 * dy0 * dz0;\n    scalar_t w010 = dx0 * dy1 * dz0;\n    scalar_t w001 = dx0 * dy0 * dz1;\n    scalar_t w110 = dx1 * dy1 * dz0;\n    scalar_t w011 = dx0 * dy1 * dz1;\n    scalar_t w101 = dx1 * dy0 * dz1;\n    scalar_t w111 = dx1 * dy1 * dz1;\n\n    // Sign (/!\\ compute sign before warping indices)\n    int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X);\n    int8_t sy1 = bound::sign(bound1, iy0 + 1, src_Y);\n    int8_t sz1 = bound::sign(bound2, iz0 + 1, src_Z);\n    int8_t sx0 = bound::sign(bound0, ix0, src_X);\n    int8_t sy0 = bound::sign(bound1, iy0, src_Y);\n    int8_t sz0 = bound::sign(bound2, iz0, src_Z);\n    int8_t s000 = sx0 * sy0 * sz0;\n    int8_t s100 = sx1 * sy0 * sz0;\n    int8_t s010 = sx0 * sy1 * sz0;\n    int8_t s001 = sx0 * sy0 * sz1;\n    int8_t s110 = sx1 * sy1 * sz0;\n    int8_t s011 = sx0 * sy1 * sz1;\n    int8_t s101 = sx1 * sy0 * sz1;\n    int8_t s111 = sx1 * sy1 * sz1;\n\n    // Warp indices\n    offset_t ix1, iy1, iz1;\n    ix1 = bound::index(bound0, ix0 + 1, src_X);\n    iy1 = bound::index(bound1, iy0 + 1, src_Y);\n    iz1 = bound::index(bound2, iz0 + 1, src_Z);\n    ix0 = bound::index(bound0, ix0, src_X);\n    iy0 = bound::index(bound1, iy0, src_Y);\n    iz0 = bound::index(bound2, iz0, src_Z);\n\n    offset_t o000, o100, o010, o001, o110, o011, o101, o111;\n\n    if (do_pull || do_grad || do_sgrad) {\n      // Offsets into source volume\n      o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ;\n      o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ;\n      o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ;\n      o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ;\n      o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ;\n      o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ;\n      o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ;\n      o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ;\n    } else if (!(do_push || do_count)) {\n      o000 = o100 = o010 = o001 = o110 = o011 = o101 = o111 = 0;\n    }\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t gx = static_cast<scalar_t>(0);\n      scalar_t gy = static_cast<scalar_t>(0);\n      scalar_t gz = static_cast<scalar_t>(0);\n      scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      if (trgt_K == 0) {\n        // backward w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt = trgt_ptr ? *trgt_ptr_NCXYZ : static_cast<scalar_t>(1);\n          // ^ trgt_ptr == 0 during the backward pass of count\n          src = bound::get(src_ptr_NC, o000, s000);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy0 * dz0 * src;\n          gy -= dx0 * dz0 * src;\n          gz -= dx0 * dy0 * src;\n          src = bound::get(src_ptr_NC, o100, s100);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy0 * dz0 * src;\n          gy -= dx1 * dz0 * src;\n          gz -= dx1 * dy0 * src;\n          src = bound::get(src_ptr_NC, o010, s010);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy1 * dz0 * src;\n          gy += dx0 * dz0 * src;\n          gz -= dx0 * dy1 * src;\n          src = bound::get(src_ptr_NC, o110, s110);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy1 * dz0 * src;\n          gy += dx1 * dz0 * src;\n          gz -= dx1 * dy1 * src;\n          src = bound::get(src_ptr_NC, o001, s001);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy0 * dz1 * src;\n          gy -= dx0 * dz1 * src;\n          gz += dx0 * dy0 * src;\n          src = bound::get(src_ptr_NC, o101, s101);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy0 * dz1 * src;\n          gy -= dx1 * dz1 * src;\n          gz += dx1 * dy0 * src;\n          src = bound::get(src_ptr_NC, o011, s011);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy1 * dz1 * src;\n          gy += dx0 * dz1 * src;\n          gz += dx0 * dy1 * src;\n          src = bound::get(src_ptr_NC, o111, s111);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy1 * dz1 * src;\n          gy += dx1 * dz1 * src;\n          gz += dx1 * dy1 * src;\n        }\n      } else {\n        // backward w.r.t. sgrad\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt0 = *trgt_ptr_NCXYZ, trgt1 = trgt_ptr_NCXYZ[trgt_sK], trgt2 = trgt_ptr_NCXYZ[trgt_sK * 2];\n          src = bound::get(src_ptr_NC, o000, s000);\n          gx += (dz0 * trgt1 + dy0 * trgt2) * src;\n          gy += (dz0 * trgt0 + dx0 * trgt2) * src;\n          gz += (dy0 * trgt0 + dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o100, s100);\n          gx += (-dz0 * trgt1 - dy0 * trgt2) * src;\n          gy += (-dz0 * trgt0 + dx1 * trgt2) * src;\n          gz += (-dy0 * trgt0 + dx1 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o010, s010);\n          gx += (-dz0 * trgt1 + dy1 * trgt2) * src;\n          gy += (-dz0 * trgt0 - dx0 * trgt2) * src;\n          gz += (dy1 * trgt0 - dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o110, s110);\n          gx += (dz0 * trgt1 - dy1 * trgt2) * src;\n          gy += (dz0 * trgt0 - dx1 * trgt2) * src;\n          gz += (-dy1 * trgt0 - dx1 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o001, s001);\n          gx += (dz1 * trgt1 - dy0 * trgt2) * src;\n          gy += (dz1 * trgt0 - dx0 * trgt2) * src;\n          gz += (-dy0 * trgt0 - dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o101, s101);\n          gx += (-dz1 * trgt1 + dy0 * trgt2) * src;\n          gy += (-dz1 * trgt0 - dx1 * trgt2) * src;\n          gz += (dy0 * trgt0 - dx1 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o011, s011);\n          gx += (-dz1 * trgt1 - dy1 * trgt2) * src;\n          gy += (-dz1 * trgt0 + dx0 * trgt2) * src;\n          gz += (-dy1 * trgt0 + dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o111, s111);\n          gx += (dz1 * trgt1 + dy1 * trgt2) * src;\n          gy += (dz1 * trgt0 + dx1 * trgt2) * src;\n          gz += (dy1 * trgt0 + dx1 * trgt1) * src;\n        }\n      }\n\n      scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;\n      (*grad_ptr_NXYZ) = gx;\n      grad_ptr_NXYZ[grad_sC] = gy;\n      grad_ptr_NXYZ[grad_sC * 2] = gz;\n    }\n    if (do_push || do_count) {\n      // Offsets into 'push' volume\n      o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ;\n      o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ;\n      o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ;\n      o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ;\n      o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ;\n      o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ;\n      o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ;\n      o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ;\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_pull) {\n      scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCXYZ = bound::get(src_ptr_NC, o000, s000) * w000 + bound::get(src_ptr_NC, o100, s100) * w100 +\n            bound::get(src_ptr_NC, o010, s010) * w010 + bound::get(src_ptr_NC, o110, s110) * w110 +\n            bound::get(src_ptr_NC, o001, s001) * w001 + bound::get(src_ptr_NC, o101, s101) * w101 +\n            bound::get(src_ptr_NC, o011, s011) * w011 + bound::get(src_ptr_NC, o111, s111) * w111;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~\n    else if (do_sgrad) {\n      scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) {\n        scalar_t src000 = bound::get(src_ptr_NC, o000, s000);\n        scalar_t src100 = bound::get(src_ptr_NC, o100, s100);\n        scalar_t src010 = bound::get(src_ptr_NC, o010, s010);\n        scalar_t src110 = bound::get(src_ptr_NC, o110, s110);\n        scalar_t src001 = bound::get(src_ptr_NC, o001, s001);\n        scalar_t src101 = bound::get(src_ptr_NC, o101, s101);\n        scalar_t src011 = bound::get(src_ptr_NC, o011, s011);\n        scalar_t src111 = bound::get(src_ptr_NC, o111, s111);\n        *out_ptr_NCXYZ = -dy0 * dz0 * src000 + dy0 * dz0 * src100 - dy1 * dz0 * src010 + dy1 * dz0 * src110 -\n            dy0 * dz1 * src001 + dy0 * dz1 * src101 - dy1 * dz1 * src011 + dy1 * dz1 * src111;\n        out_ptr_NCXYZ[out_sK] = -dx0 * dz0 * src000 - dx1 * dz0 * src100 + dx0 * dz0 * src010 + dx1 * dz0 * src110 -\n            dx0 * dz1 * src001 - dx1 * dz1 * src101 + dx0 * dz1 * src011 + dx1 * dz1 * src111;\n        out_ptr_NCXYZ[out_sK * 2] = -dx0 * dy0 * src000 - dx1 * dy0 * src100 - dx0 * dy1 * src010 - dx1 * dy1 * src110 +\n            dx0 * dy0 * src001 + dx1 * dy0 * src101 + dx0 * dy1 * src011 + dx1 * dy1 * src111;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_push) {\n      // Offsets into 'push' volume\n      o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ;\n      o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ;\n      o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ;\n      o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ;\n      o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ;\n      o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ;\n      o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ;\n      o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ;\n      scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      if (trgt_K == 0) {\n        // Diff w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt = *trgt_ptr_NCXYZ;\n          bound::add(out_ptr_NC, o000, w000 * trgt, s000);\n          bound::add(out_ptr_NC, o100, w100 * trgt, s100);\n          bound::add(out_ptr_NC, o010, w010 * trgt, s010);\n          bound::add(out_ptr_NC, o110, w110 * trgt, s110);\n          bound::add(out_ptr_NC, o001, w001 * trgt, s001);\n          bound::add(out_ptr_NC, o101, w101 * trgt, s101);\n          bound::add(out_ptr_NC, o011, w011 * trgt, s011);\n          bound::add(out_ptr_NC, o111, w111 * trgt, s111);\n        }\n      } else {\n        // Diff w.r.t. sgrad\n        scalar_t val;\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt0 = *trgt_ptr_NCXYZ, trgt1 = trgt_ptr_NCXYZ[trgt_sK], trgt2 = trgt_ptr_NCXYZ[trgt_sK * 2];\n          val = -dy0 * dz0 * trgt0 - dx0 * dz0 * trgt1 - dx0 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o000, val, s000);\n          val = dy0 * dz0 * trgt0 - dx1 * dz0 * trgt1 - dx1 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o100, val, s100);\n          val = -dy1 * dz0 * trgt0 + dx0 * dz0 * trgt1 - dx0 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o010, val, s010);\n          val = dy1 * dz0 * trgt0 + dx1 * dz0 * trgt1 - dx1 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o110, val, s110);\n          val = -dy0 * dz1 * trgt0 - dx0 * dz1 * trgt1 + dx0 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o001, val, s001);\n          val = dy0 * dz1 * trgt0 - dx1 * dz1 * trgt1 + dx1 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o101, val, s101);\n          val = -dy1 * dz1 * trgt0 + dx0 * dz1 * trgt1 + dx0 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o011, val, s011);\n          val = dy1 * dz1 * trgt0 + dx1 * dz1 * trgt1 + dx1 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o111, val, s111);\n        }\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_count) {\n      scalar_t* out_ptr_N = out_ptr + n * out_sN;\n      bound::add(out_ptr_N, o000, w000, s000);\n      bound::add(out_ptr_N, o100, w100, s100);\n      bound::add(out_ptr_N, o010, w010, s010);\n      bound::add(out_ptr_N, o110, w110, s110);\n      bound::add(out_ptr_N, o001, w001, s001);\n      bound::add(out_ptr_N, o101, w101, s101);\n      bound::add(out_ptr_N, o011, w011, s011);\n      bound::add(out_ptr_N, o111, w111, s111);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     LINEAR INTERPOLATION 2D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate2d_bilinear(\n      scalar_t x,\n      scalar_t y,\n      offset_t w,\n      offset_t h,\n      offset_t n) const {\n    // Get corner pixel values from (x, y, z)\n    offset_t ix0 = static_cast<offset_t>(std::floor(x));\n    offset_t iy0 = static_cast<offset_t>(std::floor(y));\n\n    // Interpolation weights (inversely proportional to distance)\n    scalar_t dx1 = x - ix0;\n    scalar_t dy1 = y - iy0;\n    scalar_t dx0 = 1. - dx1;\n    scalar_t dy0 = 1. - dy1;\n    scalar_t w00 = dx0 * dy0;\n    scalar_t w10 = dx1 * dy0;\n    scalar_t w01 = dx0 * dy1;\n    scalar_t w11 = dx1 * dy1;\n\n    // Sign (/!\\ compute sign before warping indices)\n    int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X);\n    int8_t sy1 = bound::sign(bound1, iy0 + 1, src_Y);\n    int8_t sx0 = bound::sign(bound0, ix0, src_X);\n    int8_t sy0 = bound::sign(bound1, iy0, src_Y);\n    int8_t s00 = sx0 * sy0;\n    int8_t s10 = sx1 * sy0;\n    int8_t s01 = sx0 * sy1;\n    int8_t s11 = sx1 * sy1;\n\n    // Warp indices\n    offset_t ix1, iy1;\n    ix1 = bound::index(bound0, ix0 + 1, src_X);\n    iy1 = bound::index(bound1, iy0 + 1, src_Y);\n    ix0 = bound::index(bound0, ix0, src_X);\n    iy0 = bound::index(bound1, iy0, src_Y);\n\n    offset_t o00, o10, o01, o11;\n    if (do_pull || do_grad || do_sgrad) {\n      // Offsets into source volume\n      o00 = ix0 * src_sX + iy0 * src_sY;\n      o10 = ix1 * src_sX + iy0 * src_sY;\n      o01 = ix0 * src_sX + iy1 * src_sY;\n      o11 = ix1 * src_sX + iy1 * src_sY;\n    } else if (!(do_push || do_count)) {\n      o00 = o10 = o01 = o11 = 0;\n    }\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t gx = static_cast<scalar_t>(0);\n      scalar_t gy = static_cast<scalar_t>(0);\n      scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      if (trgt_K == 0) {\n        // backward w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt = trgt_ptr ? *trgt_ptr_NCXY : static_cast<scalar_t>(1);\n          // ^ trgt_ptr == 0 during the backward pass of count\n          src = bound::get(src_ptr_NC, o00, s00);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy0 * src;\n          gy -= dx0 * src;\n          src = bound::get(src_ptr_NC, o10, s10);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy0 * src;\n          gy -= dx1 * src;\n          src = bound::get(src_ptr_NC, o01, s01);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy1 * src;\n          gy += dx0 * src;\n          src = bound::get(src_ptr_NC, o11, s11);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy1 * src;\n          gy += dx1 * src;\n        }\n      } else {\n        // backward w.r.t. sgrad\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt0 = *trgt_ptr_NCXY, trgt1 = trgt_ptr_NCXY[trgt_sK];\n          src = bound::get(src_ptr_NC, o00, s00);\n          gx += trgt1 * src;\n          gy += trgt0 * src;\n          src = bound::get(src_ptr_NC, o10, s10);\n          gx -= trgt1 * src;\n          gy -= trgt0 * src;\n          src = bound::get(src_ptr_NC, o01, s01);\n          gx -= trgt1 * src;\n          gy -= trgt0 * src;\n          src = bound::get(src_ptr_NC, o11, s11);\n          gx += trgt1 * src;\n          gy += trgt0 * src;\n        }\n      }\n\n      scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;\n      (*grad_ptr_NXY) = gx;\n      grad_ptr_NXY[grad_sC] = gy;\n    }\n    if (do_push || do_count) {\n      // Offsets into 'push' volume\n      o00 = ix0 * out_sX + iy0 * out_sY;\n      o10 = ix1 * out_sX + iy0 * out_sY;\n      o01 = ix0 * out_sX + iy1 * out_sY;\n      o11 = ix1 * out_sX + iy1 * out_sY;\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_pull) {\n      scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCXY = bound::get(src_ptr_NC, o00, s00) * w00 + bound::get(src_ptr_NC, o10, s10) * w10 +\n            bound::get(src_ptr_NC, o01, s01) * w01 + bound::get(src_ptr_NC, o11, s11) * w11;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_sgrad) {\n      scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) {\n        scalar_t src00 = bound::get(src_ptr_NC, o00, s00);\n        scalar_t src10 = bound::get(src_ptr_NC, o10, s10);\n        scalar_t src01 = bound::get(src_ptr_NC, o01, s01);\n        scalar_t src11 = bound::get(src_ptr_NC, o11, s11);\n        *out_ptr_NCXY = -dy0 * src00 + dy0 * src10 - dy1 * src01 + dy1 * src11;\n        out_ptr_NCXY[out_sK] = -dx0 * src00 - dx1 * src10 + dx0 * src01 + dx1 * src11;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_push) {\n      scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      if (trgt_K == 0) {\n        // Diff w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt = *trgt_ptr_NCXY;\n          bound::add(out_ptr_NC, o00, w00 * trgt, s00);\n          bound::add(out_ptr_NC, o10, w10 * trgt, s10);\n          bound::add(out_ptr_NC, o01, w01 * trgt, s01);\n          bound::add(out_ptr_NC, o11, w11 * trgt, s11);\n        }\n      } else {\n        // Diff w.r.t. sgrad\n        scalar_t val;\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt0 = *trgt_ptr_NCXY, trgt1 = trgt_ptr_NCXY[trgt_sK];\n          val = -dy0 * trgt0 - dx0 * trgt1;\n          bound::add(out_ptr_NC, o00, val, s00);\n          val = dy0 * trgt0 - dx1 * trgt1;\n          bound::add(out_ptr_NC, o10, val, s10);\n          val = -dy1 * trgt0 + dx0 * trgt1;\n          bound::add(out_ptr_NC, o01, val, s01);\n          val = dy1 * trgt0 + dx1 * trgt1;\n          bound::add(out_ptr_NC, o11, val, s11);\n        }\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_count) {\n      scalar_t* out_ptr_N = out_ptr + n * out_sN;\n      bound::add(out_ptr_N, o00, w00, s00);\n      bound::add(out_ptr_N, o10, w10, s10);\n      bound::add(out_ptr_N, o01, w01, s01);\n      bound::add(out_ptr_N, o11, w11, s11);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     LINEAR INTERPOLATION 1D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const {\n    // Get corner pixel values from (x)\n    offset_t ix0 = static_cast<offset_t>(std::floor(x));\n\n    // Interpolation weights (inversely proportional to distance)\n    scalar_t w1 = x - ix0;\n    scalar_t w0 = 1. - w1;\n\n    // Sign (/!\\ compute sign before warping indices)\n    int8_t s1 = bound::sign(bound0, ix0 + 1, src_X);\n    int8_t s0 = bound::sign(bound0, ix0, src_X);\n\n    // Warp indices\n    offset_t ix1;\n    ix1 = bound::index(bound0, ix0 + 1, src_X);\n    ix0 = bound::index(bound0, ix0, src_X);\n\n    offset_t o0, o1;\n    if (do_pull || do_grad || do_sgrad) {\n      // Offsets into source volume\n      o0 = ix0 * src_sX;\n      o1 = ix1 * src_sX;\n    } else if (!(do_push || do_count)) {\n      o0 = o1 = 0;\n    }\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      if (trgt_K == 0) {\n        // backward w.r.t. push/pull\n        scalar_t gx = static_cast<scalar_t>(0);\n        scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n        scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt = trgt_ptr ? *trgt_ptr_NCX : static_cast<scalar_t>(1);\n          // ^ trgt_ptr == 0 during the backward pass of count\n          src = bound::get(src_ptr_NC, o0, s0);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= src;\n          src = bound::get(src_ptr_NC, o1, s1);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += src;\n        }\n\n        scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;\n        (*grad_ptr_NX) = gx;\n      } else {\n        // backward w.r.t. sgrad\n        // -> zero (make sure this is done at initialization)\n      }\n    }\n    if (do_push || do_count) {\n      // Offsets into 'push' volume\n      o0 = ix0 * out_sX;\n      o1 = ix1 * out_sX;\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_pull) {\n      scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCX = bound::get(src_ptr_NC, o0, s0) * w0 + bound::get(src_ptr_NC, o1, s1) * w1;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_sgrad) {\n      scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCX = bound::get(src_ptr_NC, o1, s1) - bound::get(src_ptr_NC, o0, s0);\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_push) {\n      scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      if (trgt_K == 0) {\n        // Diff w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt = *trgt_ptr_NCX;\n          bound::add(out_ptr_NC, o0, w0 * trgt, s0);\n          bound::add(out_ptr_NC, o1, w1 * trgt, s1);\n        }\n      } else {\n        // Diff w.r.t. sgrad\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt0 = *trgt_ptr_NCX;\n          bound::add(out_ptr_NC, o0, -trgt0, s0);\n          bound::add(out_ptr_NC, o1, trgt0, s1);\n        }\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_count) {\n      scalar_t* out_ptr_N = out_ptr + n * out_sN;\n      bound::add(out_ptr_N, o0, w0, s0);\n      bound::add(out_ptr_N, o1, w1, s1);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  NEAREST NEIGHBOR INTERPOLATION 3D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate3d_nearest(\n      scalar_t x,\n      scalar_t y,\n      scalar_t z,\n      offset_t w,\n      offset_t h,\n      offset_t d,\n      offset_t n) const {\n    offset_t ix = static_cast<offset_t>(std::round(x));\n    offset_t iy = static_cast<offset_t>(std::round(y));\n    offset_t iz = static_cast<offset_t>(std::round(z));\n\n    // Boundary condition (/!\\ compute sign before warping indices)\n    int8_t sx = bound::sign(bound0, ix, src_X);\n    int8_t sy = bound::sign(bound1, iy, src_Y);\n    int8_t sz = bound::sign(bound2, iz, src_Z);\n    ix = bound::index(bound0, ix, src_X);\n    iy = bound::index(bound1, iy, src_Y);\n    iz = bound::index(bound2, iz, src_Z);\n\n    // Sign\n    int8_t s = sz * sy * sx;\n\n    if (do_pull) {\n      offset_t o = iz * src_sZ + iy * src_sY + ix * src_sX;\n      scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC)\n        *out_ptr_NCXYZ = bound::get(src_ptr_NC, o, s);\n    } else if (do_push && trgt_K == 0) {\n      offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX;\n      scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, *trgt_ptr_NCXYZ, s);\n    } else if (do_count) {\n      offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, static_cast<scalar_t>(1), s);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  NEAREST NEIGHBOR INTERPOLATION 2D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate2d_nearest(\n      scalar_t x,\n      scalar_t y,\n      offset_t w,\n      offset_t h,\n      offset_t n) const {\n    offset_t ix = static_cast<offset_t>(std::round(x));\n    offset_t iy = static_cast<offset_t>(std::round(y));\n\n    // Boundary condition (/!\\ compute sign before warping indices)\n    int8_t sx = bound::sign(bound0, ix, src_X);\n    int8_t sy = bound::sign(bound1, iy, src_Y);\n    ix = bound::index(bound0, ix, src_X);\n    iy = bound::index(bound1, iy, src_Y);\n\n    // Sign\n    int8_t s = sy * sx;\n\n    if (do_pull) {\n      offset_t o = iy * src_sY + ix * src_sX;\n      scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC)\n        *out_ptr_NCXY = bound::get(src_ptr_NC, o, s);\n    } else if (do_push && trgt_K == 0) {\n      offset_t o = iy * out_sY + ix * out_sX;\n      scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, *trgt_ptr_NCXY, s);\n    } else if (do_count) {\n      offset_t o = iy * out_sY + ix * out_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, static_cast<scalar_t>(1), s);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  NEAREST NEIGHBOR INTERPOLATION 1D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const {\n    offset_t i = static_cast<offset_t>(std::round(x));\n\n    // Boundary condition (/!\\ compute sign before warping indices)\n    int8_t s = bound::sign(bound0, i, src_X);\n    i = bound::index(bound0, i, src_X);\n\n    if (do_pull) {\n      offset_t o = i * src_sX;\n      scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC)\n        *out_ptr_NCX = bound::get(src_ptr_NC, o, s);\n    } else if (do_push && trgt_K == 0) {\n      offset_t o = i * out_sX;\n      scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, *trgt_ptr_NCX, s);\n    } else if (do_count) {\n      offset_t o = i * out_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, static_cast<scalar_t>(1), s);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //            LINEAR INTERPOLATION 3D + SLIDING BOUNDARY\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  // TODO\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  CUDA KERNEL (MUST BE OUT OF CLASS)\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  } // namespace\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                    FUNCTIONAL FORM WITH DISPATCH\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n#define PUSHPULL_INSTANTIATE3(BoundType0, InterpolationType0, SourceType0) \\\n  template std::deque<Tensor> pushpull(                                    \\\n      const SourceType0&,                                                  \\\n      const Tensor&,                                                       \\\n      const Tensor&,                                                       \\\n      BoundType0,                                                          \\\n      InterpolationType0,                                                  \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool);                                                               \\\n  template std::deque<Tensor> pushpull(                                    \\\n      const SourceType0&, const Tensor&, BoundType0, InterpolationType0, bool, bool, bool, bool, bool, bool)\n#define PUSHPULL_INSTANTIATE2(BoundType0, InterpolationType0)         \\\n  PUSHPULL_INSTANTIATE3(BoundType0, InterpolationType0, IntArrayRef); \\\n  PUSHPULL_INSTANTIATE3(BoundType0, InterpolationType0, Tensor)\n#define PUSHPULL_INSTANTIATE1(BoundType0)               \\\n  PUSHPULL_INSTANTIATE2(BoundType0, InterpolationType); \\\n  PUSHPULL_INSTANTIATE2(BoundType0, InterpolationVectorRef)\n#define PUSHPULL_INSTANTIATE        \\\n  PUSHPULL_INSTANTIATE1(BoundType); \\\n  PUSHPULL_INSTANTIATE1(BoundVectorRef)\n\n  // Two arguments (source, grid)\n  // > `bound` and `interpolation` can be single arguments or vectors.\n  template <typename BoundType, typename InterpolationType, typename SourceType>\n  MONAI_HOST std::deque<Tensor> pushpull(\n      const SourceType& source,\n      const Tensor& grid,\n      BoundType bound,\n      InterpolationType interpolation,\n      bool extrapolate,\n      bool do_pull,\n      bool do_push,\n      bool do_count,\n      bool do_grad,\n      bool do_sgrad) {\n    PushPullAllocator info(\n        grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);\n    info.ioset(source, grid);\n\n    return AT_DISPATCH_FLOATING_TYPES(grid.scalar_type(), \"pushpull\", [&] {\n      PushPullImpl<scalar_t, int64_t> algo(info);\n      algo.loop();\n      return algo.output;\n    });\n  }\n\n  // Three arguments (source, grid, target)\n  // > `bound` and `interpolation` can be single arguments or vectors.\n  // > `source` can be a tensor or a vector of dimensions.\n  template <typename BoundType, typename InterpolationType, typename SourceType>\n  MONAI_HOST std::deque<Tensor> pushpull(\n      const SourceType& source,\n      const Tensor& grid,\n      const Tensor& target,\n      BoundType bound,\n      InterpolationType interpolation,\n      bool extrapolate,\n      bool do_pull,\n      bool do_push,\n      bool do_count,\n      bool do_grad,\n      bool do_sgrad) {\n    PushPullAllocator info(\n        grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);\n    info.ioset(source, grid, target);\n\n    return AT_DISPATCH_FLOATING_TYPES(grid.scalar_type(), \"pushpull\", [&] {\n      PushPullImpl<scalar_t, int64_t> algo(info);\n      algo.loop();\n      return algo.output;\n    });\n  }\n\n  PUSHPULL_INSTANTIATE;\n\n} // namespace cpu\n} // namespace monai\n"
  },
  {
    "path": "monai/csrc/resample/pushpull_cuda.cu",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n// adapted from https://github.com/balbasty/nitorch\n\n// This file implements spline interpolation / sampling and its adjoint\n// operations. It corresponds loosely to torch's `GridSampler`.\n// It handles boundary conditions and interpolation orders defined in\n// `utils/resample_utils.h` and `utils/resample_utils.h`.\n// These parameters can be specified per dimension.\n// Isotropic 0-th and 1-st order interpolation have their own (faster)\n// implementations. Sliding boundary conditions are also implemented\n// separately.\n\n// TODO:\n// . [DONE] generic 3d\n// . [DONE] generic 2d\n// . [DONE] generic 1d\n// . sliding nearest 3d\n// . sliding nearest 2d\n// . sliding linear 3d\n// . sliding linear 2d\n// . sliding generic 3d\n// . sliding generic 2d\n// . [DONE] spatial gradient mode (without multiplication with output gradient)\n// . [DONE] second order gradients (backward pass for spatial gradients)\n// . performance tests\n// . input bound/inter are always vectors -> clean unused constructors\n\n#include <ATen/ATen.h>\n#include <limits>\n#include <tuple>\n#include \"bounds_common.h\"\n#include \"interpolation_common.h\"\n#include \"utils/resample_utils.h\"\n//#include <cstdio>\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n// GPU-specific parameters\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/detail/KernelUtils.h>\n#include <c10/macros/Macros.h>\nusing namespace at::cuda::detail;\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n// maximum number of channels\n// > not used in mode isotropic nearest/linear\n#ifndef MONAI_MAX_NUM_CHANNELS\n#define MONAI_MAX_NUM_CHANNELS 1024\n#endif\n\n// This parameter allows for a little bit of tolerance when considering\n// a coordinate as \"out-of-bound\" (if !extrapolate)\n#define TINY 5e-2\n\nusing at::Tensor;\nusing at::TensorOptions;\nusing c10::IntArrayRef;\n\nnamespace monai {\nMONAI_NAMESPACE_DEVICE { // cuda\n\n  namespace { // anonymous namespace > everything inside has internal linkage\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                        INDEXING UTILS\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  // This class reads and sets all the parameters that will later be used\n  // by the algorithm in PushPullImpl. All of this is done outside of the\n  // implementation class so that we do not depend on generic types. The\n  // point is to pre-allocate all necessary tensors so that we can check\n  // if they're all compatible with 32 bit math. If it's the case, we can\n  // dispatch to a 32b cuda implementation, which might increase\n  // performance. Else, we use 64 bit math to compute offsets.\n  // (On CPU, we always use 64 bit offsets because it doesn't make a huge\n  // difference. It would be different if we had a vectorized\n  // implementation as in PyTorch).\n  class PushPullAllocator {\n   public:\n    static constexpr int64_t max_int32 = std::numeric_limits<int32_t>::max();\n\n    // ~~~ CONSTRUCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    MONAI_HOST\n    PushPullAllocator(\n        int dim,\n        BoundVectorRef bound,\n        InterpolationVectorRef interpolation,\n        bool extrapolate,\n        bool do_pull,\n        bool do_push,\n        bool do_count,\n        bool do_grad,\n        bool do_sgrad)\n        : dim(dim),\n          bound0(bound.size() > 0 ? bound[0] : BoundType::Replicate),\n          bound1(\n              bound.size() > 1       ? bound[1]\n                  : bound.size() > 0 ? bound[0]\n                                     : BoundType::Replicate),\n          bound2(\n              bound.size() > 2       ? bound[2]\n                  : bound.size() > 1 ? bound[1]\n                  : bound.size() > 0 ? bound[0]\n                                     : BoundType::Replicate),\n          interpolation0(interpolation.size() > 0 ? interpolation[0] : InterpolationType::Linear),\n          interpolation1(\n              interpolation.size() > 1       ? interpolation[1]\n                  : interpolation.size() > 0 ? interpolation[0]\n                                             : InterpolationType::Linear),\n          interpolation2(\n              interpolation.size() > 2       ? interpolation[2]\n                  : interpolation.size() > 1 ? interpolation[1]\n                  : interpolation.size() > 0 ? interpolation[0]\n                                             : InterpolationType::Linear),\n          extrapolate(extrapolate),\n          do_pull(do_pull),\n          do_push(do_push),\n          do_count(do_count),\n          do_grad(do_grad),\n          do_sgrad(do_sgrad) {\n      iso = interpolation0 == interpolation1 && interpolation0 == interpolation2;\n    }\n\n    // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    // Usually used for pull:\n    // - do_pull  -> return source[grid]\n    // - do_push  -> fails\n    // - do_grad  -> return J(source)[grid]\n    // - do_sgrad -> return H(source)[grid]\n    MONAI_HOST void ioset(const Tensor& source, const Tensor& grid) {\n      init_all();\n      init_source(source);\n      init_grid(grid);\n      init_output();\n    }\n\n    // Usually used for pull_backward:\n    // - do_pull  -> return source[grid]\n    // - do_push  -> return push(target, grid, source.shape)\n    // - do_grad  -> return J(source)[grid]\n    // - do_sgrad -> return H(source)[grid]\n    MONAI_HOST void ioset(const Tensor& source, const Tensor& grid, const Tensor& target) {\n      init_all();\n      init_source(source);\n      init_grid(grid);\n      init_target(target);\n      init_output();\n    }\n\n    // Usually used for push:\n    // - do_pull  -> fails\n    // - do_push  -> return push(target, grid, source_size)\n    // - do_grad  -> fails\n    // - do_sgrad -> fails\n    MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid, const Tensor& target) {\n      init_all();\n      init_source(source_size);\n      init_grid(grid);\n      init_target(target);\n      init_output();\n    }\n\n    // Usually used for count:\n    // - do_pull  -> fails\n    // - do_push  -> return push(ones, grid, source_size)\n    // - do_grad  -> fails\n    // - do_sgrad -> fails\n    MONAI_HOST void ioset(IntArrayRef source_size, const Tensor& grid) {\n      init_all();\n      init_source(source_size);\n      init_grid(grid);\n      init_output();\n    }\n\n    // We just check that all tensors that we own are compatible with 32b math\n    bool canUse32BitIndexMath(int64_t max_elem = max_int32) const {\n      return src_32b_ok && trgt_32b_ok && grid_32b_ok && grad_32b_ok && out_32b_ok;\n    }\n\n   private:\n    // Copied from aten/src/ATen/native/IndexingUtils.cpp in PyTorch 1.6.\n    // It is used to decide to which pointer type we should dispatch to.\n    // Basically, we need to make sure that the \"furthest\" element we need\n    // to reach is less than max_elem away.\n    static bool tensorCanUse32BitIndexMath(const Tensor& t, int64_t max_elem = max_int32) {\n      int64_t elements = t.numel();\n      if (elements >= max_elem) {\n        return false;\n      }\n      if (elements == 0) {\n        return max_elem > 0;\n      }\n\n      int64_t offset = 0;\n      int64_t linearId = elements - 1;\n\n      // NOTE: Assumes all strides are positive, which is true for now\n      for (int i = t.dim() - 1; i >= 0; --i) {\n        int64_t curDimIndex = linearId % t.size(i);\n        int64_t curDimOffset = curDimIndex * t.stride(i);\n        offset += curDimOffset;\n        linearId /= t.size(i);\n      }\n\n      if (offset >= max_elem) {\n        return false;\n      }\n\n      return true;\n    }\n\n    // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    MONAI_HOST void init_all();\n    MONAI_HOST void init_source(const Tensor& source);\n    MONAI_HOST void init_source(IntArrayRef source_size);\n    MONAI_HOST void init_grid(const Tensor& grid);\n    MONAI_HOST void init_target(const Tensor& target);\n    MONAI_HOST void init_output();\n\n    // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    int dim; // dimensionality (2 or 3)\n    BoundType bound0; // boundary condition  // x|W\n    BoundType bound1; // boundary condition  // y|H\n    BoundType bound2; // boundary condition  // z|D\n    InterpolationType interpolation0; // interpolation order // x|W\n    InterpolationType interpolation1; // interpolation order // y|H\n    InterpolationType interpolation2; // interpolation order // z|D\n    bool iso; // isotropic interpolation?\n    bool extrapolate; // compute out-of-bound values\n    bool do_pull; // sample a volume\n    bool do_push; // splat a volume\n    bool do_count; // splatting weights (= jacobian determinant)\n    bool do_grad; // backprop: gradient of grid // pull\n    bool do_sgrad; // sample spatial gradients\n\n    // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    std::deque<Tensor> output;\n    TensorOptions src_opt;\n    TensorOptions grid_opt;\n    TensorOptions trgt_opt;\n    int64_t N;\n    int64_t C;\n    int64_t src_X;\n    int64_t src_Y;\n    int64_t src_Z;\n    int64_t trgt_X;\n    int64_t trgt_Y;\n    int64_t trgt_Z;\n    int64_t trgt_K;\n    int64_t src_sN;\n    int64_t src_sC;\n    int64_t src_sX;\n    int64_t src_sY;\n    int64_t src_sZ;\n    bool src_32b_ok;\n    void* src_ptr;\n    int64_t trgt_sN;\n    int64_t trgt_sC;\n    int64_t trgt_sX;\n    int64_t trgt_sY;\n    int64_t trgt_sZ;\n    int64_t trgt_sK;\n    bool trgt_32b_ok;\n    void* trgt_ptr;\n    int64_t grid_sN;\n    int64_t grid_sC;\n    int64_t grid_sX;\n    int64_t grid_sY;\n    int64_t grid_sZ;\n    bool grid_32b_ok;\n    void* grid_ptr;\n    int64_t out_sN;\n    int64_t out_sC;\n    int64_t out_sX;\n    int64_t out_sY;\n    int64_t out_sZ;\n    int64_t out_sK; // gradient dimension\n    bool out_32b_ok;\n    void* out_ptr;\n    int64_t grad_sN;\n    int64_t grad_sC;\n    int64_t grad_sX;\n    int64_t grad_sY;\n    int64_t grad_sZ;\n    bool grad_32b_ok;\n    void* grad_ptr;\n\n    // Allow PushPullImpl's constructor to access PushPullAllocator's\n    // private members.\n    template <typename scalar_t, typename offset_t>\n    friend class PushPullImpl;\n  };\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                          INITIALISATION\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  MONAI_HOST\n  void PushPullAllocator::init_all() {\n    src_opt = grid_opt = trgt_opt = TensorOptions();\n    N = C = 1L;\n    src_X = src_Y = src_Z = 1L;\n    trgt_X = trgt_Y = trgt_Z = 1L;\n    trgt_K = 0L;\n    src_sN = src_sC = src_sX = src_sY = src_sZ = 0L;\n    grid_sN = grid_sC = grid_sX = grid_sY = grid_sZ = 0L;\n    grad_sN = grad_sC = grad_sX = grad_sY = grad_sZ = 0L;\n    trgt_sN = trgt_sC = trgt_sX = trgt_sY = trgt_sZ = trgt_sK = 0L;\n    out_sN = out_sC = out_sX = out_sY = out_sZ = out_sK = 0L;\n    src_ptr = trgt_ptr = grid_ptr = out_ptr = grad_ptr = static_cast<float*>(0);\n    src_32b_ok = trgt_32b_ok = grid_32b_ok = out_32b_ok = grad_32b_ok = true;\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_source(const Tensor& source) {\n    N = source.size(0);\n    C = source.size(1);\n    src_X = source.size(2);\n    src_Y = dim < 2 ? 1L : source.size(3);\n    src_Z = dim < 3 ? 1L : source.size(4);\n    src_sN = source.stride(0);\n    src_sC = source.stride(1);\n    src_sX = source.stride(2);\n    src_sY = dim < 2 ? 0L : source.stride(3);\n    src_sZ = dim < 3 ? 0L : source.stride(4);\n    src_ptr = source.data_ptr();\n    src_opt = source.options();\n    src_32b_ok = tensorCanUse32BitIndexMath(source);\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_source(IntArrayRef source_size) {\n    src_X = source_size[0];\n    src_Y = dim < 2 ? 1L : source_size[1];\n    src_Z = dim < 3 ? 1L : source_size[2];\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_grid(const Tensor& grid) {\n    N = grid.size(0);\n    trgt_X = grid.size(1);\n    trgt_Y = dim < 2 ? 1L : grid.size(2);\n    trgt_Z = dim < 3 ? 1L : grid.size(3);\n    grid_sN = grid.stride(0);\n    grid_sX = grid.stride(1);\n    grid_sY = dim < 2 ? 0L : grid.stride(2);\n    grid_sZ = dim < 3 ? 0L : grid.stride(3);\n    grid_sC = grid.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);\n    grid_ptr = grid.data_ptr();\n    grid_opt = grid.options();\n    grid_32b_ok = tensorCanUse32BitIndexMath(grid);\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_target(const Tensor& target) {\n    N = target.size(0);\n    C = target.size(1);\n    trgt_X = target.size(2);\n    trgt_Y = dim < 2 ? 1L : target.size(3);\n    trgt_Z = dim < 3 ? 1L : target.size(4);\n    trgt_K = target.dim() == dim + 3 ? target.size(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;\n    trgt_sN = target.stride(0);\n    trgt_sC = target.stride(1);\n    trgt_sX = target.stride(2);\n    trgt_sY = dim < 2 ? 0L : target.stride(3);\n    trgt_sZ = dim < 3 ? 0L : target.stride(4);\n    trgt_sK = target.dim() == dim + 3 ? target.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5) : 0L;\n    trgt_ptr = target.data_ptr();\n    trgt_opt = target.options();\n    trgt_32b_ok = tensorCanUse32BitIndexMath(target);\n  }\n\n  MONAI_HOST\n  void PushPullAllocator::init_output() {\n    output.clear();\n    if (do_pull) {\n      if (dim == 1)\n        output.push_back(at::empty({N, C, trgt_X}, src_opt));\n      else if (dim == 2)\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y}, src_opt));\n      else\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z}, src_opt));\n      auto pull = output.back();\n      out_sN = pull.stride(0);\n      out_sC = pull.stride(1);\n      out_sX = pull.stride(2);\n      out_sY = dim < 2 ? 0L : pull.stride(3);\n      out_sZ = dim < 3 ? 0L : pull.stride(4);\n      out_sK = 0L;\n      out_ptr = pull.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(pull);\n    } else if (do_sgrad) {\n      if (dim == 1)\n        output.push_back(at::empty({N, C, trgt_X, 1}, src_opt));\n      else if (dim == 2)\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y, 2}, src_opt));\n      else\n        output.push_back(at::empty({N, C, trgt_X, trgt_Y, trgt_Z, 3}, src_opt));\n      auto sgrad = output.back();\n      out_sN = sgrad.stride(0);\n      out_sC = sgrad.stride(1);\n      out_sX = sgrad.stride(2);\n      out_sY = dim < 2 ? 0L : sgrad.stride(3);\n      out_sZ = dim < 3 ? 0L : sgrad.stride(4);\n      out_sK = sgrad.stride(dim == 1 ? 3 : dim == 2 ? 4 : 5);\n      out_ptr = sgrad.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(sgrad);\n\n      if (iso && interpolation0 == InterpolationType::Nearest)\n        sgrad.zero_();\n      if (iso && interpolation0 == InterpolationType::Linear && dim == 1)\n        sgrad.zero_();\n    } else if (do_push) {\n      if (dim == 1)\n        output.push_back(at::zeros({N, C, src_X}, trgt_opt));\n      else if (dim == 2)\n        output.push_back(at::zeros({N, C, src_X, src_Y}, trgt_opt));\n      else\n        output.push_back(at::zeros({N, C, src_X, src_Y, src_Z}, trgt_opt));\n      auto push = output.back();\n      out_sN = push.stride(0);\n      out_sC = push.stride(1);\n      out_sX = push.stride(2);\n      out_sY = dim < 2 ? 0L : push.stride(3);\n      out_sZ = dim < 3 ? 0L : push.stride(4);\n      out_sK = 0L;\n      out_ptr = push.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(push);\n    } else if (do_count) {\n      if (dim == 1)\n        output.push_back(at::zeros({N, 1, src_X}, grid_opt));\n      else if (dim == 2)\n        output.push_back(at::zeros({N, 1, src_X, src_Y}, grid_opt));\n      else\n        output.push_back(at::zeros({N, 1, src_X, src_Y, src_Z}, grid_opt));\n      auto count = output.back();\n      out_sN = count.stride(0);\n      out_sC = count.stride(1);\n      out_sX = count.stride(2);\n      out_sY = dim < 2 ? 0L : count.stride(3);\n      out_sZ = dim < 3 ? 0L : count.stride(4);\n      out_sK = 0L;\n      out_ptr = count.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(count);\n    }\n    if (do_grad) {\n      if (dim == 1)\n        output.push_back(at::zeros({N, trgt_X, 1}, grid_opt));\n      else if (dim == 2)\n        output.push_back(at::zeros({N, trgt_X, trgt_Y, 2}, grid_opt));\n      else\n        output.push_back(at::zeros({N, trgt_X, trgt_Y, trgt_Z, 3}, grid_opt));\n      auto grad = output.back();\n      grad_sN = grad.stride(0);\n      grad_sX = grad.stride(1);\n      grad_sY = dim < 2 ? 0L : grad.stride(2);\n      grad_sZ = dim < 3 ? 0L : grad.stride(3);\n      grad_sC = grad.stride(dim == 1 ? 2 : dim == 2 ? 3 : 4);\n      grad_ptr = grad.data_ptr();\n      out_32b_ok = tensorCanUse32BitIndexMath(grad);\n\n      if (iso && interpolation0 == InterpolationType::Nearest)\n        grad.zero_();\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                        GENERIC PUSHPULL CLASS\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  // This class implements the bulk of the code.\n  // /!\\ No type and shape checking is performed here.\n\n  template <typename scalar_t, typename offset_t>\n  class PushPullImpl {\n   public:\n    // ~~~ CONSTRUCTOR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    PushPullImpl(const PushPullAllocator& info)\n        : output(info.output),\n          dim(info.dim),\n          bound0(info.bound0),\n          bound1(info.bound1),\n          bound2(info.bound2),\n          interpolation0(info.interpolation0),\n          interpolation1(info.interpolation1),\n          interpolation2(info.interpolation1),\n          iso(info.iso),\n          extrapolate(info.extrapolate),\n          do_pull(info.do_pull),\n          do_push(info.do_push),\n          do_count(info.do_count),\n          do_grad(info.do_grad),\n          do_sgrad(info.do_sgrad),\n          N(static_cast<offset_t>(info.N)),\n          C(static_cast<offset_t>(info.C)),\n          src_X(static_cast<offset_t>(info.src_X)),\n          src_Y(static_cast<offset_t>(info.src_Y)),\n          src_Z(static_cast<offset_t>(info.src_Z)),\n          trgt_X(static_cast<offset_t>(info.trgt_X)),\n          trgt_Y(static_cast<offset_t>(info.trgt_Y)),\n          trgt_Z(static_cast<offset_t>(info.trgt_Z)),\n          trgt_K(static_cast<offset_t>(info.trgt_K)),\n          src_sN(static_cast<offset_t>(info.src_sN)),\n          src_sC(static_cast<offset_t>(info.src_sC)),\n          src_sX(static_cast<offset_t>(info.src_sX)),\n          src_sY(static_cast<offset_t>(info.src_sY)),\n          src_sZ(static_cast<offset_t>(info.src_sZ)),\n          src_ptr(static_cast<scalar_t*>(info.src_ptr)),\n          trgt_sN(static_cast<offset_t>(info.trgt_sN)),\n          trgt_sC(static_cast<offset_t>(info.trgt_sC)),\n          trgt_sX(static_cast<offset_t>(info.trgt_sX)),\n          trgt_sY(static_cast<offset_t>(info.trgt_sY)),\n          trgt_sZ(static_cast<offset_t>(info.trgt_sZ)),\n          trgt_sK(static_cast<offset_t>(info.trgt_sK)),\n          trgt_ptr(static_cast<scalar_t*>(info.trgt_ptr)),\n          grid_sN(static_cast<offset_t>(info.grid_sN)),\n          grid_sC(static_cast<offset_t>(info.grid_sC)),\n          grid_sX(static_cast<offset_t>(info.grid_sX)),\n          grid_sY(static_cast<offset_t>(info.grid_sY)),\n          grid_sZ(static_cast<offset_t>(info.grid_sZ)),\n          grid_ptr(static_cast<scalar_t*>(info.grid_ptr)),\n          out_sN(static_cast<offset_t>(info.out_sN)),\n          out_sC(static_cast<offset_t>(info.out_sC)),\n          out_sX(static_cast<offset_t>(info.out_sX)),\n          out_sY(static_cast<offset_t>(info.out_sY)),\n          out_sZ(static_cast<offset_t>(info.out_sZ)),\n          out_sK(static_cast<offset_t>(info.out_sK)),\n          out_ptr(static_cast<scalar_t*>(info.out_ptr)),\n          grad_sN(static_cast<offset_t>(info.grad_sN)),\n          grad_sC(static_cast<offset_t>(info.grad_sC)),\n          grad_sX(static_cast<offset_t>(info.grad_sX)),\n          grad_sY(static_cast<offset_t>(info.grad_sY)),\n          grad_sZ(static_cast<offset_t>(info.grad_sZ)),\n          grad_ptr(static_cast<scalar_t*>(info.grad_ptr)) {}\n\n    // ~~~ PUBLIC VALUE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    std::deque<Tensor> output;\n\n    // MONAI_HOST MONAI_DEVICE void printInfo() const {\n    //   printf(\"dim: %d\\n\", dim);\n    //   printf(\"do_pull:  %d\\n\", do_pull);\n    //   printf(\"do_push:  %d\\n\", do_push);\n    //   printf(\"do_count: %d\\n\", do_count);\n    //   printf(\"do_sgrad: %d\\n\", do_sgrad);\n    //   printf(\"do_grad:  %d\\n\", do_grad);\n    //   printf(\"bound:         [%d %d %d]\\n\", static_cast<int>(bound0),\n    //     static_cast<int>(bound1), static_cast<int>(bound2));\n    //   printf(\"interpolation: [%d %d %d]\\n\", static_cast<int>(interpolation0),\n    //     static_cast<int>(interpolation1), static_cast<int>(interpolation2));\n    //   printf(\"src:  [%d %d %d]\\n\", src_Z, src_Y, src_X);\n    //   printf(\"trgt: [%d %d %d (%d)]\\n\", trgt_Z, trgt_Y, trgt_X, trgt_K);\n    //   printf(\"N: %d\\n\", N);\n    //   printf(\"C: %d\\n\", C);\n    //   printf(\"src  -> %lu\\n\", reinterpret_cast<std::uintptr_t>(src_ptr));\n    //   printf(\"trgt -> %lu\\n\", reinterpret_cast<std::uintptr_t>(trgt_ptr));\n    //   printf(\"grid -> %lu\\n\", reinterpret_cast<std::uintptr_t>(grid_ptr));\n    //   printf(\"out  -> %lu\\n\", reinterpret_cast<std::uintptr_t>(out_ptr));\n    //   printf(\"grad -> %lu\\n\", reinterpret_cast<std::uintptr_t>(grad_ptr));\n    // }\n\n    // ~~~ FUNCTORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n    // Loop over voxels that belong to one CUDA block\n    // This function is called by the CUDA kernel\n    MONAI_DEVICE void loop(int threadIdx, int blockIdx, int blockDim, int gridDim) const;\n\n    MONAI_HOST MONAI_DEVICE int64_t voxcount() const {\n      return N * trgt_X * trgt_Y * trgt_Z;\n    }\n\n   private:\n    // ~~~ COMPONENTS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    MONAI_DEVICE void check1d(offset_t w, offset_t n) const;\n    MONAI_DEVICE void check2d(offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void check3d(offset_t w, offset_t h, offset_t d, offset_t n) const;\n    MONAI_DEVICE void interpolate1d(scalar_t x, offset_t w, offset_t n) const;\n    MONAI_DEVICE void interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const;\n    MONAI_DEVICE void interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const;\n    MONAI_DEVICE void interpolate1d_sliding(scalar_t x, offset_t w, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate1d_sliding_nearest(scalar_t x, offset_t w, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate1d_sliding_linear(scalar_t x, offset_t w, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate2d(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void interpolate2d_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void interpolate2d_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const;\n    MONAI_DEVICE void interpolate2d_sliding(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate2d_sliding_nearest(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n)\n        const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate2d_sliding_bilinear(scalar_t x, scalar_t y, offset_t w, offset_t h, offset_t n)\n        const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate3d(scalar_t x, scalar_t y, scalar_t z, offset_t w, offset_t h, offset_t d, offset_t n)\n        const;\n    MONAI_DEVICE void interpolate3d_nearest(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const;\n    MONAI_DEVICE void interpolate3d_trilinear(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const;\n    MONAI_DEVICE void interpolate3d_sliding(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate3d_sliding_nearest(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const { /*TODO*/\n    }\n    MONAI_DEVICE void interpolate3d_sliding_trilinear(\n        scalar_t x,\n        scalar_t y,\n        scalar_t z,\n        offset_t w,\n        offset_t h,\n        offset_t d,\n        offset_t n) const { /*TODO*/\n    }\n\n    // ~~~ OPTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    int dim; // dimensionality (2 or 3)\n    BoundType bound0; // boundary condition  // x|W\n    BoundType bound1; // boundary condition  // y|H\n    BoundType bound2; // boundary condition  // z|D\n    InterpolationType interpolation0; // interpolation order // x|W\n    InterpolationType interpolation1; // interpolation order // y|H\n    InterpolationType interpolation2; // interpolation order // z|D\n    bool iso; // isotropic interpolation?\n    bool extrapolate; // compute out-of-bound values\n    bool do_pull; // sample a volume\n    bool do_push; // splat a volume\n    bool do_count; // splatting weights (= jacobian determinant)\n    bool do_grad; // backprop: gradient of grid // pull\n    bool do_sgrad; // sample spatial gradients\n\n    // ~~~ NAVIGATORS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    offset_t N;\n    offset_t C;\n    offset_t src_X;\n    offset_t src_Y;\n    offset_t src_Z;\n    offset_t trgt_X;\n    offset_t trgt_Y;\n    offset_t trgt_Z;\n    offset_t trgt_K;\n    offset_t src_sN;\n    offset_t src_sC;\n    offset_t src_sX;\n    offset_t src_sY;\n    offset_t src_sZ;\n    scalar_t* src_ptr;\n    offset_t trgt_sN;\n    offset_t trgt_sC;\n    offset_t trgt_sX;\n    offset_t trgt_sY;\n    offset_t trgt_sZ;\n    offset_t trgt_sK;\n    scalar_t* trgt_ptr;\n    offset_t grid_sN;\n    offset_t grid_sC;\n    offset_t grid_sX;\n    offset_t grid_sY;\n    offset_t grid_sZ;\n    scalar_t* grid_ptr;\n    offset_t out_sN;\n    offset_t out_sC;\n    offset_t out_sX;\n    offset_t out_sY;\n    offset_t out_sZ;\n    offset_t out_sK; // gradient dimension\n    scalar_t* out_ptr;\n    offset_t grad_sN;\n    offset_t grad_sC;\n    offset_t grad_sX;\n    offset_t grad_sY;\n    offset_t grad_sZ;\n    scalar_t* grad_ptr;\n  };\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                             LOOP\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::loop(int threadIdx, int blockIdx, int blockDim, int gridDim)\n      const {\n    int64_t index = blockIdx * blockDim + threadIdx;\n    int64_t nthreads = voxcount();\n    offset_t trgt_XYZ = trgt_Z * trgt_Y * trgt_X;\n    offset_t trgt_YZ = trgt_Z * trgt_Y;\n    offset_t n, w, h, d;\n    for (offset_t i = index; index < nthreads; index += blockDim * gridDim, i = index) {\n      // Convert index: linear to sub\n      n = (i / trgt_XYZ);\n      w = (i / trgt_YZ) % trgt_X;\n      h = (i / trgt_Z) % trgt_Y;\n      d = i % trgt_Z;\n\n      if (dim == 1)\n        check1d(w, n);\n      else if (dim == 2)\n        check2d(w, h, n);\n      else\n        check3d(w, h, d, n);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                        CHECK OUT-OF-BOUND\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  // Here, we:\n  // 1) read the [x,y,z] source coordinate for the current target voxel\n  // 3) check if the source coordinate is in bounds\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::check3d(offset_t w, offset_t h, offset_t d, offset_t n) const {\n    // get the corresponding input x, y, z co-ordinates from grid\n    scalar_t* grid_ptr_NXYZ = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY + d * grid_sZ;\n    scalar_t x = *grid_ptr_NXYZ;\n    scalar_t y = grid_ptr_NXYZ[grid_sC];\n    scalar_t z = grid_ptr_NXYZ[grid_sC * 2];\n\n    // Check if out-of-bound\n    if (!(extrapolate ||\n          (inbounds(x, src_X, static_cast<scalar_t>(TINY)) && inbounds(y, src_Y, static_cast<scalar_t>(TINY)) &&\n           inbounds(z, src_Z, static_cast<scalar_t>(TINY))))) {\n      if (do_pull || do_sgrad) {\n        scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) {\n          *out_ptr_NCXYZ = static_cast<scalar_t>(0);\n          if (do_sgrad) {\n            out_ptr_NCXYZ[out_sK] = static_cast<scalar_t>(0);\n            out_ptr_NCXYZ[out_sK * 2] = static_cast<scalar_t>(0);\n          }\n        }\n      }\n      if (do_grad) {\n        scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;\n        (*grad_ptr_NXYZ) = static_cast<scalar_t>(0);\n        grad_ptr_NXYZ[grad_sC] = static_cast<scalar_t>(0);\n        grad_ptr_NXYZ[grad_sC * 2] = static_cast<scalar_t>(0);\n      }\n      return;\n    }\n\n    // Next step\n    if (bound0 == BoundType::Sliding) {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate3d_sliding_nearest(x, y, z, w, h, d, n);\n          case 1:\n            return interpolate3d_sliding_trilinear(x, y, z, w, h, d, n);\n        }\n      return interpolate3d_sliding(x, y, z, w, h, d, n);\n    } else {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate3d_nearest(x, y, z, w, h, d, n);\n          case 1:\n            return interpolate3d_trilinear(x, y, z, w, h, d, n);\n        }\n      return interpolate3d(x, y, z, w, h, d, n);\n    }\n  }\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::check2d(offset_t w, offset_t h, offset_t n) const {\n    // get the corresponding input x, y, z co-ordinates from grid\n    scalar_t* grid_ptr_NXY = grid_ptr + n * grid_sN + w * grid_sX + h * grid_sY;\n    scalar_t x = *grid_ptr_NXY;\n    scalar_t y = grid_ptr_NXY[grid_sC];\n\n    // Check if out-of-bound\n    if (!(extrapolate ||\n          (inbounds(x, src_X, static_cast<scalar_t>(TINY)) && inbounds(y, src_Y, static_cast<scalar_t>(TINY))))) {\n      if (do_pull || do_sgrad) {\n        scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) {\n          *out_ptr_NCXY = static_cast<scalar_t>(0);\n          if (do_sgrad)\n            out_ptr_NCXY[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n      if (do_grad) {\n        scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;\n        (*grad_ptr_NXY) = static_cast<scalar_t>(0);\n        grad_ptr_NXY[grad_sC] = static_cast<scalar_t>(0);\n      }\n      return;\n    }\n\n    // Next step\n    if (bound0 == BoundType::Sliding) {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate2d_sliding_nearest(x, y, w, h, n);\n          case 1:\n            return interpolate2d_sliding_bilinear(x, y, w, h, n);\n        }\n      return interpolate2d_sliding(x, y, w, h, n);\n    } else {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate2d_nearest(x, y, w, h, n);\n          case 1:\n            return interpolate2d_bilinear(x, y, w, h, n);\n        }\n      return interpolate2d(x, y, w, h, n);\n    }\n  }\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::check1d(offset_t w, offset_t n) const {\n    // get the corresponding input x, y, z co-ordinates from grid\n    scalar_t* grid_ptr_NX = grid_ptr + n * grid_sN + w * grid_sX;\n    scalar_t x = *grid_ptr_NX;\n\n    // Check if out-of-bound\n    if (!(extrapolate || inbounds(x, src_X, static_cast<scalar_t>(TINY)))) {\n      if (do_pull || do_sgrad) {\n        scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) {\n          *out_ptr_NCX = static_cast<scalar_t>(0);\n          if (do_sgrad)\n            out_ptr_NCX[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n      if (do_grad) {\n        scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;\n        (*grad_ptr_NX) = static_cast<scalar_t>(0);\n        grad_ptr_NX[grad_sC] = static_cast<scalar_t>(0);\n      }\n      return;\n    }\n\n    // Next step\n    if (bound0 == BoundType::Sliding) {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate1d_sliding_nearest(x, w, n);\n          case 1:\n            return interpolate1d_sliding_linear(x, w, n);\n        }\n      return interpolate1d_sliding(x, w, n);\n    } else {\n      if (iso)\n        switch (static_cast<int>(interpolation0)) {\n          case 0:\n            return interpolate1d_nearest(x, w, n);\n          case 1:\n            return interpolate1d_linear(x, w, n);\n        }\n      return interpolate1d(x, w, n);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     GENERIC INTERPOLATION 3D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate3d(\n      scalar_t x,\n      scalar_t y,\n      scalar_t z,\n      offset_t w,\n      offset_t h,\n      offset_t d,\n      offset_t n) const {\n    // Get corner pixel values from (x, y, z)\n    offset_t bx0, bx1, by0, by1, bz0, bz1;\n    interpolation::bounds(interpolation0, x, bx0, bx1);\n    interpolation::bounds(interpolation1, y, by0, by1);\n    interpolation::bounds(interpolation2, z, bz0, bz1);\n    offset_t dbx = bx1 - bx0;\n    offset_t dby = by1 - by0;\n    offset_t dbz = bz1 - bz0;\n\n    // Pre-compute offsets and target value\n    scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;\n    scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;\n    scalar_t* out_ptr_NCXYZ0 = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n    scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n    scalar_t target[3 * MONAI_MAX_NUM_CHANNELS];\n    if (trgt_ptr && (do_push || do_grad))\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC) {\n        target[c] = *trgt_ptr_NCXYZ;\n        if (trgt_K > 0) {\n          target[c + C] = trgt_ptr_NCXYZ[trgt_sK];\n          target[c + C * 2] = trgt_ptr_NCXYZ[trgt_sK * 2];\n        }\n      }\n\n    // Initialize output\n    scalar_t* out_ptr_NCXYZ = out_ptr_NCXYZ0;\n    if (do_pull || do_sgrad) {\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC) {\n        *out_ptr_NCXYZ = static_cast<scalar_t>(0);\n        if (do_sgrad) {\n          out_ptr_NCXYZ[out_sK] = static_cast<scalar_t>(0);\n          out_ptr_NCXYZ[out_sK * 2] = static_cast<scalar_t>(0);\n        }\n      }\n    }\n\n    // Pre-compute indices/weights/grad\n    scalar_t wx[8], wy[8], wz[8]; // B-spline weights\n    scalar_t gx[8], gy[8], gz[8]; // B-spline derivatives\n    scalar_t hx[8], hy[8], hz[8]; // B-spline 2nd derivatives\n    offset_t ix[8], iy[8], iz[8]; // Warped indices\n    uint8_t sx[8], sy[8], sz[8]; // Warped indices\n\n    {\n      scalar_t *owz = static_cast<scalar_t*>(wz), *ogz = static_cast<scalar_t*>(gz), *ohz = static_cast<scalar_t*>(hz);\n      offset_t* oiz = static_cast<offset_t*>(iz);\n      uint8_t* osz = static_cast<uint8_t*>(sz);\n      for (offset_t bz = bz0; bz <= bz1; ++bz) {\n        scalar_t dz = z - bz;\n        *(owz++) = interpolation::fastweight(interpolation2, dz);\n        if (do_grad || do_sgrad)\n          *(ogz++) = interpolation::fastgrad(interpolation2, dz);\n        if (do_grad && trgt_sK > 1)\n          *(ohz++) = interpolation::fasthess(interpolation2, dz);\n        *(osz++) = bound::sign(bound2, bz, src_Z);\n        *(oiz++) = bound::index(bound2, bz, src_Z);\n      }\n    }\n    {\n      scalar_t *owy = static_cast<scalar_t*>(wy), *ogy = static_cast<scalar_t*>(gy), *ohy = static_cast<scalar_t*>(hy);\n      offset_t* oiy = static_cast<offset_t*>(iy);\n      uint8_t* osy = static_cast<uint8_t*>(sy);\n      for (offset_t by = by0; by <= by1; ++by) {\n        scalar_t dy = y - by;\n        *(owy++) = interpolation::fastweight(interpolation1, dy);\n        if (do_grad || do_sgrad)\n          *(ogy++) = interpolation::fastgrad(interpolation1, dy);\n        if (do_grad && trgt_sK > 1)\n          *(ohy++) = interpolation::fasthess(interpolation1, dy);\n        *(osy++) = bound::sign(bound1, by, src_Y);\n        *(oiy++) = bound::index(bound1, by, src_Y);\n      }\n    }\n    {\n      scalar_t *owx = static_cast<scalar_t*>(wx), *ogx = static_cast<scalar_t*>(gx), *ohx = static_cast<scalar_t*>(hx);\n      offset_t* oix = static_cast<offset_t*>(ix);\n      uint8_t* osx = static_cast<uint8_t*>(sx);\n      for (offset_t bx = bx0; bx <= bx1; ++bx) {\n        scalar_t dx = x - bx;\n        *(owx++) = interpolation::fastweight(interpolation0, dx);\n        if (do_grad || do_sgrad)\n          *(ogx++) = interpolation::fastgrad(interpolation0, dx);\n        if (do_grad && trgt_sK > 1)\n          *(ohx++) = interpolation::fasthess(interpolation0, dx);\n        *(osx++) = bound::sign(bound0, bx, src_X);\n        *(oix++) = bound::index(bound0, bx, src_X);\n      }\n    }\n\n    // Convolve coefficients with basis functions\n    scalar_t ogx, ogy, ogz;\n    ogx = ogy = ogz = static_cast<scalar_t>(0);\n    for (offset_t k = 0; k <= dbz; ++k) {\n      offset_t ooz = iz[k] * out_sZ;\n      offset_t osz = iz[k] * src_sZ;\n      uint8_t szz = sz[k];\n      scalar_t wzz = wz[k];\n      scalar_t gzz = gz[k];\n      scalar_t hzz = hz[k];\n      for (offset_t j = 0; j <= dby; ++j) {\n        offset_t ooyz = ooz + iy[j] * out_sY;\n        offset_t osyz = osz + iy[j] * src_sY;\n        uint8_t syz = szz * sy[j];\n        scalar_t wyy = wy[j];\n        scalar_t gyy = gy[j];\n        scalar_t hyy = hy[j];\n        for (offset_t i = 0; i <= dbx; ++i) {\n          offset_t ooxyz = ooyz + ix[i] * out_sX;\n          offset_t osxyz = osyz + ix[i] * src_sX;\n          uint8_t sxyz = syz * sx[i];\n          scalar_t wxx = wx[i];\n          scalar_t gxx = gx[i];\n          scalar_t hxx = hx[i];\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          if (do_pull) {\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t* out_ptr_NCXYZ = out_ptr_NCXYZ0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC)\n              *out_ptr_NCXYZ += bound::get(src_ptr_NC, osxyz, sxyz) * (wxx * wyy * wzz);\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          else if (do_sgrad) {\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t* out_ptr_NCXYZ = out_ptr_NCXYZ0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) {\n              scalar_t src = bound::get(src_ptr_NC, osxyz, sxyz);\n              *out_ptr_NCXYZ += src * (gxx * wyy * wzz);\n              out_ptr_NCXYZ[out_sK] += src * (wxx * gyy * wzz);\n              out_ptr_NCXYZ[2 * out_sK] += src * (wxx * wyy * gzz);\n            }\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          else if (do_push) {\n            if (trgt_K == 0) {\n              // Diff w.r.t. push/pull\n              scalar_t* out_ptr_NC = out_ptr_NC0;\n              for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n                bound::add(out_ptr_NC, ooxyz, (wxx * wyy * wzz) * target[c], sxyz);\n            } else {\n              // Diff w.r.t. sgrad\n              scalar_t* out_ptr_NC = out_ptr_NC0;\n              for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {\n                scalar_t val = (gxx * wyy * wzz) * target[c] + (wxx * gyy * wzz) * target[c + C] +\n                    (wxx * wyy * gzz) * target[c + C * 2];\n                bound::add(out_ptr_NC, ooxyz, val, sxyz);\n              }\n            }\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          else if (do_count) {\n            bound::add(out_ptr_NC0, ooxyz, (wxx * wyy * wzz), sxyz);\n          }\n\n          // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n          if (do_grad) {\n            if (trgt_K == 0) {\n              // Diff w.r.t. pull/push\n              scalar_t* src_ptr_NC = src_ptr_NC0;\n              scalar_t dot = static_cast<scalar_t>(0);\n              for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n                scalar_t src = bound::get(src_ptr_NC, osxyz, sxyz);\n                dot += (trgt_ptr ? src * target[c] : src);\n                // trgt_ptr == 0 in the backward pass of 'count'\n              }\n              ogx += (gxx * wyy * wzz) * dot;\n              ogy += (wxx * gyy * wzz) * dot;\n              ogz += (wxx * wyy * gzz) * dot;\n            } else {\n              // Diff w.r.t. sgrad\n              scalar_t* src_ptr_NC = src_ptr_NC0;\n              scalar_t dot0, dot1, dot2;\n              dot0 = dot1 = dot2 = static_cast<scalar_t>(0);\n              for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n                scalar_t src = bound::get(src_ptr_NC, osxyz, sxyz);\n                dot0 += src * target[c];\n                dot1 += src * target[c + C];\n                dot2 += src * target[c + C * 2];\n              }\n              ogx += (hxx * wyy * wzz) * dot0 + (gxx * gyy * wzz) * dot1 + (gxx * wyy * gzz) * dot2;\n              ogy += (gxx * gyy * wzz) * dot0 + (wxx * hyy * wzz) * dot1 + (wxx * gyy * gzz) * dot2;\n              ogz += (gxx * wyy * gzz) * dot0 + (wxx * gyy * gzz) * dot1 + (wxx * wyy * hzz) * dot2;\n            }\n          }\n\n        } // x\n      } // y\n    } // z\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;\n      (*grad_ptr_NXYZ) = ogx;\n      grad_ptr_NXYZ[grad_sC] = ogy;\n      grad_ptr_NXYZ[grad_sC * 2] = ogz;\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     GENERIC INTERPOLATION 2D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate2d(\n      scalar_t x,\n      scalar_t y,\n      offset_t w,\n      offset_t h,\n      offset_t n) const {\n    // Get corner pixel values from (x, y)\n    offset_t bx0, bx1, by0, by1;\n    interpolation::bounds(interpolation0, x, bx0, bx1);\n    interpolation::bounds(interpolation1, y, by0, by1);\n    offset_t dbx = bx1 - bx0;\n    offset_t dby = by1 - by0;\n\n    // Pre-compute offsets and target value\n    scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;\n    scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;\n    scalar_t* out_ptr_NCXY0 = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n    scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n    scalar_t target[2 * MONAI_MAX_NUM_CHANNELS];\n    if (trgt_ptr && (do_push || do_grad))\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC) {\n        target[c] = *trgt_ptr_NCXY;\n        if (trgt_K > 0) {\n          target[c + C] = trgt_ptr_NCXY[trgt_sK];\n        }\n      }\n\n    // Initialize output\n    scalar_t* out_ptr_NCXY = out_ptr_NCXY0;\n    if (do_pull || do_sgrad) {\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC) {\n        *out_ptr_NCXY = static_cast<scalar_t>(0);\n        if (do_sgrad) {\n          out_ptr_NCXY[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n    }\n\n    // Pre-compute indices/weights/grad\n    scalar_t wx[8], wy[8]; // B-spline weights\n    scalar_t gx[8], gy[8]; // B-spline derivatives\n    scalar_t hx[8], hy[8]; // B-spline 2nd derivatives\n    offset_t ix[8], iy[8]; // Warped indices\n    uint8_t sx[8], sy[8]; // Warped indices\n\n    {\n      scalar_t *owy = static_cast<scalar_t*>(wy), *ogy = static_cast<scalar_t*>(gy), *ohy = static_cast<scalar_t*>(hy);\n      offset_t* oiy = static_cast<offset_t*>(iy);\n      uint8_t* osy = static_cast<uint8_t*>(sy);\n      for (offset_t by = by0; by <= by1; ++by) {\n        scalar_t dy = y - by;\n        *(owy++) = interpolation::fastweight(interpolation1, dy);\n        if (do_grad || do_sgrad)\n          *(ogy++) = interpolation::fastgrad(interpolation1, dy);\n        if (do_grad && trgt_sK > 1)\n          *(ohy++) = interpolation::fasthess(interpolation1, dy);\n        *(osy++) = bound::sign(bound1, by, src_Y);\n        *(oiy++) = bound::index(bound1, by, src_Y);\n      }\n    }\n    {\n      scalar_t *owx = static_cast<scalar_t*>(wx), *ogx = static_cast<scalar_t*>(gx), *ohx = static_cast<scalar_t*>(hx);\n      offset_t* oix = static_cast<offset_t*>(ix);\n      uint8_t* osx = static_cast<uint8_t*>(sx);\n      for (offset_t bx = bx0; bx <= bx1; ++bx) {\n        scalar_t dx = x - bx;\n        *(owx++) = interpolation::fastweight(interpolation0, dx);\n        if (do_grad || do_sgrad)\n          *(ogx++) = interpolation::fastgrad(interpolation0, dx);\n        if (do_grad && trgt_sK > 1)\n          *(ohx++) = interpolation::fasthess(interpolation0, dx);\n        *(osx++) = bound::sign(bound0, bx, src_X);\n        *(oix++) = bound::index(bound0, bx, src_X);\n      }\n    }\n\n    // Convolve coefficients with basis functions\n    scalar_t ogx, ogy;\n    ogx = ogy = static_cast<scalar_t>(0);\n    for (offset_t j = 0; j <= dby; ++j) {\n      offset_t ooy = iy[j] * out_sY;\n      offset_t osy = iy[j] * src_sY;\n      uint8_t syy = sy[j];\n      scalar_t wyy = wy[j];\n      scalar_t gyy = gy[j];\n      scalar_t hyy = hy[j];\n      for (offset_t i = 0; i <= dbx; ++i) {\n        offset_t ooxy = ooy + ix[i] * out_sX;\n        offset_t osxy = osy + ix[i] * src_sX;\n        uint8_t sxy = syy * sx[i];\n        scalar_t wxx = wx[i];\n        scalar_t gxx = gx[i];\n        scalar_t hxx = hx[i];\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        if (do_pull) {\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t* out_ptr_NCXY = out_ptr_NCXY0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC)\n            *out_ptr_NCXY += bound::get(src_ptr_NC, osxy, sxy) * (wxx * wyy);\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        else if (do_sgrad) {\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t* out_ptr_NCXY = out_ptr_NCXY0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) {\n            scalar_t src = bound::get(src_ptr_NC, osxy, sxy);\n            *out_ptr_NCXY += src * (gxx * wyy);\n            out_ptr_NCXY[out_sK] += src * (wxx * gyy);\n          }\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        else if (do_push) {\n          if (trgt_K == 0) {\n            // Diff w.r.t. push/pull\n            scalar_t* out_ptr_NC = out_ptr_NC0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n              bound::add(out_ptr_NC, ooxy, (wxx * wyy) * target[c], sxy);\n          } else {\n            // Diff w.r.t. sgrad\n            scalar_t* out_ptr_NC = out_ptr_NC0;\n            for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {\n              scalar_t val = (gxx * wyy) * target[c] + (wxx * gyy) * target[c + C];\n              bound::add(out_ptr_NC, ooxy, val, sxy);\n            }\n          }\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        else if (do_count) {\n          bound::add(out_ptr_NC0, ooxy, (wxx * wyy), sxy);\n        }\n\n        // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n        if (do_grad) {\n          if (trgt_K == 0) {\n            // Diff w.r.t. pull/push\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t dot = static_cast<scalar_t>(0);\n            for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n              scalar_t src = bound::get(src_ptr_NC, osxy, sxy);\n              dot += (trgt_ptr ? src * target[c] : src);\n              // trgt_ptr == 0 in the backward pass of 'count'\n            }\n            ogx += (gxx * wyy) * dot;\n            ogy += (wxx * gyy) * dot;\n          } else {\n            // Diff w.r.t. sgrad\n            scalar_t* src_ptr_NC = src_ptr_NC0;\n            scalar_t dot0, dot1;\n            dot0 = dot1 = static_cast<scalar_t>(0);\n            for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n              scalar_t src = bound::get(src_ptr_NC, osxy, sxy);\n              dot0 += src * target[c];\n              dot1 += src * target[c + C];\n            }\n            ogx += (hxx * wyy) * dot0 + (gxx * gyy) * dot1;\n            ogy += (gxx * gyy) * dot0 + (wxx * hyy) * dot1;\n          }\n        }\n\n      } // x\n    } // y\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;\n      (*grad_ptr_NXY) = ogx;\n      grad_ptr_NXY[grad_sC] = ogy;\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     GENERIC INTERPOLATION 1D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate1d(scalar_t x, offset_t w, offset_t n) const {\n    // Get corner pixel values from (x, y)\n    offset_t bx0, bx1;\n    interpolation::bounds(interpolation0, x, bx0, bx1);\n    offset_t dbx = bx1 - bx0;\n\n    // Pre-compute offsets and target value\n    scalar_t* src_ptr_NC0 = src_ptr + n * src_sN;\n    scalar_t* out_ptr_NC0 = out_ptr + n * out_sN;\n    scalar_t* out_ptr_NCX0 = out_ptr + n * out_sN + w * out_sX;\n    scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n    scalar_t target[2 * MONAI_MAX_NUM_CHANNELS];\n    if (trgt_ptr && (do_push || do_grad))\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC) {\n        target[c] = *trgt_ptr_NCX;\n        if (trgt_K > 0) {\n          target[c + C] = trgt_ptr_NCX[trgt_sK];\n        }\n      }\n\n    // Initialize output\n    scalar_t* out_ptr_NCX = out_ptr_NCX0;\n    if (do_pull || do_sgrad) {\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC) {\n        *out_ptr_NCX = static_cast<scalar_t>(0);\n        if (do_sgrad) {\n          out_ptr_NCX[out_sK] = static_cast<scalar_t>(0);\n        }\n      }\n    }\n\n    // Pre-compute indices/weights/grad\n    scalar_t wx[8]; // B-spline weights\n    scalar_t gx[8]; // B-spline derivatives\n    scalar_t hx[8]; // B-spline 2nd derivatives\n    offset_t ix[8]; // Warped indices\n    uint8_t sx[8]; // Warped indices\n\n    {\n      scalar_t *owx = static_cast<scalar_t*>(wx), *ogx = static_cast<scalar_t*>(gx), *ohx = static_cast<scalar_t*>(hx);\n      offset_t* oix = static_cast<offset_t*>(ix);\n      uint8_t* osx = static_cast<uint8_t*>(sx);\n      for (offset_t bx = bx0; bx <= bx1; ++bx) {\n        scalar_t dx = x - bx;\n        *(owx++) = interpolation::fastweight(interpolation0, dx);\n        if (do_grad || do_sgrad)\n          *(ogx++) = interpolation::fastgrad(interpolation0, dx);\n        if (do_grad && trgt_sK > 1)\n          *(ohx++) = interpolation::fasthess(interpolation0, dx);\n        *(osx++) = bound::sign(bound0, bx, src_X);\n        *(oix++) = bound::index(bound0, bx, src_X);\n      }\n    }\n\n    // Convolve coefficients with basis functions\n    scalar_t ogx;\n    ogx = static_cast<scalar_t>(0);\n    for (offset_t i = 0; i <= dbx; ++i) {\n      offset_t oox = ix[i] * out_sX;\n      offset_t osx = ix[i] * src_sX;\n      uint8_t sxx = sx[i];\n      scalar_t wxx = wx[i];\n      scalar_t gxx = gx[i];\n      scalar_t hxx = hx[i];\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      if (do_pull) {\n        scalar_t* src_ptr_NC = src_ptr_NC0;\n        scalar_t* out_ptr_NCX = out_ptr_NCX0;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC)\n          *out_ptr_NCX += bound::get(src_ptr_NC, osx, sxx) * wxx;\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      else if (do_sgrad) {\n        scalar_t* src_ptr_NC = src_ptr_NC0;\n        scalar_t* out_ptr_NCX = out_ptr_NCX0;\n        for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {\n          scalar_t src = bound::get(src_ptr_NC, osx, sxx);\n          *out_ptr_NCX += src * gxx;\n        }\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      else if (do_push) {\n        if (trgt_K == 0) {\n          // Diff w.r.t. push/pull\n          scalar_t* out_ptr_NC = out_ptr_NC0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n            bound::add(out_ptr_NC, oox, wxx * target[c], sxx);\n        } else {\n          // Diff w.r.t. sgrad\n          scalar_t* out_ptr_NC = out_ptr_NC0;\n          for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC) {\n            scalar_t val = gxx * target[c];\n            bound::add(out_ptr_NC, oox, val, sxx);\n          }\n        }\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      else if (do_count) {\n        bound::add(out_ptr_NC0, oox, wxx, sxx);\n      }\n\n      // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n      if (do_grad) {\n        if (trgt_K == 0) {\n          // Diff w.r.t. pull/push\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t dot = static_cast<scalar_t>(0);\n          for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n            scalar_t src = bound::get(src_ptr_NC, osx, sxx);\n            dot += (trgt_ptr ? src * target[c] : src);\n            // trgt_ptr == 0 in the backward pass of 'count'\n          }\n          ogx += gxx * dot;\n        } else {\n          // Diff w.r.t. sgrad\n          scalar_t* src_ptr_NC = src_ptr_NC0;\n          scalar_t dot;\n          dot = static_cast<scalar_t>(0);\n          for (offset_t c = 0; c < C; ++c, src_ptr_NC += src_sC) {\n            scalar_t src = bound::get(src_ptr_NC, osx, sxx);\n            dot += src * target[c];\n          }\n          ogx += hxx * dot;\n        }\n      }\n\n    } // x\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Grad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;\n      (*grad_ptr_NX) = ogx;\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     LINEAR INTERPOLATION 3D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate3d_trilinear(\n      scalar_t x,\n      scalar_t y,\n      scalar_t z,\n      offset_t w,\n      offset_t h,\n      offset_t d,\n      offset_t n) const {\n    // Get corner pixel values from (x, y, z)\n    offset_t ix0 = static_cast<offset_t>(std::floor(x));\n    offset_t iy0 = static_cast<offset_t>(std::floor(y));\n    offset_t iz0 = static_cast<offset_t>(std::floor(z));\n\n    // Interpolation weights (inversely proportional to distance)\n    scalar_t dx1 = x - ix0;\n    scalar_t dy1 = y - iy0;\n    scalar_t dz1 = z - iz0;\n    scalar_t dx0 = 1. - dx1;\n    scalar_t dy0 = 1. - dy1;\n    scalar_t dz0 = 1. - dz1;\n    scalar_t w000 = dx0 * dy0 * dz0;\n    scalar_t w100 = dx1 * dy0 * dz0;\n    scalar_t w010 = dx0 * dy1 * dz0;\n    scalar_t w001 = dx0 * dy0 * dz1;\n    scalar_t w110 = dx1 * dy1 * dz0;\n    scalar_t w011 = dx0 * dy1 * dz1;\n    scalar_t w101 = dx1 * dy0 * dz1;\n    scalar_t w111 = dx1 * dy1 * dz1;\n\n    // Sign (/!\\ compute sign before warping indices)\n    int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X);\n    int8_t sy1 = bound::sign(bound1, iy0 + 1, src_Y);\n    int8_t sz1 = bound::sign(bound2, iz0 + 1, src_Z);\n    int8_t sx0 = bound::sign(bound0, ix0, src_X);\n    int8_t sy0 = bound::sign(bound1, iy0, src_Y);\n    int8_t sz0 = bound::sign(bound2, iz0, src_Z);\n    int8_t s000 = sx0 * sy0 * sz0;\n    int8_t s100 = sx1 * sy0 * sz0;\n    int8_t s010 = sx0 * sy1 * sz0;\n    int8_t s001 = sx0 * sy0 * sz1;\n    int8_t s110 = sx1 * sy1 * sz0;\n    int8_t s011 = sx0 * sy1 * sz1;\n    int8_t s101 = sx1 * sy0 * sz1;\n    int8_t s111 = sx1 * sy1 * sz1;\n\n    // Warp indices\n    offset_t ix1, iy1, iz1;\n    ix1 = bound::index(bound0, ix0 + 1, src_X);\n    iy1 = bound::index(bound1, iy0 + 1, src_Y);\n    iz1 = bound::index(bound2, iz0 + 1, src_Z);\n    ix0 = bound::index(bound0, ix0, src_X);\n    iy0 = bound::index(bound1, iy0, src_Y);\n    iz0 = bound::index(bound2, iz0, src_Z);\n\n    offset_t o000, o100, o010, o001, o110, o011, o101, o111;\n\n    if (do_pull || do_grad || do_sgrad) {\n      // Offsets into source volume\n      o000 = ix0 * src_sX + iy0 * src_sY + iz0 * src_sZ;\n      o100 = ix1 * src_sX + iy0 * src_sY + iz0 * src_sZ;\n      o010 = ix0 * src_sX + iy1 * src_sY + iz0 * src_sZ;\n      o001 = ix0 * src_sX + iy0 * src_sY + iz1 * src_sZ;\n      o110 = ix1 * src_sX + iy1 * src_sY + iz0 * src_sZ;\n      o011 = ix0 * src_sX + iy1 * src_sY + iz1 * src_sZ;\n      o101 = ix1 * src_sX + iy0 * src_sY + iz1 * src_sZ;\n      o111 = ix1 * src_sX + iy1 * src_sY + iz1 * src_sZ;\n    } else if (!(do_push || do_count)) {\n      o000 = o100 = o010 = o001 = o110 = o011 = o101 = o111 = 0;\n    }\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t gx = static_cast<scalar_t>(0);\n      scalar_t gy = static_cast<scalar_t>(0);\n      scalar_t gz = static_cast<scalar_t>(0);\n      scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      if (trgt_K == 0) {\n        // backward w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt = trgt_ptr ? *trgt_ptr_NCXYZ : static_cast<scalar_t>(1);\n          // ^ trgt_ptr == 0 during the backward pass of count\n          src = bound::get(src_ptr_NC, o000, s000);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy0 * dz0 * src;\n          gy -= dx0 * dz0 * src;\n          gz -= dx0 * dy0 * src;\n          src = bound::get(src_ptr_NC, o100, s100);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy0 * dz0 * src;\n          gy -= dx1 * dz0 * src;\n          gz -= dx1 * dy0 * src;\n          src = bound::get(src_ptr_NC, o010, s010);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy1 * dz0 * src;\n          gy += dx0 * dz0 * src;\n          gz -= dx0 * dy1 * src;\n          src = bound::get(src_ptr_NC, o110, s110);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy1 * dz0 * src;\n          gy += dx1 * dz0 * src;\n          gz -= dx1 * dy1 * src;\n          src = bound::get(src_ptr_NC, o001, s001);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy0 * dz1 * src;\n          gy -= dx0 * dz1 * src;\n          gz += dx0 * dy0 * src;\n          src = bound::get(src_ptr_NC, o101, s101);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy0 * dz1 * src;\n          gy -= dx1 * dz1 * src;\n          gz += dx1 * dy0 * src;\n          src = bound::get(src_ptr_NC, o011, s011);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy1 * dz1 * src;\n          gy += dx0 * dz1 * src;\n          gz += dx0 * dy1 * src;\n          src = bound::get(src_ptr_NC, o111, s111);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy1 * dz1 * src;\n          gy += dx1 * dz1 * src;\n          gz += dx1 * dy1 * src;\n        }\n      } else {\n        // backward w.r.t. sgrad\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt0 = *trgt_ptr_NCXYZ, trgt1 = trgt_ptr_NCXYZ[trgt_sK], trgt2 = trgt_ptr_NCXYZ[trgt_sK * 2];\n          src = bound::get(src_ptr_NC, o000, s000);\n          gx += (dz0 * trgt1 + dy0 * trgt2) * src;\n          gy += (dz0 * trgt0 + dx0 * trgt2) * src;\n          gz += (dy0 * trgt0 + dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o100, s100);\n          gx += (-dz0 * trgt1 - dy0 * trgt2) * src;\n          gy += (-dz0 * trgt0 + dx1 * trgt2) * src;\n          gz += (-dy0 * trgt0 + dx1 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o010, s010);\n          gx += (-dz0 * trgt1 + dy1 * trgt2) * src;\n          gy += (-dz0 * trgt0 - dx0 * trgt2) * src;\n          gz += (dy1 * trgt0 - dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o110, s110);\n          gx += (dz0 * trgt1 - dy1 * trgt2) * src;\n          gy += (dz0 * trgt0 - dx1 * trgt2) * src;\n          gz += (-dy1 * trgt0 - dx1 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o001, s001);\n          gx += (dz1 * trgt1 - dy0 * trgt2) * src;\n          gy += (dz1 * trgt0 - dx0 * trgt2) * src;\n          gz += (-dy0 * trgt0 - dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o101, s101);\n          gx += (-dz1 * trgt1 + dy0 * trgt2) * src;\n          gy += (-dz1 * trgt0 - dx1 * trgt2) * src;\n          gz += (dy0 * trgt0 - dx1 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o011, s011);\n          gx += (-dz1 * trgt1 - dy1 * trgt2) * src;\n          gy += (-dz1 * trgt0 + dx0 * trgt2) * src;\n          gz += (-dy1 * trgt0 + dx0 * trgt1) * src;\n          src = bound::get(src_ptr_NC, o111, s111);\n          gx += (dz1 * trgt1 + dy1 * trgt2) * src;\n          gy += (dz1 * trgt0 + dx1 * trgt2) * src;\n          gz += (dy1 * trgt0 + dx1 * trgt1) * src;\n        }\n      }\n\n      scalar_t* grad_ptr_NXYZ = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY + d * grad_sZ;\n      (*grad_ptr_NXYZ) = gx;\n      grad_ptr_NXYZ[grad_sC] = gy;\n      grad_ptr_NXYZ[grad_sC * 2] = gz;\n    }\n    if (do_push || do_count) {\n      // Offsets into 'push' volume\n      o000 = ix0 * out_sX + iy0 * out_sY + iz0 * out_sZ;\n      o100 = ix1 * out_sX + iy0 * out_sY + iz0 * out_sZ;\n      o010 = ix0 * out_sX + iy1 * out_sY + iz0 * out_sZ;\n      o001 = ix0 * out_sX + iy0 * out_sY + iz1 * out_sZ;\n      o110 = ix1 * out_sX + iy1 * out_sY + iz0 * out_sZ;\n      o011 = ix0 * out_sX + iy1 * out_sY + iz1 * out_sZ;\n      o101 = ix1 * out_sX + iy0 * out_sY + iz1 * out_sZ;\n      o111 = ix1 * out_sX + iy1 * out_sY + iz1 * out_sZ;\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_pull) {\n      scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCXYZ = bound::get(src_ptr_NC, o000, s000) * w000 + bound::get(src_ptr_NC, o100, s100) * w100 +\n            bound::get(src_ptr_NC, o010, s010) * w010 + bound::get(src_ptr_NC, o110, s110) * w110 +\n            bound::get(src_ptr_NC, o001, s001) * w001 + bound::get(src_ptr_NC, o101, s101) * w101 +\n            bound::get(src_ptr_NC, o011, s011) * w011 + bound::get(src_ptr_NC, o111, s111) * w111;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~\n    else if (do_sgrad) {\n      scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC) {\n        scalar_t src000 = bound::get(src_ptr_NC, o000, s000);\n        scalar_t src100 = bound::get(src_ptr_NC, o100, s100);\n        scalar_t src010 = bound::get(src_ptr_NC, o010, s010);\n        scalar_t src110 = bound::get(src_ptr_NC, o110, s110);\n        scalar_t src001 = bound::get(src_ptr_NC, o001, s001);\n        scalar_t src101 = bound::get(src_ptr_NC, o101, s101);\n        scalar_t src011 = bound::get(src_ptr_NC, o011, s011);\n        scalar_t src111 = bound::get(src_ptr_NC, o111, s111);\n        *out_ptr_NCXYZ = -dy0 * dz0 * src000 + dy0 * dz0 * src100 - dy1 * dz0 * src010 + dy1 * dz0 * src110 -\n            dy0 * dz1 * src001 + dy0 * dz1 * src101 - dy1 * dz1 * src011 + dy1 * dz1 * src111;\n        out_ptr_NCXYZ[out_sK] = -dx0 * dz0 * src000 - dx1 * dz0 * src100 + dx0 * dz0 * src010 + dx1 * dz0 * src110 -\n            dx0 * dz1 * src001 - dx1 * dz1 * src101 + dx0 * dz1 * src011 + dx1 * dz1 * src111;\n        out_ptr_NCXYZ[out_sK * 2] = -dx0 * dy0 * src000 - dx1 * dy0 * src100 - dx0 * dy1 * src010 - dx1 * dy1 * src110 +\n            dx0 * dy0 * src001 + dx1 * dy0 * src101 + dx0 * dy1 * src011 + dx1 * dy1 * src111;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_push) {\n      scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      if (trgt_K == 0) {\n        // Diff w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt = *trgt_ptr_NCXYZ;\n          bound::add(out_ptr_NC, o000, w000 * trgt, s000);\n          bound::add(out_ptr_NC, o100, w100 * trgt, s100);\n          bound::add(out_ptr_NC, o010, w010 * trgt, s010);\n          bound::add(out_ptr_NC, o110, w110 * trgt, s110);\n          bound::add(out_ptr_NC, o001, w001 * trgt, s001);\n          bound::add(out_ptr_NC, o101, w101 * trgt, s101);\n          bound::add(out_ptr_NC, o011, w011 * trgt, s011);\n          bound::add(out_ptr_NC, o111, w111 * trgt, s111);\n        }\n      } else {\n        // Diff w.r.t. sgrad\n        scalar_t val;\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt0 = *trgt_ptr_NCXYZ, trgt1 = trgt_ptr_NCXYZ[trgt_sK], trgt2 = trgt_ptr_NCXYZ[trgt_sK * 2];\n          val = -dy0 * dz0 * trgt0 - dx0 * dz0 * trgt1 - dx0 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o000, val, s000);\n          val = dy0 * dz0 * trgt0 - dx1 * dz0 * trgt1 - dx1 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o100, val, s100);\n          val = -dy1 * dz0 * trgt0 + dx0 * dz0 * trgt1 - dx0 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o010, val, s010);\n          val = dy1 * dz0 * trgt0 + dx1 * dz0 * trgt1 - dx1 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o110, val, s110);\n          val = -dy0 * dz1 * trgt0 - dx0 * dz1 * trgt1 + dx0 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o001, val, s001);\n          val = dy0 * dz1 * trgt0 - dx1 * dz1 * trgt1 + dx1 * dy0 * trgt2;\n          bound::add(out_ptr_NC, o101, val, s101);\n          val = -dy1 * dz1 * trgt0 + dx0 * dz1 * trgt1 + dx0 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o011, val, s011);\n          val = dy1 * dz1 * trgt0 + dx1 * dz1 * trgt1 + dx1 * dy1 * trgt2;\n          bound::add(out_ptr_NC, o111, val, s111);\n        }\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_count) {\n      scalar_t* out_ptr_N = out_ptr + n * out_sN;\n      bound::add(out_ptr_N, o000, w000, s000);\n      bound::add(out_ptr_N, o100, w100, s100);\n      bound::add(out_ptr_N, o010, w010, s010);\n      bound::add(out_ptr_N, o110, w110, s110);\n      bound::add(out_ptr_N, o001, w001, s001);\n      bound::add(out_ptr_N, o101, w101, s101);\n      bound::add(out_ptr_N, o011, w011, s011);\n      bound::add(out_ptr_N, o111, w111, s111);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     LINEAR INTERPOLATION 2D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate2d_bilinear(\n      scalar_t x,\n      scalar_t y,\n      offset_t w,\n      offset_t h,\n      offset_t n) const {\n    // Get corner pixel values from (x, y, z)\n    offset_t ix0 = static_cast<offset_t>(std::floor(x));\n    offset_t iy0 = static_cast<offset_t>(std::floor(y));\n\n    // Interpolation weights (inversely proportional to distance)\n    scalar_t dx1 = x - ix0;\n    scalar_t dy1 = y - iy0;\n    scalar_t dx0 = 1. - dx1;\n    scalar_t dy0 = 1. - dy1;\n    scalar_t w00 = dx0 * dy0;\n    scalar_t w10 = dx1 * dy0;\n    scalar_t w01 = dx0 * dy1;\n    scalar_t w11 = dx1 * dy1;\n\n    // Sign (/!\\ compute sign before warping indices)\n    int8_t sx1 = bound::sign(bound0, ix0 + 1, src_X);\n    int8_t sy1 = bound::sign(bound1, iy0 + 1, src_Y);\n    int8_t sx0 = bound::sign(bound0, ix0, src_X);\n    int8_t sy0 = bound::sign(bound1, iy0, src_Y);\n    int8_t s00 = sx0 * sy0;\n    int8_t s10 = sx1 * sy0;\n    int8_t s01 = sx0 * sy1;\n    int8_t s11 = sx1 * sy1;\n\n    // Warp indices\n    offset_t ix1, iy1;\n    ix1 = bound::index(bound0, ix0 + 1, src_X);\n    iy1 = bound::index(bound1, iy0 + 1, src_Y);\n    ix0 = bound::index(bound0, ix0, src_X);\n    iy0 = bound::index(bound1, iy0, src_Y);\n\n    offset_t o00, o10, o01, o11;\n    if (do_pull || do_grad || do_sgrad) {\n      // Offsets into source volume\n      o00 = ix0 * src_sX + iy0 * src_sY;\n      o10 = ix1 * src_sX + iy0 * src_sY;\n      o01 = ix0 * src_sX + iy1 * src_sY;\n      o11 = ix1 * src_sX + iy1 * src_sY;\n    } else if (!(do_push || do_count)) {\n      o00 = o10 = o01 = o11 = 0;\n    }\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      scalar_t gx = static_cast<scalar_t>(0);\n      scalar_t gy = static_cast<scalar_t>(0);\n      scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      if (trgt_K == 0) {\n        // backward w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt = trgt_ptr ? *trgt_ptr_NCXY : static_cast<scalar_t>(1);\n          // ^ trgt_ptr == 0 during the backward pass of count\n          src = bound::get(src_ptr_NC, o00, s00);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy0 * src;\n          gy -= dx0 * src;\n          src = bound::get(src_ptr_NC, o10, s10);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy0 * src;\n          gy -= dx1 * src;\n          src = bound::get(src_ptr_NC, o01, s01);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= dy1 * src;\n          gy += dx0 * src;\n          src = bound::get(src_ptr_NC, o11, s11);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += dy1 * src;\n          gy += dx1 * src;\n        }\n      } else {\n        // backward w.r.t. sgrad\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt0 = *trgt_ptr_NCXY, trgt1 = trgt_ptr_NCXY[trgt_sK];\n          src = bound::get(src_ptr_NC, o00, s00);\n          gx += trgt1 * src;\n          gy += trgt0 * src;\n          src = bound::get(src_ptr_NC, o10, s10);\n          gx -= trgt1 * src;\n          gy -= trgt0 * src;\n          src = bound::get(src_ptr_NC, o01, s01);\n          gx -= trgt1 * src;\n          gy -= trgt0 * src;\n          src = bound::get(src_ptr_NC, o11, s11);\n          gx += trgt1 * src;\n          gy += trgt0 * src;\n        }\n      }\n\n      scalar_t* grad_ptr_NXY = grad_ptr + n * grad_sN + w * grad_sX + h * grad_sY;\n      (*grad_ptr_NXY) = gx;\n      grad_ptr_NXY[grad_sC] = gy;\n    }\n    if (do_push || do_count) {\n      // Offsets into 'push' volume\n      o00 = ix0 * out_sX + iy0 * out_sY;\n      o10 = ix1 * out_sX + iy0 * out_sY;\n      o01 = ix0 * out_sX + iy1 * out_sY;\n      o11 = ix1 * out_sX + iy1 * out_sY;\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_pull) {\n      scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCXY = bound::get(src_ptr_NC, o00, s00) * w00 + bound::get(src_ptr_NC, o10, s10) * w10 +\n            bound::get(src_ptr_NC, o01, s01) * w01 + bound::get(src_ptr_NC, o11, s11) * w11;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_sgrad) {\n      scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC) {\n        scalar_t src00 = bound::get(src_ptr_NC, o00, s00);\n        scalar_t src10 = bound::get(src_ptr_NC, o10, s10);\n        scalar_t src01 = bound::get(src_ptr_NC, o01, s01);\n        scalar_t src11 = bound::get(src_ptr_NC, o11, s11);\n        *out_ptr_NCXY = -dy0 * src00 + dy0 * src10 - dy1 * src01 + dy1 * src11;\n        out_ptr_NCXY[out_sK] = -dx0 * src00 - dx1 * src10 + dx0 * src01 + dx1 * src11;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_push) {\n      scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      if (trgt_K == 0) {\n        // Diff w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt = *trgt_ptr_NCXY;\n          bound::add(out_ptr_NC, o00, w00 * trgt, s00);\n          bound::add(out_ptr_NC, o10, w10 * trgt, s10);\n          bound::add(out_ptr_NC, o01, w01 * trgt, s01);\n          bound::add(out_ptr_NC, o11, w11 * trgt, s11);\n        }\n      } else {\n        // Diff w.r.t. sgrad\n        scalar_t val;\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt0 = *trgt_ptr_NCXY, trgt1 = trgt_ptr_NCXY[trgt_sK];\n          val = -dy0 * trgt0 - dx0 * trgt1;\n          bound::add(out_ptr_NC, o00, val, s00);\n          val = dy0 * trgt0 - dx1 * trgt1;\n          bound::add(out_ptr_NC, o10, val, s10);\n          val = -dy1 * trgt0 + dx0 * trgt1;\n          bound::add(out_ptr_NC, o01, val, s01);\n          val = dy1 * trgt0 + dx1 * trgt1;\n          bound::add(out_ptr_NC, o11, val, s11);\n        }\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_count) {\n      scalar_t* out_ptr_N = out_ptr + n * out_sN;\n      bound::add(out_ptr_N, o00, w00, s00);\n      bound::add(out_ptr_N, o10, w10, s10);\n      bound::add(out_ptr_N, o01, w01, s01);\n      bound::add(out_ptr_N, o11, w11, s11);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                     LINEAR INTERPOLATION 1D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate1d_linear(scalar_t x, offset_t w, offset_t n) const {\n    // Get corner pixel values from (x)\n    offset_t ix0 = static_cast<offset_t>(std::floor(x));\n\n    // Interpolation weights (inversely proportional to distance)\n    scalar_t w1 = x - ix0;\n    scalar_t w0 = 1. - w1;\n\n    // Sign (/!\\ compute sign before warping indices)\n    int8_t s1 = bound::sign(bound0, ix0 + 1, src_X);\n    int8_t s0 = bound::sign(bound0, ix0, src_X);\n\n    // Warp indices\n    offset_t ix1;\n    ix1 = bound::index(bound0, ix0 + 1, src_X);\n    ix0 = bound::index(bound0, ix0, src_X);\n\n    // Offsets into source volume\n    offset_t o0, o1;\n    if (do_pull || do_grad || do_sgrad) {\n      o0 = ix0 * src_sX;\n      o1 = ix1 * src_sX;\n    } else if (!(do_push || do_count)) {\n      o0 = o1 = 0;\n    }\n\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Grid gradient ~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_grad) {\n      if (trgt_K == 0) {\n        // backward w.r.t. push/pull\n        scalar_t gx = static_cast<scalar_t>(0);\n        scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n        scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, src_ptr_NC += src_sC) {\n          scalar_t src;\n          scalar_t trgt = trgt_ptr ? *trgt_ptr_NCX : static_cast<scalar_t>(1);\n          // ^ trgt_ptr == 0 during the backward pass of count\n          src = bound::get(src_ptr_NC, o0, s0);\n          if (trgt_ptr)\n            src *= trgt;\n          gx -= src;\n          src = bound::get(src_ptr_NC, o1, s1);\n          if (trgt_ptr)\n            src *= trgt;\n          gx += src;\n        }\n\n        scalar_t* grad_ptr_NX = grad_ptr + n * grad_sN + w * grad_sX;\n        (*grad_ptr_NX) = gx;\n      } else {\n        // backward w.r.t. sgrad\n        // -> zero (make sure this is done at initialization)\n      }\n    }\n    if (do_push || do_count) {\n      // Offsets into 'push' volume\n      o0 = ix0 * out_sX;\n      o1 = ix1 * out_sX;\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Pull ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    if (do_pull) {\n      scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCX = bound::get(src_ptr_NC, o0, s0) * w0 + bound::get(src_ptr_NC, o1, s1) * w1;\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SGrad ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_sgrad) {\n      scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC) {\n        *out_ptr_NCX = bound::get(src_ptr_NC, o1, s1) - bound::get(src_ptr_NC, o0, s0);\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_push) {\n      scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      if (trgt_K == 0) {\n        // Diff w.r.t. push/pull\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt = *trgt_ptr_NCX;\n          bound::add(out_ptr_NC, o0, w0 * trgt, s0);\n          bound::add(out_ptr_NC, o1, w1 * trgt, s1);\n        }\n      } else {\n        // Diff w.r.t. sgrad\n        for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC) {\n          scalar_t trgt0 = *trgt_ptr_NCX;\n          bound::add(out_ptr_NC, o0, -trgt0, s0);\n          bound::add(out_ptr_NC, o1, trgt0, s1);\n        }\n      }\n    }\n    // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Push ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n    else if (do_count) {\n      scalar_t* out_ptr_N = out_ptr + n * out_sN;\n      bound::add(out_ptr_N, o0, w0, s0);\n      bound::add(out_ptr_N, o1, w1, s1);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  NEAREST NEIGHBOR INTERPOLATION 3D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate3d_nearest(\n      scalar_t x,\n      scalar_t y,\n      scalar_t z,\n      offset_t w,\n      offset_t h,\n      offset_t d,\n      offset_t n) const {\n    offset_t ix = static_cast<offset_t>(std::round(x));\n    offset_t iy = static_cast<offset_t>(std::round(y));\n    offset_t iz = static_cast<offset_t>(std::round(z));\n\n    // Boundary condition (/!\\ compute sign before warping indices)\n    int8_t sx = bound::sign(bound0, ix, src_X);\n    int8_t sy = bound::sign(bound1, iy, src_Y);\n    int8_t sz = bound::sign(bound2, iz, src_Z);\n    ix = bound::index(bound0, ix, src_X);\n    iy = bound::index(bound1, iy, src_Y);\n    iz = bound::index(bound2, iz, src_Z);\n\n    // Sign\n    int8_t s = sz * sy * sx;\n\n    if (do_pull) {\n      offset_t o = iz * src_sZ + iy * src_sY + ix * src_sX;\n      scalar_t* out_ptr_NCXYZ = out_ptr + n * out_sN + w * out_sX + h * out_sY + d * out_sZ;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXYZ += out_sC, src_ptr_NC += src_sC)\n        *out_ptr_NCXYZ = bound::get(src_ptr_NC, o, s);\n    } else if (do_push && trgt_K == 0) {\n      offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX;\n      scalar_t* trgt_ptr_NCXYZ = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY + d * trgt_sZ;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXYZ += trgt_sC, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, *trgt_ptr_NCXYZ, s);\n    } else if (do_count) {\n      offset_t o = iz * out_sZ + iy * out_sY + ix * out_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, static_cast<scalar_t>(1), s);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  NEAREST NEIGHBOR INTERPOLATION 2D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate2d_nearest(\n      scalar_t x,\n      scalar_t y,\n      offset_t w,\n      offset_t h,\n      offset_t n) const {\n    offset_t ix = static_cast<offset_t>(std::round(x));\n    offset_t iy = static_cast<offset_t>(std::round(y));\n\n    // Boundary condition (/!\\ compute sign before warping indices)\n    int8_t sx = bound::sign(bound0, ix, src_X);\n    int8_t sy = bound::sign(bound1, iy, src_Y);\n    ix = bound::index(bound0, ix, src_X);\n    iy = bound::index(bound1, iy, src_Y);\n\n    // Sign\n    int8_t s = sy * sx;\n\n    if (do_pull) {\n      offset_t o = iy * src_sY + ix * src_sX;\n      scalar_t* out_ptr_NCXY = out_ptr + n * out_sN + w * out_sX + h * out_sY;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCXY += out_sC, src_ptr_NC += src_sC)\n        *out_ptr_NCXY = bound::get(src_ptr_NC, o, s);\n    } else if (do_push && trgt_K == 0) {\n      offset_t o = iy * out_sY + ix * out_sX;\n      scalar_t* trgt_ptr_NCXY = trgt_ptr + n * trgt_sN + w * trgt_sX + h * trgt_sY;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCXY += trgt_sC, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, *trgt_ptr_NCXY, s);\n    } else if (do_count) {\n      offset_t o = iy * out_sY + ix * out_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, static_cast<scalar_t>(1), s);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  NEAREST NEIGHBOR INTERPOLATION 1D\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  template <typename scalar_t, typename offset_t>\n  MONAI_DEVICE void PushPullImpl<scalar_t, offset_t>::interpolate1d_nearest(scalar_t x, offset_t w, offset_t n) const {\n    offset_t i = static_cast<offset_t>(std::round(x));\n\n    // Boundary condition (/!\\ compute sign before warping indices)\n    int8_t s = bound::sign(bound0, i, src_X);\n    i = bound::index(bound0, i, src_X);\n\n    if (do_pull) {\n      offset_t o = i * src_sX;\n      scalar_t* out_ptr_NCX = out_ptr + n * out_sN + w * out_sX;\n      scalar_t* src_ptr_NC = src_ptr + n * src_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NCX += out_sC, src_ptr_NC += src_sC)\n        *out_ptr_NCX = bound::get(src_ptr_NC, o, s);\n    } else if (do_push && trgt_K == 0) {\n      offset_t o = i * out_sX;\n      scalar_t* trgt_ptr_NCX = trgt_ptr + n * trgt_sN + w * trgt_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, trgt_ptr_NCX += trgt_sC, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, *trgt_ptr_NCX, s);\n    } else if (do_count) {\n      offset_t o = i * out_sX;\n      scalar_t* out_ptr_NC = out_ptr + n * out_sN;\n      for (offset_t c = 0; c < C; ++c, out_ptr_NC += out_sC)\n        bound::add(out_ptr_NC, o, static_cast<scalar_t>(1), s);\n    }\n  }\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //            LINEAR INTERPOLATION 3D + SLIDING BOUNDARY\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  // TODO\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                  CUDA KERNEL (MUST BE OUT OF CLASS)\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n  // CUDA Kernel\n  template <typename scalar_t, typename offset_t>\n  C10_LAUNCH_BOUNDS_1(1024)\n  __global__ void pushpull_kernel(PushPullImpl<scalar_t, offset_t> f) {\n    f.loop(threadIdx.x, blockIdx.x, blockDim.x, gridDim.x);\n  }\n\n  } // namespace\n\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n  //                    FUNCTIONAL FORM WITH DISPATCH\n  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n#define PUSHPULL_INSTANTIATE3(BoundType0, InterpolationType0, SourceType0) \\\n  template std::deque<Tensor> pushpull(                                    \\\n      const SourceType0&,                                                  \\\n      const Tensor&,                                                       \\\n      const Tensor&,                                                       \\\n      BoundType0,                                                          \\\n      InterpolationType0,                                                  \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool,                                                                \\\n      bool);                                                               \\\n  template std::deque<Tensor> pushpull(                                    \\\n      const SourceType0&, const Tensor&, BoundType0, InterpolationType0, bool, bool, bool, bool, bool, bool)\n#define PUSHPULL_INSTANTIATE2(BoundType0, InterpolationType0)         \\\n  PUSHPULL_INSTANTIATE3(BoundType0, InterpolationType0, IntArrayRef); \\\n  PUSHPULL_INSTANTIATE3(BoundType0, InterpolationType0, Tensor)\n#define PUSHPULL_INSTANTIATE1(BoundType0)               \\\n  PUSHPULL_INSTANTIATE2(BoundType0, InterpolationType); \\\n  PUSHPULL_INSTANTIATE2(BoundType0, InterpolationVectorRef)\n#define PUSHPULL_INSTANTIATE        \\\n  PUSHPULL_INSTANTIATE1(BoundType); \\\n  PUSHPULL_INSTANTIATE1(BoundVectorRef)\n\n  // Two arguments (source, grid)\n  // > `bound` and `interpolation` can be single arguments or vectors.\n  template <typename BoundType, typename InterpolationType, typename SourceType>\n  MONAI_HOST std::deque<Tensor> pushpull(\n      const SourceType& source,\n      const Tensor& grid,\n      BoundType bound,\n      InterpolationType interpolation,\n      bool extrapolate,\n      bool do_pull,\n      bool do_push,\n      bool do_count,\n      bool do_grad,\n      bool do_sgrad) {\n    PushPullAllocator info(\n        grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);\n    info.ioset(source, grid);\n\n    return AT_DISPATCH_FLOATING_TYPES_AND_HALF(grid.scalar_type(), \"pushpull\", [&] {\n      if (info.canUse32BitIndexMath()) {\n        PushPullImpl<scalar_t, int32_t> algo(info);\n        pushpull_kernel<<<GET_BLOCKS(algo.voxcount()), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(algo);\n        return algo.output;\n      } else {\n        PushPullImpl<scalar_t, int64_t> algo(info);\n        pushpull_kernel<<<GET_BLOCKS(algo.voxcount()), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(algo);\n        return algo.output;\n      }\n    });\n  }\n\n  // Three arguments (source, grid, target)\n  // > `bound` and `interpolation` can be single arguments or vectors.\n  // > `source` can be a tensor or a vector of dimensions.\n  template <typename BoundType, typename InterpolationType, typename SourceType>\n  MONAI_HOST std::deque<Tensor> pushpull(\n      const SourceType& source,\n      const Tensor& grid,\n      const Tensor& target,\n      BoundType bound,\n      InterpolationType interpolation,\n      bool extrapolate,\n      bool do_pull,\n      bool do_push,\n      bool do_count,\n      bool do_grad,\n      bool do_sgrad) {\n    PushPullAllocator info(\n        grid.dim() - 2, bound, interpolation, extrapolate, do_pull, do_push, do_count, do_grad, do_sgrad);\n    info.ioset(source, grid, target);\n\n    return AT_DISPATCH_FLOATING_TYPES_AND_HALF(grid.scalar_type(), \"pushpull\", [&] {\n      if (info.canUse32BitIndexMath()) {\n        PushPullImpl<scalar_t, int32_t> algo(info);\n        pushpull_kernel<<<GET_BLOCKS(algo.voxcount()), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(algo);\n        return algo.output;\n      } else {\n        PushPullImpl<scalar_t, int64_t> algo(info);\n        pushpull_kernel<<<GET_BLOCKS(algo.voxcount()), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(algo);\n        return algo.output;\n      }\n    });\n  }\n\n  PUSHPULL_INSTANTIATE;\n\n} // namespace gpu\n} // namespace monai\n"
  },
  {
    "path": "monai/csrc/utils/common_utils.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n#include <torch/extension.h>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor.\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous.\")\n#define CHECK_CONTIGUOUS_CUDA(x) \\\n  CHECK_CUDA(x);                 \\\n  CHECK_CONTIGUOUS(x)\n#define CHECK_DEFINED(value) \\\n  TORCH_CHECK(value.defined(), \"(): expected \" #value \" not be undefined, but it is \", value);\n#define CHECK_STRIDED(value)                                              \\\n  TORCH_CHECK(                                                            \\\n      value.layout() == at::kStrided,                                     \\\n      \"(): expected \" #value \"to have torch.strided layout, but it has \", \\\n      value.layout());\n#define CHECK_SPATIAL_1D_2D_OR_3D(value)                                \\\n  TORCH_CHECK(                                                          \\\n      (value.dim() == 3 || value.dim() == 4 || value.dim() == 5),       \\\n      \"(): expected 3D, 4D or 5D \" #value \" but got input with sizes \", \\\n      value.sizes());\n#define CHECK_GRID_COMPONENT(value, dim)           \\\n  TORCH_CHECK(                                     \\\n      value.size(-1) == dim - 2,                   \\\n      \"(): expected \" #value \" to have size \",     \\\n      dim - 2,                                     \\\n      \" in last \"                                  \\\n      \"dimension, but got \" #value \" with sizes \", \\\n      value.sizes());\n#define CHECK_SAME_DEVICE(value1, value2)     \\\n  TORCH_CHECK(                                \\\n      value1.device() == value2.device(),     \\\n      \"(): expected \" #value1 \" and \" #value2 \\\n      \" to be on same device, \"               \\\n      \"but \" #value1 \" is on \",               \\\n      value1.device(),                        \\\n      \" and \" #value2 \" is on \",              \\\n      value2.device());\n#define CHECK_SAME_DTYPE(value1, value2)      \\\n  TORCH_CHECK(                                \\\n      value1.dtype() == value2.dtype(),       \\\n      \"(): expected \" #value1 \" and \" #value2 \\\n      \" to have the same dtype, \"             \\\n      \"but \" #value1 \" has \",                 \\\n      value1.dtype(),                         \\\n      \" and \" #value2 \" has \",                \\\n      value2.dtype());\n#define CHECK_SPATIAL_NOT_EMPTY(value)                                                        \\\n  for (int64_t i = 2; i < value.dim(); i++) {                                                 \\\n    TORCH_CHECK(                                                                              \\\n        value.size(i) > 0,                                                                    \\\n        \"(): expected \" #value \" to have non-empty spatial dimensions, but input has sizes \", \\\n        value.sizes(),                                                                        \\\n        \" with dimension \",                                                                   \\\n        i,                                                                                    \\\n        \" being empty\");                                                                      \\\n  }\n#define CHECK_GRID_TARGET_COMPAT(value1, value2)                                                          \\\n  TORCH_CHECK(                                                                                            \\\n      value2.size(0) == value1.size(0) && (value2.dim() <= 2 || value2.size(2) == value1.size(1)) &&      \\\n          (value2.dim() <= 3 || value2.size(3) == value1.size(2)) &&                                      \\\n          (value2.dim() <= 4 || value2.size(4) == value1.size(3)),                                        \\\n      \"(): expected \" #value2 \" and \" #value1                                                             \\\n      \" to have same batch, width, height and (optionally) depth sizes, but got \" #value2 \" with sizes \", \\\n      value2.sizes(),                                                                                     \\\n      \" and \" #value1 \" with sizes \",                                                                     \\\n      value1.sizes());\n#define CHECK_SPATIAL_LENGTH(value, dim) \\\n  TORCH_CHECK(((int64_t)(value.size()) == dim - 2), \"(): expected \", dim, #value \" elements but got \", value.size());\n#define CHECK_VEC_NOT_EMPTY(value) TORCH_CHECK(!value.empty(), \"(): expected nonempty value \" #value);\n"
  },
  {
    "path": "monai/csrc/utils/meta_macros.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n// Helper Macros: for internal use (see below)\n#define _DO_1(TARGET) TARGET(1)\n#define _DO_2(TARGET) TARGET(2) _DO_1(TARGET)\n#define _DO_3(TARGET) TARGET(3) _DO_2(TARGET)\n#define _DO_4(TARGET) TARGET(4) _DO_3(TARGET)\n#define _DO_5(TARGET) TARGET(5) _DO_4(TARGET)\n#define _DO_6(TARGET) TARGET(6) _DO_5(TARGET)\n#define _DO_7(TARGET) TARGET(7) _DO_6(TARGET)\n#define _DO_8(TARGET) TARGET(8) _DO_7(TARGET)\n#define _DO_9(TARGET) TARGET(9) _DO_8(TARGET)\n#define _DO_10(TARGET) TARGET(10) _DO_9(TARGET)\n#define _DO_11(TARGET) TARGET(11) _DO_10(TARGET)\n#define _DO_12(TARGET) TARGET(12) _DO_11(TARGET)\n#define _DO_13(TARGET) TARGET(13) _DO_12(TARGET)\n#define _DO_14(TARGET) TARGET(14) _DO_13(TARGET)\n#define _DO_15(TARGET) TARGET(15) _DO_14(TARGET)\n#define _DO_16(TARGET) TARGET(16) _DO_15(TARGET)\n#define _DO_17(TARGET) TARGET(17) _DO_16(TARGET)\n#define _DO_18(TARGET) TARGET(18) _DO_17(TARGET)\n#define _DO_19(TARGET) TARGET(19) _DO_18(TARGET)\n#define _DO_20(TARGET) TARGET(20) _DO_19(TARGET)\n#define _DO_21(TARGET) TARGET(21) _DO_20(TARGET)\n#define _DO_22(TARGET) TARGET(22) _DO_21(TARGET)\n#define _DO_23(TARGET) TARGET(23) _DO_22(TARGET)\n#define _DO_24(TARGET) TARGET(24) _DO_23(TARGET)\n#define _DO_25(TARGET) TARGET(25) _DO_24(TARGET)\n#define _DO_26(TARGET) TARGET(26) _DO_25(TARGET)\n#define _DO_27(TARGET) TARGET(27) _DO_26(TARGET)\n#define _DO_28(TARGET) TARGET(28) _DO_27(TARGET)\n#define _DO_29(TARGET) TARGET(29) _DO_28(TARGET)\n#define _DO_30(TARGET) TARGET(30) _DO_29(TARGET)\n#define _DO_31(TARGET) TARGET(31) _DO_30(TARGET)\n#define _DO_32(TARGET) TARGET(32) _DO_31(TARGET)\n\n#define _DO_A_1(TARGET, A) TARGET(A, 1)\n#define _DO_A_2(TARGET, A) TARGET(A, 2) _DO_A_1(TARGET, A)\n#define _DO_A_3(TARGET, A) TARGET(A, 3) _DO_A_2(TARGET, A)\n#define _DO_A_4(TARGET, A) TARGET(A, 4) _DO_A_3(TARGET, A)\n#define _DO_A_5(TARGET, A) TARGET(A, 5) _DO_A_4(TARGET, A)\n#define _DO_A_6(TARGET, A) TARGET(A, 6) _DO_A_5(TARGET, A)\n#define _DO_A_7(TARGET, A) TARGET(A, 7) _DO_A_6(TARGET, A)\n#define _DO_A_8(TARGET, A) TARGET(A, 8) _DO_A_7(TARGET, A)\n#define _DO_A_9(TARGET, A) TARGET(A, 9) _DO_A_8(TARGET, A)\n#define _DO_A_10(TARGET, A) TARGET(A, 10) _DO_A_9(TARGET, A)\n#define _DO_A_11(TARGET, A) TARGET(A, 11) _DO_A_10(TARGET, A)\n#define _DO_A_12(TARGET, A) TARGET(A, 12) _DO_A_11(TARGET, A)\n#define _DO_A_13(TARGET, A) TARGET(A, 13) _DO_A_12(TARGET, A)\n#define _DO_A_14(TARGET, A) TARGET(A, 14) _DO_A_13(TARGET, A)\n#define _DO_A_15(TARGET, A) TARGET(A, 15) _DO_A_14(TARGET, A)\n#define _DO_A_16(TARGET, A) TARGET(A, 16) _DO_A_15(TARGET, A)\n#define _DO_A_17(TARGET, A) TARGET(A, 17) _DO_A_16(TARGET, A)\n#define _DO_A_18(TARGET, A) TARGET(A, 18) _DO_A_17(TARGET, A)\n#define _DO_A_19(TARGET, A) TARGET(A, 19) _DO_A_18(TARGET, A)\n#define _DO_A_20(TARGET, A) TARGET(A, 20) _DO_A_19(TARGET, A)\n#define _DO_A_21(TARGET, A) TARGET(A, 21) _DO_A_20(TARGET, A)\n#define _DO_A_22(TARGET, A) TARGET(A, 22) _DO_A_21(TARGET, A)\n#define _DO_A_23(TARGET, A) TARGET(A, 23) _DO_A_22(TARGET, A)\n#define _DO_A_24(TARGET, A) TARGET(A, 24) _DO_A_23(TARGET, A)\n#define _DO_A_25(TARGET, A) TARGET(A, 25) _DO_A_24(TARGET, A)\n#define _DO_A_26(TARGET, A) TARGET(A, 26) _DO_A_25(TARGET, A)\n#define _DO_A_27(TARGET, A) TARGET(A, 27) _DO_A_26(TARGET, A)\n#define _DO_A_28(TARGET, A) TARGET(A, 28) _DO_A_27(TARGET, A)\n#define _DO_A_29(TARGET, A) TARGET(A, 29) _DO_A_28(TARGET, A)\n#define _DO_A_30(TARGET, A) TARGET(A, 30) _DO_A_29(TARGET, A)\n#define _DO_A_31(TARGET, A) TARGET(A, 31) _DO_A_30(TARGET, A)\n#define _DO_A_32(TARGET, A) TARGET(A, 32) _DO_A_31(TARGET, A)\n\n#define _DO_1_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 1)\n#define _DO_2_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 2) _DO_1_B(TARGET, B_RANGE)\n#define _DO_3_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 3) _DO_2_B(TARGET, B_RANGE)\n#define _DO_4_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 4) _DO_3_B(TARGET, B_RANGE)\n#define _DO_5_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 5) _DO_4_B(TARGET, B_RANGE)\n#define _DO_6_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 6) _DO_5_B(TARGET, B_RANGE)\n#define _DO_7_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 7) _DO_6_B(TARGET, B_RANGE)\n#define _DO_8_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 8) _DO_7_B(TARGET, B_RANGE)\n#define _DO_9_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 9) _DO_8_B(TARGET, B_RANGE)\n#define _DO_10_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 10) _DO_9_B(TARGET, B_RANGE)\n#define _DO_11_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 11) _DO_10_B(TARGET, B_RANGE)\n#define _DO_12_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 12) _DO_11_B(TARGET, B_RANGE)\n#define _DO_13_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 13) _DO_12_B(TARGET, B_RANGE)\n#define _DO_14_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 14) _DO_13_B(TARGET, B_RANGE)\n#define _DO_15_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 15) _DO_14_B(TARGET, B_RANGE)\n#define _DO_16_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 16) _DO_15_B(TARGET, B_RANGE)\n#define _DO_17_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 17) _DO_16_B(TARGET, B_RANGE)\n#define _DO_18_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 18) _DO_17_B(TARGET, B_RANGE)\n#define _DO_19_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 19) _DO_18_B(TARGET, B_RANGE)\n#define _DO_20_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 20) _DO_19_B(TARGET, B_RANGE)\n#define _DO_21_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 21) _DO_20_B(TARGET, B_RANGE)\n#define _DO_22_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 22) _DO_21_B(TARGET, B_RANGE)\n#define _DO_23_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 23) _DO_22_B(TARGET, B_RANGE)\n#define _DO_24_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 24) _DO_23_B(TARGET, B_RANGE)\n#define _DO_25_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 25) _DO_24_B(TARGET, B_RANGE)\n#define _DO_26_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 26) _DO_25_B(TARGET, B_RANGE)\n#define _DO_27_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 27) _DO_26_B(TARGET, B_RANGE)\n#define _DO_28_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 28) _DO_27_B(TARGET, B_RANGE)\n#define _DO_29_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 29) _DO_28_B(TARGET, B_RANGE)\n#define _DO_30_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 30) _DO_29_B(TARGET, B_RANGE)\n#define _DO_31_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 31) _DO_30_B(TARGET, B_RANGE)\n#define _DO_32_B(TARGET, B_RANGE) _DO_A_##B_RANGE(TARGET, 32) _DO_31_B(TARGET, B_RANGE)\n\n#define _CASE_A(A) \\\n  case (A):        \\\n    CASE(A) break;\n#define _CASE_AB(A, B) \\\n  case (A * 100 + B):  \\\n    CASE(A, B) break;\n\n// Preproccessor For Loops\n#define DO_FOR_A(TARGET, A_RANGE) _DO_##A_RANGE(TARGET)\n#define DO_FOR_AB(TARGET, A_RANGE, B_RANGE) _DO_##A_RANGE##_B(TARGET, B_RANGE)\n\n// Preproccessor Switch Statement Generators\n#define SWITCH_A(CASE, A_RANGE, A) \\\n  switch (A) { DO_FOR_A(_CASE_A, A_RANGE) }\n#define SWITCH_AB(CALL, A_RANGE, B_RANGE, A, B) \\\n  switch (A * 100 + B) { DO_FOR_AB(_CASE_AB, A_RANGE, B_RANGE) }\n"
  },
  {
    "path": "monai/csrc/utils/resample_utils.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#pragma once\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n// We need to define AT_PARALLEL_OPENMP (even if -fopenmp is\n// not used) so that at::parallel_for is defined somewhere.\n// This must be done before <ATen/Parallel.h> is included.\n//\n// Note that if AT_PARALLEL_OPENMP = 1 but compilation does not use\n// -fopenmp, omp pragmas will be ignored. In that case, the code will\n// be effectively sequential, and we don't have to worry about\n// operations being atomic.\n#if !(AT_PARALLEL_OPENMP)\n#if !(AT_PARALLEL_NATIVE)\n#if !(AT_PARALLEL_NATIVE_TBB)\n#error No parallel backend specified\n#endif\n#endif\n#endif\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n// These are defines that help writing generic code for both GPU and CPU\n#ifdef __CUDACC__\n#include <ATen/cuda/CUDAApplyUtils.cuh>\n#include <THC/THCAtomics.cuh>\n#define MONAI_INLINE __forceinline__\n#define MONAI_DEVICE __device__\n#define MONAI_HOST __host__\n#define MONAI_ATOMIC_ADD monai::gpuAtomicAdd\n#define MONAI_NAMESPACE_DEVICE namespace cuda\nnamespace monai {\n// atomicAdd API changed between pytorch 1.4 and 1.5.\ntemplate <typename scalar_t, typename offset_t>\nstatic __forceinline__ __device__ void gpuAtomicAdd(scalar_t* ptr, offset_t offset, scalar_t value) {\n#if MONAI_TORCH_VERSION >= 10500\n  ::gpuAtomicAdd(ptr + offset, value);\n#else\n  ::atomicAdd(ptr + offset, value);\n#endif\n}\n} // namespace monai\n#else\n#define MONAI_INLINE inline\n#define MONAI_DEVICE\n#define MONAI_HOST\n#define MONAI_ATOMIC_ADD monai::cpuAtomicAdd\n#define MONAI_NAMESPACE_DEVICE namespace cpu\nnamespace monai {\ntemplate <typename scalar_t, typename offset_t>\nstatic inline void cpuAtomicAdd(scalar_t* ptr, offset_t offset, scalar_t value) {\n#if AT_PARALLEL_OPENMP\n#if _OPENMP\n#pragma omp atomic\n#endif\n#endif\n  ptr[offset] += value;\n}\n} // namespace monai\n#endif\n// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n#include <ATen/ATen.h>\n\nnamespace monai {\n\nenum class BoundType : int64_t {\n  Replicate, // Replicate last inbound value = clip coordinates\n  DCT1, // Symmetric w.r.t. center of the last inbound voxel\n  DCT2, // Symmetric w.r.t. edge of the last inbound voxel (=Neuman)\n  DST1, // Asymmetric w.r.t. center of the last inbound voxel\n  DST2, // Asymmetric w.r.t. edge of the last inbound voxel (=Dirichlet)\n  DFT, // Circular / Wrap around the FOV\n  Sliding, // For deformation-fields only: mixture of DCT2 and DST2\n  Zero, // Zero outside of the FOV\n  NoCheck // /!\\ Checks disabled: assume coordinates are inbound\n};\n\nusing BoundVectorRef = c10::ArrayRef<BoundType>;\n\nenum class InterpolationType : int64_t {\n  Nearest,\n  Linear,\n  Quadratic,\n  Cubic,\n  FourthOrder,\n  FifthOrder,\n  SixthOrder,\n  SeventhOrder\n};\nusing InterpolationVectorRef = c10::ArrayRef<InterpolationType>;\n\nstatic MONAI_INLINE MONAI_HOST std::ostream& operator<<(std::ostream& os, const BoundType& bound) {\n  switch (bound) {\n    case BoundType::Replicate:\n      return os << \"Replicate\";\n    case BoundType::DCT1:\n      return os << \"DCT1\";\n    case BoundType::DCT2:\n      return os << \"DCT2\";\n    case BoundType::DST1:\n      return os << \"DST1\";\n    case BoundType::DST2:\n      return os << \"DST2\";\n    case BoundType::DFT:\n      return os << \"DFT\";\n    case BoundType::Zero:\n      return os << \"Zero\";\n    case BoundType::Sliding:\n      return os << \"Sliding\";\n    case BoundType::NoCheck:\n      return os << \"NoCheck\";\n  }\n  return os << \"Unknown bound\";\n}\n\nstatic MONAI_INLINE MONAI_HOST std::ostream& operator<<(std::ostream& os, const InterpolationType& itp) {\n  switch (itp) {\n    case InterpolationType::Nearest:\n      return os << \"Nearest\";\n    case InterpolationType::Linear:\n      return os << \"Linear\";\n    case InterpolationType::Quadratic:\n      return os << \"Quadratic\";\n    case InterpolationType::Cubic:\n      return os << \"Cubic\";\n    case InterpolationType::FourthOrder:\n      return os << \"FourthOrder\";\n    case InterpolationType::FifthOrder:\n      return os << \"FifthOrder\";\n    case InterpolationType::SixthOrder:\n      return os << \"SixthOrder\";\n    case InterpolationType::SeventhOrder:\n      return os << \"SeventhOrder\";\n  }\n  return os << \"Unknown interpolation order\";\n}\n\n} // namespace monai\n"
  },
  {
    "path": "monai/csrc/utils/tensor_description.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n\n// Struct to easily cache descriptive information about a tensor.\n// This is helpful as regular calls to the size and stride member\n// functions of tensors appear to cause memory issues.\nstruct TensorDescription {\n public:\n  TensorDescription(torch::Tensor tensor) {\n    batchCount = tensor.size(0);\n    batchStride = tensor.stride(0);\n\n    channelCount = tensor.size(1);\n    channelStride = tensor.stride(1);\n\n    dimensions = tensor.dim() - 2;\n    sizes = new int[dimensions];\n    strides = new int[dimensions];\n\n    for (int i = 0; i < dimensions; i++) {\n      sizes[i] = tensor.size(i + 2);\n      strides[i] = tensor.stride(i + 2);\n    }\n  }\n\n  ~TensorDescription() {\n    delete[] sizes;\n    delete[] strides;\n  }\n\n  int batchCount;\n  int batchStride;\n\n  int channelCount;\n  int channelStride;\n\n  int dimensions;\n  int* sizes;\n  int* strides;\n};\n"
  },
  {
    "path": "monai/csrc/utils/tensor_indexing.h",
    "content": "/*\nCopyright (c) MONAI Consortium\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n    http://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n*/\n\n#include <torch/extension.h>\n\n// Struct to easily index input tensors.\nstruct Indexer {\n public:\n  Indexer(int dimensions, int* sizes) {\n    m_dimensions = dimensions;\n    m_sizes = sizes;\n    m_index = new int[dimensions]{0};\n  }\n  ~Indexer() {\n    delete[] m_index;\n  }\n\n  bool operator++(int) {\n    for (int i = 0; i < m_dimensions; i++) {\n      m_index[i] += 1;\n\n      if (m_index[i] < m_sizes[i]) {\n        return true;\n      } else {\n        m_index[i] = 0;\n      }\n    }\n\n    return false;\n  }\n\n  int& operator[](int dimensionIndex) {\n    return m_index[dimensionIndex];\n  }\n\n private:\n  int m_dimensions;\n  int* m_sizes;\n  int* m_index;\n};\n"
  },
  {
    "path": "monai/data/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport contextlib\n\nfrom .box_utils import (\n    box_area,\n    box_centers,\n    box_giou,\n    box_iou,\n    box_pair_giou,\n    boxes_center_distance,\n    centers_in_boxes,\n    convert_box_mode,\n    convert_box_to_standard_mode,\n)\nfrom .csv_saver import CSVSaver\nfrom .dataloader import DataLoader\nfrom .dataset import (\n    ArrayDataset,\n    CacheDataset,\n    CacheNTransDataset,\n    CSVDataset,\n    Dataset,\n    DatasetFunc,\n    GDSDataset,\n    LMDBDataset,\n    NPZDictItemDataset,\n    PersistentDataset,\n    SmartCacheDataset,\n    ZipDataset,\n)\nfrom .dataset_summary import DatasetSummary\nfrom .decathlon_datalist import (\n    check_missing_files,\n    create_cross_validation_datalist,\n    load_decathlon_datalist,\n    load_decathlon_properties,\n)\nfrom .folder_layout import FolderLayout, FolderLayoutBase\nfrom .grid_dataset import GridPatchDataset, PatchDataset, PatchIter, PatchIterd\nfrom .image_dataset import ImageDataset\nfrom .image_reader import ImageReader, ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader\nfrom .image_writer import (\n    SUPPORTED_WRITERS,\n    ImageWriter,\n    ITKWriter,\n    NibabelWriter,\n    PILWriter,\n    logger,\n    register_writer,\n    resolve_writer,\n)\nfrom .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer\nfrom .itk_torch_bridge import (\n    get_itk_image_center,\n    itk_image_to_metatensor,\n    itk_to_monai_affine,\n    metatensor_to_itk_image,\n    monai_to_itk_affine,\n    monai_to_itk_ddf,\n)\nfrom .meta_obj import MetaObj, get_track_meta, set_track_meta\nfrom .meta_tensor import MetaTensor\nfrom .samplers import DistributedSampler, DistributedWeightedRandomSampler\nfrom .synthetic import create_test_image_2d, create_test_image_3d\nfrom .test_time_augmentation import TestTimeAugmentation\nfrom .thread_buffer import ThreadBuffer, ThreadDataLoader\nfrom .torchscript_utils import load_net_with_metadata, save_net_with_metadata\nfrom .utils import (\n    affine_to_spacing,\n    compute_importance_map,\n    compute_shape_offset,\n    convert_tables_to_dicts,\n    correct_nifti_header_if_necessary,\n    create_file_basename,\n    decollate_batch,\n    dense_patch_slices,\n    get_extra_metadata_keys,\n    get_random_patch,\n    get_valid_patch_size,\n    is_supported_format,\n    iter_patch,\n    iter_patch_position,\n    iter_patch_slices,\n    json_hashing,\n    list_data_collate,\n    orientation_ras_lps,\n    pad_list_data_collate,\n    partition_dataset,\n    partition_dataset_classes,\n    pickle_hashing,\n    rectify_header_sform_qform,\n    remove_extra_metadata,\n    remove_keys,\n    reorient_spatial_axes,\n    resample_datalist,\n    select_cross_validation_folds,\n    set_rnd,\n    sorted_dict,\n    to_affine_nd,\n    worker_init_fn,\n    zoom_affine,\n)\n\n# FIXME: workaround for https://github.com/Project-MONAI/MONAI/issues/5291\n# from .video_dataset import CameraDataset, VideoDataset, VideoFileDataset\nfrom .wsi_datasets import MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset\nfrom .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader, WSIReader\n\nwith contextlib.suppress(BaseException):\n    from multiprocessing.reduction import ForkingPickler\n\n    def _rebuild_meta(cls, storage, dtype, metadata):\n        storage_offset, size, stride, requires_grad, meta_dict = metadata\n        storage = storage._untyped_storage if hasattr(storage, \"_untyped_storage\") else storage\n        t = cls([], dtype=dtype, device=storage.device)\n        t.set_(storage, storage_offset, size, stride)\n        t.requires_grad = requires_grad\n        t.__dict__ = meta_dict\n        return t\n\n    def reduce_meta_tensor(meta_tensor):\n        if hasattr(meta_tensor, \"untyped_storage\"):\n            storage = meta_tensor.untyped_storage()\n        elif hasattr(meta_tensor, \"_typed_storage\"):  # gh pytorch 44dac51/torch/_tensor.py#L231-L233\n            storage = meta_tensor._typed_storage()\n        else:\n            storage = meta_tensor.storage()\n        dtype = meta_tensor.dtype\n        if meta_tensor.is_cuda:\n            raise NotImplementedError(\"sharing CUDA metatensor across processes not implemented\")\n        metadata = (\n            meta_tensor.storage_offset(),\n            meta_tensor.size(),\n            meta_tensor.stride(),\n            meta_tensor.requires_grad,\n            meta_tensor.__dict__,\n        )\n        return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata)\n\n    ForkingPickler.register(MetaTensor, reduce_meta_tensor)\n\nfrom .ultrasound_confidence_map import UltrasoundConfidenceMap\n"
  },
  {
    "path": "monai/data/box_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis utility module mainly supports rectangular bounding boxes with a few\ndifferent parameterizations and methods for converting between them. It\nprovides reliable access to the spatial coordinates of the box vertices in the\n\"canonical ordering\":\n[xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, zmin, xmax, ymax, zmax] for 3D.\nWe currently define this ordering as `monai.data.box_utils.StandardMode` and\nthe rest of the detection pipelines mainly assumes boxes in `StandardMode`.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport inspect\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Sequence\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor\nfrom monai.utils import look_up_option\nfrom monai.utils.enums import BoxModeName\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type\n\n# We support 2-D or 3-D bounding boxes\nSUPPORTED_SPATIAL_DIMS = [2, 3]\n\n# TO_REMOVE = 0.0 if the bottom-right corner pixel/voxel is not included in the boxes,\n#      i.e., when xmin=1., xmax=2., we have w = 1.\n# TO_REMOVE = 1.0  if the bottom-right corner pixel/voxel is included in the boxes,\n#       i.e., when xmin=1., xmax=2., we have w = 2.\n# Currently, only `TO_REMOVE = 0.0` is supported\nTO_REMOVE = 0.0  # xmax-xmin = w -TO_REMOVE.\n\n# Some torch functions do not support half precision.\n# We therefore compute those functions under COMPUTE_DTYPE\nCOMPUTE_DTYPE = torch.float32\n\n\nclass BoxMode(ABC):\n    \"\"\"\n    An abstract class of a ``BoxMode``.\n\n    A ``BoxMode`` is callable that converts box mode of ``boxes``, which are Nx4 (2D) or Nx6 (3D) torch tensor or ndarray.\n    ``BoxMode`` has several subclasses that represents different box modes, including\n\n    - :class:`~monai.data.box_utils.CornerCornerModeTypeA`:\n      represents [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, zmin, xmax, ymax, zmax] for 3D\n    - :class:`~monai.data.box_utils.CornerCornerModeTypeB`:\n      represents [xmin, xmax, ymin, ymax] for 2D and [xmin, xmax, ymin, ymax, zmin, zmax] for 3D\n    - :class:`~monai.data.box_utils.CornerCornerModeTypeC`:\n      represents [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, xmax, ymax, zmin, zmax] for 3D\n    - :class:`~monai.data.box_utils.CornerSizeMode`:\n      represents [xmin, ymin, xsize, ysize] for 2D and [xmin, ymin, zmin, xsize, ysize, zsize] for 3D\n    - :class:`~monai.data.box_utils.CenterSizeMode`:\n      represents [xcenter, ycenter, xsize, ysize] for 2D and [xcenter, ycenter, zcenter, xsize, ysize, zsize] for 3D\n\n    We currently define ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,\n    and monai detection pipelines mainly assume ``boxes`` are in ``StandardMode``.\n\n    The implementation should be aware of:\n\n    - remember to define class variable ``name``,\n      a dictionary that maps ``spatial_dims`` to :class:`~monai.utils.enums.BoxModeName`.\n    - :func:`~monai.data.box_utils.BoxMode.boxes_to_corners` and :func:`~monai.data.box_utils.BoxMode.corners_to_boxes`\n      should not modify inputs in place.\n    \"\"\"\n\n    # a dictionary that maps spatial_dims to monai.utils.enums.BoxModeName.\n    name: dict[int, BoxModeName] = {}\n\n    @classmethod\n    def get_name(cls, spatial_dims: int) -> str:\n        \"\"\"\n        Get the mode name for the given spatial dimension using class variable ``name``.\n\n        Args:\n            spatial_dims: number of spatial dimensions of the bounding boxes.\n\n        Returns:\n            ``str``: mode string name\n        \"\"\"\n        return cls.name[spatial_dims].value\n\n    @abstractmethod\n    def boxes_to_corners(self, boxes: torch.Tensor) -> tuple:\n        \"\"\"\n        Convert the bounding boxes of the current mode to corners.\n\n        Args:\n            boxes: bounding boxes, Nx4 or Nx6 torch tensor\n\n        Returns:\n            ``tuple``: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor.\n            It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)\n\n        Example:\n            .. code-block:: python\n\n                boxes = torch.ones(10,6)\n                boxmode = BoxMode()\n                boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def corners_to_boxes(self, corners: Sequence) -> torch.Tensor:\n        \"\"\"\n        Convert the given box corners to the bounding boxes of the current mode.\n\n        Args:\n            corners: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor.\n                It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)\n\n        Returns:\n            ``Tensor``: bounding boxes, Nx4 or Nx6 torch tensor\n\n        Example:\n            .. code-block:: python\n\n                corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))\n                boxmode = BoxMode()\n                boxmode.corners_to_boxes(corners) # will return a 10x4 tensor\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass CornerCornerModeTypeA(BoxMode):\n    \"\"\"\n    A subclass of ``BoxMode``.\n\n    Also represented as \"xyxy\" or \"xyzxyz\", with format of\n    [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].\n\n    Example:\n        .. code-block:: python\n\n            CornerCornerModeTypeA.get_name(spatial_dims=2) # will return \"xyxy\"\n            CornerCornerModeTypeA.get_name(spatial_dims=3) # will return \"xyzxyz\"\n    \"\"\"\n\n    name = {2: BoxModeName.XYXY, 3: BoxModeName.XYZXYZ}\n\n    def boxes_to_corners(self, boxes: torch.Tensor) -> tuple:\n        corners: tuple\n        corners = boxes.split(1, dim=-1)\n        return corners\n\n    def corners_to_boxes(self, corners: Sequence) -> torch.Tensor:\n        boxes: torch.Tensor\n        boxes = torch.cat(tuple(corners), dim=-1)\n        return boxes\n\n\nclass CornerCornerModeTypeB(BoxMode):\n    \"\"\"\n    A subclass of ``BoxMode``.\n\n    Also represented as \"xxyy\" or \"xxyyzz\", with format of\n    [xmin, xmax, ymin, ymax] or [xmin, xmax, ymin, ymax, zmin, zmax].\n\n    Example:\n        .. code-block:: python\n\n            CornerCornerModeTypeB.get_name(spatial_dims=2) # will return \"xxyy\"\n            CornerCornerModeTypeB.get_name(spatial_dims=3) # will return \"xxyyzz\"\n    \"\"\"\n\n    name = {2: BoxModeName.XXYY, 3: BoxModeName.XXYYZZ}\n\n    def boxes_to_corners(self, boxes: torch.Tensor) -> tuple:\n        corners: tuple\n        spatial_dims = get_spatial_dims(boxes=boxes)\n        if spatial_dims == 3:\n            xmin, xmax, ymin, ymax, zmin, zmax = boxes.split(1, dim=-1)\n            corners = xmin, ymin, zmin, xmax, ymax, zmax\n        elif spatial_dims == 2:\n            xmin, xmax, ymin, ymax = boxes.split(1, dim=-1)\n            corners = xmin, ymin, xmax, ymax\n        return corners\n\n    def corners_to_boxes(self, corners: Sequence) -> torch.Tensor:\n        boxes: torch.Tensor\n        spatial_dims = get_spatial_dims(corners=corners)\n        if spatial_dims == 3:\n            boxes = torch.cat((corners[0], corners[3], corners[1], corners[4], corners[2], corners[5]), dim=-1)\n        elif spatial_dims == 2:\n            boxes = torch.cat((corners[0], corners[2], corners[1], corners[3]), dim=-1)\n        return boxes\n\n\nclass CornerCornerModeTypeC(BoxMode):\n    \"\"\"\n    A subclass of ``BoxMode``.\n\n    Also represented as \"xyxy\" or \"xyxyzz\", with format of\n    [xmin, ymin, xmax, ymax] or [xmin, ymin, xmax, ymax, zmin, zmax].\n\n    Example:\n        .. code-block:: python\n\n            CornerCornerModeTypeC.get_name(spatial_dims=2) # will return \"xyxy\"\n            CornerCornerModeTypeC.get_name(spatial_dims=3) # will return \"xyxyzz\"\n    \"\"\"\n\n    name = {2: BoxModeName.XYXY, 3: BoxModeName.XYXYZZ}\n\n    def boxes_to_corners(self, boxes: torch.Tensor) -> tuple:\n        corners: tuple\n        spatial_dims = get_spatial_dims(boxes=boxes)\n        if spatial_dims == 3:\n            xmin, ymin, xmax, ymax, zmin, zmax = boxes.split(1, dim=-1)\n            corners = xmin, ymin, zmin, xmax, ymax, zmax\n        elif spatial_dims == 2:\n            corners = boxes.split(1, dim=-1)\n        return corners\n\n    def corners_to_boxes(self, corners: Sequence) -> torch.Tensor:\n        boxes: torch.Tensor\n        spatial_dims = get_spatial_dims(corners=corners)\n        if spatial_dims == 3:\n            boxes = torch.cat((corners[0], corners[1], corners[3], corners[4], corners[2], corners[5]), dim=-1)\n        elif spatial_dims == 2:\n            boxes = torch.cat(tuple(corners), dim=-1)\n        return boxes\n\n\nclass CornerSizeMode(BoxMode):\n    \"\"\"\n    A subclass of ``BoxMode``.\n\n    Also represented as \"xywh\" or \"xyzwhd\", with format of\n    [xmin, ymin, xsize, ysize] or [xmin, ymin, zmin, xsize, ysize, zsize].\n\n    Example:\n        .. code-block:: python\n\n            CornerSizeMode.get_name(spatial_dims=2) # will return \"xywh\"\n            CornerSizeMode.get_name(spatial_dims=3) # will return \"xyzwhd\"\n    \"\"\"\n\n    name = {2: BoxModeName.XYWH, 3: BoxModeName.XYZWHD}\n\n    def boxes_to_corners(self, boxes: torch.Tensor) -> tuple:\n        corners: tuple\n        # convert to float32 when computing torch.clamp, which does not support float16\n        box_dtype = boxes.dtype\n\n        spatial_dims = get_spatial_dims(boxes=boxes)\n        if spatial_dims == 3:\n            xmin, ymin, zmin, w, h, d = boxes.split(1, dim=-1)\n            xmax = xmin + (w - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            ymax = ymin + (h - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            zmax = zmin + (d - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            corners = xmin, ymin, zmin, xmax, ymax, zmax\n        elif spatial_dims == 2:\n            xmin, ymin, w, h = boxes.split(1, dim=-1)\n            xmax = xmin + (w - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            ymax = ymin + (h - TO_REMOVE).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            corners = xmin, ymin, xmax, ymax\n        return corners\n\n    def corners_to_boxes(self, corners: Sequence) -> torch.Tensor:\n        boxes: torch.Tensor\n        spatial_dims = get_spatial_dims(corners=corners)\n        if spatial_dims == 3:\n            xmin, ymin, zmin, xmax, ymax, zmax = corners[0], corners[1], corners[2], corners[3], corners[4], corners[5]\n            boxes = torch.cat(\n                (xmin, ymin, zmin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE, zmax - zmin + TO_REMOVE), dim=-1\n            )\n        elif spatial_dims == 2:\n            xmin, ymin, xmax, ymax = corners[0], corners[1], corners[2], corners[3]\n            boxes = torch.cat((xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1)\n        return boxes\n\n\nclass CenterSizeMode(BoxMode):\n    \"\"\"\n    A subclass of ``BoxMode``.\n\n    Also represented as \"ccwh\" or \"cccwhd\", with format of\n    [xmin, ymin, xsize, ysize] or [xmin, ymin, zmin, xsize, ysize, zsize].\n\n    Example:\n        .. code-block:: python\n\n            CenterSizeMode.get_name(spatial_dims=2) # will return \"ccwh\"\n            CenterSizeMode.get_name(spatial_dims=3) # will return \"cccwhd\"\n    \"\"\"\n\n    name = {2: BoxModeName.CCWH, 3: BoxModeName.CCCWHD}\n\n    def boxes_to_corners(self, boxes: torch.Tensor) -> tuple:\n        corners: tuple\n        # convert to float32 when computing torch.clamp, which does not support float16\n        box_dtype = boxes.dtype\n\n        spatial_dims = get_spatial_dims(boxes=boxes)\n        if spatial_dims == 3:\n            xc, yc, zc, w, h, d = boxes.split(1, dim=-1)\n            xmin = xc - ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            xmax = xc + ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            ymin = yc - ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            ymax = yc + ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            zmin = zc - ((d - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            zmax = zc + ((d - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            corners = xmin, ymin, zmin, xmax, ymax, zmax\n        elif spatial_dims == 2:\n            xc, yc, w, h = boxes.split(1, dim=-1)\n            xmin = xc - ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            xmax = xc + ((w - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            ymin = yc - ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            ymax = yc + ((h - TO_REMOVE) / 2.0).to(dtype=COMPUTE_DTYPE).clamp(min=0).to(dtype=box_dtype)\n            corners = xmin, ymin, xmax, ymax\n        return corners\n\n    def corners_to_boxes(self, corners: Sequence) -> torch.Tensor:\n        boxes: torch.Tensor\n        spatial_dims = get_spatial_dims(corners=corners)\n        if spatial_dims == 3:\n            xmin, ymin, zmin, xmax, ymax, zmax = corners[0], corners[1], corners[2], corners[3], corners[4], corners[5]\n            boxes = torch.cat(\n                (\n                    (xmin + xmax + TO_REMOVE) / 2.0,\n                    (ymin + ymax + TO_REMOVE) / 2.0,\n                    (zmin + zmax + TO_REMOVE) / 2.0,\n                    xmax - xmin + TO_REMOVE,\n                    ymax - ymin + TO_REMOVE,\n                    zmax - zmin + TO_REMOVE,\n                ),\n                dim=-1,\n            )\n        elif spatial_dims == 2:\n            xmin, ymin, xmax, ymax = corners[0], corners[1], corners[2], corners[3]\n            boxes = torch.cat(\n                (\n                    (xmin + xmax + TO_REMOVE) / 2.0,\n                    (ymin + ymax + TO_REMOVE) / 2.0,\n                    xmax - xmin + TO_REMOVE,\n                    ymax - ymin + TO_REMOVE,\n                ),\n                dim=-1,\n            )\n        return boxes\n\n\n# We support the conversion between several box modes, i.e., representation of a bounding boxes\nSUPPORTED_MODES = [CornerCornerModeTypeA, CornerCornerModeTypeB, CornerCornerModeTypeC, CornerSizeMode, CenterSizeMode]\n# The standard box mode we use in all the box util functions\nStandardMode = CornerCornerModeTypeA\n\n\ndef get_spatial_dims(\n    boxes: torch.Tensor | np.ndarray | None = None,\n    points: torch.Tensor | np.ndarray | None = None,\n    corners: Sequence | None = None,\n    spatial_size: Sequence[int] | torch.Tensor | np.ndarray | None = None,\n) -> int:\n    \"\"\"\n    Get spatial dimension for the giving setting and check the validity of them.\n    Missing input is allowed. But at least one of the input value should be given.\n    It raises ValueError if the dimensions of multiple inputs do not match with each other.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray\n        points: point coordinates, [x, y] or [x, y, z], Nx2 or Nx3 torch tensor or ndarray\n        corners: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor or ndarray\n        spatial_size: The spatial size of the image where the boxes are attached.\n                len(spatial_size) should be in [2, 3].\n\n    Returns:\n        ``int``: spatial_dims, number of spatial dimensions of the bounding boxes.\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,6)\n            get_spatial_dims(boxes, spatial_size=[100,200,200]) # will return 3\n            get_spatial_dims(boxes, spatial_size=[100,200]) # will raise ValueError\n            get_spatial_dims(boxes) # will return 3\n    \"\"\"\n    spatial_dims_set = set()\n\n    # Check the validity of each input and add its corresponding spatial_dims to spatial_dims_set\n    if boxes is not None:\n        if len(boxes.shape) != 2:\n            if boxes.shape[0] == 0:\n                raise ValueError(\n                    f\"Currently we support only boxes with shape [N,4] or [N,6], \"\n                    f\"got boxes with shape {boxes.shape}. \"\n                    f\"Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6]).\"\n                )\n            else:\n                raise ValueError(\n                    f\"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}.\"\n                )\n        if int(boxes.shape[1] / 2) not in SUPPORTED_SPATIAL_DIMS:\n            raise ValueError(\n                f\"Currently we support only boxes with shape [N,4] or [N,6], got boxes with shape {boxes.shape}.\"\n            )\n        spatial_dims_set.add(int(boxes.shape[1] / 2))\n    if points is not None:\n        if len(points.shape) != 2:\n            if points.shape[0] == 0:\n                raise ValueError(\n                    f\"Currently we support only points with shape [N,2] or [N,3], \"\n                    f\"got points with shape {points.shape}. \"\n                    f\"Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3]).\"\n                )\n            else:\n                raise ValueError(\n                    f\"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}.\"\n                )\n        if int(points.shape[1]) not in SUPPORTED_SPATIAL_DIMS:\n            raise ValueError(\n                f\"Currently we support only points with shape [N,2] or [N,3], got points with shape {points.shape}.\"\n            )\n        spatial_dims_set.add(int(points.shape[1]))\n    if corners is not None:\n        if len(corners) // 2 not in SUPPORTED_SPATIAL_DIMS:\n            raise ValueError(\n                f\"Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length {len(corners)}.\"\n            )\n        spatial_dims_set.add(len(corners) // 2)\n    if spatial_size is not None:\n        if len(spatial_size) not in SUPPORTED_SPATIAL_DIMS:\n            raise ValueError(\n                f\"Currently we support only boxes on 2-D and 3-D images, got image spatial_size {spatial_size}.\"\n            )\n        spatial_dims_set.add(len(spatial_size))\n\n    # Get spatial_dims from spatial_dims_set, which contains only unique values\n    spatial_dims_list = list(spatial_dims_set)\n    if len(spatial_dims_list) == 0:\n        raise ValueError(\"At least one of the inputs needs to be non-empty.\")\n\n    if len(spatial_dims_list) == 1:\n        spatial_dims = int(spatial_dims_list[0])\n        spatial_dims = look_up_option(spatial_dims, supported=[2, 3])\n        return int(spatial_dims)\n\n    raise ValueError(\"The dimensions of multiple inputs should match with each other.\")\n\n\ndef get_boxmode(mode: str | BoxMode | type[BoxMode] | None = None, *args, **kwargs) -> BoxMode:\n    \"\"\"\n    This function that return a :class:`~monai.data.box_utils.BoxMode` object giving a representation of box mode\n\n    Args:\n        mode: a representation of box mode. If it is not given, this func will assume it is ``StandardMode()``.\n\n    Note:\n        ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,\n        also represented as \"xyxy\" for 2D and \"xyzxyz\" for 3D.\n\n        mode can be:\n            #. str: choose from :class:`~monai.utils.enums.BoxModeName`, for example,\n                - \"xyxy\": boxes has format [xmin, ymin, xmax, ymax]\n                - \"xyzxyz\": boxes has format [xmin, ymin, zmin, xmax, ymax, zmax]\n                - \"xxyy\": boxes has format [xmin, xmax, ymin, ymax]\n                - \"xxyyzz\": boxes has format [xmin, xmax, ymin, ymax, zmin, zmax]\n                - \"xyxyzz\": boxes has format [xmin, ymin, xmax, ymax, zmin, zmax]\n                - \"xywh\": boxes has format [xmin, ymin, xsize, ysize]\n                - \"xyzwhd\": boxes has format [xmin, ymin, zmin, xsize, ysize, zsize]\n                - \"ccwh\": boxes has format [xcenter, ycenter, xsize, ysize]\n                - \"cccwhd\": boxes has format [xcenter, ycenter, zcenter, xsize, ysize, zsize]\n            #. BoxMode class: choose from the subclasses of :class:`~monai.data.box_utils.BoxMode`, for example,\n                - CornerCornerModeTypeA: equivalent to \"xyxy\" or \"xyzxyz\"\n                - CornerCornerModeTypeB: equivalent to \"xxyy\" or \"xxyyzz\"\n                - CornerCornerModeTypeC: equivalent to \"xyxy\" or \"xyxyzz\"\n                - CornerSizeMode: equivalent to \"xywh\" or \"xyzwhd\"\n                - CenterSizeMode: equivalent to \"ccwh\" or \"cccwhd\"\n            #. BoxMode object: choose from the subclasses of :class:`~monai.data.box_utils.BoxMode`, for example,\n                - CornerCornerModeTypeA(): equivalent to \"xyxy\" or \"xyzxyz\"\n                - CornerCornerModeTypeB(): equivalent to \"xxyy\" or \"xxyyzz\"\n                - CornerCornerModeTypeC(): equivalent to \"xyxy\" or \"xyxyzz\"\n                - CornerSizeMode(): equivalent to \"xywh\" or \"xyzwhd\"\n                - CenterSizeMode(): equivalent to \"ccwh\" or \"cccwhd\"\n            #. None: will assume mode is ``StandardMode()``\n\n    Returns:\n        BoxMode object\n\n    Example:\n        .. code-block:: python\n\n            mode = \"xyzxyz\"\n            get_boxmode(mode) # will return CornerCornerModeTypeA()\n    \"\"\"\n    if isinstance(mode, BoxMode):\n        return mode\n\n    if inspect.isclass(mode) and issubclass(mode, BoxMode):\n        return mode(*args, **kwargs)\n\n    if isinstance(mode, str):\n        for m in SUPPORTED_MODES:\n            for n in SUPPORTED_SPATIAL_DIMS:\n                if inspect.isclass(m) and issubclass(m, BoxMode) and m.get_name(n) == mode:\n                    return m(*args, **kwargs)\n\n    if mode is not None:\n        raise ValueError(f\"Unsupported box mode: {mode}.\")\n    return StandardMode(*args, **kwargs)\n\n\ndef standardize_empty_box(boxes: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:\n    \"\"\"\n    When boxes are empty, this function standardize it to shape of (0,4) or (0,6).\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray\n        spatial_dims: number of spatial dimensions of the bounding boxes.\n\n    Returns:\n        bounding boxes with shape (N,4) or (N,6), N can be 0.\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(0,)\n            standardize_empty_box(boxes, 3)\n    \"\"\"\n    # convert numpy to tensor if needed\n    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)\n    # handle empty box\n    if boxes_t.shape[0] == 0:\n        boxes_t = torch.reshape(boxes_t, [0, spatial_dims * 2])\n    # convert tensor back to numpy if needed\n    boxes_dst, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)\n    return boxes_dst\n\n\ndef convert_box_mode(\n    boxes: NdarrayOrTensor,\n    src_mode: str | BoxMode | type[BoxMode] | None = None,\n    dst_mode: str | BoxMode | type[BoxMode] | None = None,\n) -> NdarrayOrTensor:\n    \"\"\"\n    This function converts the boxes in src_mode to the dst_mode.\n\n    Args:\n        boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.\n        src_mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.\n            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.\n        dst_mode: target box mode. If it is not given, this func will assume it is ``StandardMode()``.\n            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.\n\n    Returns:\n        bounding boxes with target mode, with same data type as ``boxes``, does not share memory with ``boxes``\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,4)\n            # The following three lines are equivalent\n            # They convert boxes with format [xmin, ymin, xmax, ymax] to [xcenter, ycenter, xsize, ysize].\n            convert_box_mode(boxes=boxes, src_mode=\"xyxy\", dst_mode=\"ccwh\")\n            convert_box_mode(boxes=boxes, src_mode=\"xyxy\", dst_mode=monai.data.box_utils.CenterSizeMode)\n            convert_box_mode(boxes=boxes, src_mode=\"xyxy\", dst_mode=monai.data.box_utils.CenterSizeMode())\n    \"\"\"\n    # handle empty box\n    if boxes.shape[0] == 0:\n        return boxes\n\n    src_boxmode = get_boxmode(src_mode)\n    dst_boxmode = get_boxmode(dst_mode)\n\n    # if mode not changed, deepcopy the original boxes\n    if isinstance(src_boxmode, type(dst_boxmode)):\n        return deepcopy(boxes)\n\n    # convert box mode\n    # convert numpy to tensor if needed\n    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)\n\n    # convert boxes to corners\n    corners = src_boxmode.boxes_to_corners(boxes_t)\n\n    # check validity of corners\n    spatial_dims = get_spatial_dims(boxes=boxes_t)\n    for axis in range(spatial_dims):\n        if (corners[spatial_dims + axis] < corners[axis]).sum() > 0:\n            warnings.warn(\"Given boxes has invalid values. The box size must be non-negative.\")\n\n    # convert corners to boxes\n    boxes_t_dst = dst_boxmode.corners_to_boxes(corners)\n\n    # convert tensor back to numpy if needed\n    boxes_dst, *_ = convert_to_dst_type(src=boxes_t_dst, dst=boxes)\n    return boxes_dst\n\n\ndef convert_box_to_standard_mode(\n    boxes: NdarrayOrTensor, mode: str | BoxMode | type[BoxMode] | None = None\n) -> NdarrayOrTensor:\n    \"\"\"\n    Convert given boxes to standard mode.\n    Standard mode is \"xyxy\" or \"xyzxyz\",\n    representing box format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].\n\n    Args:\n        boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.\n        mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.\n            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.\n\n    Returns:\n        bounding boxes with standard mode, with same data type as ``boxes``, does not share memory with ``boxes``\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,6)\n            # The following two lines are equivalent\n            # They convert boxes with format [xmin, xmax, ymin, ymax, zmin, zmax] to [xmin, ymin, zmin, xmax, ymax, zmax]\n            convert_box_to_standard_mode(boxes=boxes, mode=\"xxyyzz\")\n            convert_box_mode(boxes=boxes, src_mode=\"xxyyzz\", dst_mode=\"xyzxyz\")\n    \"\"\"\n    return convert_box_mode(boxes=boxes, src_mode=mode, dst_mode=StandardMode())\n\n\ndef box_centers(boxes: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute center points of boxes\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n    Returns:\n        center points with size of (N, spatial_dims)\n\n    \"\"\"\n    spatial_dims = get_spatial_dims(boxes=boxes)\n    return convert_box_mode(boxes=boxes, src_mode=StandardMode, dst_mode=CenterSizeMode)[:, :spatial_dims]\n\n\ndef centers_in_boxes(centers: NdarrayOrTensor, boxes: NdarrayOrTensor, eps: float = 0.01) -> NdarrayOrTensor:\n    \"\"\"\n    Checks which center points are within boxes\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.\n        centers: center points, Nx2 or Nx3 torch tensor or ndarray.\n        eps: minimum distance to border of boxes.\n\n    Returns:\n        boolean array indicating which center points are within the boxes, sized (N,).\n\n    Reference:\n        https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py\n\n    \"\"\"\n    spatial_dims = get_spatial_dims(boxes=boxes)\n\n    # compute relative position of centers compared to borders\n    # should be non-negative if centers are within boxes\n    center_to_border = [centers[:, axis] - boxes[:, axis] for axis in range(spatial_dims)] + [\n        boxes[:, axis + spatial_dims] - centers[:, axis] for axis in range(spatial_dims)\n    ]\n\n    if isinstance(boxes, np.ndarray):\n        min_center_to_border: np.ndarray = np.stack(center_to_border, axis=1).min(axis=1)\n        return min_center_to_border > eps  # array[bool]\n\n    return torch.stack(center_to_border, dim=1).to(COMPUTE_DTYPE).min(dim=1)[0] > eps  # type: ignore\n\n\ndef boxes_center_distance(\n    boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor, euclidean: bool = True\n) -> tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]:\n    \"\"\"\n    Distance of center points between two sets of boxes\n\n    Args:\n        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        euclidean: computed the euclidean distance otherwise it uses the l1 distance\n\n    Returns:\n        - The pairwise distances for every element in boxes1 and boxes2,\n          with size of (N,M) and same data type as ``boxes1``.\n        - Center points of boxes1, with size of (N,spatial_dims) and same data type as ``boxes1``.\n        - Center points of boxes2, with size of (M,spatial_dims) and same data type as ``boxes1``.\n\n    Reference:\n        https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py\n\n    \"\"\"\n\n    if not isinstance(boxes1, type(boxes2)):\n        warnings.warn(f\"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}.\")\n\n    # convert numpy to tensor if needed\n    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)\n    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)\n\n    center1 = box_centers(boxes1_t.to(COMPUTE_DTYPE))  # (N, spatial_dims)\n    center2 = box_centers(boxes2_t.to(COMPUTE_DTYPE))  # (M, spatial_dims)\n\n    if euclidean:\n        dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt()  # type: ignore\n    else:\n        # before sum: (N, M, spatial_dims)\n        dists = (center1[:, None] - center2[None]).sum(-1)\n\n    # convert tensor back to numpy if needed\n    (dists, center1, center2), *_ = convert_to_dst_type(src=(dists, center1, center2), dst=boxes1)\n    return dists, center1, center2\n\n\ndef is_valid_box_values(boxes: NdarrayOrTensor) -> bool:\n    \"\"\"\n    This function checks whether the box size is non-negative.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n    Returns:\n        whether ``boxes`` is valid\n    \"\"\"\n    spatial_dims = get_spatial_dims(boxes=boxes)\n    for axis in range(spatial_dims):\n        if (boxes[:, spatial_dims + axis] < boxes[:, axis]).sum() > 0:\n            return False\n    return True\n\n\ndef box_area(boxes: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    This function computes the area (2D) or volume (3D) of each box.\n    Half precision is not recommended for this function as it may cause overflow, especially for 3D images.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n    Returns:\n        area (2D) or volume (3D) of boxes, with size of (N,).\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,6)\n            # we do computation with torch.float32 to avoid overflow\n            compute_dtype = torch.float32\n            area = box_area(boxes=boxes.to(dtype=compute_dtype))  # torch.float32, size of (10,)\n    \"\"\"\n\n    if not is_valid_box_values(boxes):\n        raise ValueError(\"Given boxes has invalid values. The box size must be non-negative.\")\n\n    spatial_dims = get_spatial_dims(boxes=boxes)\n\n    area = boxes[:, spatial_dims] - boxes[:, 0] + TO_REMOVE\n    for axis in range(1, spatial_dims):\n        area = area * (boxes[:, axis + spatial_dims] - boxes[:, axis] + TO_REMOVE)\n\n    # convert numpy to tensor if needed\n    area_t, *_ = convert_data_type(area, torch.Tensor)\n\n    # check if NaN or Inf, especially for half precision\n    if area_t.isnan().any() or area_t.isinf().any():\n        if area_t.dtype is torch.float16:\n            raise ValueError(\"Box area is NaN or Inf. boxes is float16. Please change to float32 and test it again.\")\n        else:\n            raise ValueError(\"Box area is NaN or Inf.\")\n\n    return area\n\n\ndef _box_inter_union(\n    boxes1_t: torch.Tensor, boxes2_t: torch.Tensor, compute_dtype: torch.dtype = torch.float32\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    This internal function computes the intersection and union area of two set of boxes.\n\n    Args:\n        boxes1: bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n        boxes2: bounding boxes, Mx4 or Mx6 torch tensor. The box mode is assumed to be ``StandardMode``\n        compute_dtype: default torch.float32, dtype with which the results will be computed\n\n    Returns:\n        inter, with size of (N,M) and dtype of ``compute_dtype``.\n        union, with size of (N,M) and dtype of ``compute_dtype``.\n\n    \"\"\"\n    spatial_dims = get_spatial_dims(boxes=boxes1_t)\n\n    # compute area with float32\n    area1 = box_area(boxes=boxes1_t.to(dtype=compute_dtype))  # (N,)\n    area2 = box_area(boxes=boxes2_t.to(dtype=compute_dtype))  # (M,)\n\n    # get the left top and right bottom points for the NxM combinations\n    lt = torch.max(boxes1_t[:, None, :spatial_dims], boxes2_t[:, :spatial_dims]).to(\n        dtype=compute_dtype\n    )  # (N,M,spatial_dims) left top\n    rb = torch.min(boxes1_t[:, None, spatial_dims:], boxes2_t[:, spatial_dims:]).to(\n        dtype=compute_dtype\n    )  # (N,M,spatial_dims) right bottom\n\n    # compute size for the intersection region for the NxM combinations\n    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,M,spatial_dims)\n    inter: torch.Tensor = torch.prod(wh, dim=-1, keepdim=False)  # (N,M)\n\n    union: torch.Tensor = area1[:, None] + area2 - inter  # type: ignore\n    return inter, union\n\n\ndef box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute the intersection over union (IoU) of two set of boxes.\n\n    Args:\n        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n    Returns:\n        An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always\n        floating-point with size ``(N, M)``:\n        - if ``boxes1`` has a floating-point dtype, the same dtype is used.\n        - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.\n\n    \"\"\"\n\n    if not isinstance(boxes1, type(boxes2)):\n        warnings.warn(f\"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}.\")\n\n    # convert numpy to tensor if needed\n    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)\n    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)\n\n    # we do computation with compute_dtype to avoid overflow\n    box_dtype = boxes1_t.dtype\n\n    inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)\n\n    # compute IoU and convert back to original box_dtype or torch.float32\n    iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps)  # (N,M)\n    if not box_dtype.is_floating_point:\n        box_dtype = COMPUTE_DTYPE\n    iou_t = iou_t.to(dtype=box_dtype)\n\n    # check if NaN or Inf\n    if torch.isnan(iou_t).any() or torch.isinf(iou_t).any():\n        raise ValueError(\"Box IoU is NaN or Inf.\")\n\n    # convert tensor back to numpy if needed\n    iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1, dtype=box_dtype)\n    return iou\n\n\ndef box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute the generalized intersection over union (GIoU) of two sets of boxes.\n    The two inputs can have different shapes and the func return an NxM matrix,\n    (in contrary to :func:`~monai.data.box_utils.box_pair_giou` , which requires the inputs to have the same\n    shape and returns ``N`` values).\n\n    Args:\n        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n\n    Returns:\n        An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always\n        floating-point with size ``(N, M)``:\n        - if ``boxes1`` has a floating-point dtype, the same dtype is used.\n        - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.\n\n    Reference:\n        https://giou.stanford.edu/GIoU.pdf\n\n    \"\"\"\n\n    if not isinstance(boxes1, type(boxes2)):\n        warnings.warn(f\"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}.\")\n\n    # convert numpy to tensor if needed\n    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)\n    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)\n\n    spatial_dims = get_spatial_dims(boxes=boxes1_t)\n\n    # we do computation with compute_dtype to avoid overflow\n    box_dtype = boxes1_t.dtype\n\n    inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)\n    iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps)  # (N,M)\n\n    # Enclosure\n    # get the left top and right bottom points for the NxM combinations\n    lt = torch.min(boxes1_t[:, None, :spatial_dims], boxes2_t[:, :spatial_dims]).to(\n        dtype=COMPUTE_DTYPE\n    )  # (N,M,spatial_dims) left top\n    rb = torch.max(boxes1_t[:, None, spatial_dims:], boxes2_t[:, spatial_dims:]).to(\n        dtype=COMPUTE_DTYPE\n    )  # (N,M,spatial_dims) right bottom\n\n    # compute size for the enclosure region for the NxM combinations\n    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,M,spatial_dims)\n    enclosure = torch.prod(wh, dim=-1, keepdim=False)  # (N,M)\n\n    # GIoU\n    giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps)\n    if not box_dtype.is_floating_point:\n        box_dtype = COMPUTE_DTYPE\n    giou_t = giou_t.to(dtype=box_dtype)\n\n    if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():\n        raise ValueError(\"Box GIoU is NaN or Inf.\")\n\n    # convert tensor back to numpy if needed\n    giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)\n    return giou\n\n\ndef box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"\n    Compute the generalized intersection over union (GIoU) of a pair of boxes.\n    The two inputs should have the same shape and the func return an (N,) array,\n    (in contrary to :func:`~monai.data.box_utils.box_giou` , which does not require the inputs to have the same\n    shape and returns ``NxM`` matrix).\n\n    Args:\n        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``\n\n    Returns:\n        An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always\n        floating-point with size ``(N, )``:\n        - if ``boxes1`` has a floating-point dtype, the same dtype is used.\n        - if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.\n\n    Reference:\n        https://giou.stanford.edu/GIoU.pdf\n\n    \"\"\"\n\n    if not isinstance(boxes1, type(boxes2)):\n        warnings.warn(f\"boxes1 is {type(boxes1)}, while boxes2 is {type(boxes2)}. The result will be {type(boxes1)}.\")\n\n    # convert numpy to tensor if needed\n    boxes1_t, *_ = convert_data_type(boxes1, torch.Tensor)\n    boxes2_t, *_ = convert_data_type(boxes2, torch.Tensor)\n\n    if boxes1_t.shape != boxes2_t.shape:\n        raise ValueError(\"boxes1 and boxes2 should be paired and have same shape.\")\n\n    spatial_dims = get_spatial_dims(boxes=boxes1_t)\n\n    # we do computation with compute_dtype to avoid overflow\n    box_dtype = boxes1_t.dtype\n\n    # compute area\n    area1 = box_area(boxes=boxes1_t.to(dtype=COMPUTE_DTYPE))  # (N,)\n    area2 = box_area(boxes=boxes2_t.to(dtype=COMPUTE_DTYPE))  # (N,)\n\n    # Intersection\n    # get the left top and right bottom points for the boxes pair\n    lt = torch.max(boxes1_t[:, :spatial_dims], boxes2_t[:, :spatial_dims]).to(\n        dtype=COMPUTE_DTYPE\n    )  # (N,spatial_dims) left top\n    rb = torch.min(boxes1_t[:, spatial_dims:], boxes2_t[:, spatial_dims:]).to(\n        dtype=COMPUTE_DTYPE\n    )  # (N,spatial_dims) right bottom\n\n    # compute size for the intersection region for the boxes pair\n    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,spatial_dims)\n    inter = torch.prod(wh, dim=-1, keepdim=False)  # (N,)\n\n    # compute IoU and convert back to original box_dtype\n    union = area1 + area2 - inter\n    iou = inter / (union + torch.finfo(COMPUTE_DTYPE).eps)  # (N,)\n\n    # Enclosure\n    # get the left top and right bottom points for the boxes pair\n    lt = torch.min(boxes1_t[:, :spatial_dims], boxes2_t[:, :spatial_dims]).to(\n        dtype=COMPUTE_DTYPE\n    )  # (N,spatial_dims) left top\n    rb = torch.max(boxes1_t[:, spatial_dims:], boxes2_t[:, spatial_dims:]).to(\n        dtype=COMPUTE_DTYPE\n    )  # (N,spatial_dims) right bottom\n\n    # compute size for the enclose region for the boxes pair\n    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # (N,spatial_dims)\n    enclosure = torch.prod(wh, dim=-1, keepdim=False)  # (N,)\n\n    giou_t: torch.Tensor = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps)  # type: ignore\n    if not box_dtype.is_floating_point:\n        box_dtype = COMPUTE_DTYPE\n    giou_t = giou_t.to(dtype=box_dtype)  # (N,spatial_dims)\n\n    if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():\n        raise ValueError(\"Box GIoU is NaN or Inf.\")\n\n    # convert tensor back to numpy if needed\n    giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)\n    return giou\n\n\ndef spatial_crop_boxes(\n    boxes: NdarrayTensor,\n    roi_start: Sequence[int] | NdarrayOrTensor,\n    roi_end: Sequence[int] | NdarrayOrTensor,\n    remove_empty: bool = True,\n) -> tuple[NdarrayTensor, NdarrayOrTensor]:\n    \"\"\"\n    This function generate the new boxes when the corresponding image is cropped to the given ROI.\n    When ``remove_empty=True``, it makes sure the bounding boxes are within the new cropped image.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        roi_start: voxel coordinates for start of the crop ROI, negative values allowed.\n        roi_end: voxel coordinates for end of the crop ROI, negative values allowed.\n        remove_empty: whether to remove the boxes that are actually empty\n\n    Returns:\n        - cropped boxes, boxes[keep], does not share memory with original boxes\n        - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.\n    \"\"\"\n\n    # convert numpy to tensor if needed\n    boxes_t = convert_data_type(boxes, torch.Tensor)[0].clone()\n\n    # convert to float32 since torch.clamp_ does not support float16\n    boxes_t = boxes_t.to(dtype=COMPUTE_DTYPE)\n\n    roi_start_t = convert_to_dst_type(src=roi_start, dst=boxes_t, wrap_sequence=True)[0].to(torch.int16)\n    roi_end_t = convert_to_dst_type(src=roi_end, dst=boxes_t, wrap_sequence=True)[0].to(torch.int16)\n    roi_end_t = torch.maximum(roi_end_t, roi_start_t)\n\n    # makes sure the bounding boxes are within the patch\n    spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=roi_end)\n    for axis in range(spatial_dims):\n        boxes_t[:, axis] = boxes_t[:, axis].clamp(min=roi_start_t[axis], max=roi_end_t[axis] - TO_REMOVE)\n        boxes_t[:, axis + spatial_dims] = boxes_t[:, axis + spatial_dims].clamp(\n            min=roi_start_t[axis], max=roi_end_t[axis] - TO_REMOVE\n        )\n        boxes_t[:, axis] -= roi_start_t[axis]\n        boxes_t[:, axis + spatial_dims] -= roi_start_t[axis]\n\n    # remove the boxes that are actually empty\n    if remove_empty:\n        keep_t = boxes_t[:, spatial_dims] >= boxes_t[:, 0] + 1 - TO_REMOVE\n        for axis in range(1, spatial_dims):\n            keep_t = keep_t & (boxes_t[:, axis + spatial_dims] >= boxes_t[:, axis] + 1 - TO_REMOVE)\n        boxes_t = boxes_t[keep_t]\n    else:\n        keep_t = torch.full_like(boxes_t[:, 0], fill_value=True, dtype=torch.bool)\n\n    # convert tensor back to numpy if needed\n    boxes_keep, *_ = convert_to_dst_type(src=boxes_t, dst=boxes)\n    keep, *_ = convert_to_dst_type(src=keep_t, dst=boxes, dtype=keep_t.dtype)\n\n    return boxes_keep, keep\n\n\ndef clip_boxes_to_image(\n    boxes: NdarrayOrTensor, spatial_size: Sequence[int] | NdarrayOrTensor, remove_empty: bool = True\n) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:\n    \"\"\"\n    This function clips the ``boxes`` to makes sure the bounding boxes are within the image.\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        spatial_size: The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3].\n        remove_empty: whether to remove the boxes that are actually empty\n\n    Returns:\n        - clipped boxes, boxes[keep], does not share memory with original boxes\n        - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.\n    \"\"\"\n    spatial_dims = get_spatial_dims(boxes=boxes, spatial_size=spatial_size)\n    return spatial_crop_boxes(boxes, roi_start=[0] * spatial_dims, roi_end=spatial_size, remove_empty=remove_empty)\n\n\ndef non_max_suppression(\n    boxes: NdarrayOrTensor,\n    scores: NdarrayOrTensor,\n    nms_thresh: float,\n    max_proposals: int = -1,\n    box_overlap_metric: Callable = box_iou,\n) -> NdarrayOrTensor:\n    \"\"\"\n    Non-maximum suppression (NMS).\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.\n        nms_thresh: threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.\n        max_proposals: maximum number of boxes it keeps.\n            If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept.\n        box_overlap_metric: the metric to compute overlap between boxes.\n\n    Returns:\n        Indexes of ``boxes`` that are kept after NMS.\n\n    Example:\n        .. code-block:: python\n\n            boxes = torch.ones(10,6)\n            scores = torch.ones(10)\n            keep = non_max_suppression(boxes, scores, num_thresh=0.1)\n            boxes_after_nms = boxes[keep]\n    \"\"\"\n\n    # returns empty array if boxes is empty\n    if boxes.shape[0] == 0:\n        return convert_to_dst_type(src=np.array([]), dst=boxes, dtype=torch.long)[0]\n\n    if boxes.shape[0] != scores.shape[0]:\n        raise ValueError(\n            f\"boxes and scores should have same length, got boxes shape {boxes.shape}, scores shape {scores.shape}\"\n        )\n\n    # convert numpy to tensor if needed\n    boxes_t, *_ = convert_data_type(boxes, torch.Tensor)\n    scores_t, *_ = convert_to_dst_type(scores, boxes_t)\n\n    # sort boxes in descending order according to the scores\n    # use stable=True to ensure deterministic ordering when scores are equal\n    sort_idxs = torch.argsort(scores_t, dim=0, descending=True, stable=True)\n    boxes_sort = deepcopy(boxes_t)[sort_idxs, :]\n\n    # initialize the list of picked indexes\n    pick = []\n    idxs = torch.Tensor(list(range(boxes_sort.shape[0]))).to(device=boxes_t.device, dtype=torch.long)\n\n    # keep looping while some indexes still remain in the indexes list\n    while len(idxs) > 0:\n        # pick the first index in the indexes list and add the index value to the list of picked indexes\n        i = int(idxs[0].item())\n        pick.append(i)\n        if len(pick) >= max_proposals >= 1:\n            break\n\n        # compute the IoU between the rest of the boxes and the box just picked\n        box_overlap = box_overlap_metric(boxes_sort[idxs, :], boxes_sort[i : i + 1, :])\n\n        # keep only indexes from the index list that have overlap < nms_thresh\n        to_keep_idx = (box_overlap <= nms_thresh).flatten()\n        to_keep_idx[0] = False  # always remove idxs[0]\n        idxs = idxs[to_keep_idx]\n\n    # return only the bounding boxes that were picked using the integer data type\n    pick_idx = sort_idxs[pick]\n\n    # convert tensor back to numpy if needed\n    return convert_to_dst_type(src=pick_idx, dst=boxes, dtype=pick_idx.dtype)[0]\n\n\ndef batched_nms(\n    boxes: NdarrayOrTensor,\n    scores: NdarrayOrTensor,\n    labels: NdarrayOrTensor,\n    nms_thresh: float,\n    max_proposals: int = -1,\n    box_overlap_metric: Callable = box_iou,\n) -> NdarrayOrTensor:\n    \"\"\"\n    Performs non-maximum suppression in a batched fashion.\n    Each labels value correspond to a category, and NMS will not be applied between elements of different categories.\n\n    Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/nms.py\n\n    Args:\n        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``\n        scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.\n        labels: indices of the categories for each one of the boxes. sized(N,), value range is (0, num_classes)\n        nms_thresh: threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.\n        max_proposals: maximum number of boxes it keeps.\n            If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept.\n        box_overlap_metric: the metric to compute overlap between boxes.\n\n    Returns:\n        Indexes of ``boxes`` that are kept after NMS.\n    \"\"\"\n    # returns empty array if boxes is empty\n    if boxes.shape[0] == 0:\n        return convert_to_dst_type(src=np.array([]), dst=boxes, dtype=torch.long)[0]\n\n    # convert numpy to tensor if needed\n    boxes_t, *_ = convert_data_type(boxes, torch.Tensor, dtype=torch.float32)\n    scores_t, *_ = convert_to_dst_type(scores, boxes_t)\n    labels_t, *_ = convert_to_dst_type(labels, boxes_t, dtype=torch.long)\n\n    # strategy: in order to perform NMS independently per class.\n    # we add an offset to all the boxes. The offset is dependent\n    # only on the class idx, and is large enough so that boxes\n    # from different classes do not overlap\n    max_coordinate = boxes_t.max()\n    offsets = labels_t.to(boxes_t) * (max_coordinate + 1)\n    boxes_for_nms = boxes + offsets[:, None]\n    keep = non_max_suppression(boxes_for_nms, scores_t, nms_thresh, max_proposals, box_overlap_metric)\n\n    # convert tensor back to numpy if needed\n    return convert_to_dst_type(src=keep, dst=boxes, dtype=keep.dtype)[0]\n"
  },
  {
    "path": "monai/data/csv_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport warnings\nfrom collections import OrderedDict\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import PathLike\nfrom monai.utils import ImageMetaKey as Key\n\n\nclass CSVSaver:\n    \"\"\"\n    Save the data in a dictionary format cache, and write to a CSV file finally.\n    Typically, the data can be classification predictions, call `save` for single data\n    or call `save_batch` to save a batch of data together, and call `finalize` to write\n    the cached data into CSV file. If no metadata provided, use index from 0 to save data.\n    Note that this saver can't support multi-processing because it reads / writes single\n    CSV file and can't guarantee the data order in multi-processing situation.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        output_dir: PathLike = \"./\",\n        filename: str = \"predictions.csv\",\n        overwrite: bool = True,\n        flush: bool = False,\n        delimiter: str = \",\",\n    ) -> None:\n        \"\"\"\n        Args:\n            output_dir: output CSV file directory.\n            filename: name of the saved CSV file name.\n            overwrite: whether to overwriting existing CSV file content, if True, will clear the file before saving.\n                otherwise, will append new content to the CSV file.\n            flush: whether to write the cache data to CSV file immediately when `save_batch` and clear the cache.\n                default to False.\n            delimiter: the delimiter character in the saved file, default to \",\" as the default output type is `csv`.\n                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.\n\n        \"\"\"\n        self.output_dir = Path(output_dir)\n        self._cache_dict: OrderedDict = OrderedDict()\n        if not (isinstance(filename, str) and filename[-4:] == \".csv\"):\n            warnings.warn(\"CSV filename is not a string ends with '.csv'.\")\n        self._filepath = self.output_dir / filename\n        if self._filepath.exists() and overwrite:\n            os.remove(self._filepath)\n\n        self.flush = flush\n        self.delimiter = delimiter\n        self._data_index = 0\n\n    def finalize(self) -> None:\n        \"\"\"\n        Writes the cached dict to a csv\n\n        \"\"\"\n        if not self.output_dir.exists():\n            self.output_dir.mkdir(parents=True, exist_ok=True)\n        with open(self._filepath, \"a\") as f:\n            for k, v in self._cache_dict.items():\n                f.write(k)\n                for result in v.flatten():\n                    f.write(self.delimiter + str(result))\n                f.write(\"\\n\")\n        # clear cache content after writing\n        self.reset_cache()\n\n    def save(self, data: torch.Tensor | np.ndarray, meta_data: dict | None = None) -> None:\n        \"\"\"Save data into the cache dictionary. The metadata should have the following key:\n            - ``'filename_or_obj'`` -- save the data corresponding to file name or object.\n        If meta_data is None, use the default index from 0 to save data instead.\n\n        Args:\n            data: target data content that save into cache.\n            meta_data: the metadata information corresponding to the data.\n\n        \"\"\"\n        save_key = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)\n        self._data_index += 1\n        if isinstance(data, torch.Tensor):\n            data = data.detach().cpu().numpy()\n        self._cache_dict[save_key] = np.asarray(data, dtype=float)\n\n    def save_batch(self, batch_data: torch.Tensor | np.ndarray, meta_data: dict | None = None) -> None:\n        \"\"\"Save a batch of data into the cache dictionary.\n\n        Args:\n            batch_data: target batch data content that save into cache.\n            meta_data: every key-value in the meta_data is corresponding to 1 batch of data.\n\n        \"\"\"\n        for i, data in enumerate(batch_data):  # save a batch of files\n            self.save(data, {k: meta_data[k][i] for k in meta_data} if meta_data else None)\n\n        if self.flush:\n            self.finalize()\n\n    def get_cache(self) -> OrderedDict:\n        \"\"\"Get the cache dictionary, key is filename and value is the corresponding data\"\"\"\n\n        return self._cache_dict\n\n    def reset_cache(self) -> None:\n        \"\"\"Clear the cache dictionary content\"\"\"\n        self._cache_dict.clear()\n"
  },
  {
    "path": "monai/data/dataloader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport torch\nfrom torch.utils.data import DataLoader as _TorchDataLoader\nfrom torch.utils.data import Dataset\n\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.utils import list_data_collate, set_rnd, worker_init_fn\n\n__all__ = [\"DataLoader\"]\n\n\nclass DataLoader(_TorchDataLoader):\n    \"\"\"\n    Provides an iterable over the given `dataset`.  It inherits the PyTorch\n    DataLoader and adds enhanced `collate_fn` and `worker_fn` by default.\n\n    Although this class could be configured to be the same as\n    `torch.utils.data.DataLoader`, its default configuration is\n    recommended, mainly for the following extra features:\n\n        - It handles MONAI randomizable objects with appropriate random state\n          managements for deterministic behaviour.\n        - It is aware of the patch-based transform (such as\n          :py:class:`monai.transforms.RandSpatialCropSamplesDict`) samples for\n          preprocessing with enhanced data collating behaviour.\n          See: :py:class:`monai.transforms.Compose`.\n\n    For more details about :py:class:`torch.utils.data.DataLoader`, please see:\n    https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader.\n\n    For example, to construct a randomized dataset and iterate with the data loader:\n\n    .. code-block:: python\n\n        import torch\n\n        from monai.data import DataLoader\n        from monai.transforms import Randomizable\n\n\n        class RandomDataset(torch.utils.data.Dataset, Randomizable):\n            def __getitem__(self, index):\n                return self.R.randint(0, 1000, (1,))\n\n            def __len__(self):\n                return 16\n\n\n        dataset = RandomDataset()\n        dataloader = DataLoader(dataset, batch_size=2, num_workers=4)\n        for epoch in range(2):\n            for i, batch in enumerate(dataloader):\n                print(epoch, i, batch.data.numpy().flatten().tolist())\n\n    Args:\n        dataset: dataset from which to load the data.\n        num_workers: how many subprocesses to use for data\n            loading. ``0`` means that the data will be loaded in the main process.\n            (default: ``0``)\n        collate_fn: default to :py:func:`monai.data.utils.list_data_collate`.\n        worker_init_fn: default to :py:func:`monai.data.utils.worker_init_fn`.\n        kwargs: other parameters for PyTorch DataLoader.\n    \"\"\"\n\n    def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:\n        if num_workers == 0:\n            # when num_workers > 0, random states are determined by worker_init_fn\n            # this is to make the behavior consistent when num_workers == 0\n            # torch.int64 doesn't work well on some versions of windows\n            _g = torch.random.default_generator if kwargs.get(\"generator\") is None else kwargs[\"generator\"]\n            init_seed = _g.initial_seed()\n            _seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item()\n            set_rnd(dataset, int(_seed))\n            _g.manual_seed(init_seed)\n        if \"collate_fn\" not in kwargs:\n            kwargs[\"collate_fn\"] = list_data_collate\n        if \"worker_init_fn\" not in kwargs:\n            kwargs[\"worker_init_fn\"] = worker_init_fn\n\n        if (\n            \"multiprocessing_context\" in kwargs\n            and kwargs[\"multiprocessing_context\"] == \"spawn\"\n            and not get_track_meta()\n        ):\n            warnings.warn(\n                \"Please be aware: Return type of the dataloader will not be a Tensor as expected but\"\n                \" a MetaTensor instead! This is because 'spawn' creates a new process where _TRACK_META\"\n                \" is initialized to True again. Context:_TRACK_META is set to False and\"\n                \" multiprocessing_context to spawn\"\n            )\n\n        super().__init__(dataset=dataset, num_workers=num_workers, **kwargs)\n"
  },
  {
    "path": "monai/data/dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport collections.abc\nimport math\nimport shutil\nimport sys\nimport tempfile\nimport threading\nimport time\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom copy import copy, deepcopy\nfrom io import BytesIO\nfrom multiprocessing.managers import ListProxy\nfrom multiprocessing.pool import ThreadPool\nfrom pathlib import Path\nfrom pickle import UnpicklingError\nfrom typing import IO, TYPE_CHECKING, Any, cast\n\nimport numpy as np\nimport torch\nfrom torch.multiprocessing import Manager\nfrom torch.serialization import DEFAULT_PROTOCOL\nfrom torch.utils.data import Dataset as _TorchDataset\nfrom torch.utils.data import Subset\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing\nfrom monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id\nfrom monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import\nfrom monai.utils.misc import first\n\nif TYPE_CHECKING:\n    from tqdm import tqdm\n\n    has_tqdm = True\nelse:\n    tqdm, has_tqdm = optional_import(\"tqdm\", \"4.47.0\", min_version, \"tqdm\")\n\ncp, _ = optional_import(\"cupy\")\nlmdb, _ = optional_import(\"lmdb\")\npd, _ = optional_import(\"pandas\")\nkvikio_numpy, _ = optional_import(\"kvikio.numpy\")\n\n\nclass Dataset(_TorchDataset):\n    \"\"\"\n    A generic dataset with a length property and an optional callable data transform\n    when fetching a data sample.\n    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,\n    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset\n\n    For example, typical input data can be a list of dictionaries::\n\n        [{                            {                            {\n             'img': 'image1.nii.gz',      'img': 'image2.nii.gz',      'img': 'image3.nii.gz',\n             'seg': 'label1.nii.gz',      'seg': 'label2.nii.gz',      'seg': 'label3.nii.gz',\n             'extra': 123                 'extra': 456                 'extra': 789\n         },                           },                           }]\n    \"\"\"\n\n    def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None:\n        \"\"\"\n        Args:\n            data: input data to load and transform to generate dataset for model.\n            transform: a callable, sequence of callables or None. If transform is not\n            a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences\n            of callables are applied in order and if `None` is passed, the data is returned as is.\n        \"\"\"\n        self.data = data\n        try:\n            self.transform = Compose(transform) if not isinstance(transform, Compose) else transform\n        except Exception as e:\n            raise ValueError(\"`transform` must be a callable or a list of callables that is Composable\") from e\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n    def _transform(self, index: int):\n        \"\"\"\n        Fetch single data item from `self.data`.\n        \"\"\"\n        data_i = self.data[index]\n        return self.transform(data_i)\n\n    def __getitem__(self, index: int | slice | Sequence[int]):\n        \"\"\"\n        Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise.\n        \"\"\"\n        if isinstance(index, slice):\n            # dataset[:42]\n            start, stop, step = index.indices(len(self))\n            indices = range(start, stop, step)\n            return Subset(dataset=self, indices=indices)\n        if isinstance(index, collections.abc.Sequence):\n            # dataset[[1, 3, 4]]\n            return Subset(dataset=self, indices=index)\n        return self._transform(index)\n\n\nclass DatasetFunc(Dataset):\n    \"\"\"\n    Execute function on the input dataset and leverage the output to act as a new Dataset.\n    It can be used to load / fetch the basic dataset items, like the list of `image, label` paths.\n    Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc.\n    The `data` arg of `Dataset` will be applied to the first arg of callable `func`.\n    Usage example::\n\n        data_list = DatasetFunc(\n            data=\"path to file\",\n            func=monai.data.load_decathlon_datalist,\n            data_list_key=\"validation\",\n            base_dir=\"path to base dir\",\n        )\n        # partition dataset for every rank\n        data_partition = DatasetFunc(\n            data=data_list,\n            func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()],\n            num_partitions=torch.distributed.get_world_size(),\n        )\n        dataset = Dataset(data=data_partition, transform=transforms)\n\n    Args:\n        data: input data for the func to process, will apply to `func` as the first arg.\n        func: callable function to generate dataset items.\n        kwargs: other arguments for the `func` except for the first arg.\n\n    \"\"\"\n\n    def __init__(self, data: Any, func: Callable, **kwargs) -> None:\n        super().__init__(data=None, transform=None)  # type:ignore\n        self.src = data\n        self.func = func\n        self.kwargs = kwargs\n        self.reset()\n\n    def reset(self, data: Any | None = None, func: Callable | None = None, **kwargs):\n        \"\"\"\n        Reset the dataset items with specified `func`.\n\n        Args:\n            data: if not None, execute `func` on it, default to `self.src`.\n            func: if not None, execute the `func` with specified `kwargs`, default to `self.func`.\n            kwargs: other arguments for the `func` except for the first arg.\n\n        \"\"\"\n        src = self.src if data is None else data\n        self.data = self.func(src, **self.kwargs) if func is None else func(src, **kwargs)\n\n\nclass PersistentDataset(Dataset):\n    \"\"\"\n    Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data,\n    it can operate transforms for specific fields.  Results from the non-random transform components are computed\n    when first used, and stored in the `cache_dir` for rapid retrieval on subsequent uses.\n    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,\n    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset\n\n    The transforms which are supposed to be cached must implement the `monai.transforms.Transform`\n    interface and should not be `Randomizable`. This dataset will cache the outcomes before the first\n    `Randomizable` `Transform` within a `Compose` instance.\n\n    For example, typical input data can be a list of dictionaries::\n\n        [{                            {                            {\n            'image': 'image1.nii.gz',    'image': 'image2.nii.gz',    'image': 'image3.nii.gz',\n            'label': 'label1.nii.gz',    'label': 'label2.nii.gz',    'label': 'label3.nii.gz',\n            'extra': 123                 'extra': 456                 'extra': 789\n        },                           },                           }]\n\n    For a composite transform like\n\n    .. code-block:: python\n\n        [ LoadImaged(keys=['image', 'label']),\n        Orientationd(keys=['image', 'label'], axcodes='RAS'),\n        ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n        RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96),\n                                pos=1, neg=1, num_samples=4, image_key='image', image_threshold=0),\n        ToTensord(keys=['image', 'label'])]\n\n    Upon first use a filename based dataset will be processed by the transform for the\n    [LoadImaged, Orientationd, ScaleIntensityRanged] and the resulting tensor written to\n    the `cache_dir` before applying the remaining random dependant transforms\n    [RandCropByPosNegLabeld, ToTensord] elements for use in the analysis.\n\n    Subsequent uses of a dataset directly read pre-processed results from `cache_dir`\n    followed by applying the random dependant parts of transform processing.\n\n    During training call `set_data()` to update input data and recompute cache content.\n\n    Note:\n        The input data must be a list of file paths and will hash them as cache keys.\n\n        The filenames of the cached files also try to contain the hash of the transforms. In this\n        fashion, `PersistentDataset` should be robust to changes in transforms. This, however, is\n        not guaranteed, so caution should be used when modifying transforms to avoid unexpected\n        errors. If in doubt, it is advisable to clear the cache directory.\n\n        Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will\n        be converted to tensors, however any other object type returned by transforms will not be loadable since\n        `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.\n        Legacy cache files may not be loadable and may need to be recomputed.\n\n    Lazy Resampling:\n        If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to\n        its documentation to familiarize yourself with the interaction between `PersistentDataset` and\n        lazy resampling.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        transform: Sequence[Callable] | Callable,\n        cache_dir: Path | str | None,\n        hash_func: Callable[..., bytes] = pickle_hashing,\n        pickle_module: str = \"pickle\",\n        pickle_protocol: int = DEFAULT_PROTOCOL,\n        hash_transform: Callable[..., bytes] | None = None,\n        reset_ops_id: bool = True,\n        track_meta: bool = False,\n        weights_only: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            data: input data file paths to load and transform to generate dataset for model.\n                `PersistentDataset` expects input data to be a list of serializable\n                and hashes them as cache keys using `hash_func`.\n            transform: transforms to execute operations on input data.\n            cache_dir: If specified, this is the location for persistent storage\n                of pre-computed transformed data tensors. The cache_dir is computed once, and\n                persists on disk until explicitly removed.  Different runs, programs, experiments\n                may share a common cache dir provided that the transforms pre-processing is consistent.\n                If `cache_dir` doesn't exist, will automatically create it.\n                If `cache_dir` is `None`, there is effectively no caching.\n            hash_func: a callable to compute hash from data items to be cached.\n                defaults to `monai.data.utils.pickle_hashing`.\n            pickle_module: string representing the module used for pickling metadata and objects,\n                default to `\"pickle\"`. due to the pickle limitation in multi-processing of Dataloader,\n                we can't use `pickle` as arg directly, so here we use a string name instead.\n                if want to use other pickle module at runtime, just register like:\n                >>> from monai.data import utils\n                >>> utils.SUPPORTED_PICKLE_MOD[\"test\"] = other_pickle\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,\n                and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.\n            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.\n                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            hash_transform: a callable to compute hash from the transform information when caching.\n                This may reduce errors due to transforms changing during experiments. Default to None (no hash).\n                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.\n            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.\n                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.\n                This is useful for skipping the transform instance checks when inverting applied operations\n                using the cached content and with re-created transform instances.\n            track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.\n                default to `False`. Cannot be used with `weights_only=True`.\n            weights_only: keyword argument passed to `torch.load` when reading cached files.\n                default to `True`. When set to `True`, `torch.load` restricts loading to tensors and\n                other safe objects. Setting this to `False` is required for loading `MetaTensor`\n                objects saved with `track_meta=True`, however this creates the possibility of remote\n                code execution through `torch.load` so be aware of the security implications of doing so.\n\n        Raises:\n            ValueError: When both `track_meta=True` and `weights_only=True`, since this combination\n                prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.\n        \"\"\"\n        super().__init__(data=data, transform=transform)\n        self.cache_dir = Path(cache_dir) if cache_dir is not None else None\n        self.hash_func = hash_func\n        self.pickle_module = pickle_module\n        self.pickle_protocol = pickle_protocol\n        if self.cache_dir is not None:\n            if not self.cache_dir.exists():\n                self.cache_dir.mkdir(parents=True, exist_ok=True)\n            if not self.cache_dir.is_dir():\n                raise ValueError(\"cache_dir must be a directory.\")\n        self.transform_hash: str = \"\"\n        if hash_transform is not None:\n            self.set_transform_hash(hash_transform)\n        self.reset_ops_id = reset_ops_id\n        if track_meta and weights_only:\n            raise ValueError(\n                \"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. \"\n                \"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`.\"\n            )\n        self.track_meta = track_meta\n        self.weights_only = weights_only\n\n    def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):\n        \"\"\"Get hashable transforms, and then hash them. Hashable transforms\n        are deterministic transforms that inherit from `Transform`. We stop\n        at the first non-deterministic transform, or first that does not\n        inherit from MONAI's `Transform` class.\"\"\"\n        hashable_transforms = []\n        for _tr in self.transform.flatten().transforms:\n            if isinstance(_tr, RandomizableTrait) or not isinstance(_tr, Transform):\n                break\n            hashable_transforms.append(_tr)\n        # Try to hash. Fall back to a hash of their names\n        try:\n            transform_hash = hash_xform_func(hashable_transforms)\n        except TypeError as te:\n            if \"is not JSON serializable\" not in str(te):\n                raise te\n            names = \"\".join(tr.__class__.__name__ for tr in hashable_transforms)\n            transform_hash = hash_xform_func(names)\n        self.transform_hash = transform_hash.decode(\"utf-8\")\n\n    def set_data(self, data: Sequence):\n        \"\"\"\n        Set the input data and delete all the out-dated cache content.\n\n        \"\"\"\n        self.data = data\n        if self.cache_dir is not None and self.cache_dir.exists():\n            shutil.rmtree(self.cache_dir, ignore_errors=True)\n            self.cache_dir.mkdir(parents=True, exist_ok=True)\n\n    def _pre_transform(self, item_transformed):\n        \"\"\"\n        Process the data from original state up to the first random element.\n\n        Args:\n            item_transformed: The data to be transformed\n\n        Returns:\n            the transformed element up to the first identified\n            random transform object\n\n        \"\"\"\n        first_random = self.transform.get_index_of_first(\n            lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)\n        )\n        item_transformed = self.transform(item_transformed, end=first_random, threading=True)\n\n        if self.reset_ops_id:\n            reset_ops_id(item_transformed)\n        return item_transformed\n\n    def _post_transform(self, item_transformed):\n        \"\"\"\n        Process the data from before the first random transform to the final state ready for evaluation.\n\n        Args:\n            item_transformed: The data to be transformed (already processed up to the first random transform)\n\n        Returns:\n            the transformed element through the random transforms\n\n        \"\"\"\n        first_random = self.transform.get_index_of_first(\n            lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)\n        )\n        if first_random is not None:\n            item_transformed = self.transform(item_transformed, start=first_random)\n        return item_transformed\n\n    def _cachecheck(self, item_transformed):\n        \"\"\"\n        A function to cache the expensive input data transform operations\n        so that huge data sets (larger than computer memory) can be processed\n        on the fly as needed, and intermediate results written to disk for\n        future use.\n\n        Args:\n            item_transformed: The current data element to be mutated into transformed representation\n\n        Returns:\n            The transformed data_element, either from cache, or explicitly computing it.\n\n        Warning:\n            The current implementation does not encode transform information as part of the\n            hashing mechanism used for generating cache names when `hash_transform` is None.\n            If the transforms applied are changed in any way, the objects in the cache dir will be invalid.\n\n        \"\"\"\n        hashfile = None\n        if self.cache_dir is not None:\n            data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n            data_item_md5 += self.transform_hash\n            hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n\n        if hashfile is not None and hashfile.is_file():  # cache hit\n            try:\n                return torch.load(hashfile, weights_only=self.weights_only)\n            except PermissionError as e:\n                if sys.platform != \"win32\":\n                    raise e\n            except (UnpicklingError, RuntimeError) as e:  # corrupt or unloadable cached files are recomputed\n                if \"Invalid magic number; corrupt file\" in str(e) or isinstance(e, UnpicklingError):\n                    warnings.warn(f\"Corrupt cache file detected: {hashfile}. Deleting and recomputing.\")\n                    hashfile.unlink()\n                else:\n                    raise e\n\n        _item_transformed = self._pre_transform(deepcopy(item_transformed))  # keep the original hashed\n        if hashfile is None:\n            return _item_transformed\n        try:\n            # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n            #       to make the cache more robust to manual killing of parent process\n            #       which may leave partially written cache files in an incomplete state\n            with tempfile.TemporaryDirectory() as tmpdirname:\n                temp_hash_file = Path(tmpdirname) / hashfile.name\n                torch.save(\n                    obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),\n                    f=temp_hash_file,\n                    pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n                    pickle_protocol=self.pickle_protocol,\n                )\n                if temp_hash_file.is_file() and not hashfile.is_file():\n                    # On Unix, if target exists and is a file, it will be replaced silently if the user has permission.\n                    # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n                    try:\n                        shutil.move(str(temp_hash_file), hashfile)\n                    except FileExistsError:\n                        pass\n        except PermissionError:  # project-monai/monai issue #3613\n            pass\n        return _item_transformed\n\n    def _transform(self, index: int):\n        pre_random_item = self._cachecheck(self.data[index])\n        return self._post_transform(pre_random_item)\n\n\nclass CacheNTransDataset(PersistentDataset):\n    \"\"\"\n    Extension of `PersistentDataset`, it can also cache the result of first N transforms, no matter it's random or not.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        transform: Sequence[Callable] | Callable,\n        cache_n_trans: int,\n        cache_dir: Path | str | None,\n        hash_func: Callable[..., bytes] = pickle_hashing,\n        pickle_module: str = \"pickle\",\n        pickle_protocol: int = DEFAULT_PROTOCOL,\n        hash_transform: Callable[..., bytes] | None = None,\n        reset_ops_id: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            data: input data file paths to load and transform to generate dataset for model.\n                `PersistentDataset` expects input data to be a list of serializable\n                and hashes them as cache keys using `hash_func`.\n            transform: transforms to execute operations on input data.\n            cache_n_trans: cache the result of first N transforms.\n            cache_dir: If specified, this is the location for persistent storage\n                of pre-computed transformed data tensors. The cache_dir is computed once, and\n                persists on disk until explicitly removed.  Different runs, programs, experiments\n                may share a common cache dir provided that the transforms pre-processing is consistent.\n                If `cache_dir` doesn't exist, will automatically create it.\n                If `cache_dir` is `None`, there is effectively no caching.\n            hash_func: a callable to compute hash from data items to be cached.\n                defaults to `monai.data.utils.pickle_hashing`.\n            pickle_module: string representing the module used for pickling metadata and objects,\n                default to `\"pickle\"`. due to the pickle limitation in multi-processing of Dataloader,\n                we can't use `pickle` as arg directly, so here we use a string name instead.\n                if want to use other pickle module at runtime, just register like:\n                >>> from monai.data import utils\n                >>> utils.SUPPORTED_PICKLE_MOD[\"test\"] = other_pickle\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,\n                and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.\n            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.\n                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            hash_transform: a callable to compute hash from the transform information when caching.\n                This may reduce errors due to transforms changing during experiments. Default to None (no hash).\n                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.\n            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.\n                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.\n                This is useful for skipping the transform instance checks when inverting applied operations\n                using the cached content and with re-created transform instances.\n\n        \"\"\"\n        super().__init__(\n            data=data,\n            transform=transform,\n            cache_dir=cache_dir,\n            hash_func=hash_func,\n            pickle_module=pickle_module,\n            pickle_protocol=pickle_protocol,\n            hash_transform=hash_transform,\n            reset_ops_id=reset_ops_id,\n        )\n        self.cache_n_trans = cache_n_trans\n\n    def _pre_transform(self, item_transformed):\n        \"\"\"\n        Process the data from original state up to the N element.\n\n        Args:\n            item_transformed: The data to be transformed\n\n        Returns:\n            the transformed element up to the N transform object\n        \"\"\"\n        item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)\n\n        reset_ops_id(item_transformed)\n        return item_transformed\n\n    def _post_transform(self, item_transformed):\n        \"\"\"\n        Process the data from before the N + 1 transform to the final state ready for evaluation.\n\n        Args:\n            item_transformed: The data to be transformed (already processed up to the first N transform)\n\n        Returns:\n            the final transformed result\n        \"\"\"\n        return self.transform(item_transformed, start=self.cache_n_trans)\n\n\nclass LMDBDataset(PersistentDataset):\n    \"\"\"\n    Extension of `PersistentDataset` using LMDB as the backend.\n\n    See Also:\n        :py:class:`monai.data.PersistentDataset`\n\n    Examples:\n\n        >>> items = [{\"data\": i} for i in range(5)]\n        # [{'data': 0}, {'data': 1}, {'data': 2}, {'data': 3}, {'data': 4}]\n        >>> lmdb_ds = monai.data.LMDBDataset(items, transform=monai.transforms.SimulateDelayd(\"data\", delay_time=1))\n        >>> print(list(lmdb_ds))  # using the cached results\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        transform: Sequence[Callable] | Callable,\n        cache_dir: Path | str = \"cache\",\n        hash_func: Callable[..., bytes] = pickle_hashing,\n        db_name: str = \"monai_cache\",\n        progress: bool = True,\n        pickle_protocol=DEFAULT_PROTOCOL,\n        hash_transform: Callable[..., bytes] | None = None,\n        reset_ops_id: bool = True,\n        lmdb_kwargs: dict | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            data: input data file paths to load and transform to generate dataset for model.\n                `LMDBDataset` expects input data to be a list of serializable\n                and hashes them as cache keys using `hash_func`.\n            transform: transforms to execute operations on input data.\n            cache_dir: if specified, this is the location for persistent storage\n                of pre-computed transformed data tensors. The cache_dir is computed once, and\n                persists on disk until explicitly removed.  Different runs, programs, experiments\n                may share a common cache dir provided that the transforms pre-processing is consistent.\n                If the cache_dir doesn't exist, will automatically create it. Defaults to \"./cache\".\n            hash_func: a callable to compute hash from data items to be cached.\n                defaults to `monai.data.utils.pickle_hashing`.\n            db_name: lmdb database file name. Defaults to \"monai_cache\".\n            progress: whether to display a progress bar.\n            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.\n                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            hash_transform: a callable to compute hash from the transform information when caching.\n                This may reduce errors due to transforms changing during experiments. Default to None (no hash).\n                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.\n            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekeys.NONE``, defaults to ``True``.\n                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.\n                This is useful for skipping the transform instance checks when inverting applied operations\n                using the cached content and with re-created transform instances.\n            lmdb_kwargs: additional keyword arguments to the lmdb environment.\n                for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class\n        \"\"\"\n        super().__init__(\n            data=data,\n            transform=transform,\n            cache_dir=cache_dir,\n            hash_func=hash_func,\n            pickle_protocol=pickle_protocol,\n            hash_transform=hash_transform,\n            reset_ops_id=reset_ops_id,\n        )\n        self.progress = progress\n        if not self.cache_dir:\n            raise ValueError(\"cache_dir must be specified.\")\n        self.db_file = self.cache_dir / f\"{db_name}.lmdb\"\n        self.lmdb_kwargs = lmdb_kwargs or {}\n        if not self.lmdb_kwargs.get(\"map_size\", 0):\n            self.lmdb_kwargs[\"map_size\"] = 1024**4  # default map_size\n        # lmdb is single-writer multi-reader by default\n        # the cache is created without multi-threading\n        self._read_env: Any | None = None\n        # this runs on the primary thread/process\n        self._fill_cache_start_reader(show_progress=self.progress)\n        print(f\"Accessing lmdb file: {self.db_file.absolute()}.\")\n\n    def set_data(self, data: Sequence):\n        \"\"\"\n        Set the input data and delete all the out-dated cache content.\n\n        \"\"\"\n        super().set_data(data=data)\n        self._read_env = self._fill_cache_start_reader(show_progress=self.progress)\n\n    def _safe_serialize(self, val):\n        out = BytesIO()\n        torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol)\n        out.seek(0)\n        return out.read()\n\n    def _safe_deserialize(self, val):\n        return torch.load(BytesIO(val), map_location=\"cpu\", weights_only=True)\n\n    def _fill_cache_start_reader(self, show_progress=True):\n        \"\"\"\n        Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.\n        This method can be used with multiple processes, but it may have a negative impact on the performance.\n\n        Args:\n            show_progress: whether to show the progress bar if possible.\n        \"\"\"\n        # create cache\n        self.lmdb_kwargs[\"readonly\"] = False\n        env = lmdb.open(path=f\"{self.db_file}\", subdir=False, **self.lmdb_kwargs)\n        if show_progress and not has_tqdm:\n            warnings.warn(\"LMDBDataset: tqdm is not installed. not displaying the caching progress.\")\n        with env.begin(write=False) as search_txn:\n            for item in tqdm(self.data) if has_tqdm and show_progress else self.data:\n                key = self.hash_func(item)\n                done, retry, val = False, 5, None\n                while not done and retry > 0:\n                    try:\n                        with search_txn.cursor() as cursor:\n                            done = cursor.set_key(key)\n                        if done:\n                            continue\n                        if val is None:\n                            val = self._pre_transform(deepcopy(item))  # keep the original hashed\n                            # val = pickle.dumps(val, protocol=self.pickle_protocol)\n                            val = self._safe_serialize(val)\n                        with env.begin(write=True) as txn:\n                            txn.put(key, val)\n                        done = True\n                    except lmdb.MapFullError:\n                        done, retry = False, retry - 1\n                        size = env.info()[\"map_size\"]\n                        new_size = size * 2\n                        warnings.warn(\n                            f\"Resizing the cache database from {int(size) >> 20}MB\" f\" to {int(new_size) >> 20}MB.\"\n                        )\n                        env.set_mapsize(new_size)\n                    except lmdb.MapResizedError:\n                        # the mapsize is increased by another process\n                        # set_mapsize with a size of 0 to adopt the new size\n                        env.set_mapsize(0)\n                if not done:  # still has the map full error\n                    size = env.info()[\"map_size\"]\n                    env.close()\n                    raise ValueError(f\"LMDB map size reached, increase size above current size of {size}.\")\n        size = env.info()[\"map_size\"]\n        env.close()\n        # read-only database env\n        self.lmdb_kwargs[\"readonly\"] = True\n        self.lmdb_kwargs[\"map_size\"] = size\n        if self.lmdb_kwargs.get(\"lock\", None) is None:\n            self.lmdb_kwargs[\"lock\"] = False\n        if self.lmdb_kwargs.get(\"readahead\", None) is None:\n            self.lmdb_kwargs[\"readahead\"] = False\n        return lmdb.open(path=f\"{self.db_file}\", subdir=False, **self.lmdb_kwargs)\n\n    def _cachecheck(self, item_transformed):\n        \"\"\"\n        if the item is not found in the lmdb file, resolves to the persistent cache default behaviour.\n\n        \"\"\"\n        if self._read_env is None:\n            # this runs on multiple processes, each one should have its own env.\n            self._read_env = self._fill_cache_start_reader(show_progress=False)\n        with self._read_env.begin(write=False) as txn:\n            data = txn.get(self.hash_func(item_transformed))\n        if data is None:\n            warnings.warn(\"LMDBDataset: cache key not found, running fallback caching.\")\n            return super()._cachecheck(item_transformed)\n        try:\n            # return pickle.loads(data)\n            return self._safe_deserialize(data)\n        except Exception as err:\n            raise RuntimeError(\"Invalid cache value, corrupted lmdb file?\") from err\n\n    def info(self):\n        \"\"\"\n        Returns: dataset info dictionary.\n\n        \"\"\"\n        if self._read_env is None:\n            self._read_env = self._fill_cache_start_reader()\n        out = dict(self._read_env.info())\n        out[\"size\"] = len(self.data)\n        out[\"filename\"] = f\"{self.db_file.absolute()}\"\n        return out\n\n\nclass CacheDataset(Dataset):\n    \"\"\"\n    Dataset with cache mechanism that can load data and cache deterministic transforms' result during training.\n\n    By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline.\n    If the requested data is not in the cache, all transforms will run normally\n    (see also :py:class:`monai.data.dataset.Dataset`).\n\n    Users can set the cache rate or number of items to cache.\n    It is recommended to experiment with different `cache_num` or `cache_rate` to identify the best training speed.\n\n    The transforms which are supposed to be cached must implement the `monai.transforms.Transform`\n    interface and should not be `Randomizable`. This dataset will cache the outcomes before the first\n    `Randomizable` `Transform` within a `Compose` instance.\n    So to improve the caching efficiency, please always put as many as possible non-random transforms\n    before the randomized ones when composing the chain of transforms.\n    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,\n    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset\n\n    For example, if the transform is a `Compose` of::\n\n        transforms = Compose([\n            LoadImaged(),\n            EnsureChannelFirstd(),\n            Spacingd(),\n            Orientationd(),\n            ScaleIntensityRanged(),\n            RandCropByPosNegLabeld(),\n            ToTensord()\n        ])\n\n    when `transforms` is used in a multi-epoch training pipeline, before the first training epoch,\n    this dataset will cache the results up to ``ScaleIntensityRanged``, as\n    all non-random transforms `LoadImaged`, `EnsureChannelFirstd`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged`\n    can be cached. During training, the dataset will load the cached results and run\n    ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform\n    and the outcome not cached.\n\n    During training call `set_data()` to update input data and recompute cache content, note that it requires\n    `persistent_workers=False` in the PyTorch DataLoader.\n\n    Note:\n        `CacheDataset` executes non-random transforms and prepares cache content in the main process before\n        the first epoch, then all the subprocesses of DataLoader will read the same cache content in the main process\n        during training. it may take a long time to prepare cache content according to the size of expected cache data.\n        So to debug or verify the program before real training, users can set `cache_rate=0.0` or `cache_num=0` to\n        temporarily skip caching.\n\n    Lazy Resampling:\n        If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to\n        its documentation to familiarize yourself with the interaction between `CacheDataset` and\n        lazy resampling.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        transform: Sequence[Callable] | Callable | None = None,\n        cache_num: int = sys.maxsize,\n        cache_rate: float = 1.0,\n        num_workers: int | None = 1,\n        progress: bool = True,\n        copy_cache: bool = True,\n        as_contiguous: bool = True,\n        hash_as_key: bool = False,\n        hash_func: Callable[..., bytes] = pickle_hashing,\n        runtime_cache: bool | str | list | ListProxy = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            data: input data to load and transform to generate dataset for model.\n            transform: transforms to execute operations on input data.\n            cache_num: number of items to be cached. Default is `sys.maxsize`.\n                will take the minimum of (cache_num, data_length x cache_rate, data_length).\n            cache_rate: percentage of cached data in total, default is 1.0 (cache all).\n                will take the minimum of (cache_num, data_length x cache_rate, data_length).\n            num_workers: the number of worker threads if computing cache in the initialization.\n                If num_workers is None then the number returned by os.cpu_count() is used.\n                If a value less than 1 is specified, 1 will be used instead.\n            progress: whether to display a progress bar.\n            copy_cache: whether to `deepcopy` the cache content before applying the random transforms,\n                default to `True`. if the random transforms don't modify the cached content\n                (for example, randomly crop from the cached image and deepcopy the crop region)\n                or if every cache item is only used once in a `multi-processing` environment,\n                may set `copy=False` for better performance.\n            as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n                it may help improve the performance of following logic.\n            hash_as_key: whether to compute hash value of input data as the key to save cache,\n                if key exists, avoid saving duplicated content. it can help save memory when\n                the dataset has duplicated items or augmented dataset.\n            hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.\n                defaults to `monai.data.utils.pickle_hashing`.\n            runtime_cache: mode of cache at the runtime. Default to `False` to prepare\n                the cache content for the entire ``data`` during initialization, this potentially largely increase the\n                time required between the constructor called and first mini-batch generated.\n                Three options are provided to compute the cache on the fly after the dataset initialization:\n\n                1. ``\"threads\"`` or ``True``: use a regular ``list`` to store the cache items.\n                2. ``\"processes\"``: use a ListProxy to store the cache items, it can be shared among processes.\n                3. A list-like object: a users-provided container to be used to store the cache items.\n\n                For `thread-based` caching (typically for caching cuda tensors), option 1 is recommended.\n                For single process workflows with multiprocessing data loading, option 2 is recommended.\n                For multiprocessing workflows (typically for distributed training),\n                where this class is initialized in subprocesses, option 3 is recommended,\n                and the list-like object should be prepared in the main process and passed to all subprocesses.\n                Not following these recommendations may lead to runtime errors or duplicated cache across processes.\n\n        \"\"\"\n        super().__init__(data=data, transform=transform)\n        self.set_num = cache_num  # tracking the user-provided `cache_num` option\n        self.set_rate = cache_rate  # tracking the user-provided `cache_rate` option\n        self.progress = progress\n        self.copy_cache = copy_cache\n        self.as_contiguous = as_contiguous\n        self.hash_as_key = hash_as_key\n        self.hash_func = hash_func\n        self.num_workers = num_workers\n        if self.num_workers is not None:\n            self.num_workers = max(int(self.num_workers), 1)\n        self.runtime_cache = runtime_cache\n        self.cache_num = 0\n        self._cache: list | ListProxy = []\n        self._hash_keys: list = []\n        self.set_data(data)\n\n    def set_data(self, data: Sequence) -> None:\n        \"\"\"\n        Set the input data and run deterministic transforms to generate cache content.\n\n        Note: should call this func after an entire epoch and must set `persistent_workers=False`\n        in PyTorch DataLoader, because it needs to create new worker processes based on new\n        generated cache content.\n\n        \"\"\"\n        self.data = data\n\n        def _compute_cache_num(data_len: int):\n            self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)\n\n        if self.hash_as_key:\n            # only compute cache for the unique items of dataset, and record the last index for duplicated items\n            mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}\n            _compute_cache_num(len(mapping))\n            self._hash_keys = list(mapping)[: self.cache_num]\n            indices = list(mapping.values())[: self.cache_num]\n        else:\n            _compute_cache_num(len(self.data))\n            indices = list(range(self.cache_num))\n\n        if self.runtime_cache in (False, None):  # prepare cache content immediately\n            self._cache = self._fill_cache(indices)\n            return\n        if isinstance(self.runtime_cache, str) and \"process\" in self.runtime_cache:\n            # this must be in the main process, not in dataloader's workers\n            self._cache = Manager().list([None] * self.cache_num)\n            return\n        if (self.runtime_cache is True) or (isinstance(self.runtime_cache, str) and \"thread\" in self.runtime_cache):\n            self._cache = [None] * self.cache_num\n            return\n        self._cache = self.runtime_cache  # type: ignore\n        return\n\n    def _fill_cache(self, indices=None) -> list:\n        \"\"\"\n        Compute and fill the cache content from data source.\n\n        Args:\n            indices: target indices in the `self.data` source to compute cache.\n                if None, use the first `cache_num` items.\n\n        \"\"\"\n        if self.cache_num <= 0:\n            return []\n        if indices is None:\n            indices = list(range(self.cache_num))\n        if self.progress and not has_tqdm:\n            warnings.warn(\"tqdm is not installed, will not show the caching progress bar.\")\n        with ThreadPool(self.num_workers) as p:\n            if self.progress and has_tqdm:\n                return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc=\"Loading dataset\"))\n            return list(p.imap(self._load_cache_item, indices))\n\n    def _load_cache_item(self, idx: int):\n        \"\"\"\n        Args:\n            idx: the index of the input data sequence.\n        \"\"\"\n        item = self.data[idx]\n\n        first_random = self.transform.get_index_of_first(\n            lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)\n        )\n        item = self.transform(item, end=first_random, threading=True)\n\n        if self.as_contiguous:\n            item = convert_to_contiguous(item, memory_format=torch.contiguous_format)\n        return item\n\n    def _transform(self, index: int):\n        cache_index = None\n        if self.hash_as_key:\n            key = self.hash_func(self.data[index])\n            if key in self._hash_keys:\n                # if existing in cache, try to get the index in cache\n                cache_index = self._hash_keys.index(key)\n        elif index % len(self) < self.cache_num:  # support negative index\n            cache_index = index\n\n        if cache_index is None:\n            # no cache for this index, execute all the transforms directly\n            return super()._transform(index)\n\n        if self._cache is None:\n            raise RuntimeError(\"cache buffer is not initialized, please call `set_data()` first.\")\n        data = self._cache[cache_index]\n        # runtime cache computation\n        if data is None:\n            data = self._cache[cache_index] = self._load_cache_item(cache_index)\n\n        # load data from cache and execute from the first random transform\n        if not isinstance(self.transform, Compose):\n            raise ValueError(\"transform must be an instance of monai.transforms.Compose.\")\n\n        first_random = self.transform.get_index_of_first(\n            lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)\n        )\n        if first_random is not None:\n            data = deepcopy(data) if self.copy_cache is True else data\n            data = self.transform(data, start=first_random)\n\n        return data\n\n\nclass SmartCacheDataset(Randomizable, CacheDataset):\n    \"\"\"\n    Re-implementation of the SmartCache mechanism in NVIDIA Clara-train SDK.\n    At any time, the cache pool only keeps a subset of the whole dataset. In each epoch, only the items\n    in the cache are used for training. This ensures that data needed for training is readily available,\n    keeping GPU resources busy. Note that cached items may still have to go through a non-deterministic\n    transform sequence before being fed to GPU. At the same time, another thread is preparing replacement\n    items by applying the transform sequence to items not in cache. Once one epoch is completed, Smart\n    Cache replaces the same number of items with replacement items.\n    Smart Cache uses a simple `running window` algorithm to determine the cache content and replacement items.\n    Let N be the configured number of objects in cache; and R be the number of replacement objects (R = ceil(N * r),\n    where r is the configured replace rate).\n    For more details, please refer to:\n    https://docs.nvidia.com/clara/clara-train-archive/3.1/nvmidl/additional_features/smart_cache.html\n    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,\n    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset\n\n    For example, if we have 5 images: `[image1, image2, image3, image4, image5]`, and `cache_num=4`, `replace_rate=0.25`.\n    so the actual training images cached and replaced for every epoch are as below::\n\n        epoch 1: [image1, image2, image3, image4]\n        epoch 2: [image2, image3, image4, image5]\n        epoch 3: [image3, image4, image5, image1]\n        epoch 3: [image4, image5, image1, image2]\n        epoch N: [image[N % 5] ...]\n\n    The usage of `SmartCacheDataset` contains 4 steps:\n\n        1. Initialize `SmartCacheDataset` object and cache for the first epoch.\n        2. Call `start()` to run replacement thread in background.\n        3. Call `update_cache()` before every epoch to replace training items.\n        4. Call `shutdown()` when training ends.\n\n    During training call `set_data()` to update input data and recompute cache content, note to call\n    `shutdown()` to stop first, then update data and call `start()` to restart.\n\n    Note:\n        This replacement will not work for below cases:\n        1. Set the `multiprocessing_context` of DataLoader to `spawn`.\n        2. Launch distributed data parallel with `torch.multiprocessing.spawn`.\n        3. Run on windows(the default multiprocessing method is `spawn`) with `num_workers` greater than 0.\n        4. Set the `persistent_workers` of DataLoader to `True` with `num_workers` greater than 0.\n\n        If using MONAI workflows, please add `SmartCacheHandler` to the handler list of trainer,\n        otherwise, please make sure to call `start()`, `update_cache()`, `shutdown()` during training.\n\n    Args:\n        data: input data to load and transform to generate dataset for model.\n        transform: transforms to execute operations on input data.\n        replace_rate: percentage of the cached items to be replaced in every epoch (default to 0.1).\n        cache_num: number of items to be cached. Default is `sys.maxsize`.\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        cache_rate: percentage of cached data in total, default is 1.0 (cache all).\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        num_init_workers: the number of worker threads to initialize the cache for first epoch.\n            If num_init_workers is None then the number returned by os.cpu_count() is used.\n            If a value less than 1 is specified, 1 will be used instead.\n        num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.\n            If num_replace_workers is None then the number returned by os.cpu_count() is used.\n            If a value less than 1 is specified, 1 will be used instead.\n        progress: whether to display a progress bar when caching for the first epoch.\n        shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch.\n            it will not modify the original input data sequence in-place.\n        seed: random seed if shuffle is `True`, default to `0`.\n        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,\n            default to `True`. if the random transforms don't modify the cache content\n            or every cache item is only used once in a `multi-processing` environment,\n            may set `copy=False` for better performance.\n        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n            it may help improve the performance of following logic.\n        runtime_cache: Default to `False`, other options are not implemented yet.\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        transform: Sequence[Callable] | Callable | None = None,\n        replace_rate: float = 0.1,\n        cache_num: int = sys.maxsize,\n        cache_rate: float = 1.0,\n        num_init_workers: int | None = 1,\n        num_replace_workers: int | None = 1,\n        progress: bool = True,\n        shuffle: bool = True,\n        seed: int = 0,\n        copy_cache: bool = True,\n        as_contiguous: bool = True,\n        runtime_cache=False,\n    ) -> None:\n        if shuffle:\n            self.set_random_state(seed=seed)\n        self.shuffle = shuffle\n\n        self._start_pos: int = 0\n        self._update_lock: threading.Lock = threading.Lock()\n        self._round: int = 1\n        self._replace_done: bool = False\n        self._replace_mgr: threading.Thread | None = None\n        if runtime_cache is not False:\n            raise NotImplementedError(\"Options other than `runtime_cache=False` is not implemented yet.\")\n\n        super().__init__(\n            data=data,\n            transform=transform,\n            cache_num=cache_num,\n            cache_rate=cache_rate,\n            num_workers=num_init_workers,\n            progress=progress,\n            copy_cache=copy_cache,\n            as_contiguous=as_contiguous,\n            runtime_cache=False,\n        )\n        if self._cache is None:\n            self._cache = self._fill_cache()\n        if self.cache_num >= len(data):\n            warnings.warn(\n                \"cache_num is greater or equal than dataset length, fall back to regular monai.data.CacheDataset.\"\n            )\n        if replace_rate <= 0:\n            raise ValueError(\"replace_rate must be greater than 0, otherwise, please use monai.data.CacheDataset.\")\n\n        self.num_replace_workers: int | None = num_replace_workers\n        if self.num_replace_workers is not None:\n            self.num_replace_workers = max(int(self.num_replace_workers), 1)\n\n        self._total_num: int = len(data)\n        self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num)\n        self._replacements: list[Any] = [None for _ in range(self._replace_num)]\n        self._replace_data_idx: list[int] = list(range(self._replace_num))\n        self._compute_data_idx()\n\n    def set_data(self, data: Sequence):\n        \"\"\"\n        Set the input data and run deterministic transforms to generate cache content.\n\n        Note: should call `shutdown()` before calling this func.\n\n        \"\"\"\n        if self.is_started():\n            warnings.warn(\"SmartCacheDataset is not shutdown yet, shutdown it directly.\")\n            self.shutdown()\n\n        if self.shuffle:\n            data = copy(data)\n            self.randomize(data)\n        super().set_data(data)\n\n    def randomize(self, data: Sequence) -> None:\n        try:\n            self.R.shuffle(data)\n        except TypeError as e:\n            warnings.warn(f\"input data can't be shuffled in SmartCacheDataset with numpy.random.shuffle(): {e}.\")\n\n    def _compute_data_idx(self) -> None:\n        \"\"\"\n        Update the replacement data position in the total data.\n\n        \"\"\"\n        for i in range(self._replace_num):\n            pos: int = self._start_pos + self.cache_num + i\n            if pos >= self._total_num:\n                pos -= self._total_num\n            self._replace_data_idx[i] = pos\n\n    def is_started(self):\n        \"\"\"\n        Check whether the replacement thread is already started.\n\n        \"\"\"\n        return False if self._replace_mgr is None else self._replace_mgr.is_alive()\n\n    def start(self):\n        \"\"\"\n        Start the background thread to replace training items for every epoch.\n\n        \"\"\"\n        if not self.is_started():\n            self._restart()\n\n    def _restart(self):\n        \"\"\"\n        Restart background thread if killed for some reason.\n\n        \"\"\"\n        self._round = 1\n        self._replace_mgr = threading.Thread(target=self.manage_replacement, daemon=True)\n        self._replace_mgr.start()\n\n    def _try_update_cache(self):\n        \"\"\"\n        Update the cache items with new replacement for current epoch.\n\n        \"\"\"\n        with self._update_lock:\n            if not self._replace_done:\n                return False\n\n            del self._cache[: self._replace_num]\n            self._cache.extend(self._replacements)\n\n            self._start_pos += self._replace_num\n            if self._start_pos >= self._total_num:\n                self._start_pos -= self._total_num\n\n            self._compute_data_idx()\n\n            # ready for next round\n            self._round += 1\n            self._replace_done = False\n            return True\n\n    def update_cache(self):\n        \"\"\"\n        Update cache items for current epoch, need to call this function before every epoch.\n        If the cache has been shutdown before, need to restart the `_replace_mgr` thread.\n\n        \"\"\"\n        self.start()\n\n        # make sure update is done\n        while not self._try_update_cache():\n            time.sleep(0.01)\n\n    def _try_shutdown(self):\n        \"\"\"\n        Wait for thread lock to shut down the background thread.\n\n        \"\"\"\n        with self._update_lock:\n            if self._replace_done:\n                self._round = 0\n                self._start_pos = 0\n                self._compute_data_idx()\n                self._replace_done = False\n                return True\n            return False\n\n    def shutdown(self):\n        \"\"\"\n        Shut down the background thread for replacement.\n\n        \"\"\"\n        if not self.is_started():\n            return\n\n        # wait until replace mgr is done the current round\n        while not self._try_shutdown():\n            time.sleep(0.01)\n        if self._replace_mgr is not None:\n            self._replace_mgr.join(300)\n\n    def _replace_cache_thread(self, index: int):\n        \"\"\"\n        Execute deterministic transforms on the new data for replacement.\n\n        \"\"\"\n        pos: int = self._replace_data_idx[index]\n        self._replacements[index] = self._load_cache_item(pos)\n\n    def _compute_replacements(self):\n        \"\"\"\n        Compute expected items for the replacement of next epoch, execute deterministic transforms.\n        It can support multi-threads to accelerate the computation progress.\n\n        \"\"\"\n        with ThreadPool(self.num_replace_workers) as p:\n            p.map(self._replace_cache_thread, list(range(self._replace_num)))\n\n        self._replace_done = True\n\n    def _try_manage_replacement(self, check_round):\n        \"\"\"\n        Wait thread lock and replace training items in the background thread.\n\n        \"\"\"\n        with self._update_lock:\n            if self._round <= 0:\n                # shutdown replacement\n                self._replace_done = True\n                return True, -1\n\n            if self._round != check_round:\n                self._compute_replacements()\n            return False, self._round\n\n    def manage_replacement(self) -> None:\n        \"\"\"\n        Background thread for replacement.\n\n        \"\"\"\n        check_round: int = -1\n        done = False\n        while not done:\n            done, check_round = self._try_manage_replacement(check_round)\n            time.sleep(0.01)\n\n    def __len__(self):\n        \"\"\"\n        The dataset length is given by cache_num instead of len(data).\n\n        \"\"\"\n        return self.cache_num\n\n\nclass ZipDataset(Dataset):\n    \"\"\"\n    Zip several PyTorch datasets and output data(with the same index) together in a tuple.\n    If the output of single dataset is already a tuple, flatten it and extend to the result.\n    For example: if datasetA returns (img, imgmeta), datasetB returns (seg, segmeta),\n    finally return (img, imgmeta, seg, segmeta).\n    And if the datasets don't have same length, use the minimum length of them as the length\n    of ZipDataset.\n    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,\n    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset\n\n    Examples::\n\n        >>> zip_data = ZipDataset([[1, 2, 3], [4, 5]])\n        >>> print(len(zip_data))\n        2\n        >>> for item in zip_data:\n        >>>    print(item)\n        [1, 4]\n        [2, 5]\n\n    \"\"\"\n\n    def __init__(self, datasets: Sequence, transform: Callable | None = None) -> None:\n        \"\"\"\n        Args:\n            datasets: list of datasets to zip together.\n            transform: a callable data transform operates on the zipped item from `datasets`.\n        \"\"\"\n        super().__init__(list(datasets), transform=transform)\n\n    def __len__(self) -> int:\n        return min(len(dataset) for dataset in self.data)\n\n    def _transform(self, index: int):\n\n        def to_list(x):\n            return list(x) if isinstance(x, (tuple, list)) else [x]\n\n        data = []\n        for dataset in self.data:\n            data.extend(to_list(dataset[index]))\n\n        if self.transform is not None:\n            self.transform.map_items = False  # Compose object map_items to false so transform is applied to list\n            data = self.transform(data)\n        # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists\n        return tuple(data)\n\n\nclass ArrayDataset(Randomizable, _TorchDataset):\n    \"\"\"\n    Dataset for segmentation and classification tasks based on array format input data and transforms.\n    It ensures the same random seeds in the randomized transforms defined for image, segmentation and label.\n    The `transform` can be :py:class:`monai.transforms.Compose` or any other callable object.\n    For example:\n    If train based on Nifti format images without metadata, all transforms can be composed::\n\n        img_transform = Compose(\n            [\n                LoadImage(image_only=True),\n                EnsureChannelFirst(),\n                RandAdjustContrast()\n            ]\n        )\n        ArrayDataset(img_file_list, img_transform=img_transform)\n\n    If training based on images and the metadata, the array transforms can not be composed\n    because several transforms receives multiple parameters or return multiple values. Then Users need\n    to define their own callable method to parse metadata from `LoadImage` or set `affine` matrix\n    to `Spacing` transform::\n\n        class TestCompose(Compose):\n            def __call__(self, input_):\n                img, metadata = self.transforms[0](input_)\n                img = self.transforms[1](img)\n                img, _, _ = self.transforms[2](img, metadata[\"affine\"])\n                return self.transforms[3](img), metadata\n        img_transform = TestCompose(\n            [\n                LoadImage(image_only=False),\n                EnsureChannelFirst(),\n                Spacing(pixdim=(1.5, 1.5, 3.0)),\n                RandAdjustContrast()\n            ]\n        )\n        ArrayDataset(img_file_list, img_transform=img_transform)\n\n    Examples::\n\n        >>> ds = ArrayDataset([1, 2, 3, 4], lambda x: x + 0.1)\n        >>> print(ds[0])\n        1.1\n\n        >>> ds = ArrayDataset(img=[1, 2, 3, 4], seg=[5, 6, 7, 8])\n        >>> print(ds[0])\n        [1, 5]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        img: Sequence,\n        img_transform: Callable | None = None,\n        seg: Sequence | None = None,\n        seg_transform: Callable | None = None,\n        labels: Sequence | None = None,\n        label_transform: Callable | None = None,\n    ) -> None:\n        \"\"\"\n        Initializes the dataset with the filename lists. The transform `img_transform` is applied\n        to the images and `seg_transform` to the segmentations.\n\n        Args:\n            img: sequence of images.\n            img_transform: transform to apply to each element in `img`.\n            seg: sequence of segmentations.\n            seg_transform: transform to apply to each element in `seg`.\n            labels: sequence of labels.\n            label_transform: transform to apply to each element in `labels`.\n\n        \"\"\"\n        items = [(img, img_transform), (seg, seg_transform), (labels, label_transform)]\n        self.set_random_state(seed=get_seed())\n        datasets = [Dataset(x[0], x[1]) for x in items if x[0] is not None]\n        self.dataset = datasets[0] if len(datasets) == 1 else ZipDataset(datasets)\n\n        self._seed = 0  # transform synchronization seed\n\n    def __len__(self) -> int:\n        return len(self.dataset)\n\n    def randomize(self, data: Any | None = None) -> None:\n        self._seed = int(self.R.randint(MAX_SEED, dtype=\"uint32\"))\n\n    def __getitem__(self, index: int):\n        self.randomize()\n        if isinstance(self.dataset, ZipDataset):\n            # set transforms of each zip component\n            for dataset in self.dataset.data:\n                transform = getattr(dataset, \"transform\", None)\n                if isinstance(transform, Randomizable):\n                    transform.set_random_state(seed=self._seed)\n        transform = getattr(self.dataset, \"transform\", None)\n        if isinstance(transform, Randomizable):\n            transform.set_random_state(seed=self._seed)\n        return self.dataset[index]\n\n\nclass NPZDictItemDataset(Dataset):\n    \"\"\"\n    Represents a dataset from a loaded NPZ file. The members of the file to load are named in the keys of `keys` and\n    stored under the keyed name. All loaded arrays must have the same 0-dimension (batch) size. Items are always dicts\n    mapping names to an item extracted from the loaded arrays.\n    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,\n    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset\n\n    Args:\n        npzfile: Path to .npz file or stream containing .npz file data\n        keys: Maps keys to load from file to name to store in dataset\n        transform: Transform to apply to batch dict\n        other_keys: secondary data to load from file and store in dict `other_keys`, not returned by __getitem__\n    \"\"\"\n\n    def __init__(\n        self,\n        npzfile: str | IO,\n        keys: dict[str, str],\n        transform: Callable[..., dict[str, Any]] | None = None,\n        other_keys: Sequence[str] | None = (),\n    ):\n        self.npzfile: str | IO = npzfile if isinstance(npzfile, str) else \"STREAM\"\n        self.keys: dict[str, str] = dict(keys)\n        dat = np.load(npzfile)\n\n        self.arrays = {storedk: dat[datak] for datak, storedk in self.keys.items()}\n        self.length = self.arrays[cast(str, first(self.keys.values()))].shape[0]\n\n        self.other_keys = {} if other_keys is None else {k: dat[k] for k in other_keys}\n\n        for k, v in self.arrays.items():\n            if v.shape[0] != self.length:\n                raise ValueError(\n                    \"All loaded arrays must have the same first dimension \"\n                    f\"size {self.length}, array `{k}` has size {v.shape[0]}\"\n                )\n\n        super().__init__([], transform)\n\n    def __len__(self):\n        return self.length\n\n    def _transform(self, index: int):\n        data = {k: v[index] for k, v in self.arrays.items()}\n        result = self.transform(data) if self.transform is not None else data\n\n        if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)):\n            return result\n        raise AssertionError(\"With a dict supplied to Compose, should return a dict or a list of dicts.\")\n\n\nclass CSVDataset(Dataset):\n    \"\"\"\n    Dataset to load data from CSV files and generate a list of dictionaries,\n    every dictionary maps to a row of the CSV file, and the keys of dictionary\n    map to the column names of the CSV file.\n\n    It can load multiple CSV files and join the tables with additional `kwargs` arg.\n    Support to only load specific rows and columns.\n    And it can also group several loaded columns to generate a new column, for example,\n    set `col_groups={\"meta\": [\"meta_0\", \"meta_1\", \"meta_2\"]}`, output can be::\n\n        [\n            {\"image\": \"./image0.nii\", \"meta_0\": 11, \"meta_1\": 12, \"meta_2\": 13, \"meta\": [11, 12, 13]},\n            {\"image\": \"./image1.nii\", \"meta_0\": 21, \"meta_1\": 22, \"meta_2\": 23, \"meta\": [21, 22, 23]},\n        ]\n\n    Args:\n        src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.\n            also support to provide pandas `DataFrame` directly, will skip loading from filename.\n            if provided a list of filenames or pandas `DataFrame`, it will join the tables.\n        row_indices: indices of the expected rows to load. it should be a list,\n            every item can be a int number or a range `[start, end)` for the indices.\n            for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,\n            load all the rows in the file.\n        col_names: names of the expected columns to load. if None, load all the columns.\n        col_types: `type` and `default value` to convert the loaded columns, if None, use original data.\n            it should be a dictionary, every item maps to an expected column, the `key` is the column\n            name and the `value` is None or a dictionary to define the default value and data type.\n            the supported keys in dictionary are: [\"type\", \"default\"]. for example::\n\n                col_types = {\n                    \"subject_id\": {\"type\": str},\n                    \"label\": {\"type\": int, \"default\": 0},\n                    \"ehr_0\": {\"type\": float, \"default\": 0.0},\n                    \"ehr_1\": {\"type\": float, \"default\": 0.0},\n                    \"image\": {\"type\": str, \"default\": None},\n                }\n\n        col_groups: args to group the loaded columns to generate a new column,\n            it should be a dictionary, every item maps to a group, the `key` will\n            be the new column name, the `value` is the names of columns to combine. for example:\n            `col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(10)], \"meta\": [\"meta_1\", \"meta_2\"]}`\n        transform: transform to apply on the loaded items of a dictionary data.\n        kwargs_read_csv: dictionary args to pass to pandas `read_csv` function.\n        kwargs: additional arguments for `pandas.merge()` API to join tables.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        src: str | Sequence[str] | None = None,  # also can be `DataFrame` or a sequence of `DataFrame`\n        row_indices: Sequence[int | str] | None = None,\n        col_names: Sequence[str] | None = None,\n        col_types: dict[str, dict[str, Any] | None] | None = None,\n        col_groups: dict[str, Sequence[str]] | None = None,\n        transform: Callable | None = None,\n        kwargs_read_csv: dict | None = None,\n        **kwargs,\n    ):\n        srcs = (src,) if not isinstance(src, (tuple, list)) else src\n        dfs: list = []\n        for i in srcs:\n            if isinstance(i, str):\n                dfs.append(pd.read_csv(i, **kwargs_read_csv) if kwargs_read_csv else pd.read_csv(i))\n            elif isinstance(i, pd.DataFrame):\n                dfs.append(i)\n            else:\n                raise ValueError(\"`src` must be file path or pandas `DataFrame`.\")\n\n        data = convert_tables_to_dicts(\n            dfs=dfs, row_indices=row_indices, col_names=col_names, col_types=col_types, col_groups=col_groups, **kwargs\n        )\n        super().__init__(data=data, transform=transform)\n\n\nclass GDSDataset(PersistentDataset):\n    \"\"\"\n    An extension of the PersistentDataset using direct memory access(DMA) data path between\n    GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system\n    bandwidth while decreasing latency and utilization load on the CPU and GPU.\n\n    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.\n\n    See also: https://github.com/rapidsai/kvikio\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        transform: Sequence[Callable] | Callable,\n        cache_dir: Path | str | None,\n        device: int,\n        hash_func: Callable[..., bytes] = pickle_hashing,\n        hash_transform: Callable[..., bytes] | None = None,\n        reset_ops_id: bool = True,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Args:\n            data: input data file paths to load and transform to generate dataset for model.\n                `GDSDataset` expects input data to be a list of serializable\n                and hashes them as cache keys using `hash_func`.\n            transform: transforms to execute operations on input data.\n            cache_dir: If specified, this is the location for gpu direct storage\n                of pre-computed transformed data tensors. The cache_dir is computed once, and\n                persists on disk until explicitly removed.  Different runs, programs, experiments\n                may share a common cache dir provided that the transforms pre-processing is consistent.\n                If `cache_dir` doesn't exist, will automatically create it.\n                If `cache_dir` is `None`, there is effectively no caching.\n            device: target device to put the output Tensor data. Note that only int can be used to\n                specify the gpu to be used.\n            hash_func: a callable to compute hash from data items to be cached.\n                defaults to `monai.data.utils.pickle_hashing`.\n            hash_transform: a callable to compute hash from the transform information when caching.\n                This may reduce errors due to transforms changing during experiments. Default to None (no hash).\n                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.\n            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.\n                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.\n                This is useful for skipping the transform instance checks when inverting applied operations\n                using the cached content and with re-created transform instances.\n\n        \"\"\"\n        super().__init__(\n            data=data,\n            transform=transform,\n            cache_dir=cache_dir,\n            hash_func=hash_func,\n            hash_transform=hash_transform,\n            reset_ops_id=reset_ops_id,\n            **kwargs,\n        )\n        self.device = device\n        self._meta_cache: dict[Any, dict[Any, Any]] = {}\n\n    def _cachecheck(self, item_transformed):\n        \"\"\"\n        In order to enable direct storage to the GPU when loading the hashfile, rewritten this function.\n        Note that in this function, it will always return `torch.Tensor` when load data from cache.\n\n        Args:\n            item_transformed: The current data element to be mutated into transformed representation\n\n        Returns:\n            The transformed data_element, either from cache, or explicitly computing it.\n\n        Warning:\n            The current implementation does not encode transform information as part of the\n            hashing mechanism used for generating cache names when `hash_transform` is None.\n            If the transforms applied are changed in any way, the objects in the cache dir will be invalid.\n\n        \"\"\"\n        hashfile = None\n        # compute a cache id\n        if self.cache_dir is not None:\n            data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n            data_item_md5 += self.transform_hash\n            hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n\n        if hashfile is not None and hashfile.is_file():  # cache hit\n            with cp.cuda.Device(self.device):\n                if isinstance(item_transformed, dict):\n                    item: dict[Any, Any] = {}\n                    for k in item_transformed:\n                        meta_k = self._load_meta_cache(meta_hash_file_name=f\"{hashfile.name}-{k}-meta\")\n                        item[k] = kvikio_numpy.fromfile(f\"{hashfile}-{k}\", dtype=meta_k[\"dtype\"], like=cp.empty(()))\n                        item[k] = convert_to_tensor(item[k].reshape(meta_k[\"shape\"]), device=f\"cuda:{self.device}\")\n                        item[f\"{k}_meta_dict\"] = meta_k\n                    return item\n                elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):\n                    _meta = self._load_meta_cache(meta_hash_file_name=f\"{hashfile.name}-meta\")\n                    _data = kvikio_numpy.fromfile(f\"{hashfile}\", dtype=_meta[\"dtype\"], like=cp.empty(()))\n                    _data = convert_to_tensor(_data.reshape(_meta[\"shape\"]), device=f\"cuda:{self.device}\")\n                    filtered_keys = list(filter(lambda key: key not in [\"dtype\", \"shape\"], _meta.keys()))\n                    if bool(filtered_keys):\n                        return (_data, _meta)\n                    return _data\n                else:\n                    item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))]  # type:ignore\n                    for i, _item in enumerate(item_transformed):\n                        for k in _item:\n                            meta_i_k = self._load_meta_cache(meta_hash_file_name=f\"{hashfile.name}-{k}-meta-{i}\")\n                            item_k = kvikio_numpy.fromfile(\n                                f\"{hashfile}-{k}-{i}\", dtype=meta_i_k[\"dtype\"], like=cp.empty(())\n                            )\n                            item_k = convert_to_tensor(item[i].reshape(meta_i_k[\"shape\"]), device=f\"cuda:{self.device}\")\n                            item[i].update({k: item_k, f\"{k}_meta_dict\": meta_i_k})\n                    return item\n\n        # create new cache\n        _item_transformed = self._pre_transform(deepcopy(item_transformed))  # keep the original hashed\n        if hashfile is None:\n            return _item_transformed\n        if isinstance(_item_transformed, dict):\n            for k in _item_transformed:\n                data_hashfile = f\"{hashfile}-{k}\"\n                meta_hash_file_name = f\"{hashfile.name}-{k}-meta\"\n                if isinstance(_item_transformed[k], (np.ndarray, torch.Tensor)):\n                    self._create_new_cache(_item_transformed[k], data_hashfile, meta_hash_file_name)\n                else:\n                    return _item_transformed\n        elif isinstance(_item_transformed, (np.ndarray, torch.Tensor)):\n            data_hashfile = f\"{hashfile}\"\n            meta_hash_file_name = f\"{hashfile.name}-meta\"\n            self._create_new_cache(_item_transformed, data_hashfile, meta_hash_file_name)\n        else:\n            for i, _item in enumerate(_item_transformed):\n                for k in _item:\n                    data_hashfile = f\"{hashfile}-{k}-{i}\"\n                    meta_hash_file_name = f\"{hashfile.name}-{k}-meta-{i}\"\n                    self._create_new_cache(_item, data_hashfile, meta_hash_file_name)\n        open(hashfile, \"a\").close()  # store cacheid\n        return _item_transformed\n\n    def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):\n        self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {}\n        _item_transformed_data = data.array if isinstance(data, MetaTensor) else data\n        if isinstance(_item_transformed_data, torch.Tensor):\n            _item_transformed_data = _item_transformed_data.numpy()\n        self._meta_cache[meta_hash_file_name][\"shape\"] = _item_transformed_data.shape\n        self._meta_cache[meta_hash_file_name][\"dtype\"] = str(_item_transformed_data.dtype)\n        kvikio_numpy.tofile(_item_transformed_data, data_hashfile)\n        try:\n            # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n            #       to make the cache more robust to manual killing of parent process\n            #       which may leave partially written cache files in an incomplete state\n            with tempfile.TemporaryDirectory() as tmpdirname:\n                meta_hash_file = self.cache_dir / meta_hash_file_name\n                temp_hash_file = Path(tmpdirname) / meta_hash_file_name\n                torch.save(\n                    obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),\n                    f=temp_hash_file,\n                    pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n                    pickle_protocol=self.pickle_protocol,\n                )\n                if temp_hash_file.is_file() and not meta_hash_file.is_file():\n                    # On Unix, if target exists and is a file, it will be replaced silently if the\n                    # user has permission.\n                    # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n                    try:\n                        shutil.move(str(temp_hash_file), meta_hash_file)\n                    except FileExistsError:\n                        pass\n        except PermissionError:  # project-monai/monai issue #3613\n            pass\n\n    def _load_meta_cache(self, meta_hash_file_name):\n        if meta_hash_file_name in self._meta_cache:\n            return self._meta_cache[meta_hash_file_name]\n        else:\n            return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)\n"
  },
  {
    "path": "monai/data/dataset_summary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom itertools import chain\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection\nfrom monai.data.dataloader import DataLoader\nfrom monai.data.dataset import Dataset\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import affine_to_spacing\nfrom monai.transforms import concatenate\nfrom monai.utils import PostFix, convert_data_type, convert_to_tensor\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\nclass DatasetSummary:\n    \"\"\"\n    This class provides a way to calculate a reasonable output voxel spacing according to\n    the input dataset. The achieved values can used to resample the input in 3d segmentation tasks\n    (like using as the `pixdim` parameter in `monai.transforms.Spacingd`).\n    In addition, it also supports to compute the mean, std, min and max intensities of the input,\n    and these statistics are helpful for image normalization\n    (as parameters of `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`).\n\n    The algorithm for calculation refers to:\n    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Dataset,\n        image_key: str | None = \"image\",\n        label_key: str | None = \"label\",\n        meta_key: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        num_workers: int = 0,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            dataset: dataset from which to load the data.\n            image_key: key name of images (default: ``image``).\n            label_key: key name of labels (default: ``label``).\n            meta_key: explicitly indicate the key of the corresponding metadata dictionary.\n                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n                the metadata is a dictionary object which contains: filename, affine, original_shape, etc.\n                if None, will try to construct meta_keys by `{image_key}_{meta_key_postfix}`.\n                This is not required if `data[image_key]` is a MetaTensor.\n            meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the metadata from dict,\n                the metadata is a dictionary object (default: ``meta_dict``).\n            num_workers: how many subprocesses to use for data loading.\n                ``0`` means that the data will be loaded in the main process (default: ``0``).\n            kwargs: other parameters (except `batch_size` and `num_workers`) for DataLoader,\n                this class forces to use ``batch_size=1``.\n\n        \"\"\"\n\n        self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs)\n\n        self.image_key = image_key\n        self.label_key = label_key\n        self.meta_key = meta_key or f\"{image_key}_{meta_key_postfix}\"\n        self.all_meta_data: list = []\n\n    def collect_meta_data(self):\n        \"\"\"\n        This function is used to collect the metadata for all images of the dataset.\n        \"\"\"\n\n        for data in self.data_loader:\n            meta_dict = {}\n            if isinstance(data[self.image_key], MetaTensor):\n                meta_dict = data[self.image_key].meta\n            elif self.meta_key in data:\n                meta_dict = data[self.meta_key]\n            else:\n                warnings.warn(f\"To collect metadata for the dataset, `{self.meta_key}` or `data.meta` must exist.\")\n            self.all_meta_data.append(meta_dict)\n\n    def get_target_spacing(self, spacing_key: str = \"affine\", anisotropic_threshold: int = 3, percentile: float = 10.0):\n        \"\"\"\n        Calculate the target spacing according to all spacings.\n        If the target spacing is very anisotropic,\n        decrease the spacing value of the maximum axis according to percentile.\n        The spacing is computed from `affine_to_spacing(data[spacing_key][0], 3)` if `data[spacing_key]` is a matrix,\n        otherwise, the `data[spacing_key]` must be a vector of pixdim values.\n\n        Args:\n            spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``).\n            anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).\n            percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to\n                replace that axis.\n\n        \"\"\"\n        if len(self.all_meta_data) == 0:\n            self.collect_meta_data()\n        if spacing_key not in self.all_meta_data[0]:\n            raise ValueError(\"The provided spacing_key is not in self.all_meta_data.\")\n        spacings = []\n        for data in self.all_meta_data:\n            spacing_vals = convert_to_tensor(data[spacing_key][0], track_meta=False, wrap_sequence=True)\n            if spacing_vals.ndim == 1:  # vector\n                spacings.append(spacing_vals[:3][None])\n            elif spacing_vals.ndim == 2:  # matrix\n                spacings.append(affine_to_spacing(spacing_vals, 3)[None])\n            else:\n                raise ValueError(\"data[spacing_key] must be a vector or a matrix.\")\n        all_spacings = concatenate(to_cat=spacings, axis=0)\n        all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True)\n\n        target_spacing = np.median(all_spacings, axis=0)\n        if max(target_spacing) / min(target_spacing) >= anisotropic_threshold:\n            largest_axis = np.argmax(target_spacing)\n            target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile)\n\n        output = list(target_spacing)\n\n        return tuple(output)\n\n    def calculate_statistics(self, foreground_threshold: int = 0):\n        \"\"\"\n        This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of\n        the input dataset.\n\n        Args:\n            foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter\n                is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding\n                voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set\n                the threshold to ``-1`` (default: ``0``).\n\n        \"\"\"\n        voxel_sum = torch.as_tensor(0.0)\n        voxel_square_sum = torch.as_tensor(0.0)\n        voxel_max, voxel_min = [], []\n        voxel_ct = 0\n\n        for data in self.data_loader:\n            if self.image_key and self.label_key:\n                image, label = data[self.image_key], data[self.label_key]\n            else:\n                image, label = data\n            image, *_ = convert_data_type(data=image, output_type=torch.Tensor)\n            label, *_ = convert_data_type(data=label, output_type=torch.Tensor)\n\n            image_foreground = image[torch.where(label > foreground_threshold)]\n\n            voxel_max.append(image_foreground.max().item())\n            voxel_min.append(image_foreground.min().item())\n            voxel_ct += len(image_foreground)\n            voxel_sum += image_foreground.sum()\n            voxel_square_sum += torch.square(image_foreground).sum()\n\n        self.data_max, self.data_min = max(voxel_max), min(voxel_min)\n        self.data_mean = (voxel_sum / voxel_ct).item()\n        self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean**2)).item()\n\n    def calculate_percentiles(\n        self,\n        foreground_threshold: int = 0,\n        sampling_flag: bool = True,\n        interval: int = 10,\n        min_percentile: float = 0.5,\n        max_percentile: float = 99.5,\n    ):\n        \"\"\"\n        This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get\n        the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set\n        to accumulate only a part of the voxels.\n\n        Args:\n            foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter\n                is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding\n                voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set\n                the threshold to ``-1`` (default: ``0``).\n            sampling_flag: whether to sample only a part of the voxels (default: ``True``).\n            interval: the sampling interval for accumulating voxels (default: ``10``).\n            min_percentile: minimal percentile (default: ``0.5``).\n            max_percentile: maximal percentile (default: ``99.5``).\n\n        \"\"\"\n        all_intensities = []\n        for data in self.data_loader:\n            if self.image_key and self.label_key:\n                image, label = data[self.image_key], data[self.label_key]\n            else:\n                image, label = data\n            image, *_ = convert_data_type(data=image, output_type=torch.Tensor)\n            label, *_ = convert_data_type(data=label, output_type=torch.Tensor)\n\n            intensities = image[torch.where(label > foreground_threshold)].tolist()\n            if sampling_flag:\n                intensities = intensities[::interval]\n            all_intensities.append(intensities)\n\n        all_intensities = list(chain(*all_intensities))\n        self.data_min_percentile, self.data_max_percentile = np.percentile(\n            all_intensities, [min_percentile, max_percentile]\n        )\n        self.data_median = np.median(all_intensities)\n"
  },
  {
    "path": "monai/data/decathlon_datalist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport warnings\nfrom collections.abc import Sequence\nfrom pathlib import Path\nfrom typing import overload\n\nfrom monai.config import KeysCollection, PathLike\nfrom monai.data.utils import partition_dataset, select_cross_validation_folds\nfrom monai.utils import ensure_tuple\n\n\n@overload\ndef _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ...\n\n\n@overload\ndef _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ...\n\n\ndef _compute_path(base_dir, element, check_path=False):\n    \"\"\"\n    Args:\n        base_dir: the base directory of the dataset.\n        element: file path(s) to append to directory.\n        check_path: if `True`, only compute when the result is an existing path.\n\n    Raises:\n        TypeError: When ``element`` contains a non ``str``.\n        TypeError: When ``element`` type is not in ``Union[list, str]``.\n\n    \"\"\"\n\n    def _join_path(base_dir: PathLike, item: PathLike):\n        result = os.path.normpath(os.path.join(base_dir, item))\n        if check_path and not os.path.exists(result):\n            # if not an existing path, don't join with base dir\n            return f\"{item}\"\n        return f\"{result}\"\n\n    if isinstance(element, (str, os.PathLike)):\n        return _join_path(base_dir, element)\n    if isinstance(element, list):\n        for e in element:\n            if not isinstance(e, (str, os.PathLike)):\n                return element\n        return [_join_path(base_dir, e) for e in element]\n    return element\n\n\ndef _append_paths(base_dir: PathLike, is_segmentation: bool, items: list[dict]) -> list[dict]:\n    \"\"\"\n    Args:\n        base_dir: the base directory of the dataset.\n        is_segmentation: whether the datalist is for segmentation task.\n        items: list of data items, each of which is a dict keyed by element names.\n\n    Raises:\n        TypeError: When ``items`` contains a non ``dict``.\n\n    \"\"\"\n    for item in items:\n        if not isinstance(item, dict):\n            raise TypeError(f\"Every item in items must be a dict but got {type(item).__name__}.\")\n        for k, v in item.items():\n            if k == \"image\" or is_segmentation and k == \"label\":\n                item[k] = _compute_path(base_dir, v, check_path=False)\n            else:\n                # for other items, auto detect whether it's a valid path\n                item[k] = _compute_path(base_dir, v, check_path=True)\n    return items\n\n\ndef load_decathlon_datalist(\n    data_list_file_path: PathLike,\n    is_segmentation: bool = True,\n    data_list_key: str = \"training\",\n    base_dir: PathLike | None = None,\n) -> list[dict]:\n    \"\"\"Load image/label paths of decathlon challenge from JSON file\n\n    JSON file should follow the format of the Medical Segmentation Decathlon\n    datalist.json files, see http://medicaldecathlon.com.\n    The files are structured as follows:\n\n    .. code-block:: python\n\n        {\n            \"metadata_key_0\": \"metadata_value_0\",\n            \"metadata_key_1\": \"metadata_value_1\",\n            ...,\n            \"training\": [\n                {\"image\": \"path/to/image_1.nii.gz\", \"label\": \"path/to/label_1.nii.gz\"},\n                {\"image\": \"path/to/image_2.nii.gz\", \"label\": \"path/to/label_2.nii.gz\"},\n                ...\n            ],\n            \"test\": [\n                \"path/to/image_3.nii.gz\",\n                \"path/to/image_4.nii.gz\",\n                ...\n            ]\n        }\n\n\n    The metadata keys are optional for loading the datalist, but include:\n        - some string items: ``name``, ``description``, ``reference``, ``licence``, ``release``, ``tensorImageSize``\n        - two dict items: ``modality`` (keyed by channel index), and ``labels`` (keyed by label index)\n        - and two integer items: ``numTraining`` and ``numTest``, with the number of items.\n\n    The ``training`` key contains a list of dictionaries, each of which has at least\n    the ``image`` and ``label`` keys.\n    The image and label are loaded by :py:func:`monai.transforms.LoadImaged`, so both can be either\n    a single file path or a list of file paths, in which case they are loaded as multi-channel images.\n    Each item can also include a ``fold`` key for cross-validation purposes.\n    The \"test\" key contains a list of image paths, without labels, MONAI also supports a \"validation\" list\n    with the same format as the \"training\" list.\n\n\n    Args:\n        data_list_file_path: the path to the json file of datalist.\n        is_segmentation: whether the datalist is for segmentation task, default is True.\n        data_list_key: the key to get a list of dictionary to be used, default is \"training\".\n        base_dir: the base directory of the dataset, if None, use the datalist directory.\n\n    Raises:\n        ValueError: When ``data_list_file_path`` does not point to a file.\n        ValueError: When ``data_list_key`` is not specified in the data list file.\n\n    Returns a list of data items, each of which is a dict keyed by element names, for example:\n\n    .. code-block:: python\n\n        [\n            {'image': '/workspace/data/chest_19.nii.gz',  'label': '/workspace/labels/chest_19.nii.gz'},\n            {'image': '/workspace/data/chest_31.nii.gz',  'label': '/workspace/labels/chest_31.nii.gz'},\n        ]\n\n    \"\"\"\n    data_list_file_path = Path(data_list_file_path)\n    if not data_list_file_path.is_file():\n        raise ValueError(f\"Data list file {data_list_file_path} does not exist.\")\n    with open(data_list_file_path) as json_file:\n        json_data = json.load(json_file)\n    if data_list_key not in json_data:\n        raise ValueError(f'Data list {data_list_key} not specified in \"{data_list_file_path}\".')\n    expected_data = json_data[data_list_key]\n    if data_list_key == \"test\" and not isinstance(expected_data[0], dict):\n        # decathlon datalist may save the test images in a list directly instead of dict\n        expected_data = [{\"image\": i} for i in expected_data]\n\n    if base_dir is None:\n        base_dir = data_list_file_path.parent\n\n    return _append_paths(base_dir, is_segmentation, expected_data)\n\n\ndef load_decathlon_properties(data_property_file_path: PathLike, property_keys: Sequence[str] | str) -> dict:\n    \"\"\"Extract the properties with the specified keys from the Decathlon JSON file.\n    See under `load_decathlon_datalist` for the expected keys in the Decathlon challenge.\n\n    Args:\n        data_property_file_path: the path to the JSON file of data properties.\n        property_keys: expected keys to load from the JSON file, for example, we have these keys\n            in the decathlon challenge:\n            `name`, `description`, `reference`, `licence`, `tensorImageSize`,\n            `modality`, `labels`, `numTraining`, `numTest`, etc.\n\n    \"\"\"\n    data_property_file_path = Path(data_property_file_path)\n    if not data_property_file_path.is_file():\n        raise ValueError(f\"Data property file {data_property_file_path} does not exist.\")\n    with open(data_property_file_path) as json_file:\n        json_data = json.load(json_file)\n\n    properties = {}\n    for key in ensure_tuple(property_keys):\n        if key not in json_data:\n            raise KeyError(f\"key {key} is not in the data property file.\")\n        properties[key] = json_data[key]\n    return properties\n\n\ndef check_missing_files(\n    datalist: list[dict], keys: KeysCollection, root_dir: PathLike | None = None, allow_missing_keys: bool = False\n):\n    \"\"\"Checks whether some files in the Decathlon datalist are missing.\n    It would be helpful to check missing files before a heavy training run.\n\n    Args:\n        datalist: a list of data items, every item is a dictionary.\n            usually generated by `load_decathlon_datalist` API.\n        keys: expected keys to check in the datalist.\n        root_dir: if not None, provides the root dir for the relative file paths in `datalist`.\n        allow_missing_keys: whether allow missing keys in the datalist items.\n            if False, raise exception if missing. default to False.\n\n    Returns:\n        A list of missing filenames.\n\n    \"\"\"\n    missing_files = []\n    for item in datalist:\n        for k in ensure_tuple(keys):\n            if k not in item:\n                if not allow_missing_keys:\n                    raise ValueError(f\"key `{k}` is missing in the datalist item: {item}\")\n                continue\n\n            for f in ensure_tuple(item[k]):\n                if not isinstance(f, (str, os.PathLike)):\n                    raise ValueError(f\"filepath of key `{k}` must be a string or a list of strings, but got: {f}.\")\n                f = Path(f)\n                if isinstance(root_dir, (str, os.PathLike)):\n                    f = Path(root_dir).joinpath(f)\n                if not f.exists():\n                    missing_files.append(f)\n\n    return missing_files\n\n\ndef create_cross_validation_datalist(\n    datalist: list[dict],\n    nfolds: int,\n    train_folds: Sequence[int] | int,\n    val_folds: Sequence[int] | int,\n    train_key: str = \"training\",\n    val_key: str = \"validation\",\n    filename: Path | str | None = None,\n    shuffle: bool = True,\n    seed: int = 0,\n    check_missing: bool = False,\n    keys: KeysCollection | None = None,\n    root_dir: str | None = None,\n    allow_missing_keys: bool = False,\n    raise_error: bool = True,\n):\n    \"\"\"\n    Utility to create new Decathlon style datalist based on cross validation partition.\n\n    Args:\n        datalist: loaded list of dictionaries for all the items to partition.\n        nfolds: number of the kfold split.\n        train_folds: indices of folds for training part.\n        val_folds: indices of folds for validation part.\n        train_key: the key of train part in the new datalist, defaults to \"training\".\n        val_key: the key of validation part in the new datalist, defaults to \"validation\".\n        filename: if not None and ends with \".json\", save the new datalist into JSON file.\n        shuffle: whether to shuffle the datalist before partition, defaults to `True`.\n        seed: if `shuffle` is True, set the random seed, defaults to `0`.\n        check_missing: whether to check all the files specified by `keys` are existing.\n        keys: if not None and check_missing_files is True, the expected keys to check in the datalist.\n        root_dir: if not None, provides the root dir for the relative file paths in `datalist`.\n        allow_missing_keys: if check_missing_files is `True`, whether allow missing keys in the datalist items.\n            if False, raise exception if missing. default to False.\n        raise_error: when found missing files, if `True`, raise exception and stop, if `False`, print warning.\n\n    \"\"\"\n    if check_missing and keys is not None:\n        files = check_missing_files(datalist, keys, root_dir, allow_missing_keys)\n        if files:\n            msg = f\"some files of the datalist are missing: {files}\"\n            if raise_error:\n                raise ValueError(msg)\n            warnings.warn(msg)\n\n    data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=shuffle, seed=seed)\n    train_list = select_cross_validation_folds(partitions=data, folds=train_folds)\n    val_list = select_cross_validation_folds(partitions=data, folds=val_folds)\n    ret = {train_key: train_list, val_key: val_list}\n    if isinstance(filename, (str, Path)):\n        with open(filename, \"w\") as f:\n            json.dump(ret, f, indent=4)\n\n    return ret\n"
  },
  {
    "path": "monai/data/fft_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type\n\n\ndef ifftn_centered(ksp: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor:\n    \"\"\"\n    Pytorch-based ifft for spatial_dims-dim signals. \"centered\" means this function automatically takes care\n    of the required ifft and fft shifts. This function calls monai.networks.blocks.fft_utils_t.ifftn_centered_t.\n    This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift\n\n    Args:\n        ksp: k-space data that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)\n        is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels)\n\n    Returns:\n        \"out\" which is the output image (inverse fourier of ksp)\n\n    Example:\n\n        .. code-block:: python\n\n            import torch\n            ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts\n            # output1 and output2 will be identical\n            output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm=\"ortho\")\n            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )\n\n            output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)\n    \"\"\"\n    # handle numpy format\n    ksp_t, *_ = convert_data_type(ksp, torch.Tensor)\n\n    # compute ifftn\n    out_t = ifftn_centered_t(ksp_t, spatial_dims=spatial_dims, is_complex=is_complex)\n\n    # handle numpy format\n    out, *_ = convert_to_dst_type(src=out_t, dst=ksp)\n    return out\n\n\ndef fftn_centered(im: NdarrayOrTensor, spatial_dims: int, is_complex: bool = True) -> NdarrayOrTensor:\n    \"\"\"\n    Pytorch-based fft for spatial_dims-dim signals. \"centered\" means this function automatically takes care\n    of the required ifft and fft shifts. This function calls monai.networks.blocks.fft_utils_t.fftn_centered_t.\n    This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift\n\n    Args:\n        im: image that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)\n        is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels)\n\n    Returns:\n        \"out\" which is the output kspace (fourier of im)\n\n    Example:\n\n        .. code-block:: python\n\n            import torch\n            im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts\n            # output1 and output2 will be identical\n            output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm=\"ortho\")\n            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )\n\n            output2 = fftn_centered(im, spatial_dims=2, is_complex=True)\n    \"\"\"\n    # handle numpy format\n    im_t, *_ = convert_data_type(im, torch.Tensor)\n\n    # compute ifftn\n    out_t = fftn_centered_t(im_t, spatial_dims=spatial_dims, is_complex=is_complex)\n\n    # handle numpy format\n    out, *_ = convert_to_dst_type(src=out_t, dst=im)\n    return out\n"
  },
  {
    "path": "monai/data/folder_layout.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\n\nimport monai\nfrom monai.config import PathLike\nfrom monai.data.utils import create_file_basename\n\n__all__ = [\"FolderLayoutBase\", \"FolderLayout\", \"default_name_formatter\"]\n\n\ndef default_name_formatter(metadict: dict, saver: monai.transforms.Transform) -> dict:\n    \"\"\"Returns a kwargs dict for :py:meth:`FolderLayout.filename`,\n    according to the input metadata and SaveImage transform.\"\"\"\n    subject = (\n        metadict.get(monai.utils.ImageMetaKey.FILENAME_OR_OBJ, getattr(saver, \"_data_index\", 0))\n        if metadict\n        else getattr(saver, \"_data_index\", 0)\n    )\n    patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None\n    return {\"subject\": f\"{subject}\", \"idx\": patch_index}\n\n\nclass FolderLayoutBase(ABC):\n    \"\"\"\n    Abstract base class to define a common interface for FolderLayout and derived classes\n    Mainly, defines the ``filename(**kwargs) -> PathLike`` function, which must be defined\n    by the deriving class.\n\n    Example:\n\n    .. code-block:: python\n\n        from monai.data import FolderLayoutBase\n\n        class MyFolderLayout(FolderLayoutBase):\n            def __init__(\n                self,\n                basepath: Path,\n                extension: str = \"\",\n                makedirs: bool = False\n            ):\n                self.basepath = basepath\n                if not extension:\n                    self.extension = \"\"\n                elif extension.startswith(\".\"):\n                    self.extension = extension:\n                else:\n                    self.extension = f\".{extension}\"\n                self.makedirs = makedirs\n\n            def filename(self, patient_no: int, image_name: str, **kwargs) -> Path:\n                sub_path = self.basepath / patient_no\n                if not sub_path.exists():\n                    sub_path.mkdir(parents=True)\n\n                file = image_name\n                for k, v in kwargs.items():\n                    file += f\"_{k}-{v}\"\n\n                file +=  self.extension\n                return sub_path / file\n\n    \"\"\"\n\n    @abstractmethod\n    def filename(self, **kwargs) -> PathLike:\n        \"\"\"\n        Create a filename with path based on the input kwargs.\n        Abstract method, implement your own.\n        \"\"\"\n        raise NotImplementedError\n\n\nclass FolderLayout(FolderLayoutBase):\n    \"\"\"\n    A utility class to create organized filenames within ``output_dir``. The\n    ``filename`` method could be used to create a filename following the folder structure.\n\n    Example:\n\n    .. code-block:: python\n\n        from monai.data import FolderLayout\n\n        layout = FolderLayout(\n            output_dir=\"/test_run_1/\",\n            postfix=\"seg\",\n            extension=\"nii\",\n            makedirs=False)\n        layout.filename(subject=\"Sub-A\", idx=\"00\", modality=\"T1\")\n        # return value: \"/test_run_1/Sub-A_seg_00_modality-T1.nii\"\n\n    The output filename is a string starting with a ``subject`` ID, and\n    includes additional information about a customized index and image\n    modality.  This utility class doesn't alter the underlying image data, but\n    provides a convenient way to create filenames.\n    \"\"\"\n\n    def __init__(\n        self,\n        output_dir: PathLike,\n        postfix: str = \"\",\n        extension: str = \"\",\n        parent: bool = False,\n        makedirs: bool = False,\n        data_root_dir: PathLike = \"\",\n    ):\n        \"\"\"\n        Args:\n            output_dir: output directory.\n            postfix: a postfix string for output file name appended to ``subject``.\n            extension: output file extension to be appended to the end of an output filename.\n            parent: whether to add a level of parent folder to contain each image to the output filename.\n            makedirs: whether to create the output parent directories if they do not exist.\n            data_root_dir: an optional `PathLike` object to preserve the folder structure of the input `subject`.\n                Please see :py:func:`monai.data.utils.create_file_basename` for more details.\n        \"\"\"\n        self.output_dir = output_dir\n        self.postfix = postfix\n        self.ext = extension\n        self.parent = parent\n        self.makedirs = makedirs\n        self.data_root_dir = data_root_dir\n\n    def filename(self, subject: PathLike = \"subject\", idx=None, **kwargs) -> PathLike:\n        \"\"\"\n        Create a filename based on the input ``subject`` and ``idx``.\n\n        The output filename is formed as:\n\n            ``output_dir/[subject/]subject[_postfix][_idx][_key-value][ext]``\n\n        Args:\n            subject: subject name, used as the primary id of the output filename.\n                When a `PathLike` object is provided, the base filename will be used as the subject name,\n                the extension name of `subject` will be ignored, in favor of ``extension``\n                from this class's constructor.\n            idx: additional index name of the image.\n            kwargs: additional keyword arguments to be used to form the output filename.\n                The key-value pairs will be appended to the output filename as ``f\"_{k}-{v}\"``.\n        \"\"\"\n        full_name = create_file_basename(\n            postfix=self.postfix,\n            input_file_name=subject,\n            folder_path=self.output_dir,\n            data_root_dir=self.data_root_dir,\n            separate_folder=self.parent,\n            patch_index=idx,\n            makedirs=self.makedirs,\n        )\n        for k, v in kwargs.items():\n            full_name += f\"_{k}-{v}\"\n        if self.ext is not None:\n            ext = f\"{self.ext}\"\n            full_name += f\".{ext}\" if ext and not ext.startswith(\".\") else f\"{ext}\"\n        return full_name\n"
  },
  {
    "path": "monai/data/grid_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport warnings\nfrom collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Mapping, Sequence\nfrom copy import deepcopy\nfrom multiprocessing.managers import ListProxy\nfrom multiprocessing.pool import ThreadPool\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection\nfrom monai.config.type_definitions import NdarrayTensor\nfrom monai.data.iterable_dataset import IterableDataset\nfrom monai.data.utils import iter_patch, pickle_hashing\nfrom monai.transforms import Compose, RandomizableTrait, Transform, apply_transform, convert_to_contiguous\nfrom monai.utils import NumpyPadMode, ensure_tuple, first, min_version, optional_import\n\nif TYPE_CHECKING:\n    from tqdm import tqdm\n\n    has_tqdm = True\nelse:\n    tqdm, has_tqdm = optional_import(\"tqdm\", \"4.47.0\", min_version, \"tqdm\")\n\n__all__ = [\"PatchDataset\", \"GridPatchDataset\", \"PatchIter\", \"PatchIterd\"]\n\n\nclass PatchIter:\n    \"\"\"\n    Return a patch generator with predefined properties such as `patch_size`.\n    Typically used with :py:class:`monai.data.GridPatchDataset`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size: Sequence[int],\n        start_pos: Sequence[int] = (),\n        mode: str | None = NumpyPadMode.WRAP,\n        **pad_opts: dict,\n    ):\n        \"\"\"\n\n        Args:\n            patch_size: size of patches to generate slices for, 0/None selects whole dimension\n            start_pos: starting position in the array, default is 0 for each dimension\n            mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function.\n                If None, no wrapping is performed. Defaults to ``\"wrap\"``.\n                See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n                requires pytorch >= 1.10 for best compatibility.\n            pad_opts: other arguments for the `np.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        Note:\n            The `patch_size` is the size of the\n            patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which\n            will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D\n            array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be\n            specified by a `patch_size` of (10, 10, 10).\n\n        \"\"\"\n        self.patch_size = (None,) + tuple(patch_size)  # expand to have the channel dim\n        self.start_pos = ensure_tuple(start_pos)\n        self.mode = mode\n        self.pad_opts = pad_opts\n\n    def __call__(self, array: NdarrayTensor) -> Generator[tuple[NdarrayTensor, np.ndarray], None, None]:\n        \"\"\"\n        Args:\n            array: the image to generate patches from.\n\n        \"\"\"\n        yield from iter_patch(\n            array,\n            patch_size=self.patch_size,  # type: ignore\n            start_pos=self.start_pos,\n            overlap=0.0,\n            copy_back=False,\n            mode=self.mode,\n            **self.pad_opts,\n        )\n\n\nclass PatchIterd:\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.data.PatchIter`.\n    Return a patch generator for dictionary data and the coordinate, Typically used\n    with :py:class:`monai.data.GridPatchDataset`.\n    Suppose all the expected fields specified by `keys` have same shape.\n\n    Args:\n        keys: keys of the corresponding items to iterate patches.\n        patch_size: size of patches to generate slices for, 0/None selects whole dimension\n        start_pos: starting position in the array, default is 0 for each dimension\n        mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function.\n            If None, no wrapping is performed. Defaults to ``\"wrap\"``.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        pad_opts: other arguments for the `np.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    coords_key = \"patch_coords\"\n    original_spatial_shape_key = \"original_spatial_shape\"\n    start_pos_key = \"start_pos\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        patch_size: Sequence[int],\n        start_pos: Sequence[int] = (),\n        mode: str | None = NumpyPadMode.WRAP,\n        **pad_opts,\n    ):\n        self.keys = ensure_tuple(keys)\n        self.patch_iter = PatchIter(patch_size=patch_size, start_pos=start_pos, mode=mode, **pad_opts)\n\n    def __call__(\n        self, data: Mapping[Hashable, NdarrayTensor]\n    ) -> Generator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]:\n        d = dict(data)\n        original_spatial_shape = d[first(self.keys)].shape[1:]\n\n        for patch in zip(*[self.patch_iter(d[key]) for key in self.keys]):\n            coords = patch[0][1]  # use the coordinate of the first item\n            ret = {k: v[0] for k, v in zip(self.keys, patch)}\n            # fill in the extra keys with unmodified data\n            for k in set(d.keys()).difference(set(self.keys)):\n                ret[k] = deepcopy(d[k])\n            # also store the `coordinate`, `spatial shape of original image`, `start position` in the dictionary\n            ret[self.coords_key] = coords\n            ret[self.original_spatial_shape_key] = original_spatial_shape\n            ret[self.start_pos_key] = self.patch_iter.start_pos\n            yield ret, coords\n\n\nclass GridPatchDataset(IterableDataset):\n    \"\"\"\n    Yields patches from data read from an image dataset.\n    Typically used with `PatchIter` or `PatchIterd` so that the patches are chosen in a contiguous grid sampling scheme.\n\n     .. code-block:: python\n\n        import numpy as np\n\n        from monai.data import GridPatchDataset, DataLoader, PatchIter, RandShiftIntensity\n\n        # image-level dataset\n        images = [np.arange(16, dtype=float).reshape(1, 4, 4),\n                  np.arange(16, dtype=float).reshape(1, 4, 4)]\n        # image-level patch generator, \"grid sampling\"\n        patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))\n        # patch-level intensity shifts\n        patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)\n\n        # construct the dataset\n        ds = GridPatchDataset(data=images,\n                              patch_iter=patch_iter,\n                              transform=patch_intensity)\n        # use the grid patch dataset\n        for item in DataLoader(ds, batch_size=2, num_workers=2):\n            print(\"patch size:\", item[0].shape)\n            print(\"coordinates:\", item[1])\n\n        # >>> patch size: torch.Size([2, 1, 2, 2])\n        #     coordinates: tensor([[[0, 1], [0, 2], [0, 2]],\n        #                          [[0, 1], [2, 4], [0, 2]]])\n\n    Args:\n        data: the data source to read image data from.\n        patch_iter: converts an input image (item from dataset) into a iterable of image patches.\n            `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).\n            see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.\n        transform: a callable data transform operates on the patches.\n        with_coordinates: whether to yield the coordinates of each patch, default to `True`.\n        cache: whether to use cache mache mechanism, default to `False`.\n            see also: :py:class:`monai.data.CacheDataset`.\n        cache_num: number of items to be cached. Default is `sys.maxsize`.\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        cache_rate: percentage of cached data in total, default is 1.0 (cache all).\n            will take the minimum of (cache_num, data_length x cache_rate, data_length).\n        num_workers: the number of worker threads if computing cache in the initialization.\n            If num_workers is None then the number returned by os.cpu_count() is used.\n            If a value less than 1 is specified, 1 will be used instead.\n        progress: whether to display a progress bar.\n        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,\n            default to `True`. if the random transforms don't modify the cached content\n            (for example, randomly crop from the cached image and deepcopy the crop region)\n            or if every cache item is only used once in a `multi-processing` environment,\n            may set `copy=False` for better performance.\n        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n            it may help improve the performance of following logic.\n        hash_func: a callable to compute hash from data items to be cached.\n            defaults to `monai.data.utils.pickle_hashing`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Iterable | Sequence,\n        patch_iter: Callable,\n        transform: Callable | None = None,\n        with_coordinates: bool = True,\n        cache: bool = False,\n        cache_num: int = sys.maxsize,\n        cache_rate: float = 1.0,\n        num_workers: int | None = 1,\n        progress: bool = True,\n        copy_cache: bool = True,\n        as_contiguous: bool = True,\n        hash_func: Callable[..., bytes] = pickle_hashing,\n    ) -> None:\n        super().__init__(data=data, transform=None)\n        if transform is not None and not isinstance(transform, Compose):\n            transform = Compose(transform)\n        self.patch_iter = patch_iter\n        self.patch_transform = transform\n        self.with_coordinates = with_coordinates\n        self.set_num = cache_num\n        self.set_rate = cache_rate\n        self.progress = progress\n        self.copy_cache = copy_cache\n        self.as_contiguous = as_contiguous\n        self.hash_func = hash_func\n        self.num_workers = num_workers\n        if self.num_workers is not None:\n            self.num_workers = max(int(self.num_workers), 1)\n        self._cache: list | ListProxy = []\n        self._cache_other: list | ListProxy = []\n        self.cache = cache\n        self.first_random: int | None = None\n        if self.patch_transform is not None:\n            self.first_random = self.patch_transform.get_index_of_first(\n                lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)\n            )\n\n        if self.cache:\n            if isinstance(data, Iterator):\n                raise TypeError(\"Data can not be iterator when cache is True\")\n            self.set_data(data)  # type: ignore\n\n    def set_data(self, data: Sequence) -> None:\n        \"\"\"\n        Set the input data and run deterministic transforms to generate cache content.\n\n        Note: should call this func after an entire epoch and must set `persistent_workers=False`\n        in PyTorch DataLoader, because it needs to create new worker processes based on new\n        generated cache content.\n\n        \"\"\"\n        self.data = data\n\n        # only compute cache for the unique items of dataset, and record the last index for duplicated items\n        mapping = {self.hash_func(v): i for i, v in enumerate(self.data)}\n        self.cache_num = min(int(self.set_num), int(len(mapping) * self.set_rate), len(mapping))\n        self._hash_keys = list(mapping)[: self.cache_num]\n        indices = list(mapping.values())[: self.cache_num]\n        self._cache, self._cache_other = zip(*self._fill_cache(indices))  # type: ignore\n\n    def _fill_cache(self, indices=None) -> list:\n        \"\"\"\n        Compute and fill the cache content from data source.\n\n        Args:\n            indices: target indices in the `self.data` source to compute cache.\n                if None, use the first `cache_num` items.\n\n        \"\"\"\n        if self.cache_num <= 0:\n            return []\n        if indices is None:\n            indices = list(range(self.cache_num))\n        if self.progress and not has_tqdm:\n            warnings.warn(\"tqdm is not installed, will not show the caching progress bar.\")\n\n        pfunc = tqdm if self.progress and has_tqdm else (lambda v, **_: v)\n        with ThreadPool(self.num_workers) as p:\n            return list(pfunc(p.imap(self._load_cache_item, indices), total=len(indices), desc=\"Loading dataset\"))\n\n    def _load_cache_item(self, idx: int):\n        \"\"\"\n        Args:\n            idx: the index of the input data sequence.\n        \"\"\"\n        item = self.data[idx]  # type: ignore\n        patch_cache, other_cache = [], []\n        for patch, *others in self.patch_iter(item):\n            if self.first_random is not None:\n                patch = self.patch_transform(patch, end=self.first_random, threading=True)  # type: ignore\n\n            if self.as_contiguous:\n                patch = convert_to_contiguous(patch, memory_format=torch.contiguous_format)\n            if self.with_coordinates and len(others) > 0:  # patch_iter to yield at least 2 items: patch, coords\n                other_cache.append(others[0])\n            patch_cache.append(patch)\n        return patch_cache, other_cache\n\n    def _generate_patches(self, src, **apply_args):\n        \"\"\"\n        yield patches optionally post-processed by transform.\n\n        Args:\n            src: a iterable of image patches.\n            apply_args: other args for `self.patch_transform`.\n\n        \"\"\"\n        for patch, *others in src:\n            out_patch = patch\n            if self.patch_transform is not None:\n                out_patch = self.patch_transform(patch, **apply_args)\n            if self.with_coordinates and len(others) > 0:  # patch_iter to yield at least 2 items: patch, coords\n                yield out_patch, others[0]\n            else:\n                yield out_patch\n\n    def __iter__(self):\n        if self.cache:\n            cache_index = None\n            for image in super().__iter__():\n                key = self.hash_func(image)\n                if key in self._hash_keys:\n                    # if existing in cache, try to get the index in cache\n                    cache_index = self._hash_keys.index(key)\n                if cache_index is None:\n                    # no cache for this index, execute all the transforms directly\n                    yield from self._generate_patches(self.patch_iter(image))\n                else:\n                    if self._cache is None:\n                        raise RuntimeError(\n                            \"Cache buffer is not initialized, please call `set_data()` before epoch begins.\"\n                        )\n                    data = self._cache[cache_index]\n                    other = self._cache_other[cache_index]\n\n                    # load data from cache and execute from the first random transform\n                    data = deepcopy(data) if self.copy_cache else data\n                    yield from self._generate_patches(zip(data, other), start=self.first_random)\n        else:\n            for image in super().__iter__():\n                yield from self._generate_patches(self.patch_iter(image))\n\n\nclass PatchDataset(IterableDataset):\n    \"\"\"\n    Yields patches from data read from an image dataset.\n    The patches are generated by a user-specified callable `patch_func`,\n    and are optionally post-processed by `transform`.\n    For example, to generate random patch samples from an image dataset:\n\n    .. code-block:: python\n\n        import numpy as np\n\n        from monai.data import PatchDataset, DataLoader\n        from monai.transforms import RandSpatialCropSamples, RandShiftIntensity\n\n        # image dataset\n        images = [np.arange(16, dtype=float).reshape(1, 4, 4),\n                  np.arange(16, dtype=float).reshape(1, 4, 4)]\n        # image patch sampler\n        n_samples = 5\n        sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples,\n                                         random_center=True, random_size=False)\n        # patch-level intensity shifts\n        patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)\n        # construct the patch dataset\n        ds = PatchDataset(dataset=images,\n                          patch_func=sampler,\n                          samples_per_image=n_samples,\n                          transform=patch_intensity)\n\n        # use the patch dataset, length: len(images) x samplers_per_image\n        print(len(ds))\n\n        >>> 10\n\n        for item in DataLoader(ds, batch_size=2, shuffle=True, num_workers=2):\n            print(item.shape)\n\n        >>> torch.Size([2, 1, 3, 3])\n\n    \"\"\"\n\n    def __init__(\n        self, data: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Callable | None = None\n    ) -> None:\n        \"\"\"\n        Args:\n            data: an image dataset to extract patches from.\n            patch_func: converts an input image (item from dataset) into a sequence of image patches.\n                patch_func(dataset[idx]) must return a sequence of patches (length `samples_per_image`).\n            samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.\n            transform: transform applied to each patch.\n        \"\"\"\n        super().__init__(data=data, transform=None)\n\n        self.patch_func = patch_func\n        if samples_per_image <= 0:\n            raise ValueError(\"sampler_per_image must be a positive integer.\")\n        self.samples_per_image = int(samples_per_image)\n        self.patch_transform = transform\n\n    def __len__(self) -> int:\n        return len(self.data) * self.samples_per_image  # type: ignore\n\n    def __iter__(self):\n        for image in super().__iter__():\n            patches = self.patch_func(image)\n            if len(patches) != self.samples_per_image:\n                raise RuntimeWarning(\n                    f\"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}.\"\n                )\n            for patch in patches:\n                out_patch = patch\n                if self.patch_transform is not None:\n                    out_patch = apply_transform(self.patch_transform, patch, map_items=False)\n                yield out_patch\n"
  },
  {
    "path": "monai/data/image_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nfrom monai.config import DtypeLike\nfrom monai.data.image_reader import ImageReader\nfrom monai.transforms import LoadImage, Randomizable, apply_transform\nfrom monai.utils import MAX_SEED, get_seed\n\n\nclass ImageDataset(Dataset, Randomizable):\n    \"\"\"\n    Loads image/segmentation pairs of files from the given filename lists. Transformations can be specified\n    for the image and segmentation arrays separately.\n    The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images\n    and segs and return both the images and metadata, and no need to specify transform to load images from files.\n    For more information, please see the image_dataset demo in the MONAI tutorial repo,\n    https://github.com/Project-MONAI/tutorials/blob/master/modules/image_dataset.ipynb\n    \"\"\"\n\n    def __init__(\n        self,\n        image_files: Sequence[str],\n        seg_files: Sequence[str] | None = None,\n        labels: Sequence[float] | None = None,\n        transform: Callable | None = None,\n        seg_transform: Callable | None = None,\n        label_transform: Callable | None = None,\n        image_only: bool = True,\n        transform_with_metadata: bool = False,\n        dtype: DtypeLike = np.float32,\n        reader: ImageReader | str | None = None,\n        *args,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied\n        to the images and `seg_transform` to the segmentations.\n\n        Args:\n            image_files: list of image filenames.\n            seg_files: if in segmentation task, list of segmentation filenames.\n            labels: if in classification task, list of classification labels.\n            transform: transform to apply to image arrays.\n            seg_transform: transform to apply to segmentation arrays.\n            label_transform: transform to apply to the label data.\n            image_only: if True return only the image volume, otherwise, return image volume and the metadata.\n            transform_with_metadata: if True, the metadata will be passed to the transforms whenever possible.\n            dtype: if not None convert the loaded image to this data type.\n            reader: register reader to load image file and metadata, if None, will use the default readers.\n                If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs`\n                parameters, supported reader name: \"NibabelReader\", \"PILReader\", \"ITKReader\", \"NumpyReader\"\n            args: additional parameters for reader if providing a reader name.\n            kwargs: additional parameters for reader if providing a reader name.\n\n        Raises:\n            ValueError: When ``seg_files`` length differs from ``image_files``\n\n        \"\"\"\n\n        if seg_files is not None and len(image_files) != len(seg_files):\n            raise ValueError(\n                \"Must have same the number of segmentation as image files: \"\n                f\"images={len(image_files)}, segmentations={len(seg_files)}.\"\n            )\n\n        self.image_files = image_files\n        self.seg_files = seg_files\n        self.labels = labels\n        self.transform = transform\n        self.seg_transform = seg_transform\n        self.label_transform = label_transform\n        if image_only and transform_with_metadata:\n            raise ValueError(\"transform_with_metadata=True requires image_only=False.\")\n        self.image_only = image_only\n        self.transform_with_metadata = transform_with_metadata\n        self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs)\n        self.set_random_state(seed=get_seed())\n        self._seed = 0  # transform synchronization seed\n\n    def __len__(self) -> int:\n        return len(self.image_files)\n\n    def randomize(self, data: Any | None = None) -> None:\n        self._seed = int(self.R.randint(MAX_SEED, dtype=\"uint32\"))\n\n    def __getitem__(self, index: int):\n        self.randomize()\n        meta_data, seg_meta_data, seg, label = None, None, None, None\n\n        # load data and optionally meta\n        if self.image_only:\n            img = self.loader(self.image_files[index])\n            if self.seg_files is not None:\n                seg = self.loader(self.seg_files[index])\n        else:\n            img, meta_data = self.loader(self.image_files[index])\n            if self.seg_files is not None:\n                seg, seg_meta_data = self.loader(self.seg_files[index])\n\n        # apply the transforms\n        if self.transform is not None:\n            if isinstance(self.transform, Randomizable):\n                self.transform.set_random_state(seed=self._seed)\n\n            if self.transform_with_metadata:\n                img, meta_data = apply_transform(self.transform, (img, meta_data), map_items=False, unpack_items=True)\n            else:\n                img = apply_transform(self.transform, img, map_items=False)\n\n        if self.seg_files is not None and self.seg_transform is not None:\n            if isinstance(self.seg_transform, Randomizable):\n                self.seg_transform.set_random_state(seed=self._seed)\n\n            if self.transform_with_metadata:\n                seg, seg_meta_data = apply_transform(\n                    self.seg_transform, (seg, seg_meta_data), map_items=False, unpack_items=True\n                )\n            else:\n                seg = apply_transform(self.seg_transform, seg, map_items=False)\n\n        if self.labels is not None:\n            label = self.labels[index]\n            if self.label_transform is not None:\n                label = apply_transform(self.label_transform, label, map_items=False)  # type: ignore\n\n        # construct outputs\n        data = [img]\n        if seg is not None:\n            data.append(seg)\n        if label is not None:\n            data.append(label)\n        if not self.image_only and meta_data is not None:\n            data.append(meta_data)\n        if not self.image_only and seg_meta_data is not None:\n            data.append(seg_meta_data)\n        if len(data) == 1:\n            return data[0]\n        # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists\n        return tuple(data)\n"
  },
  {
    "path": "monai/data/image_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport gzip\nimport io\nimport os\nimport re\nimport tempfile\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Iterable, Iterator, Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Union\n\nimport numpy as np\nfrom torch.utils.data._utils.collate import np_str_obj_array_pattern\n\nfrom monai.config import KeysCollection, PathLike\nfrom monai.data.utils import (\n    affine_to_spacing,\n    correct_nifti_header_if_necessary,\n    is_no_channel,\n    is_supported_format,\n    orientation_ras_lps,\n)\nfrom monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg\n\nif TYPE_CHECKING:\n    import itk\n    import nibabel as nib\n    import nrrd\n    import pydicom\n    from nibabel.nifti1 import Nifti1Image\n    from PIL import Image as PILImage\n\n    has_nrrd = has_itk = has_nib = has_pil = has_pydicom = True\nelse:\n    itk, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\n    nib, has_nib = optional_import(\"nibabel\")\n    Nifti1Image, _ = optional_import(\"nibabel.nifti1\", name=\"Nifti1Image\")\n    PILImage, has_pil = optional_import(\"PIL.Image\")\n    pydicom, has_pydicom = optional_import(\"pydicom\")\n    nrrd, has_nrrd = optional_import(\"nrrd\", allow_namespace_pkg=True)\n\ncp, has_cp = optional_import(\"cupy\")\nkvikio, has_kvikio = optional_import(\"kvikio\")\n\nif TYPE_CHECKING:\n    import cupy\n\n    NdarrayOrCupy = Union[np.ndarray, cupy.ndarray]\nelse:\n    NdarrayOrCupy = Any\n\n__all__ = [\"ImageReader\", \"ITKReader\", \"NibabelReader\", \"NumpyReader\", \"PILReader\", \"PydicomReader\", \"NrrdReader\"]\n\n\nclass ImageReader(ABC):\n    \"\"\"\n    An abstract class defines APIs to load image files.\n\n    Typical usage of an implementation of this class is:\n\n    .. code-block:: python\n\n        image_reader = MyImageReader()\n        img_obj = image_reader.read(path_to_image)\n        img_data, meta_data = image_reader.get_data(img_obj)\n\n    - The `read` call converts image filenames into image objects,\n    - The `get_data` call fetches the image data, as well as metadata.\n    - A reader should implement `verify_suffix` with the logic of checking the input filename\n      by the filename extensions.\n\n    \"\"\"\n\n    @abstractmethod\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified `filename` is supported by the current reader.\n        This method should return True if the reader is able to read the format suggested by the\n        `filename`.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any:\n        \"\"\"\n        Read image data from specified file or files.\n        Note that it returns a data object or a sequence of data objects.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args for actual `read` API of 3rd party libs.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def get_data(self, img) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function must return two objects, the first is a numpy array of image data,\n        the second is a dictionary of metadata.\n\n        Args:\n            img: an image object loaded from an image file or a list of image objects.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\ndef _copy_compatible_dict(from_dict: dict, to_dict: dict):\n    if not isinstance(to_dict, dict):\n        raise ValueError(f\"to_dict must be a Dict, got {type(to_dict)}.\")\n    if not to_dict:\n        for key in from_dict:\n            datum = from_dict[key]\n            if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None:\n                continue\n            to_dict[key] = str(TraceKeys.NONE) if datum is None else datum  # NoneType to string for default_collate\n    else:\n        affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE\n        if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]):\n            raise RuntimeError(\n                \"affine matrix of all images should be the same for channel-wise concatenation. \"\n                f\"Got {from_dict[affine_key]} and {to_dict[affine_key]}.\"\n            )\n        if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]):\n            raise RuntimeError(\n                \"spatial_shape of all images should be the same for channel-wise concatenation. \"\n                f\"Got {from_dict[shape_key]} and {to_dict[shape_key]}.\"\n            )\n\n\ndef _stack_images(image_list: list, meta_dict: dict, to_cupy: bool = False):\n    if len(image_list) <= 1:\n        return image_list[0]\n    if not is_no_channel(meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None)):\n        channel_dim = int(meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM])\n        if to_cupy and has_cp:\n            return cp.concatenate(image_list, axis=channel_dim)\n        return np.concatenate(image_list, axis=channel_dim)\n    # stack at a new first dim as the channel dim, if `'original_channel_dim'` is unspecified\n    meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0\n    if to_cupy and has_cp:\n        return cp.stack(image_list, axis=0)\n    return np.stack(image_list, axis=0)\n\n\n@require_pkg(pkg_name=\"itk\")\nclass ITKReader(ImageReader):\n    \"\"\"\n    Load medical images based on ITK library.\n    All the supported image formats can be found at:\n    https://github.com/InsightSoftwareConsortium/ITK/tree/master/Modules/IO\n    The loaded data array will be in C order, for example, a 3D image NumPy\n    array index order will be `CDWH`.\n\n    Args:\n        channel_dim: the channel dimension of the input image, default is None.\n            This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.\n            If None, `original_channel_dim` will be either `no_channel` or `-1`.\n\n                - Nifti file is usually \"channel last\", so there is no need to specify this argument.\n                - PNG file usually has `GetNumberOfComponentsPerPixel()==3`, so there is no need to specify this argument.\n\n        series_name: the name of the DICOM series if there are multiple ones.\n            used when loading DICOM series.\n        reverse_indexing: whether to use a reversed spatial indexing convention for the returned data array.\n            If ``False``, the spatial indexing convention is reversed to be compatible with ITK;\n            otherwise, the spatial indexing follows the numpy convention. Default is ``False``.\n            This option does not affect the metadata.\n        series_meta: whether to load the metadata of the DICOM series (using the metadata from the first slice).\n            This flag is checked only when loading DICOM series. Default is ``False``.\n        affine_lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to ``True``.\n            Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix remains in the ITK convention.\n        kwargs: additional args for `itk.imread` API. more details about available args:\n            https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py\n\n    \"\"\"\n\n    def __init__(\n        self,\n        channel_dim: str | int | None = None,\n        series_name: str = \"\",\n        reverse_indexing: bool = False,\n        series_meta: bool = False,\n        affine_lps_to_ras: bool = True,\n        **kwargs,\n    ):\n        super().__init__()\n        self.kwargs = kwargs\n        self.channel_dim = float(\"nan\") if channel_dim == \"no_channel\" else channel_dim\n        self.series_name = series_name\n        self.reverse_indexing = reverse_indexing\n        self.series_meta = series_meta\n        self.affine_lps_to_ras = affine_lps_to_ras\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified file or files format is supported by ITK reader.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n\n        \"\"\"\n        return has_itk\n\n    def read(self, data: Sequence[PathLike] | PathLike, **kwargs):\n        \"\"\"\n        Read image data from specified file or files, it can read a list of images\n        and stack them together as multi-channel data in `get_data()`.\n        If passing directory path instead of file path, will treat it as DICOM images series and read.\n        Note that the returned object is ITK image object or list of ITK image objects.\n\n        Args:\n            data: file name or a list of file names to read,\n            kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys.\n                More details about available args:\n                https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py\n\n        \"\"\"\n        img_ = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for name in filenames:\n            name = f\"{name}\"\n            if Path(name).is_dir():\n                # read DICOM series\n                # https://examples.itk.org/src/io/gdcm/readdicomseriesandwrite3dimage/documentation\n                names_generator = itk.GDCMSeriesFileNames.New()\n                names_generator.SetUseSeriesDetails(True)\n                names_generator.AddSeriesRestriction(\"0008|0021\")  # Series Date\n                names_generator.SetDirectory(name)\n                series_uid = names_generator.GetSeriesUIDs()\n\n                if len(series_uid) < 1:\n                    raise FileNotFoundError(f\"no DICOMs in: {name}.\")\n                if len(series_uid) > 1:\n                    warnings.warn(f\"the directory: {name} contains more than one DICOM series.\")\n                series_identifier = series_uid[0] if not self.series_name else self.series_name\n                name = names_generator.GetFileNames(series_identifier)\n\n                name = name[0] if len(name) == 1 else name  # type: ignore\n                _obj = itk.imread(name, **kwargs_)\n                if self.series_meta:\n                    _reader = itk.ImageSeriesReader.New(FileNames=name)\n                    _reader.Update()\n                    _meta = _reader.GetMetaDataDictionaryArray()\n                    if len(_meta) > 0:\n                        # TODO: using the first slice's meta. this could be improved to filter unnecessary tags.\n                        _obj.SetMetaDataDictionary(_meta[0])\n                img_.append(_obj)\n            else:\n                img_.append(itk.imread(name, **kwargs_))\n        return img_ if len(filenames) > 1 else img_[0]\n\n    def get_data(self, img) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function returns two objects, first is numpy array of image data, second is dict of metadata.\n        It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.\n        When loading a list of files, they are stacked together at a new dimension as the first dimension,\n        and the metadata of the first image is used to represent the output metadata.\n\n        Args:\n            img: an ITK image object loaded from an image file or a list of ITK image objects.\n\n        \"\"\"\n        img_array: list[np.ndarray] = []\n        compatible_meta: dict = {}\n\n        for i in ensure_tuple(img):\n            data = self._get_array_data(i)\n            img_array.append(data)\n            header = self._get_meta_dict(i)\n            header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i, self.affine_lps_to_ras)\n            header[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS\n            header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy()\n            header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)\n            if self.channel_dim is None:  # default to \"no_channel\" or -1\n                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (\n                    float(\"nan\") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1\n                )\n            else:\n                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim\n            _copy_compatible_dict(header, compatible_meta)\n\n        return _stack_images(img_array, compatible_meta), compatible_meta\n\n    def _get_meta_dict(self, img) -> dict:\n        \"\"\"\n        Get all the metadata of the image and convert to dict type.\n\n        Args:\n            img: an ITK image object loaded from an image file.\n\n        \"\"\"\n        img_meta_dict = img.GetMetaDataDictionary()\n        meta_dict = {}\n        for key in img_meta_dict.GetKeys():\n            if key.startswith(\"ITK_\"):\n                continue\n            val = img_meta_dict[key]\n            meta_dict[key] = np.asarray(val) if type(val).__name__.startswith(\"itk\") else val\n\n        meta_dict[\"spacing\"] = np.asarray(img.GetSpacing())\n        return meta_dict\n\n    def _get_affine(self, img, lps_to_ras: bool = True):\n        \"\"\"\n        Get or construct the affine matrix of the image, it can be used to correct\n        spacing, orientation or execute spatial transforms.\n\n        Args:\n            img: an ITK image object loaded from an image file.\n            lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to True.\n\n        \"\"\"\n        direction = itk.array_from_matrix(img.GetDirection())\n        spacing = np.asarray(img.GetSpacing())\n        origin = np.asarray(img.GetOrigin())\n\n        direction = np.asarray(direction)\n        sr = min(max(direction.shape[0], 1), 3)\n        affine: np.ndarray = np.eye(sr + 1)\n        affine[:sr, :sr] = direction[:sr, :sr] @ np.diag(spacing[:sr])\n        affine[:sr, -1] = origin[:sr]\n        if lps_to_ras:\n            affine = orientation_ras_lps(affine)\n        return affine\n\n    def _get_spatial_shape(self, img):\n        \"\"\"\n        Get the spatial shape of `img`.\n\n        Args:\n            img: an ITK image object loaded from an image file.\n\n        \"\"\"\n        sr = itk.array_from_matrix(img.GetDirection()).shape[0]\n        sr = max(min(sr, 3), 1)\n        _size = list(itk.size(img))\n        if isinstance(self.channel_dim, int):\n            _size.pop(self.channel_dim)\n        return np.asarray(_size[:sr])\n\n    def _get_array_data(self, img):\n        \"\"\"\n        Get the raw array data of the image, converted to Numpy array.\n\n        Following PyTorch conventions, the returned array data has contiguous channels,\n        e.g. for an RGB image, all red channel image pixels are contiguous in memory.\n        The last axis of the returned array is the channel axis.\n\n        See also:\n\n            - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Modules/Bridge/NumPy/wrapping/PyBuffer.i.in\n\n        Args:\n            img: an ITK image object loaded from an image file.\n\n        \"\"\"\n        np_img = itk.array_view_from_image(img, keep_axes=False)\n        if img.GetNumberOfComponentsPerPixel() == 1:  # handling spatial images\n            return np_img if self.reverse_indexing else np_img.T\n        # handling multi-channel images\n        return np_img if self.reverse_indexing else np.moveaxis(np_img.T, 0, -1)\n\n\n@require_pkg(pkg_name=\"pydicom\")\nclass PydicomReader(ImageReader):\n    \"\"\"\n    Load medical images based on Pydicom library.\n    All the supported image formats can be found at:\n    https://dicom.nema.org/medical/dicom/current/output/chtml/part10/chapter_7.html\n\n    PydicomReader is also able to load segmentations, if a dicom file contains tag: `SegmentSequence`, the reader\n    will consider it as segmentation data, and to load it successfully, `PerFrameFunctionalGroupsSequence` is required\n    for dicom file, and for each frame of dicom file, `SegmentIdentificationSequence` is required.\n    This method refers to the Highdicom library.\n\n    This class refers to:\n    https://nipy.org/nibabel/dicom/dicom_orientation.html#dicom-affine-formula\n    https://github.com/pydicom/contrib-pydicom/blob/master/input-output/pydicom_series.py\n    https://highdicom.readthedocs.io/en/latest/usage.html#parsing-segmentation-seg-images\n\n    Args:\n        channel_dim: the channel dimension of the input image, default is None.\n            This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.\n            If None, `original_channel_dim` will be either `no_channel` or `-1`.\n        affine_lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to ``True``.\n            Set to ``True`` to be consistent with ``NibabelReader``,\n            otherwise the affine matrix remains in the Dicom convention.\n        swap_ij: whether to swap the first two spatial axes. Default to ``True``, so that the outputs\n            are consistent with the other readers.\n        prune_metadata: whether to prune the saved information in metadata. This argument is used for\n            `get_data` function. If True, only items that are related to the affine matrix will be saved.\n            Default to ``True``.\n        label_dict: label of the dicom data. If provided, it will be used when loading segmentation data.\n            Keys of the dict are the classes, and values are the corresponding class number. For example:\n            for TCIA collection \"C4KC-KiTS\", it can be: {\"Kidney\": 0, \"Renal Tumor\": 1}.\n        fname_regex: a regular expression to match the file names when the input is a folder.\n            If provided, only the matched files will be included. For example, to include the file name\n            \"image_0001.dcm\", the regular expression could be `\".*image_(\\\\d+).dcm\"`. Default to `\"\"`.\n            Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.\n        to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.\n            Default is False. CuPy and Kvikio are required for this option.\n            In practical use, it's recommended to add a warm up call before the actual loading.\n            A related tutorial will be prepared in the future, and the document will be updated accordingly.\n        kwargs: additional args for `pydicom.dcmread` API. more details about available args:\n            https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html\n            If the `get_data` function will be called\n            (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument\n            `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`,\n            `ImagePositionPatient`, `ImageOrientationPatient` and all `pixel_array` related tags.\n    \"\"\"\n\n    def __init__(\n        self,\n        channel_dim: str | int | None = None,\n        affine_lps_to_ras: bool = True,\n        swap_ij: bool = True,\n        prune_metadata: bool = True,\n        label_dict: dict | None = None,\n        fname_regex: str = \"\",\n        to_gpu: bool = False,\n        **kwargs,\n    ):\n        super().__init__()\n        self.kwargs = kwargs\n        self.channel_dim = float(\"nan\") if channel_dim == \"no_channel\" else channel_dim\n        self.affine_lps_to_ras = affine_lps_to_ras\n        self.swap_ij = swap_ij\n        self.prune_metadata = prune_metadata\n        self.label_dict = label_dict\n        self.fname_regex = fname_regex\n        if to_gpu and (not has_cp or not has_kvikio):\n            warnings.warn(\n                \"PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading.\"\n            )\n            to_gpu = False\n\n        if to_gpu:\n            self.warmup_kvikio()\n\n        self.to_gpu = to_gpu\n\n    def warmup_kvikio(self):\n        \"\"\"\n        Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.\n        This can accelerate the data loading process when `to_gpu` is set to True.\n        \"\"\"\n        if has_cp and has_kvikio:\n            a = cp.arange(100)\n            with tempfile.NamedTemporaryFile() as tmp_file:\n                tmp_file_name = tmp_file.name\n                f = kvikio.CuFile(tmp_file_name, \"w\")\n                f.write(a)\n                f.close()\n\n                b = cp.empty_like(a)\n                f = kvikio.CuFile(tmp_file_name, \"r\")\n                f.read(b)\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified file or files format is supported by Pydicom reader.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n\n        \"\"\"\n        return has_pydicom\n\n    def read(self, data: Sequence[PathLike] | PathLike, **kwargs):\n        \"\"\"\n        Read image data from specified file or files, it can read a list of images\n        and stack them together as multi-channel data in `get_data()`.\n        If passing directory path instead of file path, will treat it as DICOM images series and read.\n\n        Args:\n            data: file name or a list of file names to read,\n            kwargs: additional args for `pydicom.dcmread` API, will override `self.kwargs` for existing keys.\n\n        Returns:\n            If `data` represents a filename: return a pydicom dataset object.\n            If `data` represents a list of filenames or a directory: return a list of pydicom dataset object.\n            If `data` represents a list of directories: return a list of list of pydicom dataset object.\n\n        \"\"\"\n        img_ = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        self.filenames = list(filenames)\n        kwargs_ = self.kwargs.copy()\n        if self.to_gpu:\n            kwargs[\"defer_size\"] = \"100 KB\"\n        kwargs_.update(kwargs)\n\n        self.has_series = False\n\n        for i, name in enumerate(filenames):\n            name = f\"{name}\"\n            if Path(name).is_dir():\n                # read DICOM series\n                if self.fname_regex is not None:\n                    series_slcs = [slc for slc in glob.glob(os.path.join(name, \"*\")) if re.match(self.fname_regex, slc)]\n                else:\n                    series_slcs = [slc for slc in glob.glob(os.path.join(name, \"*\")) if pydicom.misc.is_dicom(slc)]\n                slices = []\n                loaded_slc_names = []\n                for slc in series_slcs:\n                    try:\n                        slices.append(pydicom.dcmread(fp=slc, **kwargs_))\n                        loaded_slc_names.append(slc)\n                    except pydicom.errors.InvalidDicomError as e:\n                        warnings.warn(f\"Failed to read {slc} with exception: \\n{e}.\", stacklevel=2)\n                if len(slices) > 1:\n                    self.has_series = True\n                    img_.append(slices)\n                    self.filenames[i] = loaded_slc_names  # type: ignore\n                else:\n                    img_.append(slices[0])  # type: ignore\n                    self.filenames[i] = loaded_slc_names[0]  # type: ignore\n            else:\n                ds = pydicom.dcmread(fp=name, **kwargs_)\n                img_.append(ds)  # type: ignore\n        if len(filenames) == 1:\n            return img_[0]\n        return img_\n\n    def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):\n        \"\"\"\n        Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new\n        dimension as the last dimension.\n\n        The stack order depends on Instance Number. The metadata will be based on the\n        first slice's metadata, and some new items will be added:\n\n        \"spacing\": the new spacing of the stacked slices.\n        \"lastImagePositionPatient\": `ImagePositionPatient` for the last slice, it will be used to achieve the affine\n            matrix.\n        \"spatial_shape\": the spatial shape of the stacked slices.\n\n        Args:\n            data: a list of pydicom dataset objects.\n        Returns:\n            a tuple that consisted with data array and metadata.\n        \"\"\"\n        slices: list = []\n        # for a dicom series\n        for slc_ds, filename in zip(data, filenames):\n            if hasattr(slc_ds, \"InstanceNumber\"):\n                slices.append((slc_ds, filename))\n            else:\n                warnings.warn(f\"slice: {filename} does not have InstanceNumber tag, skip it.\")\n        slices = sorted(slices, key=lambda s: s[0].InstanceNumber)\n        if len(slices) == 0:\n            raise ValueError(\"the input does not have valid slices.\")\n\n        first_slice, first_filename = slices[0]\n        average_distance = 0.0\n        first_array = self._get_array_data(first_slice, first_filename)\n        shape = first_array.shape\n        spacing = getattr(first_slice, \"PixelSpacing\", [1.0] * len(shape))\n        prev_pos = getattr(first_slice, \"ImagePositionPatient\", (0.0, 0.0, 0.0))[2]\n        stack_array_list: list = [first_array]\n        for idx in range(1, len(slices)):\n            slc_array = self._get_array_data(slices[idx][0], slices[idx][1])\n            slc_shape = slc_array.shape\n            slc_spacing = getattr(slices[idx][0], \"PixelSpacing\", [1.0] * len(shape))\n            slc_pos = getattr(slices[idx][0], \"ImagePositionPatient\", (0.0, 0.0, float(idx)))[2]\n            if not np.allclose(slc_spacing, spacing):\n                warnings.warn(f\"the list contains slices that have different spacings {spacing} and {slc_spacing}.\")\n            if shape != slc_shape:\n                warnings.warn(f\"the list contains slices that have different shapes {shape} and {slc_shape}.\")\n            average_distance += abs(prev_pos - slc_pos)\n            prev_pos = slc_pos\n            stack_array_list.append(slc_array)\n\n        if len(slices) > 1:\n            average_distance /= len(slices) - 1\n            spacing.append(average_distance)\n            if self.to_gpu:\n                stack_array = cp.stack(stack_array_list, axis=-1)\n            else:\n                stack_array = np.stack(stack_array_list, axis=-1)\n\n            del stack_array_list[:]\n            stack_metadata = self._get_meta_dict(first_slice)\n            stack_metadata[\"spacing\"] = np.asarray(spacing)\n            if hasattr(slices[-1][0], \"ImagePositionPatient\"):\n                stack_metadata[\"lastImagePositionPatient\"] = np.asarray(slices[-1][0].ImagePositionPatient)\n            stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),)\n        else:\n            stack_array = stack_array_list[0]\n            stack_metadata = self._get_meta_dict(first_slice)\n            stack_metadata[\"spacing\"] = np.asarray(spacing)\n            stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape\n\n        return stack_array, stack_metadata\n\n    def get_data(self, data) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function returns two objects, first is numpy array of image data, second is dict of metadata.\n        It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.\n        For dicom series within the input, all slices will be stacked first,\n        When loading a list of files (dicom file, or stacked dicom series), they are stacked together at a new\n        dimension as the first dimension, and the metadata of the first image is used to represent the output metadata.\n\n        To use this function, all pydicom dataset objects (if not segmentation data) should contain:\n        `pixel_array`, `PixelSpacing`, `ImagePositionPatient` and `ImageOrientationPatient`.\n\n        For segmentation data, we assume that the input is not a dicom series, and the object should contain\n        `SegmentSequence` in order to identify it.\n        In addition, tags (5200, 9229) and (5200, 9230) are required to achieve\n        `PixelSpacing`, `ImageOrientationPatient` and `ImagePositionPatient`.\n\n        Args:\n            data: a pydicom dataset object, or a list of pydicom dataset objects, or a list of list of\n                pydicom dataset objects.\n\n        \"\"\"\n\n        dicom_data = []\n        # combine dicom series if exists\n        if self.has_series is True:\n            # a list, all objects within a list belong to one dicom series\n            if not isinstance(data[0], list):\n                # input is a dir, self.filenames is a list of list of filenames\n                dicom_data.append(self._combine_dicom_series(data, self.filenames[0]))  # type: ignore\n            # a list of list, each inner list represents a dicom series\n            else:\n                for i, series in enumerate(data):\n                    dicom_data.append(self._combine_dicom_series(series, self.filenames[i]))  # type: ignore\n        else:\n            # a single pydicom dataset object\n            if not isinstance(data, list):\n                data = [data]\n            for i, d in enumerate(data):\n                if hasattr(d, \"SegmentSequence\"):\n                    data_array, metadata = self._get_seg_data(d, self.filenames[i])\n                else:\n                    data_array = self._get_array_data(d, self.filenames[i])\n                    metadata = self._get_meta_dict(d)\n                    metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape\n                dicom_data.append((data_array, metadata))\n\n        img_array: list[NdarrayOrCupy] = []\n        compatible_meta: dict = {}\n\n        for data_array, metadata in ensure_tuple(dicom_data):\n            if self.swap_ij:\n                data_array = cp.swapaxes(data_array, 0, 1) if self.to_gpu else np.swapaxes(data_array, 0, 1)\n            img_array.append(cp.ascontiguousarray(data_array) if self.to_gpu else np.ascontiguousarray(data_array))\n            affine = self._get_affine(metadata, self.affine_lps_to_ras)\n            metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS\n            if self.swap_ij:\n                affine = affine @ np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])\n                sp_size = list(metadata[MetaKeys.SPATIAL_SHAPE])\n                sp_size[0], sp_size[1] = sp_size[1], sp_size[0]\n                metadata[MetaKeys.SPATIAL_SHAPE] = ensure_tuple(sp_size)\n            metadata[MetaKeys.ORIGINAL_AFFINE] = affine\n            metadata[MetaKeys.AFFINE] = affine.copy()\n            if self.channel_dim is None:  # default to \"no_channel\" or -1\n                metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = (\n                    float(\"nan\") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1\n                )\n            else:\n                metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim\n            metadata[\"spacing\"] = affine_to_spacing(\n                metadata[MetaKeys.ORIGINAL_AFFINE], r=len(metadata[MetaKeys.SPATIAL_SHAPE])\n            )\n\n            _copy_compatible_dict(metadata, compatible_meta)\n\n        return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta\n\n    def _get_meta_dict(self, img) -> dict:\n        \"\"\"\n        Get all the metadata of the image and convert to dict type.\n\n        Args:\n            img: a Pydicom dataset object.\n\n        \"\"\"\n\n        metadata = img.to_json_dict(suppress_invalid_tags=True)\n\n        if self.prune_metadata:\n            prune_metadata = {}\n            for key in [\"00200037\", \"00200032\", \"00280030\", \"52009229\", \"52009230\"]:\n                if key in metadata.keys():\n                    prune_metadata[key] = metadata[key]\n            return prune_metadata\n\n        # always remove Pixel Data \"7FE00008\" or \"7FE00009\" or \"7FE00010\"\n        # always remove Data Set Trailing Padding \"FFFCFFFC\"\n        for key in [\"7FE00008\", \"7FE00009\", \"7FE00010\", \"FFFCFFFC\"]:\n            if key in metadata.keys():\n                metadata.pop(key)\n\n        return metadata  # type: ignore\n\n    def _get_affine(self, metadata: dict, lps_to_ras: bool = True):\n        \"\"\"\n        Get or construct the affine matrix of the image, it can be used to correct\n        spacing, orientation or execute spatial transforms.\n\n        Args:\n            metadata: metadata with dict type.\n            lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to True.\n\n        \"\"\"\n        affine: np.ndarray = np.eye(4)\n        if not (\"00200037\" in metadata and \"00200032\" in metadata):\n            return affine\n        # \"00200037\" is the tag of `ImageOrientationPatient`\n        rx, ry, rz, cx, cy, cz = metadata[\"00200037\"][\"Value\"]\n        # \"00200032\" is the tag of `ImagePositionPatient`\n        sx, sy, sz = metadata[\"00200032\"][\"Value\"]\n        # \"00280030\" is the tag of `PixelSpacing`\n        spacing = metadata[\"00280030\"][\"Value\"] if \"00280030\" in metadata else (1.0, 1.0)\n        dr, dc = metadata.get(\"spacing\", spacing)[:2]\n        affine[0, 0] = cx * dr\n        affine[0, 1] = rx * dc\n        affine[0, 3] = sx\n        affine[1, 0] = cy * dr\n        affine[1, 1] = ry * dc\n        affine[1, 3] = sy\n        affine[2, 0] = cz * dr\n        affine[2, 1] = rz * dc\n        affine[2, 2] = 1.0\n        affine[2, 3] = sz\n\n        # 3d\n        if \"lastImagePositionPatient\" in metadata:\n            t1n, t2n, t3n = metadata[\"lastImagePositionPatient\"]\n            n = metadata[MetaKeys.SPATIAL_SHAPE][-1]\n            k1, k2, k3 = (t1n - sx) / (n - 1), (t2n - sy) / (n - 1), (t3n - sz) / (n - 1)\n            affine[0, 2] = k1\n            affine[1, 2] = k2\n            affine[2, 2] = k3\n\n        if lps_to_ras:\n            affine = orientation_ras_lps(affine)\n        return affine\n\n    def _get_frame_data(self, img, filename, array_data) -> Iterator:\n        \"\"\"\n        yield frames and description from the segmentation image.\n        This function is adapted from Highdicom:\n        https://github.com/herrmannlab/highdicom/blob/v0.18.2/src/highdicom/seg/utils.py\n\n        which has the following license...\n\n        # =========================================================================\n        # https://github.com/herrmannlab/highdicom/blob/v0.18.2/LICENSE\n        #\n        # Copyright 2020 MGH Computational Pathology\n        # Permission is hereby granted, free of charge, to any person obtaining a\n        # copy of this software and associated documentation files (the\n        # \"Software\"), to deal in the Software without restriction, including\n        # without limitation the rights to use, copy, modify, merge, publish,\n        # distribute, sublicense, and/or sell copies of the Software, and to\n        # permit persons to whom the Software is furnished to do so, subject to\n        # the following conditions:\n        # The above copyright notice and this permission notice shall be included\n        # in all copies or substantial portions of the Software.\n        # THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS\n        # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n        # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n        # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\n        # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\n        # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\n        # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n        # =========================================================================\n\n        (https://github.com/herrmannlab/highdicom/issues/188)\n\n        Args:\n            img: a Pydicom dataset object that has attribute \"SegmentSequence\".\n\n        \"\"\"\n\n        if not hasattr(img, \"PerFrameFunctionalGroupsSequence\"):\n            raise NotImplementedError(f\"To read dicom seg: {filename}, 'PerFrameFunctionalGroupsSequence' is required.\")\n\n        frame_seg_nums = []\n        for f in img.PerFrameFunctionalGroupsSequence:\n            if not hasattr(f, \"SegmentIdentificationSequence\"):\n                raise NotImplementedError(\n                    f\"To read dicom seg: {filename}, 'SegmentIdentificationSequence' is required for each frame.\"\n                )\n            frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber))\n\n        frame_seg_nums_arr = cp.array(frame_seg_nums) if self.to_gpu else np.array(frame_seg_nums)\n\n        seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence}\n\n        for i in np.unique(frame_seg_nums_arr) if not self.to_gpu else cp.unique(frame_seg_nums_arr):\n            indices = np.where(frame_seg_nums_arr == i)[0] if not self.to_gpu else cp.where(frame_seg_nums_arr == i)[0]\n            yield (array_data[indices, ...], seg_descriptions[i])\n\n    def _get_seg_data(self, img, filename):\n        \"\"\"\n        Get the array data and metadata of the segmentation image.\n\n        Aegs:\n            img: a Pydicom dataset object that has attribute \"SegmentSequence\".\n            filename: the file path of the image.\n\n        \"\"\"\n\n        metadata = self._get_meta_dict(img)\n        n_classes = len(img.SegmentSequence)\n        array_data = self._get_array_data(img, filename)\n        spatial_shape = list(array_data.shape)\n        spatial_shape[0] = spatial_shape[0] // n_classes\n\n        if self.label_dict is not None:\n            metadata[\"labels\"] = self.label_dict\n            if self.to_gpu:\n                all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)\n            else:\n                all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)\n        else:\n            metadata[\"labels\"] = {}\n            if self.to_gpu:\n                all_segs = cp.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)\n            else:\n                all_segs = np.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)\n\n        for i, (frames, description) in enumerate(self._get_frame_data(img, filename, array_data)):\n            segment_label = getattr(description, \"SegmentLabel\", f\"label_{i}\")\n            class_name = getattr(description, \"SegmentDescription\", segment_label)\n            if class_name not in metadata[\"labels\"].keys():\n                metadata[\"labels\"][class_name] = i\n            class_num = metadata[\"labels\"][class_name]\n            all_segs[..., class_num] = frames\n\n        all_segs = all_segs.transpose([1, 2, 0, 3])\n        metadata[MetaKeys.SPATIAL_SHAPE] = all_segs.shape[:-1]\n\n        if \"52009229\" in metadata.keys():\n            shared_func_group_seq = metadata[\"52009229\"][\"Value\"][0]\n\n            # get `ImageOrientationPatient`\n            if \"00209116\" in shared_func_group_seq.keys():\n                plane_orient_seq = shared_func_group_seq[\"00209116\"][\"Value\"][0]\n                if \"00200037\" in plane_orient_seq.keys():\n                    metadata[\"00200037\"] = plane_orient_seq[\"00200037\"]\n\n            # get `PixelSpacing`\n            if \"00289110\" in shared_func_group_seq.keys():\n                pixel_measure_seq = shared_func_group_seq[\"00289110\"][\"Value\"][0]\n\n                if \"00280030\" in pixel_measure_seq.keys():\n                    pixel_spacing = pixel_measure_seq[\"00280030\"][\"Value\"]\n                    metadata[\"spacing\"] = pixel_spacing\n                    if \"00180050\" in pixel_measure_seq.keys():\n                        metadata[\"spacing\"] += pixel_measure_seq[\"00180050\"][\"Value\"]\n\n            if self.prune_metadata:\n                metadata.pop(\"52009229\")\n\n        # get `ImagePositionPatient`\n        if \"52009230\" in metadata.keys():\n            first_frame_func_group_seq = metadata[\"52009230\"][\"Value\"][0]\n            if \"00209113\" in first_frame_func_group_seq.keys():\n                plane_position_seq = first_frame_func_group_seq[\"00209113\"][\"Value\"][0]\n                if \"00200032\" in plane_position_seq.keys():\n                    metadata[\"00200032\"] = plane_position_seq[\"00200032\"]\n                    metadata[\"lastImagePositionPatient\"] = metadata[\"52009230\"][\"Value\"][-1][\"00209113\"][\"Value\"][0][\n                        \"00200032\"\n                    ][\"Value\"]\n            if self.prune_metadata:\n                metadata.pop(\"52009230\")\n\n        return all_segs, metadata\n\n    def _get_array_data_from_gpu(self, img, filename):\n        \"\"\"\n        Get the raw array data of the image. This function is used when `to_gpu` is set to True.\n\n        Args:\n            img: a Pydicom dataset object.\n            filename: the file path of the image.\n\n        \"\"\"\n        rows = getattr(img, \"Rows\", None)\n        columns = getattr(img, \"Columns\", None)\n        bits_allocated = getattr(img, \"BitsAllocated\", None)\n        samples_per_pixel = getattr(img, \"SamplesPerPixel\", 1)\n        number_of_frames = getattr(img, \"NumberOfFrames\", 1)\n        pixel_representation = getattr(img, \"PixelRepresentation\", 1)\n\n        if rows is None or columns is None or bits_allocated is None:\n            warnings.warn(\n                f\"dicom data: {filename} does not have Rows, Columns or BitsAllocated, falling back to CPU loading.\"\n            )\n\n            if not hasattr(img, \"pixel_array\"):\n                raise ValueError(f\"dicom data: {filename} does not have pixel_array.\")\n            data = img.pixel_array\n\n            return data\n\n        if bits_allocated == 8:\n            dtype = cp.int8 if pixel_representation == 1 else cp.uint8\n        elif bits_allocated == 16:\n            dtype = cp.int16 if pixel_representation == 1 else cp.uint16\n        elif bits_allocated == 32:\n            dtype = cp.int32 if pixel_representation == 1 else cp.uint32\n        else:\n            raise ValueError(\"Unsupported BitsAllocated value\")\n\n        bytes_per_pixel = bits_allocated // 8\n        total_pixels = rows * columns * samples_per_pixel * number_of_frames\n        expected_pixel_data_length = total_pixels * bytes_per_pixel\n\n        pixel_data_tag = pydicom.tag.Tag(0x7FE0, 0x0010)\n        if pixel_data_tag not in img:\n            raise ValueError(f\"dicom data: {filename} does not have pixel data.\")\n\n        offset = img.get_item(pixel_data_tag, keep_deferred=True).value_tell\n\n        with kvikio.CuFile(filename, \"r\") as f:\n            buffer = cp.empty(expected_pixel_data_length, dtype=cp.int8)\n            f.read(buffer, expected_pixel_data_length, offset)\n\n        new_shape = (number_of_frames, rows, columns) if number_of_frames > 1 else (rows, columns)\n        data = buffer.view(dtype).reshape(new_shape)\n\n        return data\n\n    def _get_array_data(self, img, filename):\n        \"\"\"\n        Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data\n        will be rescaled. The output data has the dtype float32 if the rescaling is applied.\n\n        Args:\n            img: a Pydicom dataset object.\n            filename: the file path of the image.\n\n        \"\"\"\n        # process Dicom series\n\n        if self.to_gpu:\n            data = self._get_array_data_from_gpu(img, filename)\n        else:\n            if not hasattr(img, \"pixel_array\"):\n                raise ValueError(f\"dicom data: {filename} does not have pixel_array.\")\n            data = img.pixel_array\n\n        slope, offset = 1.0, 0.0\n        rescale_flag = False\n        if hasattr(img, \"RescaleSlope\"):\n            slope = img.RescaleSlope\n            rescale_flag = True\n        if hasattr(img, \"RescaleIntercept\"):\n            offset = img.RescaleIntercept\n            rescale_flag = True\n\n        if rescale_flag:\n            if self.to_gpu:\n                slope = cp.asarray(slope, dtype=cp.float32)\n                offset = cp.asarray(offset, dtype=cp.float32)\n                data = data.astype(cp.float32) * slope + offset\n            else:\n                data = data.astype(np.float32) * slope + offset\n\n        return data\n\n\n@require_pkg(pkg_name=\"nibabel\")\nclass NibabelReader(ImageReader):\n    \"\"\"\n    Load NIfTI format images based on Nibabel library.\n\n    Args:\n        channel_dim: the channel dimension of the input image, default is None.\n            this is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.\n            if None, `original_channel_dim` will be either `no_channel` or `-1`.\n            most Nifti files are usually \"channel last\", no need to specify this argument for them.\n        as_closest_canonical: if True, load the image as closest to canonical axis format.\n        squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3)\n        to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.\n            Default is False. CuPy and Kvikio are required for this option.\n            Note: For compressed NIfTI files, some operations may still be performed on CPU memory,\n            and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.\n        kwargs: additional args for `nibabel.load` API. more details about available args:\n            https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py\n\n    \"\"\"\n\n    def __init__(\n        self,\n        channel_dim: str | int | None = None,\n        as_closest_canonical: bool = False,\n        squeeze_non_spatial_dims: bool = False,\n        to_gpu: bool = False,\n        **kwargs,\n    ):\n        super().__init__()\n        self.channel_dim = float(\"nan\") if channel_dim == \"no_channel\" else channel_dim\n        self.as_closest_canonical = as_closest_canonical\n        self.squeeze_non_spatial_dims = squeeze_non_spatial_dims\n        if to_gpu and (not has_cp or not has_kvikio):\n            warnings.warn(\n                \"NibabelReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading.\"\n            )\n            to_gpu = False\n\n        if to_gpu:\n            self.warmup_kvikio()\n\n        self.to_gpu = to_gpu\n        self.kwargs = kwargs\n\n    def warmup_kvikio(self):\n        \"\"\"\n        Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.\n        This can accelerate the data loading process when `to_gpu` is set to True.\n        \"\"\"\n        if has_cp and has_kvikio:\n            a = cp.arange(100)\n            with tempfile.NamedTemporaryFile() as tmp_file:\n                tmp_file_name = tmp_file.name\n                f = kvikio.CuFile(tmp_file_name, \"w\")\n                f.write(a)\n                f.close()\n\n                b = cp.empty_like(a)\n                f = kvikio.CuFile(tmp_file_name, \"r\")\n                f.read(b)\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified file or files format is supported by Nibabel reader.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n\n        \"\"\"\n        suffixes: Sequence[str] = [\"nii\", \"nii.gz\"]\n        return has_nib and is_supported_format(filename, suffixes)\n\n    def read(self, data: Sequence[PathLike] | PathLike, **kwargs):\n        \"\"\"\n        Read image data from specified file or files, it can read a list of images\n        and stack them together as multi-channel data in `get_data()`.\n        Note that the returned object is Nibabel image object or list of Nibabel image objects.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args for `nibabel.load` API, will override `self.kwargs` for existing keys.\n                More details about available args:\n                https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py\n\n        \"\"\"\n        img_: list[Nifti1Image] = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        self.filenames = filenames\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for name in filenames:\n            img = nib.load(name, **kwargs_)\n            img = correct_nifti_header_if_necessary(img)\n            img_.append(img)  # type: ignore\n        return img_ if len(filenames) > 1 else img_[0]\n\n    def get_data(self, img) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function returns two objects, first is numpy array of image data, second is dict of metadata.\n        It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.\n        When loading a list of files, they are stacked together at a new dimension as the first dimension,\n        and the metadata of the first image is used to present the output metadata.\n\n        Args:\n            img: a Nibabel image object loaded from an image file or a list of Nibabel image objects.\n\n        \"\"\"\n        img_array: list[NdarrayOrCupy] = []\n        compatible_meta: dict = {}\n\n        for i, filename in zip(ensure_tuple(img), self.filenames):\n            header = self._get_meta_dict(i)\n            if MetaKeys.PIXDIM in header:\n                header[MetaKeys.ORIGINAL_PIXDIM] = np.array(header[MetaKeys.PIXDIM], copy=True)\n            header[MetaKeys.AFFINE] = self._get_affine(i)\n            header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(i)\n            header[\"as_closest_canonical\"] = self.as_closest_canonical\n            if self.as_closest_canonical:\n                i = nib.as_closest_canonical(i)\n                header[MetaKeys.AFFINE] = self._get_affine(i)\n            header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)\n            header[MetaKeys.SPACE] = SpaceKeys.RAS\n            data = self._get_array_data(i, filename)\n            if self.squeeze_non_spatial_dims:\n                for d in range(len(data.shape), len(header[MetaKeys.SPATIAL_SHAPE]), -1):\n                    if data.shape[d - 1] == 1:\n                        data = data.squeeze(axis=d - 1)\n            img_array.append(data)\n            if self.channel_dim is None:  # default to \"no_channel\" or -1\n                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (\n                    float(\"nan\") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1\n                )\n            else:\n                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim\n            _copy_compatible_dict(header, compatible_meta)\n\n        return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta\n\n    def _get_meta_dict(self, img) -> dict:\n        \"\"\"\n        Get the all the metadata of the image and convert to dict type.\n\n        Args:\n            img: a Nibabel image object loaded from an image file.\n\n        \"\"\"\n        # swap to little endian as PyTorch doesn't support big endian\n        try:\n            header = img.header.as_byteswapped(\"<\")\n        except ValueError:\n            header = img.header\n        return dict(header)\n\n    def _get_affine(self, img):\n        \"\"\"\n        Get the affine matrix of the image, it can be used to correct\n        spacing, orientation or execute spatial transforms.\n\n        Args:\n            img: a Nibabel image object loaded from an image file.\n\n        \"\"\"\n        return np.array(img.affine, copy=True)\n\n    def _get_spatial_shape(self, img):\n        \"\"\"\n        Get the spatial shape of image data, it doesn't contain the channel dim.\n\n        Args:\n            img: a Nibabel image object loaded from an image file.\n\n        \"\"\"\n        # swap to little endian as PyTorch doesn't support big endian\n        try:\n            header = img.header.as_byteswapped(\"<\")\n        except ValueError:\n            header = img.header\n        dim = header.get(\"dim\", None)\n        if dim is None:\n            dim = header.get(\"dims\")  # mgh format?\n            dim = np.insert(dim, 0, 3)\n        ndim = dim[0]\n        size = list(dim[1:])\n        if not is_no_channel(self.channel_dim):\n            size.pop(int(self.channel_dim))  # type: ignore\n        spatial_rank = max(min(ndim, 3), 1)\n        return np.asarray(size[:spatial_rank])\n\n    def _get_array_data(self, img, filename):\n        \"\"\"\n        Get the raw array data of the image, converted to Numpy array.\n\n        Args:\n            img: a Nibabel image object loaded from an image file.\n            filename: file name of the image.\n\n        \"\"\"\n        if self.to_gpu:\n            file_size = os.path.getsize(filename)\n            image = cp.empty(file_size, dtype=cp.uint8)\n            with kvikio.CuFile(filename, \"r\") as f:\n                f.read(image)\n            if filename.endswith(\".nii.gz\"):\n                # for compressed data, have to tansfer to CPU to decompress\n                # and then transfer back to GPU. It is not efficient compared to .nii file\n                # and may be slower than CPU loading in some cases.\n                warnings.warn(\"Loading compressed NIfTI file into GPU may not be efficient.\")\n                compressed_data = cp.asnumpy(image)\n                with gzip.GzipFile(fileobj=io.BytesIO(compressed_data)) as gz_file:\n                    decompressed_data = gz_file.read()\n\n                image = cp.frombuffer(decompressed_data, dtype=cp.uint8)\n            data_shape = img.shape\n            data_offset = img.dataobj.offset\n            data_dtype = img.dataobj.dtype\n            return image[data_offset:].view(data_dtype).reshape(data_shape, order=\"F\")\n        return np.asanyarray(img.dataobj, order=\"C\")\n\n\nclass NumpyReader(ImageReader):\n    \"\"\"\n    Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects.\n    A typical usage is to load the `mask` data for classification task.\n    It can load part of the npz file with specified `npz_keys`.\n\n    Args:\n        npz_keys: if loading npz file, only load the specified keys, if None, load all the items.\n            stack the loaded items together to construct a new first dimension.\n        channel_dim: if not None, explicitly specify the channel dim, otherwise, treat the array as no channel.\n        kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args:\n            https://numpy.org/doc/stable/reference/generated/numpy.load.html\n\n    \"\"\"\n\n    def __init__(self, npz_keys: KeysCollection | None = None, channel_dim: str | int | None = None, **kwargs):\n        super().__init__()\n        if npz_keys is not None:\n            npz_keys = ensure_tuple(npz_keys)\n        self.npz_keys = npz_keys\n        self.channel_dim = float(\"nan\") if channel_dim == \"no_channel\" else channel_dim\n        self.kwargs = kwargs\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified file or files format is supported by Numpy reader.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n        \"\"\"\n        suffixes: Sequence[str] = [\"npz\", \"npy\"]\n        return is_supported_format(filename, suffixes)\n\n    def read(self, data: Sequence[PathLike] | PathLike, **kwargs):\n        \"\"\"\n        Read image data from specified file or files, it can read a list of data files\n        and stack them together as multi-channel data in `get_data()`.\n        Note that the returned object is Numpy array or list of Numpy arrays.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args for `numpy.load` API except `allow_pickle`, will override `self.kwargs` for existing keys.\n                More details about available args:\n                https://numpy.org/doc/stable/reference/generated/numpy.load.html\n\n        \"\"\"\n        img_: list[Nifti1Image] = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for name in filenames:\n            img = np.load(name, allow_pickle=True, **kwargs_)\n            if Path(name).name.endswith(\".npz\"):\n                # load expected items from NPZ file\n                npz_keys = list(img.keys()) if self.npz_keys is None else self.npz_keys\n                for k in npz_keys:\n                    img_.append(img[k])\n            else:\n                img_.append(img)\n\n        return img_ if len(img_) > 1 else img_[0]\n\n    def get_data(self, img) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function returns two objects, first is numpy array of image data, second is dict of metadata.\n        It constructs `affine`, `original_affine`, and `spatial_shape` and stores them in meta dict.\n        When loading a list of files, they are stacked together at a new dimension as the first dimension,\n        and the metadata of the first image is used to represent the output metadata.\n\n        Args:\n            img: a Numpy array loaded from a file or a list of Numpy arrays.\n\n        \"\"\"\n        img_array: list[np.ndarray] = []\n        compatible_meta: dict = {}\n        if isinstance(img, np.ndarray):\n            img = (img,)\n\n        for i in ensure_tuple(img):\n            header: dict[MetaKeys, Any] = {}\n            if isinstance(i, np.ndarray):\n                # if `channel_dim` is None, can not detect the channel dim, use all the dims as spatial_shape\n                spatial_shape = np.asarray(i.shape)\n                if isinstance(self.channel_dim, int):\n                    spatial_shape = np.delete(spatial_shape, self.channel_dim)\n                header[MetaKeys.SPATIAL_SHAPE] = spatial_shape\n                header[MetaKeys.SPACE] = SpaceKeys.RAS\n            img_array.append(i)\n            header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (\n                self.channel_dim if isinstance(self.channel_dim, int) else float(\"nan\")\n            )\n            _copy_compatible_dict(header, compatible_meta)\n\n        return _stack_images(img_array, compatible_meta), compatible_meta\n\n\n@require_pkg(pkg_name=\"PIL\")\nclass PILReader(ImageReader):\n    \"\"\"\n    Load common 2D image format (supports PNG, JPG, BMP) file or files from provided path.\n\n    Args:\n        converter: additional function to convert the image data after `read()`.\n            for example, use `converter=lambda image: image.convert(\"LA\")` to convert image format.\n        reverse_indexing: whether to swap axis 0 and 1 after loading the array, this is enabled by default,\n            so that output of the reader is consistent with the other readers. Set this option to ``False`` to use\n            the PIL backend's original spatial axes convention.\n        kwargs: additional args for `Image.open` API in `read()`, mode details about available args:\n            https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open\n    \"\"\"\n\n    def __init__(self, converter: Callable | None = None, reverse_indexing: bool = True, **kwargs):\n        super().__init__()\n        self.converter = converter\n        self.reverse_indexing = reverse_indexing\n        self.kwargs = kwargs\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified file or files format is supported by PIL reader.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n        \"\"\"\n        suffixes: Sequence[str] = [\"png\", \"jpg\", \"jpeg\", \"bmp\"]\n        return has_pil and is_supported_format(filename, suffixes)\n\n    def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):\n        \"\"\"\n        Read image data from specified file or files, it can read a list of images\n        and stack them together as multi-channel data in `get_data()`.\n        Note that the returned object is PIL image or list of PIL image.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args for `Image.open` API in `read()`, will override `self.kwargs` for existing keys.\n                Mode details about available args:\n                https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.open\n\n        \"\"\"\n        img_: list[PILImage.Image] = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for name in filenames:\n            img = PILImage.open(name, **kwargs_)\n            if callable(self.converter):\n                img = self.converter(img)\n            img_.append(img)\n\n        return img_ if len(filenames) > 1 else img_[0]\n\n    def get_data(self, img) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function returns two objects, first is numpy array of image data, second is dict of metadata.\n        It computes `spatial_shape` and stores it in meta dict.\n        When loading a list of files, they are stacked together at a new dimension as the first dimension,\n        and the metadata of the first image is used to represent the output metadata.\n        Note that by default `self.reverse_indexing` is set to ``True``, which swaps axis 0 and 1 after loading\n        the array because the spatial axes definition in PIL is different from other common medical packages.\n\n        Args:\n            img: a PIL Image object loaded from a file or a list of PIL Image objects.\n\n        \"\"\"\n        img_array: list[np.ndarray] = []\n        compatible_meta: dict = {}\n\n        for i in ensure_tuple(img):\n            header = self._get_meta_dict(i)\n            header[MetaKeys.SPATIAL_SHAPE] = self._get_spatial_shape(i)\n            data = np.moveaxis(np.asarray(i), 0, 1) if self.reverse_indexing else np.asarray(i)\n            img_array.append(data)\n            header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (\n                float(\"nan\") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else -1\n            )\n            _copy_compatible_dict(header, compatible_meta)\n\n        return _stack_images(img_array, compatible_meta), compatible_meta\n\n    def _get_meta_dict(self, img) -> dict:\n        \"\"\"\n        Get the all the metadata of the image and convert to dict type.\n        Args:\n            img: a PIL Image object loaded from an image file.\n\n        \"\"\"\n        return {\"format\": img.format, \"mode\": img.mode, \"width\": img.width, \"height\": img.height}\n\n    def _get_spatial_shape(self, img):\n        \"\"\"\n        Get the spatial shape of image data, it doesn't contain the channel dim.\n        Args:\n            img: a PIL Image object loaded from an image file.\n        \"\"\"\n        return np.asarray((img.width, img.height))\n\n\n@dataclass\nclass NrrdImage:\n    \"\"\"Class to wrap nrrd image array and metadata header\"\"\"\n\n    array: np.ndarray\n    header: dict\n\n\n@require_pkg(pkg_name=\"nrrd\")\nclass NrrdReader(ImageReader):\n    \"\"\"\n    Load NRRD format images based on pynrrd library.\n\n    Args:\n        channel_dim: the channel dimension of the input image, default is None.\n            This is used to set original_channel_dim in the metadata, EnsureChannelFirstD reads this field.\n            If None, `original_channel_dim` will be either `no_channel` or `0`.\n            NRRD files are usually \"channel first\".\n        dtype: dtype of the data array when loading image.\n        index_order: Specify whether the returned data array should be in C-order (‘C’) or Fortran-order (‘F’).\n            Numpy is usually in C-order, but default on the NRRD header is F\n        affine_lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to ``True``.\n            Set to ``True`` to be consistent with ``NibabelReader``, otherwise the affine matrix is unmodified.\n\n        kwargs: additional args for `nrrd.read` API. more details about available args:\n            https://github.com/mhe/pynrrd/blob/master/nrrd/reader.py\n\n    \"\"\"\n\n    def __init__(\n        self,\n        channel_dim: str | int | None = None,\n        dtype: np.dtype | type | str | None = np.float32,\n        index_order: str = \"F\",\n        affine_lps_to_ras: bool = True,\n        **kwargs,\n    ):\n        self.channel_dim = float(\"nan\") if channel_dim == \"no_channel\" else channel_dim\n        self.dtype = dtype\n        self.index_order = index_order\n        self.affine_lps_to_ras = affine_lps_to_ras\n        self.kwargs = kwargs\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified `filename` is supported by pynrrd reader.\n\n        Args:\n            filename: file name or a list of file names to read.\n                if a list of files, verify all the suffixes.\n\n        \"\"\"\n        suffixes: Sequence[str] = [\"nrrd\", \"seg.nrrd\"]\n        return has_nrrd and is_supported_format(filename, suffixes)\n\n    def read(self, data: Sequence[PathLike] | PathLike, **kwargs) -> Sequence[Any] | Any:\n        \"\"\"\n        Read image data from specified file or files.\n        Note that it returns a data object or a sequence of data objects.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args for actual `read` API of 3rd party libs.\n\n        \"\"\"\n        img_: list = []\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for name in filenames:\n            nrrd_image = NrrdImage(*nrrd.read(name, index_order=self.index_order, **kwargs_))\n            img_.append(nrrd_image)\n        return img_ if len(filenames) > 1 else img_[0]\n\n    def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Extract data array and metadata from loaded image and return them.\n        This function must return two objects, the first is a numpy array of image data,\n        the second is a dictionary of metadata.\n\n        Args:\n            img: a `NrrdImage` loaded from an image file or a list of image objects.\n\n        \"\"\"\n        img_array: list[np.ndarray] = []\n        compatible_meta: dict = {}\n\n        for i in ensure_tuple(img):\n            data = i.array.astype(self.dtype)\n            img_array.append(data)\n            header = dict(i.header)\n            if self.index_order == \"C\":\n                header = self._convert_f_to_c_order(header)\n            header[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(header)\n\n            if self.affine_lps_to_ras:\n                header = self._switch_lps_ras(header)\n            if header.get(MetaKeys.SPACE, \"left-posterior-superior\") == \"left-posterior-superior\":\n                header[MetaKeys.SPACE] = SpaceKeys.LPS  # assuming LPS if not specified\n\n            header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy()\n            header[MetaKeys.SPATIAL_SHAPE] = header[\"sizes\"].copy()\n            [header.pop(k) for k in (\"sizes\", \"space origin\", \"space directions\")]  # rm duplicated data in header\n\n            if self.channel_dim is None:  # default to \"no_channel\" or -1\n                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = (\n                    float(\"nan\") if len(data.shape) == len(header[MetaKeys.SPATIAL_SHAPE]) else 0\n                )\n            else:\n                header[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim\n            _copy_compatible_dict(header, compatible_meta)\n\n        return _stack_images(img_array, compatible_meta), compatible_meta\n\n    def _get_affine(self, header: dict) -> np.ndarray:\n        \"\"\"\n        Get the affine matrix of the image, it can be used to correct\n        spacing, orientation or execute spatial transforms.\n\n        Args:\n            img: A `NrrdImage` loaded from image file\n\n        \"\"\"\n        direction = header[\"space directions\"]\n        origin = header[\"space origin\"]\n\n        x, y = direction.shape\n        affine_diam = min(x, y) + 1\n        affine: np.ndarray = np.eye(affine_diam)\n        affine[:x, :y] = direction.T\n        affine[: (affine_diam - 1), -1] = origin  # len origin is always affine_diam - 1\n        return affine\n\n    def _switch_lps_ras(self, header: dict) -> dict:\n        \"\"\"\n        For compatibility with nibabel, switch from LPS to RAS. Adapt affine matrix and\n        `space` argument in header accordingly. If no information of space is given in the header,\n        LPS is assumed and thus converted to RAS. If information about space is given,\n        but is not LPS, the unchanged header is returned.\n\n        Args:\n            header: The image metadata as dict\n\n        \"\"\"\n        if \"space\" not in header or header[\"space\"] == \"left-posterior-superior\":\n            header[MetaKeys.ORIGINAL_AFFINE] = orientation_ras_lps(header[MetaKeys.ORIGINAL_AFFINE])\n            header[MetaKeys.SPACE] = SpaceKeys.RAS\n        return header\n\n    def _convert_f_to_c_order(self, header: dict) -> dict:\n        \"\"\"\n        All header fields of a NRRD are specified in `F` (Fortran) order, even if the image was read as C-ordered array.\n        1D arrays of header['space origin'] and header['sizes'] become inverted, e.g, [1,2,3] -> [3,2,1]\n        The 2D Array for header['space directions'] is transposed: [[1,0,0],[0,2,0],[0,0,3]] -> [[3,0,0],[0,2,0],[0,0,1]]\n        For more details refer to: https://pynrrd.readthedocs.io/en/latest/user-guide.html#index-ordering\n\n        Args:\n            header: The image metadata as dict\n\n        \"\"\"\n\n        header[\"space directions\"] = np.rot90(np.flip(header[\"space directions\"], 0))\n        header[\"space origin\"] = header[\"space origin\"][::-1]\n        header[\"sizes\"] = header[\"sizes\"][::-1]\n        return header\n"
  },
  {
    "path": "monai/data/image_writer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Mapping, Sequence\nfrom typing import TYPE_CHECKING, Any, cast\n\nimport numpy as np\n\nfrom monai.apps.utils import get_logger\nfrom monai.config import DtypeLike, NdarrayOrTensor, PathLike\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import affine_to_spacing, ensure_tuple, ensure_tuple_rep, orientation_ras_lps, to_affine_nd\nfrom monai.transforms.spatial.array import Resize, SpatialResample\nfrom monai.transforms.utils_pytorch_numpy_unification import ascontiguousarray, moveaxis\nfrom monai.utils import (\n    GridSampleMode,\n    GridSamplePadMode,\n    InterpolateMode,\n    MetaKeys,\n    OptionalImportError,\n    SpaceKeys,\n    convert_data_type,\n    convert_to_tensor,\n    get_equivalent_dtype,\n    look_up_option,\n    optional_import,\n    require_pkg,\n)\n\nDEFAULT_FMT = \"%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s\"\nEXT_WILDCARD = \"*\"\nlogger = get_logger(module_name=__name__, fmt=DEFAULT_FMT)\n\nif TYPE_CHECKING:\n    import itk\n    import nibabel as nib\n    from PIL import Image as PILImage\nelse:\n    itk, _ = optional_import(\"itk\", allow_namespace_pkg=True)\n    nib, _ = optional_import(\"nibabel\")\n    PILImage, _ = optional_import(\"PIL.Image\")\n\n__all__ = [\n    \"ImageWriter\",\n    \"ITKWriter\",\n    \"NibabelWriter\",\n    \"PILWriter\",\n    \"SUPPORTED_WRITERS\",\n    \"register_writer\",\n    \"resolve_writer\",\n    \"logger\",\n]\n\nSUPPORTED_WRITERS: dict = {}\n\n\ndef register_writer(ext_name, *im_writers):\n    \"\"\"\n    Register ``ImageWriter``, so that writing a file with filename extension ``ext_name``\n    could be resolved to a tuple of potentially appropriate ``ImageWriter``.\n    The customised writers could be registered by:\n\n    .. code-block:: python\n\n        from monai.data import register_writer\n        # `MyWriter` must implement `ImageWriter` interface\n        register_writer(\"nii\", MyWriter)\n\n    Args:\n        ext_name: the filename extension of the image.\n            As an indexing key, it will be converted to a lower case string.\n        im_writers: one or multiple ImageWriter classes with high priority ones first.\n    \"\"\"\n    fmt = f\"{ext_name}\".lower()\n    if fmt.startswith(\".\"):\n        fmt = fmt[1:]\n    existing = look_up_option(fmt, SUPPORTED_WRITERS, default=())\n    all_writers = im_writers + existing\n    SUPPORTED_WRITERS[fmt] = all_writers\n\n\ndef resolve_writer(ext_name, error_if_not_found=True) -> Sequence:\n    \"\"\"\n    Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS``\n    according to the filename extension key ``ext_name``.\n\n    Args:\n        ext_name: the filename extension of the image.\n            As an indexing key it will be converted to a lower case string.\n        error_if_not_found: whether to raise an error if no suitable image writer is found.\n            if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``.\n    \"\"\"\n    if not SUPPORTED_WRITERS:\n        init()\n    fmt = f\"{ext_name}\".lower()\n    if fmt.startswith(\".\"):\n        fmt = fmt[1:]\n    avail_writers = []\n    default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ())\n    for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers):\n        try:\n            _writer()  # this triggers `monai.utils.module.require_pkg` to check the system availability\n            avail_writers.append(_writer)\n        except OptionalImportError:\n            continue\n        except Exception:  # other writer init errors indicating it exists\n            avail_writers.append(_writer)\n    if not avail_writers and error_if_not_found:\n        raise OptionalImportError(f\"No ImageWriter backend found for {fmt}.\")\n    writer_tuple = ensure_tuple(avail_writers)\n    SUPPORTED_WRITERS[fmt] = writer_tuple\n    return writer_tuple\n\n\nclass ImageWriter:\n    \"\"\"\n    The class is a collection of utilities to write images to disk.\n\n    Main aspects to be considered are:\n\n        - dimensionality of the data array, arrangements of spatial dimensions and channel/time dimensions\n            - ``convert_to_channel_last()``\n        - metadata of the current affine and output affine, the data array should be converted accordingly\n            - ``get_meta_info()``\n            - ``resample_if_needed()``\n        - data type handling of the output image (as part of ``resample_if_needed()``)\n\n    Subclasses of this class should implement the backend-specific functions:\n\n        - ``set_data_array()`` to set the data array (input must be numpy array or torch tensor)\n            - this method sets the backend object's data part\n        - ``set_metadata()`` to set the metadata and output affine\n            - this method sets the metadata including affine handling and image resampling\n        - backend-specific data object ``create_backend_obj()``\n        - backend-specific writing function ``write()``\n\n    The primary usage of subclasses of ``ImageWriter`` is:\n\n    .. code-block:: python\n\n        writer = MyWriter()  # subclass of ImageWriter\n        writer.set_data_array(data_array)\n        writer.set_metadata(meta_dict)\n        writer.write(filename)\n\n    This creates an image writer object based on ``data_array`` and ``meta_dict`` and write to ``filename``.\n\n    It supports up to three spatial dimensions (with the resampling step supports for both 2D and 3D).\n    When saving multiple time steps or multiple channels `data_array`, time\n    and/or modality axes should be the at the `channel_dim`. For example,\n    the shape of a 2D eight-class and ``channel_dim=0``, the segmentation\n    probabilities to be saved could be `(8, 64, 64)`; in this case\n    ``data_array`` will be converted to `(64, 64, 1, 8)` (the third\n    dimension is reserved as a spatial dimension).\n\n    The ``metadata`` could optionally have the following keys:\n\n        - ``'original_affine'``: for data original affine, it will be the\n            affine of the output object, defaulting to an identity matrix.\n        - ``'affine'``: it should specify the current data affine, defaulting to an identity matrix.\n        - ``'spatial_shape'``: for data output spatial shape.\n\n    When ``metadata`` is specified, the saver will may resample data from the space defined by\n    `\"affine\"` to the space defined by `\"original_affine\"`, for more details, please refer to the\n    ``resample_if_needed`` method.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        \"\"\"\n        The constructor supports adding new instance members.\n        The current member in the base class is ``self.data_obj``, the subclasses can add more members,\n        so that necessary meta information can be stored in the object and shared among the class methods.\n        \"\"\"\n        self.data_obj: Any | NdarrayOrTensor = None\n        for k, v in kwargs.items():\n            setattr(self, k, v)\n\n    def set_data_array(self, data_array, **kwargs):\n        raise NotImplementedError(f\"Subclasses of {self.__class__.__name__} must implement this method.\")\n\n    def set_metadata(self, meta_dict: Mapping | None, **options):\n        raise NotImplementedError(f\"Subclasses of {self.__class__.__name__} must implement this method.\")\n\n    def write(self, filename: PathLike, verbose: bool = True, **kwargs):\n        \"\"\"subclass should implement this method to call the backend-specific writing APIs.\"\"\"\n        if verbose:\n            logger.info(f\"writing: {filename}\")\n\n    @classmethod\n    def create_backend_obj(cls, data_array: NdarrayOrTensor, **kwargs) -> np.ndarray:\n        \"\"\"\n        Subclass should implement this method to return a backend-specific data representation object.\n        This method is used by ``cls.write`` and the input ``data_array`` is assumed 'channel-last'.\n        \"\"\"\n        return convert_data_type(data_array, np.ndarray)[0]\n\n    @classmethod\n    def resample_if_needed(\n        cls,\n        data_array: NdarrayOrTensor,\n        affine: NdarrayOrTensor | None = None,\n        target_affine: NdarrayOrTensor | None = None,\n        output_spatial_shape: Sequence[int] | int | None = None,\n        mode: str = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        align_corners: bool = False,\n        dtype: DtypeLike = np.float64,\n    ):\n        \"\"\"\n        Convert the ``data_array`` into the coordinate system specified by\n        ``target_affine``, from the current coordinate definition of ``affine``.\n\n        If the transform between ``affine`` and ``target_affine`` could be\n        achieved by simply transposing and flipping ``data_array``, no resampling\n        will happen.  Otherwise, this function resamples ``data_array`` using the\n        transformation computed from ``affine`` and ``target_affine``.\n\n        This function assumes the NIfTI dimension notations. Spatially it\n        supports up to three dimensions, that is, H, HW, HWD for 1D, 2D, 3D\n        respectively. When saving multiple time steps or multiple channels,\n        time and/or modality axes should be appended after the first three\n        dimensions. For example, shape of 2D eight-class segmentation\n        probabilities to be saved could be `(64, 64, 1, 8)`. Also, data in\n        shape `(64, 64, 8)` or `(64, 64, 8, 1)` will be considered as a\n        single-channel 3D image. The ``convert_to_channel_last`` method can be\n        used to convert the data to the format described here.\n\n        Note that the shape of the resampled ``data_array`` may subject to some\n        rounding errors. For example, resampling a 20x20 pixel image from pixel\n        size (1.5, 1.5)-mm to (3.0, 3.0)-mm space will return a 10x10-pixel\n        image. However, resampling a 20x20-pixel image from pixel size (2.0,\n        2.0)-mm to (3.0, 3.0)-mm space will output a 14x14-pixel image, where\n        the image shape is rounded from 13.333x13.333 pixels. In this case\n        ``output_spatial_shape`` could be specified so that this function\n        writes image data to a designated shape.\n\n        Args:\n            data_array: input data array to be converted.\n            affine: the current affine of ``data_array``. Defaults to identity\n            target_affine: the designated affine of ``data_array``.\n                The actual output affine might be different from this value due to precision changes.\n            output_spatial_shape: spatial shape of the output image.\n                This option is used when resampling is needed.\n            mode: available options are {``\"bilinear\"``, ``\"nearest\"``, ``\"bicubic\"``}.\n                This option is used when resampling is needed.\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n            padding_mode: available options are {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}.\n                This option is used when resampling is needed.\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n            align_corners: boolean option of ``grid_sample`` to handle the corner convention.\n                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n            dtype: data type for resampling computation. Defaults to\n                ``np.float64`` for best precision. If ``None``, use the data type of input data.\n                The output data type of this method is always ``np.float32``.\n        \"\"\"\n        orig_type = type(data_array)\n        data_array = convert_to_tensor(data_array, track_meta=True)\n        if affine is not None:\n            data_array.affine = convert_to_tensor(affine, track_meta=False)  # type: ignore\n        resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype)\n        output_array = resampler(\n            data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape  # type: ignore\n        )\n        # convert back at the end\n        if isinstance(output_array, MetaTensor):\n            output_array.applied_operations = []\n        data_array, *_ = convert_data_type(output_array, output_type=orig_type)\n        affine, *_ = convert_data_type(output_array.affine, output_type=orig_type)  # type: ignore\n        return data_array[0], affine\n\n    @classmethod\n    def convert_to_channel_last(\n        cls,\n        data: NdarrayOrTensor,\n        channel_dim: None | int | Sequence[int] = 0,\n        squeeze_end_dims: bool = True,\n        spatial_ndim: int | None = 3,\n        contiguous: bool = False,\n    ):\n        \"\"\"\n        Rearrange the data array axes to make the `channel_dim`-th dim the last\n        dimension and ensure there are ``spatial_ndim`` number of spatial\n        dimensions.\n\n        When ``squeeze_end_dims`` is ``True``, a postprocessing step will be\n        applied to remove any trailing singleton dimensions.\n\n        Args:\n            data: input data to be converted to \"channel-last\" format.\n            channel_dim: specifies the channel axes of the data array to move to the last.\n                ``None`` indicates no channel dimension, a new axis will be appended as the channel dimension.\n                a sequence of integers indicates multiple non-spatial dimensions.\n            squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel\n                has been moved to the end). So if input is `(H,W,D,C)` and C==1, then it will be saved as `(H,W,D)`.\n                If D is also 1, it will be saved as `(H,W)`. If ``False``, image will always be saved as `(H,W,D,C)`.\n            spatial_ndim: modifying the spatial dims if needed, so that output to have at least\n                this number of spatial dims. If ``None``, the output will have the same number of\n                spatial dimensions as the input.\n            contiguous: if ``True``, the output will be contiguous.\n        \"\"\"\n        # change data to \"channel last\" format\n        if channel_dim is not None:\n            _chns = ensure_tuple(channel_dim)\n            data = moveaxis(data, _chns, tuple(range(-len(_chns), 0)))\n        else:  # adds a channel dimension\n            data = data[..., None]\n        # To ensure at least ``spatial_ndim`` number of spatial dims\n        if spatial_ndim:\n            while len(data.shape) < spatial_ndim + 1:  # assuming the data has spatial + channel dims\n                data = data[..., None, :]\n            while len(data.shape) > spatial_ndim + 1:\n                data = data[..., 0, :]\n        # if desired, remove trailing singleton dimensions\n        while squeeze_end_dims and data.shape[-1] == 1:\n            data = np.squeeze(data, -1)\n        if contiguous:\n            data = ascontiguousarray(data)\n        return data\n\n    @classmethod\n    def get_meta_info(cls, metadata: Mapping | None = None):\n        \"\"\"\n        Extracts relevant meta information from the metadata object (using ``.get``).\n        Optional keys are ``\"spatial_shape\"``, ``MetaKeys.AFFINE``, ``\"original_affine\"``.\n        \"\"\"\n        if not metadata:\n            metadata = {\"original_affine\": None, MetaKeys.AFFINE: None, MetaKeys.SPATIAL_SHAPE: None}\n        original_affine = metadata.get(\"original_affine\")\n        affine = metadata.get(MetaKeys.AFFINE)\n        spatial_shape = metadata.get(MetaKeys.SPATIAL_SHAPE)\n        return original_affine, affine, spatial_shape\n\n\n@require_pkg(pkg_name=\"itk\")\nclass ITKWriter(ImageWriter):\n    \"\"\"\n    Write data and metadata into files on disk using ITK-python.\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.data import ITKWriter\n\n        np_data = np.arange(48).reshape(3, 4, 4)\n\n        # write as 3d spatial image no channel\n        writer = ITKWriter(output_dtype=np.float32)\n        writer.set_data_array(np_data, channel_dim=None)\n        # optionally set metadata affine\n        writer.set_metadata({\"affine\": np.eye(4), \"original_affine\": -1 * np.eye(4)})\n        writer.write(\"test1.nii.gz\")\n\n        # write as 2d image, channel-first\n        writer = ITKWriter(output_dtype=np.uint8)\n        writer.set_data_array(np_data, channel_dim=0)\n        writer.set_metadata({\"spatial_shape\": (5, 5)})\n        writer.write(\"test1.png\")\n\n    \"\"\"\n\n    output_dtype: DtypeLike = None\n    channel_dim: int | None\n\n    def __init__(self, output_dtype: DtypeLike = np.float32, affine_lps_to_ras: bool | None = True, **kwargs):\n        \"\"\"\n        Args:\n            output_dtype: output data type.\n            affine_lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to ``True``.\n                Set to ``True`` to be consistent with ``NibabelWriter``,\n                otherwise the affine matrix is assumed already in the ITK convention.\n                Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.\n            kwargs: keyword arguments passed to ``ImageWriter``.\n\n        The constructor will create ``self.output_dtype`` internally.\n        ``affine`` and ``channel_dim`` are initialized as instance members (default ``None``, ``0``):\n\n            - user-specified ``affine`` should be set in ``set_metadata``,\n            - user-specified ``channel_dim`` should be set in ``set_data_array``.\n        \"\"\"\n        super().__init__(\n            output_dtype=output_dtype, affine_lps_to_ras=affine_lps_to_ras, affine=None, channel_dim=0, **kwargs\n        )\n\n    def set_data_array(\n        self, data_array: NdarrayOrTensor, channel_dim: int | None = 0, squeeze_end_dims: bool = True, **kwargs\n    ):\n        \"\"\"\n        Convert ``data_array`` into 'channel-last' numpy ndarray.\n\n        Args:\n            data_array: input data array with the channel dimension specified by ``channel_dim``.\n            channel_dim: channel dimension of the data array. Defaults to 0.\n                ``None`` indicates data without any channel dimension.\n            squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed.\n            kwargs: keyword arguments passed to ``self.convert_to_channel_last``,\n                currently support ``spatial_ndim`` and ``contiguous``, defaulting to ``3`` and ``False`` respectively.\n        \"\"\"\n        n_chns = data_array.shape[channel_dim] if channel_dim is not None else 0\n        self.data_obj = self.convert_to_channel_last(\n            data=data_array,\n            channel_dim=channel_dim,\n            squeeze_end_dims=squeeze_end_dims,\n            spatial_ndim=kwargs.pop(\"spatial_ndim\", 3),\n            contiguous=kwargs.pop(\"contiguous\", True),\n        )\n        self.channel_dim = -1  # in most cases, the data is set to channel last\n        if squeeze_end_dims and n_chns <= 1:  # num_channel==1 squeezed\n            self.channel_dim = None\n        if not squeeze_end_dims and n_chns < 1:  # originally no channel and convert_to_channel_last added a channel\n            self.channel_dim = None\n            self.data_obj = self.data_obj[..., 0]\n\n    def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options):\n        \"\"\"\n        Resample ``self.data_obj`` if needed.  This method assumes ``self.data_obj`` is a 'channel-last' ndarray.\n\n        Args:\n            meta_dict: a metadata dictionary for affine, original affine and spatial shape information.\n                Optional keys are ``\"spatial_shape\"``, ``\"affine\"``, ``\"original_affine\"``.\n            resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``).\n            options: keyword arguments passed to ``self.resample_if_needed``,\n                currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``,\n                defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively.\n        \"\"\"\n        original_affine, affine, spatial_shape = self.get_meta_info(meta_dict)\n        if self.output_dtype is None and hasattr(self.data_obj, \"dtype\"):  # pylint: disable=E0203\n            self.output_dtype = self.data_obj.dtype  # type: ignore\n        self.data_obj, self.affine = self.resample_if_needed(\n            data_array=cast(NdarrayOrTensor, self.data_obj),\n            affine=affine,\n            target_affine=original_affine if resample else None,\n            output_spatial_shape=spatial_shape if resample else None,\n            mode=options.pop(\"mode\", GridSampleMode.BILINEAR),\n            padding_mode=options.pop(\"padding_mode\", GridSamplePadMode.BORDER),\n            align_corners=options.pop(\"align_corners\", False),\n            dtype=options.pop(\"dtype\", np.float64),\n        )\n\n    def write(self, filename: PathLike, verbose: bool = False, **kwargs):\n        \"\"\"\n        Create an ITK object from ``self.create_backend_obj(self.obj, ...)`` and call ``itk.imwrite``.\n\n        Args:\n            filename: filename or PathLike object.\n            verbose: if ``True``, log the progress.\n            kwargs: keyword arguments passed to ``itk.imwrite``,\n                currently support ``compression`` and ``imageio``.\n\n        See also:\n\n            - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L809\n        \"\"\"\n        super().write(filename, verbose=verbose)\n        self.data_obj = self.create_backend_obj(\n            cast(NdarrayOrTensor, self.data_obj),\n            channel_dim=self.channel_dim,\n            affine=self.affine,\n            dtype=self.output_dtype,\n            affine_lps_to_ras=self.affine_lps_to_ras,  # type: ignore\n            **kwargs,\n        )\n        itk.imwrite(\n            self.data_obj, filename, compression=kwargs.pop(\"compression\", False), imageio=kwargs.pop(\"imageio\", None)\n        )\n\n    @classmethod\n    def create_backend_obj(\n        cls,\n        data_array: NdarrayOrTensor,\n        channel_dim: int | None = 0,\n        affine: NdarrayOrTensor | None = None,\n        dtype: DtypeLike = np.float32,\n        affine_lps_to_ras: bool | None = True,\n        **kwargs,\n    ):\n        \"\"\"\n        Create an ITK object from ``data_array``. This method assumes a 'channel-last' ``data_array``.\n\n        Args:\n            data_array: input data array.\n            channel_dim: channel dimension of the data array. This is used to create a Vector Image if it is not ``None``.\n            affine: affine matrix of the data array. This is used to compute `spacing`, `direction` and `origin`.\n            dtype: output data type.\n            affine_lps_to_ras: whether to convert the affine matrix from \"LPS\" to \"RAS\". Defaults to ``True``.\n                Set to ``True`` to be consistent with ``NibabelWriter``,\n                otherwise the affine matrix is assumed already in the ITK convention.\n                Set to ``None`` to use ``data_array.meta[MetaKeys.SPACE]`` to determine the flag.\n            kwargs: keyword arguments. Current `itk.GetImageFromArray` will read ``ttype`` from this dictionary.\n\n        see also:\n\n            - https://github.com/InsightSoftwareConsortium/ITK/blob/v5.2.1/Wrapping/Generators/Python/itk/support/extras.py#L389\n\n        \"\"\"\n        if isinstance(data_array, MetaTensor) and affine_lps_to_ras is None:\n            affine_lps_to_ras = (\n                data_array.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) != SpaceKeys.LPS\n            )  # do the converting from LPS to RAS only if the space type is currently LPS.\n        data_array = super().create_backend_obj(data_array)\n        _is_vec = channel_dim is not None\n        if _is_vec:\n            data_array = np.moveaxis(data_array, -1, 0)  # from channel last to channel first\n        data_array = data_array.T.astype(get_equivalent_dtype(dtype, np.ndarray), copy=True, order=\"C\")\n        itk_obj = itk.GetImageFromArray(data_array, is_vector=_is_vec, ttype=kwargs.pop(\"ttype\", None))\n\n        d = len(itk.size(itk_obj))\n        if affine is None:\n            affine = np.eye(d + 1, dtype=np.float64)\n        _affine = convert_data_type(affine, np.ndarray)[0]\n        if affine_lps_to_ras:\n            _affine = orientation_ras_lps(to_affine_nd(d, _affine))\n        spacing = affine_to_spacing(_affine, r=d)\n        _direction: np.ndarray = np.diag(1 / spacing)\n        _direction = _affine[:d, :d] @ _direction\n        itk_obj.SetSpacing(spacing.tolist())\n        itk_obj.SetOrigin(_affine[:d, -1].tolist())\n        itk_obj.SetDirection(itk.GetMatrixFromArray(_direction))\n        return itk_obj\n\n\n@require_pkg(pkg_name=\"nibabel\")\nclass NibabelWriter(ImageWriter):\n    \"\"\"\n    Write data and metadata into files on disk using Nibabel.\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.data import NibabelWriter\n\n        np_data = np.arange(48).reshape(3, 4, 4)\n        writer = NibabelWriter()\n        writer.set_data_array(np_data, channel_dim=None)\n        writer.set_metadata({\"affine\": np.eye(4), \"original_affine\": np.eye(4)})\n        writer.write(\"test1.nii.gz\", verbose=True)\n\n    \"\"\"\n\n    output_dtype: DtypeLike\n    affine: Any\n\n    def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs):\n        \"\"\"\n        Args:\n            output_dtype: output data type.\n            kwargs: keyword arguments passed to ``ImageWriter``.\n\n        The constructor will create ``self.output_dtype`` internally.\n        ``affine`` is initialized as instance members (default ``None``),\n        user-specified ``affine`` should be set in ``set_metadata``.\n        \"\"\"\n        super().__init__(output_dtype=output_dtype, affine=None, **kwargs)\n\n    def set_data_array(\n        self, data_array: NdarrayOrTensor, channel_dim: int | None = 0, squeeze_end_dims: bool = True, **kwargs\n    ):\n        \"\"\"\n        Convert ``data_array`` into 'channel-last' numpy ndarray.\n\n        Args:\n            data_array: input data array with the channel dimension specified by ``channel_dim``.\n            channel_dim: channel dimension of the data array. Defaults to 0.\n                ``None`` indicates data without any channel dimension.\n            squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed.\n            kwargs: keyword arguments passed to ``self.convert_to_channel_last``,\n                currently support ``spatial_ndim``, defaulting to ``3``.\n        \"\"\"\n        self.data_obj = self.convert_to_channel_last(\n            data=data_array,\n            channel_dim=channel_dim,\n            squeeze_end_dims=squeeze_end_dims,\n            spatial_ndim=kwargs.pop(\"spatial_ndim\", 3),\n        )\n\n    def set_metadata(self, meta_dict: Mapping | None, resample: bool = True, **options):\n        \"\"\"\n        Resample ``self.data_obj`` if needed.  This method assumes ``self.data_obj`` is a 'channel-last' ndarray.\n\n        Args:\n            meta_dict: a metadata dictionary for affine, original affine and spatial shape information.\n                Optional keys are ``\"spatial_shape\"``, ``\"affine\"``, ``\"original_affine\"``.\n            resample: if ``True``, the data will be resampled to the original affine (specified in ``meta_dict``).\n            options: keyword arguments passed to ``self.resample_if_needed``,\n                currently support ``mode``, ``padding_mode``, ``align_corners``, and ``dtype``,\n                defaulting to ``bilinear``, ``border``, ``False``, and ``np.float64`` respectively.\n        \"\"\"\n        original_affine, affine, spatial_shape = self.get_meta_info(meta_dict)\n        if (\n            self.output_dtype is None and self.data_obj is not None and hasattr(self.data_obj, \"dtype\")\n        ):  # pylint: disable=E0203\n            self.output_dtype = self.data_obj.dtype  # type: ignore\n        self.data_obj, self.affine = self.resample_if_needed(\n            data_array=cast(NdarrayOrTensor, self.data_obj),\n            affine=affine,\n            target_affine=original_affine if resample else None,\n            output_spatial_shape=spatial_shape if resample else None,\n            mode=options.pop(\"mode\", GridSampleMode.BILINEAR),\n            padding_mode=options.pop(\"padding_mode\", GridSamplePadMode.BORDER),\n            align_corners=options.pop(\"align_corners\", False),\n            dtype=options.pop(\"dtype\", np.float64),\n        )\n\n    def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs):\n        \"\"\"\n        Create a Nibabel object from ``self.create_backend_obj(self.obj, ...)`` and call ``nib.save``.\n\n        Args:\n            filename: filename or PathLike object.\n            verbose: if ``True``, log the progress.\n            obj_kwargs: keyword arguments passed to ``self.create_backend_obj``,\n\n        See also:\n\n            - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.save\n        \"\"\"\n        super().write(filename, verbose=verbose)\n        self.data_obj = self.create_backend_obj(\n            cast(NdarrayOrTensor, self.data_obj), affine=self.affine, dtype=self.output_dtype, **obj_kwargs\n        )\n        if self.affine is None:\n            self.affine = np.eye(4)\n        # ITK v5.2.1/Modules/IO/NIFTI/src/itkNiftiImageIO.cxx#L2175-L2176\n        _affine = to_affine_nd(r=3, affine=convert_data_type(self.affine, np.ndarray)[0])\n        self.data_obj.set_sform(_affine, code=1)\n        self.data_obj.set_qform(_affine, code=1)\n        nib.save(self.data_obj, filename)\n\n    @classmethod\n    def create_backend_obj(\n        cls, data_array: NdarrayOrTensor, affine: NdarrayOrTensor | None = None, dtype: DtypeLike = None, **kwargs\n    ):\n        \"\"\"\n        Create an Nifti1Image object from ``data_array``. This method assumes a 'channel-last' ``data_array``.\n\n        Args:\n            data_array: input data array.\n            affine: affine matrix of the data array.\n            dtype: output data type.\n            kwargs: keyword arguments. Current ``nib.nifti1.Nifti1Image`` will read\n                ``header``, ``extra``, ``file_map`` from this dictionary.\n\n        See also:\n\n            - https://nipy.org/nibabel/reference/nibabel.nifti1.html#nibabel.nifti1.Nifti1Image\n        \"\"\"\n        data_array = super().create_backend_obj(data_array)\n        if dtype is not None:\n            data_array = data_array.astype(get_equivalent_dtype(dtype, np.ndarray), copy=False)\n        affine = convert_data_type(affine, np.ndarray)[0]\n        if affine is None:\n            affine = np.eye(4)\n        affine = to_affine_nd(r=3, affine=affine)\n        return nib.nifti1.Nifti1Image(\n            data_array,\n            affine,\n            header=kwargs.pop(\"header\", None),\n            extra=kwargs.pop(\"extra\", None),\n            file_map=kwargs.pop(\"file_map\", None),\n        )\n\n\n@require_pkg(pkg_name=\"PIL\")\nclass PILWriter(ImageWriter):\n    \"\"\"\n    Write image data into files on disk using pillow.\n\n    It's based on the Image module in PIL library:\n    https://pillow.readthedocs.io/en/stable/reference/Image.html\n\n    .. code-block:: python\n\n        import numpy as np\n        from monai.data import PILWriter\n\n        np_data = np.arange(48).reshape(3, 4, 4)\n        writer = PILWriter(np.uint8)\n        writer.set_data_array(np_data, channel_dim=0)\n        writer.write(\"test1.png\", verbose=True)\n    \"\"\"\n\n    output_dtype: DtypeLike\n    channel_dim: int | None\n    scale: int | None\n\n    def __init__(\n        self, output_dtype: DtypeLike = np.float32, channel_dim: int | None = 0, scale: int | None = 255, **kwargs\n    ):\n        \"\"\"\n        Args:\n            output_dtype: output data type.\n            channel_dim: channel dimension of the data array. Defaults to 0.\n                ``None`` indicates data without any channel dimension.\n            scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling\n                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.\n            kwargs: keyword arguments passed to ``ImageWriter``.\n        \"\"\"\n        super().__init__(output_dtype=output_dtype, channel_dim=channel_dim, scale=scale, **kwargs)\n\n    def set_data_array(\n        self,\n        data_array: NdarrayOrTensor,\n        channel_dim: int | None = 0,\n        squeeze_end_dims: bool = True,\n        contiguous: bool = False,\n        **kwargs,\n    ):\n        \"\"\"\n        Convert ``data_array`` into 'channel-last' numpy ndarray.\n\n        Args:\n            data_array: input data array with the channel dimension specified by ``channel_dim``.\n            channel_dim: channel dimension of the data array. Defaults to 0.\n                ``None`` indicates data without any channel dimension.\n            squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed.\n            contiguous: if ``True``, the data array will be converted to a contiguous array. Default is ``False``.\n            kwargs: keyword arguments passed to ``self.convert_to_channel_last``,\n                currently support ``spatial_ndim``, defaulting to ``2``.\n        \"\"\"\n        self.data_obj = self.convert_to_channel_last(\n            data=data_array,\n            channel_dim=channel_dim,\n            squeeze_end_dims=squeeze_end_dims,\n            spatial_ndim=kwargs.pop(\"spatial_ndim\", 2),\n            contiguous=contiguous,\n        )\n\n    def set_metadata(self, meta_dict: Mapping | None = None, resample: bool = True, **options):\n        \"\"\"\n        Resample ``self.data_obj`` if needed.  This method assumes ``self.data_obj`` is a 'channel-last' ndarray.\n\n        Args:\n            meta_dict: a metadata dictionary for affine, original affine and spatial shape information.\n                Optional key is ``\"spatial_shape\"``.\n            resample: if ``True``, the data will be resampled to the spatial shape specified in ``meta_dict``.\n            options: keyword arguments passed to ``self.resample_if_needed``,\n                currently support ``mode``, defaulting to ``bicubic``.\n        \"\"\"\n        spatial_shape = self.get_meta_info(meta_dict)\n        if self.output_dtype is None and hasattr(self.data_obj, \"dtype\"):  # pylint: disable=E0203\n            self.output_dtype = self.data_obj.dtype  # type: ignore\n        self.data_obj = self.resample_and_clip(\n            data_array=self.data_obj,\n            output_spatial_shape=spatial_shape if resample else None,\n            mode=options.pop(\"mode\", InterpolateMode.BICUBIC),\n        )\n\n    def write(self, filename: PathLike, verbose: bool = False, **kwargs):\n        \"\"\"\n        Create a PIL image object from ``self.create_backend_obj(self.obj, ...)`` and call ``save``.\n\n        Args:\n            filename: filename or PathLike object.\n            verbose: if ``True``, log the progress.\n            kwargs: optional keyword arguments passed to ``self.create_backend_obj``\n                currently support ``reverse_indexing``, ``image_mode``, defaulting to ``True``, ``None`` respectively.\n\n        See also:\n\n            - https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save\n        \"\"\"\n        super().write(filename, verbose=verbose)\n        self.data_obj = self.create_backend_obj(\n            data_array=self.data_obj,\n            dtype=self.output_dtype,\n            reverse_indexing=kwargs.pop(\"reverse_indexing\", True),\n            image_mode=kwargs.pop(\"image_mode\", None),\n            scale=self.scale,\n            **kwargs,\n        )\n        self.data_obj.save(filename, **kwargs)\n\n    @classmethod\n    def get_meta_info(cls, metadata: Mapping | None = None):\n        return None if not metadata else metadata.get(MetaKeys.SPATIAL_SHAPE)\n\n    @classmethod\n    def resample_and_clip(\n        cls,\n        data_array: NdarrayOrTensor,\n        output_spatial_shape: Sequence[int] | None = None,\n        mode: str = InterpolateMode.BICUBIC,\n    ) -> np.ndarray:\n        \"\"\"\n        Resample ``data_array`` to ``output_spatial_shape`` if needed.\n        Args:\n            data_array: input data array. This method assumes the 'channel-last' format.\n            output_spatial_shape: output spatial shape.\n            mode: interpolation mode, default is ``InterpolateMode.BICUBIC``.\n        \"\"\"\n\n        data: np.ndarray = convert_data_type(data_array, np.ndarray)[0]\n        if output_spatial_shape is not None:\n            output_spatial_shape_ = ensure_tuple_rep(output_spatial_shape, 2)\n            mode = look_up_option(mode, InterpolateMode)\n            align_corners = None if mode in (InterpolateMode.NEAREST, InterpolateMode.AREA) else False\n            xform = Resize(spatial_size=output_spatial_shape_, mode=mode, align_corners=align_corners)\n            _min, _max = np.min(data), np.max(data)\n            if len(data.shape) == 3:\n                data = np.moveaxis(data, -1, 0)  # to channel first\n                data = convert_data_type(xform(data), np.ndarray)[0]  # type: ignore\n                data = np.moveaxis(data, 0, -1)\n            else:  # (H, W)\n                data = np.expand_dims(data, 0)  # make a channel\n                data = convert_data_type(xform(data), np.ndarray)[0][0]  # type: ignore\n            if mode != InterpolateMode.NEAREST:\n                data = np.clip(data, _min, _max)\n        return data\n\n    @classmethod\n    def create_backend_obj(\n        cls,\n        data_array: NdarrayOrTensor,\n        dtype: DtypeLike = None,\n        scale: int | None = 255,\n        reverse_indexing: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n        Create a PIL object from ``data_array``.\n\n        Args:\n            data_array: input data array.\n            dtype: output data type.\n            scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling\n                [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling.\n            reverse_indexing: if ``True``, the data array's first two dimensions will be swapped.\n            kwargs: keyword arguments. Currently ``PILImage.fromarray`` will read\n                ``image_mode`` from this dictionary, defaults to ``None``.\n\n        See also:\n\n            - https://pillow.readthedocs.io/en/stable/reference/Image.html\n        \"\"\"\n        data: np.ndarray = super().create_backend_obj(data_array)\n        if scale:\n            # scale the data to be in an integer range\n            data = np.clip(data, 0.0, 1.0)  # png writer only can scale data in range [0, 1]\n\n            if scale == np.iinfo(np.uint8).max:\n                data = (scale * data).astype(np.uint8, copy=False)\n            elif scale == np.iinfo(np.uint16).max:\n                data = (scale * data).astype(np.uint16, copy=False)\n            else:\n                raise ValueError(f\"Unsupported scale: {scale}, available options are [255, 65535].\")\n        if dtype is not None:\n            data = data.astype(get_equivalent_dtype(dtype, np.ndarray), copy=False)\n        if reverse_indexing:\n            data = np.moveaxis(data, 0, 1)\n\n        return PILImage.fromarray(data, mode=kwargs.pop(\"image_mode\", None))\n\n\ndef init():\n    \"\"\"\n    Initialize the image writer modules according to the filename extension.\n    \"\"\"\n    for ext in (\"png\", \"jpg\", \"jpeg\", \"bmp\", \"tiff\", \"tif\"):\n        register_writer(ext, PILWriter)  # TODO: test 16-bit\n    for ext in (\"nii.gz\", \"nii\"):\n        register_writer(ext, NibabelWriter, ITKWriter)\n    register_writer(\"nrrd\", ITKWriter, NibabelWriter)\n    register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter)\n"
  },
  {
    "path": "monai/data/iterable_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Iterable, Iterator, Sequence\nfrom typing import Any\n\nfrom torch.utils.data import IterableDataset as _TorchIterableDataset\nfrom torch.utils.data import get_worker_info\n\nfrom monai.data.utils import convert_tables_to_dicts\nfrom monai.transforms import apply_transform\nfrom monai.transforms.transform import Randomizable\nfrom monai.utils import optional_import\n\npd, _ = optional_import(\"pandas\")\n\n\nclass IterableDataset(_TorchIterableDataset):\n    \"\"\"\n    A generic dataset for iterable data source and an optional callable data transform\n    when fetching a data sample. Inherit from PyTorch IterableDataset:\n    https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset.\n    For example, typical input data can be web data stream which can support multi-process access.\n\n    To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,\n    every process executes transforms on part of every loaded data.\n    Note that the order of output data may not match data source in multi-processing mode.\n    And each worker process will have a different copy of the dataset object, need to guarantee\n    process-safe from data source or DataLoader.\n\n    \"\"\"\n\n    def __init__(self, data: Iterable[Any], transform: Callable | None = None) -> None:\n        \"\"\"\n        Args:\n            data: input data source to load and transform to generate dataset for model.\n            transform: a callable data transform on input data.\n        \"\"\"\n        self.data = data\n        self.transform = transform\n        self.source: Iterator[Any] | None = None\n\n    def __iter__(self):\n        info = get_worker_info()\n        num_workers = info.num_workers if info is not None else 1\n        id = info.id if info is not None else 0\n\n        self.source = iter(self.data)\n        for i, item in enumerate(self.source):\n            if i % num_workers == id:\n                if self.transform is not None:\n                    item = apply_transform(self.transform, item)\n                yield item\n\n\nclass ShuffleBuffer(Randomizable, IterableDataset):\n    \"\"\"\n    Extend the IterableDataset with a buffer and randomly pop items.\n\n    Args:\n        data: input data source to load and transform to generate dataset for model.\n        transform: a callable data transform on input data.\n        buffer_size: size of the buffer to store items and randomly pop, default to 512.\n        seed: random seed to initialize the random state of all workers, set `seed += 1` in\n            every iter() call, refer to the PyTorch idea:\n            https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.\n        epochs: number of epochs to iterate over the dataset, default to 1, -1 means infinite epochs.\n\n    Note:\n        Both ``monai.data.DataLoader`` and ``torch.utils.data.DataLoader`` do not seed this class (as a subclass of\n        ``IterableDataset``) at run time. ``persistent_workers=True`` flag (and pytorch>1.8) is therefore required\n        for multiple epochs of loading when ``num_workers>0``. For example::\n\n            import monai\n\n            def run():\n                dss = monai.data.ShuffleBuffer([1, 2, 3, 4], buffer_size=30, seed=42)\n\n                dataloader = monai.data.DataLoader(\n                    dss, batch_size=1, num_workers=2, persistent_workers=True)\n                for epoch in range(3):\n                    for item in dataloader:\n                        print(f\"epoch: {epoch} item: {item}.\")\n\n            if __name__ == '__main__':\n                run()\n\n    \"\"\"\n\n    def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0, epochs: int = 1) -> None:\n        super().__init__(data=data, transform=transform)\n        self.size = buffer_size\n        self.seed = seed\n        self.epochs = epochs\n        self._idx = 0\n\n    def randomized_pop(self, buffer):\n        \"\"\"Return the item at a randomized location `self._idx` in `buffer`.\"\"\"\n        self.randomize(len(buffer))\n        ret, buffer[self._idx] = buffer[self._idx], buffer[-1]\n        buffer.pop()\n        return ret\n\n    def generate_item(self):\n        \"\"\"Fill a `buffer` list up to `self.size`, then generate randomly popped items.\"\"\"\n        buffer: list[Any] = []\n        for item in iter(self.data):\n            if len(buffer) >= self.size:\n                yield self.randomized_pop(buffer)\n            buffer.append(item)\n        while buffer:\n            yield self.randomized_pop(buffer)\n\n    def __iter__(self):\n        \"\"\"\n        Randomly pop buffered items from `self.data`.\n        Multiple dataloader workers sharing this dataset will generate identical item sequences.\n        \"\"\"\n        self.seed += 1\n        super().set_random_state(seed=self.seed)  # make all workers in sync\n        for _ in range(self.epochs) if self.epochs >= 0 else iter(int, 1):\n            yield from IterableDataset(self.generate_item(), transform=self.transform)\n\n    def randomize(self, size: int) -> None:\n        self._idx = self.R.randint(size)\n\n\nclass CSVIterableDataset(IterableDataset):\n    \"\"\"\n    Iterable dataset to load CSV files and generate dictionary data.\n    It is particularly useful when data come from a stream, inherits from PyTorch IterableDataset:\n    https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset.\n\n    It also can be helpful when loading extremely big CSV files that can't read into memory directly,\n    just treat the big CSV file as stream input, call `reset()` of `CSVIterableDataset` for every epoch.\n    Note that as a stream input, it can't get the length of dataset.\n\n    To effectively shuffle the data in the big dataset, users can set a big buffer to continuously store\n    the loaded data, then randomly pick data from the buffer for following tasks.\n\n    To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,\n    every process executes transforms on part of every loaded data.\n    Note: the order of output data may not match data source in multi-processing mode.\n\n    It can load data from multiple CSV files and join the tables with additional `kwargs` arg.\n    Support to only load specific columns.\n    And it can also group several loaded columns to generate a new column, for example,\n    set `col_groups={\"meta\": [\"meta_0\", \"meta_1\", \"meta_2\"]}`, output can be::\n\n        [\n            {\"image\": \"./image0.nii\", \"meta_0\": 11, \"meta_1\": 12, \"meta_2\": 13, \"meta\": [11, 12, 13]},\n            {\"image\": \"./image1.nii\", \"meta_0\": 21, \"meta_1\": 22, \"meta_2\": 23, \"meta\": [21, 22, 23]},\n        ]\n\n    Args:\n        src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.\n            also support to provide iter for stream input directly, will skip loading from filename.\n            if provided a list of filenames or iters, it will join the tables.\n        chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:\n            https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.\n        buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`.\n        col_names: names of the expected columns to load. if None, load all the columns.\n        col_types: `type` and `default value` to convert the loaded columns, if None, use original data.\n            it should be a dictionary, every item maps to an expected column, the `key` is the column\n            name and the `value` is None or a dictionary to define the default value and data type.\n            the supported keys in dictionary are: [\"type\", \"default\"]. for example::\n\n                col_types = {\n                    \"subject_id\": {\"type\": str},\n                    \"label\": {\"type\": int, \"default\": 0},\n                    \"ehr_0\": {\"type\": float, \"default\": 0.0},\n                    \"ehr_1\": {\"type\": float, \"default\": 0.0},\n                    \"image\": {\"type\": str, \"default\": None},\n                }\n\n        col_groups: args to group the loaded columns to generate a new column,\n            it should be a dictionary, every item maps to a group, the `key` will\n            be the new column name, the `value` is the names of columns to combine. for example:\n            `col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(10)], \"meta\": [\"meta_1\", \"meta_2\"]}`\n        transform: transform to apply on the loaded items of a dictionary data.\n        shuffle: whether to shuffle all the data in the buffer every time a new chunk loaded.\n        seed: random seed to initialize the random state for all the workers if `shuffle` is True,\n            set `seed += 1` in every iter() call, refer to the PyTorch idea:\n            https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.\n        kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. Default to ``{\"chunksize\": chunksize}``.\n        kwargs: additional arguments for `pandas.merge()` API to join tables.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        src: str | Sequence[str] | Iterable | Sequence[Iterable],\n        chunksize: int = 1000,\n        buffer_size: int | None = None,\n        col_names: Sequence[str] | None = None,\n        col_types: dict[str, dict[str, Any] | None] | None = None,\n        col_groups: dict[str, Sequence[str]] | None = None,\n        transform: Callable | None = None,\n        shuffle: bool = False,\n        seed: int = 0,\n        kwargs_read_csv: dict | None = None,\n        **kwargs,\n    ):\n        self.src = src\n        self.chunksize = chunksize\n        self.buffer_size = 2 * chunksize if buffer_size is None else buffer_size\n        self.col_names = col_names\n        self.col_types = col_types\n        self.col_groups = col_groups\n        self.shuffle = shuffle\n        self.seed = seed\n        self.kwargs_read_csv = kwargs_read_csv or {\"chunksize\": chunksize}\n        self.kwargs = kwargs\n\n        self.iters: list[Iterable] = self.reset()\n        super().__init__(data=None, transform=transform)  # type: ignore\n\n    def reset(self, src: str | Sequence[str] | Iterable | Sequence[Iterable] | None = None):\n        \"\"\"\n        Reset the pandas `TextFileReader` iterable object to read data. For more details, please check:\n        https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.\n\n        Args:\n            src: if not None and provided the filename of CSV file, it can be a str, URL, path object\n                or file-like object to load. also support to provide iter for stream input directly,\n                will skip loading from filename. if provided a list of filenames or iters, it will join the tables.\n                default to `self.src`.\n\n        \"\"\"\n        src = self.src if src is None else src\n        srcs = (src,) if not isinstance(src, (tuple, list)) else src\n        self.iters = []\n        for i in srcs:\n            if isinstance(i, str):\n                self.iters.append(pd.read_csv(i, **self.kwargs_read_csv))\n            elif isinstance(i, Iterable):\n                self.iters.append(i)\n            else:\n                raise ValueError(\"`src` must be file path or iterable object.\")\n        return self.iters\n\n    def close(self):\n        \"\"\"\n        Close the pandas `TextFileReader` iterable objects.\n        If the input src is file path, TextFileReader was created internally, need to close it.\n        If the input src is iterable object, depends on users requirements whether to close it in this function.\n        For more details, please check:\n        https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.\n\n        \"\"\"\n        for i in self.iters:\n            i.close()  # type: ignore\n\n    def _flattened(self):\n        for chunks in zip(*self.iters):\n            yield from convert_tables_to_dicts(\n                dfs=chunks,\n                col_names=self.col_names,\n                col_types=self.col_types,\n                col_groups=self.col_groups,\n                **self.kwargs,\n            )\n\n    def __iter__(self):\n        if self.shuffle:\n            self.seed += 1\n            buffer = ShuffleBuffer(\n                data=self._flattened(), transform=self.transform, buffer_size=self.buffer_size, seed=self.seed\n            )\n            yield from buffer\n        yield from IterableDataset(data=self._flattened(), transform=self.transform)\n"
  },
  {
    "path": "monai/data/itk_torch_bridge.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import DtypeLike\nfrom monai.data import ITKReader, ITKWriter\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import orientation_ras_lps\nfrom monai.transforms import EnsureChannelFirst\nfrom monai.utils import MetaKeys, SpaceKeys, convert_to_dst_type, optional_import\n\nif TYPE_CHECKING:\n    import itk\n\n    has_itk = True\nelse:\n    itk, has_itk = optional_import(\"itk\")\n\n__all__ = [\n    \"itk_image_to_metatensor\",\n    \"metatensor_to_itk_image\",\n    \"itk_to_monai_affine\",\n    \"monai_to_itk_affine\",\n    \"get_itk_image_center\",\n    \"monai_to_itk_ddf\",\n]\n\n\ndef itk_image_to_metatensor(\n    image, channel_dim: str | int | None = None, dtype: DtypeLike | torch.dtype = float\n) -> MetaTensor:\n    \"\"\"\n    Converts an ITK image to a MetaTensor object.\n\n    Args:\n        image: The ITK image to be converted.\n        channel_dim: the channel dimension of the input image, default is None.\n            This is used to set original_channel_dim in the metadata, EnsureChannelFirst reads this field.\n            If None, the channel_dim is inferred automatically.\n            If the input array doesn't have a channel dim, this value should be ``'no_channel'``.\n        dtype: output dtype, defaults to the Python built-in `float`.\n\n    Returns:\n        A MetaTensor object containing the array data and metadata in ChannelFirst format.\n    \"\"\"\n    reader = ITKReader(affine_lps_to_ras=False, channel_dim=channel_dim)\n    image_array, meta_data = reader.get_data(image)\n    image_array = convert_to_dst_type(image_array, dst=image_array, dtype=dtype)[0]\n    metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data)\n    metatensor = EnsureChannelFirst(channel_dim=channel_dim)(metatensor)\n\n    return cast(MetaTensor, metatensor)\n\n\ndef metatensor_to_itk_image(\n    meta_tensor: MetaTensor, channel_dim: int | None = 0, dtype: DtypeLike = np.float32, **kwargs\n):\n    \"\"\"\n    Converts a MetaTensor object to an ITK image. Expects the MetaTensor to be in ChannelFirst format.\n\n    Args:\n        meta_tensor: The MetaTensor to be converted.\n        channel_dim: channel dimension of the data array, defaults to ``0`` (Channel-first).\n            ``None`` indicates no channel dimension. This is used to create a Vector Image if it is not ``None``.\n        dtype: output data type, defaults to `np.float32`.\n        kwargs: additional keyword arguments. Currently `itk.GetImageFromArray` will get ``ttype`` from this dictionary.\n\n    Returns:\n        The ITK image.\n\n    See also: :py:func:`ITKWriter.create_backend_obj`\n    \"\"\"\n    if meta_tensor.meta.get(MetaKeys.SPACE, SpaceKeys.LPS) == SpaceKeys.RAS:\n        _meta_tensor = meta_tensor.clone()\n        _meta_tensor.affine = orientation_ras_lps(meta_tensor.affine)\n        _meta_tensor.meta[MetaKeys.SPACE] = SpaceKeys.LPS\n    else:\n        _meta_tensor = meta_tensor\n    writer = ITKWriter(output_dtype=dtype, affine_lps_to_ras=False)\n    writer.set_data_array(data_array=meta_tensor.data, channel_dim=channel_dim, squeeze_end_dims=True)\n    return writer.create_backend_obj(\n        writer.data_obj,\n        channel_dim=writer.channel_dim,\n        affine=_meta_tensor.affine,\n        affine_lps_to_ras=False,  # False if the affine is in itk convention\n        dtype=writer.output_dtype,\n        kwargs=kwargs,\n    )\n\n\ndef itk_to_monai_affine(image, matrix, translation, center_of_rotation=None, reference_image=None) -> torch.Tensor:\n    \"\"\"\n    Converts an ITK affine matrix (2x2 for 2D or 3x3 for 3D matrix and translation vector) to a MONAI affine matrix.\n\n    Args:\n        image: The ITK image object. This is used to extract the spacing and direction information.\n        matrix: The 2x2 or 3x3 ITK affine matrix.\n        translation: The 2-element or 3-element ITK affine translation vector.\n        center_of_rotation: The center of rotation. If provided, the affine\n                            matrix will be adjusted to account for the difference\n                            between the center of the image and the center of rotation.\n        reference_image: The coordinate space that matrix and translation were defined\n                         in respect to. If not supplied, the coordinate space of image\n                         is used.\n\n    Returns:\n        A 4x4 MONAI affine matrix.\n    \"\"\"\n\n    _assert_itk_regions_match_array(image)\n    ndim = image.ndim\n    # If there is a reference image, compute an affine matrix that maps the image space to the\n    # reference image space.\n    if reference_image:\n        reference_affine_matrix = _compute_reference_space_affine_matrix(image, reference_image)\n    else:\n        reference_affine_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n\n    # Create affine matrix that includes translation\n    affine_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    affine_matrix[:ndim, :ndim] = torch.tensor(matrix, dtype=torch.float64)\n    affine_matrix[:ndim, ndim] = torch.tensor(translation, dtype=torch.float64)\n\n    # Adjust offset when center of rotation is different from center of the image\n    if center_of_rotation:\n        offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation)\n        affine_matrix = inverse_offset_matrix @ affine_matrix @ offset_matrix\n\n    # Adjust direction\n    direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image)\n    affine_matrix = inverse_direction_matrix @ affine_matrix @ direction_matrix\n\n    # Adjust based on spacing. It is required because MONAI does not update the\n    # pixel array according to the spacing after a transformation. For example,\n    # a rotation of 90deg for an image with different spacing along the two axis\n    # will just rotate the image array by 90deg without also scaling accordingly.\n    spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image)\n    affine_matrix = inverse_spacing_matrix @ affine_matrix @ spacing_matrix\n\n    return affine_matrix @ reference_affine_matrix\n\n\ndef monai_to_itk_affine(image, affine_matrix, center_of_rotation=None):\n    \"\"\"\n    Converts a MONAI affine matrix to an ITK affine matrix (2x2 for 2D or 3x3 for\n    3D matrix and translation vector). See also 'itk_to_monai_affine'.\n\n    Args:\n        image: The ITK image object. This is used to extract the spacing and direction information.\n        affine_matrix: The 3x3 for 2D or 4x4 for 3D MONAI affine matrix.\n        center_of_rotation: The center of rotation. If provided, the affine\n                            matrix will be adjusted to account for the difference\n                            between the center of the image and the center of rotation.\n\n    Returns:\n        The ITK matrix and the translation vector.\n    \"\"\"\n    _assert_itk_regions_match_array(image)\n\n    # Adjust spacing\n    spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image)\n    affine_matrix = spacing_matrix @ affine_matrix @ inverse_spacing_matrix\n\n    # Adjust direction\n    direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image)\n    affine_matrix = direction_matrix @ affine_matrix @ inverse_direction_matrix\n\n    # Adjust offset when center of rotation is different from center of the image\n    if center_of_rotation:\n        offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation)\n        affine_matrix = offset_matrix @ affine_matrix @ inverse_offset_matrix\n\n    ndim = image.ndim\n    matrix = affine_matrix[:ndim, :ndim].numpy()\n    translation = affine_matrix[:ndim, ndim].tolist()\n\n    return matrix, translation\n\n\ndef get_itk_image_center(image):\n    \"\"\"\n    Calculates the center of the ITK image based on its origin, size, and spacing.\n    This center is equivalent to the implicit image center that MONAI uses.\n\n    Args:\n        image: The ITK image.\n\n    Returns:\n        The center of the image as a list of coordinates.\n    \"\"\"\n    image_size = np.asarray(image.GetLargestPossibleRegion().GetSize(), np.float32)\n    spacing = np.asarray(image.GetSpacing())\n    origin = np.asarray(image.GetOrigin())\n    center = image.GetDirection() @ ((image_size / 2 - 0.5) * spacing) + origin\n\n    return center.tolist()\n\n\ndef _assert_itk_regions_match_array(image):\n    # Note: Make it more compact? Also, are there redundant checks?\n    largest_region = image.GetLargestPossibleRegion()\n    buffered_region = image.GetBufferedRegion()\n    requested_region = image.GetRequestedRegion()\n\n    largest_region_size = np.array(largest_region.GetSize())\n    buffered_region_size = np.array(buffered_region.GetSize())\n    requested_region_size = np.array(requested_region.GetSize())\n    array_size = np.array(image.shape)[::-1]\n\n    largest_region_index = np.array(largest_region.GetIndex())\n    buffered_region_index = np.array(buffered_region.GetIndex())\n    requested_region_index = np.array(requested_region.GetIndex())\n\n    indices_are_zeros = (\n        np.all(largest_region_index == 0) and np.all(buffered_region_index == 0) and np.all(requested_region_index == 0)\n    )\n\n    sizes_match = (\n        np.array_equal(array_size, largest_region_size)\n        and np.array_equal(largest_region_size, buffered_region_size)\n        and np.array_equal(buffered_region_size, requested_region_size)\n    )\n\n    if not indices_are_zeros:\n        raise AssertionError(\"ITK-MONAI bridge: non-zero ITK region indices encountered\")\n    if not sizes_match:\n        raise AssertionError(\"ITK-MONAI bridge: ITK regions should be of the same shape\")\n\n\ndef _compute_offset_matrix(image, center_of_rotation) -> tuple[torch.Tensor, torch.Tensor]:\n    ndim = image.ndim\n    offset = np.asarray(get_itk_image_center(image)) - np.asarray(center_of_rotation)\n    offset_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    offset_matrix[:ndim, ndim] = torch.tensor(offset, dtype=torch.float64)\n    inverse_offset_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    inverse_offset_matrix[:ndim, ndim] = -torch.tensor(offset, dtype=torch.float64)\n\n    return offset_matrix, inverse_offset_matrix\n\n\ndef _compute_spacing_matrix(image) -> tuple[torch.Tensor, torch.Tensor]:\n    ndim = image.ndim\n    spacing = np.asarray(image.GetSpacing(), dtype=np.float64)\n    spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    inverse_spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    for i, e in enumerate(spacing):\n        spacing_matrix[i, i] = e\n        inverse_spacing_matrix[i, i] = 1 / e\n\n    return spacing_matrix, inverse_spacing_matrix\n\n\ndef _compute_direction_matrix(image) -> tuple[torch.Tensor, torch.Tensor]:\n    ndim = image.ndim\n    direction = itk.array_from_matrix(image.GetDirection())\n    direction_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    direction_matrix[:ndim, :ndim] = torch.tensor(direction, dtype=torch.float64)\n    inverse_direction = itk.array_from_matrix(image.GetInverseDirection())\n    inverse_direction_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n    inverse_direction_matrix[:ndim, :ndim] = torch.tensor(inverse_direction, dtype=torch.float64)\n\n    return direction_matrix, inverse_direction_matrix\n\n\ndef _compute_reference_space_affine_matrix(image, ref_image) -> torch.Tensor:\n    ndim = ref_image.ndim\n\n    # Spacing and direction as matrices\n    spacing_matrix, inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(image))\n    ref_spacing_matrix, ref_inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(ref_image))\n\n    direction_matrix, inv_direction_matrix = (m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(image))\n    ref_direction_matrix, ref_inv_direction_matrix = (\n        m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(ref_image)\n    )\n\n    # Matrix calculation\n    matrix = ref_direction_matrix @ ref_spacing_matrix @ inv_spacing_matrix @ inv_direction_matrix\n\n    # Offset calculation\n    pixel_offset = -1\n    image_size = np.asarray(ref_image.GetLargestPossibleRegion().GetSize(), np.float32)\n    translation = (\n        (ref_direction_matrix @ ref_spacing_matrix - direction_matrix @ spacing_matrix)\n        @ (image_size + pixel_offset)\n        / 2\n    )\n    translation += np.asarray(ref_image.GetOrigin()) - np.asarray(image.GetOrigin())\n\n    # Convert matrix ITK matrix and translation to MONAI affine matrix\n    ref_affine_matrix = itk_to_monai_affine(image, matrix=matrix, translation=translation)\n\n    return ref_affine_matrix\n\n\ndef monai_to_itk_ddf(image, ddf):\n    \"\"\"\n    converting the dense displacement field from the MONAI space to the ITK\n    Args:\n        image: itk image of array shape 2D: (H, W) or 3D: (D, H, W)\n        ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W)\n    Returns:\n        displacement_field: itk image of the corresponding displacement field\n\n    \"\"\"\n    # 3, D, H, W -> D, H, W, 3\n    ndim = image.ndim\n    ddf = ddf.transpose(tuple(list(range(1, ndim + 1)) + [0]))\n    # x, y, z -> z, x, y\n    ddf = ddf[..., ::-1]\n\n    # Correct for spacing\n    spacing = np.asarray(image.GetSpacing(), dtype=np.float64)\n    ddf *= np.array(spacing, ndmin=ndim + 1)\n\n    # Correct for direction\n    direction = np.asarray(image.GetDirection(), dtype=np.float64)\n    ddf = np.einsum(\"ij,...j->...i\", direction, ddf, dtype=np.float64).astype(np.float32)\n\n    # initialise displacement field -\n    vector_component_type = itk.F\n    vector_pixel_type = itk.Vector[vector_component_type, ndim]\n    displacement_field_type = itk.Image[vector_pixel_type, ndim]\n    displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type)\n\n    # Set image metadata\n    displacement_field.SetSpacing(image.GetSpacing())\n    displacement_field.SetOrigin(image.GetOrigin())\n    displacement_field.SetDirection(image.GetDirection())\n\n    return displacement_field\n"
  },
  {
    "path": "monai/data/meta_obj.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport pprint\nfrom collections.abc import Iterable\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import TraceKeys, first, is_immutable\n\n_TRACK_META = True\n\n__all__ = [\"get_track_meta\", \"set_track_meta\", \"MetaObj\"]\n\n\ndef set_track_meta(val: bool) -> None:\n    \"\"\"\n    Boolean to set whether metadata is tracked. If `True`, metadata will be associated\n    its data by using subclasses of `MetaObj`. If `False`, then data will be returned\n    with empty metadata.\n\n    If `set_track_meta` is `False`, then standard data objects will be returned (e.g.,\n    `torch.Tensor` and `np.ndarray`) as opposed to MONAI's enhanced objects.\n\n    By default, this is `True`, and most users will want to leave it this way. However,\n    if you are experiencing any problems regarding metadata, and aren't interested in\n    preserving metadata, then you can disable it.\n    \"\"\"\n    global _TRACK_META\n    _TRACK_META = val\n\n\ndef get_track_meta() -> bool:\n    \"\"\"\n    Return the boolean as to whether metadata is tracked. If `True`, metadata will be\n    associated its data by using subclasses of `MetaObj`. If `False`, then data will be\n    returned with empty metadata.\n\n    If `set_track_meta` is `False`, then standard data objects will be returned (e.g.,\n    `torch.Tensor` and `np.ndarray`) as opposed to MONAI's enhanced objects.\n\n    By default, this is `True`, and most users will want to leave it this way. However,\n    if you are experiencing any problems regarding metadata, and aren't interested in\n    preserving metadata, then you can disable it.\n    \"\"\"\n    return _TRACK_META\n\n\nclass MetaObj:\n    \"\"\"\n    Abstract base class that stores data as well as any extra metadata.\n\n    This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple inheritance.\n\n    Metadata is stored in the form of a dictionary.\n\n    Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`)\n    aside from the extended meta functionality.\n\n    Copying of information:\n\n        * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the\n          first instance of `MetaObj` if `a.is_batch` is False\n          (For batched data, the metadata will be shallow copied for efficiency purposes).\n\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._meta: dict = MetaObj.get_default_meta()\n        self._applied_operations: list = MetaObj.get_default_applied_operations()\n        self._pending_operations: list = MetaObj.get_default_applied_operations()  # the same default as applied_ops\n        self._is_batch: bool = False\n\n    @staticmethod\n    def flatten_meta_objs(*args: Iterable):\n        \"\"\"\n        Recursively flatten input and yield all instances of `MetaObj`.\n        This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and\n        their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type\n        `MetaObj`.\n\n        Args:\n            args: Iterables of inputs to be flattened.\n        Returns:\n            list of nested `MetaObj` from input.\n        \"\"\"\n        for a in itertools.chain(*args):\n            if isinstance(a, (list, tuple)):\n                yield from MetaObj.flatten_meta_objs(a)\n            elif isinstance(a, MetaObj):\n                yield a\n\n    @staticmethod\n    def copy_items(data):\n        \"\"\"returns a copy of the data. list and dict are shallow copied for efficiency purposes.\"\"\"\n        if is_immutable(data):\n            return data\n        if isinstance(data, (list, dict, np.ndarray)):\n            return data.copy()\n        if isinstance(data, torch.Tensor):\n            return data.detach().clone()\n        return deepcopy(data)\n\n    def copy_meta_from(self, input_objs, copy_attr=True, keys=None):\n        \"\"\"\n        Copy metadata from a `MetaObj` or an iterable of `MetaObj` instances.\n\n        Args:\n            input_objs: list of `MetaObj` to copy data from.\n            copy_attr: whether to copy each attribute with `MetaObj.copy_item`.\n                note that if the attribute is a nested list or dict, only a shallow copy will be done.\n            keys: the keys of attributes to copy from the ``input_objs``.\n                If None, all keys from the input_objs will be copied.\n        \"\"\"\n        first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self)\n        if not hasattr(first_meta, \"__dict__\"):\n            return self\n        first_meta = first_meta.__dict__\n        keys = first_meta.keys() if keys is None else keys\n        if not copy_attr:\n            self.__dict__ = {a: first_meta[a] for a in keys if a in first_meta}  # shallow copy for performance\n        else:\n            self.__dict__.update({a: MetaObj.copy_items(first_meta[a]) for a in keys if a in first_meta})\n        return self\n\n    @staticmethod\n    def get_default_meta() -> dict:\n        \"\"\"Get the default meta.\n\n        Returns:\n            default metadata.\n        \"\"\"\n        return {}\n\n    @staticmethod\n    def get_default_applied_operations() -> list:\n        \"\"\"Get the default applied operations.\n\n        Returns:\n            default applied operations.\n        \"\"\"\n        return []\n\n    def __repr__(self) -> str:\n        \"\"\"String representation of class.\"\"\"\n        out: str = \"\\nMetadata\\n\"\n        if self.meta is not None:\n            out += \"\".join(f\"\\t{k}: {v}\\n\" for k, v in self.meta.items())\n        else:\n            out += \"None\"\n\n        out += \"\\nApplied operations\\n\"\n        if self.applied_operations is not None:\n            out += pprint.pformat(self.applied_operations, indent=2, compact=True, width=120)\n        else:\n            out += \"None\"\n\n        out += f\"\\nIs batch?: {self.is_batch}\"\n\n        return out\n\n    @property\n    def meta(self) -> dict:\n        \"\"\"Get the meta. Defaults to ``{}``.\"\"\"\n        return self._meta if hasattr(self, \"_meta\") else MetaObj.get_default_meta()\n\n    @meta.setter\n    def meta(self, d) -> None:\n        \"\"\"Set the meta.\"\"\"\n        if d == TraceKeys.NONE:\n            self._meta = MetaObj.get_default_meta()\n        else:\n            self._meta = d\n\n    @property\n    def applied_operations(self) -> list[dict]:\n        \"\"\"Get the applied operations. Defaults to ``[]``.\"\"\"\n        if hasattr(self, \"_applied_operations\"):\n            return self._applied_operations\n        return MetaObj.get_default_applied_operations()\n\n    @applied_operations.setter\n    def applied_operations(self, t) -> None:\n        \"\"\"Set the applied operations.\"\"\"\n        if t == TraceKeys.NONE:\n            # received no operations when decollating a batch\n            self._applied_operations = MetaObj.get_default_applied_operations()\n            return\n        self._applied_operations = t\n\n    def push_applied_operation(self, t: Any) -> None:\n        self._applied_operations.append(t)\n\n    def pop_applied_operation(self) -> Any:\n        return self._applied_operations.pop()\n\n    @property\n    def pending_operations(self) -> list[dict]:\n        \"\"\"Get the pending operations. Defaults to ``[]``.\"\"\"\n        if hasattr(self, \"_pending_operations\"):\n            return self._pending_operations\n        return MetaObj.get_default_applied_operations()  # the same default as applied_ops\n\n    @property\n    def has_pending_operations(self) -> bool:\n        \"\"\"\n        Determine whether there are pending operations.\n        Returns:\n            True if there are pending operations; False if not\n        \"\"\"\n        return self.pending_operations is not None and len(self.pending_operations) > 0\n\n    def push_pending_operation(self, t: Any) -> None:\n        self._pending_operations.append(t)\n\n    def pop_pending_operation(self) -> Any:\n        return self._pending_operations.pop()\n\n    def clear_pending_operations(self) -> Any:\n        self._pending_operations = MetaObj.get_default_applied_operations()\n\n    @property\n    def is_batch(self) -> bool:\n        \"\"\"Return whether object is part of batch or not.\"\"\"\n        return self._is_batch if hasattr(self, \"_is_batch\") else False\n\n    @is_batch.setter\n    def is_batch(self, val: bool) -> None:\n        \"\"\"Set whether object is part of batch or not.\"\"\"\n        self._is_batch = val\n"
  },
  {
    "path": "monai/data/meta_tensor.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport functools\nimport warnings\nfrom collections.abc import Sequence\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.config.type_definitions import NdarrayTensor\nfrom monai.data.meta_obj import MetaObj, get_track_meta\nfrom monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata\nfrom monai.utils import look_up_option\nfrom monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor\n\n__all__ = [\"MetaTensor\"]\n\n\n@functools.lru_cache(None)\ndef _get_named_tuple_like_type(func):\n    if (\n        hasattr(torch, \"return_types\")\n        and hasattr(func, \"__name__\")\n        and hasattr(torch.return_types, func.__name__)\n        and isinstance(getattr(torch.return_types, func.__name__), type)\n    ):\n        return getattr(torch.return_types, func.__name__)\n    return None\n\n\ndef _not_requiring_metadata(ret):\n    return isinstance(ret, (int, str, bytes, torch.Size, torch.dtype, torch.device, np.ndarray)) or not (\n        isinstance(ret, MetaTensor) or (isinstance(ret, Sequence) and any(isinstance(x, MetaTensor) for x in ret))\n    )\n\n\nclass MetaTensor(MetaObj, torch.Tensor):\n    \"\"\"\n    Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata.\n\n    Metadata is stored in the form of a dictionary. Nested, an affine matrix will be\n    stored. This should be in the form of `torch.Tensor`.\n\n    Behavior should be the same as `torch.Tensor` aside from the extended\n    meta functionality.\n\n    Copying of information:\n\n        * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the\n          first instance of `MetaTensor` if `a.is_batch` is False\n          (For batched data, the metadata will be shallow copied for efficiency purposes).\n\n    Example:\n        .. code-block:: python\n\n            import torch\n            from monai.data import MetaTensor\n\n            t = torch.tensor([1,2,3])\n            affine = torch.as_tensor([[2,0,0,0],\n                                      [0,2,0,0],\n                                      [0,0,2,0],\n                                      [0,0,0,1]], dtype=torch.float64)\n            meta = {\"some\": \"info\"}\n            m = MetaTensor(t, affine=affine, meta=meta)\n            m2 = m + m\n            assert isinstance(m2, MetaTensor)\n            assert m2.meta[\"some\"] == \"info\"\n            assert torch.all(m2.affine == affine)\n\n    Notes:\n        - Requires pytorch 1.9 or newer for full compatibility.\n        - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may\n          not work if `im` is of type `MetaTensor`. This can be resolved with\n          `torch.jit.trace(net, im.as_tensor())`.\n        - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.\n        - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor.\n          see: https://github.com/pytorch/pytorch/issues/54457\n        - A warning will be raised if in the constructor `affine` is not `None` and\n          `meta` already contains the key `affine`.\n        - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.\n        - With a batch of data, `batch[0]` will return the 0th image\n          with the 0th metadata. When the batch dimension is non-singleton, e.g.,\n          `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the\n          last example) of the metadata will be returned, and `is_batch` will return `True`.\n        - When creating a batch with this class, use `monai.data.DataLoader` as opposed\n          to `torch.utils.data.DataLoader`, as this will take care of collating the\n          metadata properly.\n    \"\"\"\n\n    @staticmethod\n    def __new__(\n        cls,\n        x,\n        affine: torch.Tensor | None = None,\n        meta: dict | None = None,\n        applied_operations: list | None = None,\n        *args,\n        **kwargs,\n    ) -> MetaTensor:\n        _kwargs = {\"device\": kwargs.pop(\"device\", None), \"dtype\": kwargs.pop(\"dtype\", None)} if kwargs else {}\n        return torch.as_tensor(x, *args, **_kwargs).as_subclass(cls)\n\n    def __init__(\n        self,\n        x,\n        affine: torch.Tensor | None = None,\n        meta: dict | None = None,\n        applied_operations: list | None = None,\n        *_args,\n        **_kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            x: initial array for the MetaTensor. Can be a list, tuple, NumPy ndarray, scalar, and other types.\n            affine: optional 4x4 array.\n            meta: dictionary of metadata.\n            applied_operations: list of previously applied operations on the MetaTensor,\n                the list is typically maintained by `monai.transforms.TraceableTransform`.\n                See also: :py:class:`monai.transforms.TraceableTransform`\n            _args: additional args (currently not in use in this constructor).\n            _kwargs: additional kwargs (currently not in use in this constructor).\n\n        Note:\n            If a `meta` dictionary is given, use it. Else, if `meta` exists in the input tensor `x`, use it.\n            Else, use the default value. Similar for the affine, except this could come from\n            four places, priority: `affine`, `meta[\"affine\"]`, `x.affine`, `get_default_affine`.\n\n        \"\"\"\n        super().__init__()\n        # set meta\n        if meta is not None:\n            self.meta = meta\n        elif isinstance(x, MetaObj):\n            self.__dict__ = deepcopy(x.__dict__)\n        # set the affine\n        if affine is not None:\n            if MetaKeys.AFFINE in self.meta:\n                warnings.warn(\"Setting affine, but the applied meta contains an affine. This will be overwritten.\")\n            self.affine = affine\n        elif MetaKeys.AFFINE in self.meta:\n            # by using the setter function, we ensure it is converted to torch.Tensor if not already\n            self.affine = self.meta[MetaKeys.AFFINE]\n        else:\n            self.affine = self.get_default_affine()\n        # applied_operations\n        if applied_operations is not None:\n            self.applied_operations = applied_operations\n        else:\n            self.applied_operations = MetaObj.get_default_applied_operations()\n\n        # if we are creating a new MetaTensor, then deep copy attributes\n        if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor):\n            self.copy_meta_from(self)\n\n        if MetaKeys.SPACE not in self.meta:\n            self.meta[MetaKeys.SPACE] = SpaceKeys.RAS  # defaulting to the right-anterior-superior space\n\n    @staticmethod\n    def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:\n        \"\"\"\n        Update the metadata from the output of `MetaTensor.__torch_function__`.\n\n        The output of `torch.Tensor.__torch_function__` could be a single object or a\n        sequence of them. Hence, in `MetaTensor.__torch_function__` we convert them to a\n        list of not already, and then we loop across each element, processing metadata\n        as necessary. For each element, if not of type `MetaTensor`, then nothing to do.\n\n        Args:\n            rets: the output from `torch.Tensor.__torch_function__`, which has been\n                converted to a list in `MetaTensor.__torch_function__` if it wasn't\n                already a `Sequence`.\n            func: the torch function that was applied. Examples might be `torch.squeeze`\n                or `torch.Tensor.__add__`. We need this since the metadata need to be\n                treated differently if a batch of data is considered. For example,\n                slicing (`torch.Tensor.__getitem__`) the ith element of the 0th\n                dimension of a batch of data should return a ith tensor with the ith\n                metadata.\n            args: positional arguments that were passed to `func`.\n            kwargs: keyword arguments that were passed to `func`.\n\n        Returns:\n            A sequence with the same number of elements as `rets`. For each element, if\n            the input type was not `MetaTensor`, then no modifications will have been\n            made. If global parameters have been set to false (e.g.,\n            `not get_track_meta()`), then any `MetaTensor` will be converted to\n            `torch.Tensor`. Else, metadata will be propagated as necessary (see\n            :py:func:`MetaTensor._copy_meta`).\n        \"\"\"\n        out = []\n        metas = None  # optional output metadicts for each of the return value in `rets`\n        is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, \"is_batch\"))\n        for idx, ret in enumerate(rets):\n            # if not `MetaTensor`, nothing to do.\n            if not isinstance(ret, MetaTensor):\n                pass\n            # if not tracking, convert to `torch.Tensor`.\n            elif not get_track_meta():\n                ret = ret.as_tensor()\n            # else, handle the `MetaTensor` metadata.\n            else:\n                meta_args = MetaObj.flatten_meta_objs(args, kwargs.values())\n                ret.is_batch = is_batch\n                ret.copy_meta_from(meta_args, copy_attr=not is_batch)\n                # the following is not implemented but the network arch may run into this case:\n                # if func == torch.cat and any(m.is_batch if hasattr(m, \"is_batch\") else False for m in meta_args):\n                #     raise NotImplementedError(\"torch.cat is not implemented for batch of MetaTensors.\")\n                if is_batch:\n                    ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)\n            out.append(ret)\n        # if the input was a tuple, then return it as a tuple\n        return tuple(out) if isinstance(rets, tuple) else out\n\n    @classmethod\n    def _handle_batched(cls, ret, idx, metas, func, args, kwargs):\n        \"\"\"utility function to handle batched MetaTensors.\"\"\"\n        # If we have a batch of data, then we need to be careful if a slice of\n        # the data is returned. Depending on how the data are indexed, we return\n        # some or all of the metadata, and the return object may or may not be a\n        # batch of data (e.g., `batch[:,-1]` versus `batch[0]`).\n        # if indexing e.g., `batch[0]`\n        if func == torch.Tensor.__getitem__:\n            if idx > 0 or len(args) < 2 or len(args[0]) < 1:\n                return ret\n            batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]\n            # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the\n            # first element will be `slice(None, None, None)` and `Ellipsis`,\n            # respectively. Don't need to do anything with the metadata.\n            if batch_idx in (slice(None, None, None), Ellipsis, None) or isinstance(batch_idx, torch.Tensor):\n                return ret\n            dec_batch = decollate_batch(args[0], detach=False)\n            ret_meta = dec_batch[batch_idx]\n            if isinstance(ret_meta, list) and ret_meta:  # e.g. batch[0:2], re-collate\n                try:\n                    ret_meta = list_data_collate(ret_meta)\n                except (TypeError, ValueError, RuntimeError, IndexError) as e:\n                    raise ValueError(\n                        \"Inconsistent batched metadata dicts when slicing a batch of MetaTensors, \"\n                        \"please consider converting it into a torch Tensor using `x.as_tensor()` or \"\n                        \"a numpy array using `x.array`.\"\n                    ) from e\n            elif isinstance(ret_meta, MetaObj):  # e.g. `batch[0]` or `batch[0, 1]`, batch_idx is int\n                ret_meta.is_batch = False\n            if hasattr(ret_meta, \"__dict__\"):\n                ret.__dict__ = ret_meta.__dict__.copy()\n        # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.\n        # But we only want to split the batch if the `unbind` is along the 0th dimension.\n        elif func == torch.Tensor.unbind:\n            if len(args) > 1:\n                dim = args[1]\n            elif \"dim\" in kwargs:\n                dim = kwargs[\"dim\"]\n            else:\n                dim = 0\n            if dim == 0:\n                if metas is None:\n                    metas = decollate_batch(args[0], detach=False)\n                if hasattr(metas[idx], \"__dict__\"):\n                    ret.__dict__ = metas[idx].__dict__.copy()\n                ret.is_batch = False\n        return ret\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:\n        \"\"\"Wraps all torch functions.\"\"\"\n        if kwargs is None:\n            kwargs = {}\n        ret = super().__torch_function__(func, types, args, kwargs)\n        # if `out` has been used as argument, metadata is not copied, nothing to do.\n        # if \"out\" in kwargs:\n        #     return ret\n        if _not_requiring_metadata(ret):\n            return ret\n        if _get_named_tuple_like_type(func) is not None and isinstance(ret, _get_named_tuple_like_type(func)):\n            # for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like\n            out_items = MetaTensor.update_meta(ret, func, args, kwargs)\n            for idx in range(ret.n_fields):\n                ret[idx].meta = out_items[idx].meta\n                ret[idx].applied_operations = out_items[idx].applied_operations\n            return ret\n        # we might have 1 or multiple outputs. Might be MetaTensor, might be something\n        # else (e.g., `__repr__` returns a string).\n        # Convert to list (if necessary), process, and at end remove list if one was added.\n        if not isinstance(ret, Sequence):\n            ret = [ret]\n            unpack = True\n        else:\n            unpack = False\n        ret = MetaTensor.update_meta(ret, func, args, kwargs)\n        return ret[0] if unpack else ret\n\n    @staticmethod\n    def _convert(x):\n        if isinstance(x, (MetaTensor, torch.Tensor, tuple, list)):\n            return convert_data_type(x, output_type=np.ndarray, wrap_sequence=False)[0]\n        return x\n\n    def __array_function__(self, func, types, args, kwargs):\n        \"\"\"for numpy Interoperability, so that we can compute ``np.sum(MetaTensor([1.0]))``.\"\"\"\n        try:\n            if not func.__module__.startswith(\"numpy\"):\n                return NotImplemented\n        except AttributeError:\n            return NotImplemented\n        _args = list(map(MetaTensor._convert, args))\n        _kwargs = {k: MetaTensor._convert(v) for k, v in kwargs.items()}\n        return func(*_args, **_kwargs)\n\n    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):\n        \"\"\"\n        For numpy interoperability, so that we can compute ``MetaTensor([1.0]) >= np.asarray([1.0])``.\n        This is for pytorch > 1.8.\n        \"\"\"\n        try:\n            if not type(ufunc).__module__.startswith(\"numpy\"):\n                return NotImplemented\n        except AttributeError:\n            return NotImplemented\n        if method != \"__call__\":\n            return NotImplemented\n        _inputs = map(MetaTensor._convert, inputs)\n        _kwargs = {k: MetaTensor._convert(v) for k, v in kwargs.items()}\n        if \"out\" in _kwargs:\n            return NotImplemented  # not supported\n        try:\n            return getattr(ufunc, method)(*_inputs, **_kwargs)\n        except AttributeError:\n            return NotImplemented\n\n    @staticmethod\n    def get_default_affine(dtype=torch.float64) -> torch.Tensor:\n        return torch.eye(4, device=torch.device(\"cpu\"), dtype=dtype)\n\n    def as_tensor(self) -> torch.Tensor:\n        \"\"\"\n        Return the `MetaTensor` as a `torch.Tensor`.\n        It is OS dependent as to whether this will be a deep copy or not.\n        \"\"\"\n        return self.as_subclass(torch.Tensor)\n\n    def get_array(self, output_type=np.ndarray, dtype=None, device=None, *_args, **_kwargs):\n        \"\"\"\n        Returns a new array in `output_type`, the array shares the same underlying storage when the output is a\n        numpy array. Changes to self tensor will be reflected in the ndarray and vice versa.\n\n        Args:\n            output_type: output type, see also: :py:func:`monai.utils.convert_data_type`.\n            dtype: dtype of output data. Converted to correct library type (e.g.,\n                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).\n                If left blank, it remains unchanged.\n            device: if the output is a `torch.Tensor`, select device (if `None`, unchanged).\n            _args: currently unused parameters.\n            _kwargs: currently unused parameters.\n        \"\"\"\n        return convert_data_type(self, output_type=output_type, dtype=dtype, device=device, wrap_sequence=True)[0]\n\n    def set_array(self, src, non_blocking: bool = False, *_args, **_kwargs):\n        \"\"\"\n        Copies the elements from src into self tensor and returns self.\n        The src tensor must be broadcastable with the self tensor.\n        It may be of a different data type or reside on a different device.\n\n        See also: `https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html`\n\n        Args:\n            src: the source tensor to copy from.\n            non_blocking: if True and this copy is between CPU and GPU, the copy may occur\n                asynchronously with respect to the host. For other cases, this argument has no effect.\n            _args: currently unused parameters.\n            _kwargs:  currently unused parameters.\n        \"\"\"\n        converted: torch.Tensor = convert_to_tensor(src, track_meta=False, wrap_sequence=True)\n        try:\n            return self.copy_(converted, non_blocking=non_blocking)\n        except RuntimeError:  # skip the shape checking\n            self.data = converted\n            return self\n\n    @property\n    def array(self):\n        \"\"\"\n        Returns a numpy array of ``self``. The array and ``self`` shares the same underlying storage if self is on cpu.\n        Changes to ``self`` (it's a subclass of torch.Tensor) will be reflected in the ndarray and vice versa.\n        If ``self`` is not on cpu, the call will move the array to cpu and then the storage is not shared.\n\n        :getter: see also: :py:func:`MetaTensor.get_array()`\n        :setter: see also: :py:func:`MetaTensor.set_array()`\n        \"\"\"\n        return self.get_array()\n\n    @array.setter\n    def array(self, src) -> None:\n        \"\"\"A default setter using ``self.set_array()``\"\"\"\n        self.set_array(src)\n\n    def as_dict(self, key: str, output_type=torch.Tensor, dtype=None) -> dict:\n        \"\"\"\n        Get the object as a dictionary for backwards compatibility.\n        This method does not make a deep copy of the objects.\n\n        Args:\n            key: Base key to store main data. The key for the metadata will be determined using `PostFix`.\n            output_type: `torch.Tensor` or `np.ndarray` for the main data.\n            dtype: dtype of output data. Converted to correct library type (e.g.,\n                `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).\n                If left blank, it remains unchanged.\n\n        Return:\n            A dictionary consisting of three keys, the main data (stored under `key`) and the metadata.\n        \"\"\"\n        if output_type not in (torch.Tensor, np.ndarray):\n            raise ValueError(f\"output_type must be torch.Tensor or np.ndarray, got {output_type}.\")\n        return {\n            key: self.get_array(output_type=output_type, dtype=dtype),\n            PostFix.meta(key): self.meta,\n            PostFix.transforms(key): self.applied_operations,\n        }\n\n    def astype(self, dtype, device=None, *_args, **_kwargs):\n        \"\"\"\n        Cast to ``dtype``, sharing data whenever possible.\n\n        Args:\n            dtype: dtypes such as np.float32, torch.float, \"np.float32\", float.\n            device: the device if `dtype` is a torch data type.\n            _args: additional args (currently unused).\n            _kwargs: additional kwargs (currently unused).\n\n        Returns:\n            data array instance\n        \"\"\"\n        if isinstance(dtype, str):\n            mod_str, *dtype = dtype.split(\".\", 1)\n            dtype = mod_str if not dtype else dtype[0]\n        else:\n            mod_str = getattr(dtype, \"__module__\", \"torch\")\n        mod_str = look_up_option(mod_str, {\"torch\", \"numpy\", \"np\"}, default=\"numpy\")\n\n        out_type: type[torch.Tensor] | type[np.ndarray] | None\n        if mod_str == \"torch\":\n            out_type = torch.Tensor\n        elif mod_str in (\"numpy\", \"np\"):\n            out_type = np.ndarray\n        else:\n            out_type = None\n        return self.get_array(output_type=out_type, dtype=dtype, device=device)\n\n    @property\n    def affine(self) -> torch.Tensor:\n        \"\"\"Get the affine. Defaults to ``torch.eye(4, dtype=torch.float64)``\"\"\"\n        return self.meta.get(MetaKeys.AFFINE, self.get_default_affine())  # type: ignore\n\n    @affine.setter\n    def affine(self, d: NdarrayTensor) -> None:\n        \"\"\"Set the affine.\"\"\"\n        self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device(\"cpu\"), dtype=torch.float64)\n\n    @property\n    def pixdim(self):\n        \"\"\"Get the spacing\"\"\"\n        if self.is_batch:\n            return [affine_to_spacing(a) for a in self.affine]\n        return affine_to_spacing(self.affine)\n\n    def peek_pending_shape(self):\n        \"\"\"\n        Get the currently expected spatial shape as if all the pending operations are executed.\n        For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned.\n        \"\"\"\n        res = None\n        if self.pending_operations:\n            res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)\n        # default to spatial shape (assuming channel-first input)\n        return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res\n\n    def peek_pending_affine(self):\n        res = self.affine\n        r = len(res) - 1\n        if r not in (2, 3):\n            warnings.warn(f\"Only 2d and 3d affine are supported, got {r}d input.\")\n        for p in self.pending_operations:\n            next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE), dtype=torch.float64)\n            if next_matrix is None:\n                continue\n            res = convert_to_dst_type(res, next_matrix)[0]\n            next_matrix = monai.data.utils.to_affine_nd(r, next_matrix)\n            res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)\n        return res\n\n    def peek_pending_rank(self):\n        a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine\n        return 1 if a is None else int(max(1, len(a) - 1))\n\n    def new_empty(self, size, dtype=None, device=None, requires_grad=False):  # type: ignore[override]\n        \"\"\"\n        must be defined for deepcopy to work\n\n        See:\n            - https://pytorch.org/docs/stable/generated/torch.Tensor.new_empty.html#torch-tensor-new-empty\n        \"\"\"\n        return type(self)(\n            self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad)\n        )\n\n    def clone(self, **kwargs):\n        \"\"\"\n        Returns a copy of the MetaTensor instance.\n\n        Args:\n            kwargs: additional keyword arguments to `torch.clone`.\n\n        See also: https://pytorch.org/docs/stable/generated/torch.clone.html\n        \"\"\"\n        new_inst = MetaTensor(self.as_tensor().clone(**kwargs))\n        new_inst.__dict__ = deepcopy(self.__dict__)\n        return new_inst\n\n    @staticmethod\n    def ensure_torch_and_prune_meta(\n        im: NdarrayTensor, meta: dict | None, simple_keys: bool = False, pattern: str | None = None, sep: str = \".\"\n    ):\n        \"\"\"\n        Convert the image to MetaTensor (when meta is not None). If `affine` is in the `meta` dictionary,\n        convert that to `torch.Tensor`, too. Remove any superfluous metadata.\n\n        Args:\n            im: Input image (`np.ndarray` or `torch.Tensor`)\n            meta: Metadata dictionary. When it's None, the metadata is not tracked, this method returns a torch.Tensor.\n            simple_keys: whether to keep only a simple subset of metadata keys.\n            pattern: combined with `sep`, a regular expression used to match and prune keys\n                in the metadata (nested dictionary), default to None, no key deletion.\n            sep: combined with `pattern`, used to match and delete keys in the metadata (nested dictionary).\n                default is \".\", see also :py:class:`monai.transforms.DeleteItemsd`.\n                e.g. ``pattern=\".*_code$\", sep=\" \"`` removes any meta keys that ends with ``\"_code\"``.\n\n        Returns:\n            By default, a `MetaTensor` is returned.\n            However, if `get_track_meta()` is `False` or meta=None, a `torch.Tensor` is returned.\n        \"\"\"\n        img = convert_to_tensor(im, track_meta=get_track_meta() and meta is not None)  # potentially ascontiguousarray\n        # if not tracking metadata, return `torch.Tensor`\n        if not isinstance(img, MetaTensor):\n            return img\n\n        if meta is None:\n            meta = {}\n\n        # remove any superfluous metadata.\n        if simple_keys:\n            # ensure affine is of type `torch.Tensor`\n            if MetaKeys.AFFINE in meta:\n                meta[MetaKeys.AFFINE] = convert_to_tensor(meta[MetaKeys.AFFINE])  # bc-breaking\n            remove_extra_metadata(meta)  # bc-breaking\n\n        if pattern is not None:\n            meta = monai.transforms.DeleteItemsd(keys=pattern, sep=sep, use_re=True)(meta)\n\n        # return the `MetaTensor`\n        if meta is None:\n            meta = {}\n        img.meta = meta\n        if MetaKeys.AFFINE in meta:\n            img.affine = meta[MetaKeys.AFFINE]  # this uses the affine property setter\n        else:\n            img.affine = MetaTensor.get_default_affine()\n        return img\n\n    def __repr__(self):  # type: ignore[override]\n        \"\"\"\n        Prints a representation of the tensor.\n        Prepends \"meta\" to ``torch.Tensor.__repr__``.\n        Use ``print_verbose`` for associated metadata.\n        \"\"\"\n        return f\"meta{self.as_tensor().__repr__()}\"\n\n    def __str__(self):\n        \"\"\"\n        Prints a representation of the tensor.\n        Prepends \"meta\" to ``torch.Tensor.__str__``.\n        Use ``print_verbose`` for associated metadata.\n        \"\"\"\n        return f\"meta{str(self.as_tensor())}\"\n\n    def __format__(self, format_spec):\n        \"\"\"\n        returns the output of pytorch tensor's ``__format__`` method.\n        \"\"\"\n        return self.as_tensor().__format__(format_spec)\n\n    def print_verbose(self) -> None:\n        \"\"\"Verbose print with meta data.\"\"\"\n        print(self)\n        if self.meta is not None:\n            print(self.meta.__repr__())\n\n\n# needed in later versions of Pytorch to indicate the class is safe for serialisation\nif hasattr(torch.serialization, \"add_safe_globals\"):\n    torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])\n"
  },
  {
    "path": "monai/data/samplers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.utils.data import DistributedSampler as _TorchDistributedSampler\n\n__all__ = [\"DistributedSampler\", \"DistributedWeightedRandomSampler\"]\n\n\nclass DistributedSampler(_TorchDistributedSampler):\n    \"\"\"\n    Enhance PyTorch DistributedSampler to support non-evenly divisible sampling.\n\n    Args:\n        dataset: Dataset used for sampling.\n        even_divisible: if False, different ranks can have different data length.\n            for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].\n        num_replicas: number of processes participating in distributed training.\n            by default, `world_size` is retrieved from the current distributed group.\n        rank: rank of the current process within `num_replicas`. by default,\n            `rank` is retrieved from the current distributed group.\n        shuffle: if `True`, sampler will shuffle the indices, default to True.\n        kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`.\n\n    More information about DistributedSampler, please check:\n    https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Dataset,\n        even_divisible: bool = True,\n        num_replicas: int | None = None,\n        rank: int | None = None,\n        shuffle: bool = True,\n        **kwargs,\n    ):\n        super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, **kwargs)\n\n        if not even_divisible:\n            data_len = len(dataset)  # type: ignore\n            if data_len < self.num_replicas:\n                raise ValueError(\"the dataset length is less than the number of participating ranks.\")\n            extra_size = self.total_size - data_len\n            if self.rank + extra_size >= self.num_replicas:\n                self.num_samples -= 1\n            self.total_size = data_len\n\n\nclass DistributedWeightedRandomSampler(DistributedSampler):\n    \"\"\"\n    Extend the `DistributedSampler` to support weighted sampling.\n    Refer to `torch.utils.data.WeightedRandomSampler`, for more details please check:\n    https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler.\n\n    Args:\n        dataset: Dataset used for sampling.\n        weights: a sequence of weights, not necessary summing up to one, length should exactly\n            match the full dataset.\n        num_samples_per_rank: number of samples to draw for every rank, sample from\n            the distributed subset of dataset.\n            if None, default to the length of dataset split by DistributedSampler.\n        generator: PyTorch Generator used in sampling.\n        even_divisible: if False, different ranks can have different data length.\n            for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].'\n        num_replicas: number of processes participating in distributed training.\n            by default, `world_size` is retrieved from the current distributed group.\n        rank: rank of the current process within `num_replicas`. by default,\n            `rank` is retrieved from the current distributed group.\n        kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Dataset,\n        weights: Sequence[float],\n        num_samples_per_rank: int | None = None,\n        generator: torch.Generator | None = None,\n        even_divisible: bool = True,\n        num_replicas: int | None = None,\n        rank: int | None = None,\n        **kwargs,\n    ):\n        kwargs.setdefault(\"shuffle\", True)\n        super().__init__(dataset=dataset, even_divisible=even_divisible, num_replicas=num_replicas, rank=rank, **kwargs)\n        self.weights = weights\n        self.num_samples_per_rank = num_samples_per_rank if num_samples_per_rank is not None else self.num_samples\n        self.generator = generator\n\n    def __iter__(self):\n        indices = list(super().__iter__())\n        weights = torch.as_tensor([self.weights[i] for i in indices], dtype=torch.double)\n        # sample based on the provided weights\n        rand_tensor = torch.multinomial(weights, self.num_samples_per_rank, True, generator=self.generator)\n\n        for i in rand_tensor:\n            yield indices[i]\n\n    def __len__(self):\n        return self.num_samples_per_rank\n"
  },
  {
    "path": "monai/data/synthetic.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport numpy as np\n\nfrom monai.transforms.utils import rescale_array\n\n__all__ = [\"create_test_image_2d\", \"create_test_image_3d\"]\n\n\ndef create_test_image_2d(\n    height: int,\n    width: int,\n    num_objs: int = 12,\n    rad_max: int = 30,\n    rad_min: int = 5,\n    noise_max: float = 0.0,\n    num_seg_classes: int = 5,\n    channel_dim: int | None = None,\n    random_state: np.random.RandomState | None = None,\n) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"\n    Return a noisy 2D image with `num_objs` circles and a 2D mask image. The maximum and minimum radii of the circles\n    are given as `rad_max` and `rad_min`. The mask will have `num_seg_classes` number of classes for segmentations labeled\n    sequentially from 1, plus a background class represented as 0. If `noise_max` is greater than 0 then noise will be\n    added to the image taken from the uniform distribution on range `[0,noise_max)`. If `channel_dim` is None, will create\n    an image without channel dimension, otherwise create an image with channel dimension as first dim or last dim.\n\n    Args:\n        height: height of the image. The value should be larger than `2 * rad_max`.\n        width: width of the image. The value should be larger than `2 * rad_max`.\n        num_objs: number of circles to generate. Defaults to `12`.\n        rad_max: maximum circle radius. Defaults to `30`.\n        rad_min: minimum circle radius. Defaults to `5`.\n        noise_max: if greater than 0 then noise will be added to the image taken from\n            the uniform distribution on range `[0,noise_max)`. Defaults to `0`.\n        num_seg_classes: number of classes for segmentations. Defaults to `5`.\n        channel_dim: if None, create an image without channel dimension, otherwise create\n            an image with channel dimension as first dim or last dim. Defaults to `None`.\n        random_state: the random generator to use. Defaults to `np.random`.\n\n    Returns:\n        Randomised Numpy array with shape (`height`, `width`)\n    \"\"\"\n\n    if rad_max <= rad_min:\n        raise ValueError(f\"`rad_min` {rad_min} should be less than `rad_max` {rad_max}.\")\n    if rad_min < 1:\n        raise ValueError(f\"`rad_min` {rad_min} should be no less than 1.\")\n    min_size = min(height, width)\n    if min_size <= 2 * rad_max:\n        raise ValueError(f\"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.\")\n\n    image = np.zeros((height, width))\n    rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state  # type: ignore\n\n    for _ in range(num_objs):\n        x = rs.randint(rad_max, height - rad_max)\n        y = rs.randint(rad_max, width - rad_max)\n        rad = rs.randint(rad_min, rad_max)\n        spy, spx = np.ogrid[-x : height - x, -y : width - y]\n        circle = (spx * spx + spy * spy) <= rad * rad\n\n        if num_seg_classes > 1:\n            image[circle] = np.ceil(rs.random() * num_seg_classes)\n        else:\n            image[circle] = rs.random() * 0.5 + 0.5\n\n    labels = np.ceil(image).astype(np.int32, copy=False)\n\n    norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape)\n    noisyimage: np.ndarray = rescale_array(np.maximum(image, norm))  # type: ignore\n\n    if channel_dim is not None:\n        if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)):\n            raise AssertionError(\"invalid channel dim.\")\n        if channel_dim == 0:\n            noisyimage = noisyimage[None]\n            labels = labels[None]\n        else:\n            noisyimage = noisyimage[..., None]\n            labels = labels[..., None]\n\n    return noisyimage, labels\n\n\ndef create_test_image_3d(\n    height: int,\n    width: int,\n    depth: int,\n    num_objs: int = 12,\n    rad_max: int = 30,\n    rad_min: int = 5,\n    noise_max: float = 0.0,\n    num_seg_classes: int = 5,\n    channel_dim: int | None = None,\n    random_state: np.random.RandomState | None = None,\n) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"\n    Return a noisy 3D image and segmentation.\n\n    Args:\n        height: height of the image. The value should be larger than `2 * rad_max`.\n        width: width of the image. The value should be larger than `2 * rad_max`.\n        depth: depth of the image. The value should be larger than `2 * rad_max`.\n        num_objs: number of circles to generate. Defaults to `12`.\n        rad_max: maximum circle radius. Defaults to `30`.\n        rad_min: minimum circle radius. Defaults to `5`.\n        noise_max: if greater than 0 then noise will be added to the image taken from\n            the uniform distribution on range `[0,noise_max)`. Defaults to `0`.\n        num_seg_classes: number of classes for segmentations. Defaults to `5`.\n        channel_dim: if None, create an image without channel dimension, otherwise create\n            an image with channel dimension as first dim or last dim. Defaults to `None`.\n        random_state: the random generator to use. Defaults to `np.random`.\n\n    Returns:\n        Randomised Numpy array with shape (`height`, `width`, `depth`)\n\n    See also:\n        :py:meth:`~create_test_image_2d`\n    \"\"\"\n\n    if rad_max <= rad_min:\n        raise ValueError(f\"`rad_min` {rad_min} should be less than `rad_max` {rad_max}.\")\n    if rad_min < 1:\n        raise ValueError(\"f`rad_min` {rad_min} should be no less than 1.\")\n    min_size = min(height, width, depth)\n    if min_size <= 2 * rad_max:\n        raise ValueError(f\"the minimal size {min_size} of the image should be larger than `2 * rad_max` 2x{rad_max}.\")\n\n    image = np.zeros((height, width, depth))\n    rs: np.random.RandomState = np.random.random.__self__ if random_state is None else random_state  # type: ignore\n\n    for _ in range(num_objs):\n        x = rs.randint(rad_max, height - rad_max)\n        y = rs.randint(rad_max, width - rad_max)\n        z = rs.randint(rad_max, depth - rad_max)\n        rad = rs.randint(rad_min, rad_max)\n        spy, spx, spz = np.ogrid[-x : height - x, -y : width - y, -z : depth - z]\n        circle = (spx * spx + spy * spy + spz * spz) <= rad * rad\n\n        if num_seg_classes > 1:\n            image[circle] = np.ceil(rs.random() * num_seg_classes)\n        else:\n            image[circle] = rs.random() * 0.5 + 0.5\n\n    labels = np.ceil(image).astype(np.int32, copy=False)\n\n    norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape)\n    noisyimage: np.ndarray = rescale_array(np.maximum(image, norm))  # type: ignore\n\n    if channel_dim is not None:\n        if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)):\n            raise AssertionError(\"invalid channel dim.\")\n        noisyimage, labels = (\n            (noisyimage[None], labels[None]) if channel_dim == 0 else (noisyimage[..., None], labels[..., None])\n        )\n\n    return noisyimage, labels\n"
  },
  {
    "path": "monai/data/test_time_augmentation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable\nfrom copy import deepcopy\nfrom typing import TYPE_CHECKING, Any\n\nimport torch\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.dataloader import DataLoader\nfrom monai.data.dataset import Dataset\nfrom monai.data.utils import decollate_batch, pad_list_data_collate\nfrom monai.transforms.compose import Compose\nfrom monai.transforms.croppad.batch import PadListDataCollate\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.post.dictionary import Invertd\nfrom monai.transforms.transform import Randomizable\nfrom monai.transforms.utils_pytorch_numpy_unification import mode, stack\nfrom monai.utils import CommonKeys, PostFix, optional_import\n\nif TYPE_CHECKING:\n    from tqdm import tqdm\n\n    has_tqdm = True\nelse:\n    tqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\n\n__all__ = [\"TestTimeAugmentation\"]\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\ndef _identity(x):\n    return x\n\n\nclass TestTimeAugmentation:\n    \"\"\"\n    Class for performing test time augmentations. This will pass the same image through the network multiple times.\n\n    The user passes transform(s) to be applied to each realization, and provided that at least one of those transforms\n    is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial\n    transforms, the inverse can be applied to each realization of the network's output. Once in the same spatial\n    reference, the results can then be combined and metrics computed.\n\n    Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's\n    dependency on the applied random transforms.\n\n    Reference:\n        Wang et al.,\n        Aleatoric uncertainty estimation with test-time augmentation for medical image segmentation with convolutional\n        neural networks,\n        https://doi.org/10.1016/j.neucom.2019.01.103\n\n    Args:\n        transform: transform (or composed) to be applied to each realization. At least one transform must be of type\n        `RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).\n        When `apply_inverse_to_pred` is True, all random transforms must be of type `InvertibleTransform`.\n        batch_size: number of realizations to infer at once.\n        num_workers: how many subprocesses to use for data.\n        inferrer_fn: function to use to perform inference.\n        device: device on which to perform inference.\n        image_key: key used to extract image from input dictionary.\n        orig_key: the key of the original input data in the dict. will get the applied transform information\n            for this input data, then invert them for the expected data with `image_key`.\n        orig_meta_keys: the key of the metadata of original input data, will get the `affine`, `data_shape`, etc.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.\n        meta_key_postfix: use `key_{postfix}` to fetch the metadata according to the key data,\n            default is `meta_dict`, the metadata is a dictionary object.\n            For example, to handle key `image`,  read/write affine matrices from the\n            metadata `image_meta_dict` dictionary's `affine` field.\n            this arg only works when `meta_keys=None`.\n        to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.\n        output_device: if converted the inverted data to Tensor, move the inverted results to target device\n            before `post_func`, default to \"cpu\".\n        post_func: post processing for the inverted data, should be a callable function.\n        return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True`\n            will return the full data. Dimensions will be same size as when passing a single image through\n            `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.\n        progress: whether to display a progress bar.\n        apply_inverse_to_pred: whether to apply inverse transformations to the predictions.\n            If the model's prediction is spatial (e.g. segmentation), this should be `True` to map the predictions\n            back to the original spatial reference.\n            If the prediction is non-spatial (e.g. classification label or score), this should be `False` to\n            aggregate the raw predictions directly. Defaults to `True`.\n\n    Example:\n        .. code-block:: python\n\n            model = UNet(...).to(device)\n            transform = Compose([RandAffined(keys, ...), ...])\n            transform.set_random_state(seed=123)  # ensure deterministic evaluation\n\n            tt_aug = TestTimeAugmentation(\n                transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device\n            )\n            mode, mean, std, vvc = tt_aug(test_data)\n    \"\"\"\n\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __init__(\n        self,\n        transform: InvertibleTransform,\n        batch_size: int,\n        num_workers: int = 0,\n        inferrer_fn: Callable = _identity,\n        device: str | torch.device = \"cpu\",\n        image_key=CommonKeys.IMAGE,\n        orig_key=CommonKeys.LABEL,\n        nearest_interp: bool = True,\n        orig_meta_keys: str | None = None,\n        meta_key_postfix=DEFAULT_POST_FIX,\n        to_tensor: bool = True,\n        output_device: str | torch.device = \"cpu\",\n        post_func: Callable = _identity,\n        return_full_data: bool = False,\n        progress: bool = True,\n        apply_inverse_to_pred: bool = True,\n    ) -> None:\n        self.transform = transform\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.inferrer_fn = inferrer_fn\n        self.device = device\n        self.image_key = image_key\n        self.return_full_data = return_full_data\n        self.progress = progress\n        self.apply_inverse_to_pred = apply_inverse_to_pred\n        self._pred_key = CommonKeys.PRED\n        self.inverter = Invertd(\n            keys=self._pred_key,\n            transform=transform,\n            orig_keys=orig_key,\n            orig_meta_keys=orig_meta_keys,\n            meta_key_postfix=meta_key_postfix,\n            nearest_interp=nearest_interp,\n            to_tensor=to_tensor,\n            device=output_device,\n            post_func=post_func,\n        )\n\n        # check that the transform has at least one random component, and that all random transforms are invertible\n        self._check_transforms()\n\n    def _check_transforms(self):\n        \"\"\"Should be at least 1 random transform, and all random transforms should be invertible.\"\"\"\n        transforms = [self.transform] if not isinstance(self.transform, Compose) else self.transform.transforms\n        warns = []\n        randoms = []\n\n        for idx, t in enumerate(transforms):\n            if isinstance(t, Randomizable):\n                randoms.append(t)\n                if self.apply_inverse_to_pred and not isinstance(t, InvertibleTransform):\n                    warns.append(f\"Transform #{idx} (type {type(t).__name__}) is random but not invertible.\")\n\n        if len(randoms) == 0:\n            warns.append(\"TTA usually requires at least one `Randomizable` transform in the given transform sequence.\")\n\n        if len(warns) > 0:\n            warnings.warn(\n                \"TTA has encountered issues with the given transforms:\\n  \" + \"\\n  \".join(warns), stacklevel=2\n            )\n\n    def __call__(\n        self, data: dict[str, Any], num_examples: int = 10\n    ) -> tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float] | NdarrayOrTensor:\n        \"\"\"\n        Args:\n            data: dictionary data to be processed.\n            num_examples: number of realizations to be processed and results combined.\n\n        Returns:\n            - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are\n                calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC)\n                is `std/mean` across the whole output, including `num_examples`. See original paper for clarification.\n            - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then\n                concatenating across the first dimension containing `num_examples`. This allows the user to perform\n                their own analysis if desired.\n        \"\"\"\n        d = dict(data)\n\n        # check num examples is multiple of batch size\n        if num_examples % self.batch_size != 0:\n            raise ValueError(\"num_examples should be multiple of batch size.\")\n\n        # generate batch of data of size == batch_size, dataset and dataloader\n        data_in = [deepcopy(d) for _ in range(num_examples)]\n        ds = Dataset(data_in, self.transform)\n        dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate)\n\n        outs: list = []\n\n        for b in tqdm(dl) if has_tqdm and self.progress else dl:\n            # do model forward pass\n            b[self._pred_key] = self.inferrer_fn(b[self.image_key].to(self.device))\n            if self.apply_inverse_to_pred:\n                outs.extend([self.inverter(PadListDataCollate.inverse(i))[self._pred_key] for i in decollate_batch(b)])\n            else:\n                outs.extend([i[self._pred_key] for i in decollate_batch(b)])\n\n        output: NdarrayOrTensor = stack(outs, 0)\n\n        if self.return_full_data:\n            return output\n\n        # calculate metrics\n        _mode = mode(output, dim=0)\n        mean = output.mean(0)\n        std = output.std(0)\n        vvc = (output.std() / output.mean()).item()\n\n        return _mode, mean, std, vvc\n"
  },
  {
    "path": "monai/data/thread_buffer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom multiprocessing.context import SpawnContext\nfrom queue import Empty, Full, Queue\nfrom threading import Thread\n\nimport torch\n\nfrom monai.data import DataLoader, Dataset\n\n\nclass ThreadBuffer:\n    \"\"\"\n    Iterates over values from self.src in a separate thread but yielding them in the current thread. This allows values\n    to be queued up asynchronously. The internal thread will continue running so long as the source has values or until\n    the stop() method is called.\n\n    One issue raised by using a thread in this way is that during the lifetime of the thread the source object is being\n    iterated over, so if the thread hasn't finished another attempt to iterate over it will raise an exception or yield\n    unexpected results. To ensure the thread releases the iteration and proper cleanup is done the stop() method must\n    be called which will join with the thread.\n\n    Args:\n        src: Source data iterable\n        buffer_size: Number of items to buffer from the source\n        timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items\n    \"\"\"\n\n    def __init__(self, src, buffer_size: int = 1, timeout: float = 0.01):\n        self.src = src\n        self.buffer_size = buffer_size\n        self.timeout = timeout\n        self.buffer: Queue = Queue(self.buffer_size)\n        self.gen_thread: Thread | None = None\n        self.is_running = False\n\n    def enqueue_values(self):\n        for src_val in self.src:\n            while self.is_running:\n                try:\n                    self.buffer.put(src_val, timeout=self.timeout)\n                except Full:\n                    pass  # try to add the item again\n                else:\n                    break  # successfully added the item, quit trying\n            else:  # quit the thread cleanly when requested to stop\n                break\n\n    def stop(self):\n        self.is_running = False  # signal the thread to exit\n\n        if self.gen_thread is not None:\n            self.gen_thread.join()\n\n        self.gen_thread = None\n\n    def __iter__(self):\n        self.is_running = True\n        self.gen_thread = Thread(target=self.enqueue_values, daemon=True)\n        self.gen_thread.start()\n\n        try:\n            while self.is_running and (self.gen_thread.is_alive() or not self.buffer.empty()):\n                try:\n                    yield self.buffer.get(timeout=self.timeout)\n                except Empty:\n                    pass  # queue was empty this time, try again\n        finally:\n            self.stop()  # ensure thread completion\n\n\ndef buffer_iterator(src, buffer_size: int = 1, timeout: float = 0.01, repeats: int = 1):\n    \"\"\"\n    Create a ThreadBuffer object using the `src`, `buffer_size`, and `timeout` parameters given for the constructor\n    arguments of the same names, and yield each generated object `repeats` number of times successively.\n\n    Args:\n        src: Source data iterable\n        buffer_size: Number of items to buffer from the source\n        timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items\n        repeats: Number of repeat generations to perform which is asynchronous from the generation of the next value\n\n    Returns:\n        Generator yield (repeated) values from `src` asynchronously\n    \"\"\"\n    buffer = ThreadBuffer(src=src, buffer_size=buffer_size, timeout=timeout)\n\n    for batch in buffer:\n        for _ in range(repeats):\n            yield batch\n\n\nclass _ProcessThread(Thread):\n    \"\"\"Shim class to make a thread look like a process to the DataLoader class.\"\"\"\n\n    @property\n    def pid(self):\n        return id(self)\n\n    def run(self):\n        try:\n            super().run()\n        finally:\n            torch.utils.data._utils.worker._worker_info = None  # clean up global data used for processes\n\n\nclass _ProcessQueue(Queue):\n    \"\"\"Shim class to make a thread queue look like a process queue to the DataLoader class.\"\"\"\n\n    def close(self):\n        pass\n\n    def cancel_join_thread(self):\n        pass\n\n\nclass _ProcessThreadContext(SpawnContext):\n    _name = \"processthread\"\n\n    # threads will be created which looks like processes\n    Process = _ProcessThread  # type: ignore\n    # thread queue used in place of process queue to avoid some weird cleanup errors\n    Queue = _ProcessQueue  # type: ignore\n\n\nclass ThreadDataLoader(DataLoader):\n    \"\"\"\n    Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will\n    iterate over data from the loader as expected however the data is generated on a separate thread. Use this class\n    where a `DataLoader` instance is required and not just an iterable object.\n\n    The default behaviour with `repeats` set to 1 is to yield each batch as it is generated, however with a higher\n    value the generated batch is yielded that many times while underlying dataset asynchronously generates the next.\n    Typically not all relevant information is learned from a batch in a single iteration so training multiple times\n    on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch\n    generation process more time to produce a result. This duplication is done by simply yielding the same object many\n    times and not by regenerating the data.\n\n    Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms\n    and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC\n    between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader,\n    `ThreadDataLoader` can be useful for GPU transforms. For more details:\n    https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.\n\n    The `use_thread_workers` will cause workers to be created as threads rather than processes although everything else\n    in terms of how the class works is unchanged. This allows multiple workers to be used in Windows for example, or in\n    any other situation where thread semantics is desired. Please note that some MONAI components like several datasets\n    and random transforms are not thread-safe and can't work as expected with `thread workers`, need to check all the\n    preprocessing components carefully before enabling `use_thread_workers`.\n\n    See:\n        * Fischetti et al. \"Faster SGD training by minibatch persistency.\" ArXiv (2018) https://arxiv.org/abs/1806.07353\n        * Dami et al., \"Faster Neural Network Training with Data Echoing\" ArXiv (2020) https://arxiv.org/abs/1907.05550\n        * Ramezani et al. \"GCN meets GPU: Decoupling \"When to Sample\" from \"How to Sample\".\" NeurIPS (2020).\n          https://proceedings.neurips.cc/paper/2020/file/d714d2c5a796d5814c565d78dd16188d-Paper.pdf\n\n    Args:\n        dataset: input dataset.\n        buffer_size: number of items to buffer from the data source.\n        buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.\n        repeats: number of times to yield the same batch.\n        use_thread_workers: if True and num_workers > 0 the workers are created as threads instead of processes\n        kwargs: other arguments for `DataLoader` except for `dataset`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Dataset,\n        buffer_size: int = 1,\n        buffer_timeout: float = 0.01,\n        repeats: int = 1,\n        use_thread_workers: bool = False,\n        **kwargs,\n    ):\n        # if workers should be threads, create a new multiprocessing context with the process and queue types\n        # substituted with the shim types given above\n        if use_thread_workers and kwargs.get(\"num_workers\", 0) > 0:\n            kwargs[\"multiprocessing_context\"] = _ProcessThreadContext()\n            kwargs[\"persistent_workers\"] = False\n\n        super().__init__(dataset, **kwargs)\n        self.buffer_size = buffer_size\n        self.buffer_timeout = buffer_timeout\n        self.repeats = repeats\n\n    def __iter__(self):\n        yield from buffer_iterator(super().__iter__(), self.buffer_size, self.buffer_timeout, self.repeats)\n"
  },
  {
    "path": "monai/data/torchscript_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport datetime\nimport json\nimport os\nfrom collections.abc import Mapping, Sequence\nfrom typing import IO, Any\n\nimport torch\n\nfrom monai.config import get_config_values\nfrom monai.utils import JITMetadataKeys\n\nMETADATA_FILENAME = \"metadata.json\"\n\n\ndef save_net_with_metadata(\n    jit_obj: torch.nn.Module,\n    filename_prefix_or_stream: str | IO[Any],\n    include_config_vals: bool = True,\n    append_timestamp: bool = False,\n    meta_values: Mapping[str, Any] | None = None,\n    more_extra_files: Mapping[str, bytes] | None = None,\n) -> None:\n    \"\"\"\n    Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata\n    included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used\n    here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be\n    compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will\n    include information about the network applicable to some use case, such as describing the input and output format,\n    a network name and version, a plain language description of what the network does, and other relevant scientific\n    information. Clients can use this information to determine automatically how to use the network, and users can\n    read what the network does and keep track of versions.\n\n    Examples::\n\n        net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2]))\n\n        meta = {\n            \"name\": \"Test UNet\",\n            \"used_for\": \"demonstration purposes\",\n            \"input_dims\": 2,\n            \"output_dims\": 2\n        }\n\n        # save the Torchscript bundle with the above dictionary stored as an extra file\n        save_net_with_metadata(m, \"test\", meta_values=meta)\n\n        # load the network back, `loaded_meta` has same data as `meta` plus version information\n        loaded_net, loaded_meta, _ = load_net_with_metadata(\"test.ts\")\n\n\n    Args:\n        jit_obj: object to save, should be generated by `script` or `trace`.\n        filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.ts`.\n        include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata.\n        append_timestamp: if True, a timestamp for \"now\" is appended to the file's name before the extension.\n        meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`.\n        more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`.\n    \"\"\"\n\n    now = datetime.datetime.now()\n    metadict = {}\n\n    if include_config_vals:\n        metadict.update(get_config_values())\n        metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat()\n\n    if meta_values is not None:\n        metadict.update(meta_values)\n\n    json_data = json.dumps(metadict)\n\n    extra_files = {METADATA_FILENAME: json_data.encode()}\n\n    if more_extra_files is not None:\n        extra_files.update(more_extra_files)\n\n    if isinstance(filename_prefix_or_stream, str):\n        filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream)\n        if ext == \"\":\n            ext = \".ts\"\n\n        if append_timestamp:\n            filename_prefix_or_stream = now.strftime(f\"{filename_no_ext}_%Y%m%d%H%M%S{ext}\")\n        else:\n            filename_prefix_or_stream = filename_no_ext + ext\n\n    torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files)\n\n\ndef load_net_with_metadata(\n    filename_prefix_or_stream: str | IO[Any],\n    map_location: torch.device | None = None,\n    more_extra_files: Sequence[str] = (),\n) -> tuple[torch.nn.Module, dict, dict]:\n    \"\"\"\n    Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata\n    back to a dict object. This will produce an empty dict if the metadata file is not present.\n\n    Args:\n        filename_prefix_or_stream: filename or file-like stream object.\n        map_location: network map location as in `torch.jit.load`.\n        more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`.\n    Returns:\n        Triple containing loaded object, metadata dict, and extra files dict containing other file data if present\n    \"\"\"\n    extra_files = dict.fromkeys(more_extra_files, \"\")\n    extra_files[METADATA_FILENAME] = \"\"\n\n    jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files)\n\n    extra_files = dict(extra_files.items())  # compatibility with ExtraFilesMap\n\n    if METADATA_FILENAME in extra_files:\n        json_data = extra_files[METADATA_FILENAME]\n        del extra_files[METADATA_FILENAME]\n    else:\n        json_data = \"{}\"\n\n    json_data_dict = json.loads(json_data)\n\n    return jit_obj, json_data_dict, extra_files\n"
  },
  {
    "path": "monai/data/ultrasound_confidence_map.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport numpy as np\nfrom numpy.typing import NDArray\n\nfrom monai.utils import min_version, optional_import\n\n__all__ = [\"UltrasoundConfidenceMap\"]\n\ncv2, _ = optional_import(\"cv2\")\ncsc_matrix, _ = optional_import(\"scipy.sparse\", \"1.12.0\", min_version, \"csc_matrix\")\nspsolve, _ = optional_import(\"scipy.sparse.linalg\", \"1.12.0\", min_version, \"spsolve\")\ncg, _ = optional_import(\"scipy.sparse.linalg\", \"1.12.0\", min_version, \"cg\")\nhilbert, _ = optional_import(\"scipy.signal\", \"1.12.0\", min_version, \"hilbert\")\nruge_stuben_solver, _ = optional_import(\"pyamg\", \"5.0.0\", min_version, \"ruge_stuben_solver\")\n\n\nclass UltrasoundConfidenceMap:\n    \"\"\"Compute confidence map from an ultrasound image.\n    This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005.\n    It generates a confidence map by setting source and sink points in the image and computing the probability\n    for random walks to reach the source for each pixel.\n\n    The official code is available at:\n    https://campar.in.tum.de/Main/AthanasiosKaramalisCode\n\n    Args:\n        alpha (float, optional): Alpha parameter. Defaults to 2.0.\n        beta (float, optional): Beta parameter. Defaults to 90.0.\n        gamma (float, optional): Gamma parameter. Defaults to 0.05.\n        mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.\n        sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling\n            the transform. Can be 'all', 'mid', 'min', or 'mask'.\n        use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.\n        cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.\n            Will be used only if `use_cg` is True.\n        cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.\n            Will be used only if `use_cg` is True.\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha: float = 2.0,\n        beta: float = 90.0,\n        gamma: float = 0.05,\n        mode=\"B\",\n        sink_mode=\"all\",\n        use_cg=False,\n        cg_tol=1e-6,\n        cg_maxiter=200,\n    ):\n        # The hyperparameters for confidence map estimation\n        self.alpha = alpha\n        self.beta = beta\n        self.gamma = gamma\n        self.mode = mode\n        self.sink_mode = sink_mode\n        self.use_cg = use_cg\n        self.cg_tol = cg_tol\n        self.cg_maxiter = cg_maxiter\n\n        # The precision to use for all computations\n        self.eps = np.finfo(\"float64\").eps\n\n        # Store sink indices for external use\n        self._sink_indices = np.array([], dtype=\"float64\")\n\n    def sub2ind(self, size: tuple[int, ...], rows: NDArray, cols: NDArray) -> NDArray:\n        \"\"\"Converts row and column subscripts into linear indices,\n        basically the copy of the MATLAB function of the same name.\n        https://www.mathworks.com/help/matlab/ref/sub2ind.html\n\n        This function is Pythonic so the indices start at 0.\n\n        Args:\n            size Tuple[int]: Size of the matrix\n            rows (NDArray): Row indices\n            cols (NDArray): Column indices\n\n        Returns:\n            indices (NDArray): 1-D array of linear indices\n        \"\"\"\n        indices: NDArray = rows + cols * size[0]\n        return indices\n\n    def get_seed_and_labels(\n        self, data: NDArray, sink_mode: str = \"all\", sink_mask: NDArray | None = None\n    ) -> tuple[NDArray, NDArray]:\n        \"\"\"Get the seed and label arrays for the max-flow algorithm\n\n        Args:\n            data: Input array\n            sink_mode (str, optional): Sink mode. Defaults to 'all'.\n            sink_mask (NDArray, optional): Sink mask. Defaults to None.\n\n        Returns:\n            Tuple[NDArray, NDArray]: Seed and label arrays\n        \"\"\"\n\n        # Seeds and labels (boundary conditions)\n        seeds = np.array([], dtype=\"float64\")\n        labels = np.array([], dtype=\"float64\")\n\n        # Indices for all columns\n        sc = np.arange(data.shape[1], dtype=\"float64\")\n\n        # SOURCE ELEMENTS - 1st matrix row\n        # Indices for 1st row, it will be broadcasted with sc\n        sr_up = np.array([0])\n        seed = self.sub2ind(data.shape, sr_up, sc).astype(\"float64\")\n        seed = np.unique(seed)\n        seeds = np.concatenate((seeds, seed))\n\n        # Label 1\n        label = np.ones_like(seed)\n        labels = np.concatenate((labels, label))\n\n        # Create seeds for sink elements\n\n        if sink_mode == \"all\":\n            # All elements in the last row\n            sr_down = np.ones_like(sc) * (data.shape[0] - 1)\n            self._sink_indices = np.array([sr_down, sc], dtype=\"int32\")\n            seed = self.sub2ind(data.shape, sr_down, sc).astype(\"float64\")\n\n        elif sink_mode == \"mid\":\n            # Middle element in the last row\n            sc_down = np.array([data.shape[1] // 2])\n            sr_down = np.ones_like(sc_down) * (data.shape[0] - 1)\n            self._sink_indices = np.array([sr_down, sc_down], dtype=\"int32\")\n            seed = self.sub2ind(data.shape, sr_down, sc_down).astype(\"float64\")\n\n        elif sink_mode == \"min\":\n            # Minimum element in the last row (excluding 10% from the edges)\n            ten_percent = int(data.shape[1] * 0.1)\n            min_val = np.min(data[-1, ten_percent:-ten_percent])\n            min_idxs = np.where(data[-1, ten_percent:-ten_percent] == min_val)[0] + ten_percent\n            sc_down = min_idxs\n            sr_down = np.ones_like(sc_down) * (data.shape[0] - 1)\n            self._sink_indices = np.array([sr_down, sc_down], dtype=\"int32\")\n            seed = self.sub2ind(data.shape, sr_down, sc_down).astype(\"float64\")\n\n        elif sink_mode == \"mask\":\n            # All elements in the mask\n            coords = np.where(sink_mask != 0)\n            sr_down = coords[0]\n            sc_down = coords[1]\n            self._sink_indices = np.array([sr_down, sc_down], dtype=\"int32\")\n            seed = self.sub2ind(data.shape, sr_down, sc_down).astype(\"float64\")\n\n        seed = np.unique(seed)\n        seeds = np.concatenate((seeds, seed))\n\n        # Label 2\n        label = np.ones_like(seed) * 2\n        labels = np.concatenate((labels, label))\n\n        return seeds, labels\n\n    def normalize(self, inp: NDArray) -> NDArray:\n        \"\"\"Normalize an array to [0, 1]\"\"\"\n        normalized_array: NDArray = (inp - np.min(inp)) / (np.ptp(inp) + self.eps)\n        return normalized_array\n\n    def attenuation_weighting(self, img: NDArray, alpha: float) -> NDArray:\n        \"\"\"Compute attenuation weighting\n\n        Args:\n            img (NDArray): Image\n            alpha: Attenuation coefficient (see publication)\n\n        Returns:\n            w (NDArray): Weighting expressing depth-dependent attenuation\n        \"\"\"\n\n        # Create depth vector and repeat it for each column\n        dw = np.linspace(0, 1, img.shape[0], dtype=\"float64\")\n        dw = np.tile(dw.reshape(-1, 1), (1, img.shape[1]))\n\n        w: NDArray = 1.0 - np.exp(-alpha * dw)  # Compute exp inline\n\n        return w\n\n    def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, beta: float, gamma: float):\n        \"\"\"Compute 6-Connected Laplacian for confidence estimation problem\n\n        Args:\n            padded_index (NDArray): The index matrix of the image with boundary padding.\n            padded_image (NDArray): The padded image.\n            beta (float): Random walks parameter that defines the sensitivity of the Gaussian weighting function.\n            gamma (float): Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.\n\n        Returns:\n            L (csc_matrix): The 6-connected Laplacian matrix used for confidence map estimation.\n        \"\"\"\n\n        m, _ = padded_index.shape\n\n        padded_index = padded_index.T.flatten()\n        padded_image = padded_image.T.flatten()\n\n        p = np.where(padded_index > 0)[0]\n\n        i = padded_index[p] - 1  # Index vector\n        j = padded_index[p] - 1  # Index vector\n        # Entries vector, initially for diagonal\n        s = np.zeros_like(p, dtype=\"float64\")\n\n        edge_templates = [\n            -1,  # Vertical edges\n            1,\n            m - 1,  # Diagonal edges\n            m + 1,\n            -m - 1,\n            -m + 1,\n            m,  # Horizontal edges\n            -m,\n        ]\n\n        vertical_end = None\n\n        for iter_idx, k in enumerate(edge_templates):\n            neigh_idxs = padded_index[p + k]\n\n            q = np.where(neigh_idxs > 0)[0]\n\n            ii = padded_index[p[q]] - 1\n            i = np.concatenate((i, ii))\n            jj = neigh_idxs[q] - 1\n            j = np.concatenate((j, jj))\n            w = np.abs(padded_image[p[ii]] - padded_image[p[jj]])  # Intensity derived weight\n            s = np.concatenate((s, w))\n\n            if iter_idx == 1:\n                vertical_end = s.shape[0]  # Vertical edges length\n            elif iter_idx == 5:\n                s.shape[0]  # Diagonal edges length\n\n        # Normalize weights\n        s = self.normalize(s)\n\n        # Horizontal penalty\n        s[vertical_end:] += gamma\n        # Here there is a difference between the official MATLAB code and the paper\n        # on the edge penalty. We directly implement what the official code does.\n\n        # Normalize differences\n        s = self.normalize(s)\n\n        # Gaussian weighting function\n        s = -(\n            (np.exp(-beta * s, dtype=\"float64\")) + 1e-5\n        )  # --> This epsilon changes results drastically default: 10e-6\n        # Please notice that it is not 1e-6, it is 10e-6 which is actually different.\n\n        # Create Laplacian, diagonal missing\n        lap = csc_matrix((s, (i, j)))\n\n        # Reset diagonal weights to zero for summing\n        # up the weighted edge degree in the next step\n        lap.setdiag(0)\n\n        # Weighted edge degree\n        diag = np.abs(lap.sum(axis=0).A)[0]\n\n        # Finalize Laplacian by completing the diagonal\n        lap.setdiag(diag)\n\n        return lap\n\n    def _solve_linear_system(self, lap, rhs):\n\n        if self.use_cg:\n            lap_sparse = lap.tocsr()\n            ml = ruge_stuben_solver(lap_sparse, coarse_solver=\"pinv\")\n            m = ml.aspreconditioner(cycle=\"V\")\n            x, _ = cg(lap, rhs, rtol=self.cg_tol, maxiter=self.cg_maxiter, M=m)\n        else:\n            x = spsolve(lap, rhs)\n\n        return x\n\n    def confidence_estimation(self, img, seeds, labels, beta, gamma):\n        \"\"\"Compute confidence map\n\n        Args:\n            img (NDArray): Processed image.\n            seeds (NDArray): Seeds for the random walks framework. These are indices of the source and sink nodes.\n            labels (NDArray): Labels for the random walks framework. These represent the classes or groups of the seeds.\n            beta: Random walks parameter that defines the sensitivity of the Gaussian weighting function.\n            gamma: Horizontal penalty factor that adjusts the weight of horizontal edges in the Laplacian.\n\n        Returns:\n            map: Confidence map which shows the probability of each pixel belonging to the source or sink group.\n        \"\"\"\n\n        # Index matrix with boundary padding\n        idx = np.arange(1, img.shape[0] * img.shape[1] + 1).reshape(img.shape[1], img.shape[0]).T\n        pad = 1\n\n        padded_idx = np.pad(idx, (pad, pad), \"constant\", constant_values=(0, 0))\n        padded_img = np.pad(img, (pad, pad), \"constant\", constant_values=(0, 0))\n\n        # Laplacian\n        lap = self.confidence_laplacian(padded_idx, padded_img, beta, gamma)\n\n        # Select marked columns from Laplacian to create L_M and B^T\n        b = lap[:, seeds]\n\n        # Select marked nodes to create B^T\n        n = np.sum(padded_idx > 0).item()\n        i_u = np.setdiff1d(np.arange(n), seeds.astype(int))  # Index of unmarked nodes\n        b = b[i_u, :]\n\n        # Remove marked nodes from Laplacian by deleting rows and cols\n        keep_indices = np.setdiff1d(np.arange(lap.shape[0]), seeds)\n        lap = csc_matrix(lap[keep_indices, :][:, keep_indices])\n\n        # Define M matrix\n        m = np.zeros((seeds.shape[0], 1), dtype=\"float64\")\n        m[:, 0] = labels == 1\n\n        # Right-handside (-B^T*M)\n        rhs = -b @ m\n\n        # Solve linear system\n        x = self._solve_linear_system(lap, rhs)\n\n        # Prepare output\n        probabilities = np.zeros((n,), dtype=\"float64\")\n        # Probabilities for unmarked nodes\n        probabilities[i_u] = x\n        # Max probability for marked node\n        probabilities[seeds[labels == 1].astype(int)] = 1.0\n\n        # Final reshape with same size as input image (no padding)\n        probabilities = probabilities.reshape((img.shape[1], img.shape[0])).T\n\n        return probabilities\n\n    def __call__(self, data: NDArray, sink_mask: NDArray | None = None) -> NDArray:\n        \"\"\"Compute the confidence map\n\n        Args:\n            data (NDArray): RF ultrasound data (one scanline per column) [H x W] 2D array\n\n        Returns:\n            map (NDArray): Confidence map [H x W] 2D array\n        \"\"\"\n\n        # Normalize data\n        data = data.astype(\"float64\")\n        data = self.normalize(data)\n\n        if self.mode == \"RF\":\n            # MATLAB hilbert applies the Hilbert transform to columns\n            data = np.abs(hilbert(data, axis=0)).astype(\"float64\")\n\n        seeds, labels = self.get_seed_and_labels(data, self.sink_mode, sink_mask)\n\n        # Attenuation with Beer-Lambert\n        w = self.attenuation_weighting(data, self.alpha)\n\n        # Apply weighting directly to image\n        # Same as applying it individually during the formation of the\n        # Laplacian\n        data = data * w\n\n        # Find condidence values\n        map_: NDArray = self.confidence_estimation(data, seeds, labels, self.beta, self.gamma)\n\n        return map_\n"
  },
  {
    "path": "monai/data/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport hashlib\nimport json\nimport logging\nimport math\nimport os\nimport pickle\nimport sys\nfrom collections import abc, defaultdict\nfrom collections.abc import Generator, Iterable, Mapping, Sequence, Sized\nfrom copy import deepcopy\nfrom functools import reduce\nfrom itertools import product, starmap, zip_longest\nfrom pathlib import PurePath\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom torch.utils.data._utils.collate import default_collate\n\nfrom monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike\nfrom monai.data.meta_obj import MetaObj\nfrom monai.utils import (\n    MAX_SEED,\n    BlendMode,\n    Method,\n    NumpyPadMode,\n    TraceKeys,\n    convert_data_type,\n    convert_to_dst_type,\n    ensure_tuple,\n    ensure_tuple_rep,\n    ensure_tuple_size,\n    fall_back_tuple,\n    first,\n    get_equivalent_dtype,\n    issequenceiterable,\n    look_up_option,\n    optional_import,\n)\n\npd, _ = optional_import(\"pandas\")\nDataFrame, _ = optional_import(\"pandas\", name=\"DataFrame\")\nnib, _ = optional_import(\"nibabel\")\n\n__all__ = [\n    \"AFFINE_TOL\",\n    \"SUPPORTED_PICKLE_MOD\",\n    \"affine_to_spacing\",\n    \"compute_importance_map\",\n    \"compute_shape_offset\",\n    \"convert_tables_to_dicts\",\n    \"correct_nifti_header_if_necessary\",\n    \"create_file_basename\",\n    \"decollate_batch\",\n    \"dense_patch_slices\",\n    \"get_random_patch\",\n    \"get_valid_patch_size\",\n    \"is_supported_format\",\n    \"iter_patch\",\n    \"iter_patch_position\",\n    \"iter_patch_slices\",\n    \"json_hashing\",\n    \"list_data_collate\",\n    \"no_collation\",\n    \"orientation_ras_lps\",\n    \"pad_list_data_collate\",\n    \"partition_dataset\",\n    \"partition_dataset_classes\",\n    \"pickle_hashing\",\n    \"rectify_header_sform_qform\",\n    \"reorient_spatial_axes\",\n    \"resample_datalist\",\n    \"select_cross_validation_folds\",\n    \"set_rnd\",\n    \"sorted_dict\",\n    \"to_affine_nd\",\n    \"worker_init_fn\",\n    \"zoom_affine\",\n    \"remove_keys\",\n    \"remove_extra_metadata\",\n    \"get_extra_metadata_keys\",\n    \"is_no_channel\",\n]\n\n# module to be used by `torch.save`\nSUPPORTED_PICKLE_MOD = {\"pickle\": pickle}\n\n# tolerance for affine matrix computation\nAFFINE_TOL = 1e-3\n\n\ndef get_random_patch(\n    dims: Sequence[int], patch_size: Sequence[int], rand_state: np.random.RandomState | None = None\n) -> tuple[slice, ...]:\n    \"\"\"\n    Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size` or the as\n    close to it as possible within the given dimension. It is expected that `patch_size` is a valid patch for a source\n    of shape `dims` as returned by `get_valid_patch_size`.\n\n    Args:\n        dims: shape of source array\n        patch_size: shape of patch size to generate\n        rand_state: a random state object to generate random numbers from\n\n    Returns:\n        (tuple of slice): a tuple of slice objects defining the patch\n    \"\"\"\n\n    # choose the minimal corner of the patch\n    rand_int = np.random.randint if rand_state is None else rand_state.randint\n    min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))\n\n    # create the slices for each dimension which define the patch in the source array\n    return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))\n\n\ndef iter_patch_slices(\n    image_size: Sequence[int],\n    patch_size: Sequence[int] | int,\n    start_pos: Sequence[int] = (),\n    overlap: Sequence[float] | float = 0.0,\n    padded: bool = True,\n) -> Generator[tuple[slice, ...], None, None]:\n    \"\"\"\n    Yield successive tuples of slices defining patches of size `patch_size` from an array of dimensions `image_size`.\n    The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each\n    patch is chosen in a contiguous grid using a rwo-major ordering.\n\n    Args:\n        image_size: dimensions of array to iterate over\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension\n        start_pos: starting position in the array, default is 0 for each dimension\n        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).\n            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.\n        padded: if the image is padded so the patches can go beyond the borders. Defaults to False.\n\n    Yields:\n        Tuples of slice objects defining each patch\n    \"\"\"\n\n    # ensure patch_size has the right length\n    patch_size_ = get_valid_patch_size(image_size, patch_size)\n\n    # create slices based on start position of each patch\n    for position in iter_patch_position(\n        image_size=image_size, patch_size=patch_size_, start_pos=start_pos, overlap=overlap, padded=padded\n    ):\n        yield tuple(slice(s, s + p) for s, p in zip(position, patch_size_))\n\n\ndef dense_patch_slices(\n    image_size: Sequence[int], patch_size: Sequence[int], scan_interval: Sequence[int], return_slice: bool = True\n) -> list[tuple[slice, ...]]:\n    \"\"\"\n    Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.\n\n    Args:\n        image_size: dimensions of image to iterate over\n        patch_size: size of patches to generate slices\n        scan_interval: dense patch sampling interval\n        return_slice: whether to return a list of slices (or tuples of indices), defaults to True\n\n    Returns:\n        a list of slice objects defining each patch\n\n    \"\"\"\n    num_spatial_dims = len(image_size)\n    patch_size = get_valid_patch_size(image_size, patch_size)\n    scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)\n\n    scan_num = []\n    for i in range(num_spatial_dims):\n        if scan_interval[i] == 0:\n            scan_num.append(1)\n        else:\n            num = int(math.ceil(float(image_size[i]) / scan_interval[i]))\n            scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])\n            scan_num.append(scan_dim + 1 if scan_dim is not None else 1)\n\n    starts = []\n    for dim in range(num_spatial_dims):\n        dim_starts = []\n        for idx in range(scan_num[dim]):\n            start_idx = idx * scan_interval[dim]\n            start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)\n            dim_starts.append(start_idx)\n        starts.append(dim_starts)\n    out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing=\"ij\")]).T\n    if return_slice:\n        return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]\n    return [tuple((s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]  # type: ignore\n\n\ndef iter_patch_position(\n    image_size: Sequence[int],\n    patch_size: Sequence[int] | int | np.ndarray,\n    start_pos: Sequence[int] = (),\n    overlap: Sequence[float] | float | Sequence[int] | int = 0.0,\n    padded: bool = False,\n):\n    \"\"\"\n    Yield successive tuples of upper left corner of patches of size `patch_size` from an array of dimensions `image_size`.\n    The iteration starts from position `start_pos` in the array, or starting at the origin if this isn't provided. Each\n    patch is chosen in a contiguous grid using a rwo-major ordering.\n\n    Args:\n        image_size: dimensions of array to iterate over\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension\n        start_pos: starting position in the array, default is 0 for each dimension\n        overlap: the amount of overlap of neighboring patches in each dimension.\n            Either a float or list of floats between 0.0 and 1.0 to define relative overlap to patch size, or\n            an int or list of ints to define number of pixels for overlap.\n            If only one float/int number is given, it will be applied to all dimensions. Defaults to 0.0.\n        padded: if the image is padded so the patches can go beyond the borders. Defaults to False.\n\n    Yields:\n        Tuples of positions defining the upper left corner of each patch\n    \"\"\"\n\n    # ensure patchSize and startPos are the right length\n    ndim = len(image_size)\n    patch_size_ = get_valid_patch_size(image_size, patch_size)\n    start_pos = ensure_tuple_size(start_pos, ndim)\n    overlap = ensure_tuple_rep(overlap, ndim)\n\n    # calculate steps, which depends on the amount of overlap\n    if isinstance(overlap[0], float):\n        steps = tuple(round(p * (1.0 - o)) for p, o in zip(patch_size_, overlap))\n    else:\n        steps = tuple(p - o for p, o in zip(patch_size_, overlap))\n\n    # calculate the last starting location (depending on the padding)\n    end_pos = image_size if padded else tuple(s - round(p) + 1 for s, p in zip(image_size, patch_size_))\n\n    # collect the ranges to step over each dimension\n    ranges = starmap(range, zip(start_pos, end_pos, steps))\n\n    # choose patches by applying product to the ranges\n    return product(*ranges)\n\n\ndef iter_patch(\n    arr: NdarrayOrTensor,\n    patch_size: Sequence[int] | int = 0,\n    start_pos: Sequence[int] = (),\n    overlap: Sequence[float] | float = 0.0,\n    copy_back: bool = True,\n    mode: str | None = NumpyPadMode.WRAP,\n    **pad_opts: dict,\n) -> Generator[tuple[NdarrayOrTensor, np.ndarray], None, None]:\n    \"\"\"\n    Yield successive patches from `arr` of size `patch_size`. The iteration can start from position `start_pos` in `arr`\n    but drawing from a padded array extended by the `patch_size` in each dimension (so these coordinates can be negative\n    to start in the padded region). If `copy_back` is True the values from each patch are written back to `arr`.\n\n    Args:\n        arr: array to iterate over\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension.\n            For 0 or None, padding and overlap ratio of the corresponding dimension will be 0.\n        start_pos: starting position in the array, default is 0 for each dimension\n        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).\n            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.\n        copy_back: if True data from the yielded patches is copied back to `arr` once the generator completes\n        mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function.\n            If None, no wrapping is performed. Defaults to ``\"wrap\"``.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        pad_opts: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    Yields:\n        Patches of array data from `arr` which are views into a padded array which can be modified, if `copy_back` is\n        True these changes will be reflected in `arr` once the iteration completes.\n\n    Note:\n        coordinate format is:\n\n            [1st_dim_start, 1st_dim_end,\n             2nd_dim_start, 2nd_dim_end,\n             ...,\n             Nth_dim_start, Nth_dim_end]]\n\n    \"\"\"\n\n    from monai.transforms.croppad.functional import pad_nd  # needs to be here to avoid circular import\n\n    # ensure patchSize and startPos are the right length\n    patch_size_ = get_valid_patch_size(arr.shape, patch_size)\n    start_pos = ensure_tuple_size(start_pos, arr.ndim)\n\n    # set padded flag to false if pad mode is None\n    padded = bool(mode)\n    is_v = [bool(p) for p in ensure_tuple_size(patch_size, arr.ndim)]  # whether a valid patch size provided\n    _pad_size = tuple(p if v and padded else 0 for p, v in zip(patch_size_, is_v))  # pad p if v else 0\n    _overlap = [op if v else 0.0 for op, v in zip(ensure_tuple_rep(overlap, arr.ndim), is_v)]  # overlap if v else 0.0\n    # pad image by maximum values needed to ensure patches are taken from inside an image\n    if padded:\n        arrpad = pad_nd(arr, to_pad=[(p, p) for p in _pad_size], mode=mode, **pad_opts)  # type: ignore\n        # choose a start position in the padded image\n        start_pos_padded = tuple(s + p for s, p in zip(start_pos, _pad_size))\n\n        # choose a size to iterate over which is smaller than the actual padded image to prevent producing\n        # patches which are only in the padded regions\n        iter_size = tuple(s + p for s, p in zip(arr.shape, _pad_size))\n    else:\n        arrpad = arr\n        start_pos_padded = start_pos\n        iter_size = arr.shape\n\n    for slices in iter_patch_slices(iter_size, patch_size_, start_pos_padded, _overlap, padded=padded):\n        # compensate original image padding\n        if padded:\n            coords_no_pad = tuple((coord.start - p, coord.stop - p) for coord, p in zip(slices, _pad_size))\n        else:\n            coords_no_pad = tuple((coord.start, coord.stop) for coord in slices)\n        yield arrpad[slices], np.asarray(coords_no_pad)  # data and coords (in numpy; works with torch loader)\n\n    # copy back data from the padded image if required\n    if copy_back:\n        slices = tuple(slice(p, p + s) for p, s in zip(_pad_size, arr.shape))\n        arr[...] = arrpad[slices]  # type: ignore\n\n\ndef get_valid_patch_size(image_size: Sequence[int], patch_size: Sequence[int] | int | np.ndarray) -> tuple[int, ...]:\n    \"\"\"\n    Given an image of dimensions `image_size`, return a patch size tuple taking the dimension from `patch_size` if this is\n    not 0/None. Otherwise, or if `patch_size` is shorter than `image_size`, the dimension from `image_size` is taken. This ensures\n    the returned patch size is within the bounds of `image_size`. If `patch_size` is a single number this is interpreted as a\n    patch of the same dimensionality of `image_size` with that size in each dimension.\n    \"\"\"\n    ndim = len(image_size)\n    patch_size_ = ensure_tuple_size(patch_size, ndim)\n\n    # ensure patch size dimensions are not larger than image dimension, if a dimension is None or 0 use whole dimension\n    return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_))\n\n\ndef dev_collate(batch, level: int = 1, logger_name: str = \"dev_collate\"):\n    \"\"\"\n    Recursively run collate logic and provide detailed loggings for debugging purposes.\n    It reports results at the 'critical' level, is therefore suitable in the context of exception handling.\n\n    Args:\n        batch: batch input to collate\n        level: current level of recursion for logging purposes\n        logger_name: name of logger to use for logging\n\n    See also: https://pytorch.org/docs/stable/data.html#working-with-collate-fn\n    \"\"\"\n    elem = batch[0]\n    elem_type = type(elem)\n    l_str = \">\" * level\n    batch_str = f\"{batch[:10]}{' ... ' if len(batch) > 10 else ''}\"\n    if isinstance(elem, torch.Tensor):\n        try:\n            logging.getLogger(logger_name).critical(f\"{l_str} collate/stack a list of tensors\")\n            return torch.stack(batch, 0)\n        except TypeError as e:\n            logging.getLogger(logger_name).critical(\n                f\"{l_str} E: {e}, type {[type(elem).__name__ for elem in batch]} in collate({batch_str})\"\n            )\n            return\n        except RuntimeError as e:\n            logging.getLogger(logger_name).critical(\n                f\"{l_str} E: {e}, shape {[elem.shape for elem in batch]} in collate({batch_str})\"\n            )\n            return\n    elif elem_type.__module__ == \"numpy\" and elem_type.__name__ != \"str_\" and elem_type.__name__ != \"string_\":\n        if elem_type.__name__ in [\"ndarray\", \"memmap\"]:\n            logging.getLogger(logger_name).critical(f\"{l_str} collate/stack a list of numpy arrays\")\n            return dev_collate([torch.as_tensor(b) for b in batch], level=level, logger_name=logger_name)\n        elif elem.shape == ():  # scalars\n            return batch\n    elif isinstance(elem, (float, int, str, bytes)):\n        return batch\n    elif isinstance(elem, abc.Mapping):\n        out = {}\n        for key in elem:\n            logging.getLogger(logger_name).critical(f'{l_str} collate dict key \"{key}\" out of {len(elem)} keys')\n            out[key] = dev_collate([d[key] for d in batch], level=level + 1, logger_name=logger_name)\n        return out\n    elif isinstance(elem, abc.Sequence):\n        it = iter(batch)\n        els = list(it)\n        try:\n            sizes = [len(elem) for elem in els]  # may not have `len`\n        except TypeError:\n            types = [type(elem).__name__ for elem in els]\n            logging.getLogger(logger_name).critical(f\"{l_str} E: type {types} in collate({batch_str})\")\n            return\n        logging.getLogger(logger_name).critical(f\"{l_str} collate list of sizes: {sizes}.\")\n        if any(s != sizes[0] for s in sizes):\n            logging.getLogger(logger_name).critical(\n                f\"{l_str} collate list inconsistent sizes, got size: {sizes}, in collate({batch_str})\"\n            )\n        transposed = zip(*batch)\n        return [dev_collate(samples, level=level + 1, logger_name=logger_name) for samples in transposed]\n    logging.getLogger(logger_name).critical(f\"{l_str} E: unsupported type in collate {batch_str}.\")\n    return\n\n\ndef collate_meta_tensor_fn(batch, *, collate_fn_map=None):\n    \"\"\"\n    Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`\n    and so should not be used as a collate function directly in dataloaders.\n    \"\"\"\n    from torch.utils.data._utils.collate import collate_tensor_fn  # imported here for pylint/mypy issues\n\n    collated = collate_tensor_fn(batch)\n\n    meta_dicts = [i.meta or TraceKeys.NONE for i in batch]\n    common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])\n    if common_:\n        meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts]\n    collated.meta = default_collate(meta_dicts)\n    collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]\n    collated.is_batch = True\n    return collated\n\n\ndef collate_meta_tensor(batch):\n    \"\"\"collate a sequence of meta tensor sequences/dictionaries into\n    a single batched metatensor or a dictionary of batched metatensor\"\"\"\n    if not isinstance(batch, Sequence):\n        raise NotImplementedError()\n    elem_0 = first(batch)\n    if isinstance(elem_0, MetaObj):\n        return collate_meta_tensor_fn(batch)\n    if isinstance(elem_0, Mapping):\n        return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}\n    if isinstance(elem_0, (tuple, list)):\n        return [collate_meta_tensor([d[i] for d in batch]) for i in range(len(elem_0))]\n\n    # no more recursive search for MetaTensor\n    return default_collate(batch)\n\n\ndef list_data_collate(batch: Sequence):\n    \"\"\"\n    Enhancement for PyTorch DataLoader default collate.\n    If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list.\n    Then it's same as the default collate behavior.\n\n    Note:\n        Need to use this collate if apply some transforms that can generate batch data.\n\n    \"\"\"\n    from torch.utils.data._utils.collate import default_collate_fn_map\n\n    from monai.data.meta_tensor import MetaTensor\n\n    default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})\n    elem = batch[0]\n    data = [i for k in batch for i in k] if isinstance(elem, list) else batch\n    key = None\n    collate_fn = default_collate\n    try:\n        # if config.USE_META_DICT:\n        # data = pickle_operations(data)  # bc 0.9.0\n        if isinstance(elem, Mapping):\n            ret = {}\n            for k in elem:\n                key = k\n                data_for_batch = [d[key] for d in data]\n                ret[key] = collate_fn(data_for_batch)\n        else:\n            ret = collate_fn(data)\n        return ret\n    except RuntimeError as re:\n        re_str = str(re)\n        if \"equal size\" in re_str:\n            if key is not None:\n                re_str += f\"\\nCollate error on the key '{key}' of dictionary data.\"\n            re_str += (\n                \"\\n\\nMONAI hint: if your transforms intentionally create images of different shapes, creating your \"\n                + \"`DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its \"\n                + \"documentation).\"\n            )\n        _ = dev_collate(data)\n        raise RuntimeError(re_str) from re\n    except TypeError as re:\n        re_str = str(re)\n        if \"numpy\" in re_str and \"Tensor\" in re_str:\n            if key is not None:\n                re_str += f\"\\nCollate error on the key '{key}' of dictionary data.\"\n            re_str += (\n                \"\\n\\nMONAI hint: if your transforms intentionally create mixtures of torch Tensor and numpy ndarray, \"\n                + \"creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem \"\n                + \"(check its documentation).\"\n            )\n        _ = dev_collate(data)\n        raise TypeError(re_str) from re\n\n\ndef _non_zipping_check(batch_data: Mapping | Iterable, detach: bool, pad: bool, fill_value):\n    \"\"\"\n    Utility function based on `decollate_batch`, to identify the largest batch size from the collated data.\n    returns batch_size, the list of non-iterable items, and the dictionary or list with their items decollated.\n\n    See `decollate_batch` for more details.\n    \"\"\"\n    _deco: Mapping | Sequence\n    if isinstance(batch_data, Mapping):\n        _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}\n    elif isinstance(batch_data, Iterable):\n        _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]\n    else:\n        raise NotImplementedError(f\"Unable to de-collate: {batch_data}, type: {type(batch_data)}.\")\n    batch_size, non_iterable = 0, []\n    for k, v in _deco.items() if isinstance(_deco, Mapping) else enumerate(_deco):\n        if not isinstance(v, Iterable) or isinstance(v, (str, bytes)) or (isinstance(v, torch.Tensor) and v.ndim == 0):\n            # Not running the usual list decollate here:\n            # don't decollate ['test', 'test'] into [['t', 't'], ['e', 'e'], ['s', 's'], ['t', 't']]\n            # torch.tensor(0) is iterable but iter(torch.tensor(0)) raises TypeError: iteration over a 0-d tensor\n            non_iterable.append(k)\n        elif isinstance(v, Sized):\n            batch_size = max(batch_size, len(v))\n    return batch_size, non_iterable, _deco\n\n\ndef decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):\n    \"\"\"De-collate a batch of data (for example, as produced by a `DataLoader`).\n\n    Returns a list of structures with the original tensor's 0-th dimension sliced into elements using `torch.unbind`.\n\n    Images originally stored as (B,C,H,W,[D]) will be returned as (C,H,W,[D]). Other information,\n    such as metadata, may have been stored in a list (or a list inside nested dictionaries). In\n    this case we return the element of the list corresponding to the batch idx.\n\n    Return types aren't guaranteed to be the same as the original, since numpy arrays will have been\n    converted to torch.Tensor, sequences may be converted to lists of tensors,\n    mappings may be converted into dictionaries.\n\n    For example:\n\n    .. code-block:: python\n\n        batch_data = {\n            \"image\": torch.rand((2,1,10,10)),\n            DictPostFix.meta(\"image\"): {\"scl_slope\": torch.Tensor([0.0, 0.0])}\n        }\n        out = decollate_batch(batch_data)\n        print(len(out))\n        >>> 2\n\n        print(out[0])\n        >>> {'image': tensor([[[4.3549e-01...43e-01]]]), DictPostFix.meta(\"image\"): {'scl_slope': 0.0}}\n\n        batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))]\n        out = decollate_batch(batch_data)\n        print(out[0])\n        >>> [tensor([[[4.3549e-01...43e-01]]], tensor([[[5.3435e-01...45e-01]]])]\n\n        batch_data = torch.rand((2,1,10,10))\n        out = decollate_batch(batch_data)\n        print(out[0])\n        >>> tensor([[[4.3549e-01...43e-01]]])\n\n        batch_data = {\n            \"image\": [1, 2, 3], \"meta\": [4, 5],  # undetermined batch size\n        }\n        out = decollate_batch(batch_data, pad=True, fill_value=0)\n        print(out)\n        >>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}, {'image': 3, 'meta': 0}]\n        out = decollate_batch(batch_data, pad=False)\n        print(out)\n        >>> [{'image': 1, 'meta': 4}, {'image': 2, 'meta': 5}]\n\n    Args:\n        batch: data to be de-collated.\n        detach: whether to detach the tensors. Scalars tensors will be detached into number types\n            instead of torch tensors.\n        pad: when the items in a batch indicate different batch size, whether to pad all the sequences to the longest.\n            If False, the batch size will be the length of the shortest sequence.\n        fill_value: when `pad` is True, the `fillvalue` to use when padding, defaults to `None`.\n    \"\"\"\n    if batch is None:\n        return batch\n    if isinstance(batch, (float, int, str, bytes)) or (\n        type(batch).__module__ == \"numpy\" and not isinstance(batch, Iterable)\n    ):\n        return batch\n    # if scalar tensor/array, return the item itself.\n    if getattr(batch, \"ndim\", -1) == 0 and hasattr(batch, \"item\"):\n        return batch.item() if detach else batch\n    if isinstance(batch, torch.Tensor):\n        if detach:\n            batch = batch.detach()\n        out_list = torch.unbind(batch, dim=0)\n        # if of type MetaObj, decollate the metadata\n        if isinstance(batch, MetaObj):\n            for t, m in zip(out_list, decollate_batch(batch.meta)):\n                if isinstance(t, MetaObj):\n                    t.meta = m\n                    t.is_batch = False\n            for t, m in zip(out_list, batch.applied_operations):\n                if isinstance(t, MetaObj):\n                    t.applied_operations = m\n                    t.is_batch = False\n        if out_list[0].ndim == 0 and detach:\n            return [t.item() for t in out_list]\n        return list(out_list)\n\n    b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)\n    if b <= 0:  # all non-iterable, single item \"batch\"? {\"image\": 1, \"label\": 1}\n        return deco\n    if pad:  # duplicate non-iterable items to the longest batch\n        for k in non_iterable:\n            deco[k] = [deepcopy(deco[k]) for _ in range(b)]\n    if isinstance(deco, Mapping):\n        _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values())\n        ret = [dict(zip(deco, item)) for item in _gen]\n        # if not config.USE_META_DICT:\n        # return ret\n        # return pickle_operations(ret, is_encode=False)  # bc 0.9.0\n        return ret\n    if isinstance(deco, Iterable):\n        _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco)\n        ret_list = [list(item) for item in _gen]\n        # if not config.USE_META_DICT:\n        # return ret_list\n        # return pickle_operations(ret_list, is_encode=False)  # bc 0.9.0\n        return ret_list\n    raise NotImplementedError(f\"Unable to de-collate: {batch}, type: {type(batch)}.\")\n\n\ndef pad_list_data_collate(batch: Sequence, method: str = Method.SYMMETRIC, mode: str = NumpyPadMode.CONSTANT, **kwargs):\n    \"\"\"\n    Function version of :py:class:`monai.transforms.croppad.batch.PadListDataCollate`.\n\n    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest\n    tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of\n    different sizes.\n\n    This can be used on both list and dictionary data.\n    Note that in the case of the dictionary data, this decollate function may add the transform information of\n    `PadListDataCollate` to the list of invertible transforms if input batch have different spatial shape, so need to\n    call static method: `monai.transforms.croppad.batch.PadListDataCollate.inverse` before inverting other transforms.\n\n    Args:\n        batch: batch of data to pad-collate\n        method: padding method (see :py:class:`monai.transforms.SpatialPad`)\n        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n    from monai.transforms.croppad.batch import PadListDataCollate  # needs to be here to avoid circular import\n\n    return PadListDataCollate(method=method, mode=mode, **kwargs)(batch)\n\n\ndef no_collation(x):\n    \"\"\"\n    No any collation operation.\n    \"\"\"\n    return x\n\n\ndef worker_init_fn(worker_id: int) -> None:\n    \"\"\"\n    Callback function for PyTorch DataLoader `worker_init_fn`.\n    It can set different random seed for the transforms in different workers.\n\n    \"\"\"\n    worker_info = torch.utils.data.get_worker_info()\n    set_rnd(worker_info.dataset, seed=worker_info.seed)  # type: ignore[union-attr]\n\n\ndef set_rnd(obj, seed: int) -> int:\n    \"\"\"\n    Set seed or random state for all randomizable properties of obj.\n\n    Args:\n        obj: object to set seed or random state for.\n        seed: set the random state with an integer seed.\n    \"\"\"\n    if isinstance(obj, (tuple, list)):  # ZipDataset.data is a list\n        _seed = seed\n        for item in obj:\n            _seed = set_rnd(item, seed=seed)\n        return seed if _seed == seed else seed + 1  # return a different seed if there are randomizable items\n    if not hasattr(obj, \"__dict__\"):\n        return seed  # no attribute\n    if hasattr(obj, \"set_random_state\"):\n        obj.set_random_state(seed=seed % MAX_SEED)\n        return seed + 1  # a different seed for the next component\n    for key in obj.__dict__:\n        if key.startswith(\"__\"):  # skip the private methods\n            continue\n        seed = set_rnd(obj.__dict__[key], seed=seed)\n    return seed\n\n\ndef affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_zeros: bool = True) -> NdarrayTensor:\n    \"\"\"\n    Computing the current spacing from the affine matrix.\n\n    Args:\n        affine: a d x d affine matrix.\n        r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`.\n        dtype: data type of the output.\n        suppress_zeros: whether to suppress the zeros with ones.\n\n    Returns:\n        an `r` dimensional vector of spacing.\n    \"\"\"\n    if len(affine.shape) != 2 or affine.shape[0] != affine.shape[1]:\n        raise ValueError(f\"affine must be a square matrix, got {affine.shape}.\")\n    _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype)\n    if isinstance(_affine, torch.Tensor):\n        spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))\n    else:\n        spacing = np.sqrt(np.sum(_affine * _affine, axis=0))  # type: ignore[operator]\n    if suppress_zeros:\n        spacing[spacing == 0] = 1.0\n    spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)\n    return spacing_\n\n\ndef correct_nifti_header_if_necessary(img_nii):\n    \"\"\"\n    Check nifti object header's format, update the header if needed.\n    In the updated image pixdim matches the affine.\n\n    Args:\n        img_nii: nifti image object\n    \"\"\"\n    if img_nii.header.get(\"dim\") is None:\n        return img_nii  # not nifti?\n    dim = img_nii.header[\"dim\"][0]\n    if dim >= 5:\n        return img_nii  # do nothing for high-dimensional array\n    # check that affine matches zooms\n    pixdim = np.asarray(img_nii.header.get_zooms())[:dim]\n    norm_affine = affine_to_spacing(img_nii.affine, r=dim)\n    if np.allclose(pixdim, norm_affine):\n        return img_nii\n    if hasattr(img_nii, \"get_sform\"):\n        return rectify_header_sform_qform(img_nii)\n    return img_nii\n\n\ndef rectify_header_sform_qform(img_nii):\n    \"\"\"\n    Look at the sform and qform of the nifti object and correct it if any\n    incompatibilities with pixel dimensions\n\n    Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/io/misc_io.py\n\n    Args:\n        img_nii: nifti image object\n    \"\"\"\n    d = img_nii.header[\"dim\"][0]\n    pixdim = np.asarray(img_nii.header.get_zooms())[:d]\n    sform, qform = img_nii.get_sform(), img_nii.get_qform()\n    norm_sform = affine_to_spacing(sform, r=d)\n    norm_qform = affine_to_spacing(qform, r=d)\n    sform_mismatch = not np.allclose(norm_sform, pixdim)\n    qform_mismatch = not np.allclose(norm_qform, pixdim)\n\n    if img_nii.header[\"sform_code\"] != 0:\n        if not sform_mismatch:\n            return img_nii\n        if not qform_mismatch:\n            img_nii.set_sform(img_nii.get_qform())\n            return img_nii\n    if img_nii.header[\"qform_code\"] != 0:\n        if not qform_mismatch:\n            return img_nii\n        if not sform_mismatch:\n            img_nii.set_qform(img_nii.get_sform())\n            return img_nii\n\n    norm = affine_to_spacing(img_nii.affine, r=d)\n\n    img_nii.header.set_zooms(norm)\n    return img_nii\n\n\ndef zoom_affine(affine: np.ndarray, scale: np.ndarray | Sequence[float], diagonal: bool = True):\n    \"\"\"\n    To make column norm of `affine` the same as `scale`.  If diagonal is False,\n    returns an affine that combines orthogonal rotation and the new scale.\n    This is done by first decomposing `affine`, then setting the zoom factors to\n    `scale`, and composing a new affine; the shearing factors are removed.  If\n    diagonal is True, returns a diagonal matrix, the scaling factors are set\n    to the diagonal elements.  This function always return an affine with zero\n    translations.\n\n    Args:\n        affine (nxn matrix): a square matrix.\n        scale: new scaling factor along each dimension. if the components of the `scale` are non-positive values,\n            will use the corresponding components of the original pixdim, which is computed from the `affine`.\n        diagonal: whether to return a diagonal scaling matrix.\n            Defaults to True.\n\n    Raises:\n        ValueError: When ``affine`` is not a square matrix.\n        ValueError: When ``scale`` contains a nonpositive scalar.\n\n    Returns:\n        the updated `n x n` affine.\n\n    \"\"\"\n\n    affine = np.array(affine, dtype=float, copy=True)\n    if len(affine) != len(affine[0]):\n        raise ValueError(f\"affine must be n x n, got {len(affine)} x {len(affine[0])}.\")\n    scale_np = np.array(scale, dtype=float, copy=True)\n\n    d = len(affine) - 1\n    # compute original pixdim\n    norm = affine_to_spacing(affine, r=d)\n    if len(scale_np) < d:  # defaults based on affine\n        scale_np = np.append(scale_np, norm[len(scale_np) :])\n    scale_np = scale_np[:d]\n    scale_np = np.asarray(fall_back_tuple(scale_np, norm))\n\n    scale_np[scale_np == 0] = 1.0\n    if diagonal:\n        return np.diag(np.append(scale_np, [1.0]))\n    rzs = affine[:-1, :-1]  # rotation zoom scale\n    zs = np.linalg.cholesky(rzs.T @ rzs).T\n    rotation = rzs @ np.linalg.inv(zs)\n    s = np.sign(np.diag(zs)) * np.abs(scale_np)\n    # construct new affine with rotation and zoom\n    new_affine = np.eye(len(affine))\n    new_affine[:-1, :-1] = rotation @ np.diag(s)\n    return new_affine\n\n\ndef compute_shape_offset(\n    spatial_shape: np.ndarray | Sequence[int],\n    in_affine: NdarrayOrTensor,\n    out_affine: NdarrayOrTensor,\n    scale_extent: bool = False,\n) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"\n    Given input and output affine, compute appropriate shapes\n    in the output space based on the input array's shape.\n    This function also returns the offset to put the shape\n    in a good position with respect to the world coordinate system.\n\n    Args:\n        spatial_shape: input array's shape\n        in_affine (matrix): 2D affine matrix\n        out_affine (matrix): 2D affine matrix\n        scale_extent: whether the scale is computed based on the spacing or the full extent of voxels, for example, for\n            a factor of 0.5 scaling:\n\n            option 1, \"o\" represents a voxel, scaling the distance between voxels::\n\n                o--o--o\n                o-----o\n\n            option 2, each voxel has a physical extent, scaling the full voxel extent::\n\n                | voxel 1 | voxel 2 | voxel 3 | voxel 4 |\n                |      voxel 1      |      voxel 2      |\n\n            Option 1 may reduce the number of locations that requiring interpolation. Option 2 is more resolution\n            agnostic, that is, resampling coordinates depend on the scaling factor, not on the number of voxels.\n            Default is False, using option 1 to compute the shape and offset.\n\n    \"\"\"\n    shape = np.array(spatial_shape, copy=True, dtype=float)\n    sr = len(shape)\n    in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0]\n    out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0]\n    in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape]\n    corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing=\"ij\")).reshape((len(shape), -1))\n    corners = np.concatenate((corners, np.ones_like(corners[:1])))\n    try:\n        corners_out = np.linalg.solve(out_affine_, in_affine_) @ corners\n    except np.linalg.LinAlgError as e:\n        raise ValueError(f\"Affine {out_affine_} is not invertible\") from e\n    corners = in_affine_ @ corners\n    all_dist = corners_out[:-1].copy()\n    corners_out = corners_out[:-1] / corners_out[-1]\n    out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0)\n    offset = None\n    for i in range(corners.shape[1]):\n        min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)\n        if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL):\n            offset = corners[:-1, i]  # corner is the smallest, shift the corner to origin\n            break\n    if offset is None:  # otherwise make output image center aligned with the input image center\n        offset = in_affine_[:-1, :-1] @ (shape / 2.0) + in_affine_[:-1, -1] - out_affine_[:-1, :-1] @ (out_shape / 2.0)\n    if scale_extent:\n        in_offset = np.append(0.5 * (shape / out_shape - 1.0), 1.0)\n        offset = np.abs((in_affine_ @ in_offset / in_offset[-1])[:-1]) * np.sign(offset)\n    return out_shape.astype(int, copy=False), offset  # type: ignore\n\n\ndef to_affine_nd(r: np.ndarray | int, affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor:\n    \"\"\"\n    Using elements from affine, to create a new affine matrix by\n    assigning the rotation/zoom/scaling matrix and the translation vector.\n\n    When ``r`` is an integer, output is an (r+1)x(r+1) matrix,\n    where the top left kxk elements are copied from ``affine``,\n    the last column of the output affine is copied from ``affine``'s last column.\n    `k` is determined by `min(r, len(affine) - 1)`.\n\n    When ``r`` is an affine matrix, the output has the same shape as ``r``,\n    and the top left kxk elements are copied from ``affine``,\n    the last column of the output affine is copied from ``affine``'s last column.\n    `k` is determined by `min(len(r) - 1, len(affine) - 1)`.\n\n    Args:\n        r (int or matrix): number of spatial dimensions or an output affine to be filled.\n        affine (matrix): 2D affine matrix\n        dtype: data type of the output array.\n\n    Raises:\n        ValueError: When ``affine`` dimensions is not 2.\n        ValueError: When ``r`` is nonpositive.\n\n    Returns:\n        an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)\n\n    \"\"\"\n    dtype = get_equivalent_dtype(dtype, np.ndarray)\n    affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]\n    affine_np = affine_np.copy()\n    if affine_np.ndim != 2:\n        raise ValueError(f\"affine must have 2 dimensions, got {affine_np.ndim}.\")\n    new_affine = np.array(r, dtype=dtype, copy=True)\n    if new_affine.ndim == 0:\n        sr: int = int(new_affine.astype(np.uint))\n        if not np.isfinite(sr) or sr < 0:\n            raise ValueError(f\"r must be positive, got {sr}.\")\n        new_affine = np.eye(sr + 1, dtype=dtype)\n    d = max(min(len(new_affine) - 1, len(affine_np) - 1), 1)\n    new_affine[:d, :d] = affine_np[:d, :d]\n    if d > 1:\n        new_affine[:d, -1] = affine_np[:d, -1]\n    output, *_ = convert_to_dst_type(new_affine, affine, dtype=dtype)\n    return output\n\n\ndef reorient_spatial_axes(\n    data_shape: Sequence[int], init_affine: NdarrayOrTensor, target_affine: NdarrayOrTensor\n) -> tuple[np.ndarray, NdarrayOrTensor]:\n    \"\"\"\n    Given the input ``init_affine``, compute the orientation transform between\n    it and ``target_affine`` by rearranging/flipping the axes.\n\n    Returns the orientation transform and the updated affine (tensor or ndarray\n    depends on the input ``affine`` data type).\n    Note that this function requires external module ``nibabel.orientations``.\n    \"\"\"\n    init_affine_, *_ = convert_data_type(init_affine, np.ndarray)\n    target_affine_, *_ = convert_data_type(target_affine, np.ndarray)\n    start_ornt = nib.orientations.io_orientation(init_affine_)\n    target_ornt = nib.orientations.io_orientation(target_affine_)\n    try:\n        ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)\n    except ValueError as e:\n        raise ValueError(f\"The input affine {init_affine} and target affine {target_affine} are not compatible.\") from e\n    new_affine = init_affine_ @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape)\n    new_affine, *_ = convert_to_dst_type(new_affine, init_affine)\n    return ornt_transform, new_affine\n\n\ndef create_file_basename(\n    postfix: str,\n    input_file_name: PathLike,\n    folder_path: PathLike,\n    data_root_dir: PathLike = \"\",\n    separate_folder: bool = True,\n    patch_index=None,\n    makedirs: bool = True,\n) -> str:\n    \"\"\"\n    Utility function to create the path to the output file based on the input\n    filename (file name extension is not added by this function).\n    When ``data_root_dir`` is not specified, the output file name is:\n\n        `folder_path/input_file_name (no ext.) /input_file_name (no ext.)[_postfix][_patch_index]`\n\n    otherwise the relative path with respect to ``data_root_dir`` will be inserted, for example:\n\n    .. code-block:: python\n\n        from monai.data import create_file_basename\n        create_file_basename(\n            postfix=\"seg\",\n            input_file_name=\"/foo/bar/test1/image.png\",\n            folder_path=\"/output\",\n            data_root_dir=\"/foo/bar\",\n            separate_folder=True,\n            makedirs=False)\n        # output: /output/test1/image/image_seg\n\n    Args:\n        postfix: output name's postfix\n        input_file_name: path to the input image file.\n        folder_path: path for the output file\n        data_root_dir: if not empty, it specifies the beginning parts of the input file's\n            absolute path. This is used to compute `input_file_rel_path`, the relative path to the file from\n            `data_root_dir` to preserve folder structure when saving in case there are files in different\n            folders with the same file names.\n        separate_folder: whether to save every file in a separate folder, for example: if input filename is\n            `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as:\n            `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`.\n        patch_index: if not None, append the patch index to filename.\n        makedirs: whether to create the folder if it does not exist.\n    \"\"\"\n\n    # get the filename and directory\n    filedir, filename = os.path.split(input_file_name)\n    # remove extension\n    filename, ext = os.path.splitext(filename)\n    if ext == \".gz\":\n        filename, ext = os.path.splitext(filename)\n    # use data_root_dir to find relative path to file\n    filedir_rel_path = \"\"\n    if data_root_dir and filedir:\n        filedir_rel_path = os.path.relpath(filedir, data_root_dir)\n\n    # output folder path will be original name without the extension\n    output = os.path.join(folder_path, filedir_rel_path)\n\n    if separate_folder:\n        output = os.path.join(output, filename)\n\n    if makedirs:\n        # create target folder if no existing\n        os.makedirs(output, exist_ok=True)\n\n    # add the sub-folder plus the postfix name to become the file basename in the output path\n    output = os.path.join(output, filename + \"_\" + postfix if postfix != \"\" else filename)\n\n    if patch_index is not None:\n        output += f\"_{patch_index}\"\n\n    return os.path.normpath(output)\n\n\ndef compute_importance_map(\n    patch_size: tuple[int, ...],\n    mode: BlendMode | str = BlendMode.CONSTANT,\n    sigma_scale: Sequence[float] | float = 0.125,\n    device: torch.device | int | str = \"cpu\",\n    dtype: torch.dtype | str | None = torch.float32,\n) -> torch.Tensor:\n    \"\"\"Get importance map for different weight modes.\n\n    Args:\n        patch_size: Size of the required importance map. This should be either H, W [,D].\n        mode: {``\"constant\"``, ``\"gaussian\"``}\n            How to blend output of overlapping windows. Defaults to ``\"constant\"``.\n\n            - ``\"constant``\": gives equal weight to all predictions.\n            - ``\"gaussian``\": gives less weight to predictions on edges of windows.\n\n        sigma_scale: Sigma_scale to calculate sigma for each dimension\n            (sigma = sigma_scale * dim_size). Used for gaussian mode only.\n        device: Device to put importance map on.\n        dtype: Data type of the output importance map.\n\n    Raises:\n        ValueError: When ``mode`` is not one of [\"constant\", \"gaussian\"].\n\n    Returns:\n        Tensor of size patch_size.\n\n    \"\"\"\n    mode = look_up_option(mode, BlendMode)\n    device = torch.device(device)\n    if mode == BlendMode.CONSTANT:\n        importance_map = torch.ones(patch_size, device=device, dtype=torch.float)\n    elif mode == BlendMode.GAUSSIAN:\n        sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))\n        sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]\n\n        for i in range(len(patch_size)):\n            x = torch.arange(\n                start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=torch.float, device=device\n            )\n            x = torch.exp(x**2 / (-2 * sigmas[i] ** 2))  # 1D gaussian\n            importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x\n    else:\n        raise ValueError(\n            f\"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}].\"\n        )\n    # handle non-positive weights\n    min_non_zero = max(torch.min(importance_map).item(), 1e-3)\n    importance_map = torch.clamp_(importance_map.to(torch.float), min=min_non_zero).to(dtype)\n    return importance_map\n\n\ndef is_supported_format(filename: Sequence[PathLike] | PathLike, suffixes: Sequence[str]) -> bool:\n    \"\"\"\n    Verify whether the specified file or files format match supported suffixes.\n    If supported suffixes is None, skip the verification and return True.\n\n    Args:\n        filename: file name or a list of file names to read.\n            if a list of files, verify all the suffixes.\n        suffixes: all the supported image suffixes of current reader, must be a list of lower case suffixes.\n\n    \"\"\"\n    filenames: Sequence[PathLike] = ensure_tuple(filename)\n    for name in filenames:\n        full_suffix = \"\".join(map(str.lower, PurePath(name).suffixes))\n        if all(f\".{s.lower()}\" not in full_suffix for s in suffixes):\n            return False\n\n    return True\n\n\ndef partition_dataset(\n    data: Sequence,\n    ratios: Sequence[float] | None = None,\n    num_partitions: int | None = None,\n    shuffle: bool = False,\n    seed: int = 0,\n    drop_last: bool = False,\n    even_divisible: bool = False,\n):\n    \"\"\"\n    Split the dataset into N partitions. It can support shuffle based on specified random seed.\n    Will return a set of datasets, every dataset contains 1 partition of original dataset.\n    And it can split the dataset based on specified ratios or evenly split into `num_partitions`.\n    Refer to: https://pytorch.org/docs/stable/distributed.html#module-torch.distributed.launch.\n\n    Note:\n        It also can be used to partition dataset for ranks in distributed training.\n        For example, partition dataset before training and use `CacheDataset`, every rank trains with its own data.\n        It can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch:\n\n        .. code-block:: python\n\n            data_partition = partition_dataset(\n                data=train_files,\n                num_partitions=dist.get_world_size(),\n                shuffle=True,\n                even_divisible=True,\n            )[dist.get_rank()]\n\n            train_ds = SmartCacheDataset(\n                data=data_partition,\n                transform=train_transforms,\n                replace_rate=0.2,\n                cache_num=15,\n            )\n\n    Args:\n        data: input dataset to split, expect a list of data.\n        ratios: a list of ratio number to split the dataset, like [8, 1, 1].\n        num_partitions: expected number of the partitions to evenly split, only works when `ratios` not specified.\n        shuffle: whether to shuffle the original dataset before splitting.\n        seed: random seed to shuffle the dataset, only works when `shuffle` is True.\n        drop_last: only works when `even_divisible` is False and no ratios specified.\n            if True, will drop the tail of the data to make it evenly divisible across partitions.\n            if False, will add extra indices to make the data evenly divisible across partitions.\n        even_divisible: if True, guarantee every partition has same length.\n\n    Examples::\n\n        >>> data = [1, 2, 3, 4, 5]\n        >>> partition_dataset(data, ratios=[0.6, 0.2, 0.2], shuffle=False)\n        [[1, 2, 3], [4], [5]]\n        >>> partition_dataset(data, num_partitions=2, shuffle=False)\n        [[1, 3, 5], [2, 4]]\n        >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=True)\n        [[1, 3], [2, 4]]\n        >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=True, drop_last=False)\n        [[1, 3, 5], [2, 4, 1]]\n        >>> partition_dataset(data, num_partitions=2, shuffle=False, even_divisible=False, drop_last=False)\n        [[1, 3, 5], [2, 4]]\n\n    \"\"\"\n    data_len = len(data)\n    datasets = []\n\n    indices = list(range(data_len))\n    if shuffle:\n        # deterministically shuffle based on fixed seed for every process\n        rs = np.random.RandomState(seed)\n        rs.shuffle(indices)\n\n    if ratios:\n        next_idx = 0\n        rsum = sum(ratios)\n        for r in ratios:\n            start_idx = next_idx\n            next_idx = min(start_idx + int(r / rsum * data_len + 0.5), data_len)\n            datasets.append([data[i] for i in indices[start_idx:next_idx]])\n        return datasets\n\n    if not num_partitions:\n        raise ValueError(\"must specify number of partitions or ratios.\")\n    # evenly split the data without ratios\n    if not even_divisible and drop_last:\n        raise RuntimeError(\"drop_last only works when even_divisible is True.\")\n    if data_len < num_partitions:\n        raise RuntimeError(f\"there is no enough data to be split into {num_partitions} partitions.\")\n\n    if drop_last and data_len % num_partitions != 0:\n        # split to nearest available length that is evenly divisible\n        num_samples = math.ceil((data_len - num_partitions) / num_partitions)\n    else:\n        num_samples = math.ceil(data_len / num_partitions)\n    # use original data length if not even divisible\n    total_size = num_samples * num_partitions if even_divisible else data_len\n\n    if not drop_last and total_size - data_len > 0:\n        # add extra samples to make it evenly divisible\n        indices += indices[: (total_size - data_len)]\n    else:\n        # remove tail of data to make it evenly divisible\n        indices = indices[:total_size]\n\n    for i in range(num_partitions):\n        _indices = indices[i:total_size:num_partitions]\n        datasets.append([data[j] for j in _indices])\n\n    return datasets\n\n\ndef partition_dataset_classes(\n    data: Sequence,\n    classes: Sequence[int],\n    ratios: Sequence[float] | None = None,\n    num_partitions: int | None = None,\n    shuffle: bool = False,\n    seed: int = 0,\n    drop_last: bool = False,\n    even_divisible: bool = False,\n):\n    \"\"\"\n    Split the dataset into N partitions based on the given class labels.\n    It can make sure the same ratio of classes in every partition.\n    Others are same as :py:class:`monai.data.partition_dataset`.\n\n    Args:\n        data: input dataset to split, expect a list of data.\n        classes: a list of labels to help split the data, the length must match the length of data.\n        ratios: a list of ratio number to split the dataset, like [8, 1, 1].\n        num_partitions: expected number of the partitions to evenly split, only works when no `ratios`.\n        shuffle: whether to shuffle the original dataset before splitting.\n        seed: random seed to shuffle the dataset, only works when `shuffle` is True.\n        drop_last: only works when `even_divisible` is False and no ratios specified.\n            if True, will drop the tail of the data to make it evenly divisible across partitions.\n            if False, will add extra indices to make the data evenly divisible across partitions.\n        even_divisible: if True, guarantee every partition has same length.\n\n    Examples::\n\n        >>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]\n        >>> classes = [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]\n        >>> partition_dataset_classes(data, classes, shuffle=False, ratios=[2, 1])\n        [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]]\n\n    \"\"\"\n    if not issequenceiterable(classes) or len(classes) != len(data):\n        raise ValueError(f\"length of classes {classes} must match the dataset length {len(data)}.\")\n    datasets = []\n    class_indices = defaultdict(list)\n    for i, c in enumerate(classes):\n        class_indices[c].append(i)\n\n    class_partition_indices: list[Sequence] = []\n    for _, per_class_indices in sorted(class_indices.items()):\n        per_class_partition_indices = partition_dataset(\n            data=per_class_indices,\n            ratios=ratios,\n            num_partitions=num_partitions,\n            shuffle=shuffle,\n            seed=seed,\n            drop_last=drop_last,\n            even_divisible=even_divisible,\n        )\n        if not class_partition_indices:\n            class_partition_indices = per_class_partition_indices\n        else:\n            for part, data_indices in zip(class_partition_indices, per_class_partition_indices):\n                part += data_indices\n\n    rs = np.random.RandomState(seed)\n    for indices in class_partition_indices:\n        if shuffle:\n            rs.shuffle(indices)\n        datasets.append([data[j] for j in indices])\n\n    return datasets\n\n\ndef resample_datalist(data: Sequence, factor: float, random_pick: bool = False, seed: int = 0):\n    \"\"\"\n    Utility function to resample the loaded datalist for training, for example:\n    If factor < 1.0, randomly pick part of the datalist and set to Dataset, useful to quickly test the program.\n    If factor > 1.0, repeat the datalist to enhance the Dataset.\n\n    Args:\n        data: original datalist to scale.\n        factor: scale factor for the datalist, for example, factor=4.5, repeat the datalist 4 times and plus\n            50% of the original datalist.\n        random_pick: whether to randomly pick data if scale factor has decimal part.\n        seed: random seed to randomly pick data.\n\n    \"\"\"\n    scale, repeats = math.modf(factor)\n    ret: list = list()\n\n    for _ in range(int(repeats)):\n        ret.extend(list(deepcopy(data)))\n    if scale > 1e-6:\n        ret.extend(partition_dataset(data=data, ratios=[scale, 1 - scale], shuffle=random_pick, seed=seed)[0])\n\n    return ret\n\n\ndef select_cross_validation_folds(partitions: Sequence[Iterable], folds: Sequence[int] | int) -> list:\n    \"\"\"\n    Select cross validation data based on data partitions and specified fold index.\n    if a list of fold indices is provided, concatenate the partitions of these folds.\n\n    Args:\n        partitions: a sequence of datasets, each item is a iterable\n        folds: the indices of the partitions to be combined.\n\n    Returns:\n        A list of combined datasets.\n\n    Example::\n\n        >>> partitions = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n        >>> select_cross_validation_folds(partitions, 2)\n        [5, 6]\n        >>> select_cross_validation_folds(partitions, [1, 2])\n        [3, 4, 5, 6]\n        >>> select_cross_validation_folds(partitions, [-1, 2])\n        [9, 10, 5, 6]\n    \"\"\"\n    return [data_item for fold_id in ensure_tuple(folds) for data_item in partitions[fold_id]]\n\n\ndef json_hashing(item) -> bytes:\n    \"\"\"\n\n    Args:\n        item: data item to be hashed\n\n    Returns: the corresponding hash key\n\n    \"\"\"\n    # TODO: Find way to hash transforms content as part of the cache\n    cache_key = \"\"\n    if sys.version_info.minor < 9:\n        cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode(\"utf-8\")).hexdigest()\n    else:\n        cache_key = hashlib.md5(\n            json.dumps(item, sort_keys=True).encode(\"utf-8\"), usedforsecurity=False  # type: ignore\n        ).hexdigest()\n    return f\"{cache_key}\".encode()\n\n\ndef pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes:\n    \"\"\"\n\n    Args:\n        item: data item to be hashed\n        protocol: protocol version used for pickling,\n            defaults to `pickle.HIGHEST_PROTOCOL`.\n\n    Returns: the corresponding hash key\n\n    \"\"\"\n    cache_key = \"\"\n    if sys.version_info.minor < 9:\n        cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest()\n    else:\n        cache_key = hashlib.md5(\n            pickle.dumps(sorted_dict(item), protocol=protocol), usedforsecurity=False  # type: ignore\n        ).hexdigest()\n    return f\"{cache_key}\".encode()\n\n\ndef sorted_dict(item, key=None, reverse=False):\n    \"\"\"Return a new sorted dictionary from the `item`.\"\"\"\n    if not isinstance(item, dict):\n        return item\n    return {k: sorted_dict(v) if isinstance(v, dict) else v for k, v in sorted(item.items(), key=key, reverse=reverse)}\n\n\ndef convert_tables_to_dicts(\n    dfs,\n    row_indices: Sequence[int | str] | None = None,\n    col_names: Sequence[str] | None = None,\n    col_types: dict[str, dict[str, Any] | None] | None = None,\n    col_groups: dict[str, Sequence[str]] | None = None,\n    **kwargs,\n) -> list[dict[str, Any]]:\n    \"\"\"\n    Utility to join pandas tables, select rows, columns and generate groups.\n    Will return a list of dictionaries, every dictionary maps to a row of data in tables.\n\n    Args:\n        dfs: data table in pandas Dataframe format. if providing a list of tables, will join them.\n        row_indices: indices of the expected rows to load. it should be a list,\n            every item can be a int number or a range `[start, end)` for the indices.\n            for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,\n            load all the rows in the file.\n        col_names: names of the expected columns to load. if None, load all the columns.\n        col_types: `type` and `default value` to convert the loaded columns, if None, use original data.\n            it should be a dictionary, every item maps to an expected column, the `key` is the column\n            name and the `value` is None or a dictionary to define the default value and data type.\n            the supported keys in dictionary are: [\"type\", \"default\"], and note that the value of `default`\n            should not be `None`. for example::\n\n                col_types = {\n                    \"subject_id\": {\"type\": str},\n                    \"label\": {\"type\": int, \"default\": 0},\n                    \"ehr_0\": {\"type\": float, \"default\": 0.0},\n                    \"ehr_1\": {\"type\": float, \"default\": 0.0},\n                }\n\n        col_groups: args to group the loaded columns to generate a new column,\n            it should be a dictionary, every item maps to a group, the `key` will\n            be the new column name, the `value` is the names of columns to combine. for example:\n            `col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(10)], \"meta\": [\"meta_1\", \"meta_2\"]}`\n        kwargs: additional arguments for `pandas.merge()` API to join tables.\n\n    \"\"\"\n    df = reduce(lambda l, r: pd.merge(l, r, **kwargs), ensure_tuple(dfs))\n    # parse row indices\n    rows: list[int | str] = []\n    if row_indices is None:\n        rows = df.index.tolist()\n    else:\n        for i in row_indices:\n            if isinstance(i, (tuple, list)):\n                if len(i) != 2:\n                    raise ValueError(\"range of row indices must contain 2 values: start and end.\")\n                rows.extend(list(range(i[0], i[1])))\n            else:\n                rows.append(i)\n\n    # convert to a list of dictionaries corresponding to every row\n    data_ = df.loc[rows] if col_names is None else df.loc[rows, col_names]\n    if isinstance(col_types, dict):\n        # fill default values for NaN\n        defaults = {k: v[\"default\"] for k, v in col_types.items() if v is not None and v.get(\"default\") is not None}\n        if defaults:\n            data_ = data_.fillna(value=defaults)\n        # convert data types\n        types = {k: v[\"type\"] for k, v in col_types.items() if v is not None and \"type\" in v}\n        if types:\n            data_ = data_.astype(dtype=types, copy=False)\n    data: list[dict] = data_.to_dict(orient=\"records\")\n\n    # group columns to generate new column\n    if col_groups is not None:\n        groups: dict[str, list] = {}\n        for name, cols in col_groups.items():\n            groups[name] = df.loc[rows, cols].values\n        # invert items of groups to every row of data\n        data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)]\n\n    return data\n\n\ndef orientation_ras_lps(affine: NdarrayTensor) -> NdarrayTensor:\n    \"\"\"\n    Convert the ``affine`` between the `RAS` and `LPS` orientation\n    by flipping the first two spatial dimensions.\n\n    Args:\n        affine: a 2D affine matrix.\n    \"\"\"\n    sr = max(affine.shape[0] - 1, 1)  # spatial rank is at least 1\n    flip_d = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]]\n    flip_diag = flip_d[min(sr - 1, 2)] + [1] * (sr - 3)\n    if isinstance(affine, torch.Tensor):\n        return torch.diag(torch.as_tensor(flip_diag).to(affine)) @ affine  # type: ignore\n    return np.diag(flip_diag).astype(affine.dtype) @ affine  # type: ignore\n\n\ndef remove_keys(data: dict, keys: list[str]) -> None:\n    \"\"\"\n    Remove keys from a dictionary. Operates in-place so nothing is returned.\n\n    Args:\n        data: dictionary to be modified.\n        keys: keys to be deleted from dictionary.\n\n    Returns:\n        `None`\n    \"\"\"\n    for k in keys:\n        _ = data.pop(k, None)\n\n\ndef remove_extra_metadata(meta: dict) -> None:\n    \"\"\"\n    Remove extra metadata from the dictionary. Operates in-place so nothing is returned.\n\n    Args:\n        meta: dictionary containing metadata to be modified.\n\n    Returns:\n        `None`\n    \"\"\"\n    keys = get_extra_metadata_keys()\n    remove_keys(data=meta, keys=keys)\n\n\ndef get_extra_metadata_keys() -> list[str]:\n    \"\"\"\n    Get a list of unnecessary keys for metadata that can be removed.\n\n    Returns:\n        List of keys to be removed.\n    \"\"\"\n    keys = [\n        \"srow_x\",\n        \"srow_y\",\n        \"srow_z\",\n        \"quatern_b\",\n        \"quatern_c\",\n        \"quatern_d\",\n        \"qoffset_x\",\n        \"qoffset_y\",\n        \"qoffset_z\",\n        \"dim\",\n        \"pixdim\",\n        *[f\"dim[{i}]\" for i in range(8)],\n        *[f\"pixdim[{i}]\" for i in range(8)],\n    ]\n\n    # TODO: it would be good to remove these, but they are currently being used in the\n    # codebase.\n    # keys += [\n    #     \"original_affine\",\n    #     \"spatial_shape\",\n    #     \"spacing\",\n    # ]\n\n    return keys\n\n\ndef is_no_channel(val) -> bool:\n    \"\"\"Returns whether `val` indicates \"no_channel\", for MetaKeys.ORIGINAL_CHANNEL_DIM.\"\"\"\n    if isinstance(val, torch.Tensor):\n        return bool(torch.isnan(val))\n    if isinstance(val, str):\n        return val == \"no_channel\"\n    if np.isscalar(val):\n        return bool(np.isnan(val))\n    return val is None\n"
  },
  {
    "path": "monai/data/video_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport tempfile\nfrom collections.abc import Callable\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nfrom torch.utils.data import Dataset, IterableDataset\n\nfrom monai.utils.enums import ColorOrder\nfrom monai.utils.module import optional_import\n\n__all__ = [\"VideoDataset\", \"VideoFileDataset\", \"CameraDataset\"]\n\nif TYPE_CHECKING:\n    import cv2\n\n    has_cv2 = True\nelse:\n    cv2, has_cv2 = None, None\n\n\ndef import_cv():\n    \"\"\"Import cv2. Put it inside a function to avoid webcam lights blinking on ``import monai``.\"\"\"\n    global cv2\n    global has_cv2\n    cv2, has_cv2 = optional_import(\"cv2\")\n\n\nclass SuppressStderr:\n    \"\"\"Suppress stderr. Useful as OpenCV (and dependencies) can produce a lot of output.\"\"\"\n\n    def __enter__(self):\n        self.errnull_file = open(os.devnull, \"w\")\n        self.old_stderr_fileno_undup = sys.stderr.fileno()\n        self.old_stderr_fileno = os.dup(sys.stderr.fileno())\n        self.old_stderr = sys.stderr\n        os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)\n        sys.stderr = self.errnull_file\n        return self\n\n    def __exit__(self, *_):\n        sys.stderr = self.old_stderr\n        os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)\n        os.close(self.old_stderr_fileno)\n        self.errnull_file.close()\n\n\nclass VideoDataset:\n    # import inside class to avoid webcam blinking on ``import monai``.\n    import_cv()\n\n    def __init__(\n        self,\n        video_source: str | int,\n        transform: Callable | None = None,\n        max_num_frames: int | None = None,\n        color_order: str = ColorOrder.RGB,\n        multiprocessing: bool = False,\n        channel_dim: int = 0,\n    ) -> None:\n        \"\"\"\n        Base video dataset.\n\n        Args:\n            video_source: filename of video.\n            transform: transform to be applied to each frame.\n            max_num_frames: Max number of frames to iterate across. If `None` is passed,\n                then the dataset will iterate until the end of the file.\n            color_order: Color order to return frame. Default is RGB.\n            multiprocessing: If `True`, open the video source on the fly. This makes\n                things process-safe, which is useful when combined with a DataLoader\n                with `num_workers>0`. However, when using with `num_workers==0`, it\n                makes sense to use `multiprocessing=False`, as the source will then\n                only be opened once, at construction, which will be faster in those\n                circumstances.\n            channel_dim: OpenCV reads with the channel as the last dimension. Use this\n                flag to move it elsewhere. By default this is zero, so the channel\n                dimension is moved to the front.\n\n        Raises:\n            RuntimeError: OpenCV not installed.\n            NotImplementedError: Unknown color order.\n        \"\"\"\n        if not has_cv2:\n            raise RuntimeError(\"OpenCV not installed.\")\n        if color_order not in ColorOrder:\n            raise NotImplementedError\n\n        self.color_order = color_order\n        self.channel_dim = channel_dim\n        self.video_source = video_source\n        self.multiprocessing = multiprocessing\n        if not multiprocessing:\n            self.cap = self.open_video(video_source)\n        self.transform = transform\n        self.max_num_frames = max_num_frames\n\n    @staticmethod\n    def open_video(video_source: str | int):\n        \"\"\"\n        Use OpenCV to open a video source from either file or capture device.\n\n        Args:\n            video_source: filename or index referring to capture device.\n\n        Raises:\n            RuntimeError: Source is a file but file not found.\n            RuntimeError: Failed to open source.\n        \"\"\"\n        if isinstance(video_source, str) and not os.path.isfile(video_source):\n            raise RuntimeError(\"Video file does not exist: \" + video_source)\n        with SuppressStderr():\n            cap = cv2.VideoCapture(video_source)\n        if not cap.isOpened():\n            raise RuntimeError(f\"Failed to open video: {video_source}\")\n        return cap\n\n    def _get_cap(self):\n        \"\"\"Return the cap. If multiprocessing, create a new one. Else return the one from construction time.\"\"\"\n        return self.open_video(self.video_source) if self.multiprocessing else self.cap\n\n    def get_fps(self) -> int:\n        \"\"\"Get the FPS of the capture device.\"\"\"\n        return self._get_cap().get(cv2.CAP_PROP_FPS)  # type: ignore\n\n    def get_frame(self) -> Any:\n        \"\"\"Get next frame. For a file, this will be the next frame, whereas for a camera\n        source, it will be the next available frame.\"\"\"\n        ret, frame = self._get_cap().read()\n        if not ret:\n            raise RuntimeError(\"Failed to read frame.\")\n        # Switch color order if desired\n        if self.color_order == ColorOrder.RGB:\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n        # move channel dim\n        frame = np.moveaxis(frame, -1, self.channel_dim)\n        return self.transform(frame) if self.transform is not None else frame\n\n\nclass VideoFileDataset(Dataset, VideoDataset):\n    \"\"\"\n    Video dataset from file.\n\n    This class requires that OpenCV be installed.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs) -> None:\n        VideoDataset.__init__(self, *args, **kwargs)\n        num_frames = self.get_num_frames()\n        if self.max_num_frames is None or num_frames < self.max_num_frames:\n            self.max_num_frames = num_frames\n\n    @staticmethod\n    def get_available_codecs() -> dict[str, str]:\n        \"\"\"Try different codecs, see which are available.\n        Returns a dictionary with of available codecs with codecs as keys and file extensions as values.\"\"\"\n        if not has_cv2:\n            return {}\n        all_codecs = {\"mp4v\": \".mp4\", \"X264\": \".avi\", \"H264\": \".mp4\", \"MP42\": \".mp4\", \"MJPG\": \".mjpeg\", \"DIVX\": \".avi\"}\n        codecs = {}\n        with SuppressStderr():\n            with tempfile.TemporaryDirectory() as tmp_dir:\n                for codec, ext in all_codecs.items():\n                    writer = cv2.VideoWriter()\n                    fname = os.path.join(tmp_dir, f\"test{ext}\")\n                    fourcc = cv2.VideoWriter_fourcc(*codec)  # type: ignore[attr-defined]\n                    noviderr = writer.open(fname, fourcc, 1, (10, 10))\n                    if noviderr:\n                        codecs[codec] = ext\n                    writer.release()\n        return codecs\n\n    def get_num_frames(self) -> int:\n        \"\"\"\n        Return the number of frames in a video file.\n\n        Raises:\n            RuntimeError: no frames found.\n        \"\"\"\n        num_frames = int(self._get_cap().get(cv2.CAP_PROP_FRAME_COUNT))\n        if num_frames == 0:\n            raise RuntimeError(\"0 frames found\")\n        return num_frames\n\n    def __len__(self):\n        return self.max_num_frames\n\n    def __getitem__(self, index: int) -> Any:\n        \"\"\"\n        Fetch single data item from index.\n        \"\"\"\n        if self.max_num_frames is not None and index >= self.max_num_frames:\n            raise IndexError\n        self._get_cap().set(cv2.CAP_PROP_POS_FRAMES, index)\n        return self.get_frame()\n\n\nclass CameraDataset(IterableDataset, VideoDataset):\n    \"\"\"\n    Video dataset from a capture device (e.g., webcam).\n\n    This class requires that OpenCV be installed.\n\n    Args:\n        video_source: index of capture device.\n            `get_num_devices` can be used to determine possible devices.\n        transform: transform to be applied to each frame.\n        max_num_frames: Max number of frames to iterate across. If `None` is passed,\n            then the dataset will iterate infinitely.\n\n    Raises:\n        RuntimeError: OpenCV not installed.\n    \"\"\"\n\n    @staticmethod\n    def get_num_devices() -> int:\n        \"\"\"Get number of possible devices detected by OpenCV that can be used for capture.\"\"\"\n        if not has_cv2:\n            return 0\n        num_devices = 0\n        while True:\n            cap = cv2.VideoCapture(num_devices)\n            if not cap.read()[0]:\n                break\n            num_devices += 1\n            cap.release()\n        return num_devices\n\n    def __iter__(self):\n        frame_count = 0\n        while True:\n            frame = self.get_frame()\n            frame_count += 1\n            yield frame\n            if self.max_num_frames is not None:\n                if frame_count == self.max_num_frames:\n                    break\n"
  },
  {
    "path": "monai/data/wsi_datasets.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport inspect\nimport os\nfrom collections.abc import Callable, Sequence\n\nimport numpy as np\nimport torch\n\nfrom monai.data import Dataset\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import iter_patch_position\nfrom monai.data.wsi_reader import BaseWSIReader, WSIReader\nfrom monai.transforms import ForegroundMask, Randomizable, apply_transform\nfrom monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep\nfrom monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys\n\n__all__ = [\"PatchWSIDataset\", \"SlidingPatchWSIDataset\", \"MaskedPatchWSIDataset\"]\n\n\nclass PatchWSIDataset(Dataset):\n    \"\"\"\n    This dataset extracts patches from whole slide images (without loading the whole image)\n    It also reads labels for each patch and provides each patch with its associated class labels.\n\n    Args:\n        data: the list of input samples including image, location, and label (see the note below for more details).\n        patch_size: the size of patch to be extracted from the whole slide image.\n        patch_level: the level at which the patches to be extracted (default to 0).\n        transform: transforms to be executed on input data.\n        include_label: whether to load and include labels in the output\n        center_location: whether the input location information is the position of the center of the patch\n        additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data\n        reader: the module to be used for loading whole slide imaging. If `reader` is\n\n            - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.\n            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.\n            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.\n\n        kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class\n\n    Returns:\n        dict: a dictionary of loaded image (in MetaTensor format) along with the labels (if requested).\n        {\"image\": MetaTensor, \"label\": torch.Tensor}\n\n    Note:\n        The input data has the following form as an example:\n\n        .. code-block:: python\n\n            [\n                {\"image\": \"path/to/image1.tiff\", \"location\": [200, 500], \"label\": 0},\n                {\"image\": \"path/to/image2.tiff\", \"location\": [100, 700], \"patch_size\": [20, 20], \"patch_level\": 2, \"label\": 1}\n            ]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        patch_size: int | tuple[int, int] | None = None,\n        patch_level: int | None = None,\n        transform: Callable | None = None,\n        include_label: bool = True,\n        center_location: bool = True,\n        additional_meta_keys: Sequence[str] | None = None,\n        reader=\"cuCIM\",\n        **kwargs,\n    ):\n        super().__init__(data, transform)\n\n        # Ensure patch size is a two dimensional tuple\n        if patch_size is None:\n            self.patch_size = None\n        else:\n            self.patch_size = ensure_tuple_rep(patch_size, 2)\n\n        # Create a default level that override all levels if it is not None\n        self.patch_level = patch_level\n        # Set the default WSIReader's level to 0 if level is not provided\n        if patch_level is None:\n            patch_level = 0\n\n        # Setup the WSI reader\n        self.wsi_reader: WSIReader | BaseWSIReader\n        if isinstance(reader, str):\n            self.wsi_reader = WSIReader(backend=reader, level=patch_level, **kwargs)\n        elif inspect.isclass(reader) and issubclass(reader, BaseWSIReader):\n            self.wsi_reader = reader(level=patch_level, **kwargs)\n        elif isinstance(reader, BaseWSIReader):\n            self.wsi_reader = reader\n        else:\n            raise ValueError(f\"Unsupported reader type: {reader}.\")\n        self.backend = self.wsi_reader.backend\n\n        self.include_label = include_label\n        self.center_location = center_location\n        self.additional_meta_keys = additional_meta_keys or []\n\n        # Initialized an empty whole slide image object dict\n        self.wsi_object_dict: dict = {}\n\n    def _get_wsi_object(self, sample: dict):\n        image_path = sample[CommonKeys.IMAGE]\n        if image_path not in self.wsi_object_dict:\n            self.wsi_object_dict[image_path] = self.wsi_reader.read(image_path)\n        return self.wsi_object_dict[image_path]\n\n    def _get_label(self, sample: dict):\n        return torch.tensor(sample[CommonKeys.LABEL], dtype=torch.float32)\n\n    def _get_location(self, sample: dict):\n        if self.center_location:\n            size = self._get_size(sample)\n            return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size)))\n        else:\n            return ensure_tuple(sample[WSIPatchKeys.LOCATION])\n\n    def _get_level(self, sample: dict):\n        if self.patch_level is None:\n            return sample.get(WSIPatchKeys.LEVEL, 0)\n        return self.patch_level\n\n    def _get_size(self, sample: dict):\n        if self.patch_size is None:\n            return ensure_tuple_rep(sample.get(WSIPatchKeys.SIZE), 2)\n        return self.patch_size\n\n    def _get_data(self, sample: dict):\n        # Don't store OpenSlide objects to avoid issues with OpenSlide internal cache\n        if self.backend == \"openslide\":\n            self.wsi_object_dict = {}\n        wsi_obj = self._get_wsi_object(sample)\n        location = self._get_location(sample)\n        level = self._get_level(sample)\n        size = self._get_size(sample)\n        return self.wsi_reader.get_data(wsi=wsi_obj, location=location, size=size, level=level)\n\n    def _transform(self, index: int):\n        # Get a single entry of data\n        sample: dict = self.data[index]\n\n        # Extract patch image and associated metadata\n        image, metadata = self._get_data(sample)\n\n        # Add additional metadata from sample\n        for key in self.additional_meta_keys:\n            metadata[key] = sample[key]\n\n        # Create MetaTensor output for image\n        output = {CommonKeys.IMAGE: MetaTensor(image, meta=metadata)}\n\n        # Include label in the output\n        if self.include_label:\n            output[CommonKeys.LABEL] = self._get_label(sample)\n\n        # Apply transforms and return it\n        return apply_transform(self.transform, output) if self.transform else output\n\n\nclass SlidingPatchWSIDataset(Randomizable, PatchWSIDataset):\n    \"\"\"\n    This dataset extracts patches in sliding-window manner from whole slide images (without loading the whole image).\n    It also reads labels for each patch and provides each patch with its associated class labels.\n\n    Args:\n        data: the list of input samples including image, location, and label (see the note below for more details).\n        patch_size: the size of patch to be extracted from the whole slide image.\n        patch_level: the level at which the patches to be extracted (default to 0).\n        mask_level: the resolution level at which the mask/map is created (for `ProbMapProducer` for instance).\n        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).\n            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.\n        offset: the offset of image to extract patches (the starting position of the upper left patch).\n        offset_limits: if offset is set to \"random\", a tuple of integers defining the lower and upper limit of the\n            random offset for all dimensions, or a tuple of tuples that defines the limits for each dimension.\n        transform: transforms to be executed on input data.\n        include_label: whether to load and include labels in the output\n        center_location: whether the input location information is the position of the center of the patch\n        additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data\n        reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is\n\n            - a string, it defines the backend of `monai.data.WSIReader`.\n            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,\n            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.\n\n        seed: random seed to randomly generate offsets. Defaults to 0.\n        kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class\n\n    Note:\n        The input data has the following form as an example:\n\n        .. code-block:: python\n\n            [\n                {\"image\": \"path/to/image1.tiff\"},\n                {\"image\": \"path/to/image2.tiff\", \"patch_size\": [20, 20], \"patch_level\": 2}\n            ]\n\n        Unlike `MaskedPatchWSIDataset`, this dataset does not filter any patches.\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        patch_size: int | tuple[int, int] | None = None,\n        patch_level: int | None = None,\n        mask_level: int = 0,\n        overlap: tuple[float, float] | float = 0.0,\n        offset: tuple[int, int] | int | str = (0, 0),\n        offset_limits: tuple[tuple[int, int], tuple[int, int]] | tuple[int, int] | None = None,\n        transform: Callable | None = None,\n        include_label: bool = False,\n        center_location: bool = False,\n        additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION, ProbMapKeys.SIZE, ProbMapKeys.COUNT),\n        reader=\"cuCIM\",\n        seed: int = 0,\n        **kwargs,\n    ):\n        super().__init__(\n            data=[],\n            patch_size=patch_size,\n            patch_level=patch_level,\n            transform=transform,\n            include_label=include_label,\n            center_location=center_location,\n            additional_meta_keys=additional_meta_keys,\n            reader=reader,\n            **kwargs,\n        )\n        self.overlap = overlap\n        self.set_random_state(seed)\n        # Set the offset config\n        self.random_offset = False\n        if isinstance(offset, str):\n            if offset == \"random\":\n                self.random_offset = True\n                self.offset_limits: tuple[tuple[int, int], tuple[int, int]] | None\n                if offset_limits is None:\n                    self.offset_limits = None\n                elif isinstance(offset_limits, tuple):\n                    if isinstance(offset_limits[0], int):\n                        self.offset_limits = (offset_limits, offset_limits)\n                    elif isinstance(offset_limits[0], tuple):\n                        self.offset_limits = offset_limits\n                    else:\n                        raise ValueError(\n                            \"The offset limits should be either a tuple of integers or tuple of tuple of integers.\"\n                        )\n                else:\n                    raise ValueError(\"The offset limits should be a tuple.\")\n            else:\n                raise ValueError(\n                    f'Invalid string for offset \"{offset}\". It should be either \"random\" as a string,'\n                    \"an integer, or a tuple of integers defining the offset.\"\n                )\n        else:\n            self.offset = ensure_tuple_rep(offset, 2)\n\n        self.mask_level = mask_level\n        # Create single sample for each patch (in a sliding window manner)\n        self.data: list\n        self.image_data = list(data)\n        for sample in self.image_data:\n            patch_samples = self._evaluate_patch_locations(sample)\n            self.data.extend(patch_samples)\n\n    def _get_offset(self, sample):\n        if self.random_offset:\n            if self.offset_limits is None:\n                offset_limits = tuple((-s, s) for s in self._get_size(sample))\n            else:\n                offset_limits = self.offset_limits\n            return tuple(self.R.randint(low, high) for low, high in offset_limits)\n        return self.offset\n\n    def _evaluate_patch_locations(self, sample):\n        \"\"\"Calculate the location for each patch in a sliding-window manner\"\"\"\n        patch_size = self._get_size(sample)\n        patch_level = self._get_level(sample)\n        wsi_obj = self._get_wsi_object(sample)\n\n        # calculate the locations\n        wsi_size = self.wsi_reader.get_size(wsi_obj, 0)\n        mask_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, self.mask_level)\n        patch_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, patch_level)\n        patch_size_0 = np.array([p * patch_ratio for p in patch_size])  # patch size at level 0\n        offset = self._get_offset(sample)\n        patch_locations = np.array(\n            list(\n                iter_patch_position(\n                    image_size=wsi_size, patch_size=patch_size_0, start_pos=offset, overlap=self.overlap, padded=False\n                )\n            )\n        )\n        # convert locations to mask_location\n        mask_locations = np.round((patch_locations + patch_size_0 // 2) / float(mask_ratio))\n\n        # fill out samples with location and metadata\n        sample[WSIPatchKeys.SIZE.value] = patch_size\n        sample[WSIPatchKeys.LEVEL.value] = patch_level\n        sample[ProbMapKeys.NAME.value] = os.path.basename(sample[CommonKeys.IMAGE])\n        sample[ProbMapKeys.COUNT.value] = len(patch_locations)\n        sample[ProbMapKeys.SIZE.value] = np.array(self.wsi_reader.get_size(wsi_obj, self.mask_level))\n        return [\n            {**sample, WSIPatchKeys.LOCATION.value: np.array(loc), ProbMapKeys.LOCATION.value: mask_loc}\n            for loc, mask_loc in zip(patch_locations, mask_locations)\n        ]\n\n\nclass MaskedPatchWSIDataset(PatchWSIDataset):\n    \"\"\"\n    This dataset extracts patches from whole slide images at the locations where foreground mask\n    at a given level is non-zero.\n\n    Args:\n        data: the list of input samples including image, location, and label (see the note below for more details).\n        patch_size: the size of patch to be extracted from the whole slide image.\n        patch_level: the level at which the patches to be extracted (default to 0).\n        mask_level: the resolution level at which the mask is created.\n        transform: transforms to be executed on input data.\n        include_label: whether to load and include labels in the output\n        center_location: whether the input location information is the position of the center of the patch\n        additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data\n        reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is\n\n            - a string, it defines the backend of `monai.data.WSIReader`.\n            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,\n            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.\n\n        kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class\n\n    Note:\n        The input data has the following form as an example:\n\n        .. code-block:: python\n\n            [\n                {\"image\": \"path/to/image1.tiff\"},\n                {\"image\": \"path/to/image2.tiff\", \"size\": [20, 20], \"level\": 2}\n            ]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        data: Sequence,\n        patch_size: int | tuple[int, int] | None = None,\n        patch_level: int | None = None,\n        mask_level: int = 7,\n        transform: Callable | None = None,\n        include_label: bool = False,\n        center_location: bool = False,\n        additional_meta_keys: Sequence[str] = (ProbMapKeys.LOCATION, ProbMapKeys.NAME),\n        reader=\"cuCIM\",\n        **kwargs,\n    ):\n        super().__init__(\n            data=[],\n            patch_size=patch_size,\n            patch_level=patch_level,\n            transform=transform,\n            include_label=include_label,\n            center_location=center_location,\n            additional_meta_keys=additional_meta_keys,\n            reader=reader,\n            **kwargs,\n        )\n\n        self.mask_level = mask_level\n        # Create single sample for each patch (in a sliding window manner)\n        self.data: list\n        self.image_data = list(data)\n        for sample in self.image_data:\n            patch_samples = self._evaluate_patch_locations(sample)\n            self.data.extend(patch_samples)\n\n    def _evaluate_patch_locations(self, sample):\n        \"\"\"Calculate the location for each patch based on the mask at different resolution level\"\"\"\n        patch_size = self._get_size(sample)\n        patch_level = self._get_level(sample)\n        wsi_obj = self._get_wsi_object(sample)\n\n        # load the entire image at level=mask_level\n        wsi, _ = self.wsi_reader.get_data(wsi_obj, level=self.mask_level)\n\n        # create the foreground tissue mask and get all indices for non-zero pixels\n        mask = np.squeeze(convert_to_dst_type(ForegroundMask(hsv_threshold={\"S\": \"otsu\"})(wsi), dst=wsi)[0])\n        mask_locations = np.vstack(mask.nonzero()).T\n\n        # convert mask locations to image locations at level=0\n        mask_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, self.mask_level)\n        patch_ratio = self.wsi_reader.get_downsample_ratio(wsi_obj, patch_level)\n        patch_size_0 = np.array([p * patch_ratio for p in patch_size])  # patch size at level 0\n        patch_locations = np.round((mask_locations + 0.5) * float(mask_ratio) - patch_size_0 // 2).astype(int)\n\n        # fill out samples with location and metadata\n        sample[WSIPatchKeys.SIZE.value] = patch_size\n        sample[WSIPatchKeys.LEVEL.value] = patch_level\n        sample[ProbMapKeys.NAME.value] = os.path.basename(sample[CommonKeys.IMAGE])\n        sample[ProbMapKeys.COUNT.value] = len(patch_locations)\n        sample[ProbMapKeys.SIZE.value] = mask.shape\n        return [\n            {**sample, WSIPatchKeys.LOCATION.value: np.array(loc), ProbMapKeys.LOCATION.value: mask_loc}\n            for loc, mask_loc in zip(patch_locations, mask_locations)\n        ]\n"
  },
  {
    "path": "monai/data/wsi_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom abc import abstractmethod\nfrom collections.abc import Sequence\nfrom os.path import abspath\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import DtypeLike, NdarrayOrTensor, PathLike\nfrom monai.data.image_reader import ImageReader, _stack_images\nfrom monai.data.utils import is_supported_format\nfrom monai.utils import (\n    WSIPatchKeys,\n    dtype_numpy_to_torch,\n    dtype_torch_to_numpy,\n    ensure_tuple,\n    ensure_tuple_rep,\n    optional_import,\n    require_pkg,\n)\nfrom monai.utils.misc import ConvertUnits\n\nOpenSlide, _ = optional_import(\"openslide\", name=\"OpenSlide\")\nTiffFile, _ = optional_import(\"tifffile\", name=\"TiffFile\")\n\n__all__ = [\"BaseWSIReader\", \"WSIReader\", \"CuCIMWSIReader\", \"OpenSlideWSIReader\", \"TiffFileWSIReader\"]\n\n\nclass BaseWSIReader(ImageReader):\n    \"\"\"\n    An abstract class that defines APIs to load patches from whole slide image files.\n\n    Args:\n        level: the whole slide image level at which the patches are extracted.\n        mpp: the resolution in micron per pixel at which the patches are extracted.\n        mpp_rtol: the acceptable relative tolerance for resolution in micro per pixel.\n        mpp_atol: the acceptable absolute tolerance for resolution in micro per pixel.\n        power: the objective power at which the patches are extracted.\n        power_rtol: the acceptable relative tolerance for objective power.\n        power_atol: the acceptable absolute tolerance for objective power.\n        channel_dim: the desired dimension for color channel.\n        dtype: the data type of output image.\n        device: target device to put the extracted patch. Note that if device is \"cuda\"\",\n            the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.\n        mode: the output image color mode, e.g., \"RGB\" or \"RGBA\".\n        kwargs: additional args for the reader\n\n        Notes:\n            Only one of resolution parameters, `level`, `mpp`, or `power`, should be provided.\n            If such parameters are provided in `get_data` method, those will override the values provided here.\n            If none of them are provided here or in `get_data`, `level=0` will be used.\n\n    Typical usage of a concrete implementation of this class is:\n\n    .. code-block:: python\n\n        image_reader = MyWSIReader()\n        wsi = image_reader.read(filepath, **kwargs)\n        img_data, meta_data = image_reader.get_data(wsi)\n\n    - The `read` call converts an image filename into whole slide image object,\n    - The `get_data` call fetches the image data, as well as metadata.\n\n    The following methods needs to be implemented for any concrete implementation of this class:\n\n    - `read` reads a whole slide image object from a given file\n    - `get_size` returns the size of the whole slide image of a given wsi object at a given level.\n    - `get_level_count` returns the number of levels in the whole slide image\n    - `_get_patch` extracts and returns a patch image form the whole slide image\n    - `_get_metadata` extracts and returns metadata for a whole slide image and a specific patch.\n\n\n    \"\"\"\n\n    supported_suffixes: list[str] = []\n    backend = \"\"\n\n    def __init__(\n        self,\n        level: int | None = None,\n        mpp: float | tuple[float, float] | None = None,\n        mpp_rtol: float = 0.05,\n        mpp_atol: float = 0.0,\n        power: int | None = None,\n        power_rtol: float = 0.05,\n        power_atol: float = 0.0,\n        channel_dim: int = 0,\n        dtype: DtypeLike | torch.dtype = np.uint8,\n        device: torch.device | str | None = None,\n        mode: str = \"RGB\",\n        **kwargs,\n    ):\n        super().__init__()\n        self.level = level\n        self.channel_dim = channel_dim\n        self.set_dtype(dtype)\n        self.set_device(device)\n        self.mode = mode\n        self.kwargs = kwargs\n        self.mpp: tuple[float, float] | None = ensure_tuple_rep(mpp, 2) if mpp is not None else None\n        self.power = power\n        self.mpp_rtol = mpp_rtol\n        self.mpp_atol = mpp_atol\n        self.power_rtol = power_rtol\n        self.power_atol = power_atol\n        self.metadata: dict[Any, Any] = {}\n\n    def set_dtype(self, dtype):\n        self.dtype: torch.dtype | np.dtype\n        if isinstance(dtype, torch.dtype):\n            self.dtype = dtype\n        else:\n            self.dtype = np.dtype(dtype)\n\n    def set_device(self, device):\n        if device is None or isinstance(device, (torch.device, str)):\n            self.device = device\n        else:\n            raise ValueError(f\"`device` must be `torch.device`, `str` or `None` but {type(device)} is given.\")\n\n    @abstractmethod\n    def get_size(self, wsi, level: int) -> tuple[int, int]:\n        \"\"\"\n        Returns the size (height, width) of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the size is calculated.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def _find_closest_level(\n        self, name: str, value: tuple, value_at_levels: Sequence[tuple], atol: float, rtol: float\n    ) -> int:\n        \"\"\"Find the level corresponding to the value of the quantity in the list of values at each level.\n        Args:\n            name: the name of the requested quantity\n            value: the value of requested quantity\n            value_at_levels: list of value of the quantity at each level\n            atol: the tolerance for the value\n            rtol: relative tolerance for the value\n        \"\"\"\n        if value in value_at_levels:\n            return value_at_levels.index(value)\n\n        closest_value = min(value_at_levels, key=lambda a_value: sum([abs(x - y) for x, y in zip(a_value, value)]))\n        for i in range(len(value)):\n            if abs(closest_value[i] - value[i]) > atol + rtol * abs(value[i]):\n                raise ValueError(\n                    f\"The requested {name} < {value} > does not exist in this whole slide image \"\n                    f\"(with {name}_rtol={rtol} and {name}_atol={atol}). \"\n                    f\"Here is the list of available {name}: {value_at_levels}. \"\n                    f\"The closest matching available {name} is {closest_value}.\"\n                    f\"Please consider changing the tolerances or use another {name}.\"\n                )\n        return value_at_levels.index(closest_value)\n\n    def get_valid_level(\n        self, wsi, level: int | None, mpp: float | tuple[float, float] | None, power: int | None\n    ) -> int:\n        \"\"\"\n        Returns the level associated to the resolution parameters in the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number.\n            mpp: the micron-per-pixel resolution.\n            power: the objective power.\n\n        \"\"\"\n\n        # Try instance parameters if no resolution is provided\n        if mpp is None and power is None and level is None:\n            mpp = self.mpp\n            power = self.power\n            level = self.level\n\n        # Ensure that at most one resolution parameter is provided.\n        resolution = [val[0] for val in [(\"level\", level), (\"mpp\", mpp), (\"power\", power)] if val[1] is not None]\n        if len(resolution) > 1:\n            raise ValueError(f\"Only one of `level`, `mpp`, or `power` should be provided. {resolution} are provided.\")\n\n        n_levels = self.get_level_count(wsi)\n\n        if mpp is not None:\n            mpp_ = ensure_tuple_rep(mpp, 2)\n            available_mpps = [self.get_mpp(wsi, level) for level in range(n_levels)]\n            level = self._find_closest_level(\"mpp\", mpp_, available_mpps, self.mpp_atol, self.mpp_rtol)\n        elif power is not None:\n            power_ = ensure_tuple(power)\n            available_powers = [(self.get_power(wsi, level),) for level in range(n_levels)]\n            level = self._find_closest_level(\"power\", power_, available_powers, self.power_atol, self.power_rtol)\n        else:\n            if level is None:\n                # Set the default value if no resolution parameter is provided.\n                level = 0\n            if level >= n_levels:\n                raise ValueError(\n                    f\"The maximum level of this image is {n_levels - 1} while level={level} is requested)!\"\n                )\n\n        return level\n\n    @abstractmethod\n    def get_level_count(self, wsi) -> int:\n        \"\"\"\n        Returns the number of levels in the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def get_downsample_ratio(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the down-sampling ratio of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the downsample ratio is calculated.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def get_file_path(self, wsi) -> str:\n        \"\"\"Return the file path for the WSI object\"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def get_mpp(self, wsi, level: int) -> tuple[float, float]:\n        \"\"\"\n        Returns the micro-per-pixel resolution of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the mpp is calculated.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def get_power(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the objective power of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the objective power is calculated.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def _get_patch(\n        self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str\n    ) -> np.ndarray:\n        \"\"\"\n        Extracts and returns a patch image form the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file or a lis of such objects\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If None, it is set to the full image size at the given level.\n            level: the level number.\n            dtype: the data type of output image.\n            mode: the output image mode, 'RGB' or 'RGBA'.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def _get_metadata(\n        self, wsi, patch: NdarrayOrTensor, location: tuple[int, int], size: tuple[int, int], level: int\n    ) -> dict:\n        \"\"\"\n        Returns metadata of the extracted patch from the whole slide image.\n\n        Args:\n            wsi: the whole slide image object, from which the patch is loaded.\n            patch: extracted patch from whole slide image.\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If None, it is set to the full image size at the given level.\n            level: the level number.\n\n        \"\"\"\n        if self.channel_dim >= len(patch.shape) or self.channel_dim < -len(patch.shape):\n            raise ValueError(\n                f\"The desired channel_dim ({self.channel_dim}) is out of bound for image shape: {patch.shape}\"\n            )\n        channel_dim: int = self.channel_dim + (len(patch.shape) if self.channel_dim < 0 else 0)\n        metadata: dict = {\n            \"backend\": self.backend,\n            \"original_channel_dim\": channel_dim,\n            \"spatial_shape\": np.array(patch.shape[:channel_dim] + patch.shape[channel_dim + 1 :]),\n            WSIPatchKeys.COUNT.value: 1,\n            WSIPatchKeys.PATH.value: self.get_file_path(wsi),\n            WSIPatchKeys.LOCATION.value: np.asarray(location),\n            WSIPatchKeys.SIZE.value: np.asarray(size),\n            WSIPatchKeys.LEVEL.value: level,\n        }\n        return metadata\n\n    def get_data(\n        self,\n        wsi,\n        location: tuple[int, int] = (0, 0),\n        size: tuple[int, int] | None = None,\n        level: int | None = None,\n        mpp: float | tuple[float, float] | None = None,\n        power: int | None = None,\n        mode: str | None = None,\n    ) -> tuple[np.ndarray, dict]:\n        \"\"\"\n        Verifies inputs, extracts patches from WSI image and generates metadata.\n\n        Args:\n            wsi: a whole slide image object loaded from a file or a list of such objects.\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If not provided or None, it is set to the full image size at the given level.\n            level: the whole slide image level at which the patches are extracted.\n            mpp: the resolution in micron per pixel at which the patches are extracted.\n            power: the objective power at which the patches are extracted.\n            dtype: the data type of output image.\n            mode: the output image mode, 'RGB' or 'RGBA'.\n\n        Returns:\n            a tuples, where the first element is an image patch [CxHxW] or stack of patches,\n                and second element is a dictionary of metadata.\n\n        Notes:\n            Only one of resolution parameters, `level`, `mpp`, or `power`, should be provided.\n            If none of them are provided, it uses the defaults that are set during class instantiation.\n            If none of them are set here or during class instantiation, `level=0` will be used.\n        \"\"\"\n        if mode is None:\n            mode = self.mode\n        patch_list: list = []\n        metadata_list: list = []\n\n        # CuImage object is iterable, so ensure_tuple won't work on single object\n        if not isinstance(wsi, (list, tuple)):\n            wsi = (wsi,)\n        for each_wsi in ensure_tuple(wsi):\n            # get the valid level based on resolution info\n            level = self.get_valid_level(each_wsi, level, mpp, power)\n\n            # Verify location\n            if location is None:\n                location = (0, 0)\n            wsi_size = self.get_size(each_wsi, 0)\n            if location[0] > wsi_size[0] or location[1] > wsi_size[1]:\n                raise ValueError(f\"Location is outside of the image: location={location}, image size={wsi_size}\")\n\n            # Verify size\n            if size is None:\n                if location != (0, 0):\n                    raise ValueError(\"Patch size should be defined to extract patches.\")\n                size = self.get_size(each_wsi, level)\n            else:\n                if size[0] <= 0 or size[1] <= 0:\n                    raise ValueError(f\"Patch size should be greater than zero, provided: patch size = {size}\")\n\n            # Get numpy dtype if it is not already.\n            dtype_np = dtype_torch_to_numpy(self.dtype) if isinstance(self.dtype, torch.dtype) else self.dtype\n            # Extract a patch or the entire image\n            patch: NdarrayOrTensor\n            patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype_np, mode=mode)\n\n            # Convert the patch to torch.Tensor if dtype is torch\n            if isinstance(self.dtype, torch.dtype) or (\n                self.device is not None and torch.device(self.device).type == \"cuda\"\n            ):\n                # Ensure dtype is torch.dtype if the device is not \"cpu\"\n                dtype_torch = (\n                    dtype_numpy_to_torch(self.dtype) if not isinstance(self.dtype, torch.dtype) else self.dtype\n                )\n                # Copy the numpy array if it is not writable\n                if patch.flags[\"WRITEABLE\"]:\n                    patch = torch.as_tensor(patch, dtype=dtype_torch, device=self.device)\n                else:\n                    patch = torch.tensor(patch, dtype=dtype_torch, device=self.device)\n\n            # check if the image has three dimensions (2D + color)\n            if patch.ndim != 3:\n                raise ValueError(\n                    f\"The image dimension should be 3 but has {patch.ndim}. \"\n                    \"`WSIReader` is designed to work only with 2D images with color channel.\"\n                )\n            # Check if there are four color channels for RGBA\n            if mode == \"RGBA\":\n                if patch.shape[self.channel_dim] != 4:\n                    raise ValueError(\n                        f\"The image is expected to have four color channels in '{mode}' mode but has \"\n                        f\"{patch.shape[self.channel_dim]}.\"\n                    )\n            # Check if there are three color channels for RGB\n            elif mode in \"RGB\" and patch.shape[self.channel_dim] != 3:\n                raise ValueError(\n                    f\"The image is expected to have three color channels in '{mode}' mode but has \"\n                    f\"{patch.shape[self.channel_dim]}. \"\n                )\n            # Get patch-related metadata\n            metadata: dict = self._get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)\n            # Create a list of patches and metadata\n            patch_list.append(patch)\n            metadata_list.append(metadata)\n        if len(wsi) > 1:\n            if len({m[\"original_channel_dim\"] for m in metadata_list}) > 1:\n                raise ValueError(\"original_channel_dim is not consistent across wsi objects.\")\n            if len({tuple(m[\"spatial_shape\"]) for m in metadata_list}) > 1:\n                raise ValueError(\"spatial_shape is not consistent across wsi objects.\")\n            for key in WSIPatchKeys:\n                metadata[key] = [m[key] for m in metadata_list]\n        return _stack_images(patch_list, metadata), metadata\n\n    def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:\n        \"\"\"\n        Verify whether the specified file or files format is supported by WSI reader.\n\n        The list of supported suffixes are read from `self.supported_suffixes`.\n\n        Args:\n            filename: filename or a list of filenames to read.\n\n        \"\"\"\n        return is_supported_format(filename, self.supported_suffixes)\n\n\nclass WSIReader(BaseWSIReader):\n    \"\"\"\n    Read whole slide images and extract patches using different backend libraries\n\n    Args:\n        backend: the name of backend whole slide image reader library, the default is cuCIM.\n        level: the whole slide image level at which the patches are extracted.\n        mpp: the resolution in micron per pixel at which the patches are extracted.\n        mpp_rtol: the acceptable relative tolerance for resolution in micro per pixel.\n        mpp_atol: the acceptable absolute tolerance for resolution in micro per pixel.\n        power: the objective power at which the patches are extracted.\n        power_rtol: the acceptable relative tolerance for objective power.\n        power_atol: the acceptable absolute tolerance for objective power.\n        channel_dim: the desired dimension for color channel. Default to 0 (channel first).\n        dtype: the data type of output image. Defaults to `np.uint8`.\n        device: target device to put the extracted patch. Note that if device is \"cuda\"\",\n            the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.\n        mode: the output image color mode, \"RGB\" or \"RGBA\". Defaults to \"RGB\".\n        num_workers: number of workers for multi-thread image loading (cucim backend only).\n        kwargs: additional arguments to be passed to the backend library\n\n        Notes:\n            Only one of resolution parameters, `level`, `mpp`, or `power`, should be provided.\n            If such parameters are provided in `get_data` method, those will override the values provided here.\n            If none of them are provided here or in `get_data`, `level=0` will be used.\n\n    \"\"\"\n\n    supported_backends = [\"cucim\", \"openslide\", \"tifffile\"]\n\n    def __init__(\n        self,\n        backend=\"cucim\",\n        level: int | None = None,\n        mpp: float | tuple[float, float] | None = None,\n        mpp_rtol: float = 0.05,\n        mpp_atol: float = 0.0,\n        power: int | None = None,\n        power_rtol: float = 0.05,\n        power_atol: float = 0.0,\n        channel_dim: int = 0,\n        dtype: DtypeLike | torch.dtype = np.uint8,\n        device: torch.device | str | None = None,\n        mode: str = \"RGB\",\n        **kwargs,\n    ):\n        self.backend = backend.lower()\n        self.reader: CuCIMWSIReader | OpenSlideWSIReader | TiffFileWSIReader\n        if self.backend == \"cucim\":\n            self.reader = CuCIMWSIReader(\n                level=level,\n                mpp=mpp,\n                mpp_rtol=mpp_rtol,\n                mpp_atol=mpp_atol,\n                power=power,\n                power_rtol=power_rtol,\n                power_atol=power_atol,\n                channel_dim=channel_dim,\n                dtype=dtype,\n                device=device,\n                mode=mode,\n                **kwargs,\n            )\n        elif self.backend == \"openslide\":\n            self.reader = OpenSlideWSIReader(\n                level=level,\n                mpp=mpp,\n                mpp_rtol=mpp_rtol,\n                mpp_atol=mpp_atol,\n                power=power,\n                power_rtol=power_rtol,\n                power_atol=power_atol,\n                channel_dim=channel_dim,\n                dtype=dtype,\n                device=device,\n                mode=mode,\n                **kwargs,\n            )\n        elif self.backend == \"tifffile\":\n            self.reader = TiffFileWSIReader(\n                level=level,\n                mpp=mpp,\n                mpp_rtol=mpp_rtol,\n                mpp_atol=mpp_atol,\n                power=power,\n                power_rtol=power_rtol,\n                power_atol=power_atol,\n                channel_dim=channel_dim,\n                dtype=dtype,\n                device=device,\n                mode=mode,\n                **kwargs,\n            )\n        else:\n            raise ValueError(\n                f\"The supported backends are cucim, openslide, and tifffile but '{self.backend}' was given.\"\n            )\n        self.supported_suffixes = self.reader.supported_suffixes\n        self.level = self.reader.level\n        self.mpp_rtol = self.reader.mpp_rtol\n        self.mpp_atol = self.reader.mpp_atol\n        self.power = self.reader.power\n        self.power_rtol = self.reader.power_rtol\n        self.power_atol = self.reader.power_atol\n        self.channel_dim = self.reader.channel_dim\n        self.dtype = self.reader.dtype\n        self.device = self.reader.device\n        self.mode = self.reader.mode\n        self.kwargs = self.reader.kwargs\n        self.metadata = self.reader.metadata\n        self.mpp = self.reader.mpp\n\n    def get_level_count(self, wsi) -> int:\n        \"\"\"\n        Returns the number of levels in the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n\n        \"\"\"\n        return self.reader.get_level_count(wsi)\n\n    def get_size(self, wsi, level: int) -> tuple[int, int]:\n        \"\"\"\n        Returns the size (height, width) of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the size is calculated.\n\n        \"\"\"\n        return self.reader.get_size(wsi, level)\n\n    def get_downsample_ratio(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the down-sampling ratio of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the downsample ratio is calculated.\n\n        \"\"\"\n        return self.reader.get_downsample_ratio(wsi, level)\n\n    def get_file_path(self, wsi) -> str:\n        \"\"\"Return the file path for the WSI object\"\"\"\n        return self.reader.get_file_path(wsi)\n\n    def get_mpp(self, wsi, level: int) -> tuple[float, float]:\n        \"\"\"\n        Returns the micro-per-pixel resolution of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the mpp is calculated.\n\n        \"\"\"\n        return self.reader.get_mpp(wsi, level)\n\n    def get_power(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the micro-per-pixel resolution of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the objective power is calculated.\n\n        \"\"\"\n        return self.reader.get_power(wsi, level)\n\n    def _get_patch(\n        self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str\n    ) -> np.ndarray:\n        \"\"\"\n        Extracts and returns a patch image form the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file or a lis of such objects.\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If None, it is set to the full image size at the given level.\n            level: the level number.\n            dtype: the data type of output image\n            mode: the output image mode, 'RGB' or 'RGBA'.\n\n        \"\"\"\n        return self.reader._get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)\n\n    def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):\n        \"\"\"\n        Read whole slide image objects from given file or list of files.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args for the reader module (overrides `self.kwargs` for existing keys).\n\n        Returns:\n            whole slide image object or list of such objects.\n\n        \"\"\"\n        return self.reader.read(data=data, **kwargs)\n\n\n@require_pkg(pkg_name=\"cucim\")\nclass CuCIMWSIReader(BaseWSIReader):\n    \"\"\"\n    Read whole slide images and extract patches using cuCIM library.\n\n    Args:\n        level: the whole slide image level at which the patches are extracted.\n        mpp: the resolution in micron per pixel at which the patches are extracted.\n        mpp_rtol: the acceptable relative tolerance for resolution in micro per pixel.\n        mpp_atol: the acceptable absolute tolerance for resolution in micro per pixel.\n        power: the objective power at which the patches are extracted.\n        power_rtol: the acceptable relative tolerance for objective power.\n        power_atol: the acceptable absolute tolerance for objective power.\n        channel_dim: the desired dimension for color channel. Default to 0 (channel first).\n        dtype: the data type of output image. Defaults to `np.uint8`.\n        device: target device to put the extracted patch. Note that if device is \"cuda\"\",\n            the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.\n        mode: the output image color mode, \"RGB\" or \"RGBA\". Defaults to \"RGB\".\n        num_workers: number of workers for multi-thread image loading.\n        kwargs: additional args for `cucim.CuImage` module:\n            https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h\n\n        Notes:\n            Only one of resolution parameters, `level`, `mpp`, or `power`, should be provided.\n            If such parameters are provided in `get_data` method, those will override the values provided here.\n            If none of them are provided here or in `get_data`, `level=0` will be used.\n\n    \"\"\"\n\n    supported_suffixes = [\"tif\", \"tiff\", \"svs\"]\n    backend = \"cucim\"\n\n    def __init__(self, num_workers: int = 0, **kwargs):\n        super().__init__(**kwargs)\n        self.num_workers = num_workers\n\n    @staticmethod\n    def get_level_count(wsi) -> int:\n        \"\"\"\n        Returns the number of levels in the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n\n        \"\"\"\n        return wsi.resolutions[\"level_count\"]  # type: ignore\n\n    def get_size(self, wsi, level: int) -> tuple[int, int]:\n        \"\"\"\n        Returns the size (height, width) of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the size is calculated.\n\n        \"\"\"\n        return (wsi.resolutions[\"level_dimensions\"][level][1], wsi.resolutions[\"level_dimensions\"][level][0])\n\n    def get_downsample_ratio(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the down-sampling ratio of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the downsample ratio is calculated.\n\n        \"\"\"\n        return float(wsi.resolutions[\"level_downsamples\"][level])\n\n    @staticmethod\n    def get_file_path(wsi) -> str:\n        \"\"\"Return the file path for the WSI object\"\"\"\n        return str(abspath(wsi.path))\n\n    def get_mpp(self, wsi, level: int) -> tuple[float, float]:\n        \"\"\"\n        Returns the micro-per-pixel resolution of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the mpp is calculated.\n\n        \"\"\"\n        downsample_ratio = self.get_downsample_ratio(wsi, level)\n\n        if \"aperio\" in wsi.metadata:\n            mpp_ = wsi.metadata[\"aperio\"].get(\"MPP\")\n            if mpp_:\n                return (downsample_ratio * float(mpp_),) * 2\n        if \"cucim\" in wsi.metadata:\n            mpp_ = wsi.metadata[\"cucim\"].get(\"spacing\")\n            if mpp_ and isinstance(mpp_, Sequence) and len(mpp_) >= 2:\n                if mpp_[0] and mpp_[1]:\n                    return (downsample_ratio * mpp_[1], downsample_ratio * mpp_[0])\n\n        raise ValueError(\"`mpp` cannot be obtained for this file. Please use `level` instead.\")\n\n    def get_power(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the objective power of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the objective power is calculated.\n\n        \"\"\"\n        if \"aperio\" in wsi.metadata:\n            objective_power = wsi.metadata[\"aperio\"].get(\"AppMag\")\n            if objective_power:\n                downsample_ratio = self.get_downsample_ratio(wsi, level)\n                return float(objective_power) / downsample_ratio\n\n        raise ValueError(\n            \"Currently, cuCIM backend can obtain the objective power only for Aperio images. \"\n            \"Please use `level` (or `mpp`) instead, or try OpenSlide backend.\"\n        )\n\n    def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):\n        \"\"\"\n        Read whole slide image objects from given file or list of files.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args that overrides `self.kwargs` for existing keys.\n                For more details look at https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h\n\n        Returns:\n            whole slide image object or list of such objects.\n\n        \"\"\"\n        cuimage_cls, _ = optional_import(\"cucim\", name=\"CuImage\")\n        wsi_list: list = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for filename in filenames:\n            wsi = cuimage_cls(filename, **kwargs_)\n            wsi_list.append(wsi)\n\n        return wsi_list if len(filenames) > 1 else wsi_list[0]\n\n    def _get_patch(\n        self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str\n    ) -> np.ndarray:\n        \"\"\"\n        Extracts and returns a patch image form the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file or a lis of such objects.\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If None, it is set to the full image size at the given level.\n            level: the level number.\n            dtype: the data type of output image.\n            mode: the output image mode, 'RGB' or 'RGBA'.\n\n        \"\"\"\n        # Extract a patch or the entire image\n        # (reverse the order of location and size to become WxH for cuCIM)\n        patch: np.ndarray = wsi.read_region(\n            location=location[::-1], size=size[::-1], level=level, num_workers=self.num_workers\n        )\n\n        # Convert to numpy\n        patch = np.asarray(patch, dtype=dtype)\n\n        # Make the channel to desired dimensions\n        patch = np.moveaxis(patch, -1, self.channel_dim)\n\n        # Check if the color channel is 3 (RGB) or 4 (RGBA)\n        if mode in \"RGB\":\n            if patch.shape[self.channel_dim] not in [3, 4]:\n                raise ValueError(\n                    f\"The image is expected to have three or four color channels in '{mode}' mode but has \"\n                    f\"{patch.shape[self.channel_dim]}. \"\n                )\n            patch = np.take(patch, [0, 1, 2], self.channel_dim)\n\n        return patch\n\n\n@require_pkg(pkg_name=\"openslide\")\nclass OpenSlideWSIReader(BaseWSIReader):\n    \"\"\"\n    Read whole slide images and extract patches using OpenSlide library.\n\n    Args:\n        level: the whole slide image level at which the patches are extracted.\n        mpp: the resolution in micron per pixel at which the patches are extracted.\n        mpp_rtol: the acceptable relative tolerance for resolution in micro per pixel.\n        mpp_atol: the acceptable absolute tolerance for resolution in micro per pixel.\n        power: the objective power at which the patches are extracted.\n        power_rtol: the acceptable relative tolerance for objective power.\n        power_atol: the acceptable absolute tolerance for objective power.\n        channel_dim: the desired dimension for color channel. Default to 0 (channel first).\n        dtype: the data type of output image. Defaults to `np.uint8`.\n        device: target device to put the extracted patch. Note that if device is \"cuda\"\",\n            the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.\n        mode: the output image color mode, \"RGB\" or \"RGBA\". Defaults to \"RGB\".\n        kwargs: additional args for `openslide.OpenSlide` module.\n\n        Notes:\n            Only one of resolution parameters, `level`, `mpp`, or `power`, should be provided.\n            If such parameters are provided in `get_data` method, those will override the values provided here.\n            If none of them are provided here or in `get_data`, `level=0` will be used.\n\n    \"\"\"\n\n    supported_suffixes = [\"tif\", \"tiff\", \"svs\"]\n    backend = \"openslide\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    @staticmethod\n    def get_level_count(wsi) -> int:\n        \"\"\"\n        Returns the number of levels in the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n\n        \"\"\"\n        return wsi.level_count  # type: ignore\n\n    def get_size(self, wsi, level: int) -> tuple[int, int]:\n        \"\"\"\n        Returns the size (height, width) of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the size is calculated.\n\n        \"\"\"\n        return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0])\n\n    def get_downsample_ratio(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the down-sampling ratio of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the downsample ratio is calculated.\n\n        \"\"\"\n        return wsi.level_downsamples[level]  # type: ignore\n\n    @staticmethod\n    def get_file_path(wsi) -> str:\n        \"\"\"Return the file path for the WSI object\"\"\"\n        return str(abspath(wsi._filename))\n\n    def get_mpp(self, wsi, level: int) -> tuple[float, float]:\n        \"\"\"\n        Returns the micro-per-pixel resolution of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the mpp is calculated.\n\n        \"\"\"\n        downsample_ratio = self.get_downsample_ratio(wsi, level)\n        if (\n            \"openslide.mpp-x\" in wsi.properties\n            and \"openslide.mpp-y\" in wsi.properties\n            and wsi.properties[\"openslide.mpp-y\"]\n            and wsi.properties[\"openslide.mpp-x\"]\n        ):\n            return (\n                downsample_ratio * float(wsi.properties[\"openslide.mpp-y\"]),\n                downsample_ratio * float(wsi.properties[\"openslide.mpp-x\"]),\n            )\n\n        if (\n            \"tiff.XResolution\" in wsi.properties\n            and \"tiff.YResolution\" in wsi.properties\n            and wsi.properties[\"tiff.YResolution\"]\n            and wsi.properties[\"tiff.XResolution\"]\n        ):\n            unit = wsi.properties.get(\"tiff.ResolutionUnit\")\n            if unit is None:\n                warnings.warn(\"The resolution unit is missing, `micrometer` will be used as default.\")\n                unit = \"micrometer\"\n\n            convert_to_micron = ConvertUnits(unit, \"micrometer\")\n            return (\n                convert_to_micron(downsample_ratio / float(wsi.properties[\"tiff.YResolution\"])),\n                convert_to_micron(downsample_ratio / float(wsi.properties[\"tiff.XResolution\"])),\n            )\n\n        raise ValueError(\"`mpp` cannot be obtained for this file. Please use `level` instead.\")\n\n    def get_power(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the objective power of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the objective power is calculated.\n\n        \"\"\"\n        objective_power = wsi.properties.get(\"openslide.objective-power\")\n        if objective_power:\n            downsample_ratio = self.get_downsample_ratio(wsi, level)\n            return float(objective_power) / downsample_ratio\n\n        raise ValueError(\"Objective `power` cannot be obtained for this file. Please use `level` (or `mpp`) instead.\")\n\n    def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):\n        \"\"\"\n        Read whole slide image objects from given file or list of files.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args that overrides `self.kwargs` for existing keys.\n\n        Returns:\n            whole slide image object or list of such objects.\n\n        \"\"\"\n        wsi_list: list = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for filename in filenames:\n            wsi = OpenSlide(filename, **kwargs_)\n            wsi_list.append(wsi)\n\n        return wsi_list if len(filenames) > 1 else wsi_list[0]\n\n    def _get_patch(\n        self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str\n    ) -> np.ndarray:\n        \"\"\"\n        Extracts and returns a patch image form the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file or a lis of such objects\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If None, it is set to the full image size at the given level.\n            level: the level number.\n            dtype: the data type of output image.\n            mode: the output image mode, 'RGB' or 'RGBA'.\n\n        \"\"\"\n        # Extract a patch or the entire image\n        # (reverse the order of location and size to become WxH for OpenSlide)\n        pil_patch = wsi.read_region(location=location[::-1], size=size[::-1], level=level)\n\n        # convert to RGB/RGBA\n        pil_patch = pil_patch.convert(mode)\n\n        # Convert to numpy\n        patch = np.asarray(pil_patch, dtype=dtype)\n\n        # Make the channel to desired dimensions\n        patch = np.moveaxis(patch, -1, self.channel_dim)\n\n        return patch\n\n\n@require_pkg(pkg_name=\"tifffile\")\nclass TiffFileWSIReader(BaseWSIReader):\n    \"\"\"\n    Read whole slide images and extract patches using TiffFile library.\n\n    Args:\n        level: the whole slide image level at which the patches are extracted.\n        mpp: the resolution in micron per pixel at which the patches are extracted.\n        mpp_rtol: the acceptable relative tolerance for resolution in micro per pixel.\n        mpp_atol: the acceptable absolute tolerance for resolution in micro per pixel.\n        channel_dim: the desired dimension for color channel. Default to 0 (channel first).\n        dtype: the data type of output image. Defaults to `np.uint8`.\n        device: target device to put the extracted patch. Note that if device is \"cuda\"\",\n            the output will be converted to torch tenor and sent to the gpu even if the dtype is numpy.\n        mode: the output image color mode, \"RGB\" or \"RGBA\". Defaults to \"RGB\".\n        kwargs: additional args for `tifffile.TiffFile` module.\n\n        Notes:\n            - Objective power cannot be obtained via TiffFile backend.\n            - Only one of resolution parameters, `level` or `mpp`, should be provided.\n                If such parameters are provided in `get_data` method, those will override the values provided here.\n                If none of them are provided here or in `get_data`, `level=0` will be used.\n\n    \"\"\"\n\n    supported_suffixes = [\"tif\", \"tiff\", \"svs\"]\n    backend = \"tifffile\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    @staticmethod\n    def get_level_count(wsi) -> int:\n        \"\"\"\n        Returns the number of levels in the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n\n        \"\"\"\n        return len(wsi.pages)\n\n    def get_size(self, wsi, level: int) -> tuple[int, int]:\n        \"\"\"\n        Returns the size (height, width) of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the size is calculated.\n\n        \"\"\"\n        return (wsi.pages[level].imagelength, wsi.pages[level].imagewidth)\n\n    def get_downsample_ratio(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the down-sampling ratio of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the downsample ratio is calculated.\n\n        \"\"\"\n        return float(wsi.pages[0].imagelength) / float(wsi.pages[level].imagelength)\n\n    @staticmethod\n    def get_file_path(wsi) -> str:\n        \"\"\"Return the file path for the WSI object\"\"\"\n        return str(abspath(wsi.filehandle.path))\n\n    def get_mpp(self, wsi, level: int) -> tuple[float, float]:\n        \"\"\"\n        Returns the micro-per-pixel resolution of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the mpp is calculated.\n\n        \"\"\"\n        if (\n            \"XResolution\" in wsi.pages[level].tags\n            and \"YResolution\" in wsi.pages[level].tags\n            and wsi.pages[level].tags[\"XResolution\"].value\n            and wsi.pages[level].tags[\"YResolution\"].value\n        ):\n            unit = wsi.pages[level].tags.get(\"ResolutionUnit\")\n            if unit is not None:\n                unit = str(unit.value.name)\n            if unit is None or len(unit) == 0:\n                warnings.warn(\"The resolution unit is missing. `micrometer` will be used as default.\")\n                unit = \"micrometer\"\n\n            convert_to_micron = ConvertUnits(unit, \"micrometer\")\n            # Here x and y resolutions are rational numbers so each of them is represented by a tuple.\n            yres = wsi.pages[level].tags[\"YResolution\"].value\n            xres = wsi.pages[level].tags[\"XResolution\"].value\n            return convert_to_micron(yres[1] / yres[0]), convert_to_micron(xres[1] / xres[0])\n\n        raise ValueError(\"`mpp`  cannot be obtained for this file. Please use `level` instead.\")\n\n    def get_power(self, wsi, level: int) -> float:\n        \"\"\"\n        Returns the objective power of the whole slide image at a given level.\n\n        Args:\n            wsi: a whole slide image object loaded from a file.\n            level: the level number where the objective power is calculated.\n\n        \"\"\"\n        raise ValueError(\n            \"Currently, TiffFile does not provide a general API to obtain objective power.\"\n            \"Please use `level` (or `mpp`) instead, or try other backends.\"\n        )\n\n    def read(self, data: Sequence[PathLike] | PathLike | np.ndarray, **kwargs):\n        \"\"\"\n        Read whole slide image objects from given file or list of files.\n\n        Args:\n            data: file name or a list of file names to read.\n            kwargs: additional args that overrides `self.kwargs` for existing keys.\n\n        Returns:\n            whole slide image object or list of such objects.\n\n        \"\"\"\n        wsi_list: list = []\n\n        filenames: Sequence[PathLike] = ensure_tuple(data)\n        kwargs_ = self.kwargs.copy()\n        kwargs_.update(kwargs)\n        for filename in filenames:\n            wsi = TiffFile(filename, **kwargs_)\n            wsi_list.append(wsi)\n\n        return wsi_list if len(filenames) > 1 else wsi_list[0]\n\n    def _get_patch(\n        self, wsi, location: tuple[int, int], size: tuple[int, int], level: int, dtype: DtypeLike, mode: str\n    ) -> np.ndarray:\n        \"\"\"\n        Extracts and returns a patch image form the whole slide image.\n\n        Args:\n            wsi: a whole slide image object loaded from a file or a lis of such objects\n            location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).\n            size: (height, width) tuple giving the patch size at the given level (`level`).\n                If None, it is set to the full image size at the given level.\n            level: the level number.\n            dtype: the data type of output image.\n            mode: the output image mode, 'RGB' or 'RGBA'.\n\n        \"\"\"\n        # Load the entire image\n        wsi_image: np.ndarray = wsi.asarray(level=level).astype(dtype)\n        if len(wsi_image.shape) < 3:\n            wsi_image = wsi_image[..., None]\n\n        # Extract patch\n        downsampling_ratio = self.get_downsample_ratio(wsi=wsi, level=level)\n        location_ = [round(location[i] / downsampling_ratio) for i in range(len(location))]\n        patch = wsi_image[location_[0] : location_[0] + size[0], location_[1] : location_[1] + size[1], :]\n\n        # Make the channel to desired dimensions\n        patch = np.moveaxis(patch, -1, self.channel_dim)\n\n        # Check if the color channel is 3 (RGB) or 4 (RGBA)\n        if mode in \"RGB\":\n            if patch.shape[self.channel_dim] not in [3, 4]:\n                raise ValueError(\n                    f\"The image is expected to have three or four color channels in '{mode}' mode but has \"\n                    f\"{patch.shape[self.channel_dim]}. \"\n                )\n            patch = np.take(patch, [0, 1, 2], self.channel_dim)\n\n        return patch\n"
  },
  {
    "path": "monai/engines/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator\nfrom .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer\nfrom .utils import (\n    DiffusionPrepareBatch,\n    IterationEvents,\n    PrepareBatch,\n    PrepareBatchDefault,\n    PrepareBatchExtraInput,\n    VPredictionPrepareBatch,\n    default_make_latent,\n    default_metric_cmp_fn,\n    default_prepare_batch,\n    engine_apply_transform,\n    get_devices_spec,\n)\nfrom .workflow import Workflow\n"
  },
  {
    "path": "monai/engines/evaluator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Iterable, Sequence\nfrom typing import TYPE_CHECKING, Any, Callable\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom monai.config import KeysCollection\nfrom monai.data import MetaTensor\nfrom monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch\nfrom monai.engines.workflow import Workflow\nfrom monai.inferers import Inferer, SimpleInferer\nfrom monai.networks.utils import eval_mode, train_mode\nfrom monai.transforms import Transform\nfrom monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import\nfrom monai.utils.enums import CommonKeys as Keys\nfrom monai.utils.enums import EngineStatsKeys as ESKeys\nfrom monai.utils.module import look_up_option\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine, EventEnum\n    from ignite.metrics import Metric\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    Metric, _ = optional_import(\"ignite.metrics\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Metric\")\n    EventEnum, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"EventEnum\")\n\n__all__ = [\"Evaluator\", \"SupervisedEvaluator\", \"EnsembleEvaluator\"]\n\n\nclass Evaluator(Workflow):\n    \"\"\"\n    Base class for all kinds of evaluators, inherits from Workflow.\n\n    Args:\n        device: an object representing the device on which to run.\n        val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.\n        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously\n            with respect to the host. For other cases, this argument has no effect.\n        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)\n            from `engine.state.batch` for every iteration, for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        iteration_update: the callable function for every iteration, expect to accept `engine`\n            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.\n            if not provided, use `self._iteration()` instead. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n        postprocessing: execute additional transformation for the model output data.\n            Typically, several Tensor based transforms composed by `Compose`.\n        key_val_metric: compute metric when every iteration completed, and save average value to\n            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the\n            checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value,\n            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update\n            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        amp: whether to enable auto-mixed-precision evaluation, default is False.\n        mode: model forward mode during evaluation, should be 'eval' or 'train',\n            which maps to `model.eval()` or `model.train()`, default to 'eval'.\n        event_names: additional custom ignite events that will register to the engine.\n            new events can be a list of str or `ignite.engine.events.EventEnum`.\n        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.\n            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html\n            #ignite.engine.engine.Engine.register_events.\n        decollate: whether to decollate the batch-first data to a list of data after model computation,\n            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.\n            default to `True`.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        device: torch.device | str,\n        val_data_loader: Iterable | DataLoader,\n        epoch_length: int | None = None,\n        non_blocking: bool = False,\n        prepare_batch: Callable = default_prepare_batch,\n        iteration_update: Callable[[Engine, Any], Any] | None = None,\n        postprocessing: Transform | None = None,\n        key_val_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        val_handlers: Sequence | None = None,\n        amp: bool = False,\n        mode: ForwardMode | str = ForwardMode.EVAL,\n        event_names: list[str | EventEnum | type[EventEnum]] | None = None,\n        event_to_attr: dict | None = None,\n        decollate: bool = True,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n    ) -> None:\n        super().__init__(\n            device=device,\n            max_epochs=1,\n            data_loader=val_data_loader,\n            epoch_length=epoch_length,\n            non_blocking=non_blocking,\n            prepare_batch=prepare_batch,\n            iteration_update=iteration_update,\n            postprocessing=postprocessing,\n            key_metric=key_val_metric,\n            additional_metrics=additional_metrics,\n            metric_cmp_fn=metric_cmp_fn,\n            handlers=val_handlers,\n            amp=amp,\n            event_names=event_names,\n            event_to_attr=event_to_attr,\n            decollate=decollate,\n            to_kwargs=to_kwargs,\n            amp_kwargs=amp_kwargs,\n        )\n        mode = look_up_option(mode, ForwardMode)\n        if mode == ForwardMode.EVAL:\n            self.mode = eval_mode\n        elif mode == ForwardMode.TRAIN:\n            self.mode = train_mode\n        else:\n            raise ValueError(f\"unsupported mode: {mode}, should be 'eval' or 'train'.\")\n\n    def run(self, global_epoch: int = 1) -> None:  # type: ignore[override]\n        \"\"\"\n        Execute validation/evaluation based on Ignite Engine.\n\n        Args:\n            global_epoch: the overall epoch if during a training. evaluator engine can get it from trainer.\n\n        \"\"\"\n        # init env value for current validation process\n        self.state.max_epochs = max(global_epoch, 1)  # at least one epoch of validation\n        self.state.epoch = global_epoch - 1\n        self.state.iteration = 0\n        super().run()\n\n    def get_stats(self, *vars):\n        \"\"\"\n        Get the statistics information of the validation process.\n        Default to return the `rank`, `best_validation_epoch` and `best_validation_metric`.\n\n        Args:\n            vars: except for the default stats, other variables name in the `self.state` to return,\n                will use the variable name as the key and the state content as the value.\n                if the variable doesn't exist, default value is `None`.\n\n        \"\"\"\n        stats = {\n            ESKeys.RANK: self.state.rank,\n            ESKeys.BEST_VALIDATION_EPOCH: self.state.best_metric_epoch,\n            ESKeys.BEST_VALIDATION_METRIC: self.state.best_metric,\n        }\n        for k in vars:\n            stats[k] = getattr(self.state, k, None)\n        return stats\n\n\nclass SupervisedEvaluator(Evaluator):\n    \"\"\"\n    Standard supervised evaluation method with image and label(optional), inherits from evaluator and Workflow.\n\n    Args:\n        device: an object representing the device on which to run.\n        val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.\n        network: network to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.\n        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously\n            with respect to the host. For other cases, this argument has no effect.\n        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)\n            from `engine.state.batch` for every iteration, for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        iteration_update: the callable function for every iteration, expect to accept `engine`\n            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.\n            if not provided, use `self._iteration()` instead. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.\n        postprocessing: execute additional transformation for the model output data.\n            Typically, several Tensor based transforms composed by `Compose`.\n        key_val_metric: compute metric when every iteration completed, and save average value to\n            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the\n            checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value,\n            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update\n            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        amp: whether to enable auto-mixed-precision evaluation, default is False.\n        mode: model forward mode during evaluation, should be 'eval' or 'train',\n            which maps to `model.eval()` or `model.train()`, default to 'eval'.\n        event_names: additional custom ignite events that will register to the engine.\n            new events can be a list of str or `ignite.engine.events.EventEnum`.\n        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.\n            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html\n            #ignite.engine.engine.Engine.register_events.\n        decollate: whether to decollate the batch-first data to a list of data after model computation,\n            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.\n            default to `True`.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n        compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to\n            `torch.Tensor` before forward pass,  then converted back afterward with copied meta information.\n        compile_kwargs: dict of the args for `torch.compile()` API, for more details:\n            https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        device: torch.device,\n        val_data_loader: Iterable | DataLoader,\n        network: torch.nn.Module,\n        epoch_length: int | None = None,\n        non_blocking: bool = False,\n        prepare_batch: Callable = default_prepare_batch,\n        iteration_update: Callable[[Engine, Any], Any] | None = None,\n        inferer: Inferer | None = None,\n        postprocessing: Transform | None = None,\n        key_val_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        val_handlers: Sequence | None = None,\n        amp: bool = False,\n        mode: ForwardMode | str = ForwardMode.EVAL,\n        event_names: list[str | EventEnum | type[EventEnum]] | None = None,\n        event_to_attr: dict | None = None,\n        decollate: bool = True,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n        compile: bool = False,\n        compile_kwargs: dict | None = None,\n    ) -> None:\n        super().__init__(\n            device=device,\n            val_data_loader=val_data_loader,\n            epoch_length=epoch_length,\n            non_blocking=non_blocking,\n            prepare_batch=prepare_batch,\n            iteration_update=iteration_update,\n            postprocessing=postprocessing,\n            key_val_metric=key_val_metric,\n            additional_metrics=additional_metrics,\n            metric_cmp_fn=metric_cmp_fn,\n            val_handlers=val_handlers,\n            amp=amp,\n            mode=mode,\n            event_names=event_names,\n            event_to_attr=event_to_attr,\n            decollate=decollate,\n            to_kwargs=to_kwargs,\n            amp_kwargs=amp_kwargs,\n        )\n        if compile:\n            compile_kwargs = {} if compile_kwargs is None else compile_kwargs\n            network = torch.compile(network, **compile_kwargs)  # type: ignore[assignment]\n        self.network = network\n        self.compile = compile\n        self.inferer = SimpleInferer() if inferer is None else inferer\n\n    def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:\n        \"\"\"\n        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.\n        Return below items in a dictionary:\n            - IMAGE: image Tensor data for model input, already moved to device.\n            - LABEL: label Tensor data corresponding to the image, already moved to device.\n            - PRED: prediction result of model.\n\n        Args:\n            engine: `SupervisedEvaluator` to execute operation for an iteration.\n            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.\n\n        Raises:\n            ValueError: When ``batchdata`` is None.\n\n        \"\"\"\n        if batchdata is None:\n            raise ValueError(\"Must provide batch data for current iteration.\")\n        batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)\n        if len(batch) == 2:\n            inputs, targets = batch\n            args: tuple = ()\n            kwargs: dict = {}\n        else:\n            inputs, targets, args, kwargs = batch\n        # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026\n        if self.compile:\n            inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None\n            if isinstance(inputs, MetaTensor):\n                warnings.warn(\n                    \"Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.\"\n                )\n                inputs, inputs_meta, inputs_applied_operations = (\n                    inputs.as_tensor(),\n                    inputs.meta,\n                    inputs.applied_operations,\n                )\n            if isinstance(targets, MetaTensor):\n                targets, targets_meta, targets_applied_operations = (\n                    targets.as_tensor(),\n                    targets.meta,\n                    targets.applied_operations,\n                )\n\n        # put iteration outputs into engine.state\n        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}\n        # execute forward computation\n        with engine.mode(engine.network):\n            if engine.amp:\n                with torch.autocast(\"cuda\", **engine.amp_kwargs):\n                    engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)\n            else:\n                engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)\n        # copy back meta info\n        if self.compile:\n            if inputs_meta is not None:\n                engine.state.output[Keys.IMAGE] = MetaTensor(\n                    inputs, meta=inputs_meta, applied_operations=inputs_applied_operations\n                )\n                engine.state.output[Keys.PRED] = MetaTensor(\n                    engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations\n                )\n            if targets_meta is not None:\n                engine.state.output[Keys.LABEL] = MetaTensor(\n                    targets, meta=targets_meta, applied_operations=targets_applied_operations\n                )\n        engine.fire_event(IterationEvents.FORWARD_COMPLETED)\n        engine.fire_event(IterationEvents.MODEL_COMPLETED)\n\n        return engine.state.output\n\n\nclass EnsembleEvaluator(Evaluator):\n    \"\"\"\n    Ensemble evaluation for multiple models, inherits from evaluator and Workflow.\n    It accepts a list of models for inference and outputs a list of predictions for further operations.\n\n    Args:\n        device: an object representing the device on which to run.\n        val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.\n        epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.\n        networks: networks to evaluate in order in the evaluator, should be regular PyTorch `torch.nn.Module`.\n        pred_keys: the keys to store every prediction data.\n            the length must exactly match the number of networks.\n            if None, use \"pred_{index}\" as key corresponding to N networks, index from `0` to `N-1`.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously\n            with respect to the host. For other cases, this argument has no effect.\n        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)\n            from `engine.state.batch` for every iteration, for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        iteration_update: the callable function for every iteration, expect to accept `engine`\n            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.\n            if not provided, use `self._iteration()` instead. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.\n        postprocessing: execute additional transformation for the model output data.\n            Typically, several Tensor based transforms composed by `Compose`.\n        key_val_metric: compute metric when every iteration completed, and save average value to\n            engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the\n            checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value,\n            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update\n            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        amp: whether to enable auto-mixed-precision evaluation, default is False.\n        mode: model forward mode during evaluation, should be 'eval' or 'train',\n            which maps to `model.eval()` or `model.train()`, default to 'eval'.\n        event_names: additional custom ignite events that will register to the engine.\n            new events can be a list of str or `ignite.engine.events.EventEnum`.\n        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.\n            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html\n            #ignite.engine.engine.Engine.register_events.\n        decollate: whether to decollate the batch-first data to a list of data after model computation,\n            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.\n            default to `True`.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        device: torch.device,\n        val_data_loader: Iterable | DataLoader,\n        networks: Sequence[torch.nn.Module],\n        pred_keys: KeysCollection | None = None,\n        epoch_length: int | None = None,\n        non_blocking: bool = False,\n        prepare_batch: Callable = default_prepare_batch,\n        iteration_update: Callable[[Engine, Any], Any] | None = None,\n        inferer: Inferer | None = None,\n        postprocessing: Transform | None = None,\n        key_val_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        val_handlers: Sequence | None = None,\n        amp: bool = False,\n        mode: ForwardMode | str = ForwardMode.EVAL,\n        event_names: list[str | EventEnum | type[EventEnum]] | None = None,\n        event_to_attr: dict | None = None,\n        decollate: bool = True,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n    ) -> None:\n        super().__init__(\n            device=device,\n            val_data_loader=val_data_loader,\n            epoch_length=epoch_length,\n            non_blocking=non_blocking,\n            prepare_batch=prepare_batch,\n            iteration_update=iteration_update,\n            postprocessing=postprocessing,\n            key_val_metric=key_val_metric,\n            additional_metrics=additional_metrics,\n            metric_cmp_fn=metric_cmp_fn,\n            val_handlers=val_handlers,\n            amp=amp,\n            mode=mode,\n            event_names=event_names,\n            event_to_attr=event_to_attr,\n            decollate=decollate,\n            to_kwargs=to_kwargs,\n            amp_kwargs=amp_kwargs,\n        )\n\n        self.networks = ensure_tuple(networks)\n        self.pred_keys = (\n            [f\"{Keys.PRED}_{i}\" for i in range(len(self.networks))] if pred_keys is None else ensure_tuple(pred_keys)\n        )\n        if len(self.pred_keys) != len(self.networks):\n            raise ValueError(\"length of `pred_keys` must be same as the length of `networks`.\")\n        self.inferer = SimpleInferer() if inferer is None else inferer\n\n    def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:\n        \"\"\"\n        callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.\n        Return below items in a dictionary:\n            - IMAGE: image Tensor data for model input, already moved to device.\n            - LABEL: label Tensor data corresponding to the image, already moved to device.\n            - pred_keys[0]: prediction result of network 0.\n            - pred_keys[1]: prediction result of network 1.\n            - ... ...\n            - pred_keys[N]: prediction result of network N.\n\n        Args:\n            engine: `EnsembleEvaluator` to execute operation for an iteration.\n            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.\n\n        Raises:\n            ValueError: When ``batchdata`` is None.\n\n        \"\"\"\n        if batchdata is None:\n            raise ValueError(\"Must provide batch data for current iteration.\")\n        batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)\n        if len(batch) == 2:\n            inputs, targets = batch\n            args: tuple = ()\n            kwargs: dict = {}\n        else:\n            inputs, targets, args, kwargs = batch\n\n        # put iteration outputs into engine.state\n        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}\n\n        for idx, network in enumerate(engine.networks):\n            with engine.mode(network):\n                if engine.amp:\n                    with torch.autocast(\"cuda\", **engine.amp_kwargs):\n                        if isinstance(engine.state.output, dict):\n                            engine.state.output.update(\n                                {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}\n                            )\n                else:\n                    if isinstance(engine.state.output, dict):\n                        engine.state.output.update(\n                            {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}\n                        )\n        engine.fire_event(IterationEvents.FORWARD_COMPLETED)\n        engine.fire_event(IterationEvents.MODEL_COMPLETED)\n\n        return engine.state.output\n"
  },
  {
    "path": "monai/engines/trainer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Iterable, Sequence\nfrom typing import TYPE_CHECKING, Any, Callable\n\nimport torch\nfrom torch.optim.optimizer import Optimizer\nfrom torch.utils.data import DataLoader\n\nfrom monai.data import MetaTensor\nfrom monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch\nfrom monai.engines.workflow import Workflow\nfrom monai.inferers import Inferer, SimpleInferer\nfrom monai.transforms import Transform\nfrom monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import\nfrom monai.utils.enums import CommonKeys as Keys\nfrom monai.utils.enums import EngineStatsKeys as ESKeys\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine, EventEnum\n    from ignite.metrics import Metric\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    Metric, _ = optional_import(\"ignite.metrics\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Metric\")\n    EventEnum, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"EventEnum\")\n\n__all__ = [\"Trainer\", \"SupervisedTrainer\", \"GanTrainer\", \"AdversarialTrainer\"]\n\n\nclass Trainer(Workflow):\n    \"\"\"\n    Base class for all kinds of trainers, inherits from Workflow.\n\n    \"\"\"\n\n    def run(self) -> None:  # type: ignore[override]\n        \"\"\"\n        Execute training based on Ignite Engine.\n        If call this function multiple times, it will continuously run from the previous state.\n\n        \"\"\"\n        self.scaler = torch.cuda.amp.GradScaler() if self.amp else None\n        super().run()\n\n    def get_stats(self, *vars):\n        \"\"\"\n        Get the statistics information of the training process.\n        Default to return the `rank`, `current_epoch`, `current_iteration`, `total_epochs`, `total_iterations`.\n\n        Args:\n            vars: except for the default stats, other variables name in the `self.state` to return,\n                will use the variable name as the key and the state content as the value.\n                if the variable doesn't exist, default value is `None`.\n\n        \"\"\"\n        stats = {\n            ESKeys.RANK: self.state.rank,\n            ESKeys.CURRENT_EPOCH: self.state.epoch,\n            ESKeys.CURRENT_ITERATION: self.state.iteration,\n            ESKeys.TOTAL_EPOCHS: self.state.max_epochs,\n            ESKeys.TOTAL_ITERATIONS: self.state.epoch_length,\n        }\n        for k in vars:\n            stats[k] = getattr(self.state, k, None)\n        return stats\n\n\nclass SupervisedTrainer(Trainer):\n    \"\"\"\n    Standard supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``.\n\n    Args:\n        device: an object representing the device on which to run.\n        max_epochs: the total epoch number for trainer to run.\n        train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.\n        network: network to train in the trainer, should be regular PyTorch `torch.nn.Module`.\n        optimizer: the optimizer associated to the network, should be regular PyTorch optimizer from `torch.optim`\n            or its subclass.\n        loss_function: the loss function associated to the optimizer, should be regular PyTorch loss,\n            which inherit from `torch.nn.modules.loss`.\n        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously\n            with respect to the host. For other cases, this argument has no effect.\n        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)\n            from `engine.state.batch` for every iteration, for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        iteration_update: the callable function for every iteration, expect to accept `engine`\n            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.\n            if not provided, use `self._iteration()` instead. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n        inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.\n        postprocessing: execute additional transformation for the model output data.\n            Typically, several Tensor based transforms composed by `Compose`.\n        key_train_metric: compute metric when every iteration completed, and save average value to\n            engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the\n            checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value,\n            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update\n            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        amp: whether to enable auto-mixed-precision training, default is False.\n        event_names: additional custom ignite events that will register to the engine.\n            new events can be a list of str or `ignite.engine.events.EventEnum`.\n        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.\n            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html\n            #ignite.engine.engine.Engine.register_events.\n        decollate: whether to decollate the batch-first data to a list of data after model computation,\n            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.\n            default to `True`.\n        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.\n            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n        compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to\n            `torch.Tensor` before forward pass,  then converted back afterward with copied meta information.\n        compile_kwargs: dict of the args for `torch.compile()` API, for more details:\n            https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile.\n    \"\"\"\n\n    def __init__(\n        self,\n        device: str | torch.device,\n        max_epochs: int,\n        train_data_loader: Iterable | DataLoader,\n        network: torch.nn.Module,\n        optimizer: Optimizer,\n        loss_function: Callable,\n        epoch_length: int | None = None,\n        non_blocking: bool = False,\n        prepare_batch: Callable = default_prepare_batch,\n        iteration_update: Callable[[Engine, Any], Any] | None = None,\n        inferer: Inferer | None = None,\n        postprocessing: Transform | None = None,\n        key_train_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        train_handlers: Sequence | None = None,\n        amp: bool = False,\n        event_names: list[str | EventEnum | type[EventEnum]] | None = None,\n        event_to_attr: dict | None = None,\n        decollate: bool = True,\n        optim_set_to_none: bool = False,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n        compile: bool = False,\n        compile_kwargs: dict | None = None,\n    ) -> None:\n        super().__init__(\n            device=device,\n            max_epochs=max_epochs,\n            data_loader=train_data_loader,\n            epoch_length=epoch_length,\n            non_blocking=non_blocking,\n            prepare_batch=prepare_batch,\n            iteration_update=iteration_update,\n            postprocessing=postprocessing,\n            key_metric=key_train_metric,\n            additional_metrics=additional_metrics,\n            metric_cmp_fn=metric_cmp_fn,\n            handlers=train_handlers,\n            amp=amp,\n            event_names=event_names,\n            event_to_attr=event_to_attr,\n            decollate=decollate,\n            to_kwargs=to_kwargs,\n            amp_kwargs=amp_kwargs,\n        )\n        if compile:\n            compile_kwargs = {} if compile_kwargs is None else compile_kwargs\n            network = torch.compile(network, **compile_kwargs)  # type: ignore[assignment]\n        self.network = network\n        self.compile = compile\n        self.optimizer = optimizer\n        self.loss_function = loss_function\n        self.inferer = SimpleInferer() if inferer is None else inferer\n        self.optim_set_to_none = optim_set_to_none\n\n    def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]) -> dict:\n        \"\"\"\n        Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.\n        Return below items in a dictionary:\n            - IMAGE: image Tensor data for model input, already moved to device.\n            - LABEL: label Tensor data corresponding to the image, already moved to device.\n            - PRED: prediction result of model.\n            - LOSS: loss value computed by loss function.\n\n        Args:\n            engine: `SupervisedTrainer` to execute operation for an iteration.\n            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.\n\n        Raises:\n            ValueError: When ``batchdata`` is None.\n\n        \"\"\"\n        if batchdata is None:\n            raise ValueError(\"Must provide batch data for current iteration.\")\n        batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)\n        if len(batch) == 2:\n            inputs, targets = batch\n            args: tuple = ()\n            kwargs: dict = {}\n        else:\n            inputs, targets, args, kwargs = batch\n        # FIXME: workaround for https://github.com/pytorch/pytorch/issues/117026\n        if self.compile:\n            inputs_meta, targets_meta, inputs_applied_operations, targets_applied_operations = None, None, None, None\n            if isinstance(inputs, MetaTensor):\n                warnings.warn(\n                    \"Will convert to PyTorch Tensor if using compile, and casting back to MetaTensor after the forward pass.\"\n                )\n                inputs, inputs_meta, inputs_applied_operations = (\n                    inputs.as_tensor(),\n                    inputs.meta,\n                    inputs.applied_operations,\n                )\n            if isinstance(targets, MetaTensor):\n                targets, targets_meta, targets_applied_operations = (\n                    targets.as_tensor(),\n                    targets.meta,\n                    targets.applied_operations,\n                )\n\n        # put iteration outputs into engine.state\n        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}\n\n        def _compute_pred_loss():\n            engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)\n            engine.fire_event(IterationEvents.FORWARD_COMPLETED)\n            engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean()\n            engine.fire_event(IterationEvents.LOSS_COMPLETED)\n\n        engine.network.train()\n        engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)\n\n        if engine.amp and engine.scaler is not None:\n            with torch.autocast(\"cuda\", **engine.amp_kwargs):\n                _compute_pred_loss()\n            engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()\n            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)\n            engine.scaler.step(engine.optimizer)\n            engine.scaler.update()\n        else:\n            _compute_pred_loss()\n            engine.state.output[Keys.LOSS].backward()\n            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)\n            engine.optimizer.step()\n        # copy back meta info\n        if self.compile:\n            if inputs_meta is not None:\n                engine.state.output[Keys.IMAGE] = MetaTensor(\n                    inputs, meta=inputs_meta, applied_operations=inputs_applied_operations\n                )\n                engine.state.output[Keys.PRED] = MetaTensor(\n                    engine.state.output[Keys.PRED], meta=inputs_meta, applied_operations=inputs_applied_operations\n                )\n            if targets_meta is not None:\n                engine.state.output[Keys.LABEL] = MetaTensor(\n                    targets, meta=targets_meta, applied_operations=targets_applied_operations\n                )\n        engine.fire_event(IterationEvents.MODEL_COMPLETED)\n\n        return engine.state.output\n\n\nclass GanTrainer(Trainer):\n    \"\"\"\n    Generative adversarial network training based on Goodfellow et al. 2014 https://arxiv.org/abs/1406.266,\n    inherits from ``Trainer`` and ``Workflow``.\n\n    Training Loop: for each batch of data size `m`\n        1. Generate `m` fakes from random latent codes.\n        2. Update discriminator with these fakes and current batch reals, repeated d_train_steps times.\n        3. If g_update_latents, generate `m` fakes from new random latent codes.\n        4. Update generator with these fakes using discriminator feedback.\n\n    Args:\n        device: an object representing the device on which to run.\n        max_epochs: the total epoch number for engine to run.\n        train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.\n        g_network: generator (G) network architecture.\n        g_optimizer: G optimizer function.\n        g_loss_function: G loss function for optimizer.\n        d_network: discriminator (D) network architecture.\n        d_optimizer: D optimizer function.\n        d_loss_function: D loss function for optimizer.\n        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.\n        g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.\n        d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.\n        d_train_steps: number of times to update D with real data minibatch. Defaults to ``1``.\n        latent_shape: size of G input latent code. Defaults to ``64``.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously\n            with respect to the host. For other cases, this argument has no effect.\n        d_prepare_batch: callback function to prepare batchdata for D inferer.\n            Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        g_prepare_batch: callback function to create batch of latent input for G inferer.\n            Defaults to return random latents. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.\n        iteration_update: the callable function for every iteration, expect to accept `engine`\n            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.\n            if not provided, use `self._iteration()` instead. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n        postprocessing: execute additional transformation for the model output data.\n            Typically, several Tensor based transforms composed by `Compose`.\n        key_train_metric: compute metric when every iteration completed, and save average value to\n            engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the\n            checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value,\n            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update\n            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        decollate: whether to decollate the batch-first data to a list of data after model computation,\n            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.\n            default to `True`.\n        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.\n            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        device: str | torch.device,\n        max_epochs: int,\n        train_data_loader: DataLoader,\n        g_network: torch.nn.Module,\n        g_optimizer: Optimizer,\n        g_loss_function: Callable,\n        d_network: torch.nn.Module,\n        d_optimizer: Optimizer,\n        d_loss_function: Callable,\n        epoch_length: int | None = None,\n        g_inferer: Inferer | None = None,\n        d_inferer: Inferer | None = None,\n        d_train_steps: int = 1,\n        latent_shape: int = 64,\n        non_blocking: bool = False,\n        d_prepare_batch: Callable = default_prepare_batch,\n        g_prepare_batch: Callable = default_make_latent,\n        g_update_latents: bool = True,\n        iteration_update: Callable[[Engine, Any], Any] | None = None,\n        postprocessing: Transform | None = None,\n        key_train_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        train_handlers: Sequence | None = None,\n        decollate: bool = True,\n        optim_set_to_none: bool = False,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n    ):\n        if not isinstance(train_data_loader, DataLoader):\n            raise ValueError(\"train_data_loader must be PyTorch DataLoader.\")\n\n        # set up Ignite engine and environments\n        super().__init__(\n            device=device,\n            max_epochs=max_epochs,\n            data_loader=train_data_loader,\n            epoch_length=epoch_length,\n            non_blocking=non_blocking,\n            prepare_batch=d_prepare_batch,\n            iteration_update=iteration_update,\n            key_metric=key_train_metric,\n            additional_metrics=additional_metrics,\n            metric_cmp_fn=metric_cmp_fn,\n            handlers=train_handlers,\n            postprocessing=postprocessing,\n            decollate=decollate,\n            to_kwargs=to_kwargs,\n            amp_kwargs=amp_kwargs,\n        )\n        self.g_network = g_network\n        self.g_optimizer = g_optimizer\n        self.g_loss_function = g_loss_function\n        self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer\n        self.d_network = d_network\n        self.d_optimizer = d_optimizer\n        self.d_loss_function = d_loss_function\n        self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer\n        self.d_train_steps = d_train_steps\n        self.latent_shape = latent_shape\n        self.g_prepare_batch = g_prepare_batch\n        self.g_update_latents = g_update_latents\n        self.optim_set_to_none = optim_set_to_none\n\n    def _iteration(\n        self, engine: GanTrainer, batchdata: dict | Sequence\n    ) -> dict[str, torch.Tensor | int | float | bool]:\n        \"\"\"\n        Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine.\n\n        Args:\n            engine: `GanTrainer` to execute operation for an iteration.\n            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.\n\n        Raises:\n            ValueError: must provide batch data for current iteration.\n\n        \"\"\"\n        if batchdata is None:\n            raise ValueError(\"must provide batch data for current iteration.\")\n\n        d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)\n        batch_size = engine.data_loader.batch_size  # type: ignore\n        g_input = engine.g_prepare_batch(\n            num_latents=batch_size,\n            latent_size=engine.latent_shape,\n            device=engine.state.device,\n            non_blocking=engine.non_blocking,\n            **engine.to_kwargs,\n        )\n        g_output = engine.g_inferer(g_input, engine.g_network)\n\n        # Train Discriminator\n        d_total_loss = torch.zeros(1)\n        for _ in range(engine.d_train_steps):\n            engine.d_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)\n            dloss = engine.d_loss_function(g_output, d_input)\n            dloss.backward()\n            engine.d_optimizer.step()\n            d_total_loss += dloss.item()\n\n        # Train Generator\n        if engine.g_update_latents:\n            g_input = engine.g_prepare_batch(\n                num_latents=batch_size,\n                latent_size=engine.latent_shape,\n                device=engine.state.device,\n                non_blocking=engine.non_blocking,\n                **engine.to_kwargs,\n            )\n        g_output = engine.g_inferer(g_input, engine.g_network)\n        engine.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)\n        g_loss = engine.g_loss_function(g_output)\n        g_loss.backward()\n        engine.g_optimizer.step()\n\n        return {\n            GanKeys.REALS: d_input,\n            GanKeys.FAKES: g_output,\n            GanKeys.LATENTS: g_input,\n            GanKeys.GLOSS: g_loss.item(),\n            GanKeys.DLOSS: d_total_loss.item(),\n        }\n\n\nclass AdversarialTrainer(Trainer):\n    \"\"\"\n    Standard supervised training workflow for adversarial loss enabled neural networks.\n\n    Args:\n        device: an object representing the device on which to run.\n        max_epochs: the total epoch number for engine to run.\n        train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.\n        g_network: ''generator'' (G) network architecture.\n        g_optimizer: G optimizer function.\n        g_loss_function: G loss function for adversarial training.\n        recon_loss_function: G loss function for reconstructions.\n        d_network: discriminator (D) network architecture.\n        d_optimizer: D optimizer function.\n        d_loss_function: D loss function for adversarial training..\n        epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to\n            the host. For other cases, this argument has no effect.\n        prepare_batch: function to parse image and label for current iteration.\n        iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input\n            parameters. if not provided, use `self._iteration()` instead.\n        g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.\n        d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.\n        postprocessing: execute additional transformation for the model output data. Typically, several Tensor based\n            transforms composed by `Compose`. Defaults to None\n        key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics\n            when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args\n            (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and\n            `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        amp: whether to enable auto-mixed-precision training, default is False.\n        event_names: additional custom ignite events that will register to the engine.\n            new events can be a list of str or `ignite.engine.events.EventEnum`.\n        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.\n            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html\n            #ignite.engine.engine.Engine.register_events.\n        decollate: whether to decollate the batch-first data to a list of data after model computation, recommend\n            `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`.\n        optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.\n            more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n    \"\"\"\n\n    def __init__(\n        self,\n        device: torch.device | str,\n        max_epochs: int,\n        train_data_loader: Iterable | DataLoader,\n        g_network: torch.nn.Module,\n        g_optimizer: Optimizer,\n        g_loss_function: Callable,\n        recon_loss_function: Callable,\n        d_network: torch.nn.Module,\n        d_optimizer: Optimizer,\n        d_loss_function: Callable,\n        epoch_length: int | None = None,\n        non_blocking: bool = False,\n        prepare_batch: Callable = default_prepare_batch,\n        iteration_update: Callable | None = None,\n        g_inferer: Inferer | None = None,\n        d_inferer: Inferer | None = None,\n        postprocessing: Transform | None = None,\n        key_train_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        train_handlers: Sequence | None = None,\n        amp: bool = False,\n        event_names: list[str | EventEnum | type[EventEnum]] | None = None,\n        event_to_attr: dict | None = None,\n        decollate: bool = True,\n        optim_set_to_none: bool = False,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n    ):\n        super().__init__(\n            device=device,\n            max_epochs=max_epochs,\n            data_loader=train_data_loader,\n            epoch_length=epoch_length,\n            non_blocking=non_blocking,\n            prepare_batch=prepare_batch,\n            iteration_update=iteration_update,\n            postprocessing=postprocessing,\n            key_metric=key_train_metric,\n            additional_metrics=additional_metrics,\n            metric_cmp_fn=metric_cmp_fn,\n            handlers=train_handlers,\n            amp=amp,\n            event_names=event_names,\n            event_to_attr=event_to_attr,\n            decollate=decollate,\n            to_kwargs=to_kwargs,\n            amp_kwargs=amp_kwargs,\n        )\n\n        self.register_events(*AdversarialIterationEvents)\n\n        self.state.g_network = g_network\n        self.state.g_optimizer = g_optimizer\n        self.state.g_loss_function = g_loss_function\n        self.state.recon_loss_function = recon_loss_function\n\n        self.state.d_network = d_network\n        self.state.d_optimizer = d_optimizer\n        self.state.d_loss_function = d_loss_function\n\n        self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer\n        self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer\n\n        self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None\n        self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None\n\n        self.optim_set_to_none = optim_set_to_none\n        self._complete_state_dict_user_keys()\n\n    def _complete_state_dict_user_keys(self) -> None:\n        \"\"\"\n        This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for\n        checkpoint saving.\n\n        Follows the example found at:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict\n        \"\"\"\n        self._state_dict_user_keys.extend(\n            [\"g_network\", \"g_optimizer\", \"d_network\", \"d_optimizer\", \"g_scaler\", \"d_scaler\"]\n        )\n\n        g_loss_state_dict = getattr(self.state.g_loss_function, \"state_dict\", None)\n        if callable(g_loss_state_dict):\n            self._state_dict_user_keys.append(\"g_loss_function\")\n\n        d_loss_state_dict = getattr(self.state.d_loss_function, \"state_dict\", None)\n        if callable(d_loss_state_dict):\n            self._state_dict_user_keys.append(\"d_loss_function\")\n\n        recon_loss_state_dict = getattr(self.state.recon_loss_function, \"state_dict\", None)\n        if callable(recon_loss_state_dict):\n            self._state_dict_user_keys.append(\"recon_loss_function\")\n\n    def _iteration(\n        self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]\n    ) -> dict[str, torch.Tensor | int | float | bool]:\n        \"\"\"\n        Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.\n        Return below items in a dictionary:\n            - IMAGE: image Tensor data for model input, already moved to device.\n            - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised\n                Learning this is equal to IMAGE.\n            - PRED: prediction result of model.\n            - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up).\n            - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE.\n            - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED.\n            - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images.\n            - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images.\n            - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function.\n            - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the\n                discriminator loss for the fake images. That is backpropagated through the generator only.\n            - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the\n                discriminator loss for the real images and the fake images. That is backpropagated through the\n                discriminator only.\n\n        Args:\n            engine: `AdversarialTrainer` to execute operation for an iteration.\n            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.\n\n        Raises:\n            ValueError: must provide batch data for current iteration.\n\n        \"\"\"\n\n        if batchdata is None:\n            raise ValueError(\"Must provide batch data for current iteration.\")\n        batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)\n\n        if len(batch) == 2:\n            inputs, targets = batch\n            args: tuple = ()\n            kwargs: dict = {}\n        else:\n            inputs, targets, args, kwargs = batch\n\n        engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs}\n\n        def _compute_generator_loss() -> None:\n            engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer(\n                inputs, engine.state.g_network, *args, **kwargs\n            )\n            engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES]\n            engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED)\n\n            engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(\n                engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs\n            )\n            engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED)\n\n            engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function(\n                engine.state.output[AdversarialKeys.FAKES], targets\n            ).mean()\n            engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED)\n\n            engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function(\n                engine.state.output[AdversarialKeys.FAKE_LOGITS]\n            ).mean()\n            engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED)\n\n        # Train Generator\n        engine.state.g_network.train()\n        engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)\n\n        if engine.amp and engine.state.g_scaler is not None:\n            with torch.autocast(\"cuda\", **engine.amp_kwargs):\n                _compute_generator_loss()\n\n            engine.state.output[Keys.LOSS] = (\n                engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]\n                + engine.state.output[AdversarialKeys.GENERATOR_LOSS]\n            )\n            engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward()\n            engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)\n            engine.state.g_scaler.step(engine.state.g_optimizer)\n            engine.state.g_scaler.update()\n        else:\n            _compute_generator_loss()\n            (\n                engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]\n                + engine.state.output[AdversarialKeys.GENERATOR_LOSS]\n            ).backward()\n            engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)\n            engine.state.g_optimizer.step()\n        engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED)\n\n        def _compute_discriminator_loss() -> None:\n            engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer(\n                engine.state.output[AdversarialKeys.REALS].contiguous().detach(),\n                engine.state.d_network,\n                *args,\n                **kwargs,\n            )\n            engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED)\n\n            engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(\n                engine.state.output[AdversarialKeys.FAKES].contiguous().detach(),\n                engine.state.d_network,\n                *args,\n                **kwargs,\n            )\n            engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED)\n\n            engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function(\n                engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS]\n            ).mean()\n            engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED)\n\n        # Train Discriminator\n        engine.state.d_network.train()\n        engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)\n\n        if engine.amp and engine.state.d_scaler is not None:\n            with torch.autocast(\"cuda\", **engine.amp_kwargs):\n                _compute_discriminator_loss()\n\n            engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()\n            engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED)\n            engine.state.d_scaler.step(engine.state.d_optimizer)\n            engine.state.d_scaler.update()\n        else:\n            _compute_discriminator_loss()\n            engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward()\n            engine.state.d_optimizer.step()\n\n        return engine.state.output\n"
  },
  {
    "path": "monai/engines/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Mapping, Sequence\nfrom typing import TYPE_CHECKING, Any, cast\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.transforms import apply_transform\nfrom monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import\nfrom monai.utils.enums import CommonKeys, GanKeys\n\nif TYPE_CHECKING:\n    from ignite.engine import EventEnum\nelse:\n    EventEnum, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"EventEnum\", as_type=\"base\"\n    )\n\n__all__ = [\n    \"IterationEvents\",\n    \"get_devices_spec\",\n    \"default_prepare_batch\",\n    \"PrepareBatch\",\n    \"PrepareBatchDefault\",\n    \"PrepareBatchExtraInput\",\n    \"DiffusionPrepareBatch\",\n    \"VPredictionPrepareBatch\",\n    \"default_make_latent\",\n    \"engine_apply_transform\",\n    \"default_metric_cmp_fn\",\n]\n\n\nclass IterationEvents(EventEnum):\n    \"\"\"\n    Additional Events engine can register and trigger in the iteration process.\n    Refer to the example in ignite: https://pytorch.org/ignite/generated/ignite.engine.events.EventEnum.html.\n    These Events can be triggered during training iteration:\n    `FORWARD_COMPLETED` is the Event when `network(image, label)` completed.\n    `LOSS_COMPLETED` is the Event when `loss(pred, label)` completed.\n    `BACKWARD_COMPLETED` is the Event when `loss.backward()` completed.\n    `MODEL_COMPLETED` is the Event when all the model related operations completed.\n    `INNER_ITERATION_STARTED` is the Event when the iteration has an inner loop and the loop is started.\n    `INNER_ITERATION_COMPLETED` is the Event when the iteration has an inner loop and the loop is completed.\n    \"\"\"\n\n    FORWARD_COMPLETED = \"forward_completed\"\n    LOSS_COMPLETED = \"loss_completed\"\n    BACKWARD_COMPLETED = \"backward_completed\"\n    MODEL_COMPLETED = \"model_completed\"\n    INNER_ITERATION_STARTED = \"inner_iteration_started\"\n    INNER_ITERATION_COMPLETED = \"inner_iteration_completed\"\n\n\ndef get_devices_spec(devices: Sequence[torch.device | str] | None = None) -> list[torch.device]:\n    \"\"\"\n    Get a valid specification for one or more devices. If `devices` is None get devices for all CUDA devices available.\n    If `devices` is and zero-length structure a single CPU compute device is returned. In any other cases `devices` is\n    returned unchanged.\n\n    Args:\n        devices: list of devices to request, None for all GPU devices, [] for CPU.\n\n    Raises:\n        RuntimeError: When all GPUs are selected (``devices=None``) but no GPUs are available.\n\n    Returns:\n        list of torch.device: list of devices.\n\n    \"\"\"\n    if devices is None:\n        devices = [torch.device(f\"cuda:{d:d}\") for d in range(torch.cuda.device_count())]\n\n        if not devices:\n            raise RuntimeError(\"No GPU devices available.\")\n\n    elif len(devices) == 0:\n        devices = [torch.device(\"cpu\")]\n\n    else:\n        devices = list(devices)\n\n    devices = [torch.device(d) if isinstance(d, str) else d for d in devices]\n    return devices  # type: ignore\n\n\ndef default_prepare_batch(\n    batchdata: dict[str, torch.Tensor] | torch.Tensor | Sequence[torch.Tensor],\n    device: str | torch.device | None = None,\n    non_blocking: bool = False,\n    **kwargs: Any,\n) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:\n    \"\"\"\n    Default function to prepare the data for current iteration.\n\n    The input `batchdata` is either a single tensor, a pair of tensors, or a dictionary of data. In the first case the\n    return value is the tensor and None, in the second case the return value is the two tensors, and in the dictionary\n    case the return value depends on what keys are present. if `CommonKeys.IMAGE` and `CommonKeys.LABEL` are present\n    then the tensors they key to are returned, if only `CommonKeys.IMAGE` is present that tensor and None is returned.\n    If `CommonKeys.REALS` is present this is returned with None. All returned tensors are moved to the given device\n    using the given non-blocking argument before being returned.\n\n    This function implements the expected API for a `prepare_batch` callable in Ignite:\n    https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html\n\n    Args:\n        batchdata: input batch data which is either a single tensor, a pair, or a dictionary\n        device: device to move every returned tensor to\n        non_blocking: equivalent argument for `Tensor.to`\n        kwargs: further arguments for `Tensor.to`\n\n    Returns:\n        image, label(optional).\n    \"\"\"\n    if not isinstance(batchdata, dict):\n        if isinstance(batchdata, torch.Tensor):\n            return batchdata.to(device=device, non_blocking=non_blocking, **kwargs), None\n        elif len(batchdata) == 2:\n            image, label = batchdata\n            return (\n                image.to(device=device, non_blocking=non_blocking, **kwargs),\n                label.to(device=device, non_blocking=non_blocking, **kwargs),\n            )\n\n        raise AssertionError(\"Default prepare_batch expects a single tensor, a tensor pair, or dictionary input data.\")\n\n    if isinstance(batchdata.get(CommonKeys.LABEL), torch.Tensor):\n        return (\n            batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs),\n            batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking, **kwargs),\n        )\n\n    if GanKeys.REALS in batchdata:\n        return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking, **kwargs)\n\n    return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), None\n\n\nclass PrepareBatch(ABC):\n    \"\"\"\n    Interface of customized prepare_batch in the trainer or evaluator workflows.\n    It takes the data of current batch, target device and non_blocking flag as input.\n    Args `batchdata`, `device`, `non_blocking` refer to the ignite API:\n    https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.\n    `kwargs` supports other args for `Tensor.to()` API.\n    \"\"\"\n\n    @abstractmethod\n    def __call__(\n        self,\n        batchdata: dict[str, torch.Tensor],\n        device: str | torch.device | None = None,\n        non_blocking: bool = False,\n        **kwargs: Any,\n    ) -> Any:\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass PrepareBatchDefault(PrepareBatch):\n    \"\"\"\n    This wraps `default_prepare_batch` to return `image` and `label` only, so is consistent with its API.\n    \"\"\"\n\n    def __call__(\n        self,\n        batchdata: dict[str, torch.Tensor] | torch.Tensor | Sequence[torch.Tensor],\n        device: str | torch.device | None = None,\n        non_blocking: bool = False,\n        **kwargs: Any,\n    ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:\n        \"\"\"\n        Args `batchdata`, `device`, `non_blocking` refer to the ignite API:\n        https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.\n        `kwargs` supports other args for `Tensor.to()` API.\n\n        \"\"\"\n        return default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n\n\nclass PrepareBatchExtraInput(PrepareBatch):\n    \"\"\"\n    Customized prepare batch callable for trainers or evaluators which support extra input data for the network.\n    Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).\n    This uses `default_prepare_batch` but requires dictionary inputs.\n\n    Args:\n        extra_keys: If a string or sequence of strings is provided, values from the input dictionary are extracted from\n            those keys and passed to the network as extra positional arguments. If a dictionary is provided, every pair\n            `(k, v)` in that dictionary will become a new keyword argument assigning to `k` the value in the input\n            dictionary keyed to `v`.\n    \"\"\"\n\n    def __init__(self, extra_keys: str | Sequence[str] | dict[str, str]) -> None:\n        self.extra_keys = extra_keys\n\n    def __call__(\n        self,\n        batchdata: dict[str, torch.Tensor],\n        device: str | torch.device | None = None,\n        non_blocking: bool = False,\n        **kwargs: Any,\n    ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:\n        \"\"\"\n        Args `batchdata`, `device`, `non_blocking` refer to the ignite API:\n        https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.\n        `kwargs` supports other args for `Tensor.to()` API.\n        \"\"\"\n        image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n        args_ = list()\n        kwargs_ = dict()\n\n        def _get_data(key: str) -> torch.Tensor:\n            data = batchdata[key]\n\n            if isinstance(data, torch.Tensor):\n                data = data.to(device=device, non_blocking=non_blocking, **kwargs)\n\n            return data\n\n        if isinstance(self.extra_keys, (str, list, tuple)):\n            for k in ensure_tuple(self.extra_keys):\n                args_.append(_get_data(k))\n        elif isinstance(self.extra_keys, dict):\n            for k, v in self.extra_keys.items():\n                kwargs_.update({k: _get_data(v)})\n\n        return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_\n\n\nclass DiffusionPrepareBatch(PrepareBatch):\n    \"\"\"\n    This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.\n\n    Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and\n    return the image and noise field as the image/target pair plus the noise field the kwargs under the key \"noise\".\n    This assumes the inferer being used in conjunction with this class expects a \"noise\" parameter to be provided.\n\n    If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition\n    field to be passed to the inferer. This will appear in the keyword arguments under the key \"condition\".\n\n    \"\"\"\n\n    def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:\n        self.condition_name = condition_name\n        self.num_train_timesteps = num_train_timesteps\n\n    def get_noise(self, images: torch.Tensor) -> torch.Tensor:\n        \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n        return torch.randn_like(images)\n\n    def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:\n        \"\"\"Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.\"\"\"\n        return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n\n    def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:\n        \"\"\"Return the target for the loss function, this is the `noise` value by default.\"\"\"\n        return noise\n\n    def __call__(\n        self,\n        batchdata: dict[str, torch.Tensor],\n        device: str | torch.device | None = None,\n        non_blocking: bool = False,\n        **kwargs: Any,\n    ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:\n        images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n        noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n        timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n\n        target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)\n        infer_kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n\n        if self.condition_name is not None and isinstance(batchdata, Mapping):\n            infer_kwargs[\"condition\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n\n        # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value\n        return images, target, (), infer_kwargs\n\n\nclass VPredictionPrepareBatch(DiffusionPrepareBatch):\n    \"\"\"\n    This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.\n\n    Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and\n    from this compute the velocity using the provided scheduler. This value is used as the target in place of the\n    noise field itself although the noise is field is in the kwargs under the key \"noise\". This assumes the inferer\n    being used in conjunction with this class expects a \"noise\" parameter to be provided.\n\n    If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition\n    field to be passed to the inferer. This will appear in the keyword arguments under the key \"condition\".\n\n    \"\"\"\n\n    def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None:\n        super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)\n        self.scheduler = scheduler\n\n    def get_target(self, images, noise, timesteps):\n        return self.scheduler.get_velocity(images, noise, timesteps)  # type: ignore[operator]\n\n\ndef default_make_latent(\n    num_latents: int,\n    latent_size: int,\n    device: str | torch.device | None = None,\n    non_blocking: bool = False,\n    **kwargs: Any,\n) -> torch.Tensor:\n    return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking, **kwargs)\n\n\ndef engine_apply_transform(batch: Any, output: Any, transform: Callable[..., dict]) -> tuple[Any, Any]:\n    \"\"\"\n    Apply transform on `batch` and `output`.\n    If `batch` and `output` are dictionaries, temporarily combine them for the transform,\n    otherwise, apply the transform for `output` data only.\n\n    \"\"\"\n    if isinstance(batch, dict) and isinstance(output, dict):\n        data = dict(batch)\n        data.update(output)\n        transformed_data = apply_transform(transform, data)\n\n        if not isinstance(transformed_data, dict):\n            raise AssertionError(\"With a dict supplied to apply_transform a single dict return is expected.\")\n\n        for k, v in transformed_data.items():\n            # split the output data of post transforms into `output` and `batch`,\n            # `batch` should be read-only, so save the generated key-value into `output`\n            if k in output or k not in batch:\n                output[k] = v\n            else:\n                batch[k] = v\n    else:\n        output = apply_transform(transform, output)\n\n    return batch, output\n\n\ndef default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:\n    \"\"\"\n    The default function to compare metric values between current metric and previous best metric.\n\n    Args:\n        current_metric: metric value of current round computation.\n        prev_best: the best metric value of previous rounds to compare with.\n\n    \"\"\"\n    return current_metric > prev_best\n"
  },
  {
    "path": "monai/engines/workflow.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Iterable, Sequence, Sized\nfrom typing import TYPE_CHECKING, Any\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\n\nfrom monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch\nfrom monai.transforms import Decollated\nfrom monai.utils import IgniteInfo, ensure_tuple, is_scalar, min_version, optional_import\n\nfrom .utils import engine_apply_transform\n\nState, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"State\")\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine, EventEnum\n    from ignite.metrics import Metric\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n    Metric, _ = optional_import(\n        \"ignite.metrics\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Metric\", as_type=\"decorator\"\n    )\n    EventEnum, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"EventEnum\", as_type=\"decorator\"\n    )\n\n\nclass Workflow(Engine):\n    \"\"\"\n    Workflow defines the core work process inheriting from Ignite engine.\n    All trainer, validator and evaluator share this same workflow as base class,\n    because they all can be treated as same Ignite engine loops.\n    It initializes all the sharable data in Ignite engine.state.\n    And attach additional processing logics to Ignite engine based on Event-Handler mechanism.\n\n    Users should consider inheriting from `trainer` or `evaluator` to develop more trainers or evaluators.\n\n    Args:\n        device: an object representing the device on which to run.\n        max_epochs: the total epoch number for engine to run, validator and evaluator have only 1 epoch.\n        data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.\n        epoch_length: number of iterations for one epoch, default to `len(data_loader)`.\n        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously\n            with respect to the host. For other cases, this argument has no effect.\n        prepare_batch: function to parse expected data (usually `image`, `label` and other network args)\n            from `engine.state.batch` for every iteration, for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.\n        iteration_update: the callable function for every iteration, expect to accept `engine`\n            and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.\n            if not provided, use `self._iteration()` instead. for more details please refer to:\n            https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.\n        postprocessing: execute additional transformation for the model output data.\n            Typically, several Tensor based transforms composed by `Compose`.\n        key_metric: compute metric when every iteration completed, and save average value to\n            engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the\n            checkpoint into files.\n        additional_metrics: more Ignite metrics that also attach to Ignite Engine.\n        metric_cmp_fn: function to compare current key metric with previous best key metric value,\n            it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update\n            `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.\n        handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:\n            CheckpointHandler, StatsHandler, etc.\n        amp: whether to enable auto-mixed-precision training or inference, default is False.\n        event_names: additional custom ignite events that will register to the engine.\n            new events can be a list of str or `ignite.engine.events.EventEnum`.\n        event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.\n            for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html\n            #ignite.engine.engine.Engine.register_events.\n        decollate: whether to decollate the batch-first data to a list of data after model computation,\n            recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.\n            default to `True`.\n        to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for\n            `device`, `non_blocking`.\n        amp_kwargs: dict of the args for `torch.autocast(\"cuda\")` API, for more details:\n            https://pytorch.org/docs/stable/amp.html#torch.autocast.\n\n    Raises:\n        TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.\n        TypeError: When ``key_metric`` is not a ``Optional[dict]``.\n        TypeError: When ``additional_metrics`` is not a ``Optional[dict]``.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        device: torch.device | str,\n        max_epochs: int,\n        data_loader: Iterable | DataLoader,\n        epoch_length: int | None = None,\n        non_blocking: bool = False,\n        prepare_batch: Callable = default_prepare_batch,\n        iteration_update: Callable[[Engine, Any], Any] | None = None,\n        postprocessing: Callable | None = None,\n        key_metric: dict[str, Metric] | None = None,\n        additional_metrics: dict[str, Metric] | None = None,\n        metric_cmp_fn: Callable = default_metric_cmp_fn,\n        handlers: Sequence | None = None,\n        amp: bool = False,\n        event_names: list[str | EventEnum | type[EventEnum]] | None = None,\n        event_to_attr: dict | None = None,\n        decollate: bool = True,\n        to_kwargs: dict | None = None,\n        amp_kwargs: dict | None = None,\n    ) -> None:\n        super().__init__(self._iteration if iteration_update is None else iteration_update)\n\n        if isinstance(data_loader, DataLoader):\n            sampler = getattr(data_loader, \"sampler\", None)\n\n            # set the epoch value for DistributedSampler objects when an epoch starts\n            if isinstance(sampler, DistributedSampler):\n\n                @self.on(Events.EPOCH_STARTED)\n                def set_sampler_epoch(engine: Engine) -> None:\n                    sampler.set_epoch(engine.state.epoch)\n\n        # if the epoch_length isn't given, attempt to get it from the length of the data loader\n        if epoch_length is None and isinstance(data_loader, Sized):\n            try:\n                epoch_length = len(data_loader)\n            except TypeError:  # raised when data_loader has an iterable dataset with no length, or is some other type\n                pass  # deliberately leave epoch_length as None\n\n        # set all sharable data for the workflow based on Ignite engine.state\n        self.state: Any = State(\n            rank=dist.get_rank() if dist.is_available() and dist.is_initialized() else 0,\n            seed=0,\n            iteration=0,\n            epoch=0,\n            max_epochs=max_epochs,\n            epoch_length=epoch_length,  # None when the dataset is iterable and so has no length\n            output=None,\n            batch=None,\n            metrics={},\n            metric_details={},\n            dataloader=None,\n            device=device if isinstance(device, torch.device) or device is None else torch.device(device),\n            key_metric_name=None,  # we can set many metrics, only use key_metric to compare and save the best model\n            best_metric=-1,\n            best_metric_epoch=-1,\n        )\n        self.data_loader = data_loader\n        self.non_blocking = non_blocking\n        self.prepare_batch = prepare_batch\n        self.metric_cmp_fn = metric_cmp_fn\n        self.amp = amp\n        self.to_kwargs = {} if to_kwargs is None else to_kwargs\n        self.amp_kwargs = {} if amp_kwargs is None else amp_kwargs\n        self.scaler: torch.cuda.amp.GradScaler | None = None\n\n        if event_names is None:\n            event_names = [IterationEvents]\n        else:\n            if not isinstance(event_names, list):\n                raise ValueError(\"`event_names` must be a list of strings or EventEnums.\")\n            event_names += [IterationEvents]\n        for name in event_names:\n            if isinstance(name, (str, EventEnum)):\n                self.register_events(name, event_to_attr=event_to_attr)  # type: ignore[arg-type]\n            elif issubclass(name, EventEnum):\n                self.register_events(*name, event_to_attr=event_to_attr)\n            else:\n                raise ValueError(\"`event_names` must be a list of strings or EventEnums.\")\n\n        if decollate:\n            self._register_decollate()\n\n        if postprocessing is not None:\n            # tips: if `decollate=False` and `postprocessing` is MONAI transforms, it may not work well\n            # because all the MONAI transforms expect `channel-first` data\n            self._register_postprocessing(postprocessing)\n        if key_metric is not None:\n            self._register_metrics(key_metric, additional_metrics)\n        if handlers is not None:\n            self._register_handlers(handlers)\n\n    def _register_decollate(self):\n        \"\"\"\n        Register the decollate operation for batch data, will execute after model forward and loss forward.\n\n        \"\"\"\n\n        @self.on(IterationEvents.MODEL_COMPLETED)\n        def _decollate_data(engine: Engine) -> None:\n            # replicate the scalar values to make sure all the items have batch dimension, then decollate\n            transform = Decollated(keys=None, detach=True)\n            if isinstance(engine.state.batch, (list, dict)):\n                engine.state.batch = transform(engine.state.batch)\n            if isinstance(engine.state.output, (list, dict)):\n                engine.state.output = transform(engine.state.output)\n\n    def _register_postprocessing(self, posttrans: Callable) -> None:\n        \"\"\"\n        Register the postprocessing logic to the engine, will execute them as a chain when iteration completed.\n\n        \"\"\"\n\n        @self.on(IterationEvents.MODEL_COMPLETED)\n        def _run_postprocessing(engine: Engine) -> None:\n            if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list):\n                engine.state.batch, engine.state.output = engine_apply_transform(\n                    batch=engine.state.batch, output=engine.state.output, transform=posttrans\n                )\n            else:\n                for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)):\n                    engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, posttrans)\n\n    def _register_metrics(self, k_metric: dict, add_metrics: dict | None = None) -> None:\n        \"\"\"\n        Register the key metric and additional metrics to the engine, supports ignite Metrics.\n\n        \"\"\"\n        if not isinstance(k_metric, dict):\n            raise TypeError(f\"`key_metric` must be None or a dict but is {type(k_metric).__name__}.\")\n        self.state.key_metric_name = list(k_metric.keys())[0]\n        metrics = dict(k_metric)\n        if add_metrics is not None and len(add_metrics) > 0:\n            if not isinstance(add_metrics, dict):\n                raise TypeError(f\"Additional metrics must be None or a dict but is {type(add_metrics).__name__}.\")\n            metrics.update(add_metrics)\n        for name, metric in metrics.items():\n            metric.attach(self, name)\n\n        @self.on(Events.EPOCH_COMPLETED)\n        def _compare_metrics(engine: Workflow) -> None:\n            key_metric_name = engine.state.key_metric_name\n            if key_metric_name is not None:\n                current_val_metric = engine.state.metrics[key_metric_name]\n                if not is_scalar(current_val_metric):\n                    warnings.warn(\n                        \"Key metric is not a scalar value, skip the metric comparison with the current best metric.\"\n                        \"Please set other metrics as the key metric, or change the `reduction` mode to 'mean'.\"\n                    )\n                    return\n\n                if engine.state.best_metric_epoch == -1 or self.metric_cmp_fn(\n                    current_val_metric, engine.state.best_metric\n                ):\n                    self.logger.info(f\"Got new best metric of {key_metric_name}: {current_val_metric}\")\n                    engine.state.best_metric = current_val_metric\n                    engine.state.best_metric_epoch = engine.state.epoch\n\n    def _register_handlers(self, handlers: Sequence) -> None:\n        \"\"\"\n        Register the handlers to the engine, supports ignite Handlers with `attach` API.\n\n        \"\"\"\n        handlers_ = ensure_tuple(handlers)\n        for handler in handlers_:\n            handler.attach(self)\n\n    def run(self) -> None:  # type: ignore[override]\n        \"\"\"\n        Execute training, validation or evaluation based on Ignite Engine.\n        \"\"\"\n        if self.state.epoch_length == 0:\n            warnings.warn(\n                \"`dataloader` is empty or the specified `epoch_length` is 0, skip the `run`.\"\n                \" If running distributed training, the program may hang in `all-gather`, `all-reduce`, etc.\"\n                \" because not all the ranks run the same computation logic.\"\n            )\n            return\n        super().run(data=self.data_loader, max_epochs=self.state.max_epochs)\n\n    def _iteration(self, engine: Any, batchdata: dict[str, torch.Tensor]) -> dict:\n        \"\"\"\n        Abstract callback function for the processing logic of 1 iteration in Ignite Engine.\n        Need subclass to implement different logics, like SupervisedTrainer/Evaluator, GANTrainer, etc.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n            batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def get_stats(self, *vars):\n        \"\"\"\n        Get the statistics information of the workflow process.\n\n        Args:\n            vars: variables name in the `self.state`, will use the variable name as the key\n                and the state content as the value. if the variable doesn't exist, default value is `None`.\n\n        \"\"\"\n        return {k: getattr(self.state, k, None) for k in vars}\n"
  },
  {
    "path": "monai/fl/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/fl/client/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .client_algo import BaseClient, ClientAlgo, ClientAlgoStats\nfrom .monai_algo import MonaiAlgo, MonaiAlgoStats\n"
  },
  {
    "path": "monai/fl/client/client_algo.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.fl.utils.exchange_object import ExchangeObject\n\n\nclass BaseClient:\n    \"\"\"\n    Provide an abstract base class to allow the client to return summary statistics of the data.\n\n    To define a new stats script, subclass this class and implement the\n    following abstract methods::\n\n        - self.get_data_stats()\n\n    initialize(), abort(), and finalize() -- inherited from `ClientAlgoStats`; can be optionally be implemented\n    to help with lifecycle management of the class object.\n    \"\"\"\n\n    def initialize(self, extra: dict | None = None) -> None:\n        \"\"\"\n        Call to initialize the ClientAlgo class.\n\n        Args:\n            extra: optional extra information, e.g. dict of `ExtraItems.CLIENT_NAME` and/or `ExtraItems.APP_ROOT`.\n        \"\"\"\n\n    def finalize(self, extra: dict | None = None) -> None:\n        \"\"\"\n        Call to finalize the ClientAlgo class.\n\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n        \"\"\"\n\n    def abort(self, extra: dict | None = None) -> None:\n        \"\"\"\n        Call to abort the ClientAlgo training or evaluation.\n\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n        \"\"\"\n\n\nclass ClientAlgoStats(BaseClient):\n\n    def get_data_stats(self, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Get summary statistics about the local data.\n\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n                For example, requested statistics.\n\n        Returns:\n\n            ExchangeObject: summary statistics.\n\n        Extra dict example::\n\n            requested_stats = {\n                FlStatistics.STATISTICS: metrics,\n                FlStatistics.NUM_OF_BINS: num_of_bins,\n                FlStatistics.BIN_RANGES: bin_ranges\n            }\n\n        Returned ExchangeObject example::\n\n            ExchangeObject(\n                statistics = {...}\n            )\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass ClientAlgo(ClientAlgoStats):\n    \"\"\"\n    Provide an abstract base class for defining algo to run on any platform.\n    To define a new algo script, subclass this class and implement the\n    following abstract methods:\n\n        - self.train()\n        - self.get_weights()\n        - self.evaluate()\n        - self.get_data_stats() (optional, inherited from `ClientAlgoStats`)\n\n    initialize(), abort(), and finalize() - inherited from `ClientAlgoStats` - can be optionally be implemented\n    to help with lifecycle management of the class object.\n    \"\"\"\n\n    def train(self, data: ExchangeObject, extra: dict | None = None) -> None:\n        \"\"\"\n        Train network and produce new network from train data.\n\n        Args:\n            data: ExchangeObject containing current network weights to base training on.\n            extra: Dict with additional information that can be provided by the FL system.\n\n        Returns:\n            None\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def get_weights(self, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Get current local weights or weight differences.\n\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n\n        Returns:\n            ExchangeObject: current local weights or weight differences.\n\n        `ExchangeObject` example:\n\n        .. code-block:: python\n\n            ExchangeObject(\n                weights = self.trainer.network.state_dict(),\n                optim = None,  # could be self.optimizer.state_dict()\n                weight_type = WeightType.WEIGHTS\n            )\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def evaluate(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Get evaluation metrics on test data.\n\n        Args:\n            data: ExchangeObject with network weights to use for evaluation.\n            extra: Dict with additional information that can be provided by the FL system.\n\n        Returns:\n            metrics: ExchangeObject with evaluation metrics.\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n"
  },
  {
    "path": "monai/fl/client/monai_algo.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport time\nfrom collections.abc import Mapping, MutableMapping\nfrom typing import Any, cast\n\nimport torch\nimport torch.distributed as dist\n\nfrom monai.apps.auto3dseg.data_analyzer import DataAnalyzer\nfrom monai.apps.utils import get_logger\nfrom monai.auto3dseg import SegSummarizer\nfrom monai.bundle import BundleWorkflow, ConfigComponent, ConfigItem, ConfigParser, ConfigWorkflow\nfrom monai.engines import SupervisedEvaluator, SupervisedTrainer, Trainer\nfrom monai.fl.client import ClientAlgo, ClientAlgoStats\nfrom monai.fl.utils.constants import ExtraItems, FiltersType, FlPhase, FlStatistics, ModelType, WeightType\nfrom monai.fl.utils.exchange_object import ExchangeObject\nfrom monai.networks.utils import copy_model_state, get_state_dict\nfrom monai.utils import min_version, require_pkg\nfrom monai.utils.enums import DataStatsKeys\n\nlogger = get_logger(__name__)\n\n\ndef convert_global_weights(global_weights: Mapping, local_var_dict: MutableMapping) -> tuple[MutableMapping, int]:\n    \"\"\"Helper function to convert global weights to local weights format\"\"\"\n    # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.\n    model_keys = global_weights.keys()\n    n_converted = 0\n    for var_name in local_var_dict:\n        if var_name in model_keys:\n            weights = global_weights[var_name]\n            try:\n                # reshape global weights to compute difference later on\n                weights = torch.reshape(torch.as_tensor(weights), local_var_dict[var_name].shape)\n                # update the local dict\n                local_var_dict[var_name] = weights\n                n_converted += 1\n            except Exception as e:\n                raise ValueError(f\"Convert weight from {var_name} failed.\") from e\n    return local_var_dict, n_converted\n\n\ndef compute_weight_diff(global_weights, local_var_dict):\n    if global_weights is None:\n        raise ValueError(\"Cannot compute weight differences if `global_weights` is None!\")\n    if local_var_dict is None:\n        raise ValueError(\"Cannot compute weight differences if `local_var_dict` is None!\")\n    # compute delta model, global model has the primary key set\n    weight_diff = {}\n    n_diff = 0\n    for name in global_weights:\n        if name not in local_var_dict:\n            continue\n        # returned weight diff will be on the cpu\n        weight_diff[name] = local_var_dict[name].cpu() - global_weights[name].cpu()\n        n_diff += 1\n        if torch.any(torch.isnan(weight_diff[name])):\n            raise ValueError(f\"Weights for {name} became NaN...\")\n    if n_diff == 0:\n        raise RuntimeError(\"No weight differences computed!\")\n    return weight_diff\n\n\ndef disable_ckpt_loaders(parser: ConfigParser) -> None:\n    if \"validate#handlers\" in parser:\n        for h in parser[\"validate#handlers\"]:\n            if ConfigComponent.is_instantiable(h):\n                if \"CheckpointLoader\" in h[\"_target_\"]:\n                    h[\"_disabled_\"] = True\n\n\nclass MonaiAlgoStats(ClientAlgoStats):\n    \"\"\"\n    Implementation of ``ClientAlgoStats`` to allow federated learning with MONAI bundle configurations.\n\n    Args:\n        bundle_root: directory path of the bundle.\n        config_train_filename: bundle training config path relative to bundle_root. Can be a list of files;\n            defaults to \"configs/train.json\". only useful when `workflow` is None.\n        config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.\n        data_stats_transform_list: transforms to apply for the data stats result.\n        histogram_only: whether to only compute histograms. Defaults to False.\n        workflow: the bundle workflow to execute, usually it's training, evaluation or inference.\n            if None, will create an `ConfigWorkflow` internally based on `config_train_filename`.\n    \"\"\"\n\n    def __init__(\n        self,\n        bundle_root: str,\n        config_train_filename: str | list | None = \"configs/train.json\",\n        config_filters_filename: str | list | None = None,\n        data_stats_transform_list: list | None = None,\n        histogram_only: bool = False,\n        workflow: BundleWorkflow | None = None,\n    ):\n        self.logger = logger\n        self.bundle_root = bundle_root\n        self.config_train_filename = config_train_filename\n        self.config_filters_filename = config_filters_filename\n        self.train_data_key = \"train\"\n        self.eval_data_key = \"eval\"\n        self.data_stats_transform_list = data_stats_transform_list\n        self.histogram_only = histogram_only\n        self.workflow = None\n        if workflow is not None:\n            if not isinstance(workflow, BundleWorkflow):\n                raise ValueError(\"workflow must be a subclass of BundleWorkflow.\")\n            if workflow.get_workflow_type() is None:\n                raise ValueError(\"workflow doesn't specify the type.\")\n            self.workflow = workflow\n\n        self.client_name: str | None = None\n        self.app_root: str = \"\"\n        self.post_statistics_filters: Any = None\n        self.phase = FlPhase.IDLE\n        self.dataset_root: Any = None\n\n    def initialize(self, extra=None):\n        \"\"\"\n        Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.\n\n        Args:\n            extra: Dict with additional information that should be provided by FL system,\n                i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.\n                You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.\n\n        \"\"\"\n        if extra is None:\n            extra = {}\n        self.client_name = extra.get(ExtraItems.CLIENT_NAME, \"noname\")\n        logging_file = extra.get(ExtraItems.LOGGING_FILE, None)\n        self.logger.info(f\"Initializing {self.client_name} ...\")\n\n        # FL platform needs to provide filepath to configuration files\n        self.app_root = extra.get(ExtraItems.APP_ROOT, \"\")\n        self.bundle_root = os.path.join(self.app_root, self.bundle_root)\n\n        if self.workflow is None:\n            config_train_files = self._add_config_files(self.config_train_filename)\n            self.workflow = ConfigWorkflow(\n                config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type=\"train\"\n            )\n        self.workflow.initialize()\n        self.workflow.bundle_root = self.bundle_root\n        # initialize the workflow as the content changed\n        self.workflow.initialize()\n\n        config_filter_files = self._add_config_files(self.config_filters_filename)\n        filter_parser = ConfigParser()\n        if len(config_filter_files) > 0:\n            filter_parser.read_config(config_filter_files)\n            # Get filters\n            self.post_statistics_filters = filter_parser.get_parsed_content(\n                FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS)\n            )\n        self.logger.info(f\"Initialized {self.client_name}.\")\n\n    def get_data_stats(self, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Returns summary statistics about the local data.\n\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n                    Both FlStatistics.HIST_BINS and FlStatistics.HIST_RANGE must be provided.\n\n        Returns:\n            stats: ExchangeObject with summary statistics.\n\n        \"\"\"\n        if extra is None:\n            raise ValueError(\"`extra` has to be set\")\n\n        if self.workflow.dataset_dir:  # type: ignore\n            self.phase = FlPhase.GET_DATA_STATS\n            self.logger.info(f\"Computing statistics on {self.workflow.dataset_dir}\")  # type: ignore\n\n            if FlStatistics.HIST_BINS not in extra:\n                raise ValueError(\"FlStatistics.NUM_OF_BINS not specified in `extra`\")\n            else:\n                hist_bins = extra[FlStatistics.HIST_BINS]\n            if FlStatistics.HIST_RANGE not in extra:\n                raise ValueError(\"FlStatistics.HIST_RANGE not specified in `extra`\")\n            else:\n                hist_range = extra[FlStatistics.HIST_RANGE]\n\n            stats_dict = {}\n\n            # train data stats\n            train_summary_stats, train_case_stats = self._get_data_key_stats(\n                data=self.workflow.train_dataset_data,  # type: ignore\n                data_key=self.train_data_key,\n                hist_bins=hist_bins,\n                hist_range=hist_range,\n                output_path=os.path.join(self.app_root, \"train_data_stats.yaml\"),\n            )\n            if train_case_stats:\n                # Only return summary statistics to FL server\n                stats_dict.update({self.train_data_key: train_summary_stats})\n\n            # eval data stats\n            eval_summary_stats = None\n            eval_case_stats = None\n            if self.workflow.val_dataset_data is not None:  # type: ignore\n                eval_summary_stats, eval_case_stats = self._get_data_key_stats(\n                    data=self.workflow.val_dataset_data,  # type: ignore\n                    data_key=self.eval_data_key,\n                    hist_bins=hist_bins,\n                    hist_range=hist_range,\n                    output_path=os.path.join(self.app_root, \"eval_data_stats.yaml\"),\n                )\n            else:\n                self.logger.warning(\"the datalist doesn't contain validation section.\")\n            if eval_summary_stats:\n                # Only return summary statistics to FL server\n                stats_dict.update({self.eval_data_key: eval_summary_stats})\n\n            # total stats\n            if train_case_stats and eval_case_stats:\n                # Compute total summary\n                total_summary_stats = self._compute_total_stats(\n                    [train_case_stats, eval_case_stats], hist_bins, hist_range\n                )\n                stats_dict.update({FlStatistics.TOTAL_DATA: total_summary_stats})\n\n            # optional filter of data stats\n            stats = ExchangeObject(statistics=stats_dict)\n            if self.post_statistics_filters is not None:\n                for _filter in self.post_statistics_filters:\n                    stats = _filter(stats, extra)\n\n            return stats\n        else:\n            raise ValueError(\"data_root not set!\")\n\n    def _get_data_key_stats(self, data, data_key, hist_bins, hist_range, output_path=None):\n        analyzer = DataAnalyzer(\n            datalist={data_key: data},\n            dataroot=self.workflow.dataset_dir,  # type: ignore\n            hist_bins=hist_bins,\n            hist_range=hist_range,\n            output_path=output_path,\n            histogram_only=self.histogram_only,\n        )\n\n        self.logger.info(f\"{self.client_name} compute data statistics on {data_key}...\")\n        all_stats = analyzer.get_all_case_stats(transform_list=self.data_stats_transform_list, key=data_key)\n\n        case_stats = all_stats[DataStatsKeys.BY_CASE]\n\n        summary_stats = {\n            FlStatistics.DATA_STATS: all_stats[DataStatsKeys.SUMMARY],\n            FlStatistics.DATA_COUNT: len(data),\n            FlStatistics.FAIL_COUNT: len(data) - len(case_stats),\n            # TODO: add shapes, voxels sizes, etc.\n        }\n\n        return summary_stats, case_stats\n\n    @staticmethod\n    def _compute_total_stats(case_stats_lists, hist_bins, hist_range):\n        # Compute total summary\n        total_case_stats = []\n        for case_stats_list in case_stats_lists:\n            total_case_stats += case_stats_list\n\n        summarizer = SegSummarizer(\n            \"image\", \"label\", average=True, do_ccp=True, hist_bins=hist_bins, hist_range=hist_range\n        )\n        total_summary_stats = summarizer.summarize(total_case_stats)\n\n        summary_stats = {\n            FlStatistics.DATA_STATS: total_summary_stats,\n            FlStatistics.DATA_COUNT: len(total_case_stats),\n            FlStatistics.FAIL_COUNT: 0,\n        }\n\n        return summary_stats\n\n    def _add_config_files(self, config_files):\n        files = []\n        if config_files:\n            if isinstance(config_files, str):\n                files.append(os.path.join(self.bundle_root, config_files))\n            elif isinstance(config_files, list):\n                for file in config_files:\n                    if isinstance(file, str):\n                        files.append(os.path.join(self.bundle_root, file))\n                    else:\n                        raise ValueError(f\"Expected config file to be of type str but got {type(file)}: {file}\")\n            else:\n                raise ValueError(\n                    f\"Expected config files to be of type str or list but got {type(config_files)}: {config_files}\"\n                )\n        return files\n\n\n@require_pkg(pkg_name=\"ignite\", version=\"0.4.10\", version_checker=min_version)\nclass MonaiAlgo(ClientAlgo, MonaiAlgoStats):\n    \"\"\"\n    Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations.\n\n    Args:\n        bundle_root: directory path of the bundle.\n        local_epochs: number of local epochs to execute during each round of local training; defaults to 1.\n        send_weight_diff: whether to send weight differences rather than full weights; defaults to `True`.\n        config_train_filename: bundle training config path relative to bundle_root. can be a list of files.\n            defaults to \"configs/train.json\". only useful when `train_workflow` is None.\n        train_kwargs: other args of the `ConfigWorkflow` of train, except for `config_file`, `meta_file`,\n            `logging_file`, `workflow_type`. only useful when `train_workflow` is None.\n        config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files.\n            if \"default\", [\"configs/train.json\", \"configs/evaluate.json\"] will be used.\n            this arg is only useful when `eval_workflow` is None.\n        eval_kwargs: other args of the `ConfigWorkflow` of evaluation, except for `config_file`, `meta_file`,\n            `logging_file`, `workflow_type`. only useful when `eval_workflow` is None.\n        config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.\n        disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`.\n        best_model_filepath: location of best model checkpoint; defaults \"models/model.pt\" relative to `bundle_root`.\n        final_model_filepath: location of final model checkpoint; defaults \"models/model_final.pt\" relative to `bundle_root`.\n        save_dict_key: If a model checkpoint contains several state dicts,\n            the one defined by `save_dict_key` will be returned by `get_weights`; defaults to \"model\".\n            If all state dicts should be returned, set `save_dict_key` to None.\n        data_stats_transform_list: transforms to apply for the data stats result.\n        eval_workflow_name: the workflow name corresponding to the \"config_evaluate_filename\", default to \"train\"\n            as the default \"config_evaluate_filename\" overrides the train workflow config.\n            this arg is only useful when `eval_workflow` is None.\n        train_workflow: the bundle workflow to execute training, if None, will create a `ConfigWorkflow` internally\n            based on `config_train_filename` and `train_kwargs`.\n        eval_workflow: the bundle workflow to execute evaluation, if None, will create a `ConfigWorkflow` internally\n            based on `config_evaluate_filename`, `eval_kwargs`, `eval_workflow_name`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        bundle_root: str,\n        local_epochs: int = 1,\n        send_weight_diff: bool = True,\n        config_train_filename: str | list | None = \"configs/train.json\",\n        train_kwargs: dict | None = None,\n        config_evaluate_filename: str | list | None = \"default\",\n        eval_kwargs: dict | None = None,\n        config_filters_filename: str | list | None = None,\n        disable_ckpt_loading: bool = True,\n        best_model_filepath: str | None = \"models/model.pt\",\n        final_model_filepath: str | None = \"models/model_final.pt\",\n        save_dict_key: str | None = \"model\",\n        data_stats_transform_list: list | None = None,\n        eval_workflow_name: str = \"train\",\n        train_workflow: BundleWorkflow | None = None,\n        eval_workflow: BundleWorkflow | None = None,\n    ):\n        self.logger = logger\n        self.bundle_root = bundle_root\n        self.local_epochs = local_epochs\n        self.send_weight_diff = send_weight_diff\n        self.config_train_filename = config_train_filename\n        self.train_kwargs = {} if train_kwargs is None else train_kwargs\n        if config_evaluate_filename == \"default\":\n            # by default, evaluator needs both training and evaluate to be instantiated\n            config_evaluate_filename = [\"configs/train.json\", \"configs/evaluate.json\"]\n        self.config_evaluate_filename = config_evaluate_filename\n        self.eval_kwargs = {} if eval_kwargs is None else eval_kwargs\n        self.config_filters_filename = config_filters_filename\n        self.disable_ckpt_loading = disable_ckpt_loading\n        self.model_filepaths = {ModelType.BEST_MODEL: best_model_filepath, ModelType.FINAL_MODEL: final_model_filepath}\n        self.save_dict_key = save_dict_key\n        self.data_stats_transform_list = data_stats_transform_list\n        self.eval_workflow_name = eval_workflow_name\n        self.train_workflow = None\n        self.eval_workflow = None\n        if train_workflow is not None:\n            if not isinstance(train_workflow, BundleWorkflow) or train_workflow.get_workflow_type() != \"train\":\n                raise ValueError(\n                    f\"train workflow must be BundleWorkflow and set type in {BundleWorkflow.supported_train_type}.\"\n                )\n            self.train_workflow = train_workflow\n        if eval_workflow is not None:\n            # evaluation workflow can be \"train\" type or \"infer\" type\n            if not isinstance(eval_workflow, BundleWorkflow) or eval_workflow.get_workflow_type() is None:\n                raise ValueError(\"train workflow must be BundleWorkflow and set type.\")\n            self.eval_workflow = eval_workflow\n        self.stats_sender = None\n\n        self.app_root = \"\"\n        self.filter_parser: ConfigParser | None = None\n        self.trainer: SupervisedTrainer | None = None\n        self.evaluator: SupervisedEvaluator | None = None\n        self.pre_filters = None\n        self.post_weight_filters = None\n        self.post_evaluate_filters = None\n        self.iter_of_start_time = 0\n        self.global_weights: Mapping | None = None\n\n        self.phase = FlPhase.IDLE\n        self.client_name = None\n        self.dataset_root = None\n\n    def initialize(self, extra=None):\n        \"\"\"\n        Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.\n\n        Args:\n            extra: Dict with additional information that should be provided by FL system,\n                i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.\n                You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.\n\n        \"\"\"\n        self._set_cuda_device()\n        if extra is None:\n            extra = {}\n        self.client_name = extra.get(ExtraItems.CLIENT_NAME, \"noname\")\n        logging_file = extra.get(ExtraItems.LOGGING_FILE, None)\n        timestamp = time.strftime(\"%Y%m%d_%H%M%S\")\n        self.logger.info(f\"Initializing {self.client_name} ...\")\n        # FL platform needs to provide filepath to configuration files\n        self.app_root = extra.get(ExtraItems.APP_ROOT, \"\")\n        self.bundle_root = os.path.join(self.app_root, self.bundle_root)\n\n        if self.train_workflow is None and self.config_train_filename is not None:\n            config_train_files = self._add_config_files(self.config_train_filename)\n            # if enabled experiment tracking, set the run name to the FL client name and timestamp,\n            # expect the tracking settings use \"run_name\" to define the run name\n            if \"run_name\" not in self.train_kwargs:\n                self.train_kwargs[\"run_name\"] = f\"{self.client_name}_{timestamp}\"\n            self.train_workflow = ConfigWorkflow(\n                config_file=config_train_files,\n                meta_file=None,\n                logging_file=logging_file,\n                workflow_type=\"train\",\n                **self.train_kwargs,\n            )\n        if self.train_workflow is not None:\n            self.train_workflow.initialize()\n            self.train_workflow.bundle_root = self.bundle_root\n            self.train_workflow.max_epochs = self.local_epochs\n            if self.disable_ckpt_loading and isinstance(self.train_workflow, ConfigWorkflow):\n                disable_ckpt_loaders(parser=self.train_workflow.parser)\n            # initialize the workflow as the content changed\n            self.train_workflow.initialize()\n            self.trainer = self.train_workflow.trainer\n            if not isinstance(self.trainer, SupervisedTrainer):\n                raise ValueError(f\"trainer must be SupervisedTrainer, but got: {type(self.trainer)}.\")\n\n        if self.eval_workflow is None and self.config_evaluate_filename is not None:\n            config_eval_files = self._add_config_files(self.config_evaluate_filename)\n            # if enabled experiment tracking, set the run name to the FL client name and timestamp,\n            # expect the tracking settings use \"run_name\" to define the run name\n            if \"run_name\" not in self.eval_kwargs:\n                self.eval_kwargs[\"run_name\"] = f\"{self.client_name}_{timestamp}\"\n            self.eval_workflow = ConfigWorkflow(\n                config_file=config_eval_files,\n                meta_file=None,\n                logging_file=logging_file,\n                workflow_type=self.eval_workflow_name,\n                **self.eval_kwargs,\n            )\n        if self.eval_workflow is not None:\n            self.eval_workflow.initialize()\n            self.eval_workflow.bundle_root = self.bundle_root\n            if self.disable_ckpt_loading and isinstance(self.eval_workflow, ConfigWorkflow):\n                disable_ckpt_loaders(parser=self.eval_workflow.parser)\n            # initialize the workflow as the content changed\n            self.eval_workflow.initialize()\n            self.evaluator = self.eval_workflow.evaluator\n            if not isinstance(self.evaluator, SupervisedEvaluator):\n                raise ValueError(f\"evaluator must be SupervisedEvaluator, but got: {type(self.evaluator)}.\")\n\n        config_filter_files = self._add_config_files(self.config_filters_filename)\n        self.filter_parser = ConfigParser()\n        if len(config_filter_files) > 0:\n            self.filter_parser.read_config(config_filter_files)\n\n        # set stats sender for nvflare\n        self.stats_sender = extra.get(ExtraItems.STATS_SENDER, self.stats_sender)\n        if self.stats_sender is not None:\n            self.stats_sender.attach(self.trainer)\n            self.stats_sender.attach(self.evaluator)\n\n        # Get filters\n        self.pre_filters = self.filter_parser.get_parsed_content(\n            FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS)\n        )\n        self.post_weight_filters = self.filter_parser.get_parsed_content(\n            FiltersType.POST_WEIGHT_FILTERS, default=ConfigItem(None, FiltersType.POST_WEIGHT_FILTERS)\n        )\n        self.post_evaluate_filters = self.filter_parser.get_parsed_content(\n            FiltersType.POST_EVALUATE_FILTERS, default=ConfigItem(None, FiltersType.POST_EVALUATE_FILTERS)\n        )\n        self.post_statistics_filters = self.filter_parser.get_parsed_content(\n            FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS)\n        )\n        self.logger.info(f\"Initialized {self.client_name}.\")\n\n    def train(self, data: ExchangeObject, extra: dict | None = None) -> None:\n        \"\"\"\n        Train on client's local data.\n\n        Args:\n            data: `ExchangeObject` containing the current global model weights.\n            extra: Dict with additional information that can be provided by the FL system.\n\n        \"\"\"\n\n        self._set_cuda_device()\n        if extra is None:\n            extra = {}\n        if not isinstance(data, ExchangeObject):\n            raise ValueError(f\"expected data to be ExchangeObject but received {type(data)}\")\n\n        if self.trainer is None:\n            raise ValueError(\"self.trainer should not be None.\")\n        if self.pre_filters is not None:\n            for _filter in self.pre_filters:\n                data = _filter(data, extra)\n        self.phase = FlPhase.TRAIN\n        self.logger.info(f\"Load {self.client_name} weights...\")\n        local_var_dict = get_state_dict(self.trainer.network)\n        self.global_weights, n_converted = convert_global_weights(\n            global_weights=cast(dict, data.weights), local_var_dict=local_var_dict\n        )\n        self._check_converted(data.weights, local_var_dict, n_converted)\n\n        # set engine state max epochs.\n        self.trainer.state.max_epochs = self.trainer.state.epoch + self.local_epochs\n        # get current iteration when a round starts\n        self.iter_of_start_time = self.trainer.state.iteration\n\n        _, updated_keys, _ = copy_model_state(src=cast(Mapping, self.global_weights), dst=self.trainer.network)\n        if len(updated_keys) == 0:\n            self.logger.warning(\"No weights loaded!\")\n        self.logger.info(f\"Start {self.client_name} training...\")\n        self.trainer.run()\n\n    def get_weights(self, extra=None):\n        \"\"\"\n        Returns the current weights of the model.\n\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n\n        Returns:\n            return_weights: `ExchangeObject` containing current weights (default)\n                or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`).\n\n        \"\"\"\n\n        self._set_cuda_device()\n        if extra is None:\n            extra = {}\n\n        # by default return current weights, return best if requested via model type.\n        self.phase = FlPhase.GET_WEIGHTS\n\n        if ExtraItems.MODEL_TYPE in extra:\n            model_type = extra.get(ExtraItems.MODEL_TYPE)\n            if not isinstance(model_type, ModelType):\n                raise ValueError(\n                    f\"Expected requested model type to be of type `ModelType` but received {type(model_type)}\"\n                )\n            if model_type in self.model_filepaths:\n                model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))\n                if not os.path.isfile(model_path):\n                    raise ValueError(f\"No best model checkpoint exists at {model_path}\")\n                weights = torch.load(model_path, map_location=\"cpu\", weights_only=True)\n                # if weights contain several state dicts, use the one defined by `save_dict_key`\n                if isinstance(weights, dict) and self.save_dict_key in weights:\n                    weights = weights.get(self.save_dict_key)\n                weigh_type: WeightType | None = WeightType.WEIGHTS\n                stats: dict = {}\n                self.logger.info(f\"Returning {model_type} checkpoint weights from {model_path}.\")\n            else:\n                raise ValueError(\n                    f\"Requested model type {model_type} not specified in `model_filepaths`: {self.model_filepaths}\"\n                )\n        else:\n            if self.trainer:\n                weights = get_state_dict(self.trainer.network)\n                # returned weights will be on the cpu\n                for k in weights.keys():\n                    weights[k] = weights[k].cpu()\n                weigh_type = WeightType.WEIGHTS\n                stats = self.trainer.get_stats()\n                # calculate current iteration and epoch data after training.\n                stats[FlStatistics.NUM_EXECUTED_ITERATIONS] = self.trainer.state.iteration - self.iter_of_start_time\n                # compute weight differences\n                if self.send_weight_diff:\n                    weights = compute_weight_diff(global_weights=self.global_weights, local_var_dict=weights)\n                    weigh_type = WeightType.WEIGHT_DIFF\n                    self.logger.info(\"Returning current weight differences.\")\n                else:\n                    self.logger.info(\"Returning current weights.\")\n            else:\n                weights = None\n                weigh_type = None\n                stats = dict()\n\n        if not isinstance(stats, dict):\n            raise ValueError(f\"stats is not a dict, {stats}\")\n        return_weights = ExchangeObject(\n            weights=weights,\n            optim=None,  # could be self.optimizer.state_dict()\n            weight_type=weigh_type,\n            statistics=stats,\n        )\n\n        # filter weights if needed (use to apply differential privacy, encryption, compression, etc.)\n        if self.post_weight_filters is not None:\n            for _filter in self.post_weight_filters:\n                return_weights = _filter(return_weights, extra)\n\n        return return_weights\n\n    def evaluate(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Evaluate on client's local data.\n\n        Args:\n            data: `ExchangeObject` containing the current global model weights.\n            extra: Dict with additional information that can be provided by the FL system.\n\n        Returns:\n            return_metrics: `ExchangeObject` containing evaluation metrics.\n\n        \"\"\"\n\n        self._set_cuda_device()\n        if extra is None:\n            extra = {}\n        if not isinstance(data, ExchangeObject):\n            raise ValueError(f\"expected data to be ExchangeObject but received {type(data)}\")\n\n        if self.evaluator is None:\n            raise ValueError(\"self.evaluator should not be None.\")\n        if self.pre_filters is not None:\n            for _filter in self.pre_filters:\n                data = _filter(data, extra)\n\n        self.phase = FlPhase.EVALUATE\n        self.logger.info(f\"Load {self.client_name} weights...\")\n        local_var_dict = get_state_dict(self.evaluator.network)\n        global_weights, n_converted = convert_global_weights(\n            global_weights=cast(dict, data.weights), local_var_dict=local_var_dict\n        )\n        self._check_converted(data.weights, local_var_dict, n_converted)\n\n        _, updated_keys, _ = copy_model_state(src=global_weights, dst=self.evaluator.network)\n        if len(updated_keys) == 0:\n            self.logger.warning(\"No weights loaded!\")\n        self.logger.info(f\"Start {self.client_name} evaluating...\")\n        if isinstance(self.trainer, Trainer):\n            self.evaluator.run(self.trainer.state.epoch + 1)\n        else:\n            self.evaluator.run()\n        return_metrics = ExchangeObject(metrics=self.evaluator.state.metrics)\n\n        if self.post_evaluate_filters is not None:\n            for _filter in self.post_evaluate_filters:\n                return_metrics = _filter(return_metrics, extra)\n        return return_metrics\n\n    def abort(self, extra=None):\n        \"\"\"\n        Abort the training or evaluation.\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n        \"\"\"\n        self.logger.info(f\"Aborting {self.client_name} during {self.phase} phase.\")\n        if isinstance(self.trainer, Trainer):\n            self.logger.info(f\"Aborting {self.client_name} trainer...\")\n            self.trainer.interrupt()\n        if isinstance(self.evaluator, Trainer):\n            self.logger.info(f\"Aborting {self.client_name} evaluator...\")\n            self.evaluator.interrupt()\n\n    def finalize(self, extra: dict | None = None) -> None:\n        \"\"\"\n        Finalize the training or evaluation.\n        Args:\n            extra: Dict with additional information that can be provided by the FL system.\n        \"\"\"\n        self.logger.info(f\"Terminating {self.client_name} during {self.phase} phase.\")\n        if isinstance(self.trainer, Trainer):\n            self.logger.info(f\"Terminating {self.client_name} trainer...\")\n            self.trainer.terminate()\n        if isinstance(self.evaluator, Trainer):\n            self.logger.info(f\"Terminating {self.client_name} evaluator...\")\n            self.evaluator.terminate()\n        if self.train_workflow is not None:\n            self.train_workflow.finalize()\n        if self.eval_workflow is not None:\n            self.eval_workflow.finalize()\n\n    def _check_converted(self, global_weights, local_var_dict, n_converted):\n        if n_converted == 0:\n            raise RuntimeError(\n                f\"No global weights converted! Received weight dict keys are {list(global_weights.keys())}\"\n            )\n        else:\n            self.logger.info(\n                f\"Converted {n_converted} global variables to match {len(local_var_dict)} local variables.\"\n            )\n\n    def _set_cuda_device(self):\n        if dist.is_initialized():\n            self.rank = int(os.environ[\"LOCAL_RANK\"])\n            torch.cuda.set_device(self.rank)\n"
  },
  {
    "path": "monai/fl/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/fl/utils/constants.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.utils.enums import StrEnum\n\n\nclass WeightType(StrEnum):\n    WEIGHTS = \"fl_weights_full\"\n    WEIGHT_DIFF = \"fl_weight_diff\"\n\n\nclass ModelType(StrEnum):\n    BEST_MODEL = \"fl_best_model\"\n    FINAL_MODEL = \"fl_final_model\"\n\n\nclass ExtraItems(StrEnum):\n    ABORT = \"fl_abort\"\n    MODEL_TYPE = \"fl_model_type\"\n    CLIENT_NAME = \"fl_client_name\"\n    APP_ROOT = \"fl_app_root\"\n    STATS_SENDER = \"fl_stats_sender\"\n    LOGGING_FILE = \"logging_file\"\n\n\nclass FlPhase(StrEnum):\n    IDLE = \"fl_idle\"\n    TRAIN = \"fl_train\"\n    EVALUATE = \"fl_evaluate\"\n    GET_WEIGHTS = \"fl_get_weights\"\n    GET_DATA_STATS = \"fl_get_data_stats\"\n\n\nclass FlStatistics(StrEnum):\n    NUM_EXECUTED_ITERATIONS = \"num_executed_iterations\"\n    STATISTICS = \"statistics\"\n    HIST_BINS = \"hist_bins\"\n    HIST_RANGE = \"hist_range\"\n    DATA_STATS = \"data_stats\"\n    DATA_COUNT = \"data_count\"\n    FAIL_COUNT = \"fail_count\"\n    TOTAL_DATA = \"total_data\"\n    FEATURE_NAMES = \"feature_names\"\n\n\nclass FiltersType(StrEnum):\n    PRE_FILTERS = \"pre_filters\"\n    POST_WEIGHT_FILTERS = \"post_weight_filters\"\n    POST_EVALUATE_FILTERS = \"post_evaluate_filters\"\n    POST_STATISTICS_FILTERS = \"post_statistics_filters\"\n"
  },
  {
    "path": "monai/fl/utils/exchange_object.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.fl.utils.constants import WeightType\n\n\nclass ExchangeObject(dict):\n    \"\"\"\n    Contains the information shared between client and server.\n\n    Args:\n        weights: model weights.\n        optim: optimizer weights.\n        metrics: evaluation metrics.\n        weight_type: type of weights (see monai.fl.utils.constants.WeightType).\n        statistics: training statistics, i.e. number executed iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        weights: dict | None = None,\n        optim: dict | None = None,\n        metrics: dict | None = None,\n        weight_type: WeightType | None = None,\n        statistics: dict | None = None,\n    ):\n        super().__init__()\n        self.weights = weights\n        self.optim = optim\n        self.metrics = metrics\n        self.weight_type = weight_type\n        self.statistics = statistics\n        self._summary: dict = {}\n\n    @property\n    def metrics(self):\n        return self._metrics\n\n    @metrics.setter\n    def metrics(self, metrics):\n        if metrics is not None:\n            if not isinstance(metrics, dict):\n                raise ValueError(f\"Expected metrics to be of type dict but received {type(metrics)}\")\n        self._metrics = metrics\n\n    @property\n    def statistics(self):\n        return self._statistics\n\n    @statistics.setter\n    def statistics(self, statistics):\n        if statistics is not None:\n            if not isinstance(statistics, dict):\n                raise ValueError(f\"Expected statistics to be of type dict but received {type(statistics)}\")\n        self._statistics = statistics\n\n    @property\n    def weight_type(self):\n        return self._weight_type\n\n    @weight_type.setter\n    def weight_type(self, weight_type):\n        if weight_type is not None:\n            if weight_type not in [WeightType.WEIGHTS, WeightType.WEIGHT_DIFF]:\n                raise ValueError(f\"Expected weight type to be either {WeightType.WEIGHTS} or {WeightType.WEIGHT_DIFF}\")\n        self._weight_type = weight_type\n\n    def is_valid_weights(self):\n        if not self.weights:\n            return False\n        if not self.weight_type:\n            return False\n        return True\n\n    def _add_to_summary(self, key, value):\n        if value:\n            if isinstance(value, dict):\n                self._summary[key] = len(value)\n            elif isinstance(value, WeightType):\n                self._summary[key] = value\n            else:\n                self._summary[key] = type(value)\n\n    def summary(self):\n        self._summary.update(self)\n        for k, v in zip(\n            [\"weights\", \"optim\", \"metrics\", \"weight_type\", \"statistics\"],\n            [self.weights, self.optim, self.metrics, self.weight_type, self.statistics],\n        ):\n            self._add_to_summary(k, v)\n        return self._summary\n\n    def __repr__(self):\n        return str(self.summary())\n\n    def __str__(self):\n        return str(self.summary())\n"
  },
  {
    "path": "monai/fl/utils/filters.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport abc\n\nfrom monai.fl.utils.exchange_object import ExchangeObject\n\n\nclass Filter(abc.ABC):\n    \"\"\"\n    Used to apply filter to content of ExchangeObject.\n    \"\"\"\n\n    @abc.abstractmethod\n    def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Run the filtering.\n\n        Arguments:\n            data: ExchangeObject containing some data.\n\n        Returns:\n            ExchangeObject: filtered data.\n        \"\"\"\n\n        raise NotImplementedError\n\n\nclass SummaryFilter(Filter):\n    \"\"\"\n    Summary filter to show content of ExchangeObject.\n    \"\"\"\n\n    def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:\n        \"\"\"\n        Example filter that doesn't filter anything but only prints data summary.\n\n        Arguments:\n            data: ExchangeObject containing some data.\n\n        Returns:\n            ExchangeObject: filtered data.\n        \"\"\"\n\n        print(f\"Summary of ExchangeObject: {data.summary()}\")\n\n        return data\n"
  },
  {
    "path": "monai/handlers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .average_precision import AveragePrecision\nfrom .calibration import CalibrationError\nfrom .checkpoint_loader import CheckpointLoader\nfrom .checkpoint_saver import CheckpointSaver\nfrom .classification_saver import ClassificationSaver\nfrom .clearml_handlers import ClearMLHandler, ClearMLImageHandler, ClearMLStatsHandler\nfrom .confusion_matrix import ConfusionMatrix\nfrom .decollate_batch import DecollateBatch\nfrom .earlystop_handler import EarlyStopHandler\nfrom .garbage_collector import GarbageCollector\nfrom .hausdorff_distance import HausdorffDistance\nfrom .ignite_metric import IgniteMetricHandler\nfrom .logfile_handler import LogfileHandler\nfrom .lr_schedule_handler import LrScheduleHandler\nfrom .mean_dice import MeanDice\nfrom .mean_iou import MeanIoUHandler\nfrom .metric_logger import MetricLogger, MetricLoggerKeys\nfrom .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler\nfrom .metrics_saver import MetricsSaver\nfrom .mlflow_handler import MLFlowHandler\nfrom .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler\nfrom .panoptic_quality import PanopticQuality\nfrom .parameter_scheduler import ParamSchedulerHandler\nfrom .postprocessing import PostProcessing\nfrom .probability_maps import ProbMapProducer\nfrom .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError\nfrom .roc_auc import ROCAUC\nfrom .smartcache_handler import SmartCacheHandler\nfrom .stats_handler import StatsHandler\nfrom .surface_distance import SurfaceDistance\nfrom .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler\nfrom .trt_handler import TrtHandler\nfrom .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports\nfrom .validation_handler import ValidationHandler\n"
  },
  {
    "path": "monai/handlers/average_precision.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import AveragePrecisionMetric\nfrom monai.utils import Average\n\n\nclass AveragePrecision(IgniteMetricHandler):\n    \"\"\"\n    Computes Average Precision (AP).\n    accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.\n\n    Args:\n        average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n            Type of averaging performed if not binary classification. Defaults to ``\"macro\"``.\n\n            - ``\"macro\"``: calculate metrics for each label, and find their unweighted mean.\n                This does not take label imbalance into account.\n            - ``\"weighted\"``: calculate metrics for each label, and find their average,\n                weighted by support (the number of true instances for each label).\n            - ``\"micro\"``: calculate metrics globally by considering each element of the label\n                indicator matrix as a label.\n            - ``\"none\"``: the scores for each class are returned.\n\n        output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n            construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n            lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n            `engine.state` and `output_transform` inherit from the ignite concept:\n            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n\n    Note:\n        Average Precision expects y to be comprised of 0's and 1's.\n        y_pred must either be probability estimates or confidence values.\n\n    \"\"\"\n\n    def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:\n        metric_fn = AveragePrecisionMetric(average=Average(average))\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)\n"
  },
  {
    "path": "monai/handlers/calibration.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import CalibrationErrorMetric, CalibrationReduction\nfrom monai.utils import MetricReduction\n\n__all__ = [\"CalibrationError\"]\n\n\nclass CalibrationError(IgniteMetricHandler):\n    \"\"\"\n    Ignite handler to compute Calibration Error during training or evaluation.\n\n    **Why Calibration Matters:**\n\n    A well-calibrated model produces probability estimates that match the true likelihood of correctness.\n    For example, predictions with 80% confidence should be correct approximately 80% of the time.\n    Modern neural networks often exhibit poor calibration (typically overconfident), which can be\n    problematic in medical imaging where probability estimates may inform clinical decisions.\n\n    This handler wraps :py:class:`~monai.metrics.CalibrationErrorMetric` for use with PyTorch Ignite\n    engines, automatically computing and aggregating calibration errors across iterations.\n\n    **Supported Calibration Metrics:**\n\n    - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors (most common).\n    - **Average Calibration Error (ACE)**: Unweighted average across bins.\n    - **Maximum Calibration Error (MCE)**: Worst-case calibration error.\n\n    Args:\n        num_bins: Number of equally-spaced bins for calibration computation. Defaults to 20.\n        include_background: Whether to include the first channel (index 0) in computation.\n            Set to ``False`` to exclude background in segmentation tasks. Defaults to ``True``.\n        calibration_reduction: Calibration error reduction mode. Options: ``\"expected\"`` (ECE),\n            ``\"average\"`` (ACE), ``\"maximum\"`` (MCE). Defaults to ``\"expected\"``.\n        metric_reduction: Reduction across batch/channel after computing per-sample errors.\n            Options: ``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``. Defaults to ``\"mean\"``.\n        output_transform: Callable to extract ``(y_pred, y)`` from ``engine.state.output``.\n            See `Ignite concepts <https://pytorch.org/ignite/concepts.html#state>`_ and\n            the batch output transform tutorial in the MONAI tutorials repository.\n        save_details: If ``True``, saves per-sample/per-channel metric values to\n            ``engine.state.metric_details[name]``. Defaults to ``True``.\n\n    References:\n        - Guo, C., et al. \"On Calibration of Modern Neural Networks.\" ICML 2017.\n          https://proceedings.mlr.press/v70/guo17a.html\n        - Barfoot, T., et al. \"Average Calibration Losses for Reliable Uncertainty in\n          Medical Image Segmentation.\" arXiv:2506.03942v3, 2025.\n          https://arxiv.org/abs/2506.03942v3\n\n    See Also:\n        - :py:class:`~monai.metrics.CalibrationErrorMetric`: The underlying metric class.\n        - :py:func:`~monai.metrics.calibration_binning`: Low-level binning for reliability diagrams.\n\n    Example:\n        >>> from monai.handlers import CalibrationError, from_engine\n        >>> from ignite.engine import Engine\n        >>>\n        >>> def evaluation_step(engine, batch):\n        ...     # Returns dict with \"pred\" (probabilities) and \"label\" (one-hot)\n        ...     return {\"pred\": model(batch[\"image\"]), \"label\": batch[\"label\"]}\n        >>>\n        >>> evaluator = Engine(evaluation_step)\n        >>>\n        >>> # Attach calibration error handler\n        >>> CalibrationError(\n        ...     num_bins=15,\n        ...     include_background=False,\n        ...     calibration_reduction=\"expected\",\n        ...     output_transform=from_engine([\"pred\", \"label\"]),\n        ... ).attach(evaluator, name=\"ECE\")\n        >>>\n        >>> # After evaluation, access results\n        >>> evaluator.run(val_loader)\n        >>> ece = evaluator.state.metrics[\"ECE\"]\n        >>> print(f\"Expected Calibration Error: {ece:.4f}\")\n    \"\"\"\n\n    def __init__(\n        self,\n        num_bins: int = 20,\n        include_background: bool = True,\n        calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED,\n        metric_reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        metric_fn = CalibrationErrorMetric(\n            num_bins=num_bins,\n            include_background=include_background,\n            calibration_reduction=calibration_reduction,\n            metric_reduction=metric_reduction,\n        )\n\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/checkpoint_loader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport warnings\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom monai.networks.utils import copy_model_state\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nCheckpoint, _ = optional_import(\"ignite.handlers\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Checkpoint\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass CheckpointLoader:\n    \"\"\"\n    CheckpointLoader acts as an Ignite handler to load checkpoint data from file.\n    It can load variables for network, optimizer, lr_scheduler, etc.\n    If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead\n    as PyTorch recommended and then use this loader to load the model.\n\n    Usage example::\n\n        trainer = SupervisedTrainer(...)\n        save_dict = {\n            \"trainer\": trainer,\n            \"net\": network,\n            \"opt\": optimizer,\n            \"lr\": lr_scheduler,\n        }\n\n        map_location = \"cuda:0\"\n        # checkpoint needs to have same save_dict for this to work\n        handler = CheckpointLoader(load_path=\"/test/checkpoint.pt\", load_dict=save_dict, map_location=map_location, strict=True)\n        handler(trainer)\n        # Trainer now has the same state as stored, including the number of epochs and iterations completed\n        # so you can resume an interrupted training at the place where it left\n\n    Args:\n        load_path: the file path of checkpoint, it should be a PyTorch `pth` file.\n        load_dict: target objects that load checkpoint to. examples::\n\n            {'network': net, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n\n        name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.\n        map_location: when loading the module for distributed training/evaluation,\n            need to provide an appropriate map_location argument to prevent a process\n            to step into others’ devices. If map_location is missing, torch.load will\n            first load the module to CPU and then copy each parameter to where it was\n            saved, which would result in all processes on the same machine using the\n            same set of devices.\n        strict: whether to strictly enforce that the keys and data shape in the `state_dict` of every item\n            of `load_dict` match the `state_dict` of the corresponding items of checkpoint, default to `True`.\n        strict_shape: whether to enforce the data shape of the matched layers in the checkpoint,\n            `if `False`, it will skip the layers that have different data shape with checkpoint content,\n            and ignore the `strict` arg. this can be useful advanced feature for transfer learning.\n            users should totally understand which layers will have different shape. default to `True`.\n\n    Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other\n        items in the `load_dict`. For example, if the shape of some layers in current model can't\n        match the checkpoint, the `parameter_group` of current optimizer may also can't match the\n        checkpoint, so skip loading checkpoint for optimizer.\n\n        For more details about loading checkpoint, please refer to:\n        https://pytorch.org/ignite/v0.4.5/generated/ignite.handlers.checkpoint.Checkpoint.html\n        #ignite.handlers.checkpoint.Checkpoint.load_objects.\n        https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        load_path: str,\n        load_dict: dict,\n        name: str | None = None,\n        map_location: dict | None = None,\n        strict: bool = True,\n        strict_shape: bool = True,\n    ) -> None:\n        if load_path is None:\n            raise AssertionError(\"must provide clear path to load checkpoint.\")\n        self.load_path = load_path\n        if load_dict is None or len(load_dict) <= 0:\n            raise AssertionError(\"must provide target objects to load.\")\n        self.logger = logging.getLogger(name)\n        self.load_dict = load_dict\n        self._name = name\n        self.map_location = map_location\n        if strict and not strict_shape:\n            warnings.warn(\"as `strict_shape` is already False, change `strict` to False.\")\n            strict = False\n        self.strict = strict\n        self.strict_shape = strict_shape\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self._name is None:\n            self.logger = engine.logger\n        engine.add_event_handler(Events.STARTED, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)\n\n        k, _ = list(self.load_dict.items())[0]\n        # single object and checkpoint is directly a state_dict\n        if len(self.load_dict) == 1 and k not in checkpoint:\n            checkpoint = {k: checkpoint}\n\n        if not self.strict_shape:\n            pop_items: list[str] = []\n            for k, obj in self.load_dict.items():\n                if isinstance(obj, torch.nn.Module):\n                    # skip items that don't match key name or data shape\n                    checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0]\n                else:\n                    warnings.warn(\"`strict_shape` is False, load checkpoint for model, skip others in `load_dict`.\")\n                    pop_items.append(k)\n            for i in pop_items:\n                self.load_dict.pop(i)\n\n        # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint\n        prior_max_epochs = engine.state.max_epochs\n        Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict)\n        if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs:\n            raise ValueError(\n                f\"Epoch count ({engine.state.epoch}) in checkpoint is larger than \"\n                f\"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, \"\n                \"construct trainer with `max_epochs` larger than checkpoint's epoch count. \"\n                \"To use checkpoint for inference, no need to load state_dict for the engine.\"\n            )\n        engine.state.max_epochs = prior_max_epochs\n\n        self.logger.info(f\"Restored all variables from {self.load_path}\")\n"
  },
  {
    "path": "monai/handlers/checkpoint_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport warnings\nfrom collections.abc import Mapping\nfrom typing import TYPE_CHECKING, Any\n\nfrom monai.utils import IgniteInfo, is_scalar, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine\n    from ignite.handlers import Checkpoint, DiskSaver\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    DiskSaver, _ = optional_import(\"ignite.handlers\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"DiskSaver\")\n    Checkpoint, _ = optional_import(\"ignite.handlers\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Checkpoint\")\n\n\nclass CheckpointSaver:\n    \"\"\"\n    CheckpointSaver acts as an Ignite handler to save checkpoint data into files.\n    It supports to save according to metrics result, epoch number, iteration number\n    and last model or exception.\n\n    Args:\n        save_dir: the target directory to save the checkpoints.\n        save_dict: source objects that save to the checkpoint. examples::\n\n            {'network': net, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n\n        name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.\n        file_prefix: prefix for the filenames to which objects will be saved.\n        save_final: whether to save checkpoint or session at final iteration or exception.\n            If checkpoints are to be saved when an exception is raised, put this handler before\n            `StatsHandler` in the handler list, because the logic with Ignite can only trigger\n            the first attached handler for `EXCEPTION_RAISED` event.\n        final_filename: set a fixed filename to save the final model if `save_final=True`.\n            If None, default to `checkpoint_final_iteration=N.pt`.\n        save_key_metric: whether to save checkpoint or session when the value of key_metric is\n            higher than all the previous values during training.keep 4 decimal places of metric,\n            checkpoint name is: {file_prefix}_key_metric=0.XXXX.pth.\n        key_metric_name: the name of key_metric in ignite metrics dictionary.\n            If None, use `engine.state.key_metric` instead.\n        key_metric_n_saved: save top N checkpoints or sessions, sorted by the value of key\n            metric in descending order.\n        key_metric_filename: set a fixed filename to set the best metric model, if not None,\n            `key_metric_n_saved` should be 1 and only keep the best metric model.\n        key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file.\n            if `True`, then will save an object in the checkpoint file with key `checkpointer` to be\n            consistent with the `include_self` arg of `Checkpoint` in ignite:\n            https://pytorch.org/ignite/v0.4.5/generated/ignite.handlers.checkpoint.Checkpoint.html.\n            typically, it's used to resume training and compare current metric with previous N values.\n        key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise,\n            save the first equally scored model. default to `False`.\n        key_metric_negative_sign: whether adding a negative sign to the metric score to compare metrics,\n            because for error-like metrics, smaller is better(objects with larger score are retained).\n            default to `False`.\n        epoch_level: save checkpoint during training for every N epochs or every N iterations.\n            `True` is epoch level, `False` is iteration level.\n        save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.\n        n_saved: save latest N checkpoints of epoch level or iteration level, 'None' is to save all.\n\n    Note:\n        CheckpointHandler can be used during training, validation or evaluation.\n        example of saved files:\n\n            - checkpoint_iteration=400.pt\n            - checkpoint_iteration=800.pt\n            - checkpoint_epoch=1.pt\n            - checkpoint_final_iteration=1000.pt\n            - checkpoint_key_metric=0.9387.pt\n\n    \"\"\"\n\n    def __init__(\n        self,\n        save_dir: str,\n        save_dict: dict,\n        name: str | None = None,\n        file_prefix: str = \"\",\n        save_final: bool = False,\n        final_filename: str | None = None,\n        save_key_metric: bool = False,\n        key_metric_name: str | None = None,\n        key_metric_n_saved: int = 1,\n        key_metric_filename: str | None = None,\n        key_metric_save_state: bool = False,\n        key_metric_greater_or_equal: bool = False,\n        key_metric_negative_sign: bool = False,\n        epoch_level: bool = True,\n        save_interval: int = 0,\n        n_saved: int | None = None,\n    ) -> None:\n        if save_dir is None:\n            raise AssertionError(\"must provide directory to save the checkpoints.\")\n        self.save_dir = save_dir\n        if not (save_dict is not None and len(save_dict) > 0):\n            raise AssertionError(\"must provide source objects to save.\")\n        self.save_dict = save_dict\n        self.logger = logging.getLogger(name)\n        self.epoch_level = epoch_level\n        self.save_interval = save_interval\n        self._final_checkpoint: Checkpoint | None = None\n        self._key_metric_checkpoint: Checkpoint | None = None\n        self._interval_checkpoint: Checkpoint | None = None\n        self._name = name\n        self._final_filename = final_filename\n\n        class _DiskSaver(DiskSaver):\n            \"\"\"\n            Enhance the DiskSaver to support fixed filename.\n\n            \"\"\"\n\n            def __init__(self, dirname: str, filename: str | None = None):\n                # set `atomic=False` as `atomic=True` only gives read/write permission to the user who saved the file,\n                # without group/others read permission\n                super().__init__(dirname=dirname, require_empty=False, atomic=False)\n                self.filename = filename\n\n            def __call__(self, checkpoint: Mapping, filename: str, metadata: Mapping | None = None) -> None:\n                if self.filename is not None:\n                    filename = self.filename\n                super().__call__(checkpoint=checkpoint, filename=filename, metadata=metadata)\n\n            def remove(self, filename: str) -> None:\n                if self.filename is not None:\n                    filename = self.filename\n                super().remove(filename=filename)\n\n        if save_final:\n\n            def _final_func(engine: Engine) -> Any:\n                return engine.state.iteration\n\n            self._final_checkpoint = Checkpoint(\n                to_save=self.save_dict,\n                save_handler=_DiskSaver(dirname=self.save_dir, filename=self._final_filename),\n                filename_prefix=file_prefix,\n                score_function=_final_func,\n                score_name=\"final_iteration\",\n            )\n\n        if save_key_metric:\n\n            def _score_func(engine: Engine) -> Any:\n                if isinstance(key_metric_name, str):\n                    metric_name = key_metric_name\n                elif hasattr(engine.state, \"key_metric_name\"):\n                    metric_name = engine.state.key_metric_name\n                else:\n                    raise ValueError(\n                        f\"Incompatible values: save_key_metric=True and key_metric_name={key_metric_name}.\"\n                    )\n                metric = engine.state.metrics[metric_name]\n                if not is_scalar(metric):\n                    warnings.warn(\n                        \"key metric is not a scalar value, skip metric comparison and don't save a model.\"\n                        \"please use other metrics as key metric, or change the `reduction` mode to 'mean'.\"\n                        f\"got metric: {metric_name}={metric}.\"\n                    )\n                    return -1\n                return (-1 if key_metric_negative_sign else 1) * metric\n\n            if key_metric_filename is not None and key_metric_n_saved > 1:\n                raise ValueError(\"if using fixed filename to save the best metric model, we should only save 1 model.\")\n\n            self._key_metric_checkpoint = Checkpoint(\n                to_save=self.save_dict,\n                save_handler=_DiskSaver(dirname=self.save_dir, filename=key_metric_filename),\n                filename_prefix=file_prefix,\n                score_function=_score_func,\n                score_name=\"key_metric\",\n                n_saved=key_metric_n_saved,\n                include_self=key_metric_save_state,\n                greater_or_equal=key_metric_greater_or_equal,\n            )\n\n        if save_interval > 0:\n\n            def _interval_func(engine: Engine) -> Any:\n                return engine.state.epoch if self.epoch_level else engine.state.iteration\n\n            self._interval_checkpoint = Checkpoint(\n                to_save=self.save_dict,\n                save_handler=_DiskSaver(dirname=self.save_dir),\n                filename_prefix=file_prefix,\n                score_function=_interval_func,\n                score_name=\"epoch\" if self.epoch_level else \"iteration\",\n                n_saved=n_saved,\n            )\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        \"\"\"\n        Utility to resume the internal state of key metric tracking list if configured to save\n        checkpoints based on the key metric value.\n        Note to set `key_metric_save_state=True` when saving the previous checkpoint.\n\n        Example::\n\n            CheckpointSaver(\n                ...\n                save_key_metric=True,\n                key_metric_save_state=True,  # config to also save the state of this saver\n            ).attach(engine)\n            engine.run(...)\n\n            # resumed training with a new CheckpointSaver\n            saver = CheckpointSaver(save_key_metric=True, ...)\n            # load the previous key metric tracking list into saver\n            CheckpointLoader(\"/test/model.pt\"), {\"checkpointer\": saver}).attach(engine)\n\n        \"\"\"\n        if self._key_metric_checkpoint is not None:\n            self._key_metric_checkpoint.load_state_dict(state_dict)\n        else:\n            warnings.warn(\"no key metric checkpoint saver to resume the key metric tracking list.\")\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self._name is None:\n            self.logger = engine.logger\n        if self._final_checkpoint is not None:\n            engine.add_event_handler(Events.COMPLETED, self.completed)\n            engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)\n        if self._key_metric_checkpoint is not None:\n            engine.add_event_handler(Events.EPOCH_COMPLETED, self.metrics_completed)\n        if self._interval_checkpoint is not None:\n            if self.epoch_level:\n                engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.save_interval), self.interval_completed)\n            else:\n                engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.save_interval), self.interval_completed)\n\n    def _delete_previous_final_ckpt(self):\n        if self._final_checkpoint is not None:\n            saved = self._final_checkpoint._saved\n            if len(saved) > 0:\n                item = saved.pop(0)\n                self._final_checkpoint.save_handler.remove(item.filename)\n                self.logger.info(f\"Deleted previous saved final checkpoint: {item.filename}\")\n\n    def completed(self, engine: Engine) -> None:\n        \"\"\"Callback for train or validation/evaluation completed Event.\n        Save final checkpoint if configure save_final is True.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if not callable(self._final_checkpoint):\n            raise AssertionError(\"Error: _final_checkpoint function not specified.\")\n        # delete previous saved final checkpoint if existing\n        self._delete_previous_final_ckpt()\n        self._final_checkpoint(engine)\n        if self.logger is None:\n            raise AssertionError\n        if not hasattr(self.logger, \"info\"):\n            raise AssertionError(\"Error, provided logger has not info attribute.\")\n        if self._final_filename is not None:\n            _final_checkpoint_path = os.path.join(self.save_dir, self._final_filename)\n        else:\n            _final_checkpoint_path = self._final_checkpoint.last_checkpoint  # type: ignore[assignment]\n        self.logger.info(f\"Train completed, saved final checkpoint: {_final_checkpoint_path}\")\n\n    def exception_raised(self, engine: Engine, e: Exception) -> None:\n        \"\"\"Callback for train or validation/evaluation exception raised Event.\n        Save current data as final checkpoint if configure save_final is True. This callback may be skipped\n        because the logic with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n            e: the exception caught in Ignite during engine.run().\n        \"\"\"\n        if not callable(self._final_checkpoint):\n            raise AssertionError(\"Error: _final_checkpoint function not specified.\")\n        # delete previous saved final checkpoint if existing\n        self._delete_previous_final_ckpt()\n        self._final_checkpoint(engine)\n        if self.logger is None:\n            raise AssertionError\n        if not hasattr(self.logger, \"info\"):\n            raise AssertionError(\"Error, provided logger has not info attribute.\")\n        if self._final_filename is not None:\n            _final_checkpoint_path = os.path.join(self.save_dir, self._final_filename)\n        else:\n            _final_checkpoint_path = self._final_checkpoint.last_checkpoint  # type: ignore[assignment]\n        self.logger.info(f\"Exception raised, saved the last checkpoint: {_final_checkpoint_path}\")\n        raise e\n\n    def metrics_completed(self, engine: Engine) -> None:\n        \"\"\"Callback to compare metrics and save models in train or validation when epoch completed.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if not callable(self._key_metric_checkpoint):\n            raise AssertionError(\"Error: _key_metric_checkpoint function not specified.\")\n        self._key_metric_checkpoint(engine)\n\n    def interval_completed(self, engine: Engine) -> None:\n        \"\"\"Callback for train epoch/iteration completed Event.\n        Save checkpoint if configure save_interval = N\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if not callable(self._interval_checkpoint):\n            raise AssertionError(\"Error: _interval_checkpoint function not specified.\")\n        self._interval_checkpoint(engine)\n        if self.logger is None:\n            raise AssertionError\n        if not hasattr(self.logger, \"info\"):\n            raise AssertionError(\"Error, provided logger has not info attribute.\")\n        if self.epoch_level:\n            self.logger.info(f\"Saved checkpoint at epoch: {engine.state.epoch}\")\n        else:\n            self.logger.info(f\"Saved checkpoint at iteration: {engine.state.iteration}\")\n"
  },
  {
    "path": "monai/handlers/classification_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport warnings\nfrom collections.abc import Callable\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom monai.data import CSVSaver, decollate_batch\nfrom monai.utils import IgniteInfo\nfrom monai.utils import ImageMetaKey as Key\nfrom monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather\n\nidist, _ = optional_import(\"ignite\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"distributed\")\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass ClassificationSaver:\n    \"\"\"\n    Event handler triggered on completing every iteration to save the classification predictions as CSV file.\n    If running in distributed data parallel, only saves CSV file in the specified rank.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        output_dir: str = \"./\",\n        filename: str = \"predictions.csv\",\n        delimiter: str = \",\",\n        overwrite: bool = True,\n        batch_transform: Callable = lambda x: x,\n        output_transform: Callable = lambda x: x,\n        name: str | None = None,\n        save_rank: int = 0,\n        saver: CSVSaver | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            output_dir: if `saver=None`, output CSV file directory.\n            filename: if `saver=None`, name of the saved CSV file name.\n            delimiter: the delimiter character in the saved file, default to \",\" as the default output type is `csv`.\n                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.\n            overwrite: if `saver=None`, whether to overwriting existing file content, if True,\n                will clear the file before saving. otherwise, will append new content to the file.\n            batch_transform: a callable that is used to extract the `meta_data` dictionary of\n                the input images from `ignite.engine.state.batch`. the purpose is to get the input\n                filenames from the `meta_data` and store with classification results together.\n                `engine.state` and `batch_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            output_transform: a callable that is used to extract the model prediction data from\n                `ignite.engine.state.output`. the first dimension of its output will be treated as\n                the batch dimension. each item in the batch will be saved individually.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            name: identifier of logging.logger to use, defaulting to `engine.logger`.\n            save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation,\n                default to 0.\n            saver: the saver instance to save classification results, if None, create a CSVSaver internally.\n                the saver must provide `save_batch(batch_data, meta_data)` and `finalize()` APIs.\n\n        \"\"\"\n        self.save_rank = save_rank\n        self.output_dir = output_dir\n        self.filename = filename\n        self.delimiter = delimiter\n        self.overwrite = overwrite\n        self.batch_transform = batch_transform\n        self.output_transform = output_transform\n        self.saver = saver\n\n        self.logger = logging.getLogger(name)\n        self._name = name\n        self._outputs: list[torch.Tensor] = []\n        self._filenames: list[str] = []\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self._name is None:\n            self.logger = engine.logger\n        if not engine.has_event_handler(self._started, Events.EPOCH_STARTED):\n            engine.add_event_handler(Events.EPOCH_STARTED, self._started)\n        if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):\n            engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n        if not engine.has_event_handler(self._finalize, Events.EPOCH_COMPLETED):\n            engine.add_event_handler(Events.EPOCH_COMPLETED, self._finalize)\n\n    def _started(self, _engine: Engine) -> None:\n        \"\"\"\n        Initialize internal buffers.\n\n        Args:\n            _engine: Ignite Engine, unused argument.\n\n        \"\"\"\n        self._outputs = []\n        self._filenames = []\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        This method assumes self.batch_transform will extract metadata from the input batch.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        meta_data = self.batch_transform(engine.state.batch)\n        if isinstance(meta_data, dict):\n            # decollate the `dictionary of list` to `list of dictionaries`\n            meta_data = decollate_batch(meta_data)\n        engine_output = self.output_transform(engine.state.output)\n        for m, o in zip(meta_data, engine_output):\n            self._filenames.append(f\"{m.get(Key.FILENAME_OR_OBJ)}\")\n            if isinstance(o, torch.Tensor):\n                o = o.detach()\n            self._outputs.append(o)\n\n    def _finalize(self, _engine: Engine) -> None:\n        \"\"\"\n        All gather classification results from ranks and save to CSV file.\n\n        Args:\n            _engine: Ignite Engine, unused argument.\n        \"\"\"\n        ws = idist.get_world_size()\n        if self.save_rank >= ws:\n            raise ValueError(\"target save rank is greater than the distributed group size.\")\n\n        outputs = torch.stack(self._outputs, dim=0)\n        filenames = self._filenames\n        if ws > 1:\n            outputs = evenly_divisible_all_gather(outputs, concat=True)\n            filenames = string_list_all_gather(filenames)\n\n        if len(filenames) == 0:\n            meta_dict = None\n        else:\n            if len(filenames) != len(outputs):\n                warnings.warn(f\"filenames length: {len(filenames)} doesn't match outputs length: {len(outputs)}.\")\n            meta_dict = {Key.FILENAME_OR_OBJ: filenames}\n\n        # save to CSV file only in the expected rank\n        if idist.get_rank() == self.save_rank:\n            saver = self.saver or CSVSaver(\n                output_dir=self.output_dir, filename=self.filename, overwrite=self.overwrite, delimiter=self.delimiter\n            )\n            saver.save_batch(outputs, meta_dict)\n            saver.finalize()\n"
  },
  {
    "path": "monai/handlers/clearml_handlers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Mapping, Sequence\nfrom typing import TYPE_CHECKING, Any\n\nfrom monai.utils import optional_import\n\nfrom .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler\n\n\nclass ClearMLHandler:\n    \"\"\"\n    Base class for the handlers to log everything to ClearML.\n    For more details of ClearML usage, please refer to:\n    https://clear.ml/docs/latest/docs/references/sdk/task\n\n    Usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        project_name: str | None,\n        task_name: str | None,\n        output_uri: str | bool,\n        tags: Sequence[str] | None,\n        reuse_last_task_id: bool,\n        continue_last_task: bool,\n        auto_connect_frameworks: bool | Mapping[str, bool | str | list],\n        auto_connect_arg_parser: bool | Mapping[str, bool],\n    ) -> None:\n        \"\"\"\n        Args:\n            project_name: ClearML project name, default to 'MONAI'.\n            task_name: ClearML task name, default to 'monai_experiment'.\n            output_uri: The default location for output models and other artifacts, default to 'True'.\n            tags: Add a list of tags (str) to the created Task, default to 'None'.\n            reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, default to 'True'.\n            continue_last_task: Continue the execution of a previously executed Task (experiment), default to 'False'.\n            auto_connect_frameworks: Automatically connect frameworks, default to 'True'.\n            auto_connect_arg_parser: Automatically connect an argparse object to the Task, default to 'True'.\n\n        \"\"\"\n\n        if TYPE_CHECKING:\n            import clearml\n        else:\n            clearml, _ = optional_import(\"clearml\")\n\n        # Always check if the user didn't already add a `task.init`` in before\n        # if so, use that task, otherwise create a new one.\n        if clearml.Task.current_task():\n            self.clearml_task = clearml.Task.current_task()\n        else:\n            self.clearml_task = clearml.Task.init(\n                project_name=project_name,\n                task_name=task_name,\n                output_uri=output_uri,\n                tags=tags,\n                reuse_last_task_id=reuse_last_task_id,\n                continue_last_task=continue_last_task,\n                auto_connect_frameworks=auto_connect_frameworks,\n                auto_connect_arg_parser=auto_connect_arg_parser,\n            )\n\n\nclass ClearMLStatsHandler(ClearMLHandler, TensorBoardStatsHandler):\n    \"\"\"\n\n    Class to write tensorboard stats by inheriting TensorBoardStatsHandler class.\n    Everything from Tensorboard is logged automatically to ClearML.\n\n    Usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        project_name: str | None = \"MONAI\",\n        task_name: str | None = \"monai_experiment\",\n        output_uri: str | bool = True,\n        tags: Sequence[str] | None = None,\n        reuse_last_task_id: bool = True,\n        continue_last_task: bool = False,\n        auto_connect_frameworks: bool | Mapping[str, bool | str | list] = True,\n        auto_connect_arg_parser: bool | Mapping[str, bool] = True,\n        *args: Any,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Args:\n            project_name: ClearML project name, default to 'MONAI'.\n            task_name: ClearML task name, default to 'monai_experiment'.\n            output_uri: The default location for output models and other artifacts, default to 'True'.\n            tags: Add a list of tags (str) to the created Task, default to 'None'.\n            reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, default to 'True'.\n            continue_last_task: Continue the execution of a previously executed Task (experiment), default to 'False'.\n            auto_connect_frameworks: Automatically connect frameworks, default to 'True'.\n            auto_connect_arg_parser: Automatically connect an argparse object to the Task, default to 'True'.\n\n        \"\"\"\n\n        ClearMLHandler.__init__(\n            self,\n            project_name=project_name,\n            task_name=task_name,\n            output_uri=output_uri,\n            tags=tags,\n            reuse_last_task_id=reuse_last_task_id,\n            continue_last_task=continue_last_task,\n            auto_connect_frameworks=auto_connect_frameworks,\n            auto_connect_arg_parser=auto_connect_arg_parser,\n        )\n        TensorBoardStatsHandler.__init__(self, *args, **kwargs)\n\n\nclass ClearMLImageHandler(ClearMLHandler, TensorBoardImageHandler):\n    \"\"\"\n\n    This class inherits all functionality from TensorBoardImageHandler class.\n    Everything from Tensorboard is logged automatically to ClearML.\n\n    Usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        project_name: str | None = \"MONAI\",\n        task_name: str | None = \"monai_experiment\",\n        output_uri: str | bool = True,\n        tags: Sequence[str] | None = None,\n        reuse_last_task_id: bool = True,\n        continue_last_task: bool = False,\n        auto_connect_frameworks: bool | Mapping[str, bool | str | list] = True,\n        auto_connect_arg_parser: bool | Mapping[str, bool] = True,\n        *args: Any,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Args:\n            project_name: ClearML project name, default to 'MONAI'.\n            task_name: ClearML task name, default to 'monai_experiment'.\n            output_uri: The default location for output models and other artifacts, default to 'True'.\n            tags: Add a list of tags (str) to the created Task, default to 'None'.\n            reuse_last_task_id: Force a new Task (experiment) with a previously used Task ID, default to 'True'.\n            continue_last_task: Continue the execution of a previously executed Task (experiment), default to 'False'.\n            auto_connect_frameworks: Automatically connect frameworks, default to 'True'.\n            auto_connect_arg_parser: Automatically connect an argparse object to the Task, default to 'True'.\n\n        \"\"\"\n\n        ClearMLHandler.__init__(\n            self,\n            project_name=project_name,\n            task_name=task_name,\n            output_uri=output_uri,\n            tags=tags,\n            reuse_last_task_id=reuse_last_task_id,\n            continue_last_task=continue_last_task,\n            auto_connect_frameworks=auto_connect_frameworks,\n            auto_connect_arg_parser=auto_connect_arg_parser,\n        )\n\n        TensorBoardImageHandler.__init__(self, *args, **kwargs)\n"
  },
  {
    "path": "monai/handlers/confusion_matrix.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import ConfusionMatrixMetric\nfrom monai.utils.enums import MetricReduction\n\n\nclass ConfusionMatrix(IgniteMetricHandler):\n    \"\"\"\n    Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        metric_name: str = \"hit_rate\",\n        compute_sample: bool = False,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            include_background: whether to include metric computation on the first channel of\n                the predicted output. Defaults to True.\n            metric_name: [``\"sensitivity\"``, ``\"specificity\"``, ``\"precision\"``, ``\"negative predictive value\"``,\n                ``\"miss rate\"``, ``\"fall out\"``, ``\"false discovery rate\"``, ``\"false omission rate\"``,\n                ``\"prevalence threshold\"``, ``\"threat score\"``, ``\"accuracy\"``, ``\"balanced accuracy\"``,\n                ``\"f1 score\"``, ``\"matthews correlation coefficient\"``, ``\"fowlkes mallows index\"``,\n                ``\"informedness\"``, ``\"markedness\"``]\n                Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),\n                and you can also input those names instead.\n            compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.\n                if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:meth:`monai.metrics.confusion_matrix`\n        \"\"\"\n        metric_fn = ConfusionMatrixMetric(\n            include_background=include_background,\n            metric_name=metric_name,\n            compute_sample=compute_sample,\n            reduction=reduction,\n        )\n        self.metric_name = metric_name\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/decollate_batch.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom monai.config import KeysCollection\nfrom monai.engines.utils import IterationEvents\nfrom monai.transforms import Decollated\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass DecollateBatch:\n    \"\"\"\n    Ignite handler to execute the `decollate batch` logic for `engine.state.batch` and `engine.state.output`.\n    Typical usage is to set `decollate=False` in the engine and execute some postprocessing logic first\n    then decollate the batch, otherwise, engine will decollate batch before the postprocessing.\n\n    Args:\n        event: expected EVENT to attach the handler, should be \"MODEL_COMPLETED\" or \"ITERATION_COMPLETED\".\n            default to \"MODEL_COMPLETED\".\n        detach: whether to detach the tensors. scalars tensors will be detached into number types\n            instead of torch tensors.\n        decollate_batch: whether to decollate `engine.state.batch` of ignite engine.\n        batch_keys: if `decollate_batch=True`, specify the keys of the corresponding items to decollate\n            in `engine.state.batch`, note that it will delete other keys not specified. if None,\n            will decollate all the keys. it replicates the scalar values to every item of the decollated list.\n        decollate_output: whether to decollate `engine.state.output` of ignite engine.\n        output_keys: if `decollate_output=True`, specify the keys of the corresponding items to decollate\n            in `engine.state.output`, note that it will delete other keys not specified. if None,\n            will decollate all the keys. it replicates the scalar values to every item of the decollated list.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        event: str = \"MODEL_COMPLETED\",\n        detach: bool = True,\n        decollate_batch: bool = True,\n        batch_keys: KeysCollection | None = None,\n        decollate_output: bool = True,\n        output_keys: KeysCollection | None = None,\n        allow_missing_keys: bool = False,\n    ):\n        event = event.upper()\n        if event not in (\"MODEL_COMPLETED\", \"ITERATION_COMPLETED\"):\n            raise ValueError(\"event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.\")\n        self.event = event\n\n        self.batch_transform = (\n            Decollated(keys=batch_keys, detach=detach, allow_missing_keys=allow_missing_keys)\n            if decollate_batch\n            else None\n        )\n\n        self.output_transform = (\n            Decollated(keys=output_keys, detach=detach, allow_missing_keys=allow_missing_keys)\n            if decollate_output\n            else None\n        )\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.event == \"MODEL_COMPLETED\":\n            engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self)\n        else:\n            engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.batch_transform is not None and isinstance(engine.state.batch, (list, dict)):\n            engine.state.batch = self.batch_transform(engine.state.batch)\n        if self.output_transform is not None and isinstance(engine.state.output, (list, dict)):\n            engine.state.output = self.output_transform(engine.state.output)\n"
  },
  {
    "path": "monai/handlers/earlystop_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\nfrom typing import TYPE_CHECKING\n\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nEarlyStopping, _ = optional_import(\"ignite.handlers\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"EarlyStopping\")\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n\n\nclass EarlyStopHandler:\n    \"\"\"\n    EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events.\n    It‘s based on the `EarlyStopping` handler in ignite.\n\n    Args:\n        patience: number of events to wait if no improvement and then stop the training.\n        score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine`\n            object that the handler attached, can be a trainer or validator, and return a score `float`.\n            an improvement is considered if the score is higher.\n        trainer: trainer engine to stop the run if no improvement, if None, must call `set_trainer()` before training.\n        min_delta: a minimum increase in the score to qualify as an improvement,\n            i.e. an increase of less than or equal to `min_delta`, will count as no improvement.\n        cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise,\n            it defines an increase after the last event, default to False.\n        epoch_level: check early stopping for every epoch or every iteration of the attached engine,\n            `True` is epoch level, `False` is iteration level, default to epoch level.\n\n    Note:\n        If in distributed training and uses loss value of every iteration to detect early stopping,\n        the values may be different in different ranks. When using this handler with distributed training,\n        please also note that to prevent \"dist.destroy_process_group()\" hangs, you can use an \"all_reduce\" operation\n        to synchronize the stop signal across all ranks. The mechanism can be implemented in the `score_function`. The following\n        is an example:\n\n        .. code-block:: python\n\n            import os\n\n            import torch\n            import torch.distributed as dist\n\n\n            def score_function(engine):\n                val_metric = engine.state.metrics[\"val_mean_dice\"]\n                if dist.is_initialized():\n                    device = torch.device(\"cuda:\" + os.environ[\"LOCAL_RANK\"])\n                    val_metric = torch.tensor([val_metric]).to(device)\n                    dist.all_reduce(val_metric, op=dist.ReduceOp.SUM)\n                    val_metric /= dist.get_world_size()\n                    return val_metric.item()\n                return val_metric\n\n\n        User may attach this handler to validator engine to detect validation metrics and stop the training,\n        in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        patience: int,\n        score_function: Callable,\n        trainer: Engine | None = None,\n        min_delta: float = 0.0,\n        cumulative_delta: bool = False,\n        epoch_level: bool = True,\n    ) -> None:\n        self.patience = patience\n        self.score_function = score_function\n        self.min_delta = min_delta\n        self.cumulative_delta = cumulative_delta\n        self.epoch_level = epoch_level\n        self._handler = None\n\n        if trainer is not None:\n            self.set_trainer(trainer=trainer)\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.epoch_level:\n            engine.add_event_handler(Events.EPOCH_COMPLETED, self)\n        else:\n            engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n\n    def set_trainer(self, trainer: Engine) -> None:\n        \"\"\"\n        Set trainer to execute early stop if not setting properly in `__init__()`.\n        \"\"\"\n        self._handler = EarlyStopping(\n            patience=self.patience,\n            score_function=self.score_function,\n            trainer=trainer,\n            min_delta=self.min_delta,\n            cumulative_delta=self.cumulative_delta,\n        )\n\n    def __call__(self, engine: Engine) -> None:\n        if self._handler is None:\n            raise RuntimeError(\"please set trainer in __init__() or call set_trainer() before training.\")\n        self._handler(engine)\n"
  },
  {
    "path": "monai/handlers/garbage_collector.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport gc\nfrom typing import TYPE_CHECKING\n\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine, Events\n    from ignite.engine.events import CallableEventWithFilter\nelse:\n    CallableEventWithFilter, _ = optional_import(\n        \"ignite.engine.events\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"CallableEventWithFilter\"\n    )\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    Events, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\n\nclass GarbageCollector:\n    \"\"\"\n    Run garbage collector after each epoch\n\n    Args:\n        trigger_event: the event that trigger a call to this handler.\n            - \"epoch\", after completion of each epoch (equivalent of ignite.engine.Events.EPOCH_COMPLETED)\n            - \"iteration\", after completion of each iteration (equivalent of ignite.engine.Events.ITERATION_COMPLETED)\n            - any ignite built-in event from ignite.engine.Events.\n            Defaults to \"epoch\".\n        log_level: log level (integer) for some garbage collection information as below. Defaults to 10 (DEBUG).\n            - 50 (CRITICAL)\n            - 40 (ERROR)\n            - 30 (WARNING)\n            - 20 (INFO)\n            - 10 (DEBUG)\n            - 0 (NOTSET)\n    \"\"\"\n\n    def __init__(self, trigger_event: str | Events | CallableEventWithFilter = \"epoch\", log_level: int = 10):\n        self.trigger_event: Events | CallableEventWithFilter\n        if isinstance(trigger_event, (Events, CallableEventWithFilter)):\n            self.trigger_event = trigger_event\n        elif trigger_event.lower() == \"epoch\":\n            self.trigger_event = Events.EPOCH_COMPLETED\n        elif trigger_event.lower() == \"iteration\":\n            self.trigger_event = Events.ITERATION_COMPLETED\n        else:\n            raise ValueError(\n                f\"'trigger_event' should be either epoch, iteration, or an ignite built-in event from\"\n                f\" ignite.engine.Events, '{trigger_event}' was given.\"\n            )\n\n        self.log_level = log_level\n\n    def attach(self, engine: Engine) -> None:\n        if not engine.has_event_handler(self, self.trigger_event):\n            engine.add_event_handler(self.trigger_event, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        This method calls python garbage collector.\n\n        Args:\n            engine: Ignite Engine, it should be either a trainer or validator.\n        \"\"\"\n        # get count before garbage collection\n        pre_count = gc.get_count()\n        # first call to garbage collector\n        gc.collect()\n        # second call to garbage collector\n        unreachable = gc.collect()\n        # get count after garbage collection\n        after_count = gc.get_count()\n        engine.logger.log(\n            self.log_level,\n            f\"Garbage Count: [before: {pre_count}] -> [after: {after_count}] (unreachable : {unreachable})\",\n        )\n"
  },
  {
    "path": "monai/handlers/hausdorff_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import HausdorffDistanceMetric\nfrom monai.utils import MetricReduction\n\n\nclass HausdorffDistance(IgniteMetricHandler):\n    \"\"\"\n    Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = False,\n        distance_metric: str = \"euclidean\",\n        percentile: float | None = None,\n        directed: bool = False,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            include_background: whether to include distance computation on the first channel of the predicted output.\n                Defaults to ``False``.\n            distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n                the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n            percentile: an optional float number between 0 and 100. If specified, the corresponding\n                percentile of the Hausdorff Distance rather than the maximum result will be achieved.\n                Defaults to ``None``.\n            directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: hausdorff distance\n                of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        \"\"\"\n        metric_fn = HausdorffDistanceMetric(\n            include_background=include_background,\n            distance_metric=distance_metric,\n            percentile=percentile,\n            directed=directed,\n            reduction=reduction,\n        )\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/ignite_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import TYPE_CHECKING, Any, cast\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.metrics import CumulativeIterationMetric, LossMetric\nfrom monai.utils import IgniteInfo, MetricReduction, min_version, optional_import\n\nidist, _ = optional_import(\"ignite\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"distributed\")\n\nif TYPE_CHECKING:\n    try:\n        _, has_ignite = optional_import(\"ignite\")\n        from ignite.engine import Engine\n        from ignite.metrics import Metric\n        from ignite.metrics.metric import reinit__is_reduced\n    except ImportError:\n        has_ignite = False\n\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    Metric, _ = optional_import(\"ignite.metrics\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Metric\", as_type=\"base\")\n    reinit__is_reduced, _ = optional_import(\n        \"ignite.metrics.metric\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"reinit__is_reduced\", as_type=\"decorator\"\n    )\n\n\nclass IgniteMetricHandler(Metric):\n    \"\"\"\n    Base Metric class based on ignite event handler mechanism.\n    The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,\n    or a list of PyTorch Tensor or numpy array without batch dim.\n\n    Args:\n        metric_fn: callable function or class to compute raw metric results after every iteration.\n            expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).\n        loss_fn: A torch _Loss function which is used to generate the LossMetric\n        output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n            construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n            lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n            `engine.state` and `output_transform` inherit from the ignite concept:\n            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n        save_details: whether to save metric computation details per image, for example: mean_dice of every image.\n            default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n        reduction: Argument for the LossMetric, look there for details\n        get_not_nans: Argument for the LossMetric, look there for details\n\n    \"\"\"\n\n    def __init__(\n        self,\n        metric_fn: CumulativeIterationMetric | None = None,\n        loss_fn: _Loss | None = None,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        self._is_reduced: bool = False\n        self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn)\n        self.loss_fn = loss_fn\n        self.save_details = save_details\n        self._scores: list = []\n        self._engine: Engine | None = None\n        self._name: str | None = None\n\n        if self.metric_fn is None and self.loss_fn is None:\n            raise ValueError(\"Either metric_fn or loss_fn have to be passed.\")\n        if self.metric_fn is not None and self.loss_fn is not None:\n            raise ValueError(\"Either metric_fn or loss_fn have to be passed, but not both.\")\n        if self.loss_fn:\n            self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans)\n\n        super().__init__(output_transform)\n\n    @reinit__is_reduced\n    def reset(self) -> None:\n        self.metric_fn.reset()\n\n    @reinit__is_reduced\n    def update(self, output: Sequence[torch.Tensor]) -> None:\n        \"\"\"\n        Args:\n            output: sequence with contents [y_pred, y].\n\n        Raises:\n            ValueError: When ``output`` length is not 2. metric_fn can only support y_pred and y.\n\n        \"\"\"\n        if len(output) != 2:\n            raise ValueError(f\"output must have length 2, got {len(output)}.\")\n\n        y_pred, y = output\n\n        self.metric_fn(y_pred, y)\n\n    def compute(self) -> Any:\n        \"\"\"\n        Raises:\n            NotComputableError: When ``compute`` is called before an ``update`` occurs.\n\n        \"\"\"\n        result = self.metric_fn.aggregate()\n        if isinstance(result, (tuple, list)):\n            if len(result) > 1:\n                warnings.warn(\"metric handler can only record the first value of result list.\")\n            result = result[0]\n\n        self._is_reduced = True\n\n        # save score of every image into engine.state for other components\n        if self.save_details:\n            if self._engine is None or self._name is None:\n                raise RuntimeError(\"please call the attach() function to connect expected engine first.\")\n            self._engine.state.metric_details[self._name] = self.metric_fn.get_buffer()  # type: ignore\n\n        if isinstance(result, torch.Tensor):\n            result = result.squeeze()\n            if result.ndim == 0:\n                result = result.item()\n        return result\n\n    def attach(self, engine: Engine, name: str) -> None:  # type: ignore[override]\n        \"\"\"\n        Attaches current metric to provided engine. On the end of engine's run,\n        `engine.state.metrics` dictionary will contain computed metric's value under provided name.\n\n        Args:\n            engine: the engine to which the metric must be attached.\n            name: the name of the metric to attach.\n\n        \"\"\"\n        super().attach(engine=engine, name=name)\n        # FIXME: record engine for communication, ignite will support it in the future version soon\n        self._engine = engine\n        self._name = name\n        if self.save_details and not hasattr(engine.state, \"metric_details\"):\n            engine.state.metric_details = {}  # type: ignore\n"
  },
  {
    "path": "monai/handlers/logfile_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nfrom typing import TYPE_CHECKING\n\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n__all__ = [\"LogfileHandler\"]\n\n\nclass LogfileHandler:\n    \"\"\"\n    Adds a `logging.FileHandler` to the attached engine's logger when the start event occurs and removes it again when\n    then completed event occurs.\n\n    A handler is needed to remove `FileHandler` object when the complete event occurs so that further runs of different\n    engines write only to the log files they should, rather than previous files. Multiple handlers can write to the same\n    file which allows output from train and evaluation engine objects to be condensed in one file. If the given output\n    directory doesn't exist it will by default be created when the  start event occurs. This can be used in conjunction\n    with `CheckpointSaver` to save a log file to the same destination as the saved checkpoints. Since the handler is\n    added possibly after other logging events during initialisation, not all logging data will be retained.\n\n    Args:\n        output_dir: directory to save the log file to\n        filename: name of the file to save log to\n        loglevel: log level for the handler\n        formatter: format string for the `logging.Formatter` set for the handler\n        create_dir: if True, create `output_dir` if it doesn't exist\n    \"\"\"\n\n    def __init__(\n        self,\n        output_dir: str,\n        filename: str = \"log.txt\",\n        loglevel: int = logging.INFO,\n        formatter: str = \"%(asctime)s %(name)s %(levelname)s: %(message)s\",\n        create_dir: bool = True,\n    ):\n        self.output_dir: str = output_dir\n        self.filename: str = filename\n        self.loglevel: int = loglevel\n        self.formatter: str = formatter\n        self.create_dir: bool = create_dir\n        self.logger: logging.Logger | None = None\n        self.handler: logging.FileHandler | None = None\n\n    def attach(self, engine: Engine) -> None:\n        self.logger = engine.logger\n        engine.add_event_handler(Events.STARTED, self._start)\n        engine.add_event_handler(Events.COMPLETED, self._completed)\n\n    def _start(self, engine: Engine) -> None:\n        if self.create_dir and not os.path.exists(self.output_dir):\n            os.makedirs(self.output_dir, exist_ok=True)\n\n        self.handler = logging.FileHandler(os.path.join(self.output_dir, self.filename))\n        self.handler.setLevel(self.loglevel)\n        self.handler.setFormatter(logging.Formatter(self.formatter))\n\n        if self.logger is not None:\n            self.logger.addHandler(self.handler)\n        else:\n            raise AttributeError(\"`self.logger` must not be None in start event\")\n\n    def _completed(self, engine: Engine) -> None:\n        if self.logger is not None and self.handler is not None:\n            self.logger.removeHandler(self.handler)\n            self.handler.close()\n        else:\n            raise AttributeError(\"`self.logger` and `self.handler` must not be None in complete event\")\n\n        self.handler = None\n"
  },
  {
    "path": "monai/handlers/lr_schedule_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nfrom collections.abc import Callable\nfrom typing import TYPE_CHECKING, Any\n\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler\n\nfrom monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n\n\nclass LrScheduleHandler:\n    \"\"\"\n    Ignite handler to update the Learning Rate based on PyTorch LR scheduler.\n    \"\"\"\n\n    def __init__(\n        self,\n        lr_scheduler: _LRScheduler | ReduceLROnPlateau,\n        print_lr: bool = True,\n        name: str | None = None,\n        epoch_level: bool = True,\n        step_transform: Callable[[Engine], Any] = lambda engine: (),\n    ) -> None:\n        \"\"\"\n        Args:\n            lr_scheduler: typically, lr_scheduler should be PyTorch\n                lr_scheduler object. If customized version, must have `step` and `get_last_lr` methods.\n            print_lr: whether to print out the latest learning rate with logging.\n            name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.\n            epoch_level: execute lr_scheduler.step() after every epoch or every iteration.\n                `True` is epoch level, `False` is iteration level.\n            step_transform: a callable that is used to transform the information from `engine`\n                to expected input data of lr_scheduler.step() function if necessary.\n\n        Raises:\n            TypeError: When ``step_transform`` is not ``callable``.\n\n        \"\"\"\n        self.lr_scheduler = lr_scheduler\n        self.print_lr = print_lr\n        self.logger = logging.getLogger(name)\n        self.epoch_level = epoch_level\n        if not callable(step_transform):\n            raise TypeError(f\"step_transform must be callable but is {type(step_transform).__name__}.\")\n        self.step_transform = step_transform\n\n        self._name = name\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self._name is None:\n            self.logger = engine.logger\n        if self.epoch_level:\n            engine.add_event_handler(Events.EPOCH_COMPLETED, self)\n        else:\n            engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        args = ensure_tuple(self.step_transform(engine))\n        self.lr_scheduler.step(*args)\n        if self.print_lr:\n            self.logger.info(f\"Current learning rate: {self.lr_scheduler._last_lr[0]}\")  # type: ignore[union-attr]\n"
  },
  {
    "path": "monai/handlers/mean_dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import DiceMetric\nfrom monai.utils import MetricReduction\n\n\nclass MeanDice(IgniteMetricHandler):\n    \"\"\"\n    Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        num_classes: int | None = None,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n        return_with_label: bool | list[str] = False,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            include_background: whether to include dice computation on the first channel of the predicted output.\n                Defaults to True.\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            num_classes: number of input channels (always including the background). When this is None,\n                ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are\n                single-channel class indices and the number of classes is not automatically inferred from data.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: mean dice of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n            return_with_label: whether to return the metrics with label, only works when reduction is \"mean_batch\".\n                If `True`, use \"label_{index}\" as the key corresponding to C channels; if 'include_background' is True,\n                the index begins at \"0\", otherwise at \"1\". It can also take a list of label names.\n                The outcome will then be returned as a dictionary.\n\n        See also:\n            :py:meth:`monai.metrics.meandice.compute_dice`\n        \"\"\"\n        metric_fn = DiceMetric(\n            include_background=include_background,\n            reduction=reduction,\n            num_classes=num_classes,\n            return_with_label=return_with_label,\n        )\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/mean_iou.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import MeanIoU\nfrom monai.utils import MetricReduction\n\n\nclass MeanIoUHandler(IgniteMetricHandler):\n    \"\"\"\n    Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            include_background: whether to include iou computation on the first channel of the predicted output.\n                Defaults to True.\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: mean iou of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:meth:`monai.metrics.meaniou.compute_iou`\n        \"\"\"\n        metric_fn = MeanIoU(include_background=include_background, reduction=reduction)\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/metric_logger.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom collections.abc import Callable, Mapping, Sequence\nfrom enum import Enum\nfrom threading import RLock\nfrom typing import TYPE_CHECKING, Any\n\nfrom monai.utils import IgniteInfo, min_version, optional_import\nfrom monai.utils.enums import CommonKeys\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n\n\ndef _get_loss_from_output(output: Sequence[Mapping[str, Any]], loss_key: str = CommonKeys.LOSS) -> Any:\n    return output[0][loss_key]\n\n\nclass MetricLoggerKeys(Enum):\n    METRICS = \"Metrics\"\n    LOSS = \"Loss\"\n\n\nclass MetricLogger:\n    \"\"\"\n    Collect per-iteration metrics and loss value from the attached trainer. This will also collect metric values from\n    a given evaluator object which is expected to perform evaluation at the end of training epochs. This class is\n    useful for collecting loss and metric values in one place for storage with checkpoint savers (`state_dict` and\n    `load_state_dict` methods provided as expected by Pytorch and Ignite) and for graphing during training.\n\n    Example::\n        # construct an evaluator saving mean dice metric values in the key \"val_mean_dice\"\n        evaluator = SupervisedEvaluator(..., key_val_metric={\"val_mean_dice\": MeanDice(...)})\n\n        # construct the logger and associate with evaluator to extract metric values from\n        logger = MetricLogger(evaluator=evaluator)\n\n        # construct the trainer with the logger passed in as a handler so that it logs loss values\n        trainer = SupervisedTrainer(..., train_handlers=[logger, ValidationHandler(1, evaluator)])\n\n        # run training, logger.loss will be a list of (iteration, loss) values, logger.metrics a dict with key\n        # \"val_mean_dice\" storing a list of (iteration, metric) values\n        trainer.run()\n\n    Args:\n        loss_transform: Converts the `output` value from the trainer's state into a loss value\n            `engine.state` and `loss_transform` inherit from the ignite concept:\n            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n        metric_transform: Converts the metric value coming from the trainer/evaluator's state into a storable value\n        evaluator: Optional evaluator to consume metric results from at the end of its evaluation run\n    \"\"\"\n\n    def __init__(\n        self,\n        loss_transform: Callable = _get_loss_from_output,\n        metric_transform: Callable = lambda x: x,\n        evaluator: Engine | None = None,\n    ) -> None:\n        self.loss_transform = loss_transform\n        self.metric_transform = metric_transform\n        self.loss: list = []\n        self.metrics: defaultdict = defaultdict(list)\n        self.iteration = 0\n        self.lock = RLock()\n\n        if evaluator is not None:\n            self.attach_evaluator(evaluator)\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n\n    def attach_evaluator(self, evaluator: Engine) -> None:\n        \"\"\"\n        Attach event  handlers to the given evaluator to log metric values from it.\n\n        Args:\n            evaluator: Ignite Engine implementing network evaluation\n        \"\"\"\n        evaluator.add_event_handler(Events.COMPLETED, self.log_metrics)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        with self.lock:\n            self.iteration = engine.state.iteration\n            lossval = self.loss_transform(engine.state.output)\n\n            self.loss.append((self.iteration, lossval))\n            self.log_metrics(engine)\n\n    def log_metrics(self, engine: Engine) -> None:\n        \"\"\"\n        Log metrics from the given Engine's state member.\n\n        Args:\n            engine: Ignite Engine to log from\n        \"\"\"\n        with self.lock:\n            for m, v in engine.state.metrics.items():\n                v = self.metric_transform(v)\n                self.metrics[m].append((self.iteration, v))\n\n    def state_dict(self):\n        return {MetricLoggerKeys.LOSS: self.loss, MetricLoggerKeys.METRICS: self.metrics}\n\n    def load_state_dict(self, state_dict):\n        self.loss[:] = state_dict[MetricLoggerKeys.LOSS]\n        self.metrics.clear()\n        self.metrics.update(state_dict[MetricLoggerKeys.METRICS])\n\n\nmetriclogger = MetricLogger\n"
  },
  {
    "path": "monai/handlers/metrics_reloaded_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical\nfrom monai.utils.enums import MetricReduction\n\n\nclass MetricsReloadedBinaryHandler(IgniteMetricHandler):\n    \"\"\"\n    Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded.\n    \"\"\"\n\n    def __init__(\n        self,\n        metric_name: str,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            metric_name: Name of a binary metric from the MetricsReloaded package.\n            include_background: whether to include computation on the first channel of\n                the predicted output. Defaults to ``True``.\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n                Here `not_nans` count the number of not nans for the metric,\n                thus its shape equals to the shape of the metric.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:meth:`monai.metrics.wrapper`\n        \"\"\"\n        metric_fn = MetricsReloadedBinary(\n            metric_name=metric_name,\n            include_background=include_background,\n            reduction=reduction,\n            get_not_nans=get_not_nans,\n        )\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n\n\nclass MetricsReloadedCategoricalHandler(IgniteMetricHandler):\n    \"\"\"\n    Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded.\n    \"\"\"\n\n    def __init__(\n        self,\n        metric_name: str,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        smooth_dr: float = 1e-5,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            metric_name: Name of a categorical metric from the MetricsReloaded package.\n            include_background: whether to include computation on the first channel of\n                the predicted output. Defaults to ``True``.\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n                Here `not_nans` count the number of not nans for the metric,\n                thus its shape equals to the shape of the metric.\n            smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: TP/TN/FP/FN of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:meth:`monai.metrics.wrapper`\n        \"\"\"\n        metric_fn = MetricsReloadedCategorical(\n            metric_name=metric_name,\n            include_background=include_background,\n            reduction=reduction,\n            get_not_nans=get_not_nans,\n            smooth_dr=smooth_dr,\n        )\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/metrics_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sequence\nfrom typing import TYPE_CHECKING\n\nfrom monai.data import decollate_batch\nfrom monai.handlers.utils import write_metrics_reports\nfrom monai.utils import IgniteInfo\nfrom monai.utils import ImageMetaKey as Key\nfrom monai.utils import ensure_tuple, min_version, optional_import, string_list_all_gather\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nidist, _ = optional_import(\"ignite\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"distributed\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass MetricsSaver:\n    \"\"\"\n    ignite handler to save metrics values and details into expected files.\n\n    Args:\n        save_dir: directory to save the metrics and metric details.\n        metrics: expected final metrics to save into files, can be: None, \"*\" or list of strings.\n            None - don't save any metrics into files.\n            \"*\" - save all the existing metrics in `engine.state.metrics` dict into separate files.\n            list of strings - specify the expected metrics to save.\n            default to \"*\" to save all the metrics into `metrics.csv`.\n        metric_details: expected metric details to save into files, the data comes from\n            `engine.state.metric_details`, which should be provided by different `Metrics`,\n            typically, it's some intermediate values in metric computation.\n            for example: mean dice of every channel of every image in the validation dataset.\n            it must contain at least 2 dims: (batch, classes, ...),\n            if not, will unsqueeze to 2 dims.\n            this arg can be: None, \"*\" or list of strings.\n            None - don't save any metric_details into files.\n            \"*\" - save all the existing metric_details in `engine.state.metric_details` dict into separate files.\n            list of strings - specify the metric_details of expected metrics to save.\n            if not None, every metric_details array will save a separate `{metric name}_raw.csv` file.\n        batch_transform: a callable that is used to extract the `meta_data` dictionary of\n            the input images from `ignite.engine.state.batch` if saving metric details. the purpose is to get the\n            input filenames from the `meta_data` and store with metric details together.\n            `engine.state` and `batch_transform` inherit from the ignite concept:\n            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n        summary_ops: expected computation operations to generate the summary report.\n            it can be: None, \"*\" or list of strings, default to None.\n            None - don't generate summary report for every expected metric_details.\n            \"*\" - generate summary report for every metric_details with all the supported operations.\n            list of strings - generate summary report for every metric_details with specified operations, they\n            should be within list: [\"mean\", \"median\", \"max\", \"min\", \"<int>percentile\", \"std\", \"notnans\"].\n            the number in \"<int>percentile\" should be [0, 100], like: \"15percentile\". default: \"90percentile\".\n            for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.\n            note that: for the overall summary, it computes `nanmean` of all classes for each image first,\n            then compute summary. example of the generated summary report::\n\n                class    mean    median    max    5percentile 95percentile  notnans\n                class0  6.0000   6.0000   7.0000   5.1000      6.9000       2.0000\n                class1  6.0000   6.0000   6.0000   6.0000      6.0000       1.0000\n                mean    6.2500   6.2500   7.0000   5.5750      6.9250       2.0000\n\n        save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0.\n        delimiter: the delimiter character in the saved file, default to \",\" as the default output type is `csv`.\n            to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.\n        output_type: expected output file type, supported types: [\"csv\"], default to \"csv\".\n\n    \"\"\"\n\n    def __init__(\n        self,\n        save_dir: str,\n        metrics: str | Sequence[str] | None = \"*\",\n        metric_details: str | Sequence[str] | None = None,\n        batch_transform: Callable = lambda x: x,\n        summary_ops: str | Sequence[str] | None = None,\n        save_rank: int = 0,\n        delimiter: str = \",\",\n        output_type: str = \"csv\",\n    ) -> None:\n        self.save_dir = save_dir\n        self.metrics = ensure_tuple(metrics) if metrics is not None else None\n        self.metric_details = ensure_tuple(metric_details) if metric_details is not None else None\n        self.batch_transform = batch_transform\n        self.summary_ops = ensure_tuple(summary_ops) if summary_ops is not None else None\n        self.save_rank = save_rank\n        self.deli = delimiter\n        self.output_type = output_type\n        self._filenames: list[str] = []\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(Events.EPOCH_STARTED, self._started)\n        engine.add_event_handler(Events.ITERATION_COMPLETED, self._get_filenames)\n        engine.add_event_handler(Events.EPOCH_COMPLETED, self)\n\n    def _started(self, _engine: Engine) -> None:\n        \"\"\"\n        Initialize internal buffers.\n\n        Args:\n            _engine: Ignite Engine, unused argument.\n\n        \"\"\"\n        self._filenames = []\n\n    def _get_filenames(self, engine: Engine) -> None:\n        if self.metric_details is not None:\n            meta_data = self.batch_transform(engine.state.batch)\n            if isinstance(meta_data, dict):\n                # decollate the `dictionary of list` to `list of dictionaries`\n                meta_data = decollate_batch(meta_data)\n            for m in meta_data:\n                self._filenames.append(f\"{m.get(Key.FILENAME_OR_OBJ)}\")\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        ws = idist.get_world_size()\n        if self.save_rank >= ws:\n            raise ValueError(\"target save rank is greater than the distributed group size.\")\n\n        # all gather file names across ranks\n        _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames\n\n        # only save metrics to file in specified rank\n        if idist.get_rank() == self.save_rank:\n            _metrics = {}\n            if self.metrics is not None and len(engine.state.metrics) > 0:\n                _metrics = {k: v for k, v in engine.state.metrics.items() if k in self.metrics or \"*\" in self.metrics}\n            _metric_details = {}\n            if hasattr(engine.state, \"metric_details\"):\n                details = engine.state.metric_details\n                if self.metric_details is not None and len(details) > 0:\n                    for k, v in details.items():\n                        if k in self.metric_details or \"*\" in self.metric_details:\n                            _metric_details[k] = v\n\n            write_metrics_reports(\n                save_dir=self.save_dir,\n                images=None if len(_images) == 0 else _images,\n                metrics=_metrics,\n                metric_details=_metric_details,\n                summary_ops=self.summary_ops,\n                deli=self.deli,\n                output_type=self.output_type,\n            )\n"
  },
  {
    "path": "monai/handlers/mlflow_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport time\nimport warnings\nfrom collections.abc import Callable, Mapping, Sequence\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any\n\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom monai.apps.utils import get_logger\nfrom monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nmlflow, _ = optional_import(\"mlflow\", descriptor=\"Please install mlflow before using MLFlowHandler.\")\nmlflow.entities, _ = optional_import(\n    \"mlflow.entities\", descriptor=\"Please install mlflow.entities before using MLFlowHandler.\"\n)\nMlflowException, _ = optional_import(\n    \"mlflow.exceptions\", name=\"MlflowException\", descriptor=\"Please install mlflow before using MLFlowHandler.\"\n)\npandas, _ = optional_import(\"pandas\", descriptor=\"Please install pandas for recording the dataset.\")\ntqdm, _ = optional_import(\"tqdm\", \"4.47.0\", min_version, \"tqdm\")\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n\nDEFAULT_TAG = \"Loss\"\n\nlogger = get_logger(module_name=__name__)\n\n\nclass MLFlowHandler:\n    \"\"\"\n    MLFlowHandler defines a set of Ignite Event-handlers for the MLFlow tracking logics.\n    It can be used for any Ignite Engine(trainer, validator and evaluator).\n    And it can track both epoch level and iteration level logging, then MLFlow can store\n    the data and visualize.\n    The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``.\n\n    Default behaviors:\n        - When EPOCH_COMPLETED, track each dictionary item in\n          ``engine.state.metrics`` in MLFlow.\n        - When ITERATION_COMPLETED, track expected item in\n          ``self.output_transform(engine.state.output)`` in MLFlow, default to `Loss`.\n\n    Usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb.\n\n    Args:\n        tracking_uri: connects to a tracking URI. can also set the `MLFLOW_TRACKING_URI` environment\n            variable to have MLflow find a URI from there. in both cases, the URI can either be\n            an HTTP/HTTPS URI for a remote server, a database connection string, or a local path\n            to log data to a directory. The URI defaults to path `mlruns`.\n            for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri.\n        iteration_log: whether to log data to MLFlow when iteration completed, default to `True`.\n            ``iteration_log`` can be also a function and it will be interpreted as an event filter\n            (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).\n            Event filter function accepts as input engine and event value (iteration) and should return True/False.\n        epoch_log: whether to log data to MLFlow when epoch completed, default to `True`.\n            ``epoch_log`` can be also a function and it will be interpreted as an event filter.\n            See ``iteration_log`` argument for more details.\n        epoch_logger: customized callable logger for epoch level logging with MLFlow.\n            Must accept parameter \"engine\", use default logger if None.\n        iteration_logger: customized callable logger for iteration level logging with MLFlow.\n            Must accept parameter \"engine\", use default logger if None.\n        dataset_logger: customized callable logger to log the dataset information with MLFlow.\n            Must accept parameter \"dataset_dict\", use default logger if None.\n        dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch\n            dataset, that needs to be recorded. This arg is only useful when MLFlow version >= 2.4.0.\n            For more details about how to log data with MLFlow, please go to the website:\n            https://mlflow.org/docs/latest/python_api/mlflow.data.html.\n        dataset_keys: a key or a collection of keys to indicate contents in the dataset that\n            need to be stored by MLFlow.\n        output_transform: a callable that is used to transform the\n            ``ignite.engine.state.output`` into a scalar to track, or a dictionary of {key: scalar}.\n            By default this value logging happens when every iteration completed.\n            The default behavior is to track loss from output[0] as output is a decollated list\n            and we replicated loss value for every item of the decollated list.\n            `engine.state` and `output_transform` inherit from the ignite concept:\n            https://pytorch-ignite.ai/concepts/03-state/, explanation and usage example are in the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n        global_epoch_transform: a callable that is used to customize global epoch number.\n            For example, in evaluation, the evaluator engine might want to track synced epoch number\n            with the trainer engine.\n        state_attributes: expected attributes from `engine.state`, if provided, will extract them\n            when epoch completed.\n        tag_name: when iteration output is a scalar, `tag_name` is used to track, defaults to `'Loss'`.\n        experiment_name: the experiment name of MLflow, default to `'monai_experiment'`. An experiment can be\n            used to record several runs.\n        run_name: the run name in an experiment. A run can be used to record information about a workflow,\n            like the loss, metrics and so on.\n        experiment_param: a dict recording parameters which will not change through the whole workflow,\n            like torch version, cuda version and so on.\n        artifacts: paths to images that need to be recorded after running the workflow.\n        optimizer_param_names: parameter names in the optimizer that need to be recorded during running the\n            workflow, default to `'lr'`.\n        close_on_complete: whether to close the mlflow run in `complete` phase in workflow, default to False.\n\n    For more details of MLFlow usage, please refer to: https://mlflow.org/docs/latest/index.html.\n\n    \"\"\"\n\n    # parameters that are logged at the start of training\n    default_tracking_params = [\"max_epochs\", \"epoch_length\"]\n\n    def __init__(\n        self,\n        tracking_uri: str | None = None,\n        iteration_log: bool | Callable[[Engine, int], bool] = True,\n        epoch_log: bool | Callable[[Engine, int], bool] = True,\n        epoch_logger: Callable[[Engine], Any] | None = None,\n        iteration_logger: Callable[[Engine], Any] | None = None,\n        dataset_logger: Callable[[Mapping[str, Dataset]], Any] | None = None,\n        dataset_dict: Mapping[str, Dataset] | None = None,\n        dataset_keys: str = CommonKeys.IMAGE,\n        output_transform: Callable = lambda x: x[0],\n        global_epoch_transform: Callable = lambda x: x,\n        state_attributes: Sequence[str] | None = None,\n        tag_name: str = DEFAULT_TAG,\n        experiment_name: str = \"monai_experiment\",\n        run_name: str | None = None,\n        experiment_param: dict | None = None,\n        artifacts: str | Sequence[Path] | None = None,\n        optimizer_param_names: str | Sequence[str] = \"lr\",\n        close_on_complete: bool = False,\n    ) -> None:\n        self.iteration_log = iteration_log\n        self.epoch_log = epoch_log\n        self.epoch_logger = epoch_logger\n        self.iteration_logger = iteration_logger\n        self.dataset_logger = dataset_logger\n        self.output_transform = output_transform\n        self.global_epoch_transform = global_epoch_transform\n        self.state_attributes = state_attributes\n        self.tag_name = tag_name\n        self.experiment_name = experiment_name\n        self.run_name = run_name\n        self.experiment_param = experiment_param\n        self.artifacts = ensure_tuple(artifacts)\n        self.optimizer_param_names = ensure_tuple(optimizer_param_names)\n        self.client = mlflow.MlflowClient(tracking_uri=tracking_uri if tracking_uri else None)\n        self.run_finish_status = mlflow.entities.RunStatus.to_string(mlflow.entities.RunStatus.FINISHED)\n        self.close_on_complete = close_on_complete\n        self.experiment = None\n        self.cur_run = None\n        self.dataset_dict = dataset_dict\n        self.dataset_keys = ensure_tuple(dataset_keys)\n\n    def _delete_exist_param_in_dict(self, param_dict: dict) -> None:\n        \"\"\"\n        Delete parameters in given dict, if they are already logged by current mlflow run.\n\n        Args:\n            param_dict: parameter dict to be logged to mlflow.\n        \"\"\"\n        if self.cur_run is None:\n            return\n\n        key_list = list(param_dict.keys())\n        log_data = self.client.get_run(self.cur_run.info.run_id).data\n        log_param_dict = log_data.params\n        for key in key_list:\n            if key in log_param_dict:\n                del param_dict[key]\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Register a set of Ignite Event-Handlers to a specified Ignite engine.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if not engine.has_event_handler(self.start, Events.STARTED):\n            engine.add_event_handler(Events.STARTED, self.start)\n        if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):\n            event = Events.ITERATION_COMPLETED\n            if callable(self.iteration_log):  # substitute event with new one using filter callable\n                event = event(event_filter=self.iteration_log)\n            engine.add_event_handler(event, self.iteration_completed)\n        if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):\n            event = Events.EPOCH_COMPLETED\n            if callable(self.epoch_log):  # substitute event with new one using filter callable\n                event = event(event_filter=self.epoch_log)\n            engine.add_event_handler(event, self.epoch_completed)\n        if not engine.has_event_handler(self.complete, Events.COMPLETED):\n            engine.add_event_handler(Events.COMPLETED, self.complete)\n        if self.close_on_complete and (not engine.has_event_handler(self.close, Events.COMPLETED)):\n            engine.add_event_handler(Events.COMPLETED, self.close)\n\n    def start(self, engine: Engine) -> None:\n        \"\"\"\n        Check MLFlow status and start if not active.\n\n        \"\"\"\n        self._set_experiment()\n        if not self.experiment:\n            raise ValueError(f\"Failed to set experiment '{self.experiment_name}' as the active experiment\")\n\n        if not self.cur_run:\n            run_name = f\"run_{time.strftime('%Y%m%d_%H%M%S')}\" if self.run_name is None else self.run_name\n            runs = self.client.search_runs(self.experiment.experiment_id)\n            runs = [r for r in runs if r.info.run_name == run_name or not self.run_name]\n            # runs marked as finish should not record info any more\n            runs = [r for r in runs if r.info.status != self.run_finish_status]\n            if runs:\n                self.cur_run = self.client.get_run(runs[-1].info.run_id)  # pick latest active run\n            else:\n                self.cur_run = self.client.create_run(experiment_id=self.experiment.experiment_id, run_name=run_name)\n\n        if self.experiment_param:\n            self._log_params(self.experiment_param)\n\n        attrs = {attr: getattr(engine.state, attr, None) for attr in self.default_tracking_params}\n        self._delete_exist_param_in_dict(attrs)\n        self._log_params(attrs)\n\n        if self.dataset_logger:\n            self.dataset_logger(self.dataset_dict)\n        else:\n            self._default_dataset_log(self.dataset_dict)\n\n    def _set_experiment(self):\n        experiment = self.experiment\n        if not experiment:\n            for _retry_time in range(3):\n                try:\n                    experiment = self.client.get_experiment_by_name(self.experiment_name)\n                    if not experiment:\n                        experiment_id = self.client.create_experiment(self.experiment_name)\n                        experiment = self.client.get_experiment(experiment_id)\n                    break\n                except MlflowException as e:\n                    if \"RESOURCE_ALREADY_EXISTS\" in str(e):\n                        logger.warning(\"Experiment already exists; delaying before retrying.\")\n                        time.sleep(1)\n                        if _retry_time == 2:\n                            raise e\n                    else:\n                        raise e\n\n        if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE:\n            raise ValueError(f\"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment\")\n        self.experiment = experiment\n\n    @staticmethod\n    def _get_pandas_dataset_info(pandas_dataset):\n        dataset_name = pandas_dataset.name\n        return {\n            f\"{dataset_name}_digest\": pandas_dataset.digest,\n            f\"{dataset_name}_samples\": pandas_dataset.profile[\"num_rows\"],\n        }\n\n    def _log_dataset(self, sample_dict: dict[str, Any], context: str = \"train\") -> None:\n        if not self.cur_run:\n            raise ValueError(\"Current Run is not Active to log the dataset\")\n\n        # Need to update the self.cur_run to sync the dataset log, otherwise the `inputs` info will be out-of-date.\n        self.cur_run = self.client.get_run(self.cur_run.info.run_id)\n        logged_set = [x for x in self.cur_run.inputs.dataset_inputs if x.dataset.name.startswith(context)]\n        # In case there are datasets with the same name.\n        dataset_count = str(len(logged_set))\n        dataset_name = f\"{context}_dataset_{dataset_count}\"\n        sample_df = pandas.DataFrame(sample_dict)\n        dataset = mlflow.data.from_pandas(sample_df, name=dataset_name)\n        exist_dataset_list = list(\n            filter(lambda x: x.dataset.digest == dataset.digest, self.cur_run.inputs.dataset_inputs)\n        )\n\n        if not len(exist_dataset_list):\n            datasets = [mlflow.entities.DatasetInput(dataset._to_mlflow_entity())]\n            self.client.log_inputs(run_id=self.cur_run.info.run_id, datasets=datasets)\n            dataset_info = MLFlowHandler._get_pandas_dataset_info(dataset)\n            self._log_params(dataset_info)\n\n    def _log_params(self, params: dict[str, Any]) -> None:\n        if not self.cur_run:\n            raise ValueError(\"Current Run is not Active to log params\")\n        params_arr = [mlflow.entities.Param(key, str(value)) for key, value in params.items()]\n        self.client.log_batch(run_id=self.cur_run.info.run_id, metrics=[], params=params_arr, tags=[])\n\n    def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        if not self.cur_run:\n            raise ValueError(\"Current Run is not Active to log metrics\")\n\n        run_id = self.cur_run.info.run_id\n        timestamp = int(time.time() * 1000)\n        metrics_arr = [\n            mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items()\n        ]\n        self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])\n\n    def _parse_artifacts(self):\n        \"\"\"\n        Log artifacts to mlflow. Given a path, all files in the path will be logged recursively.\n        Given a file, it will be logged to mlflow.\n        \"\"\"\n        artifact_list = []\n        for path_name in self.artifacts:\n            # in case the input is (None,) by default\n            if not path_name:\n                continue\n            if os.path.isfile(path_name):\n                artifact_list.append(path_name)\n            else:\n                for root, _, filenames in os.walk(path_name):\n                    for filename in filenames:\n                        file_path = os.path.join(root, filename)\n                        artifact_list.append(file_path)\n        return artifact_list\n\n    def complete(self) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation completed Event.\n        \"\"\"\n        if self.artifacts and self.cur_run:\n            artifact_list = self._parse_artifacts()\n            for artifact in artifact_list:\n                self.client.log_artifact(self.cur_run.info.run_id, artifact)\n\n    def close(self) -> None:\n        \"\"\"\n        Stop current running logger of MLFlow.\n\n        \"\"\"\n        if self.cur_run:\n            self.client.set_terminated(self.cur_run.info.run_id, self.run_finish_status)\n            self.cur_run = None\n\n    def epoch_completed(self, engine: Engine) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation epoch completed Event.\n        Track epoch level log, default values are from Ignite `engine.state.metrics` dict.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.epoch_logger is not None:\n            self.epoch_logger(engine)\n        else:\n            self._default_epoch_log(engine)\n\n    def iteration_completed(self, engine: Engine) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation iteration completed Event.\n        Track iteration level log.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.iteration_logger is not None:\n            self.iteration_logger(engine)\n        else:\n            self._default_iteration_log(engine)\n\n    def _default_epoch_log(self, engine: Engine) -> None:\n        \"\"\"\n        Execute epoch level log operation.\n        Default to track the values from Ignite `engine.state.metrics` dict and\n        track the values of specified attributes of `engine.state`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        log_dict = engine.state.metrics\n        if not log_dict:\n            return\n\n        current_epoch = self.global_epoch_transform(engine.state.epoch)\n        self._log_metrics(log_dict, step=current_epoch)\n\n        if self.state_attributes is not None:\n            attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}\n            self._log_metrics(attrs, step=current_epoch)\n\n    def _default_iteration_log(self, engine: Engine) -> None:\n        \"\"\"\n        Execute iteration log operation based on Ignite `engine.state.output` data.\n        Log the values from `self.output_transform(engine.state.output)`.\n        Since `engine.state.output` is a decollated list and we replicated the loss value for every item\n        of the decollated list, the default behavior is to track the loss from `output[0]`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        loss = self.output_transform(engine.state.output)\n        if loss is None:\n            return\n\n        if not isinstance(loss, dict):\n            loss = {self.tag_name: loss.item() if isinstance(loss, torch.Tensor) else loss}\n\n        self._log_metrics(loss, step=engine.state.iteration)\n\n        # If there is optimizer attr in engine, then record parameters specified in init function.\n        if hasattr(engine, \"optimizer\"):\n            cur_optimizer = engine.optimizer\n            for param_name in self.optimizer_param_names:\n                params = {\n                    f\"{param_name}_group_{i}\": float(param_group[param_name])\n                    for i, param_group in enumerate(cur_optimizer.param_groups)\n                }\n                self._log_metrics(params, step=engine.state.iteration)\n\n    def _default_dataset_log(self, dataset_dict: Mapping[str, Dataset] | None) -> None:\n        \"\"\"\n        Execute dataset log operation based on the input dataset_dict. The dataset_dict should have a format\n        like:\n            {\n                \"dataset_name0\": dataset0,\n                \"dataset_name1\": dataset1,\n                ......\n            }\n        The keys stand for names of datasets, which will be logged as prefixes of dataset names in MLFlow.\n        The values are PyTorch datasets from which sample names are abstracted to build a Pandas DataFrame.\n        If the input dataset_dict is None, this function will directly return and do nothing.\n\n        To use this function, every sample in the input datasets must contain keys specified by the `dataset_keys`\n        parameter.\n        This function will log a PandasDataset to MLFlow inputs, generated from the Pandas DataFrame.\n        For more details about PandasDataset, please refer to this link:\n        https://mlflow.org/docs/latest/python_api/mlflow.data.html#mlflow.data.pandas_dataset.PandasDataset\n\n        Please note that it may take a while to record the dataset if it has too many samples.\n\n        Args:\n            dataset_dict: a dictionary in which the key is the name of the dataset and the value is a PyTorch\n                dataset, that needs to be recorded.\n\n        \"\"\"\n\n        if dataset_dict is None:\n            return\n        elif len(dataset_dict) == 0:\n            warnings.warn(\"There is no dataset to log!\")\n\n        # Log datasets to MLFlow one by one.\n        for dataset_type, dataset in dataset_dict.items():\n            if dataset is None:\n                raise AttributeError(f\"The {dataset_type} dataset of is None. Cannot record it by MLFlow.\")\n\n            sample_dict: dict[str, list[str]] = {}\n            dataset_samples = getattr(dataset, \"data\", [])\n            for sample in tqdm(dataset_samples, f\"Recording the {dataset_type} dataset\"):\n                for key in self.dataset_keys:\n                    if key not in sample_dict:\n                        sample_dict[key] = []\n\n                    if key in sample:\n                        value_to_log = sample[key]\n                    else:\n                        raise KeyError(f\"Unexpect key '{key}' in the sample.\")\n\n                    if not isinstance(value_to_log, str):\n                        warnings.warn(\n                            f\"Expected type string, got type {type(value_to_log)} of the {key} name.\"\n                            \"May log an empty dataset in MLFlow\"\n                        )\n                    else:\n                        sample_dict[key].append(value_to_log)\n            self._log_dataset(sample_dict, dataset_type)\n"
  },
  {
    "path": "monai/handlers/nvtx_handlers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nWrapper around NVIDIA Tools Extension for profiling MONAI ignite workflow\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import\n\n_nvtx, _ = optional_import(\"torch._C._nvtx\", descriptor=\"NVTX is not installed. Are you sure you have a CUDA build?\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine, Events\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n    Events, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\", as_type=\"decorator\"\n    )\n\n__all__ = [\"RangeHandler\", \"RangePushHandler\", \"RangePopHandler\", \"MarkHandler\"]\n\n\nclass RangeHandler:\n    \"\"\"\n    Attach a NVTX range to a pair of Ignite events.\n    It pushes an NVTX range at the first event and pops it at the second event.\n    Stores zero-based depth of the range that is started.\n\n    Args:\n        events: a string, pair of Ignite events, pair of Ignite event literals, or pair of Ignite events and literals.\n            If a single string is provided, it should  describe the base name of a pair of default Ignite events\n            with _STARTED and _COMPLETED postfix (like \"EPOCH\" for Events.EPOCH_STARTED and Events.EPOCH_COMPLETED).\n            The accepted events are: BATCH, ITERATION, EPOCH, and ENGINE.\n            If pair of literals, each should be the literal equivalent of an Ignite event, fo instance:\n            (\"EPOCH_STARTED\" and \"EPOCH_COMPLETED\").\n            One can combine events and literals, like (Events.EPOCH_STARTED and \"EPOCH_COMPLETED\").\n            For the complete list of Events,\n            check https://pytorch.org/ignite/generated/ignite.engine.events.Events.html.\n\n        msg: ASCII message to associate with range.\n            If not provided, the name of first event will be assigned to the NVTX range.\n    \"\"\"\n\n    def __init__(self, events: str | tuple[str | Events, str | Events], msg: str | None = None) -> None:\n        self.events = self.resolve_events(events)\n        if msg is None:\n            if isinstance(events, str):\n                # assign the prefix of the events\n                msg = events\n            else:\n                # combine events' names\n                msg = \"/\".join([e.name for e in self.events])\n        self.msg = msg\n        self.depth = None\n\n    def resolve_events(self, events: str | tuple) -> tuple[Events, Events]:\n        \"\"\"\n        Resolve the input events to create a pair of Ignite events\n        \"\"\"\n        events = ensure_tuple(events)\n        if len(events) == 1:\n            return self.create_paired_events(events[0])\n        if len(events) == 2:\n            return self.get_event(events[0]), self.get_event(events[1])\n        raise ValueError(f\"Exactly two Ignite events should be provided [received {len(events)}].\")\n\n    def create_paired_events(self, event: str) -> tuple[Events, Events]:\n        \"\"\"\n        Create pair of Ignite events from a event prefix name\n        \"\"\"\n        event = event.upper()\n        event_prefix = {\"\": \"\", \"ENGINE\": \"\", \"EPOCH\": \"EPOCH_\", \"ITERATION\": \"ITERATION_\", \"BATCH\": \"GET_BATCH_\"}\n        return self.get_event(event_prefix[event] + \"STARTED\"), self.get_event(event_prefix[event] + \"COMPLETED\")\n\n    def get_event(self, event: str | Events) -> Events:\n        return Events[event.upper()] if isinstance(event, str) else event\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Attach an NVTX Range to specific Ignite events\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(self.events[0], self.range_push)\n        engine.add_event_handler(self.events[1], self.range_pop)\n\n    def range_push(self):\n        self.depth = _nvtx.rangePushA(self.msg)\n\n    def range_pop(self):\n        _nvtx.rangePop()\n\n\nclass RangePushHandler:\n    \"\"\"\n    At a specific event, pushes a range onto a stack of nested range span.\n    Stores zero-based depth of the range that is started.\n\n    Args:\n        msg: ASCII message to associate with range\n    \"\"\"\n\n    def __init__(self, event: str | Events, msg: str | None = None) -> None:\n        self.event = Events[event.upper()] if isinstance(event, str) else event\n        if msg is None:\n            msg = self.event.name\n        self.msg = msg\n        self.depth = None\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Push an NVTX range at a specific Ignite event\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(self.event, self.range_push)\n\n    def range_push(self):\n        self.depth = _nvtx.rangePushA(self.msg)\n\n\nclass RangePopHandler:\n    \"\"\"\n    At a specific event, pop a previously pushed range.\n    Stores zero-based depth of the range that is started.\n\n    Args:\n        msg: ASCII message to associate with range\n    \"\"\"\n\n    def __init__(self, event: str | Events) -> None:\n        self.event = Events[event.upper()] if isinstance(event, str) else event\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Pop an NVTX range at a specific Ignite event\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(self.event, self.range_pop)\n\n    def range_pop(self):\n        _nvtx.rangePop()\n\n\nclass MarkHandler:\n    \"\"\"\n    Mark an instantaneous event that occurred at some point.\n\n    Args:\n        msg: ASCII message to associate with range\n    \"\"\"\n\n    def __init__(self, event: str | Events, msg: str | None = None) -> None:\n        self.event = Events[event.upper()] if isinstance(event, str) else event\n        if msg is None:\n            msg = self.event.name\n        self.msg = msg\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Add an NVTX mark to a specific Ignite event\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(self.event, self.mark)\n\n    def mark(self):\n        _nvtx.markA(self.msg)\n"
  },
  {
    "path": "monai/handlers/panoptic_quality.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import PanopticQualityMetric\nfrom monai.utils import MetricReduction\n\n\nclass PanopticQuality(IgniteMetricHandler):\n    \"\"\"\n    Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes: int,\n        metric_name: str = \"pq\",\n        reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,\n        match_iou_threshold: float = 0.5,\n        smooth_numerator: float = 1e-6,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            num_classes: number of classes. The number should not count the background.\n            metric_name: output metric. The value can be \"pq\", \"sq\" or \"rq\".\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n            match_iou_threshold: IOU threshold to determine the pairing between `y_pred` and `y`. Usually,\n                it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical.\n                If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the\n                maximal amount of unique pairing.\n            smooth_numerator: a small constant added to the numerator to avoid zero.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: panoptic quality of\n                every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:meth:`monai.metrics.panoptic_quality.compute_panoptic_quality`\n        \"\"\"\n        metric_fn = PanopticQualityMetric(\n            num_classes=num_classes,\n            metric_name=metric_name,\n            reduction=reduction,\n            match_iou_threshold=match_iou_threshold,\n            smooth_numerator=smooth_numerator,\n        )\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/parameter_scheduler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nfrom bisect import bisect_right\nfrom collections.abc import Callable\nfrom typing import TYPE_CHECKING\n\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine, Events\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    Events, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\n\nclass ParamSchedulerHandler:\n    \"\"\"\n    General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or\n    multistep function. One can also pass Callables to have customized scheduling logic.\n\n    Args:\n        parameter_setter (Callable): Function that sets the required parameter\n        value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')\n         or Callable for custom logic.\n        vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.\n        epoch_level (bool): Whether the step is based on epoch or iteration. Defaults to False.\n        name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.\n        event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.\n    \"\"\"\n\n    def __init__(\n        self,\n        parameter_setter: Callable,\n        value_calculator: str | Callable,\n        vc_kwargs: dict,\n        epoch_level: bool = False,\n        name: str | None = None,\n        event: str | None = None,\n    ):\n        self.epoch_level = epoch_level\n        self.event = event if event is not None else Events.ITERATION_COMPLETED\n\n        self._calculators = {\n            \"linear\": self._linear,\n            \"exponential\": self._exponential,\n            \"step\": self._step,\n            \"multistep\": self._multistep,\n        }\n\n        self._parameter_setter = parameter_setter\n        self._vc_kwargs = vc_kwargs\n        self._value_calculator = self._get_value_calculator(value_calculator=value_calculator)\n\n        self.logger = logging.getLogger(name)\n        self._name = name\n\n    def _get_value_calculator(self, value_calculator):\n        if isinstance(value_calculator, str):\n            return self._calculators[value_calculator]\n        if callable(value_calculator):\n            return value_calculator\n        raise ValueError(\n            f\"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable.\"\n        )\n\n    def __call__(self, engine: Engine) -> None:\n        if self.epoch_level:\n            self._vc_kwargs[\"current_step\"] = engine.state.epoch\n        else:\n            self._vc_kwargs[\"current_step\"] = engine.state.iteration\n\n        new_value = self._value_calculator(**self._vc_kwargs)\n        self._parameter_setter(new_value)\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine that is used for training.\n        \"\"\"\n        if self._name is None:\n            self.logger = engine.logger\n        engine.add_event_handler(self.event, self)\n\n    @staticmethod\n    def _linear(\n        initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int\n    ) -> float:\n        \"\"\"\n        Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an\n        additional step_one steps passed. Continues the trend until it reaches max_value.\n\n        Args:\n            initial_value (float): Starting value of the parameter.\n            step_constant (int): Step index until parameter's value is kept constant.\n            step_max_value (int): Step index at which parameter's value becomes max_value.\n            max_value (float): Max parameter value.\n            current_step (int): Current step index.\n\n        Returns:\n            float: new parameter value\n        \"\"\"\n        if current_step <= step_constant:\n            delta = 0.0\n        elif current_step > step_max_value:\n            delta = max_value - initial_value\n        else:\n            delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant)\n\n        return initial_value + delta\n\n    @staticmethod\n    def _exponential(initial_value: float, gamma: float, current_step: int) -> float:\n        \"\"\"\n        Decays the parameter value by gamma every step.\n\n        Based on the closed form of ExponentialLR from Pytorch:\n        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html.\n\n        Args:\n            initial_value (float): Starting value of the parameter.\n            gamma (float): Multiplicative factor of parameter value decay.\n            current_step (int): Current step index.\n\n        Returns:\n            float: new parameter value\n        \"\"\"\n        return initial_value * gamma**current_step\n\n    @staticmethod\n    def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float:\n        \"\"\"\n        Decays the parameter value by gamma every step_size.\n\n        Based on StepLR from Pytorch:\n        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html.\n\n        Args:\n            initial_value (float): Starting value of the parameter.\n            gamma (float): Multiplicative factor of parameter value decay.\n            step_size (int): Period of parameter value decay.\n            current_step (int): Current step index.\n\n        Returns\n            float: new parameter value\n        \"\"\"\n        return initial_value * gamma ** (current_step // step_size)\n\n    @staticmethod\n    def _multistep(initial_value: float, gamma: float, milestones: list[int], current_step: int) -> float:\n        \"\"\"\n        Decays the parameter value by gamma once the number of steps reaches one of the milestones.\n\n        Based on MultiStepLR from Pytorch.\n        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html.\n\n        Args:\n            initial_value (float): Starting value of the parameter.\n            gamma (float): Multiplicative factor of parameter value decay.\n            milestones (List[int]): List of step indices. Must be increasing.\n            current_step (int): Current step index.\n\n        Returns:\n            float: new parameter value\n        \"\"\"\n        return initial_value * gamma ** bisect_right(milestones, current_step)\n"
  },
  {
    "path": "monai/handlers/postprocessing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\nfrom typing import TYPE_CHECKING\n\nfrom monai.engines.utils import IterationEvents, engine_apply_transform\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass PostProcessing:\n    \"\"\"\n    Ignite handler to execute additional post processing after the post processing in engines.\n    So users can insert other handlers between engine postprocessing and this post processing handler.\n    If using components from `monai.transforms` as the `transform`, recommend to decollate `engine.state.batch`\n    and `engine.state.batch` in the engine(set `decollate=True`) or in the `DecollateBatch` handler first.\n\n    \"\"\"\n\n    def __init__(self, transform: Callable, event: str = \"MODEL_COMPLETED\") -> None:\n        \"\"\"\n        Args:\n            transform: callable function to execute on the `engine.state.batch` and `engine.state.output`.\n                can also be composed transforms.\n            event: expected EVENT to attach the handler, should be \"MODEL_COMPLETED\" or \"ITERATION_COMPLETED\".\n                default to \"MODEL_COMPLETED\".\n\n        \"\"\"\n        self.transform = transform\n        event = event.upper()\n        if event not in (\"MODEL_COMPLETED\", \"ITERATION_COMPLETED\"):\n            raise ValueError(\"event should be `MODEL_COMPLETED` or `ITERATION_COMPLETED`.\")\n        self.event = event\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.event == \"MODEL_COMPLETED\":\n            engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self)\n        else:\n            engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list):\n            engine.state.batch, engine.state.output = engine_apply_transform(\n                batch=engine.state.batch, output=engine.state.output, transform=self.transform\n            )\n        else:\n            for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)):\n                engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, self.transform)\n"
  },
  {
    "path": "monai/handlers/probability_maps.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport threading\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\n\nfrom monai.config import DtypeLike\nfrom monai.data.folder_layout import FolderLayout\nfrom monai.utils import ProbMapKeys, min_version, optional_import\nfrom monai.utils.enums import CommonKeys, IgniteInfo\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass ProbMapProducer:\n    \"\"\"\n    Event handler triggered on completing every iteration to calculate and save the probability map.\n    This handler use metadata from MetaTensor to create the probability map. This can be simply achieved by using\n    `monai.data.SlidingPatchWSIDataset` or `monai.data.MaskedPatchWSIDataset` as the dataset.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        output_dir: str = \"./\",\n        output_postfix: str = \"\",\n        prob_key: str = \"pred\",\n        dtype: DtypeLike = np.float64,\n        name: str | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            output_dir: output directory to save probability maps.\n            output_postfix: a string appended to all output file names.\n            prob_key: the key associated to the probability output of the model\n            dtype: the data type in which the probability map is stored. Default np.float64.\n            name: identifier of logging.logger to use, defaulting to `engine.logger`.\n\n        \"\"\"\n        self.folder_layout = FolderLayout(\n            output_dir=output_dir,\n            postfix=output_postfix,\n            extension=\".npy\",\n            parent=False,\n            makedirs=True,\n            data_root_dir=\"\",\n        )\n\n        self.logger = logging.getLogger(name)\n        self._name = name\n        self.prob_key = prob_key\n        self.dtype = dtype\n        self.prob_map: dict[str, np.ndarray] = {}\n        self.counter: dict[str, int] = {}\n        self.num_done_images: int = 0\n        self.num_images: int = 0\n        self.lock = threading.Lock()\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n\n        image_data = engine.data_loader.dataset.image_data  # type: ignore\n        self.num_images = len(image_data)\n\n        # Initialized probability maps for all the images\n        for sample in image_data:\n            name = sample[ProbMapKeys.NAME]\n            self.counter[name] = sample[ProbMapKeys.COUNT]\n            self.prob_map[name] = np.zeros(sample[ProbMapKeys.SIZE], dtype=self.dtype)\n\n        if self._name is None:\n            self.logger = engine.logger\n        if not engine.has_event_handler(self, Events.ITERATION_COMPLETED):\n            engine.add_event_handler(Events.ITERATION_COMPLETED, self)\n        if not engine.has_event_handler(self.finalize, Events.COMPLETED):\n            engine.add_event_handler(Events.COMPLETED, self.finalize)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        This method assumes self.batch_transform will extract metadata from the input batch.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if not isinstance(engine.state.batch, dict) or not isinstance(engine.state.output, dict):\n            raise ValueError(\"engine.state.batch and engine.state.output must be dictionaries.\")\n        names = engine.state.batch[CommonKeys.IMAGE].meta[ProbMapKeys.NAME]\n        locs = engine.state.batch[CommonKeys.IMAGE].meta[ProbMapKeys.LOCATION]\n        probs = engine.state.output[self.prob_key]\n        for name, loc, prob in zip(names, locs, probs):\n            self.prob_map[name][tuple(loc)] = prob\n            with self.lock:\n                self.counter[name] -= 1\n                if self.counter[name] == 0:\n                    self.save_prob_map(name)\n\n    def save_prob_map(self, name: str) -> None:\n        \"\"\"\n        This method save the probability map for an image, when its inference is finished,\n        and delete that probability map from memory.\n\n        Args:\n            name: the name of image to be saved.\n        \"\"\"\n        file_path = self.folder_layout.filename(name)\n        np.save(file_path, self.prob_map[name])\n\n        self.num_done_images += 1\n        self.logger.info(f\"Inference of '{name}' is done [{self.num_done_images}/{self.num_images}]!\")\n        del self.prob_map[name]\n        del self.counter[name]\n\n    def finalize(self, engine: Engine) -> None:\n        self.logger.info(f\"Probability map is created for {self.num_done_images}/{self.num_images} images!\")\n"
  },
  {
    "path": "monai/handlers/regression_metrics.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric\nfrom monai.utils import MetricReduction\n\n\nclass MeanSquaredError(IgniteMetricHandler):\n    \"\"\"\n    Computes Mean Squared Error from full size Tensor and collects average over batch, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: mean squared error of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:class:`monai.metrics.MSEMetric`\n        \"\"\"\n        metric_fn = MSEMetric(reduction=reduction)\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n\n\nclass MeanAbsoluteError(IgniteMetricHandler):\n    \"\"\"\n    Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: mean squared error of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:class:`monai.metrics.MAEMetric`\n        \"\"\"\n        metric_fn = MAEMetric(reduction=reduction)\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n\n\nclass RootMeanSquaredError(IgniteMetricHandler):\n    \"\"\"\n    Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: mean squared error of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        See also:\n            :py:class:`monai.metrics.RMSEMetric`\n        \"\"\"\n        metric_fn = RMSEMetric(reduction=reduction)\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n\n\nclass PeakSignalToNoiseRatio(IgniteMetricHandler):\n    \"\"\"\n    Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_val: int | float,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            max_val: The dynamic range of the images/volumes (i.e., the difference between the\n                maximum and the minimum allowed values e.g. 255 for a uint8 image).\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: mean squared error of every image.\n                default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n\n        See also:\n            :py:class:`monai.metrics.PSNRMetric`\n        \"\"\"\n        metric_fn = PSNRMetric(max_val=max_val, reduction=reduction)\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/roc_auc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import ROCAUCMetric\nfrom monai.utils import Average\n\n\nclass ROCAUC(IgniteMetricHandler):\n    \"\"\"\n    Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).\n    accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.\n\n    Args:\n        average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n            Type of averaging performed if not binary classification. Defaults to ``\"macro\"``.\n\n            - ``\"macro\"``: calculate metrics for each label, and find their unweighted mean.\n                This does not take label imbalance into account.\n            - ``\"weighted\"``: calculate metrics for each label, and find their average,\n                weighted by support (the number of true instances for each label).\n            - ``\"micro\"``: calculate metrics globally by considering each element of the label\n                indicator matrix as a label.\n            - ``\"none\"``: the scores for each class are returned.\n\n        output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n            construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n            lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n            `engine.state` and `output_transform` inherit from the ignite concept:\n            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n\n    Note:\n        ROCAUC expects y to be comprised of 0's and 1's.\n        y_pred must either be probability estimates or confidence values.\n\n    \"\"\"\n\n    def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:\n        metric_fn = ROCAUCMetric(average=Average(average))\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)\n"
  },
  {
    "path": "monai/handlers/smartcache_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom monai.data import SmartCacheDataset\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass SmartCacheHandler:\n    \"\"\"\n    Attach SmartCache logic to the engine in Ignite.\n    Mainly include the `start`, `update_cache`, and `shutdown` functions of SmartCacheDataset.\n\n    \"\"\"\n\n    def __init__(self, smartcacher: SmartCacheDataset) -> None:\n        \"\"\"\n        Args:\n            smartcacher: predefined SmartCacheDataset, will attach it to the engine.\n\n        Raises:\n            TypeError: When ``smartcacher`` is not a ``monai.data.SmartCacheDataset``.\n\n        \"\"\"\n        if not isinstance(smartcacher, SmartCacheDataset):\n            raise TypeError(\"smartcacher must be a monai.data.SmartCacheDataset.\")\n        self.smartcacher = smartcacher\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        engine.add_event_handler(Events.STARTED, self.started)\n        engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)\n        engine.add_event_handler(Events.COMPLETED, self.completed)\n\n    def started(self, engine: Engine) -> None:\n        \"\"\"Callback for train or validation/evaluation started Event.\n        Start the replacement thread of SmartCacheDataset.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        self.smartcacher.start()\n\n    def epoch_completed(self, engine: Engine) -> None:\n        \"\"\"Callback for train or validation/evaluation epoch completed Event.\n        Update cache content with replacement data.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        self.smartcacher.update_cache()\n\n    def completed(self, engine: Engine) -> None:\n        \"\"\"Callback for train or validation/evaluation completed Event.\n        Stop the replacement thread of SmartCacheDataset.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        self.smartcacher.shutdown()\n"
  },
  {
    "path": "monai/handlers/stats_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import TYPE_CHECKING, Any\n\nimport torch\n\nfrom monai.apps import get_logger\nfrom monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n\nDEFAULT_KEY_VAL_FORMAT = \"{}: {:.4f} \"\nDEFAULT_TAG = \"Loss\"\n\n\nclass StatsHandler:\n    \"\"\"\n    StatsHandler defines a set of Ignite Event-handlers for all the log printing logics.\n    It can be used for any Ignite Engine(trainer, validator and evaluator).\n    And it can support logging for epoch level and iteration level with pre-defined loggers.\n\n    Note that if ``name`` is None, this class will leverage `engine.logger` as the logger, otherwise,\n    ``logging.getLogger(name)`` is used. In both cases, it's important to make sure that the logging level is at least\n    ``INFO``. To change the level of logging, please call ``import ignite; ignite.utils.setup_logger(name)``\n    (when ``name`` is not None) or ``engine.logger = ignite.utils.setup_logger(engine.logger.name, reset=True)``\n    (when ``name`` is None) before running the engine with this handler attached.\n\n    Default behaviors:\n        - When EPOCH_COMPLETED, logs ``engine.state.metrics`` using ``self.logger``.\n        - When ITERATION_COMPLETED, logs\n          ``self.output_transform(engine.state.output)`` using ``self.logger``.\n\n    Usage example::\n\n        import ignite\n        import monai\n\n        trainer = ignite.engine.Engine(lambda x, y: [0.0])  # an example trainer\n        monai.handlers.StatsHandler(name=\"train_stats\").attach(trainer)\n\n        trainer.run(range(3), max_epochs=4)\n\n    More details of example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/modules/engines/unet_training_dict.py.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        iteration_log: bool | Callable[[Engine, int], bool] = True,\n        epoch_log: bool | Callable[[Engine, int], bool] = True,\n        epoch_print_logger: Callable[[Engine], Any] | None = None,\n        iteration_print_logger: Callable[[Engine], Any] | None = None,\n        output_transform: Callable = lambda x: x[0],\n        global_epoch_transform: Callable = lambda x: x,\n        state_attributes: Sequence[str] | None = None,\n        name: str | None = \"monai.handlers.StatsHandler\",\n        tag_name: str = DEFAULT_TAG,\n        key_var_format: str = DEFAULT_KEY_VAL_FORMAT,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            iteration_log: whether to log data when iteration completed, default to `True`. ``iteration_log`` can\n                be also a function and it will be interpreted as an event filter\n                (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).\n                Event filter function accepts as input engine and event value (iteration) and should return True/False.\n                Event filtering can be helpful to customize iteration logging frequency.\n            epoch_log: whether to log data when epoch completed, default to `True`. ``epoch_log`` can be\n                also a function and it will be interpreted as an event filter. See ``iteration_log`` argument for more\n                details.\n            epoch_print_logger: customized callable printer for epoch level logging.\n                Must accept parameter \"engine\", use default printer if None.\n            iteration_print_logger: customized callable printer for iteration level logging.\n                Must accept parameter \"engine\", use default printer if None.\n            output_transform: a callable that is used to transform the\n                ``ignite.engine.state.output`` into a scalar to print, or a dictionary of {key: scalar}.\n                In the latter case, the output string will be formatted as key: value.\n                By default this value logging happens when every iteration completed.\n                The default behavior is to print loss from output[0] as output is a decollated list\n                and we replicated loss value for every item of the decollated list.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            global_epoch_transform: a callable that is used to customize global epoch number.\n                For example, in evaluation, the evaluator engine might want to print synced epoch number\n                with the trainer engine.\n            state_attributes: expected attributes from `engine.state`, if provided, will extract them\n                when epoch completed.\n            name: identifier of `logging.logger` to use, if None, defaulting to ``engine.logger``.\n            tag_name: when iteration output is a scalar, tag_name is used to print\n                tag_name: scalar_value to logger. Defaults to ``'Loss'``.\n            key_var_format: a formatting string to control the output string format of key: value.\n\n        \"\"\"\n\n        self.iteration_log = iteration_log\n        self.epoch_log = epoch_log\n        self.epoch_print_logger = epoch_print_logger\n        self.iteration_print_logger = iteration_print_logger\n        self.output_transform = output_transform\n        self.global_epoch_transform = global_epoch_transform\n        self.state_attributes = state_attributes\n        self.tag_name = tag_name\n        self.key_var_format = key_var_format\n        self.logger = get_logger(name)  # type: ignore\n        self.name = name\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Register a set of Ignite Event-Handlers to a specified Ignite engine.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.name is None:\n            self.logger = engine.logger\n        if self.logger.getEffectiveLevel() > logging.INFO:\n            suggested = f\"\\n\\nimport ignite\\nignite.utils.setup_logger('{self.logger.name}', reset=True)\"\n            if self.logger.name != engine.logger.name:\n                suggested += f\"\\nignite.utils.setup_logger('{engine.logger.name}', reset=True)\"\n            suggested += \"\\n\\n\"\n            warnings.warn(\n                f\"the effective log level of {self.logger.name} is higher than INFO, StatsHandler may not output logs,\"\n                f\"\\nplease use the following code before running the engine to enable it: {suggested}\"\n            )\n        if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):\n            event = Events.ITERATION_COMPLETED\n            if callable(self.iteration_log):  # substitute event with new one using filter callable\n                event = event(event_filter=self.iteration_log)\n            engine.add_event_handler(event, self.iteration_completed)\n        if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):\n            event = Events.EPOCH_COMPLETED\n            if callable(self.epoch_log):  # substitute event with new one using filter callable\n                event = event(event_filter=self.epoch_log)\n            engine.add_event_handler(event, self.epoch_completed)\n        if not engine.has_event_handler(self.exception_raised, Events.EXCEPTION_RAISED):\n            engine.add_event_handler(Events.EXCEPTION_RAISED, self.exception_raised)\n\n    def epoch_completed(self, engine: Engine) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation epoch completed Event.\n        Print epoch level log, default values are from Ignite `engine.state.metrics` dict.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.epoch_print_logger is not None:\n            self.epoch_print_logger(engine)\n        else:\n            self._default_epoch_print(engine)\n\n    def iteration_completed(self, engine: Engine) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation iteration completed Event.\n        Print iteration level log, default values are from Ignite `engine.state.output`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.iteration_print_logger is not None:\n            self.iteration_print_logger(engine)\n        else:\n            self._default_iteration_print(engine)\n\n    def exception_raised(self, _engine: Engine, e: Exception) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation exception raised Event.\n        Print the exception information and traceback. This callback may be skipped because the logic\n        with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.\n\n        Args:\n            _engine: Ignite Engine, unused argument.\n            e: the exception caught in Ignite during engine.run().\n\n        \"\"\"\n        self.logger.exception(f\"Exception: {e}\")\n        raise e\n\n    def _default_epoch_print(self, engine: Engine) -> None:\n        \"\"\"\n        Execute epoch level log operation.\n        Default to print the values from Ignite `engine.state.metrics` dict and\n        print the values of specified attributes of `engine.state`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        current_epoch = self.global_epoch_transform(engine.state.epoch)\n        prints_dict = flatten_dict(engine.state.metrics)\n        if prints_dict is not None and len(prints_dict) > 0:\n            out_str = f\"Epoch[{current_epoch}] Metrics -- \"\n            for name in sorted(prints_dict):\n                value = prints_dict[name]\n                out_str += self.key_var_format.format(name, value) if is_scalar(value) else f\"{name}: {str(value)}\"\n            self.logger.info(out_str)\n\n        if (\n            hasattr(engine.state, \"key_metric_name\")\n            and hasattr(engine.state, \"best_metric\")\n            and hasattr(engine.state, \"best_metric_epoch\")\n            and engine.state.key_metric_name is not None\n        ):\n            out_str = f\"Key metric: {engine.state.key_metric_name} \"\n            out_str += f\"best value: {engine.state.best_metric} \"\n            out_str += f\"at epoch: {engine.state.best_metric_epoch}\"\n            self.logger.info(out_str)\n\n        if self.state_attributes is not None and len(self.state_attributes) > 0:\n            out_str = \"State values: \"\n            for attr in self.state_attributes:\n                out_str += f\"{attr}: {getattr(engine.state, attr, None)} \"\n            self.logger.info(out_str)\n\n    def _default_iteration_print(self, engine: Engine) -> None:\n        \"\"\"\n        Execute iteration log operation based on Ignite `engine.state.output` data.\n        Print the values from `self.output_transform(engine.state.output)`.\n        Since `engine.state.output` is a decollated list and we replicated the loss value for every item\n        of the decollated list, the default behavior is to print the loss from `output[0]`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        loss = self.output_transform(engine.state.output)\n        if loss is None:\n            return  # no printing if the output is empty\n\n        out_str = \"\"\n        if isinstance(loss, dict):  # print dictionary items\n            for name in sorted(loss):\n                value = loss[name]\n                if not is_scalar(value):\n                    warnings.warn(\n                        \"ignoring non-scalar output in StatsHandler,\"\n                        \" make sure `output_transform(engine.state.output)` returns\"\n                        \" a scalar or dictionary of key and scalar pairs to avoid this warning.\"\n                        f\" {name}:{type(value)}\"\n                    )\n                    continue  # not printing multi dimensional output\n                out_str += self.key_var_format.format(name, value.item() if isinstance(value, torch.Tensor) else value)\n        elif is_scalar(loss):  # not printing multi dimensional output\n            out_str += self.key_var_format.format(\n                self.tag_name, loss.item() if isinstance(loss, torch.Tensor) else loss\n            )\n        else:\n            warnings.warn(\n                \"ignoring non-scalar output in StatsHandler,\"\n                \" make sure `output_transform(engine.state.output)` returns\"\n                \" a scalar or a dictionary of key and scalar pairs to avoid this warning.\"\n                f\" {type(loss)}\"\n            )\n\n        if not out_str:\n            return  # no value to print\n\n        num_iterations = engine.state.epoch_length\n        current_iteration = engine.state.iteration\n        if num_iterations is not None:\n            current_iteration = (current_iteration - 1) % num_iterations + 1\n        current_epoch = engine.state.epoch\n        num_epochs = engine.state.max_epochs\n\n        base_str = f\"Epoch: {current_epoch}/{num_epochs}, Iter: {current_iteration}/{num_iterations} --\"\n\n        self.logger.info(\" \".join([base_str, out_str]))\n"
  },
  {
    "path": "monai/handlers/surface_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nfrom monai.handlers.ignite_metric import IgniteMetricHandler\nfrom monai.metrics import SurfaceDistanceMetric\nfrom monai.utils import MetricReduction\n\n\nclass SurfaceDistance(IgniteMetricHandler):\n    \"\"\"\n    Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = False,\n        symmetric: bool = False,\n        distance_metric: str = \"euclidean\",\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        output_transform: Callable = lambda x: x,\n        save_details: bool = True,\n    ) -> None:\n        \"\"\"\n\n        Args:\n            include_background: whether to include distance computation on the first channel of the predicted output.\n                Defaults to ``False``.\n            symmetric: whether to calculate the symmetric average surface distance between\n                `seg_pred` and `seg_gt`. Defaults to ``False``.\n            distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n                the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n            reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then\n                construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or\n                lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            save_details: whether to save metric computation details per image, for example: surface dice\n                of every image. default to True, will save to `engine.state.metric_details` dict with the metric name as key.\n\n        \"\"\"\n        metric_fn = SurfaceDistanceMetric(\n            include_background=include_background,\n            symmetric=symmetric,\n            distance_metric=distance_metric,\n            reduction=reduction,\n        )\n        super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)\n"
  },
  {
    "path": "monai/handlers/tensorboard_handlers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import IgniteInfo, is_scalar, min_version, optional_import\nfrom monai.visualize import plot_2d_or_3d_image\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine\n    from tensorboardX import SummaryWriter as SummaryWriterX\n    from torch.utils.tensorboard import SummaryWriter\nelse:\n    Engine, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\", as_type=\"decorator\"\n    )\n    SummaryWriter, _ = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n    SummaryWriterX, _ = optional_import(\"tensorboardX\", name=\"SummaryWriter\")\n\nDEFAULT_TAG = \"Loss\"\n\n\nclass TensorBoardHandler:\n    \"\"\"\n    Base class for the handlers to write data into TensorBoard.\n\n    Args:\n        summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,\n            default to create a new TensorBoard writer.\n        log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.\n\n    \"\"\"\n\n    def __init__(self, summary_writer: SummaryWriter | SummaryWriterX | None = None, log_dir: str = \"./runs\"):\n        if summary_writer is None:\n            self._writer = SummaryWriter(log_dir=log_dir)\n            self.internal_writer = True\n        else:\n            self._writer = summary_writer\n            self.internal_writer = False\n\n    def attach(self, engine: Engine) -> None:\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def close(self):\n        \"\"\"\n        Close the summary writer if created in this TensorBoard handler.\n\n        \"\"\"\n        if self.internal_writer:\n            self._writer.close()\n\n\nclass TensorBoardStatsHandler(TensorBoardHandler):\n    \"\"\"\n    TensorBoardStatsHandler defines a set of Ignite Event-handlers for all the TensorBoard logics.\n    It can be used for any Ignite Engine(trainer, validator and evaluator).\n    And it can support both epoch level and iteration level with pre-defined TensorBoard event writer.\n    The expected data source is Ignite ``engine.state.output`` and ``engine.state.metrics``.\n\n    Default behaviors:\n        - When EPOCH_COMPLETED, write each dictionary item in\n          ``engine.state.metrics`` to TensorBoard.\n        - When ITERATION_COMPLETED, write each dictionary item in\n          ``self.output_transform(engine.state.output)`` to TensorBoard.\n\n    Usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        summary_writer: SummaryWriter | SummaryWriterX | None = None,\n        log_dir: str = \"./runs\",\n        iteration_log: bool | Callable[[Engine, int], bool] | int = True,\n        epoch_log: bool | Callable[[Engine, int], bool] | int = True,\n        epoch_event_writer: Callable[[Engine, Any], Any] | None = None,\n        iteration_event_writer: Callable[[Engine, Any], Any] | None = None,\n        output_transform: Callable = lambda x: x[0],\n        global_epoch_transform: Callable = lambda x: x,\n        state_attributes: Sequence[str] | None = None,\n        tag_name: str = DEFAULT_TAG,\n    ) -> None:\n        \"\"\"\n        Args:\n            summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,\n                default to create a new TensorBoard writer.\n            log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.\n            iteration_log: whether to write data to TensorBoard when iteration completed, default to `True`.\n                ``iteration_log`` can be also a function or int. If it is an int, it will be interpreted as the iteration interval\n                at which the iteration_event_writer is called. If it is a function, it will be interpreted as an event filter\n                (see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).\n                Event filter function accepts as input engine and event value (iteration) and should return True/False.\n            epoch_log: whether to write data to TensorBoard when epoch completed, default to `True`.\n                ``epoch_log`` can be also a function or int. If it is an int, it will be interpreted as the epoch interval\n                at which the epoch_event_writer is called. If it is a function, it will be interpreted as an event filter.\n                See ``iteration_log`` argument for more details.\n            epoch_event_writer: customized callable TensorBoard writer for epoch level.\n                Must accept parameter \"engine\" and \"summary_writer\", use default event writer if None.\n            iteration_event_writer: customized callable TensorBoard writer for iteration level.\n                Must accept parameter \"engine\" and \"summary_writer\", use default event writer if None.\n            output_transform: a callable that is used to transform the\n                ``ignite.engine.state.output`` into a scalar to plot, or a dictionary of {key: scalar}.\n                In the latter case, the output string will be formatted as key: value.\n                By default this value plotting happens when every iteration completed.\n                The default behavior is to print loss from output[0] as output is a decollated list\n                and we replicated loss value for every item of the decollated list.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            global_epoch_transform: a callable that is used to customize global epoch number.\n                For example, in evaluation, the evaluator engine might want to use trainer engines epoch number\n                when plotting epoch vs metric curves.\n            state_attributes: expected attributes from `engine.state`, if provided, will extract them\n                when epoch completed.\n            tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.\n        \"\"\"\n\n        super().__init__(summary_writer=summary_writer, log_dir=log_dir)\n        self.iteration_log = iteration_log\n        self.epoch_log = epoch_log\n        self.epoch_event_writer = epoch_event_writer\n        self.iteration_event_writer = iteration_event_writer\n        self.output_transform = output_transform\n        self.global_epoch_transform = global_epoch_transform\n        self.state_attributes = state_attributes\n        self.tag_name = tag_name\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Register a set of Ignite Event-Handlers to a specified Ignite engine.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):\n            event = Events.ITERATION_COMPLETED\n            if callable(self.iteration_log):  # substitute event with new one using filter callable\n                event = event(event_filter=self.iteration_log)\n            elif self.iteration_log > 1:\n                event = event(every=self.iteration_log)\n            engine.add_event_handler(event, self.iteration_completed)\n        if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):\n            event = Events.EPOCH_COMPLETED\n            if callable(self.epoch_log):  # substitute event with new one using filter callable\n                event = event(event_filter=self.epoch_log)\n            elif self.epoch_log > 1:\n                event = event(every=self.epoch_log)\n            engine.add_event_handler(event, self.epoch_completed)\n\n    def epoch_completed(self, engine: Engine) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation epoch completed Event.\n        Write epoch level events, default values are from Ignite `engine.state.metrics` dict.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.epoch_event_writer is not None:\n            self.epoch_event_writer(engine, self._writer)\n        else:\n            self._default_epoch_writer(engine, self._writer)\n\n    def iteration_completed(self, engine: Engine) -> None:\n        \"\"\"\n        Handler for train or validation/evaluation iteration completed Event.\n        Write iteration level events, default values are from Ignite `engine.state.output`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        \"\"\"\n        if self.iteration_event_writer is not None:\n            self.iteration_event_writer(engine, self._writer)\n        else:\n            self._default_iteration_writer(engine, self._writer)\n\n    def _write_scalar(\n        self, _engine: Engine, writer: SummaryWriter | SummaryWriterX, tag: str, value: Any, step: int\n    ) -> None:\n        \"\"\"\n        Write scale value into TensorBoard.\n        Default to call `SummaryWriter.add_scalar()`.\n\n        Args:\n            _engine: Ignite Engine, unused argument.\n            writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.\n            tag: tag name in the TensorBoard.\n            value: value of the scalar data for current step.\n            step: index of current step.\n\n        \"\"\"\n        writer.add_scalar(tag, value, step)\n\n    def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter | SummaryWriterX) -> None:\n        \"\"\"\n        Execute epoch level event write operation.\n        Default to write the values from Ignite `engine.state.metrics` dict and\n        write the values of specified attributes of `engine.state`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n            writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.\n\n        \"\"\"\n        current_epoch = self.global_epoch_transform(engine.state.epoch)\n        summary_dict = engine.state.metrics\n        for name, value in summary_dict.items():\n            if is_scalar(value):\n                self._write_scalar(engine, writer, name, value, current_epoch)\n\n        if self.state_attributes is not None:\n            for attr in self.state_attributes:\n                self._write_scalar(engine, writer, attr, getattr(engine.state, attr, None), current_epoch)\n        writer.flush()\n\n    def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter | SummaryWriterX) -> None:\n        \"\"\"\n        Execute iteration level event write operation based on Ignite `engine.state.output` data.\n        Extract the values from `self.output_transform(engine.state.output)`.\n        Since `engine.state.output` is a decollated list and we replicated the loss value for every item\n        of the decollated list, the default behavior is to track the loss from `output[0]`.\n\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n            writer: TensorBoard  or TensorBoardX writer, passed or created in TensorBoardHandler.\n\n        \"\"\"\n        loss = self.output_transform(engine.state.output)\n        if loss is None:\n            return  # do nothing if output is empty\n        if isinstance(loss, dict):\n            for name in sorted(loss):\n                value = loss[name]\n                if not is_scalar(value):\n                    warnings.warn(\n                        \"ignoring non-scalar output in TensorBoardStatsHandler,\"\n                        \" make sure `output_transform(engine.state.output)` returns\"\n                        \" a scalar or dictionary of key and scalar pairs to avoid this warning.\"\n                        f\" {name}:{type(value)}\"\n                    )\n                    continue  # not plot multi dimensional output\n                self._write_scalar(\n                    _engine=engine,\n                    writer=writer,\n                    tag=name,\n                    value=value.item() if isinstance(value, torch.Tensor) else value,\n                    step=engine.state.iteration,\n                )\n        elif is_scalar(loss):  # not printing multi dimensional output\n            self._write_scalar(\n                _engine=engine,\n                writer=writer,\n                tag=self.tag_name,\n                value=loss.item() if isinstance(loss, torch.Tensor) else loss,\n                step=engine.state.iteration,\n            )\n        else:\n            warnings.warn(\n                \"ignoring non-scalar output in TensorBoardStatsHandler,\"\n                \" make sure `output_transform(engine.state.output)` returns\"\n                \" a scalar or a dictionary of key and scalar pairs to avoid this warning.\"\n                f\" {type(loss)}\"\n            )\n        writer.flush()\n\n\nclass TensorBoardImageHandler(TensorBoardHandler):\n    \"\"\"\n    TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images.\n    2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch,\n    for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images'\n    last three dimensions will be shown as animated GIF along the last axis (typically Depth).\n    And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.\n\n    It can be used for any Ignite Engine (trainer, validator and evaluator).\n    User can easily add it to engine for any expected Event, for example: ``EPOCH_COMPLETED``,\n    ``ITERATION_COMPLETED``. The expected data source is ignite's ``engine.state.batch`` and ``engine.state.output``.\n\n    Default behavior:\n        - Show y_pred as images (GIF for 3D) on TensorBoard when Event triggered,\n        - Need to use ``batch_transform`` and ``output_transform`` to specify\n          how many images to show and show which channel.\n        - Expects ``batch_transform(engine.state.batch)`` to return data\n          format: (image[N, channel, ...], label[N, channel, ...]).\n        - Expects ``output_transform(engine.state.output)`` to return a torch\n          tensor in format (y_pred[N, channel, ...], loss).\n\n    Usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        summary_writer: SummaryWriter | SummaryWriterX | None = None,\n        log_dir: str = \"./runs\",\n        interval: int = 1,\n        epoch_level: bool = True,\n        batch_transform: Callable = lambda x: x,\n        output_transform: Callable = lambda x: x,\n        global_iter_transform: Callable = lambda x: x,\n        index: int = 0,\n        max_channels: int = 1,\n        frame_dim: int = -3,\n        max_frames: int = 64,\n    ) -> None:\n        \"\"\"\n        Args:\n            summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,\n                default to create a new TensorBoard writer.\n            log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.\n            interval: plot content from engine.state every N epochs or every N iterations, default is 1.\n            epoch_level: plot content from engine.state every N epochs or N iterations. `True` is epoch level,\n                `False` is iteration level.\n            batch_transform: a callable that is used to extract `image` and `label` from `ignite.engine.state.batch`,\n                then construct `(image, label)` pair. for example: if `ignite.engine.state.batch` is `{\"image\": xxx,\n                \"label\": xxx, \"other\": xxx}`, `batch_transform` can be `lambda x: (x[\"image\"], x[\"label\"])`.\n                will use the result to plot image from `result[0][index]` and plot label from `result[1][index]`.\n                `engine.state` and `batch_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            output_transform: a callable that is used to extract the `predictions` data from\n                `ignite.engine.state.output`, will use the result to plot output from `result[index]`.\n                `engine.state` and `output_transform` inherit from the ignite concept:\n                https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:\n                https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.\n            global_iter_transform: a callable that is used to customize global step number for TensorBoard.\n                For example, in evaluation, the evaluator engine needs to know current epoch from trainer.\n            index: plot which element in a data batch, default is the first element.\n            max_channels: number of channels to plot.\n            frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,\n                expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)\n            max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.\n        \"\"\"\n        super().__init__(summary_writer=summary_writer, log_dir=log_dir)\n        self.interval = interval\n        self.epoch_level = epoch_level\n        self.batch_transform = batch_transform\n        self.output_transform = output_transform\n        self.global_iter_transform = global_iter_transform\n        self.index = index\n        self.frame_dim = frame_dim\n        self.max_frames = max_frames\n        self.max_channels = max_channels\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.epoch_level:\n            engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self)\n        else:\n            engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n\n        Raises:\n            TypeError: When ``output_transform(engine.state.output)[0]`` type is not in\n                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.\n            TypeError: When ``batch_transform(engine.state.batch)[1]`` type is not in\n                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.\n            TypeError: When ``output_transform(engine.state.output)`` type is not in\n                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.\n\n        \"\"\"\n        step = self.global_iter_transform(engine.state.epoch if self.epoch_level else engine.state.iteration)\n        show_images = self.batch_transform(engine.state.batch)[0][self.index]\n        if isinstance(show_images, torch.Tensor):\n            show_images = show_images.detach().cpu().numpy()\n        if show_images is not None:\n            if not isinstance(show_images, np.ndarray):\n                raise TypeError(\n                    \"output_transform(engine.state.output)[0] must be None or one of \"\n                    f\"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}.\"\n                )\n            plot_2d_or_3d_image(\n                # add batch dim and plot the first item\n                data=show_images[None],\n                step=step,\n                writer=self._writer,\n                index=0,\n                max_channels=self.max_channels,\n                frame_dim=self.frame_dim,\n                max_frames=self.max_frames,\n                tag=\"input_0\",\n            )\n\n        show_labels = self.batch_transform(engine.state.batch)[1][self.index]\n        if isinstance(show_labels, torch.Tensor):\n            show_labels = show_labels.detach().cpu().numpy()\n        if show_labels is not None:\n            if not isinstance(show_labels, np.ndarray):\n                raise TypeError(\n                    \"batch_transform(engine.state.batch)[1] must be None or one of \"\n                    f\"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}.\"\n                )\n            plot_2d_or_3d_image(\n                data=show_labels[None],\n                step=step,\n                writer=self._writer,\n                index=0,\n                max_channels=self.max_channels,\n                frame_dim=self.frame_dim,\n                max_frames=self.max_frames,\n                tag=\"input_1\",\n            )\n\n        show_outputs = self.output_transform(engine.state.output)[self.index]\n        if isinstance(show_outputs, torch.Tensor):\n            show_outputs = show_outputs.detach().cpu().numpy()\n        if show_outputs is not None:\n            if not isinstance(show_outputs, np.ndarray):\n                raise TypeError(\n                    \"output_transform(engine.state.output) must be None or one of \"\n                    f\"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}.\"\n                )\n            plot_2d_or_3d_image(\n                data=show_outputs[None],\n                step=step,\n                writer=self._writer,\n                index=0,\n                max_channels=self.max_channels,\n                frame_dim=self.frame_dim,\n                max_frames=self.max_frames,\n                tag=\"output\",\n            )\n\n        self._writer.flush()\n"
  },
  {
    "path": "monai/handlers/trt_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom monai.networks import trt_compile\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass TrtHandler:\n    \"\"\"\n    TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.\n    Usage example::\n        handler = TrtHandler(model=model, base_path=\"/test/checkpoint.pt\", args={\"precision\": \"fp16\"})\n        handler.attach(engine)\n        engine.run()\n    \"\"\"\n\n    def __init__(self, model, base_path, args=None, submodule=None):\n        \"\"\"\n        Args:\n            base_path: TRT path basename. TRT plan(s) saved to \"base_path[.submodule].plan\"\n            args: passed to trt_compile(). See trt_compile() for details.\n            submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'\n        \"\"\"\n        self.model = model\n        self.base_path = base_path\n        self.args = args\n        self.submodule = submodule\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        self.logger = engine.logger\n        engine.add_event_handler(Events.STARTED, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)\n"
  },
  {
    "path": "monai/handlers/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nfrom collections import OrderedDict\nfrom collections.abc import Callable, Sequence\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection, PathLike\nfrom monai.utils import IgniteInfo, ensure_tuple, look_up_option, min_version, optional_import\n\nidist, _ = optional_import(\"ignite\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"distributed\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n__all__ = [\"stopping_fn_from_metric\", \"stopping_fn_from_loss\", \"write_metrics_reports\", \"from_engine\"]\n\n\ndef stopping_fn_from_metric(metric_name: str) -> Callable[[Engine], Any]:\n    \"\"\"\n    Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.\n    \"\"\"\n\n    def stopping_fn(engine: Engine) -> Any:\n        return engine.state.metrics[metric_name]\n\n    return stopping_fn\n\n\ndef stopping_fn_from_loss() -> Callable[[Engine], Any]:\n    \"\"\"\n    Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.\n    \"\"\"\n\n    def stopping_fn(engine: Engine) -> Any:\n        return -engine.state.output  # type:ignore\n\n    return stopping_fn\n\n\ndef write_metrics_reports(\n    save_dir: PathLike,\n    images: Sequence[str] | None,\n    metrics: dict[str, torch.Tensor | np.ndarray] | None,\n    metric_details: dict[str, torch.Tensor | np.ndarray] | None,\n    summary_ops: str | Sequence[str] | None,\n    deli: str = \",\",\n    output_type: str = \"csv\",\n    class_labels: list[str] | None = None,\n) -> None:\n    \"\"\"\n    Utility function to write the metrics into files, contains 3 parts:\n    1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair.\n    2. if `metric_details` dict is not None,  write raw metric data of every image into file, every line for 1 image.\n    3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file.\n\n    Args:\n        save_dir: directory to save all the metrics reports.\n        images: name or path of every input image corresponding to the metric_details data.\n            if None, will use index number as the filename of every input image.\n        metrics: a dictionary of (metric name, metric value) pairs.\n        metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics\n            computation, for example, the raw value can be the mean_dice of every channel of every input image.\n        summary_ops: expected computation operations to generate the summary report.\n            it can be: None, \"*\" or list of strings, default to None.\n            None - don't generate summary report for every expected metric_details.\n            \"*\" - generate summary report for every metric_details with all the supported operations.\n            list of strings - generate summary report for every metric_details with specified operations, they\n            should be within list: [\"mean\", \"median\", \"max\", \"min\", \"<int>percentile\", \"std\", \"notnans\"].\n            the number in \"<int>percentile\" should be [0, 100], like: \"15percentile\". default: \"90percentile\".\n            for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.\n            note that: for the overall summary, it computes `nanmean` of all classes for each image first,\n            then compute summary. example of the generated summary report::\n\n                class    mean    median    max    5percentile 95percentile  notnans\n                class0  6.0000   6.0000   7.0000   5.1000      6.9000       2.0000\n                class1  6.0000   6.0000   6.0000   6.0000      6.0000       1.0000\n                mean    6.2500   6.2500   7.0000   5.5750      6.9250       2.0000\n\n        deli: the delimiter character in the saved file, default to \",\" as the default output type is `csv`.\n            to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.\n        output_type: expected output file type, supported types: [\"csv\"], default to \"csv\".\n        class_labels: list of class names used to name the classes in the output report, if None,\n            \"class0\", ..., \"classn\" are used, default to None.\n\n    \"\"\"\n    if output_type.lower() != \"csv\":\n        raise ValueError(f\"unsupported output type: {output_type}.\")\n\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n\n    if metrics is not None and len(metrics) > 0:\n        with open(os.path.join(save_dir, \"metrics.csv\"), \"w\") as f:\n            for k, v in metrics.items():\n                f.write(f\"{k}{deli}{str(v)}\\n\")\n    if metric_details is not None and len(metric_details) > 0:\n        for k, v in metric_details.items():\n            if isinstance(v, torch.Tensor):\n                v = v.cpu().numpy()\n            if v.ndim == 0:\n                # reshape to [1, 1] if no batch and class dims\n                v = v.reshape((1, 1))\n            elif v.ndim == 1:\n                # reshape to [N, 1] if no class dim\n                v = v.reshape((-1, 1))\n\n            # add the average value of all classes to v\n            if class_labels is None:\n                class_labels = [\"class\" + str(i) for i in range(v.shape[1])]\n            else:\n                class_labels = [str(i) for i in class_labels]  # ensure to have a list of str\n\n            class_labels += [\"mean\"]\n            v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)\n\n            with open(os.path.join(save_dir, f\"{k}_raw.csv\"), \"w\") as f:\n                f.write(f\"filename{deli}{deli.join(class_labels)}\\n\")\n                for i, b in enumerate(v):\n                    f.write(\n                        f\"{images[i] if images is not None else str(i)}{deli}\"\n                        f\"{deli.join([f'{c:.4f}' if isinstance(c, (int, float)) else str(c) for c in b])}\\n\"\n                    )\n\n            if summary_ops is not None:\n                supported_ops = OrderedDict(\n                    {\n                        \"mean\": np.nanmean,\n                        \"median\": np.nanmedian,\n                        \"max\": np.nanmax,\n                        \"min\": np.nanmin,\n                        \"90percentile\": lambda x: np.nanpercentile(x[0], x[1]),\n                        \"std\": np.nanstd,\n                        \"notnans\": lambda x: (~np.isnan(x)).sum(),\n                    }\n                )\n                ops = ensure_tuple(summary_ops)\n                if \"*\" in ops:\n                    ops = tuple(supported_ops.keys())\n\n                def _compute_op(op: str, d: np.ndarray) -> Any:\n                    if not op.endswith(\"percentile\"):\n                        c_op = look_up_option(op, supported_ops)\n                        return c_op(d)\n\n                    threshold = int(op.split(\"percentile\")[0])\n                    return supported_ops[\"90percentile\"]((d, threshold))  # type: ignore\n\n                with open(os.path.join(save_dir, f\"{k}_summary.csv\"), \"w\") as f:\n                    f.write(f\"class{deli}{deli.join(ops)}\\n\")\n                    for i, c in enumerate(np.transpose(v)):\n                        f.write(f\"{class_labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\\n\")\n\n\ndef from_engine(keys: KeysCollection, first: bool = False) -> Callable:\n    \"\"\"\n    Utility function to simplify the `batch_transform` or `output_transform` args of ignite components\n    when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`).\n    Users only need to set the expected keys, then it will return a callable function to extract data from\n    dictionary and construct a tuple respectively.\n\n    If data is a list of dictionaries after decollating, extract expected keys and construct lists respectively,\n    for example, if data is `[{\"A\": 1, \"B\": 2}, {\"A\": 3, \"B\": 4}]`, from_engine([\"A\", \"B\"]): `([1, 3], [2, 4])`.\n\n    It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward.\n    For example, set the first key as the prediction and the second key as label to get the expected data\n    from `engine.state.output` for a metric::\n\n        from monai.handlers import MeanDice, from_engine\n\n        metric = MeanDice(\n            include_background=False,\n            output_transform=from_engine([\"pred\", \"label\"])\n        )\n\n    Args:\n        keys: specified keys to extract data from dictionary or decollated list of dictionaries.\n        first: whether only extract specified keys from the first item if input data is a list of dictionaries,\n            it's used to extract the scalar data which doesn't have batch dim and was replicated into every\n            dictionary when decollating, like `loss`, etc.\n\n\n    \"\"\"\n    _keys = ensure_tuple(keys)\n\n    def _wrapper(data):\n        if isinstance(data, dict):\n            return tuple(data[k] for k in _keys)\n        if isinstance(data, list) and isinstance(data[0], dict):\n            # if data is a list of dictionaries, extract expected keys and construct lists,\n            # if `first=True`, only extract keys from the first item of the list\n            ret = [data[0][k] if first else [i[k] for i in data] for k in _keys]\n            return tuple(ret) if len(ret) > 1 else ret[0]\n\n    return _wrapper\n\n\ndef ignore_data(x: Any) -> None:\n    \"\"\"\n    Always return `None` for any input data.\n    A typical usage is to avoid logging the engine output of every iteration during evaluation.\n\n    \"\"\"\n    return None\n"
  },
  {
    "path": "monai/handlers/validation_handler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom monai.engines.evaluator import Evaluator\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\nif TYPE_CHECKING:\n    from ignite.engine import Engine\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n\nclass ValidationHandler:\n    \"\"\"\n    Attach validator to the trainer engine in Ignite.\n    It can support to execute validation every N epochs or every N iterations.\n\n    \"\"\"\n\n    def __init__(\n        self, interval: int, validator: Evaluator | None = None, epoch_level: bool = True, exec_at_start: bool = False\n    ) -> None:\n        \"\"\"\n        Args:\n            interval: do validation every N epochs or every N iterations during training.\n            validator: run the validator when trigger validation, suppose to be Evaluator.\n                if None, should call `set_validator()` before training.\n            epoch_level: execute validation every N epochs or N iterations.\n                `True` is epoch level, `False` is iteration level.\n            exec_at_start: whether to execute a validation first when starting the training.\n                default to `False`. It can be useful especially for some transfer-learning cases\n                to validate the initial model before training.\n\n        Raises:\n            TypeError: When ``validator`` is not a ``monai.engines.evaluator.Evaluator``.\n\n        \"\"\"\n        if validator is not None and not isinstance(validator, Evaluator):\n            raise TypeError(f\"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.\")\n        self.validator = validator\n        self.interval = interval\n        self.epoch_level = epoch_level\n        self.exec_at_start = exec_at_start\n\n    def set_validator(self, validator: Evaluator) -> None:\n        \"\"\"\n        Set validator if not setting in the __init__().\n        \"\"\"\n        if not isinstance(validator, Evaluator):\n            raise TypeError(f\"validator must be a monai.engines.evaluator.Evaluator but is {type(validator).__name__}.\")\n        self.validator = validator\n\n    def attach(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.epoch_level:\n            engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.interval), self)\n        else:\n            engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.interval), self)\n        if self.exec_at_start:\n            engine.add_event_handler(Events.STARTED, self)\n\n    def __call__(self, engine: Engine) -> None:\n        \"\"\"\n        Args:\n            engine: Ignite Engine, it can be a trainer, validator or evaluator.\n        \"\"\"\n        if self.validator is None:\n            raise RuntimeError(\"please set validator in __init__() or call `set_validator()` before training.\")\n        self.validator.run(engine.state.epoch)\n"
  },
  {
    "path": "monai/inferers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .inferer import (\n    ControlNetDiffusionInferer,\n    ControlNetLatentDiffusionInferer,\n    DiffusionInferer,\n    Inferer,\n    LatentDiffusionInferer,\n    PatchInferer,\n    SaliencyInferer,\n    SimpleInferer,\n    SliceInferer,\n    SlidingWindowInferer,\n    SlidingWindowInfererAdapt,\n    VQVAETransformerInferer,\n)\nfrom .merger import AvgMerger, Merger, ZarrAvgMerger\nfrom .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter\nfrom .utils import sliding_window_inference\n"
  },
  {
    "path": "monai/inferers/inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Iterable, Iterator, Mapping, Sequence\nfrom functools import partial\nfrom pydoc import locate\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.apps.utils import get_logger\nfrom monai.data import decollate_batch\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.thread_buffer import ThreadBuffer\nfrom monai.inferers.merger import AvgMerger, Merger\nfrom monai.inferers.splitter import Splitter\nfrom monai.inferers.utils import compute_importance_map, sliding_window_inference\nfrom monai.networks.nets import (\n    VQVAE,\n    AutoencoderKL,\n    ControlNet,\n    DecoderOnlyTransformer,\n    DiffusionModelUNet,\n    SPADEAutoencoderKL,\n    SPADEDiffusionModelUNet,\n)\nfrom monai.networks.schedulers import RFlowScheduler, Scheduler\nfrom monai.transforms import CenterSpatialCrop, SpatialPad\nfrom monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import\nfrom monai.visualize import CAM, GradCAM, GradCAMpp\n\ntqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\n\nlogger = get_logger(__name__)\n\n__all__ = [\n    \"Inferer\",\n    \"PatchInferer\",\n    \"SimpleInferer\",\n    \"SlidingWindowInferer\",\n    \"SaliencyInferer\",\n    \"SliceInferer\",\n    \"SlidingWindowInfererAdapt\",\n]\n\n\nclass Inferer(ABC):\n    \"\"\"\n    A base class for model inference.\n    Extend this class to support operations during inference, e.g. a sliding window method.\n\n    Example code::\n\n        device = torch.device(\"cuda:0\")\n        transform = Compose([ToTensor(), LoadImage(image_only=True)])\n        data = transform(img_path).to(device)\n        model = UNet(...).to(device)\n        inferer = SlidingWindowInferer(...)\n\n        model.eval()\n        with torch.no_grad():\n            pred = inferer(inputs=data, network=model)\n        ...\n\n    \"\"\"\n\n    @abstractmethod\n    def __call__(self, inputs: torch.Tensor, network: Callable, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        Run inference on `inputs` with the `network` model.\n\n        Args:\n            inputs: input of the model inference.\n            network: model for inference.\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass PatchInferer(Inferer):\n    \"\"\"\n    Inference on patches instead of the whole image based on Splitter and Merger.\n    This splits the input image into patches and then merge the resulted patches.\n\n    Args:\n        splitter: a `Splitter` object that split the inputs into patches. Defaults to None.\n            If not provided or None, the inputs are considered to be already split into patches.\n            In this case, the output `merged_shape` and the optional `cropped_shape` cannot be inferred\n            and should be explicitly provided.\n        merger_cls: a `Merger` subclass that can be instantiated to merges patch outputs.\n            It can also be a string that matches the name of a class inherited from `Merger` class.\n            Defaults to `AvgMerger`.\n        batch_size: batch size for patches. If the input tensor is already batched [BxCxWxH],\n            this adds additional batching [(Bp*B)xCxWpxHp] for inference on patches.\n            Defaults to 1.\n        preprocessing: a callable that process patches before the being fed to the network.\n            Defaults to None.\n        postprocessing: a callable that process the output of the network.\n            Defaults to None.\n        output_keys: if the network output is a dictionary, this defines the keys of\n            the output dictionary to be used for merging.\n            Defaults to None, where all the keys are used.\n        match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.\n        buffer_size: number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0.\n        merger_kwargs: arguments to be passed to `merger_cls` for instantiation.\n            `merged_shape` is calculated automatically based on the input shape and\n            the output patch shape unless it is passed here.\n    \"\"\"\n\n    def __init__(\n        self,\n        splitter: Splitter | None = None,\n        merger_cls: type[Merger] | str = AvgMerger,\n        batch_size: int = 1,\n        preprocessing: Callable | None = None,\n        postprocessing: Callable | None = None,\n        output_keys: Sequence | None = None,\n        match_spatial_shape: bool = True,\n        buffer_size: int = 0,\n        **merger_kwargs: Any,\n    ) -> None:\n        Inferer.__init__(self)\n        # splitter\n        if not isinstance(splitter, (Splitter, type(None))):\n            if not isinstance(splitter, Splitter):\n                raise TypeError(\n                    f\"'splitter' should be a `Splitter` object that returns: \"\n                    \"an iterable of pairs of (patch, location) or a MetaTensor that has `PatchKeys.LOCATION` metadata).\"\n                    f\"{type(splitter)} is given.\"\n                )\n        self.splitter = splitter\n\n        # merger\n        if isinstance(merger_cls, str):\n            valid_merger_cls: type[Merger]\n            # search amongst implemented mergers in MONAI\n            valid_merger_cls, merger_found = optional_import(\"monai.inferers.merger\", name=merger_cls)\n            if not merger_found:\n                # try to locate the requested merger class (with dotted path)\n                valid_merger_cls = locate(merger_cls)  # type: ignore\n            if valid_merger_cls is None:\n                raise ValueError(f\"The requested `merger_cls` ['{merger_cls}'] does not exist.\")\n            merger_cls = valid_merger_cls\n        if not issubclass(merger_cls, Merger):\n            raise TypeError(f\"'merger' should be a subclass of `Merger`, {merger_cls} is given.\")\n        self.merger_cls = merger_cls\n        self.merger_kwargs = merger_kwargs\n\n        # pre-processor (process patch before the network)\n        if preprocessing is not None and not callable(preprocessing):\n            raise TypeError(f\"'preprocessing' should be a callable object, {type(preprocessing)} is given.\")\n        self.preprocessing = preprocessing\n\n        # post-processor (process the output of the network)\n        if postprocessing is not None and not callable(postprocessing):\n            raise TypeError(f\"'postprocessing' should be a callable object, {type(postprocessing)} is given.\")\n        self.postprocessing = postprocessing\n\n        # batch size for patches\n        if batch_size < 1:\n            raise ValueError(f\"`batch_size` must be a positive number, {batch_size} is given.\")\n        self.batch_size = batch_size\n\n        # model output keys\n        self.output_keys = output_keys\n\n        # whether to crop the output to match the input shape\n        self.match_spatial_shape = match_spatial_shape\n\n        # buffer size for multithreaded batch sampling\n        self.buffer_size = buffer_size\n\n    def _batch_sampler(\n        self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor\n    ) -> Iterator[tuple[torch.Tensor, Sequence, int]]:\n        \"\"\"Generate batch of patches and locations\n\n        Args:\n            patches: a tensor or list of tensors\n\n        Yields:\n            A batch of patches (torch.Tensor or MetaTensor), a sequence of location tuples, and the batch size\n        \"\"\"\n        if isinstance(patches, MetaTensor):\n            total_size = len(patches)\n            for i in range(0, total_size, self.batch_size):\n                batch_size = min(self.batch_size, total_size - i)\n                yield patches[i : i + batch_size], patches[i : i + batch_size].meta[PatchKeys.LOCATION], batch_size  # type: ignore\n        else:\n            buffer: Iterable | ThreadBuffer\n            if self.buffer_size > 0:\n                # Use multi-threading to sample patches with a buffer\n                buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.1)\n            else:\n                buffer = patches\n            patch_batch: list[Any] = [None] * self.batch_size\n            location_batch: list[Any] = [None] * self.batch_size\n            idx_in_batch = 0\n            for sample in buffer:\n                patch_batch[idx_in_batch] = sample[0]\n                location_batch[idx_in_batch] = sample[1]\n                idx_in_batch += 1\n                if idx_in_batch == self.batch_size:\n                    # concatenate batch of patches to create a tensor\n                    yield torch.cat(patch_batch), location_batch, idx_in_batch\n                    patch_batch = [None] * self.batch_size\n                    location_batch = [None] * self.batch_size\n                    idx_in_batch = 0\n            if idx_in_batch > 0:\n                # concatenate batch of patches to create a tensor\n                yield torch.cat(patch_batch[:idx_in_batch]), location_batch, idx_in_batch\n\n    def _ensure_tuple_outputs(self, outputs: Any) -> tuple:\n        if isinstance(outputs, dict):\n            if self.output_keys is None:\n                self.output_keys = list(outputs.keys())  # model's output keys\n            return tuple(outputs[k] for k in self.output_keys)\n        return ensure_tuple(outputs, wrap_array=True)\n\n    def _run_inference(self, network: Callable, patch: torch.Tensor, *args: Any, **kwargs: Any) -> tuple:\n        # pre-process\n        if self.preprocessing:\n            patch = self.preprocessing(patch)\n        # inference\n        outputs = network(patch, *args, **kwargs)\n        # post-process\n        if self.postprocessing:\n            outputs = self.postprocessing(outputs)\n        # ensure we have a tuple of model outputs to support multiple outputs\n        return self._ensure_tuple_outputs(outputs)\n\n    def _initialize_mergers(self, inputs, outputs, patches, batch_size):\n        in_patch = torch.chunk(patches, batch_size)[0]\n        mergers = []\n        ratios = []\n        for out_patch_batch in outputs:\n            out_patch = torch.chunk(out_patch_batch, batch_size)[0]\n            # calculate the ratio of input and output patch sizes\n            ratio = tuple(op / ip for ip, op in zip(in_patch.shape[2:], out_patch.shape[2:]))\n\n            # calculate merged_shape and cropped_shape\n            merger_kwargs = self.merger_kwargs.copy()\n            cropped_shape, merged_shape = self._get_merged_shapes(inputs, out_patch, ratio)\n            if \"merged_shape\" not in merger_kwargs:\n                merger_kwargs[\"merged_shape\"] = merged_shape\n                if merger_kwargs[\"merged_shape\"] is None:\n                    raise ValueError(\"`merged_shape` cannot be `None`.\")\n            if \"cropped_shape\" not in merger_kwargs:\n                merger_kwargs[\"cropped_shape\"] = cropped_shape\n\n            # initialize the merger\n            merger = self.merger_cls(**merger_kwargs)\n\n            # store mergers and input/output ratios\n            mergers.append(merger)\n            ratios.append(ratio)\n\n        return mergers, ratios\n\n    def _aggregate(self, outputs, locations, batch_size, mergers, ratios):\n        for output_patches, merger, ratio in zip(outputs, mergers, ratios):\n            # split batched output into individual patches and then aggregate\n            for in_loc, out_patch in zip(locations, torch.chunk(output_patches, batch_size)):\n                out_loc = [round(l * r) for l, r in zip(in_loc, ratio)]\n                merger.aggregate(out_patch, out_loc)\n\n    def _get_merged_shapes(self, inputs, out_patch, ratio):\n        \"\"\"Define the shape of merged tensors (non-padded and padded)\"\"\"\n        if self.splitter is None:\n            return None, None\n\n        # input spatial shapes\n        original_spatial_shape = self.splitter.get_input_shape(inputs)\n        padded_spatial_shape = self.splitter.get_padded_shape(inputs)\n\n        # output spatial shapes\n        output_spatial_shape = tuple(round(s * r) for s, r in zip(original_spatial_shape, ratio))\n        padded_output_spatial_shape = tuple(round(s * r) for s, r in zip(padded_spatial_shape, ratio))\n\n        # output shapes\n        cropped_shape = out_patch.shape[:2] + output_spatial_shape\n        merged_shape = out_patch.shape[:2] + padded_output_spatial_shape\n\n        if not self.match_spatial_shape:\n            cropped_shape = merged_shape\n\n        return cropped_shape, merged_shape\n\n    def __call__(\n        self,\n        inputs: torch.Tensor,\n        network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n        *args: Any,\n        **kwargs: Any,\n    ) -> Any:\n        \"\"\"\n        Args:\n            inputs: input data for inference, a torch.Tensor, representing an image or batch of images.\n                However if the data is already split, it can be fed by providing a list of tuple (patch, location),\n                or a MetaTensor that has metadata for `PatchKeys.LOCATION`. In both cases no splitter should be provided.\n            network: target model to execute inference.\n                supports callables such as ``lambda x: my_torch_model(x, additional_config)``\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n            condition (torch.Tensor, optional): If provided via `**kwargs`,\n                this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.\n                The resulting segments will be passed to the model together with the corresponding input segments.\n\n        \"\"\"\n        # check if there is a conditioning signal\n        condition = kwargs.pop(\"condition\", None)\n        # shape check for condition\n        if condition is not None:\n            if isinstance(inputs, torch.Tensor) and isinstance(condition, torch.Tensor):\n                if condition.shape != inputs.shape:\n                    raise ValueError(\n                        f\"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}\"\n                    )\n            elif isinstance(inputs, list) and isinstance(condition, list):\n                if len(inputs) != len(condition):\n                    raise ValueError(\n                        f\"Length of `condition` must match `inputs`. Got {len(inputs)} and {len(condition)}.\"\n                    )\n                for (in_patch, _), (cond_patch, _) in zip(inputs, condition):\n                    if cond_patch.shape != in_patch.shape:\n                        raise ValueError(\n                            \"Each `condition` patch must match the shape of the corresponding input patch. \"\n                            f\"Got {cond_patch.shape} and {in_patch.shape}.\"\n                        )\n            else:\n                raise ValueError(\n                    \"`condition` and `inputs` must be of the same type (both Tensor or both list of patches).\"\n                )\n\n        patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor\n        if self.splitter is None:\n            # handle situations where the splitter is not provided\n            if isinstance(inputs, torch.Tensor):\n                if isinstance(inputs, MetaTensor):\n                    if PatchKeys.LOCATION not in inputs.meta:\n                        raise ValueError(\n                            \"`PatchKey.LOCATION` does not exists in `inputs.meta`. \"\n                            \"If the inputs are already split into patches, the location of patches needs to be \"\n                            \"provided as `PatchKey.LOCATION` metadata in a MetaTensor. \"\n                            \"If the input is not already split, please provide `splitter`.\"\n                        )\n                else:\n                    raise ValueError(\n                        \"`splitter` should be set if the input is not already split into patches. \"\n                        \"For inputs that are split, the location of patches needs to be provided as \"\n                        \"(image, location) pairs, or as `PatchKey.LOCATION` metadata in a MetaTensor. \"\n                        f\"The provided inputs type is {type(inputs)}.\"\n                    )\n            patches_locations = inputs\n            if condition is not None:\n                condition_locations = condition\n        else:\n            # apply splitter\n            patches_locations = self.splitter(inputs)\n            if condition is not None:\n                # apply splitter to condition\n                condition_locations = self.splitter(condition)\n\n        ratios: list[float] = []\n        mergers: list[Merger] = []\n        if condition is not None:\n            for (patches, locations, batch_size), (condition_patches, _, _) in zip(\n                self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)\n            ):\n                # add patched condition to kwargs\n                kwargs[\"condition\"] = condition_patches\n                # run inference\n                outputs = self._run_inference(network, patches, *args, **kwargs)\n                # initialize the mergers\n                if not mergers:\n                    mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)\n                # aggregate outputs\n                self._aggregate(outputs, locations, batch_size, mergers, ratios)\n        else:\n            for patches, locations, batch_size in self._batch_sampler(patches_locations):\n                # run inference\n                outputs = self._run_inference(network, patches, *args, **kwargs)\n                # initialize the mergers\n                if not mergers:\n                    mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)\n                # aggregate outputs\n                self._aggregate(outputs, locations, batch_size, mergers, ratios)\n\n        # finalize the mergers and get the results\n        merged_outputs = [merger.finalize() for merger in mergers]\n\n        # return according to the model output\n        if self.output_keys:\n            return dict(zip(self.output_keys, merged_outputs))\n        if len(merged_outputs) == 1:\n            return merged_outputs[0]\n        return merged_outputs\n\n\nclass SimpleInferer(Inferer):\n    \"\"\"\n    SimpleInferer is the normal inference method that run model forward() directly.\n    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.\n\n    \"\"\"\n\n    def __init__(self) -> None:\n        Inferer.__init__(self)\n\n    def __call__(\n        self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any\n    ) -> torch.Tensor:\n        \"\"\"Unified callable function API of Inferers.\n\n        Args:\n            inputs: model input data for inference.\n            network: target model to execute inference.\n                supports callables such as ``lambda x: my_torch_model(x, additional_config)``\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n\n        \"\"\"\n        return network(inputs, *args, **kwargs)\n\n\nclass SlidingWindowInferer(Inferer):\n    \"\"\"\n    Sliding window method for model inference,\n    with `sw_batch_size` windows for every model.forward().\n    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.\n\n    Args:\n        roi_size: the window size to execute SlidingWindow evaluation.\n            If it has non-positive components, the corresponding `inputs` size will be used.\n            if the components of the `roi_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        sw_batch_size: the batch size to run window slices.\n        overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``.\n        mode: {``\"constant\"``, ``\"gaussian\"``}\n            How to blend output of overlapping windows. Defaults to ``\"constant\"``.\n\n            - ``\"constant``\": gives equal weight to all predictions.\n            - ``\"gaussian``\": gives less weight to predictions on edges of windows.\n\n        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``\"gaussian\"``.\n            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.\n            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding\n            spatial dimensions.\n        padding_mode: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}\n            Padding mode when ``roi_size`` is larger than inputs. Defaults to ``\"constant\"``\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        cval: fill value for 'constant' padding mode. Default: 0\n        sw_device: device for the window data.\n            By default the device (and accordingly the memory) of the `inputs` is used.\n            Normally `sw_device` should be consistent with the device where `predictor` is defined.\n        device: device for the stitched output prediction.\n            By default the device (and accordingly the memory) of the `inputs` is used. If for example\n            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the\n            `inputs` and `roi_size`. Output is on the `device`.\n        progress: whether to print a tqdm progress bar.\n        cache_roi_weight_map: whether to precompute the ROI weight map.\n        cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)\n            when input image volume is larger than this threshold (in pixels/voxels).\n            Otherwise use ``\"device\"``. Thus, the output may end-up on either cpu or gpu.\n        buffer_steps: the number of sliding window iterations along the ``buffer_dim``\n            to be buffered on ``sw_device`` before writing to ``device``.\n            (Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.)\n            default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size,\n            (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.\n        buffer_dim: the spatial dimension along which the buffers are created.\n            0 indicates the first spatial dimension. Default is -1, the last spatial dimension.\n        with_coord: whether to pass the window coordinates to ``network``. Defaults to False.\n            If True, the ``network``'s 2nd input argument should accept the window coordinates.\n\n    Note:\n        ``sw_batch_size`` denotes the max number of windows per network inference iteration,\n        not the batch size of inputs.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        roi_size: Sequence[int] | int,\n        sw_batch_size: int = 1,\n        overlap: Sequence[float] | float = 0.25,\n        mode: BlendMode | str = BlendMode.CONSTANT,\n        sigma_scale: Sequence[float] | float = 0.125,\n        padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,\n        cval: float = 0.0,\n        sw_device: torch.device | str | None = None,\n        device: torch.device | str | None = None,\n        progress: bool = False,\n        cache_roi_weight_map: bool = False,\n        cpu_thresh: int | None = None,\n        buffer_steps: int | None = None,\n        buffer_dim: int = -1,\n        with_coord: bool = False,\n    ) -> None:\n        super().__init__()\n        self.roi_size = roi_size\n        self.sw_batch_size = sw_batch_size\n        self.overlap = overlap\n        self.mode: BlendMode = BlendMode(mode)\n        self.sigma_scale = sigma_scale\n        self.padding_mode = padding_mode\n        self.cval = cval\n        self.sw_device = sw_device\n        self.device = device\n        self.progress = progress\n        self.cpu_thresh = cpu_thresh\n        self.buffer_steps = buffer_steps\n        self.buffer_dim = buffer_dim\n        self.with_coord = with_coord\n\n        # compute_importance_map takes long time when computing on cpu. We thus\n        # compute it once if it's static and then save it for future usage\n        self.roi_weight_map = None\n        try:\n            if cache_roi_weight_map and isinstance(roi_size, Sequence) and min(roi_size) > 0:  # non-dynamic roi size\n                if device is None:\n                    device = \"cpu\"\n                self.roi_weight_map = compute_importance_map(\n                    ensure_tuple(self.roi_size), mode=mode, sigma_scale=sigma_scale, device=device\n                )\n            if cache_roi_weight_map and self.roi_weight_map is None:\n                warnings.warn(\"cache_roi_weight_map=True, but cache is not created. (dynamic roi_size?)\")\n        except BaseException as e:\n            raise RuntimeError(\n                f\"roi size {self.roi_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\\n\"\n                \"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'.\"\n            ) from e\n\n    def __call__(\n        self,\n        inputs: torch.Tensor,\n        network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n        *args: Any,\n        **kwargs: Any,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:\n        \"\"\"\n\n        Args:\n            inputs: model input data for inference.\n            network: target model to execute inference.\n                supports callables such as ``lambda x: my_torch_model(x, additional_config)``\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n            condition (torch.Tensor, optional): If provided via `**kwargs`,\n                this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.\n                The resulting segments will be passed to the model together with the corresponding input segments.\n        \"\"\"\n        # shape check for condition\n        condition = kwargs.get(\"condition\", None)\n        if condition is not None and condition.shape != inputs.shape:\n            raise ValueError(f\"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}\")\n\n        device = kwargs.pop(\"device\", self.device)\n        buffer_steps = kwargs.pop(\"buffer_steps\", self.buffer_steps)\n        buffer_dim = kwargs.pop(\"buffer_dim\", self.buffer_dim)\n\n        if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:\n            device = \"cpu\"  # stitch in cpu memory if image is too large\n\n        return sliding_window_inference(\n            inputs,\n            self.roi_size,\n            self.sw_batch_size,\n            network,\n            self.overlap,\n            self.mode,\n            self.sigma_scale,\n            self.padding_mode,\n            self.cval,\n            self.sw_device,\n            device,\n            self.progress,\n            self.roi_weight_map,\n            None,\n            buffer_steps,\n            buffer_dim,\n            self.with_coord,\n            *args,\n            **kwargs,\n        )\n\n\nclass SlidingWindowInfererAdapt(SlidingWindowInferer):\n    \"\"\"\n    SlidingWindowInfererAdapt extends SlidingWindowInferer to automatically switch to buffered and then to CPU stitching,\n    when OOM on GPU. It also records a size of such large images to automatically\n    try CPU stitching for the next large image of a similar size.  If the stitching 'device' input parameter is provided,\n    automatic adaptation won't be attempted, please keep the default option device = None for adaptive behavior.\n    Note: the output might be on CPU (even if the input was on GPU), if the GPU memory was not sufficient.\n\n    \"\"\"\n\n    def __call__(\n        self,\n        inputs: torch.Tensor,\n        network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n        *args: Any,\n        **kwargs: Any,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:\n        \"\"\"\n\n        Args:\n            inputs: model input data for inference.\n            network: target model to execute inference.\n                supports callables such as ``lambda x: my_torch_model(x, additional_config)``\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n\n        \"\"\"\n\n        # if device is provided, use without any adaptations\n        if self.device is not None:\n            return super().__call__(inputs, network, *args, **kwargs)\n\n        skip_buffer = self.buffer_steps is not None and self.buffer_steps <= 0\n        cpu_cond = self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh\n        gpu_stitching = inputs.is_cuda and not cpu_cond\n        buffered_stitching = inputs.is_cuda and cpu_cond and not skip_buffer\n        buffer_steps = max(1, self.buffer_steps) if self.buffer_steps is not None else 1\n        buffer_dim = -1\n\n        sh = list(inputs.shape[2:])\n        max_dim = sh.index(max(sh))\n        if inputs.shape[max_dim + 2] / inputs.shape[-1] >= 2:\n            buffer_dim = max_dim\n\n        for _ in range(10):  # at most 10 trials\n            try:\n                return super().__call__(\n                    inputs,\n                    network,\n                    *args,\n                    device=inputs.device if gpu_stitching else torch.device(\"cpu\"),\n                    buffer_steps=buffer_steps if buffered_stitching else None,\n                    buffer_dim=buffer_dim,\n                    **kwargs,\n                )\n            except RuntimeError as e:\n                if not gpu_stitching and not buffered_stitching or \"OutOfMemoryError\" not in str(type(e).__name__):\n                    raise e\n\n                logger.info(e)\n\n                if gpu_stitching:  # if failed on gpu\n                    gpu_stitching = False\n                    self.cpu_thresh = inputs.shape[2:].numel() - 1  # update thresh\n\n                    if skip_buffer:\n                        buffered_stitching = False\n                        logger.warning(f\"GPU stitching failed, attempting on CPU, image dim {inputs.shape}.\")\n\n                    else:\n                        buffered_stitching = True\n                        self.buffer_steps = buffer_steps\n                        logger.warning(\n                            f\"GPU stitching failed, buffer {buffer_steps} dim {buffer_dim}, image dim {inputs.shape}.\"\n                        )\n                elif buffer_steps > 1:\n                    buffer_steps = max(1, buffer_steps // 2)\n                    self.buffer_steps = buffer_steps\n                    logger.warning(\n                        f\"GPU buffered stitching failed, image dim {inputs.shape} reducing buffer to {buffer_steps}.\"\n                    )\n                else:\n                    buffered_stitching = False\n                    logger.warning(f\"GPU buffered stitching failed, attempting on CPU, image dim {inputs.shape}.\")\n        raise RuntimeError(  # not possible to finish after the trials\n            f\"SlidingWindowInfererAdapt {skip_buffer} {cpu_cond} {gpu_stitching} {buffered_stitching} {buffer_steps}\"\n        )\n\n\nclass SaliencyInferer(Inferer):\n    \"\"\"\n    SaliencyInferer is inference with activation maps.\n\n    Args:\n        cam_name: expected CAM method name, should be: \"CAM\", \"GradCAM\" or \"GradCAMpp\".\n        target_layers: name of the model layer to generate the feature map.\n        class_idx: index of the class to be visualized. if None, default to argmax(logits).\n        args: other optional args to be passed to the `__init__` of cam.\n        kwargs: other optional keyword args to be passed to `__init__` of cam.\n\n    \"\"\"\n\n    def __init__(\n        self, cam_name: str, target_layers: str, class_idx: int | None = None, *args: Any, **kwargs: Any\n    ) -> None:\n        Inferer.__init__(self)\n        if cam_name.lower() not in (\"cam\", \"gradcam\", \"gradcampp\"):\n            raise ValueError(\"cam_name should be: 'CAM', 'GradCAM' or 'GradCAMpp'.\")\n        self.cam_name = cam_name.lower()\n        self.target_layers = target_layers\n        self.class_idx = class_idx\n        self.args = args\n        self.kwargs = kwargs\n\n    def __call__(self, inputs: torch.Tensor, network: nn.Module, *args: Any, **kwargs: Any):  # type: ignore\n        \"\"\"Unified callable function API of Inferers.\n\n        Args:\n            inputs: model input data for inference.\n            network: target model to execute inference.\n                supports callables such as ``lambda x: my_torch_model(x, additional_config)``\n            args: other optional args to be passed to the `__call__` of cam.\n            kwargs: other optional keyword args to be passed to `__call__` of cam.\n\n        \"\"\"\n        cam: CAM | GradCAM | GradCAMpp\n        if self.cam_name == \"cam\":\n            cam = CAM(network, self.target_layers, *self.args, **self.kwargs)\n        elif self.cam_name == \"gradcam\":\n            cam = GradCAM(network, self.target_layers, *self.args, **self.kwargs)\n        else:\n            cam = GradCAMpp(network, self.target_layers, *self.args, **self.kwargs)\n\n        return cam(inputs, self.class_idx, *args, **kwargs)\n\n\nclass SliceInferer(SlidingWindowInferer):\n    \"\"\"\n    SliceInferer extends SlidingWindowInferer to provide slice-by-slice (2D) inference when provided a 3D volume.\n    A typical use case could be a 2D model (like 2D segmentation UNet) operates on the slices from a 3D volume,\n    and the output is a 3D volume with 2D slices aggregated. Example::\n\n        # sliding over the `spatial_dim`\n        inferer = SliceInferer(roi_size=(64, 256), sw_batch_size=1, spatial_dim=1)\n        output = inferer(input_volume, net)\n\n    Args:\n        spatial_dim: Spatial dimension over which the slice-by-slice inference runs on the 3D volume.\n            For example ``0`` could slide over axial slices. ``1`` over coronal slices and ``2`` over sagittal slices.\n        args: other optional args to be passed to the `__init__` of base class SlidingWindowInferer.\n        kwargs: other optional keyword args to be passed to `__init__` of base class SlidingWindowInferer.\n\n    Note:\n        ``roi_size`` in SliceInferer is expected to be a 2D tuple when a 3D volume is provided. This allows\n        sliding across slices along the 3D volume using a selected ``spatial_dim``.\n\n    \"\"\"\n\n    def __init__(self, spatial_dim: int = 0, *args: Any, **kwargs: Any) -> None:\n        self.spatial_dim = spatial_dim\n        super().__init__(*args, **kwargs)\n        self.orig_roi_size = ensure_tuple(self.roi_size)\n\n    def __call__(\n        self,\n        inputs: torch.Tensor,\n        network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n        *args: Any,\n        **kwargs: Any,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:\n        \"\"\"\n        Args:\n            inputs: 3D input for inference\n            network: 2D model to execute inference on slices in the 3D input\n            args: optional args to be passed to ``network``.\n            kwargs: optional keyword args to be passed to ``network``.\n            condition (torch.Tensor, optional): If provided via `**kwargs`,\n                this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.\n                The resulting segments will be passed to the model together with the corresponding input segments.\"\"\"\n        if self.spatial_dim > 2:\n            raise ValueError(\"`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.\")\n\n        # Check if ``roi_size`` tuple is 2D and ``inputs`` tensor is 3D\n        self.roi_size = ensure_tuple(self.roi_size)\n        if len(self.orig_roi_size) == 2 and len(inputs.shape[2:]) == 3:\n            self.roi_size = list(self.orig_roi_size)\n            self.roi_size.insert(self.spatial_dim, 1)\n        else:\n            raise RuntimeError(\n                f\"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported.\"\n            )\n\n        # shape check for condition\n        condition = kwargs.get(\"condition\", None)\n        if condition is not None and condition.shape != inputs.shape:\n            raise ValueError(f\"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}\")\n\n        # check if there is a conditioning signal\n        if condition is not None:\n            return super().__call__(\n                inputs=inputs,\n                network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs),\n                condition=condition,\n            )\n        else:\n            return super().__call__(\n                inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs)\n            )\n\n    def network_wrapper(\n        self,\n        network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n        x: torch.Tensor,\n        condition: torch.Tensor | None = None,\n        *args: Any,\n        **kwargs: Any,\n    ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:\n        \"\"\"\n        Wrapper handles inference for 2D models over 3D volume inputs.\n        \"\"\"\n        #  Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.\n        x = x.squeeze(dim=self.spatial_dim + 2)\n\n        if condition is not None:\n            condition = condition.squeeze(dim=self.spatial_dim + 2)\n            out = network(x, condition, *args, **kwargs)\n        else:\n            out = network(x, *args, **kwargs)\n\n        #  Unsqueeze the network output so it is [N, C, D, H, W] as expected by\n        # the default SlidingWindowInferer class\n        if isinstance(out, torch.Tensor):\n            return out.unsqueeze(dim=self.spatial_dim + 2)\n\n        if isinstance(out, Mapping):\n            for k in out.keys():\n                out[k] = out[k].unsqueeze(dim=self.spatial_dim + 2)\n            return out\n\n        return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out)\n\n\nclass DiffusionInferer(Inferer):\n    \"\"\"\n    DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass\n    for a training iteration, and sample from the model.\n\n    Args:\n        scheduler: diffusion scheduler.\n    \"\"\"\n\n    def __init__(self, scheduler: Scheduler) -> None:  # type: ignore[override]\n        super().__init__()\n\n        self.scheduler = scheduler\n\n    def __call__(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        diffusion_model: DiffusionModelUNet,\n        noise: torch.Tensor,\n        timesteps: torch.Tensor,\n        condition: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Implements the forward pass for a supervised training iteration.\n\n        Args:\n            inputs: Input image to which noise is added.\n            diffusion_model: diffusion model.\n            noise: random noise, of the same shape as the input.\n            timesteps: random timesteps.\n            condition: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be\n            provided on the forward (for SPADE-like AE or SPADE-like DM)\n        \"\"\"\n        if mode not in [\"crossattn\", \"concat\"]:\n            raise NotImplementedError(f\"{mode} condition is not supported\")\n\n        noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)\n        if mode == \"concat\":\n            if condition is None:\n                raise ValueError(\"Conditioning is required for concat condition\")\n            else:\n                noisy_image = torch.cat([noisy_image, condition], dim=1)\n                condition = None\n        diffusion_model = (\n            partial(diffusion_model, seg=seg)\n            if isinstance(diffusion_model, SPADEDiffusionModelUNet)\n            else diffusion_model\n        )\n        prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)\n\n        return prediction\n\n    @torch.no_grad()\n    def sample(\n        self,\n        input_noise: torch.Tensor,\n        diffusion_model: DiffusionModelUNet,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        intermediate_steps: int | None = 100,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        verbose: bool = True,\n        seg: torch.Tensor | None = None,\n        cfg: float | None = None,\n        cfg_fill_value: float = -1.0,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Args:\n            input_noise: random noise, of the same shape as the desired sample.\n            diffusion_model: model to sample from.\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler\n            save_intermediates: whether to return intermediates along the sampling change\n            intermediate_steps: if save_intermediates is True, saves every n steps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            verbose: if true, prints the progression bar of the sampling process.\n            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.\n            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.\n            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.\n        \"\"\"\n        if mode not in [\"crossattn\", \"concat\"]:\n            raise NotImplementedError(f\"{mode} condition is not supported\")\n        if mode == \"concat\" and conditioning is None:\n            raise ValueError(\"Conditioning must be supplied for if condition mode is concat.\")\n        if not scheduler:\n            scheduler = self.scheduler\n        image = input_noise\n\n        all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))\n        if verbose and has_tqdm:\n            progress_bar = tqdm(\n                zip(scheduler.timesteps, all_next_timesteps),\n                total=min(len(scheduler.timesteps), len(all_next_timesteps)),\n            )\n        else:\n            progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))\n        intermediates = []\n\n        for t, next_t in progress_bar:\n            # 1. predict noise model_output\n            diffusion_model = (\n                partial(diffusion_model, seg=seg)\n                if isinstance(diffusion_model, SPADEDiffusionModelUNet)\n                else diffusion_model\n            )\n            if (\n                cfg is not None\n            ):  # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.\n                model_input = torch.cat([image] * 2, dim=0)\n                if conditioning is not None:\n                    uncondition = torch.ones_like(conditioning)\n                    uncondition.fill_(cfg_fill_value)\n                    conditioning_input = torch.cat([uncondition, conditioning], dim=0)\n                else:\n                    conditioning_input = None\n            else:\n                model_input = image\n                conditioning_input = conditioning\n            if mode == \"concat\" and conditioning_input is not None:\n                model_input = torch.cat([model_input, conditioning_input], dim=1)\n                model_output = diffusion_model(\n                    model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None\n                )\n            else:\n                model_output = diffusion_model(\n                    model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input\n                )\n            if cfg is not None:\n                model_output_uncond, model_output_cond = model_output.chunk(2)\n                model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)\n\n            # 2. compute previous image: x_t -> x_t-1\n            if not isinstance(scheduler, RFlowScheduler):\n                image, _ = scheduler.step(model_output, t, image)  # type: ignore\n            else:\n                image, _ = scheduler.step(model_output, t, image, next_t)  # type: ignore\n            if save_intermediates and t % intermediate_steps == 0:\n                intermediates.append(image)\n\n        if save_intermediates:\n            return image, intermediates\n        else:\n            return image\n\n    @torch.no_grad()\n    def get_likelihood(\n        self,\n        inputs: torch.Tensor,\n        diffusion_model: DiffusionModelUNet,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        original_input_range: tuple = (0, 255),\n        scaled_input_range: tuple = (0, 1),\n        verbose: bool = True,\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Computes the log-likelihoods for an input.\n\n        Args:\n            inputs: input images, NxCxHxW[xD]\n            diffusion_model: model to compute likelihood from\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.\n            save_intermediates: save the intermediate spatial KL maps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.\n            scaled_input_range: the [min,max] intensity range of the input data after scaling.\n            verbose: if true, prints the progression bar of the sampling process.\n            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.\n        \"\"\"\n\n        if not scheduler:\n            scheduler = self.scheduler\n        if scheduler._get_name() != \"DDPMScheduler\":\n            raise NotImplementedError(\n                f\"Likelihood computation is only compatible with DDPMScheduler,\"\n                f\" you are using {scheduler._get_name()}\"\n            )\n        if mode not in [\"crossattn\", \"concat\"]:\n            raise NotImplementedError(f\"{mode} condition is not supported\")\n        if mode == \"concat\" and conditioning is None:\n            raise ValueError(\"Conditioning must be supplied for if condition mode is concat.\")\n        if verbose and has_tqdm:\n            progress_bar = tqdm(scheduler.timesteps)\n        else:\n            progress_bar = iter(scheduler.timesteps)\n        intermediates = []\n        noise = torch.randn_like(inputs).to(inputs.device)\n        total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)\n        for t in progress_bar:\n            timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()\n            noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)\n            diffusion_model = (\n                partial(diffusion_model, seg=seg)\n                if isinstance(diffusion_model, SPADEDiffusionModelUNet)\n                else diffusion_model\n            )\n            if mode == \"concat\" and conditioning is not None:\n                noisy_image = torch.cat([noisy_image, conditioning], dim=1)\n                model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None)\n            else:\n                model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)\n\n            # get the model's predicted mean,  and variance if it is predicted\n            if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in [\"learned\", \"learned_range\"]:\n                model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)\n            else:\n                predicted_variance = None\n\n            # 1. compute alphas, betas\n            alpha_prod_t = scheduler.alphas_cumprod[t]\n            alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one\n            beta_prod_t = 1 - alpha_prod_t\n            beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n            # 2. compute predicted original sample from predicted noise also called\n            # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n            if scheduler.prediction_type == \"epsilon\":\n                pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n            elif scheduler.prediction_type == \"sample\":\n                pred_original_sample = model_output\n            elif scheduler.prediction_type == \"v_prediction\":\n                pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output\n            # 3. Clip \"predicted x_0\"\n            if scheduler.clip_sample:\n                pred_original_sample = torch.clamp(pred_original_sample, -1, 1)\n\n            # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t\n            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf\n            pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t\n            current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t\n\n            # 5. Compute predicted previous sample µ_t\n            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf\n            predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image\n\n            # get the posterior mean and variance\n            posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)  # type: ignore[operator]\n            posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)  # type: ignore[operator]\n\n            log_posterior_variance = torch.log(posterior_variance)\n            log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance\n\n            if t == 0:\n                # compute -log p(x_0|x_1)\n                kl = -self._get_decoder_log_likelihood(\n                    inputs=inputs,\n                    means=predicted_mean,\n                    log_scales=0.5 * log_predicted_variance,\n                    original_input_range=original_input_range,\n                    scaled_input_range=scaled_input_range,\n                )\n            else:\n                # compute kl between two normals\n                kl = 0.5 * (\n                    -1.0\n                    + log_predicted_variance\n                    - log_posterior_variance\n                    + torch.exp(log_posterior_variance - log_predicted_variance)\n                    + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)\n                )\n            total_kl += kl.view(kl.shape[0], -1).mean(dim=1)\n            if save_intermediates:\n                intermediates.append(kl.cpu())\n\n        if save_intermediates:\n            return total_kl, intermediates\n        else:\n            return total_kl\n\n    def _approx_standard_normal_cdf(self, x):\n        \"\"\"\n        A fast approximation of the cumulative distribution function of the\n        standard normal. Code adapted from https://github.com/openai/improved-diffusion.\n        \"\"\"\n\n        return 0.5 * (\n            1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))\n        )\n\n    def _get_decoder_log_likelihood(\n        self,\n        inputs: torch.Tensor,\n        means: torch.Tensor,\n        log_scales: torch.Tensor,\n        original_input_range: tuple = (0, 255),\n        scaled_input_range: tuple = (0, 1),\n    ) -> torch.Tensor:\n        \"\"\"\n        Compute the log-likelihood of a Gaussian distribution discretizing to a\n        given image. Code adapted from https://github.com/openai/improved-diffusion.\n\n        Args:\n            input: the target images. It is assumed that this was uint8 values,\n                      rescaled to the range [-1, 1].\n            means: the Gaussian mean Tensor.\n            log_scales: the Gaussian log stddev Tensor.\n            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.\n            scaled_input_range: the [min,max] intensity range of the input data after scaling.\n        \"\"\"\n        if inputs.shape != means.shape:\n            raise ValueError(f\"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}\")\n        bin_width = (scaled_input_range[1] - scaled_input_range[0]) / (\n            original_input_range[1] - original_input_range[0]\n        )\n        centered_x = inputs - means\n        inv_stdv = torch.exp(-log_scales)\n        plus_in = inv_stdv * (centered_x + bin_width / 2)\n        cdf_plus = self._approx_standard_normal_cdf(plus_in)\n        min_in = inv_stdv * (centered_x - bin_width / 2)\n        cdf_min = self._approx_standard_normal_cdf(min_in)\n        log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))\n        log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))\n        cdf_delta = cdf_plus - cdf_min\n        log_probs = torch.where(\n            inputs < -0.999,\n            log_cdf_plus,\n            torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),\n        )\n        return log_probs\n\n\nclass LatentDiffusionInferer(DiffusionInferer):\n    \"\"\"\n    LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can\n    be used to perform a signal forward pass for a training iteration, and sample from the model.\n\n    Args:\n        scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.\n        scale_factor: scale factor to multiply the values of the latent representation before processing it by the\n            second stage.\n        ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.\n        autoencoder_latent_shape:  autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a\n             difference between the autoencoder's latent shape and the DM shape.\n    \"\"\"\n\n    def __init__(\n        self,\n        scheduler: Scheduler,\n        scale_factor: float = 1.0,\n        ldm_latent_shape: list | None = None,\n        autoencoder_latent_shape: list | None = None,\n    ) -> None:\n        super().__init__(scheduler=scheduler)\n        self.scale_factor = scale_factor\n        if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):\n            raise ValueError(\"If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.\")\n        self.ldm_latent_shape = ldm_latent_shape\n        self.autoencoder_latent_shape = autoencoder_latent_shape\n        if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:\n            self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)\n            self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)\n\n    def __call__(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        autoencoder_model: AutoencoderKL | VQVAE,\n        diffusion_model: DiffusionModelUNet,\n        noise: torch.Tensor,\n        timesteps: torch.Tensor,\n        condition: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Implements the forward pass for a supervised training iteration.\n\n        Args:\n            inputs: input image to which the latent representation will be extracted and noise is added.\n            autoencoder_model: first stage model.\n            diffusion_model: diffusion model.\n            noise: random noise, of the same shape as the latent representation.\n            timesteps: random timesteps.\n            condition: conditioning for network input.\n            mode: Conditioning mode for the network.\n            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.\n        \"\"\"\n        with torch.no_grad():\n            latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor\n\n        if self.ldm_latent_shape is not None:\n            latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)\n\n        prediction: torch.Tensor = super().__call__(\n            inputs=latent,\n            diffusion_model=diffusion_model,\n            noise=noise,\n            timesteps=timesteps,\n            condition=condition,\n            mode=mode,\n            seg=seg,\n        )\n        return prediction\n\n    @torch.no_grad()\n    def sample(  # type: ignore[override]\n        self,\n        input_noise: torch.Tensor,\n        autoencoder_model: AutoencoderKL | VQVAE,\n        diffusion_model: DiffusionModelUNet,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        intermediate_steps: int | None = 100,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        verbose: bool = True,\n        seg: torch.Tensor | None = None,\n        cfg: float | None = None,\n        cfg_fill_value: float = -1.0,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Args:\n            input_noise: random noise, of the same shape as the desired latent representation.\n            autoencoder_model: first stage model.\n            diffusion_model: model to sample from.\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.\n            save_intermediates: whether to return intermediates along the sampling change\n            intermediate_steps: if save_intermediates is True, saves every n steps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            verbose: if true, prints the progression bar of the sampling process.\n            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model\n             is instance of SPADEAutoencoderKL, segmentation must be provided.\n            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.\n            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.\n        \"\"\"\n\n        if (\n            isinstance(autoencoder_model, SPADEAutoencoderKL)\n            and isinstance(diffusion_model, SPADEDiffusionModelUNet)\n            and autoencoder_model.decoder.label_nc != diffusion_model.label_nc\n        ):\n            raise ValueError(\n                f\"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic\"\n                f\"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and\"\n                f\"{diffusion_model.label_nc}\"\n            )\n\n        outputs = super().sample(\n            input_noise=input_noise,\n            diffusion_model=diffusion_model,\n            scheduler=scheduler,\n            save_intermediates=save_intermediates,\n            intermediate_steps=intermediate_steps,\n            conditioning=conditioning,\n            mode=mode,\n            verbose=verbose,\n            seg=seg,\n            cfg=cfg,\n            cfg_fill_value=cfg_fill_value,\n        )\n\n        if save_intermediates:\n            latent, latent_intermediates = outputs\n        else:\n            latent = outputs\n\n        if self.autoencoder_latent_shape is not None:\n            latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)\n            if save_intermediates:\n                latent_intermediates = [\n                    torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)\n                    for l in latent_intermediates\n                ]\n\n        decode = autoencoder_model.decode_stage_2_outputs\n        if isinstance(autoencoder_model, SPADEAutoencoderKL):\n            decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)\n        image = decode(latent / self.scale_factor)\n        if save_intermediates:\n            intermediates = []\n            for latent_intermediate in latent_intermediates:\n                decode = autoencoder_model.decode_stage_2_outputs\n                if isinstance(autoencoder_model, SPADEAutoencoderKL):\n                    decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)\n                intermediates.append(decode(latent_intermediate / self.scale_factor))\n            return image, intermediates\n\n        else:\n            return image\n\n    @torch.no_grad()\n    def get_likelihood(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        autoencoder_model: AutoencoderKL | VQVAE,\n        diffusion_model: DiffusionModelUNet,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        original_input_range: tuple | None = (0, 255),\n        scaled_input_range: tuple | None = (0, 1),\n        verbose: bool = True,\n        resample_latent_likelihoods: bool = False,\n        resample_interpolation_mode: str = \"nearest\",\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Computes the log-likelihoods of the latent representations of the input.\n\n        Args:\n            inputs: input images, NxCxHxW[xD]\n            autoencoder_model: first stage model.\n            diffusion_model: model to compute likelihood from\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler\n            save_intermediates: save the intermediate spatial KL maps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.\n            scaled_input_range: the [min,max] intensity range of the input data after scaling.\n            verbose: if true, prints the progression bar of the sampling process.\n            resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial\n                dimension as the input images.\n            resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',\n                or 'trilinear;\n            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model\n             is instance of SPADEAutoencoderKL, segmentation must be provided.\n        \"\"\"\n        if resample_latent_likelihoods and resample_interpolation_mode not in (\"nearest\", \"bilinear\", \"trilinear\"):\n            raise ValueError(\n                f\"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}\"\n            )\n        latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor\n\n        if self.ldm_latent_shape is not None:\n            latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)\n\n        outputs = super().get_likelihood(\n            inputs=latents,\n            diffusion_model=diffusion_model,\n            scheduler=scheduler,\n            save_intermediates=save_intermediates,\n            conditioning=conditioning,\n            mode=mode,\n            verbose=verbose,\n            seg=seg,\n        )\n\n        if save_intermediates and resample_latent_likelihoods:\n            intermediates = outputs[1]\n            resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)\n            intermediates = [resizer(x) for x in intermediates]\n            outputs = (outputs[0], intermediates)\n        return outputs\n\n\nclass ControlNetDiffusionInferer(DiffusionInferer):\n    \"\"\"\n    ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal\n    forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning.\n\n    Args:\n        scheduler: diffusion scheduler.\n    \"\"\"\n\n    def __init__(self, scheduler: Scheduler) -> None:\n        Inferer.__init__(self)\n        self.scheduler = scheduler\n\n    def __call__(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        diffusion_model: DiffusionModelUNet,\n        controlnet: ControlNet,\n        noise: torch.Tensor,\n        timesteps: torch.Tensor,\n        cn_cond: torch.Tensor,\n        condition: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Implements the forward pass for a supervised training iteration.\n\n        Args:\n            inputs: Input image to which noise is added.\n            diffusion_model: diffusion model.\n            controlnet: controlnet sub-network.\n            noise: random noise, of the same shape as the input.\n            timesteps: random timesteps.\n            cn_cond: conditioning image for the ControlNet.\n            condition: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be\n            provided on the forward (for SPADE-like AE or SPADE-like DM)\n        \"\"\"\n        if mode not in [\"crossattn\", \"concat\"]:\n            raise NotImplementedError(f\"{mode} condition is not supported\")\n\n        noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)\n\n        if mode == \"concat\" and condition is not None:\n            noisy_image = torch.cat([noisy_image, condition], dim=1)\n            condition = None\n\n        down_block_res_samples, mid_block_res_sample = controlnet(\n            x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition\n        )\n\n        diffuse = diffusion_model\n        if isinstance(diffusion_model, SPADEDiffusionModelUNet):\n            diffuse = partial(diffusion_model, seg=seg)\n\n        prediction: torch.Tensor = diffuse(\n            x=noisy_image,\n            timesteps=timesteps,\n            context=condition,\n            down_block_additional_residuals=down_block_res_samples,\n            mid_block_additional_residual=mid_block_res_sample,\n        )\n\n        return prediction\n\n    @torch.no_grad()\n    def sample(  # type: ignore[override]\n        self,\n        input_noise: torch.Tensor,\n        diffusion_model: DiffusionModelUNet,\n        controlnet: ControlNet,\n        cn_cond: torch.Tensor,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        intermediate_steps: int | None = 100,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        verbose: bool = True,\n        seg: torch.Tensor | None = None,\n        cfg: float | None = None,\n        cfg_fill_value: float = -1.0,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Args:\n            input_noise: random noise, of the same shape as the desired sample.\n            diffusion_model: model to sample from.\n            controlnet: controlnet sub-network.\n            cn_cond: conditioning image for the ControlNet.\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler\n            save_intermediates: whether to return intermediates along the sampling change\n            intermediate_steps: if save_intermediates is True, saves every n steps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            verbose: if true, prints the progression bar of the sampling process.\n            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.\n            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.\n            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.\n        \"\"\"\n        if mode not in [\"crossattn\", \"concat\"]:\n            raise NotImplementedError(f\"{mode} condition is not supported\")\n\n        if not scheduler:\n            scheduler = self.scheduler\n        image = input_noise\n\n        all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype)))\n        if verbose and has_tqdm:\n            progress_bar = tqdm(\n                zip(scheduler.timesteps, all_next_timesteps),\n                total=min(len(scheduler.timesteps), len(all_next_timesteps)),\n            )\n        else:\n            progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))\n        intermediates = []\n\n        if cfg is not None:\n            cn_cond = torch.cat([cn_cond] * 2, dim=0)\n\n        for t, next_t in progress_bar:\n            # Controlnet prediction\n            if cfg is not None:\n                model_input = torch.cat([image] * 2, dim=0)\n                if conditioning is not None:\n                    uncondition = torch.ones_like(conditioning)\n                    uncondition.fill_(cfg_fill_value)\n                    conditioning_input = torch.cat([uncondition, conditioning], dim=0)\n                else:\n                    conditioning_input = None\n            else:\n                model_input = image\n                conditioning_input = conditioning\n\n            # Diffusion model prediction\n            diffuse = diffusion_model\n            if isinstance(diffusion_model, SPADEDiffusionModelUNet):\n                diffuse = partial(diffusion_model, seg=seg)\n\n            if mode == \"concat\" and conditioning_input is not None:\n                # 1. Conditioning\n                model_input = torch.cat([model_input, conditioning_input], dim=1)\n                # 2. ControlNet forward\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    x=model_input,\n                    timesteps=torch.Tensor((t,)).to(input_noise.device),\n                    controlnet_cond=cn_cond,\n                    context=None,\n                )\n                # 3. predict noise model_output\n                model_output = diffuse(\n                    model_input,\n                    timesteps=torch.Tensor((t,)).to(input_noise.device),\n                    context=None,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                )\n            else:\n                # 1. Controlnet forward\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    x=model_input,\n                    timesteps=torch.Tensor((t,)).to(input_noise.device),\n                    controlnet_cond=cn_cond,\n                    context=conditioning_input,\n                )\n                # 2. predict noise model_output\n                model_output = diffuse(\n                    model_input,\n                    timesteps=torch.Tensor((t,)).to(input_noise.device),\n                    context=conditioning_input,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                )\n\n            # If classifier-free guidance isn't None, we split and compute the weighting between\n            # conditioned and unconditioned output.\n            if cfg is not None:\n                model_output_uncond, model_output_cond = model_output.chunk(2)\n                model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)\n\n            # 3. compute previous image: x_t -> x_t-1\n            if not isinstance(scheduler, RFlowScheduler):\n                image, _ = scheduler.step(model_output, t, image)  # type: ignore\n            else:\n                image, _ = scheduler.step(model_output, t, image, next_t)  # type: ignore\n\n            if save_intermediates and t % intermediate_steps == 0:\n                intermediates.append(image)\n        if save_intermediates:\n            return image, intermediates\n        else:\n            return image\n\n    @torch.no_grad()\n    def get_likelihood(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        diffusion_model: DiffusionModelUNet,\n        controlnet: ControlNet,\n        cn_cond: torch.Tensor,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        original_input_range: tuple = (0, 255),\n        scaled_input_range: tuple = (0, 1),\n        verbose: bool = True,\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Computes the log-likelihoods for an input.\n\n        Args:\n            inputs: input images, NxCxHxW[xD]\n            diffusion_model: model to compute likelihood from\n            controlnet: controlnet sub-network.\n            cn_cond: conditioning image for the ControlNet.\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.\n            save_intermediates: save the intermediate spatial KL maps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.\n            scaled_input_range: the [min,max] intensity range of the input data after scaling.\n            verbose: if true, prints the progression bar of the sampling process.\n            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.\n        \"\"\"\n\n        if not scheduler:\n            scheduler = self.scheduler\n        if scheduler._get_name() != \"DDPMScheduler\":\n            raise NotImplementedError(\n                f\"Likelihood computation is only compatible with DDPMScheduler,\"\n                f\" you are using {scheduler._get_name()}\"\n            )\n        if mode not in [\"crossattn\", \"concat\"]:\n            raise NotImplementedError(f\"{mode} condition is not supported\")\n        if verbose and has_tqdm:\n            progress_bar = tqdm(scheduler.timesteps)\n        else:\n            progress_bar = iter(scheduler.timesteps)\n        intermediates = []\n        noise = torch.randn_like(inputs).to(inputs.device)\n        total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)\n        for t in progress_bar:\n            timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()\n            noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)\n\n            diffuse = diffusion_model\n            if isinstance(diffusion_model, SPADEDiffusionModelUNet):\n                diffuse = partial(diffusion_model, seg=seg)\n\n            if mode == \"concat\" and conditioning is not None:\n                noisy_image = torch.cat([noisy_image, conditioning], dim=1)\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None\n                )\n                model_output = diffuse(\n                    noisy_image,\n                    timesteps=timesteps,\n                    context=None,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                )\n            else:\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    x=noisy_image,\n                    timesteps=torch.Tensor((t,)).to(inputs.device),\n                    controlnet_cond=cn_cond,\n                    context=conditioning,\n                )\n                model_output = diffuse(\n                    x=noisy_image,\n                    timesteps=timesteps,\n                    context=conditioning,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                )\n            # get the model's predicted mean,  and variance if it is predicted\n            if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in [\"learned\", \"learned_range\"]:\n                model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)\n            else:\n                predicted_variance = None\n\n            # 1. compute alphas, betas\n            alpha_prod_t = scheduler.alphas_cumprod[t]\n            alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one\n            beta_prod_t = 1 - alpha_prod_t\n            beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n            # 2. compute predicted original sample from predicted noise also called\n            # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n            if scheduler.prediction_type == \"epsilon\":\n                pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n            elif scheduler.prediction_type == \"sample\":\n                pred_original_sample = model_output\n            elif scheduler.prediction_type == \"v_prediction\":\n                pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output\n            # 3. Clip \"predicted x_0\"\n            if scheduler.clip_sample:\n                pred_original_sample = torch.clamp(pred_original_sample, -1, 1)\n\n            # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t\n            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf\n            pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t\n            current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t\n\n            # 5. Compute predicted previous sample µ_t\n            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf\n            predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image\n\n            # get the posterior mean and variance\n            posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)  # type: ignore[operator]\n            posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)  # type: ignore[operator]\n\n            log_posterior_variance = torch.log(posterior_variance)\n            log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance\n\n            if t == 0:\n                # compute -log p(x_0|x_1)\n                kl = -super()._get_decoder_log_likelihood(\n                    inputs=inputs,\n                    means=predicted_mean,\n                    log_scales=0.5 * log_predicted_variance,\n                    original_input_range=original_input_range,\n                    scaled_input_range=scaled_input_range,\n                )\n            else:\n                # compute kl between two normals\n                kl = 0.5 * (\n                    -1.0\n                    + log_predicted_variance\n                    - log_posterior_variance\n                    + torch.exp(log_posterior_variance - log_predicted_variance)\n                    + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)\n                )\n            total_kl += kl.view(kl.shape[0], -1).mean(dim=1)\n            if save_intermediates:\n                intermediates.append(kl.cpu())\n\n        if save_intermediates:\n            return total_kl, intermediates\n        else:\n            return total_kl\n\n\nclass ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):\n    \"\"\"\n    ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,\n    and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from\n    the model.\n\n    Args:\n        scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.\n        scale_factor: scale factor to multiply the values of the latent representation before processing it by the\n            second stage.\n        ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.\n        autoencoder_latent_shape:  autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a\n             difference between the autoencoder's latent shape and the DM shape.\n    \"\"\"\n\n    def __init__(\n        self,\n        scheduler: Scheduler,\n        scale_factor: float = 1.0,\n        ldm_latent_shape: list | None = None,\n        autoencoder_latent_shape: list | None = None,\n    ) -> None:\n        super().__init__(scheduler=scheduler)\n        self.scale_factor = scale_factor\n        if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):\n            raise ValueError(\"If ldm_latent_shape is None, autoencoder_latent_shape must be None\" \"and vice versa.\")\n        self.ldm_latent_shape = ldm_latent_shape\n        self.autoencoder_latent_shape = autoencoder_latent_shape\n        if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:\n            self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)\n            self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)\n\n    def __call__(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        autoencoder_model: AutoencoderKL | VQVAE,\n        diffusion_model: DiffusionModelUNet,\n        controlnet: ControlNet,\n        noise: torch.Tensor,\n        timesteps: torch.Tensor,\n        cn_cond: torch.Tensor,\n        condition: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Implements the forward pass for a supervised training iteration.\n\n        Args:\n            inputs: input image to which the latent representation will be extracted and noise is added.\n            autoencoder_model: first stage model.\n            diffusion_model: diffusion model.\n            controlnet: instance of ControlNet model\n            noise: random noise, of the same shape as the latent representation.\n            timesteps: random timesteps.\n            cn_cond: conditioning tensor for the ControlNet network\n            condition: conditioning for network input.\n            mode: Conditioning mode for the network.\n            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.\n        \"\"\"\n        with torch.no_grad():\n            latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor\n\n        if self.ldm_latent_shape is not None:\n            latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)\n\n        if cn_cond.shape[2:] != latent.shape[2:]:\n            cn_cond = F.interpolate(cn_cond, latent.shape[2:])\n\n        prediction = super().__call__(\n            inputs=latent,\n            diffusion_model=diffusion_model,\n            controlnet=controlnet,\n            noise=noise,\n            timesteps=timesteps,\n            cn_cond=cn_cond,\n            condition=condition,\n            mode=mode,\n            seg=seg,\n        )\n\n        return prediction\n\n    @torch.no_grad()\n    def sample(  # type: ignore[override]\n        self,\n        input_noise: torch.Tensor,\n        autoencoder_model: AutoencoderKL | VQVAE,\n        diffusion_model: DiffusionModelUNet,\n        controlnet: ControlNet,\n        cn_cond: torch.Tensor,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        intermediate_steps: int | None = 100,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        verbose: bool = True,\n        seg: torch.Tensor | None = None,\n        cfg: float | None = None,\n        cfg_fill_value: float = -1.0,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Args:\n            input_noise: random noise, of the same shape as the desired latent representation.\n            autoencoder_model: first stage model.\n            diffusion_model: model to sample from.\n            controlnet: instance of ControlNet model.\n            cn_cond: conditioning tensor for the ControlNet network.\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.\n            save_intermediates: whether to return intermediates along the sampling change\n            intermediate_steps: if save_intermediates is True, saves every n steps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            verbose: if true, prints the progression bar of the sampling process.\n            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model\n             is instance of SPADEAutoencoderKL, segmentation must be provided.\n            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.\n            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.\n        \"\"\"\n\n        if (\n            isinstance(autoencoder_model, SPADEAutoencoderKL)\n            and isinstance(diffusion_model, SPADEDiffusionModelUNet)\n            and autoencoder_model.decoder.label_nc != diffusion_model.label_nc\n        ):\n            raise ValueError(\n                \"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic\"\n                \"labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}\"\n            )\n\n        if cn_cond.shape[2:] != input_noise.shape[2:]:\n            cn_cond = F.interpolate(cn_cond, input_noise.shape[2:])\n\n        outputs = super().sample(\n            input_noise=input_noise,\n            diffusion_model=diffusion_model,\n            controlnet=controlnet,\n            cn_cond=cn_cond,\n            scheduler=scheduler,\n            save_intermediates=save_intermediates,\n            intermediate_steps=intermediate_steps,\n            conditioning=conditioning,\n            mode=mode,\n            verbose=verbose,\n            seg=seg,\n            cfg=cfg,\n            cfg_fill_value=cfg_fill_value,\n        )\n\n        if save_intermediates:\n            latent, latent_intermediates = outputs\n        else:\n            latent = outputs\n\n        if self.autoencoder_latent_shape is not None:\n            latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)\n            if save_intermediates:\n                latent_intermediates = [\n                    torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)\n                    for l in latent_intermediates\n                ]\n\n        decode = autoencoder_model.decode_stage_2_outputs\n        if isinstance(autoencoder_model, SPADEAutoencoderKL):\n            decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)\n\n        image = decode(latent / self.scale_factor)\n\n        if save_intermediates:\n            intermediates = []\n            for latent_intermediate in latent_intermediates:\n                decode = autoencoder_model.decode_stage_2_outputs\n                if isinstance(autoencoder_model, SPADEAutoencoderKL):\n                    decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)\n                intermediates.append(decode(latent_intermediate / self.scale_factor))\n            return image, intermediates\n\n        else:\n            return image\n\n    @torch.no_grad()\n    def get_likelihood(  # type: ignore[override]\n        self,\n        inputs: torch.Tensor,\n        autoencoder_model: AutoencoderKL | VQVAE,\n        diffusion_model: DiffusionModelUNet,\n        controlnet: ControlNet,\n        cn_cond: torch.Tensor,\n        scheduler: Scheduler | None = None,\n        save_intermediates: bool | None = False,\n        conditioning: torch.Tensor | None = None,\n        mode: str = \"crossattn\",\n        original_input_range: tuple | None = (0, 255),\n        scaled_input_range: tuple | None = (0, 1),\n        verbose: bool = True,\n        resample_latent_likelihoods: bool = False,\n        resample_interpolation_mode: str = \"nearest\",\n        seg: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:\n        \"\"\"\n        Computes the log-likelihoods of the latent representations of the input.\n\n        Args:\n            inputs: input images, NxCxHxW[xD]\n            autoencoder_model: first stage model.\n            diffusion_model: model to compute likelihood from\n            controlnet: instance of ControlNet model.\n            cn_cond: conditioning tensor for the ControlNet network.\n            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler\n            save_intermediates: save the intermediate spatial KL maps\n            conditioning: Conditioning for network input.\n            mode: Conditioning mode for the network.\n            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.\n            scaled_input_range: the [min,max] intensity range of the input data after scaling.\n            verbose: if true, prints the progression bar of the sampling process.\n            resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial\n                dimension as the input images.\n            resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',\n                or 'trilinear;\n            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model\n             is instance of SPADEAutoencoderKL, segmentation must be provided.\n        \"\"\"\n        if resample_latent_likelihoods and resample_interpolation_mode not in (\"nearest\", \"bilinear\", \"trilinear\"):\n            raise ValueError(\n                f\"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}\"\n            )\n\n        latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor\n\n        if cn_cond.shape[2:] != latents.shape[2:]:\n            cn_cond = F.interpolate(cn_cond, latents.shape[2:])\n\n        if self.ldm_latent_shape is not None:\n            latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)\n\n        outputs = super().get_likelihood(\n            inputs=latents,\n            diffusion_model=diffusion_model,\n            controlnet=controlnet,\n            cn_cond=cn_cond,\n            scheduler=scheduler,\n            save_intermediates=save_intermediates,\n            conditioning=conditioning,\n            mode=mode,\n            verbose=verbose,\n            seg=seg,\n        )\n\n        if save_intermediates and resample_latent_likelihoods:\n            intermediates = outputs[1]\n            resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)\n            intermediates = [resizer(x) for x in intermediates]\n            outputs = (outputs[0], intermediates)\n        return outputs\n\n\nclass VQVAETransformerInferer(nn.Module):\n    \"\"\"\n    Class to perform inference with a VQVAE + Transformer model.\n    \"\"\"\n\n    def __init__(self) -> None:\n        Inferer.__init__(self)\n\n    def __call__(\n        self,\n        inputs: torch.Tensor,\n        vqvae_model: VQVAE,\n        transformer_model: DecoderOnlyTransformer,\n        ordering: Ordering,\n        condition: torch.Tensor | None = None,\n        return_latent: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]:\n        \"\"\"\n        Implements the forward pass for a supervised training iteration.\n\n        Args:\n            inputs: input image to which the latent representation will be extracted.\n            vqvae_model: first stage model.\n            transformer_model: autoregressive transformer model.\n            ordering: ordering of the quantised latent representation.\n            return_latent: also return latent sequence and spatial dim of the latent.\n            condition: conditioning for network input.\n        \"\"\"\n        with torch.no_grad():\n            latent = vqvae_model.index_quantize(inputs)\n\n        latent_spatial_dim = tuple(latent.shape[1:])\n        latent = latent.reshape(latent.shape[0], -1)\n        latent = latent[:, ordering.get_sequence_ordering()]\n\n        # get the targets for the loss\n        target = latent.clone()\n        # Use the value from vqvae_model's num_embeddings as the starting token, the \"Begin Of Sentence\" (BOS) token.\n        # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.\n        latent = F.pad(latent, (1, 0), \"constant\", vqvae_model.num_embeddings)\n        # crop the last token as we do not need the probability of the token that follows it\n        latent = latent[:, :-1]\n        latent = latent.long()\n\n        # train on a part of the sequence if it is longer than max_seq_length\n        seq_len = latent.shape[1]\n        max_seq_len = transformer_model.max_seq_len\n        if max_seq_len < seq_len:\n            start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item())\n        else:\n            start = 0\n        prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition)\n        if return_latent:\n            return prediction, target[:, start : start + max_seq_len], latent_spatial_dim\n        else:\n            return prediction\n\n    @torch.no_grad()\n    def sample(\n        self,\n        latent_spatial_dim: tuple[int, int, int] | tuple[int, int],\n        starting_tokens: torch.Tensor,\n        vqvae_model: VQVAE,\n        transformer_model: DecoderOnlyTransformer,\n        ordering: Ordering,\n        conditioning: torch.Tensor | None = None,\n        temperature: float = 1.0,\n        top_k: int | None = None,\n        verbose: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Sampling function for the VQVAE + Transformer model.\n\n        Args:\n            latent_spatial_dim: shape of the sampled image.\n            starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.\n            vqvae_model: first stage model.\n            transformer_model: model to sample from.\n            conditioning: Conditioning for network input.\n            temperature: temperature for sampling.\n            top_k: top k sampling.\n            verbose: if true, prints the progression bar of the sampling process.\n        \"\"\"\n        seq_len = math.prod(latent_spatial_dim)\n\n        if verbose and has_tqdm:\n            progress_bar = tqdm(range(seq_len))\n        else:\n            progress_bar = iter(range(seq_len))\n\n        latent_seq = starting_tokens.long()\n        for _ in progress_bar:\n            # if the sequence context is growing too long we must crop it at block_size\n            if latent_seq.size(1) <= transformer_model.max_seq_len:\n                idx_cond = latent_seq\n            else:\n                idx_cond = latent_seq[:, -transformer_model.max_seq_len :]\n\n            # forward the model to get the logits for the index in the sequence\n            logits = transformer_model(x=idx_cond, context=conditioning)\n            # pluck the logits at the final step and scale by desired temperature\n            logits = logits[:, -1, :] / temperature\n            # optionally crop the logits to only the top k options\n            if top_k is not None:\n                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n                logits[logits < v[:, [-1]]] = -float(\"Inf\")\n            # apply softmax to convert logits to (normalized) probabilities\n            probs = F.softmax(logits, dim=-1)\n            # remove the chance to be sampled the BOS token\n            probs[:, vqvae_model.num_embeddings] = 0\n            # sample from the distribution\n            idx_next = torch.multinomial(probs, num_samples=1)\n            # append sampled index to the running sequence and continue\n            latent_seq = torch.cat((latent_seq, idx_next), dim=1)\n\n        latent_seq = latent_seq[:, 1:]\n        latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()]\n        latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)\n\n        return vqvae_model.decode_samples(latent)\n\n    @torch.no_grad()\n    def get_likelihood(\n        self,\n        inputs: torch.Tensor,\n        vqvae_model: VQVAE,\n        transformer_model: DecoderOnlyTransformer,\n        ordering: Ordering,\n        condition: torch.Tensor | None = None,\n        resample_latent_likelihoods: bool = False,\n        resample_interpolation_mode: str = \"nearest\",\n        verbose: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"\n        Computes the log-likelihoods of the latent representations of the input.\n\n        Args:\n            inputs: input images, NxCxHxW[xD]\n            vqvae_model: first stage model.\n            transformer_model: autoregressive transformer model.\n            ordering: ordering of the quantised latent representation.\n            condition: conditioning for network input.\n            resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial\n                dimension as the input images.\n            resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',\n                or 'trilinear;\n            verbose: if true, prints the progression bar of the sampling process.\n\n        \"\"\"\n        if resample_latent_likelihoods and resample_interpolation_mode not in (\"nearest\", \"bilinear\", \"trilinear\"):\n            raise ValueError(\n                f\"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}\"\n            )\n\n        with torch.no_grad():\n            latent = vqvae_model.index_quantize(inputs)\n\n        latent_spatial_dim = tuple(latent.shape[1:])\n        latent = latent.reshape(latent.shape[0], -1)\n        latent = latent[:, ordering.get_sequence_ordering()]\n        seq_len = math.prod(latent_spatial_dim)\n\n        # Use the value from vqvae_model's num_embeddings as the starting token, the \"Begin Of Sentence\" (BOS) token.\n        # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.\n        latent = F.pad(latent, (1, 0), \"constant\", vqvae_model.num_embeddings)\n        latent = latent.long()\n\n        # get the first batch, up to max_seq_length, efficiently\n        logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition)\n        probs = F.softmax(logits, dim=-1)\n        # target token for each set of logits is the next token along\n        target = latent[:, 1:]\n        probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)\n\n        # if we have not covered the full sequence we continue with inefficient looping\n        if probs.shape[1] < target.shape[1]:\n            if verbose and has_tqdm:\n                progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len))\n            else:\n                progress_bar = iter(range(transformer_model.max_seq_len, seq_len))\n\n            for i in progress_bar:\n                idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1]\n                # forward the model to get the logits for the index in the sequence\n                logits = transformer_model(x=idx_cond, context=condition)\n                # pluck the logits at the final step\n                logits = logits[:, -1, :]\n                # apply softmax to convert logits to (normalized) probabilities\n                p = F.softmax(logits, dim=-1)\n                # select correct values and append\n                p = torch.gather(p, 1, target[:, i].unsqueeze(1))\n\n                probs = torch.cat((probs, p), dim=1)\n\n        # convert to log-likelihood\n        probs = torch.log(probs)\n\n        # reshape\n        probs = probs[:, ordering.get_revert_sequence_ordering()]\n        probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim)\n        if resample_latent_likelihoods:\n            resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)\n            probs_reshaped = resizer(probs_reshaped[:, None, ...])\n\n        return probs_reshaped\n"
  },
  {
    "path": "monai/inferers/merger.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport threading\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Sequence\nfrom contextlib import nullcontext\nfrom tempfile import TemporaryDirectory\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import (\n    deprecated_arg,\n    ensure_tuple_size,\n    get_package_version,\n    optional_import,\n    require_pkg,\n    version_geq,\n)\n\nif TYPE_CHECKING:\n    import zarr\nelse:\n    zarr, _ = optional_import(\"zarr\")\n\n__all__ = [\"Merger\", \"AvgMerger\", \"ZarrAvgMerger\"]\n\n\nclass Merger(ABC):\n    \"\"\"\n    A base class for merging patches.\n    Extend this class to support operations for `PatchInference`.\n    There are two methods that must be implemented in the concrete classes:\n\n        - aggregate: aggregate the values at their corresponding locations\n        - finalize: perform any final process and return the merged output\n\n    Args:\n        merged_shape: the shape of the tensor required to merge the patches.\n        cropped_shape: the shape of the final merged output tensor.\n            If not provided, it will be the same as `merged_shape`.\n        device: the device where Merger tensors should reside.\n    \"\"\"\n\n    def __init__(\n        self,\n        merged_shape: Sequence[int],\n        cropped_shape: Sequence[int] | None = None,\n        device: torch.device | str | None = None,\n    ) -> None:\n        if merged_shape is None:\n            raise ValueError(\"Argument `merged_shape` must be provided\")\n\n        self.merged_shape: tuple[int, ...] = tuple(merged_shape)\n        self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape)\n        self.device = device\n        self.is_finalized = False\n\n    @abstractmethod\n    def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:\n        \"\"\"\n        Aggregate values for merging.\n        This method is being called in a loop and should add values to their corresponding location in the merged output results.\n\n        Args:\n            values: a tensor of shape BCHW[D], representing the values of inference output.\n            location: a tuple/list giving the top left location of the patch in the output.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def finalize(self) -> Any:\n        \"\"\"\n        Perform final operations for merging patches and return the final merged output.\n\n        Returns:\n            The results of merged patches, which is commonly a torch.Tensor representing the merged result, or\n                a string representing the filepath to the merged results on disk.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass AvgMerger(Merger):\n    \"\"\"Merge patches by taking average of the overlapping area\n\n    Args:\n        merged_shape: the shape of the tensor required to merge the patches.\n        cropped_shape: the shape of the final merged output tensor.\n            If not provided, it will be the same as `merged_shape`.\n        device: the device for aggregator tensors and final results.\n        value_dtype: the dtype for value aggregating tensor and the final result.\n        count_dtype: the dtype for sample counting tensor.\n    \"\"\"\n\n    def __init__(\n        self,\n        merged_shape: Sequence[int],\n        cropped_shape: Sequence[int] | None = None,\n        value_dtype: torch.dtype = torch.float32,\n        count_dtype: torch.dtype = torch.uint8,\n        device: torch.device | str = \"cpu\",\n    ) -> None:\n        super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape, device=device)\n        if not self.merged_shape:\n            raise ValueError(f\"`merged_shape` must be provided for `AvgMerger`. {self.merged_shape} is give.\")\n        self.value_dtype = value_dtype\n        self.count_dtype = count_dtype\n        self.values = torch.zeros(self.merged_shape, dtype=self.value_dtype, device=self.device)\n        self.counts = torch.zeros(self.merged_shape, dtype=self.count_dtype, device=self.device)\n\n    def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:\n        \"\"\"\n        Aggregate values for merging.\n\n        Args:\n            values: a tensor of shape BCHW[D], representing the values of inference output.\n            location: a tuple/list giving the top left location of the patch in the original image.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        if self.is_finalized:\n            raise ValueError(\"`AvgMerger` is already finalized. Please instantiate a new object to aggregate.\")\n        patch_size = values.shape[2:]\n        map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size))\n        map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)\n        self.values[map_slice] += values\n        self.counts[map_slice] += 1\n\n    def finalize(self) -> torch.Tensor:\n        \"\"\"\n        Finalize merging by dividing values by counts and return the merged tensor.\n\n        Notes:\n            To avoid creating a new tensor for the final results (to save memory space),\n            after this method is called, `get_values()` method will return the \"final\" averaged values,\n            and not the accumulating values. Also calling `finalize()` multiple times does not have any effect.\n\n        Returns:\n            torch.tensor: a tensor of merged patches\n        \"\"\"\n        # guard against multiple call to finalize\n        if not self.is_finalized:\n            # use in-place division to save space\n            self.values.div_(self.counts)\n            # finalize the shape\n            self.values = self.values[tuple(slice(0, end) for end in self.cropped_shape)]\n            # set finalize flag to protect performing in-place division again\n            self.is_finalized = True\n\n        return self.values\n\n    def get_output(self) -> torch.Tensor:\n        \"\"\"\n        Get the final merged output.\n\n        Returns:\n            torch.Tensor: merged output.\n        \"\"\"\n        return self.finalize()\n\n    def get_values(self) -> torch.Tensor:\n        \"\"\"\n        Get the accumulated values during aggregation or final averaged values after it is finalized.\n\n        Returns:\n            torch.tensor: aggregated values.\n\n        Notes:\n            - If called before calling `finalize()`, this method returns the accumulating values.\n            - If called after calling `finalize()`, this method returns the final merged [and averaged] values.\n        \"\"\"\n        return self.values\n\n    def get_counts(self) -> torch.Tensor:\n        \"\"\"\n        Get the aggregator tensor for number of samples.\n\n        Returns:\n            torch.Tensor: number of accumulated samples at each location.\n        \"\"\"\n        return self.counts\n\n\n@require_pkg(pkg_name=\"zarr\")\nclass ZarrAvgMerger(Merger):\n    \"\"\"Merge patches by taking average of the overlapping area and store the results in zarr array.\n\n    Zarr is a format for the storage of chunked, compressed, N-dimensional arrays.\n    Zarr data can be stored in any storage system that can be represented as a key-value store,\n    like POSIX file systems, cloud object storage, zip files, and relational and document databases.\n    See https://zarr.readthedocs.io/en/stable/ for more details.\n    It is particularly useful for storing N-dimensional arrays too large to fit into memory.\n    One specific use case of this class is to merge patches extracted from whole slide images (WSI),\n    where the merged results do not fit into memory and need to be stored on a file system.\n\n    Args:\n        merged_shape: the shape of the tensor required to merge the patches.\n        cropped_shape: the shape of the final merged output tensor.\n            If not provided, it will be the same as `merged_shape`.\n        dtype: the dtype for the final merged result. Default is `float32`.\n        value_dtype: the dtype for value aggregating tensor and the final result. Default is `float32`.\n        count_dtype: the dtype for sample counting tensor. Default is `uint8`.\n        store: the zarr store to save the final results. Default is \"merged.zarr\".\n        value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.\n        count_store: the zarr store to save the sample counting tensor. Default is a temporary store.\n        compressor: the compressor for final merged zarr array. Default is None.\n            Deprecated since 1.5.0 and will be removed in 1.7.0. Use codecs instead.\n        value_compressor: the compressor for value aggregating zarr array. Default is None.\n            Deprecated since 1.5.0 and will be removed in 1.7.0. Use value_codecs instead.\n        count_compressor: the compressor for sample counting zarr array. Default is None.\n            Deprecated since 1.5.0 and will be removed in 1.7.0. Use count_codecs instead.\n        codecs: the codecs for final merged zarr array. Default is None.\n            For zarr v3, this is a list of codec configurations. See zarr documentation for details.\n        value_codecs: the codecs for value aggregating zarr array. Default is None.\n            For zarr v3, this is a list of codec configurations. See zarr documentation for details.\n        count_codecs: the codecs for sample counting zarr array. Default is None.\n            For zarr v3, this is a list of codec configurations. See zarr documentation for details.\n        chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.\n            If True, chunk shape will be guessed from `shape` and `dtype`.\n            If False, it will be set to `shape`, i.e., single chunk for the whole array.\n            If an int, the chunk size in each dimension will be given by the value of `chunks`.\n    \"\"\"\n\n    @deprecated_arg(\n        name=\"compressor\", since=\"1.5.0\", removed=\"1.7.0\", new_name=\"codecs\", msg_suffix=\"Please use 'codecs' instead.\"\n    )\n    @deprecated_arg(\n        name=\"value_compressor\",\n        since=\"1.5.0\",\n        removed=\"1.7.0\",\n        new_name=\"value_codecs\",\n        msg_suffix=\"Please use 'value_codecs' instead.\",\n    )\n    @deprecated_arg(\n        name=\"count_compressor\",\n        since=\"1.5.0\",\n        removed=\"1.7.0\",\n        new_name=\"count_codecs\",\n        msg_suffix=\"Please use 'count_codecs' instead.\",\n    )\n    def __init__(\n        self,\n        merged_shape: Sequence[int],\n        cropped_shape: Sequence[int] | None = None,\n        dtype: np.dtype | str = \"float32\",\n        value_dtype: np.dtype | str = \"float32\",\n        count_dtype: np.dtype | str = \"uint8\",\n        store: zarr.storage.Store | str = \"merged.zarr\",  # type: ignore\n        value_store: zarr.storage.Store | str | None = None,  # type: ignore\n        count_store: zarr.storage.Store | str | None = None,  # type: ignore\n        compressor: str | None = None,\n        value_compressor: str | None = None,\n        count_compressor: str | None = None,\n        codecs: list | None = None,\n        value_codecs: list | None = None,\n        count_codecs: list | None = None,\n        chunks: Sequence[int] | bool = True,\n        thread_locking: bool = True,\n    ) -> None:\n        super().__init__(merged_shape=merged_shape, cropped_shape=cropped_shape)\n        if not self.merged_shape:\n            raise ValueError(f\"`merged_shape` must be provided for `ZarrAvgMerger`. {self.merged_shape} is give.\")\n        self.output_dtype = dtype\n        self.value_dtype = value_dtype\n        self.count_dtype = count_dtype\n        self.store = store\n        self.tmpdir: TemporaryDirectory | None\n\n        # Handle zarr v3 vs older versions\n        is_zarr_v3 = version_geq(get_package_version(\"zarr\"), \"3.0.0\")\n\n        if is_zarr_v3:\n            if value_store is None:\n                self.tmpdir = TemporaryDirectory()\n                self.value_store = zarr.storage.LocalStore(self.tmpdir.name)  # type: ignore\n            else:\n                self.value_store = value_store  # type: ignore\n            if count_store is None:\n                self.tmpdir = TemporaryDirectory()\n                self.count_store = zarr.storage.LocalStore(self.tmpdir.name)  # type: ignore\n            else:\n                self.count_store = count_store  # type: ignore\n        else:\n            self.tmpdir = None\n            self.value_store = zarr.storage.TempStore() if value_store is None else value_store  # type: ignore\n            self.count_store = zarr.storage.TempStore() if count_store is None else count_store  # type: ignore\n\n        self.chunks = chunks\n\n        # Handle compressor/codecs based on zarr version\n        is_zarr_v3 = version_geq(get_package_version(\"zarr\"), \"3.0.0\")\n\n        # Initialize codecs/compressor attributes with proper types\n        self.codecs: list | None = None\n        self.value_codecs: list | None = None\n        self.count_codecs: list | None = None\n\n        if is_zarr_v3:\n            # For zarr v3, use codecs or convert compressor to codecs\n            if codecs is not None:\n                self.codecs = codecs\n            elif compressor is not None:\n                # Convert compressor to codec format\n                if isinstance(compressor, (list, tuple)):\n                    self.codecs = compressor\n                else:\n                    self.codecs = [compressor]\n            else:\n                self.codecs = None\n\n            if value_codecs is not None:\n                self.value_codecs = value_codecs\n            elif value_compressor is not None:\n                if isinstance(value_compressor, (list, tuple)):\n                    self.value_codecs = value_compressor\n                else:\n                    self.value_codecs = [value_compressor]\n            else:\n                self.value_codecs = None\n\n            if count_codecs is not None:\n                self.count_codecs = count_codecs\n            elif count_compressor is not None:\n                if isinstance(count_compressor, (list, tuple)):\n                    self.count_codecs = count_compressor\n                else:\n                    self.count_codecs = [count_compressor]\n            else:\n                self.count_codecs = None\n        else:\n            # For zarr v2, use compressors\n            if codecs is not None:\n                # If codecs are specified in v2, use the first codec as compressor\n                self.codecs = codecs[0] if isinstance(codecs, (list, tuple)) else codecs\n            else:\n                self.codecs = compressor  # type: ignore[assignment]\n\n            if value_codecs is not None:\n                self.value_codecs = value_codecs[0] if isinstance(value_codecs, (list, tuple)) else value_codecs\n            else:\n                self.value_codecs = value_compressor  # type: ignore[assignment]\n\n            if count_codecs is not None:\n                self.count_codecs = count_codecs[0] if isinstance(count_codecs, (list, tuple)) else count_codecs\n            else:\n                self.count_codecs = count_compressor  # type: ignore[assignment]\n\n        # Create zarr arrays with appropriate parameters based on version\n        if is_zarr_v3:\n            self.output = zarr.empty(\n                shape=self.merged_shape,\n                chunks=self.chunks,\n                dtype=self.output_dtype,\n                codecs=self.codecs,\n                store=self.store,\n                overwrite=True,\n            )\n            self.values = zarr.zeros(\n                shape=self.merged_shape,\n                chunks=self.chunks,\n                dtype=self.value_dtype,\n                codecs=self.value_codecs,\n                store=self.value_store,\n                overwrite=True,\n            )\n            self.counts = zarr.zeros(\n                shape=self.merged_shape,\n                chunks=self.chunks,\n                dtype=self.count_dtype,\n                codecs=self.count_codecs,\n                store=self.count_store,\n                overwrite=True,\n            )\n        else:\n            self.output = zarr.empty(\n                shape=self.merged_shape,\n                chunks=self.chunks,\n                dtype=self.output_dtype,\n                compressor=self.codecs,\n                store=self.store,\n                overwrite=True,\n            )\n            self.values = zarr.zeros(\n                shape=self.merged_shape,\n                chunks=self.chunks,\n                dtype=self.value_dtype,\n                compressor=self.value_codecs,\n                store=self.value_store,\n                overwrite=True,\n            )\n            self.counts = zarr.zeros(\n                shape=self.merged_shape,\n                chunks=self.chunks,\n                dtype=self.count_dtype,\n                compressor=self.count_codecs,\n                store=self.count_store,\n                overwrite=True,\n            )\n\n        self.lock: threading.Lock | nullcontext\n        if thread_locking:\n            # use lock to protect the in-place addition during aggregation\n            self.lock = threading.Lock()\n        else:\n            # use nullcontext to avoid the locking if not needed\n            self.lock = nullcontext()\n\n    def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None:\n        \"\"\"\n        Aggregate values for merging.\n\n        Args:\n            values: a tensor of shape BCHW[D], representing the values of inference output.\n            location: a tuple/list giving the top left location of the patch in the original image.\n        \"\"\"\n        if self.is_finalized:\n            raise ValueError(\"`ZarrAvgMerger` is already finalized. Please instantiate a new object to aggregate.\")\n        patch_size = values.shape[2:]\n        map_slice = tuple(slice(loc, loc + size) for loc, size in zip(location, patch_size))\n        map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True)\n        with self.lock:\n            self.values[map_slice] += values.numpy()\n            self.counts[map_slice] += 1  # type: ignore[operator]\n\n    def finalize(self) -> zarr.Array:\n        \"\"\"\n        Finalize merging by dividing values by counts and return the merged tensor.\n\n        Notes:\n            To avoid creating a new tensor for the final results (to save memory space),\n            after this method is called, `get_values()` method will return the \"final\" averaged values,\n            and not the accumulating values. Also calling `finalize()` multiple times does not have any effect.\n\n        Returns:\n            zarr.Array: a zarr array of of merged patches\n        \"\"\"\n        # guard against multiple calls to finalize\n        if not self.is_finalized:\n            # use chunks for division to fit into memory\n            for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape):\n                self.output[chunk] = self.values[chunk] / self.counts[chunk]  # type: ignore[operator]\n            # finalize the shape\n            self.output.resize(self.cropped_shape)\n            # set finalize flag to protect performing in-place division again\n            self.is_finalized = True\n\n        return self.output\n\n    def get_output(self) -> zarr.Array:\n        \"\"\"\n        Get the final merged output.\n\n        Returns:\n            zarr.Array: Merged (averaged) output tensor.\n        \"\"\"\n        return self.output\n\n    def get_values(self) -> zarr.Array:\n        \"\"\"\n        Get the accumulated values during aggregation\n\n        Returns:\n            zarr.Array: aggregated values.\n\n        \"\"\"\n        return self.values\n\n    def get_counts(self) -> zarr.Array:\n        \"\"\"\n        Get the aggregator tensor for number of samples.\n\n        Returns:\n            zarr.Array: Number of accumulated samples at each location.\n        \"\"\"\n        return self.counts\n\n\ndef iterate_over_chunks(chunks, cdata_shape, slice_tuple=()):\n    \"\"\"\n    Iterate over chunks of a given shape.\n\n    Args:\n        chunks: the chunk shape\n        cdata_shape: the shape of the data in chunks\n        slice_tuple: the slice tuple to be used for indexing\n\n    Raises:\n        ValueError: When the length of chunks and cdata_shape are not the same.\n\n    Yields:\n        slices of the data\n    \"\"\"\n    if len(chunks) != len(cdata_shape):\n        raise ValueError(\"chunks and cdata_shape must have the same length\")\n    if len(chunks) == 1:\n        for i in range(cdata_shape[0]):\n            yield slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)\n    else:\n        for i in range(cdata_shape[0]):\n            yield from iterate_over_chunks(\n                chunks[1:], cdata_shape[1:], slice_tuple + (slice(i * chunks[0], (i + 1) * chunks[0]),)\n            )\n"
  },
  {
    "path": "monai/inferers/splitter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Iterable, Sequence\nfrom inspect import _empty, isclass, signature\nfrom typing import Any\n\nimport torch\n\nfrom monai.data.utils import iter_patch_position\nfrom monai.data.wsi_reader import BaseWSIReader, WSIReader\nfrom monai.transforms.utility.array import convert_to_tensor\nfrom monai.utils.misc import PathLike, ensure_tuple, ensure_tuple_rep\n\n__all__ = [\"Splitter\", \"SlidingWindowSplitter\", \"WSISlidingWindowSplitter\"]\n\n\nclass Splitter(ABC):\n    \"\"\"\n    A base class for splitting the inputs into iterable tuple of patches and locations\n    Extend this class to support operations for `PatchInference`, e.g. SlidingPatchSplitter.\n\n    Args:\n        patch_size: the size of patches to be generated.\n        device: the device where the patches are generated.\n    \"\"\"\n\n    def __init__(self, patch_size: Sequence[int] | int, device: torch.device | str | None = None) -> None:\n        self.patch_size = patch_size\n        self.device = device\n\n    @abstractmethod\n    def get_input_shape(self, inputs: Any) -> tuple:\n        \"\"\"\n        Return the input spatial shape.\n\n        Args:\n            inputs: either a tensor of shape BCHW[D], representing a batch of images,\n                or a filename (str) or list of filenames to the image(s).\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def get_padded_shape(self, inputs: Any) -> tuple:\n        \"\"\"\n        Return the actual spatial shape covered by the output split patches.\n        For instance, if the input image is padded, the actual spatial shape will be enlarged\n        and not the same as input spatial shape.\n\n        Args:\n            inputs: either a tensor of shape BCHW[D], representing a batch of images,\n                or a filename (str) or list of filenames to the image(s).\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    @abstractmethod\n    def __call__(self, inputs: Any) -> Iterable[tuple[torch.Tensor, Sequence[int]]]:\n        \"\"\"\n        Split the input image (or batch of images) into patches and return pairs of (patch, location).\n        Where location is the coordinate of top left [front] corner of a patch.\n\n        Args:\n            inputs: either a tensor of shape BCHW[D], representing a batch of images,\n                or a filename (str) or list of filenames to the image(s).\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass SlidingWindowSplitter(Splitter):\n    \"\"\"\n    Splits the input into patches with sliding window strategy and a possible overlap.\n    It also allows offsetting the starting position and filtering the patches.\n\n    Args:\n        patch_size : the size of the patches to be generated.\n        offset: the amount of offset for the patches with respect to the original input.  Defaults to 0.\n        overlap: the amount of overlap between patches in each dimension. It can be either a float in\n            the range of [0.0, 1.0) that defines relative overlap to the patch size, or it can be a non-negative int\n            that defines number of pixels for overlap. Defaults to 0.0.\n        filter_fn: a callable to filter patches. It should accepts exactly two parameters (patch, location), and\n            return True for a patch to keep. Defaults to no filtering.\n        pad_mode: string define the mode for `torch.nn.functional.pad`. The acceptable values are\n            `\"constant\"`, `\"reflect\"`, `\"replicate\"`, `\"circular\"` or `None`. Default to `\"constant\"`.\n            If None, no padding will be applied, so it will drop the patches crossing the border of\n            the image (either when the offset is negative or the image is non-divisible by the patch_size).\n        pad_value: the value for `\"constant\"` padding. Defaults to 0.\n        device: the device where the patches are generated. Defaults to the device of inputs.\n\n    Note:\n        When a scaler value is provided for `patch_size`, `offset`, or `overlap`,\n            it is broadcasted to all the spatial dimensions.\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size: Sequence[int] | int,\n        overlap: Sequence[float] | float | Sequence[int] | int = 0.0,\n        offset: Sequence[int] | int = 0,\n        filter_fn: Callable | None = None,\n        pad_mode: str | None = \"constant\",\n        pad_value: float | int = 0,\n        device: torch.device | str | None = None,\n    ) -> None:\n        super().__init__(patch_size=patch_size, device=device)\n        self.offset = offset\n        # check if fraction overlaps are within the range of [0, 1)\n        if isinstance(ensure_tuple(overlap)[0], float) and any(ov < 0.0 or ov >= 1.0 for ov in ensure_tuple(overlap)):\n            raise ValueError(\n                f\"Relative overlap must be between 0.0 and 1.0 but {overlap} is given. \"\n                \"If you wish to use number of pixels as overlap, please provide integer numbers.\"\n            )\n        elif any(ov < 0 for ov in ensure_tuple(overlap)):\n            raise ValueError(f\"Number of pixels for overlap cannot be negative. {overlap} is given. \")\n\n        self.overlap = overlap\n        self.filter_fn = self._validate_filter_fn(filter_fn)\n        # padding\n        self.pad_mode = pad_mode\n        self.pad_value = pad_value\n        # check a valid padding mode is provided if there is any negative offset.\n        if not self.pad_mode and any(off < 0 for off in ensure_tuple(offset)):\n            raise ValueError(f\"Negative `offset`requires a valid padding mode but `mode` is set to {self.pad_mode}.\")\n\n    @staticmethod\n    def _validate_filter_fn(filter_fn):\n        if callable(filter_fn):\n            sig = signature(filter_fn)\n            n_params = len(sig.parameters)\n            num_pos_params = len([v for v in sig.parameters.values() if v.default is _empty])\n            if n_params < 2:\n                raise ValueError(\n                    f\"`filter_fn` requires to accept at least two parameters (patch, location).\"\n                    f\"The provided callable ({filter_fn}) has {n_params} parameters.\"\n                )\n            elif num_pos_params > 2:\n                raise ValueError(\n                    f\"`filter_fn` can have at most two positional parameters (patch, location).\"\n                    f\"The provided callable ({filter_fn}) has {num_pos_params} positional parameters.\"\n                )\n        elif filter_fn is not None:\n            raise ValueError(\n                \"`filter_fn` should be a callable with two input parameters (patch, location). \"\n                f\"{type(filter_fn)} is given.\"\n            )\n        return filter_fn\n\n    def _calculate_pad_size(self, spatial_shape, spatial_ndim, patch_size, offset, overlap):\n        # initialize with zero\n        pad_size = [0] * 2 * spatial_ndim\n        if not self.pad_mode:\n            return pad_size, False\n        # set the starting pad size only if the offset is negative\n        pad_size[1::2] = (-min(off, 0) for off in offset)\n        # set the ending pad size only if it is not divisible by the patch size\n        end_padding = []\n        for sh, off, ps, ov in zip(spatial_shape, offset, patch_size, overlap):\n            if ps == 0:\n                pad_amount = 0\n            else:\n                if isinstance(ov, float):\n                    pad_amount = (off - sh + ps) % round(ps - (ps * ov))\n                else:\n                    pad_amount = (off - sh + ps) % round(ps - ov)\n            end_padding.append(pad_amount)\n\n        pad_size[::2] = end_padding\n        return pad_size, any(pad_size[1::2])\n\n    def _get_valid_shape_parameters(\n        self, spatial_shape: Sequence[int]\n    ) -> tuple[tuple[int, ...], tuple[float, ...] | tuple[int, ...], tuple[int, ...]]:\n        spatial_ndim = len(spatial_shape)\n        # patch_size\n        patch_size = ensure_tuple_rep(self.patch_size, spatial_ndim)\n        # overlap\n        overlap = ensure_tuple_rep(self.overlap, spatial_ndim)\n        overlap = tuple(o if p else type(overlap[0])(0) for o, p in zip(overlap, patch_size))\n        if any(ov > ps for ov, ps in zip(overlap, patch_size)):\n            raise ValueError(f\"`overlap` ({overlap}) cannot be larger than patch size ({patch_size}).\")\n        # offset\n        offset = ensure_tuple_rep(self.offset, spatial_ndim)\n        for off, ps, sh in zip(offset, patch_size, spatial_shape):\n            if off < -ps:\n                raise ValueError(f\"Negative `offset` ({off}) cannot be larger than `patch_size` ({ps}) in magnitude.\")\n            if off >= sh:\n                raise ValueError(f\"`offset` ({off}) cannot be larger than inputs size ({sh}).\")\n        return patch_size, overlap, offset\n\n    def _get_patch(self, inputs: Any, location: tuple[int, ...], patch_size: tuple[int, ...]) -> Any:\n        slices = (slice(None),) * 2 + tuple(slice(loc, loc + ps) for loc, ps in zip(location, patch_size))\n        return inputs[slices]\n\n    def get_input_shape(self, inputs: Any) -> tuple:\n        \"\"\"\n        Return the input spatial shape.\n\n        Args:\n            inputs: either a tensor of shape BCHW[D], representing a batch of images,\n                or a filename (str) or list of filenames to the image(s).\n\n        Returns:\n            spatial_shape\n        \"\"\"\n        return tuple(inputs.shape[2:])\n\n    def get_padded_shape(self, inputs: Any) -> tuple:\n        \"\"\"\n        Return the actual spatial shape covered by the output split patches.\n        For instance, if the input image is padded, the actual spatial shape will be enlarged\n        and not the same as input spatial shape.\n\n        Args:\n            inputs: either a tensor of shape BCHW[D], representing a batch of images,\n                or a filename (str) or list of filenames to the image(s).\n\n        Returns:\n            padded_spatial_shape\n\n        \"\"\"\n        spatial_shape = self.get_input_shape(inputs)\n        if not self.pad_mode:\n            return spatial_shape\n        spatial_ndim = len(spatial_shape)\n        patch_size, overlap, offset = self._get_valid_shape_parameters(spatial_shape)\n        pad_size, _ = self._calculate_pad_size(spatial_shape, spatial_ndim, patch_size, offset, overlap)\n        padded_spatial_shape = tuple(ss + ps + pe for ss, ps, pe in zip(spatial_shape, pad_size[1::2], pad_size[::2]))\n\n        return padded_spatial_shape\n\n    def __call__(self, inputs: Any) -> Iterable[tuple[torch.Tensor, Sequence[int]]]:\n        \"\"\"Split the input tensor into patches and return patches and locations.\n\n        Args:\n            inputs: either a torch.Tensor with BCHW[D] dimensions, representing an image or a batch of images\n\n        Yields:\n            tuple[torch.Tensor, Sequence[int]]: yields tuple of patch and location\n        \"\"\"\n\n        if not isinstance(inputs, torch.Tensor):\n            raise ValueError(f\"The input should be a tensor. {type(inputs)} is given.\")\n\n        spatial_shape = inputs.shape[2:]\n        spatial_ndim = len(spatial_shape)\n        patch_size, overlap, offset = self._get_valid_shape_parameters(spatial_shape)\n        pad_size, is_start_padded = self._calculate_pad_size(spatial_shape, spatial_ndim, patch_size, offset, overlap)\n\n        # Padding\n        if self.pad_mode and any(pad_size):\n            # pad the inputs\n            inputs = torch.nn.functional.pad(inputs, pad_size[::-1], mode=self.pad_mode, value=self.pad_value)\n            # update spatial shape\n            spatial_shape = inputs.shape[2:]\n            # correct the offset with respect to the padded image\n            if is_start_padded:\n                offset = tuple(off + p for off, p in zip(offset, pad_size[1::2]))\n\n        # Splitting\n        for location in iter_patch_position(spatial_shape, patch_size, offset, overlap, False):\n            patch = self._get_patch(inputs, location, patch_size)\n            patch = convert_to_tensor(patch, device=self.device)\n            # correct the location with respect to original inputs (remove starting pads)\n            if is_start_padded:\n                location = tuple(loc - p for loc, p in zip(location, pad_size[1::2]))\n            # filter patch and yield\n            if self.filter_fn is None or self.filter_fn(patch, location):\n                yield patch, location\n\n\nclass WSISlidingWindowSplitter(SlidingWindowSplitter):\n    \"\"\"\n    Splits the whole slide image input into patches with sliding window strategy and a possible overlap.\n    This extracts patches from file without loading the entire slide into memory.\n    It also allows offsetting the starting position and filtering the patches.\n\n    Args:\n        patch_size : the size of the patches to be generated.\n        offset: the amount of offset for the patches with respect to the original input.  Defaults to 0.\n        overlap: the amount of overlap between patches in each dimension. It can be either a float in\n            the range of [0.0, 1.0) that defines relative overlap to the patch size, or it can be a non-negative int\n            that defines number of pixels for overlap. Defaults to 0.0.\n        filter_fn: a callable to filter patches. It should accepts exactly two parameters (patch, location), and\n            return True for a patch to keep. Defaults to no filtering.\n        pad_mode: define the mode for padding. Either \"constant\" or None. Default to \"constant\".\n            Padding is only supported with \"OpenSlide\" or \"cuCIM\" backend, and the filling value is 256.\n        device: the device where the patches are generated. Defaults to the device of inputs.\n        reader: the module to be used for loading whole slide imaging. If `reader` is\n\n            - a string, it defines the backend of `monai.data.WSIReader`. Defaults to \"OpenSlide\".\n            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.\n            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.\n\n            To obtain an optimized performance please use either \"cuCIM\" or \"OpenSlide\" backend.\n        reader_kwargs: the arguments to pass to `WSIReader` or the provided whole slide reader class.\n            For instance, level=2, dtype=torch.float32, etc.\n            Note that if `level` is not provided, `level=0` is assumed.\n\n    Note:\n        When a scaler value is provided for `patch_size`, `offset`, or `overlap`,\n        it is broadcasted to all the spatial dimensions.\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size: Sequence[int] | int,\n        overlap: Sequence[float] | float | Sequence[int] | int = 0.0,\n        offset: Sequence[int] | int = 0,\n        filter_fn: Callable | None = None,\n        pad_mode: str | None = \"constant\",\n        device: torch.device | str | None = None,\n        reader: str | BaseWSIReader | type[BaseWSIReader] | None = \"OpenSlide\",\n        **reader_kwargs: dict,\n    ) -> None:\n        if pad_mode and pad_mode != \"constant\":\n            raise ValueError(\n                f\"The underlying wsi readers only support for constant padding. pad_mod={pad_mode} is given.\"\n            )\n\n        super().__init__(\n            patch_size=patch_size, overlap=overlap, offset=offset, filter_fn=filter_fn, device=device, pad_mode=pad_mode\n        )\n        # Set WSI reader\n        self._set_reader(reader, reader_kwargs)\n        if self.reader.backend.lower() not in [\"openslide\", \"cucim\"]:\n            warnings.warn(\n                f\"WSIReader with {self.reader.backend.lower()} backend is not supported for efficiently loading patches. \"\n                \"This may cause an significant slow down and a large memory foot print. \"\n                \"Please use other backends such as 'OpenSlide' or 'cuCIM' instead.\"\n            )\n\n    def _set_reader(self, reader: str | BaseWSIReader | type[BaseWSIReader] | None, reader_kwargs: dict) -> None:\n        \"\"\"\n        Set the WSI reader object based on the input reader\n\n        Args:\n            reader: the module to be used for loading whole slide imaging. If `reader` is\n\n                - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.\n                - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.\n                - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.\n        \"\"\"\n        self.reader: WSIReader | BaseWSIReader\n        self.reader_kwargs = reader_kwargs\n        if isinstance(reader, str):\n            self.reader = WSIReader(backend=reader, **self.reader_kwargs)\n        elif isclass(reader) and issubclass(reader, BaseWSIReader):\n            self.reader = reader(**self.reader_kwargs)\n        elif isinstance(reader, BaseWSIReader):\n            self.reader = reader\n        else:\n            raise ValueError(f\"Unsupported reader type: {reader}.\")\n\n    def _get_patch(self, inputs: Any, location: tuple[int, ...], patch_size: tuple[int, ...]) -> Any:\n        patch, _ = self.reader.get_data(wsi=inputs, location=location, size=patch_size)  # type: ignore\n        return patch[None]\n\n    def get_input_shape(self, inputs: Any) -> tuple:\n        \"\"\"\n        Return the input spatial shape.\n\n        Args:\n            inputs: either a tensor of shape BCHW[D], representing a batch of images,\n                or a filename (str) or list of filenames to the image(s).\n\n        Returns:\n            spatial_shape\n\n        \"\"\"\n        wsi = self.reader.read(inputs)\n        level = self.reader_kwargs.get(\"level\", 0)\n        return self.reader.get_size(wsi, level)\n\n    def __call__(self, inputs: PathLike | Sequence[PathLike]) -> Iterable[tuple[torch.Tensor, Sequence[int]]]:\n        \"\"\"Split the input tensor into patches and return patches and locations.\n\n        Args:\n            inputs: the file path to a whole slide image.\n\n        Yields:\n            tuple[torch.Tensor, Sequence[int]]: yields tuple of patch and location\n        \"\"\"\n        # Handle if the input file paths are batched\n        if not isinstance(inputs, str) and isinstance(inputs, Sequence):\n            if len(inputs) > 1:\n                raise ValueError(\"Only batch size of one would work for wsi image. Please provide one path at a time.\")\n            inputs = inputs[0]\n\n        # Check if the input is a sting or path like\n        if not isinstance(inputs, (str, os.PathLike)):\n            raise ValueError(f\"The input should be the path to the whole slide image. {type(inputs)} is given.\")\n\n        wsi = self.reader.read(inputs)\n        level = self.reader_kwargs.get(\"level\", 0)\n        downsample_ratio = self.reader.get_downsample_ratio(wsi, level)\n        spatial_shape: tuple = self.reader.get_size(wsi, level)\n        spatial_ndim = len(spatial_shape)\n        if spatial_ndim != 2:\n            raise ValueError(f\"WSIReader only support 2D images. {spatial_ndim} spatial dimension is provided.\")\n        patch_size, overlap, offset = self._get_valid_shape_parameters(spatial_shape)\n        pad_size, is_start_padded = self._calculate_pad_size(spatial_shape, spatial_ndim, patch_size, offset, overlap)\n\n        # Padding (extend the spatial shape)\n        if any(pad_size):\n            spatial_shape = tuple(ss + ps + pe for ss, ps, pe in zip(spatial_shape, pad_size[1::2], pad_size[::2]))\n            # correct the offset with respect to the padded image\n            if is_start_padded:\n                offset = tuple(off + p for off, p in zip(offset, pad_size[1::2]))\n\n        # Splitting (extracting patches)\n        for location in iter_patch_position(spatial_shape, patch_size, offset, overlap, False):\n            location_ = tuple(round(loc * downsample_ratio) for loc in location)\n            patch = self._get_patch(wsi, location_, patch_size)\n            patch = convert_to_tensor(patch, device=self.device)\n            # correct the location with respect to original inputs (remove starting pads)\n            if is_start_padded:\n                location = tuple(loc - p for loc, p in zip(location, pad_size[1::2]))\n            # filter patch and yield\n            if self.filter_fn is None or self.filter_fn(patch, location):\n                yield patch, location\n"
  },
  {
    "path": "monai/inferers/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nfrom collections.abc import Callable, Iterable, Mapping, Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size\nfrom monai.utils import (\n    BlendMode,\n    PytorchPadMode,\n    convert_data_type,\n    convert_to_dst_type,\n    ensure_tuple,\n    ensure_tuple_rep,\n    fall_back_tuple,\n    look_up_option,\n    optional_import,\n)\n\ntqdm, _ = optional_import(\"tqdm\", name=\"tqdm\")\n_nearest_mode = \"nearest-exact\"\n\n__all__ = [\"sliding_window_inference\"]\n\n\ndef sliding_window_inference(\n    inputs: torch.Tensor | MetaTensor,\n    roi_size: Sequence[int] | int,\n    sw_batch_size: int,\n    predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],\n    overlap: Sequence[float] | float = 0.25,\n    mode: BlendMode | str = BlendMode.CONSTANT,\n    sigma_scale: Sequence[float] | float = 0.125,\n    padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT,\n    cval: float = 0.0,\n    sw_device: torch.device | str | None = None,\n    device: torch.device | str | None = None,\n    progress: bool = False,\n    roi_weight_map: torch.Tensor | None = None,\n    process_fn: Callable | None = None,\n    buffer_steps: int | None = None,\n    buffer_dim: int = -1,\n    with_coord: bool = False,\n    *args: Any,\n    **kwargs: Any,\n) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:\n    \"\"\"\n    Sliding window inference on `inputs` with `predictor`.\n\n    The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.\n    Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.\n    e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes\n    could be ([128,64,256], [64,32,128]).\n    In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still\n    an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters\n    so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).\n\n    When roi_size is larger than the inputs' spatial size, the input image are padded during inference.\n    To maintain the same spatial sizes, the output image will be cropped to the original input size.\n\n    Args:\n        inputs: input image to be processed (assuming NCHW[D])\n        roi_size: the spatial window size for inferences, this must be a single value or a tuple with values\n            for each spatial dimension (eg. 2 for 2D, 3 for 3D).\n            When its components have None or non-positives, the corresponding inputs dimension will be used.\n            if the components of the `roi_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        sw_batch_size: the batch size to run window slices.\n        predictor: given input tensor ``patch_data`` in shape NCHW[D],\n            The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary\n            with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];\n            where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,\n            N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),\n            the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).\n            In this case, the parameter `overlap` and `roi_size` need to be carefully chosen\n            to ensure the scaled output ROI sizes are still integers.\n            If the `predictor`'s input and output spatial sizes are different,\n            we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.\n        overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``.\n        mode: {``\"constant\"``, ``\"gaussian\"``}\n            How to blend output of overlapping windows. Defaults to ``\"constant\"``.\n\n            - ``\"constant``\": gives equal weight to all predictions.\n            - ``\"gaussian``\": gives less weight to predictions on edges of windows.\n\n        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``\"gaussian\"``.\n            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.\n            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding\n            spatial dimensions.\n        padding_mode: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}\n            Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``\"constant\"``\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        cval: fill value for 'constant' padding mode. Default: 0\n        sw_device: device for the window data.\n            By default the device (and accordingly the memory) of the `inputs` is used.\n            Normally `sw_device` should be consistent with the device where `predictor` is defined.\n        device: device for the stitched output prediction.\n            By default the device (and accordingly the memory) of the `inputs` is used. If for example\n            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the\n            `inputs` and `roi_size`. Output is on the `device`.\n        progress: whether to print a `tqdm` progress bar.\n        roi_weight_map: pre-computed (non-negative) weight map for each ROI.\n            If not given, and ``mode`` is not `constant`, this map will be computed on the fly.\n        process_fn: process inference output and adjust the importance map per window\n        buffer_steps: the number of sliding window iterations along the ``buffer_dim``\n            to be buffered on ``sw_device`` before writing to ``device``.\n            (Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.)\n            default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size,\n            (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.\n        buffer_dim: the spatial dimension along which the buffers are created.\n            0 indicates the first spatial dimension. Default is -1, the last spatial dimension.\n        with_coord: whether to pass the window coordinates to ``predictor``. Default is False.\n            If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``.\n        args: optional args to be passed to ``predictor``.\n        kwargs: optional keyword args to be passed to ``predictor``.\n\n    Note:\n        - Inputs must be channel-first and have a batch dim (NCHW / NCDHW).\n        - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream.\n\n    Raises:\n        ValueError: When the input dimensions do not match the expected dimensions based on ``roi_size``.\n\n    \"\"\"\n    num_spatial_dims = len(inputs.shape) - 2\n\n    # Only perform strict shape validation if roi_size is a sequence (explicit dimensions).\n    # If roi_size is an integer, it is broadcast to all dimensions, so we cannot\n    # infer the expected dimensionality to enforce a strict check here.\n    if isinstance(roi_size, Sequence):\n        roi_dims = len(roi_size)\n        if num_spatial_dims != roi_dims:\n            raise ValueError(\n                f\"Inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size \"\n                f\"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), \"\n                f\"but got inputs shape {inputs.shape}.\\n\"\n                \"If you have channel-last data (e.g. B, D, H, W, C), please use \"\n                \"monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream.\"\n            )\n    # -----------------------------------------------------------------\n    buffered = buffer_steps is not None and buffer_steps > 0\n    if buffered:\n        if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims:\n            raise ValueError(f\"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.\")\n        if buffer_dim < 0:\n            buffer_dim += num_spatial_dims\n    overlap = ensure_tuple_rep(overlap, num_spatial_dims)\n    for o in overlap:\n        if o < 0 or o >= 1:\n            raise ValueError(f\"overlap must be >= 0 and < 1, got {overlap}.\")\n    compute_dtype = inputs.dtype\n\n    # determine image spatial size and batch size\n    # Note: all input images must have the same image size and batch size\n    batch_size, _, *image_size_ = inputs.shape\n    device = device or inputs.device\n    sw_device = sw_device or inputs.device\n\n    condition = kwargs.pop(\"condition\", None)\n\n    temp_meta = None\n    if isinstance(inputs, MetaTensor):\n        temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)\n    inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0]\n    roi_size = fall_back_tuple(roi_size, image_size_)\n\n    # in case that image size is smaller than roi size\n    image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))\n    pad_size = []\n    for k in range(len(inputs.shape) - 1, 1, -1):\n        diff = max(roi_size[k - 2] - inputs.shape[k], 0)\n        half = diff // 2\n        pad_size.extend([half, diff - half])\n    if any(pad_size):\n        inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)\n        if condition is not None:\n            condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)\n\n    # Store all slices\n    scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)\n    slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered)\n\n    num_win = len(slices)  # number of windows per image\n    total_slices = num_win * batch_size  # total number of windows\n    windows_range: Iterable\n    if not buffered:\n        non_blocking = False\n        windows_range = range(0, total_slices, sw_batch_size)\n    else:\n        slices, n_per_batch, b_slices, windows_range = _create_buffered_slices(\n            slices, batch_size, sw_batch_size, buffer_dim, buffer_steps\n        )\n        non_blocking, _ss = torch.cuda.is_available(), -1\n        for x in b_slices[:n_per_batch]:\n            if x[1] < _ss:  # detect overlapping slices\n                non_blocking = False\n                break\n            _ss = x[2]\n\n    # Create window-level importance map\n    valid_patch_size = get_valid_patch_size(image_size, roi_size)\n    if valid_patch_size == roi_size and (roi_weight_map is not None):\n        importance_map_ = roi_weight_map\n    else:\n        try:\n            valid_p_size = ensure_tuple(valid_patch_size)\n            importance_map_ = compute_importance_map(\n                valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype\n            )\n            if len(importance_map_.shape) == num_spatial_dims and not process_fn:\n                importance_map_ = importance_map_[None, None]  # adds batch, channel dimensions\n        except Exception as e:\n            raise RuntimeError(\n                f\"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\\n\"\n                \"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'.\"\n            ) from e\n    importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0]\n\n    # stores output and count map\n    output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0  # type: ignore\n    # for each patch\n    for slice_g in tqdm(windows_range) if progress else windows_range:\n        slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices))\n        unravel_slice = [\n            [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])\n            for idx in slice_range\n        ]\n        if sw_batch_size > 1:\n            win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device)\n            if condition is not None:\n                win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(\n                    sw_device\n                )\n                kwargs[\"condition\"] = win_condition\n        else:\n            s0 = unravel_slice[0]\n            s0_idx = ensure_tuple(s0)\n\n            win_data = inputs[s0_idx].to(sw_device)\n            if condition is not None:\n                win_condition = condition[s0_idx].to(sw_device)\n                kwargs[\"condition\"] = win_condition\n\n        if with_coord:\n            seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)\n        else:\n            seg_prob_out = predictor(win_data, *args, **kwargs)\n        # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.\n        dict_keys, seg_tuple = _flatten_struct(seg_prob_out)\n        if process_fn:\n            seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_)\n        else:\n            w_t = importance_map_\n        if len(w_t.shape) == num_spatial_dims:\n            w_t = w_t[None, None]\n        w_t = w_t.to(dtype=compute_dtype, device=sw_device)\n        if buffered:\n            c_start, c_end = b_slices[b_s][1:]\n            if not sw_device_buffer:\n                k = seg_tuple[0].shape[1]  # len(seg_tuple) > 1 is currently ignored\n                sp_size = list(image_size)\n                sp_size[buffer_dim] = c_end - c_start\n                sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)]\n            for p, s in zip(seg_tuple[0], unravel_slice):\n                offset = s[buffer_dim + 2].start - c_start\n                s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])\n                s[0] = slice(0, 1)\n                sw_device_buffer[0][ensure_tuple(s)] += p * w_t\n            b_i += len(unravel_slice)\n            if b_i < b_slices[b_s][0]:\n                continue\n        else:\n            sw_device_buffer = list(seg_tuple)\n\n        for ss in range(len(sw_device_buffer)):\n            b_shape = sw_device_buffer[ss].shape\n            seg_chns, seg_shape = b_shape[1], b_shape[2:]\n            z_scale = None\n            if not buffered and seg_shape != roi_size:\n                z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)]\n                w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode)\n            if len(output_image_list) <= ss:\n                output_shape = [batch_size, seg_chns]\n                output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size)\n                # allocate memory to store the full output and the count for overlapping parts\n                new_tensor: Callable = torch.empty if non_blocking else torch.zeros  # type: ignore\n                output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device))\n                count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))\n                w_t_ = w_t.to(device)\n                for __s in slices:\n                    if z_scale is not None:\n                        __s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale))\n                    count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_\n            if buffered:\n                o_slice = [slice(None)] * len(inputs.shape)\n                o_slice[buffer_dim + 2] = slice(c_start, c_end)\n                img_b = b_s // n_per_batch  # image batch index\n                o_slice[0] = slice(img_b, img_b + 1)\n                o_slice_idx = ensure_tuple(o_slice)\n                if non_blocking:\n                    output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking)\n                else:\n                    output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device)\n            else:\n                sw_device_buffer[ss] *= w_t\n                sw_device_buffer[ss] = sw_device_buffer[ss].to(device)\n                _compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss])\n        sw_device_buffer = []\n        if buffered:\n            b_s += 1\n\n    if non_blocking:\n        torch.cuda.current_stream().synchronize()\n\n    # account for any overlapping sections\n    for ss in range(len(output_image_list)):\n        output_image_list[ss] /= count_map_list.pop(0)\n\n    # remove padding if image_size smaller than roi_size\n    if any(pad_size):\n        kwargs.update({\"pad_size\": pad_size})\n        for ss, output_i in enumerate(output_image_list):\n            zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]\n            final_slicing: list[slice] = []\n            for sp in range(num_spatial_dims):\n                si = num_spatial_dims - sp - 1\n                slice_dim = slice(\n                    int(round(pad_size[sp * 2] * zoom_scale[si])),\n                    int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])),\n                )\n                final_slicing.insert(0, slice_dim)\n            output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)]\n\n    final_output = _pack_struct(output_image_list, dict_keys)\n    if temp_meta is not None:\n        final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0]\n    else:\n        final_output = convert_to_dst_type(final_output, inputs, device=device)[0]\n\n    return final_output  # type: ignore\n\n\ndef _create_buffered_slices(slices, batch_size, sw_batch_size, buffer_dim, buffer_steps):\n    \"\"\"rearrange slices for buffering\"\"\"\n    slices_np = np.asarray(slices)\n    slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind=\"mergesort\")]\n    slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np]\n    slices_np = slices_np[:, buffer_dim]\n\n    _, _, _b_lens = np.unique(slices_np[:, 0], return_counts=True, return_index=True)\n    b_ends = np.cumsum(_b_lens).tolist()  # possible buffer flush boundaries\n    x = [0, *b_ends][:: min(len(b_ends), int(buffer_steps))]\n    if x[-1] < b_ends[-1]:\n        x.append(b_ends[-1])\n    n_per_batch = len(x) - 1\n    windows_range = [\n        range(b * x[-1] + x[i], b * x[-1] + x[i + 1], sw_batch_size)\n        for b in range(batch_size)\n        for i in range(n_per_batch)\n    ]\n    b_slices = []\n    for _s, _r in enumerate(windows_range):\n        s_s = slices_np[windows_range[_s - 1].stop % len(slices) if _s > 0 else 0, 0]\n        s_e = slices_np[(_r.stop - 1) % len(slices), 1]\n        b_slices.append((_r.stop, s_s, s_e))  # buffer index, slice start, slice end\n    windows_range = itertools.chain(*windows_range)  # type: ignore\n    return slices, n_per_batch, b_slices, windows_range\n\n\ndef _compute_coords(coords, z_scale, out, patch):\n    \"\"\"sliding window batch spatial scaling indexing for multi-resolution outputs.\"\"\"\n    for original_idx, p in zip(coords, patch):\n        idx_zm = list(original_idx)  # 4D for 2D image, 5D for 3D image\n        if z_scale:\n            for axis in range(2, len(idx_zm)):\n                idx_zm[axis] = slice(\n                    int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2])\n                )\n        out[ensure_tuple(idx_zm)] += p\n\n\ndef _get_scan_interval(\n    image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: Sequence[float]\n) -> tuple[int, ...]:\n    \"\"\"\n    Compute scan interval according to the image size, roi size and overlap.\n    Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,\n    use 1 instead to make sure sliding window works.\n\n    \"\"\"\n    if len(image_size) != num_spatial_dims:\n        raise ValueError(f\"len(image_size) {len(image_size)} different from spatial dims {num_spatial_dims}.\")\n    if len(roi_size) != num_spatial_dims:\n        raise ValueError(f\"len(roi_size) {len(roi_size)} different from spatial dims {num_spatial_dims}.\")\n\n    scan_interval = []\n    for i, o in zip(range(num_spatial_dims), overlap):\n        if roi_size[i] == image_size[i]:\n            scan_interval.append(int(roi_size[i]))\n        else:\n            interval = int(roi_size[i] * (1 - o))\n            scan_interval.append(interval if interval > 0 else 1)\n    return tuple(scan_interval)\n\n\ndef _flatten_struct(seg_out):\n    dict_keys = None\n    seg_probs: tuple[torch.Tensor, ...]\n    if isinstance(seg_out, torch.Tensor):\n        seg_probs = (seg_out,)\n    elif isinstance(seg_out, Mapping):\n        dict_keys = sorted(seg_out.keys())  # track predictor's output keys\n        seg_probs = tuple(seg_out[k] for k in dict_keys)\n    else:\n        seg_probs = ensure_tuple(seg_out)\n    return dict_keys, seg_probs\n\n\ndef _pack_struct(seg_out, dict_keys=None):\n    if dict_keys is not None:\n        return dict(zip(dict_keys, seg_out))\n    if isinstance(seg_out, (list, tuple)) and len(seg_out) == 1:\n        return seg_out[0]\n    return ensure_tuple(seg_out)\n"
  },
  {
    "path": "monai/losses/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .adversarial_loss import PatchAdversarialLoss\nfrom .barlow_twins import BarlowTwinsLoss\nfrom .cldice import SoftclDiceLoss, SoftDiceclDiceLoss\nfrom .contrastive import ContrastiveLoss\nfrom .deform import BendingEnergyLoss, DiffusionLoss\nfrom .dice import (\n    Dice,\n    DiceCELoss,\n    DiceFocalLoss,\n    DiceLoss,\n    GeneralizedDiceFocalLoss,\n    GeneralizedDiceLoss,\n    GeneralizedWassersteinDiceLoss,\n    MaskedDiceLoss,\n    dice_ce,\n    dice_focal,\n    generalized_dice,\n    generalized_dice_focal,\n    generalized_wasserstein_dice,\n)\nfrom .ds_loss import DeepSupervisionLoss\nfrom .focal_loss import FocalLoss\nfrom .giou_loss import BoxGIoULoss, giou\nfrom .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss\nfrom .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss\nfrom .multi_scale import MultiScaleLoss\nfrom .nacl_loss import NACLLoss\nfrom .perceptual import PerceptualLoss\nfrom .spatial_mask import MaskedLoss\nfrom .spectral_loss import JukeboxLoss\nfrom .ssim_loss import SSIMLoss\nfrom .sure_loss import SURELoss\nfrom .tversky import TverskyLoss\nfrom .unified_focal_loss import AsymmetricUnifiedFocalLoss\n"
  },
  {
    "path": "monai/losses/adversarial_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks.layers.utils import get_act_layer\nfrom monai.utils import LossReduction\nfrom monai.utils.enums import StrEnum\n\n\nclass AdversarialCriterions(StrEnum):\n    BCE = \"bce\"\n    HINGE = \"hinge\"\n    LEAST_SQUARE = \"least_squares\"\n\n\nclass PatchAdversarialLoss(_Loss):\n    \"\"\"\n    Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator.\n    Warning: due to the possibility of using different criterions, the output of the discrimination\n    mustn't be passed to a final activation layer. That is taken care of internally within the loss.\n\n    Args:\n        reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n            Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n            - ``\"none\"``: no reduction will be applied.\n            - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n            - ``\"sum\"``: the output will be summed.\n\n        criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs.\n            Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs\n            through an activation layer prior to calling the loss.\n        no_activation_leastsq: if True, the activation layer in the case of least-squares is removed.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        criterion: str = AdversarialCriterions.LEAST_SQUARE,\n        no_activation_leastsq: bool = False,\n    ) -> None:\n        super().__init__(reduction=LossReduction(reduction))\n\n        if criterion.lower() not in list(AdversarialCriterions):\n            raise ValueError(\n                f\"Unrecognised criterion entered for Adversarial Loss. Must be one in: {', '.join(AdversarialCriterions)}\"\n            )\n\n        # Depending on the criterion, a different activation layer is used.\n        self.real_label = 1.0\n        self.fake_label = 0.0\n        self.loss_fct: _Loss\n        if criterion == AdversarialCriterions.BCE:\n            self.activation = get_act_layer(\"SIGMOID\")\n            self.loss_fct = torch.nn.BCELoss(reduction=reduction)\n        elif criterion == AdversarialCriterions.HINGE:\n            self.activation = get_act_layer(\"TANH\")\n            self.fake_label = -1.0\n        elif criterion == AdversarialCriterions.LEAST_SQUARE:\n            if no_activation_leastsq:\n                self.activation = None\n            else:\n                self.activation = get_act_layer(name=(\"LEAKYRELU\", {\"negative_slope\": 0.05}))\n            self.loss_fct = torch.nn.MSELoss(reduction=reduction)\n\n        self.criterion = criterion\n        self.reduction = reduction\n\n    def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor:\n        \"\"\"\n        Gets the ground truth tensor for the discriminator depending on whether the input is real or fake.\n\n        Args:\n            input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale\n            discriminator). This is used to match the shape.\n            target_is_real: whether the input is real or wannabe-real (1s) or fake (0s).\n        Returns:\n        \"\"\"\n        filling_label = self.real_label if target_is_real else self.fake_label\n        label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device)\n        label_tensor.requires_grad_(False)\n        return label_tensor.expand_as(input)\n\n    def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Gets a zero tensor.\n\n        Args:\n            input: tensor which shape you want the zeros tensor to correspond to.\n        Returns:\n        \"\"\"\n\n        zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device)\n        zero_label_tensor.requires_grad_(False)\n        return zero_label_tensor.expand_as(input)\n\n    def forward(\n        self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool\n    ) -> torch.Tensor | list[torch.Tensor]:\n        \"\"\"\n\n        Args:\n            input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors\n                or a tensor; they shouldn't have gone through an activation layer.\n            target_is_real: whereas the input corresponds to discriminator output for real or fake images\n            for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last\n                case, target_is_real is set to True, as the generator wants the input to be dimmed as real.\n        Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale\n            discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the\n            summed or mean loss over the tensor and discriminator/s.\n\n        \"\"\"\n\n        if not for_discriminator and not target_is_real:\n            target_is_real = True  # With generator, we always want this to be true!\n            warnings.warn(\n                \"Variable target_is_real has been set to False, but for_discriminator is set\"\n                \"to False. To optimise a generator, target_is_real must be set to True.\"\n            )\n\n        if not isinstance(input, list):\n            input = [input]\n        target_ = []\n        for _, disc_out in enumerate(input):\n            if self.criterion != AdversarialCriterions.HINGE:\n                target_.append(self.get_target_tensor(disc_out, target_is_real))\n            else:\n                target_.append(self.get_zero_tensor(disc_out))\n\n        # Loss calculation\n        loss_list = []\n        for disc_ind, disc_out in enumerate(input):\n            if self.activation is not None:\n                disc_out = self.activation(disc_out)\n            if self.criterion == AdversarialCriterions.HINGE and not target_is_real:\n                loss_ = self._forward_single(-disc_out, target_[disc_ind])\n            else:\n                loss_ = self._forward_single(disc_out, target_[disc_ind])\n            loss_list.append(loss_)\n\n        loss: torch.Tensor | list[torch.Tensor]\n        if loss_list is not None:\n            if self.reduction == LossReduction.MEAN:\n                loss = torch.mean(torch.stack(loss_list))\n            elif self.reduction == LossReduction.SUM:\n                loss = torch.sum(torch.stack(loss_list))\n            else:\n                loss = loss_list\n        return loss\n\n    def _forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        forward: torch.Tensor\n        if self.criterion == AdversarialCriterions.BCE or self.criterion == AdversarialCriterions.LEAST_SQUARE:\n            forward = self.loss_fct(input, target)\n        elif self.criterion == AdversarialCriterions.HINGE:\n            minval = torch.min(input - 1, self.get_zero_tensor(input))\n            forward = -torch.mean(minval)\n        return forward\n"
  },
  {
    "path": "monai/losses/barlow_twins.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\n\nclass BarlowTwinsLoss(_Loss):\n    \"\"\"\n    The Barlow Twins cost function takes the representations extracted by a neural network from two\n    distorted views and seeks to make the cross-correlation matrix of the two representations tend\n    towards identity. This encourages the neural network to learn similar representations with the least\n    amount of redundancy. This cost function can be used in particular in multimodal learning to work on\n    representations from two modalities. The most common use case is for unsupervised learning, where data\n    augmentations are used to generate 2 distorted views of the same sample to force the encoder to\n    extract useful features for downstream tasks.\n\n    Zbontar, Jure, et al. \"Barlow Twins: Self-Supervised Learning via Redundancy Reduction\" International\n    conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v139/zbontar21a/zbontar21a.pdf)\n\n    Adapted from:\n        https://github.com/facebookresearch/barlowtwins\n\n    \"\"\"\n\n    def __init__(self, lambd: float = 5e-3) -> None:\n        \"\"\"\n        Args:\n            lamb: Can be any float to handle the informativeness and invariance trade-off. Ideally set to 5e-3.\n\n        Raises:\n            ValueError: When an input of dimension length > 2 is passed\n            ValueError: When input and target are of different shapes\n            ValueError: When batch size is less than or equal to 1\n\n        \"\"\"\n        super().__init__()\n        self.lambd = lambd\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be B[F].\n            target: the shape should be B[F].\n        \"\"\"\n        if len(target.shape) > 2 or len(input.shape) > 2:\n            raise ValueError(\n                f\"Either target or input has dimensions greater than 2 where target \"\n                f\"shape is ({target.shape}) and input shape is ({input.shape})\"\n            )\n\n        if target.shape != input.shape:\n            raise ValueError(f\"ground truth has differing shape ({target.shape}) from input ({input.shape})\")\n\n        if target.size(0) <= 1:\n            raise ValueError(\n                f\"Batch size must be greater than 1 to compute Barlow Twins Loss, but got {target.size(0)}\"\n            )\n\n        lambd_tensor = torch.as_tensor(self.lambd).to(input.device)\n        batch_size = input.shape[0]\n\n        # normalize input and target\n        input_norm = (input - input.mean(0)) / input.std(0).add(1e-6)\n        target_norm = (target - target.mean(0)) / target.std(0).add(1e-6)\n\n        # cross-correlation matrix\n        c = torch.mm(input_norm.t(), target_norm) / batch_size  # input_norm.t() is FxB, target_norm is BxF so c is FxF\n\n        # loss\n        c_diff = (c - torch.eye(c.size(0), device=c.device)).pow_(2)  # FxF\n        c_diff[~torch.eye(c.size(0), device=c.device).bool()] *= lambd_tensor\n\n        return c_diff.sum()\n"
  },
  {
    "path": "monai/losses/cldice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\n\ndef soft_erode(img: torch.Tensor) -> torch.Tensor:  # type: ignore\n    \"\"\"\n    Perform soft erosion on the input image\n\n    Args:\n        img: the shape should be BCH(WD)\n\n    Adapted from:\n        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L6\n    \"\"\"\n    if len(img.shape) == 4:\n        p1 = -(F.max_pool2d(-img, (3, 1), (1, 1), (1, 0)))\n        p2 = -(F.max_pool2d(-img, (1, 3), (1, 1), (0, 1)))\n        return torch.min(p1, p2)\n    elif len(img.shape) == 5:\n        p1 = -(F.max_pool3d(-img, (3, 1, 1), (1, 1, 1), (1, 0, 0)))\n        p2 = -(F.max_pool3d(-img, (1, 3, 1), (1, 1, 1), (0, 1, 0)))\n        p3 = -(F.max_pool3d(-img, (1, 1, 3), (1, 1, 1), (0, 0, 1)))\n        return torch.min(torch.min(p1, p2), p3)\n\n\ndef soft_dilate(img: torch.Tensor) -> torch.Tensor:  # type: ignore\n    \"\"\"\n    Perform soft dilation on the input image\n\n    Args:\n        img: the shape should be BCH(WD)\n\n    Adapted from:\n        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L18\n    \"\"\"\n    if len(img.shape) == 4:\n        return F.max_pool2d(img, (3, 3), (1, 1), (1, 1))\n    elif len(img.shape) == 5:\n        return F.max_pool3d(img, (3, 3, 3), (1, 1, 1), (1, 1, 1))\n\n\ndef soft_open(img: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Wrapper function to perform soft opening on the input image\n\n    Args:\n        img: the shape should be BCH(WD)\n\n    Adapted from:\n        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L25\n    \"\"\"\n    eroded_image = soft_erode(img)\n    dilated_image = soft_dilate(eroded_image)\n    return dilated_image\n\n\ndef soft_skel(img: torch.Tensor, iter_: int) -> torch.Tensor:\n    \"\"\"\n    Perform soft skeletonization on the input image\n\n    Adapted from:\n       https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/soft_skeleton.py#L29\n\n    Args:\n        img: the shape should be BCH(WD)\n        iter_: number of iterations for skeletonization\n\n    Returns:\n        skeletonized image\n    \"\"\"\n    img1 = soft_open(img)\n    skel = F.relu(img - img1)\n    for _ in range(iter_):\n        img = soft_erode(img)\n        img1 = soft_open(img)\n        delta = F.relu(img - img1)\n        skel = skel + F.relu(delta - skel * delta)\n    return skel\n\n\ndef soft_dice(y_true: torch.Tensor, y_pred: torch.Tensor, smooth: float = 1.0) -> torch.Tensor:\n    \"\"\"\n    Function to compute soft dice loss\n\n    Adapted from:\n        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L22\n\n    Args:\n        y_true: the shape should be BCH(WD)\n        y_pred: the shape should be BCH(WD)\n\n    Returns:\n        dice loss\n    \"\"\"\n    intersection = torch.sum((y_true * y_pred)[:, 1:, ...])\n    coeff = (2.0 * intersection + smooth) / (torch.sum(y_true[:, 1:, ...]) + torch.sum(y_pred[:, 1:, ...]) + smooth)\n    soft_dice: torch.Tensor = 1.0 - coeff\n    return soft_dice\n\n\nclass SoftclDiceLoss(_Loss):\n    \"\"\"\n    Compute the Soft clDice loss defined in:\n\n        Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function\n        for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)\n\n    Adapted from:\n        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L7\n    \"\"\"\n\n    def __init__(self, iter_: int = 3, smooth: float = 1.0) -> None:\n        \"\"\"\n        Args:\n            iter_: Number of iterations for skeletonization. Defaults to 3.\n            smooth: Smoothing parameter. Defaults to 1.0.\n        \"\"\"\n        super().__init__()\n        self.iter = iter_\n        self.smooth = smooth\n\n    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:\n        skel_pred = soft_skel(y_pred, self.iter)\n        skel_true = soft_skel(y_true, self.iter)\n        tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (\n            torch.sum(skel_pred[:, 1:, ...]) + self.smooth\n        )\n        tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (\n            torch.sum(skel_true[:, 1:, ...]) + self.smooth\n        )\n        cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)\n        return cl_dice\n\n\nclass SoftDiceclDiceLoss(_Loss):\n    \"\"\"\n    Compute the Soft clDice loss defined in:\n\n        Shit et al. (2021) clDice -- A Novel Topology-Preserving Loss Function\n        for Tubular Structure Segmentation. (https://arxiv.org/abs/2003.07311)\n\n    Adapted from:\n        https://github.com/jocpae/clDice/blob/master/cldice_loss/pytorch/cldice.py#L38\n    \"\"\"\n\n    def __init__(self, iter_: int = 3, alpha: float = 0.5, smooth: float = 1.0) -> None:\n        \"\"\"\n        Args:\n            iter_: Number of iterations for skeletonization. Defaults to 3.\n            alpha: Weighing factor for cldice. Defaults to 0.5.\n            smooth: Smoothing parameter. Defaults to 1.0.\n        \"\"\"\n        super().__init__()\n        self.iter = iter_\n        self.smooth = smooth\n        self.alpha = alpha\n\n    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:\n        dice = soft_dice(y_true, y_pred, self.smooth)\n        skel_pred = soft_skel(y_pred, self.iter)\n        skel_true = soft_skel(y_true, self.iter)\n        tprec = (torch.sum(torch.multiply(skel_pred, y_true)[:, 1:, ...]) + self.smooth) / (\n            torch.sum(skel_pred[:, 1:, ...]) + self.smooth\n        )\n        tsens = (torch.sum(torch.multiply(skel_true, y_pred)[:, 1:, ...]) + self.smooth) / (\n            torch.sum(skel_true[:, 1:, ...]) + self.smooth\n        )\n        cl_dice = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens)\n        total_loss: torch.Tensor = (1.0 - self.alpha) * dice + self.alpha * cl_dice\n        return total_loss\n"
  },
  {
    "path": "monai/losses/contrastive.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom warnings import warn\n\nimport torch\nfrom torch.nn import functional as F\nfrom torch.nn.modules.loss import _Loss\n\n\nclass ContrastiveLoss(_Loss):\n    \"\"\"\n    Compute the Contrastive loss defined in:\n\n        Chen, Ting, et al. \"A simple framework for contrastive learning of visual representations.\" International\n        conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v119/chen20j.html)\n\n    Adapted from:\n        https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/losses.py#L5\n\n    \"\"\"\n\n    def __init__(self, temperature: float = 0.5, batch_size: int = -1) -> None:\n        \"\"\"\n        Args:\n            temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5.\n\n        Raises:\n            ValueError: When an input of dimension length > 2 is passed\n            ValueError: When input and target are of different shapes\n\n        \"\"\"\n        super().__init__()\n        self.temperature = temperature\n\n        if batch_size != -1:\n            warn(\"batch_size is no longer required to be set. It will be estimated dynamically in the forward call\")\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be B[F].\n            target: the shape should be B[F].\n        \"\"\"\n        if len(target.shape) > 2 or len(input.shape) > 2:\n            raise ValueError(\n                f\"Either target or input has dimensions greater than 2 where target \"\n                f\"shape is ({target.shape}) and input shape is ({input.shape})\"\n            )\n\n        if target.shape != input.shape:\n            raise ValueError(f\"ground truth has differing shape ({target.shape}) from input ({input.shape})\")\n\n        temperature_tensor = torch.as_tensor(self.temperature).to(input.device)\n        batch_size = input.shape[0]\n\n        negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)\n        negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)\n\n        repr = torch.cat([input, target], dim=0)\n        sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)\n        sim_ij = torch.diag(sim_matrix, batch_size)\n        sim_ji = torch.diag(sim_matrix, -batch_size)\n\n        positives = torch.cat([sim_ij, sim_ji], dim=0)\n        nominator = torch.exp(positives / temperature_tensor)\n        denominator = negatives_mask * torch.exp(sim_matrix / temperature_tensor)\n\n        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))\n\n        return torch.sum(loss_partial) / (2 * batch_size)\n"
  },
  {
    "path": "monai/losses/deform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.utils import LossReduction\n\n\ndef spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor:\n    \"\"\"\n    Calculate gradients on single dimension of a tensor using central finite difference.\n    It moves the tensor along the dimension to calculate the approximate gradient\n    dx[i] = (x[i+1] - x[i-1]) / 2.\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n\n    Args:\n        x: the shape should be BCH(WD).\n        dim: dimension to calculate gradient along.\n    Returns:\n        gradient_dx: the shape should be BCH(WD)\n    \"\"\"\n    slice_1 = slice(1, -1)\n    slice_2_s = slice(2, None)\n    slice_2_e = slice(None, -2)\n    slice_all = slice(None)\n    slicing_s, slicing_e = [slice_all, slice_all], [slice_all, slice_all]\n    while len(slicing_s) < x.ndim:\n        slicing_s = slicing_s + [slice_1]\n        slicing_e = slicing_e + [slice_1]\n    slicing_s[dim] = slice_2_s\n    slicing_e[dim] = slice_2_e\n    return (x[slicing_s] - x[slicing_e]) / 2.0\n\n\nclass BendingEnergyLoss(_Loss):\n    \"\"\"\n    Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference.\n\n    For more information,\n    see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:\n        \"\"\"\n        Args:\n            normalize:\n                Whether to divide out spatial sizes in order to make the computation roughly\n                invariant to image scale (i.e. vector field sampling resolution). Defaults to False.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.normalize = normalize\n\n    def forward(self, pred: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            pred: the shape should be BCH(WD)\n\n        Raises:\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n            ValueError: When ``pred`` is not 3-d, 4-d or 5-d.\n            ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4.\n            ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.\n\n        \"\"\"\n        if pred.ndim not in [3, 4, 5]:\n            raise ValueError(f\"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}\")\n        for i in range(pred.ndim - 2):\n            if pred.shape[-i - 1] <= 4:\n                raise ValueError(f\"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}\")\n        if pred.shape[1] != pred.ndim - 2:\n            raise ValueError(\n                f\"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, \"\n                f\"does not match number of spatial dimensions, {pred.ndim - 2}\"\n            )\n\n        # first order gradient\n        first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]\n\n        # spatial dimensions in a shape suited for broadcasting below\n        if self.normalize:\n            spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))\n\n        energy = torch.tensor(0)\n        for dim_1, g in enumerate(first_order_gradient):\n            dim_1 += 2\n            if self.normalize:\n                g *= pred.shape[dim_1] / spatial_dims\n                energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2\n            else:\n                energy = energy + spatial_gradient(g, dim_1) ** 2\n            for dim_2 in range(dim_1 + 1, pred.ndim):\n                if self.normalize:\n                    energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2\n                else:\n                    energy = energy + 2 * spatial_gradient(g, dim_2) ** 2\n\n        if self.reduction == LossReduction.MEAN.value:\n            energy = torch.mean(energy)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            energy = torch.sum(energy)  # sum over the batch and channel dims\n        elif self.reduction != LossReduction.NONE.value:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return energy\n\n\nclass DiffusionLoss(_Loss):\n    \"\"\"\n    Calculate the diffusion based on first-order differentiation of ``pred`` using central finite difference.\n    For the original paper, please refer to\n    VoxelMorph: A Learning Framework for Deformable Medical Image Registration,\n    Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca\n    IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.\n\n    For more information,\n    see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.\n\n    Adapted from:\n        VoxelMorph (https://github.com/voxelmorph/voxelmorph)\n    \"\"\"\n\n    def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:\n        \"\"\"\n        Args:\n            normalize:\n                Whether to divide out spatial sizes in order to make the computation roughly\n                invariant to image scale (i.e. vector field sampling resolution). Defaults to False.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.normalize = normalize\n\n    def forward(self, pred: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            pred:\n                Predicted dense displacement field (DDF) with shape BCH[WD],\n                where C is the number of spatial dimensions.\n                Note that diffusion loss can only be calculated\n                when the sizes of the DDF along all spatial dimensions are greater than 2.\n\n        Raises:\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n            ValueError: When ``pred`` is not 3-d, 4-d or 5-d.\n            ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.\n            ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.\n\n        \"\"\"\n        if pred.ndim not in [3, 4, 5]:\n            raise ValueError(f\"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}\")\n        for i in range(pred.ndim - 2):\n            if pred.shape[-i - 1] <= 2:\n                raise ValueError(f\"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}\")\n        if pred.shape[1] != pred.ndim - 2:\n            raise ValueError(\n                f\"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, \"\n                f\"does not match number of spatial dimensions, {pred.ndim - 2}\"\n            )\n\n        # first order gradient\n        first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]\n\n        # spatial dimensions in a shape suited for broadcasting below\n        if self.normalize:\n            spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))\n\n        diffusion = torch.tensor(0)\n        for dim_1, g in enumerate(first_order_gradient):\n            dim_1 += 2\n            if self.normalize:\n                # We divide the partial derivative for each vector component at each voxel by the spatial size\n                # corresponding to that component relative to the spatial size of the vector component with respect\n                # to which the partial derivative is taken.\n                g *= pred.shape[dim_1] / spatial_dims\n            diffusion = diffusion + g**2\n\n        if self.reduction == LossReduction.MEAN.value:\n            diffusion = torch.mean(diffusion)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            diffusion = torch.sum(diffusion)  # sum over the batch and channel dims\n        elif self.reduction != LossReduction.NONE.value:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return diffusion\n"
  },
  {
    "path": "monai/losses/dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.losses.focal_loss import FocalLoss\nfrom monai.losses.spatial_mask import MaskedLoss\nfrom monai.losses.utils import compute_tp_fp_fn\nfrom monai.networks import one_hot\nfrom monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option\n\n\nclass DiceLoss(_Loss):\n    \"\"\"\n    Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks.\n    The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).\n\n    Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,\n    must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`\n    can be 1 or N (one-hot format).\n\n    The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of\n    the inter-over-union calculation to smooth results respectively, these values should be small.\n\n    The original papers:\n\n        Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric\n        Medical Image Segmentation. 3DV 2016.\n\n        Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with\n        Soft Labels. NeurIPS 2023.\n\n        Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with\n        Soft Labels. MICCAI 2023.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        squared_pred: bool = False,\n        jaccard: bool = False,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        weight: Sequence[float] | float | int | torch.Tensor | None = None,\n        soft_label: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            include_background: if False, channel index 0 (background category) is excluded from the calculation.\n                if the non-background segmentations are small compared to the total image size they can get overwhelmed\n                by the signal from the background so excluding it in such cases helps convergence.\n            to_onehot_y: whether to convert the ``target`` into the one-hot format,\n                using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.\n            sigmoid: if True, apply a sigmoid function to the prediction.\n            softmax: if True, apply a softmax function to the prediction.\n            other_act: callable function to execute other activation layers, Defaults to ``None``. for example:\n                ``other_act = torch.tanh``.\n            squared_pred: use squared versions of targets and predictions in the denominator or not.\n            jaccard: compute Jaccard Index (soft IoU) instead of dice or not.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n\n            smooth_nr: a small constant added to the numerator to avoid zero.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.\n                Defaults to False, a Dice loss value is computed independently from each item in the batch\n                before any `reduction`.\n            weight: weights to apply to the voxels of each class. If None no weights are applied.\n                The input can be a single value (same weight for all classes), a sequence of values (the length\n                of the sequence should be the same as the number of classes. If not ``include_background``,\n                the number of classes should not include the background category class 0).\n                The value/values should be no less than 0. Defaults to None.\n            soft_label: whether the target contains non-binary values (soft labels) or not.\n                If True a soft label formulation of the loss will be used.\n\n        Raises:\n            TypeError: When ``other_act`` is not an ``Optional[Callable]``.\n            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].\n                Incompatible values.\n\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        if other_act is not None and not callable(other_act):\n            raise TypeError(f\"other_act must be None or callable but is {type(other_act).__name__}.\")\n        if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:\n            raise ValueError(\"Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].\")\n        self.include_background = include_background\n        self.to_onehot_y = to_onehot_y\n        self.sigmoid = sigmoid\n        self.softmax = softmax\n        self.other_act = other_act\n        self.squared_pred = squared_pred\n        self.jaccard = jaccard\n        self.smooth_nr = float(smooth_nr)\n        self.smooth_dr = float(smooth_dr)\n        self.batch = batch\n        weight = torch.as_tensor(weight) if weight is not None else None\n        self.register_buffer(\"class_weight\", weight)\n        self.class_weight: None | torch.Tensor\n        self.soft_label = soft_label\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD], where N is the number of classes.\n            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.\n\n        Raises:\n            AssertionError: When input and target (after one hot transform if set)\n                have different shapes.\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n\n        Example:\n            >>> from monai.losses.dice import *  # NOQA\n            >>> import torch\n            >>> from monai.losses.dice import DiceLoss\n            >>> B, C, H, W = 7, 5, 3, 2\n            >>> input = torch.rand(B, C, H, W)\n            >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()\n            >>> target = one_hot(target_idx[:, None, ...], num_classes=C)\n            >>> self = DiceLoss(reduction='none')\n            >>> loss = self(input, target)\n            >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape\n        \"\"\"\n        if self.sigmoid:\n            input = torch.sigmoid(input)\n\n        n_pred_ch = input.shape[1]\n        if self.softmax:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `softmax=True` ignored.\")\n            else:\n                input = torch.softmax(input, 1)\n\n        if self.other_act is not None:\n            input = self.other_act(input)\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                target = one_hot(target, num_classes=n_pred_ch)\n\n        if not self.include_background:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `include_background=False` ignored.\")\n            else:\n                # if skipping background, removing first channel\n                target = target[:, 1:]\n                input = input[:, 1:]\n\n        if target.shape != input.shape:\n            raise AssertionError(f\"ground truth has different shape ({target.shape}) from input ({input.shape})\")\n\n        # reducing only spatial dimensions (not batch nor channels)\n        reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()\n        if self.batch:\n            # reducing spatial dimensions and batch\n            reduce_axis = [0] + reduce_axis\n\n        ord = 2 if self.squared_pred else 1\n        tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)\n        if not self.jaccard:\n            fp *= 0.5\n            fn *= 0.5\n        numerator = 2 * tp + self.smooth_nr\n        denominator = 2 * (tp + fp + fn) + self.smooth_dr\n\n        f: torch.Tensor = 1 - numerator / denominator\n\n        num_of_classes = target.shape[1]\n        if self.class_weight is not None and num_of_classes != 1:\n            # make sure the lengths of weights are equal to the number of classes\n            if self.class_weight.ndim == 0:\n                self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)\n            else:\n                if self.class_weight.shape[0] != num_of_classes:\n                    raise ValueError(\n                        \"\"\"the length of the `weight` sequence should be the same as the number of classes.\n                        If `include_background=False`, the weight should not include\n                        the background category class 0.\"\"\"\n                    )\n            if self.class_weight.min() < 0:\n                raise ValueError(\"the value/values of the `weight` should be no less than 0.\")\n            # apply class_weight to loss\n            f = f * self.class_weight.to(f)\n\n        if self.reduction == LossReduction.MEAN.value:\n            f = torch.mean(f)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            f = torch.sum(f)  # sum over the batch and channel dims\n        elif self.reduction == LossReduction.NONE.value:\n            # If we are not computing voxelwise loss components at least\n            # make sure a none reduction maintains a broadcastable shape\n            broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)\n            f = f.view(broadcast_shape)\n        else:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return f\n\n\nclass MaskedDiceLoss(DiceLoss):\n    \"\"\"\n    Add an additional `masking` process before `DiceLoss`, accept a binary mask ([0, 1]) indicating a region,\n    `input` and `target` will be masked by the region: region with mask `1` will keep the original value,\n    region with `0` mask will be converted to `0`. Then feed `input` and `target` to normal `DiceLoss` computation.\n    This has the effect of ensuring only the masked region contributes to the loss computation and\n    hence gradient calculation.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        squared_pred: bool = False,\n        jaccard: bool = False,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        weight: Sequence[float] | float | int | torch.Tensor | None = None,\n        soft_label: bool = False,\n    ) -> None:\n        \"\"\"\n        Args follow :py:class:`monai.losses.DiceLoss`.\n        \"\"\"\n        if other_act is not None and not callable(other_act):\n            raise TypeError(f\"other_act must be None or callable but is {type(other_act).__name__}.\")\n        if sigmoid and softmax:\n            raise ValueError(\"Incompatible values: sigmoid=True and softmax=True.\")\n        if other_act is not None and (sigmoid or softmax):\n            raise ValueError(\"Incompatible values: other_act is not None and sigmoid=True or softmax=True.\")\n\n        self.pre_sigmoid = sigmoid\n        self.pre_softmax = softmax\n        self.pre_other_act = other_act\n\n        super().__init__(\n            include_background=include_background,\n            to_onehot_y=to_onehot_y,\n            sigmoid=False,\n            softmax=False,\n            other_act=None,\n            squared_pred=squared_pred,\n            jaccard=jaccard,\n            reduction=reduction,\n            smooth_nr=smooth_nr,\n            smooth_dr=smooth_dr,\n            batch=batch,\n            weight=weight,\n            soft_label=soft_label,\n        )\n\n        self.spatial_weighted = MaskedLoss(loss=super().forward)\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD].\n            target: the shape should be BNH[WD].\n            mask: the shape should B1H[WD] or 11H[WD].\n        \"\"\"\n\n        if self.pre_sigmoid:\n            input = torch.sigmoid(input)\n\n        n_pred_ch = input.shape[1]\n        if self.pre_softmax:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `softmax=True` ignored.\", stacklevel=2)\n            else:\n                input = torch.softmax(input, 1)\n\n        if self.pre_other_act is not None:\n            input = self.pre_other_act(input)\n        return self.spatial_weighted(input=input, target=target, mask=mask)  # type: ignore[no-any-return]\n\n\nclass GeneralizedDiceLoss(_Loss):\n    \"\"\"\n    Compute the generalised Dice loss defined in:\n\n        Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning\n        loss function for highly unbalanced segmentations. DLMIA 2017.\n\n    Adapted from:\n        https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L279\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        w_type: Weight | str = Weight.SQUARE,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        soft_label: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            include_background: If False channel index 0 (background category) is excluded from the calculation.\n            to_onehot_y: whether to convert the ``target`` into the one-hot format,\n                using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.\n            sigmoid: If True, apply a sigmoid function to the prediction.\n            softmax: If True, apply a softmax function to the prediction.\n            other_act: callable function to execute other activation layers, Defaults to ``None``. for example:\n                ``other_act = torch.tanh``.\n            w_type: {``\"square\"``, ``\"simple\"``, ``\"uniform\"``}\n                Type of function to transform ground truth volume to a weight factor. Defaults to ``\"square\"``.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n            smooth_nr: a small constant added to the numerator to avoid zero.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.\n                Defaults to False, intersection over union is computed from each item in the batch.\n                If True, the class-weighted intersection and union areas are first summed across the batches.\n            soft_label: whether the target contains non-binary values (soft labels) or not.\n                If True a soft label formulation of the loss will be used.\n\n        Raises:\n            TypeError: When ``other_act`` is not an ``Optional[Callable]``.\n            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].\n                Incompatible values.\n\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        if other_act is not None and not callable(other_act):\n            raise TypeError(f\"other_act must be None or callable but is {type(other_act).__name__}.\")\n        if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:\n            raise ValueError(\"Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].\")\n\n        self.include_background = include_background\n        self.to_onehot_y = to_onehot_y\n        self.sigmoid = sigmoid\n        self.softmax = softmax\n        self.other_act = other_act\n\n        self.w_type = look_up_option(w_type, Weight)\n\n        self.smooth_nr = float(smooth_nr)\n        self.smooth_dr = float(smooth_dr)\n        self.batch = batch\n        self.soft_label = soft_label\n\n    def w_func(self, grnd):\n        if self.w_type == str(Weight.SIMPLE):\n            return torch.reciprocal(grnd)\n        if self.w_type == str(Weight.SQUARE):\n            return torch.reciprocal(grnd * grnd)\n        return torch.ones_like(grnd)\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD].\n            target: the shape should be BNH[WD].\n\n        Raises:\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n\n        \"\"\"\n        if self.sigmoid:\n            input = torch.sigmoid(input)\n        n_pred_ch = input.shape[1]\n        if self.softmax:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `softmax=True` ignored.\")\n            else:\n                input = torch.softmax(input, 1)\n\n        if self.other_act is not None:\n            input = self.other_act(input)\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                target = one_hot(target, num_classes=n_pred_ch)\n\n        if not self.include_background:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `include_background=False` ignored.\")\n            else:\n                # if skipping background, removing first channel\n                target = target[:, 1:]\n                input = input[:, 1:]\n\n        if target.shape != input.shape:\n            raise AssertionError(f\"ground truth has differing shape ({target.shape}) from input ({input.shape})\")\n\n        # reducing only spatial dimensions (not batch nor channels)\n        reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()\n        if self.batch:\n            reduce_axis = [0] + reduce_axis\n\n        tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)\n        fp *= 0.5\n        fn *= 0.5\n        denominator = 2 * (tp + fp + fn)\n\n        ground_o = torch.sum(target, reduce_axis)\n        w = self.w_func(ground_o.float())\n        infs = torch.isinf(w)\n        if self.batch:\n            w[infs] = 0.0\n            w = w + infs * torch.max(w)\n        else:\n            w[infs] = 0.0\n            max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)\n            w = w + infs * max_values\n\n        final_reduce_dim = 0 if self.batch else 1\n        numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr\n        denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr\n        f: torch.Tensor = 1.0 - (numer / denom)\n\n        if self.reduction == LossReduction.MEAN.value:\n            f = torch.mean(f)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            f = torch.sum(f)  # sum over the batch and channel dims\n        elif self.reduction == LossReduction.NONE.value:\n            # If we are not computing voxelwise loss components at least\n            # make sure a none reduction maintains a broadcastable shape\n            broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)\n            f = f.view(broadcast_shape)\n        else:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return f\n\n\nclass GeneralizedWassersteinDiceLoss(_Loss):\n    \"\"\"\n    Compute the generalized Wasserstein Dice Loss defined in:\n\n        Fidon L. et al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class\n        Segmentation using Holistic Convolutional Networks. BrainLes 2017.\n\n    Or its variant (use the option weighting_mode=\"GDL\") defined in the Appendix of:\n\n        Tilborghs, S. et al. (2020) Comparative study of deep learning methods for the automatic\n        segmentation of lung, lesion and lesion type in CT scans of COVID-19 patients.\n        arXiv preprint arXiv:2007.15546\n\n    Adapted from:\n        https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss\n    \"\"\"\n\n    def __init__(\n        self,\n        dist_matrix: np.ndarray | torch.Tensor,\n        weighting_mode: str = \"default\",\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n    ) -> None:\n        \"\"\"\n        Args:\n            dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.\n            It must have dimension C x C where C is the number of classes.\n            weighting_mode: {``\"default\"``, ``\"GDL\"``}\n                Specifies how to weight the class-specific sum of errors.\n                Default to ``\"default\"``.\n\n                - ``\"default\"``: (recommended) use the original weighting method as in:\n                    Fidon L. et al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class\n                    Segmentation using Holistic Convolutional Networks. BrainLes 2017.\n                - ``\"GDL\"``: use a GDL-like weighting method as in the Appendix of:\n                    Tilborghs, S. et al. (2020) Comparative study of deep learning methods for the automatic\n                    segmentation of lung, lesion and lesion type in CT scans of COVID-19 patients.\n                    arXiv preprint arXiv:2007.15546\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n            smooth_nr: a small constant added to the numerator to avoid zero.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n\n        Raises:\n            ValueError: When ``dist_matrix`` is not a square matrix.\n\n        Example:\n            .. code-block:: python\n\n                import torch\n                import numpy as np\n                from monai.losses import GeneralizedWassersteinDiceLoss\n\n                # Example with 3 classes (including the background: label 0).\n                # The distance between the background class (label 0) and the other classes is the maximum, equal to 1.\n                # The distance between class 1 and class 2 is 0.5.\n                dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)\n                wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)\n\n                pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)\n                grnd = torch.tensor([0, 1, 2], dtype=torch.int64)\n                wass_loss(pred_score, grnd)  # 0\n\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n\n        if dist_matrix.shape[0] != dist_matrix.shape[1]:\n            raise ValueError(f\"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.\")\n\n        if weighting_mode not in [\"default\", \"GDL\"]:\n            raise ValueError(f\"weighting_mode must be either 'default' or 'GDL', got {weighting_mode}.\")\n\n        self.m = dist_matrix\n        if isinstance(self.m, np.ndarray):\n            self.m = torch.from_numpy(self.m)\n        if torch.max(self.m) != 1:\n            self.m = self.m / torch.max(self.m)\n        self.alpha_mode = weighting_mode\n        self.num_classes = self.m.size(0)\n        self.smooth_nr = float(smooth_nr)\n        self.smooth_dr = float(smooth_dr)\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD].\n            target: the shape should be BNH[WD].\n\n        \"\"\"\n        # Aggregate spatial dimensions\n        flat_input = input.reshape(input.size(0), input.size(1), -1)\n        flat_target = target.reshape(target.size(0), -1).long()\n\n        # Apply the softmax to the input scores map\n        probs = F.softmax(flat_input, dim=1)\n\n        # Compute the Wasserstein distance map\n        wass_dist_map = self.wasserstein_distance_map(probs, flat_target)\n\n        # Compute the values of alpha to use\n        alpha = self._compute_alpha_generalized_true_positives(flat_target)\n\n        # Compute the numerator and denominator of the generalized Wasserstein Dice loss\n        if self.alpha_mode == \"GDL\":\n            # use GDL-style alpha weights (i.e. normalize by the volume of each class)\n            # contrary to the original definition we also use alpha in the \"generalized all error\".\n            true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)\n            denom = self._compute_denominator(alpha, flat_target, wass_dist_map)\n        else:  # default: as in the original paper\n            # (i.e. alpha=1 for all foreground classes and 0 for the background).\n            # Compute the generalised number of true positives\n            true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)\n            all_error = torch.sum(wass_dist_map, dim=1)\n            denom = 2 * true_pos + all_error\n\n        # Compute the final loss\n        wass_dice: torch.Tensor = (2.0 * true_pos + self.smooth_nr) / (denom + self.smooth_dr)\n        wass_dice_loss: torch.Tensor = 1.0 - wass_dice\n\n        if self.reduction == LossReduction.MEAN.value:\n            wass_dice_loss = torch.mean(wass_dice_loss)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            wass_dice_loss = torch.sum(wass_dice_loss)  # sum over the batch and channel dims\n        elif self.reduction == LossReduction.NONE.value:\n            # GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)\n            pass\n        else:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return wass_dice_loss\n\n    def wasserstein_distance_map(self, flat_proba: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute the voxel-wise Wasserstein distance between the\n        flattened prediction and the flattened labels (ground_truth) with respect\n        to the distance matrix on the label space M.\n        This corresponds to eq. 6 in:\n\n            Fidon L. et al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class\n            Segmentation using Holistic Convolutional Networks. BrainLes 2017.\n\n        Args:\n            flat_proba: the probabilities of input(predicted) tensor.\n            flat_target: the target tensor.\n        \"\"\"\n        # Turn the distance matrix to a map of identical matrix\n        m = torch.clone(torch.as_tensor(self.m)).to(flat_proba.device)\n        m_extended = torch.unsqueeze(m, dim=0)\n        m_extended = torch.unsqueeze(m_extended, dim=3)\n        m_extended = m_extended.expand((flat_proba.size(0), m_extended.size(1), m_extended.size(2), flat_proba.size(2)))\n\n        # Expand the feature dimensions of the target\n        flat_target_extended = torch.unsqueeze(flat_target, dim=1)\n        flat_target_extended = flat_target_extended.expand(\n            (flat_target.size(0), m_extended.size(1), flat_target.size(1))\n        )\n        flat_target_extended = torch.unsqueeze(flat_target_extended, dim=1)\n\n        # Extract the vector of class distances for the ground-truth label at each voxel\n        m_extended = torch.gather(m_extended, dim=1, index=flat_target_extended)\n        m_extended = torch.squeeze(m_extended, dim=1)\n\n        # Compute the wasserstein distance map\n        wasserstein_map = m_extended * flat_proba\n\n        # Sum over the classes\n        wasserstein_map = torch.sum(wasserstein_map, dim=1)\n        return wasserstein_map\n\n    def _compute_generalized_true_positive(\n        self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            alpha: generalised number of true positives of target class.\n            flat_target: the target tensor.\n            wasserstein_distance_map: the map obtained from the above function.\n        \"\"\"\n        # Extend alpha to a map and select value at each voxel according to flat_target\n        alpha_extended = torch.unsqueeze(alpha, dim=2)\n        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))\n        flat_target_extended = torch.unsqueeze(flat_target, dim=1)\n        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)\n        alpha_extended = torch.squeeze(alpha_extended, dim=1)\n\n        return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1)\n\n    def _compute_denominator(\n        self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            alpha: generalised number of true positives of target class.\n            flat_target: the target tensor.\n            wasserstein_distance_map: the map obtained from the above function.\n        \"\"\"\n        # Extend alpha to a map and select value at each voxel according to flat_target\n        alpha_extended = torch.unsqueeze(alpha, dim=2)\n        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))\n        flat_target_extended = torch.unsqueeze(flat_target, dim=1)\n        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)\n        alpha_extended = torch.squeeze(alpha_extended, dim=1)\n\n        return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1)\n\n    def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            flat_target: the target tensor.\n        \"\"\"\n        alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).float().to(flat_target.device)\n        if self.alpha_mode == \"GDL\":  # GDL style\n            # Define alpha like in the generalized dice loss\n            # i.e. the inverse of the volume of each class.\n            one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).float()\n            volumes = torch.sum(one_hot_f, dim=2)\n            alpha = 1.0 / (volumes + 1.0)\n        else:  # default, i.e. like in the original paper\n            # alpha weights are 0 for the background and 1 the other classes\n            alpha[:, 0] = 0.0\n        return alpha\n\n\nclass DiceCELoss(_Loss):\n    \"\"\"\n    Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses.\n    The details of Dice loss is shown in ``monai.losses.DiceLoss``.\n    The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``.\n    In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are\n    not supported.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        squared_pred: bool = False,\n        jaccard: bool = False,\n        reduction: str = \"mean\",\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        weight: torch.Tensor | None = None,\n        lambda_dice: float = 1.0,\n        lambda_ce: float = 1.0,\n        label_smoothing: float = 0.0,\n    ) -> None:\n        \"\"\"\n        Args:\n            ``lambda_ce`` are only used for cross entropy loss.\n            ``reduction`` and ``weight`` is used for both losses and other parameters are only used for dice loss.\n\n            include_background: if False channel index 0 (background category) is excluded from the calculation.\n            to_onehot_y: whether to convert the ``target`` into the one-hot format,\n                using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.\n            sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,\n                don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.\n            softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,\n                don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.\n            other_act: callable function to execute other activation layers, Defaults to ``None``. for example:\n                ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss`.\n            squared_pred: use squared versions of targets and predictions in the denominator or not.\n            jaccard: compute Jaccard Index (soft IoU) instead of dice or not.\n            reduction: {``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``. The dice loss should\n                as least reduce the spatial dimensions, which is different from cross entropy loss, thus here\n                the ``none`` option cannot be used.\n\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n\n            smooth_nr: a small constant added to the numerator to avoid zero.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.\n                Defaults to False, a Dice loss value is computed independently from each item in the batch\n                before any `reduction`.\n            weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.\n                or a weight of positive examples to be broadcasted with target used as `pos_weight` for `BCEWithLogitsLoss`.\n                See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information.\n                The weight is also used in `DiceLoss`.\n            lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.\n                Defaults to 1.0.\n            lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.\n                Defaults to 1.0.\n            label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed\n                by the given factor to reduce overfitting.\n                Defaults to 0.0.\n\n        \"\"\"\n        super().__init__()\n        reduction = look_up_option(reduction, DiceCEReduction).value\n        dice_weight: torch.Tensor | None\n        if weight is not None and not include_background:\n            dice_weight = weight[1:]\n        else:\n            dice_weight = weight\n        self.dice = DiceLoss(\n            include_background=include_background,\n            to_onehot_y=to_onehot_y,\n            sigmoid=sigmoid,\n            softmax=softmax,\n            other_act=other_act,\n            squared_pred=squared_pred,\n            jaccard=jaccard,\n            reduction=reduction,\n            smooth_nr=smooth_nr,\n            smooth_dr=smooth_dr,\n            batch=batch,\n            weight=dice_weight,\n        )\n        self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)\n        self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)\n        if lambda_dice < 0.0:\n            raise ValueError(\"lambda_dice should be no less than 0.0.\")\n        if lambda_ce < 0.0:\n            raise ValueError(\"lambda_ce should be no less than 0.0.\")\n        self.lambda_dice = lambda_dice\n        self.lambda_ce = lambda_ce\n\n    def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute CrossEntropy loss for the input logits and target.\n        Will remove the channel dim according to PyTorch CrossEntropyLoss:\n        https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.\n\n        \"\"\"\n        n_pred_ch, n_target_ch = input.shape[1], target.shape[1]\n        if n_pred_ch != n_target_ch and n_target_ch == 1:\n            target = torch.squeeze(target, dim=1)\n            target = target.long()\n        elif not torch.is_floating_point(target):\n            target = target.to(dtype=input.dtype)\n\n        return self.cross_entropy(input, target)  # type: ignore[no-any-return]\n\n    def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute Binary CrossEntropy loss for the input logits and target in one single class.\n\n        \"\"\"\n        if not torch.is_floating_point(target):\n            target = target.to(dtype=input.dtype)\n\n        return self.binary_cross_entropy(input, target)  # type: ignore[no-any-return]\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD].\n            target: the shape should be BNH[WD] or B1H[WD].\n\n        Raises:\n            ValueError: When number of dimensions for input and target are different.\n            ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.\n\n        Returns:\n            torch.Tensor: value of the loss.\n\n        \"\"\"\n        if input.dim() != target.dim():\n            raise ValueError(\n                \"the number of dimensions for input and target should be the same, \"\n                f\"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). \"\n                \"if target is not one-hot encoded, please provide a tensor with shape B1H[WD].\"\n            )\n\n        if target.shape[1] != 1 and target.shape[1] != input.shape[1]:\n            raise ValueError(\n                \"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, \"\n                f\"got shape {input.shape} and {target.shape}.\"\n            )\n\n        dice_loss = self.dice(input, target)\n        ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)\n        total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss\n\n        return total_loss\n\n\nclass DiceFocalLoss(_Loss):\n    \"\"\"\n    Compute both Dice loss and Focal Loss, and return the weighted sum of these two losses.\n    The details of Dice loss is shown in ``monai.losses.DiceLoss``.\n    The details of Focal Loss is shown in ``monai.losses.FocalLoss``.\n\n    ``gamma`` and ``lambda_focal`` are only used for the focal loss.\n    ``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses,\n    and other parameters are only used for dice loss.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        squared_pred: bool = False,\n        jaccard: bool = False,\n        reduction: str = \"mean\",\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        gamma: float = 2.0,\n        weight: Sequence[float] | float | int | torch.Tensor | None = None,\n        lambda_dice: float = 1.0,\n        lambda_focal: float = 1.0,\n        alpha: float | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            include_background: if False channel index 0 (background category) is excluded from the calculation.\n            to_onehot_y: whether to convert the ``target`` into the one-hot format,\n                using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.\n            sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,\n                don't need to specify activation function for `FocalLoss`.\n            softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,\n                don't need to specify activation function for `FocalLoss`.\n            other_act: callable function to execute other activation layers, Defaults to ``None``.\n                for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.\n            squared_pred: use squared versions of targets and predictions in the denominator or not.\n            jaccard: compute Jaccard Index (soft IoU) instead of dice or not.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n\n            smooth_nr: a small constant added to the numerator to avoid zero.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.\n                Defaults to False, a Dice loss value is computed independently from each item in the batch\n                before any `reduction`.\n            gamma: value of the exponent gamma in the definition of the Focal loss.\n            weight: weights to apply to the voxels of each class. If None no weights are applied.\n                The input can be a single value (same weight for all classes), a sequence of values (the length\n                of the sequence should be the same as the number of classes).\n            lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.\n                Defaults to 1.0.\n            lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.\n                Defaults to 1.0.\n            alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in\n                [0, 1]. Defaults to None.\n        \"\"\"\n        super().__init__()\n        self.dice = DiceLoss(\n            include_background=include_background,\n            to_onehot_y=False,\n            sigmoid=sigmoid,\n            softmax=softmax,\n            other_act=other_act,\n            squared_pred=squared_pred,\n            jaccard=jaccard,\n            reduction=reduction,\n            smooth_nr=smooth_nr,\n            smooth_dr=smooth_dr,\n            batch=batch,\n            weight=weight,\n        )\n        self.focal = FocalLoss(\n            include_background=include_background,\n            to_onehot_y=False,\n            gamma=gamma,\n            weight=weight,\n            alpha=alpha,\n            reduction=reduction,\n        )\n        if lambda_dice < 0.0:\n            raise ValueError(\"lambda_dice should be no less than 0.0.\")\n        if lambda_focal < 0.0:\n            raise ValueError(\"lambda_focal should be no less than 0.0.\")\n        self.lambda_dice = lambda_dice\n        self.lambda_focal = lambda_focal\n        self.to_onehot_y = to_onehot_y\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD]. The input should be the original logits\n                due to the restriction of ``monai.losses.FocalLoss``.\n            target: the shape should be BNH[WD] or B1H[WD].\n\n        Raises:\n            ValueError: When number of dimensions for input and target are different.\n            ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.\n\n        Returns:\n            torch.Tensor: value of the loss.\n        \"\"\"\n        if input.dim() != target.dim():\n            raise ValueError(\n                \"the number of dimensions for input and target should be the same, \"\n                f\"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). \"\n                \"if target is not one-hot encoded, please provide a tensor with shape B1H[WD].\"\n            )\n\n        if target.shape[1] != 1 and target.shape[1] != input.shape[1]:\n            raise ValueError(\n                \"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, \"\n                f\"got shape {input.shape} and {target.shape}.\"\n            )\n\n        if self.to_onehot_y:\n            n_pred_ch = input.shape[1]\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                target = one_hot(target, num_classes=n_pred_ch)\n        dice_loss = self.dice(input, target)\n        focal_loss = self.focal(input, target)\n        total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss\n        return total_loss\n\n\nclass GeneralizedDiceFocalLoss(_Loss):\n    \"\"\"Compute both Generalized Dice Loss and Focal Loss, and return their weighted average. The details of Generalized Dice Loss\n    and Focal Loss are available at ``monai.losses.GeneralizedDiceLoss`` and ``monai.losses.FocalLoss``.\n\n    Args:\n        include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.\n            Defaults to True.\n        to_onehot_y: whether to convert the ``target`` into the one-hot format,\n            using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.\n        sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.\n        softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.\n        other_act (Optional[Callable], optional): callable function to execute other activation layers,\n            Defaults to ``None``. for example: `other_act = torch.tanh`.\n            only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.\n        w_type (Union[Weight, str], optional): {``\"square\"``, ``\"simple\"``, ``\"uniform\"``}. Type of function to transform\n            ground-truth volume to a weight factor. Defaults to ``\"square\"``.\n        reduction (Union[LossReduction, str], optional): {``\"none\"``, ``\"mean\"``, ``\"sum\"``}. Specified the reduction to\n            apply to the output. Defaults to ``\"mean\"``.\n            - ``\"none\"``: no reduction will be applied.\n            - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n            - ``\"sum\"``: the output will be summed.\n        smooth_nr (float, optional): a small constant added to the numerator to avoid zero. Defaults to 1e-5.\n        smooth_dr (float, optional): a small constant added to the denominator to avoid nan. Defaults to 1e-5.\n        batch (bool, optional): whether to sum the intersection and union areas over the batch dimension before the dividing.\n            Defaults to False, i.e., the areas are computed for each item in the batch.\n        gamma (float, optional): value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.\n        weight (Optional[Union[Sequence[float], float, int, torch.Tensor]], optional): weights to apply to\n            the voxels of each class. If None no weights are applied. The input can be a single value\n            (same weight for all classes), a sequence of values (the length of the sequence hould be the same as\n            the number of classes). Defaults to None.\n        lambda_gdl (float, optional): the trade-off weight value for Generalized Dice Loss. The value should be\n            no less than 0.0. Defaults to 1.0.\n        lambda_focal (float, optional): the trade-off weight value for Focal Loss. The value should be no less\n            than 0.0. Defaults to 1.0.\n\n    Raises:\n        ValueError: if either `lambda_gdl` or `lambda_focal` is less than 0.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        w_type: Weight | str = Weight.SQUARE,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        gamma: float = 2.0,\n        weight: Sequence[float] | float | int | torch.Tensor | None = None,\n        lambda_gdl: float = 1.0,\n        lambda_focal: float = 1.0,\n    ) -> None:\n        super().__init__()\n        self.generalized_dice = GeneralizedDiceLoss(\n            include_background=include_background,\n            to_onehot_y=to_onehot_y,\n            sigmoid=sigmoid,\n            softmax=softmax,\n            other_act=other_act,\n            w_type=w_type,\n            reduction=reduction,\n            smooth_nr=smooth_nr,\n            smooth_dr=smooth_dr,\n            batch=batch,\n        )\n        self.focal = FocalLoss(\n            include_background=include_background,\n            to_onehot_y=to_onehot_y,\n            gamma=gamma,\n            weight=weight,\n            reduction=reduction,\n        )\n        if lambda_gdl < 0.0:\n            raise ValueError(\"lambda_gdl should be no less than 0.0.\")\n        if lambda_focal < 0.0:\n            raise ValueError(\"lambda_focal should be no less than 0.0.\")\n        self.lambda_gdl = lambda_gdl\n        self.lambda_focal = lambda_focal\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input (torch.Tensor): the shape should be BNH[WD]. The input should be the original logits\n                due to the restriction of ``monai.losses.FocalLoss``.\n            target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].\n\n        Raises:\n            ValueError: When number of dimensions for input and target are different.\n            ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.\n\n        Returns:\n            torch.Tensor: value of the loss.\n        \"\"\"\n        if input.dim() != target.dim():\n            raise ValueError(\n                \"the number of dimensions for input and target should be the same, \"\n                f\"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). \"\n                \"if target is not one-hot encoded, please provide a tensor with shape B1H[WD].\"\n            )\n\n        if target.shape[1] != 1 and target.shape[1] != input.shape[1]:\n            raise ValueError(\n                \"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, \"\n                f\"got shape {input.shape} and {target.shape}.\"\n            )\n\n        gdl_loss = self.generalized_dice(input, target)\n        focal_loss = self.focal(input, target)\n        total_loss: torch.Tensor = self.lambda_gdl * gdl_loss + self.lambda_focal * focal_loss\n        return total_loss\n\n\nDice = DiceLoss\ndice_ce = DiceCELoss\ndice_focal = DiceFocalLoss\ngeneralized_dice = GeneralizedDiceLoss\ngeneralized_dice_focal = GeneralizedDiceFocalLoss\ngeneralized_wasserstein_dice = GeneralizedWassersteinDiceLoss\n"
  },
  {
    "path": "monai/losses/ds_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\n\nclass DeepSupervisionLoss(_Loss):\n    \"\"\"\n    Wrapper class around the main loss function to accept a list of tensors returned from a deeply\n    supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels.\n    \"\"\"\n\n    def __init__(self, loss: _Loss, weight_mode: str = \"exp\", weights: list[float] | None = None) -> None:\n        \"\"\"\n        Args:\n            loss: main loss instance, e.g DiceLoss().\n            weight_mode: {``\"same\"``, ``\"exp\"``, ``\"two\"``}\n                Specifies the weights calculation for each image level. Defaults to ``\"exp\"``.\n                - ``\"same\"``: all weights are equal to 1.\n                - ``\"exp\"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .\n                - ``\"two\"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc\n            weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used\n                regardless of the weight_mode\n        \"\"\"\n        super().__init__()\n        self.loss = loss\n        self.weight_mode = weight_mode\n        self.weights = weights\n        self.interp_mode = \"nearest-exact\"\n\n    def get_weights(self, levels: int = 1) -> list[float]:\n        \"\"\"\n        Calculates weights for a given number of scale levels\n        \"\"\"\n        levels = max(1, levels)\n        if self.weights is not None and len(self.weights) >= levels:\n            weights = self.weights[:levels]\n        elif self.weight_mode == \"same\":\n            weights = [1.0] * levels\n        elif self.weight_mode == \"exp\":\n            weights = [max(0.5**l, 0.0625) for l in range(levels)]\n        elif self.weight_mode == \"two\":\n            weights = [1.0 if l == 0 else 0.5 for l in range(levels)]\n        else:\n            weights = [1.0] * levels\n\n        return weights\n\n    def get_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Calculates a loss output accounting for differences in shapes,\n        and downsizing targets if necessary (using nearest neighbor interpolation)\n        Generally downsizing occurs for all level, except for the first (level==0)\n        \"\"\"\n        if input.shape[2:] != target.shape[2:]:\n            target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)\n        return self.loss(input, target)  # type: ignore[no-any-return]\n\n    def forward(self, input: None | torch.Tensor | list[torch.Tensor], target: torch.Tensor) -> torch.Tensor:\n        if isinstance(input, (list, tuple)):\n            weights = self.get_weights(levels=len(input))\n            loss = torch.tensor(0, dtype=torch.float, device=target.device)\n            for l in range(len(input)):\n                loss += weights[l] * self.get_loss(input[l].float(), target)\n            return loss\n        if input is None:\n            raise ValueError(\"input shouldn't be None.\")\n\n        return self.loss(input.float(), target)  # type: ignore[no-any-return]\n\n\nds_loss = DeepSupervisionLoss\n"
  },
  {
    "path": "monai/losses/focal_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks import one_hot\nfrom monai.utils import LossReduction\n\n\nclass FocalLoss(_Loss):\n    \"\"\"\n    FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from\n    high confidence correct predictions.\n\n    Reimplementation of the Focal Loss described in:\n\n        - [\"Focal Loss for Dense Object Detection\"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017\n        - \"AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy\",\n          Zhu et al., Medical Physics 2018\n\n    Example:\n        >>> import torch\n        >>> from monai.losses import FocalLoss\n        >>> from torch.nn import BCEWithLogitsLoss\n        >>> shape = B, N, *DIMS = 2, 3, 5, 7, 11\n        >>> input = torch.rand(*shape)\n        >>> target = torch.rand(*shape)\n        >>> # Demonstrate equivalence to BCE when gamma=0\n        >>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0)\n        >>> fl_g0_loss = fl_g0_criterion(input, target)\n        >>> bce_criterion = BCEWithLogitsLoss(reduction='none')\n        >>> bce_loss = bce_criterion(input, target)\n        >>> assert torch.allclose(fl_g0_loss, bce_loss)\n        >>> # Demonstrate \"focus\" by setting gamma > 0.\n        >>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2)\n        >>> fl_g2_loss = fl_g2_criterion(input, target)\n        >>> # Mark easy and hard cases\n        >>> is_easy = (target > 0.7) & (input > 0.7)\n        >>> is_hard = (target > 0.7) & (input < 0.3)\n        >>> easy_loss_g0 = fl_g0_loss[is_easy].mean()\n        >>> hard_loss_g0 = fl_g0_loss[is_hard].mean()\n        >>> easy_loss_g2 = fl_g2_loss[is_easy].mean()\n        >>> hard_loss_g2 = fl_g2_loss[is_hard].mean()\n        >>> # Gamma > 0 causes the loss function to \"focus\" on the hard\n        >>> # cases.  IE, easy cases are downweighted, so hard cases\n        >>> # receive a higher proportion of the loss.\n        >>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2\n        >>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0\n        >>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        gamma: float = 2.0,\n        alpha: float | Sequence[float] | None = None,\n        weight: Sequence[float] | float | int | torch.Tensor | None = None,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        use_softmax: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            include_background: if False, channel index 0 (background category) is excluded from the loss calculation.\n                If False, `alpha` is invalid when using softmax unless `alpha` is a sequence (explicit class weights).\n            to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.\n            gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.\n            alpha: value of the alpha in the definition of the alpha-balanced Focal loss.\n                The value should be in [0, 1].\n                If a sequence is provided, its length must match the number of classes\n                (excluding the background class if `include_background=False`).\n                Defaults to None.\n            weight: weights to apply to the voxels of each class. If None no weights are applied.\n                The input can be a single value (same weight for all classes), a sequence of values (the length\n                of the sequence should be the same as the number of classes. If not ``include_background``,\n                the number of classes should not include the background category class 0).\n                The value/values should be no less than 0. Defaults to None.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n\n            use_softmax: whether to use softmax to transform the original logits into probabilities.\n                If True, softmax is used. If False, sigmoid is used. Defaults to False.\n\n        Example:\n            >>> import torch\n            >>> from monai.losses import FocalLoss\n            >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)\n            >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)\n            >>> fl = FocalLoss(to_onehot_y=True)\n            >>> fl(pred, grnd)\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.include_background = include_background\n        self.to_onehot_y = to_onehot_y\n        self.gamma = gamma\n        self.weight = weight\n        self.use_softmax = use_softmax\n        self.alpha: float | torch.Tensor | None\n        if alpha is None:\n            self.alpha = None\n        elif isinstance(alpha, (float, int)):\n            self.alpha = float(alpha)\n        else:\n            self.alpha = torch.as_tensor(alpha)\n        weight = torch.as_tensor(weight) if weight is not None else None\n        self.register_buffer(\"class_weight\", weight)\n        self.class_weight: None | torch.Tensor\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD], where N is the number of classes.\n                The input should be the original logits since it will be transformed by\n                a sigmoid/softmax in the forward function.\n            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.\n\n        Raises:\n            ValueError: When input and target (after one hot transform if set)\n                have different shapes.\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n            ValueError: When ``self.weight`` is a sequence and the length is not equal to the\n                number of classes.\n            ValueError: When ``self.weight`` is/contains a value that is less than 0.\n\n        \"\"\"\n        n_pred_ch = input.shape[1]\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                target = one_hot(target, num_classes=n_pred_ch)\n\n        if not self.include_background:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `include_background=False` ignored.\")\n            else:\n                # if skipping background, removing first channel\n                target = target[:, 1:]\n                input = input[:, 1:]\n\n        if target.shape != input.shape:\n            raise ValueError(f\"ground truth has different shape ({target.shape}) from input ({input.shape})\")\n\n        loss: torch.Tensor | None = None\n        input = input.float()\n        target = target.float()\n        alpha_arg = self.alpha\n        if self.use_softmax:\n            if not self.include_background and self.alpha is not None:\n                if isinstance(self.alpha, (float, int)):\n                    alpha_arg = None\n                    warnings.warn(\n                        \"`include_background=False`, scalar `alpha` ignored when using softmax.\", stacklevel=2\n                    )\n            loss = softmax_focal_loss(input, target, self.gamma, alpha_arg)\n        else:\n            loss = sigmoid_focal_loss(input, target, self.gamma, alpha_arg)\n\n        num_of_classes = target.shape[1]\n        if self.class_weight is not None and num_of_classes != 1:\n            # make sure the lengths of weights are equal to the number of classes\n            if self.class_weight.ndim == 0:\n                self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)\n            else:\n                if self.class_weight.shape[0] != num_of_classes:\n                    raise ValueError(\n                        \"\"\"the length of the `weight` sequence should be the same as the number of classes.\n                        If `include_background=False`, the weight should not include\n                        the background category class 0.\"\"\"\n                    )\n            if self.class_weight.min() < 0:\n                raise ValueError(\"the value/values of the `weight` should be no less than 0.\")\n            # apply class_weight to loss\n            self.class_weight = self.class_weight.to(loss)\n            broadcast_dims = [-1] + [1] * len(target.shape[2:])\n            self.class_weight = self.class_weight.view(broadcast_dims)\n            loss = self.class_weight * loss\n\n        if self.reduction == LossReduction.SUM.value:\n            # Previously there was a mean over the last dimension, which did not\n            # return a compatible BCE loss. To maintain backwards compatible\n            # behavior we have a flag that performs this extra step, disable or\n            # parameterize if necessary. (Or justify why the mean should be there)\n            average_spatial_dims = True\n            if average_spatial_dims:\n                loss = loss.mean(dim=list(range(2, len(target.shape))))\n            loss = loss.sum()\n        elif self.reduction == LossReduction.MEAN.value:\n            loss = loss.mean()\n        elif self.reduction == LossReduction.NONE.value:\n            pass\n        else:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n        return loss\n\n\ndef softmax_focal_loss(\n    input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None\n) -> torch.Tensor:\n    \"\"\"\n    FL(pt) = -alpha * (1 - pt)**gamma * log(pt)\n\n    where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and\n    s_j is the unnormalized score for class j.\n    \"\"\"\n    input_ls = input.log_softmax(1)\n    loss: torch.Tensor = -(1 - input_ls.exp()).pow(gamma) * input_ls * target\n\n    if alpha is not None:\n        if isinstance(alpha, torch.Tensor):\n            alpha_t = alpha.to(device=input.device, dtype=input.dtype)\n        else:\n            alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)\n\n        if alpha_t.ndim == 0:  # scalar\n            alpha_val = alpha_t.item()\n            # (1-alpha) for the background class and alpha for the other classes\n            alpha_fac = torch.tensor([1 - alpha_val] + [alpha_val] * (target.shape[1] - 1)).to(loss)\n        else:  # tensor (sequence)\n            if alpha_t.shape[0] != target.shape[1]:\n                raise ValueError(\n                    f\"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]}).\"\n                )\n            alpha_fac = alpha_t\n\n        broadcast_dims = [-1] + [1] * len(target.shape[2:])\n        alpha_fac = alpha_fac.view(broadcast_dims)\n        loss = alpha_fac * loss\n\n    return loss\n\n\ndef sigmoid_focal_loss(\n    input: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: float | torch.Tensor | None = None\n) -> torch.Tensor:\n    \"\"\"\n    FL(pt) = -alpha * (1 - pt)**gamma * log(pt)\n\n    where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0\n    \"\"\"\n    # computing binary cross entropy with logits\n    # equivalent to F.binary_cross_entropy_with_logits(input, target, reduction='none')\n    # see also https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Loss.cpp#L363\n    loss: torch.Tensor = input - input * target - F.logsigmoid(input)\n\n    # sigmoid(-i) if t==1; sigmoid(i) if t==0 <=>\n    # 1-sigmoid(i) if t==1; sigmoid(i) if t==0 <=>\n    # 1-p if t==1; p if t==0 <=>\n    # pfac, that is, the term (1 - pt)\n    invprobs = F.logsigmoid(-input * (target * 2 - 1))  # reduced chance of overflow\n    # (pfac.log() * gamma).exp() <=>\n    # pfac.log().exp() ^ gamma <=>\n    # pfac ^ gamma\n    loss = (invprobs * gamma).exp() * loss\n\n    if alpha is not None:\n        if isinstance(alpha, torch.Tensor):\n            alpha_t = alpha.to(device=input.device, dtype=input.dtype)\n        else:\n            alpha_t = torch.tensor(alpha, device=input.device, dtype=input.dtype)\n\n        if alpha_t.ndim == 0:  # scalar\n            # alpha if t==1; (1-alpha) if t==0\n            alpha_factor = target * alpha_t + (1 - target) * (1 - alpha_t)\n        else:  # tensor (sequence)\n            if alpha_t.shape[0] != target.shape[1]:\n                raise ValueError(\n                    f\"The length of alpha ({alpha_t.shape[0]}) must match the number of classes ({target.shape[1]}).\"\n                )\n            # Reshape alpha for broadcasting: (1, C, 1, 1...)\n            broadcast_dims = [-1] + [1] * len(target.shape[2:])\n            alpha_t = alpha_t.view(broadcast_dims)\n            # Apply per-class weight only to positive samples\n            # For positive samples (target==1): multiply by alpha[c]\n            # For negative samples (target==0): keep weight as 1.0\n            alpha_factor = torch.where(target == 1, alpha_t, torch.ones_like(alpha_t))\n\n        loss = alpha_factor * loss\n\n    return loss\n"
  },
  {
    "path": "monai/losses/giou_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.data.box_utils import COMPUTE_DTYPE, box_pair_giou\nfrom monai.utils import LossReduction\n\n\nclass BoxGIoULoss(_Loss):\n    \"\"\"\n    Compute the generalized intersection over union (GIoU) loss of a pair of boxes.\n    The two inputs should have the same shape. giou_loss = 1.0 - giou\n\n    The range of GIoU is (-1.0, 1.0]. Thus the range of GIoU loss is [0.0, 2.0).\n\n    Args:\n        reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n            Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n            - ``\"none\"``: no reduction will be applied.\n            - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n            - ``\"sum\"``: the output will be summed.\n    \"\"\"\n\n    def __init__(self, reduction: LossReduction | str = LossReduction.MEAN) -> None:\n        super().__init__(reduction=LossReduction(reduction).value)\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: predicted bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n            target: GT bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``\n\n        Raises:\n            ValueError: When the two inputs have different shape.\n        \"\"\"\n        if target.shape != input.shape:\n            raise ValueError(f\"ground truth has different shape ({target.shape}) from input ({input.shape})\")\n\n        box_dtype = input.dtype\n        giou: torch.Tensor = box_pair_giou(  # type: ignore\n            target.to(dtype=COMPUTE_DTYPE), input.to(dtype=COMPUTE_DTYPE)\n        )\n        loss: torch.Tensor = 1.0 - giou\n        if self.reduction == LossReduction.MEAN.value:\n            loss = loss.mean()\n        elif self.reduction == LossReduction.SUM.value:\n            loss = loss.sum()\n        elif self.reduction == LossReduction.NONE.value:\n            pass\n        else:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n        return loss.to(box_dtype)\n\n\ngiou = BoxGIoULoss\n"
  },
  {
    "path": "monai/losses/hausdorff_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Hausdorff loss implementation based on paper:\n# https://arxiv.org/pdf/1904.10030.pdf\n\n# Repo: https://github.com/PatRyg99/HausdorffLoss\n\nfrom __future__ import annotations\n\nimport warnings\nfrom typing import Callable\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks import one_hot\nfrom monai.transforms.utils import distance_transform_edt\nfrom monai.utils import LossReduction\n\n\nclass HausdorffDTLoss(_Loss):\n    \"\"\"\n    Compute channel-wise binary Hausdorff loss based on distance transform. It can support both multi-classes and\n    multi-labels tasks. The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target`\n    (BNHW[D]).\n\n    Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,\n    must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`\n    can be 1 or N (one-hot format).\n\n    The original paper: Karimi, D. et. al. (2019) Reducing the Hausdorff Distance in Medical Image Segmentation with\n    Convolutional Neural Networks, IEEE Transactions on medical imaging, 39(2), 499-513\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha: float = 2.0,\n        include_background: bool = False,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        batch: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            alpha: the exponent to transform the distance when computing the loss. Defaults to 2.0.\n            include_background: if False, channel index 0 (background category) is excluded from the calculation.\n                if the non-background segmentations are small compared to the total image size they can get overwhelmed\n                by the signal from the background so excluding it in such cases helps convergence.\n            to_onehot_y: whether to convert the ``target`` into the one-hot format,\n                using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.\n            sigmoid: if True, apply a sigmoid function to the prediction.\n            softmax: if True, apply a softmax function to the prediction.\n            other_act: callable function to execute other activation layers, Defaults to ``None``. for example:\n                ``other_act = torch.tanh``.\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.\n                Defaults to False, a loss value is computed independently from each item in the batch\n                before any `reduction`.\n\n        Raises:\n            TypeError: When ``other_act`` is not an ``Optional[Callable]``.\n            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].\n                Incompatible values.\n\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        if other_act is not None and not callable(other_act):\n            raise TypeError(f\"other_act must be None or callable but is {type(other_act).__name__}.\")\n        if int(sigmoid) + int(softmax) > 1:\n            raise ValueError(\"Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].\")\n\n        self.alpha = alpha\n        self.include_background = include_background\n        self.to_onehot_y = to_onehot_y\n        self.sigmoid = sigmoid\n        self.softmax = softmax\n        self.other_act = other_act\n        self.batch = batch\n\n    @torch.no_grad()\n    def distance_field(self, img: torch.Tensor) -> torch.Tensor:\n        \"\"\"Generate distance transform.\n\n        Args:\n            img (np.ndarray): input mask as NCHWD or NCHW.\n\n        Returns:\n            np.ndarray: Distance field.\n        \"\"\"\n        field = torch.zeros_like(img)\n\n        for batch_idx in range(len(img)):\n            fg_mask = img[batch_idx] > 0.5\n\n            # For cases where the mask is entirely background or entirely foreground\n            # the distance transform is not well defined for all 1s,\n            # which always would happen on either foreground or background, so skip\n            if fg_mask.any() and not fg_mask.all():\n                fg_dist: torch.Tensor = distance_transform_edt(fg_mask)  # type: ignore\n                bg_mask = ~fg_mask\n                bg_dist: torch.Tensor = distance_transform_edt(bg_mask)  # type: ignore\n\n                field[batch_idx] = fg_dist + bg_dist\n\n        return field\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNHW[D], where N is the number of classes.\n            target: the shape should be BNHW[D] or B1HW[D], where N is the number of classes.\n\n        Raises:\n            ValueError: If the input is not 2D (NCHW) or 3D (NCHWD).\n            AssertionError: When input and target (after one hot transform if set)\n                have different shapes.\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n\n        Example:\n            >>> import torch\n            >>> from monai.losses.hausdorff_loss import HausdorffDTLoss\n            >>> from monai.networks.utils import one_hot\n            >>> B, C, H, W = 7, 5, 3, 2\n            >>> input = torch.rand(B, C, H, W)\n            >>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()\n            >>> target = one_hot(target_idx[:, None, ...], num_classes=C)\n            >>> self = HausdorffDTLoss(reduction='none')\n            >>> loss = self(input, target)\n            >>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape\n        \"\"\"\n        if input.dim() != 4 and input.dim() != 5:\n            raise ValueError(\"Only 2D (NCHW) and 3D (NCHWD) supported\")\n\n        if self.sigmoid:\n            input = torch.sigmoid(input)\n\n        n_pred_ch = input.shape[1]\n        if self.softmax:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `softmax=True` ignored.\")\n            else:\n                input = torch.softmax(input, 1)\n\n        if self.other_act is not None:\n            input = self.other_act(input)\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                target = one_hot(target, num_classes=n_pred_ch)\n\n        if not self.include_background:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `include_background=False` ignored.\")\n            else:\n                # If skipping background, removing first channel\n                target = target[:, 1:]\n                input = input[:, 1:]\n\n        if target.shape != input.shape:\n            raise AssertionError(f\"ground truth has different shape ({target.shape}) from input ({input.shape})\")\n\n        device = input.device\n        all_f = []\n        for i in range(input.shape[1]):\n            ch_input = input[:, [i]]\n            ch_target = target[:, [i]]\n            pred_dt = self.distance_field(ch_input.detach()).float()\n            target_dt = self.distance_field(ch_target.detach()).float()\n\n            pred_error = (ch_input - ch_target) ** 2\n            distance = pred_dt**self.alpha + target_dt**self.alpha\n\n            running_f = pred_error * distance.to(device)\n            reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()\n            if self.batch:\n                # reducing spatial dimensions and batch\n                reduce_axis = [0] + reduce_axis\n            all_f.append(running_f.mean(dim=reduce_axis, keepdim=True))\n        f = torch.cat(all_f, dim=1)\n        if self.reduction == LossReduction.MEAN.value:\n            f = torch.mean(f)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            f = torch.sum(f)  # sum over the batch and channel dims\n        elif self.reduction == LossReduction.NONE.value:\n            # If we are not computing voxelwise loss components at least make sure a none reduction maintains a\n            # broadcastable shape\n            broadcast_shape = list(f.shape[0:2]) + [1] * (len(ch_input.shape) - 2)\n            f = f.view(broadcast_shape)\n        else:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return f\n\n\nclass LogHausdorffDTLoss(HausdorffDTLoss):\n    \"\"\"\n    Compute the logarithm of the Hausdorff Distance Transform Loss.\n\n    This class computes the logarithm of the Hausdorff Distance Transform Loss, which is based on the distance transform.\n    The logarithm is computed to potentially stabilize and scale the loss values, especially when the original loss\n    values are very small.\n\n    The formula for the loss is given by:\n        log_loss = log(HausdorffDTLoss + 1)\n\n    Inherits from the HausdorffDTLoss class to utilize its distance transform computation.\n    \"\"\"\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute the logarithm of the Hausdorff Distance Transform Loss.\n\n        Args:\n            input (torch.Tensor): The shape should be BNHW[D], where N is the number of classes.\n            target (torch.Tensor): The shape should be BNHW[D] or B1HW[D], where N is the number of classes.\n\n        Returns:\n            torch.Tensor: The computed Log Hausdorff Distance Transform Loss for the given input and target.\n\n        Raises:\n            Any exceptions raised by the parent class HausdorffDTLoss.\n        \"\"\"\n        log_loss: torch.Tensor = torch.log(super().forward(input, target) + 1)\n        return log_loss\n"
  },
  {
    "path": "monai/losses/image_dissimilarity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn import functional as F\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks.layers import gaussian_1d, separable_filtering\nfrom monai.utils import LossReduction\nfrom monai.utils.module import look_up_option\n\n\ndef make_rectangular_kernel(kernel_size: int) -> torch.Tensor:\n    return torch.ones(kernel_size)\n\n\ndef make_triangular_kernel(kernel_size: int) -> torch.Tensor:\n    fsize = (kernel_size + 1) // 2\n    if fsize % 2 == 0:\n        fsize -= 1\n    f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize)\n    padding = (kernel_size - fsize) // 2 + fsize // 2\n    return F.conv1d(f, f, padding=padding).reshape(-1)\n\n\ndef make_gaussian_kernel(kernel_size: int) -> torch.Tensor:\n    sigma = torch.tensor(kernel_size / 3.0)\n    kernel = gaussian_1d(sigma=sigma, truncated=kernel_size // 2, approx=\"sampled\", normalize=False) * (\n        2.5066282 * sigma\n    )\n    return kernel[:kernel_size]\n\n\nkernel_dict = {\n    \"rectangular\": make_rectangular_kernel,\n    \"triangular\": make_triangular_kernel,\n    \"gaussian\": make_gaussian_kernel,\n}\n\n\nclass LocalNormalizedCrossCorrelationLoss(_Loss):\n    \"\"\"\n    Local squared zero-normalized cross-correlation.\n\n    The loss is based on a moving kernel/window over the y_true/y_pred,\n    within the window the square of zncc is calculated.\n    The kernel can be a rectangular / triangular / gaussian window.\n    The final loss is the averaged loss over all windows.\n\n    Adapted from:\n        https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n\n    Args:\n        spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.\n        kernel_size: kernel spatial size, must be odd.\n        kernel_type: {``\"rectangular\"``, ``\"triangular\"``, ``\"gaussian\"``}. Defaults to ``\"rectangular\"``.\n        reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n            Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n            - ``\"none\"``: no reduction will be applied.\n            - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n            - ``\"sum\"``: the output will be summed.\n        smooth_nr: a small constant added to the numerator to avoid nan.\n        smooth_dr: a small constant added to the denominator to avoid nan.\n\n    Returns:\n        torch.Tensor: The computed loss value. The output range is approximately [-1, 0], where:\n            - Values closer to -1 indicate higher correlation (better match)\n            - Values closer to 0 indicate lower correlation (worse match)\n            - This loss should be **minimized** during optimization\n\n    Note:\n        The implementation computes the squared normalized cross-correlation coefficient\n        and then negates it, transforming the correlation maximization problem into a\n        loss minimization problem suitable for standard PyTorch optimizers.\n\n        Interpretation:\n            - Loss ≈ -1: Perfect correlation between images\n            - Loss ≈ 0: No correlation between images\n            - Lower (more negative) values indicate better alignment\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        kernel_size: int = 3,\n        kernel_type: str = \"rectangular\",\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 0.0,\n        smooth_dr: float = 1e-5,\n    ) -> None:\n        super().__init__(reduction=LossReduction(reduction).value)\n\n        self.ndim = spatial_dims\n        if self.ndim not in {1, 2, 3}:\n            raise ValueError(f\"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported\")\n\n        self.kernel_size = kernel_size\n        if self.kernel_size % 2 == 0:\n            raise ValueError(f\"kernel_size must be odd, got {self.kernel_size}\")\n\n        _kernel = look_up_option(kernel_type, kernel_dict)\n        self.kernel = _kernel(self.kernel_size)\n        self.kernel.require_grads = False\n        self.kernel_vol = self.get_kernel_vol()\n\n        self.smooth_nr = float(smooth_nr)\n        self.smooth_dr = float(smooth_dr)\n\n    def get_kernel_vol(self):\n        vol = self.kernel\n        for _ in range(self.ndim - 1):\n            vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))\n        return torch.sum(vol)\n\n    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            pred: the shape should be BNH[WD].\n            target: the shape should be BNH[WD].\n        Raises:\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n        \"\"\"\n        if pred.ndim - 2 != self.ndim:\n            raise ValueError(f\"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}\")\n        if target.shape != pred.shape:\n            raise ValueError(f\"ground truth has differing shape ({target.shape}) from pred ({pred.shape})\")\n\n        t2, p2, tp = target * target, pred * pred, target * pred\n        kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)\n        kernels = [kernel] * self.ndim\n        # sum over kernel\n        t_sum = separable_filtering(target, kernels=kernels)\n        p_sum = separable_filtering(pred, kernels=kernels)\n        t2_sum = separable_filtering(t2, kernels=kernels)\n        p2_sum = separable_filtering(p2, kernels=kernels)\n        tp_sum = separable_filtering(tp, kernels=kernels)\n\n        # average over kernel\n        t_avg = t_sum / kernel_vol\n        p_avg = p_sum / kernel_vol\n\n        # normalized cross correlation between t and p\n        # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p]\n        # denoted by num / denom\n        # assume we sum over N values\n        # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]]\n        #     = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N\n        #     = sum[t*p] - sum[t] * sum[p] / N\n        #     = sum[t*p] - sum[t] * mean[p] = cross\n        # the following is actually squared ncc\n        cross = tp_sum - p_avg * t_sum\n        t_var = torch.max(\n            t2_sum - t_avg * t_sum, torch.as_tensor(self.smooth_dr, dtype=t2_sum.dtype, device=t2_sum.device)\n        )\n        p_var = torch.max(\n            p2_sum - p_avg * p_sum, torch.as_tensor(self.smooth_dr, dtype=p2_sum.dtype, device=p2_sum.device)\n        )\n        ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var)\n\n        if self.reduction == LossReduction.SUM.value:\n            return torch.sum(ncc).neg()  # sum over the batch, channel and spatial ndims\n        if self.reduction == LossReduction.NONE.value:\n            return ncc.neg()\n        if self.reduction == LossReduction.MEAN.value:\n            return torch.mean(ncc).neg()  # average over the batch, channel and spatial ndims\n        raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n\nclass GlobalMutualInformationLoss(_Loss):\n    \"\"\"\n    Differentiable global mutual information loss via Parzen windowing method.\n\n    Reference:\n        https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_type: str = \"gaussian\",\n        num_bins: int = 23,\n        sigma_ratio: float = 0.5,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-7,\n        smooth_dr: float = 1e-7,\n    ) -> None:\n        \"\"\"\n        Args:\n            kernel_type: {``\"gaussian\"``, ``\"b-spline\"``}\n                ``\"gaussian\"``: adapted from DeepReg\n                Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1.\n                ``\"b-spline\"``: based on the method of Mattes et al [1,2] and adapted from ITK\n                References:\n                  [1] \"Nonrigid multimodality image registration\"\n                      D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank\n                      Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620.\n                  [2] \"PET-CT Image Registration in the Chest Using Free-form Deformations\"\n                      D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank\n                      IEEE Transactions in Medical Imaging. Vol.22, No.1,\n                      January 2003. pp.120-128.\n\n            num_bins: number of bins for intensity\n            sigma_ratio: a hyper param for gaussian function\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n            smooth_nr: a small constant added to the numerator to avoid nan.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        if num_bins <= 0:\n            raise ValueError(f\"num_bins must > 0, got {num_bins}\")\n        bin_centers = torch.linspace(0.0, 1.0, num_bins)  # (num_bins,)\n        sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio\n        self.kernel_type = look_up_option(kernel_type, [\"gaussian\", \"b-spline\"])\n        self.num_bins = num_bins\n        self.kernel_type = kernel_type\n        if self.kernel_type == \"gaussian\":\n            self.preterm = 1 / (2 * sigma**2)\n            self.bin_centers = bin_centers[None, None, ...]\n        self.smooth_nr = float(smooth_nr)\n        self.smooth_dr = float(smooth_dr)\n\n    def parzen_windowing(\n        self, pred: torch.Tensor, target: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        if self.kernel_type == \"gaussian\":\n            pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)\n            target_weight, target_probability = self.parzen_windowing_gaussian(target)\n        elif self.kernel_type == \"b-spline\":\n            # a third order BSpline kernel is used for the pred image intensity PDF.\n            pred_weight, pred_probability = self.parzen_windowing_b_spline(pred, order=3)\n            # a zero order (box car) BSpline kernel is used for the target image intensity PDF.\n            target_weight, target_probability = self.parzen_windowing_b_spline(target, order=0)\n        else:\n            raise ValueError\n        return pred_weight, pred_probability, target_weight, target_probability\n\n    def parzen_windowing_b_spline(self, img: torch.Tensor, order: int) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Parzen windowing with b-spline kernel (adapted from ITK)\n\n        Args:\n            img: the shape should be B[NDHW].\n            order: int.\n        \"\"\"\n\n        # Compute binsize for the histograms.\n        #\n        # The binsize for the image intensities needs to be adjusted so that\n        # we can avoid dealing with boundary conditions using the cubic\n        # spline as the Parzen window.  We do this by increasing the size\n        # of the bins so that the joint histogram becomes \"padded\" at the\n        # borders. Because we are changing the binsize,\n        # we also need to shift the minimum by the padded amount in order to\n        # avoid minimum values filling in our padded region.\n        #\n        # Note that there can still be non-zero bin values in the padded region,\n        # it's just that these bins will never be a central bin for the Parzen\n        # window.\n        _max, _min = torch.max(img), torch.min(img)\n        padding = 2\n        bin_size = (_max - _min) / (self.num_bins - 2 * padding)\n        norm_min = torch.div(_min, bin_size) - padding\n\n        # assign bin/window index to each voxel\n        window_term = torch.div(img, bin_size) - norm_min  # B[NDHW]\n        # make sure the extreme values are in valid (non-padded) bins\n        window_term = torch.clamp(window_term, padding, self.num_bins - padding - 1)  # B[NDHW]\n        window_term = window_term.reshape(window_term.shape[0], -1, 1)  # (batch, num_sample, 1)\n        bins = torch.arange(self.num_bins, device=window_term.device).reshape(1, 1, -1)  # (1, 1, num_bins)\n        sample_bin_matrix = torch.abs(bins - window_term)  # (batch, num_sample, num_bins)\n\n        # b-spleen kernel\n        # (4 - 6 * abs ** 2 + 3 * abs ** 3) / 6 when 0 <= abs < 1\n        # (2 - abs) ** 3 / 6 when 1 <= abs < 2\n        weight = torch.zeros_like(sample_bin_matrix, dtype=torch.float)  # (batch, num_sample, num_bins)\n        if order == 0:\n            weight = weight + (sample_bin_matrix < 0.5) + (sample_bin_matrix == 0.5) * 0.5\n        elif order == 3:\n            weight = weight + (4 - 6 * sample_bin_matrix**2 + 3 * sample_bin_matrix**3) * (sample_bin_matrix < 1) / 6\n            weight = weight + (2 - sample_bin_matrix) ** 3 * (sample_bin_matrix >= 1) * (sample_bin_matrix < 2) / 6\n        else:\n            raise ValueError(f\"Do not support b-spline {order}-order parzen windowing\")\n\n        weight = weight / torch.sum(weight, dim=-1, keepdim=True)  # (batch, num_sample, num_bins)\n        probability = torch.mean(weight, dim=-2, keepdim=True)  # (batch, 1, num_bins)\n        return weight, probability\n\n    def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Parzen windowing with gaussian kernel (adapted from DeepReg implementation)\n        Note: the input is expected to range between 0 and 1\n        Args:\n            img: the shape should be B[NDHW].\n        \"\"\"\n        img = torch.clamp(img, 0, 1)\n        img = img.reshape(img.shape[0], -1, 1)  # (batch, num_sample, 1)\n        weight = torch.exp(\n            -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2\n        )  # (batch, num_sample, num_bin)\n        weight = weight / torch.sum(weight, dim=-1, keepdim=True)  # (batch, num_sample, num_bin)\n        probability = torch.mean(weight, dim=-2, keepdim=True)  # (batch, 1, num_bin)\n        return weight, probability\n\n    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            pred: the shape should be B[NDHW].\n            target: the shape should be same as the pred shape.\n        Raises:\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n        \"\"\"\n        if target.shape != pred.shape:\n            raise ValueError(f\"ground truth has differing shape ({target.shape}) from pred ({pred.shape})\")\n        wa, pa, wb, pb = self.parzen_windowing(pred, target)  # (batch, num_sample, num_bin), (batch, 1, num_bin)\n\n        pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1])  # (batch, num_bins, num_bins)\n        papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa))  # (batch, num_bins, num_bins)\n        mi = torch.sum(\n            pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)\n        )  # (batch)\n\n        if self.reduction == LossReduction.SUM.value:\n            return torch.sum(mi).neg()  # sum over the batch and channel ndims\n        if self.reduction == LossReduction.NONE.value:\n            return mi.neg()\n        if self.reduction == LossReduction.MEAN.value:\n            return torch.mean(mi).neg()  # average over the batch and channel ndims\n        raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n"
  },
  {
    "path": "monai/losses/multi_scale.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks.layers import gaussian_1d, separable_filtering\nfrom monai.utils import LossReduction\n\n\ndef make_gaussian_kernel(sigma: int) -> torch.Tensor:\n    if sigma <= 0:\n        raise ValueError(f\"expecting positive sigma, got sigma={sigma}\")\n    return gaussian_1d(sigma=torch.tensor(sigma), truncated=3, approx=\"sampled\", normalize=False)\n\n\ndef make_cauchy_kernel(sigma: int) -> torch.Tensor:\n    if sigma <= 0:\n        raise ValueError(f\"expecting positive sigma, got sigma={sigma}\")\n    tail = int(sigma * 5)\n    k = torch.tensor([((x / sigma) ** 2 + 1) for x in range(-tail, tail + 1)])\n    k = torch.reciprocal(k)\n    k = k / torch.sum(k)\n    return k\n\n\nkernel_fn_dict = {\"gaussian\": make_gaussian_kernel, \"cauchy\": make_cauchy_kernel}\n\n\nclass MultiScaleLoss(_Loss):\n    \"\"\"\n    This is a wrapper class.\n    It smooths the input and target at different scales before passing them into the wrapped loss function.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(\n        self,\n        loss: _Loss,\n        scales: list | None = None,\n        kernel: str = \"gaussian\",\n        reduction: LossReduction | str = LossReduction.MEAN,\n    ) -> None:\n        \"\"\"\n        Args:\n            loss: loss function to be wrapped\n            scales: list of scalars or None, if None, do not apply any scaling.\n            kernel: gaussian or cauchy.\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        if kernel not in kernel_fn_dict:\n            raise ValueError(f\"got unsupported kernel type: {kernel}\", \"only support gaussian and cauchy\")\n        self.kernel_fn = kernel_fn_dict[kernel]\n        self.loss = loss\n        self.scales = scales\n\n    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:\n        if self.scales is None:\n            loss: torch.Tensor = self.loss(y_pred, y_true)\n        else:\n            loss_list = []\n            for s in self.scales:\n                if s == 0:\n                    # no smoothing\n                    loss_list.append(self.loss(y_pred, y_true))\n                else:\n                    loss_list.append(\n                        self.loss(\n                            separable_filtering(y_pred, [self.kernel_fn(s).to(y_pred)] * (y_true.ndim - 2)),\n                            separable_filtering(y_true, [self.kernel_fn(s).to(y_pred)] * (y_true.ndim - 2)),\n                        )\n                    )\n            loss = torch.stack(loss_list, dim=0)\n\n        if self.reduction == LossReduction.MEAN.value:\n            loss = torch.mean(loss)  # the batch and channel average\n        elif self.reduction == LossReduction.SUM.value:\n            loss = torch.sum(loss)  # sum over the batch and channel dims\n        elif self.reduction != LossReduction.NONE.value:\n            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n\n        return loss\n"
  },
  {
    "path": "monai/losses/nacl_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks.layers import GaussianFilter, MeanFilter\n\n\nclass NACLLoss(_Loss):\n    \"\"\"\n    Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.\n    NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions\n    to match a soft class proportion of surrounding pixel.\n\n    Murugesan, Balamurali, et al.\n    \"Trust your neighbours: Penalty-based constraints for model calibration.\"\n    International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.\n    https://arxiv.org/abs/2303.06268\n\n    Murugesan, Balamurali, et al.\n    \"Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints.\"\n    https://arxiv.org/abs/2401.14487\n    \"\"\"\n\n    def __init__(\n        self,\n        classes: int,\n        dim: int,\n        kernel_size: int = 3,\n        kernel_ops: str = \"mean\",\n        distance_type: str = \"l1\",\n        alpha: float = 0.1,\n        sigma: float = 1.0,\n    ) -> None:\n        \"\"\"\n        Args:\n            classes: number of classes\n            dim: dimension of data (supports 2d and 3d)\n            kernel_size: size of the spatial kernel\n            distance_type: l1/l2 distance between spatial kernel and predicted logits\n            alpha: weightage between cross entropy and logit constraint\n            sigma: sigma of gaussian\n        \"\"\"\n\n        super().__init__()\n\n        if kernel_ops not in [\"mean\", \"gaussian\"]:\n            raise ValueError(\"Kernel ops must be either mean or gaussian\")\n\n        if dim not in [2, 3]:\n            raise ValueError(f\"Support 2d and 3d, got dim={dim}.\")\n\n        if distance_type not in [\"l1\", \"l2\"]:\n            raise ValueError(f\"Distance type must be either L1 or L2, got {distance_type}\")\n\n        self.nc = classes\n        self.dim = dim\n        self.cross_entropy = nn.CrossEntropyLoss()\n        self.distance_type = distance_type\n        self.alpha = alpha\n        self.ks = kernel_size\n        self.svls_layer: Any\n\n        if kernel_ops == \"mean\":\n            self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size)\n            self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim)\n        if kernel_ops == \"gaussian\":\n            self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma)\n\n    def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Converts the mask to one hot representation and is smoothened with the selected spatial filter.\n\n        Args:\n            mask: the shape should be BH[WD].\n\n        Returns:\n            torch.Tensor: the shape would be BNH[WD], N being number of classes.\n        \"\"\"\n        rmask: torch.Tensor\n\n        if self.dim == 2:\n            oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float()\n            rmask = self.svls_layer(oh_labels)\n\n        if self.dim == 3:\n            oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float()\n            rmask = self.svls_layer(oh_labels)\n\n        return rmask\n\n    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.\n\n        Args:\n            inputs: the shape should be BNH[WD], where N is the number of classes.\n            targets: the shape should be BH[WD].\n\n        Returns:\n            torch.Tensor: value of the loss.\n\n        Example:\n            >>> import torch\n            >>> from monai.losses import NACLLoss\n            >>> B, N, H, W = 8, 3, 64, 64\n            >>> input = torch.rand(B, N, H, W)\n            >>> target = torch.randint(0, N, (B, H, W))\n            >>> criterion = NACLLoss(classes = N, dim = 2)\n            >>> loss = criterion(input, target)\n        \"\"\"\n\n        loss_ce = self.cross_entropy(inputs, targets)\n\n        utargets = self.get_constr_target(targets)\n\n        if self.distance_type == \"l1\":\n            loss_conf = utargets.sub(inputs).abs_().mean()\n        elif self.distance_type == \"l2\":\n            loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()\n\n        loss: torch.Tensor = loss_ce + self.alpha * loss_conf\n\n        return loss\n"
  },
  {
    "path": "monai/losses/perceptual.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.utils import optional_import\nfrom monai.utils.enums import StrEnum\n\nLPIPS, _ = optional_import(\"lpips\", name=\"LPIPS\")\ntorchvision, _ = optional_import(\"torchvision\")\n\n\nclass PercetualNetworkType(StrEnum):\n    alex = \"alex\"\n    vgg = \"vgg\"\n    squeeze = \"squeeze\"\n    radimagenet_resnet50 = \"radimagenet_resnet50\"\n    medicalnet_resnet10_23datasets = \"medicalnet_resnet10_23datasets\"\n    medicalnet_resnet50_23datasets = \"medicalnet_resnet50_23datasets\"\n    resnet50 = \"resnet50\"\n\n\nclass PerceptualLoss(nn.Module):\n    \"\"\"\n    Perceptual loss using features from pretrained deep neural networks trained. The function supports networks\n    pretrained on: ImageNet that use the LPIPS approach from Zhang, et al. \"The unreasonable effectiveness of deep\n    features as a perceptual metric.\" https://arxiv.org/abs/1801.03924 ; RadImagenet from Mei, et al. \"RadImageNet: An\n    Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning\"\n    https://pubs.rsna.org/doi/full/10.1148/ryai.210315 ; MedicalNet from Chen et al. \"Med3D: Transfer Learning for\n    3D Medical Image Analysis\" https://arxiv.org/abs/1904.00625 ;\n    and ResNet50 from Torchvision: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html .\n\n    The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all\n    three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss.\n    MedicalNet networks are only compatible with 3D inputs and support channel-wise loss.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        network_type: type of network for perceptual loss. One of:\n            - \"alex\"\n            - \"vgg\"\n            - \"squeeze\"\n            - \"radimagenet_resnet50\"\n            - \"medicalnet_resnet10_23datasets\"\n            - \"medicalnet_resnet50_23datasets\"\n            - \"resnet50\"\n        is_fake_3d: if True use 2.5D approach for a 3D perceptual loss.\n        fake_3d_ratio: ratio of how many slices per axis are used in the 2.5D approach.\n        cache_dir: path to cache directory to save the pretrained network weights.\n        pretrained: whether to load pretrained weights. This argument only works when using networks from\n            LIPIS or Torchvision. Defaults to ``True``.\n        pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded\n            via using this argument. This argument only works when ``\"network_type\"`` is \"resnet50\".\n            Defaults to `None`.\n        pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to\n            extract the expected state dict. This argument only works when ``\"network_type\"`` is \"resnet50\".\n            Defaults to `None`.\n        channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.\n            Defaults to ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        network_type: str = PercetualNetworkType.alex,\n        is_fake_3d: bool = True,\n        fake_3d_ratio: float = 0.5,\n        cache_dir: str | None = None,\n        pretrained: bool = True,\n        pretrained_path: str | None = None,\n        pretrained_state_dict_key: str | None = None,\n        channel_wise: bool = False,\n    ):\n        super().__init__()\n\n        if spatial_dims not in [2, 3]:\n            raise NotImplementedError(\"Perceptual loss is implemented only in 2D and 3D.\")\n\n        if (spatial_dims == 2 or is_fake_3d) and \"medicalnet_\" in network_type:\n            raise ValueError(\n                \"MedicalNet networks are only compatible with ``spatial_dims=3``.\"\n                \"Argument is_fake_3d must be set to False.\"\n            )\n\n        if channel_wise and \"medicalnet_\" not in network_type:\n            raise ValueError(\"Channel-wise loss is only compatible with MedicalNet networks.\")\n\n        if network_type.lower() not in list(PercetualNetworkType):\n            raise ValueError(\n                f\"Unrecognised criterion entered for Perceptual Loss. Must be one in: {', '.join(PercetualNetworkType)}\"\n            )\n        if cache_dir:\n            torch.hub.set_dir(cache_dir)\n            # raise a warning that this may change the default cache dir for all torch.hub calls\n            warnings.warn(\n                f\"Setting cache_dir to {cache_dir}, this may change the default cache dir for all torch.hub calls.\"\n            )\n\n        self.spatial_dims = spatial_dims\n        self.perceptual_function: nn.Module\n        if spatial_dims == 3 and is_fake_3d is False:\n            self.perceptual_function = MedicalNetPerceptualSimilarity(\n                net=network_type, verbose=False, channel_wise=channel_wise\n            )\n        elif \"radimagenet_\" in network_type:\n            self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)\n        elif network_type == \"resnet50\":\n            self.perceptual_function = TorchvisionModelPerceptualSimilarity(\n                net=network_type,\n                pretrained=pretrained,\n                pretrained_path=pretrained_path,\n                pretrained_state_dict_key=pretrained_state_dict_key,\n            )\n        else:\n            self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)\n        self.is_fake_3d = is_fake_3d\n        self.fake_3d_ratio = fake_3d_ratio\n        self.channel_wise = channel_wise\n\n    def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatial_axis: int) -> torch.Tensor:\n        \"\"\"\n        Calculate perceptual loss in one of the axis used in the 2.5D approach. After the slices of one spatial axis\n        is transformed into different instances in the batch, we compute the loss using the 2D approach.\n\n        Args:\n            input: input 5D tensor. BNHWD\n            target: target 5D tensor. BNHWD\n            spatial_axis: spatial axis to obtain the 2D slices.\n        \"\"\"\n\n        def batchify_axis(x: torch.Tensor, fake_3d_perm: tuple) -> torch.Tensor:\n            \"\"\"\n            Transform slices from one spatial axis into different instances in the batch.\n            \"\"\"\n            slices = x.float().permute((0,) + fake_3d_perm).contiguous()\n            slices = slices.view(-1, x.shape[fake_3d_perm[1]], x.shape[fake_3d_perm[2]], x.shape[fake_3d_perm[3]])\n\n            return slices\n\n        preserved_axes = [2, 3, 4]\n        preserved_axes.remove(spatial_axis)\n\n        channel_axis = 1\n        input_slices = batchify_axis(x=input, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))\n        indices = torch.randperm(input_slices.shape[0])[: int(input_slices.shape[0] * self.fake_3d_ratio)].to(\n            input_slices.device\n        )\n        input_slices = torch.index_select(input_slices, dim=0, index=indices)\n        target_slices = batchify_axis(x=target, fake_3d_perm=(spatial_axis, channel_axis) + tuple(preserved_axes))\n        target_slices = torch.index_select(target_slices, dim=0, index=indices)\n\n        axis_loss = torch.mean(self.perceptual_function(input_slices, target_slices))\n\n        return axis_loss\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNHW[D].\n            target: the shape should be BNHW[D].\n        \"\"\"\n        if target.shape != input.shape:\n            raise ValueError(f\"ground truth has differing shape ({target.shape}) from input ({input.shape})\")\n\n        if self.spatial_dims == 3 and self.is_fake_3d:\n            # Compute 2.5D approach\n            loss_sagittal = self._calculate_axis_loss(input, target, spatial_axis=2)\n            loss_coronal = self._calculate_axis_loss(input, target, spatial_axis=3)\n            loss_axial = self._calculate_axis_loss(input, target, spatial_axis=4)\n            loss = loss_sagittal + loss_axial + loss_coronal\n        else:\n            # 2D and real 3D cases\n            loss = self.perceptual_function(input, target)\n\n        if self.channel_wise:\n            loss = torch.mean(loss.squeeze(), dim=0)\n        else:\n            loss = torch.mean(loss)\n\n        return loss\n\n\nclass MedicalNetPerceptualSimilarity(nn.Module):\n    \"\"\"\n    Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. \"Med3D: Transfer\n    Learning for 3D Medical Image Analysis\". This class uses torch Hub to download the networks from\n    \"Warvito/MedicalNet-models\".\n\n    Args:\n        net: {``\"medicalnet_resnet10_23datasets\"``, ``\"medicalnet_resnet50_23datasets\"``}\n            Specifies the network architecture to use. Defaults to ``\"medicalnet_resnet10_23datasets\"``.\n        verbose: if false, mute messages from torch Hub load function.\n        channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.\n                Defaults to ``False``.\n    \"\"\"\n\n    def __init__(\n        self, net: str = \"medicalnet_resnet10_23datasets\", verbose: bool = False, channel_wise: bool = False\n    ) -> None:\n        super().__init__()\n        torch.hub._validate_not_a_forked_repo = lambda a, b, c: True\n        self.model = torch.hub.load(\"warvito/MedicalNet-models\", model=net, verbose=verbose, trust_repo=True)\n        self.eval()\n\n        self.channel_wise = channel_wise\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute perceptual loss using MedicalNet 3D networks. The input and target tensors are inputted in the\n        pre-trained MedicalNet that is used for feature extraction. Then, these extracted features are normalised across\n        the channels. Finally, we compute the difference between the input and target features and calculate the mean\n        value from the spatial dimensions to obtain the perceptual loss.\n\n        Args:\n            input: 3D input tensor with shape BCDHW.\n            target: 3D target tensor with shape BCDHW.\n\n        \"\"\"\n        input = medicalnet_intensity_normalisation(input)\n        target = medicalnet_intensity_normalisation(target)\n\n        # Get model outputs\n        feats_per_ch = 0\n        for ch_idx in range(input.shape[1]):\n            input_channel = input[:, ch_idx, ...].unsqueeze(1)\n            target_channel = target[:, ch_idx, ...].unsqueeze(1)\n\n            if ch_idx == 0:\n                outs_input = self.model.forward(input_channel)\n                outs_target = self.model.forward(target_channel)\n                feats_per_ch = outs_input.shape[1]\n            else:\n                outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1)\n                outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1)\n\n        # Normalise through the channels\n        feats_input = normalize_tensor(outs_input)\n        feats_target = normalize_tensor(outs_target)\n\n        feats_diff: torch.Tensor = (feats_input - feats_target) ** 2\n        if self.channel_wise:\n            results = torch.zeros(\n                feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4]\n            )\n            for i in range(input.shape[1]):\n                l_idx = i * feats_per_ch\n                r_idx = (i + 1) * feats_per_ch\n                results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)\n        else:\n            results = feats_diff.sum(dim=1, keepdim=True)\n\n        results = spatial_average_3d(results, keepdim=True)\n\n        return results\n\n\ndef spatial_average_3d(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:\n    return x.mean([2, 3, 4], keepdim=keepdim)\n\n\ndef normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:\n    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))\n    return x / (norm_factor + eps)\n\n\ndef medicalnet_intensity_normalisation(volume):\n    \"\"\"Based on https://github.com/Tencent/MedicalNet/blob/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b/datasets/brains18.py#L133\"\"\"\n    mean = volume.mean()\n    std = volume.std()\n    return (volume - mean) / std\n\n\nclass RadImageNetPerceptualSimilarity(nn.Module):\n    \"\"\"\n    Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et\n    al. \"RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning\"). This class\n    uses torch Hub to download the networks from \"Warvito/radimagenet-models\".\n\n    Args:\n        net: {``\"radimagenet_resnet50\"``}\n            Specifies the network architecture to use. Defaults to ``\"radimagenet_resnet50\"``.\n        verbose: if false, mute messages from torch Hub load function.\n    \"\"\"\n\n    def __init__(self, net: str = \"radimagenet_resnet50\", verbose: bool = False) -> None:\n        super().__init__()\n        self.model = torch.hub.load(\"Warvito/radimagenet-models\", model=net, verbose=verbose, trust_repo=True)\n        self.eval()\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at\n        https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from\n         'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised\n        across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).\n        \"\"\"\n        # If input has just 1 channel, repeat channel to have 3 channels\n        if input.shape[1] == 1 and target.shape[1] == 1:\n            input = input.repeat(1, 3, 1, 1)\n            target = target.repeat(1, 3, 1, 1)\n\n        # Change order from 'RGB' to 'BGR'\n        input = input[:, [2, 1, 0], ...]\n        target = target[:, [2, 1, 0], ...]\n\n        # Subtract mean used during training\n        input = subtract_mean(input)\n        target = subtract_mean(target)\n\n        # Get model outputs\n        outs_input = self.model.forward(input)\n        outs_target = self.model.forward(target)\n\n        # Normalise through the channels\n        feats_input = normalize_tensor(outs_input)\n        feats_target = normalize_tensor(outs_target)\n\n        results: torch.Tensor = (feats_input - feats_target) ** 2\n        results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)\n\n        return results\n\n\nclass TorchvisionModelPerceptualSimilarity(nn.Module):\n    \"\"\"\n    Component to perform the perceptual evaluation with TorchVision models.\n    Currently, only ResNet50 is supported. The network structure is based on:\n    https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html\n\n    Args:\n        net: {``\"resnet50\"``}\n            Specifies the network architecture to use. Defaults to ``\"resnet50\"``.\n        pretrained: whether to load pretrained weights. Defaults to `True`.\n        pretrained_path: if `pretrained` is `True`, users can specify a weights file to be loaded\n            via using this argument. Defaults to `None`.\n        pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to\n            extract the expected state dict. Defaults to `None`.\n    \"\"\"\n\n    def __init__(\n        self,\n        net: str = \"resnet50\",\n        pretrained: bool = True,\n        pretrained_path: str | None = None,\n        pretrained_state_dict_key: str | None = None,\n    ) -> None:\n        super().__init__()\n        supported_networks = [\"resnet50\"]\n        if net not in supported_networks:\n            raise NotImplementedError(\n                f\"'net' {net} is not supported, please select a network from {supported_networks}.\"\n            )\n\n        if pretrained_path is None:\n            network = torchvision.models.resnet50(\n                weights=torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None\n            )\n        else:\n            network = torchvision.models.resnet50(weights=None)\n            if pretrained is True:\n                state_dict = torch.load(pretrained_path, weights_only=True)\n                if pretrained_state_dict_key is not None:\n                    state_dict = state_dict[pretrained_state_dict_key]\n                network.load_state_dict(state_dict)\n        self.final_layer = \"layer4.2.relu_2\"\n        self.model = torchvision.models.feature_extraction.create_feature_extractor(network, [self.final_layer])\n        self.eval()\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at\n        https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,\n        we make sure that the input and target have 3 channels, and then do Z-Score normalization.\n        The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar\n        approach to the lpips package).\n        \"\"\"\n        # If input has just 1 channel, repeat channel to have 3 channels\n        if input.shape[1] == 1 and target.shape[1] == 1:\n            input = input.repeat(1, 3, 1, 1)\n            target = target.repeat(1, 3, 1, 1)\n\n        # Input normalization\n        input = torchvision_zscore_norm(input)\n        target = torchvision_zscore_norm(target)\n\n        # Get model outputs\n        outs_input = self.model.forward(input)[self.final_layer]\n        outs_target = self.model.forward(target)[self.final_layer]\n\n        # Normalise through the channels\n        feats_input = normalize_tensor(outs_input)\n        feats_target = normalize_tensor(outs_target)\n\n        results: torch.Tensor = (feats_input - feats_target) ** 2\n        results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)\n\n        return results\n\n\ndef spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:\n    return x.mean([2, 3], keepdim=keepdim)\n\n\ndef torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor:\n    mean = [0.485, 0.456, 0.406]\n    std = [0.229, 0.224, 0.225]\n    x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0]\n    x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1]\n    x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2]\n    return x\n\n\ndef subtract_mean(x: torch.Tensor) -> torch.Tensor:\n    mean = [0.406, 0.456, 0.485]\n    x[:, 0, :, :] -= mean[0]\n    x[:, 1, :, :] -= mean[1]\n    x[:, 2, :, :] -= mean[2]\n    return x\n"
  },
  {
    "path": "monai/losses/spatial_mask.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport inspect\nimport warnings\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\n__all__ = [\"MaskedLoss\"]\n\n\nclass MaskedLoss(_Loss):\n    \"\"\"\n    This is a wrapper class for the loss functions.  It allows for additional\n    weighting masks to be applied to both input and target.\n\n    See Also:\n        - :py:class:`monai.losses.MaskedDiceLoss`\n    \"\"\"\n\n    def __init__(\n        self, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | _Loss, *loss_args: Any, **loss_kwargs: Any\n    ) -> None:\n        \"\"\"\n        Args:\n            loss: loss function to be wrapped, this could be a loss class or an instance of a loss class.\n            loss_args: arguments to the loss function's constructor if `loss` is a class.\n            loss_kwargs: keyword arguments to the loss function's constructor if `loss` is a class.\n        \"\"\"\n        super().__init__()\n        self.loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = (\n            loss(*loss_args, **loss_kwargs) if inspect.isclass(loss) else loss\n        )\n        if not callable(self.loss):\n            raise ValueError(\"The loss function is not callable.\")\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD].\n            target: the shape should be BNH[WD].\n            mask: the shape should be B1H[WD] or 11H[WD].\n        \"\"\"\n        if mask is None:\n            warnings.warn(\"No mask value specified for the MaskedLoss.\")\n            return self.loss(input, target)\n\n        if input.dim() != mask.dim():\n            warnings.warn(f\"Dim of input ({input.shape}) is different from mask ({mask.shape}).\")\n        if input.shape[0] != mask.shape[0] and mask.shape[0] != 1:\n            raise ValueError(f\"Batch size of mask ({mask.shape}) must be one or equal to input ({input.shape}).\")\n        if target.dim() > 1:\n            if mask.shape[1] != 1:\n                raise ValueError(f\"Mask ({mask.shape}) must have only one channel.\")\n            if input.shape[2:] != mask.shape[2:]:\n                warnings.warn(f\"Spatial size of input ({input.shape}) is different from mask ({mask.shape}).\")\n        return self.loss(input * mask, target * mask)\n"
  },
  {
    "path": "monai/losses/spectral_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.fft import fftn\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.utils import LossReduction\n\n\nclass JukeboxLoss(_Loss):\n    \"\"\"\n    Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT).\n\n    Based on:\n        Dhariwal, et al. 'Jukebox: A generative model for music.' https://arxiv.org/abs/2005.00341\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information.\n        fft_norm: {``\"forward\"``, ``\"backward\"``, ``\"ortho\"``} Specifies the normalization mode in the fft. See\n            torch.fft.fftn() for more information.\n\n        reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n            Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n            - ``\"none\"``: no reduction will be applied.\n            - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n            - ``\"sum\"``: the output will be summed.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        fft_signal_size: tuple[int] | None = None,\n        fft_norm: str = \"ortho\",\n        reduction: LossReduction | str = LossReduction.MEAN,\n    ) -> None:\n        super().__init__(reduction=LossReduction(reduction).value)\n\n        self.spatial_dims = spatial_dims\n        self.fft_signal_size = fft_signal_size\n        self.fft_dim = tuple(range(1, spatial_dims + 2))\n        self.fft_norm = fft_norm\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        input_amplitude = self._get_fft_amplitude(target)\n        target_amplitude = self._get_fft_amplitude(input)\n\n        # Compute distance between amplitude of frequency components\n        # See Section 3.3 from https://arxiv.org/abs/2005.00341\n        loss = F.mse_loss(target_amplitude, input_amplitude, reduction=\"none\")\n\n        if self.reduction == LossReduction.MEAN.value:\n            loss = loss.mean()\n        elif self.reduction == LossReduction.SUM.value:\n            loss = loss.sum()\n        elif self.reduction == LossReduction.NONE.value:\n            pass\n\n        return loss\n\n    def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Calculate the amplitude of the fourier transformations representation of the images\n\n        Args:\n            images: Images that are to undergo fftn\n\n        Returns:\n            fourier transformation amplitude\n        \"\"\"\n        img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm)\n\n        amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2)\n\n        return amplitude\n"
  },
  {
    "path": "monai/losses/ssim_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.metrics.regression import KernelType, SSIMMetric\nfrom monai.utils import LossReduction, ensure_tuple_rep\n\n\nclass SSIMLoss(_Loss):\n    \"\"\"\n    Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.\n\n    For more info, visit\n        https://vicuesoft.com/glossary/term/ssim-ms-ssim/\n\n    SSIM reference paper:\n        Wang, Zhou, et al. \"Image quality assessment: from error visibility to structural\n        similarity.\" IEEE transactions on image processing 13.4 (2004): 600-612.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        data_range: float = 1.0,\n        kernel_type: KernelType | str = KernelType.GAUSSIAN,\n        win_size: int | Sequence[int] = 11,\n        kernel_sigma: float | Sequence[float] = 1.5,\n        k1: float = 0.01,\n        k2: float = 0.03,\n        reduction: LossReduction | str = LossReduction.MEAN,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input images.\n            data_range: value range of input images. (usually 1.0 or 255)\n            kernel_type: type of kernel, can be \"gaussian\" or \"uniform\".\n            win_size: window size of kernel\n            kernel_sigma: standard deviation for Gaussian kernel.\n            k1: stability constant used in the luminance denominator\n            k2: stability constant used in the contrast denominator\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.spatial_dims = spatial_dims\n        self._data_range = data_range\n        self.kernel_type = kernel_type\n\n        if not isinstance(win_size, Sequence):\n            win_size = ensure_tuple_rep(win_size, spatial_dims)\n        self.kernel_size = win_size\n\n        if not isinstance(kernel_sigma, Sequence):\n            kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)\n        self.kernel_sigma = kernel_sigma\n\n        self.k1 = k1\n        self.k2 = k2\n\n        self.ssim_metric = SSIMMetric(\n            spatial_dims=self.spatial_dims,\n            data_range=self._data_range,\n            kernel_type=self.kernel_type,\n            win_size=self.kernel_size,\n            kernel_sigma=self.kernel_sigma,\n            k1=self.k1,\n            k2=self.k2,\n        )\n\n    @property\n    def data_range(self) -> float:\n        return self._data_range\n\n    @data_range.setter\n    def data_range(self, value: float) -> None:\n        self._data_range = value\n        self.ssim_metric.data_range = value\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])\n            target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])\n\n        Returns:\n            1 minus the ssim index (recall this is meant to be a loss function)\n\n        Example:\n            .. code-block:: python\n\n                import torch\n\n                # 2D data\n                x = torch.ones([1,1,10,10])/2\n                y = torch.ones([1,1,10,10])/2\n                print(1-SSIMLoss(spatial_dims=2)(x,y))\n\n                # pseudo-3D data\n                x = torch.ones([1,5,10,10])/2  # 5 could represent number of slices\n                y = torch.ones([1,5,10,10])/2\n                print(1-SSIMLoss(spatial_dims=2)(x,y))\n\n                # 3D data\n                x = torch.ones([1,1,10,10,10])/2\n                y = torch.ones([1,1,10,10,10])/2\n                print(1-SSIMLoss(spatial_dims=3)(x,y))\n        \"\"\"\n        ssim_value = self.ssim_metric._compute_tensor(input, target).view(-1, 1)\n        loss: torch.Tensor = 1 - ssim_value\n\n        if self.reduction == LossReduction.MEAN.value:\n            loss = torch.mean(loss)  # the batch average\n        elif self.reduction == LossReduction.SUM.value:\n            loss = torch.sum(loss)  # sum over the batch\n\n        return loss\n"
  },
  {
    "path": "monai/losses/sure_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Callable\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.modules.loss import _Loss\n\n\ndef complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    First compute the difference in the complex domain,\n    then get the absolute value and take the mse\n\n    Args:\n        x, y - B, 2, H, W real valued tensors representing complex numbers\n                or  B,1,H,W complex valued tensors\n    Returns:\n        l2_loss - scalar\n    \"\"\"\n    if not x.is_complex():\n        x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous())\n    if not y.is_complex():\n        y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous())\n\n    diff = torch.abs(x - y)\n    return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction=\"mean\")\n\n\ndef sure_loss_function(\n    operator: Callable,\n    x: torch.Tensor,\n    y_pseudo_gt: torch.Tensor,\n    y_ref: torch.Tensor | None = None,\n    eps: float = -1.0,\n    perturb_noise: torch.Tensor | None = None,\n    complex_input: bool | None = False,\n) -> torch.Tensor:\n    \"\"\"\n    Args:\n        operator (function): The operator function that takes in an input\n        tensor x and returns an output tensor y. We will use this to compute\n        the divergence. More specifically, we will perturb the input x by a\n        small amount and compute the divergence between the perturbed output\n        and the reference output\n\n        x (torch.Tensor): The input tensor of shape (B, C, H, W) to the\n        operator.  For complex input, the shape is (B, 2, H, W) aka C=2 real.\n        For real input, the shape is (B, 1, H, W) real.\n\n        y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape\n        (B, C, H, W) used to compute the L2 loss.  For complex input, the shape is\n        (B, 2, H, W) aka C=2 real.  For real input, the shape is (B, 1, H, W)\n        real.\n\n        y_ref (torch.Tensor, optional): The reference output tensor of shape\n        (B, C, H, W) used to compute the divergence. Defaults to None.  For\n        complex input, the shape is (B, 2, H, W) aka C=2 real.  For real input,\n        the shape is (B, 1, H, W) real.\n\n        eps (float, optional): The perturbation scalar. Set to -1 to set it\n        automatically estimated based on y_pseudo_gtk\n\n        perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W).\n        Defaults to None.  For complex input, the shape is (B, 2, H, W) aka C=2 real.\n        For real input, the shape is (B, 1, H, W) real.\n\n        complex_input(bool, optional): Whether the input is complex or not.\n        Defaults to False.\n\n    Returns:\n        sure_loss (torch.Tensor): The SURE loss scalar.\n    \"\"\"\n    # perturb input\n    if perturb_noise is None:\n        perturb_noise = torch.randn_like(x)\n    if eps == -1.0:\n        eps = float(torch.abs(y_pseudo_gt.max())) / 1000\n    # get y_ref if not provided\n    if y_ref is None:\n        y_ref = operator(x)\n\n    # get perturbed output\n    x_perturbed = x + eps * perturb_noise  # type: ignore\n    y_perturbed = operator(x_perturbed)\n    # divergence\n    divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref))  # type: ignore\n    # l2 loss between y_ref, y_pseudo_gt\n    if complex_input:\n        l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt)\n    else:\n        # real input\n        l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction=\"mean\")\n\n    # sure loss\n    sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3])\n    return sure_loss\n\n\nclass SURELoss(_Loss):\n    \"\"\"\n    Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator.\n\n    This is a differentiable loss function that can be used to train/guide an\n    operator (e.g. neural network), where the pseudo ground truth is available\n    but the reference ground truth is not. For example, in the MRI\n    reconstruction, the pseudo ground truth is the zero-filled reconstruction\n    and the reference ground truth is the fully sampled reconstruction.  Often,\n    the reference ground truth is not available due to the lack of fully sampled\n    data.\n\n    The original SURE loss is proposed in [1]. The SURE loss used for guiding\n    the diffusion model based MRI reconstruction is proposed in [2].\n\n    Reference\n\n    [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics\n\n    [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models.\n    (https://arxiv.org/pdf/2310.01799.pdf)\n    \"\"\"\n\n    def __init__(self, perturb_noise: torch.Tensor | None = None, eps: float | None = None) -> None:\n        \"\"\"\n        Args:\n            perturb_noise (torch.Tensor, optional): The noise vector of shape\n            (B, C, H, W). Defaults to None.  For complex input, the shape is (B, 2, H, W) aka C=2 real.\n            For real input, the shape is (B, 1, H, W) real.\n\n            eps (float, optional): The perturbation scalar. Defaults to None.\n        \"\"\"\n        super().__init__()\n        self.perturb_noise = perturb_noise\n        self.eps = eps\n\n    def forward(\n        self,\n        operator: Callable,\n        x: torch.Tensor,\n        y_pseudo_gt: torch.Tensor,\n        y_ref: torch.Tensor | None = None,\n        complex_input: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            operator (function): The operator function that takes in an input\n            tensor x and returns an output tensor y. We will use this to compute\n            the divergence. More specifically, we will perturb the input x by a\n            small amount and compute the divergence between the perturbed output\n            and the reference output\n\n            x (torch.Tensor): The input tensor of shape (B, C, H, W) to the\n            operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka\n            C=2 real.  For real input, the shape is (B, 1, H, W) real.\n\n            y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape\n            (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex\n            input, the shape is (B, 2, H, W) aka C=2 real.  For real input, the\n            shape is (B, 1, H, W) real.\n\n            y_ref (torch.Tensor, optional): The reference output tensor of the\n            same shape as y_pseudo_gt\n\n        Returns:\n            sure_loss (torch.Tensor): The SURE loss scalar.\n        \"\"\"\n\n        # check inputs shapes\n        if x.dim() != 4:\n            raise ValueError(f\"Input tensor x should be 4D, got {x.dim()}.\")\n        if y_pseudo_gt.dim() != 4:\n            raise ValueError(f\"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.\")\n        if y_ref is not None and y_ref.dim() != 4:\n            raise ValueError(f\"Input tensor y_ref should be 4D, but got {y_ref.dim()}.\")\n        if x.shape != y_pseudo_gt.shape:\n            raise ValueError(\n                f\"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, \"\n                f\"y_pseudo_gt shape {y_pseudo_gt.shape}.\"\n            )\n        if y_ref is not None and y_pseudo_gt.shape != y_ref.shape:\n            raise ValueError(\n                f\"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, \"\n                f\"y_ref shape {y_ref.shape}.\"\n            )\n\n        # compute loss\n        eps = self.eps if self.eps is not None else -1.0\n        loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, eps, self.perturb_noise, complex_input)\n\n        return loss\n"
  },
  {
    "path": "monai/losses/tversky.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.losses.utils import compute_tp_fp_fn\nfrom monai.networks import one_hot\nfrom monai.utils import LossReduction\n\n\nclass TverskyLoss(_Loss):\n    \"\"\"\n    Compute the Tversky loss defined in:\n\n        Sadegh et al. (2017) Tversky loss function for image segmentation\n        using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)\n\n        Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with\n        Soft Labels. MICCAI 2023.\n\n    Adapted from:\n        https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        to_onehot_y: bool = False,\n        sigmoid: bool = False,\n        softmax: bool = False,\n        other_act: Callable | None = None,\n        alpha: float = 0.5,\n        beta: float = 0.5,\n        reduction: LossReduction | str = LossReduction.MEAN,\n        smooth_nr: float = 1e-5,\n        smooth_dr: float = 1e-5,\n        batch: bool = False,\n        soft_label: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            include_background: If False channel index 0 (background category) is excluded from the calculation.\n            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.\n            sigmoid: If True, apply a sigmoid function to the prediction.\n            softmax: If True, apply a softmax function to the prediction.\n            other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute\n                other activation layers, Defaults to ``None``. for example:\n                `other_act = torch.tanh`.\n            alpha: weight of false positives\n            beta: weight of false negatives\n            reduction: {``\"none\"``, ``\"mean\"``, ``\"sum\"``}\n                Specifies the reduction to apply to the output. Defaults to ``\"mean\"``.\n\n                - ``\"none\"``: no reduction will be applied.\n                - ``\"mean\"``: the sum of the output will be divided by the number of elements in the output.\n                - ``\"sum\"``: the output will be summed.\n\n            smooth_nr: a small constant added to the numerator to avoid zero.\n            smooth_dr: a small constant added to the denominator to avoid nan.\n            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.\n                Defaults to False, a Dice loss value is computed independently from each item in the batch\n                before any `reduction`.\n            soft_label: whether the target contains non-binary values (soft labels) or not.\n                If True a soft label formulation of the loss will be used.\n\n        Raises:\n            TypeError: When ``other_act`` is not an ``Optional[Callable]``.\n            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].\n                Incompatible values.\n\n        \"\"\"\n\n        super().__init__(reduction=LossReduction(reduction).value)\n        if other_act is not None and not callable(other_act):\n            raise TypeError(f\"other_act must be None or callable but is {type(other_act).__name__}.\")\n        if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:\n            raise ValueError(\"Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].\")\n        self.include_background = include_background\n        self.to_onehot_y = to_onehot_y\n        self.sigmoid = sigmoid\n        self.softmax = softmax\n        self.other_act = other_act\n        self.alpha = alpha\n        self.beta = beta\n        self.smooth_nr = float(smooth_nr)\n        self.smooth_dr = float(smooth_dr)\n        self.batch = batch\n        self.soft_label = soft_label\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            input: the shape should be BNH[WD].\n            target: the shape should be BNH[WD].\n\n        Raises:\n            ValueError: When ``self.reduction`` is not one of [\"mean\", \"sum\", \"none\"].\n\n        \"\"\"\n        if self.sigmoid:\n            input = torch.sigmoid(input)\n\n        n_pred_ch = input.shape[1]\n        if self.softmax:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `softmax=True` ignored.\")\n            else:\n                input = torch.softmax(input, 1)\n\n        if self.other_act is not None:\n            input = self.other_act(input)\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                target = one_hot(target, num_classes=n_pred_ch)\n\n        if not self.include_background:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `include_background=False` ignored.\")\n            else:\n                # if skipping background, removing first channel\n                target = target[:, 1:]\n                input = input[:, 1:]\n\n        if target.shape != input.shape:\n            raise AssertionError(f\"ground truth has differing shape ({target.shape}) from input ({input.shape})\")\n\n        # reducing only spatial dimensions (not batch nor channels)\n        reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()\n        if self.batch:\n            # reducing spatial dimensions and batch\n            reduce_axis = [0] + reduce_axis\n\n        tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)\n        fp *= self.alpha\n        fn *= self.beta\n        numerator = tp + self.smooth_nr\n        denominator = tp + fp + fn + self.smooth_dr\n\n        score: torch.Tensor = 1.0 - numerator / denominator\n\n        if self.reduction == LossReduction.SUM.value:\n            return torch.sum(score)  # sum over the batch and channel dims\n        if self.reduction == LossReduction.NONE.value:\n            return score  # returns [N, num_classes] losses\n        if self.reduction == LossReduction.MEAN.value:\n            return torch.mean(score)\n        raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n"
  },
  {
    "path": "monai/losses/unified_focal_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.networks import one_hot\nfrom monai.utils import LossReduction\n\n\nclass AsymmetricFocalTverskyLoss(_Loss):\n    \"\"\"\n    AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.\n\n    Actually, it's only supported for binary image segmentation now.\n\n    Reimplementation of the Asymmetric Focal Tversky Loss described in:\n\n    - \"Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation\",\n    Michael Yeung, Computerized Medical Imaging and Graphics\n    \"\"\"\n\n    def __init__(\n        self,\n        to_onehot_y: bool = False,\n        delta: float = 0.7,\n        gamma: float = 0.75,\n        epsilon: float = 1e-7,\n        reduction: LossReduction | str = LossReduction.MEAN,\n    ) -> None:\n        \"\"\"\n        Args:\n            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.\n            delta : weight of the background. Defaults to 0.7.\n            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.\n            epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.to_onehot_y = to_onehot_y\n        self.delta = delta\n        self.gamma = gamma\n        self.epsilon = epsilon\n\n    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n        n_pred_ch = y_pred.shape[1]\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                y_true = one_hot(y_true, num_classes=n_pred_ch)\n\n        if y_true.shape != y_pred.shape:\n            raise ValueError(f\"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})\")\n\n        # clip the prediction to avoid NaN\n        y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)\n        axis = list(range(2, len(y_pred.shape)))\n\n        # Calculate true positives (tp), false negatives (fn) and false positives (fp)\n        tp = torch.sum(y_true * y_pred, dim=axis)\n        fn = torch.sum(y_true * (1 - y_pred), dim=axis)\n        fp = torch.sum((1 - y_true) * y_pred, dim=axis)\n        dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)\n\n        # Calculate losses separately for each class, enhancing both classes\n        back_dice = 1 - dice_class[:, 0]\n        fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)\n\n        # Average class scores\n        loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))\n        return loss\n\n\nclass AsymmetricFocalLoss(_Loss):\n    \"\"\"\n    AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.\n\n    Actually, it's only supported for binary image segmentation now.\n\n    Reimplementation of the Asymmetric Focal Loss described in:\n\n    - \"Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation\",\n    Michael Yeung, Computerized Medical Imaging and Graphics\n    \"\"\"\n\n    def __init__(\n        self,\n        to_onehot_y: bool = False,\n        delta: float = 0.7,\n        gamma: float = 2,\n        epsilon: float = 1e-7,\n        reduction: LossReduction | str = LossReduction.MEAN,\n    ):\n        \"\"\"\n        Args:\n            to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.\n            delta : weight of the background. Defaults to 0.7.\n            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.\n            epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.to_onehot_y = to_onehot_y\n        self.delta = delta\n        self.gamma = gamma\n        self.epsilon = epsilon\n\n    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n        n_pred_ch = y_pred.shape[1]\n\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                y_true = one_hot(y_true, num_classes=n_pred_ch)\n\n        if y_true.shape != y_pred.shape:\n            raise ValueError(f\"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})\")\n\n        y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)\n        cross_entropy = -y_true * torch.log(y_pred)\n\n        back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]\n        back_ce = (1 - self.delta) * back_ce\n\n        fore_ce = cross_entropy[:, 1]\n        fore_ce = self.delta * fore_ce\n\n        loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))\n        return loss\n\n\nclass AsymmetricUnifiedFocalLoss(_Loss):\n    \"\"\"\n    AsymmetricUnifiedFocalLoss is a variant of Focal Loss.\n\n    Actually, it's only supported for binary image segmentation now\n\n    Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:\n\n    - \"Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation\",\n    Michael Yeung, Computerized Medical Imaging and Graphics\n    \"\"\"\n\n    def __init__(\n        self,\n        to_onehot_y: bool = False,\n        num_classes: int = 2,\n        weight: float = 0.5,\n        gamma: float = 0.5,\n        delta: float = 0.7,\n        reduction: LossReduction | str = LossReduction.MEAN,\n    ):\n        \"\"\"\n        Args:\n            to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.\n            num_classes : number of classes, it only supports 2 now. Defaults to 2.\n            weight : weight for each loss function. Defaults to 0.5.\n            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.\n            delta : weight of the background. Defaults to 0.7.\n\n\n\n        Example:\n            >>> import torch\n            >>> from monai.losses import AsymmetricUnifiedFocalLoss\n            >>> pred = torch.ones((1,1,32,32), dtype=torch.float32)\n            >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)\n            >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)\n            >>> fl(pred, grnd)\n        \"\"\"\n        super().__init__(reduction=LossReduction(reduction).value)\n        self.to_onehot_y = to_onehot_y\n        self.num_classes = num_classes\n        self.gamma = gamma\n        self.delta = delta\n        self.weight: float = weight\n        self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)\n        self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)\n\n    # TODO: Implement this  function to support multiple classes segmentation\n    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            y_pred : the shape should be BNH[WD], where N is the number of classes.\n                It only supports binary segmentation.\n                The input should be the original logits since it will be transformed by\n                    a sigmoid in the forward function.\n            y_true : the shape should be BNH[WD], where N is the number of classes.\n                It only supports binary segmentation.\n\n        Raises:\n            ValueError: When input and target are different shape\n            ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5\n            ValueError: When num_classes\n            ValueError: When the number of classes entered does not match the expected number\n        \"\"\"\n        if y_pred.shape != y_true.shape:\n            raise ValueError(f\"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})\")\n\n        if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:\n            raise ValueError(f\"input shape must be 4 or 5, but got {y_pred.shape}\")\n\n        if y_pred.shape[1] == 1:\n            y_pred = one_hot(y_pred, num_classes=self.num_classes)\n            y_true = one_hot(y_true, num_classes=self.num_classes)\n\n        if torch.max(y_true) != self.num_classes - 1:\n            raise ValueError(f\"Please make sure the number of classes is {self.num_classes - 1}\")\n\n        n_pred_ch = y_pred.shape[1]\n        if self.to_onehot_y:\n            if n_pred_ch == 1:\n                warnings.warn(\"single channel prediction, `to_onehot_y=True` ignored.\")\n            else:\n                y_true = one_hot(y_true, num_classes=n_pred_ch)\n\n        asy_focal_loss = self.asy_focal_loss(y_pred, y_true)\n        asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)\n\n        loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss\n\n        if self.reduction == LossReduction.SUM.value:\n            return torch.sum(loss)  # sum over the batch and channel dims\n        if self.reduction == LossReduction.NONE.value:\n            return loss  # returns [N, num_classes] losses\n        if self.reduction == LossReduction.MEAN.value:\n            return torch.mean(loss)\n        raise ValueError(f'Unsupported reduction: {self.reduction}, available options are [\"mean\", \"sum\", \"none\"].')\n"
  },
  {
    "path": "monai/losses/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.linalg as LA\n\n\ndef compute_tp_fp_fn(\n    input: torch.Tensor,\n    target: torch.Tensor,\n    reduce_axis: list[int],\n    ord: int,\n    soft_label: bool,\n    decoupled: bool = True,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Args:\n        input: the shape should be BNH[WD], where N is the number of classes.\n        target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.\n        reduce_axis: the axis to be reduced.\n        ord: the order of the vector norm.\n        soft_label: whether the target contains non-binary values (soft labels) or not.\n            If True a soft label formulation of the loss will be used.\n        decoupled: whether the input and the target should be decoupled when computing fp and fn.\n            Only for the original implementation when soft_label is False.\n\n    Adapted from:\n        https://github.com/zifuwanggg/JDTLosses\n    \"\"\"\n\n    # the original implementation that is erroneous with soft labels\n    if ord == 1 and not soft_label:\n        tp = torch.sum(input * target, dim=reduce_axis)\n        # the original implementation of Dice and Jaccard loss\n        if decoupled:\n            fp = torch.sum(input, dim=reduce_axis) - tp\n            fn = torch.sum(target, dim=reduce_axis) - tp\n        # the original implementation of Tversky loss\n        else:\n            fp = torch.sum(input * (1 - target), dim=reduce_axis)\n            fn = torch.sum((1 - input) * target, dim=reduce_axis)\n    # the new implementation that is correct with soft labels\n    # and it is identical to the original implementation with hard labels\n    else:\n        pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)\n        ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)\n        difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)\n\n        if ord > 1:\n            pred_o = torch.pow(pred_o, exponent=ord)\n            ground_o = torch.pow(ground_o, exponent=ord)\n            difference = torch.pow(difference, exponent=ord)\n\n        tp = (pred_o + ground_o - difference) / 2\n        fp = pred_o - tp\n        fn = ground_o - tp\n\n    return tp, fp, fn\n"
  },
  {
    "path": "monai/metrics/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score\nfrom .average_precision import AveragePrecisionMetric, compute_average_precision\nfrom .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning\nfrom .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix\nfrom .cumulative_average import CumulativeAverage\nfrom .f_beta_score import FBetaScore\nfrom .fid import FIDMetric, compute_frechet_distance\nfrom .froc import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score\nfrom .generalized_dice import GeneralizedDiceScore, compute_generalized_dice\nfrom .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance\nfrom .loss_metric import LossMetric\nfrom .meandice import DiceHelper, DiceMetric, compute_dice\nfrom .meaniou import MeanIoU, compute_iou\nfrom .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric\nfrom .mmd import MMDMetric, compute_mmd\nfrom .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality\nfrom .regression import (\n    MAEMetric,\n    MAPEMetric,\n    MSEMetric,\n    MultiScaleSSIMMetric,\n    PSNRMetric,\n    RMSEMetric,\n    SSIMMetric,\n    compute_ms_ssim,\n    compute_ssim_and_cs,\n)\nfrom .rocauc import ROCAUCMetric, compute_roc_auc\nfrom .surface_dice import SurfaceDiceMetric, compute_surface_dice\nfrom .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance\nfrom .utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background, is_binary_tensor\nfrom .wrapper import MetricsReloadedBinary, MetricsReloadedCategorical\n"
  },
  {
    "path": "monai/metrics/active_learning_metrics.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom typing import Any\n\nimport torch\n\nfrom monai.metrics.utils import ignore_background\nfrom monai.utils import MetricReduction\n\nfrom .metric import Metric\n\n\nclass VarianceMetric(Metric):\n    \"\"\"\n    Compute the Variance of a given T-repeats N-dimensional array/tensor. The primary usage is as an uncertainty based\n    metric for Active Learning.\n\n    It can return the spatial variance/uncertainty map based on user choice or a single scalar value via mean/sum of the\n    variance for scoring purposes\n\n    Args:\n        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector\n        spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image dimensions\n        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used\n        threshold: To avoid NaN's a threshold is used to replace zero's\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        spatial_map: bool = False,\n        scalar_reduction: str = \"sum\",\n        threshold: float = 0.0005,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.spatial_map = spatial_map\n        self.scalar_reduction = scalar_reduction\n        self.threshold = threshold\n\n    def __call__(self, y_pred: Any) -> Any:\n        \"\"\"\n        Args:\n            y_pred: Predicted segmentation, typically segmentation model output.\n                It must be N-repeats, repeat-first tensor [N,C,H,W,D].\n\n        Returns:\n            Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty\n\n        \"\"\"\n        return compute_variance(\n            y_pred=y_pred,\n            include_background=self.include_background,\n            spatial_map=self.spatial_map,\n            scalar_reduction=self.scalar_reduction,\n            threshold=self.threshold,\n        )\n\n\nclass LabelQualityScore(Metric):\n    \"\"\"\n    The assumption is that the DL model makes better predictions than the provided label quality, hence the difference\n    can be treated as a label quality score\n\n    It can be combined with variance/uncertainty for active learning frameworks to factor in the quality of label along\n    with uncertainty\n    Args:\n        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector\n        spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image\n        dimensions\n        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used\n\n    \"\"\"\n\n    def __init__(self, include_background: bool = True, scalar_reduction: str = \"sum\") -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.scalar_reduction = scalar_reduction\n\n    def __call__(self, y_pred: Any, y: Any) -> torch.Tensor | None:\n        \"\"\"\n        Args:\n            y_pred: Predicted segmentation, typically segmentation model output.\n                It must be N-repeats, repeat-first tensor [N,C,H,W,D].\n\n        Returns:\n            Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty\n\n        \"\"\"\n        return label_quality_score(\n            y_pred=y_pred, y=y, include_background=self.include_background, scalar_reduction=self.scalar_reduction\n        )\n\n\ndef compute_variance(\n    y_pred: torch.Tensor,\n    include_background: bool = True,\n    spatial_map: bool = False,\n    scalar_reduction: str = \"mean\",\n    threshold: float = 0.0005,\n) -> torch.Tensor | None:\n    \"\"\"\n    Args:\n        y_pred: [N, C, H, W, D] or [N, C, H, W] or [N, C, H] where N is repeats, C is channels and H, W, D stand for\n            Height, Width & Depth\n        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector\n        spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image\n            dimensions\n        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used\n        threshold: To avoid NaN's a threshold is used to replace zero's\n    Returns:\n        A single scalar uncertainty/variance value or the spatial map of uncertainty/variance\n    \"\"\"\n\n    # The background utils is only applicable here because instead of Batch-dimension we have repeats here\n    y_pred = y_pred.float()\n\n    if not include_background:\n        y = y_pred\n        # TODO If this utils is made to be optional for 'y' it would be nice\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    # Set any values below 0 to threshold\n    y_pred[y_pred <= 0] = threshold\n\n    n_len = len(y_pred.shape)\n\n    if n_len < 4 and spatial_map:\n        warnings.warn(\"Spatial map requires a 2D/3D image with N-repeats and C-channels\")\n        return None\n\n    # Create new shape list\n    # The N-repeats are multiplied by channels\n    n_shape = y_pred.shape\n    new_shape = [n_shape[0] * n_shape[1]]\n    for each_dim_idx in range(2, n_len):\n        new_shape.append(n_shape[each_dim_idx])\n\n    y_reshaped = torch.reshape(y_pred, new_shape)\n    variance = torch.var(y_reshaped, dim=0, unbiased=False)\n\n    if spatial_map:\n        return variance\n\n    if scalar_reduction == MetricReduction.MEAN:\n        return torch.mean(variance)\n    if scalar_reduction == MetricReduction.SUM:\n        return torch.sum(variance)\n    raise ValueError(f\"scalar_reduction={scalar_reduction} not supported.\")\n\n\ndef label_quality_score(\n    y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, scalar_reduction: str = \"mean\"\n) -> torch.Tensor | None:\n    \"\"\"\n    The assumption is that the DL model makes better predictions than the provided label quality, hence the difference\n    can be treated as a label quality score\n\n    Args:\n        y_pred: Input data of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is\n            channels and H, W, D stand for Height, Width & Depth\n        y: Ground Truth of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is channels\n            and H, W, D stand for Height, Width & Depth\n        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector\n        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used to retrieve a single scalar\n            value, if set to 'none' a spatial map will be returned\n\n    Returns:\n        A single scalar absolute difference value as score with a reduction based on sum/mean or the spatial map of\n        absolute difference\n    \"\"\"\n\n    # The background utils is only applicable here because instead of Batch-dimension we have repeats here\n    y_pred = y_pred.float()\n    y = y.float()\n\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    n_len = len(y_pred.shape)\n    if n_len < 4 and scalar_reduction == \"none\":\n        warnings.warn(\"Reduction set to None, Spatial map return requires a 2D/3D image of B-Batchsize and C-channels\")\n        return None\n\n    abs_diff_map = torch.abs(y_pred - y)\n\n    if scalar_reduction == MetricReduction.NONE:\n        return abs_diff_map\n\n    if scalar_reduction == MetricReduction.MEAN:\n        return torch.mean(abs_diff_map, dim=list(range(1, n_len)))\n    if scalar_reduction == MetricReduction.SUM:\n        return torch.sum(abs_diff_map, dim=list(range(1, n_len)))\n    raise ValueError(f\"scalar_reduction={scalar_reduction} not supported.\")\n"
  },
  {
    "path": "monai/metrics/average_precision.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom typing import TYPE_CHECKING, cast\n\nimport numpy as np\n\nif TYPE_CHECKING:\n    import numpy.typing as npt\n\nimport torch\n\nfrom monai.utils import Average, look_up_option\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass AveragePrecisionMetric(CumulativeIterationMetric):\n    \"\"\"\n    Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are\n    imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.\n    It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each\n    threshold, with the increase in recall from the previous threshold used as the weight:\n\n    .. math::\n        \\\\text{AP} = \\\\sum_n (R_n - R_{n-1}) P_n\n        :label: ap\n\n    where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.\n\n    Referring to: `sklearn.metrics.average_precision_score\n    <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.\n\n    The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n            Type of averaging performed if not binary classification.\n            Defaults to ``\"macro\"``.\n\n            - ``\"macro\"``: calculate metrics for each label, and find their unweighted mean.\n                This does not take label imbalance into account.\n            - ``\"weighted\"``: calculate metrics for each label, and find their average,\n                weighted by support (the number of true instances for each label).\n            - ``\"micro\"``: calculate metrics globally by considering each element of the label\n                indicator matrix as a label.\n            - ``\"none\"``: the scores for each class are returned.\n\n    \"\"\"\n\n    def __init__(self, average: Average | str = Average.MACRO) -> None:\n        super().__init__()\n        self.average = average\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:  # type: ignore[override]\n        return y_pred, y\n\n    def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:\n        \"\"\"\n        Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,\n        This function reads the buffers and computes the Average Precision.\n\n        Args:\n            average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n                Type of averaging performed if not binary classification. Defaults to `self.average`.\n\n        \"\"\"\n        y_pred, y = self.get_buffer()\n        # compute final value and do metric reduction\n        if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):\n            raise ValueError(\"y_pred and y must be PyTorch Tensor.\")\n\n        return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)\n\n\ndef _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:\n    if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):\n        raise AssertionError(\"y and y_pred must be 1 dimension data with same length.\")\n    y_unique = y.unique()\n    if len(y_unique) == 1:\n        warnings.warn(f\"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.\")\n        return float(\"nan\")\n    if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):\n        warnings.warn(f\"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.\")\n        return float(\"nan\")\n\n    n = len(y)\n    indices = y_pred.argsort(descending=True)\n    y = y[indices].cpu().numpy()  # type: ignore[assignment]\n    y_pred = y_pred[indices].cpu().numpy()  # type: ignore[assignment]\n    npos = ap = tmp_pos = 0.0\n\n    for i in range(n):\n        y_i = cast(float, y[i])\n        if i + 1 < n and y_pred[i] == y_pred[i + 1]:\n            tmp_pos += y_i\n        else:\n            tmp_pos += y_i\n            npos += tmp_pos\n            ap += tmp_pos * npos / (i + 1)\n            tmp_pos = 0\n\n    return ap / npos\n\n\ndef compute_average_precision(\n    y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO\n) -> np.ndarray | float | npt.ArrayLike:\n    \"\"\"Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are\n    imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.\n    Referring to: `sklearn.metrics.average_precision_score\n    <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.\n\n    Args:\n        y_pred: input data to compute, typical classification model output.\n            the first dim must be batch, if multi-classes, it must be in One-Hot format.\n            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.\n        y: ground truth to compute AP metric, the first dim must be batch.\n            if multi-classes, it must be in One-Hot format.\n            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.\n        average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n            Type of averaging performed if not binary classification.\n            Defaults to ``\"macro\"``.\n\n            - ``\"macro\"``: calculate metrics for each label, and find their unweighted mean.\n                This does not take label imbalance into account.\n            - ``\"weighted\"``: calculate metrics for each label, and find their average,\n                weighted by support (the number of true instances for each label).\n            - ``\"micro\"``: calculate metrics globally by considering each element of the label\n                indicator matrix as a label.\n            - ``\"none\"``: the scores for each class are returned.\n\n    Raises:\n        ValueError: When ``y_pred`` dimension is not one of [1, 2].\n        ValueError: When ``y`` dimension is not one of [1, 2].\n        ValueError: When ``average`` is not one of [\"macro\", \"weighted\", \"micro\", \"none\"].\n\n    Note:\n        Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.\n\n    \"\"\"\n    y_pred_ndim = y_pred.ndimension()\n    y_ndim = y.ndimension()\n    if y_pred_ndim not in (1, 2):\n        raise ValueError(\n            f\"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}.\"\n        )\n    if y_ndim not in (1, 2):\n        raise ValueError(f\"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.\")\n    if y_pred_ndim == 2 and y_pred.shape[1] == 1:\n        y_pred = y_pred.squeeze(dim=-1)\n        y_pred_ndim = 1\n    if y_ndim == 2 and y.shape[1] == 1:\n        y = y.squeeze(dim=-1)\n\n    if y_pred_ndim == 1:\n        return _calculate(y_pred, y)\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.\")\n\n    average = look_up_option(average, Average)\n    if average == Average.MICRO:\n        return _calculate(y_pred.flatten(), y.flatten())\n    y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)\n    ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]\n    if average == Average.NONE:\n        return ap_values\n    if average == Average.MACRO:\n        return np.mean(ap_values)\n    if average == Average.WEIGHTED:\n        weights = [sum(y_) for y_ in y]\n        return np.average(ap_values, weights=weights)  # type: ignore[no-any-return]\n    raise ValueError(f'Unsupported average: {average}, available options are [\"macro\", \"weighted\", \"micro\", \"none\"].')\n"
  },
  {
    "path": "monai/metrics/calibration.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport torch\n\nfrom monai.metrics.metric import CumulativeIterationMetric\nfrom monai.metrics.utils import do_metric_reduction, ignore_background\nfrom monai.utils import MetricReduction\nfrom monai.utils.enums import StrEnum\n\n__all__ = [\"CalibrationErrorMetric\", \"CalibrationReduction\", \"calibration_binning\"]\n\n\ndef calibration_binning(\n    y_pred: torch.Tensor, y: torch.Tensor, num_bins: int = 20, right: bool = False\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute calibration bins for predicted probabilities and ground truth labels.\n\n    This function implements hard binning for calibration analysis, grouping predictions\n    into bins based on their confidence values and computing statistics for each bin.\n    These statistics can be used to assess model calibration or plot reliability diagrams.\n\n    A well-calibrated model should have predicted probabilities that match empirical accuracy.\n    For example, among all predictions with 80% confidence, approximately 80% should be correct.\n    This function provides the per-bin statistics needed to evaluate this property.\n\n    The function operates on input and target tensors with batch and channel dimensions,\n    handling each batch and channel separately. For bins that do not contain any elements,\n    the mean predicted values and mean ground truth values are set to NaN.\n\n    Args:\n        y_pred: Predicted probabilities with shape ``(B, C, spatial...)``, where B is batch size,\n            C is number of classes/channels, and spatial can be any number of dimensions (H, W, D, etc.).\n            Values should be in the range [0, 1].\n        y: Ground truth tensor with the same shape as ``y_pred``. Should be one-hot encoded\n            or contain binary values (0 or 1) indicating the true class membership.\n        num_bins: Number of equally-spaced bins to divide the [0, 1] probability range into.\n            Defaults to 20. Must be >= 1.\n        right: Determines bin boundary inclusion. If False (default), bins include the left\n            boundary and exclude the right (i.e., [left, right)). If True, bins exclude the\n            left boundary and include the right (i.e., (left, right]).\n\n    Returns:\n        A tuple of three tensors, each with shape ``(B, C, num_bins)``:\n            - **mean_p_per_bin**: Mean predicted probability for samples in each bin.\n            - **mean_gt_per_bin**: Mean ground truth value (empirical accuracy) for samples in each bin.\n            - **bin_counts**: Number of samples falling into each bin.\n\n        Bins with no samples have NaN values for mean_p_per_bin and mean_gt_per_bin.\n\n    Raises:\n        ValueError: If ``y_pred`` and ``y`` have different shapes, if input has fewer than\n            3 dimensions, or if ``num_bins < 1``.\n\n    References:\n        - Guo, C., et al. \"On Calibration of Modern Neural Networks.\" ICML 2017.\n          https://proceedings.mlr.press/v70/guo17a.html\n        - Barfoot, T., et al. \"Average Calibration Losses for Reliable Uncertainty in\n          Medical Image Segmentation.\" arXiv:2506.03942v3, 2025.\n          https://arxiv.org/abs/2506.03942v3\n\n    Note:\n        This function uses nested loops over batch and channel dimensions for binning operations.\n        For reliability diagram visualization, use the returned statistics to plot mean predicted\n        probability vs. empirical accuracy for each bin.\n\n    Example:\n        >>> import torch\n        >>> # Binary segmentation: batch=1, channels=2, spatial=4x4\n        >>> y_pred = torch.rand(1, 2, 4, 4)  # predicted probabilities\n        >>> y = torch.randint(0, 2, (1, 2, 4, 4)).float()  # one-hot ground truth\n        >>> mean_p, mean_gt, counts = calibration_binning(y_pred, y, num_bins=10)\n        >>> # mean_p, mean_gt, counts each have shape (1, 2, 10)\n    \"\"\"\n    # Input validation\n    if y_pred.shape != y.shape:\n        raise ValueError(f\"y_pred and y must have the same shape, got {y_pred.shape} and {y.shape}.\")\n    if y_pred.ndim < 3:\n        raise ValueError(f\"y_pred must have shape (B, C, spatial...), got ndim={y_pred.ndim}.\")\n    if num_bins < 1:\n        raise ValueError(f\"num_bins must be >= 1, got {num_bins}.\")\n\n    batch_size, num_channels = y_pred.shape[:2]\n    boundaries = torch.linspace(\n        start=0.0, end=1.0 + torch.finfo(torch.float32).eps, steps=num_bins + 1, device=y_pred.device\n    )\n\n    mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device)\n    mean_gt_per_bin = torch.zeros_like(mean_p_per_bin)\n    bin_counts = torch.zeros_like(mean_p_per_bin)\n\n    y_pred_flat = y_pred.flatten(start_dim=2).float()\n    y_flat = y.flatten(start_dim=2).float()\n\n    for b in range(batch_size):\n        for c in range(num_channels):\n            values_p = y_pred_flat[b, c, :]\n            values_gt = y_flat[b, c, :]\n\n            # Compute bin indices and clamp to valid range to handle out-of-range values\n            bin_idx = torch.bucketize(values_p, boundaries[1:], right=right)\n            bin_idx = bin_idx.clamp(max=num_bins - 1)\n\n            # Compute bin counts using scatter_add\n            counts = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32)\n            counts.scatter_add_(0, bin_idx, torch.ones_like(values_p))\n            bin_counts[b, c, :] = counts\n\n            # Compute sums for mean calculation using scatter_add (more compatible than scatter_reduce)\n            sum_p = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32)\n            sum_p.scatter_add_(0, bin_idx, values_p)\n\n            sum_gt = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32)\n            sum_gt.scatter_add_(0, bin_idx, values_gt)\n\n            # Compute means, avoiding division by zero\n            safe_counts = counts.clamp(min=1)\n            mean_p_per_bin[b, c, :] = sum_p / safe_counts\n            mean_gt_per_bin[b, c, :] = sum_gt / safe_counts\n\n    # Set empty bins to NaN\n    mean_p_per_bin[bin_counts == 0] = torch.nan\n    mean_gt_per_bin[bin_counts == 0] = torch.nan\n\n    return mean_p_per_bin, mean_gt_per_bin, bin_counts\n\n\nclass CalibrationReduction(StrEnum):\n    \"\"\"\n    Enumeration of calibration error reduction methods for aggregating per-bin calibration errors.\n\n    - **EXPECTED**: Expected Calibration Error (ECE) - weighted average of per-bin errors by bin count.\n      This is the most commonly used calibration metric, giving more weight to bins with more samples.\n    - **AVERAGE**: Average Calibration Error (ACE) - unweighted mean of per-bin errors.\n      Treats all bins equally regardless of sample count.\n    - **MAXIMUM**: Maximum Calibration Error (MCE) - worst-case calibration error across all bins.\n      Useful for identifying the confidence range with poorest calibration.\n\n    References:\n        - Naeini, M.P., et al. \"Obtaining Well Calibrated Probabilities Using Bayesian Binning.\" AAAI 2015.\n        - Guo, C., et al. \"On Calibration of Modern Neural Networks.\" ICML 2017.\n    \"\"\"\n\n    EXPECTED = \"expected\"\n    AVERAGE = \"average\"\n    MAXIMUM = \"maximum\"\n\n\nclass CalibrationErrorMetric(CumulativeIterationMetric):\n    \"\"\"\n    Compute the Calibration Error between predicted probabilities and ground truth labels.\n\n    **Why Calibration Matters:**\n\n    A well-calibrated classifier produces probability estimates that reflect true correctness likelihood.\n    For instance, if a model predicts 80% probability for class A, a well calibrated and reliable model\n    should be correct approximately 80% of the time among all such predictions.\n    Modern neural networks, despite high accuracy, are often poorly calibrated, as they tend to be\n    overconfident in their predictions.\n    This is particularly important in medical imaging where probability estimates may inform clinical decisions.\n\n    **How It Works:**\n\n    This metric uses a binning approach: predictions are grouped into bins based on their confidence\n    (predicted probability), and for each bin, the average confidence is compared to the empirical\n    accuracy (fraction of correct predictions). The calibration error measures the discrepancy between\n    these values across all bins.\n\n    Three reduction modes are supported:\n\n    - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors, where weights are\n      proportional to the number of samples in each bin. Most commonly used metric.\n    - **Average Calibration Error (ACE)**: Simple unweighted average across bins.\n    - **Maximum Calibration Error (MCE)**: The largest calibration error among all bins.\n\n    The metric supports both single-channel and multi-channel data in the format ``(B, C, H, W[, D])``,\n    where B is batch size, C is number of classes, and H, W, D are spatial dimensions.\n\n    Args:\n        num_bins: Number of equally-spaced bins to divide the [0, 1] probability range into.\n            Defaults to 20.\n        include_background: Whether to include the first channel (index 0) in the computation.\n            Set to ``False`` to exclude background class, which is useful in segmentation tasks\n            where background may dominate and skew calibration results. Defaults to ``True``.\n        calibration_reduction: Method for calculating calibration error from binned data.\n            Available modes: ``\"expected\"`` (ECE), ``\"average\"`` (ACE), ``\"maximum\"`` (MCE).\n            Defaults to ``\"expected\"``.\n        metric_reduction: Reduction mode to apply across batch/channel dimensions after computing\n            per-sample calibration errors. Available modes: ``\"none\"``, ``\"mean\"``, ``\"sum\"``,\n            ``\"mean_batch\"``, ``\"sum_batch\"``, ``\"mean_channel\"``, ``\"sum_channel\"``.\n            Defaults to ``\"mean\"``.\n        get_not_nans: If ``True``, ``aggregate()`` returns a tuple ``(metric, not_nans)`` where\n            ``not_nans`` is the count of non-NaN values. Defaults to ``False``.\n        right: Bin boundary inclusion rule. If ``False`` (default), bins are ``[left, right)``.\n            If ``True``, bins are ``(left, right]``.\n\n    References:\n        - Guo, C., et al. \"On Calibration of Modern Neural Networks.\" ICML 2017.\n          https://proceedings.mlr.press/v70/guo17a.html\n        - Barfoot, T., et al. \"Average Calibration Losses for Reliable Uncertainty in\n          Medical Image Segmentation.\" arXiv:2506.03942v3, 2025.\n          https://arxiv.org/abs/2506.03942v3\n\n    See Also:\n        - :py:class:`monai.handlers.CalibrationError`: Ignite handler wrapper for this metric.\n        - :py:func:`calibration_binning`: Low-level binning function for reliability diagrams.\n\n    Example:\n        Typical execution steps follow :py:class:`monai.metrics.metric.Cumulative`.\n\n        >>> import torch\n        >>> from monai.metrics import CalibrationErrorMetric\n        >>> from monai.transforms import Activations, AsDiscrete\n        >>>\n        >>> # Setup transforms for probability conversion\n        >>> num_classes = 3\n        >>> softmax = Activations(softmax=True)  # convert logits to probabilities\n        >>> to_onehot = AsDiscrete(to_onehot=num_classes)  # convert labels to one-hot\n        >>>\n        >>> # Create metric (Expected Calibration Error, excluding background)\n        >>> metric = CalibrationErrorMetric(\n        ...     num_bins=15,\n        ...     include_background=False,\n        ...     calibration_reduction=\"expected\"\n        ... )\n        >>>\n        >>> # Evaluation loop\n        >>> for batch_data in dataloader:\n        ...     logits, labels = model(batch_data)\n        ...     preds = softmax(logits)  # shape: (B, C, H, W) with values in [0, 1]\n        ...     labels_onehot = to_onehot(labels)  # shape: (B, C, H, W) with values 0 or 1\n        ...     metric(y_pred=preds, y=labels_onehot)\n        >>>\n        >>> # Get final calibration error\n        >>> ece = metric.aggregate()\n        >>> print(f\"Expected Calibration Error: {ece:.4f}\")\n    \"\"\"\n\n    def __init__(\n        self,\n        num_bins: int = 20,\n        include_background: bool = True,\n        calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED,\n        metric_reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        right: bool = False,\n    ) -> None:\n        super().__init__()\n        self.num_bins = num_bins\n        self.include_background = include_background\n        self.calibration_reduction = CalibrationReduction(calibration_reduction)\n        self.metric_reduction = metric_reduction\n        self.get_not_nans = get_not_nans\n        self.right = right\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Compute calibration error for the given predictions and ground truth.\n\n        Args:\n            y_pred: input data to compute. It should be in the format of (batch, channel, spatial...).\n                    It represents probability predictions of the model.\n            y: ground truth in one-hot format. It should be in the format of (batch, channel, spatial...).\n               The values should be binarized.\n            **kwargs: additional keyword arguments (unused, for API compatibility).\n\n        Returns:\n            Calibration error tensor with shape (batch, channel).\n        \"\"\"\n        if not self.include_background:\n            y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n        mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning(\n            y_pred=y_pred, y=y, num_bins=self.num_bins, right=self.right\n        )\n\n        # Calculate the absolute differences, ignoring nan values\n        abs_diff = torch.abs(mean_p_per_bin - mean_gt_per_bin)\n\n        if self.calibration_reduction == CalibrationReduction.EXPECTED:\n            # Calculate the weighted sum of absolute differences\n            # Handle zero denominator case (all bins empty) by returning NaN\n            denom = torch.sum(bin_counts, dim=-1)\n            zero_mask = denom == 0\n            safe_denom = torch.where(zero_mask, torch.ones_like(denom), denom)\n            result = torch.nansum(abs_diff * bin_counts, dim=-1) / safe_denom\n            result = torch.where(zero_mask, torch.full_like(result, float(\"nan\")), result)\n            return result\n        elif self.calibration_reduction == CalibrationReduction.AVERAGE:\n            return torch.nanmean(abs_diff, dim=-1)  # Average across all dimensions, ignoring nan\n        elif self.calibration_reduction == CalibrationReduction.MAXIMUM:\n            # Replace NaN with -inf for max computation, then restore NaN for all-NaN cases\n            abs_diff_for_max = torch.nan_to_num(abs_diff, nan=float(\"-inf\"))\n            max_vals = torch.max(abs_diff_for_max, dim=-1).values\n            # Restore NaN where all bins were empty (max is -inf)\n            max_vals = torch.where(max_vals == float(\"-inf\"), torch.full_like(max_vals, float(\"nan\")), max_vals)\n            return max_vals\n        else:\n            raise ValueError(f\"Unsupported calibration reduction: {self.calibration_reduction}\")\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Execute reduction logic for the output of `_compute_tensor`.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.metric_reduction`. if \"none\", will not\n                do reduction.\n\n        Returns:\n            If `get_not_nans` is True, returns a tuple (metric, not_nans), otherwise returns only the metric.\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction)\n        return (f, not_nans) if self.get_not_nans else f\n"
  },
  {
    "path": "monai/metrics/confusion_matrix.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\n\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, ignore_background\nfrom monai.utils import MetricReduction, ensure_tuple\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass ConfusionMatrixMetric(CumulativeIterationMetric):\n    \"\"\"\n    Compute confusion matrix related metrics. This function supports to calculate all metrics mentioned in:\n    `Confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_.\n    It can support both multi-classes and multi-labels classification and segmentation tasks.\n    `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms\n    in ``monai.transforms.post`` first to achieve binarized values.\n    The `include_background` parameter can be set to ``False`` for an instance to exclude\n    the first category (channel index 0) which is by convention assumed to be background. If the non-background\n    segmentations are small compared to the total image size they can get overwhelmed by the signal from the\n    background.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        include_background: whether to include metric computation on the first channel of\n            the predicted output. Defaults to True.\n        metric_name: [``\"sensitivity\"``, ``\"specificity\"``, ``\"precision\"``, ``\"negative predictive value\"``,\n            ``\"miss rate\"``, ``\"fall out\"``, ``\"false discovery rate\"``, ``\"false omission rate\"``,\n            ``\"prevalence threshold\"``, ``\"threat score\"``, ``\"accuracy\"``, ``\"balanced accuracy\"``,\n            ``\"f1 score\"``, ``\"matthews correlation coefficient\"``, ``\"fowlkes mallows index\"``,\n            ``\"informedness\"``, ``\"markedness\"``]\n            Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),\n            and you can also input those names instead.\n            Except for input only one metric, multiple metrics are also supported via input a sequence of metric names, such as\n            (\"sensitivity\", \"precision\", \"recall\"), if ``compute_sample`` is ``True``, multiple ``f`` and ``not_nans`` will be\n            returned with the same order as input names when calling the class.\n        compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.\n            if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns [(metric, not_nans), ...]. If False,\n            aggregate() returns [metric, ...].\n            Here `not_nans` count the number of not nans for True Positive, False Positive, True Negative and False Negative.\n            Its shape depends on the shape of the metric, and it has one more dimension with size 4. For example, if the shape\n            of the metric is [3, 3], `not_nans` has the shape [3, 3, 4].\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        metric_name: Sequence[str] | str = \"hit_rate\",\n        compute_sample: bool = False,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.metric_name = ensure_tuple(metric_name)\n        self.compute_sample = compute_sample\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            y_pred: input data to compute. It must be one-hot format and first dim is batch.\n                The values should be binarized.\n            y: ground truth to compute the metric. It must be one-hot format and first dim is batch.\n                The values should be binarized.\n        Raises:\n            ValueError: when `y_pred` has less than two dimensions.\n        \"\"\"\n        # check dimension\n        dims = y_pred.ndimension()\n        if dims < 2:\n            raise ValueError(\"y_pred should have at least two dimensions.\")\n        if dims == 2 or (dims == 3 and y_pred.shape[-1] == 1):\n            if self.compute_sample:\n                warnings.warn(\"As for classification task, compute_sample should be False.\")\n                self.compute_sample = False\n\n        return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background)\n\n    def aggregate(\n        self, compute_sample: bool = False, reduction: MetricReduction | str | None = None\n    ) -> list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Execute reduction for the confusion matrix values.\n\n        Args:\n            compute_sample: when reducing, if ``True``, each sample's metric will be computed based on each confusion matrix first.\n                if ``False``, compute reduction on the confusion matrices first, defaults to ``False``.\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        results: list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = []\n        for metric_name in self.metric_name:\n            if compute_sample or self.compute_sample:\n                sub_confusion_matrix = compute_confusion_matrix_metric(metric_name, data)\n                f, not_nans = do_metric_reduction(sub_confusion_matrix, reduction or self.reduction)\n            else:\n                f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n                f = compute_confusion_matrix_metric(metric_name, f)\n            if self.get_not_nans:\n                results.append((f, not_nans))\n            else:\n                results.append(f)\n        return results\n\n\ndef get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor:\n    \"\"\"\n    Compute confusion matrix. A tensor with the shape [BC4] will be returned. Where, the third dimension\n    represents the number of true positive, false positive, true negative and false negative values for\n    each channel of each sample within the input batch. Where, B equals to the batch size and C equals to\n    the number of classes that need to be computed.\n\n    Args:\n        y_pred: input data to compute. It must be one-hot format and first dim is batch.\n            The values should be binarized.\n        y: ground truth to compute the metric. It must be one-hot format and first dim is batch.\n            The values should be binarized.\n        include_background: whether to include metric computation on the first channel of\n            the predicted output. Defaults to True.\n\n    Raises:\n        ValueError: when `y_pred` and `y` have different shapes.\n    \"\"\"\n\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n    # get confusion matrix related metric\n    batch_size, n_class = y_pred.shape[:2]\n    # convert to [BNS], where S is the number of pixels for one sample.\n    # As for classification tasks, S equals to 1.\n    y_pred = y_pred.reshape(batch_size, n_class, -1)\n    y = y.reshape(batch_size, n_class, -1)\n    tp = (y_pred + y) == 2\n    tn = (y_pred + y) == 0\n\n    tp = tp.sum(dim=[2]).float()\n    tn = tn.sum(dim=[2]).float()\n    p = y.sum(dim=[2]).float()\n    n = y.shape[-1] - p\n\n    fn = p - tp\n    fp = n - tn\n\n    return torch.stack([tp, fp, tn, fn], dim=-1)\n\n\ndef compute_confusion_matrix_metric(metric_name: str, confusion_matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    This function is used to compute confusion matrix related metric.\n\n    Args:\n        metric_name: [``\"sensitivity\"``, ``\"specificity\"``, ``\"precision\"``, ``\"negative predictive value\"``,\n            ``\"miss rate\"``, ``\"fall out\"``, ``\"false discovery rate\"``, ``\"false omission rate\"``,\n            ``\"prevalence threshold\"``, ``\"threat score\"``, ``\"accuracy\"``, ``\"balanced accuracy\"``,\n            ``\"f1 score\"``, ``\"matthews correlation coefficient\"``, ``\"fowlkes mallows index\"``,\n            ``\"informedness\"``, ``\"markedness\"``]\n            Some of the metrics have multiple aliases (as shown in the wikipedia page aforementioned),\n            and you can also input those names instead.\n        confusion_matrix: Please see the doc string of the function ``get_confusion_matrix`` for more details.\n\n    Raises:\n        ValueError: when the size of the last dimension of confusion_matrix is not 4.\n        NotImplementedError: when specify a not implemented metric_name.\n\n    \"\"\"\n\n    metric = check_confusion_matrix_metric_name(metric_name)\n\n    input_dim = confusion_matrix.ndimension()\n    if input_dim == 1:\n        confusion_matrix = confusion_matrix.unsqueeze(dim=0)\n    if confusion_matrix.shape[-1] != 4:\n        raise ValueError(\"the size of the last dimension of confusion_matrix should be 4.\")\n\n    tp = confusion_matrix[..., 0]\n    fp = confusion_matrix[..., 1]\n    tn = confusion_matrix[..., 2]\n    fn = confusion_matrix[..., 3]\n    p = tp + fn\n    n = fp + tn\n    # calculate metric\n    numerator: torch.Tensor\n    denominator: torch.Tensor | float\n    nan_tensor = torch.tensor(float(\"nan\"), device=confusion_matrix.device)\n    if metric == \"tpr\":\n        numerator, denominator = tp, p\n    elif metric == \"tnr\":\n        numerator, denominator = tn, n\n    elif metric == \"ppv\":\n        numerator, denominator = tp, (tp + fp)\n    elif metric == \"npv\":\n        numerator, denominator = tn, (tn + fn)\n    elif metric == \"fnr\":\n        numerator, denominator = fn, p\n    elif metric == \"fpr\":\n        numerator, denominator = fp, n\n    elif metric == \"fdr\":\n        numerator, denominator = fp, (fp + tp)\n    elif metric == \"for\":\n        numerator, denominator = fn, (fn + tn)\n    elif metric == \"pt\":\n        tpr = torch.where(p > 0, tp / p, nan_tensor)\n        tnr = torch.where(n > 0, tn / n, nan_tensor)\n        numerator = torch.sqrt(tpr * (1.0 - tnr)) + tnr - 1.0\n        denominator = tpr + tnr - 1.0\n    elif metric == \"ts\":\n        numerator, denominator = tp, (tp + fn + fp)\n    elif metric == \"acc\":\n        numerator, denominator = (tp + tn), (p + n)\n    elif metric == \"ba\":\n        tpr = torch.where(p > 0, tp / p, nan_tensor)\n        tnr = torch.where(n > 0, tn / n, nan_tensor)\n        numerator, denominator = (tpr + tnr), 2.0\n    elif metric == \"f1\":\n        numerator, denominator = tp * 2.0, (tp * 2.0 + fn + fp)\n    elif metric == \"mcc\":\n        numerator = tp * tn - fp * fn\n        denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))\n    elif metric == \"fm\":\n        tpr = torch.where(p > 0, tp / p, nan_tensor)\n        ppv = torch.where((tp + fp) > 0, tp / (tp + fp), nan_tensor)\n        numerator = torch.sqrt(ppv * tpr)\n        denominator = 1.0\n    elif metric == \"bm\":\n        tpr = torch.where(p > 0, tp / p, nan_tensor)\n        tnr = torch.where(n > 0, tn / n, nan_tensor)\n        numerator = tpr + tnr - 1.0\n        denominator = 1.0\n    elif metric == \"mk\":\n        ppv = torch.where((tp + fp) > 0, tp / (tp + fp), nan_tensor)\n        npv = torch.where((tn + fn) > 0, tn / (tn + fn), nan_tensor)\n        numerator = ppv + npv - 1.0\n        denominator = 1.0\n    else:\n        raise NotImplementedError(\"the metric is not implemented.\")\n\n    if isinstance(denominator, torch.Tensor):\n        return torch.where(denominator != 0, numerator / denominator, nan_tensor)\n    return numerator / denominator\n\n\ndef check_confusion_matrix_metric_name(metric_name: str) -> str:\n    \"\"\"\n    There are many metrics related to confusion matrix, and some of the metrics have\n    more than one names. In addition, some of the names are very long.\n    Therefore, this function is used to check and simplify the name.\n\n    Returns:\n        Simplified metric name.\n\n    Raises:\n        NotImplementedError: when the metric is not implemented.\n    \"\"\"\n    metric_name = metric_name.replace(\" \", \"_\")\n    metric_name = metric_name.lower()\n    if metric_name in [\"sensitivity\", \"recall\", \"hit_rate\", \"true_positive_rate\", \"tpr\"]:\n        return \"tpr\"\n    if metric_name in [\"specificity\", \"selectivity\", \"true_negative_rate\", \"tnr\"]:\n        return \"tnr\"\n    if metric_name in [\"precision\", \"positive_predictive_value\", \"ppv\"]:\n        return \"ppv\"\n    if metric_name in [\"negative_predictive_value\", \"npv\"]:\n        return \"npv\"\n    if metric_name in [\"miss_rate\", \"false_negative_rate\", \"fnr\"]:\n        return \"fnr\"\n    if metric_name in [\"fall_out\", \"false_positive_rate\", \"fpr\"]:\n        return \"fpr\"\n    if metric_name in [\"false_discovery_rate\", \"fdr\"]:\n        return \"fdr\"\n    if metric_name in [\"false_omission_rate\", \"for\"]:\n        return \"for\"\n    if metric_name in [\"prevalence_threshold\", \"pt\"]:\n        return \"pt\"\n    if metric_name in [\"threat_score\", \"critical_success_index\", \"ts\", \"csi\"]:\n        return \"ts\"\n    if metric_name in [\"accuracy\", \"acc\"]:\n        return \"acc\"\n    if metric_name in [\"balanced_accuracy\", \"ba\"]:\n        return \"ba\"\n    if metric_name in [\"f1_score\", \"f1\"]:\n        return \"f1\"\n    if metric_name in [\"matthews_correlation_coefficient\", \"mcc\"]:\n        return \"mcc\"\n    if metric_name in [\"fowlkes_mallows_index\", \"fm\"]:\n        return \"fm\"\n    if metric_name in [\"informedness\", \"bookmaker_informedness\", \"bm\", \"youden_index\", \"youden\"]:\n        return \"bm\"\n    if metric_name in [\"markedness\", \"deltap\", \"mk\"]:\n        return \"mk\"\n    raise NotImplementedError(\"the metric is not implemented.\")\n"
  },
  {
    "path": "monai/metrics/cumulative_average.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom typing import Any\n\nimport torch\nimport torch.distributed as dist\n\nfrom monai.config import NdarrayOrTensor\n\n\nclass CumulativeAverage:\n    \"\"\"\n    A utility class to keep track of average values. For example during training/validation loop,\n    we need to accumulate the per-batch metrics and calculate the final average value for the whole dataset.\n    When training in multi-gpu environment, with DistributedDataParallel, it will average across the processes.\n\n    Example:\n\n    .. code-block:: python\n\n        from monai.metrics import CumulativeAverage\n\n        run_avg = CumulativeAverage()\n        batch_size = 8\n        for i in range(len(train_set)):\n            ...\n            val = calc_metric(x,y) #some metric value\n            run_avg.append(val, count=batch_size)\n\n        val_avg = run_avg.aggregate() #average value\n\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.reset()\n\n    def reset(self) -> None:\n        \"\"\"\n        Reset all  stats\n        \"\"\"\n        self.val: torch.Tensor = None  # type: ignore\n        self.sum = torch.tensor(0, dtype=torch.float)\n        self.count = torch.tensor(0, dtype=torch.float)\n        self.is_distributed = dist.is_available() and dist.is_initialized()\n\n    def get_current(self, to_numpy: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        returns the most recent value (averaged across processes)\n\n        Args:\n            to_numpy: whether to convert to numpy array. Defaults to True\n        \"\"\"\n        if self.val is None:\n            return 0\n\n        val: NdarrayOrTensor\n        val = self.val.clone()\n        val[~torch.isfinite(val)] = 0\n\n        if self.is_distributed:\n            val = val / dist.get_world_size()\n            dist.all_reduce(val)\n\n        if to_numpy:\n            val = val.cpu().numpy()\n\n        return val\n\n    def aggregate(self, to_numpy: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        returns the total average value (averaged across processes)\n\n        Args:\n            to_numpy: whether to convert to numpy array. Defaults to True\n        \"\"\"\n        if self.val is None:\n            return 0\n\n        sum = self.sum\n        count = self.count\n\n        if self.is_distributed:\n            sum = sum.to(self.val, copy=True)\n            count = count.to(self.val, copy=True)\n            dist.all_reduce(sum)\n            dist.all_reduce(count)\n\n        val: NdarrayOrTensor\n        val = torch.where(count > 0, sum / count, sum)\n\n        if to_numpy:\n            val = val.cpu().numpy()\n        return val\n\n    def append(self, val: Any, count: Any | None = 1) -> None:\n        \"\"\"\n        Append with a new value, and an optional count. Any data type is supported that is convertable\n            with torch.as_tensor() e.g. number, list, numpy array, or Tensor.\n\n        Args:\n            val: value (e.g. number, list, numpy array or Tensor) to keep track of\n            count: count (e.g. number, list, numpy array or Tensor), to update the contribution count\n\n        For example:\n            # a simple constant tracking\n            avg = CumulativeAverage()\n            avg.append(0.6)\n            avg.append(0.8)\n            print(avg.aggregate()) #prints 0.7\n\n            # an array tracking, e.g. metrics from 3 classes\n            avg= CumulativeAverage()\n            avg.append([0.2, 0.4, 0.4])\n            avg.append([0.4, 0.6, 0.4])\n            print(avg.aggregate()) #prints [0.3, 0.5. 0.4]\n\n            # different contributions / counts\n            avg= CumulativeAverage()\n            avg.append(1, count=4) #avg metric 1 coming from a batch of 4\n            avg.append(2, count=6) #avg metric 2 coming from a batch of 6\n            print(avg.aggregate()) #prints 1.6 == (1*4 +2*6)/(4+6)\n\n            # different contributions / counts\n            avg= CumulativeAverage()\n            avg.append([0.5, 0.5, 0], count=[1, 1, 0]) # last elements count is zero to ignore it\n            avg.append([0.5, 0.5, 0.5], count=[1, 1, 1]) #\n            print(avg.aggregate()) #prints [0.5, 0.5, 0,5] == ([0.5, 0.5, 0] + [0.5, 0.5, 0.5]) / ([1, 1, 0] + [1, 1, 1])\n\n        \"\"\"\n        self.val = torch.as_tensor(val, dtype=torch.float)\n        if self.val.requires_grad:\n            self.val = self.val.detach().clone()\n\n        count = torch.as_tensor(count, dtype=torch.float, device=\"cpu\")\n        if count.ndim > 0 and count.shape != self.val.shape:\n            raise ValueError(\n                f\"Count shape must match val shape, unless count is a single number: {count} val {self.val.cpu()}\"\n            )\n\n        val = count * self.val.cpu()\n\n        # account for possible non-finite numbers in val and replace them with 0s\n        nfin = torch.isfinite(val)\n        if not torch.all(nfin):\n            warnings.warn(f\"non-finite inputs received: val: {val}, count: {count}\")\n            count = torch.where(nfin, count, torch.zeros_like(count))\n            val = torch.where(nfin, val, torch.zeros_like(val))\n\n        self.count = self.count + count\n        self.sum = self.sum + val\n"
  },
  {
    "path": "monai/metrics/f_beta_score.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, ignore_background\nfrom monai.utils import MetricReduction\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass FBetaScore(CumulativeIterationMetric):\n\n    def __init__(\n        self,\n        beta: float = 1.0,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__()\n        self.beta = beta\n        self.include_background = include_background\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        if y_pred.ndimension() < 2:\n            raise ValueError(\"y_pred should have at least two dimensions.\")\n\n        return get_f_beta_score(y_pred=y_pred, y=y, include_background=self.include_background)\n\n    def aggregate(\n        self, compute_sample: bool = False, reduction: MetricReduction | str | None = None\n    ) -> Sequence[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]:\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        results: list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = []\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        f = compute_f_beta_score(f, self.beta)\n        if self.get_not_nans:\n            results.append((f, not_nans))\n        else:\n            results.append(f)\n\n        return results\n\n\ndef get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True) -> torch.Tensor:\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n    # get confusion matrix related metric\n    batch_size, n_class = y_pred.shape[:2]\n    # convert to [BNS], where S is the number of pixels for one sample.\n    # As for classification tasks, S equals to 1.\n    y_pred = y_pred.view(batch_size, n_class, -1)\n    y = y.view(batch_size, n_class, -1)\n    tp = (y_pred + y) == 2\n    tn = (y_pred + y) == 0\n\n    tp = tp.sum(dim=[2]).float()\n    tn = tn.sum(dim=[2]).float()\n    p = y.sum(dim=[2]).float()\n    n = y.shape[-1] - p\n\n    fn = p - tp\n    fp = n - tn\n\n    return torch.stack([tp, fp, tn, fn], dim=-1)\n\n\ndef compute_f_beta_score(confusion_matrix: torch.Tensor, beta: float) -> torch.Tensor:\n    input_dim = confusion_matrix.ndimension()\n    if input_dim == 1:\n        confusion_matrix = confusion_matrix.unsqueeze(dim=0)\n    if confusion_matrix.shape[-1] != 4:\n        raise ValueError(\"the size of the last dimension of confusion_matrix should be 4.\")\n\n    tp = confusion_matrix[..., 0]\n    fp = confusion_matrix[..., 1]\n    # tn = confusion_matrix[..., 2]\n    fn = confusion_matrix[..., 3]\n\n    nan_tensor = torch.tensor(float(\"nan\"), device=confusion_matrix.device)\n    numerator, denominator = (1.0 + beta**2) * tp, ((1.0 + beta**2) * tp + beta**2 * fn + fp)\n\n    if isinstance(denominator, torch.Tensor):\n        return torch.where(denominator != 0, numerator / denominator, nan_tensor)\n    return numerator / denominator\n"
  },
  {
    "path": "monai/metrics/fid.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport numpy as np\nimport torch\n\nfrom monai.metrics.metric import Metric\nfrom monai.utils import optional_import\n\nscipy, _ = optional_import(\"scipy\")\n\n\nclass FIDMetric(Metric):\n    \"\"\"\n    Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.\n    Based on: Heusel M. et al. \"Gans trained by a two time-scale update rule converge to a local nash equilibrium.\"\n    https://arxiv.org/abs/1706.08500. The inputs for this metric should be two groups of feature vectors (with format\n    (number images, number of features)) extracted from a pretrained network.\n\n    Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet.\n    However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and\n    MedicalNet for 3D images). If the chosen model output is not a scalar, a global spatia average pooling should be\n    used.\n    \"\"\"\n\n    def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        return get_fid_score(y_pred, y)\n\n\ndef get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    \"\"\"Computes the FID score metric on a batch of feature vectors.\n\n    Args:\n        y_pred: feature vectors extracted from a pretrained network run on generated images.\n        y: feature vectors extracted from a pretrained network run on images from the real data distribution.\n    \"\"\"\n    y = y.double()\n    y_pred = y_pred.double()\n\n    if y.ndimension() > 2:\n        raise ValueError(\"Inputs should have (number images, number of features) shape.\")\n\n    mu_y_pred = torch.mean(y_pred, dim=0)\n    sigma_y_pred = _cov(y_pred, rowvar=False)\n    mu_y = torch.mean(y, dim=0)\n    sigma_y = _cov(y, rowvar=False)\n\n    return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)\n\n\ndef _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor:\n    \"\"\"\n    Estimate a covariance matrix of the variables.\n\n    Args:\n        input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,\n            and each column a single observation of all those variables.\n        rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.\n            Otherwise, the relationship is transposed: each column represents a variable, while the rows contain\n            observations.\n    \"\"\"\n    if input_data.dim() < 2:\n        input_data = input_data.view(1, -1)\n\n    if not rowvar and input_data.size(0) != 1:\n        input_data = input_data.t()\n\n    factor = 1.0 / (input_data.size(1) - 1)\n    input_data = input_data - torch.mean(input_data, dim=1, keepdim=True)\n    return factor * input_data.matmul(input_data.t()).squeeze()\n\n\ndef _sqrtm(input_data: torch.Tensor) -> torch.Tensor:\n    \"\"\"Compute the square root of a matrix.\"\"\"\n    scipy_res, _ = scipy.linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float64), disp=False)\n    return torch.from_numpy(scipy_res)\n\n\ndef compute_frechet_distance(\n    mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6\n) -> torch.Tensor:\n    \"\"\"The Frechet distance between multivariate normal distributions.\"\"\"\n    diff = mu_x - mu_y\n\n    covmean = _sqrtm(sigma_x.mm(sigma_y))\n\n    # Product might be almost singular\n    if not torch.isfinite(covmean).all():\n        print(f\"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates\")\n        offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon\n        covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset))\n\n    # Numerical error might give slight imaginary component\n    if torch.is_complex(covmean):\n        if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3):\n            raise ValueError(f\"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.\")\n        covmean = covmean.real\n\n    tr_covmean = torch.trace(covmean)\n    return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean\n"
  },
  {
    "path": "monai/metrics/froc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.config import NdarrayOrTensor\n\n\ndef compute_fp_tp_probs_nd(\n    probs: NdarrayOrTensor,\n    coords: NdarrayOrTensor,\n    evaluation_mask: NdarrayOrTensor,\n    labels_to_exclude: list | None = None,\n) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]:\n    \"\"\"\n    This function is modified from the official evaluation code of\n    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish\n    true positive and false positive predictions. A true positive prediction is defined when\n    the detection point is within the annotated ground truth region.\n\n    Args:\n        probs: an array with shape (n,) that represents the probabilities of the detections.\n            Where, n is the number of predicted detections.\n        coords: an array with shape (n, n_dim) that represents the coordinates of the detections.\n            The dimensions must be in the same order as in `evaluation_mask`.\n        evaluation_mask: the ground truth mask for evaluation.\n        labels_to_exclude: labels in this list will not be counted for metric calculation.\n\n    Returns:\n        fp_probs: an array that contains the probabilities of the false positive detections.\n        tp_probs: an array that contains the probabilities of the True positive detections.\n        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.\n\n    \"\"\"\n    if not (len(probs) == len(coords)):\n        raise ValueError(f\"the length of probs {probs.shape}, should be the same as of coords {coords.shape}.\")\n    if not (len(coords.shape) > 1 and coords.shape[1] == len(evaluation_mask.shape)):\n        raise ValueError(\n            f\"coords {coords.shape} need to represent the same number of dimensions as mask {evaluation_mask.shape}.\"\n        )\n\n    if isinstance(probs, torch.Tensor):\n        probs = probs.detach().cpu().numpy()\n    if isinstance(coords, torch.Tensor):\n        coords = coords.detach().cpu().numpy()\n    if isinstance(evaluation_mask, torch.Tensor):\n        evaluation_mask = evaluation_mask.detach().cpu().numpy()\n\n    if labels_to_exclude is None:\n        labels_to_exclude = []\n\n    max_label = np.max(evaluation_mask)\n    tp_probs = np.zeros((max_label,), dtype=np.float32)\n\n    hittedlabel = evaluation_mask[tuple(coords.T)]\n    fp_probs = probs[np.where(hittedlabel == 0)]\n    for i in range(1, max_label + 1):\n        if i not in labels_to_exclude and i in hittedlabel:\n            tp_probs[i - 1] = probs[np.where(hittedlabel == i)].max()\n\n    num_targets = max_label - len(labels_to_exclude)\n    return fp_probs, tp_probs, cast(int, num_targets)\n\n\ndef compute_fp_tp_probs(\n    probs: NdarrayOrTensor,\n    y_coord: NdarrayOrTensor,\n    x_coord: NdarrayOrTensor,\n    evaluation_mask: NdarrayOrTensor,\n    labels_to_exclude: list | None = None,\n    resolution_level: int = 0,\n) -> tuple[NdarrayOrTensor, NdarrayOrTensor, int]:\n    \"\"\"\n    This function is modified from the official evaluation code of\n    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish\n    true positive and false positive predictions. A true positive prediction is defined when\n    the detection point is within the annotated ground truth region.\n\n    Args:\n        probs: an array with shape (n,) that represents the probabilities of the detections.\n            Where, n is the number of predicted detections.\n        y_coord: an array with shape (n,) that represents the Y-coordinates of the detections.\n        x_coord: an array with shape (n,) that represents the X-coordinates of the detections.\n        evaluation_mask: the ground truth mask for evaluation.\n        labels_to_exclude: labels in this list will not be counted for metric calculation.\n        resolution_level: the level at which the evaluation mask is made.\n\n    Returns:\n        fp_probs: an array that contains the probabilities of the false positive detections.\n        tp_probs: an array that contains the probabilities of the True positive detections.\n        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.\n\n    \"\"\"\n    if isinstance(y_coord, torch.Tensor):\n        y_coord = y_coord.detach().cpu().numpy()\n    if isinstance(x_coord, torch.Tensor):\n        x_coord = x_coord.detach().cpu().numpy()\n\n    y_coord = (y_coord / pow(2, resolution_level)).astype(int)\n    x_coord = (x_coord / pow(2, resolution_level)).astype(int)\n\n    stacked = np.stack([y_coord, x_coord], axis=1)\n\n    return compute_fp_tp_probs_nd(\n        probs=probs, coords=stacked, evaluation_mask=evaluation_mask, labels_to_exclude=labels_to_exclude\n    )\n\n\ndef compute_froc_curve_data(\n    fp_probs: np.ndarray | torch.Tensor, tp_probs: np.ndarray | torch.Tensor, num_targets: int, num_images: int\n) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"\n    This function is modified from the official evaluation code of\n    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to compute\n    the required data for plotting the Free Response Operating Characteristic (FROC) curve.\n\n    Args:\n        fp_probs: an array that contains the probabilities of the false positive detections for all\n            images under evaluation.\n        tp_probs: an array that contains the probabilities of the True positive detections for all\n            images under evaluation.\n        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.\n        num_images: the number of images under evaluation.\n\n    \"\"\"\n    if not isinstance(fp_probs, type(tp_probs)):\n        raise AssertionError(\"fp and tp probs should have same type.\")\n    if isinstance(fp_probs, torch.Tensor):\n        fp_probs = fp_probs.detach().cpu().numpy()\n    if isinstance(tp_probs, torch.Tensor):\n        tp_probs = tp_probs.detach().cpu().numpy()\n\n    total_fps, total_tps = [], []\n    all_probs = sorted(set(list(fp_probs) + list(tp_probs)))\n    for thresh in all_probs[1:]:\n        total_fps.append((fp_probs >= thresh).sum())\n        total_tps.append((tp_probs >= thresh).sum())\n    total_fps.append(0)\n    total_tps.append(0)\n    fps_per_image = np.asarray(total_fps) / float(num_images)\n    total_sensitivity = np.asarray(total_tps) / float(num_targets)\n    return fps_per_image, total_sensitivity\n\n\ndef compute_froc_score(\n    fps_per_image: np.ndarray, total_sensitivity: np.ndarray, eval_thresholds: tuple = (0.25, 0.5, 1, 2, 4, 8)\n) -> Any:\n    \"\"\"\n    This function is modified from the official evaluation code of\n    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to compute\n    the challenge's second evaluation metric, which is defined as the average sensitivity at\n    the predefined false positive rates per whole slide image.\n\n    Args:\n        fps_per_image: the average number of false positives per image for different thresholds.\n        total_sensitivity: sensitivities (true positive rates) for different thresholds.\n        eval_thresholds: the false positive rates for calculating the average sensitivity. Defaults\n            to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.\n\n    \"\"\"\n    interp_sens = np.interp(eval_thresholds, fps_per_image[::-1], total_sensitivity[::-1])\n    return np.mean(interp_sens)\n"
  },
  {
    "path": "monai/metrics/generalized_dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, ignore_background\nfrom monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass GeneralizedDiceScore(CumulativeIterationMetric):\n    \"\"\"\n    Compute the Generalized Dice Score metric between tensors.\n\n    This metric is the complement of the Generalized Dice Loss defined in:\n    Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning\n    loss function for highly unbalanced segmentations. DLMIA 2017.\n\n    The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        include_background: Whether to include the background class (assumed to be in channel 0) in the\n            score computation. Defaults to True.\n        reduction: Define mode of reduction to the metrics. Available reduction modes:\n            {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n            Default value is changed from `MetricReduction.MEAN_BATCH` to `MetricReduction.MEAN` in v1.5.0.\n            Old versions computed `mean` when `mean_batch` was provided due to bug in reduction.\n        weight_type: {``\"square\"``, ``\"simple\"``, ``\"uniform\"``}. Type of function to transform\n            ground truth volume into a weight factor. Defaults to ``\"square\"``.\n\n    Raises:\n        ValueError: When the `reduction` is not one of MetricReduction enum.\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        weight_type: Weight | str = Weight.SQUARE,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.reduction = look_up_option(reduction, MetricReduction)\n        self.weight_type = look_up_option(weight_type, Weight)\n        self.sum_over_classes = self.reduction in {\n            MetricReduction.SUM,\n            MetricReduction.MEAN,\n            MetricReduction.MEAN_CHANNEL,\n            MetricReduction.SUM_CHANNEL,\n        }\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Computes the Generalized Dice Score and returns a tensor with its per image values.\n\n        Args:\n            y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,\n                where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.\n            y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.\n\n        Returns:\n            torch.Tensor: Generalized Dice Score averaged across batch and class\n\n        Raises:\n            ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.\n        \"\"\"\n        return compute_generalized_dice(\n            y_pred=y_pred,\n            y=y,\n            include_background=self.include_background,\n            weight_type=self.weight_type,\n            sum_over_classes=self.sum_over_classes,\n        )\n\n    @deprecated_arg(\n        \"reduction\",\n        since=\"1.3.3\",\n        removed=\"1.7.0\",\n        msg_suffix=\"Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute\",\n    )\n    def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:\n        \"\"\"\n        Execute reduction logic for the output of `compute_generalized_dice`.\n\n        Returns:\n            torch.Tensor: Aggregated metric value.\n\n        Raises:\n            ValueError: If the data to aggregate is not a PyTorch Tensor.\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"The data to aggregate must be a PyTorch Tensor.\")\n\n        # Do metric reduction and return\n        f, _ = do_metric_reduction(data, self.reduction)\n\n        return f\n\n\ndef compute_generalized_dice(\n    y_pred: torch.Tensor,\n    y: torch.Tensor,\n    include_background: bool = True,\n    weight_type: Weight | str = Weight.SQUARE,\n    sum_over_classes: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Computes the Generalized Dice Score and returns a tensor with its per image values.\n\n    Args:\n        y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format\n            and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the\n            remaining are the spatial dimensions.\n        y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.\n        include_background: Whether to include score computation on the first channel of the\n            predicted output. Defaults to True.\n        weight_type (Union[Weight, str], optional): {``\"square\"``, ``\"simple\"``, ``\"uniform\"``}. Type of function to\n            transform ground truth volume into a weight factor. Defaults to ``\"square\"``.\n        sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.\n\n    Returns:\n        torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].\n\n    Raises:\n        ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,\n            or `y_pred` and `y` don't have the same shape.\n    \"\"\"\n    # Ensure tensors have at least 3 dimensions and have the same shape\n    dims = y_pred.dim()\n    if dims < 3:\n        raise ValueError(f\"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.\")\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes.\")\n\n    # Ignore background, if needed\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator\n    reduce_axis = list(range(2, y_pred.dim()))\n    intersection = torch.sum(y * y_pred, dim=reduce_axis)\n    y_o = torch.sum(y, dim=reduce_axis)\n    y_pred_o = torch.sum(y_pred, dim=reduce_axis)\n    denominator = y_o + y_pred_o\n\n    # Set the class weights\n    weight_type = look_up_option(weight_type, Weight)\n    if weight_type == Weight.SIMPLE:\n        w = torch.reciprocal(y_o.float())\n    elif weight_type == Weight.SQUARE:\n        w = torch.reciprocal(y_o.float() * y_o.float())\n    else:\n        w = torch.ones_like(y_o.float())\n\n    # Replace infinite values for non-appearing classes by the maximum weight\n    for b in w:\n        infs = torch.isinf(b)\n        b[infs] = 0\n        b[infs] = torch.max(b)\n\n    # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True\n    if sum_over_classes:\n        numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)\n        denom = (denominator * w).sum(dim=1, keepdim=True)\n        y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)\n    else:\n        numer = 2.0 * (intersection * w)\n        denom = denominator * w\n        y_pred_o = y_pred_o\n\n    # Compute the score\n    generalized_dice_score = numer / denom\n\n    # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.\n    # Where denom == 0 but the prediction volume is not 0, score is 0\n    denom_zeros = denom == 0\n    generalized_dice_score[denom_zeros] = torch.where(\n        (y_pred_o == 0)[denom_zeros],\n        torch.tensor(1.0, device=generalized_dice_score.device),\n        torch.tensor(0.0, device=generalized_dice_score.device),\n    )\n\n    return generalized_dice_score\n"
  },
  {
    "path": "monai/metrics/hausdorff_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing\nfrom monai.utils import MetricReduction, convert_data_type\n\nfrom .metric import CumulativeIterationMetric\n\n__all__ = [\"HausdorffDistanceMetric\", \"compute_hausdorff_distance\"]\n\n\nclass HausdorffDistanceMetric(CumulativeIterationMetric):\n    \"\"\"\n    Compute Hausdorff Distance between two tensors. It can support both multi-classes and multi-labels tasks.\n    It supports both directed and non-directed Hausdorff distance calculation. In addition, specify the `percentile`\n    parameter can get the percentile of the distance. Input `y_pred` is compared with ground truth `y`.\n    `y_preds` is expected to have binarized predictions and `y` should be in one-hot format.\n    You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values.\n    `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).\n    The implementation refers to `DeepMind's implementation <https://github.com/deepmind/surface-distance>`_.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        include_background: whether to include distance computation on the first channel of\n            the predicted output. Defaults to ``False``.\n        distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n            the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n        percentile: an optional float number between 0 and 100. If specified, the corresponding\n            percentile of the Hausdorff Distance rather than the maximum result will be achieved.\n            Defaults to ``None``.\n        directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = False,\n        distance_metric: str = \"euclidean\",\n        percentile: float | None = None,\n        directed: bool = False,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.distance_metric = distance_metric\n        self.percentile = percentile\n        self.directed = directed\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            y_pred: input data to compute, typical segmentation model output.\n                It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n                should be binarized.\n            y: ground truth to compute the distance. It must be one-hot format and first dim is batch.\n                The values should be binarized.\n            kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.\n                ``spacing``: spacing of pixel (or voxel). This parameter is relevant only\n                if ``distance_metric`` is set to ``\"euclidean\"``.\n                If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,\n                the length of the sequence must be equal to the image dimensions.\n                This spacing will be used for all images in the batch.\n                If a sequence of sequences, the length of the outer sequence must be equal to the batch size.\n                If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,\n                else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used\n                for all images in batch. Defaults to ``None``.\n\n        Raises:\n            ValueError: when `y_pred` has less than three dimensions.\n        \"\"\"\n        dims = y_pred.ndimension()\n        if dims < 3:\n            raise ValueError(\"y_pred should have at least three dimensions.\")\n\n        # compute (BxC) for each channel for each batch\n        return compute_hausdorff_distance(\n            y_pred=y_pred,\n            y=y,\n            include_background=self.include_background,\n            distance_metric=self.distance_metric,\n            percentile=self.percentile,\n            directed=self.directed,\n            spacing=kwargs.get(\"spacing\"),\n        )\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Execute reduction logic for the output of `compute_hausdorff_distance`.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n\ndef compute_hausdorff_distance(\n    y_pred: np.ndarray | torch.Tensor,\n    y: np.ndarray | torch.Tensor,\n    include_background: bool = False,\n    distance_metric: str = \"euclidean\",\n    percentile: float | None = None,\n    directed: bool = False,\n    spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Compute the Hausdorff distance.\n\n    Args:\n        y_pred: input data to compute, typical segmentation model output.\n            It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n            should be binarized.\n        y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch.\n            The values should be binarized.\n        include_background: whether to include distance computation on the first channel of\n            the predicted output. Defaults to ``False``.\n        distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n            the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n        percentile: an optional float number between 0 and 100. If specified, the corresponding\n            percentile of the Hausdorff Distance rather than the maximum result will be achieved.\n            Defaults to ``None``.\n        directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.\n        spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``\"euclidean\"``.\n            If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,\n            the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch.\n            If a sequence of sequences, the length of the outer sequence must be equal to the batch size.\n            If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,\n            else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used\n            for all images in batch. Defaults to ``None``.\n    \"\"\"\n\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n    y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0]\n    y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0]\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n    batch_size, n_class = y_pred.shape[:2]\n    hd = torch.empty((batch_size, n_class), dtype=torch.float, device=y_pred.device)\n\n    img_dim = y_pred.ndim - 2\n    spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)\n\n    for b, c in np.ndindex(batch_size, n_class):\n        _, distances, _ = get_edge_surface_distance(\n            y_pred[b, c],\n            y[b, c],\n            distance_metric=distance_metric,\n            spacing=spacing_list[b],\n            symmetric=not directed,\n            class_index=c,\n        )\n        percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]\n        max_distance = torch.max(torch.stack(percentile_distances))\n        hd[b, c] = max_distance\n    return hd\n\n\ndef _compute_percentile_hausdorff_distance(\n    surface_distance: torch.Tensor, percentile: float | None = None\n) -> torch.Tensor:\n    \"\"\"\n    This function is used to compute the Hausdorff distance.\n    \"\"\"\n\n    # for both pred and gt do not have foreground\n    if surface_distance.shape == (0,):\n        return torch.tensor(np.nan, dtype=torch.float, device=surface_distance.device)\n\n    if not percentile:\n        return surface_distance.max()\n\n    if 0 <= percentile <= 100:\n        return torch.quantile(surface_distance, percentile / 100)\n    raise ValueError(f\"percentile should be a value between 0 and 100, get {percentile}.\")\n"
  },
  {
    "path": "monai/metrics/loss_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport torch\nfrom torch.nn.modules.loss import _Loss\n\nfrom monai.config import TensorOrList\nfrom monai.metrics.utils import do_metric_reduction\nfrom monai.utils import MetricReduction\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass LossMetric(CumulativeIterationMetric):\n    \"\"\"\n    A wrapper to make ``loss_fn`` available as a cumulative metric. That is, the loss values computed from\n    mini-batches can be combined in the ``reduction`` mode across multiple iterations, as a quantitative measurement\n    of a model.\n\n    Example:\n\n    .. code-block:: python\n\n        import torch\n        from monai.losses import DiceLoss\n        from monai.metrics import LossMetric\n\n        dice_loss = DiceLoss(include_background=True)\n        loss_metric = LossMetric(loss_fn=dice_loss)\n\n        # first iteration\n        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])  # shape [batch=1, channel=1, 2, 2]\n        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])  # shape [batch=1, channel=1, 2, 2]\n        loss_metric(y_pred, y)\n\n        # second iteration\n        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]]]])  # shape [batch=1, channel=1, 2, 2]\n        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])  # shape [batch=1, channel=1, 2, 2]\n        loss_metric(y_pred, y)\n\n        # aggregate\n        print(loss_metric.aggregate(reduction=\"none\"))  # tensor([[0.2000], [0.5000]]) (shape [batch=2, channel=1])\n\n        # reset\n        loss_metric.reset()\n        print(loss_metric.aggregate())\n\n\n    Args:\n        loss_fn: a callable function that takes ``y_pred`` and optionally ``y`` as input (in the \"batch-first\" format),\n            returns a \"batch-first\" tensor of loss values.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.\n\n    \"\"\"\n\n    def __init__(\n        self, loss_fn: _Loss, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False\n    ) -> None:\n        super().__init__()\n        self.loss_fn = loss_fn\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns the aggregated loss value across multiple iterations.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n        \"\"\"\n        data = self.get_buffer()\n        if data is None:\n            return (torch.tensor(0.0), torch.tensor(0.0)) if self.get_not_nans else torch.tensor(0.0)\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> TensorOrList:\n        \"\"\"\n        Input `y_pred` is compared with ground truth `y`.\n        Both `y_pred` and `y` are expected to be a batch-first Tensor (BC[HWD]).\n\n        Returns:\n             a tensor with shape (BC[HWD]), or a list of tensors, each tensor with shape (C[HWD]).\n        \"\"\"\n        iter_loss: TensorOrList = self.loss_fn(y_pred) if y is None else self.loss_fn(y_pred, y)\n        if isinstance(iter_loss, torch.Tensor):\n            while iter_loss.dim() < 2:\n                iter_loss = iter_loss[None]\n        # to be compatible with `Cumulative`, iter_loss should at least have a batch dim.\n        # to be compatible with `do_metric_reduction`, iter_loss should at least have a batch and a channel dim.\n        return iter_loss\n"
  },
  {
    "path": "monai/metrics/meandice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction\nfrom monai.utils import MetricReduction, deprecated_arg\n\nfrom .metric import CumulativeIterationMetric\n\n__all__ = [\"DiceMetric\", \"compute_dice\", \"DiceHelper\"]\n\n\nclass DiceMetric(CumulativeIterationMetric):\n    \"\"\"\n    Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps\n    or multi-channel images with class segmentations per channel. This allows the computation for both multi-class\n    and multi-label tasks.\n\n    If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-\n    hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps\n    and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,\n    this metric applies no activations and so non-binary values will produce unexpected results if this metric is used\n    for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by\n    this metric. Typically this implies that raw predictions from a network must first be activated and possibly made\n    into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel\n    dimensions to produce a label map.\n\n    The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which\n    is by convention assumed to be background. If the non-background segmentations are small compared to the total\n    image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction\n    and ground truth is BCHW[D].\n\n    The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Further information can be found in the official\n    `MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.\n\n    Example:\n\n    .. code-block:: python\n\n        import torch\n        from monai.metrics import DiceMetric\n        from monai.losses import DiceLoss\n        from monai.networks import one_hot\n\n        batch_size, n_classes, h, w = 7, 5, 128, 128\n\n        y_pred = torch.rand(batch_size, n_classes, h, w)  # network predictions\n        y_pred = torch.argmax(y_pred, 1, True)  # convert to label map\n\n        # ground truth as label map\n        y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))\n\n        dm = DiceMetric(\n            reduction=\"mean_batch\", return_with_label=True, num_classes=n_classes\n        )\n\n        raw_scores = dm(y_pred, y)\n        print(dm.aggregate())\n\n        # now compute the Dice loss which should be the same as 1 - raw_scores\n        dl = DiceLoss(to_onehot_y=True, reduction=\"none\")\n        loss = dl(one_hot(y_pred, n_classes), y).squeeze()\n\n        print(1.0 - loss)  # same as raw_scores\n\n\n    Args:\n        include_background: whether to include Dice computation on the first channel/category of the prediction and\n            ground truth. Defaults to ``True``, use ``False`` to exclude the background class.\n        reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The\n            available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If \"none\", is\n            selected, the metric will not do reduction.\n        get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where\n            `not_nans` counts the number of valid values in the result, and will have the same shape.\n        ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be\n            set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases\n            are also empty.\n        num_classes: number of input channels (always including the background). When this is ``None``,\n            ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are\n            single-channel class indices and the number of classes is not automatically inferred from data.\n        return_with_label: whether to return the metrics with label, only works when reduction is \"mean_batch\".\n            If `True`, use \"label_{index}\" as the key corresponding to C channels; if ``include_background`` is True,\n            the index begins at \"0\", otherwise at \"1\". It can also take a list of label names.\n            The outcome will then be returned as a dictionary.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        ignore_empty: bool = True,\n        num_classes: int | None = None,\n        return_with_label: bool | list[str] = False,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n        self.ignore_empty = ignore_empty\n        self.num_classes = num_classes\n        self.return_with_label = return_with_label\n        self.dice_helper = DiceHelper(\n            include_background=self.include_background,\n            reduction=MetricReduction.NONE,\n            get_not_nans=False,\n            apply_argmax=False,\n            ignore_empty=self.ignore_empty,\n            num_classes=self.num_classes,\n        )\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Compute the dice value using ``DiceHelper``.\n\n        Args:\n            y_pred: prediction value, see class docstring for format definition.\n            y: ground truth label.\n\n        Raises:\n            ValueError: when `y_pred` has fewer than three dimensions.\n        \"\"\"\n        dims = y_pred.ndimension()\n        if dims < 3:\n            raise ValueError(f\"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.\")\n        # compute dice (BxC) for each channel for each batch\n        return self.dice_helper(y_pred=y_pred, y=y)  # type: ignore\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Execute reduction and aggregation logic for the output of `compute_dice`.\n\n        Args:\n            reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.\n                By default this will do no reduction.\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(f\"the data to aggregate must be PyTorch Tensor, got {type(data)}.\")\n\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        if self.reduction == MetricReduction.MEAN_BATCH and self.return_with_label:\n            _f = {}\n            if isinstance(self.return_with_label, bool):\n                for i, v in enumerate(f):\n                    _label_key = f\"label_{i + 1}\" if not self.include_background else f\"label_{i}\"\n                    _f[_label_key] = round(v.item(), 4)\n            else:\n                for key, v in zip(self.return_with_label, f):\n                    _f[key] = round(v.item(), 4)\n            f = _f\n        return (f, not_nans) if self.get_not_nans else f\n\n\ndef compute_dice(\n    y_pred: torch.Tensor,\n    y: torch.Tensor,\n    include_background: bool = True,\n    ignore_empty: bool = True,\n    num_classes: int | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Computes Dice score metric for a batch of predictions. This performs the same computation as\n    :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the\n    documentation for that class .\n\n    Args:\n        y_pred: input data to compute, typical segmentation model output.\n        y: ground truth to compute mean dice metric.\n        include_background: whether to include Dice computation on the first channel/category of the prediction and\n            ground truth. Defaults to ``True``, use ``False`` to exclude the background class.\n        ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be\n            set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases\n            are also empty.\n        num_classes: number of input channels (always including the background). When this is ``None``,\n            ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are\n            single-channel class indices and the number of classes is not automatically inferred from data.\n\n    Returns:\n        Dice scores per batch and per class, (shape: [batch_size, num_classes]).\n\n    \"\"\"\n    return DiceHelper(  # type: ignore\n        include_background=include_background,\n        reduction=MetricReduction.NONE,\n        get_not_nans=False,\n        apply_argmax=False,\n        ignore_empty=ignore_empty,\n        num_classes=num_classes,\n    )(y_pred=y_pred, y=y)\n\n\nclass DiceHelper:\n    \"\"\"\n    Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,\n    see the documentation for that class for input formats.\n\n    Example:\n\n    .. code-block:: python\n\n        import torch\n        from monai.metrics import DiceHelper\n\n        n_classes, batch_size = 5, 16\n        spatial_shape = (128, 128, 128)\n\n        y_pred = torch.rand(batch_size, n_classes, *spatial_shape).float()  # predictions\n        y = torch.randint(0, n_classes, size=(batch_size, 1, *spatial_shape)).long()  # ground truth\n\n        score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)\n        print(score, not_nans)\n\n    Args:\n        include_background: whether to include Dice computation on the first channel/category of the prediction and\n            ground truth. Defaults to ``True``, use ``False`` to exclude the background class.\n        threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.\n        apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to\n            get the discrete prediction. Defaults to the value of ``not threshold``.\n        activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before\n            thresholding. Defaults to False.\n        get_not_nans: whether to return the number of not-nan values.\n        reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The\n            available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If \"none\", is\n            selected, the metric will not do reduction.\n        ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be\n            set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases\n            are also empty.\n        num_classes: number of input channels (always including the background). When this is ``None``,\n            ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are\n            single-channel class indices and the number of classes is not automatically inferred from data.\n    \"\"\"\n\n    @deprecated_arg(\"softmax\", \"1.5\", \"1.7\", \"Use `apply_argmax` instead.\", new_name=\"apply_argmax\")\n    @deprecated_arg(\"sigmoid\", \"1.5\", \"1.7\", \"Use `threshold` instead.\", new_name=\"threshold\")\n    def __init__(\n        self,\n        include_background: bool | None = None,\n        threshold: bool = False,\n        apply_argmax: bool | None = None,\n        activate: bool = False,\n        get_not_nans: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,\n        ignore_empty: bool = True,\n        num_classes: int | None = None,\n        sigmoid: bool | None = None,\n        softmax: bool | None = None,\n    ) -> None:\n        # handling deprecated arguments\n        if sigmoid is not None:\n            threshold = sigmoid\n        if softmax is not None:\n            apply_argmax = softmax\n\n        self.threshold = threshold\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n        self.include_background = threshold if include_background is None else include_background\n        self.apply_argmax = not threshold if apply_argmax is None else apply_argmax\n        self.activate = activate\n        self.ignore_empty = ignore_empty\n        self.num_classes = num_classes\n\n    def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately\n        for each batch item and for each channel of those items.\n\n        Args:\n            y_pred: input predictions with shape HW[D].\n            y: ground truth with shape HW[D].\n        \"\"\"\n        y_o = torch.sum(y)\n        if y_o > 0:\n            return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))\n        if self.ignore_empty:\n            return torch.tensor(float(\"nan\"), device=y_o.device)\n        denorm = y_o + torch.sum(y_pred)\n        if denorm <= 0:\n            return torch.tensor(1.0, device=y_o.device)\n        return torch.tensor(0.0, device=y_o.device)\n\n    def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Compute the metric for the given prediction and ground truth.\n\n        Args:\n            y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).\n                the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.\n            y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).\n        \"\"\"\n        _apply_argmax, _threshold = self.apply_argmax, self.threshold\n        if self.num_classes is None:\n            n_pred_ch = y_pred.shape[1]  # y_pred is in one-hot format or multi-channel scores\n        else:\n            n_pred_ch = self.num_classes\n            if y_pred.shape[1] == 1 and self.num_classes > 1:  # y_pred is single-channel class indices\n                _apply_argmax = _threshold = False\n\n        if _apply_argmax and n_pred_ch > 1:\n            y_pred = torch.argmax(y_pred, dim=1, keepdim=True)\n\n        elif _threshold:\n            if self.activate:\n                y_pred = torch.sigmoid(y_pred)\n            y_pred = y_pred > 0.5\n\n        first_ch = 0 if self.include_background else 1\n        data = []\n        for b in range(y_pred.shape[0]):\n            c_list = []\n            for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:\n                x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()\n                x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]\n                c_list.append(self.compute_channel(x_pred, x))\n            data.append(torch.stack(c_list))\n        data = torch.stack(data, dim=0).contiguous()  # type: ignore\n\n        f, not_nans = do_metric_reduction(data, self.reduction)  # type: ignore\n        return (f, not_nans) if self.get_not_nans else f\n"
  },
  {
    "path": "monai/metrics/meaniou.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, ignore_background\nfrom monai.utils import MetricReduction\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass MeanIoU(CumulativeIterationMetric):\n    \"\"\"\n    Compute average Intersection over Union (IoU) score between two tensors.\n    It supports both multi-classes and multi-labels tasks.\n    Input `y_pred` is compared with ground truth `y`.\n    `y_pred` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms\n    in ``monai.transforms.post`` first to achieve binarized values.\n    The `include_background` parameter can be set to ``False`` to exclude\n    the first category (channel index 0) which is by convention assumed to be background. If the non-background\n    segmentations are small compared to the total image size they can get overwhelmed by the signal from the\n    background.\n    `y_pred` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        include_background: whether to include IoU computation on the first channel of\n            the predicted output. Defaults to ``True``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.\n        ignore_empty: whether to ignore empty ground truth cases during calculation.\n            If `True`, NaN value will be set for empty ground truth cases.\n            If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        ignore_empty: bool = True,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n        self.ignore_empty = ignore_empty\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            y_pred: input data to compute, typical segmentation model output.\n                It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n                should be binarized.\n            y: ground truth to compute mean IoU metric. It must be one-hot format and first dim is batch.\n                The values should be binarized.\n\n        Raises:\n            ValueError: when `y_pred` has less than three dimensions.\n        \"\"\"\n        dims = y_pred.ndimension()\n        if dims < 3:\n            raise ValueError(f\"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.\")\n        # compute IoU (BxC) for each channel for each batch\n        return compute_iou(\n            y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty\n        )\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Execute reduction logic for the output of `compute_iou`.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n\ndef compute_iou(\n    y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True\n) -> torch.Tensor:\n    \"\"\"Computes Intersection over Union (IoU) score metric from a batch of predictions.\n\n    Args:\n        y_pred: input data to compute, typical segmentation model output.\n            It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n            should be binarized.\n        y: ground truth to compute mean IoU metric. It must be one-hot format and first dim is batch.\n            The values should be binarized.\n        include_background: whether to include IoU computation on the first channel of\n            the predicted output. Defaults to True.\n        ignore_empty: whether to ignore empty ground truth cases during calculation.\n            If `True`, NaN value will be set for empty ground truth cases.\n            If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.\n\n    Returns:\n        IoU scores per batch and per class, (shape [batch_size, num_classes]).\n\n    Raises:\n        ValueError: when `y_pred` and `y` have different shapes.\n\n    \"\"\"\n\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n    # reducing only spatial dimensions (not batch nor channels)\n    n_len = len(y_pred.shape)\n    reduce_axis = list(range(2, n_len))\n    intersection = torch.sum(y * y_pred, dim=reduce_axis)\n\n    y_o = torch.sum(y, reduce_axis)\n    y_pred_o = torch.sum(y_pred, dim=reduce_axis)\n    union = y_o + y_pred_o - intersection\n\n    if ignore_empty:\n        return torch.where(y_o > 0, (intersection) / union, torch.tensor(float(\"nan\"), device=y_o.device))\n    return torch.where(union > 0, (intersection) / union, torch.tensor(1.0, device=y_o.device))\n"
  },
  {
    "path": "monai/metrics/metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\n\nfrom monai.config import TensorOrList\nfrom monai.utils import convert_data_type, evenly_divisible_all_gather\n\n__all__ = [\"Metric\", \"IterationMetric\", \"Cumulative\", \"CumulativeIterationMetric\"]\n\n\nclass Metric(ABC):\n    \"\"\"\n    Base class for metric computation for evaluating the performance of a model.\n    `__call__` is designed to execute the computation.\n\n    \"\"\"\n\n    @abstractmethod\n    def __call__(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        This method should take raw model outputs as inputs, and return values that measure the models' quality.\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def __str__(self):\n        return self.__class__.__name__\n\n\nclass IterationMetric(Metric):\n    \"\"\"\n    Base class for metrics computation at the iteration level, that is, on a min-batch of samples\n    usually using the model outcome of one iteration.\n\n    `__call__` is designed to handle `y_pred` and `y` (optional) in torch tensors or a list/tuple of tensors.\n\n    Subclasses typically implement the `_compute_tensor` function for the actual tensor computation logic.\n    \"\"\"\n\n    def __call__(\n        self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any\n    ) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]:\n        \"\"\"\n        Execute basic computation for model prediction `y_pred` and ground truth `y` (optional).\n        It supports inputs of a list of \"channel-first\" Tensor and a \"batch-first\" Tensor.\n\n        Args:\n            y_pred: the raw model prediction data at one iteration, must be a list of `channel-first` Tensor\n                or a `batch-first` Tensor.\n            y: the ground truth to compute, must be a list of `channel-first` Tensor\n                or a `batch-first` Tensor.\n            kwargs: additional parameters for specific metric computation logic (e.g. ``spacing`` for SurfaceDistanceMetric, etc.).\n\n        Returns:\n            The computed metric values at the iteration level.\n            The output shape could be a `batch-first` tensor or a list of `batch-first` tensors.\n            When it's a list of tensors, each item in the list can represent a specific type of metric.\n\n        \"\"\"\n        # handling a list of channel-first data\n        if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)):\n            return self._compute_list(y_pred, y, **kwargs)\n        # handling a single batch-first data\n        if isinstance(y_pred, torch.Tensor):\n            y_ = y.detach() if isinstance(y, torch.Tensor) else None\n            return self._compute_tensor(y_pred.detach(), y_, **kwargs)\n        raise ValueError(\"y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.\")\n\n    def _compute_list(\n        self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any\n    ) -> torch.Tensor | list[torch.Tensor | Sequence[torch.Tensor]]:\n        \"\"\"\n        Execute the metric computation for `y_pred` and `y` in a list of \"channel-first\" tensors.\n\n        The return value is a \"batch-first\" tensor, or a list of \"batch-first\" tensors.\n        When it's a list of tensors, each item in the list can represent a specific type of metric values.\n\n        For example, `self._compute_tensor` may be implemented as returning a list of `batch_size` items,\n        where each item is a tuple of three values `tp`, `fp`, `fn` for true positives, false positives,\n        and false negatives respectively. This function will return a list of three items,\n        (`tp_batched`, `fp_batched`, `fn_batched`), where each item is a `batch_size`-length tensor.\n\n        Note: subclass may enhance the operation to have multi-thread support.\n        \"\"\"\n        if y is not None:\n            ret = [\n                self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0), **kwargs)\n                for p, y_ in zip(y_pred, y)\n            ]\n        else:\n            ret = [self._compute_tensor(p_.detach().unsqueeze(0), None, **kwargs) for p_ in y_pred]\n\n        # concat the list of results (e.g. a batch of evaluation scores)\n        if isinstance(ret[0], torch.Tensor):\n            return torch.cat(ret, dim=0)  # type: ignore[arg-type]\n        # the result is a list of sequence of tensors (e.g. a batch of multi-class results)\n        if isinstance(ret[0], (list, tuple)) and all(isinstance(i, torch.Tensor) for i in ret[0]):\n            return [torch.cat(batch_i, dim=0) for batch_i in zip(*ret)]\n        return ret\n\n    @abstractmethod\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> TensorOrList:\n        \"\"\"\n        Computation logic for `y_pred` and `y` of an iteration, the data should be \"batch-first\" Tensors.\n        A subclass should implement its own computation logic.\n        The return value is usually a \"batch_first\" tensor, or a list of \"batch_first\" tensors.\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass Cumulative:\n    \"\"\"\n    Utility class for the typical cumulative computation process based on PyTorch Tensors.\n    It provides interfaces to accumulate values in the local buffers, synchronize buffers across distributed nodes,\n    and aggregate the buffered values.\n\n    In multi-processing, PyTorch programs usually distribute data to multiple nodes. Each node runs with a subset\n    of the data, adds values to its local buffers. Calling `get_buffer` could gather all the results and\n    `aggregate` can further handle the results to generate the final outcomes.\n\n    Users can implement their own `aggregate` method to handle the results,\n    using `get_buffer` to get the buffered contents.\n\n    Note: the data list should have the same length every time calling `add()` in a round,\n    it will automatically create buffers according to the length of data list.\n\n    Typically, this class is expected to execute the following steps:\n\n    .. code-block:: python\n\n        from monai.metrics import Cumulative\n\n        c = Cumulative()\n        c.append(1)  # adds a value\n        c.extend([2, 3])  # adds a batch of values\n        c.extend([4, 5, 6])  # adds a batch of values\n        print(c.get_buffer())  # tensor([1, 2, 3, 4, 5, 6])\n        print(len(c))  # 6\n        c.reset()\n        print(len(c))  # 0\n\n    The following is an example of maintaining two internal buffers:\n\n    .. code-block:: python\n\n        from monai.metrics import Cumulative\n\n        c = Cumulative()\n        c.append(1, 2)  # adds a value to two buffers respectively\n        c.extend([3, 4], [5, 6])  # adds batches of values\n        print(c.get_buffer())  # [tensor([1, 3, 4]), tensor([2, 5, 6])]\n        print(len(c))\n\n    The following is an example of extending with variable length data:\n\n    .. code-block:: python\n\n        import torch\n        from monai.metrics import Cumulative\n\n        c = Cumulative()\n        c.extend(torch.zeros((8, 2)), torch.zeros((6, 2)))  # adds batches\n        c.append(torch.zeros((2, )))  # adds a value\n        print(c.get_buffer())  # [torch.zeros((9, 2)), torch.zeros((6, 2))]\n        print(len(c))\n\n    \"\"\"\n\n    def __init__(self) -> None:\n        \"\"\"\n        Initialize the internal buffers.\n        `self._buffers` are local buffers, they are not usually used directly.\n        `self._sync_buffers` are the buffers with all the results across all the nodes.\n        \"\"\"\n        self._buffers: list[list[torch.Tensor]] | None = None\n        self._synced_tensors: list[torch.Tensor | None] | None = None\n        self._synced: bool = False\n        self.reset()\n\n    def reset(self):\n        \"\"\"\n        Reset the buffers for cumulative tensors and the synced results.\n\n        \"\"\"\n        self._buffers = None\n        self._synced_tensors = None\n        self._synced = False\n\n    def extend(self, *data: Any) -> None:\n        \"\"\"\n        Extend the local buffers with new (\"batch-first\") data.\n        A buffer will be allocated for each `data` item.\n        Compared with `self.append`, this method adds a \"batch\" of data to the local buffers.\n\n        Args:\n            data: each item can be a \"batch-first\" tensor or a list of \"channel-first\" tensors.\n                they will be concatenated at the 0-th dimension when `get_buffer()` is called.\n        \"\"\"\n        if self._buffers is None:\n            self._buffers = [[] for _ in data]\n        for b, d in zip(self._buffers, data):\n            # converting to pytorch tensors so that we can use the distributed API\n            d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True)\n            try:  # d_t must be a mini-batch of values\n                b.extend([x[0] for x in torch.split(d_t, 1, dim=0)])\n            except (AttributeError, IndexError, RuntimeError) as e:\n                raise TypeError(\n                    f\"{e}. `data` should be a batch-first tensor or\"\n                    f\" a list of channel-first tensors, got {type(d_t)}\"\n                ) from e\n        self._synced = False\n\n    def append(self, *data: Any) -> None:\n        \"\"\"\n        Add samples to the local cumulative buffers.\n        A buffer will be allocated for each `data` item.\n        Compared with `self.extend`, this method adds a single sample (instead\n        of a \"batch\") to the local buffers.\n\n        Args:\n            data: each item will be converted into a torch tensor.\n                they will be stacked at the 0-th dim with a new dimension when `get_buffer()` is called.\n\n        \"\"\"\n        if self._buffers is None:\n            self._buffers = [[] for _ in data]\n        for b, d in zip(self._buffers, data):\n            # converting to pytorch tensors so that we can use the distributed API\n            d_t, *_ = convert_data_type(d, output_type=torch.Tensor, wrap_sequence=True)\n            b.append(d_t)\n        self._synced = False\n\n    @abstractmethod\n    def aggregate(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\n        Aggregate final results based on the gathered buffers.\n        This method is expected to use `get_buffer` to gather the local buffer contents.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def _sync(self):\n        \"\"\"\n        All gather the buffers across distributed ranks for aggregating.\n        Each buffer will be concatenated as a PyTorch Tensor.\n\n        \"\"\"\n        if self._synced or self._buffers is None:\n            return\n        try:\n            self._synced_tensors = [\n                evenly_divisible_all_gather(torch.stack(b, dim=0), concat=True) for b in self._buffers\n            ]\n        except (RuntimeError, TypeError, ValueError) as e:\n            raise TypeError(f\"{e}. unable to sync buffer contents: {self._buffers}.\") from e\n        self._synced = True\n\n    def __len__(self):\n        \"\"\"\n        Return the length of the largest buffer.\n        Note that the method will trigger synchronization of the local buffers.\n        \"\"\"\n        self._sync()\n        if self._synced_tensors is None:\n            return 0\n        return max(len(x) for x in self._synced_tensors if x is not None)\n\n    def get_buffer(self):\n        \"\"\"\n        Get the synchronized list of buffers.\n        A typical usage is to generate the metrics report based on the raw metric details.\n        Each buffer is a PyTorch Tensor.\n\n        \"\"\"\n        self._sync()\n        if self._synced_tensors is None:\n            return self._synced_tensors\n        buffers = [x.detach().clone() if isinstance(x, torch.Tensor) else x for x in self._synced_tensors]\n        return buffers[0] if len(buffers) == 1 else buffers\n\n\nclass CumulativeIterationMetric(Cumulative, IterationMetric):\n    \"\"\"\n    Base class of cumulative metric which collects metrics on each mini-batch data at the iteration level.\n\n    Typically, it computes some intermediate results for each iteration, adds them to the buffers,\n    then the buffer contents could be gathered and aggregated for the final result when epoch completed.\n    Currently,``Cumulative.aggregate()`` and ``IterationMetric._compute_tensor()`` are expected to be implemented.\n\n    For example, `MeanDice` inherits this class and the usage is as follows:\n\n    .. code-block:: python\n\n        dice_metric = DiceMetric(include_background=True, reduction=\"mean\")\n\n        for val_data in val_loader:\n            val_outputs = model(val_data[\"img\"])\n            val_outputs = [postprocessing_transform(i) for i in decollate_batch(val_outputs)]\n            # compute metric for current iteration\n            dice_metric(y_pred=val_outputs, y=val_data[\"seg\"])  # callable to add metric to the buffer\n\n        # aggregate the final mean dice result\n        metric = dice_metric.aggregate().item()\n\n        # reset the status for next computation round\n        dice_metric.reset()\n\n    And to load `predictions` and `labels` from files, then compute metrics with multi-processing, please refer to:\n    https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py.\n\n    \"\"\"\n\n    def __call__(\n        self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any\n    ) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]:\n        \"\"\"\n        Execute basic computation for model prediction and ground truth.\n        It can support  both `list of channel-first Tensor` and `batch-first Tensor`.\n        Users call this API to execute computation on every batch of data, then accumulate the results,\n        or accumulate the original `y_pred` and `y`, then execute on the accumulated data.\n\n        Args:\n            y_pred: the model prediction data to compute, must be a list of `channel-first` Tensor\n                or a `batch-first` Tensor.\n            y: the ground truth to compute, must be a list of `channel-first` Tensor\n                or a `batch-first` Tensor.\n            kwargs: additional parameters for specific metric computation logic (e.g. ``spacing`` for SurfaceDistanceMetric, etc.).\n\n        Returns:\n            The computed metric values at the iteration level. The output shape should be\n            a `batch-first` tensor (BC[HWD]) or a list of `batch-first` tensors.\n        \"\"\"\n        ret = super().__call__(y_pred=y_pred, y=y, **kwargs)\n        if isinstance(ret, (tuple, list)):\n            self.extend(*ret)\n        else:\n            self.extend(ret)\n\n        return ret\n"
  },
  {
    "path": "monai/metrics/mmd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable\n\nimport torch\n\nfrom monai.metrics.metric import Metric\n\n\nclass MMDMetric(Metric):\n    \"\"\"\n    Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two\n    distributions. It is a non-negative metric where a smaller value indicates a closer match between the two\n    distributions.\n\n    Gretton, A., et al,, 2012.  A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773.\n\n    Args:\n        y_mapping: Callable to transform the y tensors before computing the metric. It is usually a Gaussian or Laplace\n            filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a\n            feature extractor or an Identity function., e.g. `y_mapping = lambda x: x.square()`.\n    \"\"\"\n\n    def __init__(self, y_mapping: Callable | None = None) -> None:\n        super().__init__()\n        self.y_mapping = y_mapping\n\n    def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:\n        return compute_mmd(y, y_pred, self.y_mapping)\n\n\ndef compute_mmd(y: torch.Tensor, y_pred: torch.Tensor, y_mapping: Callable | None) -> torch.Tensor:\n    \"\"\"\n    Args:\n        y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.\n        y_pred: second sample (e.g., the reconstructed image). It has similar shape as y.\n        y_mapping: Callable to transform the y tensors before computing the metric.\n    \"\"\"\n    if y_pred.shape[0] == 1 or y.shape[0] == 1:\n        raise ValueError(\"MMD metric requires at least two samples in y and y_pred.\")\n\n    if y_mapping is not None:\n        y = y_mapping(y)\n        y_pred = y_mapping(y_pred)\n\n    if y_pred.shape != y.shape:\n        raise ValueError(\n            \"y_pred and y shapes dont match after being processed \"\n            f\"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}\"\n        )\n\n    for d in range(len(y.shape) - 1, 1, -1):\n        y = y.squeeze(dim=d)\n        y_pred = y_pred.squeeze(dim=d)\n\n    y = y.view(y.shape[0], -1)\n    y_pred = y_pred.view(y_pred.shape[0], -1)\n\n    y_y = torch.mm(y, y.t())\n    y_pred_y_pred = torch.mm(y_pred, y_pred.t())\n    y_pred_y = torch.mm(y_pred, y.t())\n\n    m = y.shape[0]\n    n = y_pred.shape[0]\n\n    # Ref. 1 Eq. 3 (found under Lemma 6)\n    # term 1\n    c1 = 1 / (m * (m - 1))\n    a = torch.sum(y_y - torch.diag(torch.diagonal(y_y)))\n\n    # term 2\n    c2 = 1 / (n * (n - 1))\n    b = torch.sum(y_pred_y_pred - torch.diag(torch.diagonal(y_pred_y_pred)))\n\n    # term 3\n    c3 = 2 / (m * n)\n    c = torch.sum(y_pred_y)\n\n    mmd = c1 * a + c2 * b - c3 * c\n    return mmd\n"
  },
  {
    "path": "monai/metrics/panoptic_quality.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\n\nfrom monai.metrics.metric import CumulativeIterationMetric\nfrom monai.metrics.utils import do_metric_reduction, remap_instance_id\nfrom monai.utils import MetricReduction, ensure_tuple, optional_import\n\nlinear_sum_assignment, _ = optional_import(\"scipy.optimize\", name=\"linear_sum_assignment\")\n\n__all__ = [\"PanopticQualityMetric\", \"compute_panoptic_quality\", \"compute_mean_iou\"]\n\n\nclass PanopticQualityMetric(CumulativeIterationMetric):\n    \"\"\"\n    Compute Panoptic Quality between two instance segmentation masks. If specifying `metric_name` to \"SQ\" or \"RQ\",\n    Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.\n\n    Panoptic Quality is a metric used in panoptic segmentation tasks. This task unifies the typically distinct tasks\n    of semantic segmentation (assign a class label to each pixel) and\n    instance segmentation (detect and segment each object instance). Compared with semantic segmentation, panoptic\n    segmentation distinguish different instances that belong to same class.\n    Compared with instance segmentation, panoptic segmentation does not allow overlap and only one semantic label and\n    one instance id can be assigned to each pixel.\n    Please refer to the following paper for more details:\n    https://openaccess.thecvf.com/content_CVPR_2019/papers/Kirillov_Panoptic_Segmentation_CVPR_2019_paper.pdf\n\n    This class also refers to the following implementation:\n    https://github.com/TissueImageAnalytics/CoNIC\n\n    Args:\n        num_classes: number of classes. The number should not count the background.\n        metric_name: output metric. The value can be \"pq\", \"sq\" or \"rq\".\n            Except for input only one metric, multiple metrics are also supported via input a sequence of metric names\n            such as (\"pq\", \"sq\", \"rq\"). If input a sequence, a list of results with the same order\n            as the input names will be returned.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n        match_iou_threshold: IOU threshold to determine the pairing between `y_pred` and `y`. Usually,\n            it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical.\n            If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the\n            maximal amount of unique pairing.\n        smooth_numerator: a small constant added to the numerator to avoid zero.\n        return_confusion_matrix: if True, returns raw confusion matrix values (tp, fp, fn, iou_sum)\n            instead of computed metrics. Default is False.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes: int,\n        metric_name: Sequence[str] | str = \"pq\",\n        reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,\n        match_iou_threshold: float = 0.5,\n        smooth_numerator: float = 1e-6,\n        return_confusion_matrix: bool = False,\n    ) -> None:\n        super().__init__()\n        self.num_classes = num_classes\n        self.reduction = reduction\n        self.match_iou_threshold = match_iou_threshold\n        self.smooth_numerator = smooth_numerator\n        self.metric_name = ensure_tuple(metric_name)\n        self.return_confusion_matrix = return_confusion_matrix\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            y_pred: Predictions. It must be in the form of B2HW (2D) or B2HWD (3D) and have integer type.\n                The first channel and the second channel represent the instance predictions and classification\n                predictions respectively.\n            y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the\n                second channel represent the instance labels and classification labels respectively.\n                Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,\n                where 0 represents the background.\n\n        Raises:\n            ValueError: when `y_pred` and `y` have different shapes.\n            ValueError: when `y_pred` and `y` have != 2 channels.\n            ValueError: when `y_pred` and `y` have != 4 or 5 dimensions.\n\n        \"\"\"\n        if y_pred.shape != y.shape:\n            raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n        if y_pred.shape[1] != 2:\n            raise ValueError(\n                f\"for panoptic quality calculation, only 2 channels input is supported, got {y_pred.shape[1]}.\"\n            )\n\n        dims = y_pred.ndimension()\n        if dims not in (4, 5):\n            raise ValueError(\n                f\"y_pred should have 4 dimensions (batch, 2, h, w) or 5 dimensions (batch, 2, h, w, d), got {dims}.\"\n            )\n\n        batch_size = y_pred.shape[0]\n\n        outputs = torch.zeros([batch_size, self.num_classes, 4], device=y_pred.device)\n\n        for b in range(batch_size):\n            true_instance, pred_instance = y[b, 0], y_pred[b, 0]\n            true_class, pred_class = y[b, 1], y_pred[b, 1]\n            for c in range(self.num_classes):\n                pred_instance_c = (pred_class == c + 1) * pred_instance\n                true_instance_c = (true_class == c + 1) * true_instance\n\n                outputs[b, c] = compute_panoptic_quality(\n                    pred=pred_instance_c,\n                    gt=true_instance_c,\n                    remap=True,\n                    match_iou_threshold=self.match_iou_threshold,\n                    output_confusion_matrix=True,\n                )\n\n        return outputs\n\n    def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor | list[torch.Tensor]:\n        \"\"\"\n        Execute reduction logic for the output of `compute_panoptic_quality`.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n\n        Returns:\n            If `return_confusion_matrix` is True, returns the raw confusion matrix [tp, fp, fn, iou_sum].\n            Otherwise, returns the computed metric(s) based on `metric_name`.\n\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        # do metric reduction\n        f, _ = do_metric_reduction(data, reduction or self.reduction)\n\n        if self.return_confusion_matrix:\n            # Return raw confusion matrix values\n            return f\n\n        tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3]\n        results = []\n        for metric_name in self.metric_name:\n            metric_name = _check_panoptic_metric_name(metric_name)\n            if metric_name == \"rq\":\n                results.append(tp / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator))\n            elif metric_name == \"sq\":\n                results.append(iou_sum / (tp + self.smooth_numerator))\n            else:\n                results.append(iou_sum / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator))\n\n        return results[0] if len(results) == 1 else results\n\n\ndef compute_panoptic_quality(\n    pred: torch.Tensor,\n    gt: torch.Tensor,\n    metric_name: str = \"pq\",\n    remap: bool = True,\n    match_iou_threshold: float = 0.5,\n    smooth_numerator: float = 1e-6,\n    output_confusion_matrix: bool = False,\n) -> torch.Tensor:\n    \"\"\"Computes Panoptic Quality (PQ). If specifying `metric_name` to \"SQ\" or \"RQ\",\n    Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.\n\n    In addition, if `output_confusion_matrix` is True, the function will return a tensor with shape 4, which\n    represents the true positive, false positive, false negative and the sum of iou. These four values are used to\n    calculate PQ, and returning them directly enables further calculation over all images.\n\n    Args:\n        pred: input data to compute, it must be in the form of HW (2D) or HWD (3D) and have integer type.\n        gt: ground truth. It must have the same shape as `pred` and have integer type.\n        metric_name: output metric. The value can be \"pq\", \"sq\" or \"rq\".\n        remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.\n        match_iou_threshold: IOU threshold to determine the pairing between `pred` and `gt`. Usually,\n            it should >= 0.5, the pairing between instances of `pred` and `gt` are identical.\n            If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the\n            maximal amount of unique pairing.\n        smooth_numerator: a small constant added to the numerator to avoid zero.\n\n    Raises:\n        ValueError: when `pred` and `gt` have different shapes.\n        ValueError: when `match_iou_threshold` <= 0.0 or > 1.0.\n\n    \"\"\"\n\n    if gt.shape != pred.shape:\n        raise ValueError(f\"pred and gt should have same shapes, got {pred.shape} and {gt.shape}.\")\n    if match_iou_threshold <= 0.0 or match_iou_threshold > 1.0:\n        raise ValueError(f\"'match_iou_threshold' should be within (0, 1], got: {match_iou_threshold}.\")\n\n    gt = gt.int()\n    pred = pred.int()\n\n    if remap is True:\n        gt = remap_instance_id(gt)\n        pred = remap_instance_id(pred)\n\n    pairwise_iou, true_id_list, pred_id_list = _get_pairwise_iou(pred, gt, device=pred.device)\n    paired_iou, paired_true, paired_pred = _get_paired_iou(\n        pairwise_iou, match_iou_threshold, device=pairwise_iou.device\n    )\n\n    unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]\n    unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]\n\n    tp, fp, fn = len(paired_true), len(unpaired_pred), len(unpaired_true)\n    iou_sum = paired_iou.sum()\n\n    if output_confusion_matrix:\n        return torch.as_tensor([tp, fp, fn, iou_sum], device=pred.device)\n\n    metric_name = _check_panoptic_metric_name(metric_name)\n    if metric_name == \"rq\":\n        return torch.as_tensor(tp / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device)\n    if metric_name == \"sq\":\n        return torch.as_tensor(iou_sum / (tp + smooth_numerator), device=pred.device)\n    return torch.as_tensor(iou_sum / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device)\n\n\ndef _get_id_list(gt: torch.Tensor) -> list[torch.Tensor]:\n    id_list = list(gt.unique())\n    # ensure id 0 is included\n    if 0 not in id_list:\n        id_list.insert(0, torch.tensor(0).int())\n\n    return id_list\n\n\ndef _get_pairwise_iou(\n    pred: torch.Tensor, gt: torch.Tensor, device: str | torch.device = \"cpu\"\n) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:\n    pred_id_list = _get_id_list(pred)\n    true_id_list = _get_id_list(gt)\n\n    pairwise_iou = torch.zeros([len(true_id_list) - 1, len(pred_id_list) - 1], dtype=torch.float, device=device)\n    true_masks: list[torch.Tensor] = []\n    pred_masks: list[torch.Tensor] = []\n\n    for t in true_id_list[1:]:\n        t_mask = torch.as_tensor(gt == t, device=device).int()\n        true_masks.append(t_mask)\n\n    for p in pred_id_list[1:]:\n        p_mask = torch.as_tensor(pred == p, device=device).int()\n        pred_masks.append(p_mask)\n\n    for true_id in range(1, len(true_id_list)):\n        t_mask = true_masks[true_id - 1]\n        pred_true_overlap = pred[t_mask > 0]\n        pred_true_overlap_id = list(pred_true_overlap.unique())\n        for pred_id in pred_true_overlap_id:\n            if pred_id == 0:\n                continue\n            p_mask = pred_masks[pred_id - 1]\n            total = (t_mask + p_mask).sum()\n            inter = (t_mask * p_mask).sum()\n            iou = inter / (total - inter)\n            pairwise_iou[true_id - 1, pred_id - 1] = iou\n\n    return pairwise_iou, true_id_list, pred_id_list\n\n\ndef _get_paired_iou(\n    pairwise_iou: torch.Tensor, match_iou_threshold: float = 0.5, device: str | torch.device = \"cpu\"\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    if match_iou_threshold >= 0.5:\n        pairwise_iou[pairwise_iou <= match_iou_threshold] = 0.0\n        paired_true, paired_pred = torch.nonzero(pairwise_iou)[:, 0], torch.nonzero(pairwise_iou)[:, 1]\n        paired_iou = pairwise_iou[paired_true, paired_pred]\n        paired_true += 1\n        paired_pred += 1\n\n        return paired_iou, paired_true, paired_pred\n\n    pairwise_iou = pairwise_iou.cpu().numpy()  # type: ignore[assignment]\n    paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)\n    paired_iou = pairwise_iou[paired_true, paired_pred]\n    paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device)\n    paired_pred = torch.as_tensor(list(paired_pred[paired_iou > match_iou_threshold] + 1), device=device)\n    paired_iou = paired_iou[paired_iou > match_iou_threshold]\n\n    return paired_iou, paired_true, paired_pred\n\n\ndef _check_panoptic_metric_name(metric_name: str) -> str:\n    metric_name = metric_name.replace(\" \", \"_\")\n    metric_name = metric_name.lower()\n    if metric_name in [\"panoptic_quality\", \"pq\"]:\n        return \"pq\"\n    if metric_name in [\"segmentation_quality\", \"sq\"]:\n        return \"sq\"\n    if metric_name in [\"recognition_quality\", \"rq\"]:\n        return \"rq\"\n    raise ValueError(f\"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.\")\n\n\ndef compute_mean_iou(confusion_matrix: torch.Tensor, smooth_numerator: float = 1e-6) -> torch.Tensor:\n    \"\"\"Compute mean IoU from confusion matrix values.\n\n    Args:\n        confusion_matrix: tensor with shape (..., 4) where the last dimension contains\n            [tp, fp, fn, iou_sum] as returned by `compute_panoptic_quality` with `output_confusion_matrix=True`.\n        smooth_numerator: a small constant added to the numerator to avoid zero.\n\n    Returns:\n        Mean IoU computed as iou_sum / (tp + smooth_numerator).\n\n    \"\"\"\n    if confusion_matrix.shape[-1] != 4:\n        raise ValueError(\n            f\"confusion_matrix should have shape (..., 4) with [tp, fp, fn, iou_sum], \"\n            f\"got shape {confusion_matrix.shape}.\"\n        )\n    tp, iou_sum = confusion_matrix[..., 0], confusion_matrix[..., 3]\n    return iou_sum / (tp + smooth_numerator)\n"
  },
  {
    "path": "monai/metrics/regression.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom abc import abstractmethod\nfrom collections.abc import Callable, Sequence\nfrom functools import partial\nfrom typing import Any\n\nimport torch\nimport torch.nn.functional as F\n\nfrom monai.metrics.utils import do_metric_reduction\nfrom monai.utils import MetricReduction, StrEnum, convert_data_type, ensure_tuple_rep\nfrom monai.utils.type_conversion import convert_to_dst_type\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass RegressionMetric(CumulativeIterationMetric):\n    \"\"\"\n    Base class for regression metrics.\n    Input `y_pred` is compared with ground truth `y`.\n    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.\n    `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.\n\n    \"\"\"\n\n    def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None:\n        super().__init__()\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n    def _check_shape(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:\n        if y_pred.shape != y.shape:\n            raise ValueError(f\"y_pred and y shapes dont match, received y_pred: [{y_pred.shape}] and y: [{y.shape}]\")\n\n        # also check if there is atleast one non-batch dimension i.e. num_dims >= 2\n        if len(y_pred.shape) < 2:\n            raise ValueError(\"either channel or spatial dimensions required, found only batch dimension\")\n\n    @abstractmethod\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):\n            raise ValueError(\"y_pred and y must be PyTorch Tensor.\")\n        self._check_shape(y_pred, y)\n        return self._compute_metric(y_pred, y)\n\n\nclass MSEMetric(RegressionMetric):\n    r\"\"\"Compute Mean Squared Error between two tensors using function:\n\n    .. math::\n        \\operatorname {MSE}\\left(Y, \\hat{Y}\\right) =\\frac {1}{n}\\sum _{i=1}^{n}\\left(y_i-\\hat{y_i} \\right)^{2}.\n\n    More info: https://en.wikipedia.org/wiki/Mean_squared_error\n\n    Input `y_pred` is compared with ground truth `y`.\n    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n\n    \"\"\"\n\n    def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n        self.sq_func = partial(torch.pow, exponent=2.0)\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        return compute_mean_error_metrics(y_pred, y, func=self.sq_func)\n\n\nclass MAEMetric(RegressionMetric):\n    r\"\"\"Compute Mean Absolute Error between two tensors using function:\n\n    .. math::\n        \\operatorname {MAE}\\left(Y, \\hat{Y}\\right) =\\frac {1}{n}\\sum _{i=1}^{n}\\left|y_i-\\hat{y_i}\\right|.\n\n    More info: https://en.wikipedia.org/wiki/Mean_absolute_error\n\n    Input `y_pred` is compared with ground truth `y`.\n    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n\n    \"\"\"\n\n    def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n        self.abs_func = torch.abs\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        return compute_mean_error_metrics(y_pred, y, func=self.abs_func)\n\n\nclass MAPEMetric(RegressionMetric):\n    r\"\"\"Compute Mean Absolute Percentage Error between two tensors using function:\n\n    .. math::\n        \\operatorname {MAPE}\\left(Y, \\hat{Y}\\right) =\\frac {100}{n}\\sum _{i=1}^{n}\\left|\\frac{y_i-\\hat{y_i}}{y_i}\\right|.\n\n    More info: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error\n\n    Input `y_pred` is compared with ground truth `y`.\n    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.\n    Note: Tackling the undefined error, a tiny epsilon value is added to the denominator part.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n        epsilon: float. Defaults to 1e-7.\n\n    \"\"\"\n\n    def __init__(\n        self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, epsilon: float = 1e-7\n    ) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n        self.epsilon = epsilon\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        return compute_mape_metric(y_pred, y, epsilon=self.epsilon)\n\n\nclass RMSEMetric(RegressionMetric):\n    r\"\"\"Compute Root Mean Squared Error between two tensors using function:\n\n    .. math::\n        \\operatorname {RMSE}\\left(Y, \\hat{Y}\\right) ={ \\sqrt{ \\frac {1}{n}\\sum _{i=1}^{n}\\left(y_i-\\hat{y_i}\\right)^2 } } \\\n        = \\sqrt {\\operatorname{MSE}\\left(Y, \\hat{Y}\\right)}.\n\n    More info: https://en.wikipedia.org/wiki/Root-mean-square_deviation\n\n    Input `y_pred` is compared with ground truth `y`.\n    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n\n    \"\"\"\n\n    def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n        self.sq_func = partial(torch.pow, exponent=2.0)\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func)\n        return torch.sqrt(mse_out)\n\n\nclass PSNRMetric(RegressionMetric):\n    r\"\"\"Compute Peak Signal To Noise Ratio between two tensors using function:\n\n    .. math::\n        \\operatorname{PSNR}\\left(Y, \\hat{Y}\\right) = 20 \\cdot \\log_{10} \\left({\\mathit{MAX}}_Y\\right) \\\n        -10 \\cdot \\log_{10}\\left(\\operatorname{MSE\\left(Y, \\hat{Y}\\right)}\\right)\n\n    More info: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio\n\n    Help taken from:\n    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/image_ops_impl.py line 4139\n\n    Input `y_pred` is compared with ground truth `y`.\n    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        max_val: The dynamic range of the images/volumes (i.e., the difference between the\n            maximum and the minimum allowed values e.g. 255 for a uint8 image).\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n\n    \"\"\"\n\n    def __init__(\n        self, max_val: int | float, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False\n    ) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n        self.max_val = max_val\n        self.sq_func = partial(torch.pow, exponent=2.0)\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any:\n        mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func)\n        return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out)\n\n\ndef compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func: Callable) -> torch.Tensor:\n    # reducing in only channel + spatial dimensions (not batch)\n    # reduction of batch handled inside __call__() using do_metric_reduction() in respective calling class\n    flt = partial(torch.flatten, start_dim=1)\n    return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True)\n\n\ndef compute_mape_metric(y_pred: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:\n    \"\"\"\n    Compute Mean Absolute Percentage Error.\n\n    Args:\n        y_pred: predicted values\n        y: ground truth values\n        epsilon: small value to avoid division by zero\n\n    Returns:\n        MAPE value as percentage\n    \"\"\"\n    flt = partial(torch.flatten, start_dim=1)\n    percentage_error = torch.abs(y - y_pred) / torch.clamp(torch.abs(y), min=epsilon) * 100.0\n    return torch.mean(flt(percentage_error), dim=-1, keepdim=True)\n\n\nclass KernelType(StrEnum):\n    GAUSSIAN = \"gaussian\"\n    UNIFORM = \"uniform\"\n\n\nclass SSIMMetric(RegressionMetric):\n    r\"\"\"\n    Computes the Structural Similarity Index Measure (SSIM).\n\n    .. math::\n        \\operatorname {SSIM}(x,y) =\\frac {(2 \\mu_x \\mu_y + c_1)(2 \\sigma_{xy} + c_2)}{((\\mu_x^2 + \\\n                \\mu_y^2 + c_1)(\\sigma_x^2 + \\sigma_y^2 + c_2))}\n\n    For more info, visit\n        https://vicuesoft.com/glossary/term/ssim-ms-ssim/\n\n    SSIM reference paper:\n        Wang, Zhou, et al. \"Image quality assessment: from error visibility to structural\n        similarity.\" IEEE transactions on image processing 13.4 (2004): 600-612.\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input images.\n        data_range: value range of input images. (usually 1.0 or 255)\n        kernel_type: type of kernel, can be \"gaussian\" or \"uniform\".\n        win_size: window size of kernel\n        kernel_sigma: standard deviation for Gaussian kernel.\n        k1: stability constant used in the luminance denominator\n        k2: stability constant used in the contrast denominator\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        data_range: float = 1.0,\n        kernel_type: KernelType | str = KernelType.GAUSSIAN,\n        win_size: int | Sequence[int] = 11,\n        kernel_sigma: float | Sequence[float] = 1.5,\n        k1: float = 0.01,\n        k2: float = 0.03,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n\n        self.spatial_dims = spatial_dims\n        self.data_range = data_range\n        self.kernel_type = kernel_type\n\n        if not isinstance(win_size, Sequence):\n            win_size = ensure_tuple_rep(win_size, spatial_dims)\n        self.kernel_size = win_size\n\n        if not isinstance(kernel_sigma, Sequence):\n            kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)\n        self.kernel_sigma = kernel_sigma\n\n        self.k1 = k1\n        self.k2 = k2\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            y_pred: Predicted image.\n                It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n            y: Reference image.\n                It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n\n        Raises:\n            ValueError: when `y_pred` is not a 2D or 3D image.\n        \"\"\"\n        dims = y_pred.ndimension()\n        if self.spatial_dims == 2 and dims != 4:\n            raise ValueError(\n                f\"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} \"\n                f\"spatial dimensions, got {dims}.\"\n            )\n\n        if self.spatial_dims == 3 and dims != 5:\n            raise ValueError(\n                f\"y_pred should have 5 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}\"\n                f\" spatial dimensions, got {dims}.\"\n            )\n\n        ssim_value_full_image, _ = compute_ssim_and_cs(\n            y_pred=y_pred,\n            y=y,\n            spatial_dims=self.spatial_dims,\n            data_range=self.data_range,\n            kernel_type=self.kernel_type,\n            kernel_size=self.kernel_size,\n            kernel_sigma=self.kernel_sigma,\n            k1=self.k1,\n            k2=self.k2,\n        )\n\n        ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean(\n            1, keepdim=True\n        )\n\n        return ssim_per_batch\n\n\ndef _gaussian_kernel(\n    spatial_dims: int, num_channels: int, kernel_size: Sequence[int], kernel_sigma: Sequence[float]\n) -> torch.Tensor:\n    \"\"\"Computes 2D or 3D gaussian kernel.\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input images.\n        num_channels: number of channels in the image\n        kernel_size: size of kernel\n        kernel_sigma: standard deviation for Gaussian kernel.\n    \"\"\"\n\n    def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor:\n        \"\"\"Computes 1D gaussian kernel.\n\n        Args:\n            kernel_size: size of the gaussian kernel\n            sigma: Standard deviation of the gaussian kernel\n        \"\"\"\n        dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1)\n        gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)\n        return (gauss / gauss.sum()).unsqueeze(dim=0)\n\n    gaussian_kernel_x = gaussian_1d(kernel_size[0], kernel_sigma[0])\n    gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1])\n    kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)  # (kernel_size, 1) * (1, kernel_size)\n\n    kernel_dimensions: tuple[int, ...] = (num_channels, 1, kernel_size[0], kernel_size[1])\n\n    if spatial_dims == 3:\n        gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,]\n        kernel = torch.mul(\n            kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]),\n            gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]),\n        )\n        kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1], kernel_size[2])\n\n    return kernel.expand(kernel_dimensions)\n\n\ndef compute_ssim_and_cs(\n    y_pred: torch.Tensor,\n    y: torch.Tensor,\n    spatial_dims: int,\n    kernel_size: Sequence[int],\n    kernel_sigma: Sequence[float],\n    data_range: float = 1.0,\n    kernel_type: KernelType | str = KernelType.GAUSSIAN,\n    k1: float = 0.01,\n    k2: float = 0.03,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch\n    of images.\n\n    Args:\n        y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])\n        y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])\n        kernel_size: the size of the kernel to use for the SSIM computation.\n        kernel_sigma: the standard deviation of the kernel to use for the SSIM computation.\n        spatial_dims: number of spatial dimensions of the images (2, 3)\n        data_range: the data range of the images.\n        kernel_type: the type of kernel to use for the SSIM computation. Can be either \"gaussian\" or \"uniform\".\n        k1: the first stability constant.\n        k2: the second stability constant.\n\n    Returns:\n        ssim: the Structural Similarity Index Measure score for the batch of images.\n        cs: the Contrast Sensitivity for the batch of images.\n    \"\"\"\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n    y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0]\n    y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0]\n\n    num_channels = y_pred.size(1)\n\n    if kernel_type == KernelType.GAUSSIAN:\n        kernel = _gaussian_kernel(spatial_dims, num_channels, kernel_size, kernel_sigma)\n    elif kernel_type == KernelType.UNIFORM:\n        kernel = torch.ones((num_channels, 1, *kernel_size)) / torch.prod(torch.tensor(kernel_size))\n\n    kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0]\n\n    c1 = (k1 * data_range) ** 2  # stability constant for luminance\n    c2 = (k2 * data_range) ** 2  # stability constant for contrast\n\n    conv_fn = getattr(F, f\"conv{spatial_dims}d\")\n    mu_x = conv_fn(y_pred, kernel, groups=num_channels)\n    mu_y = conv_fn(y, kernel, groups=num_channels)\n    mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels)\n    mu_yy = conv_fn(y * y, kernel, groups=num_channels)\n    mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels)\n\n    sigma_x = mu_xx - mu_x * mu_x\n    sigma_y = mu_yy - mu_y * mu_y\n    sigma_xy = mu_xy - mu_x * mu_y\n\n    contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2)\n    ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity\n\n    return ssim_value_full_image, contrast_sensitivity\n\n\nclass MultiScaleSSIMMetric(RegressionMetric):\n    \"\"\"\n    Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM).\n\n    MS-SSIM reference paper:\n        Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. \"Multiscale structural\n        similarity for image quality assessment.\" In The Thirty-Seventh Asilomar Conference\n        on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). IEEE\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input images.\n        data_range: value range of input images. (usually 1.0 or 255)\n        kernel_type: type of kernel, can be \"gaussian\" or \"uniform\".\n        kernel_size: size of kernel\n        kernel_sigma: standard deviation for Gaussian kernel.\n        k1: stability constant used in the luminance denominator\n        k2: stability constant used in the contrast denominator\n        weights: parameters for image similarity and contrast sensitivity at different resolution scores.\n        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        data_range: float = 1.0,\n        kernel_type: KernelType | str = KernelType.GAUSSIAN,\n        kernel_size: int | Sequence[int] = 11,\n        kernel_sigma: float | Sequence[float] = 1.5,\n        k1: float = 0.01,\n        k2: float = 0.03,\n        weights: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__(reduction=reduction, get_not_nans=get_not_nans)\n\n        self.spatial_dims = spatial_dims\n        self.data_range = data_range\n        self.kernel_type = kernel_type\n\n        if not isinstance(kernel_size, Sequence):\n            kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)\n        self.kernel_size = kernel_size\n\n        if not isinstance(kernel_sigma, Sequence):\n            kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)\n        self.kernel_sigma = kernel_sigma\n\n        self.k1 = k1\n        self.k2 = k2\n        self.weights = weights\n\n    def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        return compute_ms_ssim(\n            y_pred=y_pred,\n            y=y,\n            spatial_dims=self.spatial_dims,\n            data_range=self.data_range,\n            kernel_type=self.kernel_type,\n            kernel_size=self.kernel_size,\n            kernel_sigma=self.kernel_sigma,\n            k1=self.k1,\n            k2=self.k2,\n            weights=self.weights,\n        )\n\n\ndef compute_ms_ssim(\n    y_pred: torch.Tensor,\n    y: torch.Tensor,\n    spatial_dims: int,\n    data_range: float = 1.0,\n    kernel_type: KernelType | str = KernelType.GAUSSIAN,\n    kernel_size: int | Sequence[int] = 11,\n    kernel_sigma: float | Sequence[float] = 1.5,\n    k1: float = 0.01,\n    k2: float = 0.03,\n    weights: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),\n) -> torch.Tensor:\n    \"\"\"\n    Args:\n        y_pred: Predicted image.\n            It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n        y: Reference image.\n            It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n        spatial_dims: number of spatial dimensions of the input images.\n        data_range: value range of input images. (usually 1.0 or 255)\n        kernel_type: type of kernel, can be \"gaussian\" or \"uniform\".\n        kernel_size: size of kernel\n        kernel_sigma: standard deviation for Gaussian kernel.\n        k1: stability constant used in the luminance denominator\n        k2: stability constant used in the contrast denominator\n        weights: parameters for image similarity and contrast sensitivity at different resolution scores.\n    Raises:\n        ValueError: when `y_pred` is not a 2D or 3D image.\n    \"\"\"\n    dims = y_pred.ndimension()\n    if spatial_dims == 2 and dims != 4:\n        raise ValueError(\n            f\"y_pred should have 4 dimensions (batch, channel, height, width) when using {spatial_dims} \"\n            f\"spatial dimensions, got {dims}.\"\n        )\n\n    if spatial_dims == 3 and dims != 5:\n        raise ValueError(\n            f\"y_pred should have 5 dimensions (batch, channel, height, width, depth) when using {spatial_dims}\"\n            f\" spatial dimensions, got {dims}.\"\n        )\n\n    if not isinstance(kernel_size, Sequence):\n        kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)\n\n    if not isinstance(kernel_sigma, Sequence):\n        kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)\n    # check if image have enough size for the number of downsamplings and the size of the kernel\n    weights_div = max(1, (len(weights) - 1)) ** 2\n    y_pred_spatial_dims = y_pred.shape[2:]\n    for i in range(len(y_pred_spatial_dims)):\n        if y_pred_spatial_dims[i] // weights_div <= kernel_size[i] - 1:\n            raise ValueError(\n                f\"For a given number of `weights` parameters {len(weights)} and kernel size \"\n                f\"{kernel_size[i]}, the image height must be larger than \"\n                f\"{(kernel_size[i] - 1) * weights_div}.\"\n            )\n\n    weights_tensor = torch.tensor(weights, device=y_pred.device, dtype=torch.float)\n\n    avg_pool = getattr(F, f\"avg_pool{spatial_dims}d\")\n\n    multiscale_list: list[torch.Tensor] = []\n    for _ in range(len(weights_tensor)):\n        ssim, cs = compute_ssim_and_cs(\n            y_pred=y_pred,\n            y=y,\n            spatial_dims=spatial_dims,\n            data_range=data_range,\n            kernel_type=kernel_type,\n            kernel_size=kernel_size,\n            kernel_sigma=kernel_sigma,\n            k1=k1,\n            k2=k2,\n        )\n\n        cs_per_batch = cs.view(cs.shape[0], -1).mean(1)\n\n        multiscale_list.append(torch.relu(cs_per_batch))\n        y_pred = avg_pool(y_pred, kernel_size=2)\n        y = avg_pool(y, kernel_size=2)\n\n    ssim = ssim.view(ssim.shape[0], -1).mean(1)\n    multiscale_list[-1] = torch.relu(ssim)\n    multiscale_list_tensor = torch.stack(multiscale_list)\n\n    ms_ssim_value_full_image = torch.prod(multiscale_list_tensor ** weights_tensor.view(-1, 1), dim=0)\n\n    ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean(\n        1, keepdim=True\n    )\n\n    return ms_ssim_per_batch\n"
  },
  {
    "path": "monai/metrics/rocauc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom typing import TYPE_CHECKING, cast\n\nimport numpy as np\n\nif TYPE_CHECKING:\n    import numpy.typing as npt\n\nimport torch\n\nfrom monai.utils import Average, look_up_option\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass ROCAUCMetric(CumulativeIterationMetric):\n    \"\"\"\n    Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:\n    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/\n    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.\n    The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n            Type of averaging performed if not binary classification.\n            Defaults to ``\"macro\"``.\n\n            - ``\"macro\"``: calculate metrics for each label, and find their unweighted mean.\n                This does not take label imbalance into account.\n            - ``\"weighted\"``: calculate metrics for each label, and find their average,\n                weighted by support (the number of true instances for each label).\n            - ``\"micro\"``: calculate metrics globally by considering each element of the label\n                indicator matrix as a label.\n            - ``\"none\"``: the scores for each class are returned.\n\n    \"\"\"\n\n    def __init__(self, average: Average | str = Average.MACRO) -> None:\n        super().__init__()\n        self.average = average\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:  # type: ignore[override]\n        return y_pred, y\n\n    def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:\n        \"\"\"\n        Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,\n        This function reads the buffers and computes the area under the ROC.\n\n        Args:\n            average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n                Type of averaging performed if not binary classification. Defaults to `self.average`.\n\n        \"\"\"\n        y_pred, y = self.get_buffer()\n        # compute final value and do metric reduction\n        if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):\n            raise ValueError(\"y_pred and y must be PyTorch Tensor.\")\n\n        return compute_roc_auc(y_pred=y_pred, y=y, average=average or self.average)\n\n\ndef _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:\n    if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):\n        raise AssertionError(\"y and y_pred must be 1 dimension data with same length.\")\n    y_unique = y.unique()\n    if len(y_unique) == 1:\n        warnings.warn(f\"y values can not be all {y_unique.item()}, skip AUC computation and return `Nan`.\")\n        return float(\"nan\")\n    if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):\n        warnings.warn(f\"y values must be 0 or 1, but in {y_unique.tolist()}, skip AUC computation and return `Nan`.\")\n        return float(\"nan\")\n\n    n = len(y)\n    indices = y_pred.argsort()\n    y = y[indices].cpu().numpy()  # type: ignore[assignment]\n    y_pred = y_pred[indices].cpu().numpy()  # type: ignore[assignment]\n    nneg = auc = tmp_pos = tmp_neg = 0.0\n\n    for i in range(n):\n        y_i = cast(float, y[i])\n        if i + 1 < n and y_pred[i] == y_pred[i + 1]:\n            tmp_pos += y_i\n            tmp_neg += 1 - y_i\n            continue\n        if tmp_pos + tmp_neg > 0:\n            tmp_pos += y_i\n            tmp_neg += 1 - y_i\n            nneg += tmp_neg\n            auc += tmp_pos * (nneg - tmp_neg / 2)\n            tmp_pos = tmp_neg = 0\n            continue\n        if y_i == 1:\n            auc += nneg\n        else:\n            nneg += 1\n    return auc / (nneg * (n - nneg))\n\n\ndef compute_roc_auc(\n    y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO\n) -> np.ndarray | float | npt.ArrayLike:\n    \"\"\"Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:\n    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/\n    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.\n\n    Args:\n        y_pred: input data to compute, typical classification model output.\n            the first dim must be batch, if multi-classes, it must be in One-Hot format.\n            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.\n        y: ground truth to compute ROC AUC metric, the first dim must be batch.\n            if multi-classes, it must be in One-Hot format.\n            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.\n        average: {``\"macro\"``, ``\"weighted\"``, ``\"micro\"``, ``\"none\"``}\n            Type of averaging performed if not binary classification.\n            Defaults to ``\"macro\"``.\n\n            - ``\"macro\"``: calculate metrics for each label, and find their unweighted mean.\n                This does not take label imbalance into account.\n            - ``\"weighted\"``: calculate metrics for each label, and find their average,\n                weighted by support (the number of true instances for each label).\n            - ``\"micro\"``: calculate metrics globally by considering each element of the label\n                indicator matrix as a label.\n            - ``\"none\"``: the scores for each class are returned.\n\n    Raises:\n        ValueError: When ``y_pred`` dimension is not one of [1, 2].\n        ValueError: When ``y`` dimension is not one of [1, 2].\n        ValueError: When ``average`` is not one of [\"macro\", \"weighted\", \"micro\", \"none\"].\n\n    Note:\n        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.\n\n    \"\"\"\n    y_pred_ndim = y_pred.ndimension()\n    y_ndim = y.ndimension()\n    if y_pred_ndim not in (1, 2):\n        raise ValueError(\n            f\"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}.\"\n        )\n    if y_ndim not in (1, 2):\n        raise ValueError(f\"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.\")\n    if y_pred_ndim == 2 and y_pred.shape[1] == 1:\n        y_pred = y_pred.squeeze(dim=-1)\n        y_pred_ndim = 1\n    if y_ndim == 2 and y.shape[1] == 1:\n        y = y.squeeze(dim=-1)\n\n    if y_pred_ndim == 1:\n        return _calculate(y_pred, y)\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.\")\n\n    average = look_up_option(average, Average)\n    if average == Average.MICRO:\n        return _calculate(y_pred.flatten(), y.flatten())\n    y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)\n    auc_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]\n    if average == Average.NONE:\n        return auc_values\n    if average == Average.MACRO:\n        return np.mean(auc_values)\n    if average == Average.WEIGHTED:\n        weights = [sum(y_) for y_ in y]\n        return np.average(auc_values, weights=weights)  # type: ignore[no-any-return]\n    raise ValueError(f'Unsupported average: {average}, available options are [\"macro\", \"weighted\", \"micro\", \"none\"].')\n"
  },
  {
    "path": "monai/metrics/surface_dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing\nfrom monai.utils import MetricReduction\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass SurfaceDiceMetric(CumulativeIterationMetric):\n    \"\"\"\n    Computes the Normalized Surface Dice (NSD) for each batch sample and class of\n    predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.\n    This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.\n    Be aware that by default (`use_subvoxels=False`), the computation of boundaries is different from DeepMind's\n    implementation https://github.com/deepmind/surface-distance.\n    In this implementation, the length/area of a segmentation boundary is\n    interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary\n    depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).\n    This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103.\n\n    The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`.\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        class_thresholds: List of class-specific thresholds.\n            The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.\n            Each threshold needs to be a finite, non-negative number.\n        include_background: Whether to include NSD computation on the first channel of the predicted output.\n            Defaults to ``False``.\n        distance_metric: The metric used to compute surface distances.\n            One of [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``].\n            Defaults to ``\"euclidean\"``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count.\n            Defaults to ``False``.\n            `not_nans` is the number of batch samples for which not all class-specific NSD values were nan values.\n            If set to ``True``, the function `aggregate` will return both the aggregated NSD and the `not_nans` count.\n            If set to ``False``, `aggregate` will only return the aggregated NSD.\n        use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        class_thresholds: list[float],\n        include_background: bool = False,\n        distance_metric: str = \"euclidean\",\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        use_subvoxels: bool = False,\n    ) -> None:\n        super().__init__()\n        self.class_thresholds = class_thresholds\n        self.include_background = include_background\n        self.distance_metric = distance_metric\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n        self.use_subvoxels = use_subvoxels\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor:  # type: ignore[override]\n        r\"\"\"\n        Args:\n            y_pred: Predicted segmentation, typically segmentation model output.\n                It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n            y: Reference segmentation.\n                It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n            kwargs: additional parameters: ``spacing`` should be passed to correctly compute the metric.\n                ``spacing``: spacing of pixel (or voxel). This parameter is relevant only\n                if ``distance_metric`` is set to ``\"euclidean\"``.\n                If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,\n                the length of the sequence must be equal to the image dimensions.\n                This spacing will be used for all images in the batch.\n                If a sequence of sequences, the length of the outer sequence must be equal to the batch size.\n                If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,\n                else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used\n                for all images in batch. Defaults to ``None``.\n                use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.\n\n\n        Returns:\n            Pytorch Tensor of shape [B,C], containing the NSD values :math:`\\operatorname {NSD}_{b,c}` for each batch\n            index :math:`b` and class :math:`c`.\n        \"\"\"\n        return compute_surface_dice(\n            y_pred=y_pred,\n            y=y,\n            class_thresholds=self.class_thresholds,\n            include_background=self.include_background,\n            distance_metric=self.distance_metric,\n            spacing=kwargs.get(\"spacing\"),\n            use_subvoxels=self.use_subvoxels,\n        )\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        r\"\"\"\n        Aggregates the output of `_compute_tensor`.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n\n        Returns:\n            If `get_not_nans` is set to ``True``, this function returns the aggregated NSD and the `not_nans` count.\n            If `get_not_nans` is set to ``False``, this function returns only the aggregated NSD.\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n\ndef compute_surface_dice(\n    y_pred: torch.Tensor,\n    y: torch.Tensor,\n    class_thresholds: list[float],\n    include_background: bool = False,\n    distance_metric: str = \"euclidean\",\n    spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,\n    use_subvoxels: bool = False,\n) -> torch.Tensor:\n    r\"\"\"\n    This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as\n    :math:`\\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation\n    boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the\n    reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation\n    in pixels. The NSD is bounded between 0 and 1.\n\n    This implementation supports multi-class tasks with an individual threshold :math:`\\tau_c` for each class :math:`c`.\n    The class-specific NSD for batch index :math:`b`, :math:`\\operatorname {NSD}_{b,c}`, is computed using the function:\n\n    .. math::\n        \\operatorname {NSD}_{b,c} \\left(Y_{b,c}, \\hat{Y}_{b,c}\\right) = \\frac{\\left|\\mathcal{D}_{Y_{b,c}}^{'}\\right| +\n        \\left| \\mathcal{D}_{\\hat{Y}_{b,c}}^{'} \\right|}{\\left|\\mathcal{D}_{Y_{b,c}}\\right| +\n        \\left|\\mathcal{D}_{\\hat{Y}_{b,c}}\\right|}\n        :label: nsd\n\n    with :math:`\\mathcal{D}_{Y_{b,c}}` and :math:`\\mathcal{D}_{\\hat{Y}_{b,c}}` being two sets of nearest-neighbor\n    distances. :math:`\\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference\n    segmentation boundary and vice-versa for :math:`\\mathcal{D}_{\\hat{Y}_{b,c}}`. :math:`\\mathcal{D}_{Y_{b,c}}^{'}` and\n    :math:`\\mathcal{D}_{\\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the\n    acceptable distance :math:`\\tau_c`:\n\n    .. math::\n        \\mathcal{D}_{Y_{b,c}}^{'} = \\{ d \\in \\mathcal{D}_{Y_{b,c}} \\, | \\, d \\leq \\tau_c \\}.\n\n\n    In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation,\n    a nan value will be returned for this class. In the case of a class being present in only one of predicted\n    segmentation or reference segmentation, the class NSD will be 0.\n\n    This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.\n    The computation of boundaries follows DeepMind's implementation\n    https://github.com/deepmind/surface-distance when `use_subvoxels=True`; Otherwise the length of a segmentation\n    boundary is interpreted as the number of its edge pixels.\n\n    Args:\n        y_pred: Predicted segmentation, typically segmentation model output.\n            It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n        y: Reference segmentation.\n            It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].\n        class_thresholds: List of class-specific thresholds.\n            The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.\n            Each threshold needs to be a finite, non-negative number.\n        include_background: Whether to include the surface dice computation on the first channel of\n            the predicted output. Defaults to ``False``.\n        distance_metric: The metric used to compute surface distances.\n            One of [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``].\n            Defaults to ``\"euclidean\"``.\n        spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``\"euclidean\"``.\n            If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,\n            the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch.\n            If a sequence of sequences, the length of the outer sequence must be equal to the batch size.\n            If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,\n            else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used\n            for all images in batch. Defaults to ``None``.\n        use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.\n\n    Raises:\n        ValueError: If `y_pred` and/or `y` are not PyTorch tensors.\n        ValueError: If `y_pred` and/or `y` do not have four dimensions.\n        ValueError: If `y_pred` and/or `y` have different shapes.\n        ValueError: If `y_pred` and/or `y` are not one-hot encoded\n        ValueError: If the number of channels of `y_pred` and/or `y` is different from the number of class thresholds.\n        ValueError: If any class threshold is not finite.\n        ValueError: If any class threshold is negative.\n\n    Returns:\n        Pytorch Tensor of shape [B,C], containing the NSD values :math:`\\operatorname {NSD}_{b,c}` for each batch index\n        :math:`b` and class :math:`c`.\n    \"\"\"\n\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):\n        raise ValueError(\"y_pred and y must be PyTorch Tensor.\")\n\n    if y_pred.ndimension() not in (4, 5) or y.ndimension() not in (4, 5):\n        raise ValueError(\"y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].\")\n\n    if y_pred.shape != y.shape:\n        raise ValueError(\n            f\"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y).\"\n        )\n\n    batch_size, n_class = y_pred.shape[:2]\n\n    if n_class != len(class_thresholds):\n        raise ValueError(\n            f\"number of classes ({n_class}) does not match number of class thresholds ({len(class_thresholds)}).\"\n        )\n\n    if any(~np.isfinite(class_thresholds)):\n        raise ValueError(\"All class thresholds need to be finite.\")\n\n    if any(np.array(class_thresholds) < 0):\n        raise ValueError(\"All class thresholds need to be >= 0.\")\n\n    nsd = torch.empty((batch_size, n_class), device=y_pred.device, dtype=torch.float)\n\n    img_dim = y_pred.ndim - 2\n    spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)\n\n    for b, c in np.ndindex(batch_size, n_class):\n        (edges_pred, edges_gt), (distances_pred_gt, distances_gt_pred), areas = get_edge_surface_distance(  # type: ignore\n            y_pred[b, c],\n            y[b, c],\n            distance_metric=distance_metric,\n            spacing=spacing_list[b],\n            use_subvoxels=use_subvoxels,\n            symmetric=True,\n            class_index=c,\n        )\n        boundary_correct: int | torch.Tensor | float\n        boundary_complete: int | torch.Tensor | float\n        if not use_subvoxels:\n            boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)\n            boundary_correct = torch.sum(distances_pred_gt <= class_thresholds[c]) + torch.sum(\n                distances_gt_pred <= class_thresholds[c]\n            )\n        else:\n            areas_pred, areas_gt = areas  # type: ignore\n            areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]\n            boundary_complete = areas_gt.sum() + areas_pred.sum()\n            gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0\n            pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0\n            boundary_correct = gt_true + pred_true\n        if boundary_complete == 0:\n            # the class is neither present in the prediction, nor in the reference segmentation\n            nsd[b, c] = torch.tensor(np.nan)\n        else:\n            nsd[b, c] = boundary_correct / boundary_complete\n\n    return nsd\n"
  },
  {
    "path": "monai/metrics/surface_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing\nfrom monai.utils import MetricReduction, convert_data_type\n\nfrom .metric import CumulativeIterationMetric\n\n\nclass SurfaceDistanceMetric(CumulativeIterationMetric):\n    \"\"\"\n    Compute Surface Distance between two tensors. It can support both multi-classes and multi-labels tasks.\n    It supports both symmetric and asymmetric surface distance calculation.\n    Input `y_pred` is compared with ground truth `y`.\n    `y_preds` is expected to have binarized predictions and `y` should be in one-hot format.\n    You can use suitable transforms in ``monai.transforms.post`` first to achieve binarized values.\n    `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).\n\n    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.\n\n    Args:\n        include_background: whether to include distance computation on the first channel of\n            the predicted output. Defaults to ``False``.\n        symmetric: whether to calculate the symmetric average surface distance between\n            `seg_pred` and `seg_gt`. Defaults to ``False``.\n        distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n            the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        include_background: bool = False,\n        symmetric: bool = False,\n        distance_metric: str = \"euclidean\",\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__()\n        self.include_background = include_background\n        self.distance_metric = distance_metric\n        self.symmetric = symmetric\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Args:\n            y_pred: input data to compute, typical segmentation model output.\n                It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n                should be binarized.\n            y: ground truth to compute the distance. It must be one-hot format and first dim is batch.\n                The values should be binarized.\n            kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.\n                ``spacing``: spacing of pixel (or voxel). This parameter is relevant only\n                if ``distance_metric`` is set to ``\"euclidean\"``.\n                If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,\n                the length of the sequence must be equal to the image dimensions.\n                This spacing will be used for all images in the batch.\n                If a sequence of sequences, the length of the outer sequence must be equal to the batch size.\n                If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,\n                else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used\n                for all images in batch. Defaults to ``None``.\n\n        Raises:\n            ValueError: when `y_pred` has less than three dimensions.\n        \"\"\"\n        if y_pred.dim() < 3:\n            raise ValueError(\"y_pred should have at least three dimensions.\")\n\n        # compute (BxC) for each channel for each batch\n        return compute_average_surface_distance(\n            y_pred=y_pred,\n            y=y,\n            include_background=self.include_background,\n            symmetric=self.symmetric,\n            distance_metric=self.distance_metric,\n            spacing=kwargs.get(\"spacing\"),\n        )\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Execute reduction logic for the output of `compute_average_surface_distance`.\n\n        Args:\n            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n                available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n                ``\"mean_channel\"``, ``\"sum_channel\"``}, default to `self.reduction`. if \"none\", will not do reduction.\n\n        \"\"\"\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n\ndef compute_average_surface_distance(\n    y_pred: np.ndarray | torch.Tensor,\n    y: np.ndarray | torch.Tensor,\n    include_background: bool = False,\n    symmetric: bool = False,\n    distance_metric: str = \"euclidean\",\n    spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    This function is used to compute the Average Surface Distance from `y_pred` to `y`\n    under the default setting.\n    In addition, if sets ``symmetric = True``, the average symmetric surface distance between\n    these two inputs will be returned.\n    The implementation refers to `DeepMind's implementation <https://github.com/deepmind/surface-distance>`_.\n\n    Args:\n        y_pred: input data to compute, typical segmentation model output.\n            It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values\n            should be binarized.\n        y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch.\n            The values should be binarized.\n        include_background: whether to include distance computation on the first channel of\n            the predicted output. Defaults to ``False``.\n        symmetric: whether to calculate the symmetric average surface distance between\n            `seg_pred` and `seg_gt`. Defaults to ``False``.\n        distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n            the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n        spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``\"euclidean\"``.\n            If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,\n            the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch.\n            If a sequence of sequences, the length of the outer sequence must be equal to the batch size.\n            If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,\n            else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used\n            for all images in batch. Defaults to ``None``.\n    \"\"\"\n\n    if not include_background:\n        y_pred, y = ignore_background(y_pred=y_pred, y=y)\n\n    y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0]\n    y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0]\n\n    if y.shape != y_pred.shape:\n        raise ValueError(f\"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.\")\n\n    batch_size, n_class = y_pred.shape[:2]\n    asd = torch.empty((batch_size, n_class), dtype=torch.float32, device=y_pred.device)\n\n    img_dim = y_pred.ndim - 2\n    spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)\n\n    for b, c in np.ndindex(batch_size, n_class):\n        _, distances, _ = get_edge_surface_distance(\n            y_pred[b, c],\n            y[b, c],\n            distance_metric=distance_metric,\n            spacing=spacing_list[b],\n            symmetric=symmetric,\n            class_index=c,\n        )\n        surface_distance = torch.cat(distances)\n        asd[b, c] = torch.tensor(np.nan) if surface_distance.shape == (0,) else surface_distance.mean()\n\n    return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]\n"
  },
  {
    "path": "monai/metrics/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Iterable, Sequence\nfrom functools import cache, partial\nfrom types import ModuleType\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import NdarrayOrTensor, NdarrayTensor\nfrom monai.transforms.croppad.dictionary import CropForegroundD\nfrom monai.transforms.utils import distance_transform_edt as monai_distance_transform_edt\nfrom monai.utils import (\n    MetricReduction,\n    convert_to_cupy,\n    convert_to_dst_type,\n    convert_to_numpy,\n    convert_to_tensor,\n    deprecated_arg,\n    ensure_tuple_rep,\n    look_up_option,\n    optional_import,\n)\n\nbinary_erosion, _ = optional_import(\"scipy.ndimage\", name=\"binary_erosion\")\ndistance_transform_edt, _ = optional_import(\"scipy.ndimage\", name=\"distance_transform_edt\")\ndistance_transform_cdt, _ = optional_import(\"scipy.ndimage\", name=\"distance_transform_cdt\")\n\n__all__ = [\n    \"ignore_background\",\n    \"do_metric_reduction\",\n    \"get_mask_edges\",\n    \"get_surface_distance\",\n    \"is_binary_tensor\",\n    \"remap_instance_id\",\n    \"prepare_spacing\",\n    \"get_code_to_measure_table\",\n]\n\n\ndef ignore_background(y_pred: NdarrayTensor, y: NdarrayTensor) -> tuple[NdarrayTensor, NdarrayTensor]:\n    \"\"\"\n    This function is used to remove background (the first channel) for `y_pred` and `y`.\n\n    Args:\n        y_pred: predictions. As for classification tasks,\n            `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,\n            the shape should be [BNHW] or [BNHWD].\n        y: ground truth, the first dim is batch.\n\n    \"\"\"\n\n    y = y[:, 1:] if y.shape[1] > 1 else y  # type: ignore[assignment]\n    y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred  # type: ignore[assignment]\n    return y_pred, y\n\n\ndef do_metric_reduction(\n    f: torch.Tensor, reduction: MetricReduction | str = MetricReduction.MEAN\n) -> tuple[torch.Tensor | Any, torch.Tensor]:\n    \"\"\"\n    This function is to do the metric reduction for calculated `not-nan` metrics of each sample's each class.\n    The function also returns `not_nans`, which counts the number of not nans for the metric.\n\n    Args:\n        f: a tensor that contains the calculated metric scores per batch and\n            per class. The first two dims should be batch and class.\n        reduction: define the mode to reduce metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``.\n            if \"none\", return the input f tensor and not_nans.\n\n    Raises:\n        ValueError: When ``reduction`` is not one of\n            [\"mean\", \"sum\", \"mean_batch\", \"sum_batch\", \"mean_channel\", \"sum_channel\" \"none\"].\n    \"\"\"\n\n    # some elements might be Nan (if ground truth y was missing (zeros))\n    # we need to account for it\n    nans = torch.isnan(f)\n    not_nans = ~nans\n\n    t_zero = torch.zeros(1, device=f.device, dtype=torch.float)\n    reduction = look_up_option(reduction, MetricReduction)\n    if reduction == MetricReduction.NONE:\n        return f, not_nans.float()\n\n    f[nans] = 0\n    if reduction == MetricReduction.MEAN:\n        # 2 steps, first, mean by channel (accounting for nans), then by batch\n        not_nans = not_nans.sum(dim=1).float()\n        f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero)  # channel average\n\n        not_nans = (not_nans > 0).sum(dim=0).float()\n        f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero)  # batch average\n\n    elif reduction == MetricReduction.SUM:\n        not_nans = not_nans.sum(dim=[0, 1]).float()\n        f = torch.sum(f, dim=[0, 1])  # sum over the batch and channel dims\n    elif reduction == MetricReduction.MEAN_BATCH:\n        not_nans = not_nans.sum(dim=0).float()\n        f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero)  # batch average\n    elif reduction == MetricReduction.SUM_BATCH:\n        not_nans = not_nans.sum(dim=0).float()\n        f = f.sum(dim=0).float()  # the batch sum\n    elif reduction == MetricReduction.MEAN_CHANNEL:\n        not_nans = not_nans.sum(dim=1).float()\n        f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero)  # channel average\n    elif reduction == MetricReduction.SUM_CHANNEL:\n        not_nans = not_nans.sum(dim=1).float()\n        f = f.sum(dim=1).float()  # the channel sum\n    elif reduction != MetricReduction.NONE:\n        raise ValueError(\n            f\"Unsupported reduction: {reduction}, available options are \"\n            '[\"mean\", \"sum\", \"mean_batch\", \"sum_batch\", \"mean_channel\", \"sum_channel\" \"none\"].'\n        )\n    return f, not_nans\n\n\n@deprecated_arg(\n    name=\"always_return_as_numpy\",\n    since=\"1.5.0\",\n    removed=\"1.7.0\",\n    msg_suffix=\"The option is removed and the return type will always be equal to the input type.\",\n)\ndef get_mask_edges(\n    seg_pred: NdarrayOrTensor,\n    seg_gt: NdarrayOrTensor,\n    label_idx: int = 1,\n    crop: bool = True,\n    spacing: Sequence | None = None,\n    always_return_as_numpy: bool = False,\n) -> tuple[NdarrayTensor, NdarrayTensor]:\n    \"\"\"\n    Compute edges from binary segmentation masks. This\n    function is helpful to further calculate metrics such as Average Surface\n    Distance and Hausdorff Distance.\n    The input images can be binary or labelfield images. If labelfield images\n    are supplied, they are converted to binary images using `label_idx`.\n\n    In order to improve the computing efficiency, before getting the edges,\n    the images can be cropped and only keep the foreground if not specifies\n    ``crop = False``.\n\n    We require that images are the same size, and assume that they occupy the\n    same space (spacing, orientation, etc.).\n\n    Args:\n        seg_pred: the predicted binary or labelfield image.\n        seg_gt: the actual binary or labelfield image.\n        label_idx: for labelfield images, convert to binary with\n            `seg_pred = seg_pred == label_idx`.\n        crop: crop input images and only keep the foregrounds. In order to\n            maintain two inputs' shapes, here the bounding box is achieved\n            by ``(seg_pred | seg_gt)`` which represents the union set of two\n            images. Defaults to ``True``.\n        spacing: the input spacing. If not None, the subvoxel edges and areas will be computed.\n            otherwise `scipy`'s binary erosion is used to calculate the edges.\n        always_return_as_numpy: whether to a numpy array regardless of the input type.\n            If False, return the same type as inputs.\n            The default value is changed from `True` to `False` in v1.5.0.\n    \"\"\"\n    # move in the funciton to avoid using all the GPUs\n    cucim_binary_erosion, has_cucim_binary_erosion = optional_import(\"cucim.skimage.morphology\", name=\"binary_erosion\")\n    if seg_pred.shape != seg_gt.shape:\n        raise ValueError(f\"seg_pred and seg_gt should have same shapes, got {seg_pred.shape} and {seg_gt.shape}.\")\n    converter: Any\n    lib: ModuleType\n    if isinstance(seg_pred, torch.Tensor) and not always_return_as_numpy:\n        converter = partial(convert_to_tensor, device=seg_pred.device)\n        lib = torch\n    else:\n        converter = convert_to_numpy\n        lib = np\n    use_cucim = (\n        spacing is None\n        and has_cucim_binary_erosion\n        and isinstance(seg_pred, torch.Tensor)\n        and seg_pred.device.type == \"cuda\"\n    )\n\n    # If not binary images, convert them\n    if seg_pred.dtype not in (bool, torch.bool):\n        seg_pred = seg_pred == label_idx\n    if seg_gt.dtype not in (bool, torch.bool):\n        seg_gt = seg_gt == label_idx\n    if crop:\n        or_vol = seg_pred | seg_gt\n        if not or_vol.any():\n            pred, gt = lib.zeros(seg_pred.shape, dtype=bool), lib.zeros(seg_gt.shape, dtype=bool)\n            return (pred, gt) if spacing is None else (pred, gt, pred, gt)\n        channel_first = [seg_pred[None], seg_gt[None], or_vol[None]]\n        if spacing is None and not use_cucim:  # cpu only erosion\n            seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, device=\"cpu\", dtype=bool)\n        else:  # pytorch subvoxel, maybe on gpu, but croppad boolean values on GPU is not supported\n            seg_pred, seg_gt, or_vol = convert_to_tensor(channel_first, dtype=torch.float16)\n        cropper = CropForegroundD(\n            [\"pred\", \"gt\"], source_key=\"src\", margin=1, allow_smaller=False, start_coord_key=None, end_coord_key=None\n        )\n        cropped = cropper({\"pred\": seg_pred, \"gt\": seg_gt, \"src\": or_vol})  # type: ignore\n        seg_pred, seg_gt = cropped[\"pred\"][0], cropped[\"gt\"][0]\n\n    if spacing is None:  # Do binary erosion and use XOR to get edges\n        if not use_cucim:\n            seg_pred, seg_gt = convert_to_numpy([seg_pred, seg_gt], dtype=bool)\n            edges_pred = binary_erosion(seg_pred) ^ seg_pred\n            edges_gt = binary_erosion(seg_gt) ^ seg_gt\n        else:\n            seg_pred, seg_gt = convert_to_cupy([seg_pred, seg_gt], dtype=bool)  # type: ignore[arg-type]\n            edges_pred = cucim_binary_erosion(seg_pred) ^ seg_pred\n            edges_gt = cucim_binary_erosion(seg_gt) ^ seg_gt\n        return converter((edges_pred, edges_gt), dtype=bool)  # type: ignore\n    code_to_area_table, k = get_code_to_measure_table(spacing, device=seg_pred.device)  # type: ignore\n    spatial_dims = len(spacing)\n    conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d\n    vol = torch.stack([seg_pred[None], seg_gt[None]], dim=0).float()  # type: ignore\n    code_pred, code_gt = conv(vol, k.to(vol))\n    # edges\n    all_ones = len(code_to_area_table) - 1\n    edges_pred = (code_pred != 0) & (code_pred != all_ones)\n    edges_gt = (code_gt != 0) & (code_gt != all_ones)\n    # areas of edges\n    areas_pred = torch.index_select(code_to_area_table, 0, code_pred.view(-1).int()).reshape(code_pred.shape)\n    areas_gt = torch.index_select(code_to_area_table, 0, code_gt.view(-1).int()).reshape(code_gt.shape)\n    ret = (edges_pred[0], edges_gt[0], areas_pred[0], areas_gt[0])\n    return converter(ret, wrap_sequence=False)  # type: ignore\n\n\ndef get_surface_distance(\n    seg_pred: NdarrayOrTensor,\n    seg_gt: NdarrayOrTensor,\n    distance_metric: str = \"euclidean\",\n    spacing: int | float | np.ndarray | Sequence[int | float] | None = None,\n) -> NdarrayOrTensor:\n    \"\"\"\n    This function is used to compute the surface distances from `seg_pred` to `seg_gt`.\n\n    Args:\n        seg_pred: the edge of the predictions.\n        seg_gt: the edge of the ground truth.\n        distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n            the metric used to compute surface distance. Defaults to ``\"euclidean\"``.\n\n            - ``\"euclidean\"``, uses Exact Euclidean distance transform.\n            - ``\"chessboard\"``, uses `chessboard` metric in chamfer type of transform.\n            - ``\"taxicab\"``, uses `taxicab` metric in chamfer type of transform.\n        spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``\"euclidean\"``.\n            Several input options are allowed:\n            (1) If a single number, isotropic spacing with that value is used.\n            (2) If a sequence of numbers, the length of the sequence must be equal to the image dimensions.\n            (3) If ``None``, spacing of unity is used. Defaults to ``None``.\n\n    Note:\n        If seg_pred or seg_gt is all 0, may result in nan/inf distance.\n\n    \"\"\"\n    lib: ModuleType = torch if isinstance(seg_pred, torch.Tensor) else np\n    if not seg_gt.any():\n        dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32)\n    else:\n        if not lib.any(seg_pred):\n            dis = np.inf * lib.ones_like(seg_gt, dtype=lib.float32)\n            dis = dis[seg_gt]\n            return convert_to_dst_type(dis, seg_pred, dtype=dis.dtype)[0]\n        if distance_metric == \"euclidean\":\n            dis = monai_distance_transform_edt((~seg_gt)[None, ...], sampling=spacing)[0]  # type: ignore\n        elif distance_metric in {\"chessboard\", \"taxicab\"}:\n            dis = distance_transform_cdt(convert_to_numpy(~seg_gt), metric=distance_metric)\n        else:\n            raise ValueError(f\"distance_metric {distance_metric} is not implemented.\")\n    dis = convert_to_dst_type(dis, seg_pred, dtype=lib.float32)[0]\n    return dis[seg_pred]  # type: ignore\n\n\ndef get_edge_surface_distance(\n    y_pred: torch.Tensor,\n    y: torch.Tensor,\n    distance_metric: str = \"euclidean\",\n    spacing: int | float | np.ndarray | Sequence[int | float] | None = None,\n    use_subvoxels: bool = False,\n    symmetric: bool = False,\n    class_index: int = -1,\n) -> tuple[\n    tuple[torch.Tensor, torch.Tensor],\n    tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor],\n    tuple[torch.Tensor, torch.Tensor] | tuple[()],\n]:\n    \"\"\"\n    This function is used to compute the surface distance from `y_pred` to `y` using the edges of the masks.\n\n    Args:\n        y_pred: the predicted binary or labelfield image. Expected to be in format (H, W[, D]).\n        y: the actual binary or labelfield image. Expected to be in format (H, W[, D]).\n        distance_metric: : [``\"euclidean\"``, ``\"chessboard\"``, ``\"taxicab\"``]\n            See :py:func:`monai.metrics.utils.get_surface_distance`.\n        spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``\"euclidean\"``.\n            See :py:func:`monai.metrics.utils.get_surface_distance`.\n        use_subvoxels: whether to use subvoxel resolution (using the spacing).\n            This will return the areas of the edges.\n        symmetric: whether to compute the surface distance from `y_pred` to `y` and from `y` to `y_pred`.\n        class_index: The class-index used for context when warning about empty ground truth or prediction.\n\n    Returns:\n        (edges_pred, edges_gt), (distances_pred_to_gt, [distances_gt_to_pred]), (areas_pred, areas_gt) | tuple()\n\n    \"\"\"\n    edges_spacing = None\n    if use_subvoxels:\n        edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))\n    (edges_pred, edges_gt, *areas) = get_mask_edges(\n        y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False\n    )\n    if not edges_gt.any():\n        warnings.warn(\n            f\"the ground truth of class {class_index if class_index != -1 else 'Unknown'} is all 0,\"\n            \" this may result in nan/inf distance.\"\n        )\n    if not edges_pred.any():\n        warnings.warn(\n            f\"the prediction of class {class_index if class_index != -1 else 'Unknown'} is all 0,\"\n            \" this may result in nan/inf distance.\"\n        )\n    distances: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]\n    if symmetric:\n        distances = (\n            get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),\n            get_surface_distance(edges_gt, edges_pred, distance_metric, spacing),\n        )  # type: ignore\n    else:\n        distances = (get_surface_distance(edges_pred, edges_gt, distance_metric, spacing),)  # type: ignore\n    return convert_to_tensor(((edges_pred, edges_gt), distances, tuple(areas)), device=y_pred.device)  # type: ignore[no-any-return]\n\n\ndef is_binary_tensor(input: torch.Tensor, name: str) -> None:\n    \"\"\"Determines whether the input tensor is torch binary tensor or not.\n\n    Args:\n        input (torch.Tensor): tensor to validate.\n        name (str): name of the tensor being checked.\n\n    Raises:\n        ValueError: if `input` is not a PyTorch Tensor.\n\n    Note:\n        A warning message is printed, if the tensor is not binary.\n    \"\"\"\n    if not isinstance(input, torch.Tensor):\n        raise ValueError(f\"{name} must be of type PyTorch Tensor.\")\n    if not torch.all(input.byte() == input) or input.max() > 1 or input.min() < 0:\n        warnings.warn(f\"{name} should be a binarized tensor.\")\n\n\ndef remap_instance_id(pred: torch.Tensor, by_size: bool = False) -> torch.Tensor:\n    \"\"\"\n    This function is used to rename all instance id of `pred`, so that the id is\n    contiguous.\n    For example: all ids of the input can be [0, 1, 2] rather than [0, 2, 5].\n    This function is helpful for calculating metrics like Panoptic Quality (PQ).\n    The implementation refers to:\n\n    https://github.com/vqdang/hover_net\n\n    Args:\n        pred: segmentation predictions in the form of torch tensor. Each\n            value of the tensor should be an integer, and represents the prediction of its corresponding instance id.\n        by_size: if True, largest instance will be assigned a smaller id.\n\n    \"\"\"\n    pred_id: Iterable[Any] = list(pred.unique())\n    # the original implementation has the limitation that if there is no 0 in pred, error will happen\n    pred_id = [i for i in pred_id if i != 0]\n\n    if not pred_id:\n        return pred\n    if by_size:\n        instance_size = [(pred == instance_id).sum() for instance_id in pred_id]\n        pair_data = zip(pred_id, instance_size)\n        pair_list = sorted(pair_data, key=lambda x: x[1], reverse=True)\n        pred_id, _ = zip(*pair_list)\n\n    new_pred = torch.zeros_like(pred, dtype=torch.int)\n    for idx, instance_id in enumerate(pred_id):\n        new_pred[pred == instance_id] = idx + 1\n    return new_pred\n\n\ndef prepare_spacing(\n    spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None,\n    batch_size: int,\n    img_dim: int,\n) -> Sequence[None | int | float | np.ndarray | Sequence[int | float]]:\n    \"\"\"\n    This function is used to prepare the `spacing` parameter to include batch dimension for the computation of\n    surface distance, hausdorff distance or surface dice.\n\n    An example with batch_size = 4 and img_dim = 3:\n    input spacing = None -> output spacing = [None, None, None, None]\n    input spacing = 0.8 -> output spacing = [0.8, 0.8, 0.8, 0.8]\n    input spacing = [0.8, 0.5, 0.9] -> output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]]\n    input spacing = [0.8, 0.7, 1.2, 0.8] -> output spacing = [0.8, 0.7, 1.2, 0.8] (same as input)\n\n    An example with batch_size = 3 and img_dim = 3:\n    input spacing = [0.8, 0.5, 0.9] ->\n    output spacing = [[0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9], [0.8, 0.5, 0.9]]\n\n    Args:\n        spacing: can be a float, a sequence of length `img_dim`, or a sequence with length `batch_size`\n        that includes floats or sequences of length `img_dim`.\n\n    Raises:\n        ValueError: when `spacing` is a sequence of sequence, where the outer sequence length does not\n        equal `batch_size` or inner sequence length does not equal `img_dim`.\n\n    Returns:\n        spacing: a sequence with length `batch_size` that includes integers, floats or sequences of length `img_dim`.\n    \"\"\"\n    if spacing is None or isinstance(spacing, (int, float)):\n        return list([spacing] * batch_size)\n    if isinstance(spacing, (Sequence, np.ndarray)):\n        if any(not isinstance(s, type(spacing[0])) for s in list(spacing)):\n            raise ValueError(f\"if `spacing` is a sequence, its elements should be of same type, got {spacing}.\")\n        if isinstance(spacing[0], (Sequence, np.ndarray)):\n            if len(spacing) != batch_size:\n                raise ValueError(\n                    \"if `spacing` is a sequence of sequences, \"\n                    f\"the outer sequence should have same length as batch size ({batch_size}), got {spacing}.\"\n                )\n            if any(len(s) != img_dim for s in list(spacing)):\n                raise ValueError(\n                    \"each element of `spacing` list should either have same length as\"\n                    f\"image dim ({img_dim}), got {spacing}.\"\n                )\n            if not all(isinstance(i, (int, float)) for s in list(spacing) for i in list(s)):\n                raise ValueError(\n                    f\"if `spacing` is a sequence of sequences or 2D np.ndarray, \"\n                    f\"the elements should be integers or floats, got {spacing}.\"\n                )\n            return list(spacing)\n        if isinstance(spacing[0], (int, float)):\n            if len(spacing) != img_dim:\n                raise ValueError(\n                    f\"if `spacing` is a sequence of numbers, \"\n                    f\"it should have same length as image dim ({img_dim}), got {spacing}.\"\n                )\n            return [spacing for _ in range(batch_size)]  # type: ignore\n        raise ValueError(f\"`spacing` is a sequence of elements with unsupported type: {type(spacing[0])}\")\n    raise ValueError(\n        f\"`spacing` should either be a number, a sequence of numbers or a sequence of sequences, got {spacing}.\"\n    )\n\n\nENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]}\n\n\n@cache\ndef _get_neighbour_code_to_normals_table(device=None):\n    \"\"\"\n    returns a lookup table. For every binary neighbour code (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)\n    it contains the surface normals of the triangles. The length of the normal vector encodes the surfel area.\n    Adapted from https://github.com/deepmind/surface-distance\n\n    created using the marching_cube algorithm see e.g. https://en.wikipedia.org/wiki/Marching_cubes\n\n    Args:\n        device: torch device to use for the table.\n    \"\"\"\n    zeros = [0.0, 0.0, 0.0]\n    ret = [\n        [zeros, zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[-0.125, -0.125, 0.125], zeros, zeros, zeros],\n        [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros],\n        [[0.125, -0.125, 0.125], zeros, zeros, zeros],\n        [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros],\n        [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[-0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],\n        [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros],\n        [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros],\n        [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros],\n        [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros],\n        [[0.125, -0.125, -0.125], zeros, zeros, zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros],\n        [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros],\n        [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],\n        [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],\n        [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],\n        [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n        [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],\n        [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros],\n        [[0.125, -0.125, 0.125], zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros],\n        [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros],\n        [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros],\n        [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],\n        [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],\n        [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],\n        [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],\n        [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],\n        [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],\n        [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros],\n        [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],\n        [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros],\n        [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros],\n        [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros],\n        [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n        [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],\n        [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros],\n        [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros],\n        [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n        [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n        [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],\n        [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros],\n        [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros],\n        [[-0.125, -0.125, 0.125], zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros],\n        [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],\n        [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],\n        [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros],\n        [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n        [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],\n        [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],\n        [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],\n        [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],\n        [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros],\n        [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros],\n        [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],\n        [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],\n        [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n        [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],\n        [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros],\n        [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n        [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],\n        [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],\n        [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],\n        [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],\n        [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros],\n        [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros],\n        [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros],\n        [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],\n        [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],\n        [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],\n        [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros],\n        [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],\n        [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n        [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],\n        [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],\n        [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros],\n        [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros],\n        [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],\n        [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros],\n        [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125], zeros],\n        [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],\n        [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], zeros],\n        [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],\n        [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0], [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],\n        [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n        [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],\n        [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],\n        [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],\n        [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros],\n        [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],\n        [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],\n        [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros],\n        [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],\n        [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375], [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],\n        [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], zeros],\n        [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],\n        [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],\n        [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5], zeros],\n        [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],\n        [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5], zeros, zeros],\n        [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],\n        [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],\n        [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros],\n        [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], zeros, zeros],\n        [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],\n        [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375], [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],\n        [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25], [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],\n        [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n        [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],\n        [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],\n        [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25], [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], zeros, zeros],\n        [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros],\n        [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.125, -0.125, 0.125], zeros, zeros, zeros],\n        [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], zeros],\n        [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125], zeros],\n        [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],\n        [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], zeros],\n        [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0], [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],\n        [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],\n        [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros],\n        [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],\n        [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n        [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125], zeros],\n        [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0], zeros, zeros],\n        [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25], zeros],\n        [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125], zeros],\n        [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], zeros, zeros],\n        [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375], [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],\n        [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],\n        [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],\n        [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25], [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],\n        [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125], zeros],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25], [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],\n        [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125], zeros],\n        [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros],\n        [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25], zeros],\n        [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], zeros, zeros],\n        [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], zeros, zeros],\n        [[0.125, -0.125, 0.125], zeros, zeros, zeros],\n        [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], zeros],\n        [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],\n        [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25], [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],\n        [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros],\n        [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375], [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],\n        [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125], zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],\n        [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros],\n        [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125], zeros],\n        [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125], zeros, zeros],\n        [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], zeros, zeros],\n        [[0.125, -0.125, -0.125], zeros, zeros, zeros],\n        [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0], zeros, zeros],\n        [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125], zeros],\n        [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125], zeros],\n        [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0], zeros, zeros],\n        [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125], zeros],\n        [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], zeros, zeros],\n        [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125], zeros, zeros],\n        [[-0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125], zeros],\n        [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], zeros, zeros],\n        [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], zeros, zeros],\n        [[0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], zeros, zeros],\n        [[0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [[0.125, 0.125, 0.125], zeros, zeros, zeros],\n        [zeros, zeros, zeros, zeros],\n    ]\n    return torch.as_tensor(ret, device=device)\n\n\ndef create_table_neighbour_code_to_surface_area(spacing_mm, device=None):\n    \"\"\"\n    Returns an array mapping neighbourhood code to the surface elements area.\n    Adapted from https://github.com/deepmind/surface-distance\n\n    Note that the normals encode the initial surface area. This function computes\n    the area corresponding to the given `spacing`.\n\n    Args:\n        spacing_mm: a sequence of 3 numbers. Voxel spacing along the first 3 spatial axes.\n        device: device to put the table on.\n\n    Returns:\n        An array of size 256, mapping neighbourhood code to the surface area.\n        ENCODING_KERNEL[3] which is the kernel used to compute the neighbourhood code.\n    \"\"\"\n    spacing_mm = ensure_tuple_rep(spacing_mm, 3)\n    # compute the area for all 256 possible surface elements given a 2x2x2 neighbourhood according to the spacing_mm\n    c = _get_neighbour_code_to_normals_table(device)\n    s = torch.as_tensor(\n        [[[spacing_mm[1] * spacing_mm[2], spacing_mm[0] * spacing_mm[2], spacing_mm[0] * spacing_mm[1]]]],\n        device=device,\n        dtype=c.dtype,\n    )\n    norm = torch.linalg.norm(c * s, dim=-1)\n    neighbour_code_to_surface_area = norm.sum(-1)\n    return neighbour_code_to_surface_area, torch.as_tensor([[ENCODING_KERNEL[3]]], device=device)\n\n\ndef create_table_neighbour_code_to_contour_length(spacing_mm, device=None):\n    \"\"\"\n    Returns an array mapping neighbourhood code to the contour length.\n    Adapted from https://github.com/deepmind/surface-distance\n\n    In 2D, each point has 4 neighbors. Thus, are 16 configurations. A\n    configuration is encoded with '1' meaning \"inside the object\" and '0' \"outside\n    the object\". For example,\n    \"0101\" and \"1010\" both encode an edge along the first spatial axis with length spacing[0] mm;\n    \"0011\" and \"1100\" both encode an edge along the second spatial axis with length spacing[1] mm.\n\n    Args:\n        spacing_mm: 2-element list-like structure. Pixel spacing along the 1st and 2nd spatial axes.\n        device: device to put the table on.\n\n    Returns:\n        A 16-element array mapping neighbourhood code to the contour length.\n        ENCODING_KERNEL[2] which is the kernel used to compute the neighbourhood code.\n    \"\"\"\n    spacing_mm = ensure_tuple_rep(spacing_mm, 2)\n    first, second = spacing_mm  # spacing along the first and second spatial dimension respectively\n    diag = 0.5 * np.linalg.norm(spacing_mm)\n\n    neighbour_code_to_contour_length = np.zeros([16], dtype=diag.dtype)\n    neighbour_code_to_contour_length[int(\"0001\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"0010\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"0011\", 2)] = second\n    neighbour_code_to_contour_length[int(\"0100\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"0101\", 2)] = first\n    neighbour_code_to_contour_length[int(\"0110\", 2)] = 2 * diag\n    neighbour_code_to_contour_length[int(\"0111\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"1000\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"1001\", 2)] = 2 * diag\n    neighbour_code_to_contour_length[int(\"1010\", 2)] = first\n    neighbour_code_to_contour_length[int(\"1011\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"1100\", 2)] = second\n    neighbour_code_to_contour_length[int(\"1101\", 2)] = diag\n    neighbour_code_to_contour_length[int(\"1110\", 2)] = diag\n    neighbour_code_to_contour_length = convert_to_tensor(neighbour_code_to_contour_length, device=device)\n    return neighbour_code_to_contour_length, torch.as_tensor([[ENCODING_KERNEL[2]]], device=device)\n\n\ndef get_code_to_measure_table(spacing, device=None):\n    \"\"\"\n    returns a table mapping neighbourhood code to the surface area or contour length.\n\n    Args:\n        spacing: a sequence of 2 or 3 numbers, indicating the spacing in the spatial dimensions.\n        device: device to put the table on.\n    \"\"\"\n    spatial_dims = len(spacing)\n    spacing = ensure_tuple_rep(spacing, look_up_option(spatial_dims, (2, 3)))\n    if spatial_dims == 2:\n        return create_table_neighbour_code_to_contour_length(spacing, device)\n    return create_table_neighbour_code_to_surface_area(spacing, device)\n"
  },
  {
    "path": "monai/metrics/wrapper.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import cast\n\nimport torch\n\nfrom monai.metrics.utils import do_metric_reduction, ignore_background\nfrom monai.utils import MetricReduction, convert_to_numpy, convert_to_tensor, optional_import\n\nfrom .metric import CumulativeIterationMetric\n\nBinaryPairwiseMeasures, _ = optional_import(\"MetricsReloaded.metrics.pairwise_measures\", name=\"BinaryPairwiseMeasures\")\nMultiClassPairwiseMeasures, _ = optional_import(\n    \"MetricsReloaded.metrics.pairwise_measures\", name=\"MultiClassPairwiseMeasures\"\n)\n\n__all__ = [\"MetricsReloadedBinary\", \"MetricsReloadedCategorical\"]\n\n\nclass MetricsReloadedWrapper(CumulativeIterationMetric):\n    \"\"\"Base class for defining MetricsReloaded metrics as a CumulativeIterationMetric.\n\n    Args:\n        metric_name: Name of a metric from the MetricsReloaded package.\n        include_background: whether to include computation on the first channel of\n            the predicted output. Defaults to ``True``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric,\n            thus its shape equals to the shape of the metric.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        metric_name: str,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__()\n        self.metric_name = metric_name\n        self.include_background = include_background\n        self.reduction = reduction\n        self.get_not_nans = get_not_nans\n\n    def aggregate(\n        self, reduction: MetricReduction | str | None = None\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        data = self.get_buffer()\n        if not isinstance(data, torch.Tensor):\n            raise ValueError(\"the data to aggregate must be PyTorch Tensor.\")\n        # do metric reduction\n        f, not_nans = do_metric_reduction(data, reduction or self.reduction)\n        return (f, not_nans) if self.get_not_nans else f\n\n    def prepare_onehot(self, y_pred, y):\n        \"\"\"Prepares onehot encoded input for metric call.\"\"\"\n        y = y.float()\n        y_pred = y_pred.float()\n        if not self.include_background:\n            y_pred, y = ignore_background(y_pred=y_pred, y=y)\n        return y_pred, y, y_pred.device\n\n\nclass MetricsReloadedBinary(MetricsReloadedWrapper):\n    \"\"\"\n    Wraps the binary pairwise metrics of MetricsReloaded.\n\n    Args:\n        metric_name: Name of a binary metric from the MetricsReloaded package.\n        include_background: whether to include computation on the first channel of\n            the predicted output. Defaults to ``True``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric,\n            thus its shape equals to the shape of the metric.\n\n    Example:\n\n    .. code-block:: python\n\n        import torch\n        from monai.metrics import MetricsReloadedBinary\n\n        metric_name = \"Cohens Kappa\"\n        metric = MetricsReloadedBinary(metric_name=metric_name)\n\n        # first iteration\n        # shape [batch=1, channel=1, 2, 2]\n        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])\n        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])\n        print(metric(y_pred, y))\n\n        # second iteration\n        # shape [batch=1, channel=1, 2, 2]\n        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]]]])\n        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])\n        print(metric(y_pred, y))\n\n        # aggregate\n        # shape ([batch=2, channel=1])\n        print(metric.aggregate(reduction=\"none\"))  # tensor([[0.5], [0.2]])\n\n        # reset\n        metric.reset()\n\n    \"\"\"\n\n    def __init__(\n        self,\n        metric_name: str,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n    ) -> None:\n        super().__init__(\n            metric_name=metric_name,\n            include_background=include_background,\n            reduction=reduction,\n            get_not_nans=get_not_nans,\n        )\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"Computes a binary (single-class) MetricsReloaded metric from a batch of\n        predictions and references.\n\n        Args:\n            y_pred: Prediction with dimensions (batch, channel, *spatial), where channel=1.\n                The values should be binarized.\n            y: Ground-truth with dimensions (batch, channel, *spatial), where channel=1.\n                The values should be binarized.\n\n        Raises:\n            ValueError: when `y_pred` has less than three dimensions.\n            ValueError: when second dimension ~= 1\n\n        \"\"\"\n        # Preprocess\n        y_pred, y, device = self.prepare_onehot(y_pred, y)\n\n        # Sanity check\n        dims = y_pred.ndimension()\n        if dims < 3:\n            raise ValueError(f\"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.\")\n        if y_pred.shape[1] != 1 or y.shape[1] != 1:\n            raise ValueError(f\"y_pred.shape[1]={y_pred.shape[1]} and y.shape[1]={y.shape[1]} should be one.\")\n\n        # To numpy array\n        y_pred = convert_to_numpy(y_pred)\n        y = convert_to_numpy(y)\n\n        # Create binary pairwise metric object\n        bpm = BinaryPairwiseMeasures(y_pred, y, axis=tuple(range(2, dims)), smooth_dr=1e-5)\n\n        # Is requested metric available?\n        if self.metric_name not in bpm.metrics:\n            raise ValueError(f\"Unsupported metric: {self.metric_name}\")\n\n        # Compute metric\n        metric = bpm.metrics[self.metric_name]()\n\n        # Return metric as tensor\n        return convert_to_tensor(metric, device=device)  # type: ignore[no-any-return]\n\n\nclass MetricsReloadedCategorical(MetricsReloadedWrapper):\n    \"\"\"\n    Wraps the categorical pairwise metrics of MetricsReloaded.\n\n\n    Args:\n        metric_name: Name of a categorical metric from the MetricsReloaded package.\n        include_background: whether to include computation on the first channel of\n            the predicted output. Defaults to ``True``.\n        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,\n            available reduction modes: {``\"none\"``, ``\"mean\"``, ``\"sum\"``, ``\"mean_batch\"``, ``\"sum_batch\"``,\n            ``\"mean_channel\"``, ``\"sum_channel\"``}, default to ``\"mean\"``. if \"none\", will not do reduction.\n        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).\n            Here `not_nans` count the number of not nans for the metric,\n            thus its shape equals to the shape of the metric.\n        smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero.\n\n    Example:\n\n    .. code-block:: python\n\n        import torch\n        from monai.metrics import MetricsReloadedCategorical\n\n        metric_name = \"Weighted Cohens Kappa\"\n        metric = MetricsReloadedCategorical(metric_name=metric_name)\n\n        # first iteration\n        # shape [bach=1, channel=3, 2, 2]\n        y_pred = torch.tensor([[[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]])\n        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])\n        print(metric(y_pred, y))\n\n        # second iteration\n        # shape [batch=1, channel=3, 2, 2]\n        y_pred = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, 0], [0, 0]]]])\n        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])\n        print(metric(y_pred, y))\n\n        # aggregate\n        # shape ([batch=2, channel=1])\n        print(metric.aggregate(reduction=\"none\"))  # tensor([[0.2727], [0.6000]])\n\n        # reset\n        metric.reset()\n\n    \"\"\"\n\n    def __init__(\n        self,\n        metric_name: str,\n        include_background: bool = True,\n        reduction: MetricReduction | str = MetricReduction.MEAN,\n        get_not_nans: bool = False,\n        smooth_dr: float = 1e-5,\n    ) -> None:\n        super().__init__(\n            metric_name=metric_name,\n            include_background=include_background,\n            reduction=reduction,\n            get_not_nans=get_not_nans,\n        )\n        self.smooth_dr = smooth_dr\n\n    def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"Computes a categorical (multi-class) MetricsReloaded metric from a batch of\n        predictions and references.\n\n        Args:\n            y_pred: Prediction with dimensions (batch, channel, *spatial). The values should be\n                one-hot encoded and binarized.\n            y: Ground-truth with dimensions (batch, channel, *spatial). The values should be 1\n                one-hot encoded and binarized.\n\n        Raises:\n            ValueError: when `y_pred` has less than three dimensions.\n\n        \"\"\"\n        # Preprocess\n        y_pred, y, device = self.prepare_onehot(y_pred, y)\n\n        # Sanity check\n        dims = y_pred.ndimension()\n        if dims < 3:\n            raise ValueError(f\"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.\")\n\n        num_classes = y_pred.shape[1]\n\n        # Reshape and permute for compatible dimension with MetricsReloaded\n        y_pred = y_pred.reshape(y_pred.shape[0], y_pred.shape[1], -1)\n        y_pred = y_pred.permute((0, 2, 1))\n        y = y.reshape(y.shape[0], y.shape[1], -1)\n        y = y.permute((0, 2, 1))\n        dims = y_pred.ndimension()\n\n        # To numpy array\n        y_pred = convert_to_numpy(y_pred)\n        y = convert_to_numpy(y)\n\n        # Create categorical pairwise metric object\n        bpm = MultiClassPairwiseMeasures(\n            y_pred,\n            y,\n            axis=tuple(range(1, dims)),\n            smooth_dr=self.smooth_dr,\n            list_values=list(range(num_classes)),\n            is_onehot=True,\n        )\n\n        # Is requested metric available?\n        if self.metric_name not in bpm.metrics:\n            raise ValueError(f\"Unsupported metric: {self.metric_name}\")\n\n        # Compute metric\n        metric = bpm.metrics[self.metric_name]()\n\n        # Put back singleton channel dimension\n        metric = metric[..., None]\n\n        # Return metric as tensor\n        return cast(torch.Tensor, convert_to_tensor(metric, device=device))\n"
  },
  {
    "path": "monai/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .trt_compiler import trt_compile\nfrom .utils import (\n    add_casts_around_norms,\n    convert_to_onnx,\n    convert_to_torchscript,\n    convert_to_trt,\n    copy_model_state,\n    eval_mode,\n    get_state_dict,\n    icnr_init,\n    look_up_named_module,\n    normal_init,\n    normalize_transform,\n    one_hot,\n    pixelshuffle,\n    predict_segmentation,\n    replace_modules,\n    replace_modules_temp,\n    save_state,\n    set_named_module,\n    to_norm_affine,\n    train_mode,\n)\n"
  },
  {
    "path": "monai/networks/blocks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .acti_norm import ADN\nfrom .activation import GEGLU, MemoryEfficientSwish, Mish, Swish\nfrom .aspp import SimpleASPP\nfrom .backbone_fpn_utils import BackboneWithFPN\nfrom .cablock import CABlock, FeedForward\nfrom .convolutions import Convolution, ResidualUnit\nfrom .crf import CRF\nfrom .crossattention import CrossAttentionBlock\nfrom .denseblock import ConvDenseBlock, DenseBlock\nfrom .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock\nfrom .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample\nfrom .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding\nfrom .encoder import BaseEncoder\nfrom .fcn import FCN, GCN, MCFCN, Refine\nfrom .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7\nfrom .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock\nfrom .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock\nfrom .mlp import MLPBlock\nfrom .patchembedding import PatchEmbed, PatchEmbeddingBlock\nfrom .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock\nfrom .segresnet_block import ResBlock\nfrom .selfattention import SABlock\nfrom .spade_norm import SPADE\nfrom .spatialattention import SpatialAttentionBlock\nfrom .squeeze_and_excitation import (\n    ChannelSELayer,\n    ResidualSELayer,\n    SEBlock,\n    SEBottleneck,\n    SEResNetBottleneck,\n    SEResNeXtBottleneck,\n)\nfrom .transformerblock import TransformerBlock\nfrom .unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock\nfrom .upsample import SubpixelUpsample, Subpixelupsample, SubpixelUpSample, Upsample, UpSample\nfrom .warp import DVF2DDF, Warp\n"
  },
  {
    "path": "monai/networks/blocks/acti_norm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch.nn as nn\n\nfrom monai.networks.layers.utils import get_act_layer, get_dropout_layer, get_norm_layer\n\n\nclass ADN(nn.Sequential):\n    \"\"\"\n    Constructs a sequential module of optional activation (A), dropout (D), and normalization (N) layers\n    with an arbitrary order::\n\n        -- (Norm) -- (Dropout) -- (Acti) --\n\n    Args:\n        ordering: a string representing the ordering of activation, dropout, and normalization. Defaults to \"NDA\".\n        in_channels: `C` from an expected input of size (N, C, H[, W, D]).\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        norm_dim: determine the spatial dimensions of the normalization layer.\n            defaults to `dropout_dim` if unspecified.\n        dropout: dropout ratio. Defaults to no dropout.\n        dropout_dim: determine the spatial dimensions of dropout.\n            defaults to `norm_dim` if unspecified.\n\n            - When dropout_dim = 1, randomly zeroes some of the elements for each channel.\n            - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).\n            - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).\n\n    Examples::\n\n        # activation, group norm, dropout\n        >>> norm_params = (\"GROUP\", {\"num_groups\": 1, \"affine\": False})\n        >>> ADN(norm=norm_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering=\"AND\")\n        ADN(\n            (A): ReLU()\n            (N): GroupNorm(1, 1, eps=1e-05, affine=False)\n            (D): Dropout(p=0.8, inplace=False)\n        )\n\n        # LeakyReLU, dropout\n        >>> act_params = (\"leakyrelu\", {\"negative_slope\": 0.1, \"inplace\": True})\n        >>> ADN(act=act_params, in_channels=1, dropout_dim=1, dropout=0.8, ordering=\"AD\")\n        ADN(\n            (A): LeakyReLU(negative_slope=0.1, inplace=True)\n            (D): Dropout(p=0.8, inplace=False)\n        )\n\n    See also:\n\n        :py:class:`monai.networks.layers.Dropout`\n        :py:class:`monai.networks.layers.Act`\n        :py:class:`monai.networks.layers.Norm`\n        :py:class:`monai.networks.layers.split_args`\n\n    \"\"\"\n\n    def __init__(\n        self,\n        ordering: str = \"NDA\",\n        in_channels: int | None = None,\n        act: tuple | str | None = \"RELU\",\n        norm: tuple | str | None = None,\n        norm_dim: int | None = None,\n        dropout: tuple | str | float | None = None,\n        dropout_dim: int | None = None,\n    ) -> None:\n        super().__init__()\n\n        op_dict = {\"A\": None, \"D\": None, \"N\": None}\n        # define the normalization type and the arguments to the constructor\n        if norm is not None:\n            if norm_dim is None and dropout_dim is None:\n                raise ValueError(\"norm_dim or dropout_dim needs to be specified.\")\n            op_dict[\"N\"] = get_norm_layer(name=norm, spatial_dims=norm_dim or dropout_dim, channels=in_channels)\n\n        # define the activation type and the arguments to the constructor\n        if act is not None:\n            op_dict[\"A\"] = get_act_layer(act)\n\n        if dropout is not None:\n            if norm_dim is None and dropout_dim is None:\n                raise ValueError(\"norm_dim or dropout_dim needs to be specified.\")\n            op_dict[\"D\"] = get_dropout_layer(name=dropout, dropout_dim=dropout_dim or norm_dim)\n\n        for item in ordering.upper():\n            if item not in op_dict:\n                raise ValueError(f\"ordering must be a string of {op_dict}, got {item} in it.\")\n            if op_dict[item] is not None:\n                self.add_module(item, op_dict[item])\n"
  },
  {
    "path": "monai/networks/blocks/activation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch import nn\n\nfrom monai.utils import optional_import\n\nif optional_import(\"torch.nn.functional\", name=\"mish\")[1]:\n\n    def monai_mish(x, inplace: bool = False):\n        return torch.nn.functional.mish(x, inplace=inplace)\n\nelse:\n\n    def monai_mish(x, inplace: bool = False):\n        return x * torch.tanh(torch.nn.functional.softplus(x))\n\n\nif optional_import(\"torch.nn.functional\", name=\"silu\")[1]:\n\n    def monai_swish(x, inplace: bool = False):\n        return torch.nn.functional.silu(x, inplace=inplace)\n\nelse:\n\n    def monai_swish(x, inplace: bool = False):\n        return SwishImplementation.apply(x)\n\n\nclass Swish(nn.Module):\n    r\"\"\"Applies the element-wise function:\n\n    .. math::\n        \\text{Swish}(x) = x * \\text{Sigmoid}(\\alpha * x) ~~~~\\text{for constant value}~ \\alpha.\n\n    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.\n\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n\n    Examples::\n\n        >>> import torch\n        >>> from monai.networks.layers.factories import Act\n        >>> m = Act['swish']()\n        >>> input = torch.randn(2)\n        >>> output = m(input)\n    \"\"\"\n\n    def __init__(self, alpha=1.0):\n        super().__init__()\n        self.alpha = alpha\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return input * torch.sigmoid(self.alpha * input)\n\n\nclass SwishImplementation(torch.autograd.Function):\n    r\"\"\"Memory efficient implementation for training\n    Follows recommendation from:\n    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853\n\n    Results in ~ 30% memory saving during training as compared to Swish()\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input):\n        result = input * torch.sigmoid(input)\n        ctx.save_for_backward(input)\n        return result\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input = ctx.saved_tensors[0]\n        sigmoid_input = torch.sigmoid(input)\n        return grad_output * (sigmoid_input * (1 + input * (1 - sigmoid_input)))\n\n\nclass MemoryEfficientSwish(nn.Module):\n    r\"\"\"Applies the element-wise function:\n\n    .. math::\n        \\text{Swish}(x) = x * \\text{Sigmoid}(\\alpha * x) ~~~~\\text{for constant value}~ \\alpha=1.\n\n    Memory efficient implementation for training following recommendation from:\n    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853\n\n    Results in ~ 30% memory saving during training as compared to Swish()\n\n    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.\n\n    From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented,\n    this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version.\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional\n          dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n\n    Examples::\n\n        >>> import torch\n        >>> from monai.networks.layers.factories import Act\n        >>> m = Act['memswish']()\n        >>> input = torch.randn(2)\n        >>> output = m(input)\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super().__init__()\n        # inplace only works when using torch.nn.functional.silu\n        self.inplace = inplace\n\n    def forward(self, input: torch.Tensor):\n        return monai_swish(input, self.inplace)\n\n\nclass Mish(nn.Module):\n    r\"\"\"Applies the element-wise function:\n\n    .. math::\n        \\text{Mish}(x) = x * tanh(\\text{softplus}(x)).\n\n    Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.\n\n    From Pytorch 1.9.0+, the optimized version of `Mish` is implemented,\n    this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version.\n\n    Shape:\n        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions\n        - Output: :math:`(N, *)`, same shape as the input\n\n\n    Examples::\n\n        >>> import torch\n        >>> from monai.networks.layers.factories import Act\n        >>> m = Act['mish']()\n        >>> input = torch.randn(2)\n        >>> output = m(input)\n    \"\"\"\n\n    def __init__(self, inplace: bool = False):\n        super().__init__()\n        # inplace only works when using torch.nn.functional.mish\n        self.inplace = inplace\n\n    def forward(self, input: torch.Tensor):\n        return monai_mish(input, self.inplace)\n\n\nclass GEGLU(nn.Module):\n    r\"\"\"Applies the element-wise function:\n\n    .. math::\n        \\text{GEGLU}(x) = x_1 * \\text{GELU}(x_2)\n\n    where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension.\n\n    Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202.\n\n    Shape:\n        - Input: :math:`(N, *, 2 * D)`\n        - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions\n\n    Examples::\n\n        >>> import torch\n        >>> from monai.networks.layers.factories import Act\n        >>> m = Act['geglu']()\n        >>> input = torch.randn(2, 8)  # last dim must be even\n        >>> output = m(input)\n    \"\"\"\n\n    def forward(self, input: torch.Tensor):\n        x, gate = input.chunk(2, dim=-1)\n        return x * nn.functional.gelu(gate)\n"
  },
  {
    "path": "monai/networks/blocks/activation_checkpointing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import cast\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils.checkpoint import checkpoint\n\n\nclass ActivationCheckpointWrapper(nn.Module):\n    \"\"\"Wrapper applying activation checkpointing to a module during training.\n\n    Args:\n        module: The module to wrap with activation checkpointing.\n    \"\"\"\n\n    def __init__(self, module: nn.Module) -> None:\n        super().__init__()\n        self.module = module\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass with optional activation checkpointing.\n\n        Args:\n            x: Input tensor.\n\n        Returns:\n            Output tensor from the wrapped module.\n        \"\"\"\n        return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))\n"
  },
  {
    "path": "monai/networks/blocks/aspp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.layers import same_padding\nfrom monai.networks.layers.factories import Conv\n\n\nclass SimpleASPP(nn.Module):\n    \"\"\"\n    A simplified version of the atrous spatial pyramid pooling (ASPP) module.\n\n    Chen et al., Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.\n    https://arxiv.org/abs/1802.02611\n\n    Wang et al., A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions\n    from CT Images. https://ieeexplore.ieee.org/document/9109297\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        conv_out_channels: int,\n        kernel_sizes: Sequence[int] = (1, 3, 3, 3),\n        dilations: Sequence[int] = (1, 2, 4, 6),\n        norm_type: tuple | str | None = \"BATCH\",\n        acti_type: tuple | str | None = \"LEAKYRELU\",\n        bias: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n            in_channels: number of input channels.\n            conv_out_channels: number of output channels of each atrous conv.\n                The final number of output channels is conv_out_channels * len(kernel_sizes).\n            kernel_sizes: a sequence of four convolutional kernel sizes.\n                Defaults to (1, 3, 3, 3) for four (dilated) convolutions.\n            dilations: a sequence of four convolutional dilation parameters.\n                Defaults to (1, 2, 4, 6) for four (dilated) convolutions.\n            norm_type: final kernel-size-one convolution normalization type.\n                Defaults to batch norm.\n            acti_type: final kernel-size-one convolution activation type.\n                Defaults to leaky ReLU.\n            bias: whether to have a bias term in convolution blocks. Defaults to False.\n                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n                if a conv layer is directly followed by a batch norm layer, bias should be False.\n\n        Raises:\n            ValueError: When ``kernel_sizes`` length differs from ``dilations``.\n\n        See also:\n\n            :py:class:`monai.networks.layers.Act`\n            :py:class:`monai.networks.layers.Conv`\n            :py:class:`monai.networks.layers.Norm`\n\n        \"\"\"\n        super().__init__()\n        if len(kernel_sizes) != len(dilations):\n            raise ValueError(\n                \"kernel_sizes and dilations length must match, \"\n                f\"got kernel_sizes={len(kernel_sizes)} dilations={len(dilations)}.\"\n            )\n        pads = tuple(same_padding(k, d) for k, d in zip(kernel_sizes, dilations))\n\n        self.convs = nn.ModuleList()\n        for k, d, p in zip(kernel_sizes, dilations, pads):\n            _conv = Conv[Conv.CONV, spatial_dims](\n                in_channels=in_channels, out_channels=conv_out_channels, kernel_size=k, dilation=d, padding=p\n            )\n            self.convs.append(_conv)\n\n        out_channels = conv_out_channels * len(pads)  # final conv. output channels\n        self.conv_k1 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=out_channels,\n            out_channels=out_channels,\n            kernel_size=1,\n            act=acti_type,\n            norm=norm_type,\n            bias=bias,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: in shape (batch, channel, spatial_1[, spatial_2, ...]).\n        \"\"\"\n        x_out = torch.cat([conv(x) for conv in self.convs], dim=1)\n        x_out = self.conv_k1(x_out)\n        return x_out\n"
  },
  {
    "path": "monai/networks/blocks/attention_utils.py",
    "content": "# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\ndef get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Get relative positional embeddings according to the relative positions of\n    query and key sizes.\n\n    Args:\n        q_size (int): size of query q.\n        k_size (int): size of key k.\n        rel_pos (Tensor): relative position embeddings (L, C).\n\n    Returns:\n        Extracted positional embeddings according to relative positions.\n    \"\"\"\n    rel_pos_resized: torch.Tensor = torch.Tensor()\n    max_rel_dist = int(2 * max(q_size, k_size) - 1)\n    # Interpolate rel pos if needed.\n    if rel_pos.shape[0] != max_rel_dist:\n        # Interpolate rel pos.\n        rel_pos_resized = F.interpolate(\n            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode=\"linear\"\n        )\n        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)\n    else:\n        rel_pos_resized = rel_pos\n\n    # Scale the coords with short length if shapes for q and k are different.\n    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)\n    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)\n    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)\n\n    return rel_pos_resized[relative_coords.long()]\n\n\ndef add_decomposed_rel_pos(\n    attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: tuple, k_size: tuple\n) -> torch.Tensor:\n    r\"\"\"\n    Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:\n    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py\n\n    Only 2D and 3D are supported.\n\n    Encoding the relative position of tokens in the attention matrix: tokens spaced a distance\n    `d` apart will have the same embedding value (unlike absolute positional embedding).\n\n    .. math::\n        Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale\n\n    where\n\n    .. math::\n        E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}\n\n    with :math:`R_{p(i), p(j)} \\in R^{dim}` and :math:`p(i), p(j)`,\n    respectively spatial positions of element :math:`i` and :math:`j`\n\n    When using \"decomposed\" relative positional embedding, positional embedding is defined (\"decomposed\") as follow:\n\n    .. math::\n        R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}\n\n    with :math:`n = 1...dim`\n\n    Decomposed relative positional embedding reduces the complexity from :math:`\\mathcal{O}(d1*...*dn)` to\n    :math:`\\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.\n\n    Args:\n        attn (Tensor): attention map.\n        q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).\n        rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.\n        q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).\n        k_size (Tuple): spatial sequence size of key k with (k_dim_1, ...,  k_dim_n).\n\n    Returns:\n        attn (Tensor): attention logits with added relative positional embeddings.\n    \"\"\"\n    rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])\n    rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])\n\n    batch, _, dim = q.shape\n\n    if len(rel_pos_lst) == 2:\n        q_h, q_w = q_size[:2]\n        k_h, k_w = k_size[:2]\n        r_q = q.reshape(batch, q_h, q_w, dim)\n        rel_h = torch.einsum(\"bhwc,hkc->bhwk\", r_q, rh)\n        rel_w = torch.einsum(\"bhwc,wkc->bhwk\", r_q, rw)\n\n        attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(\n            batch, q_h * q_w, k_h * k_w\n        )\n    elif len(rel_pos_lst) == 3:\n        q_h, q_w, q_d = q_size[:3]\n        k_h, k_w, k_d = k_size[:3]\n\n        rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])\n\n        r_q = q.reshape(batch, q_h, q_w, q_d, dim)\n        rel_h = torch.einsum(\"bhwdc,hkc->bhwdk\", r_q, rh)\n        rel_w = torch.einsum(\"bhwdc,wkc->bhwdk\", r_q, rw)\n        rel_d = torch.einsum(\"bhwdc,wkc->bhwdk\", r_q, rd)\n\n        attn = (\n            attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)\n            + rel_h[:, :, :, :, None, None]\n            + rel_w[:, :, :, None, :, None]\n            + rel_d[:, :, :, None, None, :]\n        ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)\n\n    return attn\n"
  },
  {
    "path": "monai/networks/blocks/backbone_fpn_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n#\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\"\"\"\nThis script is modified from from torchvision to support N-D images,\nby overriding the definition of convolutional layers and pooling layers.\n\nhttps://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom torch import Tensor, nn\n\nfrom monai.networks.nets import resnet\nfrom monai.utils import optional_import\n\nfrom .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool\n\ntorchvision_models, _ = optional_import(\"torchvision.models\")\n\n__all__ = [\"BackboneWithFPN\"]\n\n\nclass BackboneWithFPN(nn.Module):\n    \"\"\"\n    Adds an FPN on top of a model.\n    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to\n    extract a submodel that returns the feature maps specified in return_layers.\n    The same limitations of IntermediateLayerGetter apply here.\n\n    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py\n    Except that this class uses spatial_dims\n\n    Args:\n        backbone: backbone network\n        return_layers: a dict containing the names\n            of the modules for which the activations will be returned as\n            the key of the dict, and the value of the dict is the name\n            of the returned activation (which the user can specify).\n        in_channels_list: number of channels for each feature map\n            that is returned, in the order they are present in the OrderedDict\n        out_channels: number of channels in the FPN.\n        spatial_dims: 2D or 3D images\n    \"\"\"\n\n    def __init__(\n        self,\n        backbone: nn.Module,\n        return_layers: dict[str, str],\n        in_channels_list: list[int],\n        out_channels: int,\n        spatial_dims: int | None = None,\n        extra_blocks: ExtraFPNBlock | None = None,\n    ) -> None:\n        super().__init__()\n\n        # if spatial_dims is not specified, try to find it from backbone.\n        if spatial_dims is None:\n            if hasattr(backbone, \"spatial_dims\") and isinstance(backbone.spatial_dims, int):\n                spatial_dims = backbone.spatial_dims\n            elif isinstance(backbone.conv1, nn.Conv2d):\n                spatial_dims = 2\n            elif isinstance(backbone.conv1, nn.Conv3d):\n                spatial_dims = 3\n            else:\n                raise ValueError(\"Could not find spatial_dims of backbone, please specify it.\")\n\n        if extra_blocks is None:\n            extra_blocks = LastLevelMaxPool(spatial_dims)\n\n        self.body = torchvision_models._utils.IntermediateLayerGetter(backbone, return_layers=return_layers)\n        self.fpn = FeaturePyramidNetwork(\n            spatial_dims=spatial_dims,\n            in_channels_list=in_channels_list,\n            out_channels=out_channels,\n            extra_blocks=extra_blocks,\n        )\n        self.out_channels = out_channels\n\n    def forward(self, x: Tensor) -> dict[str, Tensor]:\n        \"\"\"\n        Computes the resulted feature maps of the network.\n\n        Args:\n            x: input images\n\n        Returns:\n            feature maps after FPN layers. They are ordered from highest resolution first.\n        \"\"\"\n        x = self.body(x)  # backbone\n        y: dict[str, Tensor] = self.fpn(x)  # FPN\n        return y\n\n\ndef _resnet_fpn_extractor(\n    backbone: resnet.ResNet,\n    spatial_dims: int,\n    trainable_layers: int = 5,\n    returned_layers: list[int] | None = None,\n    extra_blocks: ExtraFPNBlock | None = None,\n) -> BackboneWithFPN:\n    \"\"\"\n    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py\n    Except that ``in_channels_stage2 = backbone.in_planes // 8`` instead of ``in_channels_stage2 = backbone.inplanes // 8``,\n    and it requires spatial_dims: 2D or 3D images.\n    \"\"\"\n\n    # select layers that wont be frozen\n    if trainable_layers < 0 or trainable_layers > 5:\n        raise ValueError(f\"Trainable layers should be in the range [0,5], got {trainable_layers}\")\n    layers_to_train = [\"layer4\", \"layer3\", \"layer2\", \"layer1\", \"conv1\"][:trainable_layers]\n    if trainable_layers == 5:\n        layers_to_train.append(\"bn1\")\n    for name, parameter in backbone.named_parameters():\n        if all(not name.startswith(layer) for layer in layers_to_train):\n            parameter.requires_grad_(False)\n\n    if extra_blocks is None:\n        extra_blocks = LastLevelMaxPool(spatial_dims)\n\n    if returned_layers is None:\n        returned_layers = [1, 2, 3, 4]\n    if min(returned_layers) <= 0 or max(returned_layers) >= 5:\n        raise ValueError(f\"Each returned layer should be in the range [1,4]. Got {returned_layers}\")\n    return_layers = {f\"layer{k}\": str(v) for v, k in enumerate(returned_layers)}\n\n    in_channels_stage2 = backbone.in_planes // 8\n    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]\n    out_channels = 256\n    return BackboneWithFPN(\n        backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, spatial_dims=spatial_dims\n    )\n"
  },
  {
    "path": "monai/networks/blocks/cablock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nfrom typing import cast\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.utils import optional_import\n\nrearrange, _ = optional_import(\"einops\", name=\"rearrange\")\n\n__all__ = [\"FeedForward\", \"CABlock\"]\n\n\nclass FeedForward(nn.Module):\n    \"\"\"Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.\n    Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (2D or 3D)\n        dim: Number of input channels\n        ffn_expansion_factor: Factor to expand hidden features dimension\n        bias: Whether to use bias in convolution layers\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):\n        super().__init__()\n        hidden_features = int(dim * ffn_expansion_factor)\n\n        self.project_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=dim,\n            out_channels=hidden_features * 2,\n            kernel_size=1,\n            bias=bias,\n            conv_only=True,\n        )\n\n        self.dwconv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_features * 2,\n            out_channels=hidden_features * 2,\n            kernel_size=3,\n            strides=1,\n            padding=1,\n            groups=hidden_features * 2,\n            bias=bias,\n            conv_only=True,\n        )\n\n        self.project_out = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_features,\n            out_channels=dim,\n            kernel_size=1,\n            bias=bias,\n            conv_only=True,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.project_in(x)\n        x1, x2 = self.dwconv(x).chunk(2, dim=1)\n        return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2))\n\n\nclass CABlock(nn.Module):\n    \"\"\"Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention\n    by operating on feature channels instead of spatial dimensions. Incorporates depth-wise\n    convolutions for local mixing before attention, achieving linear complexity vs quadratic\n    in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>\n\n    Args:\n        spatial_dims: Number of spatial dimensions (2D or 3D)\n        dim: Number of input channels\n        num_heads: Number of attention heads\n        bias: Whether to use bias in convolution layers\n        flash_attention: Whether to use flash attention optimization. Defaults to False.\n\n    Raises:\n        ValueError: If flash attention is not available in current PyTorch version\n        ValueError: If spatial_dims is greater than 3\n    \"\"\"\n\n    def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):\n        super().__init__()\n        if flash_attention and not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ValueError(\"Flash attention not available\")\n        if spatial_dims > 3:\n            raise ValueError(f\"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}\")\n        self.spatial_dims = spatial_dims\n        self.num_heads = num_heads\n        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))\n        self.flash_attention = flash_attention\n\n        self.qkv = Convolution(\n            spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True\n        )\n\n        self.qkv_dwconv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=dim * 3,\n            out_channels=dim * 3,\n            kernel_size=3,\n            strides=1,\n            padding=1,\n            groups=dim * 3,\n            bias=bias,\n            conv_only=True,\n        )\n\n        self.project_out = Convolution(\n            spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True\n        )\n\n        self._attention_fn = self._get_attention_fn()\n\n    def _get_attention_fn(self):\n        if self.flash_attention:\n            return self._flash_attention\n        return self._normal_attention\n\n    def _flash_attention(self, q, k, v):\n        \"\"\"Flash attention implementation using scaled dot-product attention.\"\"\"\n        scale = float(self.temperature.mean())\n        out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False)\n        return out\n\n    def _normal_attention(self, q, k, v):\n        \"\"\"Attention matrix multiplication with depth-wise convolutions.\"\"\"\n        attn = (q @ k.transpose(-2, -1)) * self.temperature\n        attn = attn.softmax(dim=-1)\n        return attn @ v\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass for MDTA attention.\n        1. Apply depth-wise convolutions to Q, K, V\n        2. Reshape Q, K, V for multi-head attention\n        3. Compute attention matrix using flash or normal attention\n        4. Reshape and project out attention output\"\"\"\n        spatial_dims = x.shape[2:]\n\n        # Project and mix\n        qkv = self.qkv_dwconv(self.qkv(x))\n        q, k, v = qkv.chunk(3, dim=1)\n\n        # Select attention\n        if self.spatial_dims == 2:\n            qkv_to_multihead = \"b (head c) h w -> b head c (h w)\"\n            multihead_to_qkv = \"b head c (h w) -> b (head c) h w\"\n        else:  # dims == 3\n            qkv_to_multihead = \"b (head c) d h w -> b head c (d h w)\"\n            multihead_to_qkv = \"b head c (d h w) -> b (head c) d h w\"\n\n        # Reconstruct and project feature map\n        q = rearrange(q, qkv_to_multihead, head=self.num_heads)\n        k = rearrange(k, qkv_to_multihead, head=self.num_heads)\n        v = rearrange(v, qkv_to_multihead, head=self.num_heads)\n\n        q = torch.nn.functional.normalize(q, dim=-1)\n        k = torch.nn.functional.normalize(k, dim=-1)\n\n        out = self._attention_fn(q, k, v)\n        out = rearrange(\n            out,\n            multihead_to_qkv,\n            head=self.num_heads,\n            **dict(zip([\"h\", \"w\"] if self.spatial_dims == 2 else [\"d\", \"h\", \"w\"], spatial_dims)),\n        )\n\n        return cast(torch.Tensor, self.project_out(out))\n"
  },
  {
    "path": "monai/networks/blocks/convolutions.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import ADN\nfrom monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding\nfrom monai.networks.layers.factories import Conv\n\n\nclass Convolution(nn.Sequential):\n    \"\"\"\n    Constructs a convolution with normalization, optional dropout, and optional activation layers::\n\n        -- (Conv|ConvTrans) -- (Norm -- Dropout -- Acti) --\n\n    if ``conv_only`` set to ``True``::\n\n        -- (Conv|ConvTrans) --\n\n    For example:\n\n    .. code-block:: python\n\n        from monai.networks.blocks import Convolution\n\n        conv = Convolution(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=1,\n            adn_ordering=\"ADN\",\n            act=(\"prelu\", {\"init\": 0.2}),\n            dropout=0.1,\n            norm=(\"layer\", {\"normalized_shape\": (10, 10, 10)}),\n        )\n        print(conv)\n\n    output::\n\n        Convolution(\n          (conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n          (adn): ADN(\n            (A): PReLU(num_parameters=1)\n            (D): Dropout(p=0.1, inplace=False)\n            (N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True)\n          )\n        )\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        strides: convolution stride. Defaults to 1.\n        kernel_size: convolution kernel size. Defaults to 3.\n        adn_ordering: a string representing the ordering of activation, normalization, and dropout.\n            Defaults to \"NDA\".\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        dropout: dropout ratio. Defaults to no dropout.\n        dropout_dim: determine the spatial dimensions of dropout. Defaults to 1.\n\n            - When dropout_dim = 1, randomly zeroes some of the elements for each channel.\n            - When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).\n            - When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).\n\n            The value of dropout_dim should be no larger than the value of `spatial_dims`.\n        dilation: dilation rate. Defaults to 1.\n        groups: controls the connections between inputs and outputs. Defaults to 1.\n        bias: whether to have a bias term. Defaults to True.\n        conv_only: whether to use the convolutional layer only. Defaults to False.\n        is_transposed: if True uses ConvTrans instead of Conv. Defaults to False.\n        padding: controls the amount of implicit zero-paddings on both sides for padding number of points\n            for each dimension. Defaults to None.\n        output_padding: controls the additional size added to one side of the output shape.\n            Defaults to None.\n\n    See also:\n\n        :py:class:`monai.networks.layers.Conv`\n        :py:class:`monai.networks.blocks.ADN`\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        strides: Sequence[int] | int = 1,\n        kernel_size: Sequence[int] | int = 3,\n        adn_ordering: str = \"NDA\",\n        act: tuple | str | None = \"PRELU\",\n        norm: tuple | str | None = \"INSTANCE\",\n        dropout: tuple | str | float | None = None,\n        dropout_dim: int | None = 1,\n        dilation: Sequence[int] | int = 1,\n        groups: int = 1,\n        bias: bool = True,\n        conv_only: bool = False,\n        is_transposed: bool = False,\n        padding: Sequence[int] | int | None = None,\n        output_padding: Sequence[int] | int | None = None,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.is_transposed = is_transposed\n        if padding is None:\n            padding = same_padding(kernel_size, dilation)\n        conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.spatial_dims]\n\n        conv: nn.Module\n        if is_transposed:\n            if output_padding is None:\n                output_padding = stride_minus_kernel_padding(1, strides)\n            conv = conv_type(\n                in_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=strides,\n                padding=padding,\n                output_padding=output_padding,\n                groups=groups,\n                bias=bias,\n                dilation=dilation,\n            )\n        else:\n            conv = conv_type(\n                in_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=strides,\n                padding=padding,\n                dilation=dilation,\n                groups=groups,\n                bias=bias,\n            )\n\n        self.add_module(\"conv\", conv)\n\n        if conv_only:\n            return\n        if act is None and norm is None and dropout is None:\n            return\n        self.add_module(\n            \"adn\",\n            ADN(\n                ordering=adn_ordering,\n                in_channels=out_channels,\n                act=act,\n                norm=norm,\n                norm_dim=self.spatial_dims,\n                dropout=dropout,\n                dropout_dim=dropout_dim,\n            ),\n        )\n\n\nclass ResidualUnit(nn.Module):\n    \"\"\"\n    Residual module with multiple convolutions and a residual connection.\n\n    For example:\n\n    .. code-block:: python\n\n        from monai.networks.blocks import ResidualUnit\n\n        convs = ResidualUnit(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=1,\n            adn_ordering=\"AN\",\n            act=(\"prelu\", {\"init\": 0.2}),\n            norm=(\"layer\", {\"normalized_shape\": (10, 10, 10)}),\n        )\n        print(convs)\n\n    output::\n\n        ResidualUnit(\n          (conv): Sequential(\n            (unit0): Convolution(\n              (conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n              (adn): ADN(\n                (A): PReLU(num_parameters=1)\n                (N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True)\n              )\n            )\n            (unit1): Convolution(\n              (conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))\n              (adn): ADN(\n                (A): PReLU(num_parameters=1)\n                (N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True)\n              )\n            )\n          )\n          (residual): Identity()\n        )\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        strides: convolution stride. Defaults to 1.\n        kernel_size: convolution kernel size. Defaults to 3.\n        subunits: number of convolutions. Defaults to 2.\n        adn_ordering: a string representing the ordering of activation, normalization, and dropout.\n            Defaults to \"NDA\".\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        dropout: dropout ratio. Defaults to no dropout.\n        dropout_dim: determine the dimensions of dropout. Defaults to 1.\n\n            - When dropout_dim = 1, randomly zeroes some of the elements for each channel.\n            - When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map).\n            - When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map).\n\n            The value of dropout_dim should be no larger than the value of `dimensions`.\n        dilation: dilation rate. Defaults to 1.\n        bias: whether to have a bias term. Defaults to True.\n        last_conv_only: for the last subunit, whether to use the convolutional layer only.\n            Defaults to False.\n        padding: controls the amount of implicit zero-paddings on both sides for padding number of points\n            for each dimension. Defaults to None.\n\n    See also:\n\n        :py:class:`monai.networks.blocks.Convolution`\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        strides: Sequence[int] | int = 1,\n        kernel_size: Sequence[int] | int = 3,\n        subunits: int = 2,\n        adn_ordering: str = \"NDA\",\n        act: tuple | str | None = \"PRELU\",\n        norm: tuple | str | None = \"INSTANCE\",\n        dropout: tuple | str | float | None = None,\n        dropout_dim: int | None = 1,\n        dilation: Sequence[int] | int = 1,\n        bias: bool = True,\n        last_conv_only: bool = False,\n        padding: Sequence[int] | int | None = None,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.conv = nn.Sequential()\n        self.residual = nn.Identity()\n        if not padding:\n            padding = same_padding(kernel_size, dilation)\n        schannels = in_channels\n        sstrides = strides\n        subunits = max(1, subunits)\n\n        for su in range(subunits):\n            conv_only = last_conv_only and su == (subunits - 1)\n            unit = Convolution(\n                self.spatial_dims,\n                schannels,\n                out_channels,\n                strides=sstrides,\n                kernel_size=kernel_size,\n                adn_ordering=adn_ordering,\n                act=act,\n                norm=norm,\n                dropout=dropout,\n                dropout_dim=dropout_dim,\n                dilation=dilation,\n                bias=bias,\n                conv_only=conv_only,\n                padding=padding,\n            )\n\n            self.conv.add_module(f\"unit{su:d}\", unit)\n\n            # after first loop set channels and strides to what they should be for subsequent units\n            schannels = out_channels\n            sstrides = 1\n\n        # apply convolution to input to change number of output channels and size to match that coming from self.conv\n        if np.prod(strides) != 1 or in_channels != out_channels:\n            rkernel_size = kernel_size\n            rpadding = padding\n\n            if np.prod(strides) == 1:  # if only adapting number of channels a 1x1 kernel is used with no padding\n                rkernel_size = 1\n                rpadding = 0\n\n            conv_type = Conv[Conv.CONV, self.spatial_dims]\n            self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        res: torch.Tensor = self.residual(x)  # create the additive residual from x\n        cx: torch.Tensor = self.conv(x)  # apply x to sequence of operations\n        return cx + res  # add the residual to the output\n"
  },
  {
    "path": "monai/networks/blocks/crf.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch.nn.functional import softmax\n\nfrom monai.networks.layers.filtering import PHLFilter\nfrom monai.networks.utils import meshgrid_ij\n\n__all__ = [\"CRF\"]\n\n\nclass CRF(torch.nn.Module):\n    \"\"\"\n    Conditional Random Field: Combines message passing with a class\n    compatibility convolution into an iterative process designed\n    to successively minimise the energy of the class labeling.\n\n    In this implementation, the message passing step is a weighted\n    combination of a gaussian filter and a bilateral filter.\n    The bilateral term is included to respect existing structure\n    within the reference tensor.\n\n    See:\n        https://arxiv.org/abs/1502.03240\n    \"\"\"\n\n    def __init__(\n        self,\n        iterations: int = 5,\n        bilateral_weight: float = 1.0,\n        gaussian_weight: float = 1.0,\n        bilateral_spatial_sigma: float = 5.0,\n        bilateral_color_sigma: float = 0.5,\n        gaussian_spatial_sigma: float = 5.0,\n        update_factor: float = 3.0,\n        compatibility_matrix: torch.Tensor | None = None,\n    ):\n        \"\"\"\n        Args:\n            iterations: the number of iterations.\n            bilateral_weight: the weighting of the bilateral term in the message passing step.\n            gaussian_weight: the weighting of the gaussian term in the message passing step.\n            bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term.\n            bilateral_color_sigma: standard deviation in color space for the bilateral term.\n            gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term.\n            update_factor: determines the magnitude of each update.\n            compatibility_matrix: a matrix describing class compatibility,\n                should be NxN where N is the number of classes.\n        \"\"\"\n        super().__init__()\n        self.iterations = iterations\n        self.bilateral_weight = bilateral_weight\n        self.gaussian_weight = gaussian_weight\n        self.bilateral_spatial_sigma = bilateral_spatial_sigma\n        self.bilateral_color_sigma = bilateral_color_sigma\n        self.gaussian_spatial_sigma = gaussian_spatial_sigma\n        self.update_factor = update_factor\n        self.compatibility_matrix = compatibility_matrix\n\n    def forward(self, input_tensor: torch.Tensor, reference_tensor: torch.Tensor):\n        \"\"\"\n        Args:\n            input_tensor: tensor containing initial class logits.\n            reference_tensor: the reference tensor used to guide the message passing.\n\n        Returns:\n            output (torch.Tensor): output tensor.\n        \"\"\"\n\n        # constructing spatial feature tensor\n        spatial_features = _create_coordinate_tensor(reference_tensor)\n\n        # constructing final feature tensors for bilateral and gaussian kernel\n        bilateral_features = torch.cat(\n            [spatial_features / self.bilateral_spatial_sigma, reference_tensor / self.bilateral_color_sigma], dim=1\n        )\n        gaussian_features = spatial_features / self.gaussian_spatial_sigma\n\n        # setting up output tensor\n        output_tensor = softmax(input_tensor, dim=1)\n\n        # mean field loop\n        for _ in range(self.iterations):\n            # message passing step for both kernels\n            bilateral_output = PHLFilter.apply(output_tensor, bilateral_features)\n            gaussian_output = PHLFilter.apply(output_tensor, gaussian_features)\n\n            # combining filter outputs\n            combined_output = self.bilateral_weight * bilateral_output + self.gaussian_weight * gaussian_output\n\n            # optionally running a compatibility transform\n            if self.compatibility_matrix is not None:\n                flat = combined_output.flatten(start_dim=2).permute(0, 2, 1)\n                flat = torch.matmul(flat, self.compatibility_matrix)\n                combined_output = flat.permute(0, 2, 1).reshape(combined_output.shape)\n\n            # update and normalize\n            output_tensor = softmax(input_tensor + self.update_factor * combined_output, dim=1)\n\n        return output_tensor\n\n\n# helper methods\ndef _create_coordinate_tensor(tensor):\n    axes = [torch.arange(tensor.size(i)) for i in range(2, tensor.dim())]\n    grids = meshgrid_ij(axes)\n    coords = torch.stack(grids).to(device=tensor.device, dtype=tensor.dtype)\n    return torch.stack(tensor.size(0) * [coords], dim=0)\n"
  },
  {
    "path": "monai/networks/blocks/crossattention.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.layers.utils import get_rel_pos_embedding_layer\nfrom monai.utils import optional_import\n\nRearrange, _ = optional_import(\"einops.layers.torch\", name=\"Rearrange\")\n\n\nclass CrossAttentionBlock(nn.Module):\n    \"\"\"\n    A cross-attention block, based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n    One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        dropout_rate: float = 0.0,\n        hidden_input_size: int | None = None,\n        context_input_size: int | None = None,\n        dim_head: int | None = None,\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n        causal: bool = False,\n        sequence_length: int | None = None,\n        rel_pos_embedding: str | None = None,\n        input_size: tuple | None = None,\n        attention_dtype: torch.dtype | None = None,\n        use_flash_attention: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            hidden_size (int): dimension of hidden layer.\n            num_heads (int): number of attention heads.\n            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.\n            hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.\n            context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.\n            dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.\n            qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.\n            save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.\n            causal (bool, optional): whether to use causal attention.\n            sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.\n            rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only\n                \"decomposed\" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.\n            input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional\n                parameter size.\n            attention_dtype: cast attention operations to this dtype.\n            use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n                (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n\n        if dim_head:\n            inner_size = num_heads * dim_head\n            self.head_dim = dim_head\n        else:\n            if hidden_size % num_heads != 0:\n                raise ValueError(\"hidden size should be divisible by num_heads.\")\n            inner_size = hidden_size\n            self.head_dim = hidden_size // num_heads\n\n        if causal and sequence_length is None:\n            raise ValueError(\"sequence_length is necessary for causal attention.\")\n\n        if use_flash_attention and save_attn:\n            raise ValueError(\n                \"save_attn has been set to True, but use_flash_attention is also set\"\n                \"to True. save_attn can only be used if use_flash_attention is False\"\n            )\n\n        if use_flash_attention and rel_pos_embedding is not None:\n            raise ValueError(\"rel_pos_embedding must be None if you are using flash_attention.\")\n\n        self.num_heads = num_heads\n        self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size\n        self.context_input_size = context_input_size if context_input_size else hidden_size\n        self.out_proj = nn.Linear(inner_size, self.hidden_input_size)\n        # key, query, value projections\n        self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)\n        self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)\n        self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)\n        self.input_rearrange = Rearrange(\"b h (l d) -> b l h d\", l=num_heads)\n\n        self.out_rearrange = Rearrange(\"b l h d -> b h (l d)\")\n        self.drop_output = nn.Dropout(dropout_rate)\n        self.drop_weights = nn.Dropout(dropout_rate)\n        self.dropout_rate = dropout_rate\n\n        self.scale = self.head_dim**-0.5\n        self.save_attn = save_attn\n        self.attention_dtype = attention_dtype\n\n        self.causal = causal\n        self.sequence_length = sequence_length\n        self.use_flash_attention = use_flash_attention\n\n        if causal and sequence_length is not None:\n            # causal mask to ensure that attention is only applied to the left in the input sequence\n            self.register_buffer(\n                \"causal_mask\",\n                torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),\n            )\n            self.causal_mask: torch.Tensor\n        else:\n            self.causal_mask = torch.Tensor()\n\n        self.att_mat = torch.Tensor()\n        self.rel_positional_embedding = (\n            get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)\n            if rel_pos_embedding is not None\n            else None\n        )\n        self.input_size = input_size\n\n    def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):\n        \"\"\"\n        Args:\n            x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C\n            context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C\n\n        Return:\n            torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C\n        \"\"\"\n        # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n        b, t, c = x.size()  # batch size, sequence length, embedding dimensionality (hidden_size)\n\n        q = self.input_rearrange(self.to_q(x))\n        kv = context if context is not None else x\n        _, kv_t, _ = kv.size()\n        k = self.input_rearrange(self.to_k(kv))\n        v = self.input_rearrange(self.to_v(kv))\n\n        if self.attention_dtype is not None:\n            q = q.to(self.attention_dtype)\n            k = k.to(self.attention_dtype)\n\n        if self.use_flash_attention:\n            x = torch.nn.functional.scaled_dot_product_attention(\n                query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal\n            )\n        else:\n            att_mat = torch.einsum(\"blxd,blyd->blxy\", q, k) * self.scale\n            # apply relative positional embedding if defined\n            if self.rel_positional_embedding is not None:\n                att_mat = self.rel_positional_embedding(x, att_mat, q)\n\n            if self.causal:\n                att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float(\"-inf\"))\n\n            att_mat = att_mat.softmax(dim=-1)\n\n            if self.save_attn:\n                # no gradients and new tensor;\n                # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html\n                self.att_mat = att_mat.detach()\n\n            att_mat = self.drop_weights(att_mat)\n            x = torch.einsum(\"bhxy,bhyd->bhxd\", att_mat, v)\n\n        x = self.out_rearrange(x)\n        x = self.out_proj(x)\n        x = self.drop_output(x)\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/denseblock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution, ResidualUnit\nfrom monai.networks.layers.factories import Act, Norm\n\n__ALL__ = [\"DenseBlock\", \"ConvDenseBlock\"]\n\n\nclass DenseBlock(nn.Sequential):\n    \"\"\"\n    A DenseBlock is a sequence of layers where each layer's outputs are concatenated with their inputs. This has the\n    effect of accumulating outputs from previous layers as inputs to later ones and as the final output of the block.\n\n    Args:\n        layers: sequence of nn.Module objects to define the individual layers of the dense block\n    \"\"\"\n\n    def __init__(self, layers: Sequence[nn.Module]):\n        super().__init__()\n        for i, l in enumerate(layers):\n            self.add_module(f\"layers{i}\", l)\n\n    def forward(self, x):\n        for l in self.children():\n            result = l(x)\n            x = torch.cat([x, result], 1)\n\n        return x\n\n\nclass ConvDenseBlock(DenseBlock):\n    \"\"\"\n    This dense block is defined as a sequence of `Convolution` or `ResidualUnit` blocks. The `_get_layer` method returns\n    an object for each layer and can be overridden to change the composition of the block.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        channels: output channels for each layer.\n        dilations: dilation value for each layer.\n        kernel_size: convolution kernel size. Defaults to 3.\n        num_res_units: number of convolutions. Defaults to 2.\n        adn_ordering: a string representing the ordering of activation, normalization, and dropout. Defaults to \"NDA\".\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        dropout: dropout ratio. Defaults to no dropout.\n        bias: whether to have a bias term. Defaults to True.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        channels: Sequence[int],\n        dilations: Sequence[int] | None = None,\n        kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 0,\n        adn_ordering: str = \"NDA\",\n        act: tuple | str | None = Act.PRELU,\n        norm: tuple | str | None = Norm.INSTANCE,\n        dropout: tuple | str | float | None = None,\n        bias: bool = True,\n    ):\n        self.spatial_dims = spatial_dims\n        self.kernel_size = kernel_size\n        self.num_res_units = num_res_units\n        self.adn_ordering = adn_ordering\n        self.act = act\n        self.norm = norm\n        self.dropout = dropout\n        self.bias = bias\n\n        l_channels = in_channels\n        dilations = dilations if dilations is not None else ([1] * len(channels))\n        layers = []\n\n        if len(channels) != len(dilations):\n            raise ValueError(\"Length of `channels` and `dilations` must match\")\n\n        for c, d in zip(channels, dilations):\n            layer = self._get_layer(l_channels, c, d)\n            layers.append(layer)\n            l_channels += c\n\n        super().__init__(layers)\n\n    def _get_layer(self, in_channels, out_channels, dilation):\n        if self.num_res_units > 0:\n            return ResidualUnit(\n                spatial_dims=self.spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=self.kernel_size,\n                subunits=self.num_res_units,\n                adn_ordering=self.adn_ordering,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                dilation=dilation,\n                bias=self.bias,\n            )\n        else:\n            return Convolution(\n                spatial_dims=self.spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=self.kernel_size,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                dilation=dilation,\n                bias=self.bias,\n            )\n"
  },
  {
    "path": "monai/networks/blocks/dints_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.networks.layers.factories import Conv\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\n\n__all__ = [\"FactorizedIncreaseBlock\", \"FactorizedReduceBlock\", \"P3DActiConvNormBlock\", \"ActiConvNormBlock\"]\n\n\nclass FactorizedIncreaseBlock(torch.nn.Sequential):\n    \"\"\"\n    Up-sampling the features by two using linear interpolation and convolutions.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        \"\"\"\n        Args:\n            in_channel: number of input channels\n            out_channel: number of output channels\n            spatial_dims: number of spatial dimensions\n            act_name: activation layer type and arguments.\n            norm_name: feature normalization type and arguments.\n        \"\"\"\n        super().__init__()\n        self._in_channel = in_channel\n        self._out_channel = out_channel\n        self._spatial_dims = spatial_dims\n        if self._spatial_dims not in (2, 3):\n            raise ValueError(\"spatial_dims must be 2 or 3.\")\n\n        conv_type = Conv[Conv.CONV, self._spatial_dims]\n        mode = \"trilinear\" if self._spatial_dims == 3 else \"bilinear\"\n        self.add_module(\"up\", torch.nn.Upsample(scale_factor=2, mode=mode, align_corners=True))\n        self.add_module(\"acti\", get_act_layer(name=act_name))\n        self.add_module(\n            \"conv\",\n            conv_type(\n                in_channels=self._in_channel,\n                out_channels=self._out_channel,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                groups=1,\n                bias=False,\n                dilation=1,\n            ),\n        )\n        self.add_module(\n            \"norm\", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)\n        )\n\n\nclass FactorizedReduceBlock(torch.nn.Module):\n    \"\"\"\n    Down-sampling the feature by 2 using stride.\n    The length along each spatial dimension must be a multiple of 2.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        \"\"\"\n        Args:\n            in_channel: number of input channels\n            out_channel: number of output channels.\n            spatial_dims: number of spatial dimensions.\n            act_name: activation layer type and arguments.\n            norm_name: feature normalization type and arguments.\n        \"\"\"\n        super().__init__()\n        self._in_channel = in_channel\n        self._out_channel = out_channel\n        self._spatial_dims = spatial_dims\n        if self._spatial_dims not in (2, 3):\n            raise ValueError(\"spatial_dims must be 2 or 3.\")\n\n        conv_type = Conv[Conv.CONV, self._spatial_dims]\n\n        self.act = get_act_layer(name=act_name)\n        self.conv_1 = conv_type(\n            in_channels=self._in_channel,\n            out_channels=self._out_channel // 2,\n            kernel_size=1,\n            stride=2,\n            padding=0,\n            groups=1,\n            bias=False,\n            dilation=1,\n        )\n        self.conv_2 = conv_type(\n            in_channels=self._in_channel,\n            out_channels=self._out_channel - self._out_channel // 2,\n            kernel_size=1,\n            stride=2,\n            padding=0,\n            groups=1,\n            bias=False,\n            dilation=1,\n        )\n        self.norm = get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        The length along each spatial dimension must be a multiple of 2.\n        \"\"\"\n        x = self.act(x)\n        if self._spatial_dims == 3:\n            out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1)\n        else:\n            out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)\n        out = self.norm(out)\n        return out\n\n\nclass P3DActiConvNormBlock(torch.nn.Sequential):\n    \"\"\"\n    -- (act) -- (conv) -- (norm) --\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        kernel_size: int,\n        padding: int,\n        mode: int = 0,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        \"\"\"\n        Args:\n            in_channel: number of input channels.\n            out_channel: number of output channels.\n            kernel_size: kernel size to be expanded to 3D.\n            padding: padding size to be expanded to 3D.\n            mode: mode for the anisotropic kernels:\n\n                - 0: ``(k, k, 1)``, ``(1, 1, k)``,\n                - 1: ``(k, 1, k)``, ``(1, k, 1)``,\n                - 2: ``(1, k, k)``. ``(k, 1, 1)``.\n\n            act_name: activation layer type and arguments.\n            norm_name: feature normalization type and arguments.\n        \"\"\"\n        super().__init__()\n        self._in_channel = in_channel\n        self._out_channel = out_channel\n        self._p3dmode = int(mode)\n\n        conv_type = Conv[Conv.CONV, 3]\n\n        if self._p3dmode == 0:  # (k, k, 1), (1, 1, k)\n            kernel_size0 = (kernel_size, kernel_size, 1)\n            kernel_size1 = (1, 1, kernel_size)\n            padding0 = (padding, padding, 0)\n            padding1 = (0, 0, padding)\n        elif self._p3dmode == 1:  # (k, 1, k), (1, k, 1)\n            kernel_size0 = (kernel_size, 1, kernel_size)\n            kernel_size1 = (1, kernel_size, 1)\n            padding0 = (padding, 0, padding)\n            padding1 = (0, padding, 0)\n        elif self._p3dmode == 2:  # (1, k, k), (k, 1, 1)\n            kernel_size0 = (1, kernel_size, kernel_size)\n            kernel_size1 = (kernel_size, 1, 1)\n            padding0 = (0, padding, padding)\n            padding1 = (padding, 0, 0)\n        else:\n            raise ValueError(\"`mode` must be 0, 1, or 2.\")\n\n        self.add_module(\"acti\", get_act_layer(name=act_name))\n        self.add_module(\n            \"conv\",\n            conv_type(\n                in_channels=self._in_channel,\n                out_channels=self._in_channel,\n                kernel_size=kernel_size0,\n                stride=1,\n                padding=padding0,\n                groups=1,\n                bias=False,\n                dilation=1,\n            ),\n        )\n        self.add_module(\n            \"conv_1\",\n            conv_type(\n                in_channels=self._in_channel,\n                out_channels=self._out_channel,\n                kernel_size=kernel_size1,\n                stride=1,\n                padding=padding1,\n                groups=1,\n                bias=False,\n                dilation=1,\n            ),\n        )\n        self.add_module(\"norm\", get_norm_layer(name=norm_name, spatial_dims=3, channels=self._out_channel))\n\n\nclass ActiConvNormBlock(torch.nn.Sequential):\n    \"\"\"\n    -- (Acti) -- (Conv) -- (Norm) --\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        kernel_size: int = 3,\n        padding: int = 1,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        \"\"\"\n        Args:\n            in_channel: number of input channels.\n            out_channel: number of output channels.\n            kernel_size: kernel size of the convolution.\n            padding: padding size of the convolution.\n            spatial_dims: number of spatial dimensions.\n            act_name: activation layer type and arguments.\n            norm_name: feature normalization type and arguments.\n        \"\"\"\n        super().__init__()\n        self._in_channel = in_channel\n        self._out_channel = out_channel\n        self._spatial_dims = spatial_dims\n\n        conv_type = Conv[Conv.CONV, self._spatial_dims]\n        self.add_module(\"acti\", get_act_layer(name=act_name))\n        self.add_module(\n            \"conv\",\n            conv_type(\n                in_channels=self._in_channel,\n                out_channels=self._out_channel,\n                kernel_size=kernel_size,\n                stride=1,\n                padding=padding,\n                groups=1,\n                bias=False,\n                dilation=1,\n            ),\n        )\n        self.add_module(\n            \"norm\", get_norm_layer(name=norm_name, spatial_dims=self._spatial_dims, channels=self._out_channel)\n        )\n"
  },
  {
    "path": "monai/networks/blocks/downsample.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.layers.factories import Conv, Pool\nfrom monai.networks.utils import pixelunshuffle\nfrom monai.utils import DownsampleMode, ensure_tuple_rep, look_up_option\n\n__all__ = [\"MaxAvgPool\", \"DownSample\", \"Downsample\", \"SubpixelDownsample\", \"SubpixelDownSample\", \"Subpixeldownsample\"]\n\n\nclass MaxAvgPool(nn.Module):\n    \"\"\"\n    Downsample with both maxpooling and avgpooling,\n    double the channel size by concatenating the downsampled feature maps.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        kernel_size: Sequence[int] | int,\n        stride: Sequence[int] | int | None = None,\n        padding: Sequence[int] | int = 0,\n        ceil_mode: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            kernel_size: the kernel size of both pooling operations.\n            stride: the stride of the window. Default value is `kernel_size`.\n            padding: implicit zero padding to be added to both pooling operations.\n            ceil_mode: when True, will use ceil instead of floor to compute the output shape.\n        \"\"\"\n        super().__init__()\n        _params = {\n            \"kernel_size\": ensure_tuple_rep(kernel_size, spatial_dims),\n            \"stride\": None if stride is None else ensure_tuple_rep(stride, spatial_dims),\n            \"padding\": ensure_tuple_rep(padding, spatial_dims),\n            \"ceil_mode\": ceil_mode,\n        }\n        self.max_pool = Pool[Pool.MAX, spatial_dims](**_params)\n        self.avg_pool = Pool[Pool.AVG, spatial_dims](**_params)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...]).\n\n        Returns:\n            Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]).\n        \"\"\"\n        return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1)\n\n\nclass DownSample(nn.Sequential):\n    \"\"\"\n    Downsamples data by `scale_factor`.\n\n    Supported modes are:\n\n    - \"conv\": uses a strided convolution for learnable downsampling.\n    - \"convgroup\": uses a grouped strided convolution for efficient feature reduction.\n    - \"nontrainable\": uses :py:class:`torch.nn.Upsample` with inverse scale factor.\n    - \"pixelunshuffle\": uses :py:class:`monai.networks.blocks.PixelUnshuffle` for channel-space rearrangement.\n\n    This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.\n    Please check the link below for more details:\n    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms\n\n    This module can optionally take a pre-convolution\n    (often used to map the number of features from `in_channels` to `out_channels`).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int | None = None,\n        out_channels: int | None = None,\n        scale_factor: Sequence[float] | float = 2,\n        kernel_size: Sequence[float] | float | None = None,\n        mode: DownsampleMode | str = DownsampleMode.CONV,\n        pre_conv: nn.Module | str | None = \"default\",\n        post_conv: nn.Module | None = None,\n        bias: bool = True,\n    ) -> None:\n        \"\"\"\n        Downsamples data by `scale_factor`.\n        Supported modes are:\n\n            - DownsampleMode.CONV: uses a strided convolution for learnable downsampling.\n            - DownsampleMode.CONVGROUP: uses a grouped strided convolution for efficient feature reduction.\n            - DownsampleMode.MAXPOOL: uses maxpooling for non-learnable downsampling.\n            - DownsampleMode.AVGPOOL: uses average pooling for non-learnable downsampling.\n            - DownsampleMode.PIXELUNSHUFFLE: uses :py:class:`monai.networks.blocks.SubpixelDownsample`.\n\n        This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.\n        Please check the link below for more details:\n        https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms\n\n        This module can optionally take a pre-convolution and post-convolution\n        (often used to map the number of features from `in_channels` to `out_channels`).\n\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of channels of the input image.\n            out_channels: number of channels of the output image. Defaults to `in_channels`.\n            scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2.\n            kernel_size: kernel size used during convolutions. Defaults to `scale_factor`.\n            mode: {``DownsampleMode.CONV``, ``DownsampleMode.CONVGROUP``, ``DownsampleMode.MAXPOOL``, ``DownsampleMode.AVGPOOL``,\n                ``DownsampleMode.PIXELUNSHUFFLE``}. Defaults to ``DownsampleMode.CONV``.\n            pre_conv: a conv block applied before downsampling. Defaults to \"default\".\n                When ``conv_block`` is ``\"default\"``, one reserved conv layer will be utilized.\n                Only used in the \"maxpool\", \"avgpool\" or \"pixelunshuffle\" modes.\n            post_conv: a conv block applied after downsampling. Defaults to None. Only used in the \"maxpool\" and \"avgpool\" modes.\n            bias: whether to have a bias term in the default preconv and conv layers. Defaults to True.\n        \"\"\"\n        super().__init__()\n\n        scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)\n        down_mode = look_up_option(mode, DownsampleMode)\n\n        if not kernel_size:\n            kernel_size_ = scale_factor_\n            padding = ensure_tuple_rep(0, spatial_dims)\n        else:\n            kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)\n            padding = tuple((k - 1) // 2 for k in kernel_size_)\n\n        if down_mode == DownsampleMode.CONV:\n            if not in_channels:\n                raise ValueError(\"in_channels needs to be specified in conv mode\")\n            self.add_module(\n                \"conv\",\n                Conv[Conv.CONV, spatial_dims](\n                    in_channels=in_channels,\n                    out_channels=out_channels or in_channels,\n                    kernel_size=kernel_size_,\n                    stride=scale_factor_,\n                    padding=padding,\n                    bias=bias,\n                ),\n            )\n        elif down_mode == DownsampleMode.CONVGROUP:\n            if not in_channels:\n                raise ValueError(\"in_channels needs to be specified\")\n            if out_channels is None:\n                out_channels = in_channels\n            groups = in_channels if out_channels % in_channels == 0 else 1\n            self.add_module(\n                \"convgroup\",\n                Conv[Conv.CONV, spatial_dims](\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    kernel_size=kernel_size_,\n                    stride=scale_factor_,\n                    padding=padding,\n                    groups=groups,\n                    bias=bias,\n                ),\n            )\n        elif down_mode == DownsampleMode.MAXPOOL:\n            if pre_conv == \"default\" and (out_channels != in_channels):\n                if not in_channels:\n                    raise ValueError(\"in_channels needs to be specified\")\n                self.add_module(\n                    \"preconv\",\n                    Conv[Conv.CONV, spatial_dims](\n                        in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias\n                    ),\n                )\n            self.add_module(\n                \"maxpool\", Pool[Pool.MAX, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding)\n            )\n            if post_conv:\n                self.add_module(\"postconv\", post_conv)\n\n        elif down_mode == DownsampleMode.AVGPOOL:\n            if pre_conv == \"default\" and (out_channels != in_channels):\n                if not in_channels:\n                    raise ValueError(\"in_channels needs to be specified\")\n                self.add_module(\n                    \"preconv\",\n                    Conv[Conv.CONV, spatial_dims](\n                        in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias\n                    ),\n                )\n            self.add_module(\n                \"avgpool\", Pool[Pool.AVG, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding)\n            )\n            if post_conv:\n                self.add_module(\"postconv\", post_conv)\n\n        elif down_mode == DownsampleMode.PIXELUNSHUFFLE:\n            self.add_module(\n                \"pixelunshuffle\",\n                SubpixelDownsample(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    scale_factor=scale_factor_[0],\n                    conv_block=pre_conv,\n                    bias=bias,\n                ),\n            )\n\n\nclass SubpixelDownsample(nn.Module):\n    \"\"\"\n    Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.\n    The module consists of two parts. First, a convolutional layer is employed\n    to adjust the number of channels. Secondly, a pixel unshuffle manipulation\n    rearranges the spatial information into channel space, effectively reducing\n    spatial dimensions while increasing channel depth.\n\n    The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions\n    from (B, C, H*r, W*r) to (B, C*r², H, W) for 2D images or from (B, C, H*r, W*r, D*r) to (B, C*r³, H, W, D) in 3D case.\n\n    Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2).\n\n    See: Shi et al., 2016, \"Real-Time Single Image and Video Super-Resolution\n    Using an Efficient Sub-Pixel Convolutional Neural Network.\"\n\n    The pixel unshuffle mechanism is the inverse operation of:\n    https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int | None,\n        out_channels: int | None = None,\n        scale_factor: int = 2,\n        conv_block: nn.Module | str | None = \"default\",\n        bias: bool = True,\n    ) -> None:\n        \"\"\"\n        Downsamples data by rearranging spatial information into channel space.\n        This reduces spatial dimensions while increasing channel depth.\n\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of channels of the input image.\n            out_channels: optional number of channels of the output image.\n            scale_factor: factor to reduce the spatial dimensions by. Defaults to 2.\n            conv_block: a conv block to adjust channels before downsampling. Defaults to None.\n                When ``conv_block`` is ``\"default\"``, one reserved conv layer will be utilized.\n                When ``conv_block`` is an ``nn.module``,\n                please ensure the input number of channels matches requirements.\n            bias: whether to have a bias term in the default conv_block. Defaults to True.\n        \"\"\"\n        super().__init__()\n\n        if scale_factor <= 0:\n            raise ValueError(f\"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.\")\n\n        self.dimensions = spatial_dims\n        self.scale_factor = scale_factor\n\n        if conv_block == \"default\":\n            if not in_channels:\n                raise ValueError(\"in_channels need to be specified.\")\n            out_channels = out_channels or in_channels\n            self.conv_block = Conv[Conv.CONV, self.dimensions](\n                in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=bias\n            )\n        elif conv_block is None:\n            self.conv_block = nn.Identity()\n        else:\n            self.conv_block = conv_block\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).\n        Returns:\n            Tensor with reduced spatial dimensions and increased channel depth.\n        \"\"\"\n        x = self.conv_block(x)\n        if not all(d % self.scale_factor == 0 for d in x.shape[2:]):\n            raise ValueError(\n                f\"All spatial dimensions {x.shape[2:]} must be evenly \" f\"divisible by scale_factor {self.scale_factor}\"\n            )\n        x = pixelunshuffle(x, self.dimensions, self.scale_factor)\n        return x\n\n\nDownsample = DownSample\nSubpixelDownSample = Subpixeldownsample = SubpixelDownsample\n"
  },
  {
    "path": "monai/networks/blocks/dynunet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.layers.factories import Act, Norm\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\n\n\nclass UnetResBlock(nn.Module):\n    \"\"\"\n    A skip-connection based module that can be used for DynUNet, based on:\n    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.\n    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        kernel_size: convolution kernel size.\n        stride: convolution stride.\n        norm_name: feature normalization type and arguments.\n        act_name: activation layer type and arguments.\n        dropout: dropout probability.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[int] | int,\n        stride: Sequence[int] | int,\n        norm_name: tuple | str,\n        act_name: tuple | str = (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.01}),\n        dropout: tuple | str | float | None = None,\n    ):\n        super().__init__()\n        self.conv1 = get_conv_layer(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            dropout=dropout,\n            act=None,\n            norm=None,\n            conv_only=False,\n        )\n        self.conv2 = get_conv_layer(\n            spatial_dims,\n            out_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            dropout=dropout,\n            act=None,\n            norm=None,\n            conv_only=False,\n        )\n        self.lrelu = get_act_layer(name=act_name)\n        self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)\n        self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)\n        self.downsample = in_channels != out_channels\n        stride_np = np.atleast_1d(stride)\n        if not np.all(stride_np == 1):\n            self.downsample = True\n        if self.downsample:\n            self.conv3 = get_conv_layer(\n                spatial_dims,\n                in_channels,\n                out_channels,\n                kernel_size=1,\n                stride=stride,\n                dropout=dropout,\n                act=None,\n                norm=None,\n                conv_only=False,\n            )\n            self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)\n\n    def forward(self, inp):\n        residual = inp\n        out = self.conv1(inp)\n        out = self.norm1(out)\n        out = self.lrelu(out)\n        out = self.conv2(out)\n        out = self.norm2(out)\n        if hasattr(self, \"conv3\"):\n            residual = self.conv3(residual)\n        if hasattr(self, \"norm3\"):\n            residual = self.norm3(residual)\n        out += residual\n        out = self.lrelu(out)\n        return out\n\n\nclass UnetBasicBlock(nn.Module):\n    \"\"\"\n    A CNN module that can be used for DynUNet, based on:\n    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.\n    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        kernel_size: convolution kernel size.\n        stride: convolution stride.\n        norm_name: feature normalization type and arguments.\n        act_name: activation layer type and arguments.\n        dropout: dropout probability.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[int] | int,\n        stride: Sequence[int] | int,\n        norm_name: tuple | str,\n        act_name: tuple | str = (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.01}),\n        dropout: tuple | str | float | None = None,\n    ):\n        super().__init__()\n        self.conv1 = get_conv_layer(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            dropout=dropout,\n            act=None,\n            norm=None,\n            conv_only=False,\n        )\n        self.conv2 = get_conv_layer(\n            spatial_dims,\n            out_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            dropout=dropout,\n            act=None,\n            norm=None,\n            conv_only=False,\n        )\n        self.lrelu = get_act_layer(name=act_name)\n        self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)\n        self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels)\n\n    def forward(self, inp):\n        out = self.conv1(inp)\n        out = self.norm1(out)\n        out = self.lrelu(out)\n        out = self.conv2(out)\n        out = self.norm2(out)\n        out = self.lrelu(out)\n        return out\n\n\nclass UnetUpBlock(nn.Module):\n    \"\"\"\n    An upsampling module that can be used for DynUNet, based on:\n    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.\n    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        kernel_size: convolution kernel size.\n        stride: convolution stride.\n        upsample_kernel_size: convolution kernel size for transposed convolution layers.\n        norm_name: feature normalization type and arguments.\n        act_name: activation layer type and arguments.\n        dropout: dropout probability.\n        trans_bias: transposed convolution bias.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[int] | int,\n        stride: Sequence[int] | int,\n        upsample_kernel_size: Sequence[int] | int,\n        norm_name: tuple | str,\n        act_name: tuple | str = (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.01}),\n        dropout: tuple | str | float | None = None,\n        trans_bias: bool = False,\n    ):\n        super().__init__()\n        upsample_stride = upsample_kernel_size\n        self.transp_conv = get_conv_layer(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            kernel_size=upsample_kernel_size,\n            stride=upsample_stride,\n            dropout=dropout,\n            bias=trans_bias,\n            act=None,\n            norm=None,\n            conv_only=False,\n            is_transposed=True,\n        )\n        self.conv_block = UnetBasicBlock(\n            spatial_dims,\n            out_channels + out_channels,\n            out_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            dropout=dropout,\n            norm_name=norm_name,\n            act_name=act_name,\n        )\n\n    def forward(self, inp, skip):\n        # number of channels for skip should equals to out_channels\n        out = self.transp_conv(inp)\n        out = torch.cat((out, skip), dim=1)\n        out = self.conv_block(out)\n        return out\n\n\nclass UnetOutBlock(nn.Module):\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, dropout: tuple | str | float | None = None\n    ):\n        super().__init__()\n        self.conv = get_conv_layer(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            kernel_size=1,\n            stride=1,\n            dropout=dropout,\n            bias=True,\n            act=None,\n            norm=None,\n            conv_only=False,\n        )\n\n    def forward(self, inp):\n        return self.conv(inp)\n\n\ndef get_conv_layer(\n    spatial_dims: int,\n    in_channels: int,\n    out_channels: int,\n    kernel_size: Sequence[int] | int = 3,\n    stride: Sequence[int] | int = 1,\n    act: tuple | str | None = Act.PRELU,\n    norm: tuple | str | None = Norm.INSTANCE,\n    dropout: tuple | str | float | None = None,\n    bias: bool = False,\n    conv_only: bool = True,\n    is_transposed: bool = False,\n):\n    padding = get_padding(kernel_size, stride)\n    output_padding = None\n    if is_transposed:\n        output_padding = get_output_padding(kernel_size, stride, padding)\n    return Convolution(\n        spatial_dims,\n        in_channels,\n        out_channels,\n        strides=stride,\n        kernel_size=kernel_size,\n        act=act,\n        norm=norm,\n        dropout=dropout,\n        bias=bias,\n        conv_only=conv_only,\n        is_transposed=is_transposed,\n        padding=padding,\n        output_padding=output_padding,\n    )\n\n\ndef get_padding(kernel_size: Sequence[int] | int, stride: Sequence[int] | int) -> tuple[int, ...] | int:\n    kernel_size_np = np.atleast_1d(kernel_size)\n    stride_np = np.atleast_1d(stride)\n    padding_np = (kernel_size_np - stride_np + 1) / 2\n    if np.min(padding_np) < 0:\n        raise AssertionError(\"padding value should not be negative, please change the kernel size and/or stride.\")\n    padding = tuple(int(p) for p in padding_np)\n\n    return padding if len(padding) > 1 else padding[0]\n\n\ndef get_output_padding(\n    kernel_size: Sequence[int] | int, stride: Sequence[int] | int, padding: Sequence[int] | int\n) -> tuple[int, ...] | int:\n    kernel_size_np = np.atleast_1d(kernel_size)\n    stride_np = np.atleast_1d(stride)\n    padding_np = np.atleast_1d(padding)\n\n    out_padding_np = 2 * padding_np + stride_np - kernel_size_np\n    if np.min(out_padding_np) < 0:\n        raise AssertionError(\"out_padding value should not be negative, please change the kernel size and/or stride.\")\n    out_padding = tuple(int(p) for p in out_padding_np)\n\n    return out_padding if len(out_padding) > 1 else out_padding[0]\n"
  },
  {
    "path": "monai/networks/blocks/encoder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom abc import ABCMeta, abstractmethod\n\n__all__ = [\"BaseEncoder\"]\n\n\nclass BaseEncoder(metaclass=ABCMeta):\n    \"\"\"\n    Abstract class defines interface of encoders in flexible unet.\n    Encoders in flexible unet must derive from this class. Each interface method\n    should return a list containing relative information about a series of networks\n    defined by encoder. For example, the efficient-net encoder implement 10 basic\n    network structures in one encoder. When calling `get_encoder_name_string_list`\n    function, a string list like [\"efficientnet-b0\", \"efficientnet-b1\" ... \"efficientnet-l2\"]\n    should be returned.\n    \"\"\"\n\n    @classmethod\n    @abstractmethod\n    def get_encoder_parameters(cls) -> list[dict]:\n        \"\"\"\n        Get parameter list to initialize encoder networks.\n        Each parameter dict must have `spatial_dims`, `in_channels`\n        and `pretrained` parameters.\n        The reason that this function should return a list is that a\n        series of encoders can be implemented by one encoder class\n        given different initialization parameters. Each parameter dict\n        in return list should be able to initialize a unique encoder.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def num_channels_per_output(cls) -> list[tuple[int, ...]]:\n        \"\"\"\n        Get number of output features' channels.\n        The reason that this function should return a list is that a\n        series of encoders can be implemented by one encoder class\n        given different initialization parameters. And it is possible\n        that different encoders have different output feature map\n        channels. Therefore a list of output feature map channel tuples\n        corresponding to each encoder should be returned by this method.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def num_outputs(cls) -> list[int]:\n        \"\"\"\n        Get number of outputs of encoder.\n        The reason that this function should return a list is that a\n        series of encoders can be implemented by one encoder class\n        given different initialization parameters. And it is possible\n        that different encoders have different output feature numbers.\n        Therefore a list of output feature numbers corresponding to\n        each encoder should be returned by this method.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def get_encoder_names(cls) -> list[str]:\n        \"\"\"\n        Get the name string of encoders which will be used to initialize\n        flexible unet.\n        The reason that this function should return a list is that a\n        series of encoders can be implemented by one encoder class\n        given different initialization parameters. And a name string is\n        the key to each encoder in flexible unet backbone registry.\n        Therefore this method should return every encoder name that needs\n        to be registered in flexible unet.\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "monai/networks/blocks/fcn.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.blocks.upsample import UpSample\nfrom monai.networks.layers.factories import Act, Conv, Norm\nfrom monai.utils import optional_import\n\nmodels, _ = optional_import(\"torchvision\", name=\"models\")\n\n\nclass GCN(nn.Module):\n    \"\"\"\n    The Global Convolutional Network module using large 1D\n    Kx1 and 1xK kernels to represent 2D kernels.\n    \"\"\"\n\n    def __init__(self, inplanes: int, planes: int, ks: int = 7):\n        \"\"\"\n        Args:\n            inplanes: number of input channels.\n            planes: number of output channels.\n            ks: kernel size for one dimension. Defaults to 7.\n        \"\"\"\n        super().__init__()\n\n        conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2]\n        self.conv_l1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0))\n        self.conv_l2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=(1, ks), padding=(0, ks // 2))\n        self.conv_r1 = conv2d_type(in_channels=inplanes, out_channels=planes, kernel_size=(1, ks), padding=(0, ks // 2))\n        self.conv_r2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=(ks, 1), padding=(ks // 2, 0))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: in shape (batch, inplanes, spatial_1, spatial_2).\n        \"\"\"\n        x_l = self.conv_l1(x)\n        x_l = self.conv_l2(x_l)\n        x_r = self.conv_r1(x)\n        x_r = self.conv_r2(x_r)\n        x = x_l + x_r\n        return x\n\n\nclass Refine(nn.Module):\n    \"\"\"\n    Simple residual block to refine the details of the activation maps.\n    \"\"\"\n\n    def __init__(self, planes: int):\n        \"\"\"\n        Args:\n            planes: number of input channels.\n        \"\"\"\n        super().__init__()\n\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n        conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2]\n        norm2d_type: type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2]\n\n        self.bn = norm2d_type(num_features=planes)\n        self.relu = relu_type(inplace=True)\n        self.conv1 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=3, padding=1)\n        self.conv2 = conv2d_type(in_channels=planes, out_channels=planes, kernel_size=3, padding=1)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: in shape (batch, planes, spatial_1, spatial_2).\n        \"\"\"\n        residual = x\n        x = self.bn(x)\n        x = self.relu(x)\n        x = self.conv1(x)\n        x = self.bn(x)\n        x = self.relu(x)\n        x = self.conv2(x)\n\n        return residual + x\n\n\nclass FCN(nn.Module):\n    \"\"\"\n    2D FCN network with 3 input channels. The small decoder is built\n    with the GCN and Refine modules.\n    The code is adapted from `lsqshr's official 2D code <https://github.com/lsqshr/AH-Net/blob/master/net2d.py>`_.\n\n    Args:\n        out_channels: number of output channels. Defaults to 1.\n        upsample_mode: [``\"transpose\"``, ``\"bilinear\"``]\n            The mode of upsampling manipulations.\n            Using the second mode cannot guarantee the model's reproducibility. Defaults to ``bilinear``.\n\n            - ``transpose``, uses transposed convolution layers.\n            - ``bilinear``, uses bilinear interpolation.\n\n        pretrained: If True, returns a model pre-trained on ImageNet\n        progress: If True, displays a progress bar of the download to stderr.\n    \"\"\"\n\n    def __init__(\n        self, out_channels: int = 1, upsample_mode: str = \"bilinear\", pretrained: bool = True, progress: bool = True\n    ):\n        super().__init__()\n\n        conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2]\n\n        self.upsample_mode = upsample_mode\n        self.conv2d_type = conv2d_type\n        self.out_channels = out_channels\n        resnet = models.resnet50(\n            progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None\n        )\n\n        self.conv1 = resnet.conv1\n        self.bn0 = resnet.bn1\n        self.relu = resnet.relu\n        self.maxpool = resnet.maxpool\n\n        self.layer1 = resnet.layer1\n        self.layer2 = resnet.layer2\n        self.layer3 = resnet.layer3\n        self.layer4 = resnet.layer4\n\n        self.gcn1 = GCN(2048, self.out_channels)\n        self.gcn2 = GCN(1024, self.out_channels)\n        self.gcn3 = GCN(512, self.out_channels)\n        self.gcn4 = GCN(64, self.out_channels)\n        self.gcn5 = GCN(64, self.out_channels)\n\n        self.refine1 = Refine(self.out_channels)\n        self.refine2 = Refine(self.out_channels)\n        self.refine3 = Refine(self.out_channels)\n        self.refine4 = Refine(self.out_channels)\n        self.refine5 = Refine(self.out_channels)\n        self.refine6 = Refine(self.out_channels)\n        self.refine7 = Refine(self.out_channels)\n        self.refine8 = Refine(self.out_channels)\n        self.refine9 = Refine(self.out_channels)\n        self.refine10 = Refine(self.out_channels)\n        self.transformer = self.conv2d_type(in_channels=256, out_channels=64, kernel_size=1)\n\n        if self.upsample_mode == \"transpose\":\n            self.up_conv = UpSample(spatial_dims=2, in_channels=self.out_channels, scale_factor=2, mode=\"deconv\")\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\n        Args:\n            x: in shape (batch, 3, spatial_1, spatial_2).\n        \"\"\"\n        org_input = x\n        x = self.conv1(x)\n        x = self.bn0(x)\n        x = self.relu(x)\n        conv_x = x\n        x = self.maxpool(x)\n        pool_x = x\n\n        fm1 = self.layer1(x)\n        fm2 = self.layer2(fm1)\n        fm3 = self.layer3(fm2)\n        fm4 = self.layer4(fm3)\n\n        gcfm1 = self.refine1(self.gcn1(fm4))\n        gcfm2 = self.refine2(self.gcn2(fm3))\n        gcfm3 = self.refine3(self.gcn3(fm2))\n        gcfm4 = self.refine4(self.gcn4(pool_x))\n        gcfm5 = self.refine5(self.gcn5(conv_x))\n\n        if self.upsample_mode == \"transpose\":\n            fs1 = self.refine6(self.up_conv(gcfm1) + gcfm2)\n            fs2 = self.refine7(self.up_conv(fs1) + gcfm3)\n            fs3 = self.refine8(self.up_conv(fs2) + gcfm4)\n            fs4 = self.refine9(self.up_conv(fs3) + gcfm5)\n            return self.refine10(self.up_conv(fs4))\n        fs1 = self.refine6(F.interpolate(gcfm1, fm3.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm2)\n        fs2 = self.refine7(F.interpolate(fs1, fm2.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm3)\n        fs3 = self.refine8(F.interpolate(fs2, pool_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm4)\n        fs4 = self.refine9(F.interpolate(fs3, conv_x.size()[2:], mode=self.upsample_mode, align_corners=True) + gcfm5)\n        return self.refine10(F.interpolate(fs4, org_input.size()[2:], mode=self.upsample_mode, align_corners=True))\n\n\nclass MCFCN(FCN):\n    \"\"\"\n    The multi-channel version of the 2D FCN module.\n    Adds a projection layer to take arbitrary number of inputs.\n\n    Args:\n        in_channels: number of input channels. Defaults to 3.\n        out_channels: number of output channels. Defaults to 1.\n        upsample_mode: [``\"transpose\"``, ``\"bilinear\"``]\n            The mode of upsampling manipulations.\n            Using the second mode cannot guarantee the model's reproducibility. Defaults to ``bilinear``.\n\n            - ``transpose``, uses transposed convolution layers.\n            - ``bilinear``, uses bilinear interpolate.\n        pretrained: If True, returns a model pre-trained on ImageNet\n        progress: If True, displays a progress bar of the download to stderr.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 1,\n        upsample_mode: str = \"bilinear\",\n        pretrained: bool = True,\n        progress: bool = True,\n    ):\n        super().__init__(\n            out_channels=out_channels, upsample_mode=upsample_mode, pretrained=pretrained, progress=progress\n        )\n\n        self.init_proj = Convolution(\n            spatial_dims=2,\n            in_channels=in_channels,\n            out_channels=3,\n            kernel_size=1,\n            act=(\"relu\", {\"inplace\": True}),\n            norm=Norm.BATCH,\n            bias=False,\n        )\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\n        Args:\n            x: in shape (batch, in_channels, spatial_1, spatial_2).\n        \"\"\"\n        x = self.init_proj(x)\n        return super().forward(x)\n"
  },
  {
    "path": "monai/networks/blocks/feature_pyramid_network.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py\n# which has the following license...\n# https://github.com/pytorch/vision/blob/main/LICENSE\n#\n# BSD 3-Clause License\n\n# Copyright (c) Soumith Chintala 2016,\n# All rights reserved.\n\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n\n# * Neither the name of the copyright holder nor the names of its\n#   contributors may be used to endorse or promote products derived from\n#   this software without specific prior written permission.\n\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\"\"\"\nThis script is modified from from torchvision to support N-D images,\nby overriding the definition of convolutional layers and pooling layers.\n\nhttps://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import OrderedDict\nfrom collections.abc import Callable\nfrom typing import cast\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor, nn\n\nfrom monai.networks.layers.factories import Conv, Pool\n\n__all__ = [\"ExtraFPNBlock\", \"LastLevelMaxPool\", \"LastLevelP6P7\", \"FeaturePyramidNetwork\"]\n\n\nclass ExtraFPNBlock(nn.Module):\n    \"\"\"\n    Base class for the extra block in the FPN.\n\n    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/ops/feature_pyramid_network.py\n    \"\"\"\n\n    def forward(self, results: list[Tensor], x: list[Tensor], names: list[str]):\n        \"\"\"\n        Compute extended set of results of the FPN and their names.\n\n        Args:\n            results: the result of the FPN\n            x: the original feature maps\n            names: the names for each one of the original feature maps\n\n        Returns:\n            - the extended set of results of the FPN\n            - the extended set of names for the results\n        \"\"\"\n\n\nclass LastLevelMaxPool(ExtraFPNBlock):\n    \"\"\"\n    Applies a max_pool2d or max_pool3d on top of the last feature map. Serves as an ``extra_blocks``\n    in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .\n    \"\"\"\n\n    def __init__(self, spatial_dims: int):\n        super().__init__()\n        pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n        self.maxpool = pool_type(kernel_size=1, stride=2, padding=0)\n\n    def forward(self, results: list[Tensor], x: list[Tensor], names: list[str]) -> tuple[list[Tensor], list[str]]:\n        names.append(\"pool\")\n        results.append(self.maxpool(results[-1]))\n        return results, names\n\n\nclass LastLevelP6P7(ExtraFPNBlock):\n    \"\"\"\n    This module is used in RetinaNet to generate extra layers, P6 and P7.\n    Serves as an ``extra_blocks``\n    in :class:`~monai.networks.blocks.feature_pyramid_network.FeaturePyramidNetwork` .\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int):\n        super().__init__()\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        self.p6 = conv_type(in_channels, out_channels, kernel_size=3, stride=2, padding=1)\n        self.p7 = conv_type(out_channels, out_channels, kernel_size=3, stride=2, padding=1)\n        for module in [self.p6, self.p7]:\n            nn.init.kaiming_uniform_(module.weight, a=1)\n            nn.init.constant_(module.bias, 0)\n        self.use_P5 = in_channels == out_channels\n\n    def forward(self, results: list[Tensor], x: list[Tensor], names: list[str]) -> tuple[list[Tensor], list[str]]:\n        p5, c5 = results[-1], x[-1]\n        x5 = p5 if self.use_P5 else c5\n        p6 = self.p6(x5)\n        p7 = self.p7(F.relu(p6))\n        results.extend([p6, p7])\n        names.extend([\"p6\", \"p7\"])\n        return results, names\n\n\nclass FeaturePyramidNetwork(nn.Module):\n    \"\"\"\n    Module that adds a FPN from on top of a set of feature maps. This is based on\n    `\"Feature Pyramid Network for Object Detection\" <https://arxiv.org/abs/1612.03144>`_.\n\n    The feature maps are currently supposed to be in increasing depth\n    order.\n\n    The input to the model is expected to be an OrderedDict[Tensor], containing\n    the feature maps on top of which the FPN will be added.\n\n    Args:\n        spatial_dims: 2D or 3D images\n        in_channels_list: number of channels for each feature map that\n            is passed to the module\n        out_channels: number of channels of the FPN representation\n        extra_blocks: if provided, extra operations will\n            be performed. It is expected to take the fpn features, the original\n            features and the names of the original features as input, and returns\n            a new list of feature maps and their corresponding names\n\n    Examples::\n\n        >>> m = FeaturePyramidNetwork(2, [10, 20, 30], 5)\n        >>> # get some dummy data\n        >>> x = OrderedDict()\n        >>> x['feat0'] = torch.rand(1, 10, 64, 64)\n        >>> x['feat2'] = torch.rand(1, 20, 16, 16)\n        >>> x['feat3'] = torch.rand(1, 30, 8, 8)\n        >>> # compute the FPN on top of x\n        >>> output = m(x)\n        >>> print([(k, v.shape) for k, v in output.items()])\n        >>> # returns\n        >>>   [('feat0', torch.Size([1, 5, 64, 64])),\n        >>>    ('feat2', torch.Size([1, 5, 16, 16])),\n        >>>    ('feat3', torch.Size([1, 5, 8, 8]))]\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels_list: list[int],\n        out_channels: int,\n        extra_blocks: ExtraFPNBlock | None = None,\n    ):\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n\n        self.inner_blocks = nn.ModuleList()\n        self.layer_blocks = nn.ModuleList()\n        for in_channels in in_channels_list:\n            if in_channels == 0:\n                raise ValueError(\"in_channels=0 is currently not supported\")\n            inner_block_module = conv_type(in_channels, out_channels, 1)\n            layer_block_module = conv_type(out_channels, out_channels, 3, padding=1)\n            self.inner_blocks.append(inner_block_module)\n            self.layer_blocks.append(layer_block_module)\n\n        # initialize parameters now to avoid modifying the initialization of top_blocks\n        conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims]\n        for m in self.modules():\n            if isinstance(m, conv_type_):\n                nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1)\n                nn.init.constant_(cast(torch.Tensor, m.bias), 0.0)\n\n        if extra_blocks is not None:\n            if not isinstance(extra_blocks, ExtraFPNBlock):\n                raise AssertionError\n        self.extra_blocks = extra_blocks\n\n    def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:\n        \"\"\"\n        This is equivalent to self.inner_blocks[idx](x),\n        but torchscript doesn't support this yet\n        \"\"\"\n        num_blocks = len(self.inner_blocks)\n        if idx < 0:\n            idx += num_blocks\n        out = x\n        for i, module in enumerate(self.inner_blocks):\n            if i == idx:\n                out = module(x)\n        return out\n\n    def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:\n        \"\"\"\n        This is equivalent to self.layer_blocks[idx](x),\n        but torchscript doesn't support this yet\n        \"\"\"\n        num_blocks = len(self.layer_blocks)\n        if idx < 0:\n            idx += num_blocks\n        out = x\n        for i, module in enumerate(self.layer_blocks):\n            if i == idx:\n                out = module(x)\n        return out\n\n    def forward(self, x: dict[str, Tensor]) -> dict[str, Tensor]:\n        \"\"\"\n        Computes the FPN for a set of feature maps.\n\n        Args:\n            x: feature maps for each feature level.\n\n        Returns:\n            feature maps after FPN layers. They are ordered from highest resolution first.\n        \"\"\"\n        # unpack OrderedDict into two lists for easier handling\n        names = list(x.keys())\n        x_values: list[Tensor] = list(x.values())\n\n        last_inner = self.get_result_from_inner_blocks(x_values[-1], -1)\n        results = []\n        results.append(self.get_result_from_layer_blocks(last_inner, -1))\n\n        for idx in range(len(x_values) - 2, -1, -1):\n            inner_lateral = self.get_result_from_inner_blocks(x_values[idx], idx)\n            feat_shape = inner_lateral.shape[2:]\n            inner_top_down = F.interpolate(last_inner, size=feat_shape, mode=\"nearest\")\n            last_inner = inner_lateral + inner_top_down\n            results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))\n\n        if self.extra_blocks is not None:\n            results, names = self.extra_blocks(results, x_values, names)\n\n        # make it back an OrderedDict\n        out = OrderedDict(list(zip(names, results)))\n\n        return out\n"
  },
  {
    "path": "monai/networks/blocks/fft_utils_t.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch import Tensor\n\n\ndef roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor:\n    \"\"\"\n    Similar to roll but for only one dim.\n\n    Args:\n        x: input data (k-space or image) that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        shift: the amount of shift along each of shift_dims dimension\n        shift_dim: the dimension over which the shift is applied\n\n    Returns:\n        1d-shifted version of x\n    \"\"\"\n    shift = shift % x.size(shift_dim)\n    if shift == 0:\n        return x\n\n    left = x.narrow(shift_dim, 0, x.size(shift_dim) - shift)\n    right = x.narrow(shift_dim, x.size(shift_dim) - shift, shift)\n\n    return torch.cat((right, left), dim=shift_dim)\n\n\ndef roll(x: Tensor, shift: list[int], shift_dims: list[int]) -> Tensor:\n    \"\"\"\n    Similar to np.roll but applies to PyTorch Tensors\n\n    Args:\n        x: input data (k-space or image) that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        shift: the amount of shift along each of shift_dims dimensions\n        shift_dims: dimensions over which the shift is applied\n\n    Returns:\n        shifted version of x\n    \"\"\"\n    if len(shift) != len(shift_dims):\n        raise ValueError(f\"len(shift) != len(shift_dims), got f{len(shift)} and f{len(shift_dims)}.\")\n    for s, d in zip(shift, shift_dims):\n        x = roll_1d(x, s, d)\n    return x\n\n\ndef fftshift(x: Tensor, shift_dims: list[int]) -> Tensor:\n    \"\"\"\n    Similar to np.fft.fftshift but applies to PyTorch Tensors\n\n    Args:\n        x: input data (k-space or image) that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        shift_dims: dimensions over which the shift is applied\n\n    Returns:\n        fft-shifted version of x\n    \"\"\"\n    shift = [0] * len(shift_dims)\n    for i, dim_num in enumerate(shift_dims):\n        shift[i] = x.shape[dim_num] // 2\n    return roll(x, shift, shift_dims)\n\n\ndef ifftshift(x: Tensor, shift_dims: list[int]) -> Tensor:\n    \"\"\"\n    Similar to np.fft.ifftshift but applies to PyTorch Tensors\n\n    Args:\n        x: input data (k-space or image) that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        shift_dims: dimensions over which the shift is applied\n\n    Returns:\n        ifft-shifted version of x\n    \"\"\"\n    shift = [0] * len(shift_dims)\n    for i, dim_num in enumerate(shift_dims):\n        shift[i] = (x.shape[dim_num] + 1) // 2\n    return roll(x, shift, shift_dims)\n\n\ndef ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor:\n    \"\"\"\n    Pytorch-based ifft for spatial_dims-dim signals. \"centered\" means this function automatically takes care\n    of the required ifft and fft shifts.\n    This is equivalent to do fft in numpy based on numpy.fft.ifftn, numpy.fft.fftshift, and numpy.fft.ifftshift\n\n    Args:\n        ksp: k-space data that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)\n        is_complex: if True, then the last dimension of the input ksp is expected to be 2 (representing real and imaginary channels)\n\n    Returns:\n        \"out\" which is the output image (inverse fourier of ksp)\n\n    Example:\n\n        .. code-block:: python\n\n            import torch\n            ksp = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts\n            # output1 and output2 will be identical\n            output1 = torch.fft.ifftn(torch.view_as_complex(torch.fft.ifftshift(ksp,dim=(-3,-2))), dim=(-2,-1), norm=\"ortho\")\n            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )\n\n            output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)\n    \"\"\"\n    # define spatial dims to perform ifftshift, fftshift, and ifft\n    dims = list(range(-spatial_dims, 0))\n    if is_complex:\n        if ksp.shape[-1] != 2:\n            raise ValueError(f\"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).\")\n        x = torch.view_as_complex(ifftshift(ksp, [d - 1 for d in dims]))\n    else:\n        x = ifftshift(ksp, dims)\n\n    x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm=\"ortho\"))\n\n    out: Tensor = fftshift(x, [d - 1 for d in dims])\n\n    return out\n\n\ndef fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> Tensor:\n    \"\"\"\n    Pytorch-based fft for spatial_dims-dim signals. \"centered\" means this function automatically takes care\n    of the required ifft and fft shifts.\n    This is equivalent to do ifft in numpy based on numpy.fft.fftn, numpy.fft.fftshift, and numpy.fft.ifftshift\n\n    Args:\n        im: image that can be\n            1) real-valued: the shape is (C,H,W) for 2D spatial inputs and (C,H,W,D) for 3D, or\n            2) complex-valued: the shape is (C,H,W,2) for 2D spatial data and (C,H,W,D,2) for 3D. C is the number of channels.\n        spatial_dims: number of spatial dimensions (e.g., is 2 for an image, and is 3 for a volume)\n        is_complex: if True, then the last dimension of the input im is expected to be 2 (representing real and imaginary channels)\n\n    Returns:\n        \"out\" which is the output kspace (fourier of im)\n\n    Example:\n\n        .. code-block:: python\n\n            import torch\n            im = torch.ones(1,3,3,2) # the last dim belongs to real/imaginary parts\n            # output1 and output2 will be identical\n            output1 = torch.fft.fftn(torch.view_as_complex(torch.fft.ifftshift(im,dim=(-3,-2))), dim=(-2,-1), norm=\"ortho\")\n            output1 = torch.fft.fftshift( torch.view_as_real(output1), dim=(-3,-2) )\n\n            output2 = fftn_centered(im, spatial_dims=2, is_complex=True)\n    \"\"\"\n    # define spatial dims to perform ifftshift, fftshift, and fft\n    dims = list(range(-spatial_dims, 0))\n    if is_complex:\n        if im.shape[-1] != 2:\n            raise ValueError(f\"img.shape[-1] is not 2 ({im.shape[-1]}).\")\n        x = torch.view_as_complex(ifftshift(im, [d - 1 for d in dims]))\n    else:\n        x = ifftshift(im, dims)\n\n    x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm=\"ortho\"))\n\n    out: Tensor = fftshift(x, [d - 1 for d in dims])\n\n    return out\n"
  },
  {
    "path": "monai/networks/blocks/localnet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.layers import same_padding\nfrom monai.networks.layers.factories import Conv, Norm, Pool\n\n\ndef get_conv_block(\n    spatial_dims: int,\n    in_channels: int,\n    out_channels: int,\n    kernel_size: Sequence[int] | int = 3,\n    act: tuple | str | None = \"RELU\",\n    norm: tuple | str | None = \"BATCH\",\n) -> nn.Module:\n    padding = same_padding(kernel_size)\n    mod: nn.Module = Convolution(\n        spatial_dims,\n        in_channels,\n        out_channels,\n        kernel_size=kernel_size,\n        act=act,\n        norm=norm,\n        bias=False,\n        conv_only=False,\n        padding=padding,\n    )\n    return mod\n\n\ndef get_conv_layer(\n    spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int = 3\n) -> nn.Module:\n    padding = same_padding(kernel_size)\n    mod: nn.Module = Convolution(\n        spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding\n    )\n    return mod\n\n\ndef get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module:\n    mod: nn.Module = Convolution(\n        spatial_dims=spatial_dims,\n        in_channels=in_channels,\n        out_channels=out_channels,\n        strides=2,\n        act=\"RELU\",\n        norm=\"BATCH\",\n        bias=False,\n        is_transposed=True,\n        padding=1,\n        output_padding=1,\n    )\n    return mod\n\n\nclass ResidualBlock(nn.Module):\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int\n    ) -> None:\n        super().__init__()\n        if in_channels != out_channels:\n            raise ValueError(\n                f\"expecting in_channels == out_channels, \" f\"got in_channels={in_channels}, out_channels={out_channels}\"\n            )\n        self.conv_block = get_conv_block(\n            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size\n        )\n        self.conv = get_conv_layer(\n            spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size\n        )\n        self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)\n        self.relu = nn.ReLU()\n\n    def forward(self, x) -> torch.Tensor:\n        out: torch.Tensor = self.relu(self.norm(self.conv(self.conv_block(x))) + x)\n        return out\n\n\nclass LocalNetResidualBlock(nn.Module):\n\n    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:\n        super().__init__()\n        if in_channels != out_channels:\n            raise ValueError(\n                f\"expecting in_channels == out_channels, \" f\"got in_channels={in_channels}, out_channels={out_channels}\"\n            )\n        self.conv_layer = get_conv_layer(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)\n        self.norm = Norm[Norm.BATCH, spatial_dims](out_channels)\n        self.relu = nn.ReLU()\n\n    def forward(self, x, mid) -> torch.Tensor:\n        out: torch.Tensor = self.relu(self.norm(self.conv_layer(x)) + mid)\n        return out\n\n\nclass LocalNetDownSampleBlock(nn.Module):\n    \"\"\"\n    A down-sample module that can be used for LocalNet, based on:\n    `Weakly-supervised convolutional neural networks for multimodal image registration\n    <https://doi.org/10.1016/j.media.2018.07.002>`_.\n    `Label-driven weakly-supervised learning for multimodal deformable image registration\n    <https://arxiv.org/abs/1711.01666>`_.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            kernel_size: convolution kernel size.\n        Raises:\n            NotImplementedError: when ``kernel_size`` is even\n        \"\"\"\n        super().__init__()\n        self.conv_block = get_conv_block(\n            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size\n        )\n        self.residual_block = ResidualBlock(\n            spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size\n        )\n        self.max_pool = Pool[Pool.MAX, spatial_dims](kernel_size=2)\n\n    def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Halves the spatial dimensions.\n        A tuple of (x, mid) is returned:\n\n            -  x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),\n            -  mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3])\n\n        Args:\n            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])\n\n        Raises:\n            ValueError: when input spatial dimensions are not even.\n        \"\"\"\n        for i in x.shape[2:]:\n            if i % 2 != 0:\n                raise ValueError(\"expecting x spatial dimensions be even, \" f\"got x of shape {x.shape}\")\n        x = self.conv_block(x)\n        mid = self.residual_block(x)\n        x = self.max_pool(mid)\n        return x, mid\n\n\nclass LocalNetUpSampleBlock(nn.Module):\n    \"\"\"\n    An up-sample module that can be used for LocalNet, based on:\n    `Weakly-supervised convolutional neural networks for multimodal image registration\n    <https://doi.org/10.1016/j.media.2018.07.002>`_.\n    `Label-driven weakly-supervised learning for multimodal deformable image registration\n    <https://arxiv.org/abs/1711.01666>`_.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        mode: str = \"nearest\",\n        align_corners: bool | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            mode: interpolation mode of the additive upsampling, default to 'nearest'.\n            align_corners: whether to align corners for the additive upsampling, default to None.\n        Raises:\n            ValueError: when ``in_channels != 2 * out_channels``\n        \"\"\"\n        super().__init__()\n        self.deconv_block = get_deconv_block(\n            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels\n        )\n        self.conv_block = get_conv_block(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels)\n        self.residual_block = LocalNetResidualBlock(\n            spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels\n        )\n        if in_channels / out_channels != 2:\n            raise ValueError(\n                f\"expecting in_channels == 2 * out_channels, \"\n                f\"got in_channels={in_channels}, out_channels={out_channels}\"\n            )\n        self.out_channels = out_channels\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def additive_upsampling(self, x, mid) -> torch.Tensor:\n        x = F.interpolate(x, mid.shape[2:], mode=self.mode, align_corners=self.align_corners)\n        # [(batch, out_channels, ...), (batch, out_channels, ...)]\n        x = x.split(split_size=int(self.out_channels), dim=1)\n        # (batch, out_channels, ...)\n        out: torch.Tensor = torch.sum(torch.stack(x, dim=-1), dim=-1)\n        return out\n\n    def forward(self, x, mid) -> torch.Tensor:\n        \"\"\"\n        Halves the channel and doubles the spatial dimensions.\n\n        Args:\n            x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])\n            mid: mid-level feature saved during down-sampling,\n                in shape (batch, ``out_channels``, midsize_1, midsize_2, [midsize_3])\n\n        Raises:\n            ValueError: when ``midsize != insize * 2``\n        \"\"\"\n        for i, j in zip(x.shape[2:], mid.shape[2:]):\n            if j != 2 * i:\n                raise ValueError(\n                    \"expecting mid spatial dimensions be exactly the double of x spatial dimensions, \"\n                    f\"got x of shape {x.shape}, mid of shape {mid.shape}\"\n                )\n        h0 = self.deconv_block(x) + self.additive_upsampling(x, mid)\n        r1 = h0 + mid\n        r2 = self.conv_block(h0)\n        out: torch.Tensor = self.residual_block(r2, r1)\n        return out\n\n\nclass LocalNetFeatureExtractorBlock(nn.Module):\n    \"\"\"\n    A feature-extraction module that can be used for LocalNet, based on:\n    `Weakly-supervised convolutional neural networks for multimodal image registration\n    <https://doi.org/10.1016/j.media.2018.07.002>`_.\n    `Label-driven weakly-supervised learning for multimodal deformable image registration\n    <https://arxiv.org/abs/1711.01666>`_.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        act: tuple | str | None = \"RELU\",\n        initializer: str = \"kaiming_uniform\",\n    ) -> None:\n        \"\"\"\n        Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        act: activation type and arguments. Defaults to ReLU.\n        kernel_initializer: kernel initializer. Defaults to None.\n        \"\"\"\n        super().__init__()\n        self.conv_block = get_conv_block(\n            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, act=act, norm=None\n        )\n        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n        for m in self.conv_block.modules():\n            if isinstance(m, conv_type):\n                if initializer == \"kaiming_uniform\":\n                    nn.init.kaiming_normal_(torch.as_tensor(m.weight))\n                elif initializer == \"zeros\":\n                    nn.init.zeros_(torch.as_tensor(m.weight))\n                else:\n                    raise ValueError(\n                        f\"initializer {initializer} is not supported, \" \"currently supporting kaiming_uniform and zeros\"\n                    )\n\n    def forward(self, x) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])\n        \"\"\"\n        out: torch.Tensor = self.conv_block(x)\n        return out\n"
  },
  {
    "path": "monai/networks/blocks/mednext_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Portions of this code are derived from the original repository at:\n# https://github.com/MIC-DKFZ/MedNeXt\n# and are used under the terms of the Apache License, Version 2.0.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nall = [\"MedNeXtBlock\", \"MedNeXtDownBlock\", \"MedNeXtUpBlock\", \"MedNeXtOutBlock\"]\n\n\ndef get_conv_layer(spatial_dim: int = 3, transpose: bool = False):\n    if spatial_dim == 2:\n        return nn.ConvTranspose2d if transpose else nn.Conv2d\n    else:  # spatial_dim == 3\n        return nn.ConvTranspose3d if transpose else nn.Conv3d\n\n\nclass MedNeXtBlock(nn.Module):\n    \"\"\"\n    MedNeXtBlock class for the MedNeXt model.\n\n    Args:\n        in_channels (int): Number of input channels.\n        out_channels (int): Number of output channels.\n        expansion_ratio (int): Expansion ratio for the block. Defaults to 4.\n        kernel_size (int): Kernel size for convolutions. Defaults to 7.\n        use_residual_connection (int): Whether to use residual connection. Defaults to True.\n        norm_type (str): Type of normalization to use. Defaults to \"group\".\n        dim (str): Dimension of the input. Can be \"2d\" or \"3d\". Defaults to \"3d\".\n        global_resp_norm (bool): Whether to use global response normalization. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        expansion_ratio: int = 4,\n        kernel_size: int = 7,\n        use_residual_connection: int = True,\n        norm_type: str = \"group\",\n        dim=\"3d\",\n        global_resp_norm=False,\n    ):\n\n        super().__init__()\n\n        self.do_res = use_residual_connection\n\n        self.dim = dim\n        conv = get_conv_layer(spatial_dim=2 if dim == \"2d\" else 3)\n        global_resp_norm_param_shape = (1,) * (2 if dim == \"2d\" else 3)\n        # First convolution layer with DepthWise Convolutions\n        self.conv1 = conv(\n            in_channels=in_channels,\n            out_channels=in_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            padding=kernel_size // 2,\n            groups=in_channels,\n        )\n\n        # Normalization Layer. GroupNorm is used by default.\n        if norm_type == \"group\":\n            self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)  # type: ignore\n        elif norm_type == \"layer\":\n            self.norm = nn.LayerNorm(\n                normalized_shape=[in_channels] + [kernel_size] * (2 if dim == \"2d\" else 3)  # type: ignore\n            )\n        # Second convolution (Expansion) layer with Conv3D 1x1x1\n        self.conv2 = conv(\n            in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0\n        )\n\n        # GeLU activations\n        self.act = nn.GELU()\n\n        # Third convolution (Compression) layer with Conv3D 1x1x1\n        self.conv3 = conv(\n            in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0\n        )\n\n        self.global_resp_norm = global_resp_norm\n        if self.global_resp_norm:\n            global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape\n            self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)\n            self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the MedNeXtBlock.\n\n        Args:\n            x (torch.Tensor): Input tensor.\n\n        Returns:\n            torch.Tensor: Output tensor.\n        \"\"\"\n        x1 = x\n        x1 = self.conv1(x1)\n        x1 = self.act(self.conv2(self.norm(x1)))\n\n        if self.global_resp_norm:\n            # gamma, beta: learnable affine transform parameters\n            # X: input of shape (N,C,H,W,D)\n            if self.dim == \"2d\":\n                gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)\n            else:\n                gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)\n            nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)\n            x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1\n        x1 = self.conv3(x1)\n        if self.do_res:\n            x1 = x + x1\n        return x1\n\n\nclass MedNeXtDownBlock(MedNeXtBlock):\n    \"\"\"\n    MedNeXtDownBlock class for downsampling in the MedNeXt model.\n\n    Args:\n        in_channels (int): Number of input channels.\n        out_channels (int): Number of output channels.\n        expansion_ratio (int): Expansion ratio for the block. Defaults to 4.\n        kernel_size (int): Kernel size for convolutions. Defaults to 7.\n        use_residual_connection (bool): Whether to use residual connection. Defaults to False.\n        norm_type (str): Type of normalization to use. Defaults to \"group\".\n        dim (str): Dimension of the input. Can be \"2d\" or \"3d\". Defaults to \"3d\".\n        global_resp_norm (bool): Whether to use global response normalization. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        expansion_ratio: int = 4,\n        kernel_size: int = 7,\n        use_residual_connection: bool = False,\n        norm_type: str = \"group\",\n        dim: str = \"3d\",\n        global_resp_norm: bool = False,\n    ):\n\n        super().__init__(\n            in_channels,\n            out_channels,\n            expansion_ratio,\n            kernel_size,\n            use_residual_connection=False,\n            norm_type=norm_type,\n            dim=dim,\n            global_resp_norm=global_resp_norm,\n        )\n\n        conv = get_conv_layer(spatial_dim=2 if dim == \"2d\" else 3)\n        self.resample_do_res = use_residual_connection\n        if use_residual_connection:\n            self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)\n\n        self.conv1 = conv(\n            in_channels=in_channels,\n            out_channels=in_channels,\n            kernel_size=kernel_size,\n            stride=2,\n            padding=kernel_size // 2,\n            groups=in_channels,\n        )\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the MedNeXtDownBlock.\n\n        Args:\n            x (torch.Tensor): Input tensor.\n\n        Returns:\n            torch.Tensor: Output tensor.\n        \"\"\"\n        x1 = super().forward(x)\n\n        if self.resample_do_res:\n            res = self.res_conv(x)\n            x1 = x1 + res\n\n        return x1\n\n\nclass MedNeXtUpBlock(MedNeXtBlock):\n    \"\"\"\n    MedNeXtUpBlock class for upsampling in the MedNeXt model.\n\n    Args:\n        in_channels (int): Number of input channels.\n        out_channels (int): Number of output channels.\n        expansion_ratio (int): Expansion ratio for the block. Defaults to 4.\n        kernel_size (int): Kernel size for convolutions. Defaults to 7.\n        use_residual_connection (bool): Whether to use residual connection. Defaults to False.\n        norm_type (str): Type of normalization to use. Defaults to \"group\".\n        dim (str): Dimension of the input. Can be \"2d\" or \"3d\". Defaults to \"3d\".\n        global_resp_norm (bool): Whether to use global response normalization. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        expansion_ratio: int = 4,\n        kernel_size: int = 7,\n        use_residual_connection: bool = False,\n        norm_type: str = \"group\",\n        dim: str = \"3d\",\n        global_resp_norm: bool = False,\n    ):\n        super().__init__(\n            in_channels,\n            out_channels,\n            expansion_ratio,\n            kernel_size,\n            use_residual_connection=False,\n            norm_type=norm_type,\n            dim=dim,\n            global_resp_norm=global_resp_norm,\n        )\n\n        self.resample_do_res = use_residual_connection\n\n        self.dim = dim\n        conv = get_conv_layer(spatial_dim=2 if dim == \"2d\" else 3, transpose=True)\n        if use_residual_connection:\n            self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)\n\n        self.conv1 = conv(\n            in_channels=in_channels,\n            out_channels=in_channels,\n            kernel_size=kernel_size,\n            stride=2,\n            padding=kernel_size // 2,\n            groups=in_channels,\n        )\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the MedNeXtUpBlock.\n\n        Args:\n            x (torch.Tensor): Input tensor.\n\n        Returns:\n            torch.Tensor: Output tensor.\n        \"\"\"\n        x1 = super().forward(x)\n        # Asymmetry but necessary to match shape\n\n        if self.dim == \"2d\":\n            x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0))\n        else:\n            x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))\n\n        if self.resample_do_res:\n            res = self.res_conv(x)\n            if self.dim == \"2d\":\n                res = torch.nn.functional.pad(res, (1, 0, 1, 0))\n            else:\n                res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))\n            x1 = x1 + res\n\n        return x1\n\n\nclass MedNeXtOutBlock(nn.Module):\n    \"\"\"\n    MedNeXtOutBlock class for the output block in the MedNeXt model.\n\n    Args:\n        in_channels (int): Number of input channels.\n        n_classes (int): Number of output classes.\n        dim (str): Dimension of the input. Can be \"2d\" or \"3d\".\n    \"\"\"\n\n    def __init__(self, in_channels, n_classes, dim):\n        super().__init__()\n\n        conv = get_conv_layer(spatial_dim=2 if dim == \"2d\" else 3, transpose=True)\n        self.conv_out = conv(in_channels, n_classes, kernel_size=1)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the MedNeXtOutBlock.\n\n        Args:\n            x (torch.Tensor): Input tensor.\n\n        Returns:\n            torch.Tensor: Output tensor.\n        \"\"\"\n        return self.conv_out(x)\n"
  },
  {
    "path": "monai/networks/blocks/mlp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch.nn as nn\n\nfrom monai.networks.layers import get_act_layer\nfrom monai.networks.layers.factories import split_args\nfrom monai.utils import look_up_option\n\nSUPPORTED_DROPOUT_MODE = {\"vit\", \"swin\", \"vista3d\"}\n\n\nclass MLPBlock(nn.Module):\n    \"\"\"\n    A multi-layer perceptron block, based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n    \"\"\"\n\n    def __init__(\n        self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0, act: tuple | str = \"GELU\", dropout_mode=\"vit\"\n    ) -> None:\n        \"\"\"\n        Args:\n            hidden_size: dimension of hidden layer.\n            mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.\n            dropout_rate: fraction of the input units to drop.\n            act: activation type and arguments. Defaults to GELU. Also supports \"GEGLU\" and others.\n            dropout_mode: dropout mode, can be \"vit\" or \"swin\".\n                \"vit\" mode uses two dropout instances as implemented in\n                https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87\n                \"swin\" corresponds to one instance as implemented in\n                https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23\n                \"vista3d\" mode does not use dropout.\n\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n        mlp_dim = mlp_dim or hidden_size\n        act_name, _ = split_args(act)\n        self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != \"GEGLU\" else nn.Linear(hidden_size, mlp_dim * 2)\n        self.linear2 = nn.Linear(mlp_dim, hidden_size)\n        self.fn = get_act_layer(act)\n        # Use Union[nn.Dropout, nn.Identity] for type annotations\n        self.drop1: nn.Dropout | nn.Identity\n        self.drop2: nn.Dropout | nn.Identity\n\n        dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)\n        if dropout_opt == \"vit\":\n            self.drop1 = nn.Dropout(dropout_rate)\n            self.drop2 = nn.Dropout(dropout_rate)\n        elif dropout_opt == \"swin\":\n            self.drop1 = nn.Dropout(dropout_rate)\n            self.drop2 = self.drop1\n        elif dropout_opt == \"vista3d\":\n            self.drop1 = nn.Identity()\n            self.drop2 = nn.Identity()\n        else:\n            raise ValueError(f\"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}\")\n\n    def forward(self, x):\n        x = self.fn(self.linear1(x))\n        x = self.drop1(x)\n        x = self.linear2(x)\n        x = self.drop2(x)\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/patchembedding.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import LayerNorm\n\nfrom monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding\nfrom monai.networks.layers import Conv, trunc_normal_\nfrom monai.utils import ensure_tuple_rep, optional_import\nfrom monai.utils.module import look_up_option\n\nRearrange, _ = optional_import(\"einops.layers.torch\", name=\"Rearrange\")\nSUPPORTED_PATCH_EMBEDDING_TYPES = {\"conv\", \"perceptron\"}\nSUPPORTED_POS_EMBEDDING_TYPES = {\"none\", \"learnable\", \"sincos\", \"fourier\"}\n\n\nclass PatchEmbeddingBlock(nn.Module):\n    \"\"\"\n    A patch embedding block, based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n\n    Example::\n\n        >>> from monai.networks.blocks import PatchEmbeddingBlock\n        >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,\n        >>>                     proj_type=\"conv\", pos_embed_type=\"sincos\")\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        img_size: Sequence[int] | int,\n        patch_size: Sequence[int] | int,\n        hidden_size: int,\n        num_heads: int,\n        proj_type: str = \"conv\",\n        pos_embed_type: str = \"learnable\",\n        dropout_rate: float = 0.0,\n        spatial_dims: int = 3,\n        pos_embed_kwargs: dict | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: dimension of input channels.\n            img_size: dimension of input image.\n            patch_size: dimension of patch size.\n            hidden_size: dimension of hidden layer.\n            num_heads: number of attention heads.\n            proj_type: patch embedding layer type.\n            pos_embed_type: position embedding layer type.\n            dropout_rate: fraction of the input units to drop.\n            spatial_dims: number of spatial dimensions.\n            pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain\n                              `temperature` and for fourier it can contain `scales`.\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(f\"dropout_rate {dropout_rate} should be between 0 and 1.\")\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(f\"hidden size {hidden_size} should be divisible by num_heads {num_heads}.\")\n\n        self.proj_type = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES)\n        self.pos_embed_type = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)\n\n        img_size = ensure_tuple_rep(img_size, spatial_dims)\n        patch_size = ensure_tuple_rep(patch_size, spatial_dims)\n        for m, p in zip(img_size, patch_size):\n            if m < p:\n                raise ValueError(\"patch_size should be smaller than img_size.\")\n            if self.proj_type == \"perceptron\" and m % p != 0:\n                raise ValueError(\"patch_size should be divisible by img_size for perceptron.\")\n        self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])\n        self.patch_dim = int(in_channels * np.prod(patch_size))\n\n        self.patch_embeddings: nn.Module\n        if self.proj_type == \"conv\":\n            self.patch_embeddings = Conv[Conv.CONV, spatial_dims](\n                in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size\n            )\n        elif self.proj_type == \"perceptron\":\n            # for 3d: \"b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)\"\n            chars = ((\"h\", \"p1\"), (\"w\", \"p2\"), (\"d\", \"p3\"))[:spatial_dims]\n            from_chars = \"b c \" + \" \".join(f\"({k} {v})\" for k, v in chars)\n            to_chars = f\"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)\"\n            axes_len = {f\"p{i + 1}\": p for i, p in enumerate(patch_size)}\n            self.patch_embeddings = nn.Sequential(\n                Rearrange(f\"{from_chars} -> {to_chars}\", **axes_len), nn.Linear(self.patch_dim, hidden_size)\n            )\n        self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))\n        self.dropout = nn.Dropout(dropout_rate)\n\n        pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs\n\n        if self.pos_embed_type == \"none\":\n            pass\n        elif self.pos_embed_type == \"learnable\":\n            trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)\n        elif self.pos_embed_type == \"sincos\":\n            grid_size = []\n            for in_size, pa_size in zip(img_size, patch_size):\n                grid_size.append(in_size // pa_size)\n\n            self.position_embeddings = build_sincos_position_embedding(\n                grid_size, hidden_size, spatial_dims, **pos_embed_kwargs\n            )\n        elif self.pos_embed_type == \"fourier\":\n            grid_size = []\n            for in_size, pa_size in zip(img_size, patch_size):\n                grid_size.append(in_size // pa_size)\n\n            self.position_embeddings = build_fourier_position_embedding(\n                grid_size, hidden_size, spatial_dims, **pos_embed_kwargs\n            )\n        else:\n            raise ValueError(f\"pos_embed_type {self.pos_embed_type} not supported.\")\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward(self, x):\n        x = self.patch_embeddings(x)\n        if self.proj_type == \"conv\":\n            x = x.flatten(2).transpose(-1, -2)\n        embeddings = x + self.position_embeddings\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    Patch embedding block based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n\n    Unlike ViT patch embedding block: (1) input is padded to satisfy window size requirements (2) normalized if\n    specified (3) position embedding is not used.\n\n    Example::\n\n        >>> from monai.networks.blocks import PatchEmbed\n        >>> PatchEmbed(patch_size=2, in_chans=1, embed_dim=48, norm_layer=nn.LayerNorm, spatial_dims=3)\n    \"\"\"\n\n    def __init__(\n        self,\n        patch_size: Sequence[int] | int = 2,\n        in_chans: int = 1,\n        embed_dim: int = 48,\n        norm_layer: type[LayerNorm] = nn.LayerNorm,\n        spatial_dims: int = 3,\n    ) -> None:\n        \"\"\"\n        Args:\n            patch_size: dimension of patch size.\n            in_chans: dimension of input channels.\n            embed_dim: number of linear projection output channels.\n            norm_layer: normalization layer.\n            spatial_dims: spatial dimension.\n        \"\"\"\n\n        super().__init__()\n\n        if spatial_dims not in (2, 3):\n            raise ValueError(\"spatial dimension should be 2 or 3.\")\n\n        patch_size = ensure_tuple_rep(patch_size, spatial_dims)\n        self.patch_size = patch_size\n        self.embed_dim = embed_dim\n        self.proj = Conv[Conv.CONV, spatial_dims](\n            in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n        if norm_layer is not None:\n            self.norm = norm_layer(embed_dim)\n        else:\n            self.norm = None\n\n    def forward(self, x):\n        x_shape = x.size()\n        if len(x_shape) == 5:\n            _, _, d, h, w = x_shape\n            if w % self.patch_size[2] != 0:\n                x = F.pad(x, (0, self.patch_size[2] - w % self.patch_size[2]))\n            if h % self.patch_size[1] != 0:\n                x = F.pad(x, (0, 0, 0, self.patch_size[1] - h % self.patch_size[1]))\n            if d % self.patch_size[0] != 0:\n                x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - d % self.patch_size[0]))\n\n        elif len(x_shape) == 4:\n            _, _, h, w = x_shape\n            if w % self.patch_size[1] != 0:\n                x = F.pad(x, (0, self.patch_size[1] - w % self.patch_size[1]))\n            if h % self.patch_size[0] != 0:\n                x = F.pad(x, (0, 0, 0, self.patch_size[0] - h % self.patch_size[0]))\n\n        x = self.proj(x)\n        if self.norm is not None:\n            x_shape = x.size()\n            x = x.flatten(2).transpose(1, 2)\n            x = self.norm(x)\n            if len(x_shape) == 5:\n                d, wh, ww = x_shape[2], x_shape[3], x_shape[4]\n                x = x.transpose(1, 2).view(-1, self.embed_dim, d, wh, ww)\n            elif len(x_shape) == 4:\n                wh, ww = x_shape[2], x_shape[3]\n                x = x.transpose(1, 2).view(-1, self.embed_dim, wh, ww)\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/pos_embed_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport collections.abc\nfrom itertools import repeat\n\nimport torch\nimport torch.nn as nn\n\n__all__ = [\"build_fourier_position_embedding\", \"build_sincos_position_embedding\"]\n\n\n# From PyTorch internals\ndef _ntuple(n):\n\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return tuple(x)\n        return tuple(repeat(x, n))\n\n    return parse\n\n\ndef build_fourier_position_embedding(\n    grid_size: int | list[int], embed_dim: int, spatial_dims: int = 3, scales: float | list[float] = 1.0\n) -> torch.nn.Parameter:\n    \"\"\"\n    Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension,\n    spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant\n    points more distinguishable.\n    Position embedding is made anistropic by allowing setting different scales for each spatial dimension.\n        Reference: https://arxiv.org/abs/2509.02488\n\n    Args:\n        grid_size (int | List[int]): The size of the grid in each spatial dimension.\n        embed_dim (int): The dimension of the embedding.\n        spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).\n        scales (float | List[float]): The scale for every spatial dimension. If a single float is provided,\n                              the same scale is used for all dimensions.\n\n    Returns:\n        pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.\n    \"\"\"\n\n    to_tuple = _ntuple(spatial_dims)\n    grid_size_t = to_tuple(grid_size)\n    if len(grid_size_t) != spatial_dims:\n        raise ValueError(f\"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.\")\n\n    if embed_dim % 2 != 0:\n        raise ValueError(\"embed_dim must be even for Fourier position embedding\")\n\n    # Ensure scales is a tensor of shape (spatial_dims,)\n    if isinstance(scales, float):\n        scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)\n    elif isinstance(scales, (list, tuple)):\n        if len(scales) != spatial_dims:\n            raise ValueError(f\"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}\")\n        scales_tensor = torch.tensor(scales, dtype=torch.float)\n    else:\n        raise TypeError(f\"scales must be float or list of floats, got {type(scales)}\")\n\n    gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor\n\n    position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t]\n    positions = torch.stack(torch.meshgrid(*position_indices, indexing=\"ij\"), dim=-1)\n    positions = positions.flatten(end_dim=-2)\n\n    x_proj = (2.0 * torch.pi * positions) @ gaussians.T\n\n    pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)\n    pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False)\n\n    return pos_emb\n\n\ndef build_sincos_position_embedding(\n    grid_size: int | list[int], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0\n) -> torch.nn.Parameter:\n    \"\"\"\n    Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature.\n    Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py\n\n    Args:\n        grid_size (List[int]): The size of the grid in each spatial dimension.\n        embed_dim (int): The dimension of the embedding.\n        spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).\n        temperature (float): The temperature for the sin-cos position embedding.\n\n    Returns:\n        pos_embed (nn.Parameter): The sin-cos position embedding as a fixed parameter.\n    \"\"\"\n\n    if spatial_dims == 2:\n        to_2tuple = _ntuple(2)\n        grid_size_t = to_2tuple(grid_size)\n        h, w = grid_size_t\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w = torch.arange(w, dtype=torch.float32)\n\n        grid_h, grid_w = torch.meshgrid(grid_h, grid_w)\n\n        if embed_dim % 4 != 0:\n            raise AssertionError(\"Embed dimension must be divisible by 4 for 2D sin-cos position embedding\")\n\n        pos_dim = embed_dim // 4\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1.0 / (temperature**omega)\n        out_h = torch.einsum(\"m,d->md\", [grid_h.flatten(), omega])\n        out_w = torch.einsum(\"m,d->md\", [grid_w.flatten(), omega])\n        pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]\n    elif spatial_dims == 3:\n        to_3tuple = _ntuple(3)\n        grid_size_t = to_3tuple(grid_size)\n        h, w, d = grid_size_t\n        grid_h = torch.arange(h, dtype=torch.float32)\n        grid_w = torch.arange(w, dtype=torch.float32)\n        grid_d = torch.arange(d, dtype=torch.float32)\n\n        grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)\n\n        if embed_dim % 6 != 0:\n            raise AssertionError(\"Embed dimension must be divisible by 6 for 3D sin-cos position embedding\")\n\n        pos_dim = embed_dim // 6\n        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim\n        omega = 1.0 / (temperature**omega)\n        out_h = torch.einsum(\"m,d->md\", [grid_h.flatten(), omega])\n        out_w = torch.einsum(\"m,d->md\", [grid_w.flatten(), omega])\n        out_d = torch.einsum(\"m,d->md\", [grid_d.flatten(), omega])\n        pos_emb = torch.cat(\n            [\n                torch.sin(out_w),\n                torch.cos(out_w),\n                torch.sin(out_h),\n                torch.cos(out_h),\n                torch.sin(out_d),\n                torch.cos(out_d),\n            ],\n            dim=1,\n        )[None, :, :]\n    else:\n        raise NotImplementedError(\"Spatial Dimension Size {spatial_dims} Not Implemented!\")\n\n    pos_embed = nn.Parameter(pos_emb)\n    pos_embed.requires_grad = False\n\n    return pos_embed\n"
  },
  {
    "path": "monai/networks/blocks/regunet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.layers import Conv, Norm, Pool, same_padding\n\n\ndef get_conv_block(\n    spatial_dims: int,\n    in_channels: int,\n    out_channels: int,\n    kernel_size: Sequence[int] | int = 3,\n    strides: int = 1,\n    padding: tuple[int, ...] | int | None = None,\n    act: tuple | str | None = \"RELU\",\n    norm: tuple | str | None = \"BATCH\",\n    initializer: str | None = \"kaiming_uniform\",\n) -> nn.Module:\n    if padding is None:\n        padding = same_padding(kernel_size)\n    conv_block: nn.Module = Convolution(\n        spatial_dims,\n        in_channels,\n        out_channels,\n        kernel_size=kernel_size,\n        strides=strides,\n        act=act,\n        norm=norm,\n        bias=False,\n        conv_only=False,\n        padding=padding,\n    )\n    conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n    for m in conv_block.modules():\n        if isinstance(m, conv_type):\n            if initializer == \"kaiming_uniform\":\n                nn.init.kaiming_normal_(torch.as_tensor(m.weight))\n            elif initializer == \"zeros\":\n                nn.init.zeros_(torch.as_tensor(m.weight))\n            else:\n                raise ValueError(\n                    f\"initializer {initializer} is not supported, \" \"currently supporting kaiming_uniform and zeros\"\n                )\n    return conv_block\n\n\ndef get_conv_layer(\n    spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int = 3\n) -> nn.Module:\n    padding = same_padding(kernel_size)\n    mod: nn.Module = Convolution(\n        spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding\n    )\n    return mod\n\n\nclass RegistrationResidualConvBlock(nn.Module):\n    \"\"\"\n    A block with skip links and layer - norm - activation.\n    Only changes the number of channels, the spatial size is kept same.\n    \"\"\"\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, num_layers: int = 2, kernel_size: int = 3\n    ):\n        \"\"\"\n\n        Args:\n            spatial_dims: number of spatial dimensions\n            in_channels: number of input channels\n            out_channels: number of output channels\n            num_layers: number of layers inside the block\n            kernel_size: kernel_size\n        \"\"\"\n        super().__init__()\n        self.num_layers = num_layers\n        self.layers = nn.ModuleList(\n            [\n                get_conv_layer(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels if i == 0 else out_channels,\n                    out_channels=out_channels,\n                    kernel_size=kernel_size,\n                )\n                for i in range(num_layers)\n            ]\n        )\n        self.norms = nn.ModuleList([Norm[Norm.BATCH, spatial_dims](out_channels) for _ in range(num_layers)])\n        self.acts = nn.ModuleList([nn.ReLU() for _ in range(num_layers)])\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n\n        Args:\n            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])\n\n        Returns:\n            Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]),\n            with the same spatial size as ``x``\n        \"\"\"\n        skip = x\n        for i, (conv, norm, act) in enumerate(zip(self.layers, self.norms, self.acts)):\n            x = conv(x)\n            x = norm(x)\n            if i == self.num_layers - 1:\n                # last block\n                x = x + skip\n            x = act(x)\n        return x\n\n\nclass RegistrationDownSampleBlock(nn.Module):\n    \"\"\"\n    A down-sample module used in RegUNet to half the spatial size.\n    The number of channels is kept same.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, channels: int, pooling: bool) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            channels: channels\n            pooling: use MaxPool if True, strided conv if False\n        \"\"\"\n        super().__init__()\n        if pooling:\n            self.layer = Pool[Pool.MAX, spatial_dims](kernel_size=2)\n        else:\n            self.layer = get_conv_block(\n                spatial_dims=spatial_dims,\n                in_channels=channels,\n                out_channels=channels,\n                kernel_size=2,\n                strides=2,\n                padding=0,\n            )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Halves the spatial dimensions and keeps the same channel.\n        output in shape (batch, ``channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),\n\n        Args:\n            x: Tensor in shape (batch, ``channels``, insize_1, insize_2, [insize_3])\n\n        Raises:\n            ValueError: when input spatial dimensions are not even.\n        \"\"\"\n        for i in x.shape[2:]:\n            if i % 2 != 0:\n                raise ValueError(\"expecting x spatial dimensions be even, \" f\"got x of shape {x.shape}\")\n        out: torch.Tensor = self.layer(x)\n        return out\n\n\ndef get_deconv_block(spatial_dims: int, in_channels: int, out_channels: int) -> nn.Module:\n    mod: nn.Module = Convolution(\n        spatial_dims=spatial_dims,\n        in_channels=in_channels,\n        out_channels=out_channels,\n        strides=2,\n        act=\"RELU\",\n        norm=\"BATCH\",\n        bias=False,\n        is_transposed=True,\n        padding=1,\n        output_padding=1,\n    )\n    return mod\n\n\nclass RegistrationExtractionBlock(nn.Module):\n    \"\"\"\n    The Extraction Block used in RegUNet.\n    Extracts feature from each ``extract_levels`` and takes the average.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        extract_levels: tuple[int],\n        num_channels: tuple[int] | list[int],\n        out_channels: int,\n        kernel_initializer: str | None = \"kaiming_uniform\",\n        activation: str | None = None,\n        mode: str = \"nearest\",\n        align_corners: bool | None = None,\n    ):\n        \"\"\"\n\n        Args:\n            spatial_dims: number of spatial dimensions\n            extract_levels: spatial levels to extract feature from, 0 refers to the input scale\n            num_channels: number of channels at each scale level,\n                List or Tuple of length equals to `depth` of the RegNet\n            out_channels: number of output channels\n            kernel_initializer: kernel initializer\n            activation: kernel activation function\n            mode: feature map interpolation mode, default to \"nearest\".\n            align_corners: whether to align corners for feature map interpolation.\n        \"\"\"\n        super().__init__()\n        self.extract_levels = extract_levels\n        self.max_level = max(extract_levels)\n        self.layers = nn.ModuleList(\n            [\n                get_conv_block(\n                    spatial_dims=spatial_dims,\n                    in_channels=num_channels[d],\n                    out_channels=out_channels,\n                    norm=None,\n                    act=activation,\n                    initializer=kernel_initializer,\n                )\n                for d in extract_levels\n            ]\n        )\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor:\n        \"\"\"\n\n        Args:\n            x: Decoded feature at different spatial levels, sorted from deep to shallow\n            image_size: output image size\n\n        Returns:\n            Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size``\n        \"\"\"\n        feature_list = [\n            F.interpolate(\n                layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners\n            )\n            for layer, level in zip(self.layers, self.extract_levels)\n        ]\n        out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0)\n        return out\n"
  },
  {
    "path": "monai/networks/blocks/rel_pos_embedding.py",
    "content": "# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Iterable\n\nimport torch\nfrom torch import nn\n\nfrom monai.networks.blocks.attention_utils import add_decomposed_rel_pos\nfrom monai.utils.misc import ensure_tuple_size\n\n\nclass DecomposedRelativePosEmbedding(nn.Module):\n    def __init__(self, s_input_dims: tuple[int, int] | tuple[int, int, int], c_dim: int, num_heads: int) -> None:\n        \"\"\"\n        Args:\n            s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)\n            c_dim (int): channel dimension\n            num_heads(int): number of attention heads\n        \"\"\"\n        super().__init__()\n\n        # validate inputs\n        if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:\n            raise ValueError(\"s_input_dims must be set as follows: (H, W) or (H, W, D)\")\n\n        self.s_input_dims = s_input_dims\n        self.c_dim = c_dim\n        self.num_heads = num_heads\n        self.rel_pos_arr = nn.ParameterList(\n            [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]\n        )\n\n    def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:\n        \"\"\"\"\"\"\n        batch = x.shape[0]\n        h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)\n\n        att_mat = add_decomposed_rel_pos(\n            att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),\n            q.contiguous().view(batch * self.num_heads, h * w * d, -1),\n            self.rel_pos_arr,\n            (h, w) if d == 1 else (h, w, d),\n            (h, w) if d == 1 else (h, w, d),\n        )\n\n        att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)\n        return att_mat\n"
  },
  {
    "path": "monai/networks/blocks/segresnet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch.nn as nn\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.blocks.upsample import UpSample\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\nfrom monai.utils import InterpolateMode, UpsampleMode\n\n\ndef get_conv_layer(\n    spatial_dims: int, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, bias: bool = False\n):\n    return Convolution(\n        spatial_dims, in_channels, out_channels, strides=stride, kernel_size=kernel_size, bias=bias, conv_only=True\n    )\n\n\ndef get_upsample_layer(\n    spatial_dims: int, in_channels: int, upsample_mode: UpsampleMode | str = \"nontrainable\", scale_factor: int = 2\n):\n    return UpSample(\n        spatial_dims=spatial_dims,\n        in_channels=in_channels,\n        out_channels=in_channels,\n        scale_factor=scale_factor,\n        mode=upsample_mode,\n        interp_mode=InterpolateMode.LINEAR,\n        align_corners=False,\n    )\n\n\nclass ResBlock(nn.Module):\n    \"\"\"\n    ResBlock employs skip connection and two convolution blocks and is used\n    in SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization\n    <https://arxiv.org/pdf/1810.11654.pdf>`_.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        norm: tuple | str,\n        kernel_size: int = 3,\n        act: tuple | str = (\"RELU\", {\"inplace\": True}),\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions, could be 1, 2 or 3.\n            in_channels: number of input channels.\n            norm: feature normalization type and arguments.\n            kernel_size: convolution kernel size, the value should be an odd number. Defaults to 3.\n            act: activation type and arguments. Defaults to ``RELU``.\n        \"\"\"\n\n        super().__init__()\n\n        if kernel_size % 2 != 1:\n            raise AssertionError(\"kernel_size should be an odd number.\")\n\n        self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)\n        self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)\n        self.act = get_act_layer(act)\n        self.conv1 = get_conv_layer(\n            spatial_dims, in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size\n        )\n        self.conv2 = get_conv_layer(\n            spatial_dims, in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size\n        )\n\n    def forward(self, x):\n        identity = x\n\n        x = self.norm1(x)\n        x = self.act(x)\n        x = self.conv1(x)\n\n        x = self.norm2(x)\n        x = self.act(x)\n        x = self.conv2(x)\n\n        x += identity\n\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/selfattention.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.layers.utils import get_rel_pos_embedding_layer\nfrom monai.utils import optional_import\n\nRearrange, _ = optional_import(\"einops.layers.torch\", name=\"Rearrange\")\n\n\nclass SABlock(nn.Module):\n    \"\"\"\n    A self-attention block, based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        dropout_rate: float = 0.0,\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n        dim_head: int | None = None,\n        hidden_input_size: int | None = None,\n        causal: bool = False,\n        sequence_length: int | None = None,\n        rel_pos_embedding: str | None = None,\n        input_size: tuple | None = None,\n        attention_dtype: torch.dtype | None = None,\n        include_fc: bool = True,\n        use_combined_linear: bool = True,\n        use_flash_attention: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            hidden_size (int): dimension of hidden layer.\n            num_heads (int): number of attention heads.\n            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.\n            qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.\n            save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.\n            dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.\n            hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.\n            causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762).\n            sequence_length: if causal is True, it is necessary to specify the sequence length.\n            rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.\n                For now only \"decomposed\" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.\n            input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative\n                positional parameter size.\n            attention_dtype: cast attention operations to this dtype.\n            include_fc: whether to include the final linear layer. Default to True.\n            use_combined_linear: whether to use a single linear layer for qkv projection, default to True.\n            use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n                (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(\"hidden size should be divisible by num_heads.\")\n\n        if dim_head:\n            self.inner_dim = num_heads * dim_head\n            self.dim_head = dim_head\n        else:\n            if hidden_size % num_heads != 0:\n                raise ValueError(\"hidden size should be divisible by num_heads.\")\n            self.inner_dim = hidden_size\n            self.dim_head = hidden_size // num_heads\n\n        if causal and sequence_length is None:\n            raise ValueError(\"sequence_length is necessary for causal attention.\")\n\n        if use_flash_attention and save_attn:\n            raise ValueError(\n                \"save_attn has been set to True, but use_flash_attention is also set\"\n                \"to True. save_attn can only be used if use_flash_attention is False.\"\n            )\n\n        if use_flash_attention and rel_pos_embedding is not None:\n            raise ValueError(\"rel_pos_embedding must be None if you are using flash_attention.\")\n\n        self.num_heads = num_heads\n        self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size\n        self.out_proj: nn.Linear | nn.Identity\n        if include_fc:\n            self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)\n        else:\n            self.out_proj = nn.Identity()\n\n        self.qkv: nn.Linear | nn.Identity\n        self.to_q: nn.Linear | nn.Identity\n        self.to_k: nn.Linear | nn.Identity\n        self.to_v: nn.Linear | nn.Identity\n\n        if use_combined_linear:\n            self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)\n            self.to_q = self.to_k = self.to_v = nn.Identity()  # add to enable torchscript\n            self.input_rearrange = Rearrange(\"b h (qkv l d) -> qkv b l h d\", qkv=3, l=num_heads)\n        else:\n            self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)\n            self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)\n            self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)\n            self.qkv = nn.Identity()  # add to enable torchscript\n            self.input_rearrange = Rearrange(\"b h (l d) -> b l h d\", l=num_heads)\n        self.out_rearrange = Rearrange(\"b l h d -> b h (l d)\")\n        self.drop_output = nn.Dropout(dropout_rate)\n        self.drop_weights = nn.Dropout(dropout_rate)\n        self.dropout_rate = dropout_rate\n        self.scale = self.dim_head**-0.5\n        self.save_attn = save_attn\n        self.att_mat = torch.Tensor()\n        self.attention_dtype = attention_dtype\n        self.causal = causal\n        self.sequence_length = sequence_length\n        self.include_fc = include_fc\n        self.use_combined_linear = use_combined_linear\n        self.use_flash_attention = use_flash_attention\n\n        if causal and sequence_length is not None:\n            # causal mask to ensure that attention is only applied to the left in the input sequence\n            self.register_buffer(\n                \"causal_mask\",\n                torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),\n            )\n            self.causal_mask: torch.Tensor\n        else:\n            self.causal_mask = torch.Tensor()\n\n        self.rel_positional_embedding = (\n            get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads)\n            if rel_pos_embedding is not None\n            else None\n        )\n        self.input_size = input_size\n\n    def forward(self, x, attn_mask: torch.Tensor | None = None):\n        \"\"\"\n        Args:\n            x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C\n            attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.\n            B x (s_dim_1 * ... * s_dim_n). Defaults to None.\n\n        Return:\n            torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C\n        \"\"\"\n        if self.use_combined_linear:\n            output = self.input_rearrange(self.qkv(x))\n            q, k, v = output[0], output[1], output[2]\n        else:\n            q = self.input_rearrange(self.to_q(x))\n            k = self.input_rearrange(self.to_k(x))\n            v = self.input_rearrange(self.to_v(x))\n\n        if self.attention_dtype is not None:\n            q = q.to(self.attention_dtype)\n            k = k.to(self.attention_dtype)\n\n        if self.use_flash_attention:\n            x = F.scaled_dot_product_attention(\n                query=q,\n                key=k,\n                value=v,\n                attn_mask=attn_mask,\n                scale=self.scale,\n                dropout_p=self.dropout_rate,\n                is_causal=self.causal,\n            )\n        else:\n            att_mat = torch.einsum(\"blxd,blyd->blxy\", q, k) * self.scale\n\n            # apply relative positional embedding if defined\n            if self.rel_positional_embedding is not None:\n                att_mat = self.rel_positional_embedding(x, att_mat, q)\n\n            if self.causal:\n                if attn_mask is not None:\n                    raise ValueError(\"Causal attention does not support attention masks.\")\n                att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float(\"-inf\"))\n\n            if attn_mask is not None:\n                attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)\n                attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)\n                att_mat = att_mat.masked_fill(attn_mask == 0, float(\"-inf\"))\n\n            att_mat = att_mat.softmax(dim=-1)\n            if self.save_attn:\n                # no gradients and new tensor;\n                # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html\n                self.att_mat = att_mat.detach()\n\n            att_mat = self.drop_weights(att_mat)\n            x = torch.einsum(\"bhxy,bhyd->bhxd\", att_mat, v)\n\n        x = self.out_rearrange(x)\n        if self.include_fc:\n            x = self.out_proj(x)\n        x = self.drop_output(x)\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/spade_norm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.layers.utils import get_norm_layer\n\n\nclass SPADE(nn.Module):\n    \"\"\"\n    Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a\n    semantic map. This block is used in SPADE-based image-to-image translation models, as described in\n    Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291).\n\n    Args:\n        label_nc: number of semantic labels\n        norm_nc: number of output channels\n        kernel_size: kernel size\n        spatial_dims: number of spatial dimensions\n        hidden_channels: number of channels in the intermediate gamma and beta layers\n        norm: type of base normalisation used before applying the SPADE normalisation\n        norm_params: parameters for the base normalisation\n    \"\"\"\n\n    def __init__(\n        self,\n        label_nc: int,\n        norm_nc: int,\n        kernel_size: int = 3,\n        spatial_dims: int = 2,\n        hidden_channels: int = 64,\n        norm: str | tuple = \"INSTANCE\",\n        norm_params: dict | None = None,\n    ) -> None:\n        super().__init__()\n\n        if norm_params is None:\n            norm_params = {}\n        if len(norm_params) != 0:\n            norm = (norm, norm_params)\n        self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc)\n        self.mlp_shared = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=label_nc,\n            out_channels=hidden_channels,\n            kernel_size=kernel_size,\n            norm=None,\n            act=\"LEAKYRELU\",\n        )\n        self.mlp_gamma = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_channels,\n            out_channels=norm_nc,\n            kernel_size=kernel_size,\n            act=None,\n        )\n        self.mlp_beta = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_channels,\n            out_channels=norm_nc,\n            kernel_size=kernel_size,\n            act=None,\n        )\n\n    def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels.\n            segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels.\n            The map will be interpolated to the dimension of x internally.\n        \"\"\"\n\n        # Part 1. generate parameter-free normalized activations\n        normalized = self.param_free_norm(x.contiguous())\n\n        # Part 2. produce scaling and bias conditioned on semantic map\n        segmap = F.interpolate(segmap, size=x.size()[2:], mode=\"nearest\")\n        actv = self.mlp_shared(segmap)\n        gamma = self.mlp_gamma(actv)\n        beta = self.mlp_beta(actv)\n        out: torch.Tensor = normalized * (1 + gamma) + beta\n        return out\n"
  },
  {
    "path": "monai/networks/blocks/spatialattention.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import SABlock\n\n\nclass SpatialAttentionBlock(nn.Module):\n    \"\"\"Perform spatial self-attention on the input tensor.\n\n    The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then\n    self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape.\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n        num_channels: number of input channels. Must be divisible by num_head_channels.\n        num_head_channels: number of channels per head.\n        norm_num_groups: Number of groups for the group norm layer.\n        norm_eps: Epsilon for the normalization.\n        attention_dtype: cast attention operations to this dtype.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        num_channels: int,\n        num_head_channels: int | None = None,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        attention_dtype: torch.dtype | None = None,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.spatial_dims = spatial_dims\n        self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)\n        # check num_head_channels is divisible by num_channels\n        if num_head_channels is not None and num_channels % num_head_channels != 0:\n            raise ValueError(\"num_channels must be divisible by num_head_channels\")\n        num_heads = num_channels // num_head_channels if num_head_channels is not None else 1\n        self.attn = SABlock(\n            hidden_size=num_channels,\n            num_heads=num_heads,\n            qkv_bias=True,\n            attention_dtype=attention_dtype,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n\n    def forward(self, x: torch.Tensor):\n        residual = x\n        shape = x.shape\n        x = self.norm(x)\n        x = x.reshape(*shape[:2], -1).transpose(1, 2)  # \"b c h w d -> b (h w d) c\"\n        x = self.attn(x)\n        x = x.transpose(1, 2).reshape(shape)  # \"b (h w d) c -> b c h w d\"\n        x = x + residual\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/squeeze_and_excitation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.layers.factories import Act, Conv, Norm, Pool, split_args\n\n\nclass ChannelSELayer(nn.Module):\n    \"\"\"\n    Re-implementation of the Squeeze-and-Excitation block based on:\n    \"Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507\".\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        r: int = 2,\n        acti_type_1: tuple[str, dict] | str = (\"relu\", {\"inplace\": True}),\n        acti_type_2: tuple[str, dict] | str = \"sigmoid\",\n        add_residual: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n            in_channels: number of input channels.\n            r: the reduction ratio r in the paper. Defaults to 2.\n            acti_type_1: activation type of the hidden squeeze layer. Defaults to ``(\"relu\", {\"inplace\": True})``.\n            acti_type_2: activation type of the output squeeze layer. Defaults to \"sigmoid\".\n\n        Raises:\n            ValueError: When ``r`` is nonpositive or larger than ``in_channels``.\n\n        See also:\n\n            :py:class:`monai.networks.layers.Act`\n\n        \"\"\"\n        super().__init__()\n\n        self.add_residual = add_residual\n\n        pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]\n        self.avg_pool = pool_type(1)  # spatial size (1, 1, ...)\n\n        channels = int(in_channels // r)\n        if channels <= 0:\n            raise ValueError(f\"r must be positive and smaller than in_channels, got r={r} in_channels={in_channels}.\")\n\n        act_1, act_1_args = split_args(acti_type_1)\n        act_2, act_2_args = split_args(acti_type_2)\n        self.fc = nn.Sequential(\n            nn.Linear(in_channels, channels, bias=True),\n            Act[act_1](**act_1_args),\n            nn.Linear(channels, in_channels, bias=True),\n            Act[act_2](**act_2_args),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).\n        \"\"\"\n        b, c = x.shape[:2]\n        y: torch.Tensor = self.avg_pool(x).view(b, c)\n        y = self.fc(y).view([b, c] + [1] * (x.ndim - 2))\n        result = x * y\n\n        # Residual connection is moved here instead of providing an override of forward in ResidualSELayer since\n        # Torchscript has an issue with using super().\n        if self.add_residual:\n            result += x\n\n        return result\n\n\nclass ResidualSELayer(ChannelSELayer):\n    \"\"\"\n    A \"squeeze-and-excitation\"-like layer with a residual connection::\n\n        --+-- SE --o--\n          |        |\n          +--------+\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        r: int = 2,\n        acti_type_1: tuple[str, dict] | str = \"leakyrelu\",\n        acti_type_2: tuple[str, dict] | str = \"relu\",\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n            in_channels: number of input channels.\n            r: the reduction ratio r in the paper. Defaults to 2.\n            acti_type_1: defaults to \"leakyrelu\".\n            acti_type_2: defaults to \"relu\".\n\n        See also:\n            :py:class:`monai.networks.blocks.ChannelSELayer`\n        \"\"\"\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            r=r,\n            acti_type_1=acti_type_1,\n            acti_type_2=acti_type_2,\n            add_residual=True,\n        )\n\n\nclass SEBlock(nn.Module):\n    \"\"\"\n    Residual module enhanced with Squeeze-and-Excitation::\n\n        ----+- conv1 --  conv2 -- conv3 -- SE -o---\n            |                                  |\n            +---(channel project if needed)----+\n\n    Re-implementation of the SE-Resnet block based on:\n    \"Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507\".\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        n_chns_1: int,\n        n_chns_2: int,\n        n_chns_3: int,\n        conv_param_1: dict | None = None,\n        conv_param_2: dict | None = None,\n        conv_param_3: dict | None = None,\n        project: Convolution | None = None,\n        r: int = 2,\n        acti_type_1: tuple[str, dict] | str = (\"relu\", {\"inplace\": True}),\n        acti_type_2: tuple[str, dict] | str = \"sigmoid\",\n        acti_type_final: tuple[str, dict] | str | None = (\"relu\", {\"inplace\": True}),\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n            in_channels: number of input channels.\n            n_chns_1: number of output channels in the 1st convolution.\n            n_chns_2: number of output channels in the 2nd convolution.\n            n_chns_3: number of output channels in the 3rd convolution.\n            conv_param_1: additional parameters to the 1st convolution.\n                Defaults to ``{\"kernel_size\": 1, \"norm\": Norm.BATCH, \"act\": (\"relu\", {\"inplace\": True})}``\n            conv_param_2: additional parameters to the 2nd convolution.\n                Defaults to ``{\"kernel_size\": 3, \"norm\": Norm.BATCH, \"act\": (\"relu\", {\"inplace\": True})}``\n            conv_param_3: additional parameters to the 3rd convolution.\n                Defaults to ``{\"kernel_size\": 1, \"norm\": Norm.BATCH, \"act\": None}``\n            project: in the case of residual chns and output chns doesn't match, a project\n                (Conv) layer/block is used to adjust the number of chns. In SENET, it is\n                consisted with a Conv layer as well as a Norm layer.\n                Defaults to None (chns are matchable) or a Conv layer with kernel size 1.\n            r: the reduction ratio r in the paper. Defaults to 2.\n            acti_type_1: activation type of the hidden squeeze layer. Defaults to \"relu\".\n            acti_type_2: activation type of the output squeeze layer. Defaults to \"sigmoid\".\n            acti_type_final: activation type of the end of the block. Defaults to \"relu\".\n\n        See also:\n\n            :py:class:`monai.networks.blocks.ChannelSELayer`\n\n        \"\"\"\n        super().__init__()\n\n        if not conv_param_1:\n            conv_param_1 = {\"kernel_size\": 1, \"norm\": Norm.BATCH, \"act\": (\"relu\", {\"inplace\": True})}\n        self.conv1 = Convolution(\n            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=n_chns_1, **conv_param_1\n        )\n\n        if not conv_param_2:\n            conv_param_2 = {\"kernel_size\": 3, \"norm\": Norm.BATCH, \"act\": (\"relu\", {\"inplace\": True})}\n        self.conv2 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_1, out_channels=n_chns_2, **conv_param_2)\n\n        if not conv_param_3:\n            conv_param_3 = {\"kernel_size\": 1, \"norm\": Norm.BATCH, \"act\": None}\n        self.conv3 = Convolution(spatial_dims=spatial_dims, in_channels=n_chns_2, out_channels=n_chns_3, **conv_param_3)\n\n        self.se_layer = ChannelSELayer(\n            spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2\n        )\n\n        if project is None and in_channels != n_chns_3:\n            self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1)\n        elif project is None:\n            self.project = nn.Identity()\n        else:\n            self.project = project\n\n        if acti_type_final is not None:\n            act_final, act_final_args = split_args(acti_type_final)\n            self.act = Act[act_final](**act_final_args)\n        else:\n            self.act = nn.Identity()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).\n        \"\"\"\n        residual = self.project(x)\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.conv3(x)\n        x = self.se_layer(x)\n        x += residual\n        x = self.act(x)\n        return x\n\n\nclass SEBottleneck(SEBlock):\n    \"\"\"\n    Bottleneck for SENet154.\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        inplanes: int,\n        planes: int,\n        groups: int,\n        reduction: int,\n        stride: int = 1,\n        downsample: Convolution | None = None,\n    ) -> None:\n        conv_param_1 = {\n            \"strides\": 1,\n            \"kernel_size\": 1,\n            \"act\": (\"relu\", {\"inplace\": True}),\n            \"norm\": Norm.BATCH,\n            \"bias\": False,\n        }\n        conv_param_2 = {\n            \"strides\": stride,\n            \"kernel_size\": 3,\n            \"act\": (\"relu\", {\"inplace\": True}),\n            \"norm\": Norm.BATCH,\n            \"groups\": groups,\n            \"bias\": False,\n        }\n        conv_param_3 = {\"strides\": 1, \"kernel_size\": 1, \"act\": None, \"norm\": Norm.BATCH, \"bias\": False}\n\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=inplanes,\n            n_chns_1=planes * 2,\n            n_chns_2=planes * 4,\n            n_chns_3=planes * 4,\n            conv_param_1=conv_param_1,\n            conv_param_2=conv_param_2,\n            conv_param_3=conv_param_3,\n            project=downsample,\n            r=reduction,\n        )\n\n\nclass SEResNetBottleneck(SEBlock):\n    \"\"\"\n    ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe\n    implementation and uses `strides=stride` in `conv1` and not in `conv2`\n    (the latter is used in the torchvision implementation of ResNet).\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        inplanes: int,\n        planes: int,\n        groups: int,\n        reduction: int,\n        stride: int = 1,\n        downsample: Convolution | None = None,\n    ) -> None:\n        conv_param_1 = {\n            \"strides\": stride,\n            \"kernel_size\": 1,\n            \"act\": (\"relu\", {\"inplace\": True}),\n            \"norm\": Norm.BATCH,\n            \"bias\": False,\n        }\n        conv_param_2 = {\n            \"strides\": 1,\n            \"kernel_size\": 3,\n            \"act\": (\"relu\", {\"inplace\": True}),\n            \"norm\": Norm.BATCH,\n            \"groups\": groups,\n            \"bias\": False,\n        }\n        conv_param_3 = {\"strides\": 1, \"kernel_size\": 1, \"act\": None, \"norm\": Norm.BATCH, \"bias\": False}\n\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=inplanes,\n            n_chns_1=planes,\n            n_chns_2=planes,\n            n_chns_3=planes * 4,\n            conv_param_1=conv_param_1,\n            conv_param_2=conv_param_2,\n            conv_param_3=conv_param_3,\n            project=downsample,\n            r=reduction,\n        )\n\n\nclass SEResNeXtBottleneck(SEBlock):\n    \"\"\"\n    ResNeXt bottleneck type C with a Squeeze-and-Excitation module.\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        inplanes: int,\n        planes: int,\n        groups: int,\n        reduction: int,\n        stride: int = 1,\n        downsample: Convolution | None = None,\n        base_width: int = 4,\n    ) -> None:\n        conv_param_1 = {\n            \"strides\": 1,\n            \"kernel_size\": 1,\n            \"act\": (\"relu\", {\"inplace\": True}),\n            \"norm\": Norm.BATCH,\n            \"bias\": False,\n        }\n        conv_param_2 = {\n            \"strides\": stride,\n            \"kernel_size\": 3,\n            \"act\": (\"relu\", {\"inplace\": True}),\n            \"norm\": Norm.BATCH,\n            \"groups\": groups,\n            \"bias\": False,\n        }\n        conv_param_3 = {\"strides\": 1, \"kernel_size\": 1, \"act\": None, \"norm\": Norm.BATCH, \"bias\": False}\n        width = math.floor(planes * (base_width / 64)) * groups\n\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=inplanes,\n            n_chns_1=width,\n            n_chns_2=width,\n            n_chns_3=planes * 4,\n            conv_param_1=conv_param_1,\n            conv_param_2=conv_param_2,\n            conv_param_3=conv_param_3,\n            project=downsample,\n            r=reduction,\n        )\n"
  },
  {
    "path": "monai/networks/blocks/text_embedding.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch import nn\nfrom torch.utils import model_zoo\n\nurl_map = {\n    \"clip_encoding_universal_model_32\": (\n        \"https://github.com/Project-MONAI/MONAI-extra-test-data/\"\n        \"releases/download/0.8.1/clip_encoding_universal_model.pth\"\n    )\n}\n\n\nclass TextEncoder(nn.Module):\n    \"\"\"\n    Text to vision encoding by Contrastive Language-Image Pre-training (CLIP) or random embedding.\n    The text to vision encoder loads the pre-trained or random initialized weights with connection to 2D/3D vision models.\n\n    Contrastive Language-Image Pre-training (CLIP), based on: \"Radford et al.,\n    Learning Transferable Visual Models From Natural Language Supervision <https://arxiv.org/abs/2103.00020>\"\n\n    Connecting text and medical 3D image, based on: \"Liu et al.,\n    CLIP-Driven Universal Model for Organ Segmentation and Tumor Detection <https://arxiv.org/pdf/2301.00785.pdf>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        out_channels: int,\n        spatial_dims: int = 3,\n        text_dim: int = 512,\n        hidden_size: int = 256,\n        encoding: str = \"clip_encoding_universal_model_32\",\n        pretrained: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            out_channels: number of output channels, to control text-based embedding for classes.\n            spatial_dims: number of spatial dims.\n            text_dim: dimension of text embeddings.\n            hidden_size: dimension of hidden features, compatible to different vision feature dimensions.\n            encoding: the text embedding type, default to use clip text pretrained weights.\n            pretrained: whether to load pretrained weights from e.g., (CLIP) to initialize text embeddings, default to False.\n        \"\"\"\n        super().__init__()\n        self.encoding = encoding\n\n        self.spatial_dims = spatial_dims\n        if spatial_dims not in (2, 3):\n            raise ValueError(\"spatial dimension should be 2 or 3.\")\n\n        if self.encoding == \"rand_embedding\":\n            self.text_embedding = nn.Embedding(out_channels, hidden_size)\n        else:\n            self.register_buffer(\"text_embedding\", torch.randn(out_channels, text_dim))\n\n            if pretrained:\n                model_url = url_map[self.encoding]\n                pretrain_state_dict = model_zoo.load_url(model_url, map_location=\"cpu\")\n                self.text_embedding.data = pretrain_state_dict.float()  # type: ignore\n            else:\n                print(f\"{self.encoding} is not implemented, and can not be downloaded, please load your own\")\n\n            self.text_to_vision = nn.Linear(text_dim, hidden_size)\n\n    def forward(self):\n        if self.encoding == \"rand_embedding\":\n            # text embedding as random initialized 'rand_embedding'\n            text_embedding = self.text_embedding.weight\n        else:\n            print(self.text_embedding)\n            text_embedding = nn.functional.relu(self.text_to_vision(self.text_embedding))\n\n        if self.spatial_dims == 3:\n            text_embedding = text_embedding.unsqueeze(2).unsqueeze(2).unsqueeze(2)\n        elif self.spatial_dims == 2:\n            text_embedding = text_embedding.unsqueeze(2).unsqueeze(2)\n\n        return text_embedding\n"
  },
  {
    "path": "monai/networks/blocks/transformerblock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock\n\n\nclass TransformerBlock(nn.Module):\n    \"\"\"\n    A transformer block, based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        mlp_dim: int,\n        num_heads: int,\n        dropout_rate: float = 0.0,\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n        causal: bool = False,\n        sequence_length: int | None = None,\n        with_cross_attention: bool = False,\n        use_flash_attention: bool = False,\n        include_fc: bool = True,\n        use_combined_linear: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            hidden_size (int): dimension of hidden layer.\n            mlp_dim (int): dimension of feedforward layer.\n            num_heads (int): number of attention heads.\n            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.\n            qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.\n            save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.\n            use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n                (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n            include_fc: whether to include the final linear layer. Default to True.\n            use_combined_linear: whether to use a single linear layer for qkv projection, default to True.\n\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(\"hidden_size should be divisible by num_heads.\")\n\n        self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)\n        self.norm1 = nn.LayerNorm(hidden_size)\n        self.attn = SABlock(\n            hidden_size,\n            num_heads,\n            dropout_rate,\n            qkv_bias=qkv_bias,\n            save_attn=save_attn,\n            causal=causal,\n            sequence_length=sequence_length,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.norm2 = nn.LayerNorm(hidden_size)\n        self.with_cross_attention = with_cross_attention\n\n        self.norm_cross_attn = nn.LayerNorm(hidden_size)\n        self.cross_attn = CrossAttentionBlock(\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            dropout_rate=dropout_rate,\n            qkv_bias=qkv_bias,\n            causal=False,\n            use_flash_attention=use_flash_attention,\n        )\n\n    def forward(\n        self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        x = x + self.attn(self.norm1(x), attn_mask=attn_mask)\n        if self.with_cross_attention:\n            x = x + self.cross_attn(self.norm_cross_attn(x), context=context)\n        x = x + self.mlp(self.norm2(x))\n        return x\n"
  },
  {
    "path": "monai/networks/blocks/unetr_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer\n\n\nclass UnetrUpBlock(nn.Module):\n    \"\"\"\n    An upsampling module that can be used for UNETR: \"Hatamizadeh et al.,\n    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[int] | int,\n        upsample_kernel_size: Sequence[int] | int,\n        norm_name: tuple | str,\n        res_block: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            kernel_size: convolution kernel size.\n            upsample_kernel_size: convolution kernel size for transposed convolution layers.\n            norm_name: feature normalization type and arguments.\n            res_block: bool argument to determine if residual block is used.\n\n        \"\"\"\n\n        super().__init__()\n        upsample_stride = upsample_kernel_size\n        self.transp_conv = get_conv_layer(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            kernel_size=upsample_kernel_size,\n            stride=upsample_stride,\n            conv_only=True,\n            is_transposed=True,\n        )\n\n        if res_block:\n            self.conv_block = UnetResBlock(\n                spatial_dims,\n                out_channels + out_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=1,\n                norm_name=norm_name,\n            )\n        else:\n            self.conv_block = UnetBasicBlock(  # type: ignore\n                spatial_dims,\n                out_channels + out_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=1,\n                norm_name=norm_name,\n            )\n\n    def forward(self, inp, skip):\n        # number of channels for skip should equals to out_channels\n        out = self.transp_conv(inp)\n        out = torch.cat((out, skip), dim=1)\n        out = self.conv_block(out)\n        return out\n\n\nclass UnetrPrUpBlock(nn.Module):\n    \"\"\"\n    A projection upsampling module that can be used for UNETR: \"Hatamizadeh et al.,\n    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        num_layer: int,\n        kernel_size: Sequence[int] | int,\n        stride: Sequence[int] | int,\n        upsample_kernel_size: Sequence[int] | int,\n        norm_name: tuple | str,\n        conv_block: bool = False,\n        res_block: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            num_layer: number of upsampling blocks.\n            kernel_size: convolution kernel size.\n            stride: convolution stride.\n            upsample_kernel_size: convolution kernel size for transposed convolution layers.\n            norm_name: feature normalization type and arguments.\n            conv_block: bool argument to determine if convolutional block is used.\n            res_block: bool argument to determine if residual block is used.\n\n        \"\"\"\n\n        super().__init__()\n\n        upsample_stride = upsample_kernel_size\n        self.transp_conv_init = get_conv_layer(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            kernel_size=upsample_kernel_size,\n            stride=upsample_stride,\n            conv_only=True,\n            is_transposed=True,\n        )\n        if conv_block:\n            if res_block:\n                self.blocks = nn.ModuleList(\n                    [\n                        nn.Sequential(\n                            get_conv_layer(\n                                spatial_dims,\n                                out_channels,\n                                out_channels,\n                                kernel_size=upsample_kernel_size,\n                                stride=upsample_stride,\n                                conv_only=True,\n                                is_transposed=True,\n                            ),\n                            UnetResBlock(\n                                spatial_dims=spatial_dims,\n                                in_channels=out_channels,\n                                out_channels=out_channels,\n                                kernel_size=kernel_size,\n                                stride=stride,\n                                norm_name=norm_name,\n                            ),\n                        )\n                        for i in range(num_layer)\n                    ]\n                )\n            else:\n                self.blocks = nn.ModuleList(\n                    [\n                        nn.Sequential(\n                            get_conv_layer(\n                                spatial_dims,\n                                out_channels,\n                                out_channels,\n                                kernel_size=upsample_kernel_size,\n                                stride=upsample_stride,\n                                conv_only=True,\n                                is_transposed=True,\n                            ),\n                            UnetBasicBlock(\n                                spatial_dims=spatial_dims,\n                                in_channels=out_channels,\n                                out_channels=out_channels,\n                                kernel_size=kernel_size,\n                                stride=stride,\n                                norm_name=norm_name,\n                            ),\n                        )\n                        for i in range(num_layer)\n                    ]\n                )\n        else:\n            self.blocks = nn.ModuleList(\n                [\n                    get_conv_layer(\n                        spatial_dims,\n                        out_channels,\n                        out_channels,\n                        kernel_size=upsample_kernel_size,\n                        stride=upsample_stride,\n                        conv_only=True,\n                        is_transposed=True,\n                    )\n                    for i in range(num_layer)\n                ]\n            )\n\n    def forward(self, x):\n        x = self.transp_conv_init(x)\n        for blk in self.blocks:\n            x = blk(x)\n        return x\n\n\nclass UnetrBasicBlock(nn.Module):\n    \"\"\"\n    A CNN module that can be used for UNETR, based on: \"Hatamizadeh et al.,\n    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[int] | int,\n        stride: Sequence[int] | int,\n        norm_name: tuple | str,\n        res_block: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            kernel_size: convolution kernel size.\n            stride: convolution stride.\n            norm_name: feature normalization type and arguments.\n            res_block: bool argument to determine if residual block is used.\n\n        \"\"\"\n\n        super().__init__()\n\n        if res_block:\n            self.layer = UnetResBlock(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                norm_name=norm_name,\n            )\n        else:\n            self.layer = UnetBasicBlock(  # type: ignore\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                norm_name=norm_name,\n            )\n\n    def forward(self, inp):\n        return self.layer(inp)\n"
  },
  {
    "path": "monai/networks/blocks/upsample.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.layers.factories import Conv, Pad, Pool\nfrom monai.networks.utils import icnr_init, pixelshuffle\nfrom monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option\n\n__all__ = [\"Upsample\", \"UpSample\", \"SubpixelUpsample\", \"Subpixelupsample\", \"SubpixelUpSample\"]\n\n\nclass UpSample(nn.Sequential):\n    \"\"\"\n    Upsamples data by `scale_factor`.\n    Supported modes are:\n\n        - \"deconv\": uses a transposed convolution.\n        - \"deconvgroup\": uses a transposed group convolution.\n        - \"nontrainable\": uses :py:class:`torch.nn.Upsample`.\n        - \"pixelshuffle\": uses :py:class:`monai.networks.blocks.SubpixelUpsample`.\n\n    This operation will cause non-deterministic when ``mode`` is ``UpsampleMode.NONTRAINABLE``.\n    Please check the link below for more details:\n    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms\n    This module can optionally take a pre-convolution\n    (often used to map the number of features from `in_channels` to `out_channels`).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int | None = None,\n        out_channels: int | None = None,\n        scale_factor: Sequence[float] | float = 2,\n        kernel_size: Sequence[float] | float | None = None,\n        size: tuple[int] | int | None = None,\n        mode: UpsampleMode | str = UpsampleMode.DECONV,\n        pre_conv: nn.Module | str | None = \"default\",\n        post_conv: nn.Module | None = None,\n        interp_mode: str = InterpolateMode.LINEAR,\n        align_corners: bool | None = True,\n        bias: bool = True,\n        apply_pad_pool: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of channels of the input image.\n            out_channels: number of channels of the output image. Defaults to `in_channels`.\n            scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.\n            kernel_size: kernel size used during transposed convolutions. Defaults to `scale_factor`.\n            size: spatial size of the output image.\n                Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.\n                In torch.nn.functional.interpolate, only one of `size` or `scale_factor` should be defined,\n                thus if size is defined, `scale_factor` will not be used.\n                Defaults to None.\n            mode: {``\"deconv\"``, ``\"deconvgroup\"``, ``\"nontrainable\"``, ``\"pixelshuffle\"``}. Defaults to ``\"deconv\"``.\n            pre_conv: a conv block applied before upsampling. Defaults to \"default\".\n                When ``conv_block`` is ``\"default\"``, one reserved conv layer will be utilized when\n                Only used in the \"nontrainable\" or \"pixelshuffle\" mode.\n            post_conv: a conv block applied after upsampling. Defaults to None. Only used in the \"nontrainable\"  mode.\n            interp_mode: {``\"nearest\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``}\n                Only used in the \"nontrainable\" mode.\n                If ends with ``\"linear\"`` will use ``spatial dims`` to determine the correct interpolation.\n                This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.\n                The interpolation mode. Defaults to ``\"linear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html\n            align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.\n                Only used in the \"nontrainable\" mode.\n            bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.\n            apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the\n                size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.\n                Only used in the \"pixelshuffle\" mode.\n\n        \"\"\"\n        super().__init__()\n        scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)\n        up_mode = look_up_option(mode, UpsampleMode)\n\n        if not kernel_size:\n            kernel_size_ = scale_factor_\n            output_padding = padding = 0\n        else:\n            kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)\n            padding = tuple((k - 1) // 2 for k in kernel_size_)  # type: ignore\n            output_padding = tuple(s - 1 - (k - 1) % 2 for k, s in zip(kernel_size_, scale_factor_))  # type: ignore\n\n        if up_mode == UpsampleMode.DECONV:\n            if not in_channels:\n                raise ValueError(f\"in_channels needs to be specified in the '{mode}' mode.\")\n            self.add_module(\n                \"deconv\",\n                Conv[Conv.CONVTRANS, spatial_dims](\n                    in_channels=in_channels,\n                    out_channels=out_channels or in_channels,\n                    kernel_size=kernel_size_,\n                    stride=scale_factor_,\n                    padding=padding,\n                    output_padding=output_padding,\n                    bias=bias,\n                ),\n            )\n        elif up_mode == UpsampleMode.DECONVGROUP:\n            if not in_channels:\n                raise ValueError(f\"in_channels needs to be specified in the '{mode}' mode.\")\n\n            if out_channels is None:\n                out_channels = in_channels\n            groups = out_channels if in_channels % out_channels == 0 else 1\n\n            self.add_module(\n                \"deconvgroup\",\n                Conv[Conv.CONVTRANS, spatial_dims](\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    kernel_size=kernel_size_,\n                    stride=scale_factor_,\n                    padding=padding,\n                    output_padding=output_padding,\n                    groups=groups,\n                    bias=bias,\n                ),\n            )\n        elif up_mode == UpsampleMode.NONTRAINABLE:\n            if pre_conv == \"default\" and (out_channels != in_channels):  # defaults to no conv if out_chns==in_chns\n                if not in_channels:\n                    raise ValueError(f\"in_channels needs to be specified in the '{mode}' mode.\")\n                self.add_module(\n                    \"preconv\",\n                    Conv[Conv.CONV, spatial_dims](\n                        in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias\n                    ),\n                )\n            elif pre_conv is not None and pre_conv != \"default\":\n                self.add_module(\"preconv\", pre_conv)  # type: ignore\n            elif pre_conv is None and (out_channels != in_channels):\n                raise ValueError(\n                    \"in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels.\"\n                )\n\n            interp_mode = InterpolateMode(interp_mode)\n            linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]\n            if interp_mode in linear_mode:  # choose mode based on dimensions\n                interp_mode = linear_mode[spatial_dims - 1]\n\n            upsample = nn.Upsample(\n                size=size,\n                scale_factor=None if size else scale_factor_,\n                mode=interp_mode.value,\n                align_corners=align_corners,\n            )\n\n            self.add_module(\"upsample_non_trainable\", upsample)\n            if post_conv:\n                self.add_module(\"postconv\", post_conv)\n        elif up_mode == UpsampleMode.PIXELSHUFFLE:\n            self.add_module(\n                \"pixelshuffle\",\n                SubpixelUpsample(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    scale_factor=scale_factor_[0],  # isotropic\n                    conv_block=pre_conv,\n                    apply_pad_pool=apply_pad_pool,\n                    bias=bias,\n                ),\n            )\n        else:\n            raise NotImplementedError(f\"Unsupported upsampling mode {mode}.\")\n\n\nclass SubpixelUpsample(nn.Module):\n    \"\"\"\n    Upsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.\n    The module is consisted with two parts. First of all, a convolutional layer is employed\n    to increase the number of channels into: ``in_channels * (scale_factor ** dimensions)``.\n    Secondly, a pixel shuffle manipulation is utilized to aggregates the feature maps from\n    low resolution space and build the super resolution space.\n    The first part of the module is not fixed, a sequential layers can be used to replace the\n    default single layer.\n\n    See: Shi et al., 2016, \"Real-Time Single Image and Video Super-Resolution\n    Using an Efficient Sub-Pixel Convolutional Neural Network.\"\n\n    See: Aitken et al., 2017, \"Checkerboard artifact free sub-pixel convolution\".\n\n    The idea comes from:\n    https://arxiv.org/abs/1609.05158\n\n    The pixel shuffle mechanism refers to:\n    https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle.\n    and:\n    https://github.com/pytorch/pytorch/pull/6340.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int | None,\n        out_channels: int | None = None,\n        scale_factor: int = 2,\n        conv_block: nn.Module | str | None = \"default\",\n        apply_pad_pool: bool = True,\n        bias: bool = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of channels of the input image.\n            out_channels: optional number of channels of the output image.\n            scale_factor: multiplier for spatial size. Defaults to 2.\n            conv_block: a conv block to extract feature maps before upsampling. Defaults to None.\n\n                - When ``conv_block`` is ``\"default\"``, one reserved conv layer will be utilized.\n                - When ``conv_block`` is an ``nn.module``,\n                  please ensure the output number of channels is divisible ``(scale_factor ** dimensions)``.\n\n            apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the\n                size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution\n                component of subpixel convolutions described in Aitken et al.\n            bias: whether to have a bias term in the default conv_block. Defaults to True.\n\n        \"\"\"\n        super().__init__()\n\n        if scale_factor <= 0:\n            raise ValueError(f\"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.\")\n\n        self.dimensions = spatial_dims\n        self.scale_factor = scale_factor\n\n        if conv_block == \"default\":\n            out_channels = out_channels or in_channels\n            if not out_channels:\n                raise ValueError(\"in_channels need to be specified.\")\n            conv_out_channels = out_channels * (scale_factor**self.dimensions)\n            self.conv_block = Conv[Conv.CONV, self.dimensions](\n                in_channels=in_channels, out_channels=conv_out_channels, kernel_size=3, stride=1, padding=1, bias=bias\n            )\n\n            icnr_init(self.conv_block, self.scale_factor)\n        elif conv_block is None:\n            self.conv_block = nn.Identity()\n        else:\n            self.conv_block = conv_block\n\n        self.pad_pool: nn.Module = nn.Identity()\n\n        if apply_pad_pool:\n            pool_type = Pool[Pool.AVG, self.dimensions]\n            pad_type = Pad[Pad.CONSTANTPAD, self.dimensions]\n\n            self.pad_pool = nn.Sequential(\n                pad_type(padding=(self.scale_factor - 1, 0) * self.dimensions, value=0.0),\n                pool_type(kernel_size=self.scale_factor, stride=1),\n            )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).\n        \"\"\"\n        x = self.conv_block(x)\n        if x.shape[1] % (self.scale_factor**self.dimensions) != 0:\n            raise ValueError(\n                f\"Number of channels after `conv_block` ({x.shape[1]}) must be evenly \"\n                \"divisible by scale_factor ** dimensions \"\n                f\"({self.scale_factor}^{self.dimensions}={self.scale_factor**self.dimensions}).\"\n            )\n        x = pixelshuffle(x, self.dimensions, self.scale_factor)\n        x = self.pad_pool(x)\n        return x\n\n\nUpsample = UpSample\nSubpixelupsample = SubpixelUpSample = SubpixelUpsample\n"
  },
  {
    "path": "monai/networks/blocks/warp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom monai.config.deviceconfig import USE_COMPILED\nfrom monai.networks.layers.spatial_transforms import grid_pull\nfrom monai.networks.utils import meshgrid_ij\nfrom monai.utils import GridSampleMode, GridSamplePadMode, optional_import\n\n_C, _ = optional_import(\"monai._C\")\n\n__all__ = [\"Warp\", \"DVF2DDF\"]\n\n\nclass Warp(nn.Module):\n    \"\"\"\n    Warp an image with given dense displacement field (DDF).\n    \"\"\"\n\n    def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value, jitter=False):\n        \"\"\"\n        For pytorch native APIs, the possible values are:\n\n            - mode: ``\"nearest\"``, ``\"bilinear\"``, ``\"bicubic\"``.\n            - padding_mode: ``\"zeros\"``, ``\"border\"``, ``\"reflection\"``\n\n        See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n\n        For MONAI C++/CUDA extensions, the possible values are:\n\n            - mode: ``\"nearest\"``, ``\"bilinear\"``, ``\"bicubic\"``, 0, 1, ...\n            - padding_mode: ``\"zeros\"``, ``\"border\"``, ``\"reflection\"``, 0, 1, ...\n\n        See also: :py:class:`monai.networks.layers.grid_pull`\n\n        - jitter: bool, default=False\n            Define reference grid on non-integer values\n            Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration\n            based on mutual information. Image and Vision Computing, 19:33-44, 2001.\n\n        Note that using ``mode=\"nearest\"`` makes the warping operation effectively non-differentiable:\n        gradients are zero almost everywhere, which can block gradient flow during training.\n        For learning-based registration, use ``\"bilinear\"`` (2D) or ``\"trilinear\"`` (3D) interpolation instead.\n\n        See https://github.com/Project-MONAI/tutorials/blob/main/3d_registration/learn2reg_oasis_unpaired_brain_mr.ipynb\n        for examples of semi-supervised registration using segmentations.\n        \"\"\"\n        super().__init__()\n        # resolves _interp_mode for different methods\n\n        if USE_COMPILED:\n            if mode in (inter.value for inter in GridSampleMode):\n                mode = GridSampleMode(mode)\n                if mode == GridSampleMode.BILINEAR:\n                    mode = 1\n                elif mode == GridSampleMode.NEAREST:\n                    mode = 0\n                elif mode == GridSampleMode.BICUBIC:\n                    mode = 3\n                else:\n                    mode = 1  # default to linear\n            self._interp_mode = mode\n        else:\n            warnings.warn(\"monai.networks.blocks.Warp: Using PyTorch native grid_sample.\")\n            self._interp_mode = GridSampleMode(mode).value\n\n        # resolves _padding_mode for different methods\n        if USE_COMPILED:\n            if padding_mode in (pad.value for pad in GridSamplePadMode):\n                padding_mode = GridSamplePadMode(padding_mode)\n                if padding_mode == GridSamplePadMode.ZEROS:\n                    padding_mode = 7\n                elif padding_mode == GridSamplePadMode.BORDER:\n                    padding_mode = 0\n                elif padding_mode == GridSamplePadMode.REFLECTION:\n                    padding_mode = 1\n                else:\n                    padding_mode = 0  # default to nearest\n            self._padding_mode = padding_mode\n        else:\n            self._padding_mode = GridSamplePadMode(padding_mode).value\n\n        self.ref_grid = None\n        self.jitter = jitter\n\n    def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int = 0) -> torch.Tensor:\n        if (\n            self.ref_grid is not None\n            and self.ref_grid.shape[0] == ddf.shape[0]\n            and self.ref_grid.shape[1:] == ddf.shape[2:]\n        ):\n            return self.ref_grid  # type: ignore\n        mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]\n        grid = torch.stack(meshgrid_ij(*mesh_points), dim=0)  # (spatial_dims, ...)\n        grid = torch.stack([grid] * ddf.shape[0], dim=0)  # (batch, spatial_dims, ...)\n        self.ref_grid = grid.to(ddf)\n        if jitter:\n            # Define reference grid on non-integer values\n            with torch.random.fork_rng(enabled=seed):\n                torch.random.manual_seed(seed)\n                grid += torch.rand_like(grid)\n        self.ref_grid.requires_grad = False\n        return self.ref_grid\n\n    def forward(self, image: torch.Tensor, ddf: torch.Tensor):\n        \"\"\"\n        Args:\n            image: Tensor in shape (batch, num_channels, H, W[, D])\n            ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])\n\n        Returns:\n            warped_image in the same shape as image (batch, num_channels, H, W[, D])\n        \"\"\"\n        spatial_dims = len(image.shape) - 2\n        if spatial_dims not in (2, 3):\n            raise NotImplementedError(f\"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.\")\n        ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:])\n        if ddf.shape != ddf_shape:\n            raise ValueError(\n                f\"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, \"\n                f\"Got {ddf.shape} instead.\"\n            )\n        grid = self.get_reference_grid(ddf, jitter=self.jitter) + ddf\n        grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1])  # (batch, ..., spatial_dims)\n\n        if not USE_COMPILED:  # pytorch native grid_sample\n            for i, dim in enumerate(grid.shape[1:-1]):\n                grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1\n            index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))\n            grid = grid[..., index_ordering]  # z, y, x -> x, y, z\n            return F.grid_sample(\n                image, grid, mode=self._interp_mode, padding_mode=f\"{self._padding_mode}\", align_corners=True\n            )\n\n        # using csrc resampling\n        return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)\n\n\nclass DVF2DDF(nn.Module):\n    \"\"\"\n    Layer calculates a dense displacement field (DDF) from a dense velocity field (DVF)\n    with scaling and squaring.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n\n    \"\"\"\n\n    def __init__(\n        self, num_steps: int = 7, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.ZEROS.value\n    ):\n        super().__init__()\n        if num_steps <= 0:\n            raise ValueError(f\"expecting positive num_steps, got {num_steps}\")\n        self.num_steps = num_steps\n        self.warp_layer = Warp(mode=mode, padding_mode=padding_mode)\n\n    def forward(self, dvf: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            dvf: dvf to be transformed, in shape (batch, ``spatial_dims``, H, W[,D])\n\n        Returns:\n            a dense displacement field\n        \"\"\"\n        ddf: torch.Tensor = dvf / (2**self.num_steps)\n        for _ in range(self.num_steps):\n            ddf = ddf + self.warp_layer(image=ddf, ddf=ddf)\n        return ddf\n"
  },
  {
    "path": "monai/networks/layers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .conjugate_gradient import ConjugateGradient\nfrom .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding\nfrom .drop_path import DropPath\nfrom .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, RelPosEmbedding, split_args\nfrom .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter\nfrom .gmm import GaussianMixtureModel\nfrom .simplelayers import (\n    LLTM,\n    ApplyFilter,\n    ChannelPad,\n    EllipticalFilter,\n    Flatten,\n    GaussianFilter,\n    HilbertTransform,\n    LaplaceFilter,\n    MeanFilter,\n    MedianFilter,\n    Reshape,\n    SavitzkyGolayFilter,\n    SharpenFilter,\n    SkipConnection,\n    apply_filter,\n    median_filter,\n    separable_filtering,\n)\nfrom .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push\nfrom .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer\nfrom .vector_quantizer import EMAQuantizer, VectorQuantizer\nfrom .weight_init import _no_grad_trunc_normal_, trunc_normal_\n"
  },
  {
    "path": "monai/networks/layers/conjugate_gradient.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Callable\n\nimport torch\nfrom torch import nn\n\n\ndef _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Complex dot product between tensors x1 and x2: sum(x1.*x2)\n    \"\"\"\n    if torch.is_complex(x1):\n        assert torch.is_complex(x2), \"x1 and x2 must both be complex\"\n        return torch.sum(x1.conj() * x2)\n    else:\n        return torch.sum(x1 * x2)\n\n\ndef _zdot_single(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Complex dot product between tensor x and itself\n    \"\"\"\n    res = _zdot(x, x)\n    if torch.is_complex(res):\n        return res.real\n    else:\n        return res\n\n\nclass ConjugateGradient(nn.Module):\n    \"\"\"\n    Congugate Gradient (CG) solver for linear systems Ax = y.\n\n    For linear_op that is positive definite and self-adjoint, CG is\n    guaranteed to converge CG is often used to solve linear systems of the form\n    Ax = y, where A is too large to store explicitly, but can be computed via a\n    linear operator.\n\n    As a result, here we won't set A explicitly as a matrix, but rather as a\n    linear operator. For example, A could be a FFT/IFFT operation\n    \"\"\"\n\n    def __init__(self, linear_op: Callable, num_iter: int):\n        \"\"\"\n        Args:\n            linear_op: Linear operator\n            num_iter: Number of iterations to run CG\n        \"\"\"\n        super().__init__()\n\n        self.linear_op = linear_op\n        self.num_iter = num_iter\n\n    def update(\n        self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        perform one iteration of the CG method. It takes the current solution x,\n        the current search direction p, the current residual r, and the old\n        residual norm rsold as inputs. Then it computes the new solution, search\n        direction, residual, and residual norm, and returns them.\n        \"\"\"\n\n        dy = self.linear_op(p)\n        p_dot_dy = _zdot(p, dy)\n        alpha = rsold / p_dot_dy\n        x = x + alpha * p\n        r = r - alpha * dy\n        rsnew = _zdot_single(r)\n        beta = rsnew / rsold\n        rsold = rsnew\n        p = beta * p + r\n        return x, p, r, rsold\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        run conjugate gradient for num_iter iterations to solve Ax = y\n\n        Args:\n            x: tensor (real or complex); Initial guess for linear system Ax = y.\n            The size of x should be applicable to the linear operator. For\n            example, if the linear operator is FFT, then x is HCHW; if the\n            linear operator is a matrix multiplication, then x is a vector\n\n            y: tensor (real or complex); Measurement. Same size as x\n\n        Returns:\n            x: Solution to Ax = y\n        \"\"\"\n        # Compute residual\n        r = y - self.linear_op(x)\n        rsold = _zdot_single(r)\n        p = r\n\n        # Update\n        for _i in range(self.num_iter):\n            x, p, r, rsold = self.update(x, p, r, rsold)\n            if rsold < 1e-10:\n                break\n        return x\n"
  },
  {
    "path": "monai/networks/layers/convutils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\n\n__all__ = [\"same_padding\", \"stride_minus_kernel_padding\", \"calculate_out_shape\", \"gaussian_1d\", \"polyval\"]\n\n\ndef same_padding(kernel_size: Sequence[int] | int, dilation: Sequence[int] | int = 1) -> tuple[int, ...] | int:\n    \"\"\"\n    Return the padding value needed to ensure a convolution using the given kernel size produces an output of the same\n    shape as the input for a stride of 1, otherwise ensure a shape of the input divided by the stride rounded down.\n\n    Raises:\n        NotImplementedError: When ``np.any((kernel_size - 1) * dilation % 2 == 1)``.\n\n    \"\"\"\n\n    kernel_size_np = np.atleast_1d(kernel_size)\n    dilation_np = np.atleast_1d(dilation)\n\n    if np.any((kernel_size_np - 1) * dilation % 2 == 1):\n        raise NotImplementedError(\n            f\"Same padding not available for kernel_size={kernel_size_np} and dilation={dilation_np}.\"\n        )\n\n    padding_np = (kernel_size_np - 1) / 2 * dilation_np\n    padding = tuple(int(p) for p in padding_np)\n\n    return padding if len(padding) > 1 else padding[0]\n\n\ndef stride_minus_kernel_padding(kernel_size: Sequence[int] | int, stride: Sequence[int] | int) -> tuple[int, ...] | int:\n    kernel_size_np = np.atleast_1d(kernel_size)\n    stride_np = np.atleast_1d(stride)\n\n    out_padding_np = stride_np - kernel_size_np\n    out_padding = tuple(int(p) for p in out_padding_np)\n\n    return out_padding if len(out_padding) > 1 else out_padding[0]\n\n\ndef calculate_out_shape(\n    in_shape: Sequence[int] | int | np.ndarray,\n    kernel_size: Sequence[int] | int,\n    stride: Sequence[int] | int,\n    padding: Sequence[int] | int,\n) -> tuple[int, ...] | int:\n    \"\"\"\n    Calculate the output tensor shape when applying a convolution to a tensor of shape `inShape` with kernel size\n    `kernel_size`, stride value `stride`, and input padding value `padding`. All arguments can be scalars or multiple\n    values, return value is a scalar if all inputs are scalars.\n    \"\"\"\n    in_shape_np = np.atleast_1d(in_shape)\n    kernel_size_np = np.atleast_1d(kernel_size)\n    stride_np = np.atleast_1d(stride)\n    padding_np = np.atleast_1d(padding)\n\n    out_shape_np = ((in_shape_np - kernel_size_np + padding_np + padding_np) // stride_np) + 1\n    out_shape = tuple(int(s) for s in out_shape_np)\n\n    return out_shape\n\n\ndef gaussian_1d(\n    sigma: torch.Tensor, truncated: float = 4.0, approx: str = \"erf\", normalize: bool = False\n) -> torch.Tensor:\n    \"\"\"\n    one dimensional Gaussian kernel.\n\n    Args:\n        sigma: std of the kernel\n        truncated: tail length\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n\n            - ``erf`` approximation interpolates the error function;\n            - ``sampled`` uses a sampled Gaussian kernel;\n            - ``scalespace`` corresponds to\n              https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel\n              based on the modified Bessel functions.\n\n        normalize: whether to normalize the kernel with `kernel.sum()`.\n\n    Raises:\n        ValueError: When ``truncated`` is non-positive.\n\n    Returns:\n        1D torch tensor\n\n    \"\"\"\n    sigma = torch.as_tensor(sigma, dtype=torch.float, device=sigma.device if isinstance(sigma, torch.Tensor) else None)\n    device = sigma.device\n    if truncated <= 0.0:\n        raise ValueError(f\"truncated must be positive, got {truncated}.\")\n    tail = int(max(float(sigma) * truncated, 0.5) + 0.5)\n    if approx.lower() == \"erf\":\n        x = torch.arange(-tail, tail + 1, dtype=torch.float, device=device)\n        t = 0.70710678 / torch.abs(sigma)\n        out = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())\n        out = out.clamp(min=0)\n    elif approx.lower() == \"sampled\":\n        x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device)\n        out = torch.exp(-0.5 / (sigma * sigma) * x**2)\n        if not normalize:  # compute the normalizer\n            out = out / (2.5066282 * sigma)\n    elif approx.lower() == \"scalespace\":\n        sigma2 = sigma * sigma\n        out_pos: list[torch.Tensor | None] = [None] * (tail + 1)\n        out_pos[0] = _modified_bessel_0(sigma2)\n        out_pos[1] = _modified_bessel_1(sigma2)\n        for k in range(2, len(out_pos)):\n            out_pos[k] = _modified_bessel_i(k, sigma2)\n        out = out_pos[:0:-1]\n        out.extend(out_pos)\n        out = torch.stack(out) * torch.exp(-sigma2)\n    else:\n        raise NotImplementedError(f\"Unsupported option: approx='{approx}'.\")\n    return out / out.sum() if normalize else out  # type: ignore\n\n\ndef polyval(coef, x) -> torch.Tensor:\n    \"\"\"\n    Evaluates the polynomial defined by `coef` at `x`.\n\n    For a 1D sequence of coef (length n), evaluate::\n\n        y = coef[n-1] + x * (coef[n-2] + ... + x * (coef[1] + x * coef[0]))\n\n    Args:\n        coef: a sequence of floats representing the coefficients of the polynomial\n        x: float or a sequence of floats representing the variable of the polynomial\n\n    Returns:\n        1D torch tensor\n    \"\"\"\n    device = x.device if isinstance(x, torch.Tensor) else None\n    coef = torch.as_tensor(coef, dtype=torch.float, device=device)\n    if coef.ndim == 0 or (len(coef) < 1):\n        return torch.zeros(x.shape)\n    x = torch.as_tensor(x, dtype=torch.float, device=device)\n    ans = coef[0]\n    for c in coef[1:]:\n        ans = ans * x + c\n    return ans  # type: ignore\n\n\ndef _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:\n    x = torch.as_tensor(x, dtype=torch.float, device=x.device if isinstance(x, torch.Tensor) else None)\n    if torch.abs(x) < 3.75:\n        y = x * x / 14.0625\n        return polyval([0.45813e-2, 0.360768e-1, 0.2659732, 1.2067492, 3.0899424, 3.5156229, 1.0], y)\n    ax = torch.abs(x)\n    y = 3.75 / ax\n    _coef = [\n        0.392377e-2,\n        -0.1647633e-1,\n        0.2635537e-1,\n        -0.2057706e-1,\n        0.916281e-2,\n        -0.157565e-2,\n        0.225319e-2,\n        0.1328592e-1,\n        0.39894228,\n    ]\n    return polyval(_coef, y) * torch.exp(ax) / torch.sqrt(ax)\n\n\ndef _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:\n    x = torch.as_tensor(x, dtype=torch.float, device=x.device if isinstance(x, torch.Tensor) else None)\n    if torch.abs(x) < 3.75:\n        y = x * x / 14.0625\n        _coef = [0.32411e-3, 0.301532e-2, 0.2658733e-1, 0.15084934, 0.51498869, 0.87890594, 0.5]\n        return torch.abs(x) * polyval(_coef, y)\n    ax = torch.abs(x)\n    y = 3.75 / ax\n    _coef = [\n        -0.420059e-2,\n        0.1787654e-1,\n        -0.2895312e-1,\n        0.2282967e-1,\n        -0.1031555e-1,\n        0.163801e-2,\n        -0.362018e-2,\n        -0.3988024e-1,\n        0.39894228,\n    ]\n    ans = polyval(_coef, y) * torch.exp(ax) / torch.sqrt(ax)\n    return -ans if x < 0.0 else ans\n\n\ndef _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:\n    if n < 2:\n        raise ValueError(f\"n must be greater than 1, got n={n}.\")\n    x = torch.as_tensor(x, dtype=torch.float, device=x.device if isinstance(x, torch.Tensor) else None)\n    if x == 0.0:\n        return x\n    device = x.device\n    tox = 2.0 / torch.abs(x)\n    ans, bip, bi = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)\n    m = int(2 * (n + np.floor(np.sqrt(40.0 * n))))\n    for j in range(m, 0, -1):\n        bim = bip + float(j) * tox * bi\n        bip = bi\n        bi = bim\n        if abs(bi) > 1.0e10:\n            ans = ans * 1.0e-10\n            bi = bi * 1.0e-10\n            bip = bip * 1.0e-10\n        if j == n:\n            ans = bip\n    ans = ans * _modified_bessel_0(x) / bi\n    return -ans if x < 0.0 and (n % 2) == 1 else ans\n"
  },
  {
    "path": "monai/networks/layers/drop_path.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch.nn as nn\n\n\nclass DropPath(nn.Module):\n    \"\"\"Stochastic drop paths per sample for residual blocks.\n    Based on:\n    https://github.com/rwightman/pytorch-image-models\n    \"\"\"\n\n    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:\n        \"\"\"\n        Args:\n            drop_prob: drop path probability.\n            scale_by_keep: scaling by non-dropped probability.\n        \"\"\"\n        super().__init__()\n        self.drop_prob = drop_prob\n        self.scale_by_keep = scale_by_keep\n\n        if not (0 <= drop_prob <= 1):\n            raise ValueError(\"Drop path prob should be between 0 and 1.\")\n\n    def drop_path(self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):\n        if drop_prob == 0.0 or not training:\n            return x\n        keep_prob = 1 - drop_prob\n        shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n        if keep_prob > 0.0 and scale_by_keep:\n            random_tensor.div_(keep_prob)\n        return x * random_tensor\n\n    def forward(self, x):\n        return self.drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n"
  },
  {
    "path": "monai/networks/layers/factories.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nDefines factories for creating layers in generic, extensible, and dimensionally independent ways. A separate factory\nobject is created for each type of layer, and factory functions keyed to names are added to these objects. Whenever\na layer is requested the factory name and any necessary arguments are passed to the factory object. The return value\nis typically a type but can be any callable producing a layer object.\n\nThe factory objects contain functions keyed to names converted to upper case, these names can be referred to as members\nof the factory so that they can function as constant identifiers. eg. instance normalization is named `Norm.INSTANCE`.\n\nFor example, to get a transpose convolution layer the name is needed and then a dimension argument is provided which is\npassed to the factory function:\n\n.. code-block:: python\n\n    dimension = 3\n    name = Conv.CONVTRANS\n    conv = Conv[name, dimension]\n\nThis allows the `dimension` value to be set in the constructor, for example so that the dimensionality of a network is\nparameterizable. Not all factories require arguments after the name, the caller must be aware which are required.\n\nDefining new factories involves creating the object then associating it with factory functions:\n\n.. code-block:: python\n\n    fact = LayerFactory()\n\n    @fact.factory_function('test')\n    def make_something(x, y):\n        # do something with x and y to choose which layer type to return\n        return SomeLayerType\n    ...\n\n    # request object from factory TEST with 1 and 2 as values for x and y\n    layer = fact[fact.TEST, 1, 2]\n\nTypically the caller of a factory would know what arguments to pass (ie. the dimensionality of the requested type) but\ncan be parameterized with the factory name and the arguments to pass to the created type at instantiation time:\n\n.. code-block:: python\n\n    def use_factory(fact_args):\n        fact_name, type_args = split_args\n        layer_type = fact[fact_name, 1, 2]\n        return layer_type(**type_args)\n    ...\n\n    kw_args = {'arg0':0, 'arg1':True}\n    layer = use_factory( (fact.TEST, kwargs) )\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport torch.nn as nn\n\nfrom monai.networks.utils import has_nvfuser_instance_norm\nfrom monai.utils import ComponentStore, look_up_option, optional_import\n\n__all__ = [\"LayerFactory\", \"Dropout\", \"Norm\", \"Act\", \"Conv\", \"Pool\", \"Pad\", \"RelPosEmbedding\", \"split_args\"]\n\n\nclass LayerFactory(ComponentStore):\n    \"\"\"\n    Factory object for creating layers, this uses given factory functions to actually produce the types or constructing\n    callables. These functions are referred to by name and can be added at any time.\n    \"\"\"\n\n    def __init__(self, name: str, description: str) -> None:\n        super().__init__(name, description)\n        self.__doc__ = (\n            f\"Layer Factory '{name}': {description}\\n\".strip()\n            + \"\\nPlease see :py:class:`monai.networks.layers.split_args` for additional args parsing.\"\n            + \"\\n\\nThe supported members are:\"\n        )\n\n    def add_factory_callable(self, name: str, func: Callable, desc: str | None = None) -> None:\n        \"\"\"\n        Add the factory function to this object under the given name, with optional description.\n        \"\"\"\n        description: str = desc or func.__doc__ or \"\"\n        self.add(name.upper(), description, func)\n        # append name to the docstring\n        assert self.__doc__ is not None\n        self.__doc__ += f\"{', ' if len(self.names) > 1 else ' '}``{name}``\"\n\n    def add_factory_class(self, name: str, cls: type, desc: str | None = None) -> None:\n        \"\"\"\n        Adds a factory function which returns the supplied class under the given name, with optional description.\n        \"\"\"\n        self.add_factory_callable(name, lambda x=None: cls, desc)\n\n    def factory_function(self, name: str) -> Callable:\n        \"\"\"\n        Decorator for adding a factory function with the given name.\n        \"\"\"\n\n        def _add(func: Callable) -> Callable:\n            self.add_factory_callable(name, func)\n            return func\n\n        return _add\n\n    def get_constructor(self, factory_name: str, *args) -> Any:\n        \"\"\"\n        Get the constructor for the given factory name and arguments.\n\n        Raises:\n            TypeError: When ``factory_name`` is not a ``str``.\n\n        \"\"\"\n\n        if not isinstance(factory_name, str):\n            raise TypeError(f\"factory_name must a str but is {type(factory_name).__name__}.\")\n\n        component = look_up_option(factory_name.upper(), self.components)\n\n        return component.value(*args)\n\n    def __getitem__(self, args) -> Any:\n        \"\"\"\n        Get the given name or name/arguments pair. If `args` is a callable it is assumed to be the constructor\n        itself and is returned, otherwise it should be the factory name or a pair containing the name and arguments.\n        \"\"\"\n\n        # `args[0]` is actually a type or constructor\n        if callable(args):\n            return args\n\n        # `args` is a factory name or a name with arguments\n        if isinstance(args, str):\n            name_obj, args = args, ()\n        else:\n            name_obj, *args = args\n\n        return self.get_constructor(name_obj, *args)\n\n    def __getattr__(self, key):\n        \"\"\"\n        If `key` is a factory name, return it, otherwise behave as inherited. This allows referring to factory names\n        as if they were constants, eg. `Fact.FOO` for a factory Fact with factory function foo.\n        \"\"\"\n\n        if key in self.components:\n            return key\n\n        return super().__getattribute__(key)\n\n\ndef split_args(args):\n    \"\"\"\n    Split arguments in a way to be suitable for using with the factory types. If `args` is a string it's interpreted as\n    the type name.\n\n    Args:\n        args (str or a tuple of object name and kwarg dict): input arguments to be parsed.\n\n    Raises:\n        TypeError: When ``args`` type is not in ``Union[str, Tuple[Union[str, Callable], dict]]``.\n\n    Examples::\n\n        >>> act_type, args = split_args(\"PRELU\")\n        >>> monai.networks.layers.Act[act_type]\n        <class 'torch.nn.modules.activation.PReLU'>\n\n        >>> act_type, args = split_args((\"PRELU\", {\"num_parameters\": 1, \"init\": 0.25}))\n        >>> monai.networks.layers.Act[act_type](**args)\n        PReLU(num_parameters=1)\n\n    \"\"\"\n\n    if isinstance(args, str):\n        return args, {}\n    name_obj, name_args = args\n\n    if not (isinstance(name_obj, str) or callable(name_obj)) or not isinstance(name_args, dict):\n        msg = \"Layer specifiers must be single strings or pairs of the form (name/object-types, argument dict)\"\n        raise TypeError(msg)\n\n    return name_obj, name_args\n\n\n# Define factories for these layer types\nDropout = LayerFactory(name=\"Dropout layers\", description=\"Factory for creating dropout layers.\")\nNorm = LayerFactory(name=\"Normalization layers\", description=\"Factory for creating normalization layers.\")\nAct = LayerFactory(name=\"Activation layers\", description=\"Factory for creating activation layers.\")\nConv = LayerFactory(name=\"Convolution layers\", description=\"Factory for creating convolution layers.\")\nPool = LayerFactory(name=\"Pooling layers\", description=\"Factory for creating pooling layers.\")\nPad = LayerFactory(name=\"Padding layers\", description=\"Factory for creating padding layers.\")\nRelPosEmbedding = LayerFactory(\n    name=\"Relative positional embedding layers\",\n    description=\"Factory for creating relative positional embedding factory\",\n)\n\n\n@Dropout.factory_function(\"dropout\")\ndef dropout_factory(dim: int) -> type[nn.Dropout | nn.Dropout2d | nn.Dropout3d]:\n    \"\"\"\n    Dropout layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the dropout layer\n\n    Returns:\n        Dropout[dim]d\n    \"\"\"\n    types = (nn.Dropout, nn.Dropout2d, nn.Dropout3d)\n    return types[dim - 1]\n\n\nDropout.add_factory_class(\"alphadropout\", nn.AlphaDropout)\n\n\n@Norm.factory_function(\"instance\")\ndef instance_factory(dim: int) -> type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]:\n    \"\"\"\n    Instance normalization layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the instance normalization layer\n\n    Returns:\n        InstanceNorm[dim]d\n    \"\"\"\n    types = (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)\n    return types[dim - 1]\n\n\n@Norm.factory_function(\"batch\")\ndef batch_factory(dim: int) -> type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]:\n    \"\"\"\n    Batch normalization layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the batch normalization layer\n\n    Returns:\n        BatchNorm[dim]d\n    \"\"\"\n    types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\n    return types[dim - 1]\n\n\n@Norm.factory_function(\"instance_nvfuser\")\ndef instance_nvfuser_factory(dim):\n    \"\"\"\n    `InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`.\n    It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS.\n    In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist,\n    `nn.InstanceNorm3d` will be returned instead.\n    This layer is based on a customized autograd function, which is not supported in TorchScript currently.\n    Please switch to use `nn.InstanceNorm3d` if TorchScript is necessary.\n\n    Please check the following link for more details about how to install `apex`:\n    https://github.com/NVIDIA/apex#installation\n\n    \"\"\"\n\n    if dim != 3:\n        types = (nn.InstanceNorm1d, nn.InstanceNorm2d)\n        warnings.warn(f\"`InstanceNorm3dNVFuser` only supports 3d cases, use {types[dim - 1]} instead.\")\n        return types[dim - 1]\n\n    if not has_nvfuser_instance_norm():\n        warnings.warn(\n            \"`apex.normalization.InstanceNorm3dNVFuser` is not installed properly, use nn.InstanceNorm3d instead.\"\n        )\n        return nn.InstanceNorm3d\n    return optional_import(\"apex.normalization\", name=\"InstanceNorm3dNVFuser\")[0]\n\n\nNorm.add_factory_class(\"group\", nn.GroupNorm)\nNorm.add_factory_class(\"layer\", nn.LayerNorm)\nNorm.add_factory_class(\"localresponse\", nn.LocalResponseNorm)\nNorm.add_factory_class(\"syncbatch\", nn.SyncBatchNorm)\n\nAct.add_factory_class(\"elu\", nn.modules.ELU)\nAct.add_factory_class(\"relu\", nn.modules.ReLU)\nAct.add_factory_class(\"leakyrelu\", nn.modules.LeakyReLU)\nAct.add_factory_class(\"prelu\", nn.modules.PReLU)\nAct.add_factory_class(\"relu6\", nn.modules.ReLU6)\nAct.add_factory_class(\"selu\", nn.modules.SELU)\nAct.add_factory_class(\"celu\", nn.modules.CELU)\nAct.add_factory_class(\"gelu\", nn.modules.GELU)\nAct.add_factory_class(\"sigmoid\", nn.modules.Sigmoid)\nAct.add_factory_class(\"tanh\", nn.modules.Tanh)\nAct.add_factory_class(\"softmax\", nn.modules.Softmax)\nAct.add_factory_class(\"logsoftmax\", nn.modules.LogSoftmax)\n\n\n@Act.factory_function(\"swish\")\ndef swish_factory():\n    \"\"\"\n    Swish activation layer.\n\n    Returns:\n        Swish\n    \"\"\"\n    from monai.networks.blocks.activation import Swish\n\n    return Swish\n\n\n@Act.factory_function(\"memswish\")\ndef memswish_factory():\n    \"\"\"\n    Memory efficient swish activation layer.\n\n    Returns:\n        MemoryEfficientSwish\n    \"\"\"\n    from monai.networks.blocks.activation import MemoryEfficientSwish\n\n    return MemoryEfficientSwish\n\n\n@Act.factory_function(\"mish\")\ndef mish_factory():\n    \"\"\"\n    Mish activation layer.\n\n    Returns:\n        Mish\n    \"\"\"\n    from monai.networks.blocks.activation import Mish\n\n    return Mish\n\n\n@Act.factory_function(\"geglu\")\ndef geglu_factory():\n    \"\"\"\n    GEGLU activation layer.\n\n    Returns:\n        GEGLU\n    \"\"\"\n    from monai.networks.blocks.activation import GEGLU\n\n    return GEGLU\n\n\n@Conv.factory_function(\"conv\")\ndef conv_factory(dim: int) -> type[nn.Conv1d | nn.Conv2d | nn.Conv3d]:\n    \"\"\"\n    Convolutional layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the convolutional layer\n\n    Returns:\n        Conv[dim]d\n    \"\"\"\n    types = (nn.Conv1d, nn.Conv2d, nn.Conv3d)\n    return types[dim - 1]\n\n\n@Conv.factory_function(\"convtrans\")\ndef convtrans_factory(dim: int) -> type[nn.ConvTranspose1d | nn.ConvTranspose2d | nn.ConvTranspose3d]:\n    \"\"\"\n    Transposed convolutional layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the transposed convolutional layer\n\n    Returns:\n        ConvTranspose[dim]d\n    \"\"\"\n    types = (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)\n    return types[dim - 1]\n\n\n@Pool.factory_function(\"max\")\ndef maxpooling_factory(dim: int) -> type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d]:\n    \"\"\"\n    Max pooling layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the max pooling layer\n\n    Returns:\n        MaxPool[dim]d\n    \"\"\"\n    types = (nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d)\n    return types[dim - 1]\n\n\n@Pool.factory_function(\"adaptivemax\")\ndef adaptive_maxpooling_factory(dim: int) -> type[nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d]:\n    \"\"\"\n    Adaptive max pooling layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the adaptive max pooling layer\n\n    Returns:\n        AdaptiveMaxPool[dim]d\n    \"\"\"\n    types = (nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d)\n    return types[dim - 1]\n\n\n@Pool.factory_function(\"avg\")\ndef avgpooling_factory(dim: int) -> type[nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d]:\n    \"\"\"\n    Average pooling layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the average pooling layer\n\n    Returns:\n        AvgPool[dim]d\n    \"\"\"\n    types = (nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d)\n    return types[dim - 1]\n\n\n@Pool.factory_function(\"adaptiveavg\")\ndef adaptive_avgpooling_factory(dim: int) -> type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d]:\n    \"\"\"\n    Adaptive average pooling layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the adaptive average pooling layer\n\n    Returns:\n        AdaptiveAvgPool[dim]d\n    \"\"\"\n    types = (nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d)\n    return types[dim - 1]\n\n\n@Pad.factory_function(\"replicationpad\")\ndef replication_pad_factory(dim: int) -> type[nn.ReplicationPad1d | nn.ReplicationPad2d | nn.ReplicationPad3d]:\n    \"\"\"\n    Replication padding layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the replication padding layer\n\n    Returns:\n        ReplicationPad[dim]d\n    \"\"\"\n    types = (nn.ReplicationPad1d, nn.ReplicationPad2d, nn.ReplicationPad3d)\n    return types[dim - 1]\n\n\n@Pad.factory_function(\"constantpad\")\ndef constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d | nn.ConstantPad3d]:\n    \"\"\"\n    Constant padding layers in 1,2,3 dimensions.\n\n    Args:\n        dim: desired dimension of the constant padding layer\n\n    Returns:\n        ConstantPad[dim]d\n    \"\"\"\n    types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)\n    return types[dim - 1]\n\n\n@RelPosEmbedding.factory_function(\"decomposed\")\ndef decomposed_rel_pos_embedding() -> type[nn.Module]:\n    from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding\n\n    return DecomposedRelativePosEmbedding\n"
  },
  {
    "path": "monai/networks/layers/filtering.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.utils.module import optional_import\n\n_C, _ = optional_import(\"monai._C\")\n\n__all__ = [\"BilateralFilter\", \"PHLFilter\", \"TrainableBilateralFilter\", \"TrainableJointBilateralFilter\"]\n\n\nclass BilateralFilter(torch.autograd.Function):\n    \"\"\"\n    Blurs the input tensor spatially whilst preserving edges. Can run on 1D, 2D, or 3D,\n    tensors (on top of Batch and Channel dimensions). Two implementations are provided,\n    an exact solution and a much faster approximation which uses a permutohedral lattice.\n\n    See:\n        https://en.wikipedia.org/wiki/Bilateral_filter\n        https://graphics.stanford.edu/papers/permutohedral/\n\n    Args:\n        input: input tensor.\n        spatial_sigma: the standard deviation of the spatial blur. Higher values can\n            hurt performance when not using the approximate method (see fast approx).\n        color_sigma: the standard deviation of the color blur. Lower values preserve\n            edges better whilst higher values tend to a simple gaussian spatial blur.\n        fast approx: This flag chooses between two implementations. The approximate method may\n            produce artifacts in some scenarios whereas the exact solution may be intolerably\n            slow for high spatial standard deviations.\n\n    Returns:\n        output (torch.Tensor): output tensor.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):\n        \"\"\"autograd forward\"\"\"\n        ctx.ss = spatial_sigma\n        ctx.cs = color_sigma\n        ctx.fa = fast_approx\n        output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        return output_data\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        \"\"\"autograd backward\"\"\"\n        spatial_sigma, color_sigma, fast_approx = ctx.ss, ctx.cs, ctx.fa\n        grad_input = _C.bilateral_filter(grad_output, spatial_sigma, color_sigma, fast_approx)\n        return grad_input, None, None, None\n\n\nclass PHLFilter(torch.autograd.Function):\n    \"\"\"\n    Filters input based on arbitrary feature vectors. Uses a permutohedral\n    lattice data structure to efficiently approximate n-dimensional gaussian\n    filtering. Complexity is broadly independent of kernel size. Most applicable\n    to higher filter dimensions and larger kernel sizes.\n\n    See:\n        https://graphics.stanford.edu/papers/permutohedral/\n\n    Args:\n        input: input tensor to be filtered.\n        features: feature tensor used to filter the input.\n        sigmas: the standard deviations of each feature in the filter.\n\n    Returns:\n        output (torch.Tensor): output tensor.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input, features, sigmas=None):\n        scaled_features = features\n        if sigmas is not None:\n            for i in range(features.size(1)):\n                scaled_features[:, i, ...] /= sigmas[i]\n\n        ctx.save_for_backward(scaled_features)\n        output_data = _C.phl_filter(input, scaled_features)\n        return output_data\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        raise NotImplementedError(\"PHLFilter does not currently support Backpropagation\")\n        # scaled_features, = ctx.saved_variables\n        # grad_input = _C.phl_filter(grad_output, scaled_features)\n        # return grad_input\n\n\nclass TrainableBilateralFilterFunction(torch.autograd.Function):\n    \"\"\"\n    torch.autograd.Function for the TrainableBilateralFilter layer.\n\n    See:\n        F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in\n        computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718\n\n    Args:\n        input: input tensor to be filtered.\n        sigma x: trainable standard deviation of the spatial filter kernel in x direction.\n        sigma y: trainable standard deviation of the spatial filter kernel in y direction.\n        sigma z: trainable standard deviation of the spatial filter kernel in z direction.\n        color sigma: trainable standard deviation of the intensity range kernel. This filter\n            parameter determines the degree of edge preservation.\n\n    Returns:\n        output (torch.Tensor): filtered tensor.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma):\n        output_tensor, output_weights_tensor, do_dx_ki, do_dsig_r, do_dsig_x, do_dsig_y, do_dsig_z = _C.tbf_forward(\n            input_img, sigma_x, sigma_y, sigma_z, color_sigma\n        )\n\n        ctx.save_for_backward(\n            input_img,\n            sigma_x,\n            sigma_y,\n            sigma_z,\n            color_sigma,\n            output_tensor,\n            output_weights_tensor,\n            do_dx_ki,\n            do_dsig_r,\n            do_dsig_x,\n            do_dsig_y,\n            do_dsig_z,\n        )\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        return output_tensor\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input_img = ctx.saved_tensors[0]  # input image\n        sigma_x = ctx.saved_tensors[1]\n        sigma_y = ctx.saved_tensors[2]\n        sigma_z = ctx.saved_tensors[3]\n        color_sigma = ctx.saved_tensors[4]\n        output_tensor = ctx.saved_tensors[5]  # filtered image\n        output_weights_tensor = ctx.saved_tensors[6]  # weights\n        do_dx_ki = ctx.saved_tensors[7]  # derivative of output with respect to input, while k==i\n        do_dsig_r = ctx.saved_tensors[8]  # derivative of output with respect to range sigma\n        do_dsig_x = ctx.saved_tensors[9]  # derivative of output with respect to sigma x\n        do_dsig_y = ctx.saved_tensors[10]  # derivative of output with respect to sigma y\n        do_dsig_z = ctx.saved_tensors[11]  # derivative of output with respect to sigma z\n\n        # calculate gradient with respect to the sigmas\n        grad_color_sigma = torch.sum(grad_output * do_dsig_r)\n        grad_sig_x = torch.sum(grad_output * do_dsig_x)\n        grad_sig_y = torch.sum(grad_output * do_dsig_y)\n        grad_sig_z = torch.sum(grad_output * do_dsig_z)\n\n        grad_output_tensor = _C.tbf_backward(\n            grad_output,\n            input_img,\n            output_tensor,\n            output_weights_tensor,\n            do_dx_ki,\n            sigma_x,\n            sigma_y,\n            sigma_z,\n            color_sigma,\n        )\n\n        return grad_output_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma\n\n\nclass TrainableBilateralFilter(torch.nn.Module):\n    \"\"\"\n    Implementation of a trainable bilateral filter layer as proposed in the corresponding publication.\n    All filter parameters can be trained data-driven. The spatial filter kernels x, y, and z determine\n    image smoothing whereas the color parameter specifies the amount of edge preservation.\n    Can run on 1D, 2D, or 3D tensors (on top of Batch and Channel dimensions).\n\n    See:\n        F. Wagner, et al., Ultralow-parameter denoising: Trainable bilateral filter layers in\n        computed tomography, Medical Physics (2022), https://doi.org/10.1002/mp.15718\n\n    Args:\n        input: input tensor to be filtered.\n        spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard\n            deviations of the spatial filter kernels. Tuple length must equal the number of\n            spatial input dimensions.\n        color_sigma: trainable standard deviation of the intensity range kernel. This filter\n            parameter determines the degree of edge preservation.\n\n    Returns:\n        output (torch.Tensor): filtered tensor.\n    \"\"\"\n\n    def __init__(self, spatial_sigma, color_sigma):\n        super().__init__()\n\n        if isinstance(spatial_sigma, float):\n            spatial_sigma = [spatial_sigma, spatial_sigma, spatial_sigma]\n            self.len_spatial_sigma = 3\n        elif len(spatial_sigma) == 1:\n            spatial_sigma = [spatial_sigma[0], 0.01, 0.01]\n            self.len_spatial_sigma = 1\n        elif len(spatial_sigma) == 2:\n            spatial_sigma = [spatial_sigma[0], spatial_sigma[1], 0.01]\n            self.len_spatial_sigma = 2\n        elif len(spatial_sigma) == 3:\n            spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]\n            self.len_spatial_sigma = 3\n        else:\n            raise ValueError(\n                f\"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}.\"\n            )\n\n        # Register sigmas as trainable parameters.\n        self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))\n        self.sigma_y = torch.nn.Parameter(torch.tensor(spatial_sigma[1]))\n        self.sigma_z = torch.nn.Parameter(torch.tensor(spatial_sigma[2]))\n        self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))\n\n    def forward(self, input_tensor):\n        if input_tensor.shape[1] != 1:\n            raise ValueError(\n                f\"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. \"\n                \"Please use multiple parallel filter layers if you want \"\n                \"to filter multiple channels.\"\n            )\n\n        len_input = len(input_tensor.shape)\n\n        # C++ extension so far only supports 5-dim inputs.\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n\n        if self.len_spatial_sigma != len_input:\n            raise ValueError(f\"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).\")\n\n        prediction = TrainableBilateralFilterFunction.apply(\n            input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color\n        )\n\n        # Make sure to return tensor of the same shape as the input.\n        if len_input == 3:\n            prediction = prediction.squeeze(4).squeeze(3)\n        elif len_input == 4:\n            prediction = prediction.squeeze(4)\n\n        return prediction\n\n\nclass TrainableJointBilateralFilterFunction(torch.autograd.Function):\n    \"\"\"\n    torch.autograd.Function for the TrainableJointBilateralFilter layer.\n\n    See:\n        F. Wagner, et al., Trainable joint bilateral filters for enhanced prediction stability in\n        low-dose CT, Scientific Reports (2022), https://doi.org/10.1038/s41598-022-22530-4\n\n    Args:\n        input: input tensor to be filtered.\n        guide: guidance image tensor to be used during filtering.\n        sigma x: trainable standard deviation of the spatial filter kernel in x direction.\n        sigma y: trainable standard deviation of the spatial filter kernel in y direction.\n        sigma z: trainable standard deviation of the spatial filter kernel in z direction.\n        color sigma: trainable standard deviation of the intensity range kernel. This filter\n            parameter determines the degree of edge preservation.\n\n    Returns:\n        output (torch.Tensor): filtered tensor.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma):\n        output_tensor, output_weights_tensor, do_dx_ki, do_dsig_r, do_dsig_x, do_dsig_y, do_dsig_z = _C.tjbf_forward(\n            input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma\n        )\n\n        ctx.save_for_backward(\n            input_img,\n            sigma_x,\n            sigma_y,\n            sigma_z,\n            color_sigma,\n            output_tensor,\n            output_weights_tensor,\n            do_dx_ki,\n            do_dsig_r,\n            do_dsig_x,\n            do_dsig_y,\n            do_dsig_z,\n            guidance_img,\n        )\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        return output_tensor\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input_img = ctx.saved_tensors[0]  # input image\n        sigma_x = ctx.saved_tensors[1]\n        sigma_y = ctx.saved_tensors[2]\n        sigma_z = ctx.saved_tensors[3]\n        color_sigma = ctx.saved_tensors[4]\n        output_tensor = ctx.saved_tensors[5]  # filtered image\n        output_weights_tensor = ctx.saved_tensors[6]  # weights\n        do_dx_ki = ctx.saved_tensors[7]  # derivative of output with respect to input, while k==i\n        do_dsig_r = ctx.saved_tensors[8]  # derivative of output with respect to range sigma\n        do_dsig_x = ctx.saved_tensors[9]  # derivative of output with respect to sigma x\n        do_dsig_y = ctx.saved_tensors[10]  # derivative of output with respect to sigma y\n        do_dsig_z = ctx.saved_tensors[11]  # derivative of output with respect to sigma z\n        guidance_img = ctx.saved_tensors[12]  # guidance image\n\n        # calculate gradient with respect to the sigmas\n        grad_color_sigma = torch.sum(grad_output * do_dsig_r)\n        grad_sig_x = torch.sum(grad_output * do_dsig_x)\n        grad_sig_y = torch.sum(grad_output * do_dsig_y)\n        grad_sig_z = torch.sum(grad_output * do_dsig_z)\n\n        grad_output_tensor, grad_guidance_tensor = _C.tjbf_backward(\n            grad_output,\n            input_img,\n            guidance_img,\n            output_tensor,\n            output_weights_tensor,\n            do_dx_ki,\n            sigma_x,\n            sigma_y,\n            sigma_z,\n            color_sigma,\n        )\n\n        return grad_output_tensor, grad_guidance_tensor, grad_sig_x, grad_sig_y, grad_sig_z, grad_color_sigma\n\n\nclass TrainableJointBilateralFilter(torch.nn.Module):\n    \"\"\"\n    Implementation of a trainable joint bilateral filter layer as proposed in the corresponding publication.\n    The guidance image is used as additional (edge) information during filtering. All filter parameters and the\n    guidance image can be trained data-driven. The spatial filter kernels x, y, and z determine\n    image smoothing whereas the color parameter specifies the amount of edge preservation.\n    Can run on 1D, 2D, or 3D tensors (on top of Batch and Channel dimensions). Input tensor shape must match\n    guidance tensor shape.\n\n    See:\n        F. Wagner, et al., Trainable joint bilateral filters for enhanced prediction stability in\n        low-dose CT, Scientific Reports (2022), https://doi.org/10.1038/s41598-022-22530-4\n\n    Args:\n        input: input tensor to be filtered.\n        guide: guidance image tensor to be used during filtering.\n        spatial_sigma: tuple (sigma_x, sigma_y, sigma_z) initializing the trainable standard\n            deviations of the spatial filter kernels. Tuple length must equal the number of\n            spatial input dimensions.\n        color_sigma: trainable standard deviation of the intensity range kernel. This filter\n            parameter determines the degree of edge preservation.\n\n    Returns:\n        output (torch.Tensor): filtered tensor.\n    \"\"\"\n\n    def __init__(self, spatial_sigma, color_sigma):\n        super().__init__()\n\n        if isinstance(spatial_sigma, float):\n            spatial_sigma = [spatial_sigma, spatial_sigma, spatial_sigma]\n            self.len_spatial_sigma = 3\n        elif len(spatial_sigma) == 1:\n            spatial_sigma = [spatial_sigma[0], 0.01, 0.01]\n            self.len_spatial_sigma = 1\n        elif len(spatial_sigma) == 2:\n            spatial_sigma = [spatial_sigma[0], spatial_sigma[1], 0.01]\n            self.len_spatial_sigma = 2\n        elif len(spatial_sigma) == 3:\n            spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]\n            self.len_spatial_sigma = 3\n        else:\n            raise ValueError(\n                f\"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}.\"\n            )\n\n        # Register sigmas as trainable parameters.\n        self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))\n        self.sigma_y = torch.nn.Parameter(torch.tensor(spatial_sigma[1]))\n        self.sigma_z = torch.nn.Parameter(torch.tensor(spatial_sigma[2]))\n        self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))\n\n    def forward(self, input_tensor, guidance_tensor):\n        if input_tensor.shape[1] != 1:\n            raise ValueError(\n                f\"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. \"\n                \"Please use multiple parallel filter layers if you want \"\n                \"to filter multiple channels.\"\n            )\n        if input_tensor.shape != guidance_tensor.shape:\n            raise ValueError(\n                \"Shape of input image must equal shape of guidance image.\"\n                f\"Got {input_tensor.shape} and {guidance_tensor.shape}.\"\n            )\n\n        len_input = len(input_tensor.shape)\n\n        # C++ extension so far only supports 5-dim inputs.\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n            guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n            guidance_tensor = guidance_tensor.unsqueeze(4)\n\n        if self.len_spatial_sigma != len_input:\n            raise ValueError(f\"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).\")\n\n        prediction = TrainableJointBilateralFilterFunction.apply(\n            input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color\n        )\n\n        # Make sure to return tensor of the same shape as the input.\n        if len_input == 3:\n            prediction = prediction.squeeze(4).squeeze(3)\n        elif len_input == 4:\n            prediction = prediction.squeeze(4)\n\n        return prediction\n"
  },
  {
    "path": "monai/networks/layers/gmm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai._extensions.loader import load_module\n\n__all__ = [\"GaussianMixtureModel\"]\n\n\nclass GaussianMixtureModel:\n    \"\"\"\n    Takes an initial labeling and uses a mixture of Gaussians to approximate each classes\n    distribution in the feature space. Each unlabeled element is then assigned a probability\n    of belonging to each class based on it's fit to each classes approximated distribution.\n\n    See:\n        https://en.wikipedia.org/wiki/Mixture_model\n    \"\"\"\n\n    def __init__(self, channel_count: int, mixture_count: int, mixture_size: int, verbose_build: bool = False):\n        \"\"\"\n        Args:\n            channel_count: The number of features per element.\n            mixture_count: The number of class distributions.\n            mixture_size: The number Gaussian components per class distribution.\n            verbose_build: If ``True``, turns on verbose logging of load steps.\n        \"\"\"\n        if not torch.cuda.is_available():\n            raise NotImplementedError(\"GaussianMixtureModel is currently implemented for CUDA.\")\n        self.channel_count = channel_count\n        self.mixture_count = mixture_count\n        self.mixture_size = mixture_size\n        self.compiled_extension = load_module(\n            \"gmm\",\n            {\"CHANNEL_COUNT\": channel_count, \"MIXTURE_COUNT\": mixture_count, \"MIXTURE_SIZE\": mixture_size},\n            verbose_build=verbose_build,\n        )\n        self.params, self.scratch = self.compiled_extension.init()\n\n    def reset(self):\n        \"\"\"\n        Resets the parameters of the model.\n        \"\"\"\n        self.params, self.scratch = self.compiled_extension.init()\n\n    def learn(self, features, labels):\n        \"\"\"\n        Learns, from scratch, the distribution of each class from the provided labels.\n\n        Args:\n            features (torch.Tensor): features for each element.\n            labels (torch.Tensor): initial labeling for each element.\n        \"\"\"\n        self.compiled_extension.learn(self.params, self.scratch, features, labels)\n\n    def apply(self, features):\n        \"\"\"\n        Applies the current model to a set of feature vectors.\n\n        Args:\n            features (torch.Tensor): feature vectors for each element.\n\n        Returns:\n            output (torch.Tensor): class assignment probabilities for each element.\n        \"\"\"\n        return _ApplyFunc.apply(self.params, features, self.compiled_extension)\n\n\nclass _ApplyFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, params, features, compiled_extension):\n        return compiled_extension.apply(params, features)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        raise NotImplementedError(\"GMM does not support backpropagation\")\n"
  },
  {
    "path": "monai/networks/layers/simplelayers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.autograd import Function\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.networks.layers.convutils import gaussian_1d\nfrom monai.networks.layers.factories import Conv\nfrom monai.utils import (\n    ChannelMatching,\n    SkipMode,\n    convert_to_tensor,\n    ensure_tuple_rep,\n    issequenceiterable,\n    look_up_option,\n    optional_import,\n)\n\n_C, _ = optional_import(\"monai._C\")\nfft, _ = optional_import(\"torch.fft\")\n\n__all__ = [\n    \"ChannelPad\",\n    \"Flatten\",\n    \"GaussianFilter\",\n    \"HilbertTransform\",\n    \"LLTM\",\n    \"MedianFilter\",\n    \"Reshape\",\n    \"SavitzkyGolayFilter\",\n    \"SkipConnection\",\n    \"apply_filter\",\n    \"median_filter\",\n    \"separable_filtering\",\n]\n\n\nclass ChannelPad(nn.Module):\n    \"\"\"\n    Expand the input tensor's channel dimension from length `in_channels` to `out_channels`,\n    by padding or a projection.\n    \"\"\"\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, mode: ChannelMatching | str = ChannelMatching.PAD\n    ):\n        \"\"\"\n\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            mode: {``\"pad\"``, ``\"project\"``}\n                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``\"pad\"``.\n\n                - ``\"pad\"``: with zero padding.\n                - ``\"project\"``: with a trainable conv with kernel size one.\n        \"\"\"\n        super().__init__()\n        self.project = None\n        self.pad = None\n        if in_channels == out_channels:\n            return\n        mode = look_up_option(mode, ChannelMatching)\n        if mode == ChannelMatching.PROJECT:\n            conv_type = Conv[Conv.CONV, spatial_dims]\n            self.project = conv_type(in_channels, out_channels, kernel_size=1)\n            return\n        if mode == ChannelMatching.PAD:\n            if in_channels > out_channels:\n                raise ValueError('Incompatible values: channel_matching=\"pad\" and in_channels > out_channels.')\n            pad_1 = (out_channels - in_channels) // 2\n            pad_2 = out_channels - in_channels - pad_1\n            pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0]\n            self.pad = tuple(pad)\n            return\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.project is not None:\n            return torch.as_tensor(self.project(x))  # as_tensor used to get around mypy typing bug\n        if self.pad is not None:\n            return F.pad(x, self.pad)\n        return x\n\n\nclass SkipConnection(nn.Module):\n    \"\"\"\n    Combine the forward pass input with the result from the given submodule::\n\n        --+--submodule--o--\n          |_____________|\n\n    The available modes are ``\"cat\"``, ``\"add\"``, ``\"mul\"``.\n    \"\"\"\n\n    def __init__(self, submodule, dim: int = 1, mode: str | SkipMode = \"cat\") -> None:\n        \"\"\"\n\n        Args:\n            submodule: the module defines the trainable branch.\n            dim: the dimension over which the tensors are concatenated.\n                Used when mode is ``\"cat\"``.\n            mode: ``\"cat\"``, ``\"add\"``, ``\"mul\"``. defaults to ``\"cat\"``.\n        \"\"\"\n        super().__init__()\n        self.submodule = submodule\n        self.dim = dim\n        self.mode = look_up_option(mode, SkipMode).value\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        y = self.submodule(x)\n\n        if self.mode == \"cat\":\n            return torch.cat([x, y], dim=self.dim)\n        if self.mode == \"add\":\n            return torch.add(x, y)\n        if self.mode == \"mul\":\n            return torch.mul(x, y)\n        raise NotImplementedError(f\"Unsupported mode {self.mode}.\")\n\n\nclass Flatten(nn.Module):\n    \"\"\"\n    Flattens the given input in the forward pass to be [B,-1] in shape.\n    \"\"\"\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return x.view(x.size(0), -1)\n\n\nclass Reshape(nn.Module):\n    \"\"\"\n    Reshapes input tensors to the given shape (minus batch dimension), retaining original batch size.\n    \"\"\"\n\n    def __init__(self, *shape: int) -> None:\n        \"\"\"\n        Given a shape list/tuple `shape` of integers (s0, s1, ... , sn), this layer will reshape input tensors of\n        shape (batch, s0 * s1 * ... * sn) to shape (batch, s0, s1, ... , sn).\n\n        Args:\n            shape: list/tuple of integer shape dimensions\n        \"\"\"\n        super().__init__()\n        self.shape = (1,) + tuple(shape)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        shape = list(self.shape)\n        shape[0] = x.shape[0]  # done this way for Torchscript\n        return x.reshape(shape)\n\n\ndef _separable_filtering_conv(\n    input_: torch.Tensor,\n    kernels: list[torch.Tensor],\n    pad_mode: str,\n    d: int,\n    spatial_dims: int,\n    paddings: list[int],\n    num_channels: int,\n) -> torch.Tensor:\n    if d < 0:\n        return input_\n\n    s = [1] * len(input_.shape)\n    s[d + 2] = -1\n    _kernel = kernels[d].reshape(s)\n\n    # if filter kernel is unity, don't convolve\n    if _kernel.numel() == 1 and _kernel[0] == 1:\n        return _separable_filtering_conv(input_, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels)\n\n    _kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims)\n    _padding = [0] * spatial_dims\n    _padding[d] = paddings[d]\n    conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]\n\n    # translate padding for input to torch.nn.functional.pad\n    _reversed_padding_repeated_twice: list[list[int]] = [[p, p] for p in reversed(_padding)]\n    _sum_reversed_padding_repeated_twice: list[int] = sum(_reversed_padding_repeated_twice, [])\n    padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode)\n\n    return conv_type(\n        input=_separable_filtering_conv(padded_input, kernels, pad_mode, d - 1, spatial_dims, paddings, num_channels),\n        weight=_kernel,\n        groups=num_channels,\n    )\n\n\ndef separable_filtering(x: torch.Tensor, kernels: list[torch.Tensor], mode: str = \"zeros\") -> torch.Tensor:\n    \"\"\"\n    Apply 1-D convolutions along each spatial dimension of `x`.\n\n    Args:\n        x: the input image. must have shape (batch, channels, H[, W, ...]).\n        kernels: kernel along each spatial dimension.\n            could be a single kernel (duplicated for all spatial dimensions), or\n            a list of `spatial_dims` number of kernels.\n        mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``\n            or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.\n\n    Raises:\n        TypeError: When ``x`` is not a ``torch.Tensor``.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import torch\n        >>> from monai.networks.layers import separable_filtering\n        >>> img = torch.randn(2, 4, 32, 32)  # batch_size 2, channels 4, 32x32 2D images\n        # applying a [-1, 0, 1] filter along each of the spatial dimensions.\n        # the output shape is the same as the input shape.\n        >>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))\n        # applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.\n        # the output shape is the same as the input shape.\n        >>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])\n\n    \"\"\"\n\n    if not isinstance(x, torch.Tensor):\n        raise TypeError(f\"x must be a torch.Tensor but is {type(x).__name__}.\")\n\n    spatial_dims = len(x.shape) - 2\n    if isinstance(kernels, torch.Tensor):\n        kernels = [kernels] * spatial_dims\n    _kernels = [s.to(x) for s in kernels]\n    _paddings = [(k.shape[0] - 1) // 2 for k in _kernels]\n    n_chs = x.shape[1]\n    pad_mode = \"constant\" if mode == \"zeros\" else mode\n\n    return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims - 1, spatial_dims, _paddings, n_chs)\n\n\ndef apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor:\n    \"\"\"\n    Filtering `x` with `kernel` independently for each batch and channel respectively.\n\n    Args:\n        x: the input image, must have shape (batch, channels, H[, W, D]).\n        kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]).\n            `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`.\n        kwargs: keyword arguments passed to `conv*d()` functions.\n\n    Returns:\n        The filtered `x`.\n\n    Examples:\n\n    .. code-block:: python\n\n        >>> import torch\n        >>> from monai.networks.layers import apply_filter\n        >>> img = torch.rand(2, 5, 10, 10)  # batch_size 2, channels 5, 10x10 2D images\n        >>> out = apply_filter(img, torch.rand(3, 3))   # spatial kernel\n        >>> out = apply_filter(img, torch.rand(5, 3, 3))  # channel-wise kernels\n        >>> out = apply_filter(img, torch.rand(2, 5, 3, 3))  # batch-, channel-wise kernels\n\n    \"\"\"\n    if not isinstance(x, torch.Tensor):\n        raise TypeError(f\"x must be a torch.Tensor but is {type(x).__name__}.\")\n    batch, chns, *spatials = x.shape\n    n_spatial = len(spatials)\n    if n_spatial > 3:\n        raise NotImplementedError(f\"Only spatial dimensions up to 3 are supported but got {n_spatial}.\")\n    k_size = len(kernel.shape)\n    if k_size < n_spatial or k_size > n_spatial + 2:\n        raise ValueError(\n            f\"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}.\"\n        )\n    kernel = kernel.to(x)\n    # broadcast kernel size to (batch chns, spatial_kernel_size)\n    kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :])\n    kernel = kernel.reshape(-1, 1, *kernel.shape[2:])  # group=1\n    x = x.view(1, kernel.shape[0], *spatials)\n    conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]\n    if \"padding\" not in kwargs:\n        kwargs[\"padding\"] = \"same\"\n\n    if \"stride\" not in kwargs:\n        kwargs[\"stride\"] = 1\n    output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)\n    return output.view(batch, chns, *output.shape[2:])\n\n\nclass SavitzkyGolayFilter(nn.Module):\n    \"\"\"\n    Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.\n\n    Args:\n        window_length: Length of the filter window, must be a positive odd integer.\n        order: Order of the polynomial to fit to each window, must be less than ``window_length``.\n        axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).\n        mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or\n        ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.\n    \"\"\"\n\n    def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = \"zeros\"):\n        super().__init__()\n        if order >= window_length:\n            raise ValueError(\"order must be less than window_length.\")\n\n        self.axis = axis\n        self.mode = mode\n        self.coeffs = self._make_coeffs(window_length, order)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and\n                have a device type of ``'cpu'``.\n        Returns:\n            torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using\n            polynomials of order ``self.order``, along axis specified in ``self.axis``.\n        \"\"\"\n\n        # Make input a real tensor on the CPU\n        x = torch.as_tensor(x, device=x.device if isinstance(x, torch.Tensor) else None)\n        if torch.is_complex(x):\n            raise ValueError(\"x must be real.\")\n        x = x.to(dtype=torch.float)\n\n        if (self.axis < 0) or (self.axis > len(x.shape) - 1):\n            raise ValueError(f\"Invalid axis for shape of x, got axis {self.axis} and shape {x.shape}.\")\n\n        # Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs,\n        # while the other kernels will be set to [1].\n        n_spatial_dims = len(x.shape) - 2\n        spatial_processing_axis = self.axis - 2\n        new_dims_before = spatial_processing_axis\n        new_dims_after = n_spatial_dims - spatial_processing_axis - 1\n        kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)]\n        for _ in range(new_dims_before):\n            kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype))\n        for _ in range(new_dims_after):\n            kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype))\n\n        return separable_filtering(x, kernel_list, mode=self.mode)\n\n    @staticmethod\n    def _make_coeffs(window_length, order):\n        half_length, rem = divmod(window_length, 2)\n        if rem == 0:\n            raise ValueError(\"window_length must be odd.\")\n\n        idx = torch.arange(window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device=\"cpu\")\n        a = idx ** torch.arange(order + 1, dtype=torch.float, device=\"cpu\").reshape(-1, 1)\n        y = torch.zeros(order + 1, dtype=torch.float, device=\"cpu\")\n        y[0] = 1.0\n        return torch.linalg.lstsq(a, y).solution.squeeze()\n\n\nclass HilbertTransform(nn.Module):\n    \"\"\"\n    Determine the analytical signal of a Tensor along a particular axis.\n\n    Args:\n        axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension).\n        n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``.\n    \"\"\"\n\n    def __init__(self, axis: int = 2, n: int | None = None) -> None:\n        super().__init__()\n        self.axis = axis\n        self.n = n\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``.\n        Returns:\n            torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using\n            FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``.\n        \"\"\"\n\n        # Make input a real tensor\n        x = torch.as_tensor(x, device=x.device if isinstance(x, torch.Tensor) else None)\n        if torch.is_complex(x):\n            raise ValueError(\"x must be real.\")\n        x = x.to(dtype=torch.float)\n\n        if (self.axis < 0) or (self.axis > len(x.shape) - 1):\n            raise ValueError(f\"Invalid axis for shape of x, got axis {self.axis} and shape {x.shape}.\")\n\n        n = x.shape[self.axis] if self.n is None else self.n\n        if n <= 0:\n            raise ValueError(\"N must be positive.\")\n        x = torch.as_tensor(x, dtype=torch.complex64)\n        # Create frequency axis\n        f = torch.cat(\n            [\n                torch.true_divide(torch.arange(0, (n - 1) // 2 + 1, device=x.device), float(n)),\n                torch.true_divide(torch.arange(-(n // 2), 0, device=x.device), float(n)),\n            ]\n        )\n        xf = fft.fft(x, n=n, dim=self.axis)\n        # Create step function\n        u = torch.heaviside(f, torch.tensor([0.5], device=f.device))\n        u = torch.as_tensor(u, dtype=x.dtype, device=u.device)\n        new_dims_before = self.axis\n        new_dims_after = len(xf.shape) - self.axis - 1\n        for _ in range(new_dims_before):\n            u.unsqueeze_(0)\n        for _ in range(new_dims_after):\n            u.unsqueeze_(-1)\n\n        ht = fft.ifft(xf * 2 * u, dim=self.axis)\n\n        # Apply transform\n        return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype)\n\n\ndef get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None) -> torch.Tensor:\n    \"\"\"\n    Create a binary kernel to extract the patches.\n    The window size HxWxD will create a (H*W*D)xHxWxD kernel.\n    \"\"\"\n    win_size = convert_to_tensor(window_size, int, wrap_sequence=True)\n    prod = torch.prod(win_size)\n    s = [prod, 1, *win_size]\n    return torch.diag(torch.ones(prod, dtype=dtype, device=device)).view(s)  # type: ignore\n\n\ndef median_filter(\n    in_tensor: torch.Tensor,\n    kernel_size: Sequence[int] | int = (3, 3, 3),\n    spatial_dims: int = 3,\n    kernel: torch.Tensor | None = None,\n    **kwargs,\n) -> torch.Tensor:\n    \"\"\"\n    Apply median filter to an image.\n\n    Args:\n        in_tensor: input tensor; median filtering will be applied to the last `spatial_dims` dimensions.\n        kernel_size: the convolution kernel size.\n        spatial_dims: number of spatial dimensions to apply median filtering.\n        kernel: an optional customized kernel.\n        kwargs: additional parameters to the `conv`.\n\n    Returns:\n        the filtered input tensor, shape remains the same as ``in_tensor``\n\n    Example::\n\n        >>> from monai.networks.layers import median_filter\n        >>> import torch\n        >>> x = torch.rand(4, 5, 7, 6)\n        >>> output = median_filter(x, (3, 3, 3))\n        >>> output.shape\n        torch.Size([4, 5, 7, 6])\n\n    \"\"\"\n    if not isinstance(in_tensor, torch.Tensor):\n        raise TypeError(f\"Input type is not a torch.Tensor. Got {type(in_tensor)}\")\n\n    original_shape = in_tensor.shape\n    oshape, sshape = original_shape[: len(original_shape) - spatial_dims], original_shape[-spatial_dims:]\n    oprod = torch.prod(convert_to_tensor(oshape, int, wrap_sequence=True))\n    # prepare kernel\n    if kernel is None:\n        kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)\n        kernel = get_binary_kernel(kernel_size, in_tensor.dtype, in_tensor.device)\n    else:\n        kernel = kernel.to(in_tensor)\n    # map the local window to single vector\n    conv = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]\n    reshaped_input: torch.Tensor = in_tensor.reshape(oprod, 1, *sshape)  # type: ignore\n\n    # even-sized kernels are not supported\n    padding = [(k - 1) // 2 for k in reversed(kernel.shape[2:]) for _ in range(2)]\n    padded_input: torch.Tensor = F.pad(reshaped_input, pad=padding, mode=\"replicate\")\n    features: torch.Tensor = conv(padded_input, kernel, padding=0, stride=1, **kwargs)\n\n    features = features.view(oprod, -1, *sshape)  # type: ignore\n\n    # compute the median along the feature axis\n    median: torch.Tensor = torch.median(features, dim=1)[0]\n    median = median.reshape(original_shape)\n\n    return median\n\n\nclass MedianFilter(nn.Module):\n    \"\"\"\n    Apply median filter to an image.\n\n    Args:\n        radius: the blurring kernel radius (radius of 1 corresponds to 3x3x3 kernel when spatial_dims=3).\n\n    Returns:\n        filtered input tensor.\n\n    Example::\n\n        >>> from monai.networks.layers import MedianFilter\n        >>> import torch\n        >>> in_tensor = torch.rand(4, 5, 7, 6)\n        >>> blur = MedianFilter([1, 1, 1])  # 3x3x3 kernel\n        >>> output = blur(in_tensor)\n        >>> output.shape\n        torch.Size([4, 5, 7, 6])\n\n    \"\"\"\n\n    def __init__(self, radius: Sequence[int] | int, spatial_dims: int = 3, device=\"cpu\") -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.radius: Sequence[int] = ensure_tuple_rep(radius, spatial_dims)\n        self.window: Sequence[int] = [1 + 2 * deepcopy(r) for r in self.radius]\n        self.kernel = get_binary_kernel(self.window, device=device)\n\n    def forward(self, in_tensor: torch.Tensor, number_of_passes=1) -> torch.Tensor:\n        \"\"\"\n        Args:\n            in_tensor: input tensor, median filtering will be applied to the last `spatial_dims` dimensions.\n            number_of_passes: median filtering will be repeated this many times\n        \"\"\"\n        x = in_tensor\n        for _ in range(number_of_passes):\n            x = median_filter(x, kernel=self.kernel, spatial_dims=self.spatial_dims)\n        return x\n\n\nclass GaussianFilter(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor,\n        truncated: float = 4.0,\n        approx: str = \"erf\",\n        requires_grad: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n                must have shape (Batch, channels, H[, W, ...]).\n            sigma: std. could be a single value, or `spatial_dims` number of values.\n            truncated: spreads how many stds.\n            approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n\n                - ``erf`` approximation interpolates the error function;\n                - ``sampled`` uses a sampled Gaussian kernel;\n                - ``scalespace`` corresponds to\n                  https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel\n                  based on the modified Bessel functions.\n\n            requires_grad: whether to store the gradients for sigma.\n                if True, `sigma` will be the initial value of the parameters of this module\n                (for example `parameters()` iterator could be used to get the parameters);\n                otherwise this module will fix the kernels using `sigma` as the std.\n        \"\"\"\n        if issequenceiterable(sigma):\n            if len(sigma) != spatial_dims:  # type: ignore\n                raise ValueError\n        else:\n            sigma = [deepcopy(sigma) for _ in range(spatial_dims)]  # type: ignore\n        super().__init__()\n        self.sigma = [\n            torch.nn.Parameter(\n                torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None),\n                requires_grad=requires_grad,\n            )\n            for s in sigma  # type: ignore\n        ]\n        self.truncated = truncated\n        self.approx = approx\n        for idx, param in enumerate(self.sigma):\n            self.register_parameter(f\"kernel_sigma_{idx}\", param)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: in shape [Batch, chns, H, W, D].\n        \"\"\"\n        _kernel = [gaussian_1d(s, truncated=self.truncated, approx=self.approx) for s in self.sigma]\n        return separable_filtering(x=x, kernels=_kernel)\n\n\nclass LLTMFunction(Function):\n\n    @staticmethod\n    def forward(ctx, input, weights, bias, old_h, old_cell):\n        outputs = _C.lltm_forward(input, weights, bias, old_h, old_cell)\n        new_h, new_cell = outputs[:2]\n        variables = outputs[1:] + [weights]\n        ctx.save_for_backward(*variables)\n\n        return new_h, new_cell\n\n    @staticmethod\n    def backward(ctx, grad_h, grad_cell):\n        outputs = _C.lltm_backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors)\n        d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs[:5]\n\n        return d_input, d_weights, d_bias, d_old_h, d_old_cell\n\n\nclass LLTM(nn.Module):\n    \"\"\"\n    This recurrent unit is similar to an LSTM, but differs in that it lacks a forget\n    gate and uses an Exponential Linear Unit (ELU) as its internal activation function.\n    Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit.\n    It has both C++ and CUDA implementation, automatically switch according to the\n    target device where put this module to.\n\n    Args:\n        input_features: size of input feature data\n        state_size: size of the state of recurrent unit\n\n    Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html\n    \"\"\"\n\n    def __init__(self, input_features: int, state_size: int):\n        super().__init__()\n        self.input_features = input_features\n        self.state_size = state_size\n        self.weights = nn.Parameter(torch.empty(3 * state_size, input_features + state_size))\n        self.bias = nn.Parameter(torch.empty(1, 3 * state_size))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        stdv = 1.0 / math.sqrt(self.state_size)\n        for weight in self.parameters():\n            weight.data.uniform_(-stdv, +stdv)\n\n    def forward(self, input, state):\n        return LLTMFunction.apply(input, self.weights, self.bias, *state)\n\n\nclass ApplyFilter(nn.Module):\n    \"Wrapper class to apply a filter to an image.\"\n\n    def __init__(self, filter: NdarrayOrTensor) -> None:\n        super().__init__()\n\n        self.filter = convert_to_tensor(filter, dtype=torch.float32)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return apply_filter(x, self.filter)\n\n\nclass MeanFilter(ApplyFilter):\n    \"\"\"\n    Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image.\n    The mean filter used, is a `torch.Tensor` of all ones.\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, size: int) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: `int` of either 2 for 2D images and 3 for 3D images\n            size: edge length of the filter\n        \"\"\"\n        filter = torch.ones([size] * spatial_dims)\n        filter = filter\n        super().__init__(filter=filter)\n\n\nclass LaplaceFilter(ApplyFilter):\n    \"\"\"\n    Laplacian filtering for outline detection in images. Can be used to transform labels to contours.\n    The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value\n    which is `size` ** `spatial_dims`\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, size: int) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: `int` of either 2 for 2D images and 3 for 3D images\n            size: edge length of the filter\n        \"\"\"\n        filter = torch.zeros([size] * spatial_dims).float() - 1  # make all -1\n        center_point = tuple([size // 2] * spatial_dims)\n        filter[center_point] = (size**spatial_dims) - 1\n        super().__init__(filter=filter)\n\n\nclass EllipticalFilter(ApplyFilter):\n    \"\"\"\n    Elliptical filter, can be used to dilate labels or label-contours.\n    The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1`\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, size: int) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: `int` of either 2 for 2D images and 3 for 3D images\n            size: edge length of the filter\n        \"\"\"\n        radius = size // 2\n        grid = torch.meshgrid(*[torch.arange(0, size) for _ in range(spatial_dims)])\n        squared_distances = torch.stack([(axis - radius) ** 2 for axis in grid], 0).sum(0)\n        filter = squared_distances <= radius**2\n        super().__init__(filter=filter)\n\n\nclass SharpenFilter(EllipticalFilter):\n    \"\"\"\n    Convolutional filter to sharpen a 2D or 3D image.\n    The filter used contains a circle/sphere of `-1`, with the center value being\n    the absolute sum of all non-zero elements in the kernel\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, size: int) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: `int` of either 2 for 2D images and 3 for 3D images\n            size: edge length of the filter\n        \"\"\"\n        super().__init__(spatial_dims=spatial_dims, size=size)\n        center_point = tuple([size // 2] * spatial_dims)\n        center_value = self.filter.sum()\n        self.filter *= -1\n        self.filter[center_point] = center_value\n"
  },
  {
    "path": "monai/networks/layers/spatial_transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nimport monai\nfrom monai.networks import to_norm_affine\nfrom monai.utils import (\n    GridSampleMode,\n    GridSamplePadMode,\n    convert_to_dst_type,\n    ensure_tuple,\n    look_up_option,\n    optional_import,\n)\n\n_C, _ = optional_import(\"monai._C\")\n\n__all__ = [\"AffineTransform\", \"grid_pull\", \"grid_push\", \"grid_count\", \"grid_grad\"]\n\n\nclass _GridPull(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, input, grid, interpolation, bound, extrapolate):\n        opt = (bound, interpolation, extrapolate)\n        output = _C.grid_pull(input, grid, *opt)\n        if input.requires_grad or grid.requires_grad:\n            ctx.opt = opt\n            ctx.save_for_backward(input, grid)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad):\n        if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):\n            return None, None, None, None, None\n        var = ctx.saved_tensors\n        opt = ctx.opt\n        grads = _C.grid_pull_backward(grad, *var, *opt)\n        if ctx.needs_input_grad[0]:\n            return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None\n        if ctx.needs_input_grad[1]:\n            return None, grads[0], None, None, None\n\n\ndef grid_pull(\n    input: torch.Tensor, grid: torch.Tensor, interpolation=\"linear\", bound=\"zero\", extrapolate: bool = True\n) -> torch.Tensor:\n    \"\"\"\n    Sample an image with respect to a deformation field.\n\n    `interpolation` can be an int, a string or an InterpolationType.\n    Possible values are::\n\n        - 0 or 'nearest'    or InterpolationType.nearest\n        - 1 or 'linear'     or InterpolationType.linear\n        - 2 or 'quadratic'  or InterpolationType.quadratic\n        - 3 or 'cubic'      or InterpolationType.cubic\n        - 4 or 'fourth'     or InterpolationType.fourth\n        - 5 or 'fifth'      or InterpolationType.fifth\n        - 6 or 'sixth'      or InterpolationType.sixth\n        - 7 or 'seventh'    or InterpolationType.seventh\n\n    A list of values can be provided, in the order [W, H, D],\n    to specify dimension-specific interpolation orders.\n\n    `bound` can be an int, a string or a BoundType.\n    Possible values are::\n\n        - 0 or 'replicate' or 'nearest'      or BoundType.replicate or 'border'\n        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1\n        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2\n        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1\n        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2\n        - 5 or 'dft'       or 'wrap'         or BoundType.dft\n        - 7 or 'zero'      or 'zeros'        or BoundType.zero\n\n    A list of values can be provided, in the order [W, H, D],\n    to specify dimension-specific boundary conditions.\n    `sliding` is a specific condition than only applies to flow fields\n    (with as many channels as dimensions). It cannot be dimension-specific.\n    Note that:\n\n        - `dft` corresponds to circular padding\n        - `dct2` corresponds to Neumann boundary conditions (symmetric)\n        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)\n\n    See Also:\n        - https://en.wikipedia.org/wiki/Discrete_cosine_transform\n        - https://en.wikipedia.org/wiki/Discrete_sine_transform\n        - ``help(monai._C.BoundType)``\n        - ``help(monai._C.InterpolationType)``\n\n    Args:\n        input: Input image. `(B, C, Wi, Hi, Di)`.\n        grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`.\n        interpolation (int or list[int] , optional): Interpolation order.\n            Defaults to `'linear'`.\n        bound (BoundType, or list[BoundType], optional): Boundary conditions.\n            Defaults to `'zero'`.\n        extrapolate: Extrapolate out-of-bound data.\n            Defaults to `True`.\n\n    Returns:\n        output (torch.Tensor): Deformed image `(B, C, Wo, Ho, Do)`.\n\n    \"\"\"\n    # Convert parameters\n    bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]\n    interpolation = [\n        _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)\n        for i in ensure_tuple(interpolation)\n    ]\n    out: torch.Tensor\n    out = _GridPull.apply(input, grid, interpolation, bound, extrapolate)\n    if isinstance(input, monai.data.MetaTensor):\n        out = convert_to_dst_type(out, dst=input)[0]\n    return out\n\n\nclass _GridPush(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):\n        opt = (bound, interpolation, extrapolate)\n        output = _C.grid_push(input, grid, shape, *opt)\n        if input.requires_grad or grid.requires_grad:\n            ctx.opt = opt\n            ctx.save_for_backward(input, grid)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad):\n        if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):\n            return None, None, None, None, None, None\n        var = ctx.saved_tensors\n        opt = ctx.opt\n        grads = _C.grid_push_backward(grad, *var, *opt)\n        if ctx.needs_input_grad[0]:\n            return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None, None\n        if ctx.needs_input_grad[1]:\n            return None, grads[0], None, None, None, None\n\n\ndef grid_push(\n    input: torch.Tensor, grid: torch.Tensor, shape=None, interpolation=\"linear\", bound=\"zero\", extrapolate: bool = True\n):\n    \"\"\"\n    Splat an image with respect to a deformation field (pull adjoint).\n\n    `interpolation` can be an int, a string or an InterpolationType.\n    Possible values are::\n\n        - 0 or 'nearest'    or InterpolationType.nearest\n        - 1 or 'linear'     or InterpolationType.linear\n        - 2 or 'quadratic'  or InterpolationType.quadratic\n        - 3 or 'cubic'      or InterpolationType.cubic\n        - 4 or 'fourth'     or InterpolationType.fourth\n        - 5 or 'fifth'      or InterpolationType.fifth\n        - 6 or 'sixth'      or InterpolationType.sixth\n        - 7 or 'seventh'    or InterpolationType.seventh\n\n    A list of values can be provided, in the order `[W, H, D]`,\n    to specify dimension-specific interpolation orders.\n\n    `bound` can be an int, a string or a BoundType.\n    Possible values are::\n\n        - 0 or 'replicate' or 'nearest'      or BoundType.replicate\n        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1\n        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2\n        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1\n        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2\n        - 5 or 'dft'       or 'wrap'         or BoundType.dft\n        - 7 or 'zero'                        or BoundType.zero\n\n    A list of values can be provided, in the order `[W, H, D]`,\n    to specify dimension-specific boundary conditions.\n    `sliding` is a specific condition than only applies to flow fields\n    (with as many channels as dimensions). It cannot be dimension-specific.\n    Note that:\n\n        - `dft` corresponds to circular padding\n        - `dct2` corresponds to Neumann boundary conditions (symmetric)\n        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)\n\n    See Also:\n\n        - https://en.wikipedia.org/wiki/Discrete_cosine_transform\n        - https://en.wikipedia.org/wiki/Discrete_sine_transform\n        - ``help(monai._C.BoundType)``\n        - ``help(monai._C.InterpolationType)``\n\n    Args:\n        input: Input image `(B, C, Wi, Hi, Di)`.\n        grid: Deformation field `(B, Wi, Hi, Di, 1|2|3)`.\n        shape: Shape of the source image.\n        interpolation (int or list[int] , optional): Interpolation order.\n            Defaults to `'linear'`.\n        bound (BoundType, or list[BoundType], optional): Boundary conditions.\n            Defaults to `'zero'`.\n        extrapolate: Extrapolate out-of-bound data.\n            Defaults to `True`.\n\n    Returns:\n        output (torch.Tensor): Splatted image `(B, C, Wo, Ho, Do)`.\n\n    \"\"\"\n    # Convert parameters\n    bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]\n    interpolation = [\n        _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)\n        for i in ensure_tuple(interpolation)\n    ]\n\n    if shape is None:\n        shape = tuple(input.shape[2:])\n\n    out: torch.Tensor = _GridPush.apply(input, grid, shape, interpolation, bound, extrapolate)\n    if isinstance(input, monai.data.MetaTensor):\n        out = convert_to_dst_type(out, dst=input)[0]\n    return out\n\n\nclass _GridCount(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, grid, shape, interpolation, bound, extrapolate):\n        opt = (bound, interpolation, extrapolate)\n        output = _C.grid_count(grid, shape, *opt)\n        if grid.requires_grad:\n            ctx.opt = opt\n            ctx.save_for_backward(grid)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad):\n        if ctx.needs_input_grad[0]:\n            var = ctx.saved_tensors\n            opt = ctx.opt\n            return _C.grid_count_backward(grad, *var, *opt), None, None, None, None\n        return None, None, None, None, None\n\n\ndef grid_count(grid: torch.Tensor, shape=None, interpolation=\"linear\", bound=\"zero\", extrapolate: bool = True):\n    \"\"\"\n    Splatting weights with respect to a deformation field (pull adjoint).\n\n    This function is equivalent to applying grid_push to an image of ones.\n\n    `interpolation` can be an int, a string or an InterpolationType.\n    Possible values are::\n\n        - 0 or 'nearest'    or InterpolationType.nearest\n        - 1 or 'linear'     or InterpolationType.linear\n        - 2 or 'quadratic'  or InterpolationType.quadratic\n        - 3 or 'cubic'      or InterpolationType.cubic\n        - 4 or 'fourth'     or InterpolationType.fourth\n        - 5 or 'fifth'      or InterpolationType.fifth\n        - 6 or 'sixth'      or InterpolationType.sixth\n        - 7 or 'seventh'    or InterpolationType.seventh\n\n    A list of values can be provided, in the order [W, H, D],\n    to specify dimension-specific interpolation orders.\n\n    `bound` can be an int, a string or a BoundType.\n    Possible values are::\n\n        - 0 or 'replicate' or 'nearest'      or BoundType.replicate\n        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1\n        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2\n        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1\n        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2\n        - 5 or 'dft'       or 'wrap'         or BoundType.dft\n        - 7 or 'zero'                        or BoundType.zero\n\n    A list of values can be provided, in the order [W, H, D],\n    to specify dimension-specific boundary conditions.\n    `sliding` is a specific condition than only applies to flow fields\n    (with as many channels as dimensions). It cannot be dimension-specific.\n    Note that:\n\n        - `dft` corresponds to circular padding\n        - `dct2` corresponds to Neumann boundary conditions (symmetric)\n        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)\n\n    See Also:\n\n        - https://en.wikipedia.org/wiki/Discrete_cosine_transform\n        - https://en.wikipedia.org/wiki/Discrete_sine_transform\n        - ``help(monai._C.BoundType)``\n        - ``help(monai._C.InterpolationType)``\n\n    Args:\n        grid: Deformation field `(B, Wi, Hi, Di, 2|3)`.\n        shape: shape of the source image.\n        interpolation (int or list[int] , optional): Interpolation order.\n            Defaults to `'linear'`.\n        bound (BoundType, or list[BoundType], optional): Boundary conditions.\n            Defaults to `'zero'`.\n        extrapolate (bool, optional): Extrapolate out-of-bound data.\n            Defaults to `True`.\n\n    Returns:\n        output (torch.Tensor): Splat weights `(B, 1, Wo, Ho, Do)`.\n\n    \"\"\"\n    # Convert parameters\n    bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]\n    interpolation = [\n        _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)\n        for i in ensure_tuple(interpolation)\n    ]\n\n    if shape is None:\n        shape = tuple(grid.shape[2:])\n\n    out: torch.Tensor = _GridCount.apply(grid, shape, interpolation, bound, extrapolate)\n    if isinstance(input, monai.data.MetaTensor):\n        out = convert_to_dst_type(out, dst=input)[0]\n    return out\n\n\nclass _GridGrad(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, input, grid, interpolation, bound, extrapolate):\n        opt = (bound, interpolation, extrapolate)\n        output = _C.grid_grad(input, grid, *opt)\n        if input.requires_grad or grid.requires_grad:\n            ctx.opt = opt\n            ctx.save_for_backward(input, grid)\n\n        return output\n\n    @staticmethod\n    def backward(ctx, grad):\n        if not (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]):\n            return None, None, None, None, None\n        var = ctx.saved_tensors\n        opt = ctx.opt\n        grads = _C.grid_grad_backward(grad, *var, *opt)\n        if ctx.needs_input_grad[0]:\n            return grads[0], grads[1] if ctx.needs_input_grad[1] else None, None, None, None\n        if ctx.needs_input_grad[1]:\n            return None, grads[0], None, None, None\n\n\ndef grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation=\"linear\", bound=\"zero\", extrapolate: bool = True):\n    \"\"\"\n    Sample an image with respect to a deformation field.\n\n    `interpolation` can be an int, a string or an InterpolationType.\n    Possible values are::\n\n        - 0 or 'nearest'    or InterpolationType.nearest\n        - 1 or 'linear'     or InterpolationType.linear\n        - 2 or 'quadratic'  or InterpolationType.quadratic\n        - 3 or 'cubic'      or InterpolationType.cubic\n        - 4 or 'fourth'     or InterpolationType.fourth\n        - 5 or 'fifth'      or InterpolationType.fifth\n        - 6 or 'sixth'      or InterpolationType.sixth\n        - 7 or 'seventh'    or InterpolationType.seventh\n\n    A list of values can be provided, in the order [W, H, D],\n    to specify dimension-specific interpolation orders.\n\n    `bound` can be an int, a string or a BoundType.\n    Possible values are::\n\n        - 0 or 'replicate' or 'nearest'      or BoundType.replicate\n        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1\n        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2\n        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1\n        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2\n        - 5 or 'dft'       or 'wrap'         or BoundType.dft\n        - 7 or 'zero'                        or BoundType.zero\n\n    A list of values can be provided, in the order [W, H, D],\n    to specify dimension-specific boundary conditions.\n    `sliding` is a specific condition than only applies to flow fields\n    (with as many channels as dimensions). It cannot be dimension-specific.\n    Note that:\n\n        - `dft` corresponds to circular padding\n        - `dct2` corresponds to Neumann boundary conditions (symmetric)\n        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)\n\n    See Also:\n\n        - https://en.wikipedia.org/wiki/Discrete_cosine_transform\n        - https://en.wikipedia.org/wiki/Discrete_sine_transform\n        - ``help(monai._C.BoundType)``\n        - ``help(monai._C.InterpolationType)``\n\n\n    Args:\n        input: Input image. `(B, C, Wi, Hi, Di)`.\n        grid: Deformation field. `(B, Wo, Ho, Do, 2|3)`.\n        interpolation (int or list[int] , optional): Interpolation order.\n            Defaults to `'linear'`.\n        bound (BoundType, or list[BoundType], optional): Boundary conditions.\n            Defaults to `'zero'`.\n        extrapolate: Extrapolate out-of-bound data. Defaults to `True`.\n\n    Returns:\n        output (torch.Tensor): Sampled gradients (B, C, Wo, Ho, Do, 1|2|3).\n\n    \"\"\"\n    # Convert parameters\n    bound = [_C.BoundType.__members__[b] if isinstance(b, str) else _C.BoundType(b) for b in ensure_tuple(bound)]\n    interpolation = [\n        _C.InterpolationType.__members__[i] if isinstance(i, str) else _C.InterpolationType(i)\n        for i in ensure_tuple(interpolation)\n    ]\n\n    out: torch.Tensor = _GridGrad.apply(input, grid, interpolation, bound, extrapolate)\n    if isinstance(input, monai.data.MetaTensor):\n        out = convert_to_dst_type(out, dst=input)[0]\n    return out\n\n\nclass AffineTransform(nn.Module):\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int | None = None,\n        normalized: bool = False,\n        mode: str = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.ZEROS,\n        align_corners: bool = True,\n        reverse_indexing: bool = True,\n        zero_centered: bool | None = None,\n    ) -> None:\n        \"\"\"\n        Apply affine transformations with a batch of affine matrices.\n\n        When `normalized=False` and `reverse_indexing=True`,\n        it does the commonly used resampling in the 'pull' direction\n        following the ``scipy.ndimage.affine_transform`` convention.\n        In this case `theta` is equivalent to (ndim+1, ndim+1) input ``matrix`` of ``scipy.ndimage.affine_transform``,\n        operates on homogeneous coordinates.\n        See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html\n\n        When `normalized=True` and `reverse_indexing=False`,\n        it applies `theta` to the normalized coordinates (coords. in the range of [-1, 1]) directly.\n        This is often used with `align_corners=False` to achieve resolution-agnostic resampling,\n        thus useful as a part of trainable modules such as the spatial transformer networks.\n        See also: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html\n\n        Args:\n            spatial_size: output spatial shape, the full output shape will be\n                `[N, C, *spatial_size]` where N and C are inferred from the `src` input of `self.forward`.\n            normalized: indicating whether the provided affine matrix `theta` is defined\n                for the normalized coordinates. If `normalized=False`, `theta` will be converted\n                to operate on normalized coordinates as pytorch affine_grid works with the normalized\n                coordinates.\n            mode: {``\"bilinear\"``, ``\"nearest\"``}\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"zeros\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html.\n            reverse_indexing: whether to reverse the spatial indexing of image and coordinates.\n                set to `False` if `theta` follows pytorch's default \"D, H, W\" convention.\n                set to `True` if `theta` follows `scipy.ndimage` default \"i, j, k\" convention.\n            zero_centered: whether the affine is applied to coordinates in a zero-centered value range.\n                With `zero_centered=True`, for example, the center of rotation will be the\n                spatial center of the input; with `zero_centered=False`, the center of rotation will be the\n                origin of the input. This option is only available when `normalized=False`,\n                where the default behaviour is `False` if unspecified.\n                See also: :py:func:`monai.networks.utils.normalize_transform`.\n        \"\"\"\n        super().__init__()\n        self.spatial_size = ensure_tuple(spatial_size) if spatial_size is not None else None\n        self.normalized = normalized\n        self.mode: str = look_up_option(mode, GridSampleMode)\n        self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)\n        self.align_corners = align_corners\n        self.reverse_indexing = reverse_indexing\n        if zero_centered is not None and self.normalized:\n            raise ValueError(\"`normalized=True` is not compatible with the `zero_centered` option.\")\n        self.zero_centered = zero_centered if zero_centered is not None else False\n\n    def forward(\n        self, src: torch.Tensor, theta: torch.Tensor, spatial_size: Sequence[int] | int | None = None\n    ) -> torch.Tensor:\n        \"\"\"\n        ``theta`` must be an affine transformation matrix with shape\n        3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,\n        4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,\n        where `N` is the batch size. `theta` will be converted into float Tensor for the computation.\n\n        Args:\n            src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),\n                where N is the batch dim, C is the number of channels.\n            theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,\n                Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,\n                `theta` will be repeated N times, N is the batch dim of `src`.\n            spatial_size: output spatial shape, the full output shape will be\n                `[N, C, *spatial_size]` where N and C are inferred from the `src`.\n\n        Raises:\n            TypeError: When ``theta`` is not a ``torch.Tensor``.\n            ValueError: When ``theta`` is not one of [Nxdxd, dxd].\n            ValueError: When ``theta`` is not one of [Nx3x3, Nx4x4].\n            TypeError: When ``src`` is not a ``torch.Tensor``.\n            ValueError: When ``src`` spatially is not one of [2D, 3D].\n            ValueError: When affine and image batch dimension differ.\n\n        \"\"\"\n        # validate `theta`\n        if not isinstance(theta, torch.Tensor):\n            raise TypeError(f\"theta must be torch.Tensor but is {type(theta).__name__}.\")\n        if theta.dim() not in (2, 3):\n            raise ValueError(f\"theta must be Nxdxd or dxd, got {theta.shape}.\")\n        if theta.dim() == 2:\n            theta = theta[None]  # adds a batch dim.\n        theta = theta.clone()  # no in-place change of theta\n        theta_shape = tuple(theta.shape[1:])\n        if theta_shape in ((2, 3), (3, 4)):  # needs padding to dxd\n            pad_affine = torch.tensor([0, 0, 1] if theta_shape[0] == 2 else [0, 0, 0, 1])\n            pad_affine = pad_affine.repeat(theta.shape[0], 1, 1).to(theta)\n            pad_affine.requires_grad = False\n            theta = torch.cat([theta, pad_affine], dim=1)\n        if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):\n            raise ValueError(f\"theta must be Nx3x3 or Nx4x4, got {theta.shape}.\")\n        if not torch.is_floating_point(theta):\n            raise ValueError(f\"theta must be floating point data, got {theta.dtype}\")\n\n        # validate `src`\n        if not isinstance(src, torch.Tensor):\n            raise TypeError(f\"src must be torch.Tensor but is {type(src).__name__}.\")\n        sr = src.dim() - 2  # input spatial rank\n        if sr not in (2, 3):\n            raise ValueError(f\"Unsupported src dimension: {sr}, available options are [2, 3].\")\n\n        # set output shape\n        src_size = tuple(src.shape)\n        dst_size = src_size  # default to the src shape\n        if self.spatial_size is not None:\n            dst_size = src_size[:2] + self.spatial_size\n        if spatial_size is not None:\n            dst_size = src_size[:2] + ensure_tuple(spatial_size)\n\n        # reverse and normalize theta if needed\n        if not self.normalized:\n            theta = to_norm_affine(\n                affine=theta,\n                src_size=src_size[2:],\n                dst_size=dst_size[2:],\n                align_corners=False,\n                zero_centered=self.zero_centered,\n            )\n        if self.reverse_indexing:\n            rev_idx = torch.as_tensor(range(sr - 1, -1, -1), device=src.device)\n            theta[:, :sr] = theta[:, rev_idx]\n            theta[:, :, :sr] = theta[:, :, rev_idx]\n        if (theta.shape[0] == 1) and src_size[0] > 1:\n            # adds a batch dim to `theta` in order to match `src`\n            theta = theta.repeat(src_size[0], 1, 1)\n        if theta.shape[0] != src_size[0]:\n            raise ValueError(\n                f\"affine and image batch dimension must match, got affine={theta.shape[0]} image={src_size[0]}.\"\n            )\n\n        grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)\n        dst = nn.functional.grid_sample(\n            input=src.contiguous(),\n            grid=grid,\n            mode=self.mode,\n            padding_mode=self.padding_mode,\n            align_corners=self.align_corners,\n        )\n        return dst\n"
  },
  {
    "path": "monai/networks/layers/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch.nn\n\nfrom monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args\nfrom monai.utils import has_option\n\n__all__ = [\"get_norm_layer\", \"get_act_layer\", \"get_dropout_layer\", \"get_pool_layer\"]\n\n\ndef get_norm_layer(name: tuple | str, spatial_dims: int | None = 1, channels: int | None = 1):\n    \"\"\"\n    Create a normalization layer instance.\n\n    For example, to create normalization layers:\n\n    .. code-block:: python\n\n        from monai.networks.layers import get_norm_layer\n\n        g_layer = get_norm_layer(name=(\"group\", {\"num_groups\": 1}))\n        n_layer = get_norm_layer(name=\"instance\", spatial_dims=2)\n\n    Args:\n        name: a normalization type string or a tuple of type string and parameters.\n        spatial_dims: number of spatial dimensions of the input.\n        channels: number of features/channels when the normalization layer requires this parameter\n            but it is not specified in the norm parameters.\n    \"\"\"\n    if name == \"\":\n        return torch.nn.Identity()\n    norm_name, norm_args = split_args(name)\n    norm_type = Norm[norm_name, spatial_dims]\n    kw_args = dict(norm_args)\n    if has_option(norm_type, \"num_features\") and \"num_features\" not in kw_args:\n        kw_args[\"num_features\"] = channels\n    if has_option(norm_type, \"num_channels\") and \"num_channels\" not in kw_args:\n        kw_args[\"num_channels\"] = channels\n    return norm_type(**kw_args)\n\n\ndef get_act_layer(name: tuple | str):\n    \"\"\"\n    Create an activation layer instance.\n\n    For example, to create activation layers:\n\n    .. code-block:: python\n\n        from monai.networks.layers import get_act_layer\n\n        s_layer = get_act_layer(name=\"swish\")\n        p_layer = get_act_layer(name=(\"prelu\", {\"num_parameters\": 1, \"init\": 0.25}))\n\n    Args:\n        name: an activation type string or a tuple of type string and parameters.\n    \"\"\"\n    if name == \"\":\n        return torch.nn.Identity()\n    act_name, act_args = split_args(name)\n    act_type = Act[act_name]\n    return act_type(**act_args)\n\n\ndef get_dropout_layer(name: tuple | str | float | int, dropout_dim: int | None = 1):\n    \"\"\"\n    Create a dropout layer instance.\n\n    For example, to create dropout layers:\n\n    .. code-block:: python\n\n        from monai.networks.layers import get_dropout_layer\n\n        d_layer = get_dropout_layer(name=\"dropout\")\n        a_layer = get_dropout_layer(name=(\"alphadropout\", {\"p\": 0.25}))\n\n    Args:\n        name: a dropout ratio or a tuple of dropout type and parameters.\n        dropout_dim: the spatial dimension of the dropout operation.\n    \"\"\"\n    if name == \"\":\n        return torch.nn.Identity()\n    if isinstance(name, (int, float)):\n        # if dropout was specified simply as a p value, use default name and make a keyword map with the value\n        drop_name = Dropout.DROPOUT\n        drop_args = {\"p\": float(name)}\n    else:\n        drop_name, drop_args = split_args(name)\n    drop_type = Dropout[drop_name, dropout_dim]\n    return drop_type(**drop_args)\n\n\ndef get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):\n    \"\"\"\n    Create a pooling layer instance.\n\n    For example, to create adaptiveavg layer:\n\n    .. code-block:: python\n\n        from monai.networks.layers import get_pool_layer\n\n        pool_layer = get_pool_layer((\"adaptiveavg\", {\"output_size\": (1, 1, 1)}), spatial_dims=3)\n\n    Args:\n        name: a pooling type string or a tuple of type string and parameters.\n        spatial_dims: number of spatial dimensions of the input.\n\n    \"\"\"\n    if name == \"\":\n        return torch.nn.Identity()\n    pool_name, pool_args = split_args(name)\n    pool_type = Pool[pool_name, spatial_dims]\n    return pool_type(**pool_args)\n\n\ndef get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: tuple | None, c_dim: int, num_heads: int):\n    embedding_name, embedding_args = split_args(name)\n    embedding_type = RelPosEmbedding[embedding_name]\n    # create a dictionary with the default values which can be overridden by embedding_args\n    kw_args = {\"s_input_dims\": s_input_dims, \"c_dim\": c_dim, \"num_heads\": num_heads, **embedding_args}\n    # filter out unused argument names\n    kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}\n\n    return embedding_type(**kw_args)\n"
  },
  {
    "path": "monai/networks/layers/vector_quantizer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\n\n__all__ = [\"VectorQuantizer\", \"EMAQuantizer\"]\n\n\nclass EMAQuantizer(nn.Module):\n    \"\"\"\n    Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on  Neural\n    Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation\n    that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit\n    58d9a2746493717a7c9252938da7efa6006f3739.\n\n    This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due\n    to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353\n    on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input.\n        num_embeddings: number of atomic elements in the codebook.\n        embedding_dim: number of channels of the input and atomic elements.\n        commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.\n        decay: EMA decay. Defaults to 0.99.\n        epsilon: epsilon value. Defaults to 1e-5.\n        embedding_init: initialization method for the codebook. Defaults to \"normal\".\n        ddp_sync: whether to synchronize the codebook across processes. Defaults to True.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        num_embeddings: int,\n        embedding_dim: int,\n        commitment_cost: float = 0.25,\n        decay: float = 0.99,\n        epsilon: float = 1e-5,\n        embedding_init: str = \"normal\",\n        ddp_sync: bool = True,\n    ):\n        super().__init__()\n        self.spatial_dims: int = spatial_dims\n        self.embedding_dim: int = embedding_dim\n        self.num_embeddings: int = num_embeddings\n\n        assert self.spatial_dims in [2, 3], ValueError(\n            f\"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}.\"\n        )\n\n        self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim)\n        if embedding_init == \"normal\":\n            # Initialization is passed since the default one is normal inside the nn.Embedding\n            pass\n        elif embedding_init == \"kaiming_uniform\":\n            torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode=\"fan_in\", nonlinearity=\"linear\")\n        self.embedding.weight.requires_grad = False\n\n        self.commitment_cost: float = commitment_cost\n\n        self.register_buffer(\"ema_cluster_size\", torch.zeros(self.num_embeddings))\n        self.register_buffer(\"ema_w\", self.embedding.weight.data.clone())\n        # declare types for mypy\n        self.ema_cluster_size: torch.Tensor\n        self.ema_w: torch.Tensor\n        self.decay: float = decay\n        self.epsilon: float = epsilon\n\n        self.ddp_sync: bool = ddp_sync\n\n        # Precalculating required permutation shapes\n        self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1]\n        self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list(\n            range(1, self.spatial_dims + 1)\n        )\n\n    def quantize(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.\n\n        Args:\n            inputs: Encoding space tensors of shape [B, C, H, W, D].\n\n        Returns:\n            torch.Tensor: Flatten version of the input of shape [B*H*W*D, C].\n            torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings].\n            torch.Tensor: Quantization indices of shape [B,H,W,D,1]\n\n        \"\"\"\n        with torch.autocast(\"cuda\", enabled=False):\n            encoding_indices_view = list(inputs.shape)\n            del encoding_indices_view[1]\n\n            inputs = inputs.float()\n\n            # Converting to channel last format\n            flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)\n\n            # Calculate Euclidean distances\n            distances = (\n                (flat_input**2).sum(dim=1, keepdim=True)\n                + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)\n                - 2 * torch.mm(flat_input, self.embedding.weight.t())\n            )\n\n            # Mapping distances to indexes\n            encoding_indices = torch.max(-distances, dim=1)[1]\n            encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()\n\n            # Quantize and reshape\n            encoding_indices = encoding_indices.view(encoding_indices_view)\n\n            return flat_input, encodings, encoding_indices\n\n    def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space\n        [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the\n        decoder.\n\n        Args:\n            embedding_indices: Tensor in channel last format which holds indices referencing atomic\n                elements from self.embedding\n\n        Returns:\n            torch.Tensor: Quantize space representation of encoding_indices in channel first format.\n        \"\"\"\n        with torch.autocast(\"cuda\", enabled=False):\n            embedding: torch.Tensor = (\n                self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()\n            )\n            return embedding\n\n    def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:\n        \"\"\"\n        TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the\n        example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused\n\n        Args:\n            encodings_sum: The summation of one hot representation of what encoding was used for each\n                position.\n            dw: The multiplication of the one hot representation of what encoding was used for each\n                position with the flattened input.\n\n        Returns:\n            None\n        \"\"\"\n        if self.ddp_sync and torch.distributed.is_initialized():\n            torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)\n            torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)\n        else:\n            pass\n\n    def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        flat_input, encodings, encoding_indices = self.quantize(inputs)\n        quantized = self.embed(encoding_indices)\n\n        # Use EMA to update the embedding vectors\n        if self.training:\n            with torch.no_grad():\n                encodings_sum = encodings.sum(0)\n                dw = torch.mm(encodings.t(), flat_input)\n\n                if self.ddp_sync:\n                    self.distributed_synchronization(encodings_sum, dw)\n\n                self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay))\n\n                # Laplace smoothing of the cluster size\n                n = self.ema_cluster_size.sum()\n                weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n\n                self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay))\n                self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1))\n\n        # Encoding Loss\n        loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs)\n\n        # Straight Through Estimator\n        quantized = inputs + (quantized - inputs).detach()\n\n        return quantized, loss, encoding_indices\n\n\nclass VectorQuantizer(torch.nn.Module):\n    \"\"\"\n    Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of\n    the quantization in their own class.\n\n    Args:\n        quantizer (torch.nn.Module):  Quantizer module that needs to return its quantized representation, loss and index\n            based quantized representation.\n    \"\"\"\n\n    def __init__(self, quantizer: EMAQuantizer):\n        super().__init__()\n\n        self.quantizer: EMAQuantizer = quantizer\n\n        self.perplexity: torch.Tensor = torch.rand(1)\n\n    def forward(self, inputs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        quantized, loss, encoding_indices = self.quantizer(inputs)\n        # Perplexity calculations\n        avg_probs = (\n            torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings)\n            .float()\n            .div(encoding_indices.numel())\n        )\n\n        self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))\n\n        return loss, quantized\n\n    def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:\n        return self.quantizer.embed(embedding_indices=embedding_indices)\n\n    def quantize(self, encodings: torch.Tensor) -> torch.Tensor:\n        output = self.quantizer(encodings)\n        encoding_indices: torch.Tensor = output[2]\n        return encoding_indices\n"
  },
  {
    "path": "monai/networks/layers/weight_init.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\n\nimport torch\n\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    \"\"\"Tensor initialization with truncated normal distribution.\n    Based on:\n    https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    https://github.com/rwightman/pytorch-image-models\n\n    Args:\n       tensor: an n-dimensional `torch.Tensor`.\n       mean: the mean of the normal distribution.\n       std: the standard deviation of the normal distribution.\n       a: the minimum cutoff value.\n       b: the maximum cutoff value.\n    \"\"\"\n\n    def norm_cdf(x):\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    with torch.no_grad():\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n        tensor.erfinv_()\n        tensor.mul_(std * math.sqrt(2.0))\n        tensor.add_(mean)\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    \"\"\"Tensor initialization with truncated normal distribution.\n    Based on:\n    https://github.com/rwightman/pytorch-image-models\n\n    Args:\n       tensor: an n-dimensional `torch.Tensor`\n       mean: the mean of the normal distribution\n       std: the standard deviation of the normal distribution\n       a: the minimum cutoff value\n       b: the maximum cutoff value\n    \"\"\"\n\n    if std <= 0:\n        raise ValueError(\"the standard deviation should be greater than zero.\")\n\n    if a >= b:\n        raise ValueError(\"minimum cutoff value (a) should be smaller than maximum cutoff value (b).\")\n\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n"
  },
  {
    "path": "monai/networks/nets/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .ahnet import AHnet, Ahnet, AHNet\nfrom .attentionunet import AttentionUnet\nfrom .autoencoder import AutoEncoder\nfrom .autoencoderkl import AutoencoderKL\nfrom .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet\nfrom .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus\nfrom .classifier import Classifier, Critic, Discriminator\nfrom .controlnet import ControlNet\nfrom .daf3d import DAF3D\nfrom .densenet import (\n    DenseNet,\n    Densenet,\n    DenseNet121,\n    Densenet121,\n    DenseNet169,\n    Densenet169,\n    DenseNet201,\n    Densenet201,\n    DenseNet264,\n    Densenet264,\n    densenet121,\n    densenet169,\n    densenet201,\n    densenet264,\n)\nfrom .diffusion_model_unet import DiffusionModelUNet\nfrom .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch\nfrom .dynunet import DynUNet, DynUnet, Dynunet\nfrom .efficientnet import (\n    BlockArgs,\n    EfficientNet,\n    EfficientNetBN,\n    EfficientNetBNFeatures,\n    EfficientNetEncoder,\n    drop_connect,\n    get_efficientnet_image_size,\n)\nfrom .flexible_unet import FLEXUNET_BACKBONE, FlexibleUNet, FlexUNet, FlexUNetEncoderRegister\nfrom .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet\nfrom .generator import Generator\nfrom .highresnet import HighResBlock, HighResNet\nfrom .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet\nfrom .masked_autoencoder_vit import MaskedAutoEncoderViT\nfrom .mednext import (\n    MedNeXt,\n    MedNext,\n    MedNextB,\n    MedNeXtB,\n    MedNextBase,\n    MedNextL,\n    MedNeXtL,\n    MedNeXtLarge,\n    MedNextLarge,\n    MedNextM,\n    MedNeXtM,\n    MedNeXtMedium,\n    MedNextMedium,\n    MedNextS,\n    MedNeXtS,\n    MedNeXtSmall,\n    MedNextSmall,\n)\nfrom .milmodel import MILModel\nfrom .netadapter import NetAdapter\nfrom .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator\nfrom .quicknat import Quicknat\nfrom .regressor import Regressor\nfrom .regunet import GlobalNet, LocalNet, RegUNet\nfrom .resnet import (\n    ResNet,\n    ResNetBlock,\n    ResNetBottleneck,\n    ResNetEncoder,\n    ResNetFeatures,\n    get_medicalnet_pretrained_resnet_args,\n    get_pretrained_resnet_medicalnet,\n    resnet10,\n    resnet18,\n    resnet34,\n    resnet50,\n    resnet101,\n    resnet152,\n    resnet200,\n)\nfrom .segresnet import SegResNet, SegResNetVAE\nfrom .segresnet_ds import SegResNetDS, SegResNetDS2\nfrom .senet import (\n    SENet,\n    SEnet,\n    Senet,\n    SENet154,\n    SEnet154,\n    Senet154,\n    SEResNet50,\n    SEresnet50,\n    Seresnet50,\n    SEResNet101,\n    SEresnet101,\n    Seresnet101,\n    SEResNet152,\n    SEresnet152,\n    Seresnet152,\n    SEResNext50,\n    SEResNeXt50,\n    SEresnext50,\n    Seresnext50,\n    SEResNext101,\n    SEResNeXt101,\n    SEresnext101,\n    Seresnext101,\n    senet154,\n    seresnet50,\n    seresnet101,\n    seresnet152,\n    seresnext50,\n    seresnext101,\n)\nfrom .spade_autoencoderkl import SPADEAutoencoderKL\nfrom .spade_diffusion_model_unet import SPADEDiffusionModelUNet\nfrom .spade_network import SPADENet\nfrom .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR\nfrom .torchvision_fc import TorchVisionFCModel\nfrom .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex\nfrom .transformer import DecoderOnlyTransformer\nfrom .unet import UNet, Unet\nfrom .unetr import UNETR\nfrom .varautoencoder import VarAutoEncoder\nfrom .vista3d import VISTA3D, vista3d132\nfrom .vit import ViT\nfrom .vitautoenc import ViTAutoEnc\nfrom .vnet import VNet\nfrom .voxelmorph import VoxelMorph, VoxelMorphUNet\nfrom .vqvae import VQVAE\n"
  },
  {
    "path": "monai/networks/nets/ahnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks.fcn import FCN\nfrom monai.networks.layers.factories import Act, Conv, Norm, Pool\n\n__all__ = [\"AHnet\", \"Ahnet\", \"AHNet\"]\n\n\nclass Bottleneck3x3x1(nn.Module):\n    expansion = 4\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        inplanes: int,\n        planes: int,\n        stride: Sequence[int] | int = 1,\n        downsample: nn.Sequential | None = None,\n    ) -> None:\n        super().__init__()\n\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        pool_type: type[nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n\n        self.conv1 = conv_type(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = norm_type(planes)\n        self.conv2 = conv_type(\n            planes,\n            planes,\n            kernel_size=(3, 3, 1)[-spatial_dims:],\n            stride=stride,\n            padding=(1, 1, 0)[-spatial_dims:],\n            bias=False,\n        )\n        self.bn2 = norm_type(planes)\n        self.conv3 = conv_type(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = norm_type(planes * 4)\n        self.relu = relu_type(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n        self.pool = pool_type(kernel_size=(1, 1, 2)[-spatial_dims:], stride=(1, 1, 2)[-spatial_dims:])\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n            if out.size() != residual.size():\n                out = self.pool(out)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Projection(nn.Sequential):\n\n    def __init__(self, spatial_dims: int, num_input_features: int, num_output_features: int):\n        super().__init__()\n\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n\n        self.add_module(\"norm\", norm_type(num_input_features))\n        self.add_module(\"relu\", relu_type(inplace=True))\n        self.add_module(\"conv\", conv_type(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))\n\n\nclass DenseBlock(nn.Sequential):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        num_layers: int,\n        num_input_features: int,\n        bn_size: int,\n        growth_rate: int,\n        dropout_prob: float,\n    ):\n        super().__init__()\n        for i in range(num_layers):\n            layer = Pseudo3DLayer(\n                spatial_dims, num_input_features + i * growth_rate, growth_rate, bn_size, dropout_prob\n            )\n            self.add_module(f\"denselayer{i + 1}\", layer)\n\n\nclass UpTransition(nn.Sequential):\n\n    def __init__(\n        self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = \"transpose\"\n    ):\n        super().__init__()\n\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n\n        self.add_module(\"norm\", norm_type(num_input_features))\n        self.add_module(\"relu\", relu_type(inplace=True))\n        self.add_module(\"conv\", conv_type(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))\n        if upsample_mode == \"transpose\":\n            conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims]\n            self.add_module(\n                \"up\", conv_trans_type(num_output_features, num_output_features, kernel_size=2, stride=2, bias=False)\n            )\n        else:\n            align_corners: bool | None = None\n            if upsample_mode in [\"trilinear\", \"bilinear\"]:\n                align_corners = True\n            self.add_module(\"up\", nn.Upsample(scale_factor=2, mode=upsample_mode, align_corners=align_corners))\n\n\nclass Final(nn.Sequential):\n\n    def __init__(\n        self, spatial_dims: int, num_input_features: int, num_output_features: int, upsample_mode: str = \"transpose\"\n    ):\n        super().__init__()\n\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n\n        self.add_module(\"norm\", norm_type(num_input_features))\n        self.add_module(\"relu\", relu_type(inplace=True))\n        self.add_module(\n            \"conv\",\n            conv_type(\n                num_input_features,\n                num_output_features,\n                kernel_size=(3, 3, 1)[-spatial_dims:],\n                stride=1,\n                padding=(1, 1, 0)[-spatial_dims:],\n                bias=False,\n            ),\n        )\n        if upsample_mode == \"transpose\":\n            conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims]\n            self.add_module(\n                \"up\", conv_trans_type(num_output_features, num_output_features, kernel_size=2, stride=2, bias=False)\n            )\n        else:\n            align_corners: bool | None = None\n            if upsample_mode in [\"trilinear\", \"bilinear\"]:\n                align_corners = True\n            self.add_module(\"up\", nn.Upsample(scale_factor=2, mode=upsample_mode, align_corners=align_corners))\n\n\nclass Pseudo3DLayer(nn.Module):\n\n    def __init__(self, spatial_dims: int, num_input_features: int, growth_rate: int, bn_size: int, dropout_prob: float):\n        super().__init__()\n        # 1x1x1\n\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n\n        self.bn1 = norm_type(num_input_features)\n        self.relu1 = relu_type(inplace=True)\n        self.conv1 = conv_type(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)\n        # 3x3x1\n        self.bn2 = norm_type(bn_size * growth_rate)\n        self.relu2 = relu_type(inplace=True)\n        self.conv2 = conv_type(\n            bn_size * growth_rate,\n            growth_rate,\n            kernel_size=(3, 3, 1)[-spatial_dims:],\n            stride=1,\n            padding=(1, 1, 0)[-spatial_dims:],\n            bias=False,\n        )\n        # 1x1x3\n        self.bn3 = norm_type(growth_rate)\n        self.relu3 = relu_type(inplace=True)\n        self.conv3 = conv_type(\n            growth_rate,\n            growth_rate,\n            kernel_size=(1, 1, 3)[-spatial_dims:],\n            stride=1,\n            padding=(0, 0, 1)[-spatial_dims:],\n            bias=False,\n        )\n        # 1x1x1\n        self.bn4 = norm_type(growth_rate)\n        self.relu4 = relu_type(inplace=True)\n        self.conv4 = conv_type(growth_rate, growth_rate, kernel_size=1, stride=1, bias=False)\n        self.dropout_prob = dropout_prob\n\n    def forward(self, x):\n        inx = x\n        x = self.bn1(x)\n        x = self.relu1(x)\n        x = self.conv1(x)\n\n        x = self.bn2(x)\n        x = self.relu2(x)\n        x3x3x1 = self.conv2(x)\n\n        x = self.bn3(x3x3x1)\n        x = self.relu3(x)\n        x1x1x3 = self.conv3(x)\n\n        x = x3x3x1 + x1x1x3\n        x = self.bn4(x)\n        x = self.relu4(x)\n        new_features = self.conv4(x)\n\n        self.dropout_prob = 0.0  # Dropout will make trouble!\n        # since we use the train mode for inference\n        if self.dropout_prob > 0.0:\n            new_features = F.dropout(new_features, p=self.dropout_prob, training=self.training)\n        return torch.cat([inx, new_features], 1)\n\n\nclass PSP(nn.Module):\n\n    def __init__(self, spatial_dims: int, psp_block_num: int, in_ch: int, upsample_mode: str = \"transpose\"):\n        super().__init__()\n        self.up_modules = nn.ModuleList()\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        pool_type: type[nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n\n        self.pool_modules = nn.ModuleList()\n        self.project_modules = nn.ModuleList()\n\n        for i in range(psp_block_num):\n            size = (2 ** (i + 3), 2 ** (i + 3), 1)[-spatial_dims:]\n            self.pool_modules.append(pool_type(kernel_size=size, stride=size))\n            self.project_modules.append(\n                conv_type(in_ch, 1, kernel_size=(1, 1, 1)[-spatial_dims:], stride=1, padding=(1, 1, 0)[-spatial_dims:])\n            )\n\n        self.spatial_dims = spatial_dims\n        self.psp_block_num = psp_block_num\n        self.upsample_mode = upsample_mode\n\n        if self.upsample_mode == \"transpose\":\n            conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims]\n            for i in range(psp_block_num):\n                size = (2 ** (i + 3), 2 ** (i + 3), 1)[-spatial_dims:]\n                pad_size = (2 ** (i + 3), 2 ** (i + 3), 0)[-spatial_dims:]\n                self.up_modules.append(conv_trans_type(1, 1, kernel_size=size, stride=size, padding=pad_size))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        outputs = []\n        if self.upsample_mode == \"transpose\":\n            for project_module, pool_module, up_module in zip(self.project_modules, self.pool_modules, self.up_modules):\n                output = up_module(project_module(pool_module(x)))\n                outputs.append(output)\n        else:\n            for project_module, pool_module in zip(self.project_modules, self.pool_modules):\n                interpolate_size = x.shape[2:]\n                align_corners: bool | None = None\n                if self.upsample_mode in [\"trilinear\", \"bilinear\"]:\n                    align_corners = True\n                output = F.interpolate(\n                    project_module(pool_module(x)),\n                    size=interpolate_size,\n                    mode=self.upsample_mode,\n                    align_corners=align_corners,\n                )\n                outputs.append(output)\n        x = torch.cat(outputs, dim=1)\n        return x\n\n\nclass AHNet(nn.Module):\n    \"\"\"\n    AHNet based on `Anisotropic Hybrid Network <https://arxiv.org/pdf/1711.08580.pdf>`_.\n    Adapted from `lsqshr's official code <https://github.com/lsqshr/AH-Net/blob/master/net3d.py>`_.\n    Except from the original network that supports 3D inputs, this implementation also supports 2D inputs.\n    According to the `tests for deconvolutions <https://github.com/Project-MONAI/MONAI/issues/1023>`_, using\n    ``\"transpose\"`` rather than linear interpolations is faster. Therefore, this implementation sets ``\"transpose\"``\n    as the default upsampling method.\n\n    To meet the requirements of the structure, the input size for each spatial dimension\n    (except the last one) should be: divisible by 2 ** (psp_block_num + 3) and no less than 32 in ``transpose`` mode,\n    and should be divisible by 32 and no less than 2 ** (psp_block_num + 3) in other upsample modes.\n    In addition, the input size for the last spatial dimension should be divisible by 32, and at least one spatial size\n    should be no less than 64.\n\n    Args:\n        layers: number of residual blocks for 4 layers of the network (layer1...layer4). Defaults to ``(3, 4, 6, 3)``.\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        in_channels: number of input channels for the network. Default to 1.\n        out_channels: number of output channels for the network. Defaults to 1.\n        psp_block_num: the number of pyramid volumetric pooling modules used at the end of the network before the final\n            output layer for extracting multiscale features. The number should be an integer that belongs to [0,4]. Defaults\n            to 4.\n        upsample_mode: [``\"transpose\"``, ``\"bilinear\"``, ``\"trilinear\"``, ``nearest``]\n            The mode of upsampling manipulations.\n            Using the last two modes cannot guarantee the model's reproducibility. Defaults to ``transpose``.\n\n            - ``\"transpose\"``, uses transposed convolution layers.\n            - ``\"bilinear\"``, uses bilinear interpolate.\n            - ``\"trilinear\"``, uses trilinear interpolate.\n            - ``\"nearest\"``, uses nearest interpolate.\n        pretrained: whether to load pretrained weights from ResNet50 to initialize convolution layers, default to False.\n        progress: If True, displays a progress bar of the download of pretrained weights to stderr.\n    \"\"\"\n\n    def __init__(\n        self,\n        layers: tuple = (3, 4, 6, 3),\n        spatial_dims: int = 3,\n        in_channels: int = 1,\n        out_channels: int = 1,\n        psp_block_num: int = 4,\n        upsample_mode: str = \"transpose\",\n        pretrained: bool = False,\n        progress: bool = True,\n    ):\n        self.inplanes = 64\n        super().__init__()\n\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        conv_trans_type = Conv[Conv.CONVTRANS, spatial_dims]\n        norm_type = Norm[Norm.BATCH, spatial_dims]\n        pool_type: type[nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n        conv2d_type: type[nn.Conv2d] = Conv[Conv.CONV, 2]\n        norm2d_type: type[nn.BatchNorm2d] = Norm[Norm.BATCH, 2]\n\n        self.conv2d_type = conv2d_type\n        self.norm2d_type = norm2d_type\n        self.conv_type = conv_type\n        self.norm_type = norm_type\n        self.relu_type = relu_type\n        self.pool_type = pool_type\n        self.spatial_dims = spatial_dims\n        self.psp_block_num = psp_block_num\n        self.psp: PSP\n\n        if spatial_dims not in [2, 3]:\n            raise AssertionError(\"spatial_dims can only be 2 or 3.\")\n        if psp_block_num not in [0, 1, 2, 3, 4]:\n            raise AssertionError(\"psp_block_num should be an integer that belongs to [0, 4].\")\n\n        self.conv1 = conv_type(\n            in_channels,\n            64,\n            kernel_size=(7, 7, 3)[-spatial_dims:],\n            stride=(2, 2, 1)[-spatial_dims:],\n            padding=(3, 3, 1)[-spatial_dims:],\n            bias=False,\n        )\n        self.pool1 = pool_type(kernel_size=(1, 1, 2)[-spatial_dims:], stride=(1, 1, 2)[-spatial_dims:])\n        self.bn0 = norm_type(64)\n        self.relu = relu_type(inplace=True)\n        if upsample_mode in [\"transpose\", \"nearest\"]:\n            # To maintain the determinism, the value of kernel_size and stride should be the same.\n            # (you can check this link for reference: https://github.com/Project-MONAI/MONAI/pull/815 )\n            self.maxpool = pool_type(kernel_size=(2, 2, 2)[-spatial_dims:], stride=2)\n        else:\n            self.maxpool = pool_type(kernel_size=(3, 3, 3)[-spatial_dims:], stride=2, padding=1)\n\n        self.layer1 = self._make_layer(Bottleneck3x3x1, 64, layers[0], stride=1)\n        self.layer2 = self._make_layer(Bottleneck3x3x1, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(Bottleneck3x3x1, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(Bottleneck3x3x1, 512, layers[3], stride=2)\n\n        # Make the 3D dense decoder layers\n        densegrowth = 20\n        densebn = 4\n        ndenselayer = 3\n\n        num_init_features = 64\n        noutres1 = 256\n        noutres2 = 512\n        noutres3 = 1024\n        noutres4 = 2048\n\n        self.up0 = UpTransition(spatial_dims, noutres4, noutres3, upsample_mode)\n        self.dense0 = DenseBlock(spatial_dims, ndenselayer, noutres3, densebn, densegrowth, 0.0)\n        noutdense = noutres3 + ndenselayer * densegrowth\n\n        self.up1 = UpTransition(spatial_dims, noutdense, noutres2, upsample_mode)\n        self.dense1 = DenseBlock(spatial_dims, ndenselayer, noutres2, densebn, densegrowth, 0.0)\n        noutdense1 = noutres2 + ndenselayer * densegrowth\n\n        self.up2 = UpTransition(spatial_dims, noutdense1, noutres1, upsample_mode)\n        self.dense2 = DenseBlock(spatial_dims, ndenselayer, noutres1, densebn, densegrowth, 0.0)\n        noutdense2 = noutres1 + ndenselayer * densegrowth\n\n        self.trans1 = Projection(spatial_dims, noutdense2, num_init_features)\n        self.dense3 = DenseBlock(spatial_dims, ndenselayer, num_init_features, densebn, densegrowth, 0.0)\n        noutdense3 = num_init_features + densegrowth * ndenselayer\n\n        self.up3 = UpTransition(spatial_dims, noutdense3, num_init_features, upsample_mode)\n        self.dense4 = DenseBlock(spatial_dims, ndenselayer, num_init_features, densebn, densegrowth, 0.0)\n        noutdense4 = num_init_features + densegrowth * ndenselayer\n\n        self.psp = PSP(spatial_dims, psp_block_num, noutdense4, upsample_mode)\n        self.final = Final(spatial_dims, psp_block_num + noutdense4, out_channels, upsample_mode)\n\n        # Initialise parameters\n        for m in self.modules():\n            if isinstance(m, (conv_type, conv_trans_type)):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2.0 / n))\n            elif isinstance(m, norm_type):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n        if pretrained:\n            net2d = FCN(pretrained=True, progress=progress)\n            self.copy_from(net2d)\n\n    def _make_layer(self, block: type[Bottleneck3x3x1], planes: int, blocks: int, stride: int = 1) -> nn.Sequential:\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                self.conv_type(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=(stride, stride, 1)[: self.spatial_dims],\n                    bias=False,\n                ),\n                self.pool_type(\n                    kernel_size=(1, 1, stride)[: self.spatial_dims], stride=(1, 1, stride)[: self.spatial_dims]\n                ),\n                self.norm_type(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(\n            block(self.spatial_dims, self.inplanes, planes, (stride, stride, 1)[: self.spatial_dims], downsample)\n        )\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.spatial_dims, self.inplanes, planes))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.pool1(x)\n        x = self.bn0(x)\n        x = self.relu(x)\n        conv_x = x\n        x = self.maxpool(x)\n        pool_x = x\n\n        fm1 = self.layer1(x)\n        fm2 = self.layer2(fm1)\n        fm3 = self.layer3(fm2)\n        fm4 = self.layer4(fm3)\n\n        sum0 = self.up0(fm4) + fm3\n        d0 = self.dense0(sum0)\n\n        sum1 = self.up1(d0) + fm2\n        d1 = self.dense1(sum1)\n\n        sum2 = self.up2(d1) + fm1\n        d2 = self.dense2(sum2)\n\n        sum3 = self.trans1(d2) + pool_x\n        d3 = self.dense3(sum3)\n\n        sum4 = self.up3(d3) + conv_x\n        d4 = self.dense4(sum4)\n        if self.psp_block_num > 0:\n            psp = self.psp(d4)\n            x = torch.cat((psp, d4), dim=1)\n        else:\n            x = d4\n        return self.final(x)\n\n    def copy_from(self, net):\n        # This method only supports for 3D AHNet, the input channel should be 1.\n        p2d, p3d = next(net.conv1.parameters()), next(self.conv1.parameters())\n\n        # From 64x3x7x7 -> 64x3x7x7x1 -> 64x1x7x7x3\n        weights = p2d.data.unsqueeze(dim=4).permute(0, 4, 2, 3, 1).clone()\n        p3d.data = weights.repeat([1, p3d.shape[1], 1, 1, 1])\n\n        # Copy the initial module BN0\n        copy_bn_param(net.bn0, self.bn0)\n\n        # Copy layer1 to layer4\n        for i in range(1, 5):\n            layer_num = \"layer\" + str(i)\n\n            layer_2d = []\n            layer_3d = []\n            for m1 in vars(net)[\"_modules\"][layer_num].modules():\n                if isinstance(m1, (self.norm2d_type, self.conv2d_type)):\n                    layer_2d.append(m1)\n            for m2 in vars(self)[\"_modules\"][layer_num].modules():\n                if isinstance(m2, (self.norm_type, self.conv_type)):\n                    layer_3d.append(m2)\n\n            for m1, m2 in zip(layer_2d, layer_3d):\n                if isinstance(m1, self.conv2d_type):\n                    copy_conv_param(m1, m2)\n                if isinstance(m1, self.norm2d_type):\n                    copy_bn_param(m1, m2)\n\n\ndef copy_conv_param(module2d, module3d):\n    for p2d, p3d in zip(module2d.parameters(), module3d.parameters()):\n        p3d.data[:] = p2d.data.unsqueeze(dim=4).clone()[:]\n\n\ndef copy_bn_param(module2d, module3d):\n    for p2d, p3d in zip(module2d.parameters(), module3d.parameters()):\n        p3d.data[:] = p2d.data[:]  # Two parameter gamma and beta\n\n\nAHnet = Ahnet = AHNet\n"
  },
  {
    "path": "monai/networks/nets/attentionunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.layers.factories import Norm\n\n__all__ = [\"AttentionUnet\"]\n\n\nclass ConvBlock(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[int] | int = 3,\n        strides: int = 1,\n        dropout=0.0,\n    ):\n        super().__init__()\n        layers = [\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                strides=strides,\n                padding=None,\n                adn_ordering=\"NDA\",\n                act=\"relu\",\n                norm=Norm.BATCH,\n                dropout=dropout,\n            ),\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=out_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                strides=1,\n                padding=None,\n                adn_ordering=\"NDA\",\n                act=\"relu\",\n                norm=Norm.BATCH,\n                dropout=dropout,\n            ),\n        ]\n        self.conv = nn.Sequential(*layers)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_c: torch.Tensor = self.conv(x)\n        return x_c\n\n\nclass UpConv(nn.Module):\n\n    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size=3, strides=2, dropout=0.0):\n        super().__init__()\n        self.up = Convolution(\n            spatial_dims,\n            in_channels,\n            out_channels,\n            strides=strides,\n            kernel_size=kernel_size,\n            act=\"relu\",\n            adn_ordering=\"NDA\",\n            norm=Norm.BATCH,\n            dropout=dropout,\n            is_transposed=True,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_u: torch.Tensor = self.up(x)\n        return x_u\n\n\nclass AttentionBlock(nn.Module):\n\n    def __init__(self, spatial_dims: int, f_int: int, f_g: int, f_l: int, dropout=0.0):\n        super().__init__()\n        self.W_g = nn.Sequential(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=f_g,\n                out_channels=f_int,\n                kernel_size=1,\n                strides=1,\n                padding=0,\n                dropout=dropout,\n                conv_only=True,\n            ),\n            Norm[Norm.BATCH, spatial_dims](f_int),\n        )\n\n        self.W_x = nn.Sequential(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=f_l,\n                out_channels=f_int,\n                kernel_size=1,\n                strides=1,\n                padding=0,\n                dropout=dropout,\n                conv_only=True,\n            ),\n            Norm[Norm.BATCH, spatial_dims](f_int),\n        )\n\n        self.psi = nn.Sequential(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=f_int,\n                out_channels=1,\n                kernel_size=1,\n                strides=1,\n                padding=0,\n                dropout=dropout,\n                conv_only=True,\n            ),\n            Norm[Norm.BATCH, spatial_dims](1),\n            nn.Sigmoid(),\n        )\n\n        self.relu = nn.ReLU()\n\n    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:\n        g1 = self.W_g(g)\n        x1 = self.W_x(x)\n        psi: torch.Tensor = self.relu(g1 + x1)\n        psi = self.psi(psi)\n\n        return x * psi\n\n\nclass AttentionLayer(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        submodule: nn.Module,\n        up_kernel_size=3,\n        strides=2,\n        dropout=0.0,\n    ):\n        super().__init__()\n        self.attention = AttentionBlock(\n            spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2\n        )\n        self.upconv = UpConv(\n            spatial_dims=spatial_dims,\n            in_channels=out_channels,\n            out_channels=in_channels,\n            strides=strides,\n            kernel_size=up_kernel_size,\n        )\n        self.merge = Convolution(\n            spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout\n        )\n        self.submodule = submodule\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        fromlower = self.upconv(self.submodule(x))\n        att = self.attention(g=fromlower, x=x)\n        att_m: torch.Tensor = self.merge(torch.cat((att, fromlower), dim=1))\n        return att_m\n\n\nclass AttentionUnet(nn.Module):\n    \"\"\"\n    Attention Unet based on\n    Otkay et al. \"Attention U-Net: Learning Where to Look for the Pancreas\"\n    https://arxiv.org/abs/1804.03999\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input image.\n        in_channels: number of the input channel.\n        out_channels: number of the output classes.\n        channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2.\n        strides (Sequence[int]): stride to use for convolutions.\n        kernel_size: convolution kernel size.\n        up_kernel_size: convolution kernel size for transposed convolution layers.\n        dropout: dropout ratio. Defaults to no dropout.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        up_kernel_size: Sequence[int] | int = 3,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n        self.dimensions = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.channels = channels\n        self.strides = strides\n        self.kernel_size = kernel_size\n        self.dropout = dropout\n\n        head = ConvBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=channels[0],\n            dropout=dropout,\n            kernel_size=self.kernel_size,\n        )\n        reduce_channels = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=channels[0],\n            out_channels=out_channels,\n            kernel_size=1,\n            strides=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.up_kernel_size = up_kernel_size\n\n        def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:\n            if len(channels) > 2:\n                subblock = _create_block(channels[1:], strides[1:])\n                return AttentionLayer(\n                    spatial_dims=spatial_dims,\n                    in_channels=channels[0],\n                    out_channels=channels[1],\n                    submodule=nn.Sequential(\n                        ConvBlock(\n                            spatial_dims=spatial_dims,\n                            in_channels=channels[0],\n                            out_channels=channels[1],\n                            strides=strides[0],\n                            dropout=self.dropout,\n                            kernel_size=self.kernel_size,\n                        ),\n                        subblock,\n                    ),\n                    up_kernel_size=self.up_kernel_size,\n                    strides=strides[0],\n                    dropout=dropout,\n                )\n            else:\n                # the next layer is the bottom so stop recursion,\n                # create the bottom layer as the subblock for this layer\n                return self._get_bottom_layer(channels[0], channels[1], strides[0])\n\n        encdec = _create_block(self.channels, self.strides)\n        self.model = nn.Sequential(head, encdec, reduce_channels)\n\n    def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module:\n        return AttentionLayer(\n            spatial_dims=self.dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            submodule=ConvBlock(\n                spatial_dims=self.dimensions,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                strides=strides,\n                dropout=self.dropout,\n                kernel_size=self.kernel_size,\n            ),\n            up_kernel_size=self.up_kernel_size,\n            strides=strides,\n            dropout=self.dropout,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_m: torch.Tensor = self.model(x)\n        return x_m\n"
  },
  {
    "path": "monai/networks/nets/autoencoder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution, ResidualUnit\nfrom monai.networks.layers.factories import Act, Norm\n\n__all__ = [\"AutoEncoder\"]\n\n\nclass AutoEncoder(nn.Module):\n    \"\"\"\n    Simple definition of an autoencoder and base class for the architecture implementing\n    :py:class:`monai.networks.nets.VarAutoEncoder`. The network is composed of an encode sequence of blocks, followed\n    by an intermediary sequence of blocks, and finally a decode sequence of blocks. The encode and decode blocks are\n    default :py:class:`monai.networks.blocks.Convolution` instances with the encode blocks having the given stride\n    and the decode blocks having transpose convolutions with the same stride. If `num_res_units` is given residual\n    blocks are used instead.\n\n    By default the intermediary sequence is empty but if `inter_channels` is given to specify the output channels of\n    blocks then this will be become a sequence of Convolution blocks or of residual blocks if `num_inter_units` is\n    given. The optional parameter `inter_dilations` can be used to specify the dilation values of the convolutions in\n    these blocks, this allows a network to use dilated kernels in this  middle section. Since the intermediary section\n    isn't meant to change the size of the output the strides for all these kernels is 1.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.\n        strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.\n        kernel_size: convolution kernel size, the value(s) should be odd. If sequence,\n            its length should equal to dimensions. Defaults to 3.\n        up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,\n            its length should equal to dimensions. Defaults to 3.\n        num_res_units: number of residual units. Defaults to 0.\n        inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode.\n        inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1.\n        num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0.\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        dropout: dropout ratio. Defaults to no dropout.\n        bias: whether to have a bias term in convolution blocks. Defaults to True.\n            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n            if a conv layer is directly followed by a batch norm layer, bias should be False.\n        padding: controls the amount of implicit zero-paddings on both sides for padding number of points\n            for each dimension in convolution blocks. Defaults to None.\n\n    Examples::\n\n        from monai.networks.nets import AutoEncoder\n\n        # 3 layers each down/up sampling their inputs by a factor 2 with no intermediate layer\n        net = AutoEncoder(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            channels=(2, 4, 8),\n            strides=(2, 2, 2)\n        )\n\n        # 1 layer downsampling by 2, followed by a sequence of residual units with 2 convolutions defined by\n        # progressively increasing dilations, then final upsample layer\n        net = AutoEncoder(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(4,),\n                strides=(2,),\n                inter_channels=(8, 8, 8),\n                inter_dilations=(1, 2, 4),\n                num_inter_units=2\n            )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        up_kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 0,\n        inter_channels: list | None = None,\n        inter_dilations: list | None = None,\n        num_inter_units: int = 2,\n        act: tuple | str | None = Act.PRELU,\n        norm: tuple | str = Norm.INSTANCE,\n        dropout: tuple | str | float | None = None,\n        bias: bool = True,\n        padding: Sequence[int] | int | None = None,\n    ) -> None:\n        super().__init__()\n        self.dimensions = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.channels = list(channels)\n        self.strides = list(strides)\n        self.kernel_size = kernel_size\n        self.up_kernel_size = up_kernel_size\n        self.num_res_units = num_res_units\n        self.act = act\n        self.norm = norm\n        self.dropout = dropout\n        self.bias = bias\n        self.padding = padding\n        self.num_inter_units = num_inter_units\n        self.inter_channels = inter_channels if inter_channels is not None else []\n        self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels))\n\n        # The number of channels and strides should match\n        if len(channels) != len(strides):\n            raise ValueError(\"Autoencoder expects matching number of channels and strides\")\n\n        self.encoded_channels = in_channels\n        decode_channel_list = list(channels[-2::-1]) + [out_channels]\n\n        self.encode, self.encoded_channels = self._get_encode_module(self.encoded_channels, channels, strides)\n        self.intermediate, self.encoded_channels = self._get_intermediate_module(self.encoded_channels, num_inter_units)\n        self.decode, _ = self._get_decode_module(self.encoded_channels, decode_channel_list, strides[::-1] or [1])\n\n    def _get_encode_module(\n        self, in_channels: int, channels: Sequence[int], strides: Sequence[int]\n    ) -> tuple[nn.Sequential, int]:\n        \"\"\"\n        Returns the encode part of the network by building up a sequence of layers returned by `_get_encode_layer`.\n        \"\"\"\n        encode = nn.Sequential()\n        layer_channels = in_channels\n\n        for i, (c, s) in enumerate(zip(channels, strides)):\n            layer = self._get_encode_layer(layer_channels, c, s, False)\n            encode.add_module(f\"encode_{i}\", layer)\n            layer_channels = c\n\n        return encode, layer_channels\n\n    def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> tuple[nn.Module, int]:\n        \"\"\"\n        Returns the intermediate block of the network which accepts input from the encoder and whose output goes\n        to the decoder.\n        \"\"\"\n        # Define some types\n        intermediate: nn.Module\n        unit: nn.Module\n\n        intermediate = nn.Identity()\n        layer_channels = in_channels\n\n        if self.inter_channels:\n            intermediate = nn.Sequential()\n\n            for i, (dc, di) in enumerate(zip(self.inter_channels, self.inter_dilations)):\n                if self.num_inter_units > 0:\n                    unit = ResidualUnit(\n                        spatial_dims=self.dimensions,\n                        in_channels=layer_channels,\n                        out_channels=dc,\n                        strides=1,\n                        kernel_size=self.kernel_size,\n                        subunits=self.num_inter_units,\n                        act=self.act,\n                        norm=self.norm,\n                        dropout=self.dropout,\n                        dilation=di,\n                        bias=self.bias,\n                        padding=self.padding,\n                    )\n                else:\n                    unit = Convolution(\n                        spatial_dims=self.dimensions,\n                        in_channels=layer_channels,\n                        out_channels=dc,\n                        strides=1,\n                        kernel_size=self.kernel_size,\n                        act=self.act,\n                        norm=self.norm,\n                        dropout=self.dropout,\n                        dilation=di,\n                        bias=self.bias,\n                        padding=self.padding,\n                    )\n\n                intermediate.add_module(f\"inter_{i}\", unit)\n                layer_channels = dc\n\n        return intermediate, layer_channels\n\n    def _get_decode_module(\n        self, in_channels: int, channels: Sequence[int], strides: Sequence[int]\n    ) -> tuple[nn.Sequential, int]:\n        \"\"\"\n        Returns the decode part of the network by building up a sequence of layers returned by `_get_decode_layer`.\n        \"\"\"\n        decode = nn.Sequential()\n        layer_channels = in_channels\n\n        for i, (c, s) in enumerate(zip(channels, strides)):\n            layer = self._get_decode_layer(layer_channels, c, s, i == (len(strides) - 1))\n            decode.add_module(f\"decode_{i}\", layer)\n            layer_channels = c\n\n        return decode, layer_channels\n\n    def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Module:\n        \"\"\"\n        Returns a single layer of the encoder part of the network.\n        \"\"\"\n        mod: nn.Module\n        if self.num_res_units > 0:\n            mod = ResidualUnit(\n                spatial_dims=self.dimensions,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                strides=strides,\n                kernel_size=self.kernel_size,\n                subunits=self.num_res_units,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n                padding=self.padding,\n                last_conv_only=is_last,\n            )\n            return mod\n        mod = Convolution(\n            spatial_dims=self.dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            strides=strides,\n            kernel_size=self.kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n            padding=self.padding,\n            conv_only=is_last,\n        )\n        return mod\n\n    def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, is_last: bool) -> nn.Sequential:\n        \"\"\"\n        Returns a single layer of the decoder part of the network.\n        \"\"\"\n        decode = nn.Sequential()\n\n        conv = Convolution(\n            spatial_dims=self.dimensions,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            strides=strides,\n            kernel_size=self.up_kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n            padding=self.padding,\n            conv_only=is_last and self.num_res_units == 0,\n            is_transposed=True,\n        )\n\n        decode.add_module(\"conv\", conv)\n\n        if self.num_res_units > 0:\n            ru = ResidualUnit(\n                spatial_dims=self.dimensions,\n                in_channels=out_channels,\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=self.kernel_size,\n                subunits=1,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n                padding=self.padding,\n                last_conv_only=is_last,\n            )\n\n            decode.add_module(\"resunit\", ru)\n\n        return decode\n\n    def forward(self, x: torch.Tensor) -> Any:\n        x = self.encode(x)\n        x = self.intermediate(x)\n        x = self.decode(x)\n        return x\n"
  },
  {
    "path": "monai/networks/nets/autoencoderkl.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample\nfrom monai.utils import ensure_tuple_rep, optional_import\n\nRearrange, _ = optional_import(\"einops.layers.torch\", name=\"Rearrange\")\n\n__all__ = [\"AutoencoderKL\"]\n\n\nclass AsymmetricPad(nn.Module):\n    \"\"\"\n    Pad the input tensor asymmetrically along every spatial dimension.\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n    \"\"\"\n\n    def __init__(self, spatial_dims: int) -> None:\n        super().__init__()\n        self.pad = (0, 1) * spatial_dims\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = nn.functional.pad(x, self.pad, mode=\"constant\", value=0.0)\n        return x\n\n\nclass AEKLDownsample(nn.Module):\n    \"\"\"\n    Convolution-based downsampling layer.\n\n    Args:\n        spatial_dims: number of spatial dimensions (1D, 2D, 3D).\n        in_channels: number of input channels.\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, in_channels: int) -> None:\n        super().__init__()\n        self.pad = AsymmetricPad(spatial_dims=spatial_dims)\n\n        self.conv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            strides=2,\n            kernel_size=3,\n            padding=0,\n            conv_only=True,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.pad(x)\n        x = self.conv(x)\n        return x\n\n\nclass AEKLResBlock(nn.Module):\n    \"\"\"\n    Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a\n    residual connection between input and output.\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n        in_channels: input channels to the layer.\n        norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of\n            channels is divisible by this number.\n        norm_eps: epsilon for the normalisation.\n        out_channels: number of output channels.\n    \"\"\"\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels if out_channels is None else out_channels\n\n        self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)\n        self.conv1 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.in_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n        self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)\n        self.conv2 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.out_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        self.nin_shortcut: nn.Module\n        if self.in_channels != self.out_channels:\n            self.nin_shortcut = Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.in_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=1,\n                padding=0,\n                conv_only=True,\n            )\n        else:\n            self.nin_shortcut = nn.Identity()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        h = x\n        h = self.norm1(h)\n        h = F.silu(h)\n        h = self.conv1(h)\n\n        h = self.norm2(h)\n        h = F.silu(h)\n        h = self.conv2(h)\n\n        x = self.nin_shortcut(x)\n\n        return x + h\n\n\nclass Encoder(nn.Module):\n    \"\"\"\n    Convolutional cascade that downsamples the image into a spatial latent space.\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n        in_channels: number of input channels.\n        channels: sequence of block output channels.\n        out_channels: number of channels in the bottom layer (latent space) of the autoencoder.\n        num_res_blocks: number of residual blocks (see _ResBlock) per level.\n        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.\n        norm_eps: epsilon for the normalization.\n        attention_levels: indicate which level from channels contain an attention block.\n        with_nonlocal_attn: if True use non-local attention block.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        channels: Sequence[int],\n        out_channels: int,\n        num_res_blocks: Sequence[int],\n        norm_num_groups: int,\n        norm_eps: float,\n        attention_levels: Sequence[bool],\n        with_nonlocal_attn: bool = True,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.channels = channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.norm_num_groups = norm_num_groups\n        self.norm_eps = norm_eps\n        self.attention_levels = attention_levels\n\n        blocks: list[nn.Module] = []\n        # Initial convolution\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=channels[0],\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        # Residual and downsampling blocks\n        output_channel = channels[0]\n        for i in range(len(channels)):\n            input_channel = output_channel\n            output_channel = channels[i]\n            is_final_block = i == len(channels) - 1\n\n            for _ in range(self.num_res_blocks[i]):\n                blocks.append(\n                    AEKLResBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=input_channel,\n                        norm_num_groups=norm_num_groups,\n                        norm_eps=norm_eps,\n                        out_channels=output_channel,\n                    )\n                )\n                input_channel = output_channel\n                if attention_levels[i]:\n                    blocks.append(\n                        SpatialAttentionBlock(\n                            spatial_dims=spatial_dims,\n                            num_channels=input_channel,\n                            norm_num_groups=norm_num_groups,\n                            norm_eps=norm_eps,\n                            include_fc=include_fc,\n                            use_combined_linear=use_combined_linear,\n                            use_flash_attention=use_flash_attention,\n                        )\n                    )\n\n            if not is_final_block:\n                blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel))\n        # Non-local attention block\n        if with_nonlocal_attn is True:\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=channels[-1],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=channels[-1],\n                )\n            )\n\n            blocks.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=channels[-1],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=channels[-1],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=channels[-1],\n                )\n            )\n        # Normalise and convert to latent size\n        blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True))\n        blocks.append(\n            Convolution(\n                spatial_dims=self.spatial_dims,\n                in_channels=channels[-1],\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            x = block(x)\n        return x\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Convolutional cascade upsampling from a spatial latent space into an image space.\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n        channels: sequence of block output channels.\n        in_channels: number of channels in the bottom layer (latent space) of the autoencoder.\n        out_channels: number of output channels.\n        num_res_blocks: number of residual blocks (see _ResBlock) per level.\n        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.\n        norm_eps: epsilon for the normalization.\n        attention_levels: indicate which level from channels contain an attention block.\n        with_nonlocal_attn: if True use non-local attention block.\n        use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        channels: Sequence[int],\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: Sequence[int],\n        norm_num_groups: int,\n        norm_eps: float,\n        attention_levels: Sequence[bool],\n        with_nonlocal_attn: bool = True,\n        use_convtranspose: bool = False,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.channels = channels\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.norm_num_groups = norm_num_groups\n        self.norm_eps = norm_eps\n        self.attention_levels = attention_levels\n\n        reversed_block_out_channels = list(reversed(channels))\n\n        blocks: list[nn.Module] = []\n\n        # Initial convolution\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=reversed_block_out_channels[0],\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        # Non-local attention block\n        if with_nonlocal_attn is True:\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=reversed_block_out_channels[0],\n                )\n            )\n            blocks.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n            blocks.append(\n                AEKLResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=reversed_block_out_channels[0],\n                )\n            )\n\n        reversed_attention_levels = list(reversed(attention_levels))\n        reversed_num_res_blocks = list(reversed(num_res_blocks))\n        block_out_ch = reversed_block_out_channels[0]\n        for i in range(len(reversed_block_out_channels)):\n            block_in_ch = block_out_ch\n            block_out_ch = reversed_block_out_channels[i]\n            is_final_block = i == len(channels) - 1\n\n            for _ in range(reversed_num_res_blocks[i]):\n                blocks.append(\n                    AEKLResBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=block_in_ch,\n                        norm_num_groups=norm_num_groups,\n                        norm_eps=norm_eps,\n                        out_channels=block_out_ch,\n                    )\n                )\n                block_in_ch = block_out_ch\n\n                if reversed_attention_levels[i]:\n                    blocks.append(\n                        SpatialAttentionBlock(\n                            spatial_dims=spatial_dims,\n                            num_channels=block_in_ch,\n                            norm_num_groups=norm_num_groups,\n                            norm_eps=norm_eps,\n                            include_fc=include_fc,\n                            use_combined_linear=use_combined_linear,\n                            use_flash_attention=use_flash_attention,\n                        )\n                    )\n\n            if not is_final_block:\n                if use_convtranspose:\n                    blocks.append(\n                        Upsample(\n                            spatial_dims=spatial_dims, mode=\"deconv\", in_channels=block_in_ch, out_channels=block_in_ch\n                        )\n                    )\n                else:\n                    post_conv = Convolution(\n                        spatial_dims=spatial_dims,\n                        in_channels=block_in_ch,\n                        out_channels=block_in_ch,\n                        strides=1,\n                        kernel_size=3,\n                        padding=1,\n                        conv_only=True,\n                    )\n                    blocks.append(\n                        Upsample(\n                            spatial_dims=spatial_dims,\n                            mode=\"nontrainable\",\n                            in_channels=block_in_ch,\n                            out_channels=block_in_ch,\n                            interp_mode=\"nearest\",\n                            scale_factor=2.0,\n                            post_conv=post_conv,\n                            align_corners=None,\n                        )\n                    )\n\n        blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=block_in_ch,\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            x = block(x)\n        return x\n\n\nclass AutoencoderKL(nn.Module):\n    \"\"\"\n    Autoencoder model with KL-regularized latent space based on\n    Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n    and Pinaya et al. \"Brain Imaging Generation with Latent Diffusion Models\" https://arxiv.org/abs/2209.07162\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        num_res_blocks: number of residual blocks (see _ResBlock) per level.\n        channels: number of output channels for each block.\n        attention_levels: sequence of levels to add attention.\n        latent_channels: latent embedding dimension.\n        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.\n        norm_eps: epsilon for the normalization.\n        with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.\n        with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.\n        use_checkpoint: if True, use activation checkpoint to save memory.\n        use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.\n        include_fc: whether to include the final linear layer in the attention block. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int = 1,\n        out_channels: int = 1,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        latent_channels: int = 3,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        with_encoder_nonlocal_attn: bool = True,\n        with_decoder_nonlocal_attn: bool = True,\n        use_checkpoint: bool = False,\n        use_convtranspose: bool = False,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in channels):\n            raise ValueError(\"AutoencoderKL expects all channels being multiple of norm_num_groups\")\n\n        if len(channels) != len(attention_levels):\n            raise ValueError(\"AutoencoderKL expects channels being same size of attention_levels\")\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))\n\n        if len(num_res_blocks) != len(channels):\n            raise ValueError(\n                \"`num_res_blocks` should be a single integer or a tuple of integers with the same length as \"\n                \"`channels`.\"\n            )\n\n        self.encoder: nn.Module = Encoder(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            channels=channels,\n            out_channels=latent_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            attention_levels=attention_levels,\n            with_nonlocal_attn=with_encoder_nonlocal_attn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.decoder: nn.Module = Decoder(\n            spatial_dims=spatial_dims,\n            channels=channels,\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            attention_levels=attention_levels,\n            with_nonlocal_attn=with_decoder_nonlocal_attn,\n            use_convtranspose=use_convtranspose,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.quant_conv_mu = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=latent_channels,\n            out_channels=latent_channels,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.quant_conv_log_sigma = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=latent_channels,\n            out_channels=latent_channels,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.post_quant_conv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=latent_channels,\n            out_channels=latent_channels,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.latent_channels = latent_channels\n        self.use_checkpoint = use_checkpoint\n\n    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.\n\n        Args:\n            x: BxCx[SPATIAL DIMS] tensor\n\n        \"\"\"\n        if self.use_checkpoint:\n            h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)\n        else:\n            h = self.encoder(x)\n\n        z_mu = self.quant_conv_mu(h)\n        z_log_var = self.quant_conv_log_sigma(h)\n        z_log_var = torch.clamp(z_log_var, -30.0, 20.0)\n        z_sigma = torch.exp(z_log_var / 2)\n\n        return z_mu, z_sigma\n\n    def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        From the mean and sigma representations resulting of encoding an image through the latent space,\n        obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and\n        adding the mean.\n\n        Args:\n            z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image\n            z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image\n\n        Returns:\n            sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]\n        \"\"\"\n        eps = torch.randn_like(z_sigma)\n        z_vae = z_mu + eps * z_sigma\n        return z_vae\n\n    def reconstruct(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Encodes and decodes an input image.\n\n        Args:\n            x: BxCx[SPATIAL DIMENSIONS] tensor.\n\n        Returns:\n            reconstructed image, of the same shape as input\n        \"\"\"\n        z_mu, _ = self.encode(x)\n        reconstruction = self.decode(z_mu)\n        return reconstruction\n\n    def decode(self, z: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Based on a latent space sample, forwards it through the Decoder.\n\n        Args:\n            z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]\n\n        Returns:\n            decoded image tensor\n        \"\"\"\n        z = self.post_quant_conv(z)\n        dec: torch.Tensor\n        if self.use_checkpoint:\n            dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)\n        else:\n            dec = self.decoder(z)\n        return dec\n\n    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        z_mu, z_sigma = self.encode(x)\n        z = self.sampling(z_mu, z_sigma)\n        reconstruction = self.decode(z)\n        return reconstruction, z_mu, z_sigma\n\n    def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:\n        z_mu, z_sigma = self.encode(x)\n        z = self.sampling(z_mu, z_sigma)\n        return z\n\n    def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:\n        image = self.decode(z)\n        return image\n\n    def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:\n        \"\"\"\n        Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).\n\n        Args:\n            old_state_dict: state dict from the old AutoencoderKL model.\n        \"\"\"\n\n        new_state_dict = self.state_dict()\n        # if all keys match, just load the state dict\n        if all(k in new_state_dict for k in old_state_dict):\n            print(\"All keys match, loading state dict.\")\n            self.load_state_dict(old_state_dict)\n            return\n\n        if verbose:\n            # print all new_state_dict keys that are not in old_state_dict\n            for k in new_state_dict:\n                if k not in old_state_dict:\n                    print(f\"key {k} not found in old state dict\")\n            # and vice versa\n            print(\"----------------------------------------------\")\n            for k in old_state_dict:\n                if k not in new_state_dict:\n                    print(f\"key {k} not found in new state dict\")\n\n        # copy over all matching keys\n        for k in new_state_dict:\n            if k in old_state_dict:\n                new_state_dict[k] = old_state_dict.pop(k)\n\n        # fix the attention blocks\n        attention_blocks = [k.replace(\".attn.to_q.weight\", \"\") for k in new_state_dict if \"attn.to_q.weight\" in k]\n        for block in attention_blocks:\n            new_state_dict[f\"{block}.attn.to_q.weight\"] = old_state_dict.pop(f\"{block}.to_q.weight\")\n            new_state_dict[f\"{block}.attn.to_k.weight\"] = old_state_dict.pop(f\"{block}.to_k.weight\")\n            new_state_dict[f\"{block}.attn.to_v.weight\"] = old_state_dict.pop(f\"{block}.to_v.weight\")\n            new_state_dict[f\"{block}.attn.to_q.bias\"] = old_state_dict.pop(f\"{block}.to_q.bias\")\n            new_state_dict[f\"{block}.attn.to_k.bias\"] = old_state_dict.pop(f\"{block}.to_k.bias\")\n            new_state_dict[f\"{block}.attn.to_v.bias\"] = old_state_dict.pop(f\"{block}.to_v.bias\")\n\n            # old version did not have a projection so set these to the identity\n            new_state_dict[f\"{block}.attn.out_proj.weight\"] = torch.eye(\n                new_state_dict[f\"{block}.attn.out_proj.weight\"].shape[0]\n            )\n            new_state_dict[f\"{block}.attn.out_proj.bias\"] = torch.zeros(\n                new_state_dict[f\"{block}.attn.out_proj.bias\"].shape\n            )\n\n        # fix the upsample conv blocks which were renamed postconv\n        for k in new_state_dict:\n            if \"postconv\" in k:\n                old_name = k.replace(\"postconv\", \"conv\")\n                new_state_dict[k] = old_state_dict.pop(old_name)\n        if verbose:\n            # print all remaining keys in old_state_dict\n            print(\"remaining keys in old_state_dict:\", old_state_dict.keys())\n        self.load_state_dict(new_state_dict, strict=True)\n"
  },
  {
    "path": "monai/networks/nets/basic_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution, UpSample\nfrom monai.networks.layers.factories import Conv, Pool\nfrom monai.utils import ensure_tuple_rep\n\n__all__ = [\"BasicUnet\", \"Basicunet\", \"basicunet\", \"BasicUNet\"]\n\n\nclass TwoConv(nn.Sequential):\n    \"\"\"two convolutions.\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_chns: int,\n        out_chns: int,\n        act: str | tuple,\n        norm: str | tuple,\n        bias: bool,\n        dropout: float | tuple = 0.0,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_chns: number of input channels.\n            out_chns: number of output channels.\n            act: activation type and arguments.\n            norm: feature normalization type and arguments.\n            bias: whether to have a bias term in convolution blocks.\n            dropout: dropout ratio. Defaults to no dropout.\n\n        \"\"\"\n        super().__init__()\n\n        conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1)\n        conv_1 = Convolution(\n            spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1\n        )\n        self.add_module(\"conv_0\", conv_0)\n        self.add_module(\"conv_1\", conv_1)\n\n\nclass Down(nn.Sequential):\n    \"\"\"maxpooling downsampling and two convolutions.\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_chns: int,\n        out_chns: int,\n        act: str | tuple,\n        norm: str | tuple,\n        bias: bool,\n        dropout: float | tuple = 0.0,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_chns: number of input channels.\n            out_chns: number of output channels.\n            act: activation type and arguments.\n            norm: feature normalization type and arguments.\n            bias: whether to have a bias term in convolution blocks.\n            dropout: dropout ratio. Defaults to no dropout.\n\n        \"\"\"\n        super().__init__()\n        max_pooling = Pool[\"MAX\", spatial_dims](kernel_size=2)\n        convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout)\n        self.add_module(\"max_pooling\", max_pooling)\n        self.add_module(\"convs\", convs)\n\n\nclass UpCat(nn.Module):\n    \"\"\"upsampling, concatenation with the encoder feature map, two convolutions\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_chns: int,\n        cat_chns: int,\n        out_chns: int,\n        act: str | tuple,\n        norm: str | tuple,\n        bias: bool,\n        dropout: float | tuple = 0.0,\n        upsample: str = \"deconv\",\n        pre_conv: nn.Module | str | None = \"default\",\n        interp_mode: str = \"linear\",\n        align_corners: bool | None = True,\n        halves: bool = True,\n        is_pad: bool = True,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_chns: number of input channels to be upsampled.\n            cat_chns: number of channels from the encoder.\n            out_chns: number of output channels.\n            act: activation type and arguments.\n            norm: feature normalization type and arguments.\n            bias: whether to have a bias term in convolution blocks.\n            dropout: dropout ratio. Defaults to no dropout.\n            upsample: upsampling mode, available options are\n                ``\"deconv\"``, ``\"pixelshuffle\"``, ``\"nontrainable\"``.\n            pre_conv: a conv block applied before upsampling.\n                Only used in the \"nontrainable\" or \"pixelshuffle\" mode.\n            interp_mode: {``\"nearest\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``}\n                Only used in the \"nontrainable\" mode.\n            align_corners: set the align_corners parameter for upsample. Defaults to True.\n                Only used in the \"nontrainable\" mode.\n            halves: whether to halve the number of channels during upsampling.\n                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.\n            is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.\n\n        \"\"\"\n        super().__init__()\n        if upsample == \"nontrainable\" and pre_conv is None:\n            up_chns = in_chns\n        else:\n            up_chns = in_chns // 2 if halves else in_chns\n        self.upsample = UpSample(\n            spatial_dims,\n            in_chns,\n            up_chns,\n            2,\n            mode=upsample,\n            pre_conv=pre_conv,\n            interp_mode=interp_mode,\n            align_corners=align_corners,\n        )\n        self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout)\n        self.is_pad = is_pad\n\n    def forward(self, x: torch.Tensor, x_e: torch.Tensor | None):\n        \"\"\"\n\n        Args:\n            x: features to be upsampled.\n            x_e: optional features from the encoder, if None, this branch is not in use.\n        \"\"\"\n        x_0 = self.upsample(x)\n\n        if x_e is not None and torch.jit.isinstance(x_e, torch.Tensor):\n            if self.is_pad:\n                # handling spatial shapes due to the 2x maxpooling with odd edge lengths.\n                dimensions = len(x.shape) - 2\n                sp = [0] * (dimensions * 2)\n                for i in range(dimensions):\n                    if x_e.shape[-i - 1] != x_0.shape[-i - 1]:\n                        sp[i * 2 + 1] = 1\n                x_0 = torch.nn.functional.pad(x_0, sp, \"replicate\")\n            x = self.convs(torch.cat([x_e, x_0], dim=1))  # input channels: (cat_chns + up_chns)\n        else:\n            x = self.convs(x_0)\n\n        return x\n\n\nclass BasicUNet(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        features: Sequence[int] = (32, 32, 64, 128, 256, 32),\n        act: str | tuple = (\"LeakyReLU\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        norm: str | tuple = (\"instance\", {\"affine\": True}),\n        bias: bool = True,\n        dropout: float | tuple = 0.0,\n        upsample: str = \"deconv\",\n    ):\n        \"\"\"\n        A UNet implementation with 1D/2D/3D supports.\n\n        Based on:\n\n            Falk et al. \"U-Net – Deep Learning for Cell Counting, Detection, and\n            Morphometry\". Nature Methods 16, 67–70 (2019), DOI:\n            http://dx.doi.org/10.1038/s41592-018-0261-2\n\n        Args:\n            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.\n            in_channels: number of input channels. Defaults to 1.\n            out_channels: number of output channels. Defaults to 2.\n            features: six integers as numbers of features.\n                Defaults to ``(32, 32, 64, 128, 256, 32)``,\n\n                - the first five values correspond to the five-level encoder feature sizes.\n                - the last value corresponds to the feature size after the last upsampling.\n\n            act: activation type and arguments. Defaults to LeakyReLU.\n            norm: feature normalization type and arguments. Defaults to instance norm.\n            bias: whether to have a bias term in convolution blocks. Defaults to True.\n                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n                if a conv layer is directly followed by a batch norm layer, bias should be False.\n            dropout: dropout ratio. Defaults to no dropout.\n            upsample: upsampling mode, available options are\n                ``\"deconv\"``, ``\"pixelshuffle\"``, ``\"nontrainable\"``.\n\n        Examples::\n\n            # for spatial 2D\n            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))\n\n            # for spatial 2D, with group norm\n            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=(\"group\", {\"num_groups\": 4}))\n\n            # for spatial 3D\n            >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))\n\n        See Also\n\n            - :py:class:`monai.networks.nets.DynUNet`\n            - :py:class:`monai.networks.nets.UNet`\n\n        \"\"\"\n        super().__init__()\n        fea = ensure_tuple_rep(features, 6)\n        print(f\"BasicUNet features: {fea}.\")\n\n        self.conv_0 = TwoConv(spatial_dims, in_channels, features[0], act, norm, bias, dropout)\n        self.down_1 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout)\n        self.down_2 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout)\n        self.down_3 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout)\n        self.down_4 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout)\n\n        self.upcat_4 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)\n        self.upcat_3 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)\n        self.upcat_2 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)\n        self.upcat_1 = UpCat(spatial_dims, fea[1], fea[0], fea[5], act, norm, bias, dropout, upsample, halves=False)\n\n        self.final_conv = Conv[\"conv\", spatial_dims](fea[5], out_channels, kernel_size=1)\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\n        Args:\n            x: input should have spatially N dimensions\n                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`.\n                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have\n                even edge lengths.\n\n        Returns:\n            A torch Tensor of \"raw\" predictions in shape\n            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.\n        \"\"\"\n        x0 = self.conv_0(x)\n\n        x1 = self.down_1(x0)\n        x2 = self.down_2(x1)\n        x3 = self.down_3(x2)\n        x4 = self.down_4(x3)\n\n        u4 = self.upcat_4(x4, x3)\n        u3 = self.upcat_3(u4, x2)\n        u2 = self.upcat_2(u3, x1)\n        u1 = self.upcat_1(u2, x0)\n\n        logits = self.final_conv(u1)\n        return logits\n\n\nBasicUnet = Basicunet = basicunet = BasicUNet\n"
  },
  {
    "path": "monai/networks/nets/basic_unetplusplus.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.layers.factories import Conv\nfrom monai.networks.nets.basic_unet import Down, TwoConv, UpCat\nfrom monai.utils import ensure_tuple_rep\n\n__all__ = [\"BasicUnetPlusPlus\", \"BasicunetPlusPlus\", \"basicunetplusplus\", \"BasicUNetPlusPlus\"]\n\n\nclass BasicUNetPlusPlus(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        features: Sequence[int] = (32, 32, 64, 128, 256, 32),\n        deep_supervision: bool = False,\n        act: str | tuple = (\"LeakyReLU\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        norm: str | tuple = (\"instance\", {\"affine\": True}),\n        bias: bool = True,\n        dropout: float | tuple = 0.0,\n        upsample: str = \"deconv\",\n    ):\n        \"\"\"\n        A UNet++ implementation with 1D/2D/3D supports.\n\n        Based on:\n\n            Zhou et al. \"UNet++: A Nested U-Net Architecture for Medical Image\n            Segmentation\". 4th Deep Learning in Medical Image Analysis (DLMIA)\n            Workshop, DOI: https://doi.org/10.48550/arXiv.1807.10165\n\n\n        Args:\n            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.\n            in_channels: number of input channels. Defaults to 1.\n            out_channels: number of output channels. Defaults to 2.\n            features: six integers as numbers of features.\n                Defaults to ``(32, 32, 64, 128, 256, 32)``,\n\n                - the first five values correspond to the five-level encoder feature sizes.\n                - the last value corresponds to the feature size after the last upsampling.\n\n            deep_supervision: whether to prune the network at inference time. Defaults to False. If true, returns a list,\n                whose elements correspond to outputs at different nodes.\n            act: activation type and arguments. Defaults to LeakyReLU.\n            norm: feature normalization type and arguments. Defaults to instance norm.\n            bias: whether to have a bias term in convolution blocks. Defaults to True.\n                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n                if a conv layer is directly followed by a batch norm layer, bias should be False.\n            dropout: dropout ratio. Defaults to no dropout.\n            upsample: upsampling mode, available options are\n                ``\"deconv\"``, ``\"pixelshuffle\"``, ``\"nontrainable\"``.\n\n        Examples::\n\n            # for spatial 2D\n            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))\n\n            # for spatial 2D, with deep supervision enabled\n            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), deep_supervision=True)\n\n            # for spatial 2D, with group norm\n            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=(\"group\", {\"num_groups\": 4}))\n\n            # for spatial 3D\n            >>> net = BasicUNetPlusPlus(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))\n\n        See Also\n            - :py:class:`monai.networks.nets.BasicUNet`\n            - :py:class:`monai.networks.nets.DynUNet`\n            - :py:class:`monai.networks.nets.UNet`\n\n        \"\"\"\n        super().__init__()\n\n        self.deep_supervision = deep_supervision\n\n        fea = ensure_tuple_rep(features, 6)\n        print(f\"BasicUNetPlusPlus features: {fea}.\")\n\n        self.conv_0_0 = TwoConv(spatial_dims, in_channels, fea[0], act, norm, bias, dropout)\n        self.conv_1_0 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout)\n        self.conv_2_0 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout)\n        self.conv_3_0 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout)\n        self.conv_4_0 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout)\n\n        self.upcat_0_1 = UpCat(spatial_dims, fea[1], fea[0], fea[0], act, norm, bias, dropout, upsample, halves=False)\n        self.upcat_1_1 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)\n        self.upcat_2_1 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)\n        self.upcat_3_1 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)\n\n        self.upcat_0_2 = UpCat(\n            spatial_dims, fea[1], fea[0] * 2, fea[0], act, norm, bias, dropout, upsample, halves=False\n        )\n        self.upcat_1_2 = UpCat(spatial_dims, fea[2], fea[1] * 2, fea[1], act, norm, bias, dropout, upsample)\n        self.upcat_2_2 = UpCat(spatial_dims, fea[3], fea[2] * 2, fea[2], act, norm, bias, dropout, upsample)\n\n        self.upcat_0_3 = UpCat(\n            spatial_dims, fea[1], fea[0] * 3, fea[0], act, norm, bias, dropout, upsample, halves=False\n        )\n        self.upcat_1_3 = UpCat(spatial_dims, fea[2], fea[1] * 3, fea[1], act, norm, bias, dropout, upsample)\n\n        self.upcat_0_4 = UpCat(\n            spatial_dims, fea[1], fea[0] * 4, fea[5], act, norm, bias, dropout, upsample, halves=False\n        )\n\n        self.final_conv_0_1 = Conv[\"conv\", spatial_dims](fea[0], out_channels, kernel_size=1)\n        self.final_conv_0_2 = Conv[\"conv\", spatial_dims](fea[0], out_channels, kernel_size=1)\n        self.final_conv_0_3 = Conv[\"conv\", spatial_dims](fea[0], out_channels, kernel_size=1)\n        self.final_conv_0_4 = Conv[\"conv\", spatial_dims](fea[5], out_channels, kernel_size=1)\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\n        Args:\n            x: input should have spatially N dimensions\n                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `dimensions`.\n                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have\n                even edge lengths.\n\n        Returns:\n            A torch Tensor of \"raw\" predictions in shape\n            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.\n        \"\"\"\n        x_0_0 = self.conv_0_0(x)\n        x_1_0 = self.conv_1_0(x_0_0)\n        x_0_1 = self.upcat_0_1(x_1_0, x_0_0)\n\n        x_2_0 = self.conv_2_0(x_1_0)\n        x_1_1 = self.upcat_1_1(x_2_0, x_1_0)\n        x_0_2 = self.upcat_0_2(x_1_1, torch.cat([x_0_0, x_0_1], dim=1))\n\n        x_3_0 = self.conv_3_0(x_2_0)\n        x_2_1 = self.upcat_2_1(x_3_0, x_2_0)\n        x_1_2 = self.upcat_1_2(x_2_1, torch.cat([x_1_0, x_1_1], dim=1))\n        x_0_3 = self.upcat_0_3(x_1_2, torch.cat([x_0_0, x_0_1, x_0_2], dim=1))\n\n        x_4_0 = self.conv_4_0(x_3_0)\n        x_3_1 = self.upcat_3_1(x_4_0, x_3_0)\n        x_2_2 = self.upcat_2_2(x_3_1, torch.cat([x_2_0, x_2_1], dim=1))\n        x_1_3 = self.upcat_1_3(x_2_2, torch.cat([x_1_0, x_1_1, x_1_2], dim=1))\n        x_0_4 = self.upcat_0_4(x_1_3, torch.cat([x_0_0, x_0_1, x_0_2, x_0_3], dim=1))\n\n        output_0_1 = self.final_conv_0_1(x_0_1)\n        output_0_2 = self.final_conv_0_2(x_0_2)\n        output_0_3 = self.final_conv_0_3(x_0_3)\n        output_0_4 = self.final_conv_0_4(x_0_4)\n\n        if self.deep_supervision:\n            output = [output_0_1, output_0_2, output_0_3, output_0_4]\n        else:\n            output = [output_0_4]\n\n        return output\n\n\nBasicUnetPlusPlus = BasicunetPlusPlus = basicunetplusplus = BasicUNetPlusPlus\n"
  },
  {
    "path": "monai/networks/nets/cell_sam_wrapper.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom monai.utils import optional_import\n\nbuild_sam_vit_b, has_sam = optional_import(\"segment_anything.build_sam\", name=\"build_sam_vit_b\")\n\n_all__ = [\"CellSamWrapper\"]\n\n\nclass CellSamWrapper(torch.nn.Module):\n    \"\"\"\n    CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything\n    with an image only decoder, that can be used for segmentation tasks.\n\n\n    Args:\n        auto_resize_inputs: whether to resize inputs before passing to the network.\n            (usually they need be resized, unless they are already at the expected size)\n        network_resize_roi: expected input size for the network.\n            (currently SAM expects 1024x1024)\n        checkpoint: checkpoint file to load the SAM weights from.\n            (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)\n        return_features: whether to return features from SAM encoder\n            (without using decoder/upsampling to the original input size)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        auto_resize_inputs=True,\n        network_resize_roi=(1024, 1024),\n        checkpoint=\"sam_vit_b_01ec64.pth\",\n        return_features=False,\n        *args,\n        **kwargs,\n    ) -> None:\n        super().__init__(*args, **kwargs)\n\n        self.network_resize_roi = network_resize_roi\n        self.auto_resize_inputs = auto_resize_inputs\n        self.return_features = return_features\n\n        if not has_sam:\n            raise ValueError(\n                \"SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git\"\n            )\n\n        model = build_sam_vit_b(checkpoint=checkpoint)\n\n        model.prompt_encoder = None\n        model.mask_decoder = None\n\n        model.mask_decoder = nn.Sequential(\n            nn.BatchNorm2d(num_features=256),\n            nn.ReLU(inplace=True),\n            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),\n            nn.BatchNorm2d(num_features=128),\n            nn.ReLU(inplace=True),\n            nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),\n        )\n\n        self.model = model\n\n    def forward(self, x):\n        sh = x.shape[2:]\n\n        if self.auto_resize_inputs:\n            x = F.interpolate(x, size=self.network_resize_roi, mode=\"bilinear\")\n\n        x = self.model.image_encoder(x)\n\n        if not self.return_features:\n            x = self.model.mask_decoder(x)\n            if self.auto_resize_inputs:\n                x = F.interpolate(x, size=sh, mode=\"bilinear\")\n\n        return x\n"
  },
  {
    "path": "monai/networks/nets/classifier.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.layers.factories import Act, Norm, split_args\nfrom monai.networks.nets.regressor import Regressor\n\n__all__ = [\"Classifier\", \"Discriminator\", \"Critic\"]\n\n\nclass Classifier(Regressor):\n    \"\"\"\n    Defines a classification network from Regressor by specifying the output shape as a single dimensional tensor\n    with size equal to the number of classes to predict. The final activation function can also be specified, eg.\n    softmax or sigmoid.\n\n    Args:\n        in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)\n        classes: integer stating the dimension of the final output tensor\n        channels: tuple of integers stating the output channels of each convolutional layer\n        strides: tuple of integers stating the stride (downscale factor) of each convolutional layer\n        kernel_size: integer or tuple of integers stating size of convolutional kernels\n        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units\n        act: name or type defining activation layers\n        norm: name or type defining normalization layers\n        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout\n        bias: boolean stating if convolution layers should have a bias component\n        last_act: name defining the last activation layer\n    \"\"\"\n\n    def __init__(\n        self,\n        in_shape: Sequence[int],\n        classes: int,\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 2,\n        act=Act.PRELU,\n        norm=Norm.INSTANCE,\n        dropout: float | None = None,\n        bias: bool = True,\n        last_act: str | None = None,\n    ) -> None:\n        super().__init__(in_shape, (classes,), channels, strides, kernel_size, num_res_units, act, norm, dropout, bias)\n\n        if last_act is not None:\n            last_act_name, last_act_args = split_args(last_act)\n            last_act_type = Act[last_act_name]\n\n            self.final.add_module(\"lastact\", last_act_type(**last_act_args))\n\n\nclass Discriminator(Classifier):\n    \"\"\"\n    Defines a discriminator network from Classifier with a single output value and sigmoid activation by default. This\n    is meant for use with GANs or other applications requiring a generic discriminator network.\n\n    Args:\n        in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)\n        channels: tuple of integers stating the output channels of each convolutional layer\n        strides: tuple of integers stating the stride (downscale factor) of each convolutional layer\n        kernel_size: integer or tuple of integers stating size of convolutional kernels\n        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units\n        act: name or type defining activation layers\n        norm: name or type defining normalization layers\n        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout\n        bias: boolean stating if convolution layers should have a bias component\n        last_act: name defining the last activation layer\n    \"\"\"\n\n    def __init__(\n        self,\n        in_shape: Sequence[int],\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 2,\n        act=Act.PRELU,\n        norm=Norm.INSTANCE,\n        dropout: float | None = 0.25,\n        bias: bool = True,\n        last_act=Act.SIGMOID,\n    ) -> None:\n        super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, last_act)\n\n\nclass Critic(Classifier):\n    \"\"\"\n    Defines a critic network from Classifier with a single output value and no final activation. The final layer is\n    `nn.Flatten` instead of `nn.Linear`, the final result is computed as the mean over the first dimension. This is\n    meant to be used with Wasserstein GANs.\n\n    Args:\n        in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)\n        channels: tuple of integers stating the output channels of each convolutional layer\n        strides: tuple of integers stating the stride (downscale factor) of each convolutional layer\n        kernel_size: integer or tuple of integers stating size of convolutional kernels\n        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units\n        act: name or type defining activation layers\n        norm: name or type defining normalization layers\n        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout\n        bias: boolean stating if convolution layers should have a bias component\n    \"\"\"\n\n    def __init__(\n        self,\n        in_shape: Sequence[int],\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 2,\n        act=Act.PRELU,\n        norm=Norm.INSTANCE,\n        dropout: float | None = 0.25,\n        bias: bool = True,\n    ) -> None:\n        super().__init__(in_shape, 1, channels, strides, kernel_size, num_res_units, act, norm, dropout, bias, None)\n\n    def _get_final_layer(self, in_shape: Sequence[int]):\n        return nn.Flatten()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.net(x)\n        x = self.final(x)\n        x = x.mean(1)\n        return x.view((x.shape[0], -1))\n"
  },
  {
    "path": "monai/networks/nets/controlnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding\nfrom monai.utils import ensure_tuple_rep\n\n\nclass ControlNetConditioningEmbedding(nn.Module):\n    \"\"\"\n    Network to encode the conditioning into a latent space.\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]):\n        super().__init__()\n\n        self.conv_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=channels[0],\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            adn_ordering=\"A\",\n            act=\"SWISH\",\n        )\n\n        self.blocks = nn.ModuleList([])\n\n        for i in range(len(channels) - 1):\n            channel_in = channels[i]\n            channel_out = channels[i + 1]\n            self.blocks.append(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=channel_in,\n                    out_channels=channel_in,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    adn_ordering=\"A\",\n                    act=\"SWISH\",\n                )\n            )\n\n            self.blocks.append(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=channel_in,\n                    out_channels=channel_out,\n                    strides=2,\n                    kernel_size=3,\n                    padding=1,\n                    adn_ordering=\"A\",\n                    act=\"SWISH\",\n                )\n            )\n\n        self.conv_out = zero_module(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=channels[-1],\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n    def forward(self, conditioning):\n        embedding = self.conv_in(conditioning)\n\n        for block in self.blocks:\n            embedding = block(embedding)\n\n        embedding = self.conv_out(embedding)\n\n        return embedding\n\n\ndef zero_module(module):\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module\n\n\nclass ControlNet(nn.Module):\n    \"\"\"\n    Control network for diffusion models based on Zhang and Agrawala \"Adding Conditional Control to Text-to-Image\n    Diffusion Models\" (https://arxiv.org/abs/2302.05543)\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        num_res_blocks: number of residual blocks (see ResnetBlock) per level.\n        channels: tuple of block output channels.\n        attention_levels: list of levels to add attention.\n        norm_num_groups: number of groups for the normalization.\n        norm_eps: epsilon for the normalization.\n        resblock_updown: if True use residual blocks for up/downsampling.\n        num_head_channels: number of channels in each attention head.\n        with_conditioning: if True add spatial transformers to perform conditioning.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`\n            classes.\n        upcast_attention: if True, upcast attention operations to full precision.\n        conditioning_embedding_in_channels: number of input channels for the conditioning embedding.\n        conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to True.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        resblock_updown: bool = False,\n        num_head_channels: int | Sequence[int] = 8,\n        with_conditioning: bool = False,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        num_class_embeds: int | None = None,\n        upcast_attention: bool = False,\n        conditioning_embedding_in_channels: int = 1,\n        conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        if with_conditioning is True and cross_attention_dim is None:\n            raise ValueError(\n                \"ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) \"\n                \"to be specified when with_conditioning=True.\"\n            )\n        if cross_attention_dim is not None and with_conditioning is False:\n            raise ValueError(\"ControlNet expects with_conditioning=True when specifying the cross_attention_dim.\")\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in channels):\n            raise ValueError(\n                f\"ControlNet expects all channels to be a multiple of norm_num_groups, but got\"\n                f\" channels={channels} and norm_num_groups={norm_num_groups}\"\n            )\n\n        if len(channels) != len(attention_levels):\n            raise ValueError(\n                f\"ControlNet expects channels to have the same length as attention_levels, but got \"\n                f\"channels={channels} and attention_levels={attention_levels}\"\n            )\n\n        if isinstance(num_head_channels, int):\n            num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))\n\n        if len(num_head_channels) != len(attention_levels):\n            raise ValueError(\n                f\"num_head_channels should have the same length as attention_levels, but got channels={channels} and \"\n                f\"attention_levels={attention_levels} . For the i levels without attention,\"\n                \" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.\"\n            )\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))\n\n        if len(num_res_blocks) != len(channels):\n            raise ValueError(\n                f\"`num_res_blocks` should be a single integer or a tuple of integers with the same length as \"\n                f\"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}.\"\n            )\n\n        self.in_channels = in_channels\n        self.block_out_channels = channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_levels = attention_levels\n        self.num_head_channels = num_head_channels\n        self.with_conditioning = with_conditioning\n\n        # input\n        self.conv_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=channels[0],\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        # time\n        time_embed_dim = channels[0] * 4\n        self.time_embed = nn.Sequential(\n            nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)\n        )\n\n        # class embedding\n        self.num_class_embeds = num_class_embeds\n        if num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n\n        # control net conditioning embedding\n        self.controlnet_cond_embedding = ControlNetConditioningEmbedding(\n            spatial_dims=spatial_dims,\n            in_channels=conditioning_embedding_in_channels,\n            channels=conditioning_embedding_num_channels,\n            out_channels=channels[0],\n        )\n\n        # down\n        self.down_blocks = nn.ModuleList([])\n        self.controlnet_down_blocks = nn.ModuleList([])\n        output_channel = channels[0]\n\n        controlnet_block = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=output_channel,\n            out_channels=output_channel,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        controlnet_block = zero_module(controlnet_block.conv)\n        self.controlnet_down_blocks.append(controlnet_block)\n\n        for i in range(len(channels)):\n            input_channel = output_channel\n            output_channel = channels[i]\n            is_final_block = i == len(channels) - 1\n\n            down_block = get_down_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                num_res_blocks=num_res_blocks[i],\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_downsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(attention_levels[i] and not with_conditioning),\n                with_cross_attn=(attention_levels[i] and with_conditioning),\n                num_head_channels=num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n            )\n\n            self.down_blocks.append(down_block)\n\n            for _ in range(num_res_blocks[i]):\n                controlnet_block = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=output_channel,\n                    out_channels=output_channel,\n                    strides=1,\n                    kernel_size=1,\n                    padding=0,\n                    conv_only=True,\n                )\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n            #\n            if not is_final_block:\n                controlnet_block = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=output_channel,\n                    out_channels=output_channel,\n                    strides=1,\n                    kernel_size=1,\n                    padding=0,\n                    conv_only=True,\n                )\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n        # mid\n        mid_block_channel = channels[-1]\n\n        self.middle_block = get_mid_block(\n            spatial_dims=spatial_dims,\n            in_channels=mid_block_channel,\n            temb_channels=time_embed_dim,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            with_conditioning=with_conditioning,\n            num_head_channels=num_head_channels[-1],\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n\n        controlnet_block = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=output_channel,\n            out_channels=output_channel,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_mid_block = controlnet_block\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        controlnet_cond: torch.Tensor,\n        conditioning_scale: float = 1.0,\n        context: torch.Tensor | None = None,\n        class_labels: torch.Tensor | None = None,\n    ) -> tuple[list[torch.Tensor], torch.Tensor]:\n        \"\"\"\n        Args:\n            x: input tensor (N, C, H, W, [D]).\n            timesteps: timestep tensor (N,).\n            controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D])\n            conditioning_scale: conditioning scale.\n            context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init.\n            class_labels: context tensor (N, ).\n        \"\"\"\n        # 1. time\n        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=x.dtype)\n        emb = self.time_embed(t_emb)\n\n        # 2. class\n        if self.num_class_embeds is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n            class_emb = self.class_embedding(class_labels)\n            class_emb = class_emb.to(dtype=x.dtype)\n            emb = emb + class_emb\n\n        # 3. initial convolution\n        h = self.conv_in(x)\n\n        controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)\n\n        h += controlnet_cond\n\n        # 4. down\n        if context is not None and self.with_conditioning is False:\n            raise ValueError(\"model should have with_conditioning = True if context is provided\")\n        down_block_res_samples: list[torch.Tensor] = [h]\n        for downsample_block in self.down_blocks:\n            h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)\n            for residual in res_samples:\n                down_block_res_samples.append(residual)\n\n        # 5. mid\n        h = self.middle_block(hidden_states=h, temb=emb, context=context)\n\n        # 6. Control net blocks\n        controlnet_down_block_res_samples = []\n\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples.append(down_block_res_sample)\n\n        down_block_res_samples = controlnet_down_block_res_samples\n\n        mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h)\n\n        # 6. scaling\n        down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]\n        mid_block_res_sample *= conditioning_scale\n\n        return down_block_res_samples, mid_block_res_sample\n\n    def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:\n        \"\"\"\n        Load a state dict from a ControlNet trained with\n        [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).\n\n        Args:\n            old_state_dict: state dict from the old ControlNet model.\n        \"\"\"\n\n        new_state_dict = self.state_dict()\n        # if all keys match, just load the state dict\n        if all(k in new_state_dict for k in old_state_dict):\n            print(\"All keys match, loading state dict.\")\n            self.load_state_dict(old_state_dict)\n            return\n\n        if verbose:\n            # print all new_state_dict keys that are not in old_state_dict\n            for k in new_state_dict:\n                if k not in old_state_dict:\n                    print(f\"key {k} not found in old state dict\")\n            # and vice versa\n            print(\"----------------------------------------------\")\n            for k in old_state_dict:\n                if k not in new_state_dict:\n                    print(f\"key {k} not found in new state dict\")\n\n        # copy over all matching keys\n        for k in new_state_dict:\n            if k in old_state_dict:\n                new_state_dict[k] = old_state_dict.pop(k)\n\n        # fix the attention blocks\n        attention_blocks = [k.replace(\".out_proj.weight\", \"\") for k in new_state_dict if \"out_proj.weight\" in k]\n        for block in attention_blocks:\n            # projection\n            new_state_dict[f\"{block}.out_proj.weight\"] = old_state_dict.pop(f\"{block}.to_out.0.weight\")\n            new_state_dict[f\"{block}.out_proj.bias\"] = old_state_dict.pop(f\"{block}.to_out.0.bias\")\n\n        if verbose:\n            # print all remaining keys in old_state_dict\n            print(\"remaining keys in old_state_dict:\", old_state_dict.keys())\n        self.load_state_dict(new_state_dict)\n"
  },
  {
    "path": "monai/networks/nets/daf3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections import OrderedDict\nfrom collections.abc import Callable, Sequence\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom monai.networks.blocks import ADN\nfrom monai.networks.blocks.aspp import SimpleASPP\nfrom monai.networks.blocks.backbone_fpn_utils import BackboneWithFPN\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork\nfrom monai.networks.layers.factories import Conv, Norm\nfrom monai.networks.layers.utils import get_norm_layer\nfrom monai.networks.nets.resnet import ResNet, ResNetBottleneck\n\n__all__ = [\n    \"AttentionModule\",\n    \"Daf3dASPP\",\n    \"Daf3dResNetBottleneck\",\n    \"Daf3dResNetDilatedBottleneck\",\n    \"Daf3dResNet\",\n    \"Daf3dBackbone\",\n    \"Daf3dFPN\",\n    \"Daf3dBackboneWithFPN\",\n    \"DAF3D\",\n]\n\n\nclass AttentionModule(nn.Module):\n    \"\"\"\n    Attention Module as described in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'\n    <https://arxiv.org/pdf/1907.01743.pdf>. Returns refined single layer feature (SLF) and attentive map\n\n    Args:\n        spatial_dims: dimension of inputs.\n        in_channels: number of input channels (channels of slf and mlf).\n        out_channels: number of output channels (channels of attentive map and refined slf).\n        norm: normalization type.\n        act: activation type.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims,\n        in_channels,\n        out_channels,\n        norm=(\"group\", {\"num_groups\": 32, \"num_channels\": 64}),\n        act=\"PRELU\",\n    ):\n        super().__init__()\n\n        self.attentive_map = nn.Sequential(\n            Convolution(spatial_dims, in_channels, out_channels, kernel_size=1, norm=norm, act=act),\n            Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act),\n            Convolution(\n                spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, adn_ordering=\"A\", act=\"SIGMOID\"\n            ),\n        )\n        self.refine = nn.Sequential(\n            Convolution(spatial_dims, in_channels, out_channels, kernel_size=1, norm=norm, act=act),\n            Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act),\n            Convolution(spatial_dims, out_channels, out_channels, kernel_size=3, padding=1, norm=norm, act=act),\n        )\n\n    def forward(self, slf, mlf):\n        att = self.attentive_map(torch.cat((slf, mlf), 1))\n        out = self.refine(torch.cat((slf, att * mlf), 1))\n        return (out, att)\n\n\nclass Daf3dASPP(SimpleASPP):\n    \"\"\"\n    Atrous Spatial Pyramid Pooling module as used in 'Deep Attentive Features for Prostate Segmentation in\n    3D Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>. Core functionality as in SimpleASPP, but after each\n    layerwise convolution a group normalization is added. Further weight initialization for convolutions is provided in\n    _init_weight(). Additional possibility to specify the number of final output channels.\n\n    Args:\n        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.\n        in_channels: number of input channels.\n        conv_out_channels: number of output channels of each atrous conv.\n        out_channels: number of output channels of final convolution.\n            If None, uses len(kernel_sizes) * conv_out_channels\n        kernel_sizes: a sequence of four convolutional kernel sizes.\n            Defaults to (1, 3, 3, 3) for four (dilated) convolutions.\n        dilations: a sequence of four convolutional dilation parameters.\n            Defaults to (1, 2, 4, 6) for four (dilated) convolutions.\n        norm_type: final kernel-size-one convolution normalization type.\n            Defaults to batch norm.\n        acti_type: final kernel-size-one convolution activation type.\n            Defaults to leaky ReLU.\n        bias: whether to have a bias term in convolution blocks. Defaults to False.\n            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n            if a conv layer is directly followed by a batch norm layer, bias should be False.\n\n    Raises:\n        ValueError: When ``kernel_sizes`` length differs from ``dilations``.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        conv_out_channels: int,\n        out_channels: int | None = None,\n        kernel_sizes: Sequence[int] = (1, 3, 3, 3),\n        dilations: Sequence[int] = (1, 2, 4, 6),\n        norm_type: tuple | str | None = \"BATCH\",\n        acti_type: tuple | str | None = \"LEAKYRELU\",\n        bias: bool = False,\n    ) -> None:\n        super().__init__(\n            spatial_dims, in_channels, conv_out_channels, kernel_sizes, dilations, norm_type, acti_type, bias\n        )\n\n        # add normalization after each atrous convolution, initializes weights\n        new_convs = nn.ModuleList()\n        for _conv in self.convs:\n            tmp_conv = Convolution(1, 1, 1)\n            tmp_conv.conv = _conv\n            tmp_conv.adn = ADN(ordering=\"N\", norm=norm_type, norm_dim=1)\n            tmp_conv = self._init_weight(tmp_conv)\n            new_convs.append(tmp_conv)\n        self.convs = new_convs\n\n        # change final convolution to different out_channels\n        if out_channels is None:\n            out_channels = len(kernel_sizes) * conv_out_channels\n\n        self.conv_k1 = Convolution(\n            spatial_dims=3,\n            in_channels=len(kernel_sizes) * conv_out_channels,\n            out_channels=out_channels,\n            kernel_size=1,\n            norm=norm_type,\n            act=acti_type,\n        )\n\n    def _init_weight(self, conv):\n        for m in conv.modules():\n            if isinstance(m, nn.Conv3d):  # true for conv.conv\n                torch.nn.init.kaiming_normal_(m.weight)\n        return conv\n\n\nclass Daf3dResNetBottleneck(ResNetBottleneck):\n    \"\"\"\n    ResNetBottleneck block as used in 'Deep Attentive Features for Prostate Segmentation in 3D\n    Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.\n    Instead of Batch Norm Group Norm is used, instead of ReLU PReLU activation is used.\n    Initial expansion is 2 instead of 4 and second convolution uses groups.\n\n    Args:\n        in_planes: number of input channels.\n        planes: number of output channels (taking expansion into account).\n        spatial_dims: number of spatial dimensions of the input image.\n        stride: stride to use for second conv layer.\n        downsample: which downsample layer to use.\n        norm: which normalization layer to use. Defaults to group.\n    \"\"\"\n\n    expansion = 2\n\n    def __init__(\n        self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=(\"group\", {\"num_groups\": 32})\n    ):\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n\n        norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)\n\n        # in case downsample uses batch norm, change to group norm\n        if isinstance(downsample, nn.Sequential):\n            downsample = nn.Sequential(\n                conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),\n                norm_layer(channels=planes * self.expansion),\n            )\n\n        super().__init__(in_planes, planes, spatial_dims, stride, downsample)\n\n        # change norm from batch to group norm\n        self.bn1 = norm_layer(channels=planes)\n        self.bn2 = norm_layer(channels=planes)\n        self.bn3 = norm_layer(channels=planes * self.expansion)\n\n        # adapt second convolution to work with groups\n        self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False)\n\n        # adapt activation function\n        self.relu = nn.PReLU()\n\n\nclass Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):\n    \"\"\"\n    ResNetDilatedBottleneck as used in 'Deep Attentive Features for Prostate Segmentation in 3D\n    Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.\n    Same as Daf3dResNetBottleneck but dilation of 2 is used in second convolution.\n    Args:\n        in_planes: number of input channels.\n        planes: number of output channels (taking expansion into account).\n        spatial_dims: number of spatial dimensions of the input image.\n        stride: stride to use for second conv layer.\n        downsample: which downsample layer to use.\n    \"\"\"\n\n    def __init__(\n        self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=(\"group\", {\"num_groups\": 32})\n    ):\n        super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm)\n\n        # add dilation in second convolution\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        self.conv2 = conv_type(\n            planes, planes, kernel_size=3, stride=stride, padding=2, dilation=2, groups=32, bias=False\n        )\n\n\nclass Daf3dResNet(ResNet):\n    \"\"\"\n    ResNet as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'\n    <https://arxiv.org/pdf/1907.01743.pdf>.\n    Uses two Daf3dResNetBottleneck blocks followed by two Daf3dResNetDilatedBottleneck blocks.\n\n    Args:\n        layers: how many layers to use.\n        block_inplanes: determine the size of planes at each step. Also tunable with widen_factor.\n        spatial_dims: number of spatial dimensions of the input image.\n        n_input_channels: number of input channels for first convolutional layer.\n        conv1_t_size: size of first convolution layer, determines kernel and padding.\n        conv1_t_stride: stride of first convolution layer.\n        no_max_pool: bool argument to determine if to use maxpool layer.\n        shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.\n            - 'A': using `self._downsample_basic_block`.\n            - 'B': kernel_size 1 conv + norm.\n        widen_factor: widen output for each layer.\n        num_classes: number of output (classifications).\n        feed_forward: whether to add the FC layer for the output, default to `True`.\n        bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        layers: list[int],\n        block_inplanes: list[int],\n        spatial_dims: int = 3,\n        n_input_channels: int = 3,\n        conv1_t_size: tuple[int] | int = 7,\n        conv1_t_stride: tuple[int] | int = 1,\n        no_max_pool: bool = False,\n        shortcut_type: str = \"B\",\n        widen_factor: float = 1.0,\n        num_classes: int = 400,\n        feed_forward: bool = True,\n        bias_downsample: bool = True,  # for backwards compatibility (also see PR #5477)\n    ):\n        super().__init__(\n            ResNetBottleneck,\n            layers,\n            block_inplanes,\n            spatial_dims,\n            n_input_channels,\n            conv1_t_size,\n            conv1_t_stride,\n            no_max_pool,\n            shortcut_type,\n            widen_factor,\n            num_classes,\n            feed_forward,\n            bias_downsample,\n        )\n\n        self.in_planes = 64\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        norm_type: Callable = Norm[Norm.GROUP, spatial_dims]\n\n        # adapt first convolution to work with new in_planes\n        self.conv1 = conv_type(\n            n_input_channels, self.in_planes, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False\n        )\n        self.bn1 = norm_type(32, 64)\n        self.relu = nn.PReLU()\n\n        # adapt layers to our needs\n        self.layer1 = self._make_layer(Daf3dResNetBottleneck, block_inplanes[0], layers[0], spatial_dims, shortcut_type)\n        self.layer2 = self._make_layer(\n            Daf3dResNetBottleneck,\n            block_inplanes[1],\n            layers[1],\n            spatial_dims,\n            shortcut_type,\n            stride=(1, 2, 2),  # type: ignore\n        )\n        self.layer3 = self._make_layer(\n            Daf3dResNetDilatedBottleneck, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=1\n        )\n        self.layer4 = self._make_layer(\n            Daf3dResNetDilatedBottleneck, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=1\n        )\n\n\nclass Daf3dBackbone(nn.Module):\n    \"\"\"\n    Backbone for 3D Feature Pyramid Network in DAF3D module based on 'Deep Attentive Features for Prostate Segmentation in\n    3D Transrectal Ultrasound' <https://arxiv.org/pdf/1907.01743.pdf>.\n\n    Args:\n        n_input_channels: number of input channels for the first convolution.\n    \"\"\"\n\n    def __init__(self, n_input_channels):\n        super().__init__()\n        net = Daf3dResNet(\n            layers=[3, 4, 6, 3],\n            block_inplanes=[128, 256, 512, 1024],\n            n_input_channels=n_input_channels,\n            num_classes=2,\n            bias_downsample=False,\n        )\n        net_modules = list(net.children())\n        self.layer0 = nn.Sequential(*net_modules[:3])\n        self.layer1 = nn.Sequential(*net_modules[3:5])\n        self.layer2 = net_modules[5]\n        self.layer3 = net_modules[6]\n        self.layer4 = net_modules[7]\n\n    def forward(self, x):\n        layer0 = self.layer0(x)\n        layer1 = self.layer1(layer0)\n        layer2 = self.layer2(layer1)\n        layer3 = self.layer3(layer2)\n        layer4 = self.layer4(layer3)\n        return layer4\n\n\nclass Daf3dFPN(FeaturePyramidNetwork):\n    \"\"\"\n    Feature Pyramid Network as used in 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'\n    <https://arxiv.org/pdf/1907.01743.pdf>.\n    Omits 3x3x3 convolution of layer_blocks and interpolates resulting feature maps to be the same size as\n    feature map with highest resolution.\n\n    Args:\n        spatial_dims: 2D or 3D images\n        in_channels_list: number of channels for each feature map that is passed to the module\n        out_channels: number of channels of the FPN representation\n        extra_blocks: if provided, extra operations will be performed.\n            It is expected to take the fpn features, the original\n            features and the names of the original features as input, and returns\n            a new list of feature maps and their corresponding names\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels_list: list[int],\n        out_channels: int,\n        extra_blocks: ExtraFPNBlock | None = None,\n    ):\n        super().__init__(spatial_dims, in_channels_list, out_channels, extra_blocks)\n\n        self.inner_blocks = nn.ModuleList()\n        for in_channels in in_channels_list:\n            if in_channels == 0:\n                raise ValueError(\"in_channels=0 is currently not supported\")\n            inner_block_module = Convolution(\n                spatial_dims,\n                in_channels,\n                out_channels,\n                kernel_size=1,\n                adn_ordering=\"NA\",\n                act=\"PRELU\",\n                norm=(\"group\", {\"num_groups\": 32, \"num_channels\": 128}),\n            )\n            self.inner_blocks.append(inner_block_module)\n\n    def forward(self, x: dict[str, Tensor]) -> dict[str, Tensor]:\n        # unpack OrderedDict into two lists for easier handling\n        names = list(x.keys())\n        x_values: list[Tensor] = list(x.values())\n\n        last_inner = self.get_result_from_inner_blocks(x_values[-1], -1)\n        results = []\n        results.append(last_inner)\n\n        for idx in range(len(x_values) - 2, -1, -1):\n            inner_lateral = self.get_result_from_inner_blocks(x_values[idx], idx)\n            feat_shape = inner_lateral.shape[2:]\n            inner_top_down = F.interpolate(last_inner, size=feat_shape, mode=\"trilinear\")\n            last_inner = inner_lateral + inner_top_down\n            results.insert(0, last_inner)\n\n        if self.extra_blocks is not None:\n            results, names = self.extra_blocks(results, x_values, names)\n\n        # bring all layers to same size\n        results = [results[0]] + [F.interpolate(l, size=x[\"feat1\"].size()[2:], mode=\"trilinear\") for l in results[1:]]\n        # make it back an OrderedDict\n        out = OrderedDict(list(zip(names, results)))\n\n        return out\n\n\nclass Daf3dBackboneWithFPN(BackboneWithFPN):\n    \"\"\"\n    Same as BackboneWithFPN but uses custom Daf3DFPN as feature pyramid network\n\n    Args:\n        backbone: backbone network\n        return_layers: a dict containing the names\n            of the modules for which the activations will be returned as\n            the key of the dict, and the value of the dict is the name\n            of the returned activation (which the user can specify).\n        in_channels_list: number of channels for each feature map\n            that is returned, in the order they are present in the OrderedDict\n        out_channels: number of channels in the FPN.\n        spatial_dims: 2D or 3D images\n        extra_blocks: if provided, extra operations will\n            be performed. It is expected to take the fpn features, the original\n            features and the names of the original features as input, and returns\n            a new list of feature maps and their corresponding names\n    \"\"\"\n\n    def __init__(\n        self,\n        backbone: nn.Module,\n        return_layers: dict[str, str],\n        in_channels_list: list[int],\n        out_channels: int,\n        spatial_dims: int | None = None,\n        extra_blocks: ExtraFPNBlock | None = None,\n    ) -> None:\n        super().__init__(backbone, return_layers, in_channels_list, out_channels, spatial_dims, extra_blocks)\n\n        if spatial_dims is None:\n            if hasattr(backbone, \"spatial_dims\") and isinstance(backbone.spatial_dims, int):\n                spatial_dims = backbone.spatial_dims\n            elif isinstance(backbone.conv1, nn.Conv2d):\n                spatial_dims = 2\n            elif isinstance(backbone.conv1, nn.Conv3d):\n                spatial_dims = 3\n            else:\n                raise ValueError(\n                    \"Could not determine value of  `spatial_dims` from backbone, please provide explicit value.\"\n                )\n\n        self.fpn = Daf3dFPN(spatial_dims, in_channels_list, out_channels, extra_blocks)\n\n\nclass DAF3D(nn.Module):\n    \"\"\"\n    DAF3D network based on 'Deep Attentive Features for Prostate Segmentation in 3D Transrectal Ultrasound'\n    <https://arxiv.org/pdf/1907.01743.pdf>.\n    The network consists of a 3D Feature Pyramid Network which is applied on the feature maps of a 3D ResNet,\n    followed by a custom Attention Module and an ASPP module.\n    During training the supervised signal consists of the outputs of the FPN (four Single Layer Features, SLFs),\n    the outputs of the attention module (four Attentive Features) and the final prediction.\n    They are individually compared to the ground truth, the final loss consists of a weighted sum of all\n    individual losses (see DAF3D tutorial for details).\n    There is an additional possiblity to return all supervised signals as well as the Attentive Maps in validation\n    mode to visualize inner functionality of the network.\n\n    Args:\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        visual_output: whether to return all SLFs, Attentive Maps, Refined SLFs in validation mode\n            can be used to visualize inner functionality of the network\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, visual_output=False):\n        super().__init__()\n        self.visual_output = visual_output\n        self.backbone_with_fpn = Daf3dBackboneWithFPN(\n            backbone=Daf3dBackbone(in_channels),\n            return_layers={\"layer1\": \"feat1\", \"layer2\": \"feat2\", \"layer3\": \"feat3\", \"layer4\": \"feat4\"},\n            in_channels_list=[256, 512, 1024, 2048],\n            out_channels=128,\n            spatial_dims=3,\n        )\n        self.predict1 = nn.Conv3d(128, out_channels, kernel_size=1)\n\n        group_norm = (\"group\", {\"num_groups\": 32, \"num_channels\": 64})\n        act_prelu = (\"prelu\", {\"num_parameters\": 1, \"init\": 0.25})\n        self.fuse = nn.Sequential(\n            Convolution(\n                spatial_dims=3,\n                in_channels=512,\n                out_channels=64,\n                kernel_size=1,\n                adn_ordering=\"NA\",\n                norm=group_norm,\n                act=act_prelu,\n            ),\n            Convolution(\n                spatial_dims=3,\n                in_channels=64,\n                out_channels=64,\n                kernel_size=3,\n                adn_ordering=\"NA\",\n                padding=1,\n                norm=group_norm,\n                act=act_prelu,\n            ),\n            Convolution(\n                spatial_dims=3,\n                in_channels=64,\n                out_channels=64,\n                kernel_size=3,\n                adn_ordering=\"NA\",\n                padding=1,\n                norm=group_norm,\n                act=act_prelu,\n            ),\n        )\n        self.attention = AttentionModule(\n            spatial_dims=3, in_channels=192, out_channels=64, norm=group_norm, act=act_prelu\n        )\n\n        self.refine = Convolution(3, 256, 64, kernel_size=1, adn_ordering=\"NA\", norm=group_norm, act=act_prelu)\n        self.predict2 = nn.Conv3d(64, out_channels, kernel_size=1)\n        self.aspp = Daf3dASPP(\n            spatial_dims=3,\n            in_channels=64,\n            conv_out_channels=64,\n            out_channels=64,\n            kernel_sizes=(3, 3, 3, 3),\n            dilations=((1, 1, 1), (1, 6, 6), (1, 12, 12), (1, 18, 18)),  # type: ignore\n            norm_type=group_norm,\n            acti_type=None,\n            bias=True,\n        )\n\n    def forward(self, x):\n        # layers from 1 - 4\n        single_layer_features = list(self.backbone_with_fpn(x).values())\n\n        # first 4 supervised signals (SLFs 1 - 4)\n        supervised1 = [self.predict1(slf) for slf in single_layer_features]\n\n        mlf = self.fuse(torch.cat(single_layer_features, 1))\n\n        attentive_features_maps = [self.attention(slf, mlf) for slf in single_layer_features]\n        att_features, att_maps = tuple(zip(*attentive_features_maps))\n\n        # second 4 supervised signals (af 1 - 4)\n        supervised2 = [self.predict2(af) for af in att_features]\n\n        # attentive maps as optional additional output\n        supervised3 = [self.predict2(am) for am in att_maps]\n\n        attentive_mlf = self.refine(torch.cat(att_features, 1))\n\n        aspp = self.aspp(attentive_mlf)\n\n        supervised_final = self.predict2(aspp)\n\n        if self.training:\n            output = supervised1 + supervised2 + [supervised_final]\n            output = [F.interpolate(o, size=x.size()[2:], mode=\"trilinear\") for o in output]\n        else:\n            if self.visual_output:\n                supervised_final = F.interpolate(supervised_final, size=x.size()[2:], mode=\"trilinear\")\n                supervised_inner = [\n                    F.interpolate(o, size=x.size()[2:], mode=\"trilinear\")\n                    for o in supervised1 + supervised2 + supervised3\n                ]\n                output = [supervised_final] + supervised_inner\n            else:\n                output = F.interpolate(supervised_final, size=x.size()[2:], mode=\"trilinear\")\n        return output\n"
  },
  {
    "path": "monai/networks/nets/densenet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport re\nfrom collections import OrderedDict\nfrom collections.abc import Callable, Sequence\n\nimport torch\nimport torch.nn as nn\nfrom torch.hub import load_state_dict_from_url\n\nfrom monai.networks.layers.factories import Conv, Dropout, Pool\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\nfrom monai.utils.module import look_up_option\n\n__all__ = [\n    \"DenseNet\",\n    \"Densenet\",\n    \"DenseNet121\",\n    \"densenet121\",\n    \"Densenet121\",\n    \"DenseNet169\",\n    \"densenet169\",\n    \"Densenet169\",\n    \"DenseNet201\",\n    \"densenet201\",\n    \"Densenet201\",\n    \"DenseNet264\",\n    \"densenet264\",\n    \"Densenet264\",\n]\n\n\nclass _DenseLayer(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        growth_rate: int,\n        bn_size: int,\n        dropout_prob: float,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of the input channel.\n            growth_rate: how many filters to add each layer (k in paper).\n            bn_size: multiplicative factor for number of bottle neck layers.\n                (i.e. bn_size * k features in the bottleneck layer)\n            dropout_prob: dropout rate after each dense layer.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n        \"\"\"\n        super().__init__()\n\n        out_channels = bn_size * growth_rate\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]\n\n        self.layers = nn.Sequential()\n\n        self.layers.add_module(\"norm1\", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels))\n        self.layers.add_module(\"relu1\", get_act_layer(name=act))\n        self.layers.add_module(\"conv1\", conv_type(in_channels, out_channels, kernel_size=1, bias=False))\n\n        self.layers.add_module(\"norm2\", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels))\n        self.layers.add_module(\"relu2\", get_act_layer(name=act))\n        self.layers.add_module(\"conv2\", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))\n\n        if dropout_prob > 0:\n            self.layers.add_module(\"dropout\", dropout_type(dropout_prob))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        new_features = self.layers(x)\n        return torch.cat([x, new_features], 1)\n\n\nclass _DenseBlock(nn.Sequential):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        layers: int,\n        in_channels: int,\n        bn_size: int,\n        growth_rate: int,\n        dropout_prob: float,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            layers: number of layers in the block.\n            in_channels: number of the input channel.\n            bn_size: multiplicative factor for number of bottle neck layers.\n                (i.e. bn_size * k features in the bottleneck layer)\n            growth_rate: how many filters to add each layer (k in paper).\n            dropout_prob: dropout rate after each dense layer.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n        \"\"\"\n        super().__init__()\n        for i in range(layers):\n            layer = _DenseLayer(spatial_dims, in_channels, growth_rate, bn_size, dropout_prob, act=act, norm=norm)\n            in_channels += growth_rate\n            self.add_module(f\"denselayer{i + 1}\", layer)\n\n\nclass _Transition(nn.Sequential):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of the input channel.\n            out_channels: number of the output classes.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n        \"\"\"\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        pool_type: Callable = Pool[Pool.AVG, spatial_dims]\n\n        self.add_module(\"norm\", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels))\n        self.add_module(\"relu\", get_act_layer(name=act))\n        self.add_module(\"conv\", conv_type(in_channels, out_channels, kernel_size=1, bias=False))\n        self.add_module(\"pool\", pool_type(kernel_size=2, stride=2))\n\n\nclass DenseNet(nn.Module):\n    \"\"\"\n    Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_.\n    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.\n    This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below\n    for more details:\n    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input image.\n        in_channels: number of the input channel.\n        out_channels: number of the output classes.\n        init_features: number of filters in the first convolution layer.\n        growth_rate: how many filters to add each layer (k in paper).\n        block_config: how many layers in each pooling block.\n        bn_size: multiplicative factor for number of bottle neck layers.\n            (i.e. bn_size * k features in the bottleneck layer)\n        act: activation type and arguments. Defaults to relu.\n        norm: feature normalization type and arguments. Defaults to batch norm.\n        dropout_prob: dropout rate after each dense layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        init_features: int = 64,\n        growth_rate: int = 32,\n        block_config: Sequence[int] = (6, 12, 24, 16),\n        bn_size: int = 4,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        dropout_prob: float = 0.0,\n    ) -> None:\n        super().__init__()\n\n        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n        pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n        avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[\n            Pool.ADAPTIVEAVG, spatial_dims\n        ]\n\n        self.features = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"conv0\", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),\n                    (\"norm0\", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=init_features)),\n                    (\"relu0\", get_act_layer(name=act)),\n                    (\"pool0\", pool_type(kernel_size=3, stride=2, padding=1)),\n                ]\n            )\n        )\n\n        in_channels = init_features\n        for i, num_layers in enumerate(block_config):\n            block = _DenseBlock(\n                spatial_dims=spatial_dims,\n                layers=num_layers,\n                in_channels=in_channels,\n                bn_size=bn_size,\n                growth_rate=growth_rate,\n                dropout_prob=dropout_prob,\n                act=act,\n                norm=norm,\n            )\n            self.features.add_module(f\"denseblock{i + 1}\", block)\n            in_channels += num_layers * growth_rate\n            if i == len(block_config) - 1:\n                self.features.add_module(\n                    \"norm5\", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)\n                )\n            else:\n                _out_channels = in_channels // 2\n                trans = _Transition(\n                    spatial_dims, in_channels=in_channels, out_channels=_out_channels, act=act, norm=norm\n                )\n                self.features.add_module(f\"transition{i + 1}\", trans)\n                in_channels = _out_channels\n\n        # pooling and classification\n        self.class_layers = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"relu\", get_act_layer(name=act)),\n                    (\"pool\", avg_pool_type(1)),\n                    (\"flatten\", nn.Flatten(1)),\n                    (\"out\", nn.Linear(in_channels, out_channels)),\n                ]\n            )\n        )\n\n        for m in self.modules():\n            if isinstance(m, conv_type):\n                nn.init.kaiming_normal_(torch.as_tensor(m.weight))\n            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):\n                nn.init.constant_(torch.as_tensor(m.weight), 1)\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.features(x)\n        x = self.class_layers(x)\n        return x\n\n\ndef _load_state_dict(model: nn.Module, arch: str, progress: bool):\n    \"\"\"\n    This function is used to load pretrained models.\n    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.\n\n    \"\"\"\n    model_urls = {\n        \"densenet121\": \"https://download.pytorch.org/models/densenet121-a639ec97.pth\",\n        \"densenet169\": \"https://download.pytorch.org/models/densenet169-b2777c0a.pth\",\n        \"densenet201\": \"https://download.pytorch.org/models/densenet201-c1103571.pth\",\n    }\n    model_url = look_up_option(arch, model_urls, None)\n    if model_url is None:\n        raise ValueError(\n            \"only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights.\"\n        )\n\n    pattern = re.compile(\n        r\"^(.*denselayer\\d+)(\\.(?:norm|relu|conv))\\.((?:[12])\\.(?:weight|bias|running_mean|running_var))$\"\n    )\n\n    state_dict = load_state_dict_from_url(model_url, progress=progress)\n    for key in list(state_dict.keys()):\n        res = pattern.match(key)\n        if res:\n            new_key = res.group(1) + \".layers\" + res.group(2) + res.group(3)\n            state_dict[new_key] = state_dict[key]\n            del state_dict[key]\n\n    model_dict = model.state_dict()\n    state_dict = {\n        k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)\n    }\n    model_dict.update(state_dict)\n    model.load_state_dict(model_dict)\n\n\nclass DenseNet121(DenseNet):\n    \"\"\"DenseNet121 with optional pretrained support when `spatial_dims` is 2.\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        init_features: int = 64,\n        growth_rate: int = 32,\n        block_config: Sequence[int] = (6, 12, 24, 16),\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            init_features=init_features,\n            growth_rate=growth_rate,\n            block_config=block_config,\n            **kwargs,\n        )\n        if pretrained:\n            if spatial_dims > 2:\n                raise NotImplementedError(\n                    \"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not\"\n                    \"provide pretrained models for more than two spatial dimensions.\"\n                )\n            _load_state_dict(self, \"densenet121\", progress)\n\n\nclass DenseNet169(DenseNet):\n    \"\"\"DenseNet169 with optional pretrained support when `spatial_dims` is 2.\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        init_features: int = 64,\n        growth_rate: int = 32,\n        block_config: Sequence[int] = (6, 12, 32, 32),\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            init_features=init_features,\n            growth_rate=growth_rate,\n            block_config=block_config,\n            **kwargs,\n        )\n        if pretrained:\n            if spatial_dims > 2:\n                raise NotImplementedError(\n                    \"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not\"\n                    \"provide pretrained models for more than two spatial dimensions.\"\n                )\n            _load_state_dict(self, \"densenet169\", progress)\n\n\nclass DenseNet201(DenseNet):\n    \"\"\"DenseNet201 with optional pretrained support when `spatial_dims` is 2.\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        init_features: int = 64,\n        growth_rate: int = 32,\n        block_config: Sequence[int] = (6, 12, 48, 32),\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            init_features=init_features,\n            growth_rate=growth_rate,\n            block_config=block_config,\n            **kwargs,\n        )\n        if pretrained:\n            if spatial_dims > 2:\n                raise NotImplementedError(\n                    \"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not\"\n                    \"provide pretrained models for more than two spatial dimensions.\"\n                )\n            _load_state_dict(self, \"densenet201\", progress)\n\n\nclass DenseNet264(DenseNet):\n    \"\"\"DenseNet264\"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        init_features: int = 64,\n        growth_rate: int = 32,\n        block_config: Sequence[int] = (6, 12, 64, 48),\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            init_features=init_features,\n            growth_rate=growth_rate,\n            block_config=block_config,\n            **kwargs,\n        )\n        if pretrained:\n            raise NotImplementedError(\"Currently PyTorch Hub does not provide densenet264 pretrained models.\")\n\n\nDensenet = DenseNet\nDensenet121 = densenet121 = DenseNet121\nDensenet169 = densenet169 = DenseNet169\nDensenet201 = densenet201 = DenseNet201\nDensenet264 = densenet264 = DenseNet264\n"
  },
  {
    "path": "monai/networks/nets/diffusion_model_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\nfrom functools import reduce\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample\nfrom monai.networks.layers.factories import Pool\nfrom monai.utils import ensure_tuple_rep, optional_import\n\nRearrange, _ = optional_import(\"einops.layers.torch\", name=\"Rearrange\")\n\n__all__ = [\"DiffusionModelUNet\"]\n\n\ndef zero_module(module: nn.Module) -> nn.Module:\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\nclass DiffusionUNetTransformerBlock(nn.Module):\n    \"\"\"\n    A Transformer block that allows for the input dimension to differ from the hidden dimension.\n\n    Args:\n        num_channels: number of channels in the input and output.\n        num_attention_heads: number of heads to use for multi-head attention.\n        num_head_channels: number of channels in each attention head.\n        dropout: dropout probability to use.\n        cross_attention_dim: size of the context vector for cross attention.\n        upcast_attention: if True, upcast attention operations to full precision.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_channels: int,\n        num_attention_heads: int,\n        num_head_channels: int,\n        dropout: float = 0.0,\n        cross_attention_dim: int | None = None,\n        upcast_attention: bool = False,\n        use_flash_attention: bool = False,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n    ) -> None:\n        super().__init__()\n        self.attn1 = SABlock(\n            hidden_size=num_attention_heads * num_head_channels,\n            hidden_input_size=num_channels,\n            num_heads=num_attention_heads,\n            dim_head=num_head_channels,\n            dropout_rate=dropout,\n            attention_dtype=torch.float if upcast_attention else None,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act=\"GEGLU\", dropout_rate=dropout)\n        self.attn2 = CrossAttentionBlock(\n            hidden_size=num_attention_heads * num_head_channels,\n            num_heads=num_attention_heads,\n            hidden_input_size=num_channels,\n            context_input_size=cross_attention_dim,\n            dim_head=num_head_channels,\n            dropout_rate=dropout,\n            attention_dtype=torch.float if upcast_attention else None,\n            use_flash_attention=use_flash_attention,\n        )\n        self.norm1 = nn.LayerNorm(num_channels)\n        self.norm2 = nn.LayerNorm(num_channels)\n        self.norm3 = nn.LayerNorm(num_channels)\n\n    def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:\n        # 1. Self-Attention\n        x = self.attn1(self.norm1(x)) + x\n\n        # 2. Cross-Attention\n        x = self.attn2(self.norm2(x), context=context) + x\n\n        # 3. Feed-forward\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply\n    standard transformer action. Finally, reshape to image.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of channels in the input and output.\n        num_attention_heads: number of heads to use for multi-head attention.\n        num_head_channels: number of channels in each attention head.\n        num_layers: number of layers of Transformer blocks to use.\n        dropout: dropout probability to use.\n        norm_num_groups: number of groups for the normalization.\n        norm_eps: epsilon for the normalization.\n        cross_attention_dim: number of context dimensions to use.\n        upcast_attention: if True, upcast attention operations to full precision.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_attention_heads: int,\n        num_head_channels: int,\n        num_layers: int = 1,\n        dropout: float = 0.0,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        cross_attention_dim: int | None = None,\n        upcast_attention: bool = False,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        inner_dim = num_attention_heads * num_head_channels\n\n        self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)\n\n        self.proj_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=inner_dim,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                DiffusionUNetTransformerBlock(\n                    num_channels=inner_dim,\n                    num_attention_heads=num_attention_heads,\n                    num_head_channels=num_head_channels,\n                    dropout=dropout,\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n                for _ in range(num_layers)\n            ]\n        )\n\n        self.proj_out = zero_module(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=inner_dim,\n                out_channels=in_channels,\n                strides=1,\n                kernel_size=1,\n                padding=0,\n                conv_only=True,\n            )\n        )\n\n    def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:\n        # note: if no context is given, cross-attention defaults to self-attention\n        batch = channel = height = width = depth = -1\n        if self.spatial_dims == 2:\n            batch, channel, height, width = x.shape\n        if self.spatial_dims == 3:\n            batch, channel, height, width, depth = x.shape\n\n        residual = x\n        x = self.norm(x)\n        x = self.proj_in(x)\n\n        inner_dim = x.shape[1]\n\n        if self.spatial_dims == 2:\n            x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)\n        if self.spatial_dims == 3:\n            x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim)\n\n        for block in self.transformer_blocks:\n            x = block(x, context=context)\n\n        if self.spatial_dims == 2:\n            x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()\n        if self.spatial_dims == 3:\n            x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()\n\n        x = self.proj_out(x)\n        return x + residual\n\n\ndef get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:\n    \"\"\"\n    Create sinusoidal timestep embeddings following the implementation in Ho et al. \"Denoising Diffusion Probabilistic\n    Models\" https://arxiv.org/abs/2006.11239.\n\n    Args:\n        timesteps: a 1-D Tensor of N indices, one per batch element.\n        embedding_dim: the dimension of the output.\n        max_period: controls the minimum frequency of the embeddings.\n    \"\"\"\n    if timesteps.ndim != 1:\n        raise ValueError(\"Timesteps should be a 1d-array\")\n\n    half_dim = embedding_dim // 2\n    exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)\n    freqs = torch.exp(exponent / half_dim)\n\n    args = timesteps[:, None].float() * freqs[None, :]\n    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n\n    # zero pad\n    if embedding_dim % 2 == 1:\n        embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))\n\n    return embedding\n\n\nclass DiffusionUnetDownsample(nn.Module):\n    \"\"\"\n    Downsampling layer.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        num_channels: number of input channels.\n        use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is\n            False, the number of output channels must be the same as the number of input channels.\n        out_channels: number of output channels.\n        padding: controls the amount of implicit zero-paddings on both sides for padding number of points\n            for each dimension.\n    \"\"\"\n\n    def __init__(\n        self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1\n    ) -> None:\n        super().__init__()\n        self.num_channels = num_channels\n        self.out_channels = out_channels or num_channels\n        self.use_conv = use_conv\n        if use_conv:\n            self.op = Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.num_channels,\n                out_channels=self.out_channels,\n                strides=2,\n                kernel_size=3,\n                padding=padding,\n                conv_only=True,\n            )\n        else:\n            if self.num_channels != self.out_channels:\n                raise ValueError(\"num_channels and out_channels must be equal when use_conv=False\")\n            self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)\n\n    def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:\n        del emb\n        if x.shape[1] != self.num_channels:\n            raise ValueError(\n                f\"Input number of channels ({x.shape[1]}) is not equal to expected number of channels \"\n                f\"({self.num_channels})\"\n            )\n        output: torch.Tensor = self.op(x)\n        return output\n\n\nclass WrappedUpsample(Upsample):\n    \"\"\"\n    Wraps MONAI upsample block to allow for calling with timestep embeddings.\n    \"\"\"\n\n    def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:\n        del emb\n        upsampled: torch.Tensor = super().forward(x)\n        return upsampled\n\n\nclass DiffusionUNetResnetBlock(nn.Module):\n    \"\"\"\n    Residual block with timestep conditioning.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        temb_channels: number of timestep embedding  channels.\n        out_channels: number of output channels.\n        up: if True, performs upsampling.\n        down: if True, performs downsampling.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        temb_channels: int,\n        out_channels: int | None = None,\n        up: bool = False,\n        down: bool = False,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.channels = in_channels\n        self.emb_channels = temb_channels\n        self.out_channels = out_channels or in_channels\n        self.up = up\n        self.down = down\n\n        self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)\n        self.nonlinearity = nn.SiLU()\n        self.conv1 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        self.upsample = self.downsample = None\n        if self.up:\n            self.upsample = WrappedUpsample(\n                spatial_dims=spatial_dims,\n                mode=\"nontrainable\",\n                in_channels=in_channels,\n                out_channels=in_channels,\n                interp_mode=\"nearest\",\n                scale_factor=2.0,\n                align_corners=None,\n            )\n        elif down:\n            self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)\n\n        self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)\n\n        self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True)\n        self.conv2 = zero_module(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.out_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n        self.skip_connection: nn.Module\n        if self.out_channels == in_channels:\n            self.skip_connection = nn.Identity()\n        else:\n            self.skip_connection = Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=1,\n                padding=0,\n                conv_only=True,\n            )\n\n    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:\n        h = x\n        h = self.norm1(h)\n        h = self.nonlinearity(h)\n\n        if self.upsample is not None:\n            x = self.upsample(x)\n            h = self.upsample(h)\n        elif self.downsample is not None:\n            x = self.downsample(x)\n            h = self.downsample(h)\n\n        h = self.conv1(h)\n\n        if self.spatial_dims == 2:\n            temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]\n        else:\n            temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]\n        h = h + temb\n\n        h = self.norm2(h)\n        h = self.nonlinearity(h)\n        h = self.conv2(h)\n        output: torch.Tensor = self.skip_connection(x) + h\n        return output\n\n\nclass DownBlock(nn.Module):\n    \"\"\"\n    Unet's down block containing resnet and downsamplers blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_downsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for downsampling.\n        downsample_padding: padding used in the downsampling block.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_downsample: bool = True,\n        resblock_updown: bool = False,\n        downsample_padding: int = 1,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n\n        resnets = []\n\n        for i in range(num_res_blocks):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsampler: nn.Module | None\n            if resblock_updown:\n                self.downsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    down=True,\n                )\n            else:\n                self.downsampler = DiffusionUnetDownsample(\n                    spatial_dims=spatial_dims,\n                    num_channels=out_channels,\n                    use_conv=True,\n                    out_channels=out_channels,\n                    padding=downsample_padding,\n                )\n        else:\n            self.downsampler = None\n\n    def forward(\n        self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None\n    ) -> tuple[torch.Tensor, list[torch.Tensor]]:\n        del context\n        output_states = []\n\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb)\n            output_states.append(hidden_states)\n\n        if self.downsampler is not None:\n            hidden_states = self.downsampler(hidden_states, temb)\n            output_states.append(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass AttnDownBlock(nn.Module):\n    \"\"\"\n    Unet's down block containing resnet, downsamplers and self-attention blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding  channels.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_downsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for downsampling.\n        downsample_padding: padding used in the downsampling block.\n        num_head_channels: number of channels in each attention head.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_downsample: bool = True,\n        resblock_updown: bool = False,\n        downsample_padding: int = 1,\n        num_head_channels: int = 1,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n\n        resnets = []\n        attentions = []\n\n        for i in range(num_res_blocks):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                )\n            )\n            attentions.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=out_channels,\n                    num_head_channels=num_head_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.downsampler: nn.Module | None\n        if add_downsample:\n            if resblock_updown:\n                self.downsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    down=True,\n                )\n            else:\n                self.downsampler = DiffusionUnetDownsample(\n                    spatial_dims=spatial_dims,\n                    num_channels=out_channels,\n                    use_conv=True,\n                    out_channels=out_channels,\n                    padding=downsample_padding,\n                )\n        else:\n            self.downsampler = None\n\n    def forward(\n        self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None\n    ) -> tuple[torch.Tensor, list[torch.Tensor]]:\n        del context\n        output_states = []\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states).contiguous()\n            output_states.append(hidden_states)\n\n        if self.downsampler is not None:\n            hidden_states = self.downsampler(hidden_states, temb)\n            output_states.append(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass CrossAttnDownBlock(nn.Module):\n    \"\"\"\n    Unet's down block containing resnet, downsamplers and cross-attention blocks.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_downsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for downsampling.\n        downsample_padding: padding used in the downsampling block.\n        num_head_channels: number of channels in each attention head.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        upcast_attention: if True, upcast attention operations to full precision.\n        dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_downsample: bool = True,\n        resblock_updown: bool = False,\n        downsample_padding: int = 1,\n        num_head_channels: int = 1,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        upcast_attention: bool = False,\n        dropout_cattn: float = 0.0,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n\n        resnets = []\n        attentions = []\n\n        for i in range(num_res_blocks):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                )\n            )\n\n            attentions.append(\n                SpatialTransformer(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    num_attention_heads=out_channels // num_head_channels,\n                    num_head_channels=num_head_channels,\n                    num_layers=transformer_num_layers,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    dropout=dropout_cattn,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.downsampler: nn.Module | None\n        if add_downsample:\n            if resblock_updown:\n                self.downsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    down=True,\n                )\n            else:\n                self.downsampler = DiffusionUnetDownsample(\n                    spatial_dims=spatial_dims,\n                    num_channels=out_channels,\n                    use_conv=True,\n                    out_channels=out_channels,\n                    padding=downsample_padding,\n                )\n        else:\n            self.downsampler = None\n\n    def forward(\n        self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None\n    ) -> tuple[torch.Tensor, list[torch.Tensor]]:\n        output_states = []\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states, context=context).contiguous()\n            output_states.append(hidden_states)\n\n        if self.downsampler is not None:\n            hidden_states = self.downsampler(hidden_states, temb)\n            output_states.append(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass AttnMidBlock(nn.Module):\n    \"\"\"\n    Unet's mid block containing resnet and self-attention blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        temb_channels: number of timestep embedding channels.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        num_head_channels: number of channels in each attention head.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        temb_channels: int,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        num_head_channels: int = 1,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.resnet_1 = DiffusionUNetResnetBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            temb_channels=temb_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n        )\n        self.attention = SpatialAttentionBlock(\n            spatial_dims=spatial_dims,\n            num_channels=in_channels,\n            num_head_channels=num_head_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n\n        self.resnet_2 = DiffusionUNetResnetBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            temb_channels=temb_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n        )\n\n    def forward(\n        self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        del context\n        hidden_states = self.resnet_1(hidden_states, temb)\n        hidden_states = self.attention(hidden_states).contiguous()\n        hidden_states = self.resnet_2(hidden_states, temb)\n\n        return hidden_states\n\n\nclass CrossAttnMidBlock(nn.Module):\n    \"\"\"\n    Unet's mid block containing resnet and cross-attention blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        temb_channels: number of timestep embedding channels\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        num_head_channels: number of channels in each attention head.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        upcast_attention: if True, upcast attention operations to full precision.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        temb_channels: int,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        num_head_channels: int = 1,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        upcast_attention: bool = False,\n        dropout_cattn: float = 0.0,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.resnet_1 = DiffusionUNetResnetBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            temb_channels=temb_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n        )\n        self.attention = SpatialTransformer(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            num_attention_heads=in_channels // num_head_channels,\n            num_head_channels=num_head_channels,\n            num_layers=transformer_num_layers,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            dropout=dropout_cattn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.resnet_2 = DiffusionUNetResnetBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=in_channels,\n            temb_channels=temb_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n        )\n\n    def forward(\n        self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        hidden_states = self.resnet_1(hidden_states, temb)\n        hidden_states = self.attention(hidden_states, context=context)\n        hidden_states = self.resnet_2(hidden_states, temb)\n\n        return hidden_states\n\n\nclass UpBlock(nn.Module):\n    \"\"\"\n    Unet's up block containing resnet and upsamplers blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        prev_output_channel: number of channels from residual connection.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_upsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for upsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_upsample: bool = True,\n        resblock_updown: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n        resnets = []\n\n        for i in range(num_res_blocks):\n            res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        self.upsampler: nn.Module | None\n        if add_upsample:\n            if resblock_updown:\n                self.upsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    up=True,\n                )\n            else:\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                self.upsampler = WrappedUpsample(\n                    spatial_dims=spatial_dims,\n                    mode=\"nontrainable\",\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    interp_mode=\"nearest\",\n                    scale_factor=2.0,\n                    post_conv=post_conv,\n                    align_corners=None,\n                )\n\n        else:\n            self.upsampler = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_list: list[torch.Tensor],\n        temb: torch.Tensor,\n        context: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        del context\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_list[-1]\n            res_hidden_states_list = res_hidden_states_list[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb)\n\n        if self.upsampler is not None:\n            hidden_states = self.upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\nclass AttnUpBlock(nn.Module):\n    \"\"\"\n    Unet's up block containing resnet, upsamplers, and self-attention blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        prev_output_channel: number of channels from residual connection.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_upsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for upsampling.\n        num_head_channels: number of channels in each attention head.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_upsample: bool = True,\n        resblock_updown: bool = False,\n        num_head_channels: int = 1,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n\n        resnets = []\n        attentions = []\n\n        for i in range(num_res_blocks):\n            res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                )\n            )\n            attentions.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=out_channels,\n                    num_head_channels=num_head_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.attentions = nn.ModuleList(attentions)\n\n        self.upsampler: nn.Module | None\n        if add_upsample:\n            if resblock_updown:\n                self.upsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    up=True,\n                )\n            else:\n\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                self.upsampler = WrappedUpsample(\n                    spatial_dims=spatial_dims,\n                    mode=\"nontrainable\",\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    interp_mode=\"nearest\",\n                    scale_factor=2.0,\n                    post_conv=post_conv,\n                    align_corners=None,\n                )\n        else:\n            self.upsampler = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_list: list[torch.Tensor],\n        temb: torch.Tensor,\n        context: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        del context\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_list[-1]\n            res_hidden_states_list = res_hidden_states_list[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states).contiguous()\n\n        if self.upsampler is not None:\n            hidden_states = self.upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\nclass CrossAttnUpBlock(nn.Module):\n    \"\"\"\n    Unet's up block containing resnet, upsamplers, and self-attention blocks.\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        prev_output_channel: number of channels from residual connection.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_upsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for upsampling.\n        num_head_channels: number of channels in each attention head.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        upcast_attention: if True, upcast attention operations to full precision.\n        dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_upsample: bool = True,\n        resblock_updown: bool = False,\n        num_head_channels: int = 1,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        upcast_attention: bool = False,\n        dropout_cattn: float = 0.0,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n\n        resnets = []\n        attentions = []\n\n        for i in range(num_res_blocks):\n            res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                )\n            )\n            attentions.append(\n                SpatialTransformer(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    num_attention_heads=out_channels // num_head_channels,\n                    num_head_channels=num_head_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    num_layers=transformer_num_layers,\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    dropout=dropout_cattn,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.upsampler: nn.Module | None\n        if add_upsample:\n            if resblock_updown:\n                self.upsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    up=True,\n                )\n            else:\n\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                self.upsampler = WrappedUpsample(\n                    spatial_dims=spatial_dims,\n                    mode=\"nontrainable\",\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    interp_mode=\"nearest\",\n                    scale_factor=2.0,\n                    post_conv=post_conv,\n                    align_corners=None,\n                )\n        else:\n            self.upsampler = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_list: list[torch.Tensor],\n        temb: torch.Tensor,\n        context: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_list[-1]\n            res_hidden_states_list = res_hidden_states_list[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states, context=context)\n\n        if self.upsampler is not None:\n            hidden_states = self.upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\ndef get_down_block(\n    spatial_dims: int,\n    in_channels: int,\n    out_channels: int,\n    temb_channels: int,\n    num_res_blocks: int,\n    norm_num_groups: int,\n    norm_eps: float,\n    add_downsample: bool,\n    resblock_updown: bool,\n    with_attn: bool,\n    with_cross_attn: bool,\n    num_head_channels: int,\n    transformer_num_layers: int,\n    cross_attention_dim: int | None,\n    upcast_attention: bool = False,\n    dropout_cattn: float = 0.0,\n    include_fc: bool = True,\n    use_combined_linear: bool = False,\n    use_flash_attention: bool = False,\n) -> nn.Module:\n    if with_attn:\n        return AttnDownBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_downsample=add_downsample,\n            resblock_updown=resblock_updown,\n            num_head_channels=num_head_channels,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n    elif with_cross_attn:\n        return CrossAttnDownBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_downsample=add_downsample,\n            resblock_updown=resblock_updown,\n            num_head_channels=num_head_channels,\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            dropout_cattn=dropout_cattn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n    else:\n        return DownBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_downsample=add_downsample,\n            resblock_updown=resblock_updown,\n        )\n\n\ndef get_mid_block(\n    spatial_dims: int,\n    in_channels: int,\n    temb_channels: int,\n    norm_num_groups: int,\n    norm_eps: float,\n    with_conditioning: bool,\n    num_head_channels: int,\n    transformer_num_layers: int,\n    cross_attention_dim: int | None,\n    upcast_attention: bool = False,\n    dropout_cattn: float = 0.0,\n    include_fc: bool = True,\n    use_combined_linear: bool = False,\n    use_flash_attention: bool = False,\n) -> nn.Module:\n    if with_conditioning:\n        return CrossAttnMidBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            temb_channels=temb_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            num_head_channels=num_head_channels,\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            dropout_cattn=dropout_cattn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n    else:\n        return AttnMidBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            temb_channels=temb_channels,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            num_head_channels=num_head_channels,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n\n\ndef get_up_block(\n    spatial_dims: int,\n    in_channels: int,\n    prev_output_channel: int,\n    out_channels: int,\n    temb_channels: int,\n    num_res_blocks: int,\n    norm_num_groups: int,\n    norm_eps: float,\n    add_upsample: bool,\n    resblock_updown: bool,\n    with_attn: bool,\n    with_cross_attn: bool,\n    num_head_channels: int,\n    transformer_num_layers: int,\n    cross_attention_dim: int | None,\n    upcast_attention: bool = False,\n    dropout_cattn: float = 0.0,\n    include_fc: bool = True,\n    use_combined_linear: bool = False,\n    use_flash_attention: bool = False,\n) -> nn.Module:\n    if with_attn:\n        return AttnUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            prev_output_channel=prev_output_channel,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_upsample=add_upsample,\n            resblock_updown=resblock_updown,\n            num_head_channels=num_head_channels,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n    elif with_cross_attn:\n        return CrossAttnUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            prev_output_channel=prev_output_channel,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_upsample=add_upsample,\n            resblock_updown=resblock_updown,\n            num_head_channels=num_head_channels,\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            dropout_cattn=dropout_cattn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n    else:\n        return UpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            prev_output_channel=prev_output_channel,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_upsample=add_upsample,\n            resblock_updown=resblock_updown,\n        )\n\n\nclass DiffusionModelUNet(nn.Module):\n    \"\"\"\n    Unet network with timestep embedding and attention mechanisms for conditioning based on\n    Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n    and Pinaya et al. \"Brain Imaging Generation with Latent Diffusion Models\" https://arxiv.org/abs/2209.07162\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        num_res_blocks: number of residual blocks (see _ResnetBlock) per level.\n        channels: tuple of block output channels.\n        attention_levels: list of levels to add attention.\n        norm_num_groups: number of groups for the normalization.\n        norm_eps: epsilon for the normalization.\n        resblock_updown: if True use residual blocks for up/downsampling.\n        num_head_channels: number of channels in each attention head.\n        with_conditioning: if True add spatial transformers to perform conditioning.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`\n            classes.\n        upcast_attention: if True, upcast attention operations to full precision.\n        dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html), default to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        resblock_updown: bool = False,\n        num_head_channels: int | Sequence[int] = 8,\n        with_conditioning: bool = False,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        num_class_embeds: int | None = None,\n        upcast_attention: bool = False,\n        dropout_cattn: float = 0.0,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        if with_conditioning is True and cross_attention_dim is None:\n            raise ValueError(\n                \"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) \"\n                \"when using with_conditioning.\"\n            )\n        if cross_attention_dim is not None and with_conditioning is False:\n            raise ValueError(\n                \"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.\"\n            )\n        if dropout_cattn > 1.0 or dropout_cattn < 0.0:\n            raise ValueError(\"Dropout cannot be negative or >1.0!\")\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in channels):\n            raise ValueError(\"DiffusionModelUNet expects all num_channels being multiple of norm_num_groups\")\n\n        if len(channels) != len(attention_levels):\n            raise ValueError(\"DiffusionModelUNet expects num_channels being same size of attention_levels\")\n\n        if isinstance(num_head_channels, int):\n            num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))\n\n        if len(num_head_channels) != len(attention_levels):\n            raise ValueError(\n                \"num_head_channels should have the same length as attention_levels. For the i levels without attention,\"\n                \" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.\"\n            )\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))\n\n        if len(num_res_blocks) != len(channels):\n            raise ValueError(\n                \"`num_res_blocks` should be a single integer or a tuple of integers with the same length as \"\n                \"`num_channels`.\"\n            )\n\n        self.in_channels = in_channels\n        self.block_out_channels = channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_levels = attention_levels\n        self.num_head_channels = num_head_channels\n        self.with_conditioning = with_conditioning\n\n        # input\n        self.conv_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=channels[0],\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        # time\n        time_embed_dim = channels[0] * 4\n        self.time_embed = nn.Sequential(\n            nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)\n        )\n\n        # class embedding\n        self.num_class_embeds = num_class_embeds\n        if num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n\n        # down\n        self.down_blocks = nn.ModuleList([])\n        output_channel = channels[0]\n        for i in range(len(channels)):\n            input_channel = output_channel\n            output_channel = channels[i]\n            is_final_block = i == len(channels) - 1\n\n            down_block = get_down_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                num_res_blocks=num_res_blocks[i],\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_downsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(attention_levels[i] and not with_conditioning),\n                with_cross_attn=(attention_levels[i] and with_conditioning),\n                num_head_channels=num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                dropout_cattn=dropout_cattn,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n            )\n\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.middle_block = get_mid_block(\n            spatial_dims=spatial_dims,\n            in_channels=channels[-1],\n            temb_channels=time_embed_dim,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            with_conditioning=with_conditioning,\n            num_head_channels=num_head_channels[-1],\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            dropout_cattn=dropout_cattn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n\n        # up\n        self.up_blocks = nn.ModuleList([])\n        reversed_block_out_channels = list(reversed(channels))\n        reversed_num_res_blocks = list(reversed(num_res_blocks))\n        reversed_attention_levels = list(reversed(attention_levels))\n        reversed_num_head_channels = list(reversed(num_head_channels))\n        output_channel = reversed_block_out_channels[0]\n        for i in range(len(reversed_block_out_channels)):\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]\n\n            is_final_block = i == len(channels) - 1\n\n            up_block = get_up_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                prev_output_channel=prev_output_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                num_res_blocks=reversed_num_res_blocks[i] + 1,\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_upsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(reversed_attention_levels[i] and not with_conditioning),\n                with_cross_attn=(reversed_attention_levels[i] and with_conditioning),\n                num_head_channels=reversed_num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                dropout_cattn=dropout_cattn,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n            )\n\n            self.up_blocks.append(up_block)\n\n        # out\n        self.out = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),\n            nn.SiLU(),\n            zero_module(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=channels[0],\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n            ),\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        context: torch.Tensor | None = None,\n        class_labels: torch.Tensor | None = None,\n        down_block_additional_residuals: tuple[torch.Tensor] | None = None,\n        mid_block_additional_residual: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: input tensor (N, C, SpatialDims).\n            timesteps: timestep tensor (N,).\n            context: context tensor (N, 1, ContextDim).\n            class_labels: context tensor (N, ).\n            down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).\n            mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).\n        \"\"\"\n        # 1. time\n        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=x.dtype)\n        emb = self.time_embed(t_emb)\n\n        # 2. class\n        if self.num_class_embeds is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n            class_emb = self.class_embedding(class_labels)\n            class_emb = class_emb.to(dtype=x.dtype)\n            emb = emb + class_emb\n\n        # 3. initial convolution\n        h = self.conv_in(x)\n\n        # 4. down\n        if context is not None and self.with_conditioning is False:\n            raise ValueError(\"model should have with_conditioning = True if context is provided\")\n        down_block_res_samples: list[torch.Tensor] = [h]\n        for downsample_block in self.down_blocks:\n            h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)\n            for residual in res_samples:\n                down_block_res_samples.append(residual)\n\n        # Additional residual conections for Controlnets\n        if down_block_additional_residuals is not None:\n            new_down_block_res_samples: list[torch.Tensor] = []\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples += [down_block_res_sample]\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 5. mid\n        h = self.middle_block(hidden_states=h, temb=emb, context=context)\n\n        # Additional residual conections for Controlnets\n        if mid_block_additional_residual is not None:\n            h = h + mid_block_additional_residual\n\n        # 6. up\n        for upsample_block in self.up_blocks:\n            idx: int = -len(upsample_block.resnets)  # type: ignore\n            res_samples = down_block_res_samples[idx:]\n            down_block_res_samples = down_block_res_samples[:idx]\n            h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)\n\n        # 7. output block\n        output: torch.Tensor = self.out(h)\n\n        return output\n\n    def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:\n        \"\"\"\n        Load a state dict from a DiffusionModelUNet trained with\n        [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).\n\n        Args:\n            old_state_dict: state dict from the old DecoderOnlyTransformer  model.\n        \"\"\"\n\n        new_state_dict = self.state_dict()\n        # if all keys match, just load the state dict\n        if all(k in new_state_dict for k in old_state_dict):\n            print(\"All keys match, loading state dict.\")\n            self.load_state_dict(old_state_dict)\n            return\n\n        if verbose:\n            # print all new_state_dict keys that are not in old_state_dict\n            for k in new_state_dict:\n                if k not in old_state_dict:\n                    print(f\"key {k} not found in old state dict\")\n            # and vice versa\n            print(\"----------------------------------------------\")\n            for k in old_state_dict:\n                if k not in new_state_dict:\n                    print(f\"key {k} not found in new state dict\")\n\n        # copy over all matching keys\n        for k in new_state_dict:\n            if k in old_state_dict:\n                new_state_dict[k] = old_state_dict.pop(k)\n\n        # fix the attention blocks\n        attention_blocks = [k.replace(\".attn.to_k.weight\", \"\") for k in new_state_dict if \"attn.to_k.weight\" in k]\n        for block in attention_blocks:\n            new_state_dict[f\"{block}.attn.to_q.weight\"] = old_state_dict.pop(f\"{block}.to_q.weight\")\n            new_state_dict[f\"{block}.attn.to_k.weight\"] = old_state_dict.pop(f\"{block}.to_k.weight\")\n            new_state_dict[f\"{block}.attn.to_v.weight\"] = old_state_dict.pop(f\"{block}.to_v.weight\")\n            new_state_dict[f\"{block}.attn.to_q.bias\"] = old_state_dict.pop(f\"{block}.to_q.bias\")\n            new_state_dict[f\"{block}.attn.to_k.bias\"] = old_state_dict.pop(f\"{block}.to_k.bias\")\n            new_state_dict[f\"{block}.attn.to_v.bias\"] = old_state_dict.pop(f\"{block}.to_v.bias\")\n\n            # projection\n            if f\"{block}.attn.out_proj.weight\" in new_state_dict and f\"{block}.attn.out_proj.bias\" in new_state_dict:\n                new_state_dict[f\"{block}.attn.out_proj.weight\"] = old_state_dict.pop(f\"{block}.proj_attn.weight\")\n                new_state_dict[f\"{block}.attn.out_proj.bias\"] = old_state_dict.pop(f\"{block}.proj_attn.bias\")\n        # fix the cross attention blocks\n        cross_attention_blocks = [\n            k.replace(\".out_proj.weight\", \"\")\n            for k in new_state_dict\n            if \"out_proj.weight\" in k and \"transformer_blocks\" in k\n        ]\n        for block in cross_attention_blocks:\n            new_state_dict[f\"{block}.out_proj.weight\"] = old_state_dict.pop(f\"{block}.to_out.0.weight\")\n            new_state_dict[f\"{block}.out_proj.bias\"] = old_state_dict.pop(f\"{block}.to_out.0.bias\")\n\n        # fix the upsample conv blocks which were renamed postconv\n        for k in new_state_dict:\n            if \"postconv\" in k:\n                old_name = k.replace(\"postconv\", \"conv\")\n                new_state_dict[k] = old_state_dict.pop(old_name)\n        if verbose:\n            # print all remaining keys in old_state_dict\n            print(\"remaining keys in old_state_dict:\", old_state_dict.keys())\n        self.load_state_dict(new_state_dict)\n\n\nclass DiffusionModelEncoder(nn.Module):\n    \"\"\"\n    Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on\n    Wolleb et al. \"Diffusion Models for Medical Anomaly Detection\" (https://arxiv.org/abs/2203.04306).\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        input_shape: spatial shape of the input (without batch and channel dims).\n        num_res_blocks: number of residual blocks (see _ResnetBlock) per level.\n        channels: tuple of block output channels.\n        attention_levels: list of levels to add attention.\n        norm_num_groups: number of groups for the normalization.\n        norm_eps: epsilon for the normalization.\n        resblock_updown: if True use residual blocks for downsampling.\n        num_head_channels: number of channels in each attention head.\n        with_conditioning: if True add spatial transformers to perform conditioning.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.\n        upcast_attention: if True, upcast attention operations to full precision.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        input_shape: Sequence[int] = (64, 64),\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        resblock_updown: bool = False,\n        num_head_channels: int | Sequence[int] = 8,\n        with_conditioning: bool = False,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        num_class_embeds: int | None = None,\n        upcast_attention: bool = False,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        if with_conditioning is True and cross_attention_dim is None:\n            raise ValueError(\n                \"DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) \"\n                \"when using with_conditioning.\"\n            )\n        if cross_attention_dim is not None and with_conditioning is False:\n            raise ValueError(\n                \"DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim.\"\n            )\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in channels):\n            raise ValueError(\"DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups\")\n        if len(channels) != len(attention_levels):\n            raise ValueError(\"DiffusionModelEncoder expects num_channels being same size of attention_levels\")\n\n        if isinstance(num_head_channels, int):\n            num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))\n\n        if len(num_head_channels) != len(attention_levels):\n            raise ValueError(\n                \"num_head_channels should have the same length as attention_levels. For the i levels without attention,\"\n                \" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.\"\n            )\n\n        self.in_channels = in_channels\n        self.block_out_channels = channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_levels = attention_levels\n        self.num_head_channels = num_head_channels\n        self.with_conditioning = with_conditioning\n\n        # input\n        self.conv_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=channels[0],\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        # time\n        time_embed_dim = channels[0] * 4\n        self.time_embed = nn.Sequential(\n            nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)\n        )\n\n        # class embedding\n        self.num_class_embeds = num_class_embeds\n        if num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n\n        # down\n        self.down_blocks = nn.ModuleList([])\n        output_channel = channels[0]\n        for i in range(len(channels)):\n            input_channel = output_channel\n            output_channel = channels[i]\n            is_final_block = i == len(channels)  # - 1\n\n            down_block = get_down_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                num_res_blocks=num_res_blocks[i],\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_downsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(attention_levels[i] and not with_conditioning),\n                with_cross_attn=(attention_levels[i] and with_conditioning),\n                num_head_channels=num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n            )\n\n            self.down_blocks.append(down_block)\n\n        for _ in channels:\n            input_shape = [int(np.ceil(i_ / 2)) for i_ in input_shape]\n\n        last_dim_flattened = int(reduce(lambda x, y: x * y, input_shape) * channels[-1])\n\n        self.out: nn.Module = nn.Sequential(\n            nn.Linear(last_dim_flattened, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        context: torch.Tensor | None = None,\n        class_labels: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: input tensor (N, C, SpatialDims).\n            timesteps: timestep tensor (N,).\n            context: context tensor (N, 1, ContextDim).\n            class_labels: context tensor (N, ).\n        \"\"\"\n        # 1. time\n        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=x.dtype)\n        emb = self.time_embed(t_emb)\n\n        # 2. class\n        if self.num_class_embeds is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n            class_emb = self.class_embedding(class_labels)\n            class_emb = class_emb.to(dtype=x.dtype)\n            emb = emb + class_emb\n\n        # 3. initial convolution\n        h = self.conv_in(x)\n\n        # 4. down\n        if context is not None and self.with_conditioning is False:\n            raise ValueError(\"model should have with_conditioning = True if context is provided\")\n        for downsample_block in self.down_blocks:\n            h, _ = downsample_block(hidden_states=h, temb=emb, context=context)\n\n        h = h.reshape(h.shape[0], -1)\n\n        # 5. out\n        output: torch.Tensor = self.out(h)\n\n        return output\n"
  },
  {
    "path": "monai/networks/nets/dints.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport datetime\nimport warnings\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks.dints_block import (\n    ActiConvNormBlock,\n    FactorizedIncreaseBlock,\n    FactorizedReduceBlock,\n    P3DActiConvNormBlock,\n)\nfrom monai.networks.layers.factories import Conv\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\nfrom monai.utils import optional_import\n\n# solving shortest path problem\ncsr_matrix, _ = optional_import(\"scipy.sparse\", name=\"csr_matrix\")\ndijkstra, _ = optional_import(\"scipy.sparse.csgraph\", name=\"dijkstra\")\n\n__all__ = [\"DiNTS\", \"TopologyConstruction\", \"TopologyInstance\", \"TopologySearch\"]\n\n\n@torch.jit.interface\nclass CellInterface(torch.nn.Module):\n    \"\"\"interface for torchscriptable Cell\"\"\"\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor:  # type: ignore\n        pass\n\n\n@torch.jit.interface\nclass StemInterface(torch.nn.Module):\n    \"\"\"interface for torchscriptable Stem\"\"\"\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore\n        pass\n\n\nclass StemTS(StemInterface):\n    \"\"\"wrapper for torchscriptable Stem\"\"\"\n\n    def __init__(self, *mod):\n        super().__init__()\n        self.mod = torch.nn.Sequential(*mod)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.mod(x)  # type: ignore\n\n\ndef _dfs(node, paths):\n    \"\"\"use depth first search to find all path activation combination\"\"\"\n    if node == paths:\n        return [[0], [1]]\n    child = _dfs(node + 1, paths)\n    return [[0] + _ for _ in child] + [[1] + _ for _ in child]\n\n\nclass _IdentityWithRAMCost(nn.Identity):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.ram_cost = 0\n\n\nclass _ActiConvNormBlockWithRAMCost(ActiConvNormBlock):\n    \"\"\"The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated.\n    Here is the estimation:\n     feature_size = output_size/out_channel\n     total_ram = ram_cost * output_size\n     total_ram = in_channel * feature_size (activation map) +\n                 in_channel * feature_size (convolution map) +\n                 out_channel * feature_size (normalization)\n               = (2*in_channel + out_channel) * output_size/out_channel\n     ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        kernel_size: int,\n        padding: int,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        super().__init__(in_channel, out_channel, kernel_size, padding, spatial_dims, act_name, norm_name)\n        self.ram_cost = 1 + in_channel / out_channel * 2\n\n\nclass _P3DActiConvNormBlockWithRAMCost(P3DActiConvNormBlock):\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        kernel_size: int,\n        padding: int,\n        p3dmode: int = 0,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        super().__init__(in_channel, out_channel, kernel_size, padding, p3dmode, act_name, norm_name)\n        # 1 in_channel (activation) + 1 in_channel (convolution) +\n        # 1 out_channel (convolution) + 1 out_channel (normalization)\n        self.ram_cost = 2 + 2 * in_channel / out_channel\n\n\nclass _FactorizedIncreaseBlockWithRAMCost(FactorizedIncreaseBlock):\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)\n        # s0 is upsampled 2x from s1, representing feature sizes at two resolutions.\n        # 2 * in_channel * s0 (upsample + activation) + 2 * out_channel * s0 (conv + normalization)\n        # s0 = output_size/out_channel\n        self.ram_cost = 2 * in_channel / out_channel + 2\n\n\nclass _FactorizedReduceBlockWithRAMCost(FactorizedReduceBlock):\n\n    def __init__(\n        self,\n        in_channel: int,\n        out_channel: int,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        super().__init__(in_channel, out_channel, spatial_dims, act_name, norm_name)\n        # s0 is upsampled 2x from s1, representing feature sizes at two resolutions.\n        # in_channel * s0 (activation) + 3 * out_channel * s1 (convolution, concatenation, normalization)\n        # s0 = s1 * 2^(spatial_dims) = output_size / out_channel * 2^(spatial_dims)\n        self.ram_cost = in_channel / out_channel * 2**self._spatial_dims + 3\n\n\nclass MixedOp(nn.Module):\n    \"\"\"\n    The weighted averaging of cell operations.\n    Args:\n        c: number of output channels.\n        ops: a dictionary of operations. See also: ``Cell.OPS2D`` or ``Cell.OPS3D``.\n        arch_code_c: binary cell operation code. It represents the operation results added to the output.\n    \"\"\"\n\n    def __init__(self, c: int, ops: dict, arch_code_c=None):\n        super().__init__()\n        if arch_code_c is None:\n            arch_code_c = np.ones(len(ops))\n        self.ops = nn.ModuleList()\n        for arch_c, op_name in zip(arch_code_c, ops):\n            if arch_c > 0:\n                self.ops.append(ops[op_name](c))\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor | None = None):\n        \"\"\"\n        Args:\n            x: input tensor.\n            weight: learnable architecture weights for cell operations. arch_code_c are derived from it.\n        Return:\n            out: weighted average of the operation results.\n        \"\"\"\n        out = 0.0\n        if weight is not None:\n            weight = weight.to(x)\n        for idx, _op in enumerate(self.ops):\n            out = (out + _op(x)) if weight is None else out + _op(x) * weight[idx]\n        return out\n\n\nclass Cell(CellInterface):\n    \"\"\"\n    The basic class for cell operation search, which contains a preprocessing operation and a mixed cell operation.\n    Each cell is defined on a `path` in the topology search space.\n    Args:\n        c_prev: number of input channels\n        c: number of output channels\n        rate: resolution change rate. It represents the preprocessing operation before the mixed cell operation.\n            ``-1`` for 2x downsample, ``1`` for 2x upsample, ``0`` for no change of resolution.\n        arch_code_c: cell operation code\n    \"\"\"\n\n    DIRECTIONS = 3\n    # Possible output paths for `Cell`.\n    #\n    #       - UpSample\n    #      /\n    # +--+/\n    # |  |--- Identity or AlignChannels\n    # +--+\\\n    #      \\\n    #       - Downsample\n\n    # Define 2D operation set, parameterized by the number of channels\n    OPS2D = {\n        \"skip_connect\": lambda _c: _IdentityWithRAMCost(),\n        \"conv_3x3\": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=2),\n    }\n\n    # Define 3D operation set, parameterized by the number of channels\n    OPS3D = {\n        \"skip_connect\": lambda _c: _IdentityWithRAMCost(),\n        \"conv_3x3x3\": lambda c: _ActiConvNormBlockWithRAMCost(c, c, 3, padding=1, spatial_dims=3),\n        \"conv_3x3x1\": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=0),\n        \"conv_3x1x3\": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=1),\n        \"conv_1x3x3\": lambda c: _P3DActiConvNormBlockWithRAMCost(c, c, 3, padding=1, p3dmode=2),\n    }\n\n    # Define connection operation set, parameterized by the number of channels\n    ConnOPS = {\n        \"up\": _FactorizedIncreaseBlockWithRAMCost,\n        \"down\": _FactorizedReduceBlockWithRAMCost,\n        \"identity\": _IdentityWithRAMCost,\n        \"align_channels\": _ActiConvNormBlockWithRAMCost,\n    }\n\n    def __init__(\n        self,\n        c_prev: int,\n        c: int,\n        rate: int,\n        arch_code_c=None,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n    ):\n        super().__init__()\n        self._spatial_dims = spatial_dims\n        self._act_name = act_name\n        self._norm_name = norm_name\n\n        if rate == -1:  # downsample\n            self.preprocess = self.ConnOPS[\"down\"](\n                c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name\n            )\n        elif rate == 1:  # upsample\n            self.preprocess = self.ConnOPS[\"up\"](\n                c_prev, c, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name\n            )\n        else:\n            if c_prev == c:\n                self.preprocess = self.ConnOPS[\"identity\"]()\n            else:\n                self.preprocess = self.ConnOPS[\"align_channels\"](\n                    c_prev, c, 1, 0, spatial_dims=self._spatial_dims, act_name=self._act_name, norm_name=self._norm_name\n                )\n\n        # Define 2D operation set, parameterized by the number of channels\n        self.OPS2D = {\n            \"skip_connect\": lambda _c: _IdentityWithRAMCost(),\n            \"conv_3x3\": lambda c: _ActiConvNormBlockWithRAMCost(\n                c, c, 3, padding=1, spatial_dims=2, act_name=self._act_name, norm_name=self._norm_name\n            ),\n        }\n\n        # Define 3D operation set, parameterized by the number of channels\n        self.OPS3D = {\n            \"skip_connect\": lambda _c: _IdentityWithRAMCost(),\n            \"conv_3x3x3\": lambda c: _ActiConvNormBlockWithRAMCost(\n                c, c, 3, padding=1, spatial_dims=3, act_name=self._act_name, norm_name=self._norm_name\n            ),\n            \"conv_3x3x1\": lambda c: _P3DActiConvNormBlockWithRAMCost(\n                c, c, 3, padding=1, p3dmode=0, act_name=self._act_name, norm_name=self._norm_name\n            ),\n            \"conv_3x1x3\": lambda c: _P3DActiConvNormBlockWithRAMCost(\n                c, c, 3, padding=1, p3dmode=1, act_name=self._act_name, norm_name=self._norm_name\n            ),\n            \"conv_1x3x3\": lambda c: _P3DActiConvNormBlockWithRAMCost(\n                c, c, 3, padding=1, p3dmode=2, act_name=self._act_name, norm_name=self._norm_name\n            ),\n        }\n\n        self.OPS = {}\n        if self._spatial_dims == 2:\n            self.OPS = self.OPS2D\n        elif self._spatial_dims == 3:\n            self.OPS = self.OPS3D\n        else:\n            raise NotImplementedError(f\"Spatial dimensions {self._spatial_dims} is not supported.\")\n\n        self.op = MixedOp(c, self.OPS, arch_code_c)\n\n    def forward(self, x: torch.Tensor, weight: torch.Tensor | None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: input tensor\n            weight: weights for different operations.\n        \"\"\"\n        x = self.preprocess(x)\n        x = self.op(x, weight)\n        return x\n\n\nclass DiNTS(nn.Module):\n    \"\"\"\n    Reimplementation of DiNTS based on\n    \"DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation\n    <https://arxiv.org/abs/2103.15954>\".\n\n    The model contains a pre-defined multi-resolution stem block (defined in this class) and a\n    DiNTS space (defined in :py:class:`monai.networks.nets.TopologyInstance` and\n    :py:class:`monai.networks.nets.TopologySearch`).\n\n    The stem block is for: 1) input downsample and 2) output upsample to original size.\n    The model downsamples the input image by 2 (if ``use_downsample=True``).\n    The downsampled image is downsampled by [1, 2, 4, 8] times (``num_depths=4``) and used as input to the\n    DiNTS search space (``TopologySearch``) or the DiNTS instance (``TopologyInstance``).\n\n        - ``TopologyInstance`` is the final searched model. The initialization requires the searched architecture codes.\n        - ``TopologySearch`` is a multi-path topology and cell operation search space.\n          The architecture codes will be initialized as one.\n        - ``TopologyConstruction`` is the parent class which constructs the instance and search space.\n\n    To meet the requirements of the structure, the input size for each spatial dimension should be:\n    divisible by 2 ** (num_depths + 1).\n\n    Args:\n        dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`.\n        in_channels: number of input image channels.\n        num_classes: number of output segmentation classes.\n        act_name: activation name, default to 'RELU'.\n        norm_name: normalization used in convolution blocks. Default to `InstanceNorm`.\n        spatial_dims: spatial 2D or 3D inputs.\n        use_downsample: use downsample in the stem.\n            If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8],\n            if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].\n        node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`.\n            +1 for multi-resolution inputs.\n            In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None.\n    \"\"\"\n\n    def __init__(\n        self,\n        dints_space,\n        in_channels: int,\n        num_classes: int,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n        spatial_dims: int = 3,\n        use_downsample: bool = True,\n        node_a=None,\n    ):\n        super().__init__()\n\n        self.dints_space = dints_space\n        self.filter_nums = dints_space.filter_nums\n        self.num_blocks = dints_space.num_blocks\n        self.num_depths = dints_space.num_depths\n        if spatial_dims not in (2, 3):\n            raise NotImplementedError(f\"Spatial dimensions {spatial_dims} is not supported.\")\n        self._spatial_dims = spatial_dims\n        if node_a is None:\n            self.node_a = torch.ones((self.num_blocks + 1, self.num_depths))\n        else:\n            self.node_a = node_a\n\n        # define stem operations for every block\n        conv_type = Conv[Conv.CONV, spatial_dims]\n        self.stem_down = nn.ModuleDict()\n        self.stem_up = nn.ModuleDict()\n        self.stem_finals = nn.Sequential(\n            ActiConvNormBlock(\n                self.filter_nums[0],\n                self.filter_nums[0],\n                act_name=act_name,\n                norm_name=norm_name,\n                spatial_dims=spatial_dims,\n            ),\n            conv_type(\n                in_channels=self.filter_nums[0],\n                out_channels=num_classes,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                groups=1,\n                bias=True,\n                dilation=1,\n            ),\n        )\n        mode = \"trilinear\" if self._spatial_dims == 3 else \"bilinear\"\n        for res_idx in range(self.num_depths):\n            # define downsample stems before DiNTS search\n            if use_downsample:\n                self.stem_down[str(res_idx)] = StemTS(\n                    nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True),\n                    conv_type(\n                        in_channels=in_channels,\n                        out_channels=self.filter_nums[res_idx],\n                        kernel_size=3,\n                        stride=1,\n                        padding=1,\n                        groups=1,\n                        bias=False,\n                        dilation=1,\n                    ),\n                    get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]),\n                    get_act_layer(name=act_name),\n                    conv_type(\n                        in_channels=self.filter_nums[res_idx],\n                        out_channels=self.filter_nums[res_idx + 1],\n                        kernel_size=3,\n                        stride=2,\n                        padding=1,\n                        groups=1,\n                        bias=False,\n                        dilation=1,\n                    ),\n                    get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx + 1]),\n                )\n                self.stem_up[str(res_idx)] = StemTS(\n                    get_act_layer(name=act_name),\n                    conv_type(\n                        in_channels=self.filter_nums[res_idx + 1],\n                        out_channels=self.filter_nums[res_idx],\n                        kernel_size=3,\n                        stride=1,\n                        padding=1,\n                        groups=1,\n                        bias=False,\n                        dilation=1,\n                    ),\n                    get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]),\n                    nn.Upsample(scale_factor=2, mode=mode, align_corners=True),\n                )\n\n            else:\n                self.stem_down[str(res_idx)] = StemTS(\n                    nn.Upsample(scale_factor=1 / (2**res_idx), mode=mode, align_corners=True),\n                    conv_type(\n                        in_channels=in_channels,\n                        out_channels=self.filter_nums[res_idx],\n                        kernel_size=3,\n                        stride=1,\n                        padding=1,\n                        groups=1,\n                        bias=False,\n                        dilation=1,\n                    ),\n                    get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[res_idx]),\n                )\n                self.stem_up[str(res_idx)] = StemTS(\n                    get_act_layer(name=act_name),\n                    conv_type(\n                        in_channels=self.filter_nums[res_idx],\n                        out_channels=self.filter_nums[max(res_idx - 1, 0)],\n                        kernel_size=3,\n                        stride=1,\n                        padding=1,\n                        groups=1,\n                        bias=False,\n                        dilation=1,\n                    ),\n                    get_norm_layer(\n                        name=norm_name, spatial_dims=spatial_dims, channels=self.filter_nums[max(res_idx - 1, 0)]\n                    ),\n                    nn.Upsample(scale_factor=2 ** (res_idx != 0), mode=mode, align_corners=True),\n                )\n\n    def weight_parameters(self):\n        return [param for name, param in self.named_parameters()]\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\n        Prediction based on dynamic arch_code.\n\n        Args:\n            x: input tensor.\n        \"\"\"\n        inputs = []\n        for d in range(self.num_depths):\n            # allow multi-resolution input\n            _mod_w: StemInterface = self.stem_down[str(d)]  # type: ignore[assignment]\n            x_out = _mod_w.forward(x)\n            if self.node_a[0][d]:\n                inputs.append(x_out)\n            else:\n                inputs.append(torch.zeros_like(x_out))\n\n        outputs = self.dints_space(inputs)\n\n        blk_idx = self.num_blocks - 1\n        start = False\n        _temp: torch.Tensor = torch.empty(0)\n        for res_idx in range(self.num_depths - 1, -1, -1):\n            _mod_up: StemInterface = self.stem_up[str(res_idx)]  # type: ignore[assignment]\n            if start:\n                _temp = _mod_up.forward(outputs[res_idx] + _temp)\n            elif self.node_a[blk_idx + 1][res_idx]:\n                start = True\n                _temp = _mod_up.forward(outputs[res_idx])\n        prediction = self.stem_finals(_temp)\n        return prediction\n\n\nclass TopologyConstruction(nn.Module):\n    \"\"\"\n    The base class for `TopologyInstance` and `TopologySearch`.\n\n    Args:\n        arch_code: `[arch_code_a, arch_code_c]`, numpy arrays. The architecture codes defining the model.\n            For example, for a ``num_depths=4, num_blocks=12`` search space:\n\n            - `arch_code_a` is a 12x10 (10 paths) binary matrix representing if a path is activated.\n            - `arch_code_c` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used.\n            - `arch_code` in ``__init__()`` is used for creating the network and remove unused network blocks. If None,\n\n            all paths and cells operations will be used, and must be in the searching stage (is_search=True).\n        channel_mul: adjust intermediate channel number, default is 1.\n        cell: operation of each node.\n        num_blocks: number of blocks (depth in the horizontal direction) of the DiNTS search space.\n        num_depths: number of image resolutions of the DiNTS search space: 1, 1/2, 1/4 ... in each dimension.\n        use_downsample: use downsample in the stem. If False, the search space will be in resolution [1, 1/2, 1/4, 1/8],\n            if True, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].\n        device: `'cpu'`, `'cuda'`, or device ID.\n\n\n    Predefined variables:\n        `filter_nums`: default to 32. Double the number of channels after downsample.\n        topology related variables:\n\n            - `arch_code2in`: path activation to its incoming node index (resolution). For depth = 4,\n              arch_code2in = [0, 1, 0, 1, 2, 1, 2, 3, 2, 3]. The first path outputs from node 0 (top resolution),\n              the second path outputs from node 1 (second resolution in the search space),\n              the third path outputs from node 0, etc.\n            - `arch_code2ops`: path activation to operations of upsample 1, keep 0, downsample -1. For depth = 4,\n              arch_code2ops = [0, 1, -1, 0, 1, -1, 0, 1, -1, 0]. The first path does not change\n              resolution, the second path perform upsample, the third perform downsample, etc.\n            - `arch_code2out`: path activation to its output node index.\n              For depth = 4, arch_code2out = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],\n              the first and second paths connects to node 0 (top resolution), the 3,4,5 paths connects to node 1, etc.\n    \"\"\"\n\n    def __init__(\n        self,\n        arch_code: list | None = None,\n        channel_mul: float = 1.0,\n        cell=Cell,\n        num_blocks: int = 6,\n        num_depths: int = 3,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n        use_downsample: bool = True,\n        device: str = \"cpu\",\n    ):\n        super().__init__()\n\n        n_feats = tuple([32 * (2**_i) for _i in range(num_depths + 1)])\n        self.filter_nums = [int(n_feat * channel_mul) for n_feat in n_feats]\n\n        self.num_blocks = num_blocks\n        self.num_depths = num_depths\n        print(\n            f\"{datetime.datetime.now()}\"\n            f\" - Length of input patch is recommended to be a multiple of {2 ** (num_depths + int(use_downsample)):d}.\"\n        )\n\n        self._spatial_dims = spatial_dims\n        self._act_name = act_name\n        self._norm_name = norm_name\n        self.use_downsample = use_downsample\n        self.device = device\n        self.num_cell_ops = 0\n        if self._spatial_dims == 2:\n            self.num_cell_ops = len(cell.OPS2D)\n        elif self._spatial_dims == 3:\n            self.num_cell_ops = len(cell.OPS3D)\n\n        # Calculate predefined parameters for topology search and decoding\n        arch_code2in, arch_code2out = [], []\n        for i in range(Cell.DIRECTIONS * self.num_depths - 2):\n            arch_code2in.append((i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS)\n        arch_code2ops = ([-1, 0, 1] * self.num_depths)[1:-1]\n        for m in range(self.num_depths):\n            arch_code2out.extend([m, m, m])\n        arch_code2out = arch_code2out[1:-1]\n        self.arch_code2in = arch_code2in\n        self.arch_code2ops = arch_code2ops\n        self.arch_code2out = arch_code2out\n\n        # define NAS search space\n        if arch_code is None:\n            arch_code_a = torch.ones((self.num_blocks, len(self.arch_code2out))).to(self.device)\n            arch_code_c = torch.ones((self.num_blocks, len(self.arch_code2out), self.num_cell_ops)).to(self.device)\n        else:\n            arch_code_a = torch.from_numpy(arch_code[0]).to(self.device)\n            arch_code_c = F.one_hot(torch.from_numpy(arch_code[1]).to(torch.int64), self.num_cell_ops).to(self.device)\n\n        self.arch_code_a = arch_code_a\n        self.arch_code_c = arch_code_c\n        # define cell operation on each path\n        self.cell_tree = nn.ModuleDict()\n        for blk_idx in range(self.num_blocks):\n            for res_idx in range(len(self.arch_code2out)):\n                if self.arch_code_a[blk_idx, res_idx] == 1:\n                    self.cell_tree[str((blk_idx, res_idx))] = cell(\n                        self.filter_nums[self.arch_code2in[res_idx] + int(use_downsample)],\n                        self.filter_nums[self.arch_code2out[res_idx] + int(use_downsample)],\n                        self.arch_code2ops[res_idx],\n                        self.arch_code_c[blk_idx, res_idx],\n                        self._spatial_dims,\n                        self._act_name,\n                        self._norm_name,\n                    )\n\n    def forward(self, x):\n        \"\"\"This function to be implemented by the architecture instances or search spaces.\"\"\"\n\n\nclass TopologyInstance(TopologyConstruction):\n    \"\"\"\n    Instance of the final searched architecture. Only used in re-training/inference stage.\n    \"\"\"\n\n    def __init__(\n        self,\n        arch_code=None,\n        channel_mul: float = 1.0,\n        cell=Cell,\n        num_blocks: int = 6,\n        num_depths: int = 3,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n        use_downsample: bool = True,\n        device: str = \"cpu\",\n    ):\n        \"\"\"\n        Initialize DiNTS topology search space of neural architectures.\n        \"\"\"\n        if arch_code is None:\n            warnings.warn(\"arch_code not provided when not searching.\")\n\n        super().__init__(\n            arch_code=arch_code,\n            channel_mul=channel_mul,\n            cell=cell,\n            num_blocks=num_blocks,\n            num_depths=num_depths,\n            spatial_dims=spatial_dims,\n            act_name=act_name,\n            norm_name=norm_name,\n            use_downsample=use_downsample,\n            device=device,\n        )\n\n    def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor]:\n        \"\"\"\n        Args:\n            x: input tensor.\n        \"\"\"\n        # generate path activation probability\n        inputs = x\n        for blk_idx in range(self.num_blocks):\n            outputs = [torch.tensor(0.0, dtype=x[0].dtype, device=x[0].device)] * self.num_depths\n            for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data):\n                if activation:\n                    mod: CellInterface = self.cell_tree[str((blk_idx, res_idx))]  # type: ignore[assignment]\n                    _out = mod.forward(x=inputs[self.arch_code2in[res_idx]], weight=None)\n                    outputs[self.arch_code2out[res_idx]] = outputs[self.arch_code2out[res_idx]] + _out\n            inputs = outputs\n\n        return inputs\n\n\nclass TopologySearch(TopologyConstruction):\n    \"\"\"\n    DiNTS topology search space of neural architectures.\n\n    Examples:\n\n    .. code-block:: python\n\n        from monai.networks.nets.dints import TopologySearch\n\n        topology_search_space = TopologySearch(\n            channel_mul=0.5, num_blocks=8, num_depths=4, use_downsample=True, spatial_dims=3)\n        topology_search_space.get_ram_cost_usage(in_size=(2, 16, 80, 80, 80), full=True)\n        multi_res_images = [\n            torch.randn(2, 16, 80, 80, 80),\n            torch.randn(2, 32, 40, 40, 40),\n            torch.randn(2, 64, 20, 20, 20),\n            torch.randn(2, 128, 10, 10, 10)]\n        prediction = topology_search_space(image)\n        for x in prediction: print(x.shape)\n        # torch.Size([2, 16, 80, 80, 80])\n        # torch.Size([2, 32, 40, 40, 40])\n        # torch.Size([2, 64, 20, 20, 20])\n        # torch.Size([2, 128, 10, 10, 10])\n\n    Class method overview:\n\n        - ``get_prob_a()``: convert learnable architecture weights to path activation probabilities.\n        - ``get_ram_cost_usage()``: get estimated ram cost.\n        - ``get_topology_entropy()``: get topology entropy loss in searching stage.\n        - ``decode()``: get final binarized architecture code.\n        - ``gen_mtx()``: generate variables needed for topology search.\n\n    Predefined variables:\n        - `tidx`: index used to convert path activation matrix T = (depth,depth) in transfer_mtx to\n          path activation arch_code (1,3*depth-2), for depth = 4, tidx = [0, 1, 4, 5, 6, 9, 10, 11, 14, 15],\n          A tidx (10 binary values) represents the path activation.\n        - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern.\n          It is used to convert path activation pattern (1, paths) to node activation (1, nodes)\n        - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation\n          patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths).\n        - `all_connect`: All possible path activations. For depth = 4,\n          all_connection has 1024 vectors of length 10 (10 paths).\n          The return value will exclude path activation of all 0.\n    \"\"\"\n\n    node2out: list[list]\n    node2in: list[list]\n\n    def __init__(\n        self,\n        channel_mul: float = 1.0,\n        cell=Cell,\n        arch_code: list | None = None,\n        num_blocks: int = 6,\n        num_depths: int = 3,\n        spatial_dims: int = 3,\n        act_name: tuple | str = \"RELU\",\n        norm_name: tuple | str = (\"INSTANCE\", {\"affine\": True}),\n        use_downsample: bool = True,\n        device: str = \"cpu\",\n    ):\n        \"\"\"\n        Initialize DiNTS topology search space of neural architectures.\n        \"\"\"\n        super().__init__(\n            arch_code=arch_code,\n            channel_mul=channel_mul,\n            cell=cell,\n            num_blocks=num_blocks,\n            num_depths=num_depths,\n            spatial_dims=spatial_dims,\n            act_name=act_name,\n            norm_name=norm_name,\n            use_downsample=use_downsample,\n            device=device,\n        )\n\n        tidx = []\n        _d = Cell.DIRECTIONS\n        for i in range(_d * self.num_depths - 2):\n            tidx.append((i + 1) // _d * self.num_depths + (i + 1) // _d - 1 + (i + 1) % _d)\n        self.tidx = tidx\n        transfer_mtx, node_act_list, child_list = self.gen_mtx(num_depths)\n\n        self.node_act_list = np.asarray(node_act_list)\n        self.node_act_dict = {str(self.node_act_list[i]): i for i in range(len(self.node_act_list))}\n        self.transfer_mtx = transfer_mtx\n        self.child_list = np.asarray(child_list)\n\n        self.ram_cost = np.zeros((self.num_blocks, len(self.arch_code2out), self.num_cell_ops))\n        for blk_idx in range(self.num_blocks):\n            for res_idx in range(len(self.arch_code2out)):\n                if self.arch_code_a[blk_idx, res_idx] == 1:\n                    cell_inter: Cell = self.cell_tree[str((blk_idx, res_idx))]  # type: ignore\n                    self.ram_cost[blk_idx, res_idx] = np.array(\n                        [op.ram_cost + cell_inter.preprocess.ram_cost for op in cell_inter.op.ops[: self.num_cell_ops]]\n                    )  # type: ignore\n\n        # define cell and macro architecture probabilities\n        self.log_alpha_c = nn.Parameter(\n            torch.zeros(self.num_blocks, len(self.arch_code2out), self.num_cell_ops)\n            .normal_(1, 0.01)\n            .to(self.device)\n            .requires_grad_()\n        )\n        self.log_alpha_a = nn.Parameter(\n            torch.zeros(self.num_blocks, len(self.arch_code2out)).normal_(0, 0.01).to(self.device).requires_grad_()\n        )\n        self._arch_param_names = [\"log_alpha_a\", \"log_alpha_c\"]\n\n    def gen_mtx(self, depth: int):\n        \"\"\"\n        Generate elements needed in decoding and topology.\n\n            - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern.\n               It is used to convert path activation pattern (1, paths) to node activation (1, nodes)\n            - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation\n               patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths).\n            - `all_connect`: All possible path activations. For depth = 4,\n              all_connection has 1024 vectors of length 10 (10 paths).\n              The return value will exclude path activation of all 0.\n        \"\"\"\n        # total paths in a block, each node has three output paths,\n        # except the two nodes at the top and the bottom scales\n        paths = Cell.DIRECTIONS * depth - 2\n\n        # for 10 paths, all_connect has 1024 possible path activations. [1 0 0 0 0 0 0 0 0 0] means the top\n        # path is activated.\n        all_connect = _dfs(0, paths - 1)\n\n        # Save all possible connections in mtx (might be redundant and infeasible)\n        mtx = []\n        for m in all_connect:\n            # convert path activation [1,paths] to path activation matrix [depth, depth]\n            ma = np.zeros((depth, depth))\n            for i in range(paths):\n                ma[(i + 1) // Cell.DIRECTIONS, (i + 1) // Cell.DIRECTIONS - 1 + (i + 1) % Cell.DIRECTIONS] = m[i]\n            mtx.append(ma)\n\n        # define all possible node activation\n        node_act_list = _dfs(0, depth - 1)[1:]\n        transfer_mtx = {}\n        for arch_code in node_act_list:\n            # make sure each activated node has an active connection, inactivated node has no connection\n            arch_code_mtx = [_ for _ in mtx if ((np.sum(_, 0) > 0).astype(int) == np.array(arch_code)).all()]\n            transfer_mtx[str(np.array(arch_code))] = arch_code_mtx\n\n        return transfer_mtx, node_act_list, all_connect[1:]\n\n    def weight_parameters(self):\n        return [param for name, param in self.named_parameters() if name not in self._arch_param_names]\n\n    def get_prob_a(self, child: bool = False):\n        \"\"\"\n        Get final path and child model probabilities from architecture weights `log_alpha_a`.\n        This is used in forward pass, getting training loss, and final decoding.\n\n        Args:\n            child: return child probability (used in decoding)\n        Return:\n            arch_code_prob_a: the path activation probability of size:\n                `[number of blocks, number of paths in each block]`.\n                For 12 blocks, 4 depths search space, the size is [12,10]\n            probs_a: The probability of all child models (size 1023x10). Each child model is a path activation pattern\n                 (1D vector of length 10 for 10 paths). In total 1023 child models (2^10 -1)\n        \"\"\"\n        _arch_code_prob_a = torch.sigmoid(self.log_alpha_a)\n        # remove the case where all path are zero, and re-normalize.\n        norm = 1 - (1 - _arch_code_prob_a).prod(-1)\n        arch_code_prob_a = _arch_code_prob_a / norm.unsqueeze(1)\n        if child:\n            path_activation = torch.from_numpy(self.child_list).to(self.device)\n            probs_a = [\n                (\n                    path_activation * _arch_code_prob_a[blk_idx]\n                    + (1 - path_activation) * (1 - _arch_code_prob_a[blk_idx])\n                ).prod(-1)\n                / norm[blk_idx]\n                for blk_idx in range(self.num_blocks)\n            ]\n            probs_a = torch.stack(probs_a)  # type: ignore\n            return probs_a, arch_code_prob_a\n        return None, arch_code_prob_a\n\n    def get_ram_cost_usage(self, in_size, full: bool = False):\n        \"\"\"\n        Get estimated output tensor size to approximate RAM consumption.\n\n        Args:\n            in_size: input image shape (4D/5D, ``[BCHW[D]]``) at the highest resolution level.\n            full: full ram cost usage with all probability of 1.\n        \"\"\"\n        # convert input image size to feature map size at each level\n        batch_size = in_size[0]\n        image_size = np.array(in_size[-self._spatial_dims :])\n        sizes = []\n        for res_idx in range(self.num_depths):\n            sizes.append(batch_size * self.filter_nums[res_idx] * (image_size // (2**res_idx)).prod())\n        sizes = torch.tensor(sizes, dtype=torch.float32, device=self.device) / (2 ** (int(self.use_downsample)))\n        probs_a, arch_code_prob_a = self.get_prob_a(child=False)\n        cell_prob = F.softmax(self.log_alpha_c, dim=-1)\n        if full:\n            arch_code_prob_a = arch_code_prob_a.detach()\n            arch_code_prob_a.fill_(1)\n        ram_cost = torch.from_numpy(self.ram_cost).to(dtype=torch.float32, device=self.device)\n        usage = 0.0\n        for blk_idx in range(self.num_blocks):\n            # node activation for input\n            # cell operation\n            for path_idx in range(len(self.arch_code2out)):\n                usage += (\n                    arch_code_prob_a[blk_idx, path_idx]\n                    * (1 + (ram_cost[blk_idx, path_idx] * cell_prob[blk_idx, path_idx]).sum())\n                    * sizes[self.arch_code2out[path_idx]]\n                )\n        return usage * 32 / 8 / 1024**2\n\n    def get_topology_entropy(self, probs):\n        \"\"\"\n        Get topology entropy loss at searching stage.\n\n        Args:\n            probs: path activation probabilities\n        \"\"\"\n        if hasattr(self, \"node2in\"):\n            node2in = self.node2in  # pylint: disable=E0203\n            node2out = self.node2out  # pylint: disable=E0203\n        else:\n            # node activation index to feasible input child_idx\n            node2in = [[] for _ in range(len(self.node_act_list))]\n            # node activation index to feasible output child_idx\n            node2out = [[] for _ in range(len(self.node_act_list))]\n            for child_idx in range(len(self.child_list)):\n                _node_in, _node_out = np.zeros(self.num_depths), np.zeros(self.num_depths)\n                for res_idx in range(len(self.arch_code2out)):\n                    _node_out[self.arch_code2out[res_idx]] += self.child_list[child_idx][res_idx]\n                    _node_in[self.arch_code2in[res_idx]] += self.child_list[child_idx][res_idx]\n                _node_in = (_node_in >= 1).astype(int)\n                _node_out = (_node_out >= 1).astype(int)\n                node2in[self.node_act_dict[str(_node_out)]].append(child_idx)\n                node2out[self.node_act_dict[str(_node_in)]].append(child_idx)\n            self.node2in = node2in\n            self.node2out = node2out\n        # calculate entropy\n        ent = 0\n        for blk_idx in range(self.num_blocks - 1):\n            blk_ent = 0\n            # node activation probability\n            for node_idx in range(len(self.node_act_list)):\n                _node_p = probs[blk_idx, node2in[node_idx]].sum()\n                _out_probs = probs[blk_idx + 1, node2out[node_idx]].sum()\n                blk_ent += -(_node_p * torch.log(_out_probs + 1e-5) + (1 - _node_p) * torch.log(1 - _out_probs + 1e-5))\n            ent += blk_ent\n        return ent\n\n    def decode(self):\n        \"\"\"\n        Decode network log_alpha_a/log_alpha_c using dijkstra shortest path algorithm.\n\n        `[node_a, arch_code_a, arch_code_c, arch_code_a_max]` is decoded when using ``self.decode()``.\n\n        For example, for a ``num_depths=4``, ``num_blocks=12`` search space:\n\n            - ``node_a`` is a 4x13 binary matrix representing if a feature node is activated\n              (13 because of multi-resolution inputs).\n            - ``arch_code_a`` is a 12x10 (10 paths) binary matrix representing if a path is activated.\n            - ``arch_code_c`` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used.\n\n        Return:\n            arch_code with maximum probability\n        \"\"\"\n        probs, arch_code_prob_a = self.get_prob_a(child=True)\n        arch_code_a_max = self.child_list[torch.argmax(probs, -1).data.cpu().numpy()]\n        arch_code_c = torch.argmax(F.softmax(self.log_alpha_c, -1), -1).data.cpu().numpy()\n        probs = probs.data.cpu().numpy()\n\n        # define adjacency matrix\n        amtx = np.zeros(\n            (1 + len(self.child_list) * self.num_blocks + 1, 1 + len(self.child_list) * self.num_blocks + 1)\n        )\n\n        # build a path activation to child index searching dictionary\n        path2child = {str(self.child_list[i]): i for i in range(len(self.child_list))}\n\n        # build a submodel to submodel index\n        sub_amtx = np.zeros((len(self.child_list), len(self.child_list)))\n        for child_idx in range(len(self.child_list)):\n            _node_act = np.zeros(self.num_depths).astype(int)\n            for path_idx in range(len(self.child_list[child_idx])):\n                _node_act[self.arch_code2out[path_idx]] += self.child_list[child_idx][path_idx]\n            _node_act = (_node_act >= 1).astype(int)\n            for mtx in self.transfer_mtx[str(_node_act)]:\n                connect_child_idx = path2child[str(mtx.flatten()[self.tidx].astype(int))]\n                sub_amtx[child_idx, connect_child_idx] = 1\n\n        # fill in source to first block, add 1e-5/1e-3 to avoid log0 and negative edge weights\n        amtx[0, 1 : 1 + len(self.child_list)] = -np.log(probs[0] + 1e-5) + 0.001\n\n        # fill in the rest blocks\n        for blk_idx in range(1, self.num_blocks):\n            amtx[\n                1 + (blk_idx - 1) * len(self.child_list) : 1 + blk_idx * len(self.child_list),\n                1 + blk_idx * len(self.child_list) : 1 + (blk_idx + 1) * len(self.child_list),\n            ] = sub_amtx * np.tile(-np.log(probs[blk_idx] + 1e-5) + 0.001, (len(self.child_list), 1))\n\n        # fill in the last to the sink\n        amtx[1 + (self.num_blocks - 1) * len(self.child_list) : 1 + self.num_blocks * len(self.child_list), -1] = 0.001\n\n        graph = csr_matrix(amtx)\n        dist_matrix, predecessors, sources = dijkstra(\n            csgraph=graph, directed=True, indices=0, min_only=True, return_predecessors=True\n        )\n        index, a_idx = -1, -1\n        arch_code_a = np.zeros((self.num_blocks, len(self.arch_code2out)))\n        node_a = np.zeros((self.num_blocks + 1, self.num_depths))\n\n        # decoding to paths\n        while True:\n            index = predecessors[index]\n            if index == 0:\n                break\n            child_idx = (index - 1) % len(self.child_list)\n            arch_code_a[a_idx, :] = self.child_list[child_idx]\n            for res_idx in range(len(self.arch_code2out)):\n                node_a[a_idx, self.arch_code2out[res_idx]] += arch_code_a[a_idx, res_idx]\n            a_idx -= 1\n        for res_idx in range(len(self.arch_code2out)):\n            node_a[a_idx, self.arch_code2in[res_idx]] += arch_code_a[0, res_idx]\n        node_a = (node_a >= 1).astype(int)\n        return node_a, arch_code_a, arch_code_c, arch_code_a_max\n\n    def forward(self, x):\n        \"\"\"\n        Prediction based on dynamic arch_code.\n\n        Args:\n            x: a list of `num_depths` input tensors as a multi-resolution input.\n                tensor is of shape `BCHW[D]` where `C` must match `self.filter_nums`.\n        \"\"\"\n        # generate path activation probability\n        probs_a, arch_code_prob_a = self.get_prob_a(child=False)\n        inputs = x\n        for blk_idx in range(self.num_blocks):\n            outputs = [0.0] * self.num_depths\n            for res_idx, activation in enumerate(self.arch_code_a[blk_idx].data.cpu().numpy()):\n                if activation:\n                    _w = F.softmax(self.log_alpha_c[blk_idx, res_idx], dim=-1)\n                    outputs[self.arch_code2out[res_idx]] += (\n                        self.cell_tree[str((blk_idx, res_idx))](inputs[self.arch_code2in[res_idx]], weight=_w)\n                        * arch_code_prob_a[blk_idx, res_idx]\n                    )\n            inputs = outputs\n\n        return inputs\n"
  },
  {
    "path": "monai/networks/nets/dynunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# isort: dont-add-import: from __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Optional, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.functional import interpolate\n\nfrom monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock\n\n__all__ = [\"DynUNet\", \"DynUnet\", \"Dynunet\"]\n\n\nclass DynUNetSkipLayer(nn.Module):\n    \"\"\"\n    Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection.\n    The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet\n    structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on\n    looping over lists of layers and accumulating lists of output tensors which must be indexed. The `heads` list is\n    shared amongst all the instances of this class and is used to store the output from the supervision heads during\n    forward passes of the network.\n    \"\"\"\n\n    heads: Optional[list[torch.Tensor]]\n\n    def __init__(self, index, downsample, upsample, next_layer, heads=None, super_head=None):\n        super().__init__()\n        self.downsample = downsample\n        self.next_layer = next_layer\n        self.upsample = upsample\n        self.super_head = super_head\n        self.heads = heads\n        self.index = index\n\n    def forward(self, x):\n        downout = self.downsample(x)\n        nextout = self.next_layer(downout)\n        upout = self.upsample(nextout, downout)\n        if self.super_head is not None and self.heads is not None and self.index > 0:\n            self.heads[self.index - 1] = self.super_head(upout)\n\n        return upout\n\n\nclass DynUNet(nn.Module):\n    \"\"\"\n    This reimplementation of a dynamic UNet (DynUNet) is based on:\n    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.\n    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.\n    `Optimized U-Net for Brain Tumor Segmentation <https://arxiv.org/pdf/2110.03352.pdf>`_.\n\n    This model is more flexible compared with ``monai.networks.nets.UNet`` in three\n    places:\n\n        - Residual connection is supported in conv blocks.\n        - Anisotropic kernel sizes and strides can be used in each layers.\n        - Deep supervision heads can be added.\n\n    The model supports 2D or 3D inputs and is consisted with four kinds of blocks:\n    one input block, `n` downsample blocks, one bottleneck and `n+1` upsample blocks. Where, `n>0`.\n    The first and last kernel and stride values of the input sequences are used for input block and\n    bottleneck respectively, and the rest value(s) are used for downsample and upsample blocks.\n    Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``)\n    is no less than 3 in order to have at least one downsample and upsample blocks.\n\n    To meet the requirements of the structure, the input size for each spatial dimension should be divisible\n    by the product of all strides in the corresponding dimension. In addition, the minimal spatial size should have\n    at least one dimension that has twice the size of the product of all strides.\n    For example, if `strides=((1, 2, 4), 2, 2, 1)`, the spatial size should be divisible by `(4, 8, 16)`,\n    and the minimal spatial size is `(8, 8, 16)` or `(4, 16, 16)` or `(4, 8, 32)`.\n\n    The output size for each spatial dimension equals to the input size of the corresponding dimension divided by the\n    stride in strides[0].\n    For example, if `strides=((1, 2, 4), 2, 2, 1)` and the input size is `(64, 32, 32)`, the output size is `(64, 16, 8)`.\n\n    For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`.\n\n    Usage example with medical segmentation decathlon dataset is available at:\n    https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        kernel_size: convolution kernel size.\n        strides: convolution strides for each blocks.\n        upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should\n            equal to strides[1:].\n        filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add\n            this argument to make the network more flexible. As shown in the third reference, one way to determine\n            this argument is like:\n            ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``.\n            The above way is used in the network that wins task 1 in the BraTS21 Challenge.\n            If not specified, the way which nnUNet used will be employed. Defaults to ``None``.\n        dropout: dropout ratio. Defaults to no dropout.\n        norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.\n            `INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when:\n            1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used.\n        act_name: activation layer type and arguments. Defaults to ``leakyrelu``.\n        deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.\n            If ``True``, in training mode, the forward function will output not only the final feature map\n            (from `output_block`), but also the feature maps that come from the intermediate up sample layers.\n            In order to unify the return type (the restriction of TorchScript), all intermediate\n            feature maps are interpolated into the same size as the final feature map and stacked together\n            (with a new dimension in the first axis)into one single tensor.\n            For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and\n            (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps\n            will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).\n            When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss\n            one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.\n        deep_supr_num: number of feature maps that will output during deep supervision head. The\n            value should be larger than 0 and less than the number of up sample layers.\n            Defaults to 1.\n        res_block: whether to use residual connection based convolution blocks during the network.\n            Defaults to ``False``.\n        trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: Sequence[Union[Sequence[int], int]],\n        strides: Sequence[Union[Sequence[int], int]],\n        upsample_kernel_size: Sequence[Union[Sequence[int], int]],\n        filters: Optional[Sequence[int]] = None,\n        dropout: Optional[Union[tuple, str, float]] = None,\n        norm_name: Union[tuple, str] = (\"INSTANCE\", {\"affine\": True}),\n        act_name: Union[tuple, str] = (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.01}),\n        deep_supervision: bool = False,\n        deep_supr_num: int = 1,\n        res_block: bool = False,\n        trans_bias: bool = False,\n    ):\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = kernel_size\n        self.strides = strides\n        self.upsample_kernel_size = upsample_kernel_size\n        self.norm_name = norm_name\n        self.act_name = act_name\n        self.dropout = dropout\n        self.conv_block = UnetResBlock if res_block else UnetBasicBlock\n        self.trans_bias = trans_bias\n        if filters is not None:\n            self.filters = filters\n            self.check_filters()\n        else:\n            self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]\n        self.input_block = self.get_input_block()\n        self.downsamples = self.get_downsamples()\n        self.bottleneck = self.get_bottleneck()\n        self.upsamples = self.get_upsamples()\n        self.output_block = self.get_output_block(0)\n        self.deep_supervision = deep_supervision\n        self.deep_supr_num = deep_supr_num\n        # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on\n        self.heads: list[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num\n        if self.deep_supervision:\n            self.deep_supervision_heads = self.get_deep_supervision_heads()\n            self.check_deep_supr_num()\n\n        self.apply(self.initialize_weights)\n        self.check_kernel_stride()\n\n        def create_skips(index, downsamples, upsamples, bottleneck, superheads=None):\n            \"\"\"\n            Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is\n            done recursively from the top down since a recursive nn.Module subclass is being used to be compatible\n            with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads`\n            since the `input_block` is passed to this function as the first item in `downsamples`, however this\n            shouldn't be associated with a supervision head.\n            \"\"\"\n\n            if len(downsamples) != len(upsamples):\n                raise ValueError(f\"{len(downsamples)} != {len(upsamples)}\")\n\n            if len(downsamples) == 0:  # bottom of the network, pass the bottleneck block\n                return bottleneck\n\n            if superheads is None:\n                next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck)\n                return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)\n\n            super_head_flag = False\n            if index == 0:  # don't associate a supervision head with self.input_block\n                rest_heads = superheads\n            else:\n                if len(superheads) > 0:\n                    super_head_flag = True\n                    rest_heads = superheads[1:]\n                else:\n                    rest_heads = nn.ModuleList()\n\n            # create the next layer down, this will stop at the bottleneck layer\n            next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], bottleneck, superheads=rest_heads)\n            if super_head_flag:\n                return DynUNetSkipLayer(\n                    index,\n                    downsample=downsamples[0],\n                    upsample=upsamples[0],\n                    next_layer=next_layer,\n                    heads=self.heads,\n                    super_head=superheads[0],\n                )\n\n            return DynUNetSkipLayer(index, downsample=downsamples[0], upsample=upsamples[0], next_layer=next_layer)\n\n        if not self.deep_supervision:\n            self.skip_layers = create_skips(\n                0, [self.input_block] + list(self.downsamples), self.upsamples[::-1], self.bottleneck\n            )\n        else:\n            self.skip_layers = create_skips(\n                0,\n                [self.input_block] + list(self.downsamples),\n                self.upsamples[::-1],\n                self.bottleneck,\n                superheads=self.deep_supervision_heads,\n            )\n\n    def check_kernel_stride(self):\n        kernels, strides = self.kernel_size, self.strides\n        error_msg = \"length of kernel_size and strides should be the same, and no less than 3.\"\n        if len(kernels) != len(strides) or len(kernels) < 3:\n            raise ValueError(error_msg)\n\n        for idx, k_i in enumerate(kernels):\n            kernel, stride = k_i, strides[idx]\n            if not isinstance(kernel, int):\n                error_msg = f\"length of kernel_size in block {idx} should be the same as spatial_dims.\"\n                if len(kernel) != self.spatial_dims:\n                    raise ValueError(error_msg)\n            if not isinstance(stride, int):\n                error_msg = f\"length of stride in block {idx} should be the same as spatial_dims.\"\n                if len(stride) != self.spatial_dims:\n                    raise ValueError(error_msg)\n\n    def check_deep_supr_num(self):\n        deep_supr_num, strides = self.deep_supr_num, self.strides\n        num_up_layers = len(strides) - 1\n        if deep_supr_num >= num_up_layers:\n            raise ValueError(\"deep_supr_num should be less than the number of up sample layers.\")\n        if deep_supr_num < 1:\n            raise ValueError(\"deep_supr_num should be larger than 0.\")\n\n    def check_filters(self):\n        filters = self.filters\n        if len(filters) < len(self.strides):\n            raise ValueError(\"length of filters should be no less than the length of strides.\")\n        else:\n            self.filters = filters[: len(self.strides)]\n\n    def forward(self, x):\n        out = self.skip_layers(x)\n        out = self.output_block(out)\n        if self.training and self.deep_supervision:\n            out_all = [out]\n            for feature_map in self.heads:\n                out_all.append(interpolate(feature_map, out.shape[2:]))\n            return torch.stack(out_all, dim=1)\n        return out\n\n    def get_input_block(self):\n        return self.conv_block(\n            self.spatial_dims,\n            self.in_channels,\n            self.filters[0],\n            self.kernel_size[0],\n            self.strides[0],\n            self.norm_name,\n            self.act_name,\n            dropout=self.dropout,\n        )\n\n    def get_bottleneck(self):\n        return self.conv_block(\n            self.spatial_dims,\n            self.filters[-2],\n            self.filters[-1],\n            self.kernel_size[-1],\n            self.strides[-1],\n            self.norm_name,\n            self.act_name,\n            dropout=self.dropout,\n        )\n\n    def get_output_block(self, idx: int):\n        return UnetOutBlock(self.spatial_dims, self.filters[idx], self.out_channels, dropout=self.dropout)\n\n    def get_downsamples(self):\n        inp, out = self.filters[:-2], self.filters[1:-1]\n        strides, kernel_size = self.strides[1:-1], self.kernel_size[1:-1]\n        return self.get_module_list(inp, out, kernel_size, strides, self.conv_block)  # type: ignore\n\n    def get_upsamples(self):\n        inp, out = self.filters[1:][::-1], self.filters[:-1][::-1]\n        strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1]\n        upsample_kernel_size = self.upsample_kernel_size[::-1]\n        return self.get_module_list(\n            inp,  # type: ignore\n            out,  # type: ignore\n            kernel_size,\n            strides,\n            UnetUpBlock,  # type: ignore\n            upsample_kernel_size,\n            trans_bias=self.trans_bias,\n        )\n\n    def get_module_list(\n        self,\n        in_channels: list[int],\n        out_channels: list[int],\n        kernel_size: Sequence[Union[Sequence[int], int]],\n        strides: Sequence[Union[Sequence[int], int]],\n        conv_block: nn.Module,\n        upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None,\n        trans_bias: bool = False,\n    ):\n        layers = []\n        if upsample_kernel_size is not None:\n            for in_c, out_c, kernel, stride, up_kernel in zip(\n                in_channels, out_channels, kernel_size, strides, upsample_kernel_size\n            ):\n                params = {\n                    \"spatial_dims\": self.spatial_dims,\n                    \"in_channels\": in_c,\n                    \"out_channels\": out_c,\n                    \"kernel_size\": kernel,\n                    \"stride\": stride,\n                    \"norm_name\": self.norm_name,\n                    \"act_name\": self.act_name,\n                    \"dropout\": self.dropout,\n                    \"upsample_kernel_size\": up_kernel,\n                    \"trans_bias\": trans_bias,\n                }\n                layer = conv_block(**params)\n                layers.append(layer)\n        else:\n            for in_c, out_c, kernel, stride in zip(in_channels, out_channels, kernel_size, strides):\n                params = {\n                    \"spatial_dims\": self.spatial_dims,\n                    \"in_channels\": in_c,\n                    \"out_channels\": out_c,\n                    \"kernel_size\": kernel,\n                    \"stride\": stride,\n                    \"norm_name\": self.norm_name,\n                    \"act_name\": self.act_name,\n                    \"dropout\": self.dropout,\n                }\n                layer = conv_block(**params)\n                layers.append(layer)\n        return nn.ModuleList(layers)\n\n    def get_deep_supervision_heads(self):\n        return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)])\n\n    @staticmethod\n    def initialize_weights(module):\n        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):\n            module.weight = nn.init.kaiming_normal_(module.weight, a=0.01)\n            if module.bias is not None:\n                module.bias = nn.init.constant_(module.bias, 0)\n\n\nDynUnet = Dynunet = DynUNet\n"
  },
  {
    "path": "monai/networks/nets/efficientnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nimport operator\nimport re\nfrom functools import reduce\nfrom typing import NamedTuple\n\nimport torch\nfrom torch import nn\nfrom torch.utils import model_zoo\n\nfrom monai.networks.blocks import BaseEncoder\nfrom monai.networks.layers.factories import Act, Conv, Pad, Pool\nfrom monai.networks.layers.utils import get_norm_layer\nfrom monai.utils.module import look_up_option\n\n__all__ = [\n    \"EfficientNet\",\n    \"EfficientNetBN\",\n    \"get_efficientnet_image_size\",\n    \"drop_connect\",\n    \"EfficientNetBNFeatures\",\n    \"BlockArgs\",\n    \"EfficientNetEncoder\",\n]\n\nefficientnet_params = {\n    # model_name: (width_mult, depth_mult, image_size, dropout_rate, dropconnect_rate)\n    \"efficientnet-b0\": (1.0, 1.0, 224, 0.2, 0.2),\n    \"efficientnet-b1\": (1.0, 1.1, 240, 0.2, 0.2),\n    \"efficientnet-b2\": (1.1, 1.2, 260, 0.3, 0.2),\n    \"efficientnet-b3\": (1.2, 1.4, 300, 0.3, 0.2),\n    \"efficientnet-b4\": (1.4, 1.8, 380, 0.4, 0.2),\n    \"efficientnet-b5\": (1.6, 2.2, 456, 0.4, 0.2),\n    \"efficientnet-b6\": (1.8, 2.6, 528, 0.5, 0.2),\n    \"efficientnet-b7\": (2.0, 3.1, 600, 0.5, 0.2),\n    \"efficientnet-b8\": (2.2, 3.6, 672, 0.5, 0.2),\n    \"efficientnet-l2\": (4.3, 5.3, 800, 0.5, 0.2),\n}\n\nurl_map = {\n    \"efficientnet-b0\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth\",\n    \"efficientnet-b1\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth\",\n    \"efficientnet-b2\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth\",\n    \"efficientnet-b3\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth\",\n    \"efficientnet-b4\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth\",\n    \"efficientnet-b5\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth\",\n    \"efficientnet-b6\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth\",\n    \"efficientnet-b7\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth\",\n    # trained with adversarial examples, simplify the name to decrease string length\n    \"b0-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth\",\n    \"b1-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth\",\n    \"b2-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth\",\n    \"b3-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth\",\n    \"b4-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth\",\n    \"b5-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth\",\n    \"b6-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth\",\n    \"b7-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth\",\n    \"b8-ap\": \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth\",\n}\n\n\nclass MBConvBlock(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int,\n        stride: int,\n        image_size: list[int],\n        expand_ratio: int,\n        se_ratio: float | None,\n        id_skip: bool | None = True,\n        norm: str | tuple = (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        drop_connect_rate: float | None = 0.2,\n    ) -> None:\n        \"\"\"\n        Mobile Inverted Residual Bottleneck Block.\n\n        Args:\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            kernel_size: size of the kernel for conv ops.\n            stride: stride to use for conv ops.\n            image_size: input image resolution.\n            expand_ratio: expansion ratio for inverted bottleneck.\n            se_ratio: squeeze-excitation ratio for se layers.\n            id_skip: whether to use skip connection.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n            drop_connect_rate: dropconnect rate for drop connection (individual weights) layers.\n\n        References:\n            [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)\n            [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)\n            [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)\n        \"\"\"\n        super().__init__()\n\n        # select the type of N-Dimensional layers to use\n        # these are based on spatial dims and selected from MONAI factories\n        conv_type = Conv[\"conv\", spatial_dims]\n        adaptivepool_type = Pool[\"adaptiveavg\", spatial_dims]\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.id_skip = id_skip\n        self.stride = stride\n        self.expand_ratio = expand_ratio\n        self.drop_connect_rate = drop_connect_rate\n\n        if (se_ratio is not None) and (0.0 < se_ratio <= 1.0):\n            self.has_se = True\n            self.se_ratio = se_ratio\n        else:\n            self.has_se = False\n\n        # Expansion phase (Inverted Bottleneck)\n        inp = in_channels  # number of input channels\n        oup = in_channels * expand_ratio  # number of output channels\n        if self.expand_ratio != 1:\n            self._expand_conv = conv_type(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)\n            self._expand_conv_padding = _make_same_padder(self._expand_conv, image_size)\n\n            self._bn0 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=oup)\n        else:\n            # need to have the following to fix JIT error:\n            #   \"Module 'MBConvBlock' has no attribute '_expand_conv'\"\n\n            # FIXME: find a better way to bypass JIT error\n            self._expand_conv = nn.Identity()\n            self._expand_conv_padding = nn.Identity()\n            self._bn0 = nn.Identity()\n\n        # Depthwise convolution phase\n        self._depthwise_conv = conv_type(\n            in_channels=oup,\n            out_channels=oup,\n            groups=oup,  # groups makes it depthwise\n            kernel_size=kernel_size,\n            stride=self.stride,\n            bias=False,\n        )\n        self._depthwise_conv_padding = _make_same_padder(self._depthwise_conv, image_size)\n        self._bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=oup)\n        image_size = _calculate_output_image_size(image_size, self.stride)\n\n        # Squeeze and Excitation layer, if desired\n        if self.has_se:\n            self._se_adaptpool = adaptivepool_type(1)\n            num_squeezed_channels = max(1, int(in_channels * self.se_ratio))\n            self._se_reduce = conv_type(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)\n            self._se_reduce_padding = _make_same_padder(self._se_reduce, [1] * spatial_dims)\n            self._se_expand = conv_type(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)\n            self._se_expand_padding = _make_same_padder(self._se_expand, [1] * spatial_dims)\n\n        # Pointwise convolution phase\n        final_oup = out_channels\n        self._project_conv = conv_type(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)\n        self._project_conv_padding = _make_same_padder(self._project_conv, image_size)\n        self._bn2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=final_oup)\n\n        # swish activation to use - using memory efficient swish by default\n        # can be switched to normal swish using self.set_swish() function call\n        self._swish = Act[\"memswish\"](inplace=True)\n\n    def forward(self, inputs: torch.Tensor):\n        \"\"\"MBConvBlock\"s forward function.\n\n        Args:\n            inputs: Input tensor.\n\n        Returns:\n            Output of this block after processing.\n        \"\"\"\n        # Expansion and Depthwise Convolution\n        x = inputs\n        if self.expand_ratio != 1:\n            x = self._expand_conv(self._expand_conv_padding(x))\n            x = self._bn0(x)\n            x = self._swish(x)\n\n        x = self._depthwise_conv(self._depthwise_conv_padding(x))\n        x = self._bn1(x)\n        x = self._swish(x)\n\n        # Squeeze and Excitation\n        if self.has_se:\n            x_squeezed = self._se_adaptpool(x)\n            x_squeezed = self._se_reduce(self._se_reduce_padding(x_squeezed))\n            x_squeezed = self._swish(x_squeezed)\n            x_squeezed = self._se_expand(self._se_expand_padding(x_squeezed))\n            x = torch.sigmoid(x_squeezed) * x\n\n        # Pointwise Convolution\n        x = self._project_conv(self._project_conv_padding(x))\n        x = self._bn2(x)\n\n        # Skip connection and drop connect\n        if self.id_skip and self.stride == 1 and self.in_channels == self.out_channels:\n            # the combination of skip connection and drop connect brings about stochastic depth.\n            if self.drop_connect_rate:\n                x = drop_connect(x, p=self.drop_connect_rate, training=self.training)\n            x = x + inputs  # skip connection\n        return x\n\n    def set_swish(self, memory_efficient: bool = True) -> None:\n        \"\"\"Sets swish function as memory efficient (for training) or standard (for export).\n\n        Args:\n            memory_efficient (bool): Whether to use memory-efficient version of swish.\n        \"\"\"\n        self._swish = Act[\"memswish\"](inplace=True) if memory_efficient else Act[\"swish\"](alpha=1.0)\n\n\nclass EfficientNet(nn.Module):\n\n    def __init__(\n        self,\n        blocks_args_str: list[str],\n        spatial_dims: int = 2,\n        in_channels: int = 3,\n        num_classes: int = 1000,\n        width_coefficient: float = 1.0,\n        depth_coefficient: float = 1.0,\n        dropout_rate: float = 0.2,\n        image_size: int = 224,\n        norm: str | tuple = (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        drop_connect_rate: float = 0.2,\n        depth_divisor: int = 8,\n    ) -> None:\n        \"\"\"\n        EfficientNet based on `Rethinking Model Scaling for Convolutional Neural Networks <https://arxiv.org/pdf/1905.11946.pdf>`_.\n        Adapted from `EfficientNet-PyTorch <https://github.com/lukemelas/EfficientNet-PyTorch>`_.\n\n        Args:\n            blocks_args_str: block definitions.\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            num_classes: number of output classes.\n            width_coefficient: width multiplier coefficient (w in paper).\n            depth_coefficient: depth multiplier coefficient (d in paper).\n            dropout_rate: dropout rate for dropout layers.\n            image_size: input image resolution.\n            norm: feature normalization type and arguments.\n            drop_connect_rate: dropconnect rate for drop connection (individual weights) layers.\n            depth_divisor: depth divisor for channel rounding.\n\n        \"\"\"\n        super().__init__()\n\n        if spatial_dims not in (1, 2, 3):\n            raise ValueError(\"spatial_dims can only be 1, 2 or 3.\")\n\n        # select the type of N-Dimensional layers to use\n        # these are based on spatial dims and selected from MONAI factories\n        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[\"conv\", spatial_dims]\n        adaptivepool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[\n            \"adaptiveavg\", spatial_dims\n        ]\n\n        # decode blocks args into arguments for MBConvBlock\n        blocks_args = [BlockArgs.from_string(s) for s in blocks_args_str]\n\n        # checks for successful decoding of blocks_args_str\n        if not isinstance(blocks_args, list):\n            raise ValueError(\"blocks_args must be a list\")\n\n        if blocks_args == []:\n            raise ValueError(\"block_args must be non-empty\")\n\n        self._blocks_args = blocks_args\n        self.num_classes = num_classes\n        self.in_channels = in_channels\n        self.drop_connect_rate = drop_connect_rate\n\n        # expand input image dimensions to list\n        current_image_size = [image_size] * spatial_dims\n\n        # Stem\n        stride = 2\n        out_channels = _round_filters(32, width_coefficient, depth_divisor)  # number of output channels\n        self._conv_stem = conv_type(self.in_channels, out_channels, kernel_size=3, stride=stride, bias=False)\n        self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size)\n        self._bn0 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels)\n        current_image_size = _calculate_output_image_size(current_image_size, stride)\n\n        # build MBConv blocks\n        num_blocks = 0\n        self._blocks = nn.Sequential()\n\n        self.extract_stacks = []\n\n        # update baseline blocks to input/output filters and number of repeats based on width and depth multipliers.\n        for idx, block_args in enumerate(self._blocks_args):\n            block_args = block_args._replace(\n                input_filters=_round_filters(block_args.input_filters, width_coefficient, depth_divisor),\n                output_filters=_round_filters(block_args.output_filters, width_coefficient, depth_divisor),\n                num_repeat=_round_repeats(block_args.num_repeat, depth_coefficient),\n            )\n            self._blocks_args[idx] = block_args\n\n            # calculate the total number of blocks - needed for drop_connect estimation\n            num_blocks += block_args.num_repeat\n\n            if block_args.stride > 1:\n                self.extract_stacks.append(idx)\n\n        self.extract_stacks.append(len(self._blocks_args))\n\n        # create and add MBConvBlocks to self._blocks\n        idx = 0  # block index counter\n        for stack_idx, block_args in enumerate(self._blocks_args):\n            blk_drop_connect_rate = self.drop_connect_rate\n\n            # scale drop connect_rate\n            if blk_drop_connect_rate:\n                blk_drop_connect_rate *= float(idx) / num_blocks\n\n            sub_stack = nn.Sequential()\n            # the first block needs to take care of stride and filter size increase.\n            sub_stack.add_module(\n                str(idx),\n                MBConvBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=block_args.input_filters,\n                    out_channels=block_args.output_filters,\n                    kernel_size=block_args.kernel_size,\n                    stride=block_args.stride,\n                    image_size=current_image_size,\n                    expand_ratio=block_args.expand_ratio,\n                    se_ratio=block_args.se_ratio,\n                    id_skip=block_args.id_skip,\n                    norm=norm,\n                    drop_connect_rate=blk_drop_connect_rate,\n                ),\n            )\n            idx += 1  # increment blocks index counter\n\n            current_image_size = _calculate_output_image_size(current_image_size, block_args.stride)\n            if block_args.num_repeat > 1:  # modify block_args to keep same output size\n                block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)\n\n            # add remaining block repeated num_repeat times\n            for _ in range(block_args.num_repeat - 1):\n                blk_drop_connect_rate = self.drop_connect_rate\n\n                # scale drop connect_rate\n                if blk_drop_connect_rate:\n                    blk_drop_connect_rate *= float(idx) / num_blocks\n\n                # add blocks\n                sub_stack.add_module(\n                    str(idx),\n                    MBConvBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=block_args.input_filters,\n                        out_channels=block_args.output_filters,\n                        kernel_size=block_args.kernel_size,\n                        stride=block_args.stride,\n                        image_size=current_image_size,\n                        expand_ratio=block_args.expand_ratio,\n                        se_ratio=block_args.se_ratio,\n                        id_skip=block_args.id_skip,\n                        norm=norm,\n                        drop_connect_rate=blk_drop_connect_rate,\n                    ),\n                )\n                idx += 1  # increment blocks index counter\n\n            self._blocks.add_module(str(stack_idx), sub_stack)\n\n        # sanity check to see if len(self._blocks) equal expected num_blocks\n        if idx != num_blocks:\n            raise ValueError(\"total number of blocks created != num_blocks\")\n\n        # Head\n        head_in_channels = block_args.output_filters\n        out_channels = _round_filters(1280, width_coefficient, depth_divisor)\n        self._conv_head = conv_type(head_in_channels, out_channels, kernel_size=1, bias=False)\n        self._conv_head_padding = _make_same_padder(self._conv_head, current_image_size)\n        self._bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels)\n\n        # final linear layer\n        self._avg_pooling = adaptivepool_type(1)\n        self._dropout = nn.Dropout(dropout_rate)\n        self._fc = nn.Linear(out_channels, self.num_classes)\n\n        # swish activation to use - using memory efficient swish by default\n        # can be switched to normal swish using self.set_swish() function call\n        self._swish = Act[\"memswish\"]()\n\n        # initialize weights using Tensorflow's init method from official impl.\n        self._initialize_weights()\n\n    def set_swish(self, memory_efficient: bool = True) -> None:\n        \"\"\"\n        Sets swish function as memory efficient (for training) or standard (for JIT export).\n\n        Args:\n            memory_efficient: whether to use memory-efficient version of swish.\n\n        \"\"\"\n        self._swish = Act[\"memswish\"]() if memory_efficient else Act[\"swish\"](alpha=1.0)\n        sub_stack: nn.Sequential\n        block: MBConvBlock\n        for sub_stack in self._blocks:  # type: ignore[assignment]\n            for block in sub_stack:  # type: ignore[assignment]\n                block.set_swish(memory_efficient)\n\n    def forward(self, inputs: torch.Tensor):\n        \"\"\"\n        Args:\n            inputs: input should have spatially N dimensions\n            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.\n\n        Returns:\n            a torch Tensor of classification prediction in shape ``(Batch, num_classes)``.\n        \"\"\"\n        # Stem\n        x = self._conv_stem(self._conv_stem_padding(inputs))\n        x = self._swish(self._bn0(x))\n        # Blocks\n        x = self._blocks(x)\n        # Head\n        x = self._conv_head(self._conv_head_padding(x))\n        x = self._swish(self._bn1(x))\n\n        # Pooling and final linear layer\n        x = self._avg_pooling(x)\n\n        x = x.flatten(start_dim=1)\n        x = self._dropout(x)\n        x = self._fc(x)\n        return x\n\n    def _initialize_weights(self) -> None:\n        \"\"\"\n        Args:\n            None, initializes weights for conv/linear/batchnorm layers\n            following weight init methods from\n            `official Tensorflow EfficientNet implementation\n            <https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py#L61>`_.\n            Adapted from `EfficientNet-PyTorch's init method\n            <https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/efficientnet_builder.py>`_.\n        \"\"\"\n        for _, m in self.named_modules():\n            if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n                fan_out = reduce(operator.mul, m.kernel_size, 1) * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):\n                m.weight.data.fill_(1.0)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                fan_out = m.weight.size(0)\n                fan_in = 0\n                init_range = 1.0 / math.sqrt(fan_in + fan_out)\n                m.weight.data.uniform_(-init_range, init_range)\n                m.bias.data.zero_()\n\n\nclass EfficientNetBN(EfficientNet):\n\n    def __init__(\n        self,\n        model_name: str,\n        pretrained: bool = True,\n        progress: bool = True,\n        spatial_dims: int = 2,\n        in_channels: int = 3,\n        num_classes: int = 1000,\n        norm: str | tuple = (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        adv_prop: bool = False,\n    ) -> None:\n        \"\"\"\n        Generic wrapper around EfficientNet, used to initialize EfficientNet-B0 to EfficientNet-B7 models\n        model_name is mandatory argument as there is no EfficientNetBN itself,\n        it needs the N in [0, 1, 2, 3, 4, 5, 6, 7, 8] to be a model\n\n        Args:\n            model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2].\n            pretrained: whether to initialize pretrained ImageNet weights, only available for spatial_dims=2 and batch\n                norm is used.\n            progress: whether to show download progress for pretrained weights download.\n            spatial_dims: number of spatial dimensions.\n            in_channels: number of input channels.\n            num_classes: number of output classes.\n            norm: feature normalization type and arguments.\n            adv_prop: whether to use weights trained with adversarial examples.\n                This argument only works when `pretrained` is `True`.\n\n        Examples::\n\n            # for pretrained spatial 2D ImageNet\n            >>> image_size = get_efficientnet_image_size(\"efficientnet-b0\")\n            >>> inputs = torch.rand(1, 3, image_size, image_size)\n            >>> model = EfficientNetBN(\"efficientnet-b0\", pretrained=True)\n            >>> model.eval()\n            >>> outputs = model(inputs)\n\n            # create spatial 2D\n            >>> model = EfficientNetBN(\"efficientnet-b0\", spatial_dims=2)\n\n            # create spatial 3D\n            >>> model = EfficientNetBN(\"efficientnet-b0\", spatial_dims=3)\n\n            # create EfficientNetB7 for spatial 2D\n            >>> model = EfficientNetBN(\"efficientnet-b7\", spatial_dims=2)\n\n        \"\"\"\n        # block args\n        blocks_args_str = [\n            \"r1_k3_s11_e1_i32_o16_se0.25\",\n            \"r2_k3_s22_e6_i16_o24_se0.25\",\n            \"r2_k5_s22_e6_i24_o40_se0.25\",\n            \"r3_k3_s22_e6_i40_o80_se0.25\",\n            \"r3_k5_s11_e6_i80_o112_se0.25\",\n            \"r4_k5_s22_e6_i112_o192_se0.25\",\n            \"r1_k3_s11_e6_i192_o320_se0.25\",\n        ]\n\n        # check if model_name is valid model\n        if model_name not in efficientnet_params:\n            model_name_string = \", \".join(efficientnet_params.keys())\n            raise ValueError(f\"invalid model_name {model_name} found, must be one of {model_name_string} \")\n\n        # get network parameters\n        weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]\n\n        # create model and initialize random weights\n        super().__init__(\n            blocks_args_str=blocks_args_str,\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            num_classes=num_classes,\n            width_coefficient=weight_coeff,\n            depth_coefficient=depth_coeff,\n            dropout_rate=dropout_rate,\n            image_size=image_size,\n            drop_connect_rate=dropconnect_rate,\n            norm=norm,\n        )\n\n        # only pretrained for when `spatial_dims` is 2\n        if pretrained and (spatial_dims == 2):\n            _load_state_dict(self, model_name, progress, adv_prop)\n\n\nclass EfficientNetBNFeatures(EfficientNet):\n\n    def __init__(\n        self,\n        model_name: str,\n        pretrained: bool = True,\n        progress: bool = True,\n        spatial_dims: int = 2,\n        in_channels: int = 3,\n        num_classes: int = 1000,\n        norm: str | tuple = (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        adv_prop: bool = False,\n    ) -> None:\n        \"\"\"\n        Initialize EfficientNet-B0 to EfficientNet-B7 models as a backbone, the backbone can\n        be used as an encoder for segmentation and objection models.\n        Compared with the class `EfficientNetBN`, the only different place is the forward function.\n\n        This class refers to `PyTorch image models <https://github.com/rwightman/pytorch-image-models>`_.\n\n        \"\"\"\n        blocks_args_str = [\n            \"r1_k3_s11_e1_i32_o16_se0.25\",\n            \"r2_k3_s22_e6_i16_o24_se0.25\",\n            \"r2_k5_s22_e6_i24_o40_se0.25\",\n            \"r3_k3_s22_e6_i40_o80_se0.25\",\n            \"r3_k5_s11_e6_i80_o112_se0.25\",\n            \"r4_k5_s22_e6_i112_o192_se0.25\",\n            \"r1_k3_s11_e6_i192_o320_se0.25\",\n        ]\n\n        # check if model_name is valid model\n        if model_name not in efficientnet_params:\n            model_name_string = \", \".join(efficientnet_params.keys())\n            raise ValueError(f\"invalid model_name {model_name} found, must be one of {model_name_string} \")\n\n        # get network parameters\n        weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]\n\n        # create model and initialize random weights\n        super().__init__(\n            blocks_args_str=blocks_args_str,\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            num_classes=num_classes,\n            width_coefficient=weight_coeff,\n            depth_coefficient=depth_coeff,\n            dropout_rate=dropout_rate,\n            image_size=image_size,\n            drop_connect_rate=dropconnect_rate,\n            norm=norm,\n        )\n\n        # only pretrained for when `spatial_dims` is 2\n        if pretrained and (spatial_dims == 2):\n            _load_state_dict(self, model_name, progress, adv_prop)\n\n    def forward(self, inputs: torch.Tensor):\n        \"\"\"\n        Args:\n            inputs: input should have spatially N dimensions\n            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.\n\n        Returns:\n            a list of torch Tensors.\n        \"\"\"\n        # Stem\n        x = self._conv_stem(self._conv_stem_padding(inputs))\n        x = self._swish(self._bn0(x))\n\n        features = []\n        if 0 in self.extract_stacks:\n            features.append(x)\n        for i, block in enumerate(self._blocks):\n            x = block(x)\n            if i + 1 in self.extract_stacks:\n                features.append(x)\n        return features\n\n\nclass EfficientNetEncoder(EfficientNetBNFeatures, BaseEncoder):\n    \"\"\"\n    Wrap the original efficientnet to an encoder for flexible-unet.\n    \"\"\"\n\n    backbone_names = [\n        \"efficientnet-b0\",\n        \"efficientnet-b1\",\n        \"efficientnet-b2\",\n        \"efficientnet-b3\",\n        \"efficientnet-b4\",\n        \"efficientnet-b5\",\n        \"efficientnet-b6\",\n        \"efficientnet-b7\",\n        \"efficientnet-b8\",\n        \"efficientnet-l2\",\n    ]\n\n    @classmethod\n    def get_encoder_parameters(cls) -> list[dict]:\n        \"\"\"\n        Get the initialization parameter for efficientnet backbones.\n        \"\"\"\n        parameter_list = []\n        for backbone_name in cls.backbone_names:\n            parameter_list.append(\n                {\n                    \"model_name\": backbone_name,\n                    \"pretrained\": True,\n                    \"progress\": True,\n                    \"spatial_dims\": 2,\n                    \"in_channels\": 3,\n                    \"num_classes\": 1000,\n                    \"norm\": (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n                    \"adv_prop\": \"ap\" in backbone_name,\n                }\n            )\n        return parameter_list\n\n    @classmethod\n    def num_channels_per_output(cls) -> list[tuple[int, ...]]:\n        \"\"\"\n        Get number of efficientnet backbone output feature maps' channel.\n        \"\"\"\n        return [\n            (16, 24, 40, 112, 320),\n            (16, 24, 40, 112, 320),\n            (16, 24, 48, 120, 352),\n            (24, 32, 48, 136, 384),\n            (24, 32, 56, 160, 448),\n            (24, 40, 64, 176, 512),\n            (32, 40, 72, 200, 576),\n            (32, 48, 80, 224, 640),\n            (32, 56, 88, 248, 704),\n            (72, 104, 176, 480, 1376),\n        ]\n\n    @classmethod\n    def num_outputs(cls) -> list[int]:\n        \"\"\"\n        Get number of efficientnet backbone output feature maps.\n        Since every backbone contains the same 5 output feature maps,\n        the number list should be `[5] * 10`.\n        \"\"\"\n        return [5] * 10\n\n    @classmethod\n    def get_encoder_names(cls) -> list[str]:\n        \"\"\"\n        Get names of efficient backbone.\n        \"\"\"\n        return cls.backbone_names\n\n\ndef get_efficientnet_image_size(model_name: str) -> int:\n    \"\"\"\n    Get the input image size for a given efficientnet model.\n\n    Args:\n        model_name: name of model to initialize, can be from [efficientnet-b0, ..., efficientnet-b7].\n\n    Returns:\n        Image size for single spatial dimension as integer.\n\n    \"\"\"\n    # check if model_name is valid model\n    if model_name not in efficientnet_params:\n        model_name_string = \", \".join(efficientnet_params.keys())\n        raise ValueError(f\"invalid model_name {model_name} found, must be one of {model_name_string} \")\n\n    # return input image size (all dims equal so only need to return for one dim)\n    _, _, res, _, _ = efficientnet_params[model_name]\n    return res\n\n\ndef drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor:\n    \"\"\"\n    Drop connect layer that drops individual connections.\n    Differs from dropout as dropconnect drops connections instead of whole neurons as in dropout.\n\n    Based on `Deep Networks with Stochastic Depth <https://arxiv.org/pdf/1603.09382.pdf>`_.\n    Adapted from `Official Tensorflow EfficientNet utils\n    <https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/utils.py>`_.\n\n    This function is generalized for MONAI's N-Dimensional spatial activations\n    e.g. 1D activations [B, C, H], 2D activations [B, C, H, W] and 3D activations [B, C, H, W, D]\n\n    Args:\n        inputs: input tensor with [B, C, dim_1, dim_2, ..., dim_N] where N=spatial_dims.\n        p: probability to use for dropping connections.\n        training: whether in training or evaluation mode.\n\n    Returns:\n        output: output tensor after applying drop connection.\n    \"\"\"\n    if p < 0.0 or p > 1.0:\n        raise ValueError(f\"p must be in range of [0, 1], found {p}\")\n\n    # eval mode: drop_connect is switched off - so return input without modifying\n    if not training:\n        return inputs\n\n    # train mode: calculate and apply drop_connect\n    batch_size: int = inputs.shape[0]\n    keep_prob: float = 1 - p\n    num_dims: int = len(inputs.shape) - 2\n\n    # build dimensions for random tensor, use num_dims to populate appropriate spatial dims\n    random_tensor_shape: list[int] = [batch_size, 1] + [1] * num_dims\n\n    # generate binary_tensor mask according to probability (p for 0, 1-p for 1)\n    random_tensor: torch.Tensor = torch.rand(random_tensor_shape, dtype=inputs.dtype, device=inputs.device)\n    random_tensor += keep_prob\n\n    # round to form binary tensor\n    binary_tensor: torch.Tensor = torch.floor(random_tensor)\n\n    # drop connect using binary tensor\n    output: torch.Tensor = inputs / keep_prob * binary_tensor\n    return output\n\n\ndef _load_state_dict(model: nn.Module, arch: str, progress: bool, adv_prop: bool) -> None:\n    if adv_prop:\n        arch = arch.split(\"efficientnet-\")[-1] + \"-ap\"\n    model_url = look_up_option(arch, url_map, None)\n    if model_url is None:\n        print(f\"pretrained weights of {arch} is not provided\")\n    else:\n        # load state dict from url\n        model_url = url_map[arch]\n        pretrain_state_dict = model_zoo.load_url(model_url, progress=progress)\n        model_state_dict = model.state_dict()\n\n        pattern = re.compile(r\"(.+)\\.\\d+(\\.\\d+\\..+)\")\n        for key, value in model_state_dict.items():\n            pretrain_key = re.sub(pattern, r\"\\1\\2\", key)\n            if pretrain_key in pretrain_state_dict and value.shape == pretrain_state_dict[pretrain_key].shape:\n                model_state_dict[key] = pretrain_state_dict[pretrain_key]\n\n        model.load_state_dict(model_state_dict)\n\n\ndef _get_same_padding_conv_nd(\n    image_size: list[int], kernel_size: tuple[int, ...], dilation: tuple[int, ...], stride: tuple[int, ...]\n) -> list[int]:\n    \"\"\"\n    Helper for getting padding (nn.ConstantPadNd) to be used to get SAME padding\n    conv operations similar to Tensorflow's SAME padding.\n\n    This function is generalized for MONAI's N-Dimensional spatial operations (e.g. Conv1D, Conv2D, Conv3D)\n\n    Args:\n        image_size: input image/feature spatial size.\n        kernel_size: conv kernel's spatial size.\n        dilation: conv dilation rate for Atrous conv.\n        stride: stride for conv operation.\n\n    Returns:\n        paddings for ConstantPadNd padder to be used on input tensor to conv op.\n    \"\"\"\n    # get number of spatial dimensions, corresponds to kernel size length\n    num_dims = len(kernel_size)\n\n    # additional checks to populate dilation and stride (in case they are single entry tuples)\n    if len(dilation) == 1:\n        dilation = dilation * num_dims\n\n    if len(stride) == 1:\n        stride = stride * num_dims\n\n    # equation to calculate (pad^+ + pad^-) size\n    _pad_size: list[int] = [\n        max((math.ceil(_i_s / _s) - 1) * _s + (_k_s - 1) * _d + 1 - _i_s, 0)\n        for _i_s, _k_s, _d, _s in zip(image_size, kernel_size, dilation, stride)\n    ]\n    # distribute paddings into pad^+ and pad^- following Tensorflow's same padding strategy\n    _paddings: list[tuple[int, int]] = [(_p // 2, _p - _p // 2) for _p in _pad_size]\n\n    # unroll list of tuples to tuples, and then to list\n    # reversed as nn.ConstantPadNd expects paddings starting with last dimension\n    _paddings_ret: list[int] = [outer for inner in reversed(_paddings) for outer in inner]\n    return _paddings_ret\n\n\ndef _make_same_padder(conv_op: nn.Conv1d | nn.Conv2d | nn.Conv3d, image_size: list[int]):\n    \"\"\"\n    Helper for initializing ConstantPadNd with SAME padding similar to Tensorflow.\n    Uses output of _get_same_padding_conv_nd() to get the padding size.\n\n    This function is generalized for MONAI's N-Dimensional spatial operations (e.g. Conv1D, Conv2D, Conv3D)\n\n    Args:\n        conv_op: nn.ConvNd operation to extract parameters for op from\n        image_size: input image/feature spatial size\n\n    Returns:\n        If padding required then nn.ConstandNd() padder initialized to paddings otherwise nn.Identity()\n    \"\"\"\n    # calculate padding required\n    padding: list[int] = _get_same_padding_conv_nd(image_size, conv_op.kernel_size, conv_op.dilation, conv_op.stride)\n\n    # initialize and return padder\n    padder = Pad[\"constantpad\", len(padding) // 2]\n    if sum(padding) > 0:\n        return padder(padding=padding, value=0.0)\n    return nn.Identity()\n\n\ndef _round_filters(filters: int, width_coefficient: float | None, depth_divisor: float) -> int:\n    \"\"\"\n    Calculate and round number of filters based on width coefficient multiplier and depth divisor.\n\n    Args:\n        filters: number of input filters.\n        width_coefficient: width coefficient for model.\n        depth_divisor: depth divisor to use.\n\n    Returns:\n        new_filters: new number of filters after calculation.\n    \"\"\"\n\n    if not width_coefficient:\n        return filters\n\n    multiplier: float = width_coefficient\n    divisor: float = depth_divisor\n    filters_float: float = filters * multiplier\n\n    # follow the formula transferred from official TensorFlow implementation\n    new_filters: float = max(divisor, int(filters_float + divisor / 2) // divisor * divisor)\n    if new_filters < 0.9 * filters_float:  # prevent rounding by more than 10%\n        new_filters += divisor\n    return int(new_filters)\n\n\ndef _round_repeats(repeats: int, depth_coefficient: float | None) -> int:\n    \"\"\"\n    Re-calculate module's repeat number of a block based on depth coefficient multiplier.\n\n    Args:\n        repeats: number of original repeats.\n        depth_coefficient: depth coefficient for model.\n\n    Returns:\n        new repeat: new number of repeat after calculating.\n    \"\"\"\n    if not depth_coefficient:\n        return repeats\n\n    # follow the formula transferred from official TensorFlow impl.\n    return int(math.ceil(depth_coefficient * repeats))\n\n\ndef _calculate_output_image_size(input_image_size: list[int], stride: int | tuple[int]):\n    \"\"\"\n    Calculates the output image size when using _make_same_padder with a stride.\n    Required for static padding.\n\n    Args:\n        input_image_size: input image/feature spatial size.\n        stride: Conv2d operation\"s stride.\n\n    Returns:\n        output_image_size: output image/feature spatial size.\n    \"\"\"\n\n    # checks to extract integer stride in case tuple was received\n    if isinstance(stride, tuple):\n        all_strides_equal = all(stride[0] == s for s in stride)\n        if not all_strides_equal:\n            raise ValueError(f\"unequal strides are not possible, got {stride}\")\n\n        stride = stride[0]\n\n    # return output image size\n    return [int(math.ceil(im_sz / stride)) for im_sz in input_image_size]\n\n\nclass BlockArgs(NamedTuple):\n    \"\"\"\n    BlockArgs object to assist in decoding string notation\n        of arguments for MBConvBlock definition.\n    \"\"\"\n\n    num_repeat: int\n    kernel_size: int\n    stride: int\n    expand_ratio: int\n    input_filters: int\n    output_filters: int\n    id_skip: bool\n    se_ratio: float | None = None\n\n    @staticmethod\n    def from_string(block_string: str):\n        \"\"\"\n        Get a BlockArgs object from a string notation of arguments.\n\n        Args:\n            block_string (str): A string notation of arguments.\n                                Examples: \"r1_k3_s11_e1_i32_o16_se0.25\".\n\n        Returns:\n            BlockArgs: namedtuple defined at the top of this function.\n        \"\"\"\n        ops = block_string.split(\"_\")\n        options = {}\n        for op in ops:\n            splits = re.split(r\"(\\d.*)\", op)\n            if len(splits) >= 2:\n                key, value = splits[:2]\n                options[key] = value\n\n        # check stride\n        stride_check = (\n            (\"s\" in options and len(options[\"s\"]) == 1)\n            or (len(options[\"s\"]) == 2 and options[\"s\"][0] == options[\"s\"][1])\n            or (len(options[\"s\"]) == 3 and options[\"s\"][0] == options[\"s\"][1] and options[\"s\"][0] == options[\"s\"][2])\n        )\n        if not stride_check:\n            raise ValueError(\"invalid stride option received\")\n\n        return BlockArgs(\n            num_repeat=int(options[\"r\"]),\n            kernel_size=int(options[\"k\"]),\n            stride=int(options[\"s\"][0]),\n            expand_ratio=int(options[\"e\"]),\n            input_filters=int(options[\"i\"]),\n            output_filters=int(options[\"o\"]),\n            id_skip=(\"noskip\" not in block_string),\n            se_ratio=float(options[\"se\"]) if \"se\" in options else None,\n        )\n\n    def to_string(self):\n        \"\"\"\n        Return a block string notation for current BlockArgs object\n\n        Returns:\n            A string notation of BlockArgs object arguments.\n                Example: \"r1_k3_s11_e1_i32_o16_se0.25_noskip\".\n        \"\"\"\n        string = (\n            f\"r{self.num_repeat}_k{self.kernel_size}_s{self.stride}{self.stride}\"\n            f\"_e{self.expand_ratio}_i{self.input_filters}_o{self.output_filters}\"\n            f\"_se{self.se_ratio}\"\n        )\n\n        if not self.id_skip:\n            string += \"_noskip\"\n        return string\n"
  },
  {
    "path": "monai/networks/nets/flexible_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\nfrom pydoc import locate\nfrom typing import Any\n\nimport torch\nfrom torch import nn\n\nfrom monai.networks.blocks import BaseEncoder, UpSample\nfrom monai.networks.layers.factories import Conv\nfrom monai.networks.layers.utils import get_act_layer\nfrom monai.networks.nets import EfficientNetEncoder\nfrom monai.networks.nets.basic_unet import UpCat\nfrom monai.networks.nets.resnet import ResNetEncoder\nfrom monai.utils import InterpolateMode, optional_import\n\n__all__ = [\"FlexibleUNet\", \"FlexUNet\", \"FLEXUNET_BACKBONE\", \"FlexUNetEncoderRegister\"]\n\n\nclass FlexUNetEncoderRegister:\n    \"\"\"\n    A register to regist backbones for the flexible unet. All backbones can be found in\n    register_dict. Please notice each output of backbone must be 2x downsample in spatial\n    dimension of last output. For example, if given a 512x256 2D image and a backbone with\n    4 outputs. Then spatial size of each encoder output should be 256x128, 128x64, 64x32\n    and 32x16.\n    \"\"\"\n\n    def __init__(self):\n        self.register_dict = {}\n\n    def register_class(self, name: type[Any] | str):\n        \"\"\"\n        Register a given class to the encoder dict. Please notice that input class must be a\n        subclass of BaseEncoder.\n        \"\"\"\n        if isinstance(name, str):\n            tmp_name, has_built_in = optional_import(\"monai.networks.nets\", name=f\"{name}\")  # search built-in\n            if not has_built_in:\n                tmp_name = locate(f\"{name}\")  # search dotted path\n            name = tmp_name\n            if not isinstance(name, type):\n                raise ValueError(f\"Cannot find {name} class.\")\n\n        if not issubclass(name, BaseEncoder):\n            warnings.warn(\n                f\"{name} would better be derived from monai.networks.blocks.BaseEncoder \"\n                \"or implement all interfaces specified by it.\"\n            )\n\n        name_string_list = name.get_encoder_names()\n        feature_number_list = name.num_outputs()\n        feature_channel_list = name.num_channels_per_output()\n        parameter_list = name.get_encoder_parameters()\n\n        assert len(name_string_list) == len(feature_number_list) == len(feature_channel_list) == len(parameter_list)\n        for cnt, name_string in enumerate(name_string_list):\n            cur_dict = {\n                \"type\": name,\n                \"feature_number\": feature_number_list[cnt],\n                \"feature_channel\": feature_channel_list[cnt],\n                \"parameter\": parameter_list[cnt],\n            }\n            self.register_dict[name_string] = cur_dict\n\n\nFLEXUNET_BACKBONE = FlexUNetEncoderRegister()\nFLEXUNET_BACKBONE.register_class(EfficientNetEncoder)\nFLEXUNET_BACKBONE.register_class(ResNetEncoder)\n\n\nclass UNetDecoder(nn.Module):\n    \"\"\"\n    UNet Decoder.\n    This class refers to `segmentation_models.pytorch\n    <https://github.com/qubvel/segmentation_models.pytorch>`_.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        encoder_channels: number of output channels for all feature maps in encoder.\n            `len(encoder_channels)` should be no less than 2.\n        decoder_channels: number of output channels for all feature maps in decoder.\n            `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.\n        act: activation type and arguments.\n        norm: feature normalization type and arguments.\n        dropout: dropout ratio.\n        bias: whether to have a bias term in convolution blocks in this decoder.\n        upsample: upsampling mode, available options are\n            ``\"deconv\"``, ``\"pixelshuffle\"``, ``\"nontrainable\"``.\n        pre_conv: a conv block applied before upsampling.\n            Only used in the \"nontrainable\" or \"pixelshuffle\" mode.\n        interp_mode: {``\"nearest\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``}\n            Only used in the \"nontrainable\" mode.\n        align_corners: set the align_corners parameter for upsample. Defaults to True.\n            Only used in the \"nontrainable\" mode.\n        is_pad: whether to pad upsampling features to fit the encoder spatial dims.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        encoder_channels: Sequence[int],\n        decoder_channels: Sequence[int],\n        act: str | tuple,\n        norm: str | tuple,\n        dropout: float | tuple,\n        bias: bool,\n        upsample: str,\n        pre_conv: str | None,\n        interp_mode: str,\n        align_corners: bool | None,\n        is_pad: bool,\n    ):\n        super().__init__()\n        if len(encoder_channels) < 2:\n            raise ValueError(\"the length of `encoder_channels` should be no less than 2.\")\n        if len(decoder_channels) != len(encoder_channels) - 1:\n            raise ValueError(\"`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.\")\n\n        in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])\n        skip_channels = list(encoder_channels[1:-1][::-1]) + [0]\n        halves = [True] * (len(skip_channels) - 1)\n        halves.append(False)\n        blocks = []\n        for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):\n            blocks.append(\n                UpCat(\n                    spatial_dims=spatial_dims,\n                    in_chns=in_chn,\n                    cat_chns=skip_chn,\n                    out_chns=out_chn,\n                    act=act,\n                    norm=norm,\n                    dropout=dropout,\n                    bias=bias,\n                    upsample=upsample,\n                    pre_conv=pre_conv,\n                    interp_mode=interp_mode,\n                    align_corners=align_corners,\n                    halves=halve,\n                    is_pad=is_pad,\n                )\n            )\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, features: list[torch.Tensor], skip_connect: int = 4):\n        skips = features[:-1][::-1]\n        features = features[1:][::-1]\n\n        x = features[0]\n        for i, block in enumerate(self.blocks):\n            if i < skip_connect:\n                skip = skips[i]\n            else:\n                skip = None\n            x = block(x, skip)\n\n        return x\n\n\nclass SegmentationHead(nn.Sequential):\n    \"\"\"\n    Segmentation head.\n    This class refers to `segmentation_models.pytorch\n    <https://github.com/qubvel/segmentation_models.pytorch>`_.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels for the block.\n        out_channels: number of output channels for the block.\n        kernel_size: kernel size for the conv layer.\n        act: activation type and arguments.\n        scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernel_size: int = 3,\n        act: tuple | str | None = None,\n        scale_factor: float = 1.0,\n    ):\n        conv_layer = Conv[Conv.CONV, spatial_dims](\n            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2\n        )\n        up_layer: nn.Module = nn.Identity()\n        if scale_factor > 1.0:\n            up_layer = UpSample(\n                spatial_dims=spatial_dims,\n                scale_factor=scale_factor,\n                mode=\"nontrainable\",\n                pre_conv=None,\n                interp_mode=InterpolateMode.LINEAR,\n            )\n        if act is not None:\n            act_layer = get_act_layer(act)\n        else:\n            act_layer = nn.Identity()\n        super().__init__(conv_layer, up_layer, act_layer)\n\n\nclass FlexibleUNet(nn.Module):\n    \"\"\"\n    A flexible implementation of UNet-like encoder-decoder architecture.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        backbone: str,\n        pretrained: bool = False,\n        decoder_channels: tuple = (256, 128, 64, 32, 16),\n        spatial_dims: int = 2,\n        norm: str | tuple = (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.1}),\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        dropout: float | tuple = 0.0,\n        decoder_bias: bool = False,\n        upsample: str = \"nontrainable\",\n        pre_conv: str = \"default\",\n        interp_mode: str = \"nearest\",\n        is_pad: bool = True,\n    ) -> None:\n        \"\"\"\n        A flexible implement of UNet, in which the backbone/encoder can be replaced with\n        any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension\n        and the spatial size of each dimension must be a multiple of 32 if is_pad parameter\n        is False.\n        Please notice each output of backbone must be 2x downsample in spatial dimension\n        of last output. For example, if given a 512x256 2D image and a backbone with 4 outputs.\n        Spatial size of each encoder output should be 256x128, 128x64, 64x32 and 32x16.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            backbone: name of backbones to initialize, only support efficientnet and resnet right now,\n                can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].\n            pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks\n                if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks\n                if spatial_dims=3 and in_channels=1. Default to False.\n            decoder_channels: number of output channels for all feature maps in decoder.\n                `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default\n                to (256, 128, 64, 32, 16).\n            spatial_dims: number of spatial dimensions, default to 2.\n            norm: normalization type and arguments, default to (\"batch\", {\"eps\": 1e-3,\n                \"momentum\": 0.1}).\n            act: activation type and arguments, default to (\"relu\", {\"inplace\": True}).\n            dropout: dropout ratio, default to 0.0.\n            decoder_bias: whether to have a bias term in decoder's convolution blocks.\n            upsample: upsampling mode, available options are``\"deconv\"``, ``\"pixelshuffle\"``,\n                ``\"nontrainable\"``.\n            pre_conv:a conv block applied before upsampling. Only used in the \"nontrainable\" or\n                \"pixelshuffle\" mode, default to `default`.\n            interp_mode: {``\"nearest\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``}\n                Only used in the \"nontrainable\" mode.\n            is_pad: whether to pad upsampling features to fit features from encoder. Default to True.\n                If this parameter is set to \"True\", the spatial dim of network input can be arbitrary\n                size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.\n        \"\"\"\n        super().__init__()\n\n        if backbone not in FLEXUNET_BACKBONE.register_dict:\n            raise ValueError(\n                f\"invalid model_name {backbone} found, must be one of {FLEXUNET_BACKBONE.register_dict.keys()}.\"\n            )\n\n        if spatial_dims not in (2, 3):\n            raise ValueError(\"spatial_dims can only be 2 or 3.\")\n\n        encoder = FLEXUNET_BACKBONE.register_dict[backbone]\n        self.backbone = backbone\n        self.spatial_dims = spatial_dims\n        encoder_parameters = encoder[\"parameter\"]\n        if not (\n            (\"spatial_dims\" in encoder_parameters)\n            and (\"in_channels\" in encoder_parameters)\n            and (\"pretrained\" in encoder_parameters)\n        ):\n            raise ValueError(\"The backbone init method must have spatial_dims, in_channels and pretrained parameters.\")\n        encoder_feature_num = encoder[\"feature_number\"]\n        if encoder_feature_num > 5:\n            raise ValueError(\"Flexible unet can only accept no more than 5 encoder feature maps.\")\n\n        decoder_channels = decoder_channels[:encoder_feature_num]\n        self.skip_connect = encoder_feature_num - 1\n        encoder_parameters.update({\"spatial_dims\": spatial_dims, \"in_channels\": in_channels, \"pretrained\": pretrained})\n        encoder_channels = tuple([in_channels] + list(encoder[\"feature_channel\"]))\n        encoder_type = encoder[\"type\"]\n        self.encoder = encoder_type(**encoder_parameters)\n\n        self.decoder = UNetDecoder(\n            spatial_dims=spatial_dims,\n            encoder_channels=encoder_channels,\n            decoder_channels=decoder_channels,\n            act=act,\n            norm=norm,\n            dropout=dropout,\n            bias=decoder_bias,\n            upsample=upsample,\n            interp_mode=interp_mode,\n            pre_conv=pre_conv,\n            align_corners=None,\n            is_pad=is_pad,\n        )\n        self.segmentation_head = SegmentationHead(\n            spatial_dims=spatial_dims,\n            in_channels=decoder_channels[-1],\n            out_channels=out_channels,\n            kernel_size=3,\n            act=None,\n        )\n\n    def forward(self, inputs: torch.Tensor):\n        \"\"\"\n        Do a typical encoder-decoder-header inference.\n\n        Args:\n            inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,\n                N is defined by `dimensions`.\n\n        Returns:\n            A torch Tensor of \"raw\" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.\n\n        \"\"\"\n        x = inputs\n        enc_out = self.encoder(x)\n        decoder_out = self.decoder(enc_out, self.skip_connect)\n        x_seg = self.segmentation_head(decoder_out)\n\n        return x_seg\n\n\nFlexUNet = FlexibleUNet\n"
  },
  {
    "path": "monai/networks/nets/fullyconnectednet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import ADN\nfrom monai.networks.layers.factories import Act\n\n__all__ = [\"FullyConnectedNet\", \"VarFullyConnectedNet\"]\n\n\ndef _get_adn_layer(act: tuple | str | None, dropout: tuple | str | float | None, ordering: str | None) -> ADN:\n    if ordering:\n        return ADN(act=act, dropout=dropout, dropout_dim=1, ordering=ordering)\n    return ADN(act=act, dropout=dropout, dropout_dim=1)\n\n\nclass FullyConnectedNet(nn.Sequential):\n    \"\"\"\n    Simple full-connected layer neural network composed of a sequence of linear layers with PReLU activation and\n    dropout.  The network accepts input with `in_channels` channels, has output with `out_channels` channels, and\n    hidden layer output channels given in `hidden_channels`. If `bias` is True then linear units have a bias term.\n\n    Args:\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        hidden_channels: number of output channels for each hidden layer.\n        dropout: dropout ratio. Defaults to no dropout.\n        act: activation type and arguments. Defaults to PReLU.\n        bias: whether to have a bias term in linear units. Defaults to True.\n        adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`.\n\n    Examples::\n\n        # accepts 4 values and infers 3 values as output, has 3 hidden layers with 10, 20, 10 values as output\n        net = FullyConnectedNet(4, 3, [10, 20, 10], dropout=0.2)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        hidden_channels: Sequence[int],\n        dropout: tuple | str | float | None = None,\n        act: tuple | str | None = Act.PRELU,\n        bias: bool = True,\n        adn_ordering: str | None = None,\n    ) -> None:\n        \"\"\"\n        Defines a network accept input with `in_channels` channels, output of `out_channels` channels, and hidden layers\n        with channels given in `hidden_channels`. If `bias` is True then linear units have a bias term.\n        \"\"\"\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.hidden_channels = list(hidden_channels)\n        self.act = act\n        self.dropout = dropout\n        self.adn_ordering = adn_ordering\n\n        self.add_module(\"flatten\", nn.Flatten())\n\n        prev_channels = self.in_channels\n        for i, c in enumerate(hidden_channels):\n            self.add_module(f\"hidden_{i}\", self._get_layer(prev_channels, c, bias))\n            prev_channels = c\n\n        self.add_module(\"output\", nn.Linear(prev_channels, out_channels, bias))\n\n    def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Sequential:\n        seq = nn.Sequential(\n            nn.Linear(in_channels, out_channels, bias), _get_adn_layer(self.act, self.dropout, self.adn_ordering)\n        )\n        return seq\n\n\nclass VarFullyConnectedNet(nn.Module):\n    \"\"\"\n    Variational fully-connected network. This is composed of an encode layer, reparameterization layer, and then a\n    decode layer.\n\n    Args:\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        latent_size: number of latent variables to use.\n        encode_channels: number of output channels for each hidden layer of the encode half.\n        decode_channels: number of output channels for each hidden layer of the decode half.\n        dropout: dropout ratio. Defaults to no dropout.\n        act: activation type and arguments. Defaults to PReLU.\n        bias: whether to have a bias term in linear units. Defaults to True.\n        adn_ordering: order of operations in :py:class:`monai.networks.blocks.ADN`.\n\n    Examples::\n\n        # accepts inputs with 4 values, uses a latent space of 2 variables, and produces outputs of 3 values\n        net = VarFullyConnectedNet(4, 3, 2, [5, 10], [10, 5])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        latent_size: int,\n        encode_channels: Sequence[int],\n        decode_channels: Sequence[int],\n        dropout: tuple | str | float | None = None,\n        act: tuple | str | None = Act.PRELU,\n        bias: bool = True,\n        adn_ordering: str | None = None,\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.latent_size = latent_size\n\n        self.encode = nn.Sequential()\n        self.decode = nn.Sequential()\n        self.flatten = nn.Flatten()\n\n        self.adn_layer = _get_adn_layer(act, dropout, adn_ordering)\n\n        prev_channels = self.in_channels\n        for i, c in enumerate(encode_channels):\n            self.encode.add_module(f\"encode_{i}\", self._get_layer(prev_channels, c, bias))\n            prev_channels = c\n\n        self.mu = nn.Linear(prev_channels, self.latent_size)\n        self.logvar = nn.Linear(prev_channels, self.latent_size)\n        self.decodeL = nn.Linear(self.latent_size, prev_channels)\n\n        for i, c in enumerate(decode_channels):\n            self.decode.add_module(f\"decode{i}\", self._get_layer(prev_channels, c, bias))\n            prev_channels = c\n\n        self.decode.add_module(\"final\", nn.Linear(prev_channels, out_channels, bias))\n\n    def _get_layer(self, in_channels: int, out_channels: int, bias: bool) -> nn.Sequential:\n        seq = nn.Sequential(nn.Linear(in_channels, out_channels, bias))\n        seq.add_module(\"ADN\", self.adn_layer)\n        return seq\n\n    def encode_forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        x = self.encode(x)\n        x = self.flatten(x)\n        mu = self.mu(x)\n        logvar = self.logvar(x)\n        return mu, logvar\n\n    def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Tensor:\n        x: torch.Tensor\n        x = self.decodeL(z)\n        x = torch.relu(x)\n        x = self.flatten(x)\n        x = self.decode(x)\n        if use_sigmoid:\n            x = torch.sigmoid(x)\n        return x\n\n    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:\n        std = torch.exp(0.5 * logvar)\n\n        if self.training:  # multiply random noise with std only during training\n            std = torch.randn_like(std).mul(std)\n\n        return std.add_(mu)\n\n    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        mu, logvar = self.encode_forward(x)\n        z = self.reparameterize(mu, logvar)\n        return self.decode_forward(z), mu, logvar, z\n"
  },
  {
    "path": "monai/networks/nets/generator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution, ResidualUnit\nfrom monai.networks.layers.factories import Act, Norm\nfrom monai.networks.layers.simplelayers import Reshape\nfrom monai.utils import ensure_tuple, ensure_tuple_rep\n\n\nclass Generator(nn.Module):\n    \"\"\"\n    Defines a simple generator network accepting a latent vector and through a sequence of convolution layers\n    constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to\n    create each of these layers, override this method to define layers beyond the default\n    :py:class:`monai.networks.blocks.Convolution` or :py:class:`monai.networks.blocks.ResidualUnit` layers.\n\n    The layers are constructed using the values in the `channels` and `strides` arguments, the number being defined by\n    the length of these (which must match). Input is first passed through a :py:class:`torch.nn.Linear` layer to\n    convert the input vector to an image tensor with dimensions `start_shape`. This passes through the convolution\n    layers and is progressively upsampled if the `strides` values are greater than 1 using transpose convolutions. The\n    size of the final output is defined by the `start_shape` dimension and the amount of upsampling done through\n    strides. In the default definition the size of the output's spatial dimensions will be that of `start_shape`\n    multiplied by the product of `strides`, thus the example network below upsamples an starting size of (64, 8, 8)\n    to (1, 64, 64) since its `strides` are (2, 2, 2).\n\n    Args:\n        latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension)\n        start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork\n        channels: tuple of integers stating the output channels of each convolutional layer\n        strides: tuple of integers stating the stride (upscale factor) of each convolutional layer\n        kernel_size: integer or tuple of integers stating size of convolutional kernels\n        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units\n        act: name or type defining activation layers\n        norm: name or type defining normalization layers\n        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout\n        bias: boolean stating if convolution layers should have a bias component\n\n    Examples::\n\n        # 3 layers, latent input vector of shape (42, 24), output volume of shape (1, 64, 64)\n        net = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2))\n\n    \"\"\"\n\n    def __init__(\n        self,\n        latent_shape: Sequence[int],\n        start_shape: Sequence[int],\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 2,\n        act=Act.PRELU,\n        norm=Norm.INSTANCE,\n        dropout: float | None = None,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n\n        self.in_channels, *self.start_shape = ensure_tuple(start_shape)\n        self.dimensions = len(self.start_shape)\n\n        self.latent_shape = ensure_tuple(latent_shape)\n        self.channels = ensure_tuple(channels)\n        self.strides = ensure_tuple(strides)\n        self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions)\n        self.num_res_units = num_res_units\n        self.act = act\n        self.norm = norm\n        self.dropout = dropout\n        self.bias = bias\n\n        self.flatten = nn.Flatten()\n        self.linear = nn.Linear(int(np.prod(self.latent_shape)), int(np.prod(start_shape)))\n        self.reshape = Reshape(*start_shape)\n        self.conv = nn.Sequential()\n\n        echannel = self.in_channels\n\n        # transform tensor of shape `start_shape' into output shape through transposed convolutions and residual units\n        for i, (c, s) in enumerate(zip(channels, strides)):\n            is_last = i == len(channels) - 1\n            layer = self._get_layer(echannel, c, s, is_last)\n            self.conv.add_module(f\"layer_{i}\", layer)\n            echannel = c\n\n    def _get_layer(\n        self, in_channels: int, out_channels: int, strides: int, is_last: bool\n    ) -> Convolution | nn.Sequential:\n        \"\"\"\n        Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels`\n        number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last`\n        is True this is the final layer and is not expected to include activation and normalization layers.\n        \"\"\"\n\n        layer: Convolution | nn.Sequential\n\n        layer = Convolution(\n            in_channels=in_channels,\n            strides=strides,\n            is_transposed=True,\n            conv_only=is_last or self.num_res_units > 0,\n            spatial_dims=self.dimensions,\n            out_channels=out_channels,\n            kernel_size=self.kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n        )\n\n        if self.num_res_units > 0:\n            ru = ResidualUnit(\n                in_channels=out_channels,\n                subunits=self.num_res_units,\n                last_conv_only=is_last,\n                spatial_dims=self.dimensions,\n                out_channels=out_channels,\n                kernel_size=self.kernel_size,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n            )\n\n            layer = nn.Sequential(layer, ru)\n\n        return layer\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.flatten(x)\n        x = self.linear(x)\n        x = self.reshape(x)\n        x = self.conv(x)\n        return x\n"
  },
  {
    "path": "monai/networks/nets/highresnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import ADN, Convolution\nfrom monai.networks.layers.simplelayers import ChannelPad\nfrom monai.utils import ChannelMatching\n\n__all__ = [\"HighResBlock\", \"HighResNet\"]\n\nDEFAULT_LAYER_PARAMS_3D = (\n    # initial conv layer\n    {\"name\": \"conv_0\", \"n_features\": 16, \"kernel_size\": 3},\n    # residual blocks\n    {\"name\": \"res_1\", \"n_features\": 16, \"kernels\": (3, 3), \"repeat\": 3},\n    {\"name\": \"res_2\", \"n_features\": 32, \"kernels\": (3, 3), \"repeat\": 3},\n    {\"name\": \"res_3\", \"n_features\": 64, \"kernels\": (3, 3), \"repeat\": 3},\n    # final conv layers\n    {\"name\": \"conv_1\", \"n_features\": 80, \"kernel_size\": 1},\n    {\"name\": \"conv_2\", \"kernel_size\": 1},\n)\n\n\nclass HighResBlock(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        kernels: Sequence[int] = (3, 3),\n        dilation: Sequence[int] | int = 1,\n        norm_type: tuple | str = (\"batch\", {\"affine\": True}),\n        acti_type: tuple | str = (\"relu\", {\"inplace\": True}),\n        bias: bool = False,\n        channel_matching: ChannelMatching | str = ChannelMatching.PAD,\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            kernels: each integer k in `kernels` corresponds to a convolution layer with kernel size k.\n            dilation: spacing between kernel elements.\n            norm_type: feature normalization type and arguments.\n                Defaults to ``(\"batch\", {\"affine\": True})``.\n            acti_type: {``\"relu\"``, ``\"prelu\"``, ``\"relu6\"``}\n                Non-linear activation using ReLU or PReLU. Defaults to ``\"relu\"``.\n            bias: whether to have a bias term in convolution blocks. Defaults to False.\n                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n                if a conv layer is directly followed by a batch norm layer, bias should be False.\n            channel_matching: {``\"pad\"``, ``\"project\"``}\n                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``\"pad\"``.\n\n                - ``\"pad\"``: with zero padding.\n                - ``\"project\"``: with a trainable conv with kernel size one.\n\n        Raises:\n            ValueError: When ``channel_matching=pad`` and ``in_channels > out_channels``. Incompatible values.\n\n        \"\"\"\n        super().__init__()\n        self.chn_pad = ChannelPad(\n            spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, mode=channel_matching\n        )\n\n        layers = nn.ModuleList()\n        _in_chns, _out_chns = in_channels, out_channels\n\n        for kernel_size in kernels:\n            layers.append(\n                ADN(ordering=\"NA\", in_channels=_in_chns, act=acti_type, norm=norm_type, norm_dim=spatial_dims)\n            )\n            layers.append(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=_in_chns,\n                    out_channels=_out_chns,\n                    kernel_size=kernel_size,\n                    dilation=dilation,\n                    bias=bias,\n                    conv_only=True,\n                )\n            )\n            _in_chns = _out_chns\n\n        self.layers = nn.Sequential(*layers)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_conv: torch.Tensor = self.layers(x)\n        return x_conv + torch.as_tensor(self.chn_pad(x))\n\n\nclass HighResNet(nn.Module):\n    \"\"\"\n    Reimplementation of highres3dnet based on\n    Li et al., \"On the compactness, efficiency, and representation of 3D\n    convolutional networks: Brain parcellation as a pretext task\", IPMI '17\n\n    Adapted from:\n    https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/network/highres3dnet.py\n    https://github.com/fepegar/highresnet\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input image.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        norm_type: feature normalization type and arguments.\n            Defaults to ``(\"batch\", {\"affine\": True})``.\n        acti_type: activation type and arguments.\n            Defaults to ``(\"relu\", {\"inplace\": True})``.\n        dropout_prob: probability of the feature map to be zeroed\n            (only applies to the penultimate conv layer).\n        bias: whether to have a bias term in convolution blocks. Defaults to False.\n            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n            if a conv layer is directly followed by a batch norm layer, bias should be False.\n        layer_params: specifying key parameters of each layer/block.\n        channel_matching: {``\"pad\"``, ``\"project\"``}\n            Specifies handling residual branch and conv branch channel mismatches. Defaults to ``\"pad\"``.\n\n            - ``\"pad\"``: with zero padding.\n            - ``\"project\"``: with a trainable conv with kernel size one.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        in_channels: int = 1,\n        out_channels: int = 1,\n        norm_type: str | tuple = (\"batch\", {\"affine\": True}),\n        acti_type: str | tuple = (\"relu\", {\"inplace\": True}),\n        dropout_prob: tuple | str | float | None = 0.0,\n        bias: bool = False,\n        layer_params: Sequence[dict] = DEFAULT_LAYER_PARAMS_3D,\n        channel_matching: ChannelMatching | str = ChannelMatching.PAD,\n    ) -> None:\n        super().__init__()\n        blocks = nn.ModuleList()\n\n        # initial conv layer\n        params = layer_params[0]\n        _in_chns, _out_chns = in_channels, params[\"n_features\"]\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=_in_chns,\n                out_channels=_out_chns,\n                kernel_size=params[\"kernel_size\"],\n                adn_ordering=\"NA\",\n                act=acti_type,\n                norm=norm_type,\n                bias=bias,\n            )\n        )\n\n        # residual blocks\n        for idx, params in enumerate(layer_params[1:-2]):  # res blocks except the 1st and last two conv layers.\n            _in_chns, _out_chns = _out_chns, params[\"n_features\"]\n            _dilation = 2**idx\n            for _ in range(params[\"repeat\"]):\n                blocks.append(\n                    HighResBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=_in_chns,\n                        out_channels=_out_chns,\n                        kernels=params[\"kernels\"],\n                        dilation=_dilation,\n                        norm_type=norm_type,\n                        acti_type=acti_type,\n                        bias=bias,\n                        channel_matching=channel_matching,\n                    )\n                )\n                _in_chns = _out_chns\n\n        # final conv layers\n        params = layer_params[-2]\n        _in_chns, _out_chns = _out_chns, params[\"n_features\"]\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=_in_chns,\n                out_channels=_out_chns,\n                kernel_size=params[\"kernel_size\"],\n                adn_ordering=\"NAD\",\n                act=acti_type,\n                norm=norm_type,\n                bias=bias,\n                dropout=dropout_prob,\n            )\n        )\n\n        params = layer_params[-1]\n        _in_chns = _out_chns\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=_in_chns,\n                out_channels=out_channels,\n                kernel_size=params[\"kernel_size\"],\n                adn_ordering=\"NAD\",\n                act=acti_type,\n                norm=norm_type,\n                bias=bias,\n                dropout=dropout_prob,\n            )\n        )\n\n        self.blocks = nn.Sequential(*blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return torch.as_tensor(self.blocks(x))\n"
  },
  {
    "path": "monai/networks/nets/hovernet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# =========================================================================\n# Adapted from https://github.com/vqdang/hover_net\n# which has the following license:\n# https://github.com/vqdang/hover_net/blob/master/LICENSE\n# MIT License\n\n# Original publication:\n#  @article{graham2019hover,\n#    title={Hover-net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images},\n#    author={Graham, Simon and Vu, Quoc Dang and Raza, Shan E Ahmed and Azam, Ayesha and Tsang, Yee Wah and Kwak,\n#            Jin Tae and Rajpoot, Nasir},\n#    journal={Medical Image Analysis},\n#    pages={101563},\n#    year={2019},\n#    publisher={Elsevier}\n# }\n# =========================================================================\n\nfrom __future__ import annotations\n\nimport os\nimport re\nimport warnings\nfrom collections import OrderedDict\nfrom collections.abc import Callable, Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.apps.utils import download_url\nfrom monai.networks.blocks import UpSample\nfrom monai.networks.layers.factories import Conv, Dropout\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\nfrom monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode\nfrom monai.utils.module import look_up_option\n\n__all__ = [\"HoVerNet\", \"Hovernet\", \"HoVernet\", \"HoVerNet\"]\n\n\nclass _DenseLayerDecoder(nn.Module):\n\n    def __init__(\n        self,\n        num_features: int,\n        in_channels: int,\n        out_channels: int,\n        dropout_prob: float = 0.0,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        kernel_size: int = 3,\n        padding: int = 0,\n    ) -> None:\n        \"\"\"\n        Args:\n            num_features: number of internal channels used for the layer\n            in_channels: number of the input channels.\n            out_channels: number of the output channels.\n            dropout_prob: dropout rate after each dense layer.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n            kernel_size: size of the kernel for >1 convolutions (dependent on mode)\n            padding: padding value for >1 convolutions.\n        \"\"\"\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, 2]\n        dropout_type: Callable = Dropout[Dropout.DROPOUT, 2]\n\n        self.layers = nn.Sequential()\n\n        self.layers.add_module(\"preact_bna/bn\", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))\n        self.layers.add_module(\"preact_bna/relu\", get_act_layer(name=act))\n        self.layers.add_module(\"conv1\", conv_type(in_channels, num_features, kernel_size=1, bias=False))\n        self.layers.add_module(\"conv1/norm\", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))\n        self.layers.add_module(\"conv1/relu2\", get_act_layer(name=act))\n        self.layers.add_module(\n            \"conv2\",\n            conv_type(num_features, out_channels, kernel_size=kernel_size, padding=padding, groups=4, bias=False),\n        )\n\n        if dropout_prob > 0:\n            self.layers.add_module(\"dropout\", dropout_type(dropout_prob))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x1 = self.layers(x)\n        if x1.shape[-1] != x.shape[-1]:\n            trim = (x.shape[-1] - x1.shape[-1]) // 2\n            x = x[:, :, trim:-trim, trim:-trim]\n\n        x = torch.cat([x, x1], 1)\n\n        return x\n\n\nclass _DecoderBlock(nn.Sequential):\n\n    def __init__(\n        self,\n        layers: int,\n        num_features: int,\n        in_channels: int,\n        out_channels: int,\n        dropout_prob: float = 0.0,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        kernel_size: int = 3,\n        same_padding: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            layers: number of layers in the block.\n            num_features: number of internal features used.\n            in_channels: number of the input channel.\n            out_channels: number of the output channel.\n            dropout_prob: dropout rate after each dense layer.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n            kernel_size: size of the kernel for >1 convolutions (dependent on mode)\n            same_padding: whether to do padding for >1 convolutions to ensure\n                the output size is the same as the input size.\n        \"\"\"\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, 2]\n\n        padding: int = kernel_size // 2 if same_padding else 0\n\n        self.add_module(\n            \"conva\", conv_type(in_channels, in_channels // 4, kernel_size=kernel_size, padding=padding, bias=False)\n        )\n\n        _in_channels = in_channels // 4\n        for i in range(layers):\n            layer = _DenseLayerDecoder(\n                num_features,\n                _in_channels,\n                out_channels,\n                dropout_prob,\n                act=act,\n                norm=norm,\n                kernel_size=kernel_size,\n                padding=padding,\n            )\n            _in_channels += out_channels\n            self.add_module(f\"denselayerdecoder{i + 1}\", layer)\n\n        trans = _Transition(_in_channels, act=act, norm=norm)\n        self.add_module(\"bna_block\", trans)\n        self.add_module(\"convf\", conv_type(_in_channels, _in_channels, kernel_size=1, bias=False))\n\n\nclass _DenseLayer(nn.Sequential):\n\n    def __init__(\n        self,\n        num_features: int,\n        in_channels: int,\n        out_channels: int,\n        dropout_prob: float = 0.0,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        drop_first_norm_relu: int = 0,\n        kernel_size: int = 3,\n    ) -> None:\n        \"\"\"Dense Convolutional Block.\n\n        References:\n            Huang, Gao, et al. \"Densely connected convolutional networks.\"\n            Proceedings of the IEEE conference on computer vision and\n            pattern recognition. 2017.\n\n        Args:\n            num_features: number of internal channels used for the layer\n            in_channels: number of the input channels.\n            out_channels: number of the output channels.\n            dropout_prob: dropout rate after each dense layer.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n            drop_first_norm_relu - omits the first norm/relu for the first layer\n            kernel_size: size of the kernel for >1 convolutions (dependent on mode)\n        \"\"\"\n        super().__init__()\n\n        self.layers = nn.Sequential()\n        conv_type: Callable = Conv[Conv.CONV, 2]\n        dropout_type: Callable = Dropout[Dropout.DROPOUT, 2]\n\n        if not drop_first_norm_relu:\n            self.layers.add_module(\"preact/bn\", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))\n            self.layers.add_module(\"preact/relu\", get_act_layer(name=act))\n\n        self.layers.add_module(\"conv1\", conv_type(in_channels, num_features, kernel_size=1, padding=0, bias=False))\n        self.layers.add_module(\"conv1/bn\", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))\n        self.layers.add_module(\"conv1/relu\", get_act_layer(name=act))\n\n        if in_channels != 64 and drop_first_norm_relu:\n            self.layers.add_module(\n                \"conv2\", conv_type(num_features, num_features, kernel_size=kernel_size, stride=2, padding=2, bias=False)\n            )\n        else:\n            self.layers.add_module(\n                \"conv2\", conv_type(num_features, num_features, kernel_size=kernel_size, padding=1, bias=False)\n            )\n\n        self.layers.add_module(\"conv2/bn\", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))\n        self.layers.add_module(\"conv2/relu\", get_act_layer(name=act))\n        self.layers.add_module(\"conv3\", conv_type(num_features, out_channels, kernel_size=1, padding=0, bias=False))\n\n        if dropout_prob > 0:\n            self.layers.add_module(\"dropout\", dropout_type(dropout_prob))\n\n\nclass _Transition(nn.Sequential):\n\n    def __init__(\n        self, in_channels: int, act: str | tuple = (\"relu\", {\"inplace\": True}), norm: str | tuple = \"batch\"\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: number of the input channel.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n        \"\"\"\n        super().__init__()\n\n        self.add_module(\"bn\", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))\n        self.add_module(\"relu\", get_act_layer(name=act))\n\n\nclass _ResidualBlock(nn.Module):\n\n    def __init__(\n        self,\n        layers: int,\n        num_features: int,\n        in_channels: int,\n        out_channels: int,\n        dropout_prob: float = 0.0,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        freeze_dense_layer: bool = False,\n        freeze_block: bool = False,\n    ) -> None:\n        \"\"\"Residual block.\n\n        References:\n            He, Kaiming, et al. \"Deep residual learning for image\n            recognition.\" Proceedings of the IEEE conference on computer\n            vision and pattern recognition. 2016.\n\n        Args:\n            layers: number of layers in the block.\n            num_features: number of internal features used.\n            in_channels: number of the input channel.\n            out_channels: number of the output channel.\n            dropout_prob: dropout rate after each dense layer.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n            freeze_dense_layer: whether to freeze all dense layers within the block.\n            freeze_block: whether to freeze the whole block.\n\n        \"\"\"\n        super().__init__()\n\n        self.layers = nn.Sequential()\n        conv_type: Callable = Conv[Conv.CONV, 2]\n\n        if in_channels == 64:\n            self.shortcut = conv_type(in_channels, out_channels, kernel_size=1, bias=False)\n        else:\n            self.shortcut = conv_type(in_channels, out_channels, kernel_size=1, stride=2, padding=1, bias=False)\n\n        layer = _DenseLayer(\n            num_features, in_channels, out_channels, dropout_prob, act=act, norm=norm, drop_first_norm_relu=True\n        )\n        self.layers.add_module(\"denselayer_0\", layer)\n\n        for i in range(1, layers):\n            layer = _DenseLayer(num_features, out_channels, out_channels, dropout_prob, act=act, norm=norm)\n            self.layers.add_module(f\"denselayer_{i}\", layer)\n\n        self.bna_block = _Transition(out_channels, act=act, norm=norm)\n\n        if freeze_dense_layer:\n            self.layers.requires_grad_(False)\n        if freeze_block:\n            self.requires_grad_(False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        sc = self.shortcut(x)\n\n        if self.shortcut.stride == (2, 2):\n            sc = sc[:, :, :-1, :-1]\n\n        for layer in self.layers:\n            x = layer.forward(x)\n            if x.shape[-2:] != sc.shape[-2:]:\n                x = x[:, :, :-1, :-1]\n\n            x = x + sc\n            sc = x\n\n        x = self.bna_block(x)\n\n        return x\n\n\nclass _DecoderBranch(nn.ModuleList):\n\n    def __init__(\n        self,\n        decode_config: Sequence[int] = (8, 4),\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        dropout_prob: float = 0.0,\n        out_channels: int = 2,\n        kernel_size: int = 3,\n        same_padding: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            decode_config: number of layers for each block.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n            dropout_prob: dropout rate after each dense layer.\n            out_channels: number of the output channel.\n            kernel_size: size of the kernel for >1 convolutions (dependent on mode)\n            same_padding: whether to do padding for >1 convolutions to ensure\n                the output size is the same as the input size.\n        \"\"\"\n        super().__init__()\n        conv_type: Callable = Conv[Conv.CONV, 2]\n\n        # decode branches\n        _in_channels = 1024\n        _num_features = 128\n        _out_channels = 32\n\n        self.decoder_blocks = nn.Sequential()\n        for i, num_layers in enumerate(decode_config):\n            block = _DecoderBlock(\n                layers=num_layers,\n                num_features=_num_features,\n                in_channels=_in_channels,\n                out_channels=_out_channels,\n                dropout_prob=dropout_prob,\n                act=act,\n                norm=norm,\n                kernel_size=kernel_size,\n                same_padding=same_padding,\n            )\n            self.decoder_blocks.add_module(f\"decoderblock{i + 1}\", block)\n            _in_channels = 512\n\n        # output layers\n        self.output_features = nn.Sequential()\n        _i = len(decode_config)\n        _pad_size = (kernel_size - 1) // 2\n        _seq_block = nn.Sequential(\n            OrderedDict(\n                [(\"conva\", conv_type(256, 64, kernel_size=kernel_size, stride=1, bias=False, padding=_pad_size))]\n            )\n        )\n\n        self.output_features.add_module(f\"decoderblock{_i + 1}\", _seq_block)\n\n        _seq_block = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"bn\", get_norm_layer(name=norm, spatial_dims=2, channels=64)),\n                    (\"relu\", get_act_layer(name=act)),\n                    (\"conv\", conv_type(64, out_channels, kernel_size=1, stride=1)),\n                ]\n            )\n        )\n\n        self.output_features.add_module(f\"decoderblock{_i + 2}\", _seq_block)\n\n        self.upsample = UpSample(\n            2, scale_factor=2, mode=UpsampleMode.NONTRAINABLE, interp_mode=InterpolateMode.BILINEAR, bias=False\n        )\n\n    def forward(self, xin: torch.Tensor, short_cuts: list[torch.Tensor]) -> torch.Tensor:\n        block_number = len(short_cuts) - 1\n        x = xin + short_cuts[block_number]\n\n        for block in self.decoder_blocks:\n            x = block(x)\n            x = self.upsample(x)\n            block_number -= 1\n            trim = (short_cuts[block_number].shape[-1] - x.shape[-1]) // 2\n            if trim > 0:\n                x += short_cuts[block_number][:, :, trim:-trim, trim:-trim]\n\n        for block in self.output_features:\n            x = block(x)\n\n        return x\n\n\nclass HoVerNet(nn.Module):\n    \"\"\"HoVerNet model\n\n    References:\n      Graham, Simon et al. Hover-net: Simultaneous segmentation\n      and classification of nuclei in multi-tissue histology images,\n      Medical Image Analysis 2019\n\n      https://github.com/vqdang/hover_net\n      https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html\n\n    This network is non-deterministic since it uses `torch.nn.Upsample` with ``UpsampleMode.NONTRAINABLE`` mode which\n    is implemented with torch.nn.functional.interpolate(). Please check the link below for more details:\n    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms\n\n    Args:\n        mode: use original implementation (`HoVerNetMODE.ORIGINAL` or \"original\") or\n          a faster implementation (`HoVerNetMODE.FAST` or \"fast\"). Defaults to `HoVerNetMODE.FAST`.\n        in_channels: number of the input channel.\n        np_out_channels: number of the output channel of the nucleus prediction branch.\n        out_classes: number of the nuclear type classes.\n        act: activation type and arguments. Defaults to relu.\n        norm: feature normalization type and arguments. Defaults to batch norm.\n        decoder_padding: whether to do padding on convolution layers in the decoders. In the conic branch\n            of the referred repository, the architecture is changed to do padding on convolution layers in order to\n            get the same output size as the input, and this changed version is used on CoNIC challenge.\n            Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed.\n        dropout_prob: dropout rate after each dense layer.\n        pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url.\n            There are two supported forms of weights:\n            1. preact-resnet50 weights coming from the referred hover_net\n            repository, each user is responsible for checking the content of model/datasets and the applicable licenses\n            and determining if suitable for the intended use. please check the following link for more details:\n            https://github.com/vqdang/hover_net#data-format\n            2. standard resnet50 weights of torchvision. Please check the following link for more details:\n            https://pytorch.org/vision/main/_modules/torchvision/models/resnet.html#ResNet50_Weights\n        adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this\n            value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format,\n            this value should be `True`.\n        pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True.\n            It is used to extract the expected state dict.\n        freeze_encoder: whether to freeze the encoder of the network.\n    \"\"\"\n\n    Mode = HoVerNetMode\n    Branch = HoVerNetBranch\n\n    def __init__(\n        self,\n        mode: HoVerNetMode | str = HoVerNetMode.FAST,\n        in_channels: int = 3,\n        np_out_channels: int = 2,\n        out_classes: int = 0,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n        decoder_padding: bool = False,\n        dropout_prob: float = 0.0,\n        pretrained_url: str | None = None,\n        adapt_standard_resnet: bool = False,\n        pretrained_state_dict_key: str | None = None,\n        freeze_encoder: bool = False,\n    ) -> None:\n        super().__init__()\n\n        if isinstance(mode, str):\n            mode = mode.upper()\n        self.mode = look_up_option(mode, HoVerNetMode)\n\n        if self.mode == \"ORIGINAL\" and decoder_padding is True:\n            warnings.warn(\n                \"'decoder_padding=True' only works when mode is 'FAST', otherwise the output size may not equal to the input.\"\n            )\n\n        if out_classes > 128:\n            raise ValueError(\"Number of nuclear types classes exceeds maximum (128)\")\n        elif out_classes == 1:\n            raise ValueError(\"Number of nuclear type classes should either be None or >1\")\n\n        if dropout_prob > 1 or dropout_prob < 0:\n            raise ValueError(\"Dropout can only be in the range 0.0 to 1.0\")\n\n        # number of filters in the first convolution layer.\n        _init_features: int = 64\n        # number of layers in each pooling block.\n        _block_config: Sequence[int] = (3, 4, 6, 3)\n\n        if self.mode == HoVerNetMode.FAST:\n            _ksize = 3\n            _pad = 3\n        else:\n            _ksize = 5\n            _pad = 0\n\n        conv_type: type[nn.Conv2d] = Conv[Conv.CONV, 2]\n\n        self.conv0 = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"conv\", conv_type(in_channels, _init_features, kernel_size=7, stride=1, padding=_pad, bias=False)),\n                    (\"bn\", get_norm_layer(name=norm, spatial_dims=2, channels=_init_features)),\n                    (\"relu\", get_act_layer(name=act)),\n                ]\n            )\n        )\n\n        _in_channels = _init_features\n        _out_channels = 256\n        _num_features = _init_features\n\n        self.res_blocks = nn.Sequential()\n\n        for i, num_layers in enumerate(_block_config):\n            freeze_dense_layer = False\n            freeze_block = False\n            if freeze_encoder:\n                if i == 0:\n                    freeze_dense_layer = True\n                else:\n                    freeze_block = True\n            block = _ResidualBlock(\n                layers=num_layers,\n                num_features=_num_features,\n                in_channels=_in_channels,\n                out_channels=_out_channels,\n                dropout_prob=dropout_prob,\n                act=act,\n                norm=norm,\n                freeze_dense_layer=freeze_dense_layer,\n                freeze_block=freeze_block,\n            )\n            self.res_blocks.add_module(f\"d{i}\", block)\n\n            _in_channels = _out_channels\n            _out_channels *= 2\n            _num_features *= 2\n\n        # bottleneck convolution\n        self.bottleneck = nn.Sequential()\n        self.bottleneck.add_module(\n            \"conv_bottleneck\", conv_type(_in_channels, _num_features, kernel_size=1, stride=1, padding=0, bias=False)\n        )\n        self.upsample = UpSample(\n            2, scale_factor=2, mode=UpsampleMode.NONTRAINABLE, interp_mode=InterpolateMode.BILINEAR, bias=False\n        )\n\n        # decode branches\n        self.nucleus_prediction = _DecoderBranch(\n            kernel_size=_ksize, same_padding=decoder_padding, out_channels=np_out_channels\n        )\n        self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize, same_padding=decoder_padding)\n        self.type_prediction: _DecoderBranch | None = (\n            _DecoderBranch(out_channels=out_classes, kernel_size=_ksize, same_padding=decoder_padding)\n            if out_classes > 0\n            else None\n        )\n\n        for m in self.modules():\n            if isinstance(m, conv_type):\n                nn.init.kaiming_normal_(torch.as_tensor(m.weight))\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(torch.as_tensor(m.weight), 1)\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n\n        if pretrained_url is not None:\n            if adapt_standard_resnet:\n                weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key)\n            else:\n                weights = _remap_preact_resnet_model(pretrained_url)\n            _load_pretrained_encoder(self, weights)\n\n    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:\n        if self.mode == HoVerNetMode.ORIGINAL.value:\n            if x.shape[-1] != 270 or x.shape[-2] != 270:\n                raise ValueError(\"Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL\")\n        else:\n            if x.shape[-1] != 256 or x.shape[-2] != 256:\n                raise ValueError(\"Input size should be 256 x 256 when using HoVerNetMode.FAST\")\n\n        x = self.conv0(x)\n        short_cuts = []\n\n        for i, block in enumerate(self.res_blocks):\n            x = block.forward(x)\n\n            if i <= 2:\n                short_cuts.append(x)\n\n        x = self.bottleneck(x)\n        x = self.upsample(x)\n\n        output = {\n            HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts),\n            HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts),\n        }\n        if self.type_prediction is not None:\n            output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts)\n\n        return output\n\n\ndef _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict):\n    model_dict = model.state_dict()\n    state_dict = {\n        k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)\n    }\n\n    model_dict.update(state_dict)\n    model.load_state_dict(model_dict)\n    if len(state_dict.keys()) == 0:\n        warnings.warn(\n            \"no key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct.\"\n        )\n    else:\n        print(f\"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.\")\n\n\ndef _remap_preact_resnet_model(model_url: str):\n    pattern_conv0 = re.compile(r\"^(conv0\\.\\/)(.+)$\")\n    pattern_block = re.compile(r\"^(d\\d+)\\.(.+)$\")\n    pattern_layer = re.compile(r\"^(.+\\.d\\d+)\\.units\\.(\\d+)(.+)$\")\n    pattern_bna = re.compile(r\"^(.+\\.d\\d+)\\.blk_bna\\.(.+)\")\n    # download the pretrained weights into torch hub's default dir\n    weights_dir = os.path.join(torch.hub.get_dir(), \"preact-resnet50.pth\")\n    download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)\n    map_location = None if torch.cuda.is_available() else torch.device(\"cpu\")\n    state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)[\"desc\"]\n\n    for key in list(state_dict.keys()):\n        new_key = None\n        if pattern_conv0.match(key):\n            new_key = re.sub(pattern_conv0, r\"conv0.conv\\2\", key)\n        elif pattern_block.match(key):\n            new_key = re.sub(pattern_block, r\"res_blocks.\\1.\\2\", key)\n            if pattern_layer.match(new_key):\n                new_key = re.sub(pattern_layer, r\"\\1.layers.denselayer_\\2.layers\\3\", new_key)\n            elif pattern_bna.match(new_key):\n                new_key = re.sub(pattern_bna, r\"\\1.bna_block.\\2\", new_key)\n        if new_key:\n            state_dict[new_key] = state_dict[key]\n            del state_dict[key]\n        if \"upsample2x\" in key:\n            del state_dict[key]\n\n    return state_dict\n\n\ndef _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None):\n    pattern_conv0 = re.compile(r\"^conv1\\.(.+)$\")\n    pattern_bn1 = re.compile(r\"^bn1\\.(.+)$\")\n    pattern_block = re.compile(r\"^layer(\\d+)\\.(\\d+)\\.(.+)$\")\n    # bn3 to next denselayer's preact/bn\n    pattern_block_bn3 = re.compile(r\"^(res_blocks.d\\d+\\.layers\\.denselayer_)(\\d+)\\.layers\\.bn3\\.(.+)$\")\n    # bn1, bn2 to conv1/bn, conv2/bn\n    pattern_block_bn = re.compile(r\"^(res_blocks.d\\d+\\.layers\\.denselayer_\\d+\\.layers)\\.bn(\\d+)\\.(.+)$\")\n    pattern_downsample0 = re.compile(r\"^(res_blocks.d\\d+).+\\.downsample\\.0\\.(.+)\")\n    pattern_downsample1 = re.compile(r\"^(res_blocks.d\\d+).+\\.downsample\\.1\\.(.+)\")\n    # download the pretrained weights into torch hub's default dir\n    weights_dir = os.path.join(torch.hub.get_dir(), \"resnet50.pth\")\n    download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)\n    map_location = None if torch.cuda.is_available() else torch.device(\"cpu\")\n    state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)\n    if state_dict_key is not None:\n        state_dict = state_dict[state_dict_key]\n\n    for key in list(state_dict.keys()):\n        new_key = None\n        if pattern_conv0.match(key):\n            new_key = re.sub(pattern_conv0, r\"conv0.conv.\\1\", key)\n        elif pattern_bn1.match(key):\n            new_key = re.sub(pattern_bn1, r\"conv0.bn.\\1\", key)\n        elif pattern_block.match(key):\n            new_key = re.sub(\n                pattern_block,\n                lambda s: \"res_blocks.d\"\n                + str(int(s.group(1)) - 1)\n                + \".layers.denselayer_\"\n                + s.group(2)\n                + \".layers.\"\n                + s.group(3),\n                key,\n            )\n            if pattern_block_bn3.match(new_key):\n                new_key = re.sub(\n                    pattern_block_bn3,\n                    lambda s: s.group(1) + str(int(s.group(2)) + 1) + \".layers.preact/bn.\" + s.group(3),\n                    new_key,\n                )\n            elif pattern_block_bn.match(new_key):\n                new_key = re.sub(pattern_block_bn, r\"\\1.conv\\2/bn.\\3\", new_key)\n            elif pattern_downsample0.match(new_key):\n                new_key = re.sub(pattern_downsample0, r\"\\1.shortcut.\\2\", new_key)\n            elif pattern_downsample1.match(new_key):\n                new_key = re.sub(pattern_downsample1, r\"\\1.bna_block.bn.\\2\", new_key)\n        if new_key:\n            state_dict[new_key] = state_dict[key]\n            del state_dict[key]\n\n    return state_dict\n\n\nHovernet = HoVernet = HoverNet = HoVerNet\n"
  },
  {
    "path": "monai/networks/nets/masked_autoencoder_vit.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.patchembedding import PatchEmbeddingBlock\nfrom monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding\nfrom monai.networks.blocks.transformerblock import TransformerBlock\nfrom monai.networks.layers import trunc_normal_\nfrom monai.utils import ensure_tuple_rep\nfrom monai.utils.module import look_up_option\n\nSUPPORTED_POS_EMBEDDING_TYPES = {\"none\", \"learnable\", \"sincos\"}\n\n__all__ = [\"MaskedAutoEncoderViT\"]\n\n\nclass MaskedAutoEncoderViT(nn.Module):\n    \"\"\"\n    Masked Autoencoder (ViT), based on: \"Kaiming et al.,\n    Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>\"\n    Only a subset of the patches passes through the encoder. The decoder tries to reconstruct\n    the masked patches, resulting in improved training speed.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        img_size: Sequence[int] | int,\n        patch_size: Sequence[int] | int,\n        hidden_size: int = 768,\n        mlp_dim: int = 512,\n        num_layers: int = 12,\n        num_heads: int = 12,\n        masking_ratio: float = 0.75,\n        decoder_hidden_size: int = 384,\n        decoder_mlp_dim: int = 512,\n        decoder_num_layers: int = 4,\n        decoder_num_heads: int = 12,\n        proj_type: str = \"conv\",\n        pos_embed_type: str = \"sincos\",\n        decoder_pos_embed_type: str = \"sincos\",\n        dropout_rate: float = 0.0,\n        spatial_dims: int = 3,\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: dimension of input channels or the number of channels for input.\n            img_size: dimension of input image.\n            patch_size: dimension of patch size\n            hidden_size: dimension of hidden layer. Defaults to 768.\n            mlp_dim: dimension of feedforward layer. Defaults to 512.\n            num_layers:  number of transformer blocks. Defaults to 12.\n            num_heads: number of attention heads. Defaults to 12.\n            masking_ratio: ratio of patches to be masked. Defaults to 0.75.\n            decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.\n            decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.\n            decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.\n            decoder_num_heads: number of attention heads for decoder. Defaults to 12.\n            proj_type: position embedding layer type. Defaults to \"conv\".\n            pos_embed_type: position embedding layer type. Defaults to \"sincos\".\n            decoder_pos_embed_type: position embedding layer type for decoder. Defaults to \"sincos\".\n            dropout_rate: fraction of the input units to drop. Defaults to 0.0.\n            spatial_dims: number of spatial dimensions. Defaults to 3.\n            qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.\n            save_attn: to make accessible the attention in self attention block. Defaults to False.\n        Examples::\n            # for single channel input with image size of (96,96,96), and sin-cos positional encoding\n            >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),\n            pos_embed_type='sincos')\n            # for 3-channel with image size of (128,128,128) and a learnable positional encoding\n            >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')\n            # for 3-channel with image size of (224,224) and a masking ratio of 0.25\n            >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,\n            spatial_dims=2)\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(f\"dropout_rate should be between 0 and 1, got {dropout_rate}.\")\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(\"hidden_size should be divisible by num_heads.\")\n\n        if decoder_hidden_size % decoder_num_heads != 0:\n            raise ValueError(\"decoder_hidden_size should be divisible by decoder_num_heads.\")\n\n        self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)\n        self.img_size = ensure_tuple_rep(img_size, spatial_dims)\n        self.spatial_dims = spatial_dims\n        for m, p in zip(self.img_size, self.patch_size):\n            if m % p != 0:\n                raise ValueError(f\"patch_size={patch_size} should be divisible by img_size={img_size}.\")\n\n        self.decoder_hidden_size = decoder_hidden_size\n\n        if masking_ratio <= 0 or masking_ratio >= 1:\n            raise ValueError(f\"masking_ratio should be in the range (0, 1), got {masking_ratio}.\")\n\n        self.masking_ratio = masking_ratio\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))\n\n        self.patch_embedding = PatchEmbeddingBlock(\n            in_channels=in_channels,\n            img_size=img_size,\n            patch_size=patch_size,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            proj_type=proj_type,\n            pos_embed_type=pos_embed_type,\n            dropout_rate=dropout_rate,\n            spatial_dims=self.spatial_dims,\n        )\n        blocks = [\n            TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)\n            for _ in range(num_layers)\n        ]\n        self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))\n\n        # decoder\n        self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)\n\n        self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))\n\n        self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)\n        self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))\n\n        decoder_blocks = [\n            TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)\n            for _ in range(decoder_num_layers)\n        ]\n        self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))\n        self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)\n\n        self._init_weights()\n\n    def _init_weights(self):\n        \"\"\"\n        similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and\n        classification tokens\n        \"\"\"\n        if self.decoder_pos_embed_type == \"none\":\n            pass\n        elif self.decoder_pos_embed_type == \"learnable\":\n            trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)\n        elif self.decoder_pos_embed_type == \"sincos\":\n            grid_size = []\n            for in_size, pa_size in zip(self.img_size, self.patch_size):\n                grid_size.append(in_size // pa_size)\n\n            self.decoder_pos_embedding = build_sincos_position_embedding(\n                grid_size, self.decoder_hidden_size, self.spatial_dims\n            )\n\n        else:\n            raise ValueError(f\"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.\")\n\n        # initialize patch_embedding like nn.Linear (instead of nn.Conv2d)\n        trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)\n        trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)\n\n    def _masking(self, x, masking_ratio: float | None = None):\n        batch_size, num_tokens, _ = x.shape\n        percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio\n        selected_indices = torch.multinomial(\n            torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False\n        )\n        x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices]  # gather the selected tokens\n        mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)\n        mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0\n\n        return x_masked, selected_indices, mask\n\n    def forward(self, x, masking_ratio: float | None = None):\n        x = self.patch_embedding(x)\n        x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)\n\n        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        x = self.blocks(x)\n\n        # decoder\n        x = self.decoder_embed(x)\n\n        x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)\n        x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :]  # no cls token\n        x_ = x_ + self.decoder_pos_embedding\n        x = torch.cat([x[:, :1, :], x_], dim=1)\n        x = self.decoder_blocks(x)\n        x = self.decoder_pred(x)\n\n        x = x[:, 1:, :]\n        return x, mask\n"
  },
  {
    "path": "monai/networks/nets/mednext.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Portions of this code are derived from the original repository at:\n# https://github.com/MIC-DKFZ/MedNeXt\n# and are used under the terms of the Apache License, Version 2.0.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock\n\n__all__ = [\n    \"MedNeXt\",\n    \"MedNeXtSmall\",\n    \"MedNeXtBase\",\n    \"MedNeXtMedium\",\n    \"MedNeXtLarge\",\n    \"MedNext\",\n    \"MedNextS\",\n    \"MedNeXtS\",\n    \"MedNextSmall\",\n    \"MedNextB\",\n    \"MedNeXtB\",\n    \"MedNextBase\",\n    \"MedNextM\",\n    \"MedNeXtM\",\n    \"MedNextMedium\",\n    \"MedNextL\",\n    \"MedNeXtL\",\n    \"MedNextLarge\",\n]\n\n\nclass MedNeXt(nn.Module):\n    \"\"\"\n    MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975\n\n    Args:\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        init_filters: number of output channels for initial convolution layer. Defaults to 32.\n        in_channels: number of input channels for the network. Defaults to 1.\n        out_channels: number of output channels for the network. Defaults to 2.\n        encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2.\n        decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.\n        bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.\n        kernel_size: kernel size for convolutions. Defaults to 7.\n        deep_supervision: whether to use deep supervision. Defaults to False.\n        use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.\n        blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].\n        blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.\n        blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2].\n        norm_type: type of normalization layer. Defaults to 'group'.\n        global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        init_filters: int = 32,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        encoder_expansion_ratio: Sequence[int] | int = 2,\n        decoder_expansion_ratio: Sequence[int] | int = 2,\n        bottleneck_expansion_ratio: int = 2,\n        kernel_size: int = 7,\n        deep_supervision: bool = False,\n        use_residual_connection: bool = False,\n        blocks_down: Sequence[int] = (2, 2, 2, 2),\n        blocks_bottleneck: int = 2,\n        blocks_up: Sequence[int] = (2, 2, 2, 2),\n        norm_type: str = \"group\",\n        global_resp_norm: bool = False,\n    ):\n        \"\"\"\n        Initialize the MedNeXt model.\n\n        This method sets up the architecture of the model, including:\n        - Stem convolution\n        - Encoder stages and downsampling blocks\n        - Bottleneck blocks\n        - Decoder stages and upsampling blocks\n        - Output blocks for deep supervision (if enabled)\n        \"\"\"\n        super().__init__()\n\n        self.do_ds = deep_supervision\n        assert spatial_dims in [2, 3], \"`spatial_dims` can only be 2 or 3.\"\n        spatial_dims_str = f\"{spatial_dims}d\"\n        enc_kernel_size = dec_kernel_size = kernel_size\n\n        if isinstance(encoder_expansion_ratio, int):\n            encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down)\n\n        if isinstance(decoder_expansion_ratio, int):\n            decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up)\n\n        conv = nn.Conv2d if spatial_dims_str == \"2d\" else nn.Conv3d\n\n        self.stem = conv(in_channels, init_filters, kernel_size=1)\n\n        enc_stages = []\n        down_blocks = []\n\n        for i, num_blocks in enumerate(blocks_down):\n            enc_stages.append(\n                nn.Sequential(\n                    *[\n                        MedNeXtBlock(\n                            in_channels=init_filters * (2**i),\n                            out_channels=init_filters * (2**i),\n                            expansion_ratio=encoder_expansion_ratio[i],\n                            kernel_size=enc_kernel_size,\n                            use_residual_connection=use_residual_connection,\n                            norm_type=norm_type,\n                            dim=spatial_dims_str,\n                            global_resp_norm=global_resp_norm,\n                        )\n                        for _ in range(num_blocks)\n                    ]\n                )\n            )\n\n            down_blocks.append(\n                MedNeXtDownBlock(\n                    in_channels=init_filters * (2**i),\n                    out_channels=init_filters * (2 ** (i + 1)),\n                    expansion_ratio=encoder_expansion_ratio[i],\n                    kernel_size=enc_kernel_size,\n                    use_residual_connection=use_residual_connection,\n                    norm_type=norm_type,\n                    dim=spatial_dims_str,\n                )\n            )\n\n        self.enc_stages = nn.ModuleList(enc_stages)\n        self.down_blocks = nn.ModuleList(down_blocks)\n\n        self.bottleneck = nn.Sequential(\n            *[\n                MedNeXtBlock(\n                    in_channels=init_filters * (2 ** len(blocks_down)),\n                    out_channels=init_filters * (2 ** len(blocks_down)),\n                    expansion_ratio=bottleneck_expansion_ratio,\n                    kernel_size=dec_kernel_size,\n                    use_residual_connection=use_residual_connection,\n                    norm_type=norm_type,\n                    dim=spatial_dims_str,\n                    global_resp_norm=global_resp_norm,\n                )\n                for _ in range(blocks_bottleneck)\n            ]\n        )\n\n        up_blocks = []\n        dec_stages = []\n        for i, num_blocks in enumerate(blocks_up):\n            up_blocks.append(\n                MedNeXtUpBlock(\n                    in_channels=init_filters * (2 ** (len(blocks_up) - i)),\n                    out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),\n                    expansion_ratio=decoder_expansion_ratio[i],\n                    kernel_size=dec_kernel_size,\n                    use_residual_connection=use_residual_connection,\n                    norm_type=norm_type,\n                    dim=spatial_dims_str,\n                    global_resp_norm=global_resp_norm,\n                )\n            )\n\n            dec_stages.append(\n                nn.Sequential(\n                    *[\n                        MedNeXtBlock(\n                            in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),\n                            out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),\n                            expansion_ratio=decoder_expansion_ratio[i],\n                            kernel_size=dec_kernel_size,\n                            use_residual_connection=use_residual_connection,\n                            norm_type=norm_type,\n                            dim=spatial_dims_str,\n                            global_resp_norm=global_resp_norm,\n                        )\n                        for _ in range(num_blocks)\n                    ]\n                )\n            )\n\n        self.up_blocks = nn.ModuleList(up_blocks)\n        self.dec_stages = nn.ModuleList(dec_stages)\n\n        self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str)\n\n        if deep_supervision:\n            out_blocks = [\n                MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str)\n                for i in range(1, len(blocks_up) + 1)\n            ]\n\n            out_blocks.reverse()\n            self.out_blocks = nn.ModuleList(out_blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:\n        \"\"\"\n        Forward pass of the MedNeXt model.\n\n        This method performs the forward pass through the model, including:\n        - Stem convolution\n        - Encoder stages and downsampling\n        - Bottleneck blocks\n        - Decoder stages and upsampling with skip connections\n        - Output blocks for deep supervision (if enabled)\n\n        Args:\n            x (torch.Tensor): Input tensor.\n\n        Returns:\n            torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).\n        \"\"\"\n        # Apply stem convolution\n        x = self.stem(x)\n\n        # Encoder forward pass\n        enc_outputs = []\n        for enc_stage, down_block in zip(self.enc_stages, self.down_blocks):\n            x = enc_stage(x)\n            enc_outputs.append(x)\n            x = down_block(x)\n\n        # Bottleneck forward pass\n        x = self.bottleneck(x)\n\n        # Initialize deep supervision outputs if enabled\n        if self.do_ds:\n            ds_outputs = []\n\n        # Decoder forward pass with skip connections\n        for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)):\n            if self.do_ds and i < len(self.out_blocks):\n                ds_outputs.append(self.out_blocks[i](x))\n\n            x = up_block(x)\n            x = x + enc_outputs[-(i + 1)]\n            x = dec_stage(x)\n\n        # Final output block\n        x = self.out_0(x)\n\n        # Return output(s)\n        if self.do_ds and self.training:\n            return (x, *ds_outputs[::-1])\n        else:\n            return x\n\n\n# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975\ndef create_mednext(\n    variant: str,\n    spatial_dims: int = 3,\n    in_channels: int = 1,\n    out_channels: int = 2,\n    kernel_size: int = 3,\n    deep_supervision: bool = False,\n) -> MedNeXt:\n    \"\"\"\n    Factory method to create MedNeXt variants.\n\n    Args:\n        variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').\n        spatial_dims (int): Number of spatial dimensions. Defaults to 3.\n        in_channels (int): Number of input channels. Defaults to 1.\n        out_channels (int): Number of output channels. Defaults to 2.\n        kernel_size (int): Kernel size for convolutions. Defaults to 3.\n        deep_supervision (bool): Whether to use deep supervision. Defaults to False.\n\n    Returns:\n        MedNeXt: The specified MedNeXt variant.\n\n    Raises:\n        ValueError: If an invalid variant is specified.\n    \"\"\"\n    common_args = {\n        \"spatial_dims\": spatial_dims,\n        \"in_channels\": in_channels,\n        \"out_channels\": out_channels,\n        \"kernel_size\": kernel_size,\n        \"deep_supervision\": deep_supervision,\n        \"use_residual_connection\": True,\n        \"norm_type\": \"group\",\n        \"global_resp_norm\": False,\n        \"init_filters\": 32,\n    }\n\n    if variant.upper() == \"S\":\n        return MedNeXt(\n            encoder_expansion_ratio=2,\n            decoder_expansion_ratio=2,\n            bottleneck_expansion_ratio=2,\n            blocks_down=(2, 2, 2, 2),\n            blocks_bottleneck=2,\n            blocks_up=(2, 2, 2, 2),\n            **common_args,  # type: ignore\n        )\n    elif variant.upper() == \"B\":\n        return MedNeXt(\n            encoder_expansion_ratio=(2, 3, 4, 4),\n            decoder_expansion_ratio=(4, 4, 3, 2),\n            bottleneck_expansion_ratio=4,\n            blocks_down=(2, 2, 2, 2),\n            blocks_bottleneck=2,\n            blocks_up=(2, 2, 2, 2),\n            **common_args,  # type: ignore\n        )\n    elif variant.upper() == \"M\":\n        return MedNeXt(\n            encoder_expansion_ratio=(2, 3, 4, 4),\n            decoder_expansion_ratio=(4, 4, 3, 2),\n            bottleneck_expansion_ratio=4,\n            blocks_down=(3, 4, 4, 4),\n            blocks_bottleneck=4,\n            blocks_up=(4, 4, 4, 3),\n            **common_args,  # type: ignore\n        )\n    elif variant.upper() == \"L\":\n        return MedNeXt(\n            encoder_expansion_ratio=(3, 4, 8, 8),\n            decoder_expansion_ratio=(8, 8, 4, 3),\n            bottleneck_expansion_ratio=8,\n            blocks_down=(3, 4, 8, 8),\n            blocks_bottleneck=8,\n            blocks_up=(8, 8, 4, 3),\n            **common_args,  # type: ignore\n        )\n    else:\n        raise ValueError(f\"Invalid MedNeXt variant: {variant}\")\n\n\nMedNext = MedNeXt\nMedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext(\"S\", **kwargs)\nMedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext(\"B\", **kwargs)\nMedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext(\"M\", **kwargs)\nMedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext(\"L\", **kwargs)\n"
  },
  {
    "path": "monai/networks/nets/milmodel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import cast\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.utils import optional_import\n\nmodels, _ = optional_import(\"torchvision.models\")\n\n\nclass MILModel(nn.Module):\n    \"\"\"\n    Multiple Instance Learning (MIL) model, with a backbone classification model.\n    Currently, it only works for 2D images, a typical use case is for classification of the\n    digital pathology whole slide images. The expected shape of input data is `[B, N, C, H, W]`,\n    where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances\n    extracted from every original image in the batch. A tutorial example is available at:\n    https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning.\n\n    Args:\n        num_classes: number of output classes.\n        mil_mode: MIL algorithm, available values (Defaults to ``\"att\"``):\n\n            - ``\"mean\"`` - average features from all instances, equivalent to pure CNN (non MIL).\n            - ``\"max\"`` - retain only the instance with the max probability for loss calculation.\n            - ``\"att\"`` - attention based MIL https://arxiv.org/abs/1802.04712.\n            - ``\"att_trans\"`` - transformer MIL https://arxiv.org/abs/2111.01556.\n            - ``\"att_trans_pyramid\"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556.\n\n        pretrained: init backbone with pretrained weights, defaults to ``True``.\n        backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features,\n            or a string name of a torchvision model).\n            Defaults to ``None``, in which case ResNet50 is used.\n        backbone_num_features: Number of output features of the backbone CNN\n            Defaults to ``None`` (necessary only when using a custom backbone)\n        trans_blocks: number of the blocks in `TransformEncoder` layer.\n        trans_dropout: dropout rate in `TransformEncoder` layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes: int,\n        mil_mode: str = \"att\",\n        pretrained: bool = True,\n        backbone: str | nn.Module | None = None,\n        backbone_num_features: int | None = None,\n        trans_blocks: int = 4,\n        trans_dropout: float = 0.0,\n    ) -> None:\n        super().__init__()\n\n        if num_classes <= 0:\n            raise ValueError(\"Number of classes must be positive: \" + str(num_classes))\n\n        if mil_mode.lower() not in [\"mean\", \"max\", \"att\", \"att_trans\", \"att_trans_pyramid\"]:\n            raise ValueError(\"Unsupported mil_mode: \" + str(mil_mode))\n\n        self.mil_mode = mil_mode.lower()\n        self.attention = nn.Sequential()\n        self.transformer: nn.Module | None = None\n\n        if backbone is None:\n            net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)\n            nfc = net.fc.in_features  # save the number of final features\n            net.fc = torch.nn.Identity()  # remove final linear layer\n\n            self.extra_outputs: dict[str, torch.Tensor] = {}\n\n            if mil_mode == \"att_trans_pyramid\":\n                # register hooks to capture outputs of intermediate layers\n                def forward_hook(layer_name):\n\n                    def hook(module, input, output):\n                        self.extra_outputs[layer_name] = output\n\n                    return hook\n\n                net.layer1.register_forward_hook(forward_hook(\"layer1\"))\n                net.layer2.register_forward_hook(forward_hook(\"layer2\"))\n                net.layer3.register_forward_hook(forward_hook(\"layer3\"))\n                net.layer4.register_forward_hook(forward_hook(\"layer4\"))\n\n        elif isinstance(backbone, str):\n            # assume torchvision model string is provided\n            torch_model = getattr(models, backbone, None)\n            if torch_model is None:\n                raise ValueError(\"Unknown torch vision model\" + str(backbone))\n            net = torch_model(weights=\"DEFAULT\" if pretrained else None)\n\n            if getattr(net, \"fc\", None) is not None:\n                nfc = net.fc.in_features  # save the number of final features\n                net.fc = torch.nn.Identity()  # remove final linear layer\n            else:\n                raise ValueError(\n                    \"Unable to detect FC layer for the torchvision model \" + str(backbone),\n                    \". Please initialize the backbone model manually.\",\n                )\n\n        elif isinstance(backbone, nn.Module):\n            # use a custom backbone\n            net = backbone\n            nfc = backbone_num_features\n\n            if backbone_num_features is None:\n                raise ValueError(\"Number of endencoder features must be provided for a custom backbone model\")\n\n        else:\n            raise ValueError(\"Unsupported backbone\")\n\n        if backbone is not None and mil_mode not in [\"mean\", \"max\", \"att\", \"att_trans\"]:\n            raise ValueError(\"Custom backbone is not supported for the mode:\" + str(mil_mode))\n\n        if self.mil_mode in [\"mean\", \"max\"]:\n            pass\n        elif self.mil_mode == \"att\":\n            self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))\n\n        elif self.mil_mode == \"att_trans\":\n            transformer = nn.TransformerEncoderLayer(d_model=nfc, nhead=8, dropout=trans_dropout)\n            self.transformer = nn.TransformerEncoder(transformer, num_layers=trans_blocks)\n            self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))\n\n        elif self.mil_mode == \"att_trans_pyramid\":\n            transformer_list = nn.ModuleList(\n                [\n                    nn.TransformerEncoder(\n                        nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout), num_layers=trans_blocks\n                    ),\n                    nn.Sequential(\n                        nn.Linear(768, 256),\n                        nn.TransformerEncoder(\n                            nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),\n                            num_layers=trans_blocks,\n                        ),\n                    ),\n                    nn.Sequential(\n                        nn.Linear(1280, 256),\n                        nn.TransformerEncoder(\n                            nn.TransformerEncoderLayer(d_model=256, nhead=8, dropout=trans_dropout),\n                            num_layers=trans_blocks,\n                        ),\n                    ),\n                    nn.TransformerEncoder(\n                        nn.TransformerEncoderLayer(d_model=2304, nhead=8, dropout=trans_dropout),\n                        num_layers=trans_blocks,\n                    ),\n                ]\n            )\n            self.transformer = transformer_list\n            nfc = nfc + 256\n            self.attention = nn.Sequential(nn.Linear(nfc, 2048), nn.Tanh(), nn.Linear(2048, 1))\n\n        else:\n            raise ValueError(\"Unsupported mil_mode: \" + str(mil_mode))\n\n        self.myfc = nn.Linear(nfc, num_classes)\n        self.net = net\n\n    def calc_head(self, x: torch.Tensor) -> torch.Tensor:\n        sh = x.shape\n\n        if self.mil_mode == \"mean\":\n            x = self.myfc(x)\n            x = torch.mean(x, dim=1)\n\n        elif self.mil_mode == \"max\":\n            x = self.myfc(x)\n            x, _ = torch.max(x, dim=1)\n\n        elif self.mil_mode == \"att\":\n            a = self.attention(x)\n            a = torch.softmax(a, dim=1)\n            x = torch.sum(x * a, dim=1)\n\n            x = self.myfc(x)\n\n        elif self.mil_mode == \"att_trans\" and self.transformer is not None:\n            x = x.permute(1, 0, 2)\n            x = self.transformer(x)\n            x = x.permute(1, 0, 2)\n\n            a = self.attention(x)\n            a = torch.softmax(a, dim=1)\n            x = torch.sum(x * a, dim=1)\n\n            x = self.myfc(x)\n\n        elif self.mil_mode == \"att_trans_pyramid\" and self.transformer is not None:\n            l1 = torch.mean(self.extra_outputs[\"layer1\"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)\n            l2 = torch.mean(self.extra_outputs[\"layer2\"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)\n            l3 = torch.mean(self.extra_outputs[\"layer3\"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)\n            l4 = torch.mean(self.extra_outputs[\"layer4\"], dim=(2, 3)).reshape(sh[0], sh[1], -1).permute(1, 0, 2)\n\n            transformer_list = cast(nn.ModuleList, self.transformer)\n\n            x = transformer_list[0](l1)\n            x = transformer_list[1](torch.cat((x, l2), dim=2))\n            x = transformer_list[2](torch.cat((x, l3), dim=2))\n            x = transformer_list[3](torch.cat((x, l4), dim=2))\n\n            x = x.permute(1, 0, 2)\n\n            a = self.attention(x)\n            a = torch.softmax(a, dim=1)\n            x = torch.sum(x * a, dim=1)\n\n            x = self.myfc(x)\n\n        else:\n            raise ValueError(\"Wrong model mode\" + str(self.mil_mode))\n\n        return x\n\n    def forward(self, x: torch.Tensor, no_head: bool = False) -> torch.Tensor:\n        sh = x.shape\n        x = x.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4])\n\n        x = self.net(x)\n        x = x.reshape(sh[0], sh[1], -1)\n\n        if not no_head:\n            x = self.calc_head(x)\n\n        return x\n"
  },
  {
    "path": "monai/networks/nets/netadapter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport torch\n\nfrom monai.networks.layers import Conv, get_pool_layer\nfrom monai.networks.utils import look_up_named_module, set_named_module\nfrom monai.utils import look_up_option, optional_import\n\nget_graph_node_names, _has_utils = optional_import(\"torchvision.models.feature_extraction\", name=\"get_graph_node_names\")\ncreate_feature_extractor, _ = optional_import(\"torchvision.models.feature_extraction\", name=\"create_feature_extractor\")\n\n\nclass NetAdapter(torch.nn.Module):\n    \"\"\"\n    Wrapper to replace the last layer of model by convolutional layer or FC layer.\n\n    See also: :py:class:`monai.networks.nets.TorchVisionFCModel`\n\n    Args:\n        model: a PyTorch model, which can be both 2D and 3D models. typically, it can be a pretrained model\n            in Torchvision, like: ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, etc.\n            more details: https://pytorch.org/vision/stable/models.html.\n        num_classes: number of classes for the last classification layer. Default to 1.\n        dim: number of supported spatial dimensions in the specified model, depends on the model implementation.\n            default to 2 as most Torchvision models are for 2D image processing.\n        in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.\n        use_conv: whether to use convolutional layer to replace the last layer, default to False.\n        pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,\n            the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.\n            default to `(\"avg\", {\"kernel_size\": 7, \"stride\": 1})`.\n        bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,\n            default to True.\n        fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``\"fc\"``.\n        node_name: the corresponding feature extractor node name of `model`.\n            Defaults to \"\", the extractor is not in use.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        num_classes: int = 1,\n        dim: int = 2,\n        in_channels: int | None = None,\n        use_conv: bool = False,\n        pool: tuple[str, dict[str, Any]] | None = (\"avg\", {\"kernel_size\": 7, \"stride\": 1}),\n        bias: bool = True,\n        fc_name: str = \"fc\",\n        node_name: str = \"\",\n    ):\n        super().__init__()\n        layers = list(model.children())\n        orig_fc = look_up_named_module(fc_name, model)\n        if orig_fc is None:\n            orig_fc = layers[-1]\n        # guess the number of input channels of the last fully connected layer\n        in_channels_: int\n        if in_channels is None:\n            if not hasattr(orig_fc, \"in_features\"):\n                raise ValueError(\"please specify input channels of the last fully connected layer with `in_channels`.\")\n            in_channels_ = orig_fc.in_features\n\n        else:\n            in_channels_ = in_channels\n\n        # modify the input model, depending on whether to replace the last pooling layer ``pool``\n        if pool is None:  # no modification of pooling\n            if node_name != \"\":\n                raise ValueError(\"`node_name` is not compatible with `pool=None`, please set `pool=''`.\")\n            # we just drop the model's fully connected layer or set it to identity\n            if look_up_named_module(fc_name, model):\n                self.features = set_named_module(model, fc_name, torch.nn.Identity())\n            else:\n                self.features = torch.nn.Sequential(*layers[:-1])  # assuming FC is the last and model is sequential\n            self.pool = None\n        else:\n            # user-specified new pooling layer, we drop both the pooling and FC layers from the model\n            if node_name and _has_utils:\n                node_name = look_up_option(node_name, get_graph_node_names(model)[0 if model.training else 1])\n                self.features = create_feature_extractor(model, [node_name])\n            else:\n                self.features = torch.nn.Sequential(*layers[:-2])  # assuming the last 2 layers are pooling&FC\n            self.pool = get_pool_layer(name=pool, spatial_dims=dim)\n\n        # create new fully connected layer or kernel size 1 convolutional layer\n        self.fc: torch.nn.Linear | torch.nn.Conv2d | torch.nn.Conv3d\n        if use_conv:\n            self.fc = Conv[Conv.CONV, dim](in_channels=in_channels_, out_channels=num_classes, kernel_size=1, bias=bias)\n        else:\n            self.fc = torch.nn.Linear(in_features=in_channels_, out_features=num_classes, bias=bias)\n        self.use_conv = use_conv\n        self.dim = dim\n        self.node_name = node_name\n\n    def forward(self, x):\n        x = self.features(x)\n        if isinstance(x, tuple):\n            x = x[0]  # it might be a namedtuple such as torchvision.model.InceptionOutputs\n        elif torch.jit.isinstance(x, dict[str, torch.Tensor]):\n            x = x[self.node_name]  # torchvision create_feature_extractor\n        if self.pool is not None:\n            x = self.pool(x)\n        if not self.use_conv:\n            x = torch.flatten(x, 1)\n        else:  # user specified `use_conv` but the pooling layer removed the spatial dims\n            while len(x.shape) < self.dim + 2:\n                x = x[..., None]\n        x = self.fc(x)\n\n        return x\n"
  },
  {
    "path": "monai/networks/nets/patchgan_discriminator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.layers import Act\nfrom monai.networks.utils import normal_init\n\n\nclass MultiScalePatchDiscriminator(nn.Sequential):\n    \"\"\"\n    Multi-scale Patch-GAN discriminator based on Pix2PixHD:\n    High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585)\n\n    The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images\n    at different spatial scales.\n\n    Args:\n        num_d: number of discriminators\n        num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first\n            discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved.\n        spatial_dims: number of spatial dimensions (1D, 2D etc.)\n        channels: number of filters in the first convolutional layer (doubled for each subsequent layer)\n        in_channels: number of input channels\n        out_channels: number of output channels in each discriminator\n        kernel_size: kernel size of the convolution layers\n        activation: activation layer type\n        norm: normalisation type\n        bias: introduction of layer bias\n        dropout: probability of dropout applied, defaults to 0.\n        minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture\n            requested isn't going to downsample the input image beyond value of 1.\n        last_conv_kernel_size: kernel size of the last convolutional layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_d: int,\n        num_layers_d: int,\n        spatial_dims: int,\n        channels: int,\n        in_channels: int,\n        out_channels: int = 1,\n        kernel_size: int = 4,\n        activation: str | tuple = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        norm: str | tuple = \"BATCH\",\n        bias: bool = False,\n        dropout: float | tuple = 0.0,\n        minimum_size_im: int = 256,\n        last_conv_kernel_size: int = 1,\n    ) -> None:\n        super().__init__()\n        self.num_d = num_d\n        self.num_layers_d = num_layers_d\n        self.num_channels = channels\n        self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims)\n        for i_ in range(self.num_d):\n            num_layers_d_i = self.num_layers_d * (i_ + 1)\n            output_size = float(minimum_size_im) / (2**num_layers_d_i)\n            if output_size < 1:\n                raise AssertionError(\n                    f\"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}.\"\n                    \"Please reduce num_layers, reduce num_D or enter bigger images.\"\n                )\n            subnet_d = PatchDiscriminator(\n                spatial_dims=spatial_dims,\n                channels=self.num_channels,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                num_layers_d=num_layers_d_i,\n                kernel_size=kernel_size,\n                activation=activation,\n                norm=norm,\n                bias=bias,\n                padding=self.padding,\n                dropout=dropout,\n                last_conv_kernel_size=last_conv_kernel_size,\n            )\n\n            self.add_module(f\"discriminator_{i_}\", subnet_d)\n\n    def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:\n        \"\"\"\n        Args:\n            i: Input tensor\n\n        Returns:\n            list of outputs and another list of lists with the intermediate features\n            of each discriminator.\n        \"\"\"\n\n        out: list[torch.Tensor] = []\n        intermediate_features: list[list[torch.Tensor]] = []\n        for disc in self.children():\n            out_d: list[torch.Tensor] = disc(i)\n            out.append(out_d[-1])\n            intermediate_features.append(out_d[:-1])\n\n        return out, intermediate_features\n\n\nclass PatchDiscriminator(nn.Sequential):\n    \"\"\"\n    Patch-GAN discriminator based on Pix2PixHD:\n    High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585)\n\n\n    Args:\n        spatial_dims: number of spatial dimensions (1D, 2D etc.)\n        channels: number of filters in the first convolutional layer (doubled for each subsequent layer)\n        in_channels: number of input channels\n        out_channels: number of output channels\n        num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator.\n        kernel_size: kernel size of the convolution layers\n        act: activation type and arguments. Defaults to LeakyReLU.\n        norm: feature normalization type and arguments. Defaults to batch norm.\n        bias: whether to have a bias term in convolution blocks. Defaults to False.\n        padding: padding to be applied to the convolutional layers\n        dropout: proportion of dropout applied, defaults to 0.\n        last_conv_kernel_size: kernel size of the last convolutional layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        channels: int,\n        in_channels: int,\n        out_channels: int = 1,\n        num_layers_d: int = 3,\n        kernel_size: int = 4,\n        activation: str | tuple = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        norm: str | tuple = \"BATCH\",\n        bias: bool = False,\n        padding: int | Sequence[int] = 1,\n        dropout: float | tuple = 0.0,\n        last_conv_kernel_size: int | None = None,\n    ) -> None:\n        super().__init__()\n        self.num_layers_d = num_layers_d\n        self.num_channels = channels\n        if last_conv_kernel_size is None:\n            last_conv_kernel_size = kernel_size\n\n        self.add_module(\n            \"initial_conv\",\n            Convolution(\n                spatial_dims=spatial_dims,\n                kernel_size=kernel_size,\n                in_channels=in_channels,\n                out_channels=channels,\n                act=activation,\n                bias=True,\n                norm=None,\n                dropout=dropout,\n                padding=padding,\n                strides=2,\n            ),\n        )\n\n        input_channels = channels\n        output_channels = channels * 2\n\n        # Initial Layer\n        for l_ in range(self.num_layers_d):\n            if l_ == self.num_layers_d - 1:\n                stride = 1\n            else:\n                stride = 2\n            layer = Convolution(\n                spatial_dims=spatial_dims,\n                kernel_size=kernel_size,\n                in_channels=input_channels,\n                out_channels=output_channels,\n                act=activation,\n                bias=bias,\n                norm=norm,\n                dropout=dropout,\n                padding=padding,\n                strides=stride,\n            )\n            self.add_module(f\"{l_}\", layer)\n            input_channels = output_channels\n            output_channels = output_channels * 2\n\n        # Final layer\n        self.add_module(\n            \"final_conv\",\n            Convolution(\n                spatial_dims=spatial_dims,\n                kernel_size=last_conv_kernel_size,\n                in_channels=input_channels,\n                out_channels=out_channels,\n                bias=True,\n                conv_only=True,\n                padding=int((last_conv_kernel_size - 1) / 2),\n                dropout=0.0,\n                strides=1,\n            ),\n        )\n\n        self.apply(normal_init)\n\n    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:\n        \"\"\"\n        Args:\n            x: input tensor\n\n        Returns:\n            list of intermediate features, with the last element being the output.\n        \"\"\"\n        out = [x]\n        for submodel in self.children():\n            intermediate_output = submodel(out[-1])\n            out.append(intermediate_output)\n\n        return out[1:]\n"
  },
  {
    "path": "monai/networks/nets/quicknat.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks import ConvDenseBlock, Convolution\nfrom monai.networks.blocks import squeeze_and_excitation as se\nfrom monai.networks.layers.factories import Act, Norm\nfrom monai.networks.layers.simplelayers import SkipConnection\nfrom monai.networks.layers.utils import get_dropout_layer, get_pool_layer\nfrom monai.utils import optional_import\n\n# Lazy import to avoid dependency\nse1, flag = optional_import(\"squeeze_and_excitation\")\n\n__all__ = [\"Quicknat\"]\n\n# QuickNAT specific Blocks\n\n\nclass SkipConnectionWithIdx(SkipConnection):\n    \"\"\"\n    Combine the forward pass input with the result from the given submodule::\n    --+--submodule--o--\n      |_____________|\n    The available modes are ``\"cat\"``, ``\"add\"``, ``\"mul\"``.\n    Defaults to \"cat\" and dimension 1.\n    Inherits from SkipConnection but provides the indizes with each forward pass.\n    \"\"\"\n\n    def forward(self, input, indices):  # type: ignore[override]\n        return super().forward(input), indices\n\n\nclass SequentialWithIdx(nn.Sequential):\n    \"\"\"\n    A sequential container.\n    Modules will be added to it in the order they are passed in the\n    constructor.\n    Own implementation to work with the new indices in the forward pass.\n    \"\"\"\n\n    def __init__(self, *args):\n        super().__init__(*args)\n\n    def forward(self, input, indices):  # type: ignore[override]\n        for module in self:\n            input, indices = module(input, indices)\n        return input, indices\n\n\nclass ClassifierBlock(Convolution):\n    \"\"\"\n    Returns a classifier block without an activation function at the top.\n    It consists of a 1 * 1 convolutional layer which maps the input to a num_class channel feature map.\n    The output is a probability map for each of the classes.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of classes to map to.\n        strides: convolution stride. Defaults to 1.\n        kernel_size: convolution kernel size. Defaults to 3.\n        adn_ordering: a string representing the ordering of activation, normalization, and dropout.\n        Defaults to \"NDA\".\n        act: activation type and arguments. Defaults to PReLU.\n\n    \"\"\"\n\n    def __init__(self, spatial_dims, in_channels, out_channels, strides, kernel_size, act=None, adn_ordering=\"A\"):\n        super().__init__(spatial_dims, in_channels, out_channels, strides, kernel_size, adn_ordering, act)\n\n    def forward(self, input: torch.Tensor, weights=None, indices=None):\n        _, channel, *dims = input.size()\n        if weights is not None:\n            weights, _ = torch.max(weights, dim=0)\n            weights = weights.view(1, channel, 1, 1)\n            # use weights to adapt how the classes are weighted.\n            if len(dims) == 2:\n                out_conv = F.conv2d(input, weights)\n            else:\n                raise ValueError(\"Quicknat is a 2D architecture, please check your dimension.\")\n        else:\n            out_conv = super().forward(input)\n        # no indices to return\n        return out_conv, None\n\n\n# Quicknat specific blocks. All blocks inherit from MONAI blocks but have adaptions to their structure\nclass ConvConcatDenseBlock(ConvDenseBlock):\n    \"\"\"\n    This dense block is defined as a sequence of 'Convolution' blocks. It overwrite the '_get_layer' methodto change the ordering of\n    Every convolutional layer is preceded by a batch-normalization layer and a Rectifier Linear Unit (ReLU) layer.\n    The first two convolutional layers are followed by a concatenation layer that concatenates\n    the input feature map with outputs of the current and previous convolutional blocks.\n    Kernel size of two convolutional layers kept small to limit number of paramters.\n    Appropriate padding is provided so that the size of feature maps before and after convolution remains constant.\n    The output channels for each convolution layer is set to 64, which acts as a bottle- neck for feature map selectivity.\n    The input channel size is variable, depending on the number of dense connections.\n    The third convolutional layer is also preceded by a batch normalization and ReLU,\n    but has a 1 * 1 kernel size to compress the feature map size to 64.\n    Args:\n        in_channles: variable depending on depth of the network\n        seLayer: Squeeze and Excite block to be included, defaults to None, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'},\n        dropout_layer: Dropout block to be included, defaults to None.\n    :return: forward passed tensor\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        se_layer: nn.Module | None = None,\n        dropout_layer: nn.Dropout2d | None = None,\n        kernel_size: Sequence[int] | int = 5,\n        num_filters: int = 64,\n    ):\n        self.count = 0\n        super().__init__(\n            in_channels=in_channels,\n            spatial_dims=2,\n            # number of channels stay constant throughout the convolution layers\n            channels=[num_filters, num_filters, num_filters],\n            norm=(\"instance\", {\"num_features\": in_channels}),\n            kernel_size=kernel_size,\n        )\n        self.se_layer = se_layer if se_layer is not None else nn.Identity()\n        self.dropout_layer = dropout_layer if dropout_layer is not None else nn.Identity()\n\n    def _get_layer(self, in_channels, out_channels, dilation):\n        \"\"\"\n        After ever convolutional layer the output is concatenated with the input and the layer before.\n        The concatenated output is used as input to the next convolutional layer.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            strides: convolution stride.\n            is_top: True if this is the top block.\n        \"\"\"\n        kernelsize = self.kernel_size if self.count < 2 else (1, 1)\n        # padding = None if self.count < 2 else (0, 0)\n        self.count += 1\n        conv = Convolution(\n            spatial_dims=self.spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            strides=1,\n            kernel_size=kernelsize,\n            act=self.act,\n            norm=(\"instance\", {\"num_features\": in_channels}),\n        )\n        return nn.Sequential(conv.get_submodule(\"adn\"), conv.get_submodule(\"conv\"))\n\n    def forward(self, input, _):  # type: ignore[override]\n        i = 0\n        result = input\n        result1 = input  # this will not stay this value, needed here for pylint/mypy\n\n        for l in self.children():\n            # ignoring the max (un-)pool and droupout already added in the initial initialization step\n            if isinstance(l, (nn.MaxPool2d, nn.MaxUnpool2d, nn.Dropout2d)):\n                continue\n            # first convolutional forward\n            result = l(result)\n            if i == 0:\n                result1 = result\n                # concatenation with the input feature map\n                result = torch.cat((input, result), dim=1)\n\n            if i == 1:\n                # concatenation with input feature map and feature map from first convolution\n                result = torch.cat((result1, result, input), dim=1)\n            i = i + 1\n\n        # if SELayer or Dropout layer defined put output through layer before returning,\n        # else it just goes through nn.Identity and the output does not change\n        result = self.se_layer(result)\n        result = self.dropout_layer(result)\n\n        return result, None\n\n\nclass Encoder(ConvConcatDenseBlock):\n    \"\"\"\n    Returns a convolution dense block for the encoding (down) part of a layer of the network.\n    This Encoder block downpools the data with max_pool.\n    Its output is used as input to the next layer down.\n    New feature: it returns the indices of the max_pool to the decoder (up) path\n    at the same layer to upsample the input.\n\n    Args:\n        in_channels: number of input channels.\n        max_pool: predefined max_pool layer to downsample the data.\n        se_layer: Squeeze and Excite block to be included, defaults to None.\n        dropout: Dropout block to be included, defaults to None.\n        kernel_size : kernel size of the convolutional layers. Defaults to 5*5\n        num_filters : number of input channels to each convolution block. Defaults to 64\n    \"\"\"\n\n    def __init__(self, in_channels: int, max_pool, se_layer, dropout, kernel_size, num_filters):\n        super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters)\n        self.max_pool = max_pool\n\n    def forward(self, input, indices=None):  # type: ignore[override]\n        input, indices = self.max_pool(input)\n\n        out_block, _ = super().forward(input, None)\n        # safe the indices for unpool on decoder side\n        return out_block, indices\n\n\nclass Decoder(ConvConcatDenseBlock):\n    \"\"\"\n    Returns a convolution dense block for the decoding (up) part of a layer of the network.\n    This will upsample data with an unpool block before the forward.\n    It uses the indices from corresponding encoder on it's level.\n    Its output is used as input to the next layer up.\n\n    Args:\n        in_channels: number of input channels.\n        un_pool: predefined unpool block.\n        se_layer: predefined SELayer. Defaults to None.\n        dropout: predefined dropout block. Defaults to None.\n        kernel_size: Kernel size of convolution layers. Defaults to 5*5.\n        num_filters: number of input channels to each convolution layer. Defaults to 64.\n    \"\"\"\n\n    def __init__(self, in_channels: int, un_pool, se_layer, dropout, kernel_size, num_filters):\n        super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters)\n        self.un_pool = un_pool\n\n    def forward(self, input, indices):  # type: ignore[override]\n        out_block, _ = super().forward(input, None)\n        out_block = self.un_pool(out_block, indices)\n        return out_block, None\n\n\nclass Bottleneck(ConvConcatDenseBlock):\n    \"\"\"\n    Returns the bottom or bottleneck layer at the bottom of a network linking encoder to decoder halves.\n    It consists of a 5 * 5 convolutional layer and a batch normalization layer to separate\n    the encoder and decoder part of the network, restricting information flow between the encoder and decoder.\n\n    Args:\n        in_channels: number of input channels.\n        se_layer: predefined SELayer. Defaults to None.\n        dropout: predefined dropout block. Defaults to None.\n        un_pool: predefined unpool block.\n        max_pool: predefined maxpool block.\n        kernel_size: Kernel size of convolution layers. Defaults to 5*5.\n        num_filters: number of input channels to each convolution layer. Defaults to 64.\n    \"\"\"\n\n    def __init__(self, in_channels: int, se_layer, dropout, max_pool, un_pool, kernel_size, num_filters):\n        super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters)\n        self.max_pool = max_pool\n        self.un_pool = un_pool\n\n    def forward(self, input, indices):  # type: ignore[override]\n        out_block, indices = self.max_pool(input)\n        out_block, _ = super().forward(out_block, None)\n        out_block = self.un_pool(out_block, indices)\n        return out_block, None\n\n\nclass Quicknat(nn.Module):\n    \"\"\"\n    Model for \"Quick segmentation of NeuroAnaTomy (QuickNAT) based on a deep fully convolutional neural network.\n    Refer to: \"QuickNAT: A Fully Convolutional Network for Quick and Accurate Segmentation of Neuroanatomy by\n    Abhijit Guha Roya, Sailesh Conjetib, Nassir Navabb, Christian Wachingera\"\n\n    QuickNAT has an encoder/decoder like 2D F-CNN architecture with 4 encoders and 4 decoders separated by a bottleneck layer.\n    The final layer is a classifier block with softmax.\n    The architecture includes skip connections between all encoder and decoder blocks of the same spatial resolution,\n    similar to the U-Net architecture.\n    All Encoder and Decoder consist of three convolutional layers all with a Batch Normalization and ReLU.\n    The first two convolutional layers are followed by a concatenation layer that concatenates\n    the input feature map with outputs of the current and previous convolutional blocks.\n    The kernel size of the first two convolutional layers is 5*5, the third convolutional layer has a kernel size of 1*1.\n\n    Data in the encode path is downsampled using max pooling layers instead of upsamling like UNet and in the decode path\n    upsampled using max un-pooling layers instead of transpose convolutions.\n    The pooling is done at the beginning of the block and the unpool afterwards.\n    The indices of the max pooling in the Encoder are forwarded through the layer to be available to the corresponding Decoder.\n\n    The bottleneck block consists of a 5 * 5 convolutional layer and a batch normalization layer\n    to separate the encoder and decoder part of the network,\n    restricting information flow between the encoder and decoder.\n\n    The output feature map from the last decoder block is passed to the classifier block,\n    which is a convolutional layer with 1 * 1 kernel size that maps the input to an N channel feature map,\n    where N is the number of segmentation classes.\n\n    To further explain this consider the first example network given below. This network has 3 layers with strides\n    of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input\n    data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of\n    the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its\n    input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this\n    ensures the final output of the network has the same shape as the input.\n\n    The original QuickNAT implementation included a `enable_test_dropout()` mechanism for uncertainty estimation during\n    testing. As the dropout layers are the only stochastic components of this network calling the train() method instead\n    of eval() in testing or inference has the same effect.\n\n    Args:\n        num_classes: number of classes to segmentate (output channels).\n        num_channels: number of input channels.\n        num_filters: number of output channels for each convolutional layer in a Dense Block.\n        kernel_size: size of the kernel of each convolutional layer in a Dense Block.\n        kernel_c: convolution kernel size of classifier block kernel.\n        stride_convolution: convolution stride. Defaults to 1.\n        pool: kernel size of the pooling layer,\n        stride_pool: stride for the pooling layer.\n        se_block: Squeeze and Excite block type to be included, defaults to None. Valid options : NONE, CSE, SSE, CSSE,\n        droup_out: dropout ratio. Defaults to no dropout.\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D).\n            Defaults to \"NA\". See also: :py:class:`monai.networks.blocks.ADN`.\n\n    Examples::\n\n        from monai.networks.nets import QuickNAT\n\n        # network with max pooling by a factor of 2 at each layer with no se_block.\n        net = QuickNAT(\n            num_classes=3,\n            num_channels=1,\n            num_filters=64,\n            pool = 2,\n            se_block = \"None\"\n        )\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes: int = 33,\n        num_channels: int = 1,\n        num_filters: int = 64,\n        kernel_size: Sequence[int] | int = 5,\n        kernel_c: int = 1,\n        stride_conv: int = 1,\n        pool: int = 2,\n        stride_pool: int = 2,\n        # Valid options : NONE, CSE, SSE, CSSE\n        se_block: str = \"None\",\n        drop_out: float = 0,\n        act: tuple | str = Act.PRELU,\n        norm: tuple | str = Norm.INSTANCE,\n        adn_ordering: str = \"NA\",\n    ) -> None:\n        self.act = act\n        self.norm = norm\n        self.adn_ordering = adn_ordering\n        super().__init__()\n        se_layer = self.get_selayer(num_filters, se_block)\n        dropout_layer = get_dropout_layer(name=(\"dropout\", {\"p\": drop_out}), dropout_dim=2)\n        max_pool = get_pool_layer(\n            name=(\"max\", {\"kernel_size\": pool, \"stride\": stride_pool, \"return_indices\": True, \"ceil_mode\": True}),\n            spatial_dims=2,\n        )\n        # for the unpooling layer there is currently no Monai implementation available, return to torch implementation\n        un_pool = nn.MaxUnpool2d(kernel_size=pool, stride=stride_pool)\n\n        # sequence of convolutional strides (like in UNet) not needed as they are always stride_conv. This defaults to 1.\n        def _create_model(layer: int) -> nn.Module:\n            \"\"\"\n            Builds the QuickNAT structure from the bottom up by recursing down to the bottelneck layer, then creating sequential\n            blocks containing the decoder, a skip connection around the previous block, and the encoder.\n            At the last layer a classifier block is added to the Sequential.\n\n            Args:\n                layer = inversproportional to the layers left to create\n            \"\"\"\n            subblock: nn.Module\n            if layer < 4:\n                subblock = _create_model(layer + 1)\n\n            else:\n                subblock = Bottleneck(num_filters, se_layer, dropout_layer, max_pool, un_pool, kernel_size, num_filters)\n\n            if layer == 1:\n                down = ConvConcatDenseBlock(num_channels, se_layer, dropout_layer, kernel_size, num_filters)\n                up = ConvConcatDenseBlock(num_filters * 2, se_layer, dropout_layer, kernel_size, num_filters)\n                classifier = ClassifierBlock(2, num_filters, num_classes, stride_conv, kernel_c)\n                return SequentialWithIdx(down, SkipConnectionWithIdx(subblock), up, classifier)\n            else:\n                up = Decoder(num_filters * 2, un_pool, se_layer, dropout_layer, kernel_size, num_filters)\n                down = Encoder(num_filters, max_pool, se_layer, dropout_layer, kernel_size, num_filters)\n                return SequentialWithIdx(down, SkipConnectionWithIdx(subblock), up)\n\n        self.model = _create_model(1)\n\n    def get_selayer(self, n_filters, se_block_type=\"None\"):\n        \"\"\"\n        Returns the SEBlock defined in the initialization of the QuickNAT model.\n\n        Args:\n            n_filters: encoding half of the layer\n            se_block_type: defaults to None. Valid options are None, CSE, SSE, CSSE\n        Returns: Appropriate SEBlock. SSE and CSSE not implemented in Monai yet.\n        \"\"\"\n        if se_block_type == \"CSE\":\n            return se.ChannelSELayer(2, n_filters)\n        # not implemented in squeeze_and_excitation in monai use other squeeze_and_excitation import:\n        elif se_block_type == \"SSE\" or se_block_type == \"CSSE\":\n            # Throw error if squeeze_and_excitation is not installed\n            if not flag:\n                raise ImportError(\"Please install squeeze_and_excitation locally to use SpatialSELayer\")\n            if se_block_type == \"SSE\":\n                return se1.SpatialSELayer(n_filters)\n            else:\n                return se1.ChannelSpatialSELayer(n_filters)\n        else:\n            return None\n\n    @property\n    def is_cuda(self):\n        \"\"\"\n        Check if model parameters are allocated on the GPU.\n        \"\"\"\n        return next(self.parameters()).is_cuda\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        input, _ = self.model(input, None)\n        return input\n"
  },
  {
    "path": "monai/networks/nets/regressor.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution, ResidualUnit\nfrom monai.networks.layers.convutils import calculate_out_shape, same_padding\nfrom monai.networks.layers.factories import Act, Norm\nfrom monai.networks.layers.simplelayers import Reshape\nfrom monai.utils import ensure_tuple, ensure_tuple_rep\n\n__all__ = [\"Regressor\"]\n\n\nclass Regressor(nn.Module):\n    \"\"\"\n    This defines a network for relating large-sized input tensors to small output tensors, ie. regressing large\n    values to a prediction. An output of a single dimension can be used as value regression or multi-label\n    classification prediction, an output of a single value can be used as a discriminator or critic prediction.\n\n    The network is constructed as a sequence of layers, either :py:class:`monai.networks.blocks.Convolution` or\n    :py:class:`monai.networks.blocks.ResidualUnit`, with a final fully-connected layer resizing the output from the\n    blocks to the final size. Each block is defined with a stride value typically used to downsample the input using\n    strided convolutions. In this way each block progressively condenses information from the input into a deep\n    representation the final fully-connected layer relates to a final result.\n\n    Args:\n        in_shape: tuple of integers stating the dimension of the input tensor (minus batch dimension)\n        out_shape: tuple of integers stating the dimension of the final output tensor (minus batch dimension)\n        channels: tuple of integers stating the output channels of each convolutional layer\n        strides: tuple of integers stating the stride (downscale factor) of each convolutional layer\n        kernel_size: integer or tuple of integers stating size of convolutional kernels\n        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units\n        act: name or type defining activation layers\n        norm: name or type defining normalization layers\n        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout\n        bias: boolean stating if convolution layers should have a bias component\n\n    Examples::\n\n        # infers a 2-value result (eg. a 2D cartesian coordinate) from a 64x64 image\n        net = Regressor((1, 64, 64), (2,), (2, 4, 8), (2, 2, 2))\n\n    \"\"\"\n\n    def __init__(\n        self,\n        in_shape: Sequence[int],\n        out_shape: Sequence[int],\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 2,\n        act=Act.PRELU,\n        norm=Norm.INSTANCE,\n        dropout: float | None = None,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n\n        self.in_channels, *self.in_shape = ensure_tuple(in_shape)\n        self.dimensions = len(self.in_shape)\n        self.channels = ensure_tuple(channels)\n        self.strides = ensure_tuple(strides)\n        self.out_shape = ensure_tuple(out_shape)\n        self.kernel_size = ensure_tuple_rep(kernel_size, self.dimensions)\n        self.num_res_units = num_res_units\n        self.act = act\n        self.norm = norm\n        self.dropout = dropout\n        self.bias = bias\n        self.net = nn.Sequential()\n\n        echannel = self.in_channels\n\n        padding = same_padding(kernel_size)\n\n        self.final_size = np.asarray(self.in_shape, dtype=int)\n        self.reshape = Reshape(*self.out_shape)\n\n        # encode stage\n        for i, (c, s) in enumerate(zip(self.channels, self.strides)):\n            layer = self._get_layer(echannel, c, s, i == len(channels) - 1)\n            echannel = c  # use the output channel number as the input for the next loop\n            self.net.add_module(f\"layer_{i}\", layer)\n            self.final_size = calculate_out_shape(self.final_size, kernel_size, s, padding)  # type: ignore\n\n        self.final = self._get_final_layer((echannel,) + self.final_size)\n\n    def _get_layer(\n        self, in_channels: int, out_channels: int, strides: int, is_last: bool\n    ) -> ResidualUnit | Convolution:\n        \"\"\"\n        Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels`\n        number of channels. The `strides` indicates downsampling factor, ie. convolutional stride. If `is_last`\n        is True this is the final layer and is not expected to include activation and normalization layers.\n        \"\"\"\n\n        layer: ResidualUnit | Convolution\n\n        if self.num_res_units > 0:\n            layer = ResidualUnit(\n                subunits=self.num_res_units,\n                last_conv_only=is_last,\n                spatial_dims=self.dimensions,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                strides=strides,\n                kernel_size=self.kernel_size,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n            )\n        else:\n            layer = Convolution(\n                conv_only=is_last,\n                spatial_dims=self.dimensions,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                strides=strides,\n                kernel_size=self.kernel_size,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n            )\n\n        return layer\n\n    def _get_final_layer(self, in_shape: Sequence[int]):\n        linear = nn.Linear(int(np.prod(in_shape)), int(np.prod(self.out_shape)))\n        return nn.Sequential(nn.Flatten(), linear)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.net(x)\n        x = self.final(x)\n        x = self.reshape(x)\n        return x\n"
  },
  {
    "path": "monai/networks/nets/regunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom monai.networks.blocks.regunet_block import (\n    RegistrationDownSampleBlock,\n    RegistrationExtractionBlock,\n    RegistrationResidualConvBlock,\n    get_conv_block,\n    get_deconv_block,\n)\nfrom monai.networks.utils import meshgrid_ij\n\n__all__ = [\"RegUNet\", \"AffineHead\", \"GlobalNet\", \"LocalNet\"]\n\n\nclass RegUNet(nn.Module):\n    \"\"\"\n    Class that implements an adapted UNet. This class also serve as the parent class of LocalNet and GlobalNet\n\n    Reference:\n        O. Ronneberger, P. Fischer, and T. Brox,\n        “U-net: Convolutional networks for biomedical image segmentation,”,\n        Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.\n        https://arxiv.org/abs/1505.04597\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_channel_initial: int,\n        depth: int,\n        out_kernel_initializer: str | None = \"kaiming_uniform\",\n        out_activation: str | None = None,\n        out_channels: int = 3,\n        extract_levels: tuple[int] | None = None,\n        pooling: bool = True,\n        concat_skip: bool = False,\n        encode_kernel_sizes: int | list[int] = 3,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dims\n            in_channels: number of input channels\n            num_channel_initial: number of initial channels\n            depth: input is at level 0, bottom is at level depth.\n            out_kernel_initializer: kernel initializer for the last layer\n            out_activation: activation at the last layer\n            out_channels: number of channels for the output\n            extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``\n            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv\n            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition\n            encode_kernel_sizes: kernel size for down-sampling\n        \"\"\"\n        super().__init__()\n        if not extract_levels:\n            extract_levels = (depth,)\n        if max(extract_levels) != depth:\n            raise AssertionError\n\n        # save parameters\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.num_channel_initial = num_channel_initial\n        self.depth = depth\n        self.out_kernel_initializer = out_kernel_initializer\n        self.out_activation = out_activation\n        self.out_channels = out_channels\n        self.extract_levels = extract_levels\n        self.pooling = pooling\n        self.concat_skip = concat_skip\n\n        if isinstance(encode_kernel_sizes, int):\n            encode_kernel_sizes = [encode_kernel_sizes] * (self.depth + 1)\n        if len(encode_kernel_sizes) != self.depth + 1:\n            raise AssertionError\n        self.encode_kernel_sizes: list[int] = encode_kernel_sizes\n\n        self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)]\n        self.min_extract_level = min(self.extract_levels)\n\n        # init layers\n        # all lists start with d = 0\n        self.encode_convs: nn.ModuleList\n        self.encode_pools: nn.ModuleList\n        self.bottom_block: nn.Sequential\n        self.decode_deconvs: nn.ModuleList\n        self.decode_convs: nn.ModuleList\n        self.output_block: nn.Module\n\n        # build layers\n        self.build_layers()\n\n    def build_layers(self):\n        self.build_encode_layers()\n        self.build_decode_layers()\n\n    def build_encode_layers(self):\n        # encoding / down-sampling\n        self.encode_convs = nn.ModuleList(\n            [\n                self.build_conv_block(\n                    in_channels=self.in_channels if d == 0 else self.num_channels[d - 1],\n                    out_channels=self.num_channels[d],\n                    kernel_size=self.encode_kernel_sizes[d],\n                )\n                for d in range(self.depth)\n            ]\n        )\n        self.encode_pools = nn.ModuleList(\n            [self.build_down_sampling_block(channels=self.num_channels[d]) for d in range(self.depth)]\n        )\n        self.bottom_block = self.build_bottom_block(\n            in_channels=self.num_channels[-2], out_channels=self.num_channels[-1]\n        )\n\n    def build_conv_block(self, in_channels, out_channels, kernel_size):\n        return nn.Sequential(\n            get_conv_block(\n                spatial_dims=self.spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n            ),\n            RegistrationResidualConvBlock(\n                spatial_dims=self.spatial_dims,\n                in_channels=out_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n            ),\n        )\n\n    def build_down_sampling_block(self, channels: int):\n        return RegistrationDownSampleBlock(spatial_dims=self.spatial_dims, channels=channels, pooling=self.pooling)\n\n    def build_bottom_block(self, in_channels: int, out_channels: int):\n        kernel_size = self.encode_kernel_sizes[self.depth]\n        return nn.Sequential(\n            get_conv_block(\n                spatial_dims=self.spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n            ),\n            RegistrationResidualConvBlock(\n                spatial_dims=self.spatial_dims,\n                in_channels=out_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n            ),\n        )\n\n    def build_decode_layers(self):\n        self.decode_deconvs = nn.ModuleList(\n            [\n                self.build_up_sampling_block(in_channels=self.num_channels[d + 1], out_channels=self.num_channels[d])\n                for d in range(self.depth - 1, self.min_extract_level - 1, -1)\n            ]\n        )\n        self.decode_convs = nn.ModuleList(\n            [\n                self.build_conv_block(\n                    in_channels=(2 * self.num_channels[d] if self.concat_skip else self.num_channels[d]),\n                    out_channels=self.num_channels[d],\n                    kernel_size=3,\n                )\n                for d in range(self.depth - 1, self.min_extract_level - 1, -1)\n            ]\n        )\n\n        # extraction\n        self.output_block = self.build_output_block()\n\n    def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:\n        return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)\n\n    def build_output_block(self) -> nn.Module:\n        return RegistrationExtractionBlock(\n            spatial_dims=self.spatial_dims,\n            extract_levels=self.extract_levels,\n            num_channels=self.num_channels,\n            out_channels=self.out_channels,\n            kernel_initializer=self.out_kernel_initializer,\n            activation=self.out_activation,\n        )\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])\n\n        Returns:\n            Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]), with the same spatial size as ``x``\n        \"\"\"\n        image_size = x.shape[2:]\n        skips = []  # [0, ..., depth - 1]\n        encoded = x\n        for encode_conv, encode_pool in zip(self.encode_convs, self.encode_pools):\n            skip = encode_conv(encoded)\n            encoded = encode_pool(skip)\n            skips.append(skip)\n        decoded = self.bottom_block(encoded)\n\n        outs = [decoded]\n\n        for i, (decode_deconv, decode_conv) in enumerate(zip(self.decode_deconvs, self.decode_convs)):\n            decoded = decode_deconv(decoded)\n            if self.concat_skip:\n                decoded = torch.cat([decoded, skips[-i - 1]], dim=1)\n            else:\n                decoded = decoded + skips[-i - 1]\n            decoded = decode_conv(decoded)\n            outs.append(decoded)\n\n        out = self.output_block(outs, image_size=image_size)\n        return out\n\n\nclass AffineHead(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        image_size: list[int],\n        decode_size: list[int],\n        in_channels: int,\n        save_theta: bool = False,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions\n            image_size: output spatial size\n            decode_size: input spatial size (two or three integers depending on ``spatial_dims``)\n            in_channels: number of input channels\n            save_theta: whether to save the theta matrix estimation\n        \"\"\"\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        if spatial_dims == 2:\n            in_features = in_channels * decode_size[0] * decode_size[1]\n            out_features = 6\n            out_init = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)\n        elif spatial_dims == 3:\n            in_features = in_channels * decode_size[0] * decode_size[1] * decode_size[2]\n            out_features = 12\n            out_init = torch.tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], dtype=torch.float)\n        else:\n            raise ValueError(f\"only support 2D/3D operation, got spatial_dims={spatial_dims}\")\n\n        self.fc = nn.Linear(in_features=in_features, out_features=out_features)\n        self.grid = self.get_reference_grid(image_size)  # (spatial_dims, ...)\n\n        # init weight/bias\n        self.fc.weight.data.zero_()\n        self.fc.bias.data.copy_(out_init)\n\n        self.save_theta = save_theta\n        self.theta = torch.Tensor()\n\n    @staticmethod\n    def get_reference_grid(image_size: tuple[int] | list[int]) -> torch.Tensor:\n        mesh_points = [torch.arange(0, dim) for dim in image_size]\n        grid = torch.stack(meshgrid_ij(*mesh_points), dim=0)  # (spatial_dims, ...)\n        return grid.to(dtype=torch.float)\n\n    def affine_transform(self, theta: torch.Tensor):\n        # (spatial_dims, ...) -> (spatial_dims + 1, ...)\n        grid_padded = torch.cat([self.grid, torch.ones_like(self.grid[:1])])\n\n        # grid_warped[b,p,...] = sum_over_q(grid_padded[q,...] * theta[b,p,q]\n        if self.spatial_dims == 2:\n            grid_warped = torch.einsum(\"qij,bpq->bpij\", grid_padded, theta.reshape(-1, 2, 3))\n        elif self.spatial_dims == 3:\n            grid_warped = torch.einsum(\"qijk,bpq->bpijk\", grid_padded, theta.reshape(-1, 3, 4))\n        else:\n            raise ValueError(f\"do not support spatial_dims={self.spatial_dims}\")\n        return grid_warped\n\n    def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor:\n        f = x[0]\n        self.grid = self.grid.to(device=f.device)\n        theta = self.fc(f.reshape(f.shape[0], -1))\n        if self.save_theta:\n            self.theta = theta.detach()\n        out: torch.Tensor = self.affine_transform(theta) - self.grid\n        return out\n\n\nclass GlobalNet(RegUNet):\n    \"\"\"\n    Build GlobalNet for image registration.\n\n    Reference:\n        Hu, Yipeng, et al.\n        \"Label-driven weakly-supervised learning\n        for multimodal deformable image registration,\"\n        https://arxiv.org/abs/1711.01666\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size: list[int],\n        spatial_dims: int,\n        in_channels: int,\n        num_channel_initial: int,\n        depth: int,\n        out_kernel_initializer: str | None = \"kaiming_uniform\",\n        out_activation: str | None = None,\n        pooling: bool = True,\n        concat_skip: bool = False,\n        encode_kernel_sizes: int | list[int] = 3,\n        save_theta: bool = False,\n    ):\n        \"\"\"\n        Args:\n            image_size: output displacement field spatial size\n            spatial_dims: number of spatial dims\n            in_channels: number of input channels\n            num_channel_initial: number of initial channels\n            depth: input is at level 0, bottom is at level depth.\n            out_kernel_initializer: kernel initializer for the last layer\n            out_activation: activation at the last layer\n            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv\n            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition\n            encode_kernel_sizes: kernel size for down-sampling\n            save_theta: whether to save the theta matrix estimation\n        \"\"\"\n        for size in image_size:\n            if size % (2**depth) != 0:\n                raise ValueError(\n                    f\"given depth {depth}, \"\n                    f\"all input spatial dimension must be divisible by {2 ** depth}, \"\n                    f\"got input of size {image_size}\"\n                )\n        self.image_size = image_size\n        self.decode_size = [size // (2**depth) for size in image_size]\n        self.save_theta = save_theta\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            num_channel_initial=num_channel_initial,\n            depth=depth,\n            out_kernel_initializer=out_kernel_initializer,\n            out_activation=out_activation,\n            out_channels=spatial_dims,\n            pooling=pooling,\n            concat_skip=concat_skip,\n            encode_kernel_sizes=encode_kernel_sizes,\n        )\n\n    def build_output_block(self):\n        return AffineHead(\n            spatial_dims=self.spatial_dims,\n            image_size=self.image_size,\n            decode_size=self.decode_size,\n            in_channels=self.num_channels[-1],\n            save_theta=self.save_theta,\n        )\n\n\nclass AdditiveUpSampleBlock(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        mode: str = \"nearest\",\n        align_corners: bool | None = None,\n    ):\n        super().__init__()\n        self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        output_size = [size * 2 for size in x.shape[2:]]\n        deconved = self.deconv(x)\n        resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners)\n        resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)\n        out: torch.Tensor = deconved + resized\n        return out\n\n\nclass LocalNet(RegUNet):\n    \"\"\"\n    Reimplementation of LocalNet, based on:\n    `Weakly-supervised convolutional neural networks for multimodal image registration\n    <https://doi.org/10.1016/j.media.2018.07.002>`_.\n    `Label-driven weakly-supervised learning for multimodal deformable image registration\n    <https://arxiv.org/abs/1711.01666>`_.\n\n    Adapted from:\n        DeepReg (https://github.com/DeepRegNet/DeepReg)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_channel_initial: int,\n        extract_levels: tuple[int],\n        out_kernel_initializer: str | None = \"kaiming_uniform\",\n        out_activation: str | None = None,\n        out_channels: int = 3,\n        pooling: bool = True,\n        use_additive_sampling: bool = True,\n        concat_skip: bool = False,\n        mode: str = \"nearest\",\n        align_corners: bool | None = None,\n    ):\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dims\n            in_channels: number of input channels\n            num_channel_initial: number of initial channels\n            out_kernel_initializer: kernel initializer for the last layer\n            out_activation: activation at the last layer\n            out_channels: number of channels for the output\n            extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``\n            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d\n            use_additive_sampling: whether use additive up-sampling layer for decoding.\n            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition\n            mode: mode for interpolation when use_additive_sampling, default is \"nearest\".\n            align_corners: align_corners for interpolation when use_additive_sampling, default is None.\n        \"\"\"\n        self.use_additive_upsampling = use_additive_sampling\n        self.mode = mode\n        self.align_corners = align_corners\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            num_channel_initial=num_channel_initial,\n            extract_levels=extract_levels,\n            depth=max(extract_levels),\n            out_kernel_initializer=out_kernel_initializer,\n            out_activation=out_activation,\n            out_channels=out_channels,\n            pooling=pooling,\n            concat_skip=concat_skip,\n            encode_kernel_sizes=[7] + [3] * max(extract_levels),\n        )\n\n    def build_bottom_block(self, in_channels: int, out_channels: int):\n        kernel_size = self.encode_kernel_sizes[self.depth]\n        return get_conv_block(\n            spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size\n        )\n\n    def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:\n        if self.use_additive_upsampling:\n            return AdditiveUpSampleBlock(\n                spatial_dims=self.spatial_dims,\n                in_channels=in_channels,\n                out_channels=out_channels,\n                mode=self.mode,\n                align_corners=self.align_corners,\n            )\n\n        return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)\n"
  },
  {
    "path": "monai/networks/nets/resnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport re\nfrom collections.abc import Callable\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.encoder import BaseEncoder\nfrom monai.networks.layers.factories import Conv, Pool\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer\nfrom monai.utils import ensure_tuple_rep\nfrom monai.utils.module import look_up_option, optional_import\n\nhf_hub_download, _ = optional_import(\"huggingface_hub\", name=\"hf_hub_download\")\nEntryNotFoundError, _ = optional_import(\"huggingface_hub.utils._errors\", name=\"EntryNotFoundError\")\n\nMEDICALNET_HUGGINGFACE_REPO_BASENAME = \"TencentMedicalNet/MedicalNet-Resnet\"\nMEDICALNET_HUGGINGFACE_FILES_BASENAME = \"resnet_\"\n\n__all__ = [\n    \"ResNet\",\n    \"ResNetBlock\",\n    \"ResNetBottleneck\",\n    \"resnet10\",\n    \"resnet18\",\n    \"resnet34\",\n    \"resnet50\",\n    \"resnet101\",\n    \"resnet152\",\n    \"resnet200\",\n]\n\nresnet_params = {\n    # model_name: (block, layers, shortcut_type, bias_downsample, datasets23)\n    \"resnet10\": (\"basic\", [1, 1, 1, 1], \"B\", False, True),\n    \"resnet18\": (\"basic\", [2, 2, 2, 2], \"A\", True, True),\n    \"resnet34\": (\"basic\", [3, 4, 6, 3], \"A\", True, True),\n    \"resnet50\": (\"bottleneck\", [3, 4, 6, 3], \"B\", False, True),\n    \"resnet101\": (\"bottleneck\", [3, 4, 23, 3], \"B\", False, False),\n    \"resnet152\": (\"bottleneck\", [3, 8, 36, 3], \"B\", False, False),\n    \"resnet200\": (\"bottleneck\", [3, 24, 36, 3], \"B\", False, False),\n}\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_inplanes():\n    return [64, 128, 256, 512]\n\n\ndef get_avgpool():\n    return [0, 1, (1, 1), (1, 1, 1)]\n\n\nclass ResNetBlock(nn.Module):\n    expansion = 1\n\n    def __init__(\n        self,\n        in_planes: int,\n        planes: int,\n        spatial_dims: int = 3,\n        stride: int = 1,\n        downsample: nn.Module | partial | None = None,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n    ) -> None:\n        \"\"\"\n        Args:\n            in_planes: number of input channels.\n            planes: number of output channels.\n            spatial_dims: number of spatial dimensions of the input image.\n            stride: stride to use for first conv layer.\n            downsample: which downsample layer to use.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n        \"\"\"\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n\n        self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)\n        self.bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)\n        self.act = get_act_layer(name=act)\n        self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)\n        self.bn2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        residual = x\n\n        out: torch.Tensor = self.conv1(x)\n        out = self.bn1(out)\n        out = self.act(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.act(out)\n\n        return out\n\n\nclass ResNetBottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(\n        self,\n        in_planes: int,\n        planes: int,\n        spatial_dims: int = 3,\n        stride: int = 1,\n        downsample: nn.Module | partial | None = None,\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n    ) -> None:\n        \"\"\"\n        Args:\n            in_planes: number of input channels.\n            planes: number of output channels (taking expansion into account).\n            spatial_dims: number of spatial dimensions of the input image.\n            stride: stride to use for second conv layer.\n            downsample: which downsample layer to use.\n            act: activation type and arguments. Defaults to relu.\n            norm: feature normalization type and arguments. Defaults to batch norm.\n        \"\"\"\n\n        super().__init__()\n\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n        norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)\n\n        self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)\n        self.bn1 = norm_layer(channels=planes)\n        self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)\n        self.bn2 = norm_layer(channels=planes)\n        self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = norm_layer(channels=planes * self.expansion)\n        self.act = get_act_layer(name=act)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        residual = x\n\n        out: torch.Tensor = self.conv1(x)\n        out = self.bn1(out)\n        out = self.act(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.act(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.act(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    \"\"\"\n    ResNet based on: `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`_\n    and `Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet? <https://arxiv.org/pdf/1711.09577.pdf>`_.\n    Adapted from `<https://github.com/kenshohara/3D-ResNets-PyTorch/tree/master/models>`_.\n\n    Args:\n        block: which ResNet block to use, either Basic or Bottleneck.\n            ResNet block class or str.\n            for Basic: ResNetBlock or 'basic'\n            for Bottleneck: ResNetBottleneck or 'bottleneck'\n        layers: how many layers to use.\n        block_inplanes: determine the size of planes at each step. Also tunable with widen_factor.\n        spatial_dims: number of spatial dimensions of the input image.\n        n_input_channels: number of input channels for first convolutional layer.\n        conv1_t_size: size of first convolution layer, determines kernel and padding.\n        conv1_t_stride: stride of first convolution layer.\n        no_max_pool: bool argument to determine if to use maxpool layer.\n        shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.\n            - 'A': using `self._downsample_basic_block`.\n            - 'B': kernel_size 1 conv + norm.\n        widen_factor: widen output for each layer.\n        num_classes: number of output (classifications).\n        feed_forward: whether to add the FC layer for the output, default to `True`.\n        bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.\n        act: activation type and arguments. Defaults to relu.\n        norm: feature normalization type and arguments. Defaults to batch norm.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        block: type[ResNetBlock | ResNetBottleneck] | str,\n        layers: list[int],\n        block_inplanes: list[int],\n        spatial_dims: int = 3,\n        n_input_channels: int = 3,\n        conv1_t_size: tuple[int] | int = 7,\n        conv1_t_stride: tuple[int] | int = 1,\n        no_max_pool: bool = False,\n        shortcut_type: str = \"B\",\n        widen_factor: float = 1.0,\n        num_classes: int = 400,\n        feed_forward: bool = True,\n        bias_downsample: bool = True,  # for backwards compatibility (also see PR #5477)\n        act: str | tuple = (\"relu\", {\"inplace\": True}),\n        norm: str | tuple = \"batch\",\n    ) -> None:\n        super().__init__()\n\n        if isinstance(block, str):\n            if block == \"basic\":\n                block = ResNetBlock\n            elif block == \"bottleneck\":\n                block = ResNetBottleneck\n            else:\n                raise ValueError(f\"Unknown block '{block}', use basic or bottleneck\")\n\n        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n        pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n        avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[\n            Pool.ADAPTIVEAVG, spatial_dims\n        ]\n\n        block_avgpool = get_avgpool()\n        block_inplanes = [int(x * widen_factor) for x in block_inplanes]\n\n        self.in_planes = block_inplanes[0]\n        self.no_max_pool = no_max_pool\n        self.bias_downsample = bias_downsample\n\n        conv1_kernel_size = ensure_tuple_rep(conv1_t_size, spatial_dims)\n        conv1_stride = ensure_tuple_rep(conv1_t_stride, spatial_dims)\n\n        self.conv1 = conv_type(\n            n_input_channels,\n            self.in_planes,\n            kernel_size=conv1_kernel_size,\n            stride=conv1_stride,\n            padding=tuple(k // 2 for k in conv1_kernel_size),\n            bias=False,\n        )\n\n        norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)\n        self.bn1 = norm_layer\n        self.act = get_act_layer(name=act)\n        self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)\n        self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2)\n        self.layer3 = self._make_layer(block, block_inplanes[2], layers[2], spatial_dims, shortcut_type, stride=2)\n        self.layer4 = self._make_layer(block, block_inplanes[3], layers[3], spatial_dims, shortcut_type, stride=2)\n        self.avgpool = avgp_type(block_avgpool[spatial_dims])\n        self.fc = nn.Linear(block_inplanes[3] * block.expansion, num_classes) if feed_forward else None\n\n        for m in self.modules():\n            if isinstance(m, conv_type):\n                nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, type(norm_layer)):\n                nn.init.constant_(torch.as_tensor(m.weight), 1)\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n\n    def _downsample_basic_block(self, x: torch.Tensor, planes: int, stride: int, spatial_dims: int = 3) -> torch.Tensor:\n        out: torch.Tensor = get_pool_layer((\"avg\", {\"kernel_size\": 1, \"stride\": stride}), spatial_dims=spatial_dims)(x)\n        zero_pads = torch.zeros(out.size(0), planes - out.size(1), *out.shape[2:], dtype=out.dtype, device=out.device)\n        out = torch.cat([out.data, zero_pads], dim=1)\n        return out\n\n    def _make_layer(\n        self,\n        block: type[ResNetBlock | ResNetBottleneck],\n        planes: int,\n        blocks: int,\n        spatial_dims: int,\n        shortcut_type: str,\n        stride: int = 1,\n        norm: str | tuple = \"batch\",\n    ) -> nn.Sequential:\n        conv_type: Callable = Conv[Conv.CONV, spatial_dims]\n\n        downsample: nn.Module | partial | None = None\n        if stride != 1 or self.in_planes != planes * block.expansion:\n            if look_up_option(shortcut_type, {\"A\", \"B\"}) == \"A\":\n                downsample = partial(\n                    self._downsample_basic_block,\n                    planes=planes * block.expansion,\n                    stride=stride,\n                    spatial_dims=spatial_dims,\n                )\n            else:\n                downsample = nn.Sequential(\n                    conv_type(\n                        self.in_planes,\n                        planes * block.expansion,\n                        kernel_size=1,\n                        stride=stride,\n                        bias=self.bias_downsample,\n                    ),\n                    get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),\n                )\n\n        layers = [\n            block(\n                in_planes=self.in_planes,\n                planes=planes,\n                spatial_dims=spatial_dims,\n                stride=stride,\n                downsample=downsample,\n                norm=norm,\n            )\n        ]\n\n        self.in_planes = planes * block.expansion\n        for _i in range(1, blocks):\n            layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.act(x)\n        if not self.no_max_pool:\n            x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n\n        x = x.view(x.size(0), -1)\n        if self.fc is not None:\n            x = self.fc(x)\n\n        return x\n\n\nclass ResNetFeatures(ResNet):\n\n    def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None:\n        \"\"\"Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for\n        segmentation and objection models.\n\n        Compared with the class `ResNet`, the only different place is the forward function.\n\n        Args:\n            model_name: name of model to initialize, can be from [resnet10, ..., resnet200].\n            pretrained: whether to initialize pretrained MedicalNet weights,\n                only available for spatial_dims=3 and in_channels=1.\n            spatial_dims: number of spatial dimensions of the input image.\n            in_channels: number of input channels for first convolutional layer.\n        \"\"\"\n        if model_name not in resnet_params:\n            model_name_string = \", \".join(resnet_params.keys())\n            raise ValueError(f\"invalid model_name {model_name} found, must be one of {model_name_string} \")\n\n        block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name]\n\n        super().__init__(\n            block=block,\n            layers=layers,\n            block_inplanes=get_inplanes(),\n            spatial_dims=spatial_dims,\n            n_input_channels=in_channels,\n            conv1_t_stride=2,\n            shortcut_type=shortcut_type,\n            feed_forward=False,\n            bias_downsample=bias_downsample,\n        )\n        if pretrained:\n            if spatial_dims == 3 and in_channels == 1:\n                _load_state_dict(self, model_name, datasets23=datasets23)\n            else:\n                raise ValueError(\"Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.\")\n\n    def forward(self, inputs: torch.Tensor):\n        \"\"\"\n        Args:\n            inputs: input should have spatially N dimensions\n            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.\n\n        Returns:\n            a list of torch Tensors.\n        \"\"\"\n        x = self.conv1(inputs)\n        x = self.bn1(x)\n        x = self.act(x)\n\n        features = []\n        features.append(x)\n\n        if not self.no_max_pool:\n            x = self.maxpool(x)\n\n        x = self.layer1(x)\n        features.append(x)\n\n        x = self.layer2(x)\n        features.append(x)\n\n        x = self.layer3(x)\n        features.append(x)\n\n        x = self.layer4(x)\n        features.append(x)\n\n        return features\n\n\nclass ResNetEncoder(ResNetFeatures, BaseEncoder):\n    \"\"\"Wrap the original resnet to an encoder for flexible-unet.\"\"\"\n\n    backbone_names = [\"resnet10\", \"resnet18\", \"resnet34\", \"resnet50\", \"resnet101\", \"resnet152\", \"resnet200\"]\n\n    @classmethod\n    def get_encoder_parameters(cls) -> list[dict]:\n        \"\"\"Get the initialization parameter for resnet backbones.\"\"\"\n        parameter_list = []\n        for backbone_name in cls.backbone_names:\n            parameter_list.append(\n                {\"model_name\": backbone_name, \"pretrained\": True, \"spatial_dims\": 3, \"in_channels\": 1}\n            )\n        return parameter_list\n\n    @classmethod\n    def num_channels_per_output(cls) -> list[tuple[int, ...]]:\n        \"\"\"Get number of resnet backbone output feature maps channel.\"\"\"\n        return [\n            (64, 64, 128, 256, 512),\n            (64, 64, 128, 256, 512),\n            (64, 64, 128, 256, 512),\n            (64, 256, 512, 1024, 2048),\n            (64, 256, 512, 1024, 2048),\n            (64, 256, 512, 1024, 2048),\n            (64, 256, 512, 1024, 2048),\n        ]\n\n    @classmethod\n    def num_outputs(cls) -> list[int]:\n        \"\"\"Get number of resnet backbone output feature maps.\n\n        Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.\n        \"\"\"\n        return [5] * 7\n\n    @classmethod\n    def get_encoder_names(cls) -> list[str]:\n        \"\"\"Get names of resnet backbones.\"\"\"\n        return cls.backbone_names\n\n\ndef _resnet(\n    arch: str,\n    block: type[ResNetBlock | ResNetBottleneck],\n    layers: list[int],\n    block_inplanes: list[int],\n    pretrained: bool | str,\n    progress: bool,\n    **kwargs: Any,\n) -> ResNet:\n    model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)\n    if pretrained:\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        if isinstance(pretrained, str):\n            if Path(pretrained).exists():\n                logger.info(f\"Loading weights from {pretrained}...\")\n                model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)\n            else:\n                # Throw error\n                raise FileNotFoundError(\"The pretrained checkpoint file is not found\")\n        else:\n            # Also check bias downsample and shortcut.\n            if kwargs.get(\"spatial_dims\", 3) == 3:\n                if kwargs.get(\"n_input_channels\", 3) == 1 and kwargs.get(\"feed_forward\", True) is False:\n                    search_res = re.search(r\"resnet(\\d+)\", arch)\n                    if search_res:\n                        resnet_depth = int(search_res.group(1))\n                    else:\n                        raise ValueError(\"arch argument should be as 'resnet_{resnet_depth}\")\n\n                    # Check model bias_downsample and shortcut_type\n                    bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)\n                    if shortcut_type == kwargs.get(\"shortcut_type\", \"B\") and (\n                        bias_downsample == kwargs.get(\"bias_downsample\", True)\n                    ):\n                        # Download the MedicalNet pretrained model\n                        model_state_dict = get_pretrained_resnet_medicalnet(\n                            resnet_depth, device=device, datasets23=True\n                        )\n                    else:\n                        raise NotImplementedError(\n                            f\"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} \"\n                            f\"when using pretrained MedicalNet resnet{resnet_depth}\"\n                        )\n                else:\n                    raise NotImplementedError(\n                        \"Please set n_input_channels to 1\"\n                        \"and feed_forward to False in order to use MedicalNet pretrained weights\"\n                    )\n            else:\n                raise NotImplementedError(\"MedicalNet pretrained weights are only avalaible for 3D models\")\n        model_state_dict = {key.replace(\"module.\", \"\"): value for key, value in model_state_dict.items()}\n        model.load_state_dict(model_state_dict, strict=True)\n    return model\n\n\ndef resnet10(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-10 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet10\", ResNetBlock, [1, 1, 1, 1], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-18 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet18\", ResNetBlock, [2, 2, 2, 2], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-34 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet34\", ResNetBlock, [3, 4, 6, 3], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-50 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet50\", ResNetBottleneck, [3, 4, 6, 3], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-101 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet101\", ResNetBottleneck, [3, 4, 23, 3], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-152 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet152\", ResNetBottleneck, [3, 8, 36, 3], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    \"\"\"ResNet-200 with optional pretrained support when `spatial_dims` is 3.\n\n    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet(\"resnet200\", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs)\n\n\ndef get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = \"cpu\", datasets23: bool = True):\n    \"\"\"\n    Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet\n\n    Args:\n        resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200\n        device: device on which the returned state dict will be loaded. \"cpu\" or \"cuda\" for example.\n        datasets23: if True, get the weights trained on more datasets (23).\n                    Not all depths are available. If not, standard weights are returned.\n\n    Returns:\n        Pretrained state dict\n\n    Raises:\n        huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub\n        NotImplementedError: if `resnet_depth` is not supported\n    \"\"\"\n\n    medicalnet_huggingface_repo_basename = \"TencentMedicalNet/MedicalNet-Resnet\"\n    medicalnet_huggingface_files_basename = \"resnet_\"\n    supported_depth = [10, 18, 34, 50, 101, 152, 200]\n\n    logger.info(\n        f\"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}\"\n    )\n\n    if resnet_depth in supported_depth:\n        filename = (\n            f\"{medicalnet_huggingface_files_basename}{resnet_depth}.pth\"\n            if not datasets23\n            else f\"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth\"\n        )\n        try:\n            pretrained_path = hf_hub_download(\n                repo_id=f\"{medicalnet_huggingface_repo_basename}{resnet_depth}\", filename=filename\n            )\n        except Exception:\n            if datasets23:\n                logger.info(f\"{filename} not available for resnet{resnet_depth}\")\n                filename = f\"{medicalnet_huggingface_files_basename}{resnet_depth}.pth\"\n                logger.info(f\"Trying with {filename}\")\n                pretrained_path = hf_hub_download(\n                    repo_id=f\"{medicalnet_huggingface_repo_basename}{resnet_depth}\", filename=filename\n                )\n            else:\n                raise EntryNotFoundError(\n                    f\"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}\"\n                ) from None\n        checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)\n    else:\n        raise NotImplementedError(\"Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]\")\n    logger.info(f\"{filename} downloaded\")\n    return checkpoint.get(\"state_dict\")\n\n\ndef get_medicalnet_pretrained_resnet_args(resnet_depth: int):\n    \"\"\"\n    Return correct shortcut_type and bias_downsample\n    for pretrained MedicalNet weights according to resnet depth.\n    \"\"\"\n    # After testing\n    # False: 10, 50, 101, 152, 200\n    # Any: 18, 34\n    bias_downsample = resnet_depth in (18, 34)\n    shortcut_type = \"A\" if resnet_depth in [18, 34] else \"B\"\n    return bias_downsample, shortcut_type\n\n\ndef _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None:\n    search_res = re.search(r\"resnet(\\d+)\", model_name)\n    if search_res:\n        resnet_depth = int(search_res.group(1))\n        datasets23 = model_name.endswith(\"_23datasets\")\n    else:\n        raise ValueError(\"model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.\")\n\n    model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=\"cpu\", datasets23=datasets23)\n    model_state_dict = {key.replace(\"module.\", \"\"): value for key, value in model_state_dict.items()}\n    model.load_state_dict(model_state_dict)\n"
  },
  {
    "path": "monai/networks/nets/restormer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.cablock import CABlock, FeedForward\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.blocks.downsample import DownSample\nfrom monai.networks.blocks.upsample import UpSample\nfrom monai.networks.layers.factories import Norm\nfrom monai.utils.enums import DownsampleMode, UpsampleMode\n\n\nclass MDTATransformerBlock(nn.Module):\n    \"\"\"Basic transformer unit combining MDTA and GDFN with skip connections.\n    Unlike standard transformers that use LayerNorm, this block uses Instance Norm\n    for better adaptation to image restoration tasks.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (2D or 3D)\n        dim: Number of input channels\n        num_heads: Number of attention heads\n        ffn_expansion_factor: Expansion factor for feed-forward network\n        bias: Whether to use bias in attention layers\n        layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False.\n        flash_attention: Whether to use flash attention optimization. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        dim: int,\n        num_heads: int,\n        ffn_expansion_factor: float,\n        bias: bool,\n        layer_norm_use_bias: bool = False,\n        flash_attention: bool = False,\n    ):\n        super().__init__()\n        self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)\n        self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention)\n        self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias)\n        self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = x + self.attn(self.norm1(x))\n        x = x + self.ffn(self.norm2(x))\n        return x\n\n\nclass OverlapPatchEmbed(Convolution):\n    \"\"\"Initial feature extraction using overlapped convolutions.\n    Unlike standard patch embeddings that use non-overlapping patches,\n    this approach maintains spatial continuity through 3x3 convolutions.\n\n    Args:\n        spatial_dims: Number of spatial dimensions (2D or 3D)\n        in_channels: Number of input channels\n        embed_dim: Dimension of embedded features. Defaults to 48.\n        bias: Whether to use bias in convolution layer. Defaults to False.\n    \"\"\"\n\n    def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False):\n        super().__init__(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=embed_dim,\n            kernel_size=3,\n            strides=1,\n            padding=1,\n            bias=bias,\n            conv_only=True,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = super().forward(x)\n        return x\n\n\nclass Restormer(nn.Module):\n    \"\"\"Restormer: Efficient Transformer for High-Resolution Image Restoration.\n\n    Implements a U-Net style architecture with transformer blocks, combining:\n    - Multi-scale feature processing through progressive down/upsampling\n    - Efficient attention via MDTA blocks\n    - Local feature mixing through GDFN\n    - Skip connections for preserving spatial details\n\n    Architecture:\n        - Encoder: Progressive feature downsampling with increasing channels\n        - Latent: Deep feature processing at lowest resolution\n        - Decoder: Progressive upsampling with skip connections\n        - Refinement: Final feature enhancement\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 2,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        dim: int = 48,\n        num_blocks: tuple[int, ...] = (1, 1, 1, 1),\n        heads: tuple[int, ...] = (1, 1, 1, 1),\n        num_refinement_blocks: int = 4,\n        ffn_expansion_factor: float = 2.66,\n        bias: bool = False,\n        layer_norm_use_bias: bool = True,\n        dual_pixel_task: bool = False,\n        flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        \"\"\"Initialize Restormer model.\n\n        Args:\n            spatial_dims: Number of spatial dimensions (2D or 3D)\n            in_channels: Number of input image channels\n            out_channels: Number of output image channels\n            dim: Base feature dimension. Defaults to 48.\n            num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1).\n            heads: Number of attention heads at each scale. Defaults to (1,1,1,1).\n            num_refinement_blocks: Number of final refinement blocks. Defaults to 4.\n            ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66.\n            bias: Whether to use bias in convolutions. Defaults to False.\n            layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True.\n            dual_pixel_task: Enable dual-pixel specific processing. Defaults to False.\n            flash_attention: Use flash attention if available. Defaults to False.\n\n        Note:\n            The number of blocks must be greater than 1\n            The length of num_blocks and heads must be equal\n            All values in num_blocks must be greater than 0\n        \"\"\"\n        # Check input parameters\n        assert len(num_blocks) > 1, \"Number of blocks must be greater than 1\"\n        assert len(num_blocks) == len(heads), \"Number of blocks and heads must be equal\"\n        assert all(n > 0 for n in num_blocks), \"Number of blocks must be greater than 0\"\n\n        # Initial feature extraction\n        self.patch_embed = OverlapPatchEmbed(spatial_dims, in_channels, dim)\n        self.encoder_levels = nn.ModuleList()\n        self.downsamples = nn.ModuleList()\n        self.decoder_levels = nn.ModuleList()\n        self.upsamples = nn.ModuleList()\n        self.reduce_channels = nn.ModuleList()\n        num_steps = len(num_blocks) - 1\n        self.num_steps = num_steps\n        self.spatial_dims = spatial_dims\n        spatial_multiplier = 2 ** (spatial_dims - 1)\n\n        # Define encoder levels\n        for n in range(num_steps):\n            current_dim = dim * (2) ** (n)\n            next_dim = current_dim // spatial_multiplier\n            self.encoder_levels.append(\n                nn.Sequential(\n                    *[\n                        MDTATransformerBlock(\n                            spatial_dims=spatial_dims,\n                            dim=current_dim,\n                            num_heads=heads[n],\n                            ffn_expansion_factor=ffn_expansion_factor,\n                            bias=bias,\n                            layer_norm_use_bias=layer_norm_use_bias,\n                            flash_attention=flash_attention,\n                        )\n                        for _ in range(num_blocks[n])\n                    ]\n                )\n            )\n\n            self.downsamples.append(\n                DownSample(\n                    spatial_dims=self.spatial_dims,\n                    in_channels=current_dim,\n                    out_channels=next_dim,\n                    mode=DownsampleMode.PIXELUNSHUFFLE,\n                    scale_factor=2,\n                    bias=bias,\n                )\n            )\n\n        # Define latent space\n        latent_dim = dim * (2) ** (num_steps)\n        self.latent = nn.Sequential(\n            *[\n                MDTATransformerBlock(\n                    spatial_dims=spatial_dims,\n                    dim=latent_dim,\n                    num_heads=heads[num_steps],\n                    ffn_expansion_factor=ffn_expansion_factor,\n                    bias=bias,\n                    layer_norm_use_bias=layer_norm_use_bias,\n                    flash_attention=flash_attention,\n                )\n                for _ in range(num_blocks[num_steps])\n            ]\n        )\n\n        # Define decoder levels\n        for n in reversed(range(num_steps)):\n            current_dim = dim * (2) ** (n)\n            next_dim = dim * (2) ** (n + 1)\n            self.upsamples.append(\n                UpSample(\n                    spatial_dims=self.spatial_dims,\n                    in_channels=next_dim,\n                    out_channels=(current_dim),\n                    mode=UpsampleMode.PIXELSHUFFLE,\n                    scale_factor=2,\n                    bias=bias,\n                    apply_pad_pool=False,\n                )\n            )\n\n            # Reduce channel layers to deal with skip connections\n            if n != 0:\n                self.reduce_channels.append(\n                    Convolution(\n                        spatial_dims=self.spatial_dims,\n                        in_channels=next_dim,\n                        out_channels=current_dim,\n                        kernel_size=1,\n                        bias=bias,\n                        conv_only=True,\n                    )\n                )\n                decoder_dim = current_dim\n            else:\n                decoder_dim = next_dim\n\n            self.decoder_levels.append(\n                nn.Sequential(\n                    *[\n                        MDTATransformerBlock(\n                            spatial_dims=spatial_dims,\n                            dim=decoder_dim,\n                            num_heads=heads[n],\n                            ffn_expansion_factor=ffn_expansion_factor,\n                            bias=bias,\n                            layer_norm_use_bias=layer_norm_use_bias,\n                            flash_attention=flash_attention,\n                        )\n                        for _ in range(num_blocks[n])\n                    ]\n                )\n            )\n\n        # Final refinement and output\n        self.refinement = nn.Sequential(\n            *[\n                MDTATransformerBlock(\n                    spatial_dims=spatial_dims,\n                    dim=decoder_dim,\n                    num_heads=heads[0],\n                    ffn_expansion_factor=ffn_expansion_factor,\n                    bias=bias,\n                    layer_norm_use_bias=layer_norm_use_bias,\n                    flash_attention=flash_attention,\n                )\n                for _ in range(num_refinement_blocks)\n            ]\n        )\n        self.dual_pixel_task = dual_pixel_task\n        if self.dual_pixel_task:\n            self.skip_conv = Convolution(\n                spatial_dims=self.spatial_dims,\n                in_channels=dim,\n                out_channels=dim * 2,\n                kernel_size=1,\n                bias=bias,\n                conv_only=True,\n            )\n        self.output = Convolution(\n            spatial_dims=self.spatial_dims,\n            in_channels=dim * 2,\n            out_channels=out_channels,\n            kernel_size=3,\n            strides=1,\n            padding=1,\n            bias=bias,\n            conv_only=True,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Forward pass of Restormer.\n        Processes input through encoder-decoder architecture with skip connections.\n        Args:\n            inp_img: Input image tensor of shape (B, C, H, W, [D])\n\n        Returns:\n            Restored image tensor of shape (B, C, H, W, [D])\n        \"\"\"\n        assert all(\n            x.shape[-i] > 2**self.num_steps for i in range(1, self.spatial_dims + 1)\n        ), \"All spatial dimensions should be larger than 2^number_of_step\"\n\n        # Patch embedding\n        x = self.patch_embed(x)\n        skip_connections = []\n\n        # Encoding path\n        for _idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)):\n            x = encoder(x)\n            skip_connections.append(x)\n            x = downsample(x)\n\n        # Latent space\n        x = self.latent(x)\n\n        # Decoding path\n        for idx in range(len(self.decoder_levels)):\n            x = self.upsamples[idx](x)\n            x = torch.concat([x, skip_connections[-(idx + 1)]], 1)\n            if idx < len(self.decoder_levels) - 1:\n                x = self.reduce_channels[idx](x)\n            x = self.decoder_levels[idx](x)\n\n        # Final refinement\n        x = self.refinement(x)\n\n        if self.dual_pixel_task:\n            x = x + self.skip_conv(skip_connections[0])\n            x = self.output(x)\n        else:\n            x = self.output(x)\n\n        return x\n"
  },
  {
    "path": "monai/networks/nets/segresnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks.segresnet_block import ResBlock, get_conv_layer, get_upsample_layer\nfrom monai.networks.layers.factories import Dropout\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\nfrom monai.utils import UpsampleMode\n\n__all__ = [\"SegResNet\", \"SegResNetVAE\"]\n\n\nclass SegResNet(nn.Module):\n    \"\"\"\n    SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization\n    <https://arxiv.org/pdf/1810.11654.pdf>`_.\n    The module does not include the variational autoencoder (VAE).\n    The model supports 2D or 3D inputs.\n\n    Args:\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        init_filters: number of output channels for initial convolution layer. Defaults to 8.\n        in_channels: number of input channels for the network. Defaults to 1.\n        out_channels: number of output channels for the network. Defaults to 2.\n        dropout_prob: probability of an element to be zero-ed. Defaults to ``None``.\n        act: activation type and arguments. Defaults to ``RELU``.\n        norm: feature normalization type and arguments. Defaults to ``GROUP``.\n        norm_name: deprecating option for feature normalization type.\n        num_groups: deprecating option for group norm. parameters.\n        use_conv_final: if add a final convolution block to output. Defaults to ``True``.\n        blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.\n        blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.\n        upsample_mode: [``\"deconv\"``, ``\"nontrainable\"``, ``\"pixelshuffle\"``]\n            The mode of upsampling manipulations.\n            Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.\n\n            - ``deconv``, uses transposed convolution layers.\n            - ``nontrainable``, uses non-trainable `linear` interpolation.\n            - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        init_filters: int = 8,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        dropout_prob: float | None = None,\n        act: tuple | str = (\"RELU\", {\"inplace\": True}),\n        norm: tuple | str = (\"GROUP\", {\"num_groups\": 8}),\n        norm_name: str = \"\",\n        num_groups: int = 8,\n        use_conv_final: bool = True,\n        blocks_down: tuple = (1, 2, 2, 4),\n        blocks_up: tuple = (1, 1, 1),\n        upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE,\n    ):\n        super().__init__()\n\n        if spatial_dims not in (2, 3):\n            raise ValueError(\"`spatial_dims` can only be 2 or 3.\")\n\n        self.spatial_dims = spatial_dims\n        self.init_filters = init_filters\n        self.in_channels = in_channels\n        self.blocks_down = blocks_down\n        self.blocks_up = blocks_up\n        self.dropout_prob = dropout_prob\n        self.act = act  # input options\n        self.act_mod = get_act_layer(act)\n        if norm_name:\n            if norm_name.lower() != \"group\":\n                raise ValueError(f\"Deprecating option 'norm_name={norm_name}', please use 'norm' instead.\")\n            norm = (\"group\", {\"num_groups\": num_groups})\n        self.norm = norm\n        self.upsample_mode = UpsampleMode(upsample_mode)\n        self.use_conv_final = use_conv_final\n        self.convInit = get_conv_layer(spatial_dims, in_channels, init_filters)\n        self.down_layers = self._make_down_layers()\n        self.up_layers, self.up_samples = self._make_up_layers()\n        self.conv_final = self._make_final_conv(out_channels)\n\n        if dropout_prob is not None:\n            self.dropout = Dropout[Dropout.DROPOUT, spatial_dims](dropout_prob)\n\n    def _make_down_layers(self):\n        down_layers = nn.ModuleList()\n        blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm)\n        for i, item in enumerate(blocks_down):\n            layer_in_channels = filters * 2**i\n            pre_conv = (\n                get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2)\n                if i > 0\n                else nn.Identity()\n            )\n            down_layer = nn.Sequential(\n                pre_conv, *[ResBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(item)]\n            )\n            down_layers.append(down_layer)\n        return down_layers\n\n    def _make_up_layers(self):\n        up_layers, up_samples = nn.ModuleList(), nn.ModuleList()\n        upsample_mode, blocks_up, spatial_dims, filters, norm = (\n            self.upsample_mode,\n            self.blocks_up,\n            self.spatial_dims,\n            self.init_filters,\n            self.norm,\n        )\n        n_up = len(blocks_up)\n        for i in range(n_up):\n            sample_in_channels = filters * 2 ** (n_up - i)\n            up_layers.append(\n                nn.Sequential(\n                    *[\n                        ResBlock(spatial_dims, sample_in_channels // 2, norm=norm, act=self.act)\n                        for _ in range(blocks_up[i])\n                    ]\n                )\n            )\n            up_samples.append(\n                nn.Sequential(\n                    *[\n                        get_conv_layer(spatial_dims, sample_in_channels, sample_in_channels // 2, kernel_size=1),\n                        get_upsample_layer(spatial_dims, sample_in_channels // 2, upsample_mode=upsample_mode),\n                    ]\n                )\n            )\n        return up_layers, up_samples\n\n    def _make_final_conv(self, out_channels: int):\n        return nn.Sequential(\n            get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.init_filters),\n            self.act_mod,\n            get_conv_layer(self.spatial_dims, self.init_filters, out_channels, kernel_size=1, bias=True),\n        )\n\n    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:\n        x = self.convInit(x)\n        if self.dropout_prob is not None:\n            x = self.dropout(x)\n\n        down_x = []\n\n        for down in self.down_layers:\n            x = down(x)\n            down_x.append(x)\n\n        return x, down_x\n\n    def decode(self, x: torch.Tensor, down_x: list[torch.Tensor]) -> torch.Tensor:\n        for i, (up, upl) in enumerate(zip(self.up_samples, self.up_layers)):\n            x = up(x) + down_x[i + 1]\n            x = upl(x)\n\n        if self.use_conv_final:\n            x = self.conv_final(x)\n\n        return x\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x, down_x = self.encode(x)\n        down_x.reverse()\n\n        x = self.decode(x, down_x)\n        return x\n\n\nclass SegResNetVAE(SegResNet):\n    \"\"\"\n    SegResNetVAE based on `3D MRI brain tumor segmentation using autoencoder regularization\n    <https://arxiv.org/pdf/1810.11654.pdf>`_.\n    The module contains the variational autoencoder (VAE).\n    The model supports 2D or 3D inputs.\n\n    Args:\n        input_image_size: the size of images to input into the network. It is used to\n            determine the in_features of the fc layer in VAE.\n        vae_estimate_std: whether to estimate the standard deviations in VAE. Defaults to ``False``.\n        vae_default_std: if not to estimate the std, use the default value. Defaults to 0.3.\n        vae_nz: number of latent variables in VAE. Defaults to 256.\n            Where, 128 to represent mean, and 128 to represent std.\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        init_filters: number of output channels for initial convolution layer. Defaults to 8.\n        in_channels: number of input channels for the network. Defaults to 1.\n        out_channels: number of output channels for the network. Defaults to 2.\n        dropout_prob: probability of an element to be zero-ed. Defaults to ``None``.\n        act: activation type and arguments. Defaults to ``RELU``.\n        norm: feature normalization type and arguments. Defaults to ``GROUP``.\n        use_conv_final: if add a final convolution block to output. Defaults to ``True``.\n        blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.\n        blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.\n        upsample_mode: [``\"deconv\"``, ``\"nontrainable\"``, ``\"pixelshuffle\"``]\n            The mode of upsampling manipulations.\n            Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.\n\n            - ``deconv``, uses transposed convolution layers.\n            - ``nontrainable``, uses non-trainable `linear` interpolation.\n            - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_image_size: Sequence[int],\n        vae_estimate_std: bool = False,\n        vae_default_std: float = 0.3,\n        vae_nz: int = 256,\n        spatial_dims: int = 3,\n        init_filters: int = 8,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        dropout_prob: float | None = None,\n        act: str | tuple = (\"RELU\", {\"inplace\": True}),\n        norm: tuple | str = (\"GROUP\", {\"num_groups\": 8}),\n        use_conv_final: bool = True,\n        blocks_down: tuple = (1, 2, 2, 4),\n        blocks_up: tuple = (1, 1, 1),\n        upsample_mode: UpsampleMode | str = UpsampleMode.NONTRAINABLE,\n    ):\n        super().__init__(\n            spatial_dims=spatial_dims,\n            init_filters=init_filters,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            dropout_prob=dropout_prob,\n            act=act,\n            norm=norm,\n            use_conv_final=use_conv_final,\n            blocks_down=blocks_down,\n            blocks_up=blocks_up,\n            upsample_mode=upsample_mode,\n        )\n\n        self.input_image_size = input_image_size\n        self.smallest_filters = 16\n\n        zoom = 2 ** (len(self.blocks_down) - 1)\n        self.fc_insize = [s // (2 * zoom) for s in self.input_image_size]\n\n        self.vae_estimate_std = vae_estimate_std\n        self.vae_default_std = vae_default_std\n        self.vae_nz = vae_nz\n        self._prepare_vae_modules()\n        self.vae_conv_final = self._make_final_conv(in_channels)\n\n    def _prepare_vae_modules(self):\n        zoom = 2 ** (len(self.blocks_down) - 1)\n        v_filters = self.init_filters * zoom\n        total_elements = int(self.smallest_filters * np.prod(self.fc_insize))\n\n        self.vae_down = nn.Sequential(\n            get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters),\n            self.act_mod,\n            get_conv_layer(self.spatial_dims, v_filters, self.smallest_filters, stride=2, bias=True),\n            get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=self.smallest_filters),\n            self.act_mod,\n        )\n        self.vae_fc1 = nn.Linear(total_elements, self.vae_nz)\n        self.vae_fc2 = nn.Linear(total_elements, self.vae_nz)\n        self.vae_fc3 = nn.Linear(self.vae_nz, total_elements)\n\n        self.vae_fc_up_sample = nn.Sequential(\n            get_conv_layer(self.spatial_dims, self.smallest_filters, v_filters, kernel_size=1),\n            get_upsample_layer(self.spatial_dims, v_filters, upsample_mode=self.upsample_mode),\n            get_norm_layer(name=self.norm, spatial_dims=self.spatial_dims, channels=v_filters),\n            self.act_mod,\n        )\n\n    def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor):\n        \"\"\"\n        Args:\n            net_input: the original input of the network.\n            vae_input: the input of VAE module, which is also the output of the network's encoder.\n        \"\"\"\n        x_vae = self.vae_down(vae_input)\n        x_vae = x_vae.view(-1, self.vae_fc1.in_features)\n        z_mean = self.vae_fc1(x_vae)\n\n        z_mean_rand = torch.randn_like(z_mean)\n        z_mean_rand.requires_grad_(False)\n\n        if self.vae_estimate_std:\n            z_sigma = self.vae_fc2(x_vae)\n            z_sigma = F.softplus(z_sigma)\n            vae_reg_loss = 0.5 * torch.mean(z_mean**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1)\n\n            x_vae = z_mean + z_sigma * z_mean_rand\n        else:\n            z_sigma = self.vae_default_std\n            vae_reg_loss = torch.mean(z_mean**2)\n\n            x_vae = z_mean + z_sigma * z_mean_rand\n\n        x_vae = self.vae_fc3(x_vae)\n        x_vae = self.act_mod(x_vae)\n        x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize)\n        x_vae = self.vae_fc_up_sample(x_vae)\n\n        for up, upl in zip(self.up_samples, self.up_layers):\n            x_vae = up(x_vae)\n            x_vae = upl(x_vae)\n\n        x_vae = self.vae_conv_final(x_vae)\n        vae_mse_loss = F.mse_loss(net_input, x_vae)\n        vae_loss = vae_reg_loss + vae_mse_loss\n        return vae_loss\n\n    def forward(self, x):\n        net_input = x\n        x, down_x = self.encode(x)\n        down_x.reverse()\n\n        vae_input = x\n        x = self.decode(x, down_x)\n\n        if self.training:\n            vae_loss = self._get_vae_loss(net_input, vae_input)\n            return x, vae_loss\n\n        return x, None\n"
  },
  {
    "path": "monai/networks/nets/segresnet_ds.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nfrom collections.abc import Callable\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.upsample import UpSample\nfrom monai.networks.layers.factories import Act, Conv, Norm, split_args\nfrom monai.networks.layers.utils import get_act_layer, get_norm_layer\nfrom monai.utils import UpsampleMode, has_option\n\n__all__ = [\"SegResNetDS\", \"SegResNetDS2\"]\n\n\ndef scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):\n    \"\"\"\n    A helper function to compute a schedule of scale at different downsampling levels,\n    given the input resolution.\n\n    .. code-block:: python\n\n        scales_for_resolution(resolution=[1,1,5], n_stages=5)\n\n    Args:\n        resolution: input image resolution (in mm)\n        n_stages: optionally the number of stages of the network\n    \"\"\"\n\n    ndim = len(resolution)\n    res = np.array(resolution)\n    if not all(res > 0):\n        raise ValueError(\"Resolution must be positive\")\n\n    nl = np.floor(np.log2(np.max(res) / res)).astype(np.int32)\n    scales = [tuple(np.where(2**i >= 2**nl, 1, 2)) for i in range(max(nl))]\n    if n_stages and n_stages > max(nl):\n        scales = scales + [(2,) * ndim] * (n_stages - max(nl))\n    else:\n        scales = scales[:n_stages]\n    return scales\n\n\ndef aniso_kernel(scale: tuple | list):\n    \"\"\"\n    A helper function to compute kernel_size, padding and stride for the given scale\n\n    Args:\n        scale: scale from a current scale level\n    \"\"\"\n    kernel_size = [3 if scale[k] > 1 else 1 for k in range(len(scale))]\n    padding = [k // 2 for k in kernel_size]\n    return kernel_size, padding, scale\n\n\nclass SegResBlock(nn.Module):\n    \"\"\"\n    Residual network block used SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization\n    <https://arxiv.org/pdf/1810.11654.pdf>`_.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        norm: tuple | str,\n        kernel_size: tuple | int = 3,\n        act: tuple | str = \"relu\",\n    ) -> None:\n        \"\"\"\n        Args:\n            spatial_dims: number of spatial dimensions, could be 1, 2 or 3.\n            in_channels: number of input channels.\n            norm: feature normalization type and arguments.\n            kernel_size: convolution kernel size. Defaults to 3.\n            act: activation type and arguments. Defaults to ``RELU``.\n        \"\"\"\n        super().__init__()\n\n        if isinstance(kernel_size, (tuple, list)):\n            padding = tuple(k // 2 for k in kernel_size)\n        else:\n            padding = kernel_size // 2  # type: ignore\n\n        self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)\n        self.act1 = get_act_layer(act)\n        self.conv1 = Conv[Conv.CONV, spatial_dims](\n            in_channels=in_channels,\n            out_channels=in_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            padding=padding,\n            bias=False,\n        )\n\n        self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)\n        self.act2 = get_act_layer(act)\n        self.conv2 = Conv[Conv.CONV, spatial_dims](\n            in_channels=in_channels,\n            out_channels=in_channels,\n            kernel_size=kernel_size,\n            stride=1,\n            padding=padding,\n            bias=False,\n        )\n\n    def forward(self, x):\n        identity = x\n        x = self.conv2(self.act2(self.norm2(self.conv1(self.act1(self.norm1(x))))))\n        x += identity\n        return x\n\n\nclass SegResEncoder(nn.Module):\n    \"\"\"\n    SegResEncoder based on the encoder structure in `3D MRI brain tumor segmentation using autoencoder regularization\n    <https://arxiv.org/pdf/1810.11654.pdf>`_.\n\n    Args:\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        init_filters: number of output channels for initial convolution layer. Defaults to 32.\n        in_channels: number of input channels for the network. Defaults to 1.\n        out_channels: number of output channels for the network. Defaults to 2.\n        act: activation type and arguments. Defaults to ``RELU``.\n        norm: feature normalization type and arguments. Defaults to ``BATCH``.\n        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.\n        head_module: optional callable module to apply to the final features.\n        anisotropic_scales: optional list of scale for each scale level.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        init_filters: int = 32,\n        in_channels: int = 1,\n        act: tuple | str = \"relu\",\n        norm: tuple | str = \"batch\",\n        blocks_down: tuple = (1, 2, 2, 4),\n        head_module: nn.Module | None = None,\n        anisotropic_scales: tuple | None = None,\n    ):\n        super().__init__()\n\n        if spatial_dims not in (1, 2, 3):\n            raise ValueError(\"`spatial_dims` can only be 1, 2 or 3.\")\n\n        # ensure normalization has affine trainable parameters (if not specified)\n        norm = split_args(norm)\n        if has_option(Norm[norm[0], spatial_dims], \"affine\"):\n            norm[1].setdefault(\"affine\", True)  # type: ignore\n\n        # ensure activation is inplace (if not specified)\n        act = split_args(act)\n        if has_option(Act[act[0]], \"inplace\"):\n            act[1].setdefault(\"inplace\", True)  # type: ignore\n\n        filters = init_filters  # base number of features\n\n        kernel_size, padding, _ = aniso_kernel(anisotropic_scales[0]) if anisotropic_scales else (3, 1, 1)\n        self.conv_init = Conv[Conv.CONV, spatial_dims](\n            in_channels=in_channels,\n            out_channels=filters,\n            kernel_size=kernel_size,\n            padding=padding,\n            stride=1,\n            bias=False,\n        )\n        self.layers = nn.ModuleList()\n\n        for i in range(len(blocks_down)):\n            level = nn.ModuleDict()\n\n            kernel_size, padding, stride = aniso_kernel(anisotropic_scales[i]) if anisotropic_scales else (3, 1, 2)\n            blocks = [\n                SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act)\n                for _ in range(blocks_down[i])\n            ]\n            level[\"blocks\"] = nn.Sequential(*blocks)\n\n            if i < len(blocks_down) - 1:\n                level[\"downsample\"] = Conv[Conv.CONV, spatial_dims](\n                    in_channels=filters,\n                    out_channels=2 * filters,\n                    bias=False,\n                    kernel_size=kernel_size,\n                    stride=stride,\n                    padding=padding,\n                )\n            else:\n                level[\"downsample\"] = nn.Identity()\n\n            self.layers.append(level)\n            filters *= 2\n\n        self.head_module = head_module\n        self.in_channels = in_channels\n        self.blocks_down = blocks_down\n        self.init_filters = init_filters\n        self.norm = norm\n        self.act = act\n        self.spatial_dims = spatial_dims\n\n    def _forward(self, x: torch.Tensor) -> list[torch.Tensor]:\n        outputs = []\n        x = self.conv_init(x)\n\n        for level in self.layers:\n            x = level[\"blocks\"](x)  # type: ignore\n            outputs.append(x)\n            x = level[\"downsample\"](x)  # type: ignore\n\n        if self.head_module is not None:\n            outputs = self.head_module(outputs)\n\n        return outputs\n\n    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:\n        return self._forward(x)\n\n\nclass SegResNetDS(nn.Module):\n    \"\"\"\n    SegResNetDS based on `3D MRI brain tumor segmentation using autoencoder regularization\n    <https://arxiv.org/pdf/1810.11654.pdf>`_.\n    It is similar to https://monai.readthedocs.io/en/stable/networks.html#segresnet, with several\n    improvements including deep supervision and non-isotropic kernel support.\n\n    Args:\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        init_filters: number of output channels for initial convolution layer. Defaults to 32.\n        in_channels: number of input channels for the network. Defaults to 1.\n        out_channels: number of output channels for the network. Defaults to 2.\n        act: activation type and arguments. Defaults to ``RELU``.\n        norm: feature normalization type and arguments. Defaults to ``BATCH``.\n        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.\n        blocks_up: number of upsample blocks (optional).\n        dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.\n                 At dsdepth==1,only a single output is returned.\n        preprocess: optional callable function to apply before the model's forward pass\n        resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring\n                    image spacing into an approximately isotropic space.\n                    Otherwise, by default, the kernel size and downsampling is always isotropic.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        init_filters: int = 32,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        act: tuple | str = \"relu\",\n        norm: tuple | str = \"batch\",\n        blocks_down: tuple = (1, 2, 2, 4),\n        blocks_up: tuple | None = None,\n        dsdepth: int = 1,\n        preprocess: nn.Module | Callable | None = None,\n        upsample_mode: UpsampleMode | str = \"deconv\",\n        resolution: tuple | None = None,\n    ):\n        super().__init__()\n\n        if spatial_dims not in (1, 2, 3):\n            raise ValueError(\"`spatial_dims` can only be 1, 2 or 3.\")\n\n        self.spatial_dims = spatial_dims\n        self.init_filters = init_filters\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.act = act\n        self.norm = norm\n        self.blocks_down = blocks_down\n        self.dsdepth = max(dsdepth, 1)\n        self.resolution = resolution\n        self.preprocess = preprocess\n\n        if resolution is not None:\n            if not isinstance(resolution, (list, tuple)):\n                raise TypeError(\"resolution must be a tuple\")\n            elif not all(r > 0 for r in resolution):\n                raise ValueError(\"resolution must be positive\")\n\n        # ensure normalization had affine trainable parameters (if not specified)\n        norm = split_args(norm)\n        if has_option(Norm[norm[0], spatial_dims], \"affine\"):\n            norm[1].setdefault(\"affine\", True)  # type: ignore\n\n        # ensure activation is inplace (if not specified)\n        act = split_args(act)\n        if has_option(Act[act[0]], \"inplace\"):\n            act[1].setdefault(\"inplace\", True)  # type: ignore\n\n        anisotropic_scales = None\n        if resolution:\n            anisotropic_scales = scales_for_resolution(resolution, n_stages=len(blocks_down))\n        self.anisotropic_scales = anisotropic_scales\n\n        self.encoder = SegResEncoder(\n            spatial_dims=spatial_dims,\n            init_filters=init_filters,\n            in_channels=in_channels,\n            act=act,\n            norm=norm,\n            blocks_down=blocks_down,\n            anisotropic_scales=anisotropic_scales,\n        )\n\n        n_up = len(blocks_down) - 1\n        if blocks_up is None:\n            blocks_up = (1,) * n_up  # assume 1 upsample block per level\n        self.blocks_up = blocks_up\n\n        filters = init_filters * 2**n_up\n        self.up_layers = nn.ModuleList()\n\n        for i in range(n_up):\n            filters = filters // 2\n            kernel_size, _, stride = (\n                aniso_kernel(anisotropic_scales[len(blocks_up) - i - 1]) if anisotropic_scales else (3, 1, 2)\n            )\n\n            level = nn.ModuleDict()\n            level[\"upsample\"] = UpSample(\n                mode=upsample_mode,\n                spatial_dims=spatial_dims,\n                in_channels=2 * filters,\n                out_channels=filters,\n                kernel_size=kernel_size,\n                scale_factor=stride,\n                bias=False,\n                align_corners=False,\n            )\n            blocks = [\n                SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act)\n                for _ in range(blocks_up[i])\n            ]\n            level[\"blocks\"] = nn.Sequential(*blocks)\n\n            if len(blocks_up) - i <= dsdepth:  # deep supervision heads\n                level[\"head\"] = Conv[Conv.CONV, spatial_dims](\n                    in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True\n                )\n            else:\n                level[\"head\"] = nn.Identity()\n\n            self.up_layers.append(level)\n\n        if n_up == 0:  # in a corner case of flat structure (no downsampling), attache a single head\n            level = nn.ModuleDict(\n                {\n                    \"upsample\": nn.Identity(),\n                    \"blocks\": nn.Identity(),\n                    \"head\": Conv[Conv.CONV, spatial_dims](\n                        in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True\n                    ),\n                }\n            )\n            self.up_layers.append(level)\n\n    def shape_factor(self):\n        \"\"\"\n        Calculate the factors (divisors) that the input image shape must be divisible by\n        \"\"\"\n        if self.anisotropic_scales is None:\n            d = [2 ** (len(self.blocks_down) - 1)] * self.spatial_dims\n        else:\n            d = list(np.prod(np.array(self.anisotropic_scales[:-1]), axis=0))\n        return d\n\n    def is_valid_shape(self, x):\n        \"\"\"\n        Calculate if the input shape is divisible by the minimum factors for the current network configuration\n        \"\"\"\n        a = [i % j == 0 for i, j in zip(x.shape[2:], self.shape_factor())]\n        return all(a)\n\n    def _forward(self, x: torch.Tensor) -> None | torch.Tensor | list[torch.Tensor]:\n        if self.preprocess is not None:\n            x = self.preprocess(x)\n\n        if not self.is_valid_shape(x):\n            raise ValueError(f\"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}\")\n\n        x_down = self.encoder(x)\n\n        x_down.reverse()\n        x = x_down.pop(0)\n\n        if len(x_down) == 0:\n            x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]\n\n        outputs: list[torch.Tensor] = []\n\n        i = 0\n        for level in self.up_layers:\n            x = level[\"upsample\"](x)  # type: ignore\n            x += x_down.pop(0)\n            x = level[\"blocks\"](x)  # type: ignore\n\n            if len(self.up_layers) - i <= self.dsdepth:\n                outputs.append(level[\"head\"](x))  # type: ignore\n            i = i + 1\n\n        outputs.reverse()\n\n        # in eval() mode, always return a single final output\n        if not self.training or len(outputs) == 1:\n            return outputs[0]\n\n        # return a list of DS outputs\n        return outputs\n\n    def forward(self, x: torch.Tensor) -> None | torch.Tensor | list[torch.Tensor]:\n        return self._forward(x)\n\n\nclass SegResNetDS2(SegResNetDS):\n    \"\"\"\n    SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D\n     <https://arxiv.org/abs/2406.05285>`_.\n\n    Args:\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        init_filters: number of output channels for initial convolution layer. Defaults to 32.\n        in_channels: number of input channels for the network. Defaults to 1.\n        out_channels: number of output channels for the network. Defaults to 2.\n        act: activation type and arguments. Defaults to ``RELU``.\n        norm: feature normalization type and arguments. Defaults to ``BATCH``.\n        blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.\n        blocks_up: number of upsample blocks (optional).\n        dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.\n                 At dsdepth==1,only a single output is returned.\n        preprocess: optional callable function to apply before the model's forward pass\n        resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring\n                    image spacing into an approximately isotropic space.\n                    Otherwise, by default, the kernel size and downsampling is always isotropic.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        init_filters: int = 32,\n        in_channels: int = 1,\n        out_channels: int = 2,\n        act: tuple | str = \"relu\",\n        norm: tuple | str = \"batch\",\n        blocks_down: tuple = (1, 2, 2, 4),\n        blocks_up: tuple | None = None,\n        dsdepth: int = 1,\n        preprocess: nn.Module | Callable | None = None,\n        upsample_mode: UpsampleMode | str = \"deconv\",\n        resolution: tuple | None = None,\n    ):\n        super().__init__(\n            spatial_dims=spatial_dims,\n            init_filters=init_filters,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            act=act,\n            norm=norm,\n            blocks_down=blocks_down,\n            blocks_up=blocks_up,\n            dsdepth=dsdepth,\n            preprocess=preprocess,\n            upsample_mode=upsample_mode,\n            resolution=resolution,\n        )\n\n        self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])\n\n    def forward(  # type: ignore\n        self, x: torch.Tensor, with_point: bool = True, with_label: bool = True\n    ) -> tuple[None | torch.Tensor | list[torch.Tensor], None | torch.Tensor | list[torch.Tensor]]:\n        \"\"\"\n        Args:\n            x: input tensor.\n            with_point: if true, return the point branch output.\n            with_label: if true, return the label branch output.\n        \"\"\"\n        if self.preprocess is not None:\n            x = self.preprocess(x)\n\n        if not self.is_valid_shape(x):\n            raise ValueError(f\"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}\")\n\n        x_down = self.encoder(x)\n\n        x_down.reverse()\n        x = x_down.pop(0)\n\n        if len(x_down) == 0:\n            x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]\n\n        outputs: list[torch.Tensor] = []\n        outputs_auto: list[torch.Tensor] = []\n        level: nn.ModuleDict\n        x_ = x\n        if with_point:\n            if with_label:\n                x_ = x.clone()\n            i = 0\n            for level in self.up_layers:  # type: ignore\n                x = level[\"upsample\"](x)\n                x = x + x_down[i]\n                x = level[\"blocks\"](x)\n\n                if len(self.up_layers) - i <= self.dsdepth:\n                    outputs.append(level[\"head\"](x))\n                i = i + 1\n\n            outputs.reverse()\n        x = x_\n        if with_label:\n            i = 0\n            for level in self.up_layers_auto:  # type: ignore\n                x = level[\"upsample\"](x)\n                x = x + x_down[i]\n                x = level[\"blocks\"](x)\n\n                if len(self.up_layers) - i <= self.dsdepth:\n                    outputs_auto.append(level[\"head\"](x))\n                i = i + 1\n\n            outputs_auto.reverse()\n\n        return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto\n\n    def set_auto_grad(self, auto_freeze=False, point_freeze=False):\n        \"\"\"\n        Args:\n            auto_freeze: if true, freeze the image encoder and the auto-branch.\n            point_freeze: if true, freeze the image encoder and the point-branch.\n        \"\"\"\n        for param in self.encoder.parameters():\n            param.requires_grad = (not auto_freeze) and (not point_freeze)\n\n        for param in self.up_layers_auto.parameters():\n            param.requires_grad = not auto_freeze\n\n        for param in self.up_layers.parameters():\n            param.requires_grad = not point_freeze\n"
  },
  {
    "path": "monai/networks/nets/senet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport re\nfrom collections import OrderedDict\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport torch\nimport torch.nn as nn\nfrom torch.hub import load_state_dict_from_url\n\nfrom monai.apps.utils import download_url\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.blocks.squeeze_and_excitation import SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck\nfrom monai.networks.layers.factories import Act, Conv, Dropout, Norm, Pool\nfrom monai.utils.module import look_up_option\n\n__all__ = [\n    \"SENet\",\n    \"SENet154\",\n    \"SEResNet50\",\n    \"SEResNet101\",\n    \"SEResNet152\",\n    \"SEResNeXt50\",\n    \"SEResNext101\",\n    \"SE_NET_MODELS\",\n]\n\nSE_NET_MODELS = {\n    \"senet154\": \"http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth\",\n    \"se_resnet50\": \"http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth\",\n    \"se_resnet101\": \"http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth\",\n    \"se_resnet152\": \"http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth\",\n    \"se_resnext50_32x4d\": \"http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth\",\n    \"se_resnext101_32x4d\": \"http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth\",\n}\n\n\nclass SENet(nn.Module):\n    \"\"\"\n    SENet based on `Squeeze-and-Excitation Networks <https://arxiv.org/pdf/1709.01507.pdf>`_.\n    Adapted from `Cadene Hub 2D version\n    <https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py>`_.\n\n    Args:\n        spatial_dims: spatial dimension of the input data.\n        in_channels: channel number of the input data.\n        block: SEBlock class or str.\n            for SENet154: SEBottleneck or 'se_bottleneck'\n            for SE-ResNet models: SEResNetBottleneck or 'se_resnet_bottleneck'\n            for SE-ResNeXt models:  SEResNeXtBottleneck or 'se_resnetxt_bottleneck'\n        layers: number of residual blocks for 4 layers of the network (layer1...layer4).\n        groups: number of groups for the 3x3 convolution in each bottleneck block.\n            for SENet154: 64\n            for SE-ResNet models: 1\n            for SE-ResNeXt models:  32\n        reduction: reduction ratio for Squeeze-and-Excitation modules.\n            for all models: 16\n        dropout_prob: drop probability for the Dropout layer.\n            if `None` the Dropout layer is not used.\n            for SENet154: 0.2\n            for SE-ResNet models: None\n            for SE-ResNeXt models: None\n        dropout_dim: determine the dimensions of dropout. Defaults to 1.\n            When dropout_dim = 1, randomly zeroes some of the elements for each channel.\n            When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).\n            When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).\n        inplanes:  number of input channels for layer1.\n            for SENet154: 128\n            for SE-ResNet models: 64\n            for SE-ResNeXt models: 64\n        downsample_kernel_size: kernel size for downsampling convolutions in layer2, layer3 and layer4.\n            for SENet154: 3\n            for SE-ResNet models: 1\n            for SE-ResNeXt models: 1\n        input_3x3: If `True`, use three 3x3 convolutions instead of\n            a single 7x7 convolution in layer0.\n            - For SENet154: True\n            - For SE-ResNet models: False\n            - For SE-ResNeXt models: False\n        num_classes: number of outputs in `last_linear` layer.\n            for all models: 1000\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        block: type[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck] | str,\n        layers: Sequence[int],\n        groups: int,\n        reduction: int,\n        dropout_prob: float | None = 0.2,\n        dropout_dim: int = 1,\n        inplanes: int = 128,\n        downsample_kernel_size: int = 3,\n        input_3x3: bool = True,\n        num_classes: int = 1000,\n    ) -> None:\n        super().__init__()\n\n        if isinstance(block, str):\n            if block == \"se_bottleneck\":\n                block = SEBottleneck\n            elif block == \"se_resnet_bottleneck\":\n                block = SEResNetBottleneck\n            elif block == \"se_resnetxt_bottleneck\":\n                block = SEResNeXtBottleneck\n            else:\n                raise ValueError(\n                    f\"Unknown block '{block}', use se_bottleneck, se_resnet_bottleneck or se_resnetxt_bottleneck\"\n                )\n\n        relu_type: type[nn.ReLU] = Act[Act.RELU]\n        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n        pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]\n        norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        dropout_type: type[nn.Dropout | nn.Dropout2d | nn.Dropout3d] = Dropout[Dropout.DROPOUT, dropout_dim]\n        avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[\n            Pool.ADAPTIVEAVG, spatial_dims\n        ]\n\n        self.inplanes = inplanes\n        self.spatial_dims = spatial_dims\n\n        layer0_modules: list[tuple[str, Any]]\n\n        if input_3x3:\n            layer0_modules = [\n                (\n                    \"conv1\",\n                    conv_type(in_channels=in_channels, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False),\n                ),\n                (\"bn1\", norm_type(num_features=64)),\n                (\"relu1\", relu_type(inplace=True)),\n                (\"conv2\", conv_type(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)),\n                (\"bn2\", norm_type(num_features=64)),\n                (\"relu2\", relu_type(inplace=True)),\n                (\n                    \"conv3\",\n                    conv_type(in_channels=64, out_channels=inplanes, kernel_size=3, stride=1, padding=1, bias=False),\n                ),\n                (\"bn3\", norm_type(num_features=inplanes)),\n                (\"relu3\", relu_type(inplace=True)),\n            ]\n        else:\n            layer0_modules = [\n                (\n                    \"conv1\",\n                    conv_type(\n                        in_channels=in_channels, out_channels=inplanes, kernel_size=7, stride=2, padding=3, bias=False\n                    ),\n                ),\n                (\"bn1\", norm_type(num_features=inplanes)),\n                (\"relu1\", relu_type(inplace=True)),\n            ]\n\n        layer0_modules.append((\"pool\", pool_type(kernel_size=3, stride=2, ceil_mode=True)))\n        self.layer0 = nn.Sequential(OrderedDict(layer0_modules))\n        self.layer1 = self._make_layer(\n            block, planes=64, blocks=layers[0], groups=groups, reduction=reduction, downsample_kernel_size=1\n        )\n        self.layer2 = self._make_layer(\n            block,\n            planes=128,\n            blocks=layers[1],\n            stride=2,\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=downsample_kernel_size,\n        )\n        self.layer3 = self._make_layer(\n            block,\n            planes=256,\n            blocks=layers[2],\n            stride=2,\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=downsample_kernel_size,\n        )\n        self.layer4 = self._make_layer(\n            block,\n            planes=512,\n            blocks=layers[3],\n            stride=2,\n            groups=groups,\n            reduction=reduction,\n            downsample_kernel_size=downsample_kernel_size,\n        )\n        self.adaptive_avg_pool = avg_pool_type(1)\n        self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None\n        self.last_linear = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, conv_type):\n                nn.init.kaiming_normal_(torch.as_tensor(m.weight))\n            elif isinstance(m, norm_type):\n                nn.init.constant_(torch.as_tensor(m.weight), 1)\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.constant_(torch.as_tensor(m.bias), 0)\n\n    def _make_layer(\n        self,\n        block: type[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck],\n        planes: int,\n        blocks: int,\n        groups: int,\n        reduction: int,\n        stride: int = 1,\n        downsample_kernel_size: int = 1,\n    ) -> nn.Sequential:\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = Convolution(\n                spatial_dims=self.spatial_dims,\n                in_channels=self.inplanes,\n                out_channels=planes * block.expansion,\n                strides=stride,\n                kernel_size=downsample_kernel_size,\n                act=None,\n                norm=Norm.BATCH,\n                bias=False,\n            )\n\n        layers = []\n        layers.append(\n            block(\n                spatial_dims=self.spatial_dims,\n                inplanes=self.inplanes,\n                planes=planes,\n                groups=groups,\n                reduction=reduction,\n                stride=stride,\n                downsample=downsample,\n            )\n        )\n        self.inplanes = planes * block.expansion\n        for _num in range(1, blocks):\n            layers.append(\n                block(\n                    spatial_dims=self.spatial_dims,\n                    inplanes=self.inplanes,\n                    planes=planes,\n                    groups=groups,\n                    reduction=reduction,\n                )\n            )\n\n        return nn.Sequential(*layers)\n\n    def features(self, x: torch.Tensor):\n        x = self.layer0(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        return x\n\n    def logits(self, x: torch.Tensor):\n        x = self.adaptive_avg_pool(x)\n        if self.dropout is not None:\n            x = self.dropout(x)\n        x = torch.flatten(x, 1)\n        x = self.last_linear(x)\n        return x\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.features(x)\n        x = self.logits(x)\n        return x\n\n\ndef _load_state_dict(model: nn.Module, arch: str, progress: bool):\n    \"\"\"\n    This function is used to load pretrained models.\n    \"\"\"\n    model_url = look_up_option(arch, SE_NET_MODELS, None)\n    if model_url is None:\n        raise ValueError(\n            \"only 'senet154', 'se_resnet50', 'se_resnet101',  'se_resnet152', 'se_resnext50_32x4d', \"\n            + \"and se_resnext101_32x4d are supported to load pretrained weights.\"\n        )\n\n    pattern_conv = re.compile(r\"^(layer[1-4]\\.\\d\\.(?:conv)\\d\\.)(\\w*)$\")\n    pattern_bn = re.compile(r\"^(layer[1-4]\\.\\d\\.)(?:bn)(\\d\\.)(\\w*)$\")\n    pattern_se = re.compile(r\"^(layer[1-4]\\.\\d\\.)(?:se_module.fc1.)(\\w*)$\")\n    pattern_se2 = re.compile(r\"^(layer[1-4]\\.\\d\\.)(?:se_module.fc2.)(\\w*)$\")\n    pattern_down_conv = re.compile(r\"^(layer[1-4]\\.\\d\\.)(?:downsample.0.)(\\w*)$\")\n    pattern_down_bn = re.compile(r\"^(layer[1-4]\\.\\d\\.)(?:downsample.1.)(\\w*)$\")\n\n    if isinstance(model_url, dict):\n        download_url(model_url[\"url\"], filepath=model_url[\"filename\"])\n        state_dict = torch.load(model_url[\"filename\"], map_location=None, weights_only=True)\n    else:\n        state_dict = load_state_dict_from_url(model_url, progress=progress)\n    for key in list(state_dict.keys()):\n        new_key = None\n        if pattern_conv.match(key):\n            new_key = re.sub(pattern_conv, r\"\\1conv.\\2\", key)\n        elif pattern_bn.match(key):\n            new_key = re.sub(pattern_bn, r\"\\1conv\\2adn.N.\\3\", key)\n        elif pattern_se.match(key):\n            state_dict[key] = state_dict[key].squeeze()\n            new_key = re.sub(pattern_se, r\"\\1se_layer.fc.0.\\2\", key)\n        elif pattern_se2.match(key):\n            state_dict[key] = state_dict[key].squeeze()\n            new_key = re.sub(pattern_se2, r\"\\1se_layer.fc.2.\\2\", key)\n        elif pattern_down_conv.match(key):\n            new_key = re.sub(pattern_down_conv, r\"\\1project.conv.\\2\", key)\n        elif pattern_down_bn.match(key):\n            new_key = re.sub(pattern_down_bn, r\"\\1project.adn.N.\\2\", key)\n        if new_key:\n            state_dict[new_key] = state_dict[key]\n            del state_dict[key]\n\n    model_dict = model.state_dict()\n    state_dict = {\n        k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)\n    }\n    model_dict.update(state_dict)\n    model.load_state_dict(model_dict)\n\n\nclass SENet154(SENet):\n    \"\"\"SENet154 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.\"\"\"\n\n    def __init__(\n        self,\n        layers: Sequence[int] = (3, 8, 36, 3),\n        groups: int = 64,\n        reduction: int = 16,\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(block=SEBottleneck, layers=layers, groups=groups, reduction=reduction, **kwargs)\n        if pretrained:\n            # it only worked when `spatial_dims` is 2\n            _load_state_dict(self, \"senet154\", progress)\n\n\nclass SEResNet50(SENet):\n    \"\"\"SEResNet50 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.\"\"\"\n\n    def __init__(\n        self,\n        layers: Sequence[int] = (3, 4, 6, 3),\n        groups: int = 1,\n        reduction: int = 16,\n        dropout_prob: float | None = None,\n        inplanes: int = 64,\n        downsample_kernel_size: int = 1,\n        input_3x3: bool = False,\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            block=SEResNetBottleneck,\n            layers=layers,\n            groups=groups,\n            reduction=reduction,\n            dropout_prob=dropout_prob,\n            inplanes=inplanes,\n            downsample_kernel_size=downsample_kernel_size,\n            input_3x3=input_3x3,\n            **kwargs,\n        )\n        if pretrained:\n            # it only worked when `spatial_dims` is 2\n            _load_state_dict(self, \"se_resnet50\", progress)\n\n\nclass SEResNet101(SENet):\n    \"\"\"\n    SEResNet101 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.\n    \"\"\"\n\n    def __init__(\n        self,\n        layers: Sequence[int] = (3, 4, 23, 3),\n        groups: int = 1,\n        reduction: int = 16,\n        inplanes: int = 64,\n        downsample_kernel_size: int = 1,\n        input_3x3: bool = False,\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            block=SEResNetBottleneck,\n            layers=layers,\n            groups=groups,\n            reduction=reduction,\n            inplanes=inplanes,\n            downsample_kernel_size=downsample_kernel_size,\n            input_3x3=input_3x3,\n            **kwargs,\n        )\n        if pretrained:\n            # it only worked when `spatial_dims` is 2\n            _load_state_dict(self, \"se_resnet101\", progress)\n\n\nclass SEResNet152(SENet):\n    \"\"\"\n    SEResNet152 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.\n    \"\"\"\n\n    def __init__(\n        self,\n        layers: Sequence[int] = (3, 8, 36, 3),\n        groups: int = 1,\n        reduction: int = 16,\n        inplanes: int = 64,\n        downsample_kernel_size: int = 1,\n        input_3x3: bool = False,\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            block=SEResNetBottleneck,\n            layers=layers,\n            groups=groups,\n            reduction=reduction,\n            inplanes=inplanes,\n            downsample_kernel_size=downsample_kernel_size,\n            input_3x3=input_3x3,\n            **kwargs,\n        )\n        if pretrained:\n            # it only worked when `spatial_dims` is 2\n            _load_state_dict(self, \"se_resnet152\", progress)\n\n\nclass SEResNext50(SENet):\n    \"\"\"\n    SEResNext50 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.\n    \"\"\"\n\n    def __init__(\n        self,\n        layers: Sequence[int] = (3, 4, 6, 3),\n        groups: int = 32,\n        reduction: int = 16,\n        dropout_prob: float | None = None,\n        inplanes: int = 64,\n        downsample_kernel_size: int = 1,\n        input_3x3: bool = False,\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            block=SEResNeXtBottleneck,\n            layers=layers,\n            groups=groups,\n            dropout_prob=dropout_prob,\n            reduction=reduction,\n            inplanes=inplanes,\n            downsample_kernel_size=downsample_kernel_size,\n            input_3x3=input_3x3,\n            **kwargs,\n        )\n        if pretrained:\n            # it only worked when `spatial_dims` is 2\n            _load_state_dict(self, \"se_resnext50_32x4d\", progress)\n\n\nclass SEResNext101(SENet):\n    \"\"\"\n    SEResNext101 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.\n    \"\"\"\n\n    def __init__(\n        self,\n        layers: Sequence[int] = (3, 4, 23, 3),\n        groups: int = 32,\n        reduction: int = 16,\n        dropout_prob: float | None = None,\n        inplanes: int = 64,\n        downsample_kernel_size: int = 1,\n        input_3x3: bool = False,\n        pretrained: bool = False,\n        progress: bool = True,\n        **kwargs,\n    ) -> None:\n        super().__init__(\n            block=SEResNeXtBottleneck,\n            layers=layers,\n            groups=groups,\n            dropout_prob=dropout_prob,\n            reduction=reduction,\n            inplanes=inplanes,\n            downsample_kernel_size=downsample_kernel_size,\n            input_3x3=input_3x3,\n            **kwargs,\n        )\n        if pretrained:\n            # it only worked when `spatial_dims` is 2\n            _load_state_dict(self, \"se_resnext101_32x4d\", progress)\n\n\nSEnet = Senet = SENet\nSEnet154 = Senet154 = senet154 = SENet154\nSEresnet50 = Seresnet50 = seresnet50 = SEResNet50\nSEresnet101 = Seresnet101 = seresnet101 = SEResNet101\nSEresnet152 = Seresnet152 = seresnet152 = SEResNet152\nSEResNeXt50 = SEresnext50 = Seresnext50 = seresnext50 = SEResNext50\nSEResNeXt101 = SEresnext101 = Seresnext101 = seresnext101 = SEResNext101\n"
  },
  {
    "path": "monai/networks/nets/spade_autoencoderkl.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample\nfrom monai.networks.blocks.spade_norm import SPADE\nfrom monai.networks.nets.autoencoderkl import Encoder\nfrom monai.utils import ensure_tuple_rep\n\n__all__ = [\"SPADEAutoencoderKL\"]\n\n\nclass SPADEResBlock(nn.Module):\n    \"\"\"\n    Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a\n    residual connection between input and output.\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: number of spatial dimensions (1D, 2D, 3D).\n        in_channels: input channels to the layer.\n        norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of\n            channels is divisible by this number.\n        norm_eps: epsilon for the normalisation.\n        out_channels: number of output channels.\n        label_nc: number of semantic channels for SPADE normalisation\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        norm_num_groups: int,\n        norm_eps: float,\n        out_channels: int,\n        label_nc: int,\n        spade_intermediate_channels: int,\n    ) -> None:\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels if out_channels is None else out_channels\n        self.norm1 = SPADE(\n            label_nc=label_nc,\n            norm_nc=in_channels,\n            norm=\"GROUP\",\n            norm_params={\"num_groups\": norm_num_groups, \"affine\": False, \"eps\": norm_eps},\n            hidden_channels=spade_intermediate_channels,\n            kernel_size=3,\n            spatial_dims=spatial_dims,\n        )\n        self.conv1 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.in_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n        self.norm2 = SPADE(\n            label_nc=label_nc,\n            norm_nc=out_channels,\n            norm=\"GROUP\",\n            norm_params={\"num_groups\": norm_num_groups, \"affine\": False, \"eps\": norm_eps},\n            hidden_channels=spade_intermediate_channels,\n            kernel_size=3,\n            spatial_dims=spatial_dims,\n        )\n        self.conv2 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.out_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        self.nin_shortcut: nn.Module\n        if self.in_channels != self.out_channels:\n            self.nin_shortcut = Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.in_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=1,\n                padding=0,\n                conv_only=True,\n            )\n        else:\n            self.nin_shortcut = nn.Identity()\n\n    def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:\n        h = x\n        h = self.norm1(h, seg)\n        h = F.silu(h)\n        h = self.conv1(h)\n        h = self.norm2(h, seg)\n        h = F.silu(h)\n        h = self.conv2(h)\n\n        x = self.nin_shortcut(x)\n\n        return x + h\n\n\nclass SPADEDecoder(nn.Module):\n    \"\"\"\n    Convolutional cascade upsampling from a spatial latent space into an image space.\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: number of spatial dimensions (1D, 2D, 3D).\n        channels: sequence of block output channels.\n        in_channels: number of channels in the bottom layer (latent space) of the autoencoder.\n        out_channels: number of output channels.\n        num_res_blocks: number of residual blocks (see ResBlock) per level.\n        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.\n        norm_eps: epsilon for the normalization.\n        attention_levels: indicate which level from channels contain an attention block.\n        label_nc: number of semantic channels for SPADE normalisation.\n        with_nonlocal_attn: if True use non-local attention block.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        channels: Sequence[int],\n        in_channels: int,\n        out_channels: int,\n        num_res_blocks: Sequence[int],\n        norm_num_groups: int,\n        norm_eps: float,\n        attention_levels: Sequence[bool],\n        label_nc: int,\n        with_nonlocal_attn: bool = True,\n        spade_intermediate_channels: int = 128,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.channels = channels\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.norm_num_groups = norm_num_groups\n        self.norm_eps = norm_eps\n        self.attention_levels = attention_levels\n        self.label_nc = label_nc\n\n        reversed_block_out_channels = list(reversed(channels))\n\n        blocks: list[nn.Module] = []\n\n        # Initial convolution\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=reversed_block_out_channels[0],\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        # Non-local attention block\n        if with_nonlocal_attn is True:\n            blocks.append(\n                SPADEResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=reversed_block_out_channels[0],\n                    label_nc=label_nc,\n                    spade_intermediate_channels=spade_intermediate_channels,\n                )\n            )\n            blocks.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n            blocks.append(\n                SPADEResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=reversed_block_out_channels[0],\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    out_channels=reversed_block_out_channels[0],\n                    label_nc=label_nc,\n                    spade_intermediate_channels=spade_intermediate_channels,\n                )\n            )\n\n        reversed_attention_levels = list(reversed(attention_levels))\n        reversed_num_res_blocks = list(reversed(num_res_blocks))\n        block_out_ch = reversed_block_out_channels[0]\n        for i in range(len(reversed_block_out_channels)):\n            block_in_ch = block_out_ch\n            block_out_ch = reversed_block_out_channels[i]\n            is_final_block = i == len(channels) - 1\n\n            for _ in range(reversed_num_res_blocks[i]):\n                blocks.append(\n                    SPADEResBlock(\n                        spatial_dims=spatial_dims,\n                        in_channels=block_in_ch,\n                        norm_num_groups=norm_num_groups,\n                        norm_eps=norm_eps,\n                        out_channels=block_out_ch,\n                        label_nc=label_nc,\n                        spade_intermediate_channels=spade_intermediate_channels,\n                    )\n                )\n                block_in_ch = block_out_ch\n\n                if reversed_attention_levels[i]:\n                    blocks.append(\n                        SpatialAttentionBlock(\n                            spatial_dims=spatial_dims,\n                            num_channels=block_in_ch,\n                            norm_num_groups=norm_num_groups,\n                            norm_eps=norm_eps,\n                            include_fc=include_fc,\n                            use_combined_linear=use_combined_linear,\n                            use_flash_attention=use_flash_attention,\n                        )\n                    )\n\n            if not is_final_block:\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=block_in_ch,\n                    out_channels=block_in_ch,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                blocks.append(\n                    Upsample(\n                        spatial_dims=spatial_dims,\n                        mode=\"nontrainable\",\n                        in_channels=block_in_ch,\n                        out_channels=block_in_ch,\n                        interp_mode=\"nearest\",\n                        scale_factor=2.0,\n                        post_conv=post_conv,\n                        align_corners=None,\n                    )\n                )\n\n        blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))\n        blocks.append(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=block_in_ch,\n                out_channels=out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            if isinstance(block, SPADEResBlock):\n                x = block(x, seg)\n            else:\n                x = block(x)\n        return x\n\n\nclass SPADEAutoencoderKL(nn.Module):\n    \"\"\"\n    Autoencoder model with KL-regularized latent space based on\n    Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n    and Pinaya et al. \"Brain Imaging Generation with Latent Diffusion Models\" https://arxiv.org/abs/2209.07162\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: number of spatial dimensions (1D, 2D, 3D).\n        label_nc: number of semantic channels for SPADE normalisation.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        num_res_blocks: number of residual blocks (see ResBlock) per level.\n        channels: sequence of block output channels.\n        attention_levels: sequence of levels to add attention.\n        latent_channels: latent embedding dimension.\n        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.\n        norm_eps: epsilon for the normalization.\n        with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.\n        with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        label_nc: int,\n        in_channels: int = 1,\n        out_channels: int = 1,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        latent_channels: int = 3,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        with_encoder_nonlocal_attn: bool = True,\n        with_decoder_nonlocal_attn: bool = True,\n        spade_intermediate_channels: int = 128,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in channels):\n            raise ValueError(\"SPADEAutoencoderKL expects all channels being multiple of norm_num_groups\")\n\n        if len(channels) != len(attention_levels):\n            raise ValueError(\"SPADEAutoencoderKL expects channels being same size of attention_levels\")\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))\n\n        if len(num_res_blocks) != len(channels):\n            raise ValueError(\n                \"`num_res_blocks` should be a single integer or a tuple of integers with the same length as \"\n                \"`channels`.\"\n            )\n\n        self.encoder = Encoder(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            channels=channels,\n            out_channels=latent_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            attention_levels=attention_levels,\n            with_nonlocal_attn=with_encoder_nonlocal_attn,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.decoder = SPADEDecoder(\n            spatial_dims=spatial_dims,\n            channels=channels,\n            in_channels=latent_channels,\n            out_channels=out_channels,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            attention_levels=attention_levels,\n            label_nc=label_nc,\n            with_nonlocal_attn=with_decoder_nonlocal_attn,\n            spade_intermediate_channels=spade_intermediate_channels,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n        self.quant_conv_mu = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=latent_channels,\n            out_channels=latent_channels,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.quant_conv_log_sigma = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=latent_channels,\n            out_channels=latent_channels,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.post_quant_conv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=latent_channels,\n            out_channels=latent_channels,\n            strides=1,\n            kernel_size=1,\n            padding=0,\n            conv_only=True,\n        )\n        self.latent_channels = latent_channels\n\n    def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.\n\n        Args:\n            x: BxCx[SPATIAL DIMS] tensor\n\n        \"\"\"\n        h = self.encoder(x)\n        z_mu = self.quant_conv_mu(h)\n        z_log_var = self.quant_conv_log_sigma(h)\n        z_log_var = torch.clamp(z_log_var, -30.0, 20.0)\n        z_sigma = torch.exp(z_log_var / 2)\n\n        return z_mu, z_sigma\n\n    def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        From the mean and sigma representations resulting of encoding an image through the latent space,\n        obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and\n        adding the mean.\n\n        Args:\n            z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image\n            z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image\n\n        Returns:\n            sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]\n        \"\"\"\n        eps = torch.randn_like(z_sigma)\n        z_vae = z_mu + eps * z_sigma\n        return z_vae\n\n    def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Encodes and decodes an input image.\n\n        Args:\n            x: BxCx[SPATIAL DIMENSIONS] tensor.\n            seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.\n        Returns:\n            reconstructed image, of the same shape as input\n        \"\"\"\n        z_mu, _ = self.encode(x)\n        reconstruction = self.decode(z_mu, seg)\n        return reconstruction\n\n    def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Based on a latent space sample, forwards it through the Decoder.\n\n        Args:\n            z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]\n            seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.\n        Returns:\n            decoded image tensor\n        \"\"\"\n        z = self.post_quant_conv(z)\n        dec: torch.Tensor = self.decoder(z, seg)\n        return dec\n\n    def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        z_mu, z_sigma = self.encode(x)\n        z = self.sampling(z_mu, z_sigma)\n        reconstruction = self.decode(z, seg)\n        return reconstruction, z_mu, z_sigma\n\n    def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:\n        z_mu, z_sigma = self.encode(x)\n        z = self.sampling(z_mu, z_sigma)\n        return z\n\n    def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:\n        image = self.decode(z, seg)\n        return image\n"
  },
  {
    "path": "monai/networks/nets/spade_diffusion_model_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\n\nfrom monai.networks.blocks import Convolution, SpatialAttentionBlock\nfrom monai.networks.blocks.spade_norm import SPADE\nfrom monai.networks.nets.diffusion_model_unet import (\n    DiffusionUnetDownsample,\n    DiffusionUNetResnetBlock,\n    SpatialTransformer,\n    WrappedUpsample,\n    get_down_block,\n    get_mid_block,\n    get_timestep_embedding,\n    zero_module,\n)\nfrom monai.utils import ensure_tuple_rep\n\n__all__ = [\"SPADEDiffusionModelUNet\"]\n\n\nclass SPADEDiffResBlock(nn.Module):\n    \"\"\"\n    Residual block with timestep conditioning and SPADE norm.\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        temb_channels: number of timestep embedding  channels.\n        label_nc: number of semantic channels for SPADE normalisation.\n        out_channels: number of output channels.\n        up: if True, performs upsampling.\n        down: if True, performs downsampling.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        temb_channels: int,\n        label_nc: int,\n        out_channels: int | None = None,\n        up: bool = False,\n        down: bool = False,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        spade_intermediate_channels: int = 128,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.channels = in_channels\n        self.emb_channels = temb_channels\n        self.out_channels = out_channels or in_channels\n        self.up = up\n        self.down = down\n\n        self.norm1 = SPADE(\n            label_nc=label_nc,\n            norm_nc=in_channels,\n            norm=\"GROUP\",\n            norm_params={\"num_groups\": norm_num_groups, \"eps\": norm_eps, \"affine\": True},\n            hidden_channels=spade_intermediate_channels,\n            kernel_size=3,\n            spatial_dims=spatial_dims,\n        )\n\n        self.nonlinearity = nn.SiLU()\n        self.conv1 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=self.out_channels,\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        self.upsample = self.downsample = None\n        if self.up:\n            self.upsample = WrappedUpsample(\n                spatial_dims=spatial_dims,\n                mode=\"nontrainable\",\n                in_channels=in_channels,\n                out_channels=in_channels,\n                interp_mode=\"nearest\",\n                scale_factor=2.0,\n                align_corners=None,\n            )\n        elif down:\n            self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)\n\n        self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)\n\n        self.norm2 = SPADE(\n            label_nc=label_nc,\n            norm_nc=self.out_channels,\n            norm=\"GROUP\",\n            norm_params={\"num_groups\": norm_num_groups, \"eps\": norm_eps, \"affine\": True},\n            hidden_channels=spade_intermediate_channels,\n            kernel_size=3,\n            spatial_dims=spatial_dims,\n        )\n        self.conv2 = zero_module(\n            Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.out_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n        self.skip_connection: nn.Module\n\n        if self.out_channels == in_channels:\n            self.skip_connection = nn.Identity()\n        else:\n            self.skip_connection = Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=in_channels,\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=1,\n                padding=0,\n                conv_only=True,\n            )\n\n    def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:\n        h = x\n        h = self.norm1(h, seg)\n        h = self.nonlinearity(h)\n\n        if self.upsample is not None:\n            x = self.upsample(x)\n            h = self.upsample(h)\n        elif self.downsample is not None:\n            x = self.downsample(x)\n            h = self.downsample(h)\n\n        h = self.conv1(h)\n\n        if self.spatial_dims == 2:\n            temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]\n        else:\n            temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]\n        h = h + temb\n\n        h = self.norm2(h, seg)\n        h = self.nonlinearity(h)\n        h = self.conv2(h)\n        output: torch.Tensor = self.skip_connection(x) + h\n        return output\n\n\nclass SPADEUpBlock(nn.Module):\n    \"\"\"\n    Unet's up block containing resnet and upsamplers blocks.\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        prev_output_channel: number of channels from residual connection.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        label_nc: number of semantic channels for SPADE normalisation.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_upsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for upsampling.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        label_nc: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_upsample: bool = True,\n        resblock_updown: bool = False,\n        spade_intermediate_channels: int = 128,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n        resnets = []\n\n        for i in range(num_res_blocks):\n            res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                SPADEDiffResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    label_nc=label_nc,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    spade_intermediate_channels=spade_intermediate_channels,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        self.upsampler: nn.Module | None\n        if add_upsample:\n            if resblock_updown:\n                self.upsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    up=True,\n                )\n            else:\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                self.upsampler = WrappedUpsample(\n                    spatial_dims=spatial_dims,\n                    mode=\"nontrainable\",\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    interp_mode=\"nearest\",\n                    scale_factor=2.0,\n                    post_conv=post_conv,\n                    align_corners=None,\n                )\n        else:\n            self.upsampler = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_list: list[torch.Tensor],\n        temb: torch.Tensor,\n        seg: torch.Tensor,\n        context: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        del context\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_list[-1]\n            res_hidden_states_list = res_hidden_states_list[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n            hidden_states = resnet(hidden_states, temb, seg)\n\n        if self.upsampler is not None:\n            hidden_states = self.upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\nclass SPADEAttnUpBlock(nn.Module):\n    \"\"\"\n    Unet's up block containing resnet, upsamplers, and self-attention blocks.\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        prev_output_channel: number of channels from residual connection.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        label_nc: number of semantic channels for SPADE normalisation\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_upsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for upsampling.\n        num_head_channels: number of channels in each attention head.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        label_nc: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_upsample: bool = True,\n        resblock_updown: bool = False,\n        num_head_channels: int = 1,\n        spade_intermediate_channels: int = 128,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n        resnets = []\n        attentions = []\n\n        for i in range(num_res_blocks):\n            res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                SPADEDiffResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    label_nc=label_nc,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    spade_intermediate_channels=spade_intermediate_channels,\n                )\n            )\n            attentions.append(\n                SpatialAttentionBlock(\n                    spatial_dims=spatial_dims,\n                    num_channels=out_channels,\n                    num_head_channels=num_head_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.attentions = nn.ModuleList(attentions)\n\n        self.upsampler: nn.Module | None\n        if add_upsample:\n            if resblock_updown:\n                self.upsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    up=True,\n                )\n            else:\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                self.upsampler = WrappedUpsample(\n                    spatial_dims=spatial_dims,\n                    mode=\"nontrainable\",\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    interp_mode=\"nearest\",\n                    scale_factor=2.0,\n                    post_conv=post_conv,\n                    align_corners=None,\n                )\n        else:\n            self.upsampler = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_list: list[torch.Tensor],\n        temb: torch.Tensor,\n        seg: torch.Tensor,\n        context: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        del context\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_list[-1]\n            res_hidden_states_list = res_hidden_states_list[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n            hidden_states = resnet(hidden_states, temb, seg)\n            hidden_states = attn(hidden_states).contiguous()\n\n        if self.upsampler is not None:\n            hidden_states = self.upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\nclass SPADECrossAttnUpBlock(nn.Module):\n    \"\"\"\n    Unet's up block containing resnet, upsamplers, and self-attention blocks.\n    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: The number of spatial dimensions.\n        in_channels: number of input channels.\n        prev_output_channel: number of channels from residual connection.\n        out_channels: number of output channels.\n        temb_channels: number of timestep embedding channels.\n        label_nc: number of semantic channels for SPADE normalisation.\n        num_res_blocks: number of residual blocks.\n        norm_num_groups: number of groups for the group normalization.\n        norm_eps: epsilon for the group normalization.\n        add_upsample: if True add downsample block.\n        resblock_updown: if True use residual blocks for upsampling.\n        num_head_channels: number of channels in each attention head.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        upcast_attention: if True, upcast attention operations to full precision.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism.\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        label_nc: int,\n        num_res_blocks: int = 1,\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        add_upsample: bool = True,\n        resblock_updown: bool = False,\n        num_head_channels: int = 1,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        upcast_attention: bool = False,\n        spade_intermediate_channels: int = 128,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.resblock_updown = resblock_updown\n        resnets = []\n        attentions = []\n\n        for i in range(num_res_blocks):\n            res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                SPADEDiffResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    label_nc=label_nc,\n                    spade_intermediate_channels=spade_intermediate_channels,\n                )\n            )\n            attentions.append(\n                SpatialTransformer(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    num_attention_heads=out_channels // num_head_channels,\n                    num_head_channels=num_head_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    num_layers=transformer_num_layers,\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.upsampler: nn.Module | None\n        if add_upsample:\n            if resblock_updown:\n                self.upsampler = DiffusionUNetResnetBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    norm_num_groups=norm_num_groups,\n                    norm_eps=norm_eps,\n                    up=True,\n                )\n            else:\n                post_conv = Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n                self.upsampler = WrappedUpsample(\n                    spatial_dims=spatial_dims,\n                    mode=\"nontrainable\",\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    interp_mode=\"nearest\",\n                    scale_factor=2.0,\n                    post_conv=post_conv,\n                    align_corners=None,\n                )\n        else:\n            self.upsampler = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_list: list[torch.Tensor],\n        temb: torch.Tensor,\n        seg: torch.Tensor | None = None,\n        context: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_list[-1]\n            res_hidden_states_list = res_hidden_states_list[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n            hidden_states = resnet(hidden_states, temb, seg)\n            hidden_states = attn(hidden_states, context=context).contiguous()\n\n        if self.upsampler is not None:\n            hidden_states = self.upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\ndef get_spade_up_block(\n    spatial_dims: int,\n    in_channels: int,\n    prev_output_channel: int,\n    out_channels: int,\n    temb_channels: int,\n    num_res_blocks: int,\n    norm_num_groups: int,\n    norm_eps: float,\n    add_upsample: bool,\n    resblock_updown: bool,\n    with_attn: bool,\n    with_cross_attn: bool,\n    num_head_channels: int,\n    transformer_num_layers: int,\n    label_nc: int,\n    cross_attention_dim: int | None,\n    upcast_attention: bool = False,\n    spade_intermediate_channels: int = 128,\n    include_fc: bool = True,\n    use_combined_linear: bool = False,\n    use_flash_attention: bool = False,\n) -> nn.Module:\n    if with_attn:\n        return SPADEAttnUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            prev_output_channel=prev_output_channel,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            label_nc=label_nc,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_upsample=add_upsample,\n            resblock_updown=resblock_updown,\n            num_head_channels=num_head_channels,\n            spade_intermediate_channels=spade_intermediate_channels,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n    elif with_cross_attn:\n        return SPADECrossAttnUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            prev_output_channel=prev_output_channel,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            label_nc=label_nc,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_upsample=add_upsample,\n            resblock_updown=resblock_updown,\n            num_head_channels=num_head_channels,\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            spade_intermediate_channels=spade_intermediate_channels,\n            use_flash_attention=use_flash_attention,\n        )\n    else:\n        return SPADEUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            prev_output_channel=prev_output_channel,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            label_nc=label_nc,\n            num_res_blocks=num_res_blocks,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            add_upsample=add_upsample,\n            resblock_updown=resblock_updown,\n            spade_intermediate_channels=spade_intermediate_channels,\n        )\n\n\nclass SPADEDiffusionModelUNet(nn.Module):\n    \"\"\"\n    UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for\n    semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at\n    https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        label_nc: number of semantic channels for SPADE normalisation.\n        num_res_blocks: number of residual blocks (see ResnetBlock) per level.\n        channels: tuple of block output channels.\n        attention_levels: list of levels to add attention.\n        norm_num_groups: number of groups for the normalization.\n        norm_eps: epsilon for the normalization.\n        resblock_updown: if True use residual blocks for up/downsampling.\n        num_head_channels: number of channels in each attention head.\n        with_conditioning: if True add spatial transformers to perform conditioning.\n        transformer_num_layers: number of layers of Transformer blocks to use.\n        cross_attention_dim: number of context dimensions to use.\n        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`\n            classes.\n        upcast_attention: if True, upcast attention operations to full precision.\n        spade_intermediate_channels: number of intermediate channels for SPADE block layer.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        label_nc: int,\n        num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),\n        channels: Sequence[int] = (32, 64, 64, 64),\n        attention_levels: Sequence[bool] = (False, False, True, True),\n        norm_num_groups: int = 32,\n        norm_eps: float = 1e-6,\n        resblock_updown: bool = False,\n        num_head_channels: int | Sequence[int] = 8,\n        with_conditioning: bool = False,\n        transformer_num_layers: int = 1,\n        cross_attention_dim: int | None = None,\n        num_class_embeds: int | None = None,\n        upcast_attention: bool = False,\n        spade_intermediate_channels: int = 128,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        if with_conditioning is True and cross_attention_dim is None:\n            raise ValueError(\n                \"SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) \"\n                \"when using with_conditioning.\"\n            )\n        if cross_attention_dim is not None and with_conditioning is False:\n            raise ValueError(\n                \"SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.\"\n            )\n\n        # All number of channels should be multiple of num_groups\n        if any((out_channel % norm_num_groups) != 0 for out_channel in channels):\n            raise ValueError(\"SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups\")\n\n        if len(channels) != len(attention_levels):\n            raise ValueError(\"SPADEDiffusionModelUNet expects num_channels being same size of attention_levels\")\n\n        if isinstance(num_head_channels, int):\n            num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))\n\n        if len(num_head_channels) != len(attention_levels):\n            raise ValueError(\n                \"num_head_channels should have the same length as attention_levels. For the i levels without attention,\"\n                \" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.\"\n            )\n\n        if isinstance(num_res_blocks, int):\n            num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))\n\n        if len(num_res_blocks) != len(channels):\n            raise ValueError(\n                \"`num_res_blocks` should be a single integer or a tuple of integers with the same length as \"\n                \"`num_channels`.\"\n            )\n\n        self.in_channels = in_channels\n        self.block_out_channels = channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_levels = attention_levels\n        self.num_head_channels = num_head_channels\n        self.with_conditioning = with_conditioning\n        self.label_nc = label_nc\n\n        # input\n        self.conv_in = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=channels[0],\n            strides=1,\n            kernel_size=3,\n            padding=1,\n            conv_only=True,\n        )\n\n        # time\n        time_embed_dim = channels[0] * 4\n        self.time_embed = nn.Sequential(\n            nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)\n        )\n\n        # class embedding\n        self.num_class_embeds = num_class_embeds\n        if num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n\n        # down\n        self.down_blocks = nn.ModuleList([])\n        output_channel = channels[0]\n        for i in range(len(channels)):\n            input_channel = output_channel\n            output_channel = channels[i]\n            is_final_block = i == len(channels) - 1\n\n            down_block = get_down_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                num_res_blocks=num_res_blocks[i],\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_downsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(attention_levels[i] and not with_conditioning),\n                with_cross_attn=(attention_levels[i] and with_conditioning),\n                num_head_channels=num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                include_fc=include_fc,\n                use_combined_linear=use_combined_linear,\n                use_flash_attention=use_flash_attention,\n            )\n\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.middle_block = get_mid_block(\n            spatial_dims=spatial_dims,\n            in_channels=channels[-1],\n            temb_channels=time_embed_dim,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            with_conditioning=with_conditioning,\n            num_head_channels=num_head_channels[-1],\n            transformer_num_layers=transformer_num_layers,\n            cross_attention_dim=cross_attention_dim,\n            upcast_attention=upcast_attention,\n            include_fc=include_fc,\n            use_combined_linear=use_combined_linear,\n            use_flash_attention=use_flash_attention,\n        )\n\n        # up\n        self.up_blocks = nn.ModuleList([])\n        reversed_block_out_channels = list(reversed(channels))\n        reversed_num_res_blocks = list(reversed(num_res_blocks))\n        reversed_attention_levels = list(reversed(attention_levels))\n        reversed_num_head_channels = list(reversed(num_head_channels))\n        output_channel = reversed_block_out_channels[0]\n        for i in range(len(reversed_block_out_channels)):\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]\n\n            is_final_block = i == len(channels) - 1\n\n            up_block = get_spade_up_block(\n                spatial_dims=spatial_dims,\n                in_channels=input_channel,\n                prev_output_channel=prev_output_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                num_res_blocks=reversed_num_res_blocks[i] + 1,\n                norm_num_groups=norm_num_groups,\n                norm_eps=norm_eps,\n                add_upsample=not is_final_block,\n                resblock_updown=resblock_updown,\n                with_attn=(reversed_attention_levels[i] and not with_conditioning),\n                with_cross_attn=(reversed_attention_levels[i] and with_conditioning),\n                num_head_channels=reversed_num_head_channels[i],\n                transformer_num_layers=transformer_num_layers,\n                cross_attention_dim=cross_attention_dim,\n                upcast_attention=upcast_attention,\n                label_nc=label_nc,\n                spade_intermediate_channels=spade_intermediate_channels,\n                use_flash_attention=use_flash_attention,\n            )\n\n            self.up_blocks.append(up_block)\n\n        # out\n        self.out = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),\n            nn.SiLU(),\n            zero_module(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=channels[0],\n                    out_channels=out_channels,\n                    strides=1,\n                    kernel_size=3,\n                    padding=1,\n                    conv_only=True,\n                )\n            ),\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        timesteps: torch.Tensor,\n        seg: torch.Tensor,\n        context: torch.Tensor | None = None,\n        class_labels: torch.Tensor | None = None,\n        down_block_additional_residuals: tuple[torch.Tensor] | None = None,\n        mid_block_additional_residual: torch.Tensor | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x: input tensor (N, C, SpatialDims).\n            timesteps: timestep tensor (N,).\n            seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.\n            context: context tensor (N, 1, ContextDim).\n            class_labels: context tensor (N, ).\n            down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).\n            mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).\n        \"\"\"\n        # 1. time\n        t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=x.dtype)\n        emb = self.time_embed(t_emb)\n\n        # 2. class\n        if self.num_class_embeds is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n            class_emb = self.class_embedding(class_labels)\n            class_emb = class_emb.to(dtype=x.dtype)\n            emb = emb + class_emb\n\n        # 3. initial convolution\n        h = self.conv_in(x)\n\n        # 4. down\n        if context is not None and self.with_conditioning is False:\n            raise ValueError(\"model should have with_conditioning = True if context is provided\")\n        down_block_res_samples: list[torch.Tensor] = [h]\n        for downsample_block in self.down_blocks:\n            h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)\n            for residual in res_samples:\n                down_block_res_samples.append(residual)\n\n        # Additional residual conections for Controlnets\n        if down_block_additional_residuals is not None:\n            new_down_block_res_samples: list[torch.Tensor] = [h]\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples.append(down_block_res_sample)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 5. mid\n        h = self.middle_block(hidden_states=h, temb=emb, context=context)\n\n        # Additional residual conections for Controlnets\n        if mid_block_additional_residual is not None:\n            h = h + mid_block_additional_residual\n\n        # 6. up\n        for upsample_block in self.up_blocks:\n            idx: int = -len(upsample_block.resnets)  # type: ignore\n            res_samples = down_block_res_samples[idx:]\n            down_block_res_samples = down_block_res_samples[:idx]\n            h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context)\n\n        # 7. output block\n        output: torch.Tensor = self.out(h)\n\n        return output\n"
  },
  {
    "path": "monai/networks/nets/spade_network.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.blocks.spade_norm import SPADE\nfrom monai.networks.layers import Act\nfrom monai.networks.layers.utils import get_act_layer\nfrom monai.utils.enums import StrEnum\n\n__all__ = [\"SPADENet\"]\n\n\nclass UpsamplingModes(StrEnum):\n    bicubic = \"bicubic\"\n    nearest = \"nearest\"\n    bilinear = \"bilinear\"\n\n\nclass SPADENetResBlock(nn.Module):\n    \"\"\"\n    Creates a Residual Block with SPADE normalisation.\n\n    Args:\n        spatial_dims: number of spatial dimensions\n        in_channels: number of input channels\n        out_channels: number of output channels\n        label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks\n        spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks\n        norm: base normalisation type used on top of SPADE\n        kernel_size: convolutional kernel size\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        label_nc: int,\n        spade_intermediate_channels: int = 128,\n        norm: str | tuple = \"INSTANCE\",\n        act: str | tuple = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        kernel_size: int = 3,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.int_channels = min(self.in_channels, self.out_channels)\n        self.learned_shortcut = self.in_channels != self.out_channels\n        self.conv_0 = Convolution(\n            spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None\n        )\n        self.conv_1 = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=self.int_channels,\n            out_channels=self.out_channels,\n            act=None,\n            norm=None,\n        )\n        self.activation = get_act_layer(act)\n        self.norm_0 = SPADE(\n            label_nc=label_nc,\n            norm_nc=self.in_channels,\n            kernel_size=kernel_size,\n            spatial_dims=spatial_dims,\n            hidden_channels=spade_intermediate_channels,\n            norm=norm,\n        )\n        self.norm_1 = SPADE(\n            label_nc=label_nc,\n            norm_nc=self.int_channels,\n            kernel_size=kernel_size,\n            spatial_dims=spatial_dims,\n            hidden_channels=spade_intermediate_channels,\n            norm=norm,\n        )\n\n        if self.learned_shortcut:\n            self.conv_s = Convolution(\n                spatial_dims=spatial_dims,\n                in_channels=self.in_channels,\n                out_channels=self.out_channels,\n                act=None,\n                norm=None,\n                kernel_size=1,\n            )\n            self.norm_s = SPADE(\n                label_nc=label_nc,\n                norm_nc=self.in_channels,\n                kernel_size=kernel_size,\n                spatial_dims=spatial_dims,\n                hidden_channels=spade_intermediate_channels,\n                norm=norm,\n            )\n\n    def forward(self, x, seg):\n        x_s = self.shortcut(x, seg)\n        dx = self.conv_0(self.activation(self.norm_0(x, seg)))\n        dx = self.conv_1(self.activation(self.norm_1(dx, seg)))\n        out = x_s + dx\n        return out\n\n    def shortcut(self, x, seg):\n        if self.learned_shortcut:\n            x_s = self.conv_s(self.norm_s(x, seg))\n        else:\n            x_s = x\n        return x_s\n\n\nclass SPADEEncoder(nn.Module):\n    \"\"\"\n    Encoding branch of a VAE compatible with a SPADE-like generator\n\n    Args:\n        spatial_dims: number of spatial dimensions\n        in_channels: number of input channels\n        z_dim: latent space dimension of the VAE containing the image sytle information\n        channels: number of output after each downsampling block\n        input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers\n        of the autoencoder (HxWx[D])\n        kernel_size: convolutional kernel size\n        norm: normalisation layer type\n        act: activation type\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        z_dim: int,\n        channels: Sequence[int],\n        input_shape: Sequence[int],\n        kernel_size: int = 3,\n        norm: str | tuple = \"INSTANCE\",\n        act: str | tuple = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.z_dim = z_dim\n        self.channels = channels\n        if len(input_shape) != spatial_dims:\n            raise ValueError(f\"Length of parameter input shape must match spatial_dims; got {input_shape}\")\n        for s_ind, s_ in enumerate(input_shape):\n            if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):\n                raise ValueError(\n                    \"Each dimension of your input must be divisible by 2 ** (autoencoder depth).\"\n                    f\"The shape in position {s_ind}, {s_} is not divisible by {len(channels)}. \"\n                )\n        self.input_shape = input_shape\n        self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape]\n        blocks = []\n        ch_init = self.in_channels\n        for _, ch_value in enumerate(channels):\n            blocks.append(\n                Convolution(\n                    spatial_dims=spatial_dims,\n                    in_channels=ch_init,\n                    out_channels=ch_value,\n                    strides=2,\n                    kernel_size=kernel_size,\n                    norm=norm,\n                    act=act,\n                )\n            )\n            ch_init = ch_value\n\n        self.blocks = nn.ModuleList(blocks)\n        self.fc_mu = nn.Linear(\n            in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim\n        )\n        self.fc_var = nn.Linear(\n            in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim\n        )\n\n    def forward(self, x):\n        for block in self.blocks:\n            x = block(x)\n        x = x.view(x.size(0), -1)\n        mu = self.fc_mu(x)\n        logvar = self.fc_var(x)\n        return mu, logvar\n\n    def encode(self, x):\n        for block in self.blocks:\n            x = block(x)\n        x = x.view(x.size(0), -1)\n        mu = self.fc_mu(x)\n        logvar = self.fc_var(x)\n        return self.reparameterize(mu, logvar)\n\n    def reparameterize(self, mu, logvar):\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps.mul(std) + mu\n\n\nclass SPADEDecoder(nn.Module):\n    \"\"\"\n    Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch,\n    behaving like a GAN, or coupled to a SPADE encoder.\n\n    Args:\n        label_nc: number of semantic labels\n        spatial_dims: number of spatial dimensions\n        out_channels: number of output channels\n        label_nc: number of semantic channels used for the SPADE normalisation blocks\n        input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers\n        channels: number of output after each downsampling block\n        z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)\n        is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no)\n        spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks\n        norm: base normalisation type\n        act:  activation layer type\n        last_act: activation layer type for the last layer of the network (can differ from previous)\n        kernel_size: convolutional kernel size\n        upsampling_mode: upsampling mode (nearest, bilinear etc.)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        out_channels: int,\n        label_nc: int,\n        input_shape: Sequence[int],\n        channels: list[int],\n        z_dim: int | None = None,\n        is_vae: bool = True,\n        spade_intermediate_channels: int = 128,\n        norm: str | tuple = \"INSTANCE\",\n        act: str | tuple = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        last_act: str | tuple | None = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        kernel_size: int = 3,\n        upsampling_mode: str = UpsamplingModes.nearest.value,\n    ):\n        super().__init__()\n        self.is_vae = is_vae\n        self.out_channels = out_channels\n        self.label_nc = label_nc\n        self.num_channels = channels\n        if len(input_shape) != spatial_dims:\n            raise ValueError(f\"Length of parameter input shape must match spatial_dims; got {input_shape}\")\n        for s_ind, s_ in enumerate(input_shape):\n            if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):\n                raise ValueError(\n                    \"Each dimension of your input must be divisible by 2 ** (autoencoder depth).\"\n                    f\"The shape in position {s_ind}, {s_} is not divisible by {len(channels)}. \"\n                )\n        self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape]\n\n        if not self.is_vae:\n            self.conv_init = Convolution(\n                spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size\n            )\n        elif self.is_vae and z_dim is None:\n            raise ValueError(\n                \"If the network is used in VAE-GAN mode, parameter z_dim \"\n                \"(number of latent channels in the VAE) must be populated.\"\n            )\n        else:\n            self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0])\n\n        self.z_dim = z_dim\n        blocks = []\n        channels.append(self.out_channels)\n        self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode)\n        for ch_ind, ch_value in enumerate(channels[:-1]):\n            blocks.append(\n                SPADENetResBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=ch_value,\n                    out_channels=channels[ch_ind + 1],\n                    label_nc=label_nc,\n                    spade_intermediate_channels=spade_intermediate_channels,\n                    norm=norm,\n                    kernel_size=kernel_size,\n                    act=act,\n                )\n            )\n\n        self.blocks = torch.nn.ModuleList(blocks)\n        self.last_conv = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=channels[-1],\n            out_channels=out_channels,\n            padding=(kernel_size - 1) // 2,\n            kernel_size=kernel_size,\n            norm=None,\n            act=last_act,\n        )\n\n    def forward(self, seg, z: torch.Tensor | None = None):\n        \"\"\"\n        Args:\n            seg: input BxCxHxW[xD] semantic map on which the output is conditioned on\n            z: latent vector output by the encoder if self.is_vae is True. When is_vae is\n            False, z is a random noise vector.\n\n        Returns:\n\n        \"\"\"\n        if not self.is_vae:\n            x = F.interpolate(seg, size=tuple(self.latent_spatial_shape))\n            x = self.conv_init(x)\n        else:\n            if (\n                z is None and self.z_dim is not None\n            ):  # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well.\n                z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device())\n            x = self.fc(z)\n            x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape)\n\n        for res_block in self.blocks:\n            x = res_block(x, seg)\n            x = self.upsampling(x)\n\n        x = self.last_conv(x)\n        return x\n\n\nclass SPADENet(nn.Module):\n    \"\"\"\n    SPADE Network, implemented based on the code by Park, T et al. in\n    \"Semantic Image Synthesis with Spatially-Adaptive Normalization\"\n    (https://github.com/NVlabs/SPADE)\n\n    Args:\n        spatial_dims: number of spatial dimensions\n        in_channels: number of input channels\n        out_channels: number of output channels\n        label_nc: number of semantic channels used for the SPADE normalisation blocks\n        input_shape:  spatial input shape of the tensor, necessary to do the reshaping after the linear layers\n        channels: number of output after each downsampling block\n        z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)\n        is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false)\n        spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks\n        norm: base normalisation type\n        act: activation layer type\n        last_act: activation layer type for the last layer of the network (can differ from previous)\n        kernel_size: convolutional kernel size\n        upsampling_mode: upsampling mode (nearest, bilinear etc.)\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        label_nc: int,\n        input_shape: Sequence[int],\n        channels: list[int],\n        z_dim: int | None = None,\n        is_vae: bool = True,\n        spade_intermediate_channels: int = 128,\n        norm: str | tuple = \"INSTANCE\",\n        act: str | tuple = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        last_act: str | tuple | None = (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        kernel_size: int = 3,\n        upsampling_mode: str = UpsamplingModes.nearest.value,\n    ):\n        super().__init__()\n        self.is_vae = is_vae\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.channels = channels\n        self.label_nc = label_nc\n        self.input_shape = input_shape\n\n        if self.is_vae:\n            if z_dim is None:\n                ValueError(\"The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.\")\n            else:\n                self.encoder = SPADEEncoder(\n                    spatial_dims=spatial_dims,\n                    in_channels=in_channels,\n                    z_dim=z_dim,\n                    channels=channels,\n                    input_shape=input_shape,\n                    kernel_size=kernel_size,\n                    norm=norm,\n                    act=act,\n                )\n\n        decoder_channels = channels\n        decoder_channels.reverse()\n\n        self.decoder = SPADEDecoder(\n            spatial_dims=spatial_dims,\n            out_channels=out_channels,\n            label_nc=label_nc,\n            input_shape=input_shape,\n            channels=decoder_channels,\n            z_dim=z_dim,\n            is_vae=is_vae,\n            spade_intermediate_channels=spade_intermediate_channels,\n            norm=norm,\n            act=act,\n            last_act=last_act,\n            kernel_size=kernel_size,\n            upsampling_mode=upsampling_mode,\n        )\n\n    def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None):\n        z = None\n        if self.is_vae:\n            z_mu, z_logvar = self.encoder(x)\n            z = self.encoder.reparameterize(z_mu, z_logvar)\n            return self.decoder(seg, z), z_mu, z_logvar\n        else:\n            return (self.decoder(seg, z),)\n\n    def encode(self, x: torch.Tensor):\n        if self.is_vae:\n            return self.encoder.encode(x)\n        else:\n            return None\n\n    def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None):\n        return self.decoder(seg, z)\n"
  },
  {
    "path": "monai/networks/nets/swin_unetr.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as checkpoint\nfrom torch.nn import LayerNorm\n\nfrom monai.networks.blocks import MLPBlock as Mlp\nfrom monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock\nfrom monai.networks.layers import DropPath, trunc_normal_\nfrom monai.utils import ensure_tuple_rep, look_up_option, optional_import\n\nrearrange, _ = optional_import(\"einops\", name=\"rearrange\")\n\n__all__ = [\n    \"SwinUNETR\",\n    \"window_partition\",\n    \"window_reverse\",\n    \"WindowAttention\",\n    \"SwinTransformerBlock\",\n    \"PatchMerging\",\n    \"PatchMergingV2\",\n    \"MERGING_MODE\",\n    \"BasicLayer\",\n    \"SwinTransformer\",\n]\n\n\nclass SwinUNETR(nn.Module):\n    \"\"\"\n    Swin UNETR based on: \"Hatamizadeh et al.,\n    Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images\n    <https://arxiv.org/abs/2201.01266>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        patch_size: int = 2,\n        depths: Sequence[int] = (2, 2, 2, 2),\n        num_heads: Sequence[int] = (3, 6, 12, 24),\n        window_size: Sequence[int] | int = 7,\n        qkv_bias: bool = True,\n        mlp_ratio: float = 4.0,\n        feature_size: int = 24,\n        norm_name: tuple | str = \"instance\",\n        drop_rate: float = 0.0,\n        attn_drop_rate: float = 0.0,\n        dropout_path_rate: float = 0.0,\n        normalize: bool = True,\n        norm_layer: type[LayerNorm] = nn.LayerNorm,\n        patch_norm: bool = False,\n        use_checkpoint: bool = False,\n        spatial_dims: int = 3,\n        downsample: str | nn.Module = \"merging\",\n        use_v2: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: dimension of input channels.\n            out_channels: dimension of output channels.\n            patch_size: size of the patch token.\n            feature_size: dimension of network feature size.\n            depths: number of layers in each stage.\n            num_heads: number of attention heads.\n            window_size: local window size.\n            qkv_bias: add a learnable bias to query, key, value.\n            mlp_ratio: ratio of mlp hidden dim to embedding dim.\n            norm_name: feature normalization type and arguments.\n            drop_rate: dropout rate.\n            attn_drop_rate: attention dropout rate.\n            dropout_path_rate: drop path rate.\n            normalize: normalize output intermediate features in each stage.\n            norm_layer: normalization layer.\n            patch_norm: whether to apply normalization to the patch embedding. Default is False.\n            use_checkpoint: use gradient checkpointing for reduced memory usage.\n            spatial_dims: number of spatial dims.\n            downsample: module used for downsampling, available options are `\"mergingv2\"`, `\"merging\"` and a\n                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.\n                The default is currently `\"merging\"` (the original version defined in v0.9.0).\n            use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.\n\n        Examples::\n\n            # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.\n            >>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)\n\n            # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.\n            >>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))\n\n            # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.\n            >>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)\n\n        \"\"\"\n\n        super().__init__()\n\n        if spatial_dims not in (2, 3):\n            raise ValueError(\"spatial dimension should be 2 or 3.\")\n\n        self.patch_size = patch_size\n\n        patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)\n        window_size = ensure_tuple_rep(window_size, spatial_dims)\n\n        if not (0 <= drop_rate <= 1):\n            raise ValueError(\"dropout rate should be between 0 and 1.\")\n\n        if not (0 <= attn_drop_rate <= 1):\n            raise ValueError(\"attention dropout rate should be between 0 and 1.\")\n\n        if not (0 <= dropout_path_rate <= 1):\n            raise ValueError(\"drop path rate should be between 0 and 1.\")\n\n        if feature_size % 12 != 0:\n            raise ValueError(\"feature_size should be divisible by 12.\")\n\n        self.normalize = normalize\n\n        self.swinViT = SwinTransformer(\n            in_chans=in_channels,\n            embed_dim=feature_size,\n            window_size=window_size,\n            patch_size=patch_sizes,\n            depths=depths,\n            num_heads=num_heads,\n            mlp_ratio=mlp_ratio,\n            qkv_bias=qkv_bias,\n            drop_rate=drop_rate,\n            attn_drop_rate=attn_drop_rate,\n            drop_path_rate=dropout_path_rate,\n            norm_layer=norm_layer,\n            patch_norm=patch_norm,\n            use_checkpoint=use_checkpoint,\n            spatial_dims=spatial_dims,\n            downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,\n            use_v2=use_v2,\n        )\n\n        self.encoder1 = UnetrBasicBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=feature_size,\n            kernel_size=3,\n            stride=1,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.encoder2 = UnetrBasicBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size,\n            out_channels=feature_size,\n            kernel_size=3,\n            stride=1,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.encoder3 = UnetrBasicBlock(\n            spatial_dims=spatial_dims,\n            in_channels=2 * feature_size,\n            out_channels=2 * feature_size,\n            kernel_size=3,\n            stride=1,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.encoder4 = UnetrBasicBlock(\n            spatial_dims=spatial_dims,\n            in_channels=4 * feature_size,\n            out_channels=4 * feature_size,\n            kernel_size=3,\n            stride=1,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.encoder10 = UnetrBasicBlock(\n            spatial_dims=spatial_dims,\n            in_channels=16 * feature_size,\n            out_channels=16 * feature_size,\n            kernel_size=3,\n            stride=1,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.decoder5 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=16 * feature_size,\n            out_channels=8 * feature_size,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.decoder4 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size * 8,\n            out_channels=feature_size * 4,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.decoder3 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size * 4,\n            out_channels=feature_size * 2,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=True,\n        )\n        self.decoder2 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size * 2,\n            out_channels=feature_size,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.decoder1 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size,\n            out_channels=feature_size,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=True,\n        )\n\n        self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)\n\n    def load_from(self, weights):\n        layers1_0: BasicLayer = self.swinViT.layers1[0]  # type: ignore[assignment]\n        layers2_0: BasicLayer = self.swinViT.layers2[0]  # type: ignore[assignment]\n        layers3_0: BasicLayer = self.swinViT.layers3[0]  # type: ignore[assignment]\n        layers4_0: BasicLayer = self.swinViT.layers4[0]  # type: ignore[assignment]\n        wstate = weights[\"state_dict\"]\n\n        with torch.no_grad():\n            self.swinViT.patch_embed.proj.weight.copy_(wstate[\"module.patch_embed.proj.weight\"])\n            self.swinViT.patch_embed.proj.bias.copy_(wstate[\"module.patch_embed.proj.bias\"])\n            for bname, block in layers1_0.blocks.named_children():\n                block.load_from(weights, n_block=bname, layer=\"layers1\")  # type: ignore[operator]\n\n            if layers1_0.downsample is not None:\n                d = layers1_0.downsample\n                d.reduction.weight.copy_(wstate[\"module.layers1.0.downsample.reduction.weight\"])  # type: ignore\n                d.norm.weight.copy_(wstate[\"module.layers1.0.downsample.norm.weight\"])  # type: ignore\n                d.norm.bias.copy_(wstate[\"module.layers1.0.downsample.norm.bias\"])  # type: ignore\n\n            for bname, block in layers2_0.blocks.named_children():\n                block.load_from(weights, n_block=bname, layer=\"layers2\")  # type: ignore[operator]\n\n            if layers2_0.downsample is not None:\n                d = layers2_0.downsample\n                d.reduction.weight.copy_(wstate[\"module.layers2.0.downsample.reduction.weight\"])  # type: ignore\n                d.norm.weight.copy_(wstate[\"module.layers2.0.downsample.norm.weight\"])  # type: ignore\n                d.norm.bias.copy_(wstate[\"module.layers2.0.downsample.norm.bias\"])  # type: ignore\n\n            for bname, block in layers3_0.blocks.named_children():\n                block.load_from(weights, n_block=bname, layer=\"layers3\")  # type: ignore[operator]\n\n            if layers3_0.downsample is not None:\n                d = layers3_0.downsample\n                d.reduction.weight.copy_(wstate[\"module.layers3.0.downsample.reduction.weight\"])  # type: ignore\n                d.norm.weight.copy_(wstate[\"module.layers3.0.downsample.norm.weight\"])  # type: ignore\n                d.norm.bias.copy_(wstate[\"module.layers3.0.downsample.norm.bias\"])  # type: ignore\n\n            for bname, block in layers4_0.blocks.named_children():\n                block.load_from(weights, n_block=bname, layer=\"layers4\")  # type: ignore[operator]\n\n            if layers4_0.downsample is not None:\n                d = layers4_0.downsample\n                d.reduction.weight.copy_(wstate[\"module.layers4.0.downsample.reduction.weight\"])  # type: ignore\n                d.norm.weight.copy_(wstate[\"module.layers4.0.downsample.norm.weight\"])  # type: ignore\n                d.norm.bias.copy_(wstate[\"module.layers4.0.downsample.norm.bias\"])  # type: ignore\n\n    @torch.jit.unused\n    def _check_input_size(self, spatial_shape):\n        img_size = np.array(spatial_shape)\n        remainder = (img_size % np.power(self.patch_size, 5)) > 0\n        if remainder.any():\n            wrong_dims = (np.where(remainder)[0] + 2).tolist()\n            raise ValueError(\n                f\"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})\"\n                f\" must be divisible by {self.patch_size}**5.\"\n            )\n\n    def forward(self, x_in):\n        if not torch.jit.is_scripting() and not torch.jit.is_tracing():\n            self._check_input_size(x_in.shape[2:])\n        hidden_states_out = self.swinViT(x_in, self.normalize)\n        enc0 = self.encoder1(x_in)\n        enc1 = self.encoder2(hidden_states_out[0])\n        enc2 = self.encoder3(hidden_states_out[1])\n        enc3 = self.encoder4(hidden_states_out[2])\n        dec4 = self.encoder10(hidden_states_out[4])\n        dec3 = self.decoder5(dec4, hidden_states_out[3])\n        dec2 = self.decoder4(dec3, enc3)\n        dec1 = self.decoder3(dec2, enc2)\n        dec0 = self.decoder2(dec1, enc1)\n        out = self.decoder1(dec0, enc0)\n        logits = self.out(out)\n        return logits\n\n\ndef window_partition(x, window_size):\n    \"\"\"window partition operation based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n\n     Args:\n        x: input tensor.\n        window_size: local window size.\n    \"\"\"\n    x_shape = x.size()  # length 4 or 5 only\n    if len(x_shape) == 5:\n        b, d, h, w, c = x_shape\n        x = x.view(\n            b,\n            d // window_size[0],\n            window_size[0],\n            h // window_size[1],\n            window_size[1],\n            w // window_size[2],\n            window_size[2],\n            c,\n        )\n        windows = (\n            x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)\n        )\n    else:  # if len(x_shape) == 4:\n        b, h, w, c = x.shape\n        x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)\n        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)\n\n    return windows\n\n\ndef window_reverse(windows, window_size, dims):\n    \"\"\"window reverse operation based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n\n     Args:\n        windows: windows tensor.\n        window_size: local window size.\n        dims: dimension values.\n    \"\"\"\n    if len(dims) == 4:\n        b, d, h, w = dims\n        x = windows.view(\n            b,\n            d // window_size[0],\n            h // window_size[1],\n            w // window_size[2],\n            window_size[0],\n            window_size[1],\n            window_size[2],\n            -1,\n        )\n        x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)\n\n    elif len(dims) == 3:\n        b, h, w = dims\n        x = windows.view(b, h // window_size[0], w // window_size[1], window_size[0], window_size[1], -1)\n        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)\n    return x\n\n\ndef get_window_size(x_size, window_size, shift_size=None):\n    \"\"\"Computing window size based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n\n     Args:\n        x_size: input size.\n        window_size: local window size.\n        shift_size: window shifting size.\n    \"\"\"\n\n    use_window_size = list(window_size)\n    if shift_size is not None:\n        use_shift_size = list(shift_size)\n    for i in range(len(x_size)):\n        if x_size[i] <= window_size[i]:\n            use_window_size[i] = x_size[i]\n            if shift_size is not None:\n                use_shift_size[i] = 0\n\n    if shift_size is None:\n        return tuple(use_window_size)\n    else:\n        return tuple(use_window_size), tuple(use_shift_size)\n\n\nclass WindowAttention(nn.Module):\n    \"\"\"\n    Window based multi-head self attention module with relative position bias based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        window_size: Sequence[int],\n        qkv_bias: bool = False,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim: number of feature channels.\n            num_heads: number of attention heads.\n            window_size: local window size.\n            qkv_bias: add a learnable bias to query, key, value.\n            attn_drop: attention dropout rate.\n            proj_drop: dropout rate of output.\n        \"\"\"\n\n        super().__init__()\n        self.dim = dim\n        self.window_size = window_size\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n        mesh_args = torch.meshgrid.__kwdefaults__\n\n        if len(self.window_size) == 3:\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros(\n                    (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),\n                    num_heads,\n                )\n            )\n            coords_d = torch.arange(self.window_size[0])\n            coords_h = torch.arange(self.window_size[1])\n            coords_w = torch.arange(self.window_size[2])\n            if mesh_args is not None:\n                coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing=\"ij\"))\n            else:\n                coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))\n            coords_flatten = torch.flatten(coords, 1)\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()\n            relative_coords[:, :, 0] += self.window_size[0] - 1\n            relative_coords[:, :, 1] += self.window_size[1] - 1\n            relative_coords[:, :, 2] += self.window_size[2] - 1\n            relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)\n            relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1\n        elif len(self.window_size) == 2:\n            self.relative_position_bias_table = nn.Parameter(\n                torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)\n            )\n            coords_h = torch.arange(self.window_size[0])\n            coords_w = torch.arange(self.window_size[1])\n            if mesh_args is not None:\n                coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing=\"ij\"))\n            else:\n                coords = torch.stack(torch.meshgrid(coords_h, coords_w))\n            coords_flatten = torch.flatten(coords, 1)\n            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]\n            relative_coords = relative_coords.permute(1, 2, 0).contiguous()\n            relative_coords[:, :, 0] += self.window_size[0] - 1\n            relative_coords[:, :, 1] += self.window_size[1] - 1\n            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n\n        relative_position_index = relative_coords.sum(-1)\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        trunc_normal_(self.relative_position_bias_table, std=0.02)\n        self.softmax = nn.Softmax(dim=-1)\n\n    def forward(self, x, mask):\n        b, n, c = x.shape\n        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n        q = q * self.scale\n        attn = q @ k.transpose(-2, -1)\n        relative_position_bias = self.relative_position_bias_table[\n            self.relative_position_index.clone()[:n, :n].reshape(-1)  # type: ignore[operator]\n        ].reshape(n, n, -1)\n        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()\n        attn = attn + relative_position_bias.unsqueeze(0)\n        if mask is not None:\n            nw = mask.shape[0]\n            attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, n, n)\n            attn = self.softmax(attn)\n        else:\n            attn = self.softmax(attn)\n\n        attn = self.attn_drop(attn).to(v.dtype)\n        x = (attn @ v).transpose(1, 2).reshape(b, n, c)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass SwinTransformerBlock(nn.Module):\n    \"\"\"\n    Swin Transformer block based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        window_size: Sequence[int],\n        shift_size: Sequence[int],\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        drop_path: float = 0.0,\n        act_layer: str = \"GELU\",\n        norm_layer: type[LayerNorm] = nn.LayerNorm,\n        use_checkpoint: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim: number of feature channels.\n            num_heads: number of attention heads.\n            window_size: local window size.\n            shift_size: window shift size.\n            mlp_ratio: ratio of mlp hidden dim to embedding dim.\n            qkv_bias: add a learnable bias to query, key, value.\n            drop: dropout rate.\n            attn_drop: attention dropout rate.\n            drop_path: stochastic depth rate.\n            act_layer: activation layer.\n            norm_layer: normalization layer.\n            use_checkpoint: use gradient checkpointing for reduced memory usage.\n        \"\"\"\n\n        super().__init__()\n        self.dim = dim\n        self.num_heads = num_heads\n        self.window_size = window_size\n        self.shift_size = shift_size\n        self.mlp_ratio = mlp_ratio\n        self.use_checkpoint = use_checkpoint\n        self.norm1 = norm_layer(dim)\n        self.attn = WindowAttention(\n            dim,\n            window_size=self.window_size,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode=\"swin\")\n\n    def forward_part1(self, x, mask_matrix):\n        x_shape = x.size()\n        x = self.norm1(x)\n        if len(x_shape) == 5:\n            b, d, h, w, c = x.shape\n            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)\n            pad_l = pad_t = pad_d0 = 0\n            pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]\n            pad_b = (window_size[1] - h % window_size[1]) % window_size[1]\n            pad_r = (window_size[2] - w % window_size[2]) % window_size[2]\n            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))\n            _, dp, hp, wp, _ = x.shape\n            dims = [b, dp, hp, wp]\n\n        else:  # elif len(x_shape) == 4\n            b, h, w, c = x.shape\n            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)\n            pad_l = pad_t = 0\n            pad_b = (window_size[0] - h % window_size[0]) % window_size[0]\n            pad_r = (window_size[1] - w % window_size[1]) % window_size[1]\n            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))\n            _, hp, wp, _ = x.shape\n            dims = [b, hp, wp]\n\n        if any(i > 0 for i in shift_size):\n            if len(x_shape) == 5:\n                shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))\n            elif len(x_shape) == 4:\n                shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))\n            attn_mask = mask_matrix\n        else:\n            shifted_x = x\n            attn_mask = None\n        x_windows = window_partition(shifted_x, window_size)\n        attn_windows = self.attn(x_windows, mask=attn_mask)\n        attn_windows = attn_windows.view(-1, *(window_size + (c,)))\n        shifted_x = window_reverse(attn_windows, window_size, dims)\n        if any(i > 0 for i in shift_size):\n            if len(x_shape) == 5:\n                x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))\n            elif len(x_shape) == 4:\n                x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))\n        else:\n            x = shifted_x\n\n        if len(x_shape) == 5:\n            if pad_d1 > 0 or pad_r > 0 or pad_b > 0:\n                x = x[:, :d, :h, :w, :].contiguous()\n        elif len(x_shape) == 4:\n            if pad_r > 0 or pad_b > 0:\n                x = x[:, :h, :w, :].contiguous()\n\n        return x\n\n    def forward_part2(self, x):\n        return self.drop_path(self.mlp(self.norm2(x)))\n\n    def load_from(self, weights, n_block, layer):\n        root = f\"module.{layer}.0.blocks.{n_block}.\"\n        block_names = [\n            \"norm1.weight\",\n            \"norm1.bias\",\n            \"attn.relative_position_bias_table\",\n            \"attn.relative_position_index\",\n            \"attn.qkv.weight\",\n            \"attn.qkv.bias\",\n            \"attn.proj.weight\",\n            \"attn.proj.bias\",\n            \"norm2.weight\",\n            \"norm2.bias\",\n            \"mlp.fc1.weight\",\n            \"mlp.fc1.bias\",\n            \"mlp.fc2.weight\",\n            \"mlp.fc2.bias\",\n        ]\n        with torch.no_grad():\n            self.norm1.weight.copy_(weights[\"state_dict\"][root + block_names[0]])\n            self.norm1.bias.copy_(weights[\"state_dict\"][root + block_names[1]])\n            self.attn.relative_position_bias_table.copy_(weights[\"state_dict\"][root + block_names[2]])\n            self.attn.relative_position_index.copy_(weights[\"state_dict\"][root + block_names[3]])  # type: ignore[operator]\n            self.attn.qkv.weight.copy_(weights[\"state_dict\"][root + block_names[4]])\n            self.attn.qkv.bias.copy_(weights[\"state_dict\"][root + block_names[5]])\n            self.attn.proj.weight.copy_(weights[\"state_dict\"][root + block_names[6]])\n            self.attn.proj.bias.copy_(weights[\"state_dict\"][root + block_names[7]])\n            self.norm2.weight.copy_(weights[\"state_dict\"][root + block_names[8]])\n            self.norm2.bias.copy_(weights[\"state_dict\"][root + block_names[9]])\n            self.mlp.linear1.weight.copy_(weights[\"state_dict\"][root + block_names[10]])\n            self.mlp.linear1.bias.copy_(weights[\"state_dict\"][root + block_names[11]])\n            self.mlp.linear2.weight.copy_(weights[\"state_dict\"][root + block_names[12]])\n            self.mlp.linear2.bias.copy_(weights[\"state_dict\"][root + block_names[13]])\n\n    def forward(self, x, mask_matrix):\n        shortcut = x\n        if self.use_checkpoint:\n            x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)\n        else:\n            x = self.forward_part1(x, mask_matrix)\n        x = shortcut + self.drop_path(x)\n        if self.use_checkpoint:\n            x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)\n        else:\n            x = x + self.forward_part2(x)\n        return x\n\n\nclass PatchMergingV2(nn.Module):\n    \"\"\"\n    Patch merging layer based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n    \"\"\"\n\n    def __init__(self, dim: int, norm_layer: type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None:\n        \"\"\"\n        Args:\n            dim: number of feature channels.\n            norm_layer: normalization layer.\n            spatial_dims: number of spatial dims.\n        \"\"\"\n\n        super().__init__()\n        self.dim = dim\n        if spatial_dims == 3:\n            self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)\n            self.norm = norm_layer(8 * dim)\n        elif spatial_dims == 2:\n            self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n            self.norm = norm_layer(4 * dim)\n\n    def forward(self, x):\n        x_shape = x.size()\n        if len(x_shape) == 5:\n            b, d, h, w, c = x_shape\n            pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)\n            if pad_input:\n                x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))\n            x = torch.cat(\n                [x[:, i::2, j::2, k::2, :] for i, j, k in itertools.product(range(2), range(2), range(2))], -1\n            )\n\n        elif len(x_shape) == 4:\n            b, h, w, c = x_shape\n            pad_input = (h % 2 == 1) or (w % 2 == 1)\n            if pad_input:\n                x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))\n            x = torch.cat([x[:, j::2, i::2, :] for i, j in itertools.product(range(2), range(2))], -1)\n\n        x = self.norm(x)\n        x = self.reduction(x)\n        return x\n\n\nclass PatchMerging(PatchMergingV2):\n    \"\"\"The `PatchMerging` module previously defined in v0.9.0.\"\"\"\n\n    def forward(self, x):\n        x_shape = x.size()\n        if len(x_shape) == 4:\n            return super().forward(x)\n        if len(x_shape) != 5:\n            raise ValueError(f\"expecting 5D x, got {x.shape}.\")\n        b, d, h, w, c = x_shape\n        pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)\n        if pad_input:\n            x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))\n        x0 = x[:, 0::2, 0::2, 0::2, :]\n        x1 = x[:, 1::2, 0::2, 0::2, :]\n        x2 = x[:, 0::2, 1::2, 0::2, :]\n        x3 = x[:, 0::2, 0::2, 1::2, :]\n        x4 = x[:, 1::2, 1::2, 0::2, :]\n        x5 = x[:, 1::2, 0::2, 1::2, :]\n        x6 = x[:, 0::2, 1::2, 1::2, :]\n        x7 = x[:, 1::2, 1::2, 1::2, :]\n        x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)\n        x = self.norm(x)\n        x = self.reduction(x)\n        return x\n\n\nMERGING_MODE = {\"merging\": PatchMerging, \"mergingv2\": PatchMergingV2}\n\n\ndef compute_mask(dims, window_size, shift_size, device):\n    \"\"\"Computing region masks based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n\n     Args:\n        dims: dimension values.\n        window_size: local window size.\n        shift_size: shift size.\n        device: device.\n    \"\"\"\n\n    cnt = 0\n\n    if len(dims) == 3:\n        d, h, w = dims\n        img_mask = torch.zeros((1, d, h, w, 1), device=device)\n        for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):\n            for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):\n                for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):\n                    img_mask[:, d, h, w, :] = cnt\n                    cnt += 1\n\n    elif len(dims) == 2:\n        h, w = dims\n        img_mask = torch.zeros((1, h, w, 1), device=device)\n        for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):\n            for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):\n                img_mask[:, h, w, :] = cnt\n                cnt += 1\n\n    mask_windows = window_partition(img_mask, window_size)\n    mask_windows = mask_windows.squeeze(-1)\n    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n    attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)\n\n    return attn_mask\n\n\nclass BasicLayer(nn.Module):\n    \"\"\"\n    Basic Swin Transformer layer in one stage based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        depth: int,\n        num_heads: int,\n        window_size: Sequence[int],\n        drop_path: list,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        norm_layer: type[LayerNorm] = nn.LayerNorm,\n        downsample: nn.Module | None = None,\n        use_checkpoint: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim: number of feature channels.\n            depth: number of layers in each stage.\n            num_heads: number of attention heads.\n            window_size: local window size.\n            drop_path: stochastic depth rate.\n            mlp_ratio: ratio of mlp hidden dim to embedding dim.\n            qkv_bias: add a learnable bias to query, key, value.\n            drop: dropout rate.\n            attn_drop: attention dropout rate.\n            norm_layer: normalization layer.\n            downsample: an optional downsampling layer at the end of the layer.\n            use_checkpoint: use gradient checkpointing for reduced memory usage.\n        \"\"\"\n\n        super().__init__()\n        self.window_size = window_size\n        self.shift_size = tuple(i // 2 for i in window_size)\n        self.no_shift = tuple(0 for i in window_size)\n        self.depth = depth\n        self.use_checkpoint = use_checkpoint\n        self.blocks = nn.ModuleList(\n            [\n                SwinTransformerBlock(\n                    dim=dim,\n                    num_heads=num_heads,\n                    window_size=self.window_size,\n                    shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    drop=drop,\n                    attn_drop=attn_drop,\n                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                    norm_layer=norm_layer,\n                    use_checkpoint=use_checkpoint,\n                )\n                for i in range(depth)\n            ]\n        )\n        self.downsample = downsample\n        if callable(self.downsample):\n            self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))\n\n    def forward(self, x):\n        x_shape = x.size()\n        if len(x_shape) == 5:\n            b, c, d, h, w = x_shape\n            window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)\n            x = rearrange(x, \"b c d h w -> b d h w c\")\n            dp = int(np.ceil(d / window_size[0])) * window_size[0]\n            hp = int(np.ceil(h / window_size[1])) * window_size[1]\n            wp = int(np.ceil(w / window_size[2])) * window_size[2]\n            attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)\n            for blk in self.blocks:\n                x = blk(x, attn_mask)\n            x = x.view(b, d, h, w, -1)\n            if self.downsample is not None:\n                x = self.downsample(x)\n            x = rearrange(x, \"b d h w c -> b c d h w\")\n\n        elif len(x_shape) == 4:\n            b, c, h, w = x_shape\n            window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)\n            x = rearrange(x, \"b c h w -> b h w c\")\n            hp = int(np.ceil(h / window_size[0])) * window_size[0]\n            wp = int(np.ceil(w / window_size[1])) * window_size[1]\n            attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)\n            for blk in self.blocks:\n                x = blk(x, attn_mask)\n            x = x.view(b, h, w, -1)\n            if self.downsample is not None:\n                x = self.downsample(x)\n            x = rearrange(x, \"b h w c -> b c h w\")\n        return x\n\n\nclass SwinTransformer(nn.Module):\n    \"\"\"\n    Swin Transformer based on: \"Liu et al.,\n    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows\n    <https://arxiv.org/abs/2103.14030>\"\n    https://github.com/microsoft/Swin-Transformer\n    \"\"\"\n\n    def __init__(\n        self,\n        in_chans: int,\n        embed_dim: int,\n        window_size: Sequence[int],\n        patch_size: Sequence[int],\n        depths: Sequence[int],\n        num_heads: Sequence[int],\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = True,\n        drop_rate: float = 0.0,\n        attn_drop_rate: float = 0.0,\n        drop_path_rate: float = 0.0,\n        norm_layer: type[LayerNorm] = nn.LayerNorm,\n        patch_norm: bool = False,\n        use_checkpoint: bool = False,\n        spatial_dims: int = 3,\n        downsample=\"merging\",\n        use_v2=False,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_chans: dimension of input channels.\n            embed_dim: number of linear projection output channels.\n            window_size: local window size.\n            patch_size: patch size.\n            depths: number of layers in each stage.\n            num_heads: number of attention heads.\n            mlp_ratio: ratio of mlp hidden dim to embedding dim.\n            qkv_bias: add a learnable bias to query, key, value.\n            drop_rate: dropout rate.\n            attn_drop_rate: attention dropout rate.\n            drop_path_rate: stochastic depth rate.\n            norm_layer: normalization layer.\n            patch_norm: add normalization after patch embedding.\n            use_checkpoint: use gradient checkpointing for reduced memory usage.\n            spatial_dims: spatial dimension.\n            downsample: module used for downsampling, available options are `\"mergingv2\"`, `\"merging\"` and a\n                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.\n                The default is currently `\"merging\"` (the original version defined in v0.9.0).\n            use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage.\n        \"\"\"\n\n        super().__init__()\n        self.num_layers = len(depths)\n        self.embed_dim = embed_dim\n        self.patch_norm = patch_norm\n        self.window_size = window_size\n        self.patch_size = patch_size\n        self.patch_embed = PatchEmbed(\n            patch_size=self.patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            norm_layer=norm_layer if self.patch_norm else None,  # type: ignore\n            spatial_dims=spatial_dims,\n        )\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        self.use_v2 = use_v2\n        self.layers1 = nn.ModuleList()\n        self.layers2 = nn.ModuleList()\n        self.layers3 = nn.ModuleList()\n        self.layers4 = nn.ModuleList()\n        if self.use_v2:\n            self.layers1c = nn.ModuleList()\n            self.layers2c = nn.ModuleList()\n            self.layers3c = nn.ModuleList()\n            self.layers4c = nn.ModuleList()\n        down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample\n        for i_layer in range(self.num_layers):\n            layer = BasicLayer(\n                dim=int(embed_dim * 2**i_layer),\n                depth=depths[i_layer],\n                num_heads=num_heads[i_layer],\n                window_size=self.window_size,\n                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                drop=drop_rate,\n                attn_drop=attn_drop_rate,\n                norm_layer=norm_layer,\n                downsample=down_sample_mod,\n                use_checkpoint=use_checkpoint,\n            )\n            if i_layer == 0:\n                self.layers1.append(layer)\n            elif i_layer == 1:\n                self.layers2.append(layer)\n            elif i_layer == 2:\n                self.layers3.append(layer)\n            elif i_layer == 3:\n                self.layers4.append(layer)\n            if self.use_v2:\n                layerc = UnetrBasicBlock(\n                    spatial_dims=spatial_dims,\n                    in_channels=embed_dim * 2**i_layer,\n                    out_channels=embed_dim * 2**i_layer,\n                    kernel_size=3,\n                    stride=1,\n                    norm_name=\"instance\",\n                    res_block=True,\n                )\n                if i_layer == 0:\n                    self.layers1c.append(layerc)\n                elif i_layer == 1:\n                    self.layers2c.append(layerc)\n                elif i_layer == 2:\n                    self.layers3c.append(layerc)\n                elif i_layer == 3:\n                    self.layers4c.append(layerc)\n\n        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n\n    def proj_out(self, x, normalize=False):\n        if normalize:\n            x_shape = x.shape\n            # Force trace() to generate a constant by casting to int\n            ch = int(x_shape[1])\n            if len(x_shape) == 5:\n                x = rearrange(x, \"n c d h w -> n d h w c\")\n                x = F.layer_norm(x, [ch])\n                x = rearrange(x, \"n d h w c -> n c d h w\")\n            elif len(x_shape) == 4:\n                x = rearrange(x, \"n c h w -> n h w c\")\n                x = F.layer_norm(x, [ch])\n                x = rearrange(x, \"n h w c -> n c h w\")\n        return x\n\n    def forward(self, x, normalize=True):\n        x0 = self.patch_embed(x)\n        x0 = self.pos_drop(x0)\n        x0_out = self.proj_out(x0, normalize)\n        if self.use_v2:\n            x0 = self.layers1c[0](x0.contiguous())\n        x1 = self.layers1[0](x0.contiguous())\n        x1_out = self.proj_out(x1, normalize)\n        if self.use_v2:\n            x1 = self.layers2c[0](x1.contiguous())\n        x2 = self.layers2[0](x1.contiguous())\n        x2_out = self.proj_out(x2, normalize)\n        if self.use_v2:\n            x2 = self.layers3c[0](x2.contiguous())\n        x3 = self.layers3[0](x2.contiguous())\n        x3_out = self.proj_out(x3, normalize)\n        if self.use_v2:\n            x3 = self.layers4c[0](x3.contiguous())\n        x4 = self.layers4[0](x3.contiguous())\n        x4_out = self.proj_out(x4, normalize)\n        return [x0_out, x1_out, x2_out, x3_out, x4_out]\n\n\ndef filter_swinunetr(key, value):\n    \"\"\"\n    A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.\n    This function is typically used with `monai.networks.copy_model_state`\n    [1] \"Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training\n    <https://arxiv.org/abs/2307.16896>\"\n\n    Args:\n        key: the key in the source state dict used for the update.\n        value: the value in the source state dict used for the update.\n\n    Examples::\n\n        import torch\n        from monai.apps import download_url\n        from monai.networks.utils import copy_model_state\n        from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr\n\n        model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)\n        resource = (\n            \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth\"\n        )\n        ssl_weights_path = \"./ssl_pretrained_weights.pth\"\n        download_url(resource, ssl_weights_path)\n        ssl_weights = torch.load(ssl_weights_path, weights_only=True)[\"model\"]\n\n        dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)\n\n    \"\"\"\n    if key in [\n        \"encoder.mask_token\",\n        \"encoder.norm.weight\",\n        \"encoder.norm.bias\",\n        \"out.conv.conv.weight\",\n        \"out.conv.conv.bias\",\n    ]:\n        return None\n\n    if key[:8] == \"encoder.\":\n        if key[8:19] == \"patch_embed\":\n            new_key = \"swinViT.\" + key[8:]\n        else:\n            new_key = \"swinViT.\" + key[8:18] + key[20:]\n\n        return new_key, value\n    else:\n        return None\n"
  },
  {
    "path": "monai/networks/nets/torchvision_fc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nfrom monai.networks.nets import NetAdapter\nfrom monai.utils import optional_import\n\nmodels, _ = optional_import(\"torchvision.models\")\n\n__all__ = [\"TorchVisionFCModel\"]\n\n\nclass TorchVisionFCModel(NetAdapter):\n    \"\"\"\n    Customize the fully connected layer of (pretrained) TorchVision model or replace it by convolutional layer.\n\n    This class supports two primary use cases:\n\n        - use ``pool=None`` to indicate no modification in the pooling layers. It should be used with ``fc_name``\n          to locate the target FC layer to modify:\n          In this case, the class will load a torchvision classification model,\n          replace the last fully connected (FC) layer with a new FC layer with ``num_classes`` outputs,\n          example input arguments: ``use_conv=False, pool=None, fc_name=\"heads.head\"``.\n          The ``heads.head`` specifies the target FC of the input model, could be found by ``model.named_modules()``,\n          for example::\n\n              from torchvision.models import vit_b_16\n              print([name[0] for name in vit_b_16().named_modules()])\n\n        - use ``pool=\"\"`` or set it to a tuple of pooling parameters to indicate modifications of both\n          the pooling and the FC layer. It should be used with ``node_name`` to locate the model feature outputs:\n          In this case, the class will load a torchvision model, remove the existing pooling and FC layers, and\n\n          - append an additional convolution layer:\n            ``use_conv=True, pool=\"\", node_name=\"permute\"``\n          - append an additional pooling and FC layers:\n            ``use_conv=False, pool=(\"avg\", {\"kernel_size\": 7, \"stride\": 1}), node_name=\"permute\"``\n          - append an additional pooling and convolution layers:\n            ``use_conv=True, pool=(\"avg\", {\"kernel_size\": 7, \"stride\": 1}), node_name=\"permute\"``\n\n          The ``permute`` in the example is the target feature extraction node of the input\n          `model_name`, could be found by using the torchvision feature extraction utilities, for example::\n\n              from torchvision.models.feature_extraction import get_graph_node_names\n              from torchvision.models import swin_t\n              print(get_graph_node_names(swin_t())[0])\n\n    Args:\n        model_name: name of any torchvision model with fully connected layer at the end.\n            ``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``,\n            ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``, ``inception_v3``.\n            model details: https://pytorch.org/vision/stable/models.html.\n        num_classes: number of classes for the last classification layer. Default to 1.\n        dim: number of supported spatial dimensions in the specified model, depends on the model implementation.\n            default to 2 as most Torchvision models are for 2D image processing.\n        in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.\n        use_conv: whether to use convolutional layer to replace the last layer, default to False.\n        pool: parameters for the pooling layer, when it's a tuple, the first item is name of the pooling layer,\n            the second item is dictionary of the initialization args. If None, will not replace the `layers[-2]`.\n            default to `(\"avg\", {\"kernel_size\": 7, \"stride\": 1})`. ``\"\"`` indicates not adding a pooling layer.\n        bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,\n            default to True.\n        pretrained: whether to use the imagenet pretrained weights. Default to False.\n        fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``\"fc\"``.\n        node_name: the corresponding feature extractor node name of `model`. Defaults to \"\", not in use.\n        weights: additional weights enum for the torchvision model.\n        kwargs: additional parameters for the torchvision model.\n\n    Example::\n\n        import torch\n        from torchvision.models.inception import Inception_V3_Weights\n\n        from monai.networks.nets import TorchVisionFCModel\n\n        model = TorchVisionFCModel(\n            \"inception_v3\",\n            num_classes=4,\n            weights=Inception_V3_Weights.IMAGENET1K_V1,\n            use_conv=False,\n            pool=None,\n        )\n        # model = TorchVisionFCModel(\"vit_b_16\", num_classes=4, pool=None, in_channels=768, fc_name=\"heads\")\n        output = model.forward(torch.randn(2, 3, 299, 299))\n        print(output.shape)  # torch.Size([2, 4])\n\n    \"\"\"\n\n    def __init__(\n        self,\n        model_name: str = \"resnet18\",\n        num_classes: int = 1,\n        dim: int = 2,\n        in_channels: int | None = None,\n        use_conv: bool = False,\n        pool: tuple[str, dict[str, Any]] | None = (\"avg\", {\"kernel_size\": 7, \"stride\": 1}),\n        bias: bool = True,\n        pretrained: bool = False,\n        fc_name: str = \"fc\",\n        node_name: str = \"\",\n        weights=None,\n        **kwargs,\n    ):\n        # if pretrained is False, weights is a weight tensor or None for no pretrained loading\n        if pretrained and weights is None:\n            weights = \"DEFAULT\"\n\n        model = getattr(models, model_name)(weights=weights, **kwargs)\n\n        super().__init__(\n            model=model,\n            num_classes=num_classes,\n            dim=dim,\n            in_channels=in_channels,\n            use_conv=use_conv,\n            pool=pool,\n            bias=bias,\n            fc_name=fc_name,\n            node_name=node_name,\n        )\n"
  },
  {
    "path": "monai/networks/nets/transchex.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\n\nimport torch\nfrom torch import nn\n\nfrom monai.config.type_definitions import PathLike\nfrom monai.utils import optional_import\n\ntransformers = optional_import(\"transformers\")\nload_tf_weights_in_bert = optional_import(\"transformers\", name=\"load_tf_weights_in_bert\")[0]\ncached_file = optional_import(\"transformers.utils\", name=\"cached_file\")[0]\nBertEmbeddings = optional_import(\"transformers.models.bert.modeling_bert\", name=\"BertEmbeddings\")[0]\nBertLayer = optional_import(\"transformers.models.bert.modeling_bert\", name=\"BertLayer\")[0]\n\n__all__ = [\"BertPreTrainedModel\", \"BertAttention\", \"BertOutput\", \"BertMixedLayer\", \"Pooler\", \"MultiModal\", \"Transchex\"]\n\n\nclass BertPreTrainedModel(nn.Module):\n    \"\"\"Module to load BERT pre-trained weights.\n    Based on:\n    LXMERT\n    https://github.com/airsplay/lxmert\n    BERT (pytorch-transformer)\n    https://github.com/huggingface/transformers\n    \"\"\"\n\n    def __init__(self, *inputs, **kwargs) -> None:\n        super().__init__()\n\n    def init_bert_weights(self, module):\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)  # type: ignore[union-attr,arg-type]\n        elif isinstance(module, torch.nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        num_language_layers,\n        num_vision_layers,\n        num_mixed_layers,\n        bert_config,\n        state_dict=None,\n        cache_dir=None,\n        from_tf=False,\n        path_or_repo_id=\"bert-base-uncased\",\n        filename=\"pytorch_model.bin\",\n        *inputs,\n        **kwargs,\n    ):\n        weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)\n        model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)\n        if state_dict is None and not from_tf:\n            map_location = \"cpu\" if not torch.cuda.is_available() else None\n            state_dict = torch.load(weights_path, map_location=map_location, weights_only=True)\n        if from_tf:\n            return load_tf_weights_in_bert(model, weights_path)\n        old_keys = []\n        new_keys = []\n        for key in state_dict.keys():\n            new_key = None\n            if \"gamma\" in key:\n                new_key = key.replace(\"gamma\", \"weight\")\n            if \"beta\" in key:\n                new_key = key.replace(\"beta\", \"bias\")\n            if new_key:\n                old_keys.append(key)\n                new_keys.append(new_key)\n        for old_key, new_key in zip(old_keys, new_keys):\n            state_dict[new_key] = state_dict.pop(old_key)\n        missing_keys: list = []\n        unexpected_keys: list = []\n        error_msgs: list = []\n        metadata = getattr(state_dict, \"_metadata\", None)\n        state_dict = state_dict.copy()\n        if metadata is not None:\n            state_dict._metadata = metadata\n\n        def load(module, prefix=\"\"):\n            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})\n            module._load_from_state_dict(\n                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs\n            )\n            for name, child in module._modules.items():\n                if child is not None:\n                    load(child, prefix + name + \".\")\n\n        start_prefix = \"\"\n        if not hasattr(model, \"bert\") and any(s.startswith(\"bert.\") for s in state_dict.keys()):\n            start_prefix = \"bert.\"\n        load(model, prefix=start_prefix)\n        return model\n\n\nclass BertAttention(nn.Module):\n    \"\"\"BERT attention layer.\n    Based on: BERT (pytorch-transformer)\n    https://github.com/huggingface/transformers\n    \"\"\"\n\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        self.key = nn.Linear(config.hidden_size, self.all_head_size)\n        self.value = nn.Linear(config.hidden_size, self.all_head_size)\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(self, hidden_states, context):\n        mixed_query_layer = self.query(hidden_states)\n        mixed_key_layer = self.key(context)\n        mixed_value_layer = self.value(context)\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n        key_layer = self.transpose_for_scores(mixed_key_layer)\n        value_layer = self.transpose_for_scores(mixed_value_layer)\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        attention_probs = self.dropout(nn.Softmax(dim=-1)(attention_scores))\n        context_layer = torch.matmul(attention_probs, value_layer)\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n        return context_layer\n\n\nclass BertOutput(nn.Module):\n    \"\"\"BERT output layer.\n    Based on: BERT (pytorch-transformer)\n    https://github.com/huggingface/transformers\n    \"\"\"\n\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertMixedLayer(nn.Module):\n    \"\"\"BERT cross attention layer.\n    Based on: BERT (pytorch-transformer)\n    https://github.com/huggingface/transformers\n    \"\"\"\n\n    def __init__(self, config) -> None:\n        super().__init__()\n        self.att_x = BertAttention(config)\n        self.output_x = BertOutput(config)\n        self.att_y = BertAttention(config)\n        self.output_y = BertOutput(config)\n\n    def forward(self, x, y):\n        output_x = self.att_x(x, y)\n        output_y = self.att_y(y, x)\n        return self.output_x(output_x, x), self.output_y(output_y, y)\n\n\nclass Pooler(nn.Module):\n    \"\"\"BERT pooler layer.\n    Based on: BERT (pytorch-transformer)\n    https://github.com/huggingface/transformers\n    \"\"\"\n\n    def __init__(self, hidden_size) -> None:\n        super().__init__()\n        self.dense = nn.Linear(hidden_size, hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass MultiModal(BertPreTrainedModel):\n    \"\"\"\n    Multimodal Transformers From Pretrained BERT Weights\"\n    \"\"\"\n\n    def __init__(\n        self, num_language_layers: int, num_vision_layers: int, num_mixed_layers: int, bert_config: dict\n    ) -> None:\n        \"\"\"\n        Args:\n            num_language_layers: number of language transformer layers.\n            num_vision_layers: number of vision transformer layers.\n            bert_config: configuration for bert language transformer encoder.\n\n        \"\"\"\n        super().__init__()\n        self.config = type(\"obj\", (object,), bert_config)\n        self.embeddings = BertEmbeddings(self.config)\n        self.language_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_language_layers)])\n        self.vision_encoder = nn.ModuleList([BertLayer(self.config) for _ in range(num_vision_layers)])\n        self.mixed_encoder = nn.ModuleList([BertMixedLayer(self.config) for _ in range(num_mixed_layers)])\n        self.apply(self.init_bert_weights)\n\n    @staticmethod\n    def _get_hidden_states(layer_output):\n        \"\"\"Extract hidden states from BertLayer output.\n\n        Compatible with both older transformers (returns a tuple) and\n        newer transformers >=5.0 (may return a tensor directly).\n        \"\"\"\n        if isinstance(layer_output, torch.Tensor):\n            return layer_output\n        return layer_output[0]\n\n    def forward(self, input_ids, token_type_ids=None, vision_feats=None, attention_mask=None):\n        language_features = self.embeddings(input_ids, token_type_ids)\n        for layer in self.vision_encoder:\n            vision_feats = self._get_hidden_states(layer(vision_feats, None))\n        for layer in self.language_encoder:\n            language_features = self._get_hidden_states(layer(language_features, attention_mask))\n        for layer in self.mixed_encoder:\n            language_features, vision_feats = layer(language_features, vision_feats)\n        return language_features, vision_feats\n\n\nclass Transchex(torch.nn.Module):\n    \"\"\"\n    TransChex based on: \"Hatamizadeh et al.,TransCheX: Self-Supervised Pretraining of Vision-Language\n    Transformers for Chest X-ray Analysis\"\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        img_size: Sequence[int] | int,\n        patch_size: int | tuple[int, int],\n        num_classes: int,\n        num_language_layers: int,\n        num_vision_layers: int,\n        num_mixed_layers: int,\n        hidden_size: int = 768,\n        drop_out: float = 0.0,\n        attention_probs_dropout_prob: float = 0.1,\n        gradient_checkpointing: bool = False,\n        hidden_act: str = \"gelu\",\n        hidden_dropout_prob: float = 0.1,\n        initializer_range: float = 0.02,\n        intermediate_size: int = 3072,\n        layer_norm_eps: float = 1e-12,\n        max_position_embeddings: int = 512,\n        model_type: str = \"bert\",\n        num_attention_heads: int = 12,\n        num_hidden_layers: int = 12,\n        pad_token_id: int = 0,\n        position_embedding_type: str = \"absolute\",\n        transformers_version: str = \"4.10.2\",\n        type_vocab_size: int = 2,\n        use_cache: bool = True,\n        vocab_size: int = 30522,\n        chunk_size_feed_forward: int = 0,\n        is_decoder: bool = False,\n        add_cross_attention: bool = False,\n        path_or_repo_id: str | PathLike = \"bert-base-uncased\",\n        filename: str = \"pytorch_model.bin\",\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: dimension of input channels.\n            img_size: dimension of input image.\n            patch_size: dimension of patch size.\n            num_classes: number of classes if classification is used.\n            num_language_layers: number of language transformer layers.\n            num_vision_layers: number of vision transformer layers.\n            num_mixed_layers: number of mixed transformer layers.\n            drop_out: fraction of the input units to drop.\n            path_or_repo_id: This can be either:\n                - a string, the *model id* of a model repo on huggingface.co.\n                - a path to a *directory* potentially containing the file.\n            filename: The name of the file to locate in `path_or_repo`.\n\n        The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.\n\n        Examples:\n\n        .. code-block:: python\n\n            # for 3-channel with image size of (224,224), patch size of (32,32), 3 classes, 2 language layers,\n            # 2 vision layers, 2 mixed modality layers and dropout of 0.2 in the classification head\n            net = Transchex(in_channels=3,\n                                 img_size=(224, 224),\n                                 num_classes=3,\n                                 num_language_layers=2,\n                                 num_vision_layers=2,\n                                 num_mixed_layers=2,\n                                 drop_out=0.2)\n\n        \"\"\"\n        super().__init__()\n        bert_config = {\n            \"attention_probs_dropout_prob\": attention_probs_dropout_prob,\n            \"classifier_dropout\": None,\n            \"gradient_checkpointing\": gradient_checkpointing,\n            \"hidden_act\": hidden_act,\n            \"hidden_dropout_prob\": hidden_dropout_prob,\n            \"hidden_size\": hidden_size,\n            \"initializer_range\": initializer_range,\n            \"intermediate_size\": intermediate_size,\n            \"layer_norm_eps\": layer_norm_eps,\n            \"max_position_embeddings\": max_position_embeddings,\n            \"model_type\": model_type,\n            \"num_attention_heads\": num_attention_heads,\n            \"num_hidden_layers\": num_hidden_layers,\n            \"pad_token_id\": pad_token_id,\n            \"position_embedding_type\": position_embedding_type,\n            \"transformers_version\": transformers_version,\n            \"type_vocab_size\": type_vocab_size,\n            \"use_cache\": use_cache,\n            \"vocab_size\": vocab_size,\n            \"chunk_size_feed_forward\": chunk_size_feed_forward,\n            \"is_decoder\": is_decoder,\n            \"add_cross_attention\": add_cross_attention,\n            \"_attn_implementation\": \"eager\",\n        }\n        if not (0 <= drop_out <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n\n        if (img_size[0] % patch_size[0] != 0) or (img_size[1] % patch_size[1] != 0):  # type: ignore\n            raise ValueError(\"img_size should be divisible by patch_size.\")\n\n        self.multimodal = MultiModal.from_pretrained(\n            num_language_layers=num_language_layers,\n            num_vision_layers=num_vision_layers,\n            num_mixed_layers=num_mixed_layers,\n            bert_config=bert_config,\n            path_or_repo_id=path_or_repo_id,\n            filename=filename,\n        )\n\n        self.patch_size = patch_size\n        self.num_patches = (img_size[0] // self.patch_size[0]) * (img_size[1] // self.patch_size[1])  # type: ignore\n        self.vision_proj = nn.Conv2d(\n            in_channels=in_channels, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size\n        )\n        self.norm_vision_pos = nn.LayerNorm(hidden_size)\n        self.pos_embed_vis = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size))\n        self.pooler = Pooler(hidden_size=hidden_size)\n        self.drop = torch.nn.Dropout(drop_out)\n        self.cls_head = torch.nn.Linear(hidden_size, num_classes)\n\n    def forward(self, input_ids, token_type_ids=None, vision_feats=None):\n        attention_mask = torch.ones_like(input_ids).unsqueeze(1).unsqueeze(2)\n        attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)\n        attention_mask = (1.0 - attention_mask) * -10000.0\n        vision_feats = self.vision_proj(vision_feats).flatten(2).transpose(1, 2)\n        vision_feats = self.norm_vision_pos(vision_feats)\n        vision_feats = vision_feats + self.pos_embed_vis\n        hidden_state_lang, hidden_state_vis = self.multimodal(\n            input_ids=input_ids, token_type_ids=token_type_ids, vision_feats=vision_feats, attention_mask=attention_mask\n        )\n        pooled_features = self.pooler(hidden_state_lang)\n        logits = self.cls_head(self.drop(pooled_features))\n        return logits\n"
  },
  {
    "path": "monai/networks/nets/transformer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import TransformerBlock\n\n__all__ = [\"DecoderOnlyTransformer\"]\n\n\nclass AbsolutePositionalEmbedding(nn.Module):\n    \"\"\"Absolute positional embedding.\n\n    Args:\n        max_seq_len: Maximum sequence length.\n        embedding_dim: Dimensionality of the embedding.\n    \"\"\"\n\n    def __init__(self, max_seq_len: int, embedding_dim: int) -> None:\n        super().__init__()\n        self.max_seq_len = max_seq_len\n        self.embedding_dim = embedding_dim\n        self.embedding = nn.Embedding(max_seq_len, embedding_dim)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        batch_size, seq_len = x.size()\n        positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)\n        embedding: torch.Tensor = self.embedding(positions)\n        return embedding\n\n\nclass DecoderOnlyTransformer(nn.Module):\n    \"\"\"Decoder-only (Autoregressive) Transformer model.\n\n    Args:\n        num_tokens: Number of tokens in the vocabulary.\n        max_seq_len: Maximum sequence length.\n        attn_layers_dim: Dimensionality of the attention layers.\n        attn_layers_depth: Number of attention layers.\n        attn_layers_heads: Number of attention heads.\n        with_cross_attention: Whether to use cross attention for conditioning.\n        embedding_dropout_rate: Dropout rate for the embedding.\n        include_fc: whether to include the final linear layer. Default to True.\n        use_combined_linear: whether to use a single linear layer for qkv projection, default to True.\n        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism\n            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).\n    \"\"\"\n\n    def __init__(\n        self,\n        num_tokens: int,\n        max_seq_len: int,\n        attn_layers_dim: int,\n        attn_layers_depth: int,\n        attn_layers_heads: int,\n        with_cross_attention: bool = False,\n        embedding_dropout_rate: float = 0.0,\n        include_fc: bool = True,\n        use_combined_linear: bool = False,\n        use_flash_attention: bool = False,\n    ) -> None:\n        super().__init__()\n        self.num_tokens = num_tokens\n        self.max_seq_len = max_seq_len\n        self.attn_layers_dim = attn_layers_dim\n        self.attn_layers_depth = attn_layers_depth\n        self.attn_layers_heads = attn_layers_heads\n        self.with_cross_attention = with_cross_attention\n\n        self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)\n        self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)\n        self.embedding_dropout = nn.Dropout(embedding_dropout_rate)\n\n        self.blocks = nn.ModuleList(\n            [\n                TransformerBlock(\n                    hidden_size=attn_layers_dim,\n                    mlp_dim=attn_layers_dim * 4,\n                    num_heads=attn_layers_heads,\n                    dropout_rate=0.0,\n                    qkv_bias=False,\n                    causal=True,\n                    sequence_length=max_seq_len,\n                    with_cross_attention=with_cross_attention,\n                    include_fc=include_fc,\n                    use_combined_linear=use_combined_linear,\n                    use_flash_attention=use_flash_attention,\n                )\n                for _ in range(attn_layers_depth)\n            ]\n        )\n\n        self.to_logits = nn.Linear(attn_layers_dim, num_tokens)\n\n    def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:\n        tok_emb = self.token_embeddings(x)\n        pos_emb = self.position_embeddings(x)\n        x = self.embedding_dropout(tok_emb + pos_emb)\n\n        for block in self.blocks:\n            x = block(x, context=context)\n        logits: torch.Tensor = self.to_logits(x)\n        return logits\n\n    def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:\n        \"\"\"\n        Load a state dict from a DecoderOnlyTransformer trained with\n        [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).\n\n        Args:\n            old_state_dict: state dict from the old DecoderOnlyTransformer  model.\n        \"\"\"\n\n        new_state_dict = self.state_dict()\n        # if all keys match, just load the state dict\n        if all(k in new_state_dict for k in old_state_dict):\n            print(\"All keys match, loading state dict.\")\n            self.load_state_dict(old_state_dict)\n            return\n\n        if verbose:\n            # print all new_state_dict keys that are not in old_state_dict\n            for k in new_state_dict:\n                if k not in old_state_dict:\n                    print(f\"key {k} not found in old state dict\")\n            # and vice versa\n            print(\"----------------------------------------------\")\n            for k in old_state_dict:\n                if k not in new_state_dict:\n                    print(f\"key {k} not found in new state dict\")\n\n        # copy over all matching keys\n        for k in new_state_dict:\n            if k in old_state_dict:\n                new_state_dict[k] = old_state_dict.pop(k)\n\n        # fix the renamed norm blocks first  norm2 -> norm_cross_attention , norm3 -> norm2\n        for k in list(old_state_dict.keys()):\n            if \"norm2\" in k:\n                new_state_dict[k.replace(\"norm2\", \"norm_cross_attn\")] = old_state_dict.pop(k)\n            if \"norm3\" in k:\n                new_state_dict[k.replace(\"norm3\", \"norm2\")] = old_state_dict.pop(k)\n        if verbose:\n            # print all remaining keys in old_state_dict\n            print(\"remaining keys in old_state_dict:\", old_state_dict.keys())\n        self.load_state_dict(new_state_dict)\n"
  },
  {
    "path": "monai/networks/nets/unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper\nfrom monai.networks.blocks.convolutions import Convolution, ResidualUnit\nfrom monai.networks.layers.factories import Act, Norm\nfrom monai.networks.layers.simplelayers import SkipConnection\n\n__all__ = [\"UNet\", \"Unet\", \"CheckpointUNet\"]\n\n\nclass UNet(nn.Module):\n    \"\"\"\n    Enhanced version of UNet which has residual units implemented with the ResidualUnit class.\n    The residual part uses a convolution to change the input dimensions to match the output dimensions\n    if this is necessary but will use nn.Identity if not.\n    Refer to: https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40.\n\n    Each layer of the network has a encode and decode path with a skip connection between them. Data in the encode path\n    is downsampled using strided convolutions (if `strides` is given values greater than 1) and in the decode path\n    upsampled using strided transpose convolutions. These down or up sampling operations occur at the beginning of each\n    block rather than afterwards as is typical in UNet implementations.\n\n    To further explain this consider the first example network given below. This network has 3 layers with strides\n    of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input\n    data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of\n    the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its\n    input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this\n    ensures the final output of the network has the same shape as the input.\n\n    Padding values for the convolutions are chosen to ensure output sizes are even divisors/multiples of the input\n    sizes if the `strides` value for a layer is a factor of the input sizes. A typical case is to use `strides` values\n    of 2 and inputs that are multiples of powers of 2. An input can thus be downsampled evenly however many times its\n    dimensions can be divided by 2, so for the example network inputs would have to have dimensions that are multiples\n    of 4. In the second example network given below the input to the bottom layer will have shape (1, 64, 15, 15) for\n    an input of shape (1, 1, 240, 240) demonstrating the input being reduced in size spatially by 2**4.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.\n        strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.\n        kernel_size: convolution kernel size, the value(s) should be odd. If sequence,\n            its length should equal to dimensions. Defaults to 3.\n        up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,\n            its length should equal to dimensions. Defaults to 3.\n        num_res_units: number of residual units. Defaults to 0.\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        dropout: dropout ratio. Defaults to no dropout.\n        bias: whether to have a bias term in convolution blocks. Defaults to True.\n            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n            if a conv layer is directly followed by a batch norm layer, bias should be False.\n        adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D).\n            Defaults to \"NDA\". See also: :py:class:`monai.networks.blocks.ADN`.\n\n    Examples::\n\n        from monai.networks.nets import UNet\n\n        # 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units\n        net = UNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            channels=(4, 8, 16),\n            strides=(2, 2),\n            num_res_units=2\n        )\n\n        # 5 layer network with simple convolution/normalization/dropout/activation blocks defining the layers\n        net=UNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            channels=(4, 8, 16, 32, 64),\n            strides=(2, 2, 2, 2),\n        )\n\n    Note: The acceptable spatial size of input data depends on the parameters of the network,\n        to set appropriate spatial size, please check the tutorial for more details:\n        https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb.\n        Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the\n        input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network,\n        the inputs must have spatial dimensions that are all multiples of 2^N.\n        Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        up_kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 0,\n        act: tuple | str = Act.PRELU,\n        norm: tuple | str = Norm.INSTANCE,\n        dropout: float = 0.0,\n        bias: bool = True,\n        adn_ordering: str = \"NDA\",\n    ) -> None:\n        super().__init__()\n\n        if len(channels) < 2:\n            raise ValueError(\"the length of `channels` should be no less than 2.\")\n        delta = len(strides) - (len(channels) - 1)\n        if delta < 0:\n            raise ValueError(\"the length of `strides` should equal to `len(channels) - 1`.\")\n        if delta > 0:\n            warnings.warn(f\"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.\")\n        if isinstance(kernel_size, Sequence) and len(kernel_size) != spatial_dims:\n            raise ValueError(\"the length of `kernel_size` should equal to `dimensions`.\")\n        if isinstance(up_kernel_size, Sequence) and len(up_kernel_size) != spatial_dims:\n            raise ValueError(\"the length of `up_kernel_size` should equal to `dimensions`.\")\n\n        self.dimensions = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.channels = channels\n        self.strides = strides\n        self.kernel_size = kernel_size\n        self.up_kernel_size = up_kernel_size\n        self.num_res_units = num_res_units\n        self.act = act\n        self.norm = norm\n        self.dropout = dropout\n        self.bias = bias\n        self.adn_ordering = adn_ordering\n\n        def _create_block(\n            inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool\n        ) -> nn.Module:\n            \"\"\"\n            Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential\n            blocks containing the downsample path, a skip connection around the previous block, and the upsample path.\n\n            Args:\n                inc: number of input channels.\n                outc: number of output channels.\n                channels: sequence of channels. Top block first.\n                strides: convolution stride.\n                is_top: True if this is the top block.\n            \"\"\"\n            c = channels[0]\n            s = strides[0]\n\n            subblock: nn.Module\n\n            if len(channels) > 2:\n                subblock = _create_block(c, c, channels[1:], strides[1:], False)  # continue recursion down\n                upc = c * 2\n            else:\n                # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer\n                subblock = self._get_bottom_layer(c, channels[1])\n                upc = c + channels[1]\n\n            down = self._get_down_layer(inc, c, s, is_top)  # create layer in downsampling path\n            up = self._get_up_layer(upc, outc, s, is_top)  # create layer in upsampling path\n\n            return self._get_connection_block(down, up, subblock)\n\n        self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True)\n\n    def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:\n        \"\"\"\n        Returns the block object defining a layer of the UNet structure including the implementation of the skip\n        between encoding (down) and decoding (up) sides of the network.\n\n        Args:\n            down_path: encoding half of the layer\n            up_path: decoding half of the layer\n            subblock: block defining the next layer in the network.\n        Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`\n        \"\"\"\n        return nn.Sequential(down_path, SkipConnection(subblock), up_path)\n\n    def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:\n        \"\"\"\n        Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point\n        in its structure. Its output is used as input to the next layer down and is concatenated with output from the\n        next layer to form the input for the decode (up) part of the layer.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            strides: convolution stride.\n            is_top: True if this is the top block.\n        \"\"\"\n        mod: nn.Module\n        if self.num_res_units > 0:\n            mod = ResidualUnit(\n                self.dimensions,\n                in_channels,\n                out_channels,\n                strides=strides,\n                kernel_size=self.kernel_size,\n                subunits=self.num_res_units,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n                adn_ordering=self.adn_ordering,\n            )\n            return mod\n        mod = Convolution(\n            self.dimensions,\n            in_channels,\n            out_channels,\n            strides=strides,\n            kernel_size=self.kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n            adn_ordering=self.adn_ordering,\n        )\n        return mod\n\n    def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module:\n        \"\"\"\n        Returns the bottom or bottleneck layer at the bottom of the network linking encode to decode halves.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n        \"\"\"\n        return self._get_down_layer(in_channels, out_channels, 1, False)\n\n    def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module:\n        \"\"\"\n        Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point\n        in its structure. Its output is used as input to the next layer up.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            strides: convolution stride.\n            is_top: True if this is the top block.\n        \"\"\"\n        conv: Convolution | nn.Sequential\n\n        conv = Convolution(\n            self.dimensions,\n            in_channels,\n            out_channels,\n            strides=strides,\n            kernel_size=self.up_kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n            conv_only=is_top and self.num_res_units == 0,\n            is_transposed=True,\n            adn_ordering=self.adn_ordering,\n        )\n\n        if self.num_res_units > 0:\n            ru = ResidualUnit(\n                self.dimensions,\n                out_channels,\n                out_channels,\n                strides=1,\n                kernel_size=self.kernel_size,\n                subunits=1,\n                act=self.act,\n                norm=self.norm,\n                dropout=self.dropout,\n                bias=self.bias,\n                last_conv_only=is_top,\n                adn_ordering=self.adn_ordering,\n            )\n            conv = nn.Sequential(conv, ru)\n\n        return conv\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.model(x)\n        return x\n\n\nclass CheckpointUNet(UNet):\n    \"\"\"UNet variant that wraps internal connection blocks with activation checkpointing.\n\n    See `UNet` for constructor arguments. During training with gradients enabled,\n    intermediate activations inside encoder-decoder connections are recomputed in\n    the backward pass to reduce peak memory usage at the cost of extra compute.\n    \"\"\"\n\n    def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:\n        \"\"\"Returns connection block with activation checkpointing applied to all components.\n\n        Args:\n            down_path: encoding half of the layer (will be wrapped with checkpointing).\n            up_path: decoding half of the layer (will be wrapped with checkpointing).\n            subblock: block defining the next layer (will be wrapped with checkpointing).\n\n        Returns:\n            Connection block with all components wrapped for activation checkpointing.\n        \"\"\"\n        subblock = ActivationCheckpointWrapper(subblock)\n        down_path = ActivationCheckpointWrapper(down_path)\n        up_path = ActivationCheckpointWrapper(up_path)\n        return super()._get_connection_block(down_path, up_path, subblock)\n\n\nUnet = UNet\n"
  },
  {
    "path": "monai/networks/nets/unetr.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch.nn as nn\n\nfrom monai.networks.blocks.dynunet_block import UnetOutBlock\nfrom monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock\nfrom monai.networks.nets.vit import ViT\nfrom monai.utils import ensure_tuple_rep\n\n\nclass UNETR(nn.Module):\n    \"\"\"\n    UNETR based on: \"Hatamizadeh et al.,\n    UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>\"\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        img_size: Sequence[int] | int,\n        feature_size: int = 16,\n        hidden_size: int = 768,\n        mlp_dim: int = 3072,\n        num_heads: int = 12,\n        proj_type: str = \"conv\",\n        norm_name: tuple | str = \"instance\",\n        conv_block: bool = True,\n        res_block: bool = True,\n        dropout_rate: float = 0.0,\n        spatial_dims: int = 3,\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: dimension of input channels.\n            out_channels: dimension of output channels.\n            img_size: dimension of input image.\n            feature_size: dimension of network feature size. Defaults to 16.\n            hidden_size: dimension of hidden layer. Defaults to 768.\n            mlp_dim: dimension of feedforward layer. Defaults to 3072.\n            num_heads: number of attention heads. Defaults to 12.\n            proj_type: patch embedding layer type. Defaults to \"conv\".\n            norm_name: feature normalization type and arguments. Defaults to \"instance\".\n            conv_block: if convolutional block is used. Defaults to True.\n            res_block: if residual block is used. Defaults to True.\n            dropout_rate: fraction of the input units to drop. Defaults to 0.0.\n            spatial_dims: number of spatial dims. Defaults to 3.\n            qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.\n            save_attn: to make accessible the attention in self attention block. Defaults to False.\n\n        Examples::\n\n            # for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm\n            >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')\n\n             # for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm\n            >>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)\n\n            # for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm\n            >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')\n\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(\"hidden_size should be divisible by num_heads.\")\n\n        self.num_layers = 12\n        img_size = ensure_tuple_rep(img_size, spatial_dims)\n        self.patch_size = ensure_tuple_rep(16, spatial_dims)\n        self.feat_size = tuple(img_d // p_d for img_d, p_d in zip(img_size, self.patch_size))\n        self.hidden_size = hidden_size\n        self.classification = False\n        self.vit = ViT(\n            in_channels=in_channels,\n            img_size=img_size,\n            patch_size=self.patch_size,\n            hidden_size=hidden_size,\n            mlp_dim=mlp_dim,\n            num_layers=self.num_layers,\n            num_heads=num_heads,\n            proj_type=proj_type,\n            classification=self.classification,\n            dropout_rate=dropout_rate,\n            spatial_dims=spatial_dims,\n            qkv_bias=qkv_bias,\n            save_attn=save_attn,\n        )\n        self.encoder1 = UnetrBasicBlock(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=feature_size,\n            kernel_size=3,\n            stride=1,\n            norm_name=norm_name,\n            res_block=res_block,\n        )\n        self.encoder2 = UnetrPrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_size,\n            out_channels=feature_size * 2,\n            num_layer=2,\n            kernel_size=3,\n            stride=1,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            conv_block=conv_block,\n            res_block=res_block,\n        )\n        self.encoder3 = UnetrPrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_size,\n            out_channels=feature_size * 4,\n            num_layer=1,\n            kernel_size=3,\n            stride=1,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            conv_block=conv_block,\n            res_block=res_block,\n        )\n        self.encoder4 = UnetrPrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_size,\n            out_channels=feature_size * 8,\n            num_layer=0,\n            kernel_size=3,\n            stride=1,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            conv_block=conv_block,\n            res_block=res_block,\n        )\n        self.decoder5 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=hidden_size,\n            out_channels=feature_size * 8,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=res_block,\n        )\n        self.decoder4 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size * 8,\n            out_channels=feature_size * 4,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=res_block,\n        )\n        self.decoder3 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size * 4,\n            out_channels=feature_size * 2,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=res_block,\n        )\n        self.decoder2 = UnetrUpBlock(\n            spatial_dims=spatial_dims,\n            in_channels=feature_size * 2,\n            out_channels=feature_size,\n            kernel_size=3,\n            upsample_kernel_size=2,\n            norm_name=norm_name,\n            res_block=res_block,\n        )\n        self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)\n        self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))\n        self.proj_view_shape = list(self.feat_size) + [self.hidden_size]\n\n    def proj_feat(self, x):\n        new_view = [x.size(0)] + self.proj_view_shape\n        x = x.view(new_view)\n        x = x.permute(self.proj_axes).contiguous()\n        return x\n\n    def forward(self, x_in):\n        x, hidden_states_out = self.vit(x_in)\n        enc1 = self.encoder1(x_in)\n        x2 = hidden_states_out[3]\n        enc2 = self.encoder2(self.proj_feat(x2))\n        x3 = hidden_states_out[6]\n        enc3 = self.encoder3(self.proj_feat(x3))\n        x4 = hidden_states_out[9]\n        enc4 = self.encoder4(self.proj_feat(x4))\n        dec4 = self.proj_feat(x)\n        dec3 = self.decoder5(dec4, enc4)\n        dec2 = self.decoder4(dec3, enc3)\n        dec1 = self.decoder3(dec2, enc2)\n        out = self.decoder2(dec1, enc1)\n        return self.out(out)\n"
  },
  {
    "path": "monai/networks/nets/varautoencoder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom monai.networks.layers.convutils import calculate_out_shape, same_padding\nfrom monai.networks.layers.factories import Act, Norm\nfrom monai.networks.nets import AutoEncoder\n\n__all__ = [\"VarAutoEncoder\"]\n\n\nclass VarAutoEncoder(AutoEncoder):\n    \"\"\"\n    Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_shape: shape of input data starting with channel dimension.\n        out_channels: number of output channels.\n        latent_size: size of the latent variable.\n        channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.\n        strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.\n        kernel_size: convolution kernel size, the value(s) should be odd. If sequence,\n            its length should equal to dimensions. Defaults to 3.\n        up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,\n            its length should equal to dimensions. Defaults to 3.\n        num_res_units: number of residual units. Defaults to 0.\n        inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode.\n        inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1.\n        num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0.\n        act: activation type and arguments. Defaults to PReLU.\n        norm: feature normalization type and arguments. Defaults to instance norm.\n        dropout: dropout ratio. Defaults to no dropout.\n        bias: whether to have a bias term in convolution blocks. Defaults to True.\n            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n            if a conv layer is directly followed by a batch norm layer, bias should be False.\n        use_sigmoid: whether to use the sigmoid function on final output. Defaults to True.\n\n    Examples::\n\n        from monai.networks.nets import VarAutoEncoder\n\n        # 3 layer network accepting images with dimensions (1, 32, 32) and using a latent vector with 2 values\n        model = VarAutoEncoder(\n            spatial_dims=2,\n            in_shape=(32, 32),  # image spatial shape\n            out_channels=1,\n            latent_size=2,\n            channels=(16, 32, 64),\n            strides=(1, 2, 2),\n        )\n\n    see also:\n        - Variational autoencoder network with MedNIST Dataset\n          https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_shape: Sequence[int],\n        out_channels: int,\n        latent_size: int,\n        channels: Sequence[int],\n        strides: Sequence[int],\n        kernel_size: Sequence[int] | int = 3,\n        up_kernel_size: Sequence[int] | int = 3,\n        num_res_units: int = 0,\n        inter_channels: list | None = None,\n        inter_dilations: list | None = None,\n        num_inter_units: int = 2,\n        act: tuple | str | None = Act.PRELU,\n        norm: tuple | str = Norm.INSTANCE,\n        dropout: tuple | str | float | None = None,\n        bias: bool = True,\n        use_sigmoid: bool = True,\n    ) -> None:\n        self.in_channels, *self.in_shape = in_shape\n        self.use_sigmoid = use_sigmoid\n\n        self.latent_size = latent_size\n        self.final_size = np.asarray(self.in_shape, dtype=int)\n\n        super().__init__(\n            spatial_dims,\n            self.in_channels,\n            out_channels,\n            channels,\n            strides,\n            kernel_size,\n            up_kernel_size,\n            num_res_units,\n            inter_channels,\n            inter_dilations,\n            num_inter_units,\n            act,\n            norm,\n            dropout,\n            bias,\n        )\n\n        padding = same_padding(self.kernel_size)\n\n        for s in strides:\n            self.final_size = calculate_out_shape(self.final_size, self.kernel_size, s, padding)  # type: ignore\n\n        linear_size = int(np.prod(self.final_size)) * self.encoded_channels\n        self.mu = nn.Linear(linear_size, self.latent_size)\n        self.logvar = nn.Linear(linear_size, self.latent_size)\n        self.decodeL = nn.Linear(self.latent_size, linear_size)\n\n    def encode_forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        x = self.encode(x)\n        x = self.intermediate(x)\n        x = x.view(x.shape[0], -1)\n        mu = self.mu(x)\n        logvar = self.logvar(x)\n        return mu, logvar\n\n    def decode_forward(self, z: torch.Tensor, use_sigmoid: bool = True) -> torch.Tensor:\n        x = F.relu(self.decodeL(z))\n        x = x.view(x.shape[0], self.channels[-1], *self.final_size)\n        x = self.decode(x)\n        if use_sigmoid:\n            x = torch.sigmoid(x)\n        return x\n\n    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:\n        std = torch.exp(0.5 * logvar)\n\n        if self.training:  # multiply random noise with std only during training\n            std = torch.randn_like(std).mul(std)\n\n        return std.add_(mu)\n\n    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        mu, logvar = self.encode_forward(x)\n        z = self.reparameterize(mu, logvar)\n        return self.decode_forward(z, self.use_sigmoid), mu, logvar, z\n"
  },
  {
    "path": "monai/networks/nets/vista3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\nfrom typing import Any, Callable\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nimport monai\nfrom monai.networks.blocks import MLPBlock, UnetrBasicBlock\nfrom monai.networks.nets import SegResNetDS2\nfrom monai.transforms.utils import convert_points_to_disc\nfrom monai.transforms.utils import keep_merge_components_with_points as lcc\nfrom monai.transforms.utils import sample_points_from_label\nfrom monai.utils import optional_import, unsqueeze_left, unsqueeze_right\n\nrearrange, _ = optional_import(\"einops\", name=\"rearrange\")\n\n__all__ = [\"VISTA3D\", \"vista3d132\"]\n\n\ndef vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1):\n    \"\"\"\n    Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_.\n    The model treats class index larger than 132 as zero-shot.\n\n    Args:\n        encoder_embed_dim: hidden dimension for encoder.\n        in_channels: input channel number.\n    \"\"\"\n    segresnet = SegResNetDS2(\n        in_channels=in_channels,\n        blocks_down=(1, 2, 2, 4, 4),\n        norm=\"instance\",\n        out_channels=encoder_embed_dim,\n        init_filters=encoder_embed_dim,\n        dsdepth=1,\n    )\n    point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132)\n    class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True)\n    vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head)\n    return vista\n\n\nclass VISTA3D(nn.Module):\n    \"\"\"\n    VISTA3D based on:\n        `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography\n        <https://arxiv.org/abs/2406.05285>`_.\n\n    Args:\n        image_encoder: image encoder backbone for feature extraction.\n        class_head: class head used for class index based segmentation\n        point_head: point head used for interactive segmetnation\n    \"\"\"\n\n    def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module):\n        super().__init__()\n        self.image_encoder = image_encoder\n        self.class_head = class_head\n        self.point_head = point_head\n        self.image_embeddings = None\n        self.auto_freeze = False\n        self.point_freeze = False\n        self.NINF_VALUE = -9999\n        self.PINF_VALUE = 9999\n\n    def update_slidingwindow_padding(\n        self,\n        pad_size: list | None,\n        labels: torch.Tensor | None,\n        prev_mask: torch.Tensor | None,\n        point_coords: torch.Tensor | None,\n    ):\n        \"\"\"\n        Image has been padded by sliding window inferer.\n        The related padding need to be performed outside of slidingwindow inferer.\n\n        Args:\n            pad_size: padding size passed from sliding window inferer.\n            labels: image label ground truth.\n            prev_mask: previous segmentation mask.\n            point_coords: point click coordinates.\n        \"\"\"\n        if pad_size is None:\n            return labels, prev_mask, point_coords\n        if labels is not None:\n            labels = F.pad(labels, pad=pad_size, mode=\"constant\", value=0)\n        if prev_mask is not None:\n            prev_mask = F.pad(prev_mask, pad=pad_size, mode=\"constant\", value=0)\n        if point_coords is not None:\n            point_coords = point_coords + torch.tensor(\n                [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device\n            )\n        return labels, prev_mask, point_coords\n\n    def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:\n        \"\"\"Get number of foreground classes based on class and point prompt.\"\"\"\n        if class_vector is None:\n            if point_coords is None:\n                raise ValueError(\"class_vector and point_coords cannot be both None.\")\n            return point_coords.shape[0]\n        else:\n            return class_vector.shape[0]\n\n    def convert_point_label(\n        self,\n        point_label: torch.Tensor,\n        label_set: Sequence[int] | None = None,\n        special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128),\n    ):\n        \"\"\"\n        Convert point label based on its class prompt. For special classes defined in special index,\n        the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those\n        classes with ambiguous classes.\n\n        Args:\n            point_label: the point label tensor, [B, N].\n            label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,\n                this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot\n                evaluation, this label_set should be the original index.\n            special_index: the special class index that needs to be converted.\n        \"\"\"\n        if label_set is None:\n            return point_label\n        if not point_label.shape[0] == len(label_set):\n            raise ValueError(\"point_label and label_set must have the same length.\")\n\n        for i in range(len(label_set)):\n            if label_set[i] in special_index:\n                for j in range(len(point_label[i])):\n                    point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j]\n        return point_label\n\n    def sample_points_patch_val(\n        self,\n        labels: torch.Tensor,\n        patch_coords: Sequence[slice],\n        label_set: Sequence[int],\n        use_center: bool = True,\n        mapped_label_set: Sequence[int] | None = None,\n        max_ppoint: int = 1,\n        max_npoint: int = 0,\n    ):\n        \"\"\"\n        Sample points for patch during sliding window validation. Only used for point only validation.\n\n        Args:\n            labels: shape [1, 1, H, W, D].\n            patch_coords: a sequence of sliding window slice objects.\n            label_set: local index, must match values in labels.\n            use_center: sample points from the center.\n            mapped_label_set: global index, it is used to identify special classes and is the global index\n                for the sampled points.\n            max_ppoint/max_npoint: positive points and negative points to sample.\n        \"\"\"\n        point_coords, point_labels = sample_points_from_label(\n            labels[patch_coords],\n            label_set,\n            max_ppoint=max_ppoint,\n            max_npoint=max_npoint,\n            device=labels.device,\n            use_center=use_center,\n        )\n        point_labels = self.convert_point_label(point_labels, mapped_label_set)\n        return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1))\n\n    def update_point_to_patch(\n        self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor\n    ):\n        \"\"\"\n        Update point_coords with respect to patch coords.\n        If point is outside of the patch, remove the coordinates and set label to -1.\n\n        Args:\n            patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.\n                This value is passed from sliding_window_inferer.\n            point_coords: point coordinates, [B, N, 3].\n            point_labels: point labels, [B, N].\n        \"\"\"\n        patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop]\n        patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start]\n        # update point coords\n        patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2)\n        patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2)\n        # [1 N 1]\n        indices = torch.logical_and(\n            ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2)\n        )\n        # check if it's within patch coords\n        point_coords = point_coords.clone() - patch_starts_tensor\n        point_labels = point_labels.clone()\n        if indices.any():\n            point_labels[~indices] = -1\n            point_coords[~indices] = 0\n            # also remove padded points, mainly used for inference.\n            not_pad_indices = (point_labels != -1).any(0)\n            point_coords = point_coords[:, not_pad_indices]\n            point_labels = point_labels[:, not_pad_indices]\n            return point_coords, point_labels\n        return None, None\n\n    def connected_components_combine(\n        self,\n        logits: torch.Tensor,\n        point_logits: torch.Tensor,\n        point_coords: torch.Tensor,\n        point_labels: torch.Tensor,\n        mapping_index: torch.Tensor,\n        thred: float = 0.5,\n    ):\n        \"\"\"\n        Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks\n        from a single image patch.\n        Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing.\n        mapping_index represents the correspondence between B and B1.\n        For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed\n        region in point clicks must be updated by the lcc function.\n        Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added.\n\n        Args:\n            logits: automatic branch results, [B, 1, H, W, D].\n            point_logits: point branch results, [B1, 1, H, W, D].\n            point_coords: point coordinates, [B1, N, 3].\n            point_labels: point labels, [B1, N].\n            mapping_index: [B].\n            thred: the threshold to convert logits to binary.\n        \"\"\"\n        logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits\n        _logits = logits[mapping_index]\n        inside = []\n        for i in range(_logits.shape[0]):\n            p_coord = point_coords[i].cpu().numpy().round().astype(int)\n            inside_p = [_logits[i, 0, p[0], p[1], p[2]].item() > 0 for p in p_coord]\n            inside.append(int(np.any(inside_p)))  # convert to int to avoid typing problems with Numpy\n\n        inside_tensor = torch.tensor(inside).to(logits.device)\n        nan_mask = torch.isnan(_logits)\n        # _logits are converted to binary [B1, 1, H, W, D]\n        _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid()\n        pos_region = point_logits.sigmoid() > thred\n        diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region)\n        diff_neg = torch.logical_and((_logits > thred), ~pos_region)\n        cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels)\n        # cc is the region that can be updated by point_logits.\n        cc = cc.to(logits.device)\n        # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask,\n        # only remove unconnected positive region.\n        uc_pos_region = torch.logical_and(pos_region, ~cc)\n        fill_mask = torch.logical_and(nan_mask, uc_pos_region)\n        if fill_mask.any():\n            # fill in the mean negative value\n            point_logits[fill_mask] = -1\n        # replace logits nan value and cc with point_logits\n        cc = torch.logical_or(nan_mask, cc).to(logits.dtype)\n        logits[mapping_index] *= 1 - cc\n        logits[mapping_index] += cc * point_logits\n        return logits\n\n    def gaussian_combine(\n        self,\n        logits: torch.Tensor,\n        point_logits: torch.Tensor,\n        point_coords: torch.Tensor,\n        point_labels: torch.Tensor,\n        mapping_index: torch.Tensor,\n        radius: int | None = None,\n    ):\n        \"\"\"\n        Combine point results with auto results using gaussian.\n\n        Args:\n            logits: automatic branch results, [B, 1, H, W, D].\n            point_logits: point branch results, [B1, 1, H, W, D].\n            point_coords: point coordinates, [B1, N, 3].\n            point_labels: point labels, [B1, N].\n            mapping_index: [B].\n            radius: gaussian ball radius.\n        \"\"\"\n        if radius is None:\n            radius = min(point_logits.shape[-3:]) // 5  # empirical value 5\n        weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum(\n            1, keepdims=True\n        )\n        weight[weight < 0] = 0\n        logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits\n        logits[mapping_index] *= weight\n        logits[mapping_index] += (1 - weight) * point_logits\n        return logits\n\n    def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):\n        \"\"\"\n        Freeze auto-branch or point-branch.\n\n        Args:\n            auto_freeze: whether to freeze the auto branch.\n            point_freeze: whether to freeze the point branch.\n        \"\"\"\n        if auto_freeze != self.auto_freeze:\n            if hasattr(self.image_encoder, \"set_auto_grad\"):\n                self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)  # type: ignore[operator]\n            else:\n                for param in self.image_encoder.parameters():\n                    param.requires_grad = (not auto_freeze) and (not point_freeze)\n            for param in self.class_head.parameters():\n                param.requires_grad = not auto_freeze\n            self.auto_freeze = auto_freeze\n\n        if point_freeze != self.point_freeze:\n            if hasattr(self.image_encoder, \"set_auto_grad\"):\n                self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)  # type: ignore[operator]\n            else:\n                for param in self.image_encoder.parameters():\n                    param.requires_grad = (not auto_freeze) and (not point_freeze)\n            for param in self.point_head.parameters():\n                param.requires_grad = not point_freeze\n            self.point_freeze = point_freeze\n\n    def forward(\n        self,\n        input_images: torch.Tensor,\n        patch_coords: list[Sequence[slice]] | None = None,\n        point_coords: torch.Tensor | None = None,\n        point_labels: torch.Tensor | None = None,\n        class_vector: torch.Tensor | None = None,\n        prompt_class: torch.Tensor | None = None,\n        labels: torch.Tensor | None = None,\n        label_set: Sequence[int] | None = None,\n        prev_mask: torch.Tensor | None = None,\n        radius: int | None = None,\n        val_point_sampler: Callable | None = None,\n        transpose: bool = False,\n        **kwargs,\n    ):\n        \"\"\"\n        The forward function for VISTA3D. We only support single patch in training and inference.\n        One exception is allowing sliding window batch size > 1 for automatic segmentation only case.\n        B represents number of objects, N represents number of points for each objects.\n\n        Args:\n            input_images: [1, 1, H, W, D]\n            point_coords: [B, N, 3]\n            point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.\n                2/3 means negative/postive ponits for special supported class like tumor.\n            class_vector: [B, 1], the global class index.\n            prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if\n                the points are for zero-shot or supported class. When class_vector and point_coords are both\n                provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]\n                will be considered novel class.\n            patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window\n                inference. This value is passed from sliding_window_inferer.\n                This is an indicator for training phase or validation phase.\n                Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude\n                coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the\n                functions using patch_coords will by default use patch_coords[0].\n            labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation\n            label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,\n                this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot\n                evaluation, this label_set should be the original index.\n            prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize].\n                This is the transposed raw output from sliding_window_inferer before any postprocessing.\n                When user click points to perform auto-results correction, this can be the auto-results.\n            radius: single float value controling the gaussian blur when combining point and auto results.\n                The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.\n            val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.\n            transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from\n                sliding window inferer/point inferer.\n        \"\"\"\n        labels, prev_mask, point_coords = self.update_slidingwindow_padding(\n            kwargs.get(\"pad_size\", None), labels, prev_mask, point_coords\n        )\n        image_size = input_images.shape[-3:]\n        device = input_images.device\n        if point_coords is None and class_vector is None:\n            return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device)\n\n        bs = self.get_foreground_class_count(class_vector, point_coords)\n        if patch_coords is not None:\n            # if during validation and perform enable based point-validation.\n            if labels is not None and label_set is not None:\n                # if labels is not None, sample from labels for each patch.\n                if val_point_sampler is None:\n                    # TODO: think about how to refactor this part.\n                    val_point_sampler = self.sample_points_patch_val\n                point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)\n                if prompt_class[0].item() == 0:  # type: ignore\n                    point_labels[0] = -1  # type: ignore\n                labels, prev_mask = None, None\n            elif point_coords is not None:\n                # If not performing patch-based point only validation, use user provided click points for inference.\n                # the point clicks is in original image space, convert it to current patch-coordinate space.\n                point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels)  # type: ignore\n\n        if point_coords is not None and point_labels is not None:\n            # remove points that used for padding purposes (point_label = -1)\n            mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool)\n            if mapping_index.any():\n                point_coords = point_coords[mapping_index]\n                point_labels = point_labels[mapping_index]\n                if prompt_class is not None:\n                    prompt_class = prompt_class[mapping_index]\n            else:\n                if self.auto_freeze or (class_vector is None and patch_coords is None):\n                    # if auto_freeze, point prompt must exist to allow loss backward\n                    # in training, class_vector and point cannot both be None due to loss.backward()\n                    mapping_index.fill_(True)\n                else:\n                    point_coords, point_labels = None, None\n\n        if point_coords is None and class_vector is None:\n            logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)\n            if transpose:\n                logits = logits.transpose(1, 0)\n            return logits\n\n        if self.image_embeddings is not None and kwargs.get(\"keep_cache\", False) and class_vector is None:\n            out, out_auto = self.image_embeddings, None\n        else:\n            out, out_auto = self.image_encoder(\n                input_images, with_point=point_coords is not None, with_label=class_vector is not None\n            )\n        # release memory\n        input_images = None  # type: ignore\n\n        # force releasing memories that set to None\n        torch.cuda.empty_cache()\n        if class_vector is not None:\n            logits, _ = self.class_head(out_auto, class_vector)\n            if point_coords is not None:\n                point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)\n                if patch_coords is None:\n                    logits = self.gaussian_combine(\n                        logits, point_logits, point_coords, point_labels, mapping_index, radius  # type: ignore\n                    )\n                else:\n                    # during validation use largest component\n                    logits = self.connected_components_combine(\n                        logits, point_logits, point_coords, point_labels, mapping_index  # type: ignore\n                    )\n        else:\n            logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype)\n            logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)\n            if prev_mask is not None and patch_coords is not None:\n                logits = self.connected_components_combine(\n                    prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),\n                    logits[mapping_index],\n                    point_coords,  # type: ignore\n                    point_labels,  # type: ignore\n                    mapping_index,\n                )\n        if kwargs.get(\"keep_cache\", False) and class_vector is None:\n            self.image_embeddings = out.detach()\n        if transpose:\n            logits = logits.transpose(1, 0)\n        return logits\n\n\nclass PointMappingSAM(nn.Module):\n    def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132):\n        \"\"\"Interactive point head used for VISTA3D.\n        Adapted from segment anything:\n        `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.\n\n        Args:\n            feature_size: feature channel from encoder.\n            max_prompt: max prompt number in each forward iteration.\n            n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings.\n            last_supported: number of classes the model support, this value should match the trained model weights.\n        \"\"\"\n        super().__init__()\n        transformer_dim = feature_size\n        self.max_prompt = max_prompt\n        self.feat_downsample = nn.Sequential(\n            nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1),\n            nn.InstanceNorm3d(feature_size),\n            nn.GELU(),\n            nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1),\n            nn.InstanceNorm3d(feature_size),\n        )\n\n        self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1)\n\n        self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4)\n        self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)\n        self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)])\n        self.not_a_point_embed = nn.Embedding(1, transformer_dim)\n        self.special_class_embed = nn.Embedding(1, transformer_dim)\n        self.mask_tokens = nn.Embedding(1, transformer_dim)\n\n        self.output_upscaling = nn.Sequential(\n            nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1),\n            nn.InstanceNorm3d(transformer_dim),\n            nn.GELU(),\n            nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),\n        )\n\n        self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3)\n        # class embedding\n        self.n_classes = n_classes\n        self.last_supported = last_supported\n        self.class_embeddings = nn.Embedding(n_classes, feature_size)\n        self.zeroshot_embed = nn.Embedding(1, transformer_dim)\n        self.supported_embed = nn.Embedding(1, transformer_dim)\n\n    def forward(\n        self,\n        out: torch.Tensor,\n        point_coords: torch.Tensor,\n        point_labels: torch.Tensor,\n        class_vector: torch.Tensor | None = None,\n    ):\n        \"\"\"Args:\n        out: feature from encoder, [1, C, H, W, C]\n        point_coords: point coordinates, [B, N, 3]\n        point_labels: point labels, [B, N]\n        class_vector: class prompts, [B]\n        \"\"\"\n        # downsample out\n        out_low = self.feat_downsample(out)\n        out_shape = tuple(out.shape[-3:])\n        # release memory\n        out = None  # type: ignore\n        torch.cuda.empty_cache()\n        # embed points\n        points = point_coords + 0.5  # Shift to center of pixel\n        point_embedding = self.pe_layer.forward_with_coords(points, out_shape)  # type: ignore\n        point_embedding[point_labels == -1] = 0.0\n        point_embedding[point_labels == -1] += self.not_a_point_embed.weight\n        point_embedding[point_labels == 0] += self.point_embeddings[0].weight  # type: ignore[arg-type]\n        point_embedding[point_labels == 1] += self.point_embeddings[1].weight  # type: ignore[arg-type]\n        point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight  # type: ignore[operator]\n        point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight  # type: ignore[operator]\n        output_tokens = self.mask_tokens.weight\n\n        output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)\n        if class_vector is None:\n            tokens_all = torch.cat(\n                (\n                    output_tokens,\n                    point_embedding,\n                    self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1),\n                ),\n                dim=1,\n            )\n            # tokens_all = torch.cat((output_tokens, point_embedding), dim=1)\n        else:\n            class_embeddings = []\n            for i in class_vector:\n                if i > self.last_supported:\n                    class_embeddings.append(self.zeroshot_embed.weight)\n                else:\n                    class_embeddings.append(self.supported_embed.weight)\n            tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1)\n        # cross attention\n        masks = []\n        max_prompt = self.max_prompt\n        for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))):\n            # remove variables in previous for loops to save peak memory for self.transformer\n            src, upscaled_embedding, hyper_in = None, None, None\n            torch.cuda.empty_cache()\n            idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0]))\n            tokens = tokens_all[idx[0] : idx[1]]\n            src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0)\n            pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0)\n            b, c, h, w, d = src.shape\n            hs, src = self.transformer(src, pos_src, tokens)\n            mask_tokens_out = hs[:, :1, :]\n            hyper_in = self.output_hypernetworks_mlps(mask_tokens_out)\n            src = src.transpose(1, 2).view(b, c, h, w, d)  # type: ignore\n            upscaled_embedding = self.output_upscaling(src)\n            b, c, h, w, d = upscaled_embedding.shape\n            mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d)\n            masks.append(mask.view(-1, 1, h, w, d))\n\n        return torch.vstack(masks)\n\n\nclass ClassMappingClassify(nn.Module):\n    \"\"\"Class head that performs automatic segmentation based on class vector.\"\"\"\n\n    def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True):\n        \"\"\"Args:\n        n_classes: maximum number of class embedding.\n        feature_size: class embedding size.\n        use_mlp: use mlp to further map class embedding.\n        \"\"\"\n        super().__init__()\n        self.use_mlp = use_mlp\n        if use_mlp:\n            self.mlp = nn.Sequential(\n                nn.Linear(feature_size, feature_size),\n                nn.InstanceNorm1d(1),\n                nn.GELU(),\n                nn.Linear(feature_size, feature_size),\n            )\n        self.class_embeddings = nn.Embedding(n_classes, feature_size)\n        self.image_post_mapping = nn.Sequential(\n            UnetrBasicBlock(\n                spatial_dims=3,\n                in_channels=feature_size,\n                out_channels=feature_size,\n                kernel_size=3,\n                stride=1,\n                norm_name=\"instance\",\n                res_block=True,\n            ),\n            UnetrBasicBlock(\n                spatial_dims=3,\n                in_channels=feature_size,\n                out_channels=feature_size,\n                kernel_size=3,\n                stride=1,\n                norm_name=\"instance\",\n                res_block=True,\n            ),\n        )\n\n    def forward(self, src: torch.Tensor, class_vector: torch.Tensor):\n        b, c, h, w, d = src.shape\n        src = self.image_post_mapping(src)\n        class_embedding = self.class_embeddings(class_vector)\n        if self.use_mlp:\n            class_embedding = self.mlp(class_embedding)\n        # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.\n        masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)\n        masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)\n        return masks_embedding, class_embedding\n\n\nclass TwoWayTransformer(nn.Module):\n    def __init__(\n        self,\n        depth: int,\n        embedding_dim: int,\n        num_heads: int,\n        mlp_dim: int,\n        activation: tuple | str = \"relu\",\n        attention_downsample_rate: int = 2,\n    ) -> None:\n        \"\"\"\n        A transformer decoder that attends to an input image using\n        queries whose positional embedding is supplied.\n        Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.\n\n        Args:\n            depth: number of layers in the transformer.\n            embedding_dim: the channel dimension for the input embeddings.\n            num_heads: the number of heads for multihead attention. Must divide embedding_dim.\n            mlp_dim: the channel dimension internal to the MLP block.\n            activation: the activation to use in the MLP block.\n            attention_downsample_rate: the rate at which to downsample the image before projecting.\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.embedding_dim = embedding_dim\n        self.num_heads = num_heads\n        self.mlp_dim = mlp_dim\n        self.layers = nn.ModuleList()\n\n        for i in range(depth):\n            self.layers.append(\n                TwoWayAttentionBlock(\n                    embedding_dim=embedding_dim,\n                    num_heads=num_heads,\n                    mlp_dim=mlp_dim,\n                    activation=activation,\n                    attention_downsample_rate=attention_downsample_rate,\n                    skip_first_layer_pe=(i == 0),\n                )\n            )\n\n        self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n        self.norm_final_attn = nn.LayerNorm(embedding_dim)\n\n    def forward(\n        self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            image_embedding: image to attend to. Should be shape\n                B x embedding_dim x h x w for any h and w.\n            image_pe: the positional encoding to add to the image. Must\n                have the same shape as image_embedding.\n            point_embedding: the embedding to add to the query points.\n                Must have shape B x N_points x embedding_dim for any N_points.\n\n        Returns:\n            torch.Tensor: the processed point_embedding.\n            torch.Tensor: the processed image_embedding.\n        \"\"\"\n        # BxCxHxW -> BxHWxC == B x N_image_tokens x C\n        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)\n        image_pe = image_pe.flatten(2).permute(0, 2, 1)\n\n        # Prepare queries\n        queries = point_embedding\n        keys = image_embedding\n\n        # Apply transformer blocks and final layernorm\n        for layer in self.layers:\n            queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe)\n\n        # Apply the final attention layer from the points to the image\n        q = queries + point_embedding\n        k = keys + image_pe\n        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)\n        queries = queries + attn_out\n        queries = self.norm_final_attn(queries)\n\n        return queries, keys\n\n\nclass TwoWayAttentionBlock(nn.Module):\n    def __init__(\n        self,\n        embedding_dim: int,\n        num_heads: int,\n        mlp_dim: int = 2048,\n        activation: tuple | str = \"relu\",\n        attention_downsample_rate: int = 2,\n        skip_first_layer_pe: bool = False,\n    ) -> None:\n        \"\"\"\n        A transformer block with four layers: (1) self-attention of sparse\n        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp\n        block on sparse inputs, and (4) cross attention of dense inputs to sparse\n        inputs.\n        Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.\n\n        Args:\n            embedding_dim: the channel dimension of the embeddings.\n            num_heads: the number of heads in the attention layers.\n            mlp_dim: the hidden dimension of the mlp block.\n            activation: the activation of the mlp block.\n            skip_first_layer_pe: skip the PE on the first layer.\n        \"\"\"\n        super().__init__()\n        self.self_attn = Attention(embedding_dim, num_heads)\n        self.norm1 = nn.LayerNorm(embedding_dim)\n\n        self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n        self.norm2 = nn.LayerNorm(embedding_dim)\n\n        self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode=\"vista3d\")\n        self.norm3 = nn.LayerNorm(embedding_dim)\n\n        self.norm4 = nn.LayerNorm(embedding_dim)\n        self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)\n\n        self.skip_first_layer_pe = skip_first_layer_pe\n\n    def forward(\n        self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        # Self attention block\n        if self.skip_first_layer_pe:\n            queries = self.self_attn(q=queries, k=queries, v=queries)\n        else:\n            q = queries + query_pe\n            attn_out = self.self_attn(q=q, k=q, v=queries)\n            queries = queries + attn_out\n        queries = self.norm1(queries)\n\n        # Cross attention block, tokens attending to image embedding\n        q = queries + query_pe\n        k = keys + key_pe\n        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)\n        queries = queries + attn_out\n        queries = self.norm2(queries)\n\n        # MLP block\n        mlp_out = self.mlp(queries)\n        queries = queries + mlp_out\n        queries = self.norm3(queries)\n\n        # Cross attention block, image embedding attending to tokens\n        q = queries + query_pe\n        k = keys + key_pe\n        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)\n        keys = keys + attn_out\n        keys = self.norm4(keys)\n\n        return queries, keys\n\n\nclass Attention(nn.Module):\n    \"\"\"\n    An attention layer that allows for downscaling the size of the embedding\n    after projection to queries, keys, and values.\n    Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.\n\n    Args:\n        embedding_dim: the channel dimension of the embeddings.\n        num_heads: the number of heads in the attention layers.\n        downsample_rate: the rate at which to downsample the image before projecting.\n    \"\"\"\n\n    def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None:\n        super().__init__()\n        self.embedding_dim = embedding_dim\n        self.internal_dim = embedding_dim // downsample_rate\n        self.num_heads = num_heads\n        if not self.internal_dim % num_heads == 0:\n            raise ValueError(\"num_heads must divide embedding_dim.\")\n\n        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)\n        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)\n        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)\n        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)\n\n    def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:\n        b, n, c = x.shape\n        x = x.reshape(b, n, num_heads, c // num_heads)\n        # B x N_heads x N_tokens x C_per_head\n        return x.transpose(1, 2)\n\n    def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:\n        b, n_heads, n_tokens, c_per_head = x.shape\n        x = x.transpose(1, 2)\n        # B x N_tokens x C\n        return x.reshape(b, n_tokens, n_heads * c_per_head)\n\n    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:\n        # Input projections\n        q = self.q_proj(q)\n        k = self.k_proj(k)\n        v = self.v_proj(v)\n\n        # Separate into heads\n        q = self._separate_heads(q, self.num_heads)\n        k = self._separate_heads(k, self.num_heads)\n        v = self._separate_heads(v, self.num_heads)\n\n        # Attention\n        _, _, _, c_per_head = q.shape\n        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens\n        attn = attn / math.sqrt(c_per_head)\n        attn = torch.softmax(attn, dim=-1)\n\n        # Get output\n        out = attn @ v\n        out = self._recombine_heads(out)\n        out = self.out_proj(out)\n\n        return out\n\n\nclass PositionEmbeddingRandom(nn.Module):\n    \"\"\"\n    Positional encoding using random spatial frequencies.\n    Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`.\n\n    Args:\n        num_pos_feats: the number of positional encoding features.\n        scale: the scale of the positional encoding.\n    \"\"\"\n\n    def __init__(self, num_pos_feats: int = 64, scale: float | None = None) -> None:\n        super().__init__()\n        if scale is None or scale <= 0.0:\n            scale = 1.0\n        self.register_buffer(\"positional_encoding_gaussian_matrix\", scale * torch.randn((3, num_pos_feats)))\n\n    def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor:\n        \"\"\"Positionally encode points that are normalized to [0,1].\"\"\"\n        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape\n        coords = 2 * coords - 1\n        # [bs=1,N=2,2] @ [2,128]\n        # [bs=1, N=2, 128]\n        coords = coords @ self.positional_encoding_gaussian_matrix  # type: ignore[operator]\n        coords = 2 * np.pi * coords\n        # outputs d_1 x ... x d_n x C shape\n        # [bs=1, N=2, 128+128=256]\n        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)\n\n    def forward(self, size: tuple[int, int, int]) -> torch.torch.Tensor:\n        \"\"\"Generate positional encoding for a grid of the specified size.\"\"\"\n        h, w, d = size\n        device: Any = self.positional_encoding_gaussian_matrix.device\n        grid = torch.ones((h, w, d), device=device, dtype=torch.float32)\n        x_embed = grid.cumsum(dim=0) - 0.5\n        y_embed = grid.cumsum(dim=1) - 0.5\n        z_embed = grid.cumsum(dim=2) - 0.5\n        x_embed = x_embed / h\n        y_embed = y_embed / w\n        z_embed = z_embed / d\n        pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))\n        # C x H x W\n        return pe.permute(3, 0, 1, 2)\n\n    def forward_with_coords(\n        self, coords_input: torch.torch.Tensor, image_size: tuple[int, int, int]\n    ) -> torch.torch.Tensor:\n        \"\"\"Positionally encode points that are not normalized to [0,1].\"\"\"\n        coords = coords_input.clone()\n        coords[:, :, 0] = coords[:, :, 0] / image_size[0]\n        coords[:, :, 1] = coords[:, :, 1] / image_size[1]\n        coords[:, :, 2] = coords[:, :, 2] / image_size[2]\n        # B x N x C\n        return self._pe_encoding(coords.to(torch.float))\n\n\nclass MLP(nn.Module):\n    \"\"\"\n    Multi-layer perceptron. This class is only used for `PointMappingSAM`.\n    Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.\n\n    Args:\n        input_dim: the input dimension.\n        hidden_dim: the hidden dimension.\n        output_dim: the output dimension.\n        num_layers: the number of layers.\n        sigmoid_output: whether to apply a sigmoid activation to the output.\n    \"\"\"\n\n    def __init__(\n        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False\n    ) -> None:\n        super().__init__()\n        self.num_layers = num_layers\n        h = [hidden_dim] * (num_layers - 1)\n        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))\n        self.sigmoid_output = sigmoid_output\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for i, layer in enumerate(self.layers):\n            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)\n        if self.sigmoid_output:\n            x = F.sigmoid(x)\n        return x\n"
  },
  {
    "path": "monai/networks/nets/vit.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.patchembedding import PatchEmbeddingBlock\nfrom monai.networks.blocks.transformerblock import TransformerBlock\n\n__all__ = [\"ViT\"]\n\n\nclass ViT(nn.Module):\n    \"\"\"\n    Vision Transformer (ViT), based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n\n    ViT supports Torchscript but only works for Pytorch after 1.8.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        img_size: Sequence[int] | int,\n        patch_size: Sequence[int] | int,\n        hidden_size: int = 768,\n        mlp_dim: int = 3072,\n        num_layers: int = 12,\n        num_heads: int = 12,\n        proj_type: str = \"conv\",\n        pos_embed_type: str = \"learnable\",\n        classification: bool = False,\n        num_classes: int = 2,\n        dropout_rate: float = 0.0,\n        spatial_dims: int = 3,\n        post_activation=\"Tanh\",\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels (int): dimension of input channels.\n            img_size (Union[Sequence[int], int]): dimension of input image.\n            patch_size (Union[Sequence[int], int]): dimension of patch size.\n            hidden_size (int, optional): dimension of hidden layer. Defaults to 768.\n            mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.\n            num_layers (int, optional): number of transformer blocks. Defaults to 12.\n            num_heads (int, optional): number of attention heads. Defaults to 12.\n            proj_type (str, optional): patch embedding layer type. Defaults to \"conv\".\n            pos_embed_type (str, optional): position embedding type. Defaults to \"learnable\".\n            classification (bool, optional): bool argument to determine if classification is used. Defaults to False.\n            num_classes (int, optional): number of classes if classification is used. Defaults to 2.\n            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.\n            spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.\n            post_activation (str, optional): add a final acivation function to the classification head\n                when `classification` is True. Default to \"Tanh\" for `nn.Tanh()`.\n                Set to other values to remove this function.\n            qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.\n            save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.\n\n        Examples::\n\n            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone\n            >>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')\n\n            # for 3-channel with image size of (128,128,128), 24 layers and classification backbone\n            >>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)\n\n            # for 3-channel with image size of (224,224), 12 layers and classification backbone\n            >>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,\n            >>>           spatial_dims=2)\n\n        \"\"\"\n\n        super().__init__()\n\n        if not (0 <= dropout_rate <= 1):\n            raise ValueError(\"dropout_rate should be between 0 and 1.\")\n\n        if hidden_size % num_heads != 0:\n            raise ValueError(\"hidden_size should be divisible by num_heads.\")\n\n        self.classification = classification\n        self.patch_embedding = PatchEmbeddingBlock(\n            in_channels=in_channels,\n            img_size=img_size,\n            patch_size=patch_size,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            proj_type=proj_type,\n            pos_embed_type=pos_embed_type,\n            dropout_rate=dropout_rate,\n            spatial_dims=spatial_dims,\n        )\n        self.blocks = nn.ModuleList(\n            [\n                TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)\n                for i in range(num_layers)\n            ]\n        )\n        self.norm = nn.LayerNorm(hidden_size)\n        if self.classification:\n            self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))\n            if post_activation == \"Tanh\":\n                self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())\n            else:\n                self.classification_head = nn.Linear(hidden_size, num_classes)  # type: ignore\n\n    def forward(self, x):\n        x = self.patch_embedding(x)\n        if hasattr(self, \"cls_token\"):\n            cls_token = self.cls_token.expand(x.shape[0], -1, -1)\n            x = torch.cat((cls_token, x), dim=1)\n        hidden_states_out = []\n        for blk in self.blocks:\n            x = blk(x)\n            hidden_states_out.append(x)\n        x = self.norm(x)\n        if hasattr(self, \"classification_head\"):\n            x = self.classification_head(x[:, 0])\n        return x, hidden_states_out\n"
  },
  {
    "path": "monai/networks/nets/vitautoenc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.patchembedding import PatchEmbeddingBlock\nfrom monai.networks.blocks.transformerblock import TransformerBlock\nfrom monai.networks.layers import Conv\nfrom monai.utils import ensure_tuple_rep, is_sqrt\n\n__all__ = [\"ViTAutoEnc\"]\n\n\nclass ViTAutoEnc(nn.Module):\n    \"\"\"\n    Vision Transformer (ViT), based on: \"Dosovitskiy et al.,\n    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>\"\n\n    Modified to also give same dimension outputs as the input size of the image\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int,\n        img_size: Sequence[int] | int,\n        patch_size: Sequence[int] | int,\n        out_channels: int = 1,\n        deconv_chns: int = 16,\n        hidden_size: int = 768,\n        mlp_dim: int = 3072,\n        num_layers: int = 12,\n        num_heads: int = 12,\n        proj_type: str = \"conv\",\n        dropout_rate: float = 0.0,\n        spatial_dims: int = 3,\n        qkv_bias: bool = False,\n        save_attn: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            in_channels: dimension of input channels or the number of channels for input.\n            img_size: dimension of input image.\n            patch_size: dimension of patch size\n            out_channels:  number of output channels. Defaults to 1.\n            deconv_chns: number of channels for the deconvolution layers. Defaults to 16.\n            hidden_size: dimension of hidden layer. Defaults to 768.\n            mlp_dim: dimension of feedforward layer. Defaults to 3072.\n            num_layers:  number of transformer blocks. Defaults to 12.\n            num_heads: number of attention heads. Defaults to 12.\n            proj_type: position embedding layer type. Defaults to \"conv\".\n            dropout_rate: fraction of the input units to drop. Defaults to 0.0.\n            spatial_dims: number of spatial dimensions. Defaults to 3.\n            qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.\n            save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False.\n\n        Examples::\n\n            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone\n            # It will provide an output of same size as that of the input\n            >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv')\n\n            # for 3-channel with image size of (128,128,128), output will be same size as of input\n            >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv')\n\n        \"\"\"\n\n        super().__init__()\n        if not is_sqrt(patch_size):\n            raise ValueError(f\"patch_size should be square number, got {patch_size}.\")\n        self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)\n        self.img_size = ensure_tuple_rep(img_size, spatial_dims)\n        self.spatial_dims = spatial_dims\n        for m, p in zip(self.img_size, self.patch_size):\n            if m % p != 0:\n                raise ValueError(f\"patch_size={patch_size} should be divisible by img_size={img_size}.\")\n\n        self.patch_embedding = PatchEmbeddingBlock(\n            in_channels=in_channels,\n            img_size=img_size,\n            patch_size=patch_size,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            proj_type=proj_type,\n            dropout_rate=dropout_rate,\n            spatial_dims=self.spatial_dims,\n        )\n        self.blocks = nn.ModuleList(\n            [\n                TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)\n                for i in range(num_layers)\n            ]\n        )\n        self.norm = nn.LayerNorm(hidden_size)\n\n        conv_trans = Conv[Conv.CONVTRANS, self.spatial_dims]\n        # self.conv3d_transpose* is to be compatible with existing 3d model weights.\n        up_kernel_size = [int(math.sqrt(i)) for i in self.patch_size]\n        self.conv3d_transpose = conv_trans(hidden_size, deconv_chns, kernel_size=up_kernel_size, stride=up_kernel_size)\n        self.conv3d_transpose_1 = conv_trans(\n            in_channels=deconv_chns, out_channels=out_channels, kernel_size=up_kernel_size, stride=up_kernel_size\n        )\n\n    def forward(self, x):\n        \"\"\"\n        Args:\n            x: input tensor must have isotropic spatial dimensions,\n                such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.\n        \"\"\"\n        spatial_size = x.shape[2:]\n        x = self.patch_embedding(x)\n        hidden_states_out = []\n        for blk in self.blocks:\n            x = blk(x)\n            hidden_states_out.append(x)\n        x = self.norm(x)\n        x = x.transpose(1, 2)\n        d = [s // p for s, p in zip(spatial_size, self.patch_size)]\n        x = torch.reshape(x, [x.shape[0], x.shape[1], *d])\n        x = self.conv3d_transpose(x)\n        x = self.conv3d_transpose_1(x)\n        return x, hidden_states_out\n"
  },
  {
    "path": "monai/networks/nets/vnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.layers.factories import Act, Conv, Dropout, Norm, split_args\nfrom monai.utils import deprecated_arg\n\n__all__ = [\"VNet\"]\n\n\ndef get_acti_layer(act: tuple[str, dict] | str, nchan: int = 0):\n    if act == \"prelu\":\n        act = (\"prelu\", {\"num_parameters\": nchan})\n    act_name, act_args = split_args(act)\n    act_type = Act[act_name]\n    return act_type(**act_args)\n\n\nclass LUConv(nn.Module):\n\n    def __init__(self, spatial_dims: int, nchan: int, act: tuple[str, dict] | str, bias: bool = False):\n        super().__init__()\n\n        self.act_function = get_acti_layer(act, nchan)\n        self.conv_block = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=nchan,\n            out_channels=nchan,\n            kernel_size=5,\n            act=None,\n            norm=Norm.BATCH,\n            bias=bias,\n        )\n\n    def forward(self, x):\n        out = self.conv_block(x)\n        out = self.act_function(out)\n        return out\n\n\ndef _make_nconv(spatial_dims: int, nchan: int, depth: int, act: tuple[str, dict] | str, bias: bool = False):\n    layers = []\n    for _ in range(depth):\n        layers.append(LUConv(spatial_dims, nchan, act, bias))\n    return nn.Sequential(*layers)\n\n\nclass InputTransition(nn.Module):\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False\n    ):\n        super().__init__()\n\n        if out_channels % in_channels != 0:\n            raise ValueError(\n                f\"out channels should be divisible by in_channels. Got in_channels={in_channels}, out_channels={out_channels}.\"\n            )\n\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.act_function = get_acti_layer(act, out_channels)\n        self.conv_block = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=5,\n            act=None,\n            norm=Norm.BATCH,\n            bias=bias,\n        )\n\n    def forward(self, x):\n        out = self.conv_block(x)\n        repeat_num = self.out_channels // self.in_channels\n        x16 = x.repeat([1, repeat_num, 1, 1, 1][: self.spatial_dims + 2])\n        out = self.act_function(torch.add(out, x16))\n        return out\n\n\nclass DownTransition(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        nconvs: int,\n        act: tuple[str, dict] | str,\n        dropout_prob: float | None = None,\n        dropout_dim: int = 3,\n        bias: bool = False,\n    ):\n        super().__init__()\n\n        conv_type: type[nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        dropout_type: type[nn.Dropout | nn.Dropout2d | nn.Dropout3d] = Dropout[Dropout.DROPOUT, dropout_dim]\n\n        out_channels = 2 * in_channels\n        self.down_conv = conv_type(in_channels, out_channels, kernel_size=2, stride=2, bias=bias)\n        self.bn1 = norm_type(out_channels)\n        self.act_function1 = get_acti_layer(act, out_channels)\n        self.act_function2 = get_acti_layer(act, out_channels)\n        self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act, bias)\n        self.dropout = dropout_type(dropout_prob) if dropout_prob is not None else None\n\n    def forward(self, x):\n        down = self.act_function1(self.bn1(self.down_conv(x)))\n        if self.dropout is not None:\n            out = self.dropout(down)\n        else:\n            out = down\n        out = self.ops(out)\n        out = self.act_function2(torch.add(out, down))\n        return out\n\n\nclass UpTransition(nn.Module):\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        nconvs: int,\n        act: tuple[str, dict] | str,\n        dropout_prob: tuple[float | None, float] = (None, 0.5),\n        dropout_dim: int = 3,\n    ):\n        super().__init__()\n\n        conv_trans_type: type[nn.ConvTranspose2d | nn.ConvTranspose3d] = Conv[Conv.CONVTRANS, spatial_dims]\n        norm_type: type[nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]\n        dropout_type: type[nn.Dropout | nn.Dropout2d | nn.Dropout3d] = Dropout[Dropout.DROPOUT, dropout_dim]\n\n        self.up_conv = conv_trans_type(in_channels, out_channels // 2, kernel_size=2, stride=2)\n        self.bn1 = norm_type(out_channels // 2)\n        self.dropout = dropout_type(dropout_prob[0]) if dropout_prob[0] is not None else None\n        self.dropout2 = dropout_type(dropout_prob[1])\n        self.act_function1 = get_acti_layer(act, out_channels // 2)\n        self.act_function2 = get_acti_layer(act, out_channels)\n        self.ops = _make_nconv(spatial_dims, out_channels, nconvs, act)\n\n    def forward(self, x, skipx):\n        if self.dropout is not None:\n            out = self.dropout(x)\n        else:\n            out = x\n        skipxdo = self.dropout2(skipx)\n        out = self.act_function1(self.bn1(self.up_conv(out)))\n        xcat = torch.cat((out, skipxdo), 1)\n        out = self.ops(xcat)\n        out = self.act_function2(torch.add(out, xcat))\n        return out\n\n\nclass OutputTransition(nn.Module):\n\n    def __init__(\n        self, spatial_dims: int, in_channels: int, out_channels: int, act: tuple[str, dict] | str, bias: bool = False\n    ):\n        super().__init__()\n\n        conv_type: type[nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]\n\n        self.act_function1 = get_acti_layer(act, out_channels)\n        self.conv_block = Convolution(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=5,\n            act=None,\n            norm=Norm.BATCH,\n            bias=bias,\n        )\n        self.conv2 = conv_type(out_channels, out_channels, kernel_size=1)\n\n    def forward(self, x):\n        # convolve 32 down to 2 channels\n        out = self.conv_block(x)\n        out = self.act_function1(out)\n        out = self.conv2(out)\n        return out\n\n\nclass VNet(nn.Module):\n    \"\"\"\n    V-Net based on `Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation\n    <https://arxiv.org/pdf/1606.04797.pdf>`_.\n    Adapted from `the official Caffe implementation\n    <https://github.com/faustomilletari/VNet>`_. and `another pytorch implementation\n    <https://github.com/mattmacy/vnet.pytorch/blob/master/vnet.py>`_.\n    The model supports 2D or 3D inputs.\n\n    Args:\n        spatial_dims: spatial dimension of the input data. Defaults to 3.\n        in_channels: number of input channels for the network. Defaults to 1.\n            The value should meet the condition that ``16 % in_channels == 0``.\n        out_channels: number of output channels for the network. Defaults to 1.\n        act: activation type in the network. Defaults to ``(\"elu\", {\"inplace\": True})``.\n        dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5.\n        dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5).\n        dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5).\n\n            - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel.\n            - ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map).\n            - ``dropout_dim = 3``, Randomly zeroes out entire channels (a channel is a 3D feature map).\n        bias: whether to have a bias term in convolution blocks. Defaults to False.\n            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,\n            if a conv layer is directly followed by a batch norm layer, bias should be False.\n\n    .. deprecated:: 1.2\n        ``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``.\n\n    \"\"\"\n\n    @deprecated_arg(\n        name=\"dropout_prob\",\n        since=\"1.2\",\n        new_name=\"dropout_prob_down\",\n        msg_suffix=\"please use `dropout_prob_down` instead.\",\n    )\n    @deprecated_arg(\n        name=\"dropout_prob\", since=\"1.2\", new_name=\"dropout_prob_up\", msg_suffix=\"please use `dropout_prob_up` instead.\"\n    )\n    def __init__(\n        self,\n        spatial_dims: int = 3,\n        in_channels: int = 1,\n        out_channels: int = 1,\n        act: tuple[str, dict] | str = (\"elu\", {\"inplace\": True}),\n        dropout_prob: float | None = 0.5,  # deprecated\n        dropout_prob_down: float | None = 0.5,\n        dropout_prob_up: tuple[float | None, float] = (0.5, 0.5),\n        dropout_dim: int = 3,\n        bias: bool = False,\n    ):\n        super().__init__()\n\n        if spatial_dims not in (2, 3):\n            raise AssertionError(\"spatial_dims can only be 2 or 3.\")\n\n        self.in_tr = InputTransition(spatial_dims, in_channels, 16, act, bias=bias)\n        self.down_tr32 = DownTransition(spatial_dims, 16, 1, act, bias=bias)\n        self.down_tr64 = DownTransition(spatial_dims, 32, 2, act, bias=bias)\n        self.down_tr128 = DownTransition(spatial_dims, 64, 3, act, dropout_prob=dropout_prob_down, bias=bias)\n        self.down_tr256 = DownTransition(spatial_dims, 128, 2, act, dropout_prob=dropout_prob_down, bias=bias)\n        self.up_tr256 = UpTransition(spatial_dims, 256, 256, 2, act, dropout_prob=dropout_prob_up)\n        self.up_tr128 = UpTransition(spatial_dims, 256, 128, 2, act, dropout_prob=dropout_prob_up)\n        self.up_tr64 = UpTransition(spatial_dims, 128, 64, 1, act)\n        self.up_tr32 = UpTransition(spatial_dims, 64, 32, 1, act)\n        self.out_tr = OutputTransition(spatial_dims, 32, out_channels, act, bias=bias)\n\n    def forward(self, x):\n        out16 = self.in_tr(x)\n        out32 = self.down_tr32(out16)\n        out64 = self.down_tr64(out32)\n        out128 = self.down_tr128(out64)\n        out256 = self.down_tr256(out128)\n        x = self.up_tr256(out256, out128)\n        x = self.up_tr128(x, out64)\n        x = self.up_tr64(x, out32)\n        x = self.up_tr32(x, out16)\n        x = self.out_tr(x)\n        return x\n"
  },
  {
    "path": "monai/networks/nets/voxelmorph.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.networks.blocks.convolutions import Convolution\nfrom monai.networks.blocks.upsample import UpSample\nfrom monai.networks.blocks.warp import DVF2DDF, Warp\nfrom monai.networks.layers.simplelayers import SkipConnection\n\n__all__ = [\"VoxelMorphUNet\", \"voxelmorphunet\", \"VoxelMorph\", \"voxelmorph\"]\n\n\nclass VoxelMorphUNet(nn.Module):\n    \"\"\"\n    The backbone network used in VoxelMorph. See :py:class:`monai.networks.nets.VoxelMorph` for more details.\n\n    A concatenated pair of images (moving and fixed) is first passed through a UNet. The output of the UNet is then\n    passed through a series of convolution blocks to produce the final prediction of the displacement field (DDF) or the\n    stationary velocity field (DVF).\n\n    In the original implementation, downsample is achieved through maxpooling, here one has the option to use either\n    maxpooling or strided convolution for downsampling. The default is to use maxpooling as it is consistent with the\n    original implementation. Note that for upsampling, the authors of VoxelMorph used nearest neighbor interpolation\n    instead of transposed convolution. In this implementation, only nearest neighbor interpolation is supported in order\n    to be consistent with the original implementation.\n\n    An instance of this class can be used as a backbone network for constructing a VoxelMorph network. See the\n    documentation of :py:class:`monai.networks.nets.VoxelMorph` for more details and an example on how to construct a\n    VoxelMorph network.\n\n    Args:\n        spatial_dims: number of spatial dimensions.\n        in_channels: number of channels in the input volume after concatenation of moving and fixed images.\n        unet_out_channels: number of channels in the output of the UNet.\n        channels: number of channels in each layer of the UNet. See the following example for more details.\n        final_conv_channels: number of channels in each layer of the final convolution block.\n        final_conv_act: activation type for the final convolution block. Defaults to LeakyReLU.\n            Since VoxelMorph was originally implemented in tensorflow where the default negative slope for\n            LeakyReLU was 0.2, we use the same default value here.\n        kernel_size: kernel size for all convolution layers in the UNet. Defaults to 3.\n        up_kernel_size: kernel size for all convolution layers in the upsampling path of the UNet. Defaults to 3.\n        act: activation type for all convolution layers in the UNet. Defaults to LeakyReLU with negative slope 0.2.\n        norm: feature normalization type and arguments for all convolution layers in the UNet. Defaults to None.\n        dropout: dropout ratio for all convolution layers in the UNet. Defaults to 0.0 (no dropout).\n        bias: whether to use bias in all convolution layers in the UNet. Defaults to True.\n        use_maxpool: whether to use maxpooling in the downsampling path of the UNet. Defaults to True.\n            Using maxpooling is the consistent with the original implementation of VoxelMorph.\n            But one can optionally use strided convolution instead (i.e. set `use_maxpool` to False).\n        adn_ordering: ordering of activation, dropout, and normalization. Defaults to \"NDA\".\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        unet_out_channels: int,\n        channels: Sequence[int],\n        final_conv_channels: Sequence[int],\n        final_conv_act: tuple | str | None = \"LEAKYRELU\",\n        kernel_size: Sequence[int] | int = 3,\n        up_kernel_size: Sequence[int] | int = 3,\n        act: tuple | str = \"LEAKYRELU\",\n        norm: tuple | str | None = None,\n        dropout: float = 0.0,\n        bias: bool = True,\n        use_maxpool: bool = True,\n        adn_ordering: str = \"NDA\",\n    ) -> None:\n        super().__init__()\n\n        if spatial_dims not in (2, 3):\n            raise ValueError(\"spatial_dims must be either 2 or 3.\")\n        if in_channels % 2 != 0:\n            raise ValueError(\"in_channels must be divisible by 2.\")\n        if len(channels) < 2:\n            raise ValueError(\"the length of `channels` should be no less than 2.\")\n        if len(channels) % 2 != 0:\n            raise ValueError(\"the elements of `channels` should be specified in pairs.\")\n        if isinstance(kernel_size, Sequence) and len(kernel_size) != spatial_dims:\n            raise ValueError(\"the length of `kernel_size` should equal to `dimensions`.\")\n        if isinstance(up_kernel_size, Sequence) and len(up_kernel_size) != spatial_dims:\n            raise ValueError(\"the length of `up_kernel_size` should equal to `dimensions`.\")\n\n        # UNet args\n        self.dimensions = spatial_dims\n        self.in_channels = in_channels\n        self.unet_out_channels = unet_out_channels\n        self.channels = channels\n        self.kernel_size = kernel_size\n        self.up_kernel_size = up_kernel_size\n        self.act = (\n            (\"leakyrelu\", {\"negative_slope\": 0.2, \"inplace\": True})\n            if isinstance(act, str) and act.upper() == \"LEAKYRELU\"\n            else act\n        )\n        self.norm = norm\n        self.dropout = dropout\n        self.bias = bias\n        self.use_maxpool = use_maxpool\n        self.adn_ordering = adn_ordering\n\n        # final convolutions args\n        self.final_conv_channels = final_conv_channels\n        self.final_conv_act = (\n            (\"leakyrelu\", {\"negative_slope\": 0.2, \"inplace\": True})\n            if isinstance(final_conv_act, str) and final_conv_act.upper() == \"LEAKYRELU\"\n            else final_conv_act\n        )\n\n        def _create_block(inc: int, outc: int, channels: Sequence[int], is_top: bool) -> nn.Module:\n            \"\"\"\n            Builds the UNet structure recursively.\n\n            Args:\n                inc: number of input channels.\n                outc: number of output channels.\n                channels: sequence of channels for each pair of down and up layers.\n                is_top: True if this is the top block.\n            \"\"\"\n\n            next_c_in, next_c_out = channels[0:2]\n            upc = next_c_in + next_c_out\n\n            subblock: nn.Module\n\n            if len(channels) > 2:\n                subblock = _create_block(next_c_in, next_c_out, channels[2:], is_top=False)  # continue recursion down\n            else:\n                # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer\n                subblock = self._get_bottom_layer(next_c_in, next_c_out)\n\n            down = self._get_down_layer(inc, next_c_in, is_top)  # create layer in downsampling path\n            up = self._get_up_layer(upc, outc, is_top)  # create layer in upsampling path\n\n            return self._get_connection_block(down, up, subblock)\n\n        def _create_final_conv(inc: int, outc: int, channels: Sequence[int]) -> nn.Module:\n            \"\"\"\n            Builds the final convolution blocks.\n\n            Args:\n                inc: number of input channels, should be the same as `unet_out_channels`.\n                outc: number of output channels, should be the same as `spatial_dims`.\n                channels: sequence of channels for each convolution layer.\n\n            Note: there is no activation after the last convolution layer as per the original implementation.\n            \"\"\"\n\n            mod: nn.Module = nn.Sequential()\n\n            for i, c in enumerate(channels):\n                mod.add_module(\n                    f\"final_conv_{i}\",\n                    Convolution(\n                        self.dimensions,\n                        inc,\n                        c,\n                        kernel_size=self.kernel_size,\n                        act=self.final_conv_act,\n                        norm=self.norm,\n                        dropout=self.dropout,\n                        bias=self.bias,\n                        adn_ordering=self.adn_ordering,\n                    ),\n                )\n                inc = c\n\n            mod.add_module(\n                \"final_conv_out\",\n                Convolution(\n                    self.dimensions,\n                    inc,\n                    outc,\n                    kernel_size=self.kernel_size,\n                    act=None,\n                    norm=self.norm,\n                    dropout=self.dropout,\n                    bias=self.bias,\n                    adn_ordering=self.adn_ordering,\n                ),\n            )\n\n            return mod\n\n        self.net = nn.Sequential(\n            _create_block(in_channels, unet_out_channels, self.channels, is_top=True),\n            _create_final_conv(unet_out_channels, self.dimensions, self.final_conv_channels),\n        )\n\n    def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:\n        \"\"\"\n        Returns the block object defining a layer of the UNet structure including the implementation of the skip\n        between encoding (down) and decoding (up) sides of the network.\n\n        Args:\n            down_path: encoding half of the layer\n            up_path: decoding half of the layer\n            subblock: block defining the next layer in the network.\n\n        Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)`\n        \"\"\"\n\n        return nn.Sequential(down_path, SkipConnection(subblock), up_path)\n\n    def _get_down_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn.Module:\n        \"\"\"\n        In each down layer, the input is first downsampled using maxpooling,\n        then passed through a convolution block, unless this is the top layer\n        in which case the input is passed through a convolution block only\n        without maxpooling first.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            is_top: True if this is the top block.\n        \"\"\"\n\n        mod: Convolution | nn.Sequential\n\n        strides = 1 if self.use_maxpool or is_top else 2\n\n        mod = Convolution(\n            self.dimensions,\n            in_channels,\n            out_channels,\n            strides=strides,\n            kernel_size=self.kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n            adn_ordering=self.adn_ordering,\n        )\n\n        if self.use_maxpool and not is_top:\n            mod = (\n                nn.Sequential(nn.MaxPool3d(kernel_size=2, stride=2), mod)\n                if self.dimensions == 3\n                else nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), mod)\n            )\n\n        return mod\n\n    def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module:\n        \"\"\"\n        Bottom layer (bottleneck) in voxelmorph consists of a typical down layer followed by an upsample layer.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n        \"\"\"\n\n        mod: nn.Module\n        upsample: nn.Module\n\n        mod = self._get_down_layer(in_channels, out_channels, is_top=False)\n\n        upsample = UpSample(\n            self.dimensions,\n            out_channels,\n            out_channels,\n            scale_factor=2,\n            mode=\"nontrainable\",\n            interp_mode=\"nearest\",\n            align_corners=None,  # required to use with interp_mode=\"nearest\"\n        )\n\n        return nn.Sequential(mod, upsample)\n\n    def _get_up_layer(self, in_channels: int, out_channels: int, is_top: bool) -> nn.Module:\n        \"\"\"\n        In each up layer, the input is passed through a convolution block before upsampled,\n        unless this is the top layer in which case the input is passed through a convolution block only\n        without upsampling.\n\n        Args:\n            in_channels: number of input channels.\n            out_channels: number of output channels.\n            is_top: True if this is the top block.\n        \"\"\"\n\n        mod: Convolution | nn.Sequential\n\n        strides = 1\n\n        mod = Convolution(\n            self.dimensions,\n            in_channels,\n            out_channels,\n            strides=strides,\n            kernel_size=self.up_kernel_size,\n            act=self.act,\n            norm=self.norm,\n            dropout=self.dropout,\n            bias=self.bias,\n            # conv_only=is_top,\n            is_transposed=False,\n            adn_ordering=self.adn_ordering,\n        )\n\n        if not is_top:\n            mod = nn.Sequential(\n                mod,\n                UpSample(\n                    self.dimensions,\n                    out_channels,\n                    out_channels,\n                    scale_factor=2,\n                    mode=\"nontrainable\",\n                    interp_mode=\"nearest\",\n                    align_corners=None,  # required to use with interp_mode=\"nearest\"\n                ),\n            )\n\n        return mod\n\n    def forward(self, concatenated_pairs: torch.Tensor) -> torch.Tensor:\n        x = self.net(concatenated_pairs)\n        return x  # type: ignore\n\n\nvoxelmorphunet = VoxelMorphUNet\n\n\nclass VoxelMorph(nn.Module):\n    \"\"\"\n    A re-implementation of VoxelMorph framework for medical image registration as described in\n    https://arxiv.org/pdf/1809.05231.pdf. For more details, please refer to VoxelMorph: A Learning Framework for\n    Deformable Medical Image Registration, Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca\n    IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.\n\n    This class is intended to be a general framework, based on which a deformable image registration\n    network can be built. Given a user-specified backbone network (e.g., UNet in the original VoxelMorph paper), this\n    class serves as a wrapper that concatenates the input pair of moving and fixed images, passes through the backbone\n    network, integrate the predicted stationary velocity field (DVF) from the backbone network to obtain the\n    displacement field (DDF), and, finally, warp the moving image using the DDF.\n\n    To construct a VoxelMorph network, one need to first construct a backbone network\n    (e.g., a :py:class:`monai.networks.nets.VoxelMorphUNet`) and pass it to the constructor of\n    :py:class:`monai.networks.nets.VoxelMorph`. The backbone network should be able to take a pair of moving and fixed\n    images as input and produce a DVF (or DDF, details to be discussed later) as output.\n\n    When `forward` is called, the input moving and fixed images are first concatenated along the channel dimension and\n    passed through the specified backbone network to produce the prediction of the displacement field (DDF) in the\n    non-diffeomorphic variant (i.e. when `integration_steps` is set to 0) or the stationary velocity field (DVF) in the\n    diffeomorphic variant (i.e. when `integration_steps` is set to a positive integer). The DVF is then integrated using\n    a scaling-and-squaring approach via a :py:class:`monai.networks.blocks.warp.DVF2DDF` module to produce the DDF.\n    Finally, the DDF is used to warp the moving image to the fixed image using a\n    :py:class:`monai.networks.blocks.warp.Warp` module. Optionally, the integration from DVF to DDF can be\n    performed on reduced resolution by specifying `half_res` to be True, in which case the output DVF from the backbone\n    network is first linearly interpolated to half resolution before integration. The output DDF is then linearly\n    interpolated again back to full resolution before being used to warp the moving image.\n\n    Args:\n        backbone: a backbone network.\n        integration_steps: number of integration steps used for obtaining DDF from DVF via scaling-and-squaring.\n            Defaults to 7. If set to 0, the network will be non-diffeomorphic.\n        half_res: whether to perform integration on half resolution. Defaults to False.\n        spatial_dims: number of spatial dimensions, defaults to 3.\n\n    Example::\n\n        from monai.networks.nets import VoxelMorphUNet, VoxelMorph\n\n        # The following example construct an instance of VoxelMorph that matches the original VoxelMorph paper\n        # https://arxiv.org/pdf/1809.05231.pdf\n\n        # First, a backbone network is constructed. In this case, we use a VoxelMorphUNet as the backbone network.\n        backbone = VoxelMorphUNet(\n            spatial_dims=3,\n            in_channels=2,\n            unet_out_channels=32,\n            channels=(16, 32, 32, 32, 32, 32),  # this indicates the down block at the top takes 16 channels as\n                                                # input, the corresponding up block at the top produces 32\n                                                # channels as output, the second down block takes 32 channels as\n                                                # input, and the corresponding up block at the same level\n                                                # produces 32 channels as output, etc.\n            final_conv_channels=(16, 16)\n        )\n\n        # Then, a full VoxelMorph network is constructed using the specified backbone network.\n        net = VoxelMorph(\n            backbone=backbone,\n            integration_steps=7,\n            half_res=False\n        )\n\n        # A forward pass through the network would look something like this\n        moving = torch.randn(1, 1, 160, 192, 224)\n        fixed = torch.randn(1, 1, 160, 192, 224)\n        warped, ddf = net(moving, fixed)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        backbone: VoxelMorphUNet | nn.Module | None = None,\n        integration_steps: int = 7,\n        half_res: bool = False,\n        spatial_dims: int = 3,\n    ) -> None:\n        super().__init__()\n\n        # specified backbone network\n        self.backbone = (\n            backbone\n            if backbone is not None\n            else VoxelMorphUNet(\n                spatial_dims=spatial_dims,\n                in_channels=2,\n                unet_out_channels=32,\n                channels=(16, 32, 32, 32, 32, 32),\n                final_conv_channels=(16, 16),\n            )\n        )\n\n        # helper attributes\n        self.spatial_dims = spatial_dims\n        self.half_res = half_res\n        self.integration_steps = integration_steps\n        self.diffeomorphic = True if self.integration_steps > 0 else False\n\n        # create helpers\n        if self.diffeomorphic:\n            self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode=\"bilinear\", padding_mode=\"zeros\")\n        self.warp = Warp(mode=\"bilinear\", padding_mode=\"zeros\")\n\n    def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        if moving.shape != fixed.shape:\n            raise ValueError(\n                \"The spatial shape of the moving image should be the same as the spatial shape of the fixed image.\"\n                f\" Got {moving.shape} and {fixed.shape} instead.\"\n            )\n\n        x = self.backbone(torch.cat([moving, fixed], dim=1))\n\n        if x.shape[1] != self.spatial_dims:\n            raise ValueError(\n                \"The number of channels in the output of the backbone network should be equal to the\"\n                f\" number of spatial dimensions {self.spatial_dims}. Got {x.shape[1]} channels instead.\"\n            )\n\n        if x.shape[2:] != moving.shape[2:]:\n            raise ValueError(\n                \"The spatial shape of the output of the backbone network should be equal to the\"\n                f\" spatial shape of the input images. Got {x.shape[2:]} instead of {moving.shape[2:]}.\"\n            )\n\n        if self.half_res:\n            x = F.interpolate(x, scale_factor=0.5, mode=\"trilinear\", align_corners=True) * 2.0\n\n        if self.diffeomorphic:\n            x = self.dvf2ddf(x)\n\n        if self.half_res:\n            x = F.interpolate(x * 0.5, scale_factor=2.0, mode=\"trilinear\", align_corners=True)\n\n        return self.warp(moving, x), x\n\n\nvoxelmorph = VoxelMorph\n"
  },
  {
    "path": "monai/networks/nets/vqvae.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.networks.blocks import Convolution\nfrom monai.networks.layers import Act\nfrom monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer\nfrom monai.utils import ensure_tuple_rep\n\n__all__ = [\"VQVAE\"]\n\n\nclass VQVAEResidualUnit(nn.Module):\n    \"\"\"\n    Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving\n    Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf).\n\n    The original implementation that can be found at\n    https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150.\n\n    Args:\n        spatial_dims: number of spatial spatial_dims of the input data.\n        in_channels: number of input channels.\n        num_res_channels: number of channels in the residual layers.\n        act: activation type and arguments. Defaults to RELU.\n        dropout: dropout ratio. Defaults to no dropout.\n        bias: whether to have a bias term. Defaults to True.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        num_res_channels: int,\n        act: tuple | str | None = Act.RELU,\n        dropout: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.num_res_channels = num_res_channels\n        self.act = act\n        self.dropout = dropout\n        self.bias = bias\n\n        self.conv1 = Convolution(\n            spatial_dims=self.spatial_dims,\n            in_channels=self.in_channels,\n            out_channels=self.num_res_channels,\n            adn_ordering=\"DA\",\n            act=self.act,\n            dropout=self.dropout,\n            bias=self.bias,\n        )\n\n        self.conv2 = Convolution(\n            spatial_dims=self.spatial_dims,\n            in_channels=self.num_res_channels,\n            out_channels=self.in_channels,\n            bias=self.bias,\n            conv_only=True,\n        )\n\n    def forward(self, x):\n        return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True)\n\n\nclass Encoder(nn.Module):\n    \"\"\"\n    Encoder module for VQ-VAE.\n\n    Args:\n        spatial_dims: number of spatial spatial_dims.\n        in_channels: number of input channels.\n        out_channels: number of channels in the latent space (embedding_dim).\n        channels: sequence containing the number of channels at each level of the encoder.\n        num_res_layers: number of sequential residual layers at each level.\n        num_res_channels: number of channels in the residual layers at each level.\n        downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the\n            following information stride (int), kernel_size (int), dilation (int) and padding (int).\n        dropout: dropout ratio.\n        act: activation type and arguments.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        channels: Sequence[int],\n        num_res_layers: int,\n        num_res_channels: Sequence[int],\n        downsample_parameters: Sequence[tuple[int, int, int, int]],\n        dropout: float,\n        act: tuple | str | None,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.channels = channels\n        self.num_res_layers = num_res_layers\n        self.num_res_channels = num_res_channels\n        self.downsample_parameters = downsample_parameters\n        self.dropout = dropout\n        self.act = act\n\n        blocks: list[nn.Module] = []\n\n        for i in range(len(self.channels)):\n            blocks.append(\n                Convolution(\n                    spatial_dims=self.spatial_dims,\n                    in_channels=self.in_channels if i == 0 else self.channels[i - 1],\n                    out_channels=self.channels[i],\n                    strides=self.downsample_parameters[i][0],\n                    kernel_size=self.downsample_parameters[i][1],\n                    adn_ordering=\"DA\",\n                    act=self.act,\n                    dropout=None if i == 0 else self.dropout,\n                    dropout_dim=1,\n                    dilation=self.downsample_parameters[i][2],\n                    padding=self.downsample_parameters[i][3],\n                )\n            )\n\n            for _ in range(self.num_res_layers):\n                blocks.append(\n                    VQVAEResidualUnit(\n                        spatial_dims=self.spatial_dims,\n                        in_channels=self.channels[i],\n                        num_res_channels=self.num_res_channels[i],\n                        act=self.act,\n                        dropout=self.dropout,\n                    )\n                )\n\n        blocks.append(\n            Convolution(\n                spatial_dims=self.spatial_dims,\n                in_channels=self.channels[len(self.channels) - 1],\n                out_channels=self.out_channels,\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            x = block(x)\n        return x\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Decoder module for VQ-VAE.\n\n    Args:\n        spatial_dims: number of spatial spatial_dims.\n        in_channels: number of channels in the latent space (embedding_dim).\n        out_channels: number of output channels.\n        channels: sequence containing the number of channels at each level of the decoder.\n        num_res_layers: number of sequential residual layers at each level.\n        num_res_channels: number of channels in the residual layers at each level.\n        upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the\n            following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).\n        dropout: dropout ratio.\n        act: activation type and arguments.\n        output_act: activation type and arguments for the output.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        channels: Sequence[int],\n        num_res_layers: int,\n        num_res_channels: Sequence[int],\n        upsample_parameters: Sequence[tuple[int, int, int, int, int]],\n        dropout: float,\n        act: tuple | str | None,\n        output_act: tuple | str | None,\n    ) -> None:\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.channels = channels\n        self.num_res_layers = num_res_layers\n        self.num_res_channels = num_res_channels\n        self.upsample_parameters = upsample_parameters\n        self.dropout = dropout\n        self.act = act\n        self.output_act = output_act\n\n        reversed_num_channels = list(reversed(self.channels))\n\n        blocks: list[nn.Module] = []\n        blocks.append(\n            Convolution(\n                spatial_dims=self.spatial_dims,\n                in_channels=self.in_channels,\n                out_channels=reversed_num_channels[0],\n                strides=1,\n                kernel_size=3,\n                padding=1,\n                conv_only=True,\n            )\n        )\n\n        reversed_num_res_channels = list(reversed(self.num_res_channels))\n        for i in range(len(self.channels)):\n            for _ in range(self.num_res_layers):\n                blocks.append(\n                    VQVAEResidualUnit(\n                        spatial_dims=self.spatial_dims,\n                        in_channels=reversed_num_channels[i],\n                        num_res_channels=reversed_num_res_channels[i],\n                        act=self.act,\n                        dropout=self.dropout,\n                    )\n                )\n\n            blocks.append(\n                Convolution(\n                    spatial_dims=self.spatial_dims,\n                    in_channels=reversed_num_channels[i],\n                    out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1],\n                    strides=self.upsample_parameters[i][0],\n                    kernel_size=self.upsample_parameters[i][1],\n                    adn_ordering=\"DA\",\n                    act=self.act,\n                    dropout=self.dropout if i != len(self.channels) - 1 else None,\n                    norm=None,\n                    dilation=self.upsample_parameters[i][2],\n                    conv_only=i == len(self.channels) - 1,\n                    is_transposed=True,\n                    padding=self.upsample_parameters[i][3],\n                    output_padding=self.upsample_parameters[i][4],\n                )\n            )\n\n        if self.output_act:\n            blocks.append(Act[self.output_act]())\n\n        self.blocks = nn.ModuleList(blocks)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        for block in self.blocks:\n            x = block(x)\n        return x\n\n\nclass VQVAE(nn.Module):\n    \"\"\"\n    Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative\n    Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf)\n\n    The original implementation can be found at\n    https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/\n\n    Args:\n        spatial_dims: number of spatial spatial_dims.\n        in_channels: number of input channels.\n        out_channels: number of output channels.\n        downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the\n            following information stride (int), kernel_size (int), dilation (int) and padding (int).\n        upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the\n            following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).\n        num_res_layers: number of sequential residual layers at each level.\n        channels: number of channels at each level.\n        num_res_channels: number of channels in the residual layers at each level.\n        num_embeddings: VectorQuantization number of atomic elements in the codebook.\n        embedding_dim: VectorQuantization number of channels of the input and atomic elements.\n        commitment_cost: VectorQuantization commitment_cost.\n        decay: VectorQuantization decay.\n        epsilon: VectorQuantization epsilon.\n        act: activation type and arguments.\n        dropout: dropout ratio.\n        output_act: activation type and arguments for the output.\n        ddp_sync: whether to synchronize the codebook across processes.\n        use_checkpointing if True, use activation checkpointing to save memory.\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_dims: int,\n        in_channels: int,\n        out_channels: int,\n        channels: Sequence[int] = (96, 96, 192),\n        num_res_layers: int = 3,\n        num_res_channels: Sequence[int] | int = (96, 96, 192),\n        downsample_parameters: Sequence[tuple[int, int, int, int]] | tuple[int, int, int, int] = (\n            (2, 4, 1, 1),\n            (2, 4, 1, 1),\n            (2, 4, 1, 1),\n        ),\n        upsample_parameters: Sequence[tuple[int, int, int, int, int]] | tuple[int, int, int, int, int] = (\n            (2, 4, 1, 1, 0),\n            (2, 4, 1, 1, 0),\n            (2, 4, 1, 1, 0),\n        ),\n        num_embeddings: int = 32,\n        embedding_dim: int = 64,\n        embedding_init: str = \"normal\",\n        commitment_cost: float = 0.25,\n        decay: float = 0.5,\n        epsilon: float = 1e-5,\n        dropout: float = 0.0,\n        act: tuple | str | None = Act.RELU,\n        output_act: tuple | str | None = None,\n        ddp_sync: bool = True,\n        use_checkpointing: bool = False,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.spatial_dims = spatial_dims\n        self.channels = channels\n        self.num_embeddings = num_embeddings\n        self.embedding_dim = embedding_dim\n        self.use_checkpointing = use_checkpointing\n\n        if isinstance(num_res_channels, int):\n            num_res_channels = ensure_tuple_rep(num_res_channels, len(channels))\n\n        if len(num_res_channels) != len(channels):\n            raise ValueError(\n                \"`num_res_channels` should be a single integer or a tuple of integers with the same length as \"\n                \"`num_channls`.\"\n            )\n        if all(isinstance(values, int) for values in upsample_parameters):\n            upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels)\n        else:\n            upsample_parameters_tuple = upsample_parameters\n\n        if all(isinstance(values, int) for values in downsample_parameters):\n            downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels)\n        else:\n            downsample_parameters_tuple = downsample_parameters\n\n        if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple):\n            raise ValueError(\"`downsample_parameters` should be a single tuple of integer or a tuple of tuples.\")\n\n        # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints\n        if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple):\n            raise ValueError(\"`upsample_parameters` should be a single tuple of integer or a tuple of tuples.\")\n\n        for parameter in downsample_parameters_tuple:\n            if len(parameter) != 4:\n                raise ValueError(\"`downsample_parameters` should be a tuple of tuples with 4 integers.\")\n\n        for parameter in upsample_parameters_tuple:\n            if len(parameter) != 5:\n                raise ValueError(\"`upsample_parameters` should be a tuple of tuples with 5 integers.\")\n\n        if len(downsample_parameters_tuple) != len(channels):\n            raise ValueError(\n                \"`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`.\"\n            )\n\n        if len(upsample_parameters_tuple) != len(channels):\n            raise ValueError(\n                \"`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`.\"\n            )\n\n        self.num_res_layers = num_res_layers\n        self.num_res_channels = num_res_channels\n\n        self.encoder = Encoder(\n            spatial_dims=spatial_dims,\n            in_channels=in_channels,\n            out_channels=embedding_dim,\n            channels=channels,\n            num_res_layers=num_res_layers,\n            num_res_channels=num_res_channels,\n            downsample_parameters=downsample_parameters_tuple,\n            dropout=dropout,\n            act=act,\n        )\n\n        self.decoder = Decoder(\n            spatial_dims=spatial_dims,\n            in_channels=embedding_dim,\n            out_channels=out_channels,\n            channels=channels,\n            num_res_layers=num_res_layers,\n            num_res_channels=num_res_channels,\n            upsample_parameters=upsample_parameters_tuple,\n            dropout=dropout,\n            act=act,\n            output_act=output_act,\n        )\n\n        self.quantizer = VectorQuantizer(\n            quantizer=EMAQuantizer(\n                spatial_dims=spatial_dims,\n                num_embeddings=num_embeddings,\n                embedding_dim=embedding_dim,\n                commitment_cost=commitment_cost,\n                decay=decay,\n                epsilon=epsilon,\n                embedding_init=embedding_init,\n                ddp_sync=ddp_sync,\n            )\n        )\n\n    def encode(self, images: torch.Tensor) -> torch.Tensor:\n        output: torch.Tensor\n        if self.use_checkpointing:\n            output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False)\n        else:\n            output = self.encoder(images)\n        return output\n\n    def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        x_loss, x = self.quantizer(encodings)\n        return x, x_loss\n\n    def decode(self, quantizations: torch.Tensor) -> torch.Tensor:\n        output: torch.Tensor\n\n        if self.use_checkpointing:\n            output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False)\n        else:\n            output = self.decoder(quantizations)\n        return output\n\n    def index_quantize(self, images: torch.Tensor) -> torch.Tensor:\n        return self.quantizer.quantize(self.encode(images=images))\n\n    def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor:\n        return self.decode(self.quantizer.embed(embedding_indices))\n\n    def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:\n        quantizations, quantization_losses = self.quantize(self.encode(images))\n        reconstruction = self.decode(quantizations)\n\n        return reconstruction, quantization_losses\n\n    def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:\n        z = self.encode(x)\n        e, _ = self.quantize(z)\n        return e\n\n    def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:\n        e, _ = self.quantize(z)\n        image = self.decode(e)\n        return image\n"
  },
  {
    "path": "monai/networks/schedulers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .ddim import DDIMScheduler\nfrom .ddpm import DDPMScheduler\nfrom .pndm import PNDMScheduler\nfrom .rectified_flow import RFlowScheduler\nfrom .scheduler import NoiseSchedules, Scheduler\n"
  },
  {
    "path": "monai/networks/schedulers/ddim.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nimport numpy as np\nimport torch\n\nfrom .ddpm import DDPMPredictionType\nfrom .scheduler import Scheduler\n\nDDIMPredictionType = DDPMPredictionType\n\n\nclass DDIMScheduler(Scheduler):\n    \"\"\"\n    Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising\n    diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. \"Denoising Diffusion\n    Implicit Models\" https://arxiv.org/abs/2010.02502\n\n    Args:\n        num_train_timesteps: number of diffusion steps used to train the model.\n        schedule: member of NoiseSchedules, name of noise schedule function in component store\n        clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.\n        set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.\n            For the final step there is no previous alpha. When this option is `True` the previous alpha product is\n            fixed to `1`, otherwise it uses the value of alpha at step 0.\n        steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and\n            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in\n            stable diffusion.\n        prediction_type: member of DDPMPredictionType\n        clip_sample_min: minimum clipping value when clip_sample equals True\n        clip_sample_max: maximum clipping value when clip_sample equals True\n        schedule_args: arguments to pass to the schedule function\n\n    \"\"\"\n\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        schedule: str = \"linear_beta\",\n        clip_sample: bool = True,\n        set_alpha_to_one: bool = True,\n        steps_offset: int = 0,\n        prediction_type: str = DDIMPredictionType.EPSILON,\n        clip_sample_min: float = -1.0,\n        clip_sample_max: float = 1.0,\n        **schedule_args,\n    ) -> None:\n        super().__init__(num_train_timesteps, schedule, **schedule_args)\n\n        if prediction_type not in DDIMPredictionType.__members__.values():\n            raise ValueError(\"Argument `prediction_type` must be a member of DDIMPredictionType\")\n\n        self.prediction_type = prediction_type\n\n        # At every step in ddim, we are looking into the previous alphas_cumprod\n        # For the final step, there is no previous alphas_cumprod because we are already at 0\n        # `set_alpha_to_one` decides whether we set this parameter simply to one or\n        # whether we use the final alpha of the \"non-previous\" one.\n        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))\n\n        self.clip_sample = clip_sample\n        self.clip_sample_values = [clip_sample_min, clip_sample_max]\n        self.steps_offset = steps_offset\n\n        # default the number of inference timesteps to the number of train steps\n        self.num_inference_steps: int\n        self.set_timesteps(self.num_train_timesteps)\n\n    def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.\n\n        Args:\n            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.\n            device: target device to put the data.\n        \"\"\"\n        if num_inference_steps > self.num_train_timesteps:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:\"\n                f\" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n        if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps:\n            raise ValueError(f\"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).\")\n\n        self.timesteps = (\n            torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device)\n            .round()\n            .long()\n        )\n        self.timesteps += self.steps_offset\n\n    def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)\n\n        return variance\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int,\n        sample: torch.Tensor,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output: direct output from learned diffusion model.\n            timestep: current discrete timestep in the diffusion chain.\n            sample: current instance of sample being created by diffusion process.\n            eta: weight of noise for added noise in diffusion step.\n            generator: random number generator.\n\n        Returns:\n            pred_prev_sample: Predicted previous sample\n            pred_original_sample: Predicted original sample\n        \"\"\"\n        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf\n        # Ideally, read DDIM paper in-detail understanding\n\n        # Notation (<variable name> -> <name in paper>\n        # - model_output -> e_theta(x_t, t)\n        # - pred_original_sample -> f_theta(x_t, t) or x_0\n        # - std_dev_t -> sigma_t\n        # - eta -> η\n        # - pred_sample_direction -> \"direction pointing to x_t\"\n        # - pred_prev_sample -> \"x_t-1\"\n\n        # 1. get previous step value (=t-1)\n        prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps\n\n        # 2. compute alphas, betas\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n\n        beta_prod_t = 1 - alpha_prod_t\n\n        # predefinitions satisfy pylint/mypy, these values won't be ultimately used\n        pred_original_sample = sample\n        pred_epsilon = model_output\n\n        # 3. compute predicted original sample from predicted noise also called\n        # \"predicted x_0\" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf\n        if self.prediction_type == DDIMPredictionType.EPSILON:\n            pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)\n            pred_epsilon = model_output\n        elif self.prediction_type == DDIMPredictionType.SAMPLE:\n            pred_original_sample = model_output\n            pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5)\n        elif self.prediction_type == DDIMPredictionType.V_PREDICTION:\n            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample\n\n        # 4. Clip \"predicted x_0\"\n        if self.clip_sample:\n            pred_original_sample = torch.clamp(\n                pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]\n            )\n\n        # 5. compute variance: \"sigma_t(η)\" -> see formula (16)\n        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)\n        variance = self._get_variance(timestep, prev_timestep)\n        std_dev_t = eta * variance**0.5\n\n        # 6. compute \"direction pointing to x_t\" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf\n        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon\n\n        # 7. compute x_t-1 without \"random noise\" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf\n        pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction\n\n        if eta > 0:\n            # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072\n            device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else \"cpu\")\n            noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device)\n            variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise\n\n            pred_prev_sample = pred_prev_sample + variance\n\n        return pred_prev_sample, pred_original_sample\n\n    def reversed_step(\n        self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output: direct output from learned diffusion model.\n            timestep: current discrete timestep in the diffusion chain.\n            sample: current instance of sample being created by diffusion process.\n\n        Returns:\n            pred_prev_sample: Predicted previous sample\n            pred_original_sample: Predicted original sample\n        \"\"\"\n        # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf\n\n        # Notation (<variable name> -> <name in paper>\n        # - model_output -> e_theta(x_t, t)\n        # - pred_original_sample -> f_theta(x_t, t) or x_0\n        # - std_dev_t -> sigma_t\n        # - eta -> η\n        # - pred_sample_direction -> \"direction pointing to x_t\"\n        # - pred_post_sample -> \"x_t+1\"\n\n        # 1. get previous step value (=t+1)\n        prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps\n\n        # 2. compute alphas, betas at timestep t+1\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n\n        beta_prod_t = 1 - alpha_prod_t\n\n        # predefinitions satisfy pylint/mypy, these values won't be ultimately used\n        pred_original_sample = sample\n        pred_epsilon = model_output\n\n        # 3. compute predicted original sample from predicted noise also called\n        # \"predicted x_0\" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf\n\n        if self.prediction_type == DDIMPredictionType.EPSILON:\n            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n            pred_epsilon = model_output\n        elif self.prediction_type == DDIMPredictionType.SAMPLE:\n            pred_original_sample = model_output\n            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)\n        elif self.prediction_type == DDIMPredictionType.V_PREDICTION:\n            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample\n\n        # 4. Clip \"predicted x_0\"\n        if self.clip_sample:\n            pred_original_sample = torch.clamp(\n                pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]\n            )\n\n        # 5. compute \"direction pointing to x_t\" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf\n        pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon\n\n        # 6. compute x_t+1 without \"random noise\" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf\n        pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction\n\n        return pred_post_sample, pred_original_sample\n"
  },
  {
    "path": "monai/networks/schedulers/ddpm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.utils import StrEnum\n\nfrom .scheduler import Scheduler\n\n\nclass DDPMVarianceType(StrEnum):\n    \"\"\"\n    Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise\n    to the denoised sample.\n    \"\"\"\n\n    FIXED_SMALL = \"fixed_small\"\n    FIXED_LARGE = \"fixed_large\"\n    LEARNED = \"learned\"\n    LEARNED_RANGE = \"learned_range\"\n\n\nclass DDPMPredictionType(StrEnum):\n    \"\"\"\n    Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument.\n\n    epsilon: predicting the noise of the diffusion process\n    sample: directly predicting the noisy sample\n    v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf\n    \"\"\"\n\n    EPSILON = \"epsilon\"\n    SAMPLE = \"sample\"\n    V_PREDICTION = \"v_prediction\"\n\n\nclass DDPMScheduler(Scheduler):\n    \"\"\"\n    Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and\n    Langevin dynamics sampling. Based on: Ho et al., \"Denoising Diffusion Probabilistic Models\"\n    https://arxiv.org/abs/2006.11239\n\n    Args:\n        num_train_timesteps: number of diffusion steps used to train the model.\n        schedule: member of NoiseSchedules, name of noise schedule function in component store\n        variance_type: member of DDPMVarianceType\n        clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.\n        prediction_type: member of DDPMPredictionType\n        clip_sample_min: minimum clipping value when clip_sample equals True\n        clip_sample_max: maximum clipping value when clip_sample equals True\n        schedule_args: arguments to pass to the schedule function\n    \"\"\"\n\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        schedule: str = \"linear_beta\",\n        variance_type: str = DDPMVarianceType.FIXED_SMALL,\n        clip_sample: bool = True,\n        prediction_type: str = DDPMPredictionType.EPSILON,\n        clip_sample_min: float = -1.0,\n        clip_sample_max: float = 1.0,\n        **schedule_args,\n    ) -> None:\n        super().__init__(num_train_timesteps, schedule, **schedule_args)\n\n        if variance_type not in DDPMVarianceType.__members__.values():\n            raise ValueError(\"Argument `variance_type` must be a member of `DDPMVarianceType`\")\n\n        if prediction_type not in DDPMPredictionType.__members__.values():\n            raise ValueError(\"Argument `prediction_type` must be a member of `DDPMPredictionType`\")\n\n        self.clip_sample = clip_sample\n        self.clip_sample_values = [clip_sample_min, clip_sample_max]\n        self.variance_type = variance_type\n        self.prediction_type = prediction_type\n\n    def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.\n\n        Args:\n            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.\n            device: target device to put the data.\n        \"\"\"\n        if num_inference_steps > self.num_train_timesteps:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:\"\n                f\" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n        self.timesteps = (\n            torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long()\n        )\n\n    def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Compute the mean of the posterior at timestep t.\n\n        Args:\n            timestep: current timestep.\n            x0: the noise-free input.\n            x_t: the input noised to timestep t.\n\n        Returns:\n            Returns the mean\n        \"\"\"\n        # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0),\n        # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf)\n        alpha_t = self.alphas[timestep]\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one\n\n        x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t)\n        x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)\n\n        mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t\n\n        return mean\n\n    def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor:\n        \"\"\"\n        Compute the variance of the posterior at timestep t.\n\n        Args:\n            timestep: current timestep.\n            predicted_variance: variance predicted by the model.\n\n        Returns:\n            Returns the variance\n        \"\"\"\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one\n\n        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)\n        # and sample from it to get previous sample\n        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample\n        variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]\n        # hacks - were probably added for training stability\n        if self.variance_type == DDPMVarianceType.FIXED_SMALL:\n            variance = torch.clamp(variance, min=1e-20)\n        elif self.variance_type == DDPMVarianceType.FIXED_LARGE:\n            variance = self.betas[timestep]\n        elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None:\n            return predicted_variance\n        elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None:\n            min_log = variance\n            max_log = self.betas[timestep]\n            frac = (predicted_variance + 1) / 2\n            variance = frac * max_log + (1 - frac) * min_log\n\n        return variance\n\n    def step(\n        self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output: direct output from learned diffusion model.\n            timestep: current discrete timestep in the diffusion chain.\n            sample: current instance of sample being created by diffusion process.\n            generator: random number generator.\n\n        Returns:\n            pred_prev_sample: Predicted previous sample\n        \"\"\"\n        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [\"learned\", \"learned_range\"]:\n            model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)\n        else:\n            predicted_variance = None\n\n        # 1. compute alphas, betas\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        # 2. compute predicted original sample from predicted noise also called\n        # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n        if self.prediction_type == DDPMPredictionType.EPSILON:\n            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n        elif self.prediction_type == DDPMPredictionType.SAMPLE:\n            pred_original_sample = model_output\n        elif self.prediction_type == DDPMPredictionType.V_PREDICTION:\n            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n\n        # 3. Clip \"predicted x_0\"\n        if self.clip_sample:\n            pred_original_sample = torch.clamp(\n                pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]\n            )\n\n        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t\n        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf\n        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t\n        current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t\n\n        # 5. Compute predicted previous sample µ_t\n        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf\n        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample\n\n        # 6. Add noise\n        variance: torch.Tensor = torch.tensor(0)\n        if timestep > 0:\n            noise = torch.randn(\n                model_output.size(),\n                dtype=model_output.dtype,\n                layout=model_output.layout,\n                generator=generator,\n                device=model_output.device,\n            )\n            variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise\n\n        pred_prev_sample = pred_prev_sample + variance\n\n        return pred_prev_sample, pred_original_sample\n"
  },
  {
    "path": "monai/networks/schedulers/pndm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import StrEnum\n\nfrom .scheduler import Scheduler\n\n\nclass PNDMPredictionType(StrEnum):\n    \"\"\"\n    Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument.\n\n    epsilon: predicting the noise of the diffusion process\n    v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf\n    \"\"\"\n\n    EPSILON = \"epsilon\"\n    V_PREDICTION = \"v_prediction\"\n\n\nclass PNDMScheduler(Scheduler):\n    \"\"\"\n    Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,\n    namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al.,\n    \"Pseudo Numerical Methods for Diffusion Models on Manifolds\"  https://arxiv.org/abs/2202.09778\n\n    Args:\n        num_train_timesteps: number of diffusion steps used to train the model.\n        schedule: member of NoiseSchedules, name of noise schedule function in component store\n        skip_prk_steps:\n            allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required\n            before plms step.\n        set_alpha_to_one:\n            each diffusion step uses the value of alphas product at that step and at the previous one. For the final\n            step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,\n            otherwise it uses the value of alpha at step 0.\n        prediction_type: member of DDPMPredictionType\n        steps_offset:\n            an offset added to the inference steps. You can use a combination of `offset=1` and\n            `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in\n            stable diffusion.\n        schedule_args: arguments to pass to the schedule function\n    \"\"\"\n\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        schedule: str = \"linear_beta\",\n        skip_prk_steps: bool = False,\n        set_alpha_to_one: bool = False,\n        prediction_type: str = PNDMPredictionType.EPSILON,\n        steps_offset: int = 0,\n        **schedule_args,\n    ) -> None:\n        super().__init__(num_train_timesteps, schedule, **schedule_args)\n\n        if prediction_type not in PNDMPredictionType.__members__.values():\n            raise ValueError(\"Argument `prediction_type` must be a member of PNDMPredictionType\")\n\n        self.prediction_type = prediction_type\n\n        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        # For now we only support F-PNDM, i.e. the runge-kutta method\n        # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf\n        # mainly at formula (9), (12), (13) and the Algorithm 2.\n        self.pndm_order = 4\n\n        self.skip_prk_steps = skip_prk_steps\n        self.steps_offset = steps_offset\n\n        # running values\n        self.cur_model_output = torch.Tensor()\n        self.counter = 0\n        self.cur_sample = torch.Tensor()\n        self.ets: list = []\n\n        # default the number of inference timesteps to the number of train steps\n        self.set_timesteps(num_train_timesteps)\n\n    def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.\n\n        Args:\n            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.\n            device: target device to put the data.\n        \"\"\"\n        if num_inference_steps > self.num_train_timesteps:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:\"\n                f\" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n        step_ratio = self.num_train_timesteps // self.num_inference_steps\n        # creates integer timesteps by multiplying by ratio\n        # casting to int to avoid issues when num_inference_step is power of 3\n        self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)\n        self._timesteps += self.steps_offset\n\n        if self.skip_prk_steps:\n            # for some models like stable diffusion the prk steps can/should be skipped to\n            # produce better results. When using PNDM with `self.skip_prk_steps` the implementation\n            # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51\n            self.prk_timesteps = np.array([])\n            self.plms_timesteps = self._timesteps[::-1]\n\n        else:\n            prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(\n                np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order\n            )\n            self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()\n            self.plms_timesteps = self._timesteps[:-3][\n                ::-1\n            ].copy()  # we copy to avoid having negative strides which are not supported by torch.from_numpy\n\n        timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)\n        self.timesteps = torch.from_numpy(timesteps).to(device)\n        # update num_inference_steps - necessary if we use prk steps\n        self.num_inference_steps = len(self.timesteps)\n\n        self.ets = []\n        self.counter = 0\n\n    def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]:\n        \"\"\"\n        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion\n        process from the learned model outputs (most often the predicted noise).\n        This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.\n\n        Args:\n            model_output: direct output from learned diffusion model.\n            timestep: current discrete timestep in the diffusion chain.\n            sample: current instance of sample being created by diffusion process.\n        Returns:\n            pred_prev_sample: Predicted previous sample\n        \"\"\"\n        # return a tuple for consistency with samplers that return (previous pred, original sample pred)\n\n        if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps:\n            return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None\n        else:\n            return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None\n\n    def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the\n        solution to the differential equation.\n\n        Args:\n            model_output: direct output from learned diffusion model.\n            timestep: current discrete timestep in the diffusion chain.\n            sample: current instance of sample being created by diffusion process.\n\n        Returns:\n            pred_prev_sample: Predicted previous sample\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2\n        prev_timestep = timestep - diff_to_prev\n        timestep = self.prk_timesteps[self.counter // 4 * 4]\n\n        if self.counter % 4 == 0:\n            self.cur_model_output = 1 / 6 * model_output\n            self.ets.append(model_output)\n            self.cur_sample = sample\n        elif (self.counter - 1) % 4 == 0:\n            self.cur_model_output += 1 / 3 * model_output\n        elif (self.counter - 2) % 4 == 0:\n            self.cur_model_output += 1 / 3 * model_output\n        elif (self.counter - 3) % 4 == 0:\n            model_output = self.cur_model_output + 1 / 6 * model_output\n            self.cur_model_output = torch.Tensor()\n\n        # cur_sample should not be an empty torch.Tensor()\n        cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample\n\n        prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)\n        self.counter += 1\n\n        return prev_sample\n\n    def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any:\n        \"\"\"\n        Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple\n        times to approximate the solution.\n\n        Args:\n            model_output: direct output from learned diffusion model.\n            timestep: current discrete timestep in the diffusion chain.\n            sample: current instance of sample being created by diffusion process.\n\n        Returns:\n            pred_prev_sample: Predicted previous sample\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        if not self.skip_prk_steps and len(self.ets) < 3:\n            raise ValueError(\n                f\"{self.__class__} can only be run AFTER scheduler has been run \"\n                \"in 'prk' mode for at least 12 iterations \"\n            )\n\n        prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps\n\n        if self.counter != 1:\n            self.ets = self.ets[-3:]\n            self.ets.append(model_output)\n        else:\n            prev_timestep = timestep\n            timestep = timestep + self.num_train_timesteps // self.num_inference_steps\n\n        if len(self.ets) == 1 and self.counter == 0:\n            model_output = model_output\n            self.cur_sample = sample\n        elif len(self.ets) == 1 and self.counter == 1:\n            model_output = (model_output + self.ets[-1]) / 2\n            sample = self.cur_sample\n            self.cur_sample = torch.Tensor()\n        elif len(self.ets) == 2:\n            model_output = (3 * self.ets[-1] - self.ets[-2]) / 2\n        elif len(self.ets) == 3:\n            model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12\n        else:\n            model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])\n\n        prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)\n        self.counter += 1\n\n        return prev_sample\n\n    def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor):\n        # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf\n        # this function computes x_(t−δ) using the formula of (9)\n        # Note that x_t needs to be added to both sides of the equation\n\n        # Notation (<variable name> -> <name in paper>\n        # alpha_prod_t -> α_t\n        # alpha_prod_t_prev -> α_(t−δ)\n        # beta_prod_t -> (1 - α_t)\n        # beta_prod_t_prev -> (1 - α_(t−δ))\n        # sample -> x_t\n        # model_output -> e_θ(x_t, t)\n        # prev_sample -> x_(t−δ)\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        if self.prediction_type == PNDMPredictionType.V_PREDICTION:\n            model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample\n\n        # corresponds to (α_(t−δ) - α_t) divided by\n        # denominator of x_t in formula (9) and plus 1\n        # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =\n        # sqrt(α_(t−δ)) / sqrt(α_t))\n        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)\n\n        # corresponds to denominator of e_θ(x_t, t) in formula (9)\n        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (\n            alpha_prod_t * beta_prod_t * alpha_prod_t_prev\n        ) ** (0.5)\n\n        # full formula (9)\n        prev_sample = (\n            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff\n        )\n\n        return prev_sample\n"
  },
  {
    "path": "monai/networks/schedulers/rectified_flow.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py\n# which has the following license:\n# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\nfrom __future__ import annotations\n\nimport numpy as np\nimport torch\nfrom torch.distributions import LogisticNormal\n\nfrom monai.utils import StrEnum\n\nfrom .ddpm import DDPMPredictionType\nfrom .scheduler import Scheduler\n\n\nclass RFlowPredictionType(StrEnum):\n    \"\"\"\n    Set of valid prediction type names for the RFlow scheduler's `prediction_type` argument.\n\n    v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf\n    \"\"\"\n\n    V_PREDICTION = DDPMPredictionType.V_PREDICTION\n\n\ndef timestep_transform(\n    t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3\n):\n    \"\"\"\n    Applies a transformation to the timestep based on image resolution scaling.\n\n    Args:\n        t (torch.Tensor): The original timestep(s).\n        input_img_size_numel (torch.Tensor): The input image's size (H * W * D).\n        base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel.\n        scale (float): Scaling factor for the transformation.\n        num_train_timesteps (int): Total number of training timesteps.\n        spatial_dim (int): Number of spatial dimensions in the image.\n\n    Returns:\n        torch.Tensor: Transformed timestep(s).\n    \"\"\"\n    t = t / num_train_timesteps\n    ratio_space = (input_img_size_numel / base_img_size_numel) ** (1.0 / spatial_dim)\n\n    ratio = ratio_space * scale\n    new_t = ratio * t / (1 + (ratio - 1) * t)\n\n    new_t = new_t * num_train_timesteps\n    return new_t\n\n\nclass RFlowScheduler(Scheduler):\n    \"\"\"\n    A rectified flow scheduler for guiding the diffusion process in a generative model.\n\n    Supports uniform and logit-normal sampling methods, timestep transformation for\n    different resolutions, and noise addition during diffusion.\n\n    Args:\n        num_train_timesteps (int): Total number of training timesteps.\n        use_discrete_timesteps (bool): Whether to use discrete timesteps.\n        sample_method (str): Training time step sampling method ('uniform' or 'logit-normal').\n        loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'.\n        scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'.\n        use_timestep_transform (bool): Whether to apply timestep transformation.\n            If true, there will be more inference timesteps at early(noisy) stages for larger image volumes.\n        transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True.\n        steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True.\n        base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True.\n        spatial_dim (int): 2 or 3, incidcating 2D or 3D images, used only if use_timestep_transform=True.\n\n    Example:\n\n        .. code-block:: python\n\n            # define a scheduler\n            noise_scheduler = RFlowScheduler(\n                num_train_timesteps = 1000,\n                use_discrete_timesteps = True,\n                sample_method = 'logit-normal',\n                use_timestep_transform = True,\n                base_img_size_numel = 32 * 32 * 32,\n                spatial_dim = 3\n            )\n\n            # during training\n            inputs = torch.ones(2,4,64,64,32)\n            noise = torch.randn_like(inputs)\n            timesteps = noise_scheduler.sample_timesteps(inputs)\n            noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)\n            predicted_velocity = diffusion_unet(\n                x=noisy_inputs,\n                timesteps=timesteps\n            )\n            loss = loss_l1(predicted_velocity, (inputs - noise))\n\n            # during inference\n            noisy_inputs = torch.randn(2,4,64,64,32)\n            input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:])\n            noise_scheduler.set_timesteps(\n                num_inference_steps=30, input_img_size_numel=input_img_size_numel)\n            )\n            all_next_timesteps = torch.cat(\n                (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))\n            )\n            for t, next_t in tqdm(\n                zip(noise_scheduler.timesteps, all_next_timesteps),\n                total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)),\n            ):\n                predicted_velocity = diffusion_unet(\n                    x=noisy_inputs,\n                    timesteps=timesteps\n                )\n                noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t)\n            final_output = noisy_inputs\n    \"\"\"\n\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        use_discrete_timesteps: bool = True,\n        sample_method: str = \"uniform\",\n        loc: float = 0.0,\n        scale: float = 1.0,\n        use_timestep_transform: bool = False,\n        transform_scale: float = 1.0,\n        steps_offset: int = 0,\n        base_img_size_numel: int = 32 * 32 * 32,\n        spatial_dim: int = 3,\n    ):\n        # rectified flow only accepts velocity prediction\n        self.prediction_type = RFlowPredictionType.V_PREDICTION\n\n        self.num_train_timesteps = num_train_timesteps\n        self.use_discrete_timesteps = use_discrete_timesteps\n        self.base_img_size_numel = base_img_size_numel\n        self.spatial_dim = spatial_dim\n\n        # sample method\n        if sample_method not in [\"uniform\", \"logit-normal\"]:\n            raise ValueError(\n                f\"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal'].\"\n            )\n        self.sample_method = sample_method\n        if sample_method == \"logit-normal\":\n            self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))\n            self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)\n\n        # timestep transform\n        self.use_timestep_transform = use_timestep_transform\n        self.transform_scale = transform_scale\n        self.steps_offset = steps_offset\n\n    def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Add noise to the original samples.\n\n        Args:\n            original_samples: original samples\n            noise: noise to add to samples\n            timesteps: timesteps tensor with shape of (N,), indicating the timestep to be computed for each sample.\n\n        Returns:\n            noisy_samples: sample with added noise\n        \"\"\"\n        timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps\n        timepoints = 1 - timepoints  # [1,1/1000]\n\n        # expand timepoint to noise shape\n        if noise.ndim == 5:\n            timepoints = timepoints[..., None, None, None, None].expand(-1, *noise.shape[1:])\n        elif noise.ndim == 4:\n            timepoints = timepoints[..., None, None, None].expand(-1, *noise.shape[1:])\n        else:\n            raise ValueError(f\"noise tensor has to be 4D or 5D tensor, yet got shape of {noise.shape}\")\n\n        noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise\n\n        return noisy_samples\n\n    def set_timesteps(\n        self,\n        num_inference_steps: int,\n        device: str | torch.device | None = None,\n        input_img_size_numel: int | None = None,\n    ) -> None:\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.\n\n        Args:\n            num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.\n            device: target device to put the data.\n            input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True.\n        \"\"\"\n        if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} should be at least 1, \"\n                \"and cannot be larger than `self.num_train_timesteps`:\"\n                f\" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n        # prepare timesteps\n        timesteps = [\n            (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)\n        ]\n        if self.use_discrete_timesteps:\n            timesteps = [int(round(t)) for t in timesteps]\n        if self.use_timestep_transform:\n            timesteps = [\n                timestep_transform(\n                    t,\n                    input_img_size_numel=input_img_size_numel,\n                    base_img_size_numel=self.base_img_size_numel,\n                    num_train_timesteps=self.num_train_timesteps,\n                    spatial_dim=self.spatial_dim,\n                )\n                for t in timesteps\n            ]\n        timesteps_np = np.array(timesteps).astype(np.float16)\n        if self.use_discrete_timesteps:\n            timesteps_np = timesteps_np.astype(np.int64)\n        self.timesteps = torch.from_numpy(timesteps_np).to(device)\n        self.timesteps += self.steps_offset\n\n    def sample_timesteps(self, x_start):\n        \"\"\"\n        Randomly samples training timesteps using the chosen sampling method.\n\n        Args:\n            x_start (torch.Tensor): The input tensor for sampling.\n\n        Returns:\n            torch.Tensor: Sampled timesteps.\n        \"\"\"\n        if self.sample_method == \"uniform\":\n            t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps\n        elif self.sample_method == \"logit-normal\":\n            t = self.sample_t(x_start) * self.num_train_timesteps\n\n        if self.use_discrete_timesteps:\n            t = t.long()\n\n        if self.use_timestep_transform:\n            input_img_size_numel = torch.prod(torch.tensor(x_start.shape[2:]))\n            t = timestep_transform(\n                t,\n                input_img_size_numel=input_img_size_numel,\n                base_img_size_numel=self.base_img_size_numel,\n                num_train_timesteps=self.num_train_timesteps,\n                spatial_dim=len(x_start.shape) - 2,\n            )\n\n        return t\n\n    def step(\n        self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: int | None = None\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Predicts the next sample in the diffusion process.\n\n        Args:\n            model_output (torch.Tensor): Output from the trained diffusion model.\n            timestep (int): Current timestep in the diffusion chain.\n            sample (torch.Tensor): Current sample in the process.\n            next_timestep (Union[int, None]): Optional next timestep.\n\n        Returns:\n            tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info.\n        \"\"\"\n        # Ensure num_inference_steps exists and is a valid integer\n        if not hasattr(self, \"num_inference_steps\") or not isinstance(self.num_inference_steps, int):\n            raise AttributeError(\n                \"num_inference_steps is missing or not an integer in the class.\"\n                \"Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it.\"\n            )\n\n        v_pred = model_output\n\n        if next_timestep is not None:\n            next_timestep = int(next_timestep)\n            dt: float = (\n                float(timestep - next_timestep) / self.num_train_timesteps\n            )  # Now next_timestep is guaranteed to be int\n        else:\n            dt = (\n                1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0\n            )  # Avoid division by zero\n\n        pred_post_sample = sample + v_pred * dt\n        pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps\n\n        return pred_post_sample, pred_original_sample\n"
  },
  {
    "path": "monai/networks/schedulers/scheduler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# =========================================================================\n# Adapted from https://github.com/huggingface/diffusers\n# which has the following license:\n# https://github.com/huggingface/diffusers/blob/main/LICENSE\n#\n# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =========================================================================\n\n\nfrom __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\nfrom monai.utils import ComponentStore, unsqueeze_right\n\nNoiseSchedules = ComponentStore(\"NoiseSchedules\", \"Functions to generate noise schedules\")\n\n\n@NoiseSchedules.add_def(\"linear_beta\", \"Linear beta schedule\")\ndef _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):\n    \"\"\"\n    Linear beta noise schedule function.\n\n    Args:\n        num_train_timesteps: number of timesteps\n        beta_start: start of beta range, default 1e-4\n        beta_end: end of beta range, default 2e-2\n\n    Returns:\n        betas: beta schedule tensor\n    \"\"\"\n    return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)\n\n\n@NoiseSchedules.add_def(\"scaled_linear_beta\", \"Scaled linear beta schedule\")\ndef _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):\n    \"\"\"\n    Scaled linear beta noise schedule function.\n\n    Args:\n        num_train_timesteps: number of timesteps\n        beta_start: start of beta range, default 1e-4\n        beta_end: end of beta range, default 2e-2\n\n    Returns:\n        betas: beta schedule tensor\n    \"\"\"\n    return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2\n\n\n@NoiseSchedules.add_def(\"sigmoid_beta\", \"Sigmoid beta schedule\")\ndef _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6):\n    \"\"\"\n    Sigmoid beta noise schedule function.\n\n    Args:\n        num_train_timesteps: number of timesteps\n        beta_start: start of beta range, default 1e-4\n        beta_end: end of beta range, default 2e-2\n        sig_range: pos/neg range of sigmoid input, default 6\n\n    Returns:\n        betas: beta schedule tensor\n    \"\"\"\n    betas = torch.linspace(-sig_range, sig_range, num_train_timesteps)\n    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start\n\n\n@NoiseSchedules.add_def(\"cosine\", \"Cosine schedule\")\ndef _cosine_beta(num_train_timesteps: int, s: float = 8e-3):\n    \"\"\"\n    Cosine noise schedule, see https://arxiv.org/abs/2102.09672\n\n    Args:\n        num_train_timesteps: number of timesteps\n        s: smoothing factor, default 8e-3 (see referenced paper)\n\n    Returns:\n        (betas, alphas, alpha_cumprod) values\n    \"\"\"\n    x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)\n    alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n    alphas_cumprod /= alphas_cumprod[0].item()\n    betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n    betas = torch.clip(betas, 0.0, 0.999)\n    alphas = 1.0 - betas\n    alphas_cumprod = torch.cumprod(alphas, dim=0)\n    return betas, alphas, alphas_cumprod\n\n\nclass Scheduler(nn.Module):\n    \"\"\"\n    Base class for other schedulers based on a noise schedule function.\n\n    This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here\n    the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`,\n    which is the name of a component in NoiseSchedules. These components must all be callables which return either\n    the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions\n    can be provided by using the NoiseSchedules.add_def, for example:\n\n    .. code-block:: python\n\n        from monai.networks.schedulers import NoiseSchedules, DDPMScheduler\n\n        @NoiseSchedules.add_def(\"my_beta_schedule\", \"Some description of your function\")\n        def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2):\n            return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)\n\n        scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"my_beta_schedule\")\n\n    All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of\n    timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through\n    the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules\n    to get a listing of stored objects with their docstring descriptions.\n\n    Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule\n    type, this now replaced with `schedule` and most names used with the previous argument now have \"_beta\" appended\n    to them, eg. 'schedule_beta=\"linear\"' -> 'schedule=\"linear_beta\"'. The `beta_start` and `beta_end` arguments are\n    still used for some schedules but these are provided as keyword arguments now.\n\n    Args:\n        num_train_timesteps: number of diffusion steps used to train the model.\n        schedule: member of NoiseSchedules,\n            a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple\n        schedule_args: arguments to pass to the schedule function\n    \"\"\"\n\n    def __init__(self, num_train_timesteps: int = 1000, schedule: str = \"linear_beta\", **schedule_args) -> None:\n        super().__init__()\n        schedule_args[\"num_train_timesteps\"] = num_train_timesteps\n        noise_sched = NoiseSchedules[schedule](**schedule_args)\n\n        # set betas, alphas, alphas_cumprod based off return value from noise function\n        if isinstance(noise_sched, tuple):\n            self.betas, self.alphas, self.alphas_cumprod = noise_sched\n        else:\n            self.betas = noise_sched\n            self.alphas = 1.0 - self.betas\n            self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n\n        self.num_train_timesteps = num_train_timesteps\n        self.one = torch.tensor(1.0)\n\n        # settable values\n        self.num_inference_steps: int | None = None\n        self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)\n\n    def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Add noise to the original samples.\n\n        Args:\n            original_samples: original samples\n            noise: noise to add to samples\n            timesteps: timesteps tensor indicating the timestep to be computed for each sample.\n\n        Returns:\n            noisy_samples: sample with added noise\n        \"\"\"\n        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples\n        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)\n        timesteps = timesteps.to(original_samples.device)\n\n        sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim)\n        sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(\n            (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim\n        )\n\n        noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise\n        return noisy_samples\n\n    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as sample\n        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)\n        timesteps = timesteps.to(sample.device)\n\n        sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim)\n        sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(\n            (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim\n        )\n\n        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample\n        return velocity\n"
  },
  {
    "path": "monai/networks/trt_compiler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport inspect\nimport os\nimport tempfile\nimport threading\nfrom collections import OrderedDict\nfrom pathlib import Path\nfrom types import MethodType\nfrom typing import Any\n\nimport torch\n\nfrom monai.apps.utils import get_logger\nfrom monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes\nfrom monai.utils.module import optional_import\n\npolygraphy, polygraphy_imported = optional_import(\"polygraphy\")\nif polygraphy_imported:\n    from polygraphy.backend.common import bytes_from_path\n    from polygraphy.backend.trt import (\n        CreateConfig,\n        Profile,\n        engine_bytes_from_network,\n        engine_from_bytes,\n        network_from_onnx_path,\n    )\n\ntrt, trt_imported = optional_import(\"tensorrt\")\ntorch_tensorrt, _ = optional_import(\"torch_tensorrt\", \"1.4.0\")\ncudart, _cudart_imported = optional_import(\"cuda.bindings.runtime\")\nif not _cudart_imported:\n    cudart, _cudart_imported = optional_import(\"cuda.cudart\")\n\n\nlock_sm = threading.Lock()\n\n\n# Map of TRT dtype -> Torch dtype\ndef trt_to_torch_dtype_dict():\n    return {\n        trt.int32: torch.int32,\n        trt.float32: torch.float32,\n        trt.float16: torch.float16,\n        trt.bfloat16: torch.float16,\n        trt.int64: torch.int64,\n        trt.int8: torch.int8,\n        trt.bool: torch.bool,\n    }\n\n\ndef get_dynamic_axes(profiles):\n    \"\"\"\n    This method calculates dynamic_axes to use in onnx.export().\n    Args:\n       profiles: [[min,opt,max],...] list of profile dimensions\n    \"\"\"\n    dynamic_axes: dict[str, list[int]] = {}\n    if not profiles:\n        return dynamic_axes\n    for profile in profiles:\n        for key in profile:\n            axes = []\n            vals = profile[key]\n            for i in range(len(vals[0])):\n                if vals[0][i] != vals[2][i]:\n                    axes.append(i)\n            if len(axes) > 0:\n                dynamic_axes[key] = axes\n    return dynamic_axes\n\n\ndef cuassert(cuda_ret):\n    \"\"\"\n    Error reporting method for CUDA calls.\n    Args:\n     cuda_ret: CUDA return code.\n    \"\"\"\n    err = cuda_ret[0]\n    if err != 0:\n        raise RuntimeError(f\"CUDA ERROR: {err}\")\n    if len(cuda_ret) > 1:\n        return cuda_ret[1]\n    return None\n\n\nclass ShapeError(Exception):\n    \"\"\"\n    Exception class to report errors from setting TRT plan input shapes\n    \"\"\"\n\n\nclass TRTEngine:\n    \"\"\"\n    An auxiliary class to implement running of TRT optimized engines\n\n    \"\"\"\n\n    def __init__(self, plan_path, logger=None):\n        \"\"\"\n        Loads serialized engine, creates execution context and activates it\n        Args:\n          plan_path: path to serialized TRT engine.\n          logger: optional logger object\n        \"\"\"\n        self.plan_path = plan_path\n        self.logger = logger or get_logger(\"monai.networks.trt_compiler\")\n        self.logger.info(f\"Loading TensorRT engine: {self.plan_path}\")\n        self.engine = engine_from_bytes(bytes_from_path(self.plan_path))\n        self.tensors = OrderedDict()\n        self.cuda_graph_instance = None  # cuda graph\n        self.context = self.engine.create_execution_context()\n        self.input_names = []\n        self.output_names = []\n        self.dtypes = []\n        self.cur_profile = 0\n        self.input_table = {}\n        dtype_dict = trt_to_torch_dtype_dict()\n        for idx in range(self.engine.num_io_tensors):\n            binding = self.engine[idx]\n            if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:\n                self.input_names.append(binding)\n            elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT:\n                self.output_names.append(binding)\n                dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]\n                self.dtypes.append(dtype)\n        self.logger.info(\n            f\"Loaded TensorRT engine: {self.plan_path}.\\nInputs: {self.input_names}\\nOutputs: {self.output_names}\"\n        )\n\n    def allocate_buffers(self, device):\n        \"\"\"\n        Allocates outputs to run TRT engine\n        Args:\n            device: GPU device to allocate memory on\n        \"\"\"\n        ctx = self.context\n\n        for i, binding in enumerate(self.output_names):\n            shape = list(ctx.get_tensor_shape(binding))\n            if binding not in self.tensors or list(self.tensors[binding].shape) != shape:\n                t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous()\n                self.tensors[binding] = t\n                ctx.set_tensor_address(binding, t.data_ptr())\n\n    def set_inputs(self, feed_dict, stream):\n        \"\"\"\n        Sets input bindings for TRT engine according to feed_dict\n        Args:\n           feed_dict: a dictionary [str->Tensor]\n           stream: CUDA stream to use\n        \"\"\"\n        e = self.engine\n        ctx = self.context\n\n        last_profile = self.cur_profile\n\n        def try_set_inputs():\n            for binding in self.input_names:\n                t = feed_dict.get(self.input_table[binding], None)\n                if t is not None:\n                    t = t.contiguous()\n                    shape = t.shape\n                    ctx.set_input_shape(binding, shape)\n                    ctx.set_tensor_address(binding, t.data_ptr())\n\n        while True:\n            try:\n                try_set_inputs()\n                break\n            except ShapeError:\n                next_profile = (self.cur_profile + 1) % e.num_optimization_profiles\n                if next_profile == last_profile:\n                    raise\n                self.cur_profile = next_profile\n                ctx.set_optimization_profile_async(self.cur_profile, stream)\n            except Exception:\n                raise\n        left = ctx.infer_shapes()\n        assert len(left) == 0\n\n    def infer(self, stream, use_cuda_graph=False):\n        \"\"\"\n        Runs TRT engine.\n        Args:\n            stream: CUDA stream to run on\n            use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls.\n        \"\"\"\n        if use_cuda_graph:\n            if self.cuda_graph_instance is not None:\n                cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))\n                cuassert(cudart.cudaStreamSynchronize(stream))\n            else:\n                # do inference before CUDA graph capture\n                noerror = self.context.execute_async_v3(stream)\n                if not noerror:\n                    raise ValueError(\"ERROR: inference failed.\")\n                # capture cuda graph\n                cuassert(\n                    cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal)\n                )\n                self.context.execute_async_v3(stream)\n                graph = cuassert(cudart.cudaStreamEndCapture(stream))\n                self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0))\n                self.logger.info(\"CUDA Graph captured!\")\n        else:\n            noerror = self.context.execute_async_v3(stream)\n            cuassert(cudart.cudaStreamSynchronize(stream))\n            if not noerror:\n                raise ValueError(\"ERROR: inference failed.\")\n\n        return self.tensors\n\n\ndef make_tensor(d):\n    return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda()\n\n\ndef unroll_input(input_names, input_example):\n    # Simulate list/tuple unrolling during ONNX export\n    unrolled_input = {}\n    for name in input_names:\n        val = input_example[name]\n        if val is not None:\n            if isinstance(val, list) or isinstance(val, tuple):\n                for i in range(len(val)):\n                    unrolled_input[f\"{name}_{i}\"] = make_tensor(val[i])\n            else:\n                unrolled_input[name] = make_tensor(val)\n    return unrolled_input\n\n\ndef parse_groups(\n    ret: list[torch.Tensor], output_lists: list[list[int]]\n) -> tuple[torch.Tensor | list[torch.Tensor], ...]:\n    \"\"\"\n    Implements parsing of 'output_lists' arg of trt_compile().\n\n    Args:\n      ret: plain list of Tensors\n\n      output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list\n                    of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.\n        Format: [[group_n] | [], ...]\n          [] or group_n == 0 : next output from ret is a scalar\n          group_n > 0  :       next output from ret is a list of group_n length\n          group_n == -1:       next output is a dynamic list. This entry can be at any\n                               position in output_lists, but can appear only once.\n    Returns:\n       Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists\n\n    \"\"\"\n    groups: tuple[torch.Tensor | list[torch.Tensor], ...] = ()\n    cur = 0\n    for idx in range(len(output_lists)):\n        gl = output_lists[idx]\n        assert len(gl) == 0 or len(gl) == 1\n        if len(gl) == 0 or gl[0] == 0:\n            groups = (*groups, ret[cur])\n            cur = cur + 1\n        elif gl[0] > 0:\n            groups = (*groups, ret[cur : cur + gl[0]])\n            cur = cur + gl[0]\n        elif gl[0] == -1:\n            rev_groups: tuple[torch.Tensor | list[torch.Tensor], ...] = ()\n            rcur = len(ret)\n            for rl in range(len(output_lists) - 1, idx, -1):\n                rgl = output_lists[rl]\n                assert len(rgl) == 0 or len(rgl) == 1\n                if len(rgl) == 0 or rgl[0] == 0:\n                    rcur = rcur - 1\n                    rev_groups = (*rev_groups, ret[rcur])\n                elif rgl[0] > 0:\n                    rcur = rcur - rgl[0]\n                    rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]])\n                else:\n                    raise ValueError(\"Two -1 lists in output\")\n            groups = (*groups, ret[cur:rcur], *rev_groups[::-1])\n            break\n    return groups\n\n\nclass TrtCompiler:\n    \"\"\"\n    This class implements:\n      - TRT lazy persistent export\n      - Running TRT with optional fallback to Torch\n        (for TRT engines with limited profiles)\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        plan_path,\n        precision=\"fp16\",\n        method=\"onnx\",\n        input_names=None,\n        output_names=None,\n        output_lists=None,\n        export_args=None,\n        build_args=None,\n        input_profiles=None,\n        dynamic_batchsize=None,\n        use_cuda_graph=False,\n        timestamp=None,\n        fallback=False,\n        forward_override=None,\n        logger=None,\n    ):\n        \"\"\"\n        Initialization method:\n         Tries to load persistent serialized TRT engine\n         Saves its arguments for lazy TRT build on first forward() call\n        Args:\n            model: Model to \"wrap\".\n            plan_path : Path where to save persistent serialized TRT engine.\n            precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'.\n            method: One of 'onnx'|'torch_trt'.\n                    Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option.\n                    'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.\n            input_names: Optional list of input names. If None, will be read from the function signature.\n            output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.\n            output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list\n                          of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.\n            export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.\n            build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.\n            input_profiles: Optional list of profiles for TRT builder and ONNX export.\n                            Each profile is a map of the form : {\"input id\" : [min_shape, opt_shape, max_shape], ...}.\n            dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be\n                               converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH].\n            [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine.\n            use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls!\n            timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes).\n            fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile).\n        \"\"\"\n\n        method_vals = [\"onnx\", \"torch_trt\"]\n        if method not in method_vals:\n            raise ValueError(f\"trt_compile(): 'method' should be one of {method_vals}, got: {method}.\")\n        precision_vals = [\"fp32\", \"tf32\", \"fp16\", \"bf16\"]\n        if precision not in precision_vals:\n            raise ValueError(f\"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.\")\n\n        self.plan_path = plan_path\n        self.precision = precision\n        self.method = method\n        self.return_dict = output_names is not None\n        self.output_names = output_names or []\n        self.output_lists = output_lists or []\n        self.profiles = input_profiles or []\n        self.dynamic_batchsize = dynamic_batchsize\n        self.export_args = export_args or {}\n        self.build_args = build_args or {}\n        self.engine: TRTEngine | None = None\n        self.use_cuda_graph = use_cuda_graph\n        self.fallback = fallback\n        self.disabled = False\n\n        self.logger = logger or get_logger(\"monai.networks.trt_compiler\")\n        self.argspec = inspect.getfullargspec(model.forward)\n\n        # Normally we read input_names from forward() but can be overridden\n        if input_names is None:\n            input_names = self.argspec.args[1:]\n        self.defaults = {}\n        if self.argspec.defaults is not None:\n            for i in range(len(self.argspec.defaults)):\n                d = self.argspec.defaults[-i - 1]\n                if d is not None:\n                    d = make_tensor(d)\n                self.defaults[self.argspec.args[-i - 1]] = d\n\n        self.input_names = input_names\n        self.old_forward = model.forward\n\n        # Force engine rebuild if older than the timestamp\n        if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp:\n            os.remove(self.plan_path)\n\n    def _inputs_to_dict(self, input_example):\n        trt_inputs = {}\n        for i, inp in enumerate(input_example):\n            input_name = self.input_names[i]\n            trt_inputs[input_name] = inp\n        return trt_inputs\n\n    def _load_engine(self):\n        \"\"\"\n        Loads TRT plan from disk and activates its execution context.\n        \"\"\"\n        try:\n            self.engine = TRTEngine(self.plan_path, self.logger)\n            # Make sure we have names correct\n            input_table = {}\n            for name in self.engine.input_names:\n                if name.startswith(\"__\") and name not in self.input_names:\n                    orig_name = name[2:]\n                else:\n                    orig_name = name\n                input_table[name] = orig_name\n            self.engine.input_table = input_table\n            self.logger.info(f\"Engine loaded, inputs:{self.engine.input_table}\")\n        except Exception as e:\n            self.logger.info(f\"Exception while loading the engine:\\n{e}\")\n\n    def forward(self, model, argv, kwargs):\n        \"\"\"\n        Main forward method:\n         Builds TRT engine if not available yet.\n         Tries to run TRT engine\n         If exception thrown and self.callback==True: falls back to original Pytorch\n\n        Args: Passing through whatever args wrapped module's forward() has\n        Returns: Passing through wrapped module's forward() return value(s)\n\n        \"\"\"\n        args = self.defaults\n        args.update(kwargs)\n        if len(argv) > 0:\n            args.update(self._inputs_to_dict(argv))\n\n        if self.engine is None and not self.disabled:\n            # Restore original forward for export\n            new_forward = model.forward\n            model.forward = self.old_forward\n            try:\n                self._load_engine()\n                if self.engine is None:\n                    build_args = args.copy()\n                    with torch.no_grad():\n                        self._build_and_save(model, build_args)\n                        # This will reassign input_names from the engine\n                    self._load_engine()\n                    assert self.engine is not None\n            except Exception as e:\n                if self.fallback:\n                    self.logger.info(f\"Failed to build engine: {e}\")\n                    self.disabled = True\n                else:\n                    raise e\n            if not self.disabled and not self.fallback:\n                # Delete all parameters\n                for param in model.parameters():\n                    del param\n                # Call empty_cache to release GPU memory\n                torch.cuda.empty_cache()\n            # restore TRT hook\n            model.forward = new_forward\n        # Run the engine\n        try:\n            if self.engine is not None:\n                # forward_trt is not thread safe as we do not use per-thread execution contexts\n                with lock_sm:\n                    device = torch.cuda.current_device()\n                    stream = torch.cuda.Stream(device=device)\n                    self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream)\n                    self.engine.allocate_buffers(device=device)\n                    # Need this to synchronize with Torch stream\n                    stream.wait_stream(torch.cuda.current_stream())\n                    ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph)\n                    # if output_names is not None, return dictionary\n                    if not self.return_dict:\n                        ret = list(ret.values())\n                        if self.output_lists:\n                            ret = parse_groups(ret, self.output_lists)\n                        elif len(ret) == 1:\n                            ret = ret[0]\n                    return ret\n        except Exception as e:\n            if self.fallback:\n                self.logger.info(f\"Exception: {e}\\nFalling back to Pytorch ...\")\n            else:\n                raise e\n        return self.old_forward(*argv, **kwargs)\n\n    def _onnx_to_trt(self, onnx_path):\n        \"\"\"\n        Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path\n        \"\"\"\n\n        profiles = []\n        for profile in self.profiles:\n            p = Profile()\n            for id, val in profile.items():\n                p.add(id, min=val[0], opt=val[1], max=val[2])\n            profiles.append(p)\n\n        build_args = self.build_args.copy()\n        build_args[\"tf32\"] = self.precision != \"fp32\"\n        if self.precision == \"fp16\":\n            build_args[\"fp16\"] = True\n        elif self.precision == \"bf16\":\n            build_args[\"bf16\"] = True\n\n        self.logger.info(f\"Building TensorRT engine for {onnx_path}: {self.plan_path}\")\n        network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])\n        return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args))\n\n    def _build_and_save(self, model, input_example):\n        \"\"\"\n        If TRT engine is not ready, exports model to ONNX,\n        builds TRT engine and saves serialized TRT engine to the disk.\n        Args:\n             input_example: passed to onnx.export()\n        \"\"\"\n\n        if self.engine is not None:\n            return\n\n        export_args = self.export_args\n        engine_bytes = None\n        add_casts_around_norms(model)\n\n        if self.method == \"torch_trt\":\n            enabled_precisions = [torch.float32]\n            if self.precision == \"fp16\":\n                enabled_precisions.append(torch.float16)\n            elif self.precision == \"bf16\":\n                enabled_precisions.append(torch.bfloat16)\n            inputs = list(input_example.values())\n\n            def get_torch_trt_input(input_shape, dynamic_batchsize):\n                min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)\n                return torch_tensorrt.Input(\n                    min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape\n                )\n\n            tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs]\n            engine_bytes = torch_tensorrt.convert_method_to_trt_engine(\n                model, \"forward\", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args\n            )\n        else:\n            dbs = self.dynamic_batchsize\n            if dbs:\n                if len(self.profiles) > 0:\n                    raise ValueError(\"ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!\")\n                if len(dbs) != 3:\n                    raise ValueError(\"dynamic_batchsize has to have len ==3 \")\n                profile = {}\n                for id, val in input_example.items():\n\n                    def add_profile(id, val):\n                        sh = val.shape\n                        if len(sh) > 0:\n                            sh = sh[1:]\n                            profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]\n\n                    if isinstance(val, list) or isinstance(val, tuple):\n                        for i in range(len(val)):\n                            add_profile(f\"{id}_{i}\", val[i])\n                    elif isinstance(val, torch.Tensor):\n                        add_profile(id, val)\n                self.profiles = [profile]\n\n            self.dynamic_axes = get_dynamic_axes(self.profiles)\n\n            if len(self.dynamic_axes) > 0:\n                export_args.update({\"dynamic_axes\": self.dynamic_axes})\n\n            # Use temporary directory for easy cleanup in case of external weights\n            with tempfile.TemporaryDirectory() as tmpdir:\n                unrolled_input = unroll_input(self.input_names, input_example)\n                onnx_path = str(Path(tmpdir) / \"model.onnx\")\n                self.logger.info(\n                    f\"Exporting to {onnx_path}:\\nunrolled_inputs={list(unrolled_input.keys())}\\n\"\n                    + f\"output_names={self.output_names}\\ninput_names={self.input_names}\\nexport args: {export_args}\"\n                )\n                convert_to_onnx(\n                    model,\n                    input_example,\n                    filename=onnx_path,\n                    input_names=list(unrolled_input.keys()),\n                    output_names=self.output_names,\n                    **export_args,\n                )\n                self.logger.info(\"Export to ONNX successful.\")\n                engine_bytes = self._onnx_to_trt(onnx_path)\n        if engine_bytes:\n            open(self.plan_path, \"wb\").write(engine_bytes)\n\n\ndef trt_forward(self, *argv, **kwargs):\n    \"\"\"\n    Patch function to replace original model's forward() with.\n    Redirects to TrtCompiler.forward()\n    \"\"\"\n    return self._trt_compiler.forward(self, argv, kwargs)\n\n\ndef trt_compile(\n    model: torch.nn.Module,\n    base_path: str,\n    args: dict[str, Any] | None = None,\n    submodule: str | list[str] | None = None,\n    logger: Any | None = None,\n) -> torch.nn.Module:\n    \"\"\"\n    Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.\n    Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.\n        NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.\n        Review the TensorRT Support Matrix for which GPUs are supported.\n    Args:\n      model: module to patch with TrtCompiler object.\n      base_path: TRT plan(s) saved to f\"{base_path}[.{submodule}].plan\" path.\n                 dirname(base_path) must exist, base_path does not have to.\n                 If base_path does point to existing file (e.g. associated checkpoint),\n                 that file becomes a dependency - its mtime is added to args[\"timestamp\"].\n      args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details.\n      submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder']\n                  If None, TrtCompiler patch is applied to the whole model.\n                  Otherwise, submodule (or list of) is being patched.\n      logger: Optional logger for diagnostics.\n    Returns:\n      Always returns same model passed in as argument. This is for ease of use in configs.\n    \"\"\"\n\n    default_args: dict[str, Any] = {\n        \"method\": \"onnx\",\n        \"precision\": \"fp16\",\n        \"build_args\": {\"builder_optimization_level\": 5, \"precision_constraints\": \"obey\"},\n    }\n\n    default_args.update(args or {})\n    args = default_args\n\n    if trt_imported and polygraphy_imported and torch.cuda.is_available():\n        # if \"path\" filename point to existing file (e.g. checkpoint)\n        # it's also treated as dependency\n        if os.path.exists(base_path):\n            timestamp = int(os.path.getmtime(base_path))\n            if \"timestamp\" in args:\n                timestamp = max(int(args[\"timestamp\"]), timestamp)\n            args[\"timestamp\"] = timestamp\n\n        def wrap(model, path):\n            if not hasattr(model, \"_trt_compiler\"):\n                model.orig_forward = model.forward\n                wrapper = TrtCompiler(model, path + \".plan\", logger=logger, **args)\n                model._trt_compiler = wrapper\n                model.forward = MethodType(trt_forward, model)\n\n        def find_sub(parent, submodule):\n            idx = submodule.find(\".\")\n            # if there is \".\" in name, call recursively\n            if idx != -1:\n                parent_name = submodule[:idx]\n                parent = getattr(parent, parent_name)\n                submodule = submodule[idx + 1 :]\n                return find_sub(parent, submodule)\n            return parent, submodule\n\n        if submodule is not None:\n            if isinstance(submodule, str):\n                submodule = [submodule]\n            for s in submodule:\n                parent, sub = find_sub(model, s)\n                wrap(getattr(parent, sub), base_path + \".\" + s)\n        else:\n            wrap(model, base_path)\n    else:\n        logger = logger or get_logger(\"monai.networks.trt_compiler\")\n        logger.warning(\"TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.\")\n\n    return model\n"
  },
  {
    "path": "monai/networks/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nUtilities and types for defining networks, these depend on PyTorch.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport io\nimport re\nimport tempfile\nimport warnings\nfrom collections import OrderedDict\nfrom collections.abc import Callable, Iterable, Mapping, Sequence\nfrom contextlib import contextmanager\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.apps.utils import get_logger\nfrom monai.config import PathLike\nfrom monai.utils.misc import ensure_tuple, save_obj, set_determinism\nfrom monai.utils.module import look_up_option, optional_import\nfrom monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor\n\nonnx, _ = optional_import(\"onnx\")\nonnxreference, _ = optional_import(\"onnx.reference\")\nonnxruntime, _ = optional_import(\"onnxruntime\")\npolygraphy, polygraphy_imported = optional_import(\"polygraphy\")\ntorch_tensorrt, _ = optional_import(\"torch_tensorrt\", \"1.4.0\")\n\n__all__ = [\n    \"one_hot\",\n    \"predict_segmentation\",\n    \"normalize_transform\",\n    \"to_norm_affine\",\n    \"CastTempType\",\n    \"normal_init\",\n    \"icnr_init\",\n    \"pixelshuffle\",\n    \"pixelunshuffle\",\n    \"eval_mode\",\n    \"train_mode\",\n    \"get_state_dict\",\n    \"copy_model_state\",\n    \"save_state\",\n    \"convert_to_onnx\",\n    \"convert_to_torchscript\",\n    \"convert_to_trt\",\n    \"meshgrid_ij\",\n    \"meshgrid_xy\",\n    \"replace_modules\",\n    \"replace_modules_temp\",\n    \"look_up_named_module\",\n    \"set_named_module\",\n    \"has_nvfuser_instance_norm\",\n    \"get_profile_shapes\",\n]\n\nlogger = get_logger(module_name=__name__)\n\n_has_nvfuser = None\n\n\ndef get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):\n    \"\"\"\n    Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.\n    \"\"\"\n\n    def scale_batch_size(input_shape: Sequence[int], scale_num: int):\n        scale_shape = [*input_shape]\n        scale_shape[0] = scale_num\n        return scale_shape\n\n    # Use the dynamic batchsize range to generate the min, opt and max model input shape\n    if dynamic_batchsize:\n        min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])\n        opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])\n        max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])\n    else:\n        min_input_shape = opt_input_shape = max_input_shape = input_shape\n    return min_input_shape, opt_input_shape, max_input_shape\n\n\ndef has_nvfuser_instance_norm():\n    \"\"\"whether the current environment has InstanceNorm3dNVFuser\n    https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16\n    \"\"\"\n    global _has_nvfuser\n    if _has_nvfuser is not None:\n        return _has_nvfuser\n\n    _, _has_nvfuser = optional_import(\"apex.normalization\", name=\"InstanceNorm3dNVFuser\")\n    if not _has_nvfuser:\n        return False\n    try:\n        import importlib\n\n        importlib.import_module(\"instance_norm_nvfuser_cuda\")\n    except ImportError:\n        _has_nvfuser = False\n    return _has_nvfuser\n\n\ndef look_up_named_module(name: str, mod, print_all_options=False):\n    \"\"\"\n    get the named module in `mod` by the attribute name,\n    for example ``look_up_named_module(net, \"features.3.1.attn\")``\n\n    Args:\n        name: a string representing the module attribute.\n        mod: a pytorch module to be searched (in ``mod.named_modules()``).\n        print_all_options: whether to print all named modules when `name` is not found in `mod`. Defaults to False.\n\n    Returns:\n        the corresponding pytorch module's subcomponent such as ``net.features[3][1].attn``\n    \"\"\"\n    name_str = look_up_option(\n        name, {n[0] for n in mod.named_modules()}, default=None, print_all_options=print_all_options\n    )\n    if name_str is None:\n        return None\n    if name_str == \"\":\n        return mod\n    for n in name_str.split(\".\"):\n        if n.isdigit():\n            mod = mod[int(n)]\n        else:\n            n = look_up_option(n, {item[0] for item in mod.named_modules()}, default=None, print_all_options=False)\n            if n is None:\n                return None\n            mod = getattr(mod, n)\n    return mod\n\n\ndef set_named_module(mod, name: str, new_layer):\n    \"\"\"\n    look up `name` in `mod` and replace the layer with `new_layer`, return the updated `mod`.\n\n    Args:\n        mod: a pytorch module to be updated.\n        name: a string representing the target module attribute.\n        new_layer: a new module replacing the corresponding layer at ``mod.name``.\n\n    Returns:\n        an updated ``mod``\n\n    See also: :py:func:`monai.networks.utils.look_up_named_module`.\n    \"\"\"\n    mods_attr = name.rsplit(\".\", 1)\n    submods, attr = mods_attr if len(mods_attr) == 2 else (\"\", name)\n    if not attr:\n        return new_layer\n    _mod = look_up_named_module(submods, mod)\n    setattr(_mod, attr, new_layer)\n    return mod\n\n\ndef one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:\n    \"\"\"\n    For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th\n    dimension has the \"one-hot\" format, i.e., it has a total length of `num_classes`,\n    with a one and `num_class-1` zeros.\n    Note that this will include the background label, thus a binary mask should be treated as having two classes.\n\n    Args:\n        labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be\n            converted into integers `labels.long()`.\n        num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to\n            `num_classes` from `1`.\n        dtype: the data type of the output one_hot label.\n        dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number.\n\n    Example:\n\n    For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`\n    when `num_classes=N` number of classes and `dim=1`.\n\n    .. code-block:: python\n\n        from monai.networks.utils import one_hot\n        import torch\n\n        a = torch.randint(0, 2, size=(1, 2, 2, 2))\n        out = one_hot(a, num_classes=2, dim=0)\n        print(out.shape)  # torch.Size([2, 2, 2, 2])\n\n        a = torch.randint(0, 2, size=(2, 1, 2, 2, 2))\n        out = one_hot(a, num_classes=2, dim=1)\n        print(out.shape)  # torch.Size([2, 2, 2, 2, 2])\n\n    \"\"\"\n\n    # if `dim` is bigger, add singleton dim at the end\n    if labels.ndim < dim + 1:\n        shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))\n        labels = torch.reshape(labels, shape)\n\n    sh = list(labels.shape)\n\n    if sh[dim] != 1:\n        raise AssertionError(\"labels should have a channel with length equal to one.\")\n\n    sh[dim] = num_classes\n\n    o = torch.zeros(size=sh, dtype=dtype, device=labels.device)\n    labels = o.scatter_(dim=dim, index=labels.long(), value=1)\n\n    return labels\n\n\ndef predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False, threshold: float = 0.0) -> Any:\n    \"\"\"\n    Given the logits from a network, computing the segmentation by thresholding all values above 0\n    if multi-labels task, computing the `argmax` along the channel axis if multi-classes task,\n    logits has shape `BCHW[D]`.\n\n    Args:\n        logits: raw data of model output.\n        mutually_exclusive: if True, `logits` will be converted into a binary matrix using\n            a combination of argmax, which is suitable for multi-classes task. Defaults to False.\n        threshold: thresholding the prediction values if multi-labels task.\n    \"\"\"\n    if not mutually_exclusive:\n        return (logits >= threshold).int()\n    if logits.shape[1] == 1:\n        warnings.warn(\"single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.\")\n        return (logits >= threshold).int()\n    return logits.argmax(1, keepdim=True)\n\n\ndef normalize_transform(\n    shape,\n    device: torch.device | str | None = None,\n    dtype: torch.dtype | None = None,\n    align_corners: bool = False,\n    zero_centered: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Compute an affine matrix according to the input shape.\n    The transform normalizes the homogeneous image coordinates to the\n    range of `[-1, 1]`.  Currently the following source coordinates are supported:\n\n        - `align_corners=False`, `zero_centered=False`, normalizing from ``[-0.5, d-0.5]``.\n        - `align_corners=True`, `zero_centered=False`, normalizing from ``[0, d-1]``.\n        - `align_corners=False`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``.\n        - `align_corners=True`, `zero_centered=True`, normalizing from ``[-d/2, d/2]``.\n\n    Args:\n        shape: input spatial shape, a sequence of integers.\n        device: device on which the returned affine will be allocated.\n        dtype: data type of the returned affine\n        align_corners: if True, consider -1 and 1 to refer to the centers of the\n            corner pixels rather than the image corners.\n            See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample\n        zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.\n            Setting this flag and `align_corners` will jointly specify the normalization source range.\n    \"\"\"\n    shape = convert_to_tensor(shape, torch.float64, device=device, wrap_sequence=True, track_meta=False)\n    norm = shape.clone().detach().to(dtype=torch.float64, device=device)  # no in-place change\n    if align_corners:\n        norm[norm <= 1.0] = 2.0\n        norm = 2.0 / (norm if zero_centered else norm - 1.0)\n        norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device))))\n        if not zero_centered:  # else shift is 0\n            norm[:-1, -1] = -1.0\n    else:\n        norm[norm <= 0.0] = 2.0\n        norm = 2.0 / (norm - 1.0 if zero_centered else norm)\n        norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device))))\n        if not zero_centered:\n            norm[:-1, -1] = 1.0 / shape - 1.0\n    norm = norm.unsqueeze(0).to(dtype=dtype)\n    norm.requires_grad = False\n    return norm  # type: ignore\n\n\ndef to_norm_affine(\n    affine: torch.Tensor,\n    src_size: Sequence[int],\n    dst_size: Sequence[int],\n    align_corners: bool = False,\n    zero_centered: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Given ``affine`` defined for coordinates in the pixel space, compute the corresponding affine\n    for the normalized coordinates.\n\n    Args:\n        affine: Nxdxd batched square matrix\n        src_size: source image spatial shape\n        dst_size: target image spatial shape\n        align_corners: if True, consider -1 and 1 to refer to the centers of the\n            corner pixels rather than the image corners.\n            See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample\n        zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.\n            See also: :py:func:`monai.networks.utils.normalize_transform`.\n\n    Raises:\n        TypeError: When ``affine`` is not a ``torch.Tensor``.\n        ValueError: When ``affine`` is not Nxdxd.\n        ValueError: When ``src_size`` or ``dst_size`` dimensions differ from ``affine``.\n\n    \"\"\"\n    if not isinstance(affine, torch.Tensor):\n        raise TypeError(f\"affine must be a torch.Tensor but is {type(affine).__name__}.\")\n    if affine.ndimension() != 3 or affine.shape[1] != affine.shape[2]:\n        raise ValueError(f\"affine must be Nxdxd, got {tuple(affine.shape)}.\")\n    sr = affine.shape[1] - 1\n    if sr != len(src_size) or sr != len(dst_size):\n        raise ValueError(f\"affine suggests {sr}D, got src={len(src_size)}D, dst={len(dst_size)}D.\")\n\n    src_xform = normalize_transform(src_size, affine.device, affine.dtype, align_corners, zero_centered)\n    dst_xform = normalize_transform(dst_size, \"cpu\", affine.dtype, align_corners, zero_centered)\n    return src_xform @ affine @ convert_to_dst_type(np.linalg.inv(dst_xform.numpy()), dst=affine)[0]  # monai#5983\n\n\ndef normal_init(\n    m, std: float = 0.02, normal_func: Callable[[torch.Tensor, float, float], Any] = torch.nn.init.normal_\n) -> None:\n    \"\"\"\n    Initialize the weight and bias tensors of `m' and its submodules to values from a normal distribution with a\n    stddev of `std'. Weight tensors of convolution and linear modules are initialized with a mean of 0, batch\n    norm modules with a mean of 1. The callable `normal_func', used to assign values, should have the same arguments\n    as its default normal_(). This can be used with `nn.Module.apply` to visit submodules of a network.\n    \"\"\"\n    cname = m.__class__.__name__\n\n    if getattr(m, \"weight\", None) is not None and (cname.find(\"Conv\") != -1 or cname.find(\"Linear\") != -1):\n        normal_func(m.weight.data, 0.0, std)\n        if getattr(m, \"bias\", None) is not None:\n            nn.init.constant_(m.bias.data, 0.0)\n\n    elif cname.find(\"BatchNorm\") != -1:\n        normal_func(m.weight.data, 1.0, std)\n        nn.init.constant_(m.bias.data, 0)\n\n\ndef icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_):\n    \"\"\"\n    ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , \"Checkerboard artifact free\n    sub-pixel convolution\".\n    \"\"\"\n    out_channels, in_channels, *dims = conv.weight.shape\n    scale_factor = upsample_factor ** len(dims)\n\n    oc2 = int(out_channels / scale_factor)\n\n    kernel = torch.zeros([oc2, in_channels] + dims)\n    kernel = init(kernel)\n    kernel = kernel.transpose(0, 1)\n    kernel = kernel.reshape(oc2, in_channels, -1)\n    kernel = kernel.repeat(1, 1, scale_factor)\n    kernel = kernel.reshape([in_channels, out_channels] + dims)\n    kernel = kernel.transpose(0, 1)\n    conv.weight.data.copy_(kernel)\n\n\ndef pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:\n    \"\"\"\n    Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.\n\n    See: Shi et al., 2016, \"Real-Time Single Image and Video Super-Resolution\n    Using an Efficient Sub-Pixel Convolutional Neural Network.\"\n\n    See: Aitken et al., 2017, \"Checkerboard artifact free sub-pixel convolution\".\n\n    Args:\n        x: Input tensor with shape BCHW[D]\n        spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D\n        scale_factor: factor to rescale the spatial dimensions by, must be >=1\n\n    Returns:\n        Reshuffled version of `x`.\n\n    Raises:\n        ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims)\n    \"\"\"\n    dim, factor = spatial_dims, scale_factor\n    input_size = list(x.size())\n    batch_size, channels = input_size[:2]\n    scale_divisor = factor**dim\n\n    if channels % scale_divisor != 0:\n        raise ValueError(\n            f\"Number of input channels ({channels}) must be evenly \"\n            f\"divisible by scale_factor ** dimensions ({factor}**{dim}={scale_divisor}).\"\n        )\n\n    org_channels = int(channels // scale_divisor)\n    output_size = [batch_size, org_channels] + [d * factor for d in input_size[2:]]\n\n    indices = list(range(2, 2 + 2 * dim))\n    indices = indices[dim:] + indices[:dim]\n    permute_indices = [0, 1]\n    for idx in range(dim):\n        permute_indices.extend(indices[idx::dim])\n\n    x = x.reshape([batch_size, org_channels] + [factor] * dim + input_size[2:])\n    x = x.permute(permute_indices).reshape(output_size)\n    return x\n\n\ndef pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:\n    \"\"\"\n    Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.\n    Inverse operation of pixelshuffle.\n\n    See: Shi et al., 2016, \"Real-Time Single Image and Video Super-Resolution\n    Using an Efficient Sub-Pixel Convolutional Neural Network.\"\n\n    See: Aitken et al., 2017, \"Checkerboard artifact free sub-pixel convolution\".\n\n    Args:\n        x: Input tensor with shape BCHW[D]\n        spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D\n        scale_factor: factor to reduce the spatial dimensions by, must be >=1\n\n    Returns:\n        Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D\n        or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor\n        and d is spatial_dims.\n\n    Raises:\n        ValueError: When spatial dimensions are not divisible by scale_factor\n    \"\"\"\n    dim, factor = spatial_dims, scale_factor\n    input_size = list(x.size())\n    batch_size, channels = input_size[:2]\n    scale_factor_mult = factor**dim\n    new_channels = channels * scale_factor_mult\n\n    if any(d % factor != 0 for d in input_size[2:]):\n        raise ValueError(\n            f\"All spatial dimensions must be divisible by factor {factor}. \" f\", spatial shape is: {input_size[2:]}\"\n        )\n    output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]]\n    reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], [])\n\n    permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)]\n    x = x.reshape(reshaped_size).permute(permute_indices)\n    x = x.reshape(output_size)\n    return x\n\n\n@contextmanager\ndef eval_mode(*nets: nn.Module):\n    \"\"\"\n    Set network(s) to eval mode and then return to original state at the end.\n\n    Args:\n        nets: Input network(s)\n\n    Examples\n\n    .. code-block:: python\n\n        t=torch.rand(1,1,16,16)\n        p=torch.nn.Conv2d(1,1,3)\n        print(p.training)  # True\n        with eval_mode(p):\n            print(p.training)  # False\n            print(p(t).sum().backward())  # will correctly raise an exception as gradients are calculated\n    \"\"\"\n\n    # Get original state of network(s).\n    # Check the training attribute in case it's TensorRT based models which don't have this attribute.\n    training = [n for n in nets if hasattr(n, \"training\") and n.training]\n\n    try:\n        # set to eval mode\n        with torch.no_grad():\n            yield [n.eval() if hasattr(n, \"eval\") else n for n in nets]\n    finally:\n        # Return required networks to training\n        for n in training:\n            if hasattr(n, \"train\"):\n                n.train()\n\n\n@contextmanager\ndef train_mode(*nets: nn.Module):\n    \"\"\"\n    Set network(s) to train mode and then return to original state at the end.\n\n    Args:\n        nets: Input network(s)\n\n    Examples\n\n    .. code-block:: python\n\n        t=torch.rand(1,1,16,16)\n        p=torch.nn.Conv2d(1,1,3)\n        p.eval()\n        print(p.training)  # False\n        with train_mode(p):\n            print(p.training)  # True\n            print(p(t).sum().backward())  # No exception\n    \"\"\"\n\n    # Get original state of network(s)\n    # Check the training attribute in case it's TensorRT based models which don't have this attribute.\n    eval_list = [n for n in nets if hasattr(n, \"training\") and (not n.training)]\n\n    try:\n        # set to train mode\n        with torch.set_grad_enabled(True):\n            yield [n.train() if hasattr(n, \"train\") else n for n in nets]\n    finally:\n        # Return required networks to eval_list\n        for n in eval_list:\n            if hasattr(n, \"eval\"):\n                n.eval()\n\n\ndef get_state_dict(obj: torch.nn.Module | Mapping):\n    \"\"\"\n    Get the state dict of input object if has `state_dict`, otherwise, return object directly.\n    For data parallel model, automatically convert it to regular model first.\n\n    Args:\n        obj: input object to check and get the state_dict.\n\n    \"\"\"\n    if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):\n        obj = obj.module\n    return obj.state_dict() if hasattr(obj, \"state_dict\") else obj\n\n\ndef copy_model_state(\n    dst: torch.nn.Module | Mapping,\n    src: torch.nn.Module | Mapping,\n    dst_prefix=\"\",\n    mapping=None,\n    exclude_vars=None,\n    inplace=True,\n    filter_func=None,\n):\n    \"\"\"\n    Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten\n    by the ones from `src` whenever their keys match. The method provides additional `dst_prefix` for\n    the `dst` key when matching them. `mapping` can be a `{\"src_key\": \"dst_key\"}` dict, indicating\n    `dst[dst_prefix + dst_key] = src[src_key]`.\n    This function is mainly to return a model state dict\n    for loading the `src` model state into the `dst` model, `src` and `dst` can have different dict keys, but\n    their corresponding values normally have the same shape.\n\n    Args:\n        dst: a pytorch module or state dict to be updated.\n        src: a pytorch module or state dict used to get the values used for the update.\n        dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`\n            will be assigned to the value of `src[src_key]`.\n        mapping: a `{\"src_key\": \"dst_key\"}` dict, indicating that `dst[dst_prefix + dst_key]`\n            to be assigned to the value of `src[src_key]`.\n        exclude_vars: a regular expression to match the `dst` variable names,\n            so that their values are not overwritten by `src`.\n        inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.\n            This option is only available when `dst` is a `torch.nn.Module`.\n        filter_func: a filter function used to filter the weights to be loaded.\n            See 'filter_swinunetr' in \"monai.networks.nets.swin_unetr.py\".\n\n    Examples:\n        .. code-block:: python\n\n            from monai.networks.nets import BasicUNet\n            from monai.networks.utils import copy_model_state\n\n            model_a = BasicUNet(in_channels=1, out_channels=4)\n            model_b = BasicUNet(in_channels=1, out_channels=2)\n            model_a_b, changed, unchanged = copy_model_state(\n                model_a, model_b, exclude_vars=\"conv_0.conv_0\", inplace=False)\n            # dst model updated: 76 of 82 variables.\n            model_a.load_state_dict(model_a_b)\n            # <All keys matched successfully>\n\n    Returns: an OrderedDict of the updated `dst` state, the changed, and unchanged keys.\n\n    \"\"\"\n    src_dict = get_state_dict(src)\n    dst_dict = OrderedDict(get_state_dict(dst))\n\n    to_skip = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)}\n\n    # update dst with items from src\n    all_keys, updated_keys = list(dst_dict), list()\n    for s, val in src_dict.items():\n        dst_key = f\"{dst_prefix}{s}\"\n        if dst_key in dst_dict and dst_key not in to_skip and dst_dict[dst_key].shape == val.shape:\n            dst_dict[dst_key] = val\n            updated_keys.append(dst_key)\n    for s in mapping if mapping else {}:\n        dst_key = f\"{dst_prefix}{mapping[s]}\"\n        if dst_key in dst_dict and dst_key not in to_skip:\n            if dst_dict[dst_key].shape != src_dict[s].shape:\n                warnings.warn(f\"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.\")\n            dst_dict[dst_key] = src_dict[s]\n            updated_keys.append(dst_key)\n    if filter_func is not None:\n        for key, value in src_dict.items():\n            new_pair = filter_func(key, value)\n            if new_pair is not None and new_pair[0] not in to_skip:\n                dst_dict[new_pair[0]] = new_pair[1]\n                updated_keys.append(new_pair[0])\n\n    updated_keys = sorted(set(updated_keys))\n    unchanged_keys = sorted(set(all_keys).difference(updated_keys))\n    logger.info(f\"'dst' model updated: {len(updated_keys)} of {len(dst_dict)} variables.\")\n    if inplace and isinstance(dst, torch.nn.Module):\n        if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)):\n            dst = dst.module\n        dst.load_state_dict(dst_dict)  # type: ignore\n    return dst_dict, updated_keys, unchanged_keys\n\n\ndef save_state(src: torch.nn.Module | dict, path: PathLike, **kwargs):\n    \"\"\"\n    Save the state dict of input source data with PyTorch `save`.\n    It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.\n    And automatically convert the data parallel module to regular module.\n    For example::\n\n        save_state(net, path)\n        save_state(net.state_dict(), path)\n        save_state({\"net\": net, \"opt\": opt}, path)\n        net_dp = torch.nn.DataParallel(net)\n        save_state(net_dp, path)\n\n    Refer to: https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.DiskSaver.html.\n\n    Args:\n        src: input data to save, can be `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.\n        path: target file path to save the input object.\n        kwargs: other args for the `save_obj` except for the `obj` and `path`.\n            default `func` is `torch.save()`, details of the args:\n            https://pytorch.org/docs/stable/generated/torch.save.html.\n\n    \"\"\"\n\n    ckpt: dict = {}\n    if isinstance(src, dict):\n        for k, v in src.items():\n            ckpt[k] = get_state_dict(v)\n    else:\n        ckpt = get_state_dict(src)\n\n    save_obj(obj=ckpt, path=path, **kwargs)\n\n\ndef convert_to_onnx(\n    model: nn.Module,\n    inputs: Sequence[Any],\n    input_names: Sequence[str] | None = None,\n    output_names: Sequence[str] | None = None,\n    opset_version: int | None = None,\n    dynamic_axes: Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None = None,\n    filename: Any | None = None,\n    verify: bool = False,\n    device: torch.device | None = None,\n    use_ort: bool = False,\n    ort_provider: Sequence[str] | None = None,\n    rtol: float = 1e-4,\n    atol: float = 0.0,\n    use_trace: bool = True,\n    do_constant_folding: bool = True,\n    constant_size_threshold: int = 16 * 1024 * 1024 * 1024,\n    **kwargs,\n):\n    \"\"\"\n    Utility to convert a model into ONNX model and optionally verify with ONNX or onnxruntime.\n    See also: https://pytorch.org/docs/stable/onnx.html for how to convert a PyTorch model to ONNX.\n\n    Args:\n        model: source PyTorch model to save.\n        inputs: input sample data used by pytorch.onnx.export. It is also used in ONNX model verification.\n        input_names: optional input names of the ONNX model.\n        output_names: optional output names of the ONNX model.\n        opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and not exceed\n        the latest opset version supported by PyTorch, for more details:\n            https://github.com/onnx/onnx/blob/main/docs/Operators.md and\n            https://github.com/pytorch/pytorch/blob/master/torch/onnx/_constants.py\n        dynamic_axes: specifies axes of tensors as dynamic (i.e. known only at run-time). If set to None,\n            the exported model will have the shapes of all input and output tensors set to match given\n            ones, for more details: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export.\n        filename: optional filename to save the ONNX model, if None, don't save the ONNX model.\n        verify: whether to verify the ONNX model with ONNX or onnxruntime.\n        device: target PyTorch device to verify the model, if None, use CUDA if available.\n        use_ort: whether to use onnxruntime to verify the model.\n        ort_provider\": onnxruntime provider to use, default is [\"CPUExecutionProvider\"].\n        rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.\n        atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.\n        use_trace: whether to use `torch.jit.trace` to export the torchscript model.\n        do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.\n        constant_size_threshold: passed to polygrapy conatant forling, default = 16M\n        kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()\n            else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:\n            https://pytorch.org/docs/master/generated/torch.jit.script.html.\n\n    \"\"\"\n    model.eval()\n    with torch.no_grad():\n        torch_versioned_kwargs = {}\n        if use_trace:\n            # let torch.onnx.export to trace the model.\n            model_to_export = model\n            torch_versioned_kwargs = kwargs\n            if \"dynamo\" in kwargs and kwargs[\"dynamo\"] and verify:\n                torch_versioned_kwargs[\"verify\"] = verify\n                verify = False\n        else:\n            # The dynamo-based ONNX exporter (torch.export) does not support ScriptModule.\n            # PyTorch 2.6–2.8: dynamo is available but NOT the default; TorchScript exporter\n            #   remains the default, so we must still script the model here.\n            # PyTorch 2.9+: dynamo became the default exporter for torch.onnx.export;\n            #   pass the raw nn.Module directly—the exporter handles it via torch.export.\n            _pt_major_minor = tuple(int(x) for x in torch.__version__.split(\"+\")[0].split(\".\")[:2])\n            if _pt_major_minor >= (2, 9):\n                model_to_export = model\n            else:\n                model_to_export = torch.jit.script(model, **kwargs)\n\n        if torch.is_tensor(inputs) or isinstance(inputs, dict):\n            onnx_inputs = (inputs,)\n        else:\n            onnx_inputs = tuple(inputs)\n        temp_file = None\n        if filename is None:\n            temp_file = tempfile.NamedTemporaryFile()\n            f = temp_file.name\n        else:\n            f = filename\n        torch.onnx.export(\n            model_to_export,\n            onnx_inputs,\n            f=f,\n            input_names=input_names,\n            output_names=output_names or None,\n            dynamic_axes=dynamic_axes,\n            opset_version=opset_version,\n            do_constant_folding=do_constant_folding,\n            **torch_versioned_kwargs,\n        )\n        onnx_model = onnx.load(f)\n\n    if do_constant_folding and polygraphy_imported:\n        from polygraphy.backend.onnx.loader import fold_constants, save_onnx\n\n        onnx_model = fold_constants(onnx_model, size_threshold=constant_size_threshold)\n        save_onnx(onnx_model, f)\n\n    if verify:\n        if isinstance(inputs, dict):\n            inputs = list(inputs.values())\n\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs]\n        model = model.to(device)\n\n        with torch.no_grad():\n            set_determinism(seed=0)\n            torch_out = ensure_tuple(model(*inputs), True)\n\n        set_determinism(seed=0)\n        model_input_names = [i.name for i in onnx_model.graph.input]\n        input_dict = dict(zip(model_input_names, [i.cpu().numpy() for i in inputs]))\n        if use_ort:\n            ort_sess = onnxruntime.InferenceSession(\n                onnx_model.SerializeToString(), providers=ort_provider if ort_provider else [\"CPUExecutionProvider\"]\n            )\n            onnx_out = ort_sess.run(None, input_dict)\n        else:\n            sess = onnxreference.ReferenceEvaluator(onnx_model)\n            onnx_out = sess.run(None, input_dict)\n        set_determinism(seed=None)\n        # compare onnx/ort and PyTorch results\n        for r1, r2 in zip(torch_out, onnx_out):\n            if isinstance(r1, torch.Tensor):\n                torch.testing.assert_close(r1.cpu(), convert_to_tensor(r2, dtype=r1.dtype), rtol=rtol, atol=atol)  # type: ignore\n\n    return onnx_model\n\n\ndef convert_to_torchscript(\n    model: nn.Module,\n    filename_or_obj: Any | None = None,\n    extra_files: dict | None = None,\n    verify: bool = False,\n    inputs: Sequence[Any] | None = None,\n    device: torch.device | None = None,\n    rtol: float = 1e-4,\n    atol: float = 0.0,\n    use_trace: bool = False,\n    **kwargs,\n):\n    \"\"\"\n    Utility to convert a model into TorchScript model and save to file,\n    with optional input / output data verification.\n\n    Args:\n        model: source PyTorch model to save.\n        filename_or_obj: if not None, specify a file-like object (has to implement write and flush)\n            or a string containing a file path name to save the TorchScript model.\n        extra_files: map from filename to contents which will be stored as part of the save model file.\n            for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html.\n        verify: whether to verify the input and output of TorchScript model.\n            if `filename_or_obj` is not None, load the saved TorchScript model and verify.\n        inputs: input test data to verify model, should be a sequence of data, every item maps to a argument\n            of `model()` function.\n        device: target device to verify the model, if None, use CUDA if available.\n        rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.\n        atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.\n        use_trace: whether to use `torch.jit.trace` to export the TorchScript model.\n        kwargs: other arguments except `obj` for `torch.jit.script()` or `torch.jit.trace()` (if use_trace is True)\n            to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html.\n\n    \"\"\"\n    model.eval()\n    with torch.no_grad():\n        if use_trace:\n            if inputs is None:\n                raise ValueError(\"Missing input data for tracing convert.\")\n            script_module = torch.jit.trace(model, example_inputs=inputs, **kwargs)\n        else:\n            script_module = torch.jit.script(model, **kwargs)\n        if filename_or_obj is not None:\n            torch.jit.save(m=script_module, f=filename_or_obj, _extra_files=extra_files)\n\n    if verify:\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        if inputs is None:\n            raise ValueError(\"Missing input data for verification.\")\n\n        inputs = [i.to(device) if isinstance(i, torch.Tensor) else i for i in inputs]\n        ts_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else script_module\n        ts_model.eval().to(device)\n        model = model.to(device)\n\n        with torch.no_grad():\n            set_determinism(seed=0)\n            torch_out = ensure_tuple(model(*inputs))\n            set_determinism(seed=0)\n            torchscript_out = ensure_tuple(ts_model(*inputs))\n            set_determinism(seed=None)\n        # compare TorchScript and PyTorch results\n        for r1, r2 in zip(torch_out, torchscript_out):\n            if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):\n                torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol)  # type: ignore\n\n    return script_module\n\n\ndef _onnx_trt_compile(\n    onnx_model,\n    min_shape: Sequence[int],\n    opt_shape: Sequence[int],\n    max_shape: Sequence[int],\n    device: int,\n    precision: str,\n    input_names: Sequence[str] | None,\n    output_names: Sequence[str] | None,\n):\n    \"\"\"\n    This function takes an ONNX model as input, exports it to a TensorRT engine, wraps the TensorRT engine\n    to a TensorRT engine-based TorchScript model and return the TorchScript model.\n\n    Args:\n        onnx_model: the source ONNX model to compile.\n        min_shape: the minimum input shape of the converted TensorRT model.\n        opt_shape: the optimization input shape of the model, on which the TensorRT optimizes.\n        max_shape: the maximum input shape of the converted TensorRT model.\n        device: the target GPU index to convert and verify the model.\n        precision: the weight precision of the converted TensorRT engine-based TorchScript model.\n            Should be 'fp32' or 'fp16'.\n        input_names: optional input names of the ONNX model. Should be a sequence like\n            `['input_0', 'input_1', ..., 'input_N']` where N equals to the number of the\n            model inputs.\n        output_names: optional output names of the ONNX model. Should be a sequence like\n            `['output_0', 'output_1', ..., 'output_N']` where N equals to the number of\n            the model outputs.\n\n    \"\"\"\n    trt, _ = optional_import(\"tensorrt\", \"8.5.3\")\n\n    input_shapes = (min_shape, opt_shape, max_shape)\n    # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.\n    input_names = [] if not input_names else input_names\n    output_names = [] if not output_names else output_names\n\n    # set up the TensorRT builder\n    torch.cuda.set_device(device)\n    logger = trt.Logger(trt.Logger.WARNING)\n    builder = trt.Builder(logger)\n    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n    profile = builder.create_optimization_profile()\n    if input_names:\n        profile.set_shape(input_names[0], *input_shapes)\n\n    # parse the ONNX model\n    parser = trt.OnnxParser(network, logger)\n    success = parser.parse(onnx_model.SerializeToString())\n    if not success:\n        parser_error_message = \"\"\n        for idx in range(parser.num_errors):\n            parser_error_message += parser.get_error(idx).desc() + \"\\n\"\n        raise Exception(f\"TensorRT cannot parse the ONNX model, due to:\\n{parser_error_message}\")\n\n    # set up the conversion configuration\n    config = builder.create_builder_config()\n    config.add_optimization_profile(profile)\n    if precision == \"fp16\":\n        config.set_flag(trt.BuilderFlag.FP16)\n    serialized_engine = builder.build_serialized_network(network, config)\n    f = io.BytesIO()\n    f.write(serialized_engine)\n\n    # wrap the serialized TensorRT engine back to a TorchScript module.\n    trt_model = torch_tensorrt.ts.embed_engine_in_new_module(\n        f.getvalue(),\n        device=torch_tensorrt.Device(f\"cuda:{device}\"),\n        input_binding_names=input_names,\n        output_binding_names=output_names,\n    )\n    return trt_model\n\n\ndef convert_to_trt(\n    model: nn.Module,\n    precision: str,\n    input_shape: Sequence[int],\n    dynamic_batchsize: Sequence[int] | None = None,\n    use_trace: bool = False,\n    filename_or_obj: Any | None = None,\n    verify: bool = False,\n    device: int | None = None,\n    use_onnx: bool | None = False,\n    onnx_input_names: Sequence[str] | None = (\"input_0\",),\n    onnx_output_names: Sequence[str] | None = (\"output_0\",),\n    rtol: float = 1e-2,\n    atol: float = 0.0,\n    **kwargs,\n):\n    \"\"\"\n    Utility to export a model into a TensorRT engine-based TorchScript model with optional input / output data verification.\n\n    There are two ways to export a model:\n    1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.\n    2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->\n    TensorRT engine-based TorchScript.\n\n    When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT\n    may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through\n    the second way, some Python data structures like `dict` are not supported. And some TorchScript models are\n    not supported by the ONNX if exported through `torch.jit.script`.\n\n    Args:\n        model: a source PyTorch model to convert.\n        precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.\n        input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or\n            [N, C, H, W, D].\n        dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be\n            converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of model\n            input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize that the\n            TensorRT tries to fit. The `OPT_BATCH` should be the most frequently used input batchsize in the application,\n            default to None.\n        use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to\n            a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True), default to False.\n        filename_or_obj: if not None, specify a file-like object (has to implement write and flush) or a string containing a\n            file path name to load the TensorRT engine based TorchScript model for verifying.\n        verify: whether to verify the input and output of the TensorRT engine based TorchScript model.\n        device: the target GPU index to convert and verify the model. If None, use #0 GPU.\n        use_onnx: whether to use the ONNX-TensorRT way to export the TensorRT engine-based TorchScript model.\n        onnx_input_names: optional input names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be\n            a sequence like `('input_0', 'input_1', ..., 'input_N')` where N equals to the number of the model inputs. If not\n            given, will use `('input_0',)`, which supposes the model only has one input.\n        onnx_output_names: optional output names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be\n            a sequence like `('output_0', 'output_1', ..., 'output_N')` where N equals to the number of the model outputs. If\n            not given, will use `('output_0',)`, which supposes the model only has one output.\n        rtol: the relative tolerance when comparing the outputs between the PyTorch model and TensorRT model.\n        atol: the absolute tolerance when comparing the outputs between the PyTorch model and TensorRT model.\n        kwargs: other arguments except `module`, `inputs`, `enabled_precisions` and `device` for `torch_tensorrt.compile()`\n            to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.\n    \"\"\"\n\n    if not torch.cuda.is_available():\n        raise Exception(\"Cannot find any GPU devices.\")\n\n    if not input_shape:\n        raise ValueError(\"Missing the input shape for model convert.\")\n\n    if not dynamic_batchsize:\n        warnings.warn(f\"There is no dynamic batch range. The converted model only takes {input_shape} shape input.\")\n\n    if (dynamic_batchsize is not None) and (len(dynamic_batchsize) != 3):\n        warnings.warn(f\"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.\")\n\n    device = device if device else 0\n    target_device = torch.device(f\"cuda:{device}\")\n    convert_precision = torch.float32 if precision == \"fp32\" else torch.half\n    inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]\n\n    # convert the torch model to a TorchScript model on target device\n    model = model.eval().to(target_device)\n    min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)\n\n    if use_onnx:\n        # set the batch dim as dynamic\n        dynamic_axes = {k: {0: \"batchsize\"} for k in onnx_input_names} if onnx_input_names else {}\n        dynamic_axes.update({k: {0: \"batchsize\"} for k in onnx_output_names} if onnx_output_names else {})\n        ir_model = convert_to_onnx(\n            model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes\n        )\n        # convert the model through the ONNX-TensorRT way\n        trt_model = _onnx_trt_compile(\n            ir_model,\n            min_shape=min_input_shape,\n            opt_shape=opt_input_shape,\n            max_shape=max_input_shape,\n            device=device,\n            precision=precision,\n            input_names=onnx_input_names,\n            output_names=onnx_output_names,\n        )\n    else:\n        ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)\n        ir_model.eval()\n        # convert the model through the Torch-TensorRT way\n        ir_model.to(target_device)\n        with torch.no_grad():\n            with torch.cuda.device(device=device):\n                input_placeholder = [\n                    torch_tensorrt.Input(\n                        min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape\n                    )\n                ]\n                trt_model = torch_tensorrt.compile(\n                    ir_model,\n                    inputs=input_placeholder,\n                    enabled_precisions=convert_precision,\n                    device=torch_tensorrt.Device(f\"cuda:{device}\"),\n                    ir=\"torchscript\",\n                    **kwargs,\n                )\n\n    # verify the outputs between the TensorRT model and PyTorch model\n    if verify:\n        if inputs is None:\n            raise ValueError(\"Missing input data for verification.\")\n\n        trt_model = torch.jit.load(filename_or_obj) if filename_or_obj is not None else trt_model\n\n        with torch.no_grad():\n            set_determinism(seed=0)\n            torch_out = ensure_tuple(model(*inputs))\n            set_determinism(seed=0)\n            trt_out = ensure_tuple(trt_model(*inputs))\n            set_determinism(seed=None)\n        # compare TorchScript and PyTorch results\n        for r1, r2 in zip(torch_out, trt_out):\n            if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):\n                torch.testing.assert_close(r1, r2, rtol=rtol, atol=atol)  # type: ignore\n\n    return trt_model\n\n\ndef meshgrid_ij(*tensors):\n    if torch.meshgrid.__kwdefaults__ is not None and \"indexing\" in torch.meshgrid.__kwdefaults__:\n        return torch.meshgrid(*tensors, indexing=\"ij\")  # new api pytorch after 1.10\n\n    return torch.meshgrid(*tensors)\n\n\ndef meshgrid_xy(*tensors):\n    if torch.meshgrid.__kwdefaults__ is not None and \"indexing\" in torch.meshgrid.__kwdefaults__:\n        return torch.meshgrid(*tensors, indexing=\"xy\")  # new api pytorch after 1.10\n\n    return torch.meshgrid(tensors[1], tensors[0], *tensors[2:])\n\n\ndef _replace_modules(\n    parent: torch.nn.Module,\n    name: str,\n    new_module: torch.nn.Module,\n    out: list[tuple[str, torch.nn.Module]],\n    strict_match: bool = True,\n    match_device: bool = True,\n) -> None:\n    \"\"\"\n    Helper function for :py:class:`monai.networks.utils.replace_modules`.\n    \"\"\"\n    if match_device:\n        devices = list({i.device for i in parent.parameters()})\n        # if only one device for whole of model\n        if len(devices) == 1:\n            new_module.to(devices[0])\n    idx = name.find(\".\")\n    # if there is \".\" in name, call recursively\n    if idx != -1:\n        parent_name = name[:idx]\n        parent = getattr(parent, parent_name)\n        name = name[idx + 1 :]\n        _out: list[tuple[str, torch.nn.Module]] = []\n        _replace_modules(parent, name, new_module, _out)\n        # prepend the parent name\n        out += [(f\"{parent_name}.{r[0]}\", r[1]) for r in _out]\n    # no \".\" in module name, do the actual replacing\n    else:\n        if strict_match:\n            old_module = getattr(parent, name)\n            setattr(parent, name, new_module)\n            out += [(name, old_module)]\n        else:\n            for mod_name, _ in parent.named_modules():\n                if name in mod_name:\n                    _replace_modules(parent, mod_name, deepcopy(new_module), out, strict_match=True)\n\n\ndef replace_modules(\n    parent: torch.nn.Module,\n    name: str,\n    new_module: torch.nn.Module,\n    strict_match: bool = True,\n    match_device: bool = True,\n) -> list[tuple[str, torch.nn.Module]]:\n    \"\"\"\n    Replace sub-module(s) in a parent module.\n\n    The name of the module to be replace can be nested e.g.,\n    `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are \".\"\n    in the module name), then this function will recursively call itself.\n\n    Args:\n        parent: module that contains the module to be replaced\n        name: name of module to be replaced. Can include \".\".\n        new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will\n            be deep copied if `strict_match == False` multiple instances are independent.\n        strict_match: if `True`, module name must `== name`. If false then\n            `name in named_modules()` will be used. `True` can be used to change just\n            one module, whereas `False` can be used to replace all modules with similar\n            name (e.g., `relu`).\n        match_device: if `True`, the device of the new module will match the model. Requires all\n            of `parent` to be on the same device.\n\n    Returns:\n        List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module.\n\n    Raises:\n        AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`.\n    \"\"\"\n    out: list[tuple[str, torch.nn.Module]] = []\n    _replace_modules(parent, name, new_module, out, strict_match, match_device)\n    return out\n\n\n@contextmanager\ndef replace_modules_temp(\n    parent: torch.nn.Module,\n    name: str,\n    new_module: torch.nn.Module,\n    strict_match: bool = True,\n    match_device: bool = True,\n):\n    \"\"\"\n    Temporarily replace sub-module(s) in a parent module (context manager).\n\n    See :py:class:`monai.networks.utils.replace_modules`.\n    \"\"\"\n    replaced: list[tuple[str, torch.nn.Module]] = []\n    try:\n        # replace\n        _replace_modules(parent, name, new_module, replaced, strict_match, match_device)\n        yield\n    finally:\n        # revert\n        for name, module in replaced:\n            _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device)\n\n\ndef freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None):\n    \"\"\"\n    A utility function to help freeze specific layers.\n\n    Args:\n        model: a source PyTorch model to freeze layer.\n        freeze_vars: a regular expression to match the `model` variable names,\n            so that their `requires_grad` will set to `False`.\n        exclude_vars: a regular expression to match the `model` variable names,\n            except for matched variable names, other `requires_grad` will set to `False`.\n\n    Raises:\n        ValueError: when freeze_vars and exclude_vars are both specified.\n\n    \"\"\"\n    if freeze_vars is not None and exclude_vars is not None:\n        raise ValueError(\"Incompatible values: freeze_vars and exclude_vars are both specified.\")\n    src_dict = get_state_dict(model)\n\n    frozen_keys = list()\n    if freeze_vars is not None:\n        to_freeze = {s_key for s_key in src_dict if freeze_vars and re.compile(freeze_vars).search(s_key)}\n        for name, param in model.named_parameters():\n            if name in to_freeze:\n                param.requires_grad = False\n                frozen_keys.append(name)\n            elif not param.requires_grad:\n                param.requires_grad = True\n                warnings.warn(\n                    f\"The freeze_vars does not include {param}, but requires_grad is False, change it to True.\"\n                )\n    if exclude_vars is not None:\n        to_exclude = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)}\n        for name, param in model.named_parameters():\n            if name not in to_exclude:\n                param.requires_grad = False\n                frozen_keys.append(name)\n            elif not param.requires_grad:\n                param.requires_grad = True\n                warnings.warn(f\"The exclude_vars includes {param}, but requires_grad is False, change it to True.\")\n\n    logger.info(f\"{len(frozen_keys)} of {len(src_dict)} variables frozen.\")\n\n\nclass CastTempType(nn.Module):\n    \"\"\"\n    Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type.\n    \"\"\"\n\n    def __init__(self, initial_type, temporary_type, submodule):\n        super().__init__()\n        self.initial_type = initial_type\n        self.temporary_type = temporary_type\n        self.submodule = submodule\n\n    def forward(self, x):\n        dtype = x.dtype\n        if dtype == self.initial_type:\n            x = x.to(self.temporary_type)\n        x = self.submodule(x)\n        if dtype == self.initial_type:\n            x = x.to(self.initial_type)\n        return x\n\n\ndef cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):\n    \"\"\"\n    Utility function to cast a single tensor from from_dtype to to_dtype\n    \"\"\"\n    return x.to(dtype=to_dtype) if x.dtype == from_dtype else x\n\n\ndef cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):\n    \"\"\"\n    Utility function to cast all tensors in a tuple from from_dtype to to_dtype\n    \"\"\"\n    if isinstance(x, torch.Tensor):\n        return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)\n    else:\n        if isinstance(x, dict):\n            new_dict = {}\n            for k in x.keys():\n                new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)\n            return new_dict\n        elif isinstance(x, tuple):\n            return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)\n\n\nclass CastToFloat(torch.nn.Module):\n    \"\"\"\n    Class used to add autocast protection for ONNX export\n    for forward methods with single return value\n    \"\"\"\n\n    def __init__(self, mod):\n        super().__init__()\n        self.mod = mod\n\n    def forward(self, x):\n        dtype = x.dtype\n        with torch.autocast(\"cuda\", enabled=False):\n            ret = self.mod.forward(x.to(torch.float32)).to(dtype)\n        return ret\n\n\nclass CastToFloatAll(torch.nn.Module):\n    \"\"\"\n    Class used to add autocast protection for ONNX export\n    for forward methods with multiple return values\n    \"\"\"\n\n    def __init__(self, mod):\n        super().__init__()\n        self.mod = mod\n\n    def forward(self, *args):\n        from_dtype = args[0].dtype\n        with torch.autocast(\"cuda\", enabled=False):\n            ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))\n        return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)\n\n\ndef wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:\n    \"\"\"\n    Generic function generator to replace base_t module with dest_t wrapper.\n    Args:\n        base_t : module type to replace\n        dest_t : destination module type\n    Returns:\n        swap function to replace base_t module with dest_t\n    \"\"\"\n\n    def expansion_fn(mod: nn.Module) -> nn.Module | None:\n        out = dest_t(mod)\n        return out\n\n    return expansion_fn\n\n\ndef simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:\n    \"\"\"\n    Generic function generator to replace base_t module with dest_t.\n    base_t and dest_t should have same atrributes. No weights are copied.\n    Args:\n        base_t : module type to replace\n        dest_t : destination module type\n    Returns:\n        swap function to replace base_t module with dest_t\n    \"\"\"\n\n    def expansion_fn(mod: nn.Module) -> nn.Module | None:\n        if not isinstance(mod, base_t):\n            return None\n        constants: Iterable = mod.__constants__  # type: ignore[assignment]\n        args = [getattr(mod, name, None) for name in constants]\n        out = dest_t(*args)\n        return out\n\n    return expansion_fn\n\n\ndef _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module:\n    \"\"\"\n    This function swaps nested modules as specified by \"dot paths\" in mod with a desired replacement. This allows\n    for swapping nested modules through arbitrary levels if children\n\n    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.\n\n    \"\"\"\n    for path, new_mod in mapping.items():\n        expanded_path = path.split(\".\")\n        parent_mod = model\n        for sub_path in expanded_path[:-1]:\n            submod = parent_mod._modules[sub_path]\n            if submod is None:\n                break\n            else:\n                parent_mod = submod\n        parent_mod._modules[expanded_path[-1]] = new_mod\n\n    return model\n\n\ndef replace_modules_by_type(\n    model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]\n) -> nn.Module:\n    \"\"\"\n    Top-level function to replace modules in model, specified by class name with a desired replacement.\n    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.\n    Args:\n        model : top level module\n        expansions : replacement dictionary: module class name -> replacement function generator\n    Returns:\n        model, possibly modified in-place\n    \"\"\"\n    mapping: dict[str, nn.Module] = {}\n    for name, m in model.named_modules():\n        m_type = type(m).__name__\n        if m_type in expansions:\n            # print (f\"Found {m_type} in expansions ...\")\n            swapped = expansions[m_type](m)\n            if swapped:\n                mapping[name] = swapped\n\n    print(f\"Swapped {len(mapping)} modules\")\n    _swap_modules(model, mapping)\n    return model\n\n\ndef add_casts_around_norms(model: nn.Module) -> nn.Module:\n    \"\"\"\n    Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export\n    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.\n    Args:\n        model : top level module\n    Returns:\n        model, possibly modified in-place\n    \"\"\"\n    print(\"Adding casts around norms...\")\n    cast_replacements = {\n        \"BatchNorm1d\": wrap_module(nn.BatchNorm1d, CastToFloat),\n        \"BatchNorm2d\": wrap_module(nn.BatchNorm2d, CastToFloat),\n        \"BatchNorm3d\": wrap_module(nn.BatchNorm2d, CastToFloat),\n        \"LayerNorm\": wrap_module(nn.LayerNorm, CastToFloat),\n        \"InstanceNorm1d\": wrap_module(nn.InstanceNorm1d, CastToFloat),\n        \"InstanceNorm3d\": wrap_module(nn.InstanceNorm3d, CastToFloat),\n    }\n    replace_modules_by_type(model, cast_replacements)\n    return model\n"
  },
  {
    "path": "monai/optimizers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .lr_finder import LearningRateFinder\nfrom .lr_scheduler import ExponentialLR, LinearLR, WarmupCosineSchedule\nfrom .novograd import Novograd\nfrom .utils import generate_param_groups\n"
  },
  {
    "path": "monai/optimizers/lr_finder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport pickle\nimport types\nimport warnings\nfrom functools import partial\nfrom typing import TYPE_CHECKING, Any, Callable\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.optim import Optimizer\nfrom torch.serialization import DEFAULT_PROTOCOL\nfrom torch.utils.data import DataLoader\n\nfrom monai.networks.utils import eval_mode\nfrom monai.optimizers.lr_scheduler import ExponentialLR, LinearLR\nfrom monai.utils import StateCacher, copy_to_device, optional_import\n\nif TYPE_CHECKING:\n    import matplotlib.pyplot as plt\n\n    has_matplotlib = True\n    import tqdm\n\n    has_tqdm = True\nelse:\n    plt, has_matplotlib = optional_import(\"matplotlib.pyplot\")\n    tqdm, has_tqdm = optional_import(\"tqdm\")\n\n__all__ = [\"LearningRateFinder\"]\n\n\nclass DataLoaderIter:\n\n    def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None:\n        if not isinstance(data_loader, DataLoader):\n            raise ValueError(\n                f\"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`\"\n            )\n        self.data_loader = data_loader\n        self._iterator = iter(data_loader)\n        self.image_extractor = image_extractor\n        self.label_extractor = label_extractor\n\n    @property\n    def dataset(self):\n        return self.data_loader.dataset\n\n    def inputs_labels_from_batch(self, batch_data):\n        images = self.image_extractor(batch_data)\n        labels = self.label_extractor(batch_data)\n        return images, labels\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        batch = next(self._iterator)\n        return self.inputs_labels_from_batch(batch)\n\n\nclass TrainDataLoaderIter(DataLoaderIter):\n\n    def __init__(\n        self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True\n    ) -> None:\n        super().__init__(data_loader, image_extractor, label_extractor)\n        self.auto_reset = auto_reset\n\n    def __next__(self):\n        try:\n            batch = next(self._iterator)\n            inputs, labels = self.inputs_labels_from_batch(batch)\n        except StopIteration:\n            if not self.auto_reset:\n                raise\n            self._iterator = iter(self.data_loader)\n            batch = next(self._iterator)\n            inputs, labels = self.inputs_labels_from_batch(batch)\n\n        return inputs, labels\n\n\nclass ValDataLoaderIter(DataLoaderIter):\n    \"\"\"This iterator will reset itself **only** when it is acquired by\n    the syntax of normal `iterator`. That is, this iterator just works\n    like a `torch.data.DataLoader`. If you want to restart it, you\n    should use it like:\n\n        ```\n        loader_iter = ValDataLoaderIter(data_loader)\n        for batch in loader_iter:\n            ...\n\n        # `loader_iter` should run out of values now, you can restart it by:\n        # 1. the way we use a `torch.data.DataLoader`\n        for batch in loader_iter:        # __iter__ is called implicitly\n            ...\n\n        # 2. passing it into `iter()` manually\n        loader_iter = iter(loader_iter)  # __iter__ is called by `iter()`\n        ```\n    \"\"\"\n\n    def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None:\n        super().__init__(data_loader, image_extractor, label_extractor)\n        self.run_limit = len(self.data_loader)\n        self.run_counter = 0\n\n    def __iter__(self):\n        if self.run_counter >= self.run_limit:\n            self._iterator = iter(self.data_loader)\n            self.run_counter = 0\n        return self\n\n    def __next__(self):\n        self.run_counter += 1\n        return super().__next__()\n\n\ndef default_image_extractor(x: Any) -> torch.Tensor:\n    \"\"\"Default callable for getting image from batch data.\"\"\"\n    out: torch.Tensor = x[\"image\"] if isinstance(x, dict) else x[0]\n    return out\n\n\ndef default_label_extractor(x: Any) -> torch.Tensor:\n    \"\"\"Default callable for getting label from batch data.\"\"\"\n    out: torch.Tensor = x[\"label\"] if isinstance(x, dict) else x[1]\n    return out\n\n\nclass LearningRateFinder:\n    \"\"\"Learning rate range test.\n\n    The learning rate range test increases the learning rate in a pre-training run\n    between two boundaries in a linear or exponential manner. It provides valuable\n    information on how well the network can be trained over a range of learning rates\n    and what is the optimal learning rate.\n\n    Example (fastai approach):\n    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)\n    >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100)\n    >>> lr_finder.get_steepest_gradient()\n    >>> lr_finder.plot() # to inspect the loss-learning rate graph\n\n    Example (Leslie Smith's approach):\n    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)\n    >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode=\"linear\")\n\n    Gradient accumulation is supported; example:\n    >>> train_data = ...    # prepared dataset\n    >>> desired_bs, real_bs = 32, 4         # batch size\n    >>> accumulation_steps = desired_bs // real_bs     # required steps for accumulation\n    >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True)\n    >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion)\n    >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps)\n\n    By default, image will be extracted from data loader with x[\"image\"] and x[0], depending on whether\n    batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader\n    returns something other than this, pass a callable function to extract it, e.g.:\n    >>> image_extractor = lambda x: x[\"input\"]\n    >>> label_extractor = lambda x: x[100]\n    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)\n    >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor)\n\n    References:\n    Modified from: https://github.com/davidtvs/pytorch-lr-finder.\n    Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186\n    \"\"\"\n\n    def __init__(\n        self,\n        model: nn.Module,\n        optimizer: Optimizer,\n        criterion: torch.nn.Module,\n        device: str | torch.device | None = None,\n        memory_cache: bool = True,\n        cache_dir: str | None = None,\n        amp: bool = False,\n        pickle_module: types.ModuleType = pickle,\n        pickle_protocol: int = DEFAULT_PROTOCOL,\n        verbose: bool = True,\n    ) -> None:\n        \"\"\"Constructor.\n\n        Args:\n            model: wrapped model.\n            optimizer: wrapped optimizer.\n            criterion: wrapped loss function.\n            device: device on which to test. run a string (\"cpu\" or \"cuda\") with an\n                optional ordinal for the device type (e.g. \"cuda:X\", where is the ordinal).\n                Alternatively, can be an object representing the device on which the\n                computation will take place. Default: None, uses the same device as `model`.\n            memory_cache: if this flag is set to True, `state_dict` of\n                model and optimizer will be cached in memory. Otherwise, they will be saved\n                to files under the `cache_dir`.\n            cache_dir: path for storing temporary files. If no path is\n                specified, system-wide temporary directory is used. Notice that this\n                parameter will be ignored if `memory_cache` is True.\n            amp: use Automatic Mixed Precision\n            pickle_module: module used for pickling metadata and objects, default to `pickle`.\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            pickle_protocol: can be specified to override the default protocol, default to `2`.\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            verbose: verbose output\n        Returns:\n            None\n        \"\"\"\n        # Check if the optimizer is already attached to a scheduler\n        self.optimizer = optimizer\n        self._check_for_scheduler()\n\n        self.model = model\n        self.criterion = criterion\n        self.history: dict[str, list] = {\"lr\": [], \"loss\": []}\n        self.memory_cache = memory_cache\n        self.cache_dir = cache_dir\n        self.amp = amp\n        self.verbose = verbose\n\n        # Save the original state of the model and optimizer so they can be restored if\n        # needed\n        self.model_device = next(self.model.parameters()).device\n        self.state_cacher = StateCacher(\n            in_memory=memory_cache, cache_dir=cache_dir, pickle_module=pickle_module, pickle_protocol=pickle_protocol\n        )\n        self.state_cacher.store(\"model\", self.model.state_dict())\n        self.state_cacher.store(\"optimizer\", self.optimizer.state_dict())\n\n        # If device is None, use the same as the model\n        self.device = device if device else self.model_device\n\n    def reset(self) -> None:\n        \"\"\"Restores the model and optimizer to their initial states.\"\"\"\n\n        self.model.load_state_dict(self.state_cacher.retrieve(\"model\"))\n        self.optimizer.load_state_dict(self.state_cacher.retrieve(\"optimizer\"))\n        self.model.to(self.model_device)\n\n    def range_test(\n        self,\n        train_loader: DataLoader,\n        val_loader: DataLoader | None = None,\n        image_extractor: Callable = default_image_extractor,\n        label_extractor: Callable = default_label_extractor,\n        start_lr: float | None = None,\n        end_lr: float = 10.0,\n        num_iter: int = 100,\n        step_mode: str = \"exp\",\n        smooth_f: float = 0.05,\n        diverge_th: int = 5,\n        accumulation_steps: int = 1,\n        non_blocking_transfer: bool = True,\n        auto_reset: bool = True,\n    ) -> None:\n        \"\"\"Performs the learning rate range test.\n\n        Args:\n            train_loader: training set data loader.\n            val_loader: validation data loader (if desired).\n            image_extractor: callable function to get the image from a batch of data.\n                Default: `x[\"image\"] if isinstance(x, dict) else x[0]`.\n            label_extractor: callable function to get the label from a batch of data.\n                Default: `x[\"label\"] if isinstance(x, dict) else x[1]`.\n            start_lr : the starting learning rate for the range test.\n                The default is the optimizer's learning rate.\n            end_lr: the maximum learning rate to test. The test may stop earlier than\n                this if the result starts diverging.\n            num_iter: the max number of iterations for test.\n            step_mode: schedule for increasing learning rate: (`linear` or `exp`).\n            smooth_f: the loss smoothing factor within the `[0, 1[` interval. Disabled\n                if set to `0`, otherwise loss is smoothed using exponential smoothing.\n            diverge_th: test is stopped when loss surpasses threshold:\n                `diverge_th * best_loss`.\n            accumulation_steps: steps for gradient accumulation. If set to `1`,\n                gradients are not accumulated.\n            non_blocking_transfer: when `True`, moves data to device asynchronously if\n                possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.\n            auto_reset: if `True`, returns model and optimizer to original states at end\n                of test.\n        Returns:\n            None\n        \"\"\"\n\n        # Reset test results\n        self.history = {\"lr\": [], \"loss\": []}\n        best_loss = -float(\"inf\")\n\n        # Move the model to the proper device\n        self.model.to(self.device)\n\n        # Check if the optimizer is already attached to a scheduler\n        self._check_for_scheduler()\n\n        # Set the starting learning rate\n        if start_lr:\n            self._set_learning_rate(start_lr)\n\n        # Check number of iterations\n        if num_iter <= 1:\n            raise ValueError(\"`num_iter` must be larger than 1\")\n\n        # Initialize the proper learning rate policy\n        lr_schedule: ExponentialLR | LinearLR\n        if step_mode.lower() == \"exp\":\n            lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter)\n        elif step_mode.lower() == \"linear\":\n            lr_schedule = LinearLR(self.optimizer, end_lr, num_iter)\n        else:\n            raise ValueError(f\"expected one of (exp, linear), got {step_mode}\")\n\n        if smooth_f < 0 or smooth_f >= 1:\n            raise ValueError(\"smooth_f is outside the range [0, 1[\")\n\n        # Create an iterator to get data batch by batch\n        train_iter = TrainDataLoaderIter(train_loader, image_extractor, label_extractor)\n        if val_loader:\n            val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor)\n\n        trange: partial[tqdm.trange] | type[range]\n        if self.verbose and has_tqdm:\n            trange = partial(tqdm.trange, desc=\"Computing optimal learning rate\")\n            tprint = tqdm.tqdm.write\n        else:\n            trange = range\n            tprint = print\n\n        for iteration in trange(num_iter):\n            if self.verbose and not has_tqdm:\n                print(f\"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}\")\n\n            # Train on batch and retrieve loss\n            loss = self._train_batch(train_iter, accumulation_steps, non_blocking_transfer=non_blocking_transfer)\n            if val_loader:\n                loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer)\n\n            # Update the learning rate\n            self.history[\"lr\"].append(lr_schedule.get_lr()[0])\n            lr_schedule.step()\n\n            # Track the best loss and smooth it if smooth_f is specified\n            if iteration == 0:\n                best_loss = loss\n            else:\n                if smooth_f > 0:\n                    loss = smooth_f * loss + (1 - smooth_f) * self.history[\"loss\"][-1]\n                if loss < best_loss:\n                    best_loss = loss\n\n            # Check if the loss has diverged; if it has, stop the test\n            self.history[\"loss\"].append(loss)\n            if loss > diverge_th * best_loss:\n                if self.verbose:\n                    tprint(\"Stopping early, the loss has diverged\")\n                break\n\n        if auto_reset:\n            if self.verbose:\n                print(\"Resetting model and optimizer\")\n            self.reset()\n\n    def _set_learning_rate(self, new_lrs: float | list) -> None:\n        \"\"\"Set learning rate(s) for optimizer.\"\"\"\n        if not isinstance(new_lrs, list):\n            new_lrs = [new_lrs] * len(self.optimizer.param_groups)\n        if len(new_lrs) != len(self.optimizer.param_groups):\n            raise ValueError(\n                \"Length of `new_lrs` is not equal to the number of parameter groups \" + \"in the given optimizer\"\n            )\n\n        for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs):\n            param_group[\"lr\"] = new_lr\n\n    def _check_for_scheduler(self):\n        \"\"\"Check optimizer doesn't already have scheduler.\"\"\"\n        for param_group in self.optimizer.param_groups:\n            if \"initial_lr\" in param_group:\n                raise RuntimeError(\"Optimizer already has a scheduler attached to it\")\n\n    def _train_batch(\n        self, train_iter: TrainDataLoaderIter, accumulation_steps: int, non_blocking_transfer: bool = True\n    ) -> float:\n        self.model.train()\n        total_loss = 0\n\n        self.optimizer.zero_grad()\n        for i in range(accumulation_steps):\n            inputs, labels = next(train_iter)\n            inputs, labels = copy_to_device([inputs, labels], device=self.device, non_blocking=non_blocking_transfer)\n\n            # Forward pass\n            outputs = self.model(inputs)\n            loss = self.criterion(outputs, labels)\n\n            # Loss should be averaged in each step\n            loss /= accumulation_steps\n\n            # Backward pass\n            if self.amp and hasattr(self.optimizer, \"_amp_stash\"):\n                # For minor performance optimization, see also:\n                # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations\n                delay_unscale = ((i + 1) % accumulation_steps) != 0\n\n                with torch.cuda.amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss:  # type: ignore\n                    scaled_loss.backward()\n            else:\n                loss.backward()\n\n            total_loss += loss.item()\n\n        self.optimizer.step()\n\n        return total_loss\n\n    def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float:\n        # Set model to evaluation mode and disable gradient computation\n        running_loss = 0\n        with eval_mode(self.model):\n            for inputs, labels in val_iter:\n                # Copy data to the correct device\n                inputs, labels = copy_to_device(\n                    [inputs, labels], device=self.device, non_blocking=non_blocking_transfer\n                )\n\n                # Forward pass and loss computation\n                outputs = self.model(inputs)\n                loss = self.criterion(outputs, labels)\n                running_loss += loss.item() * len(labels)\n\n        return running_loss / len(val_iter.dataset)\n\n    def get_lrs_and_losses(self, skip_start: int = 0, skip_end: int = 0) -> tuple[list, list]:\n        \"\"\"Get learning rates and their corresponding losses\n\n        Args:\n            skip_start: number of batches to trim from the start.\n            skip_end: number of batches to trim from the end.\n        \"\"\"\n        if skip_start < 0:\n            raise ValueError(\"skip_start cannot be negative\")\n        if skip_end < 0:\n            raise ValueError(\"skip_end cannot be negative\")\n\n        lrs = self.history[\"lr\"]\n        losses = self.history[\"loss\"]\n        end_idx = len(lrs) - skip_end - 1\n        lrs = lrs[skip_start:end_idx]\n        losses = losses[skip_start:end_idx]\n\n        return lrs, losses\n\n    def get_steepest_gradient(self, skip_start: int = 0, skip_end: int = 0) -> tuple[float, float] | tuple[None, None]:\n        \"\"\"Get learning rate which has steepest gradient and its corresponding loss\n\n        Args:\n            skip_start: number of batches to trim from the start.\n            skip_end: number of batches to trim from the end.\n\n        Returns:\n            Learning rate which has steepest gradient and its corresponding loss\n        \"\"\"\n        lrs, losses = self.get_lrs_and_losses(skip_start, skip_end)\n\n        try:\n            min_grad_idx = np.gradient(np.array(losses)).argmin()\n            return lrs[min_grad_idx], losses[min_grad_idx]\n        except ValueError:\n            print(\"Failed to compute the gradients, there might not be enough points.\")\n            return None, None\n\n    def plot(\n        self,\n        skip_start: int = 0,\n        skip_end: int = 0,\n        log_lr: bool = True,\n        ax: Any | None = None,\n        steepest_lr: bool = True,\n    ) -> Any | None:\n        \"\"\"Plots the learning rate range test.\n\n        Args:\n            skip_start: number of batches to trim from the start.\n            skip_end: number of batches to trim from the start.\n            log_lr: True to plot the learning rate in a logarithmic\n                scale; otherwise, plotted in a linear scale.\n            ax: the plot is created in the specified matplotlib axes object and the\n                figure is not be shown. If `None`, then the figure and axes object are\n                created in this method and the figure is shown.\n            steepest_lr: plot the learning rate which had the steepest gradient.\n\n        Returns:\n            The `matplotlib.axes.Axes` object that contains the plot. Returns `None` if\n            `matplotlib` is not installed.\n        \"\"\"\n        if not has_matplotlib:\n            warnings.warn(\"Matplotlib is missing, can't plot result\")\n            return None\n\n        lrs, losses = self.get_lrs_and_losses(skip_start, skip_end)\n\n        # Create the figure and axes object if axes was not already given\n        fig = None\n        if ax is None:\n            fig, ax = plt.subplots()\n\n        # Plot loss as a function of the learning rate\n        ax.plot(lrs, losses)\n\n        # Plot the LR with steepest gradient\n        if steepest_lr:\n            lr_at_steepest_grad, loss_at_steepest_grad = self.get_steepest_gradient(skip_start, skip_end)\n            if lr_at_steepest_grad is not None and loss_at_steepest_grad is not None:\n                ax.scatter(\n                    lr_at_steepest_grad,\n                    loss_at_steepest_grad,\n                    s=75,\n                    marker=\"o\",\n                    color=\"red\",\n                    zorder=3,\n                    label=\"steepest gradient\",\n                )\n                ax.legend()\n\n        if log_lr:\n            ax.set_xscale(\"log\")\n        ax.set_xlabel(\"Learning rate\")\n        ax.set_ylabel(\"Loss\")\n\n        # Show only if the figure was created internally\n        if fig is not None:\n            plt.show()\n\n        return ax\n"
  },
  {
    "path": "monai/optimizers/lr_scheduler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\n\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR, _LRScheduler\n\n__all__ = [\"LinearLR\", \"ExponentialLR\"]\n\n\nclass _LRSchedulerMONAI(_LRScheduler):\n    \"\"\"Base class for increasing the learning rate between two boundaries over a number\n    of iterations\"\"\"\n\n    def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None:\n        \"\"\"\n        Args:\n            optimizer: wrapped optimizer.\n            end_lr: the final learning rate.\n            num_iter: the number of iterations over which the test occurs.\n            last_epoch: the index of last epoch.\n        Returns:\n            None\n        \"\"\"\n        self.end_lr = end_lr\n        self.num_iter = num_iter\n        super().__init__(optimizer, last_epoch)\n\n\nclass LinearLR(_LRSchedulerMONAI):\n    \"\"\"Linearly increases the learning rate between two boundaries over a number of\n    iterations.\n    \"\"\"\n\n    def get_lr(self):\n        r = self.last_epoch / (self.num_iter - 1)\n        return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]\n\n\nclass ExponentialLR(_LRSchedulerMONAI):\n    \"\"\"Exponentially increases the learning rate between two boundaries over a number of\n    iterations.\n    \"\"\"\n\n    def get_lr(self):\n        r = self.last_epoch / (self.num_iter - 1)\n        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]\n\n\nclass WarmupCosineSchedule(LambdaLR):\n    \"\"\"Linear warmup and then cosine decay.\n    Based on https://huggingface.co/ implementation.\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        warmup_steps: int,\n        t_total: int,\n        end_lr: float = 0.0,\n        cycles: float = 0.5,\n        last_epoch: int = -1,\n        warmup_multiplier: float = 0,\n    ) -> None:\n        \"\"\"\n        Args:\n            optimizer: wrapped optimizer.\n            warmup_steps: number of warmup iterations.\n            t_total: total number of training iterations.\n            end_lr: the final learning rate. Defaults to 0.0.\n            cycles: cosine cycles parameter.\n            last_epoch: the index of last epoch.\n            warmup_multiplier: if provided, starts the linear warmup from this fraction of the initial lr.\n                Must be in 0..1 interval. Defaults to 0\n        Returns:\n            None\n        \"\"\"\n        self.warmup_steps = min(max(warmup_steps, 0), t_total)\n        self.warmup_multiplier = warmup_multiplier\n        self.t_total = t_total\n        self.cycles = cycles\n        self.end_lr = end_lr\n        if warmup_multiplier < 0 or warmup_multiplier > 1:\n            raise ValueError(\"warmup_multiplier must be in 0..1 range\")\n        super().__init__(optimizer, self.lr_lambda, last_epoch)\n\n    def lr_lambda(self, step):\n        if step < self.warmup_steps:\n            f = float(step) / float(max(1.0, self.warmup_steps))\n            return self.warmup_multiplier + (1 - self.warmup_multiplier) * f\n        progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))\n\n    def get_lr(self):\n        current_lr = [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]\n        if self.last_epoch < self.warmup_steps:\n            return current_lr\n        else:\n            return [max(self.end_lr, _current_lr) for _current_lr in current_lr]\n"
  },
  {
    "path": "monai/optimizers/novograd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Iterable\nfrom typing import TypeVar\n\nimport torch\nfrom torch.optim import Optimizer\n\nT = TypeVar(\"T\")\n\n\nclass Novograd(Optimizer):\n    \"\"\"\n    Novograd based on `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks\n    <https://arxiv.org/pdf/1905.11286.pdf>`_.\n    The code is adapted from the implementations in `Jasper for PyTorch\n    <https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/common/optimizers.py>`_,\n    and `OpenSeq2Seq <https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/optimizers/novograd.py>`_.\n\n    Args:\n        params: iterable of parameters to optimize or dicts defining parameter groups.\n        lr: learning rate. Defaults to 1e-3.\n        betas: coefficients used for computing running averages of gradient and its square. Defaults to (0.9, 0.98).\n        eps: term added to the denominator to improve numerical stability. Defaults to 1e-8.\n        weight_decay: weight decay (L2 penalty). Defaults to 0.\n        grad_averaging: gradient averaging. Defaults to ``False``.\n        amsgrad: whether to use the AMSGrad variant of this algorithm from the paper\n            `On the Convergence of Adam and Beyond <https://arxiv.org/pdf/1904.09237.pdf>`_. Defaults to ``False``.\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Iterable,\n        lr: float = 1e-3,\n        betas: tuple[float, float] = (0.9, 0.98),\n        eps: float = 1e-8,\n        weight_decay: float = 0,\n        grad_averaging: bool = False,\n        amsgrad: bool = False,\n    ):\n        if 0.0 > lr:\n            raise ValueError(f\"Invalid learning rate: {lr}\")\n        if 0.0 > eps:\n            raise ValueError(f\"Invalid epsilon value: {eps}\")\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(f\"Invalid beta parameter at index 0: {betas[0]}\")\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(f\"Invalid beta parameter at index 1: {betas[1]}\")\n        if 0.0 > weight_decay:\n            raise ValueError(f\"Invalid weight_decay value: {weight_decay}\")\n        defaults = dict(\n            lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad\n        )\n\n        super().__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super().__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault(\"amsgrad\", False)\n\n    def step(self, closure: Callable[[], T] | None = None) -> T | None:  # type: ignore\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure: A closure that reevaluates the model and returns the loss. Defaults to ``None``.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError(\"Sparse gradients are not supported.\")\n                amsgrad = group[\"amsgrad\"]\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    # Exponential moving average of gradient values\n                    state[\"exp_avg\"] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state[\"exp_avg_sq\"] = torch.zeros([]).to(state[\"exp_avg\"].device)\n                    if amsgrad:\n                        # Maintains max of all exp. moving avg. of sq. grad. values\n                        state[\"max_exp_avg_sq\"] = torch.zeros([]).to(state[\"exp_avg\"].device)\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                if amsgrad:\n                    max_exp_avg_sq = state[\"max_exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                state[\"step\"] += 1\n\n                norm = torch.sum(torch.pow(grad, 2))\n\n                if exp_avg_sq == 0:\n                    exp_avg_sq.copy_(norm)\n                else:\n                    exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)\n\n                if amsgrad:\n                    # Maintains the maximum of all 2nd moment running avg. till now\n                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)\n                    # Use the max. for normalizing running avg. of gradient\n                    denom = max_exp_avg_sq.sqrt().add_(group[\"eps\"])\n                else:\n                    denom = exp_avg_sq.sqrt().add_(group[\"eps\"])\n\n                grad.div_(denom)\n                if group[\"weight_decay\"] != 0:\n                    grad.add_(p.data, alpha=group[\"weight_decay\"])\n                if group[\"grad_averaging\"]:\n                    grad.mul_(1 - beta1)\n                exp_avg.mul_(beta1).add_(grad)\n\n                p.data.add_(exp_avg, alpha=-group[\"lr\"])\n\n        return loss\n"
  },
  {
    "path": "monai/optimizers/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sequence\n\nimport torch\n\nfrom monai.utils import ensure_tuple, ensure_tuple_rep\n\n__all__ = [\"generate_param_groups\"]\n\n\ndef generate_param_groups(\n    network: torch.nn.Module,\n    layer_matches: Sequence[Callable],\n    match_types: Sequence[str],\n    lr_values: Sequence[float],\n    include_others: bool = True,\n) -> list[dict]:\n    \"\"\"\n    Utility function to generate parameter groups with different LR values for optimizer.\n    The output parameter groups have the same order as `layer_match` functions.\n\n    Args:\n        network: source network to generate parameter groups from.\n        layer_matches: a list of callable functions to select or filter out network layer groups,\n            for \"select\" type, the input will be the `network`, for \"filter\" type,\n            the input will be every item of `network.named_parameters()`.\n            for \"select\", the parameters will be\n            `select_func(network).parameters()`.\n            for \"filter\", the parameters will be\n            `(x[1] for x in filter(f, network.named_parameters()))`\n        match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,\n            can be \"select\" or \"filter\".\n        lr_values: a list of LR values corresponding to the `layer_matches` functions.\n        include_others: whether to include the rest layers as the last group, default to True.\n\n    It's mainly used to set different LR values for different network elements, for example:\n\n    .. code-block:: python\n\n        net = Unet(spatial_dims=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])\n        print(net)  # print out network components to select expected items\n        print(net.named_parameters())  # print out all the named parameters to filter out expected items\n        params = generate_param_groups(\n            network=net,\n            layer_matches=[lambda x: x.model[0], lambda x: \"2.0.conv\" in x[0]],\n            match_types=[\"select\", \"filter\"],\n            lr_values=[1e-2, 1e-3],\n        )\n        # the groups will be a list of dictionaries:\n        # [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},\n        #  {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},\n        #  {'params': <filter object at 0x7f9088fd0da0>}]\n        optimizer = torch.optim.Adam(params, 1e-4)\n\n    \"\"\"\n    layer_matches = ensure_tuple(layer_matches)\n    match_types = ensure_tuple_rep(match_types, len(layer_matches))\n    lr_values = ensure_tuple_rep(lr_values, len(layer_matches))\n\n    def _get_select(f):\n\n        def _select():\n            return f(network).parameters()\n\n        return _select\n\n    def _get_filter(f):\n\n        def _filter():\n            # should eventually generate a list of network parameters\n            return (x[1] for x in filter(f, network.named_parameters()))\n\n        return _filter\n\n    params = []\n    _layers = []\n    for func, ty, lr in zip(layer_matches, match_types, lr_values):\n        if ty.lower() == \"select\":\n            layer_params = _get_select(func)\n        elif ty.lower() == \"filter\":\n            layer_params = _get_filter(func)\n        else:\n            raise ValueError(f\"unsupported layer match type: {ty}.\")\n\n        params.append({\"params\": layer_params(), \"lr\": lr})\n        _layers.extend([id(x) for x in layer_params()])\n\n    if include_others:\n        params.append({\"params\": filter(lambda p: id(p) not in _layers, network.parameters())})\n\n    return params\n"
  },
  {
    "path": "monai/py.typed",
    "content": ""
  },
  {
    "path": "monai/torch.patch",
    "content": "--- /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.py\t2024-10-31 06:09:21.139938791 +0000\n+++ /usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset14.bak\t2024-10-31 06:01:50.207462739 +0000\n@@ -150,6 +150,7 @@\n     ), \"is_causal and attn_mask cannot be set at the same time\"\n     assert not enable_gqa, \"conversion of scaled_dot_product_attention not implemented if enable_gqa is True\"\n\n+    scale = symbolic_helper._maybe_get_const(scale, \"f\")\n     if symbolic_helper._is_none(scale):\n         scale = _attention_scale(g, query)\n"
  },
  {
    "path": "monai/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs\nfrom .compose import Compose, OneOf, RandomOrder, SomeOf\nfrom .croppad.array import (\n    BorderPad,\n    BoundingRect,\n    CenterScaleCrop,\n    CenterSpatialCrop,\n    Crop,\n    CropForeground,\n    DivisiblePad,\n    Pad,\n    RandCropByLabelClasses,\n    RandCropByPosNegLabel,\n    RandScaleCrop,\n    RandSpatialCrop,\n    RandSpatialCropSamples,\n    RandWeightedCrop,\n    ResizeWithPadOrCrop,\n    SpatialCrop,\n    SpatialPad,\n)\nfrom .croppad.batch import PadListDataCollate\nfrom .croppad.dictionary import (\n    BorderPadd,\n    BorderPadD,\n    BorderPadDict,\n    BoundingRectd,\n    BoundingRectD,\n    BoundingRectDict,\n    CenterScaleCropd,\n    CenterScaleCropD,\n    CenterScaleCropDict,\n    CenterSpatialCropd,\n    CenterSpatialCropD,\n    CenterSpatialCropDict,\n    Cropd,\n    CropD,\n    CropDict,\n    CropForegroundd,\n    CropForegroundD,\n    CropForegroundDict,\n    DivisiblePadd,\n    DivisiblePadD,\n    DivisiblePadDict,\n    Padd,\n    PadD,\n    PadDict,\n    RandCropByLabelClassesd,\n    RandCropByLabelClassesD,\n    RandCropByLabelClassesDict,\n    RandCropByPosNegLabeld,\n    RandCropByPosNegLabelD,\n    RandCropByPosNegLabelDict,\n    RandCropd,\n    RandCropD,\n    RandCropDict,\n    RandScaleCropd,\n    RandScaleCropD,\n    RandScaleCropDict,\n    RandSpatialCropd,\n    RandSpatialCropD,\n    RandSpatialCropDict,\n    RandSpatialCropSamplesd,\n    RandSpatialCropSamplesD,\n    RandSpatialCropSamplesDict,\n    RandWeightedCropd,\n    RandWeightedCropD,\n    RandWeightedCropDict,\n    ResizeWithPadOrCropd,\n    ResizeWithPadOrCropD,\n    ResizeWithPadOrCropDict,\n    SpatialCropd,\n    SpatialCropD,\n    SpatialCropDict,\n    SpatialPadd,\n    SpatialPadD,\n    SpatialPadDict,\n)\nfrom .croppad.functional import crop_func, crop_or_pad_nd, pad_func, pad_nd\nfrom .intensity.array import (\n    AdjustContrast,\n    ClipIntensityPercentiles,\n    ComputeHoVerMaps,\n    DetectEnvelope,\n    ForegroundMask,\n    GaussianSharpen,\n    GaussianSmooth,\n    GibbsNoise,\n    HistogramNormalize,\n    IntensityRemap,\n    KSpaceSpikeNoise,\n    MaskIntensity,\n    MedianSmooth,\n    NormalizeIntensity,\n    RandAdjustContrast,\n    RandBiasField,\n    RandCoarseDropout,\n    RandCoarseShuffle,\n    RandCoarseTransform,\n    RandGaussianNoise,\n    RandGaussianSharpen,\n    RandGaussianSmooth,\n    RandGibbsNoise,\n    RandHistogramShift,\n    RandIntensityRemap,\n    RandKSpaceSpikeNoise,\n    RandRicianNoise,\n    RandScaleIntensity,\n    RandScaleIntensityFixedMean,\n    RandShiftIntensity,\n    RandStdShiftIntensity,\n    SavitzkyGolaySmooth,\n    ScaleIntensity,\n    ScaleIntensityFixedMean,\n    ScaleIntensityRange,\n    ScaleIntensityRangePercentiles,\n    ShiftIntensity,\n    StdShiftIntensity,\n    ThresholdIntensity,\n    UltrasoundConfidenceMapTransform,\n)\nfrom .intensity.dictionary import (\n    AdjustContrastd,\n    AdjustContrastD,\n    AdjustContrastDict,\n    ClipIntensityPercentilesd,\n    ClipIntensityPercentilesD,\n    ClipIntensityPercentilesDict,\n    ComputeHoVerMapsd,\n    ComputeHoVerMapsD,\n    ComputeHoVerMapsDict,\n    ForegroundMaskd,\n    ForegroundMaskD,\n    ForegroundMaskDict,\n    GaussianSharpend,\n    GaussianSharpenD,\n    GaussianSharpenDict,\n    GaussianSmoothd,\n    GaussianSmoothD,\n    GaussianSmoothDict,\n    GibbsNoised,\n    GibbsNoiseD,\n    GibbsNoiseDict,\n    HistogramNormalized,\n    HistogramNormalizeD,\n    HistogramNormalizeDict,\n    KSpaceSpikeNoised,\n    KSpaceSpikeNoiseD,\n    KSpaceSpikeNoiseDict,\n    MaskIntensityd,\n    MaskIntensityD,\n    MaskIntensityDict,\n    MedianSmoothd,\n    MedianSmoothD,\n    MedianSmoothDict,\n    NormalizeIntensityd,\n    NormalizeIntensityD,\n    NormalizeIntensityDict,\n    RandAdjustContrastd,\n    RandAdjustContrastD,\n    RandAdjustContrastDict,\n    RandBiasFieldd,\n    RandBiasFieldD,\n    RandBiasFieldDict,\n    RandCoarseDropoutd,\n    RandCoarseDropoutD,\n    RandCoarseDropoutDict,\n    RandCoarseShuffled,\n    RandCoarseShuffleD,\n    RandCoarseShuffleDict,\n    RandGaussianNoised,\n    RandGaussianNoiseD,\n    RandGaussianNoiseDict,\n    RandGaussianSharpend,\n    RandGaussianSharpenD,\n    RandGaussianSharpenDict,\n    RandGaussianSmoothd,\n    RandGaussianSmoothD,\n    RandGaussianSmoothDict,\n    RandGibbsNoised,\n    RandGibbsNoiseD,\n    RandGibbsNoiseDict,\n    RandHistogramShiftd,\n    RandHistogramShiftD,\n    RandHistogramShiftDict,\n    RandKSpaceSpikeNoised,\n    RandKSpaceSpikeNoiseD,\n    RandKSpaceSpikeNoiseDict,\n    RandRicianNoised,\n    RandRicianNoiseD,\n    RandRicianNoiseDict,\n    RandScaleIntensityd,\n    RandScaleIntensityD,\n    RandScaleIntensityDict,\n    RandScaleIntensityFixedMeand,\n    RandScaleIntensityFixedMeanD,\n    RandScaleIntensityFixedMeanDict,\n    RandShiftIntensityd,\n    RandShiftIntensityD,\n    RandShiftIntensityDict,\n    RandStdShiftIntensityd,\n    RandStdShiftIntensityD,\n    RandStdShiftIntensityDict,\n    SavitzkyGolaySmoothd,\n    SavitzkyGolaySmoothD,\n    SavitzkyGolaySmoothDict,\n    ScaleIntensityd,\n    ScaleIntensityD,\n    ScaleIntensityDict,\n    ScaleIntensityRanged,\n    ScaleIntensityRangeD,\n    ScaleIntensityRangeDict,\n    ScaleIntensityRangePercentilesd,\n    ScaleIntensityRangePercentilesD,\n    ScaleIntensityRangePercentilesDict,\n    ShiftIntensityd,\n    ShiftIntensityD,\n    ShiftIntensityDict,\n    StdShiftIntensityd,\n    StdShiftIntensityD,\n    StdShiftIntensityDict,\n    ThresholdIntensityd,\n    ThresholdIntensityD,\n    ThresholdIntensityDict,\n)\nfrom .inverse import InvertibleTransform, TraceableTransform\nfrom .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict\nfrom .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping\nfrom .io.dictionary import (\n    LoadImaged,\n    LoadImageD,\n    LoadImageDict,\n    SaveImaged,\n    SaveImageD,\n    SaveImageDict,\n    WriteFileMappingd,\n    WriteFileMappingD,\n    WriteFileMappingDict,\n)\nfrom .lazy.array import ApplyPending\nfrom .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict\nfrom .lazy.functional import apply_pending\nfrom .lazy.utils import combine_transforms, resample\nfrom .meta_utility.dictionary import (\n    FromMetaTensord,\n    FromMetaTensorD,\n    FromMetaTensorDict,\n    ToMetaTensord,\n    ToMetaTensorD,\n    ToMetaTensorDict,\n)\nfrom .nvtx import (\n    Mark,\n    Markd,\n    MarkD,\n    MarkDict,\n    RandMark,\n    RandMarkd,\n    RandMarkD,\n    RandMarkDict,\n    RandRangePop,\n    RandRangePopd,\n    RandRangePopD,\n    RandRangePopDict,\n    RandRangePush,\n    RandRangePushd,\n    RandRangePushD,\n    RandRangePushDict,\n    RangePop,\n    RangePopd,\n    RangePopD,\n    RangePopDict,\n    RangePush,\n    RangePushd,\n    RangePushD,\n    RangePushDict,\n)\nfrom .post.array import (\n    Activations,\n    AsDiscrete,\n    DistanceTransformEDT,\n    FillHoles,\n    GenerateHeatmap,\n    Invert,\n    KeepLargestConnectedComponent,\n    LabelFilter,\n    LabelToContour,\n    MeanEnsemble,\n    ProbNMS,\n    RemoveSmallObjects,\n    SobelGradients,\n    VoteEnsemble,\n)\nfrom .post.dictionary import (\n    ActivationsD,\n    Activationsd,\n    ActivationsDict,\n    AsDiscreteD,\n    AsDiscreted,\n    AsDiscreteDict,\n    DistanceTransformEDTd,\n    DistanceTransformEDTD,\n    DistanceTransformEDTDict,\n    Ensembled,\n    EnsembleD,\n    EnsembleDict,\n    FillHolesD,\n    FillHolesd,\n    FillHolesDict,\n    GenerateHeatmapd,\n    GenerateHeatmapD,\n    GenerateHeatmapDict,\n    InvertD,\n    Invertd,\n    InvertDict,\n    KeepLargestConnectedComponentD,\n    KeepLargestConnectedComponentd,\n    KeepLargestConnectedComponentDict,\n    LabelFilterD,\n    LabelFilterd,\n    LabelFilterDict,\n    LabelToContourD,\n    LabelToContourd,\n    LabelToContourDict,\n    MeanEnsembleD,\n    MeanEnsembled,\n    MeanEnsembleDict,\n    ProbNMSD,\n    ProbNMSd,\n    ProbNMSDict,\n    RemoveSmallObjectsD,\n    RemoveSmallObjectsd,\n    RemoveSmallObjectsDict,\n    SaveClassificationD,\n    SaveClassificationd,\n    SaveClassificationDict,\n    SobelGradientsd,\n    SobelGradientsD,\n    SobelGradientsDict,\n    VoteEnsembleD,\n    VoteEnsembled,\n    VoteEnsembleDict,\n)\nfrom .regularization.array import CutMix, CutOut, MixUp\nfrom .regularization.dictionary import (\n    CutMixd,\n    CutMixD,\n    CutMixDict,\n    CutOutd,\n    CutOutD,\n    CutOutDict,\n    MixUpd,\n    MixUpD,\n    MixUpDict,\n)\nfrom .signal.array import (\n    SignalContinuousWavelet,\n    SignalFillEmpty,\n    SignalRandAddGaussianNoise,\n    SignalRandAddSine,\n    SignalRandAddSinePartial,\n    SignalRandAddSquarePulse,\n    SignalRandAddSquarePulsePartial,\n    SignalRandDrop,\n    SignalRandScale,\n    SignalRandShift,\n    SignalRemoveFrequency,\n)\nfrom .signal.dictionary import SignalFillEmptyd, SignalFillEmptyD, SignalFillEmptyDict\nfrom .smooth_field.array import (\n    RandSmoothDeform,\n    RandSmoothFieldAdjustContrast,\n    RandSmoothFieldAdjustIntensity,\n    SmoothField,\n)\nfrom .smooth_field.dictionary import (\n    RandSmoothDeformd,\n    RandSmoothDeformD,\n    RandSmoothDeformDict,\n    RandSmoothFieldAdjustContrastd,\n    RandSmoothFieldAdjustContrastD,\n    RandSmoothFieldAdjustContrastDict,\n    RandSmoothFieldAdjustIntensityd,\n    RandSmoothFieldAdjustIntensityD,\n    RandSmoothFieldAdjustIntensityDict,\n)\nfrom .spatial.array import (\n    Affine,\n    AffineGrid,\n    ConvertBoxToPoints,\n    ConvertPointsToBoxes,\n    Flip,\n    GridDistortion,\n    GridPatch,\n    GridSplit,\n    Orientation,\n    Rand2DElastic,\n    Rand3DElastic,\n    RandAffine,\n    RandAffineGrid,\n    RandAxisFlip,\n    RandDeformGrid,\n    RandFlip,\n    RandGridDistortion,\n    RandGridPatch,\n    RandRotate,\n    RandRotate90,\n    RandSimulateLowResolution,\n    RandZoom,\n    Resample,\n    ResampleToMatch,\n    Resize,\n    Rotate,\n    Rotate90,\n    Spacing,\n    SpatialResample,\n    Zoom,\n)\nfrom .spatial.dictionary import (\n    Affined,\n    AffineD,\n    AffineDict,\n    ConvertBoxToPointsd,\n    ConvertBoxToPointsD,\n    ConvertBoxToPointsDict,\n    ConvertPointsToBoxesd,\n    ConvertPointsToBoxesD,\n    ConvertPointsToBoxesDict,\n    Flipd,\n    FlipD,\n    FlipDict,\n    GridDistortiond,\n    GridDistortionD,\n    GridDistortionDict,\n    GridPatchd,\n    GridPatchD,\n    GridPatchDict,\n    GridSplitd,\n    GridSplitD,\n    GridSplitDict,\n    Orientationd,\n    OrientationD,\n    OrientationDict,\n    Rand2DElasticd,\n    Rand2DElasticD,\n    Rand2DElasticDict,\n    Rand3DElasticd,\n    Rand3DElasticD,\n    Rand3DElasticDict,\n    RandAffined,\n    RandAffineD,\n    RandAffineDict,\n    RandAxisFlipd,\n    RandAxisFlipD,\n    RandAxisFlipDict,\n    RandFlipd,\n    RandFlipD,\n    RandFlipDict,\n    RandGridDistortiond,\n    RandGridDistortionD,\n    RandGridDistortionDict,\n    RandGridPatchd,\n    RandGridPatchD,\n    RandGridPatchDict,\n    RandRotate90d,\n    RandRotate90D,\n    RandRotate90Dict,\n    RandRotated,\n    RandRotateD,\n    RandRotateDict,\n    RandSimulateLowResolutiond,\n    RandSimulateLowResolutionD,\n    RandSimulateLowResolutionDict,\n    RandZoomd,\n    RandZoomD,\n    RandZoomDict,\n    ResampleToMatchd,\n    ResampleToMatchD,\n    ResampleToMatchDict,\n    Resized,\n    ResizeD,\n    ResizeDict,\n    Rotate90d,\n    Rotate90D,\n    Rotate90Dict,\n    Rotated,\n    RotateD,\n    RotateDict,\n    Spacingd,\n    SpacingD,\n    SpacingDict,\n    SpatialResampled,\n    SpatialResampleD,\n    SpatialResampleDict,\n    Zoomd,\n    ZoomD,\n    ZoomDict,\n)\nfrom .spatial.functional import spatial_resample\nfrom .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe\nfrom .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform\nfrom .utility.array import (\n    AddCoordinateChannels,\n    AddExtremePointsChannel,\n    ApplyTransformToPoints,\n    AsChannelLast,\n    CastToType,\n    ClassesToIndices,\n    ConvertToMultiChannelBasedOnBratsClasses,\n    CuCIM,\n    DataStats,\n    EnsureChannelFirst,\n    EnsureType,\n    FgBgToIndices,\n    FlattenSequence,\n    Identity,\n    ImageFilter,\n    IntensityStats,\n    LabelToMask,\n    Lambda,\n    MapLabelValue,\n    RandCuCIM,\n    RandIdentity,\n    RandImageFilter,\n    RandLambda,\n    RandTorchIO,\n    RandTorchVision,\n    RemoveRepeatedChannel,\n    RepeatChannel,\n    SimulateDelay,\n    SplitDim,\n    SqueezeDim,\n    ToCupy,\n    ToDevice,\n    ToNumpy,\n    ToPIL,\n    TorchIO,\n    TorchVision,\n    ToTensor,\n    Transpose,\n)\nfrom .utility.dictionary import (\n    AddCoordinateChannelsd,\n    AddCoordinateChannelsD,\n    AddCoordinateChannelsDict,\n    AddExtremePointsChanneld,\n    AddExtremePointsChannelD,\n    AddExtremePointsChannelDict,\n    ApplyTransformToPointsd,\n    ApplyTransformToPointsD,\n    ApplyTransformToPointsDict,\n    AsChannelLastd,\n    AsChannelLastD,\n    AsChannelLastDict,\n    CastToTyped,\n    CastToTypeD,\n    CastToTypeDict,\n    ClassesToIndicesd,\n    ClassesToIndicesD,\n    ClassesToIndicesDict,\n    ConcatItemsd,\n    ConcatItemsD,\n    ConcatItemsDict,\n    ConvertToMultiChannelBasedOnBratsClassesd,\n    ConvertToMultiChannelBasedOnBratsClassesD,\n    ConvertToMultiChannelBasedOnBratsClassesDict,\n    CopyItemsd,\n    CopyItemsD,\n    CopyItemsDict,\n    CuCIMd,\n    CuCIMD,\n    CuCIMDict,\n    DataStatsd,\n    DataStatsD,\n    DataStatsDict,\n    DeleteItemsd,\n    DeleteItemsD,\n    DeleteItemsDict,\n    EnsureChannelFirstd,\n    EnsureChannelFirstD,\n    EnsureChannelFirstDict,\n    EnsureTyped,\n    EnsureTypeD,\n    EnsureTypeDict,\n    FgBgToIndicesd,\n    FgBgToIndicesD,\n    FgBgToIndicesDict,\n    FlattenSequenced,\n    FlattenSequenceD,\n    FlattenSequenceDict,\n    FlattenSubKeysd,\n    FlattenSubKeysD,\n    FlattenSubKeysDict,\n    Identityd,\n    IdentityD,\n    IdentityDict,\n    ImageFilterd,\n    ImageFilterD,\n    ImageFilterDict,\n    IntensityStatsd,\n    IntensityStatsD,\n    IntensityStatsDict,\n    LabelToMaskd,\n    LabelToMaskD,\n    LabelToMaskDict,\n    Lambdad,\n    LambdaD,\n    LambdaDict,\n    MapLabelValued,\n    MapLabelValueD,\n    MapLabelValueDict,\n    RandCuCIMd,\n    RandCuCIMD,\n    RandCuCIMDict,\n    RandImageFilterd,\n    RandImageFilterD,\n    RandImageFilterDict,\n    RandLambdad,\n    RandLambdaD,\n    RandLambdaDict,\n    RandTorchIOd,\n    RandTorchIOD,\n    RandTorchIODict,\n    RandTorchVisiond,\n    RandTorchVisionD,\n    RandTorchVisionDict,\n    RemoveRepeatedChanneld,\n    RemoveRepeatedChannelD,\n    RemoveRepeatedChannelDict,\n    RepeatChanneld,\n    RepeatChannelD,\n    RepeatChannelDict,\n    SelectItemsd,\n    SelectItemsD,\n    SelectItemsDict,\n    SimulateDelayd,\n    SimulateDelayD,\n    SimulateDelayDict,\n    SplitDimd,\n    SplitDimD,\n    SplitDimDict,\n    SqueezeDimd,\n    SqueezeDimD,\n    SqueezeDimDict,\n    ToCupyd,\n    ToCupyD,\n    ToCupyDict,\n    ToDeviced,\n    ToDeviceD,\n    ToDeviceDict,\n    ToNumpyd,\n    ToNumpyD,\n    ToNumpyDict,\n    ToPILd,\n    ToPILD,\n    ToPILDict,\n    TorchIOd,\n    TorchIOD,\n    TorchIODict,\n    TorchVisiond,\n    TorchVisionD,\n    TorchVisionDict,\n    ToTensord,\n    ToTensorD,\n    ToTensorDict,\n    Transposed,\n    TransposeD,\n    TransposeDict,\n)\nfrom .utils import (\n    Fourier,\n    allow_missing_keys_mode,\n    attach_hook,\n    check_non_lazy_pending_ops,\n    compute_divisible_spatial_size,\n    convert_applied_interp_mode,\n    convert_pad_mode,\n    convert_to_contiguous,\n    copypaste_arrays,\n    create_control_grid,\n    create_grid,\n    create_rotate,\n    create_scale,\n    create_shear,\n    create_translate,\n    distance_transform_edt,\n    equalize_hist,\n    extreme_points_to_image,\n    generate_label_classes_crop_centers,\n    generate_pos_neg_label_crop_centers,\n    generate_spatial_bounding_box,\n    get_extreme_points,\n    get_largest_connected_component_mask,\n    get_number_image_type_conversions,\n    get_transform_backends,\n    img_bounds,\n    in_bounds,\n    is_empty,\n    is_positive,\n    map_and_generate_sampling_centers,\n    map_binary_to_indices,\n    map_classes_to_indices,\n    map_spatial_axes,\n    print_transform_backends,\n    rand_choice,\n    remove_small_objects,\n    rescale_array,\n    rescale_array_int_max,\n    rescale_instance_array,\n    reset_ops_id,\n    resize_center,\n    resolves_modes,\n    sync_meta_info,\n    weighted_patch_samples,\n    zero_margins,\n)\nfrom .utils_morphological_ops import dilate, erode\nfrom .utils_pytorch_numpy_unification import (\n    allclose,\n    any_np_pt,\n    ascontiguousarray,\n    clip,\n    concatenate,\n    cumsum,\n    floor_divide,\n    in1d,\n    isfinite,\n    isnan,\n    maximum,\n    mode,\n    moveaxis,\n    nonzero,\n    percentile,\n    ravel,\n    repeat,\n    stack,\n    unravel_index,\n    where,\n)\n"
  },
  {
    "path": "monai/transforms/adaptors.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nHow to use the adaptor function\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe key to using 'adaptor' lies in understanding the function that want to\nadapt. The 'inputs' and 'outputs' parameters take either strings, lists/tuples\nof strings or a dictionary mapping strings, depending on call signature of the\nfunction being called.\n\nThe adaptor function is written to minimise the cognitive load on the caller.\nThere should be a minimal number of cases where the caller has to set anything\non the input parameter, and for functions that return a single value, it is\nonly necessary to name the dictionary keyword to which that value is assigned.\n\nUse of `outputs`\n----------------\n\n`outputs` can take either a string, a list/tuple of string or a dict of string\nto string, depending on what the transform being adapted returns:\n\n    - If the transform returns a single argument, then outputs can be supplied a\n      string that indicates what key to assign the return value to in the\n      dictionary\n    - If the transform returns a list/tuple of values, then outputs can be supplied\n      a list/tuple of the same length. The strings in outputs map the return value\n      at the corresponding position to a key in the dictionary\n    - If the transform returns a dictionary of values, then outputs must be supplied\n      a dictionary that maps keys in the function's return dictionary to the\n      dictionary being passed between functions\n\nNote, the caller is free to use a more complex way of specifying the outputs\nparameter than is required. The following are synonymous and will be treated\nidentically:\n\n.. code-block:: python\n\n   # single argument\n   adaptor(MyTransform(), 'image')\n   adaptor(MyTransform(), ['image'])\n   adaptor(MyTransform(), {'image': 'image'})\n\n   # multiple arguments\n   adaptor(MyTransform(), ['image', 'label'])\n   adaptor(MyTransform(), {'image': 'image', 'label': 'label'})\n\nUse of `inputs`\n---------------\n\n`inputs` can usually be omitted when using `adaptor`. It is only required when a\nthe function's parameter names do not match the names in the dictionary that is\nused to chain transform calls.\n\n.. code-block:: python\n\n    class MyTransform1:\n        def __call__(self, image):\n            # do stuff to image\n            return image + 1\n\n\n    class MyTransform2:\n        def __call__(self, img_dict):\n            # do stuff to image\n            img_dict[\"image\"] += 1\n            return img_dict\n\n\n    xform = Compose([adaptor(MyTransform1(), \"image\"), MyTransform2()])\n    d = {\"image\": 1}\n    print(xform(d))\n\n    >>> {'image': 3}\n\n.. code-block:: python\n\n    class MyTransform3:\n        def __call__(self, img_dict):\n            # do stuff to image\n            img_dict[\"image\"] -= 1\n            img_dict[\"segment\"] = img_dict[\"image\"]\n            return img_dict\n\n\n    class MyTransform4:\n        def __call__(self, img, seg):\n            # do stuff to image\n            img -= 1\n            seg -= 1\n            return img, seg\n\n\n    xform = Compose([MyTransform3(), adaptor(MyTransform4(), [\"img\", \"seg\"], {\"image\": \"img\", \"segment\": \"seg\"})])\n    d = {\"image\": 1}\n    print(xform(d))\n\n    >>> {'image': 0, 'segment': 0, 'img': -1, 'seg': -1}\n\nInputs:\n\n- dictionary in: None | Name maps\n- params in (match): None | Name list | Name maps\n- params in (mismatch): Name maps\n- params & `**kwargs` (match) : None | Name maps\n- params & `**kwargs` (mismatch) : Name maps\n\nOutputs:\n\n- dictionary out: None | Name maps\n- list/tuple out: list/tuple\n- variable out: string\n\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Callable\n\n__all__ = [\"adaptor\", \"apply_alias\", \"to_kwargs\", \"FunctionSignature\"]\n\n\ndef adaptor(function, outputs, inputs=None):\n\n    def must_be_types_or_none(variable_name, variable, types):\n        if variable is not None:\n            if not isinstance(variable, types):\n                raise TypeError(f\"'{variable_name}' must be None or one of {types} but is {type(variable)}\")\n\n    def must_be_types(variable_name, variable, types):\n        if not isinstance(variable, types):\n            raise TypeError(f\"'{variable_name}' must be one of {types} but is {type(variable)}\")\n\n    def map_names(ditems, input_map):\n        return {input_map(k, k): v for k, v in ditems.items()}\n\n    def map_only_names(ditems, input_map):\n        return {v: ditems[k] for k, v in input_map.items()}\n\n    def _inner(ditems):\n        sig = FunctionSignature(function)\n\n        if sig.found_kwargs:\n            must_be_types_or_none(\"inputs\", inputs, (dict,))\n            # we just forward all arguments unless we have been provided an input map\n            if inputs is None:\n                dinputs = dict(ditems)\n            else:\n                # dict\n                dinputs = map_names(ditems, inputs)\n\n        else:\n            # no **kwargs\n            # select only items from the method signature\n            dinputs = {k: v for k, v in ditems.items() if k in sig.non_var_parameters}\n            must_be_types_or_none(\"inputs\", inputs, (str, list, tuple, dict))\n            if inputs is None:\n                pass\n            elif isinstance(inputs, str):\n                if len(sig.non_var_parameters) != 1:\n                    raise ValueError(\"if 'inputs' is a string, function may only have a single non-variadic parameter\")\n                dinputs = {inputs: ditems[inputs]}\n            elif isinstance(inputs, (list, tuple)):\n                dinputs = {k: dinputs[k] for k in inputs}\n            else:\n                # dict\n                dinputs = map_only_names(ditems, inputs)\n\n        ret = function(**dinputs)\n\n        # now the mapping back to the output dictionary depends on outputs and what was returned from the function\n        op = outputs\n        if isinstance(ret, dict):\n            must_be_types_or_none(\"outputs\", op, (dict,))\n            if op is not None:\n                ret = {v: ret[k] for k, v in op.items()}\n        elif isinstance(ret, (list, tuple)):\n            if len(ret) == 1:\n                must_be_types(\"outputs\", op, (str, list, tuple))\n            else:\n                must_be_types(\"outputs\", op, (list, tuple))\n\n            if isinstance(op, str):\n                op = [op]\n\n            if len(ret) != len(outputs):\n                raise ValueError(\"'outputs' must have the same length as the number of elements that were returned\")\n\n            ret = dict(zip(op, ret))\n        else:\n            must_be_types(\"outputs\", op, (str, list, tuple))\n            if isinstance(op, (list, tuple)):\n                if len(op) != 1:\n                    raise ValueError(\"'outputs' must be of length one if it is a list or tuple\")\n                op = op[0]\n            ret = {op: ret}\n\n        ditems = dict(ditems)\n        for k, v in ret.items():\n            ditems[k] = v\n\n        return ditems\n\n    return _inner\n\n\ndef apply_alias(fn, name_map):\n\n    def _inner(data):\n        # map names\n        pre_call = dict(data)\n        for _from, _to in name_map.items():\n            pre_call[_to] = pre_call.pop(_from)\n\n        # execute\n        post_call = fn(pre_call)\n\n        # map names back\n        for _from, _to in name_map.items():\n            post_call[_from] = post_call.pop(_to)\n\n        return post_call\n\n    return _inner\n\n\ndef to_kwargs(fn):\n\n    def _inner(data):\n        return fn(**data)\n\n    return _inner\n\n\nclass FunctionSignature:\n\n    def __init__(self, function: Callable) -> None:\n        import inspect\n\n        sfn = inspect.signature(function)\n        self.found_args = False\n        self.found_kwargs = False\n        self.defaults = {}\n        self.non_var_parameters = set()\n        for p in sfn.parameters.values():\n            if p.kind is inspect.Parameter.VAR_POSITIONAL:\n                self.found_args = True\n            if p.kind is inspect.Parameter.VAR_KEYWORD:\n                self.found_kwargs = True\n            else:\n                self.non_var_parameters.add(p.name)\n                self.defaults[p.name] = p.default is not p.empty\n\n    def __repr__(self) -> str:\n        s = \"<class 'FunctionSignature': found_args={}, found_kwargs={}, defaults={}\"\n        return s.format(self.found_args, self.found_kwargs, self.defaults)\n\n    def __str__(self) -> str:\n        return self.__repr__()\n"
  },
  {
    "path": "monai/transforms/compose.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of generic interfaces for MONAI transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Mapping, Sequence\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\n\nimport monai\nfrom monai.apps.utils import get_logger\nfrom monai.config import NdarrayOrTensor\nfrom monai.transforms.inverse import InvertibleTransform\n\n# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)\nfrom monai.transforms.lazy.functional import apply_pending_transforms\nfrom monai.transforms.traits import ThreadUnsafe\nfrom monai.transforms.transform import LazyTransform, Randomizable, apply_transform\nfrom monai.utils import MAX_SEED, TraceKeys, TraceStatusKeys, ensure_tuple, get_seed\n\nlogger = get_logger(__name__)\n\n__all__ = [\"Compose\", \"OneOf\", \"RandomOrder\", \"SomeOf\", \"execute_compose\"]\n\n\ndef execute_compose(\n    data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],\n    transforms: Sequence[Any],\n    map_items: bool | int = True,\n    unpack_items: bool = False,\n    start: int = 0,\n    end: int | None = None,\n    lazy: bool | None = False,\n    overrides: dict | None = None,\n    threading: bool = False,\n    log_stats: bool | str = False,\n) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]:\n    \"\"\"\n    ``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence\n    of transforms. As well as being used by Compose, it can be used by subclasses of\n    Compose and by code that doesn't have a Compose instance but needs to execute a\n    sequence of transforms is if it were executed by Compose. It should only be used directly\n    when it is not possible to use ``Compose.__call__`` to achieve the same goal.\n    Args:\n        data: a tensor-like object to be transformed\n        transforms: a sequence of transforms to be carried out\n        map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,\n            it can behave as follows:\n            - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied\n              to the first level of items in `data`.\n            - If an integer is provided, it specifies the maximum level of nesting to which the transformation\n              should be recursively applied. This allows treating multi-sample transforms applied after another\n              multi-sample transform while controlling how deep the mapping goes.\n        unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.\n            defaults to `False`.\n        start: the index of the first transform to be executed. If not set, this defaults to 0\n        end: the index after the last transform to be executed. If set, the transform at index-1\n            is the last transform that is executed. If this is not set, it defaults to len(transforms)\n        lazy: whether to enable :ref:`lazy evaluation<lazy_resampling>` for lazy transforms. If False, transforms will be\n            carried out on a transform by transform basis. If True, all lazy transforms will\n            be executed by accumulating changes and resampling as few times as possible.\n        overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden\n            when executing a pipeline. These each parameter that is compatible with a given transform is then applied\n            to that transform before it is executed. Note that overrides are currently only applied when\n            :ref:`lazy evaluation<lazy_resampling>` is enabled for the pipeline or a given transform. If lazy is False\n            they are ignored. Currently supported args are:\n            {``\"mode\"``, ``\"padding_mode\"``, ``\"dtype\"``, ``\"align_corners\"``, ``\"resample_mode\"``, ``device``}.\n        threading: whether executing is happening in a threaded environment. If set, copies are made\n            of transforms that have the ``RandomizedTrait`` interface.\n        log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.\n            Setting this to False disables logging. Setting it to True enables logging to the default loggers.\n            Setting a string overrides the logger name to which logging is performed.\n\n    Returns:\n        A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running\n        `data`` through the sequence of ``transforms``.\n    \"\"\"\n    end_ = len(transforms) if end is None else end\n    if start is None:\n        raise ValueError(f\"'start' ({start}) cannot be None\")\n    if start < 0:\n        raise ValueError(f\"'start' ({start}) cannot be less than 0\")\n    if start > end_:\n        raise ValueError(f\"'start' ({start}) must be less than 'end' ({end_})\")\n    if end_ > len(transforms):\n        raise ValueError(f\"'end' ({end_}) must be less than or equal to the transform count ({len(transforms)}\")\n\n    # no-op if the range is empty\n    if start == end:\n        return data\n\n    for _transform in transforms[start:end]:\n        if threading:\n            _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform\n        data = apply_transform(\n            _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats\n        )\n    data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)\n    return data\n\n\nclass Compose(Randomizable, InvertibleTransform, LazyTransform):\n    \"\"\"\n    ``Compose`` provides the ability to chain a series of callables together in\n    a sequential manner. Each transform in the sequence must take a single\n    argument and return a single value.\n\n    ``Compose`` can be used in two ways:\n\n    #. With a series of transforms that accept and return a single\n       ndarray / tensor / tensor-like parameter.\n    #. With a series of transforms that accept and return a dictionary that\n       contains one or more parameters. Such transforms must have pass-through\n       semantics that unused values in the dictionary must be copied to the return\n       dictionary. It is required that the dictionary is copied between input\n       and output of each transform.\n\n    If some transform takes a data item dictionary as input, and returns a\n    sequence of data items in the transform chain, all following transforms\n    will be applied to each item of this list if `map_items` is `True` (the\n    default).  If `map_items` is `False`, the returned sequence is passed whole\n    to the next callable in the chain.\n\n    For example:\n\n    A `Compose([transformA, transformB, transformC],\n    map_items=True)(data_dict)` could achieve the following patch-based\n    transformation on the `data_dict` input:\n\n    #. transformA normalizes the intensity of 'img' field in the `data_dict`.\n    #. transformB crops out image patches from the 'img' and 'seg' of\n       `data_dict`, and return a list of three patch samples::\n\n        {'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)}\n                             applying transformB\n                                 ---------->\n        [{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},\n         {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},\n         {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},]\n\n    #. transformC then randomly rotates or flips 'img' and 'seg' of\n       each dictionary item in the list returned by transformB.\n\n    The composed transforms will be set the same global random seed if user called\n    `set_determinism()`.\n\n    When using the pass-through dictionary operation, you can make use of\n    :class:`monai.transforms.adaptors.adaptor` to wrap transforms that don't conform\n    to the requirements. This approach allows you to use transforms from\n    otherwise incompatible libraries with minimal additional work.\n\n    Note:\n\n        In many cases, Compose is not the best way to create pre-processing\n        pipelines. Pre-processing is often not a strictly sequential series of\n        operations, and much of the complexity arises when a not-sequential\n        set of functions must be called as if it were a sequence.\n\n        Example: images and labels\n        Images typically require some kind of normalization that labels do not.\n        Both are then typically augmented through the use of random rotations,\n        flips, and deformations.\n        Compose can be used with a series of transforms that take a dictionary\n        that contains 'image' and 'label' entries. This might require wrapping\n        `torchvision` transforms before passing them to compose.\n        Alternatively, one can create a class with a `__call__` function that\n        calls your pre-processing functions taking into account that not all of\n        them are called on the labels.\n\n    Lazy resampling:\n\n        Lazy resampling is an experimental feature introduced in 1.2. Its purpose is\n        to reduce the number of resample operations that must be carried out when executing\n        a pipeline of transforms. This can provide significant performance improvements in\n        terms of pipeline executing speed and memory usage, and can also significantly\n        reduce the loss of information that occurs when performing a number of spatial\n        resamples in succession.\n\n        Lazy resampling can be enabled or disabled through the ``lazy`` parameter, either by\n        specifying it at initialisation time or overriding it at call time.\n\n        * False (default): Don't perform any lazy resampling\n        * None: Perform lazy resampling based on the 'lazy' properties of the transform instances.\n        * True: Always perform lazy resampling if possible. This will ignore the ``lazy`` properties\n          of the transform instances\n\n        Please see the :ref:`Lazy Resampling topic<lazy_resampling>` for more details of this feature\n        and examples of its use.\n\n    Args:\n        transforms: sequence of callables.\n        map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,\n            it can behave as follows:\n\n                - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied\n                  to the first level of items in `data`.\n                - If an integer is provided, it specifies the maximum level of nesting to which the transformation\n                  should be recursively applied. This allows treating multi-sample transforms applied after another\n                  multi-sample transform while controlling how deep the mapping goes.\n        unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.\n            defaults to `False`.\n        log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.\n            Setting this to False disables logging. Setting it to True enables logging to the default loggers.\n            Setting a string overrides the logger name to which logging is performed.\n        lazy: whether to enable :ref:`Lazy Resampling<lazy_resampling>` for lazy transforms. If False, transforms will\n            be carried out on a transform by transform basis. If True, all lazy transforms will be executed by\n            accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will\n            perform lazy execution on lazy transforms that have their `lazy` property set to True.\n        overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden\n            when executing a pipeline. These each parameter that is compatible with a given transform is then applied\n            to that transform before it is executed. Note that overrides are currently only applied when\n            :ref:`Lazy Resampling<lazy_resampling>` is enabled for the pipeline or a given transform. If lazy is False\n            they are ignored. Currently supported args are:\n            {``\"mode\"``, ``\"padding_mode\"``, ``\"dtype\"``, ``\"align_corners\"``, ``\"resample_mode\"``, ``device``}.\n    \"\"\"\n\n    def __init__(\n        self,\n        transforms: Sequence[Callable] | Callable | None = None,\n        map_items: bool | int = True,\n        unpack_items: bool = False,\n        log_stats: bool | str = False,\n        lazy: bool | None = False,\n        overrides: dict | None = None,\n    ) -> None:\n        LazyTransform.__init__(self, lazy=lazy)\n\n        if transforms is None:\n            transforms = []\n\n        if not isinstance(map_items, (bool, int)):\n            raise ValueError(\n                f\"Argument 'map_items' should be boolean or int. Got {type(map_items)}.\"\n                \"Check brackets when passing a sequence of callables.\"\n            )\n\n        self.transforms = ensure_tuple(transforms)\n        self.map_items = map_items\n        self.unpack_items = unpack_items\n        self.log_stats = log_stats\n        self.set_random_state(seed=get_seed())\n        self.overrides = overrides\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self._lazy = val\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose:\n        super().set_random_state(seed=seed, state=state)\n        for _transform in self.transforms:\n            if not isinstance(_transform, Randomizable):\n                continue\n            _transform.set_random_state(seed=int(self.R.randint(MAX_SEED, dtype=\"uint32\")))\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        for _transform in self.transforms:\n            if not isinstance(_transform, Randomizable):\n                continue\n            try:\n                _transform.randomize(data)\n            except TypeError as type_error:\n                tfm_name: str = type(_transform).__name__\n                warnings.warn(\n                    f\"Transform '{tfm_name}' in Compose not randomized\\n{tfm_name}.{type_error}.\", RuntimeWarning\n                )\n\n    def get_index_of_first(self, predicate):\n        \"\"\"\n        get_index_of_first takes a ``predicate`` and returns the index of the first transform that\n        satisfies the predicate (ie. makes the predicate return True). If it is unable to find\n        a transform that satisfies the ``predicate``, it returns None.\n\n        Example:\n            c = Compose([Flip(...), Rotate90(...), Zoom(...), RandRotate(...), Resize(...)])\n\n            print(c.get_index_of_first(lambda t: isinstance(t, RandomTrait)))\n            >>> 3\n            print(c.get_index_of_first(lambda t: isinstance(t, Compose)))\n            >>> None\n\n        Note:\n            This is only performed on the transforms directly held by this instance. If this\n            instance has nested ``Compose`` transforms or other transforms that contain transforms,\n            it does not iterate into them.\n\n\n        Args:\n            predicate: a callable that takes a single argument and returns a bool. When called\n            it is passed a transform from the sequence of transforms contained by this compose\n            instance.\n\n        Returns:\n            The index of the first transform in the sequence for which ``predicate`` returns\n            True. None if no transform satisfies the ``predicate``\n\n        \"\"\"\n        for i in range(len(self.transforms)):\n            if predicate(self.transforms[i]):\n                return i\n        return None\n\n    def flatten(self):\n        \"\"\"Return a Composition with a simple list of transforms, as opposed to any nested Compositions.\n\n        e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()`\n        will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.\n\n        \"\"\"\n        new_transforms = []\n        for t in self.transforms:\n            if type(t) is Compose:  # nopep8\n                new_transforms += t.flatten().transforms\n            else:\n                new_transforms.append(t)\n\n        return Compose(new_transforms)\n\n    def __len__(self):\n        \"\"\"Return number of transformations.\"\"\"\n        return len(self.flatten().transforms)\n\n    def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):\n        _lazy = self._lazy if lazy is None else lazy\n        result = execute_compose(\n            input_,\n            transforms=self.transforms,\n            start=start,\n            end=end,\n            map_items=self.map_items,\n            unpack_items=self.unpack_items,\n            lazy=_lazy,\n            overrides=self.overrides,\n            threading=threading,\n            log_stats=self.log_stats,\n        )\n\n        return result\n\n    def inverse(self, data):\n        self._raise_if_not_invertible(data)\n\n        invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)]\n        if not invertible_transforms:\n            warnings.warn(\"inverse has been called but no invertible transforms have been supplied\")\n\n        if self._lazy is True:\n            warnings.warn(\n                f\"'lazy' is set to {self._lazy} but lazy execution is not supported when inverting. \"\n                f\"'lazy' has been overridden to False for the call to inverse\"\n            )\n        # loop backwards over transforms\n        for t in reversed(invertible_transforms):\n            data = apply_transform(\n                t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats\n            )\n        return data\n\n    @staticmethod\n    def _raise_if_not_invertible(data: Any):\n        from monai.transforms.utils import has_status_keys\n\n        invertible, reasons = has_status_keys(\n            data, TraceStatusKeys.PENDING_DURING_APPLY, \"Pending operations while applying an operation\"\n        )\n\n        if invertible is False:\n            if reasons is not None:\n                reason_text = \"\\n\".join(reasons)\n                raise RuntimeError(f\"Unable to run inverse on 'data' for the following reasons:\\n{reason_text}\")\n            else:\n                raise RuntimeError(\"Unable to run inverse on 'data'; no reason logged in trace data\")\n\n\nclass OneOf(Compose):\n    \"\"\"\n    ``OneOf`` provides the ability to randomly choose one transform out of a\n    list of callables with pre-defined probabilities for each.\n\n    Args:\n        transforms: sequence of callables.\n        weights: probabilities corresponding to each callable in transforms.\n            Probabilities are normalized to sum to one.\n        map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,\n            it can behave as follows:\n\n                - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied\n                  to the first level of items in `data`.\n                - If an integer is provided, it specifies the maximum level of nesting to which the transformation\n                  should be recursively applied. This allows treating multi-sample transforms applied after another\n                  multi-sample transform while controlling how deep the mapping goes.\n        unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.\n            defaults to `False`.\n        log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.\n            Setting this to False disables logging. Setting it to True enables logging to the default loggers.\n            Setting a string overrides the logger name to which logging is performed.\n        lazy: whether to enable :ref:`Lazy Resampling<lazy_resampling>` for lazy transforms. If False, transforms will\n            be carried out on a transform by transform basis. If True, all lazy transforms will be executed by\n            accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will\n            perform lazy execution on lazy transforms that have their `lazy` property set to True.\n        overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden\n            when executing a pipeline. These each parameter that is compatible with a given transform is then applied\n            to that transform before it is executed. Note that overrides are currently only applied when\n            :ref:`Lazy Resampling<lazy_resampling>` is enabled for the pipeline or a given transform. If lazy is False\n            they are ignored. Currently supported args are:\n            {``\"mode\"``, ``\"padding_mode\"``, ``\"dtype\"``, ``\"align_corners\"``, ``\"resample_mode\"``, ``device``}.\n    \"\"\"\n\n    def __init__(\n        self,\n        transforms: Sequence[Callable] | Callable | None = None,\n        weights: Sequence[float] | float | None = None,\n        map_items: bool | int = True,\n        unpack_items: bool = False,\n        log_stats: bool | str = False,\n        lazy: bool | None = False,\n        overrides: dict | None = None,\n    ) -> None:\n        super().__init__(transforms, map_items, unpack_items, log_stats, lazy, overrides)\n        if len(self.transforms) == 0:\n            weights = []\n        elif weights is None or isinstance(weights, float):\n            weights = [1.0 / len(self.transforms)] * len(self.transforms)\n        if len(weights) != len(self.transforms):\n            raise ValueError(\n                \"transforms and weights should be same size if both specified as sequences, \"\n                f\"got {len(weights)} and {len(self.transforms)}.\"\n            )\n        self.weights = ensure_tuple(self._normalize_probabilities(weights))\n        self.log_stats = log_stats\n\n    def _normalize_probabilities(self, weights):\n        if len(weights) == 0:\n            return weights\n        weights = np.array(weights)\n        if np.any(weights < 0):\n            raise ValueError(f\"Probabilities must be greater than or equal to zero, got {weights}.\")\n        if np.all(weights == 0):\n            raise ValueError(f\"At least one probability must be greater than zero, got {weights}.\")\n        weights = weights / weights.sum()\n        return list(weights)\n\n    def flatten(self):\n        transforms = []\n        weights = []\n        for t, w in zip(self.transforms, self.weights):\n            # if nested, probability is the current weight multiplied by the nested weights,\n            # and so on recursively\n            if isinstance(t, OneOf):\n                tr = t.flatten()\n                for t_, w_ in zip(tr.transforms, tr.weights):\n                    transforms.append(t_)\n                    weights.append(w_ * w)\n            else:\n                transforms.append(t)\n                weights.append(w)\n        return OneOf(transforms, weights, self.map_items, self.unpack_items)\n\n    def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None):\n        if start != 0:\n            raise ValueError(f\"OneOf requires 'start' parameter to be 0 (start set to {start})\")\n        if end is not None:\n            raise ValueError(f\"OneOf requires 'end' parameter to be None (end set to {end}\")\n\n        if len(self.transforms) == 0:\n            return data\n\n        index = self.R.multinomial(1, self.weights).argmax()\n        _transform = self.transforms[index]\n        _lazy = self._lazy if lazy is None else lazy\n\n        data = execute_compose(\n            data,\n            [_transform],\n            start=start,\n            end=end,\n            map_items=self.map_items,\n            unpack_items=self.unpack_items,\n            lazy=_lazy,\n            overrides=self.overrides,\n            threading=threading,\n            log_stats=self.log_stats,\n        )\n\n        # if the data is a mapping (dictionary), append the OneOf transform to the end\n        if isinstance(data, monai.data.MetaTensor):\n            self.push_transform(data, extra_info={\"index\": index})\n        elif isinstance(data, Mapping):\n            for key in data:  # dictionary not change size during iteration\n                if isinstance(data[key], monai.data.MetaTensor):\n                    self.push_transform(data[key], extra_info={\"index\": index})\n        return data\n\n    def inverse(self, data):\n        if len(self.transforms) == 0:\n            return data\n\n        index = None\n        if isinstance(data, monai.data.MetaTensor):\n            index = self.pop_transform(data)[TraceKeys.EXTRA_INFO][\"index\"]\n        elif isinstance(data, Mapping):\n            for key in data:\n                if isinstance(data[key], monai.data.MetaTensor):\n                    index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO][\"index\"]\n        else:\n            raise RuntimeError(\n                f\"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}.\"\n            )\n        if index is None:\n            # no invertible transforms have been applied\n            return data\n\n        _transform = self.transforms[index]\n        # apply the inverse\n        return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data\n\n\nclass RandomOrder(Compose):\n    \"\"\"\n    ``RandomOrder`` provides the ability to apply a list of transformations in random order.\n\n    Args:\n        transforms: sequence of callables.\n        map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.\n            defaults to `True`.\n        unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.\n            defaults to `False`.\n        log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.\n            Setting this to False disables logging. Setting it to True enables logging to the default loggers.\n            Setting a string overrides the logger name to which logging is performed.\n        lazy: whether to enable :ref:`Lazy Resampling<lazy_resampling>` for lazy transforms. If False, transforms will\n            be carried out on a transform by transform basis. If True, all lazy transforms will be executed by\n            accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will\n            perform lazy execution on lazy transforms that have their `lazy` property set to True.\n        overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden\n            when executing a pipeline. These each parameter that is compatible with a given transform is then applied\n            to that transform before it is executed. Note that overrides are currently only applied when\n            :ref:`Lazy Resampling<lazy_resampling>` is enabled for the pipeline or a given transform. If lazy is False\n            they are ignored. Currently supported args are:\n            {``\"mode\"``, ``\"padding_mode\"``, ``\"dtype\"``, ``\"align_corners\"``, ``\"resample_mode\"``, ``device``}.\n    \"\"\"\n\n    def __init__(\n        self,\n        transforms: Sequence[Callable] | Callable | None = None,\n        map_items: bool = True,\n        unpack_items: bool = False,\n        log_stats: bool | str = False,\n        lazy: bool | None = False,\n        overrides: dict | None = None,\n    ) -> None:\n        super().__init__(transforms, map_items, unpack_items, log_stats, lazy, overrides)\n        self.log_stats = log_stats\n\n    def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):\n        if start != 0:\n            raise ValueError(f\"RandomOrder requires 'start' parameter to be 0 (start set to {start})\")\n        if end is not None:\n            raise ValueError(f\"RandomOrder requires 'end' parameter to be None (end set to {end}\")\n\n        if len(self.transforms) == 0:\n            return input_\n\n        num = len(self.transforms)\n        applied_order = self.R.permutation(range(num))\n        _lazy = self._lazy if lazy is None else lazy\n\n        input_ = execute_compose(\n            input_,\n            [self.transforms[ind] for ind in applied_order],\n            start=start,\n            end=end,\n            map_items=self.map_items,\n            unpack_items=self.unpack_items,\n            lazy=_lazy,\n            threading=threading,\n            log_stats=self.log_stats,\n        )\n\n        # if the data is a mapping (dictionary), append the RandomOrder transform to the end\n        if isinstance(input_, monai.data.MetaTensor):\n            self.push_transform(input_, extra_info={\"applied_order\": applied_order})\n        elif isinstance(input_, Mapping):\n            for key in input_:  # dictionary not change size during iteration\n                if isinstance(input_[key], monai.data.MetaTensor):\n                    self.push_transform(input_[key], extra_info={\"applied_order\": applied_order})\n        return input_\n\n    def inverse(self, data):\n        if len(self.transforms) == 0:\n            return data\n\n        applied_order = None\n        if isinstance(data, monai.data.MetaTensor):\n            applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO][\"applied_order\"]\n        elif isinstance(data, Mapping):\n            for key in data:\n                if isinstance(data[key], monai.data.MetaTensor):\n                    applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO][\"applied_order\"]\n        else:\n            raise RuntimeError(\n                f\"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}.\"\n            )\n        if applied_order is None:\n            # no invertible transforms have been applied\n            return data\n\n        # loop backwards over transforms\n        for o in reversed(applied_order):\n            if isinstance(self.transforms[o], InvertibleTransform):\n                data = apply_transform(\n                    self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats\n                )\n        return data\n\n\nclass SomeOf(Compose):\n    \"\"\"\n    ``SomeOf`` samples a different sequence of transforms to apply each time it is called.\n\n    It can be configured to sample a fixed or varying number of transforms each time its called. Samples are drawn\n    uniformly, or from user supplied transform weights. When varying the number of transforms sampled per call,\n    the number of transforms to sample that call is sampled uniformly from a range supplied by the user.\n\n    Args:\n        transforms: list of callables.\n        map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.\n            Defaults to `True`.\n        unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.\n            Defaults to `False`.\n        log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.\n            Setting this to False disables logging. Setting it to True enables logging to the default loggers.\n            Setting a string overrides the logger name to which logging is performed.\n        num_transforms: a 2-tuple, int, or None. The 2-tuple specifies the minimum and maximum (inclusive) number of\n            transforms to sample at each iteration. If an int is given, the lower and upper bounds are set equal.\n            None sets it to `len(transforms)`. Default to `None`.\n        replace: whether to sample with replacement. Defaults to `False`.\n        weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform).\n        lazy: whether to enable :ref:`Lazy Resampling<lazy_resampling>` for lazy transforms. If False, transforms will\n            be carried out on a transform by transform basis. If True, all lazy transforms will be executed by\n            accumulating changes and resampling as few times as possible. If lazy is None, `Compose` will\n            perform lazy execution on lazy transforms that have their `lazy` property set to True.\n        overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden\n            when executing a pipeline. These each parameter that is compatible with a given transform is then applied\n            to that transform before it is executed. Note that overrides are currently only applied when\n            :ref:`Lazy Resampling<lazy_resampling>` is enabled for the pipeline or a given transform. If lazy is False\n            they are ignored. Currently supported args are:\n            {``\"mode\"``, ``\"padding_mode\"``, ``\"dtype\"``, ``\"align_corners\"``, ``\"resample_mode\"``, ``device``}.\n    \"\"\"\n\n    def __init__(\n        self,\n        transforms: Sequence[Callable] | Callable | None = None,\n        map_items: bool = True,\n        unpack_items: bool = False,\n        log_stats: bool | str = False,\n        num_transforms: int | tuple[int, int] | None = None,\n        replace: bool = False,\n        weights: list[int] | None = None,\n        lazy: bool | None = False,\n        overrides: dict | None = None,\n    ) -> None:\n        super().__init__(transforms, map_items, unpack_items, log_stats=log_stats, lazy=lazy, overrides=overrides)\n        self.min_num_transforms, self.max_num_transforms = self._ensure_valid_num_transforms(num_transforms)\n        self.replace = replace\n        self.weights = self._normalize_probabilities(weights)\n        self.log_stats = log_stats\n\n    def _ensure_valid_num_transforms(self, num_transforms: int | tuple[int, int] | None) -> tuple:\n        if (\n            not isinstance(num_transforms, tuple)\n            and not isinstance(num_transforms, list)\n            and not isinstance(num_transforms, int)\n            and num_transforms is not None\n        ):\n            raise ValueError(\n                f\"Expected num_transforms to be of type int, list, tuple or None, but it's {type(num_transforms)}\"\n            )\n\n        if num_transforms is None:\n            result = [len(self.transforms), len(self.transforms)]\n        elif isinstance(num_transforms, int):\n            n = min(len(self.transforms), num_transforms)\n            result = [n, n]\n        else:\n            if len(num_transforms) != 2:\n                raise ValueError(f\"Expected len(num_transforms)=2, but it was {len(num_transforms)}\")\n            if not isinstance(num_transforms[0], int) or not isinstance(num_transforms[1], int):\n                raise ValueError(\n                    f\"Expected (int,int), but received ({type(num_transforms[0])}, {type(num_transforms[1])})\"\n                )\n\n            result = [num_transforms[0], num_transforms[1]]\n\n        if result[0] < 0 or result[1] > len(self.transforms):\n            raise ValueError(f\"num_transforms={num_transforms} are out of the bounds [0, {len(self.transforms)}].\")\n\n        return ensure_tuple(result)\n\n    # Modified from OneOf\n    def _normalize_probabilities(self, weights):\n        if weights is None or len(self.transforms) == 0:\n            return None\n\n        weights = np.array(weights)\n\n        n_weights = len(weights)\n        if n_weights != len(self.transforms):\n            raise ValueError(f\"Expected len(weights)={len(self.transforms)}, got: {n_weights}.\")\n\n        if np.any(weights < 0):\n            raise ValueError(f\"Probabilities must be greater than or equal to zero, got {weights}.\")\n\n        if np.all(weights == 0):\n            raise ValueError(f\"At least one probability must be greater than zero, got {weights}.\")\n\n        weights = weights / weights.sum()\n\n        return ensure_tuple(list(weights))\n\n    def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None):\n        if start != 0:\n            raise ValueError(f\"SomeOf requires 'start' parameter to be 0 (start set to {start})\")\n        if end is not None:\n            raise ValueError(f\"SomeOf requires 'end' parameter to be None (end set to {end}\")\n\n        if len(self.transforms) == 0:\n            return data\n\n        sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1)\n        applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist()\n        _lazy = self._lazy if lazy is None else lazy\n\n        data = execute_compose(\n            data,\n            [self.transforms[a] for a in applied_order],\n            start=start,\n            end=end,\n            map_items=self.map_items,\n            unpack_items=self.unpack_items,\n            lazy=_lazy,\n            overrides=self.overrides,\n            threading=threading,\n            log_stats=self.log_stats,\n        )\n        if isinstance(data, monai.data.MetaTensor):\n            self.push_transform(data, extra_info={\"applied_order\": applied_order})\n        elif isinstance(data, Mapping):\n            for key in data:  # dictionary not change size during iteration\n                if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:\n                    self.push_transform(data, key, extra_info={\"applied_order\": applied_order})\n\n        return data\n\n    # From RandomOrder\n    def inverse(self, data):\n        if len(self.transforms) == 0:\n            return data\n\n        applied_order = None\n        if isinstance(data, monai.data.MetaTensor):\n            applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO][\"applied_order\"]\n        elif isinstance(data, Mapping):\n            for key in data:\n                if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:\n                    applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO][\"applied_order\"]\n        else:\n            raise RuntimeError(\n                f\"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}.\"\n            )\n        if applied_order is None:\n            # no invertible transforms have been applied\n            return data\n\n        # loop backwards over transforms\n        for o in reversed(applied_order):\n            if isinstance(self.transforms[o], InvertibleTransform):\n                data = apply_transform(\n                    self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats\n                )\n\n        return data\n"
  },
  {
    "path": "monai/transforms/croppad/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/croppad/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for crop and pad operations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom itertools import chain\nfrom math import ceil\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import IndexSelection\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import get_random_patch, get_valid_patch_size\nfrom monai.transforms.croppad.functional import crop_func, pad_func\nfrom monai.transforms.inverse import InvertibleTransform, TraceableTransform\nfrom monai.transforms.traits import MultiSampleTrait\nfrom monai.transforms.transform import LazyTransform, Randomizable, Transform\nfrom monai.transforms.utils import (\n    compute_divisible_spatial_size,\n    generate_label_classes_crop_centers,\n    generate_pos_neg_label_crop_centers,\n    generate_spatial_bounding_box,\n    is_positive,\n    map_binary_to_indices,\n    map_classes_to_indices,\n    weighted_patch_samples,\n)\nfrom monai.utils import ImageMetaKey as Key\nfrom monai.utils import (\n    LazyAttr,\n    Method,\n    PytorchPadMode,\n    TraceKeys,\n    TransformBackends,\n    convert_data_type,\n    convert_to_tensor,\n    ensure_tuple,\n    ensure_tuple_rep,\n    fall_back_tuple,\n    look_up_option,\n)\n\n__all__ = [\n    \"Pad\",\n    \"SpatialPad\",\n    \"BorderPad\",\n    \"DivisiblePad\",\n    \"Crop\",\n    \"SpatialCrop\",\n    \"CenterSpatialCrop\",\n    \"CenterScaleCrop\",\n    \"RandSpatialCrop\",\n    \"RandScaleCrop\",\n    \"RandSpatialCropSamples\",\n    \"CropForeground\",\n    \"RandWeightedCrop\",\n    \"RandCropByPosNegLabel\",\n    \"RandCropByLabelClasses\",\n    \"ResizeWithPadOrCrop\",\n    \"BoundingRect\",\n]\n\n\nclass Pad(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Perform padding for a given an amount of padding in each dimension.\n\n    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,\n    in which case `np.pad` will be used.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        to_pad: the amount to pad in each dimension (including the channel) [(low_H, high_H), (low_W, high_W), ...].\n            if None, must provide in the `__call__` at runtime.\n        mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        to_pad: tuple[tuple[int, int]] | None = None,\n        mode: str = PytorchPadMode.CONSTANT,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        LazyTransform.__init__(self, lazy)\n        self.to_pad = to_pad\n        self.mode = mode\n        self.kwargs = kwargs\n\n    def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:\n        \"\"\"\n        dynamically compute the pad width according to the spatial shape.\n        the output is the amount of padding for all dimensions including the channel.\n\n        Args:\n            spatial_shape: spatial shape of the original image.\n\n        \"\"\"\n        raise NotImplementedError(f\"subclass {self.__class__.__name__} must implement this method.\")\n\n    def __call__(  # type: ignore[override]\n        self,\n        img: torch.Tensor,\n        to_pad: tuple[tuple[int, int]] | None = None,\n        mode: str | None = None,\n        lazy: bool | None = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.\n            to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].\n                default to `self.to_pad`.\n            mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None.\n            kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        \"\"\"\n        to_pad_ = self.to_pad if to_pad is None else to_pad\n        if to_pad_ is None:\n            spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n            to_pad_ = self.compute_pad_width(spatial_shape)\n        mode_ = self.mode if mode is None else mode\n        kwargs_ = dict(self.kwargs)\n        kwargs_.update(kwargs)\n\n        img_t = convert_to_tensor(data=img, track_meta=get_track_meta())\n        lazy_ = self.lazy if lazy is None else lazy\n        return pad_func(img_t, to_pad_, self.get_transform_info(), mode_, lazy_, **kwargs_)\n\n    def inverse(self, data: MetaTensor) -> MetaTensor:\n        transform = self.pop_transform(data)\n        padded = transform[TraceKeys.EXTRA_INFO][\"padded\"]\n        if padded[0][0] > 0 or padded[0][1] > 0:  # slicing the channel dimension\n            s = padded[0][0]\n            e = min(max(padded[0][1], s + 1), len(data))\n            data = data[s : len(data) - e]  # type: ignore\n        roi_start = [i[0] for i in padded[1:]]\n        roi_end = [i - j[1] for i, j in zip(data.shape[1:], padded[1:])]\n        cropper = SpatialCrop(roi_start=roi_start, roi_end=roi_end)\n        with cropper.trace_transform(False):\n            return cropper(data)  # type: ignore\n\n\nclass SpatialPad(Pad):\n    \"\"\"\n    Performs padding to the data, symmetric for all sides or all on one side for each dimension.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_size: the spatial size of output data after padding, if a dimension of the input\n            data size is larger than the pad size, will not pad that dimension.\n            If its components have non-positive values, the corresponding size of input image will be used\n            (no padding). for example: if the spatial size of input data is [30, 30, 30] and\n            `spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30].\n        method: {``\"symmetric\"``, ``\"end\"``}\n            Pad image symmetrically on every side or only pad at the end sides. Defaults to ``\"symmetric\"``.\n        mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int | tuple[tuple[int, ...] | int, ...],\n        method: str = Method.SYMMETRIC,\n        mode: str = PytorchPadMode.CONSTANT,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        self.spatial_size = spatial_size\n        self.method: Method = look_up_option(method, Method)\n        super().__init__(mode=mode, lazy=lazy, **kwargs)\n\n    def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:\n        \"\"\"\n        dynamically compute the pad width according to the spatial shape.\n\n        Args:\n            spatial_shape: spatial shape of the original image.\n\n        \"\"\"\n        spatial_size = fall_back_tuple(self.spatial_size, spatial_shape)\n        if self.method == Method.SYMMETRIC:\n            pad_width = []\n            for i, sp_i in enumerate(spatial_size):\n                width = max(sp_i - spatial_shape[i], 0)\n                pad_width.append((int(width // 2), int(width - (width // 2))))\n        else:\n            pad_width = [(0, int(max(sp_i - spatial_shape[i], 0))) for i, sp_i in enumerate(spatial_size)]\n        return tuple([(0, 0)] + pad_width)  # type: ignore\n\n\nclass BorderPad(Pad):\n    \"\"\"\n    Pad the input data by adding specified borders to every dimension.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_border: specified size for every spatial border. Any -ve values will be set to 0. It can be 3 shapes:\n\n            - single int number, pad all the borders with the same size.\n            - length equals the length of image shape, pad every spatial dimension separately.\n              for example, image shape(CHW) is [1, 4, 4], spatial_border is [2, 1],\n              pad every border of H dim with 2, pad every border of W dim with 1, result shape is [1, 8, 6].\n            - length equals 2 x (length of image shape), pad every border of every dimension separately.\n              for example, image shape(CHW) is [1, 4, 4], spatial_border is [1, 2, 3, 4], pad top of H dim with 1,\n              pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4.\n              the result shape is [1, 7, 11].\n        mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    def __init__(\n        self, spatial_border: Sequence[int] | int, mode: str = PytorchPadMode.CONSTANT, lazy: bool = False, **kwargs\n    ) -> None:\n        self.spatial_border = spatial_border\n        super().__init__(mode=mode, lazy=lazy, **kwargs)\n\n    def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:\n        spatial_border = ensure_tuple(self.spatial_border)\n        if not all(isinstance(b, int) for b in spatial_border):\n            raise ValueError(f\"self.spatial_border must contain only ints, got {spatial_border}.\")\n        spatial_border = tuple(max(0, b) for b in spatial_border)\n\n        if len(spatial_border) == 1:\n            data_pad_width = [(int(spatial_border[0]), int(spatial_border[0])) for _ in spatial_shape]\n        elif len(spatial_border) == len(spatial_shape):\n            data_pad_width = [(int(sp), int(sp)) for sp in spatial_border[: len(spatial_shape)]]\n        elif len(spatial_border) == len(spatial_shape) * 2:\n            data_pad_width = [\n                (int(spatial_border[2 * i]), int(spatial_border[2 * i + 1])) for i in range(len(spatial_shape))\n            ]\n        else:\n            raise ValueError(\n                f\"Unsupported spatial_border length: {len(spatial_border)}, available options are \"\n                f\"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2 * len(spatial_shape)}].\"\n            )\n        return tuple([(0, 0)] + data_pad_width)  # type: ignore\n\n\nclass DivisiblePad(Pad):\n    \"\"\"\n    Pad the input data, so that the spatial sizes are divisible by `k`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = SpatialPad.backend\n\n    def __init__(\n        self,\n        k: Sequence[int] | int,\n        mode: str = PytorchPadMode.CONSTANT,\n        method: str = Method.SYMMETRIC,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            k: the target k for each spatial dimension.\n                if `k` is negative or 0, the original size is preserved.\n                if `k` is an int, the same `k` be applied to all the input spatial dimensions.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            method: {``\"symmetric\"``, ``\"end\"``}\n                Pad image symmetrically on every side or only pad at the end sides. Defaults to ``\"symmetric\"``.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n            kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        See also :py:class:`monai.transforms.SpatialPad`\n        \"\"\"\n        self.k = k\n        self.method: Method = Method(method)\n        super().__init__(mode=mode, lazy=lazy, **kwargs)\n\n    def compute_pad_width(self, spatial_shape: Sequence[int]) -> tuple[tuple[int, int]]:\n        new_size = compute_divisible_spatial_size(spatial_shape=spatial_shape, k=self.k)\n        spatial_pad = SpatialPad(spatial_size=new_size, method=self.method)\n        return spatial_pad.compute_pad_width(spatial_shape)\n\n\nclass Crop(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Perform crop operations on the input image.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, lazy: bool = False):\n        LazyTransform.__init__(self, lazy)\n\n    @staticmethod\n    def compute_slices(\n        roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_slices: Sequence[slice] | None = None,\n    ) -> tuple[slice]:\n        \"\"\"\n        Compute the crop slices based on specified `center & size` or `start & end` or `slices`.\n\n        Args:\n            roi_center: voxel coordinates for center of the crop ROI.\n            roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,\n                will not crop that dimension of the image.\n            roi_start: voxel coordinates for start of the crop ROI.\n            roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,\n                use the end coordinate of image.\n            roi_slices: list of slices for each of the spatial dimensions.\n\n        \"\"\"\n        roi_start_t: torch.Tensor\n\n        if roi_slices:\n            if not all(s.step is None or s.step == 1 for s in roi_slices):\n                raise ValueError(f\"only slice steps of 1/None are currently supported, got {roi_slices}.\")\n            return ensure_tuple(roi_slices)\n        else:\n            if roi_center is not None and roi_size is not None:\n                roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device=\"cpu\")\n                roi_size_t = convert_to_tensor(data=roi_size, dtype=torch.int16, wrap_sequence=True, device=\"cpu\")\n                _zeros = torch.zeros_like(roi_center_t)\n                half = torch.divide(roi_size_t, 2, rounding_mode=\"floor\")\n                roi_start_t = torch.maximum(roi_center_t - half, _zeros)\n                roi_end_t = torch.maximum(roi_start_t + roi_size_t, roi_start_t)\n            else:\n                if roi_start is None or roi_end is None:\n                    raise ValueError(\"please specify either roi_center, roi_size or roi_start, roi_end.\")\n                roi_start_t = convert_to_tensor(data=roi_start, dtype=torch.int16, wrap_sequence=True)\n                roi_start_t = torch.maximum(roi_start_t, torch.zeros_like(roi_start_t))\n                roi_end_t = convert_to_tensor(data=roi_end, dtype=torch.int16, wrap_sequence=True)\n                roi_end_t = torch.maximum(roi_end_t, roi_start_t)\n            # convert to slices (accounting for 1d)\n            if roi_start_t.numel() == 1:\n                return ensure_tuple([slice(int(roi_start_t.item()), int(roi_end_t.item()))])\n            return ensure_tuple([slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())])\n\n    def __call__(  # type: ignore[override]\n        self, img: torch.Tensor, slices: tuple[slice, ...], lazy: bool | None = None\n    ) -> torch.Tensor:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        slicing doesn't apply to the channel dim.\n\n        \"\"\"\n        slices_ = list(slices)\n        sd = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:])  # spatial dims\n        if len(slices_) < sd:\n            slices_ += [slice(None)] * (sd - len(slices_))\n        # Add in the channel (no cropping)\n        slices_ = list([slice(None)] + slices_[:sd])\n\n        img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta())\n        lazy_ = self.lazy if lazy is None else lazy\n        return crop_func(img_t, tuple(slices_), lazy_, self.get_transform_info())\n\n    def inverse(self, img: MetaTensor) -> MetaTensor:\n        transform = self.pop_transform(img)\n        cropped = transform[TraceKeys.EXTRA_INFO][\"cropped\"]\n        # the amount we pad is equal to the amount we cropped in each direction\n        inverse_transform = BorderPad(cropped)\n        # Apply inverse transform\n        with inverse_transform.trace_transform(False):\n            return inverse_transform(img)  # type: ignore\n\n\nclass SpatialCrop(Crop):\n    \"\"\"\n    General purpose cropper to produce sub-volume region of interest (ROI).\n    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.\n    So the cropped result may be smaller than the expected ROI, and the cropped results of several images may\n    not have exactly the same shape.\n    It can support to crop ND spatial (channel-first) data.\n\n    The cropped region can be parameterised in various ways:\n        - a list of slices for each spatial dimension (allows for use of negative indexing and `None`)\n        - a spatial center and size\n        - the start and end coordinates of the ROI\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    def __init__(\n        self,\n        roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,\n        roi_slices: Sequence[slice] | None = None,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            roi_center: voxel coordinates for center of the crop ROI.\n            roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,\n                will not crop that dimension of the image.\n            roi_start: voxel coordinates for start of the crop ROI.\n            roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,\n                use the end coordinate of image.\n            roi_slices: list of slices for each of the spatial dimensions.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n        \"\"\"\n        super().__init__(lazy)\n        self.slices = self.compute_slices(\n            roi_center=roi_center, roi_size=roi_size, roi_start=roi_start, roi_end=roi_end, roi_slices=roi_slices\n        )\n\n    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        slicing doesn't apply to the channel dim.\n\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        return super().__call__(img=img, slices=ensure_tuple(self.slices), lazy=lazy_)\n\n\nclass CenterSpatialCrop(Crop):\n    \"\"\"\n    Crop at the center of image with specified ROI size.\n    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.\n    So the cropped result may be smaller than the expected ROI, and the cropped results of several images may\n    not have exactly the same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        roi_size: the spatial size of the crop region e.g. [224,224,128]\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            If its components have non-positive values, the corresponding size of input image will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(self, roi_size: Sequence[int] | int, lazy: bool = False) -> None:\n        super().__init__(lazy=lazy)\n        self.roi_size = roi_size\n\n    def compute_slices(self, spatial_size: Sequence[int]) -> tuple[slice]:  # type: ignore[override]\n        roi_size = fall_back_tuple(self.roi_size, spatial_size)\n        roi_center = [i // 2 for i in spatial_size]\n        return super().compute_slices(roi_center=roi_center, roi_size=roi_size)\n\n    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:  # type: ignore[override]\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        slicing doesn't apply to the channel dim.\n\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        return super().__call__(\n            img=img,\n            slices=self.compute_slices(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]),\n            lazy=lazy_,\n        )\n\n\nclass CenterScaleCrop(Crop):\n    \"\"\"\n    Crop at the center of image with specified scale of ROI size.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        roi_scale: specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5] or a number for all dims.\n            If its components have non-positive values, will use `1.0` instead, which means the input image size.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(self, roi_scale: Sequence[float] | float, lazy: bool = False):\n        super().__init__(lazy=lazy)\n        self.roi_scale = roi_scale\n\n    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:  # type: ignore[override]\n        img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n        ndim = len(img_size)\n        roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]\n        lazy_ = self.lazy if lazy is None else lazy\n        cropper = CenterSpatialCrop(roi_size=roi_size, lazy=lazy_)\n        return super().__call__(img=img, slices=cropper.compute_slices(img_size), lazy=lazy_)\n\n\nclass RandSpatialCrop(Randomizable, Crop):\n    \"\"\"\n    Crop image with random size or specific size ROI. It can crop at a random position as center\n    or at the image center. And allows to set the minimum and maximum size to limit the randomly generated ROI.\n\n    Note: even `random_size=False`, if a dimension of the expected ROI size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped results\n    of several images may not have exactly the same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        roi_size: if `random_size` is True, it specifies the minimum crop region.\n            if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            If its components have non-positive values, the corresponding size of input image will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`\n            can specify the max crop region size. if None, defaults to the input image size.\n            if its components have non-positive values, the corresponding size of input image will be used.\n        random_center: crop at random position as center or the image center.\n        random_size: crop with random size or specific size ROI.\n            if True, the actual size is sampled from `randint(roi_size, max_roi_size + 1)`.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        roi_size: Sequence[int] | int,\n        max_roi_size: Sequence[int] | int | None = None,\n        random_center: bool = True,\n        random_size: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        super().__init__(lazy)\n        self.roi_size = roi_size\n        self.max_roi_size = max_roi_size\n        self.random_center = random_center\n        self.random_size = random_size\n        self._size: Sequence[int] | None = None\n        self._slices: tuple[slice, ...]\n\n    def randomize(self, img_size: Sequence[int]) -> None:\n        self._size = fall_back_tuple(self.roi_size, img_size)\n        if self.random_size:\n            max_size = img_size if self.max_roi_size is None else fall_back_tuple(self.max_roi_size, img_size)\n            if any(i > j for i, j in zip(self._size, max_size)):\n                raise ValueError(f\"min ROI size: {self._size} is larger than max ROI size: {max_size}.\")\n            self._size = tuple(self.R.randint(low=self._size[i], high=max_size[i] + 1) for i in range(len(img_size)))\n        if self.random_center:\n            valid_size = get_valid_patch_size(img_size, self._size)\n            self._slices = get_random_patch(img_size, valid_size, self.R)\n\n    def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor:  # type: ignore\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        slicing doesn't apply to the channel dim.\n\n        \"\"\"\n        img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n        if randomize:\n            self.randomize(img_size)\n        if self._size is None:\n            raise RuntimeError(\"self._size not specified.\")\n        lazy_ = self.lazy if lazy is None else lazy\n        if self.random_center:\n            return super().__call__(img=img, slices=self._slices, lazy=lazy_)\n        cropper = CenterSpatialCrop(self._size, lazy=lazy_)\n        return super().__call__(img=img, slices=cropper.compute_slices(img_size), lazy=lazy_)\n\n\nclass RandScaleCrop(RandSpatialCrop):\n    \"\"\"\n    Subclass of :py:class:`monai.transforms.RandSpatialCrop`. Crop image with\n    random size or specific size ROI.  It can crop at a random position as\n    center or at the image center.  And allows to set the minimum and maximum\n    scale of image size to limit the randomly generated ROI.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        roi_scale: if `random_size` is True, it specifies the minimum crop size: `roi_scale * image spatial size`.\n            if `random_size` is False, it specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5].\n            If its components have non-positive values, will use `1.0` instead, which means the input image size.\n        max_roi_scale: if `random_size` is True and `roi_scale` specifies the min crop region size, `max_roi_scale`\n            can specify the max crop region size: `max_roi_scale * image spatial size`.\n            if None, defaults to the input image size. if its components have non-positive values,\n            will use `1.0` instead, which means the input image size.\n        random_center: crop at random position as center or the image center.\n        random_size: crop with random size or specified size ROI by `roi_scale * image spatial size`.\n            if True, the actual size is sampled from\n            `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        roi_scale: Sequence[float] | float,\n        max_roi_scale: Sequence[float] | float | None = None,\n        random_center: bool = True,\n        random_size: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        super().__init__(\n            roi_size=-1, max_roi_size=None, random_center=random_center, random_size=random_size, lazy=lazy\n        )\n        self.roi_scale = roi_scale\n        self.max_roi_scale = max_roi_scale\n\n    def get_max_roi_size(self, img_size):\n        ndim = len(img_size)\n        self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]\n        if self.max_roi_scale is not None:\n            self.max_roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.max_roi_scale, ndim), img_size)]\n        else:\n            self.max_roi_size = None\n\n    def randomize(self, img_size: Sequence[int]) -> None:\n        self.get_max_roi_size(img_size)\n        super().randomize(img_size)\n\n    def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor:  # type: ignore\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        slicing doesn't apply to the channel dim.\n\n        \"\"\"\n        self.get_max_roi_size(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:])\n        lazy_ = self.lazy if lazy is None else lazy\n        return super().__call__(img=img, randomize=randomize, lazy=lazy_)\n\n\nclass RandSpatialCropSamples(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Crop image with random size or specific size ROI to generate a list of N samples.\n    It can crop at a random position as center or at the image center. And allows to set\n    the minimum size to limit the randomly generated ROI.\n    It will return a list of cropped images.\n\n    Note: even `random_size=False`, if a dimension of the expected ROI size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped\n    results of several images may not have exactly the same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        roi_size: if `random_size` is True, it specifies the minimum crop region.\n            if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            If its components have non-positive values, the corresponding size of input image will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        num_samples: number of samples (crop regions) to take in the returned list.\n        max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`\n            can specify the max crop region size. if None, defaults to the input image size.\n            if its components have non-positive values, the corresponding size of input image will be used.\n        random_center: crop at random position as center or the image center.\n        random_size: crop with random size or specific size ROI.\n            The actual size is sampled from `randint(roi_size, img_size)`.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n    Raises:\n        ValueError: When ``num_samples`` is nonpositive.\n\n    \"\"\"\n\n    backend = RandSpatialCrop.backend\n\n    def __init__(\n        self,\n        roi_size: Sequence[int] | int,\n        num_samples: int,\n        max_roi_size: Sequence[int] | int | None = None,\n        random_center: bool = True,\n        random_size: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        LazyTransform.__init__(self, lazy)\n        if num_samples < 1:\n            raise ValueError(f\"num_samples must be positive, got {num_samples}.\")\n        self.num_samples = num_samples\n        self.cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size, lazy)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSpatialCropSamples:\n        super().set_random_state(seed, state)\n        self.cropper.set_random_state(seed, state)\n        return self\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        self.cropper.lazy = value\n\n    def randomize(self, data: Any | None = None) -> None:\n        pass\n\n    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> list[torch.Tensor]:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        cropping doesn't change the channel dim.\n        \"\"\"\n        ret = []\n        lazy_ = self.lazy if lazy is None else lazy\n        for i in range(self.num_samples):\n            cropped = self.cropper(img, lazy=lazy_)\n            if get_track_meta():\n                cropped.meta[Key.PATCH_INDEX] = i  # type: ignore\n                self.push_transform(cropped, replace=True, lazy=lazy_)  # track as this class instead of RandSpatialCrop\n            ret.append(cropped)\n        return ret\n\n\nclass CropForeground(Crop):\n    \"\"\"\n    Crop an image using a bounding box. The bounding box is generated by selecting foreground using select_fn\n    at channels channel_indices. margin is added in each spatial dimension of the bounding box.\n    The typical usage is to help training and evaluation if the valid part is small in the whole medical image.\n    Users can define arbitrary function to select expected foreground from the whole image or specified channels.\n    And it can also add margin to every dim of the bounding box of foreground object.\n    For example:\n\n    .. code-block:: python\n\n        image = np.array(\n            [[[0, 0, 0, 0, 0],\n              [0, 1, 2, 1, 0],\n              [0, 1, 3, 2, 0],\n              [0, 1, 2, 1, 0],\n              [0, 0, 0, 0, 0]]])  # 1x5x5, single channel 5x5 image\n\n\n        def threshold_at_one(x):\n            # threshold at 1\n            return x > 1\n\n\n        cropper = CropForeground(select_fn=threshold_at_one, margin=0)\n        print(cropper(image))\n        [[[2, 1],\n          [3, 2],\n          [2, 1]]]\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        select_fn: Callable = is_positive,\n        channel_indices: IndexSelection | None = None,\n        margin: Sequence[int] | int = 0,\n        allow_smaller: bool = False,\n        return_coords: bool = False,\n        k_divisible: Sequence[int] | int = 1,\n        mode: str = PytorchPadMode.CONSTANT,\n        lazy: bool = False,\n        **pad_kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            select_fn: function to select expected foreground, default is to select values > 0.\n            channel_indices: if defined, select foreground only on the specified channels\n                of image. if None, select foreground on the whole image.\n            margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.\n            allow_smaller: when computing box size with `margin`, whether to allow the image edges to be smaller than the\n                final box edges. If `False`, part of a padded output box might be outside of the original image, if `True`,\n                the image edges will be used as the box edges. Default to `False`.\n                The default value is changed from `True` to `False` in v1.5.0.\n            return_coords: whether return the coordinates of spatial bounding box for foreground.\n            k_divisible: make each spatial dimension to be divisible by k, default to 1.\n                if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n            pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        \"\"\"\n        LazyTransform.__init__(self, lazy)\n        self.select_fn = select_fn\n        self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None\n        self.margin = margin\n        self.allow_smaller = allow_smaller\n        self.return_coords = return_coords\n        self.k_divisible = k_divisible\n        self.padder = Pad(mode=mode, lazy=lazy, **pad_kwargs)\n\n    @Crop.lazy.setter  # type: ignore\n    def lazy(self, _val: bool):\n        self._lazy = _val\n        self.padder.lazy = _val\n\n    @property\n    def requires_current_data(self):\n        return False\n\n    def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Compute the start points and end points of bounding box to crop.\n        And adjust bounding box coords to be divisible by `k`.\n\n        \"\"\"\n        box_start, box_end = generate_spatial_bounding_box(\n            img, self.select_fn, self.channel_indices, self.margin, self.allow_smaller\n        )\n        box_start_, *_ = convert_data_type(box_start, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True)\n        box_end_, *_ = convert_data_type(box_end, output_type=np.ndarray, dtype=np.int16, wrap_sequence=True)\n        orig_spatial_size = box_end_ - box_start_\n        # make the spatial size divisible by `k`\n        spatial_size = np.asarray(compute_divisible_spatial_size(orig_spatial_size.tolist(), k=self.k_divisible))\n        # update box_start and box_end\n        box_start_ = box_start_ - np.floor_divide(np.asarray(spatial_size) - orig_spatial_size, 2)\n        box_end_ = box_start_ + spatial_size\n        return box_start_, box_end_\n\n    def crop_pad(\n        self,\n        img: torch.Tensor,\n        box_start: np.ndarray,\n        box_end: np.ndarray,\n        mode: str | None = None,\n        lazy: bool = False,\n        **pad_kwargs,\n    ) -> torch.Tensor:\n        \"\"\"\n        Crop and pad based on the bounding box.\n\n        \"\"\"\n        slices = self.compute_slices(roi_start=box_start, roi_end=box_end)\n        cropped = super().__call__(img=img, slices=slices, lazy=lazy)\n        pad_to_start = np.maximum(-box_start, 0)\n        pad_to_end = np.maximum(\n            box_end - np.asarray(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]), 0\n        )\n        pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist())))\n        pad_width = BorderPad(spatial_border=pad).compute_pad_width(\n            cropped.peek_pending_shape() if isinstance(cropped, MetaTensor) else cropped.shape[1:]\n        )\n        ret = self.padder.__call__(img=cropped, to_pad=pad_width, mode=mode, lazy=lazy, **pad_kwargs)\n        # combine the traced cropping and padding into one transformation\n        # by taking the padded info and placing it in a key inside the crop info.\n        if get_track_meta() and isinstance(ret, MetaTensor):\n            if not lazy:\n                ret.applied_operations[-1][TraceKeys.EXTRA_INFO][\"pad_info\"] = ret.applied_operations.pop()\n            else:\n                pad_info = ret.pending_operations.pop()\n                crop_info = ret.pending_operations.pop()\n                extra = crop_info[TraceKeys.EXTRA_INFO]\n                extra[\"pad_info\"] = pad_info\n                self.push_transform(\n                    ret,\n                    orig_size=crop_info.get(TraceKeys.ORIG_SIZE),\n                    sp_size=pad_info[LazyAttr.SHAPE],\n                    affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE],\n                    lazy=lazy,\n                    extra_info=extra,\n                )\n        return ret\n\n    def __call__(  # type: ignore[override]\n        self, img: torch.Tensor, mode: str | None = None, lazy: bool | None = None, **pad_kwargs\n    ) -> torch.Tensor:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is channel-first and\n        slicing doesn't change the channel dim.\n        \"\"\"\n        box_start, box_end = self.compute_bounding_box(img)\n        lazy_ = self.lazy if lazy is None else lazy\n        cropped = self.crop_pad(img, box_start, box_end, mode, lazy=lazy_, **pad_kwargs)\n\n        if self.return_coords:\n            return cropped, box_start, box_end  # type: ignore[return-value]\n        return cropped\n\n    def inverse(self, img: MetaTensor) -> MetaTensor:\n        transform = self.get_most_recent_transform(img)\n        # we moved the padding info in the forward, so put it back for the inverse\n        pad_info = transform[TraceKeys.EXTRA_INFO].pop(\"pad_info\")\n        img.applied_operations.append(pad_info)\n        # first inverse the padder\n        inv = self.padder.inverse(img)\n        # and then inverse the cropper (self)\n        return super().inverse(inv)\n\n\nclass RandWeightedCrop(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Samples a list of `num_samples` image patches according to the provided `weight_map`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_size: the spatial size of the image patch e.g. [224, 224, 128].\n            If its components have non-positive values, the corresponding size of `img` will be used.\n        num_samples: number of samples (image patches) to take in the returned list.\n        weight_map: weight map used to generate patch samples. The weights must be non-negative.\n            Each element denotes a sampling weight of the spatial location. 0 indicates no sampling.\n            It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    backend = SpatialCrop.backend\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int,\n        num_samples: int = 1,\n        weight_map: NdarrayOrTensor | None = None,\n        lazy: bool = False,\n    ):\n        LazyTransform.__init__(self, lazy)\n        self.spatial_size = ensure_tuple(spatial_size)\n        self.num_samples = int(num_samples)\n        self.weight_map = weight_map\n        self.centers: list[np.ndarray] = []\n\n    def randomize(self, weight_map: NdarrayOrTensor) -> None:\n        self.centers = weighted_patch_samples(\n            spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R\n        )  # using only the first channel as weight map\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, _val: bool):\n        self._lazy = _val\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        weight_map: NdarrayOrTensor | None = None,\n        randomize: bool = True,\n        lazy: bool | None = None,\n    ) -> list[torch.Tensor]:\n        \"\"\"\n        Args:\n            img: input image to sample patches from. assuming `img` is a channel-first array.\n            weight_map: weight map used to generate patch samples. The weights must be non-negative.\n                Each element denotes a sampling weight of the spatial location. 0 indicates no sampling.\n                It should be a single-channel array in shape, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`\n            randomize: whether to execute random operations, default to `True`.\n            lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None.\n\n        Returns:\n            A list of image patches\n        \"\"\"\n        img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n\n        if randomize:\n            if weight_map is None:\n                weight_map = self.weight_map\n            if weight_map is None:\n                raise ValueError(\"weight map must be provided for weighted patch sampling.\")\n            w_shape = weight_map.peek_pending_shape() if isinstance(weight_map, MetaTensor) else weight_map.shape[1:]\n            if img_shape != w_shape:\n                warnings.warn(f\"image and weight map spatial shape mismatch: {img_shape} vs {w_shape}.\")\n            self.randomize(weight_map)\n\n        _spatial_size = fall_back_tuple(self.spatial_size, img_shape)\n        results: list[torch.Tensor] = []\n        lazy_ = self.lazy if lazy is None else lazy\n        for i, center in enumerate(self.centers):\n            cropper = SpatialCrop(roi_center=center, roi_size=_spatial_size, lazy=lazy_)\n            cropped = cropper(img)\n            if get_track_meta():\n                ret_: MetaTensor = cropped  # type: ignore\n                ret_.meta[Key.PATCH_INDEX] = i\n                ret_.meta[\"crop_center\"] = center\n                self.push_transform(ret_, replace=True, lazy=lazy_)\n            results.append(cropped)\n        return results\n\n\nclass RandCropByPosNegLabel(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Crop random fixed sized regions with the center being a foreground or background voxel\n    based on the Pos Neg Ratio.\n    And will return a list of arrays for all the cropped images.\n    For example, crop two (3 x 3) arrays from (5 x 5) array with pos/neg=1::\n\n        [[[0, 0, 0, 0, 0],\n          [0, 1, 2, 1, 0],            [[0, 1, 2],     [[2, 1, 0],\n          [0, 1, 3, 0, 0],     -->     [0, 1, 3],      [3, 0, 0],\n          [0, 0, 0, 0, 0],             [0, 0, 0]]      [0, 0, 0]]\n          [0, 0, 0, 0, 0]]]\n\n    If a dimension of the expected spatial size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped\n    results of several images may not have exactly same shape.\n    And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the\n    valid crop ROI.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_size: the spatial size of the crop region e.g. [224, 224, 128].\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            if its components have non-positive values, the corresponding size of `label` will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        label: the label image that is used for finding foreground/background, if None, must set at\n            `self.__call__`.  Non-zero indicates foreground, zero indicates background.\n        pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for the probability\n            to pick a foreground voxel as a center rather than a background voxel.\n        neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for the probability\n            to pick a foreground voxel as a center rather than a background voxel.\n        num_samples: number of samples (crop regions) to take in each list.\n        image: optional image data to help select valid area, can be same as `img` or another image array.\n            if not None, use ``label == 0 & image > image_threshold`` to select the negative\n            sample (background) center. So the crop center will only come from the valid image areas.\n        image_threshold: if enabled `image`, use ``image > image_threshold`` to determine\n            the valid image content areas.\n        fg_indices: if provided pre-computed foreground indices of `label`, will ignore above `image` and\n            `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices`\n            and `bg_indices` together, expect to be 1 dim array of spatial indices after flattening.\n            a typical usage is to call `FgBgToIndices` transform first and cache the results.\n        bg_indices: if provided pre-computed background indices of `label`, will ignore above `image` and\n            `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices`\n            and `bg_indices` together, expect to be 1 dim array of spatial indices after flattening.\n            a typical usage is to call `FgBgToIndices` transform first and cache the results.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n    Raises:\n        ValueError: When ``pos`` or ``neg`` are negative.\n        ValueError: When ``pos=0`` and ``neg=0``. Incompatible values.\n\n    \"\"\"\n\n    backend = SpatialCrop.backend\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int,\n        label: torch.Tensor | None = None,\n        pos: float = 1.0,\n        neg: float = 1.0,\n        num_samples: int = 1,\n        image: torch.Tensor | None = None,\n        image_threshold: float = 0.0,\n        fg_indices: NdarrayOrTensor | None = None,\n        bg_indices: NdarrayOrTensor | None = None,\n        allow_smaller: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        LazyTransform.__init__(self, lazy)\n        self.spatial_size = spatial_size\n        self.label = label\n        if pos < 0 or neg < 0:\n            raise ValueError(f\"pos and neg must be nonnegative, got pos={pos} neg={neg}.\")\n        if pos + neg == 0:\n            raise ValueError(\"Incompatible values: pos=0 and neg=0.\")\n        self.pos_ratio = pos / (pos + neg)\n        self.num_samples = num_samples\n        self.image = image\n        self.image_threshold = image_threshold\n        self.centers: tuple[tuple] | None = None\n        self.fg_indices = fg_indices\n        self.bg_indices = bg_indices\n        self.allow_smaller = allow_smaller\n\n    def randomize(\n        self,\n        label: torch.Tensor | None = None,\n        fg_indices: NdarrayOrTensor | None = None,\n        bg_indices: NdarrayOrTensor | None = None,\n        image: torch.Tensor | None = None,\n    ) -> None:\n        fg_indices_ = self.fg_indices if fg_indices is None else fg_indices\n        bg_indices_ = self.bg_indices if bg_indices is None else bg_indices\n        if fg_indices_ is None or bg_indices_ is None:\n            if label is None:\n                raise ValueError(\"label must be provided.\")\n            fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)\n        _shape = None\n        if label is not None:\n            _shape = label.peek_pending_shape() if isinstance(label, MetaTensor) else label.shape[1:]\n        elif image is not None:\n            _shape = image.peek_pending_shape() if isinstance(image, MetaTensor) else image.shape[1:]\n        if _shape is None:\n            raise ValueError(\"label or image must be provided to get the spatial shape.\")\n        self.centers = generate_pos_neg_label_crop_centers(\n            self.spatial_size,\n            self.num_samples,\n            self.pos_ratio,\n            _shape,\n            fg_indices_,\n            bg_indices_,\n            self.R,\n            self.allow_smaller,\n        )\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, _val: bool):\n        self._lazy = _val\n\n    @property\n    def requires_current_data(self):\n        return False\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        label: torch.Tensor | None = None,\n        image: torch.Tensor | None = None,\n        fg_indices: NdarrayOrTensor | None = None,\n        bg_indices: NdarrayOrTensor | None = None,\n        randomize: bool = True,\n        lazy: bool | None = None,\n    ) -> list[torch.Tensor]:\n        \"\"\"\n        Args:\n            img: input data to crop samples from based on the pos/neg ratio of `label` and `image`.\n                Assumes `img` is a channel-first array.\n            label: the label image that is used for finding foreground/background, if None, use `self.label`.\n            image: optional image data to help select valid area, can be same as `img` or another image array.\n                use ``label == 0 & image > image_threshold`` to select the negative sample(background) center.\n                so the crop center will only exist on valid image area. if None, use `self.image`.\n            fg_indices: foreground indices to randomly select crop centers,\n                need to provide `fg_indices` and `bg_indices` together.\n            bg_indices: background indices to randomly select crop centers,\n                need to provide `fg_indices` and `bg_indices` together.\n            randomize: whether to execute the random operations, default to `True`.\n            lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None.\n\n        \"\"\"\n        if image is None:\n            image = self.image\n        if randomize:\n            if label is None:\n                label = self.label\n            self.randomize(label, fg_indices, bg_indices, image)\n        results: list[torch.Tensor] = []\n        if self.centers is not None:\n            img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n            roi_size = fall_back_tuple(self.spatial_size, default=img_shape)\n            lazy_ = self.lazy if lazy is None else lazy\n            for i, center in enumerate(self.centers):\n                cropper = SpatialCrop(roi_center=center, roi_size=roi_size, lazy=lazy_)\n                cropped = cropper(img)\n                if get_track_meta():\n                    ret_: MetaTensor = cropped  # type: ignore\n                    ret_.meta[Key.PATCH_INDEX] = i\n                    ret_.meta[\"crop_center\"] = center\n                    self.push_transform(ret_, replace=True, lazy=lazy_)\n                results.append(cropped)\n        return results\n\n\nclass RandCropByLabelClasses(Randomizable, TraceableTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Crop random fixed sized regions with the center being a class based on the specified ratios of every class.\n    The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the\n    cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`::\n\n        image = np.array([\n            [[0.0, 0.3, 0.4, 0.2, 0.0],\n            [0.0, 0.1, 0.2, 0.1, 0.4],\n            [0.0, 0.3, 0.5, 0.2, 0.0],\n            [0.1, 0.2, 0.1, 0.1, 0.0],\n            [0.0, 0.1, 0.2, 0.1, 0.0]]\n        ])\n        label = np.array([\n            [[0, 0, 0, 0, 0],\n            [0, 1, 2, 1, 0],\n            [0, 1, 3, 0, 0],\n            [0, 0, 0, 0, 0],\n            [0, 0, 0, 0, 0]]\n        ])\n        cropper = RandCropByLabelClasses(\n            spatial_size=[3, 3],\n            ratios=[1, 2, 3, 1],\n            num_classes=4,\n            num_samples=2,\n        )\n        label_samples = cropper(img=label, label=label, image=image)\n\n        The 2 randomly cropped samples of `label` can be:\n        [[0, 1, 2],     [[0, 0, 0],\n         [0, 1, 3],      [1, 2, 1],\n         [0, 0, 0]]      [1, 3, 0]]\n\n    If a dimension of the expected spatial size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped\n    results of several images may not have exactly same shape.\n    And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the\n    valid crop ROI.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_size: the spatial size of the crop region e.g. [224, 224, 128].\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            if its components have non-positive values, the corresponding size of `label` will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        ratios: specified ratios of every class in the label to generate crop centers, including background class.\n            if None, every class will have the same ratio to generate crop centers.\n        label: the label image that is used for finding every class, if None, must set at `self.__call__`.\n        num_classes: number of classes for argmax label, not necessary for One-Hot label.\n        num_samples: number of samples (crop regions) to take in each list.\n        image: if image is not None, only return the indices of every class that are within the valid\n            region of the image (``image > image_threshold``).\n        image_threshold: if enabled `image`, use ``image > image_threshold`` to\n            determine the valid image content area and select class indices only in this area.\n        indices: if provided pre-computed indices of every class, will ignore above `image` and\n            `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array\n            of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first\n            and cache the results for better performance.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will remain\n            unchanged.\n        warn: if `True` prints a warning if a class is not present in the label.\n        max_samples_per_class: maximum length of indices to sample in each class to reduce memory consumption.\n            Default is None, no subsampling.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    backend = SpatialCrop.backend\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int,\n        ratios: list[float | int] | None = None,\n        label: torch.Tensor | None = None,\n        num_classes: int | None = None,\n        num_samples: int = 1,\n        image: torch.Tensor | None = None,\n        image_threshold: float = 0.0,\n        indices: list[NdarrayOrTensor] | None = None,\n        allow_smaller: bool = False,\n        warn: bool = True,\n        max_samples_per_class: int | None = None,\n        lazy: bool = False,\n    ) -> None:\n        LazyTransform.__init__(self, lazy)\n        self.spatial_size = spatial_size\n        self.ratios = ratios\n        self.label = label\n        self.num_classes = num_classes\n        self.num_samples = num_samples\n        self.image = image\n        self.image_threshold = image_threshold\n        self.centers: tuple[tuple] | None = None\n        self.indices = indices\n        self.allow_smaller = allow_smaller\n        self.warn = warn\n        self.max_samples_per_class = max_samples_per_class\n\n    def randomize(\n        self,\n        label: torch.Tensor | None = None,\n        indices: list[NdarrayOrTensor] | None = None,\n        image: torch.Tensor | None = None,\n    ) -> None:\n        indices_ = self.indices if indices is None else indices\n        if indices_ is None:\n            if label is None:\n                raise ValueError(\"label must not be None.\")\n            indices_ = map_classes_to_indices(\n                label, self.num_classes, image, self.image_threshold, self.max_samples_per_class\n            )\n        _shape = None\n        if label is not None:\n            _shape = label.peek_pending_shape() if isinstance(label, MetaTensor) else label.shape[1:]\n        elif image is not None:\n            _shape = image.peek_pending_shape() if isinstance(image, MetaTensor) else image.shape[1:]\n        if _shape is None:\n            raise ValueError(\"label or image must be provided to infer the output spatial shape.\")\n        self.centers = generate_label_classes_crop_centers(\n            self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn\n        )\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, _val: bool):\n        self._lazy = _val\n\n    @property\n    def requires_current_data(self):\n        return False\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        label: torch.Tensor | None = None,\n        image: torch.Tensor | None = None,\n        indices: list[NdarrayOrTensor] | None = None,\n        randomize: bool = True,\n        lazy: bool | None = None,\n    ) -> list[torch.Tensor]:\n        \"\"\"\n        Args:\n            img: input data to crop samples from based on the ratios of every class, assumes `img` is a\n                channel-first array.\n            label: the label image that is used for finding indices of every class, if None, use `self.label`.\n            image: optional image data to help select valid area, can be same as `img` or another image array.\n                use ``image > image_threshold`` to select the centers only in valid region. if None, use `self.image`.\n            indices: list of indices for every class in the image, used to randomly select crop centers.\n            randomize: whether to execute the random operations, default to `True`.\n            lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None.\n        \"\"\"\n        if image is None:\n            image = self.image\n        if randomize:\n            if label is None:\n                label = self.label\n            self.randomize(label, indices, image)\n        results: list[torch.Tensor] = []\n        if self.centers is not None:\n            img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n            roi_size = fall_back_tuple(self.spatial_size, default=img_shape)\n            lazy_ = self.lazy if lazy is None else lazy\n            for i, center in enumerate(self.centers):\n                cropper = SpatialCrop(roi_center=tuple(center), roi_size=roi_size, lazy=lazy_)\n                cropped = cropper(img)\n                if get_track_meta():\n                    ret_: MetaTensor = cropped  # type: ignore\n                    ret_.meta[Key.PATCH_INDEX] = i\n                    ret_.meta[\"crop_center\"] = center\n                    self.push_transform(ret_, replace=True, lazy=lazy_)\n                results.append(cropped)\n\n        return results\n\n\nclass ResizeWithPadOrCrop(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Resize an image to a target spatial size by either centrally cropping the image or\n    padding it evenly with a user-specified mode.\n    When the dimension is smaller than the target size, do symmetric padding along that dim.\n    When the dimension is larger than the target size, do central cropping along that dim.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_size: the spatial size of output data after padding or crop.\n            If has non-positive values, the corresponding size of input image will be used (no padding).\n        method: {``\"symmetric\"``, ``\"end\"``}\n            Pad image symmetrically on every side or only pad at the end sides. Defaults to ``\"symmetric\"``.\n        mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n    \"\"\"\n\n    backend = list(set(SpatialPad.backend) & set(CenterSpatialCrop.backend))\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int,\n        method: str = Method.SYMMETRIC,\n        mode: str = PytorchPadMode.CONSTANT,\n        lazy: bool = False,\n        **pad_kwargs,\n    ):\n        LazyTransform.__init__(self, lazy)\n        self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, lazy=lazy, **pad_kwargs)\n        self.cropper = CenterSpatialCrop(roi_size=spatial_size, lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.padder.lazy = val\n        self.cropper.lazy = val\n        self._lazy = val\n\n    def __call__(  # type: ignore[override]\n        self, img: torch.Tensor, mode: str | None = None, lazy: bool | None = None, **pad_kwargs\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: data to pad or crop, assuming `img` is channel-first and\n                padding or cropping doesn't apply to the channel dim.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            lazy: a flag to override the lazy behaviour for this call, if set. Defaults to None.\n            pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        ret = self.padder(self.cropper(img, lazy_), mode=mode, lazy=lazy_, **pad_kwargs)\n        # remove the individual info and combine\n        if get_track_meta():\n            ret_: MetaTensor = ret  # type: ignore\n            if not lazy_:\n                pad_info = ret_.applied_operations.pop()\n                crop_info = ret_.applied_operations.pop()\n                orig_size = crop_info.get(TraceKeys.ORIG_SIZE)\n                self.push_transform(\n                    ret_, orig_size=orig_size, extra_info={\"pad_info\": pad_info, \"crop_info\": crop_info}, lazy=lazy_\n                )\n            else:\n                pad_info = ret_.pending_operations.pop()\n                crop_info = ret_.pending_operations.pop()\n                orig_size = crop_info.get(TraceKeys.ORIG_SIZE)\n                self.push_transform(\n                    ret_,\n                    orig_size=orig_size,\n                    sp_size=pad_info[LazyAttr.SHAPE],\n                    affine=crop_info[LazyAttr.AFFINE] @ pad_info[LazyAttr.AFFINE],\n                    extra_info={\"pad_info\": pad_info, \"crop_info\": crop_info},\n                    lazy=lazy_,\n                )\n\n        return ret\n\n    def inverse(self, img: MetaTensor) -> MetaTensor:\n        transform = self.pop_transform(img)\n        return self.inverse_transform(img, transform)\n\n    def inverse_transform(self, img: MetaTensor, transform) -> MetaTensor:\n        # we joined the cropping and padding, so put them back before calling the inverse\n        crop_info = transform[TraceKeys.EXTRA_INFO].pop(\"crop_info\")\n        pad_info = transform[TraceKeys.EXTRA_INFO].pop(\"pad_info\")\n        img.applied_operations.append(crop_info)\n        img.applied_operations.append(pad_info)\n        # first inverse the padder\n        inv = self.padder.inverse(img)\n        # and then inverse the cropper (self)\n        return self.cropper.inverse(inv)\n\n\nclass BoundingRect(Transform):\n    \"\"\"\n    Compute coordinates of axis-aligned bounding rectangles from input image `img`.\n    The output format of the coordinates is (shape is [channel, 2 * spatial dims]):\n\n        [[1st_spatial_dim_start, 1st_spatial_dim_end,\n         2nd_spatial_dim_start, 2nd_spatial_dim_end,\n         ...,\n         Nth_spatial_dim_start, Nth_spatial_dim_end],\n\n         ...\n\n         [1st_spatial_dim_start, 1st_spatial_dim_end,\n         2nd_spatial_dim_start, 2nd_spatial_dim_end,\n         ...,\n         Nth_spatial_dim_start, Nth_spatial_dim_end]]\n\n    The bounding boxes edges are aligned with the input image edges.\n    This function returns [0, 0, ...] if there's no positive intensity.\n\n    Args:\n        select_fn: function to select expected foreground, default is to select values > 0.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, select_fn: Callable = is_positive) -> None:\n        self.select_fn = select_fn\n\n    def __call__(self, img: NdarrayOrTensor) -> np.ndarray:\n        \"\"\"\n        See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`.\n        \"\"\"\n        bbox = []\n\n        for channel in range(img.shape[0]):\n            start_, end_ = generate_spatial_bounding_box(img, select_fn=self.select_fn, channel_indices=channel)\n            bbox.append([i for k in zip(start_, end_) for i in k])\n\n        return np.stack(bbox, axis=0)\n"
  },
  {
    "path": "monai/transforms/croppad/batch.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for crop and pad operations acting on batches of data.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import list_data_collate\nfrom monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.utils.enums import Method, PytorchPadMode, TraceKeys\n\n__all__ = [\"PadListDataCollate\"]\n\n\ndef replace_element(to_replace, batch, idx, key_or_idx):\n    # since tuple is immutable we'll have to recreate\n    if isinstance(batch[idx], tuple):\n        batch_idx_list = list(batch[idx])\n        batch_idx_list[key_or_idx] = to_replace\n        batch[idx] = tuple(batch_idx_list)\n    # else, replace\n    else:\n        batch[idx][key_or_idx] = to_replace\n    return batch\n\n\nclass PadListDataCollate(InvertibleTransform):\n    \"\"\"\n    Same as MONAI's ``list_data_collate``, except any tensors are centrally padded to match the shape of the biggest\n    tensor in each dimension. This transform is useful if some of the applied transforms generate batch data of\n    different sizes.\n\n    This can be used on both list and dictionary data.\n    Note that in the case of the dictionary data, it may add the transform information to the list of invertible transforms\n    if input batch have different spatial shape, so need to call static method: `inverse` before inverting other transforms.\n\n    Note that normally, a user won't explicitly use the `__call__` method. Rather this would be passed to the `DataLoader`.\n    This means that `__call__` handles data as it comes out of a `DataLoader`, containing batch dimension. However, the\n    `inverse` operates on dictionaries containing images of shape `C,H,W,[D]`. This asymmetry is necessary so that we can\n    pass the inverse through multiprocessing.\n\n    Args:\n        method: padding method (see :py:class:`monai.transforms.SpatialPad`)\n        mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    def __init__(self, method: str = Method.SYMMETRIC, mode: str = PytorchPadMode.CONSTANT, **kwargs) -> None:\n        self.method = method\n        self.mode = mode\n        self.kwargs = kwargs\n\n    def __call__(self, batch: Any):\n        \"\"\"\n        Args:\n            batch: batch of data to pad-collate\n        \"\"\"\n        # data is either list of dicts or list of lists\n        is_list_of_dicts = isinstance(batch[0], dict)\n        # loop over items inside of each element in a batch\n        batch_item = tuple(batch[0].keys()) if is_list_of_dicts else range(len(batch[0]))\n        for key_or_idx in batch_item:\n            # calculate max size of each dimension\n            max_shapes = []\n            for elem in batch:\n                if not isinstance(elem[key_or_idx], (torch.Tensor, np.ndarray)):\n                    break\n                max_shapes.append(elem[key_or_idx].shape[1:])\n            # len > 0 if objects were arrays, else skip as no padding to be done\n            if not max_shapes:\n                continue\n            max_shape = np.array(max_shapes).max(axis=0)\n            # If all same size, skip\n            if np.all(np.array(max_shapes).min(axis=0) == max_shape):\n                continue\n\n            # Use `SpatialPad` to match sizes, Default params are central padding, padding with 0's\n            padder = SpatialPad(spatial_size=max_shape, method=self.method, mode=self.mode, **self.kwargs)\n            for idx, batch_i in enumerate(batch):\n                orig_size = batch_i[key_or_idx].shape[1:]\n                padded = padder(batch_i[key_or_idx])\n                batch = replace_element(padded, batch, idx, key_or_idx)\n\n                # If we have a dictionary of data, append to list\n                # padder transform info is re-added with self.push_transform to ensure one info dict per transform.\n                if is_list_of_dicts:\n                    self.push_transform(\n                        batch[idx],\n                        key_or_idx,\n                        orig_size=orig_size,\n                        extra_info=self.pop_transform(batch[idx], key_or_idx, check=False),\n                    )\n\n        # After padding, use default list collator\n        return list_data_collate(batch)\n\n    @staticmethod\n    def inverse(data: dict) -> dict[Hashable, np.ndarray]:\n        if not isinstance(data, Mapping):\n            raise RuntimeError(f\"Inverse can only currently be applied on dictionaries, got type {type(data)}.\")\n\n        d = dict(data)\n        for key in d:\n            transforms = None\n            if isinstance(d[key], MetaTensor):\n                transforms = d[key].applied_operations\n            else:\n                transform_key = InvertibleTransform.trace_key(key)\n                if transform_key in d:\n                    transforms = d[transform_key]\n            if not transforms or not isinstance(transforms[-1], dict):\n                continue\n            if transforms[-1].get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__:\n                xform = transforms.pop()\n                cropping = CenterSpatialCrop(xform.get(TraceKeys.ORIG_SIZE, -1))\n                with cropping.trace_transform(False):\n                    d[key] = cropping(d[key])  # fallback to image size\n        return d\n"
  },
  {
    "path": "monai/transforms/croppad/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for crop and pad operations\ndefined in :py:class:`monai.transforms.croppad.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Hashable, Mapping, Sequence\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import IndexSelection, KeysCollection, SequenceStr\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.croppad.array import (\n    BorderPad,\n    BoundingRect,\n    CenterScaleCrop,\n    CenterSpatialCrop,\n    Crop,\n    CropForeground,\n    DivisiblePad,\n    Pad,\n    RandCropByLabelClasses,\n    RandCropByPosNegLabel,\n    RandScaleCrop,\n    RandSpatialCrop,\n    RandSpatialCropSamples,\n    RandWeightedCrop,\n    ResizeWithPadOrCrop,\n    SpatialCrop,\n    SpatialPad,\n)\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.traits import LazyTrait, MultiSampleTrait\nfrom monai.transforms.transform import LazyTransform, MapTransform, Randomizable\nfrom monai.transforms.utils import is_positive\nfrom monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep\n\n__all__ = [\n    \"Padd\",\n    \"SpatialPadd\",\n    \"BorderPadd\",\n    \"DivisiblePadd\",\n    \"Cropd\",\n    \"RandCropd\",\n    \"SpatialCropd\",\n    \"CenterSpatialCropd\",\n    \"CenterScaleCropd\",\n    \"RandScaleCropd\",\n    \"RandSpatialCropd\",\n    \"RandSpatialCropSamplesd\",\n    \"CropForegroundd\",\n    \"RandWeightedCropd\",\n    \"RandCropByPosNegLabeld\",\n    \"ResizeWithPadOrCropd\",\n    \"BoundingRectd\",\n    \"RandCropByLabelClassesd\",\n    \"PadD\",\n    \"PadDict\",\n    \"SpatialPadD\",\n    \"SpatialPadDict\",\n    \"BorderPadD\",\n    \"BorderPadDict\",\n    \"DivisiblePadD\",\n    \"DivisiblePadDict\",\n    \"CropD\",\n    \"CropDict\",\n    \"RandCropD\",\n    \"RandCropDict\",\n    \"SpatialCropD\",\n    \"SpatialCropDict\",\n    \"CenterSpatialCropD\",\n    \"CenterSpatialCropDict\",\n    \"CenterScaleCropD\",\n    \"CenterScaleCropDict\",\n    \"RandScaleCropD\",\n    \"RandScaleCropDict\",\n    \"RandSpatialCropD\",\n    \"RandSpatialCropDict\",\n    \"RandSpatialCropSamplesD\",\n    \"RandSpatialCropSamplesDict\",\n    \"CropForegroundD\",\n    \"CropForegroundDict\",\n    \"RandWeightedCropD\",\n    \"RandWeightedCropDict\",\n    \"RandCropByPosNegLabelD\",\n    \"RandCropByPosNegLabelDict\",\n    \"ResizeWithPadOrCropD\",\n    \"ResizeWithPadOrCropDict\",\n    \"BoundingRectD\",\n    \"BoundingRectDict\",\n    \"RandCropByLabelClassesD\",\n    \"RandCropByLabelClassesDict\",\n]\n\n\nclass Padd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Pad`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Pad.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        padder: Pad,\n        mode: SequenceStr = PytorchPadMode.CONSTANT,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            padder: pad transform for the input image.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n                It also can be a sequence of string, each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy)\n        if lazy is True and not isinstance(padder, LazyTrait):\n            raise ValueError(\"'padder' must inherit LazyTrait if lazy is True \" f\"'padder' is of type({type(padder)})\")\n        self.padder = padder\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        if isinstance(self.padder, LazyTransform):\n            self.padder.lazy = value\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        if lazy_ is True and not isinstance(self.padder, LazyTrait):\n            raise ValueError(\n                \"'self.padder' must inherit LazyTrait if lazy is True \" f\"'self.padder' is of type({type(self.padder)}\"\n            )\n        for key, m in self.key_iterator(d, self.mode):\n            if isinstance(self.padder, LazyTrait):\n                d[key] = self.padder(d[key], mode=m, lazy=lazy_)\n            else:\n                d[key] = self.padder(d[key], mode=m)\n\n        return d\n\n    def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.padder.inverse(d[key])\n        return d\n\n\nclass SpatialPadd(Padd):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SpatialPad`.\n    Performs padding to the data, symmetric for all sides or all on one side for each dimension.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int] | int,\n        method: str = Method.SYMMETRIC,\n        mode: SequenceStr = PytorchPadMode.CONSTANT,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            spatial_size: the spatial size of output data after padding, if a dimension of the input\n                data size is larger than the pad size, will not pad that dimension.\n                If its components have non-positive values, the corresponding size of input image will be used.\n                for example: if the spatial size of input data is [30, 30, 30] and `spatial_size=[32, 25, -1]`,\n                the spatial size of output data will be [32, 30, 30].\n            method: {``\"symmetric\"``, ``\"end\"``}\n                Pad image symmetrically on every side or only pad at the end sides. Defaults to ``\"symmetric\"``.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n                It also can be a sequence of string, each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n            kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        \"\"\"\n        padder = SpatialPad(spatial_size, method, lazy=lazy, **kwargs)\n        Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass BorderPadd(Padd):\n    \"\"\"\n    Pad the input data by adding specified borders to every dimension.\n    Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = BorderPad.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_border: Sequence[int] | int,\n        mode: SequenceStr = PytorchPadMode.CONSTANT,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            spatial_border: specified size for every spatial border. it can be 3 shapes:\n\n                - single int number, pad all the borders with the same size.\n                - length equals the length of image shape, pad every spatial dimension separately.\n                  for example, image shape(CHW) is [1, 4, 4], spatial_border is [2, 1],\n                  pad every border of H dim with 2, pad every border of W dim with 1, result shape is [1, 8, 6].\n                - length equals 2 x (length of image shape), pad every border of every dimension separately.\n                  for example, image shape(CHW) is [1, 4, 4], spatial_border is [1, 2, 3, 4], pad top of H dim with 1,\n                  pad bottom of H dim with 2, pad left of W dim with 3, pad right of W dim with 4.\n                  the result shape is [1, 7, 11].\n\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n                It also can be a sequence of string, each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n            kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        \"\"\"\n        padder = BorderPad(spatial_border=spatial_border, lazy=lazy, **kwargs)\n        Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass DivisiblePadd(Padd):\n    \"\"\"\n    Pad the input data, so that the spatial sizes are divisible by `k`.\n    Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = DivisiblePad.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        k: Sequence[int] | int,\n        mode: SequenceStr = PytorchPadMode.CONSTANT,\n        method: str = Method.SYMMETRIC,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            k: the target k for each spatial dimension.\n                if `k` is negative or 0, the original size is preserved.\n                if `k` is an int, the same `k` be applied to all the input spatial dimensions.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n                It also can be a sequence of string, each element corresponds to a key in ``keys``.\n            method: {``\"symmetric\"``, ``\"end\"``}\n                Pad image symmetrically on every side or only pad at the end sides. Defaults to ``\"symmetric\"``.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n            kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        See also :py:class:`monai.transforms.SpatialPad`\n\n        \"\"\"\n        padder = DivisiblePad(k=k, method=method, lazy=lazy, **kwargs)\n        Padd.__init__(self, keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass Cropd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        cropper: crop transform for the input image.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    backend = Crop.backend\n\n    def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False, lazy: bool = False):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy)\n        self.cropper = cropper\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        if isinstance(self.cropper, LazyTransform):\n            self.cropper.lazy = value\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            d[key] = self.cropper(d[key], lazy=lazy_)  # type: ignore\n        return d\n\n    def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.cropper.inverse(d[key])\n        return d\n\n\nclass RandCropd(Cropd, Randomizable):\n    \"\"\"\n    Base class for random crop transform.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        cropper: random crop transform for the input image.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    backend = Crop.backend\n\n    def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool = False, lazy: bool = False):\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandCropd:\n        super().set_random_state(seed, state)\n        if isinstance(self.cropper, Randomizable):\n            self.cropper.set_random_state(seed, state)\n        return self\n\n    def randomize(self, img_size: Sequence[int]) -> None:\n        if isinstance(self.cropper, Randomizable):\n            self.cropper.randomize(img_size)\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        # the first key must exist to execute random operations\n        first_item = d[self.first_key(d)]\n        self.randomize(first_item.peek_pending_shape() if isinstance(first_item, MetaTensor) else first_item.shape[1:])\n        lazy_ = self.lazy if lazy is None else lazy\n        if lazy_ is True and not isinstance(self.cropper, LazyTrait):\n            raise ValueError(\n                \"'self.cropper' must inherit LazyTrait if lazy is True \"\n                f\"'self.cropper' is of type({type(self.cropper)}\"\n            )\n        for key in self.key_iterator(d):\n            kwargs = {\"randomize\": False} if isinstance(self.cropper, Randomizable) else {}\n            if isinstance(self.cropper, LazyTrait):\n                kwargs[\"lazy\"] = lazy_\n            d[key] = self.cropper(d[key], **kwargs)  # type: ignore\n        return d\n\n\nclass SpatialCropd(Cropd):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SpatialCrop`.\n    General purpose cropper to produce sub-volume region of interest (ROI).\n    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.\n    So the cropped result may be smaller than the expected ROI, and the cropped results of several images may\n    not have exactly the same shape.\n    It can support to crop ND spatial (channel-first) data.\n\n    The cropped region can be parameterised in various ways:\n        - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`)\n        - a spatial center and size\n        - the start and end coordinates of the ROI\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        roi_center: Sequence[int] | int | None = None,\n        roi_size: Sequence[int] | int | None = None,\n        roi_start: Sequence[int] | int | None = None,\n        roi_end: Sequence[int] | int | None = None,\n        roi_slices: Sequence[slice] | None = None,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            roi_center: voxel coordinates for center of the crop ROI.\n            roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,\n                will not crop that dimension of the image.\n            roi_start: voxel coordinates for start of the crop ROI.\n            roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,\n                use the end coordinate of image.\n            roi_slices: list of slices for each of the spatial dimensions.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n        \"\"\"\n        cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass CenterSpatialCropd(Cropd):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.CenterSpatialCrop`.\n    If a dimension of the expected ROI size is larger than the input image size, will not crop that dimension.\n    So the cropped result may be smaller than the expected ROI, and the cropped results of several images may\n    not have exactly the same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        roi_size: the size of the crop region e.g. [224,224,128]\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            If its components have non-positive values, the corresponding size of input image will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self, keys: KeysCollection, roi_size: Sequence[int] | int, allow_missing_keys: bool = False, lazy: bool = False\n    ) -> None:\n        cropper = CenterSpatialCrop(roi_size, lazy=lazy)\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass CenterScaleCropd(Cropd):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.CenterScaleCrop`.\n    Note: as using the same scaled ROI to crop, all the input data specified by `keys` should have\n    the same spatial shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        roi_scale: specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5] or a number for all dims.\n            If its components have non-positive values, will use `1.0` instead, which means the input image size.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        roi_scale: Sequence[float] | float,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        cropper = CenterScaleCrop(roi_scale, lazy=lazy)\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass RandSpatialCropd(RandCropd):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandSpatialCrop`.\n    Crop image with random size or specific size ROI. It can crop at a random position as\n    center or at the image center. And allows to set the minimum and maximum size to limit the randomly\n    generated ROI. Suppose all the expected fields specified by `keys` have same shape.\n\n    Note: even `random_size=False`, if a dimension of the expected ROI size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped\n    results of several images may not have exactly the same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        roi_size: if `random_size` is True, it specifies the minimum crop region.\n            if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            If its components have non-positive values, the corresponding size of input image will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`\n            can specify the max crop region size. if None, defaults to the input image size.\n            if its components have non-positive values, the corresponding size of input image will be used.\n        random_center: crop at random position as center or the image center.\n        random_size: crop with random size or specific size ROI.\n            if True, the actual size is sampled from:\n            `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        roi_size: Sequence[int] | int,\n        max_roi_size: Sequence[int] | int | None = None,\n        random_center: bool = True,\n        random_size: bool = False,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        cropper = RandSpatialCrop(roi_size, max_roi_size, random_center, random_size, lazy=lazy)\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass RandScaleCropd(RandCropd):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandScaleCrop`.\n    Crop image with random size or specific size ROI.\n    It can crop at a random position as center or at the image center.\n    And allows to set the minimum and maximum scale of image size to limit the randomly generated ROI.\n    Suppose all the expected fields specified by `keys` have same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        roi_scale: if `random_size` is True, it specifies the minimum crop size: `roi_scale * image spatial size`.\n            if `random_size` is False, it specifies the expected scale of image size to crop. e.g. [0.3, 0.4, 0.5].\n            If its components have non-positive values, will use `1.0` instead, which means the input image size.\n        max_roi_scale: if `random_size` is True and `roi_scale` specifies the min crop region size, `max_roi_scale`\n            can specify the max crop region size: `max_roi_scale * image spatial size`.\n            if None, defaults to the input image size. if its components have non-positive values,\n            will use `1.0` instead, which means the input image size.\n        random_center: crop at random position as center or the image center.\n        random_size: crop with random size or specified size ROI by `roi_scale * image spatial size`.\n            if True, the actual size is sampled from:\n            `randint(roi_scale * image spatial size, max_roi_scale * image spatial size + 1)`.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        roi_scale: Sequence[float] | float,\n        max_roi_scale: Sequence[float] | float | None = None,\n        random_center: bool = True,\n        random_size: bool = False,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        cropper = RandScaleCrop(roi_scale, max_roi_scale, random_center, random_size, lazy=lazy)\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n\n\nclass RandSpatialCropSamplesd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandSpatialCropSamples`.\n    Crop image with random size or specific size ROI to generate a list of N samples.\n    It can crop at a random position as center or at the image center. And allows to set\n    the minimum size to limit the randomly generated ROI. Suppose all the expected fields\n    specified by `keys` have same shape, and add `patch_index` to the corresponding metadata.\n    It will return a list of dictionaries for all the cropped images.\n\n    Note: even `random_size=False`, if a dimension of the expected ROI size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than the expected ROI, and the cropped\n    results of several images may not have exactly the same shape.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        roi_size: if `random_size` is True, it specifies the minimum crop region.\n            if `random_size` is False, it specifies the expected ROI size to crop. e.g. [224, 224, 128]\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            If its components have non-positive values, the corresponding size of input image will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `roi_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        num_samples: number of samples (crop regions) to take in the returned list.\n        max_roi_size: if `random_size` is True and `roi_size` specifies the min crop region size, `max_roi_size`\n            can specify the max crop region size. if None, defaults to the input image size.\n            if its components have non-positive values, the corresponding size of input image will be used.\n        random_center: crop at random position as center or the image center.\n        random_size: crop with random size or specific size ROI.\n            The actual size is sampled from `randint(roi_size, img_size)`.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n    Raises:\n        ValueError: When ``num_samples`` is nonpositive.\n\n    \"\"\"\n\n    backend = RandSpatialCropSamples.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        roi_size: Sequence[int] | int,\n        num_samples: int,\n        max_roi_size: Sequence[int] | int | None = None,\n        random_center: bool = True,\n        random_size: bool = False,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy)\n        self.cropper = RandSpatialCropSamples(\n            roi_size, num_samples, max_roi_size, random_center, random_size, lazy=lazy\n        )\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        self.cropper.lazy = value\n\n    def randomize(self, data: Any | None = None) -> None:\n        self.sub_seed = self.R.randint(MAX_SEED, dtype=\"uint32\")\n\n    def __call__(\n        self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None\n    ) -> list[dict[Hashable, torch.Tensor]]:\n        ret: list[dict[Hashable, torch.Tensor]] = [dict(data) for _ in range(self.cropper.num_samples)]\n        # deep copy all the unmodified data\n        for i in range(self.cropper.num_samples):\n            for key in set(data.keys()).difference(set(self.keys)):\n                ret[i][key] = deepcopy(data[key])\n\n        # for each key we reset the random state to ensure crops are the same\n        self.randomize()\n\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(dict(data)):\n            self.cropper.set_random_state(seed=int(self.sub_seed))\n            for i, im in enumerate(self.cropper(data[key], lazy=lazy_)):\n                ret[i][key] = im\n        return ret\n\n\nclass CropForegroundd(Cropd):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.CropForeground`.\n    Crop only the foreground object of the expected images.\n    The typical usage is to help training and evaluation if the valid part is small in the whole medical image.\n    The valid part can be determined by any field in the data with `source_key`, for example:\n    - Select values > 0 in image field as the foreground and crop on all fields specified by `keys`.\n    - Select label = 3 in label field as the foreground to crop on all fields specified by `keys`.\n    - Select label > 0 in the third channel of a One-Hot label field as the foreground to crop all `keys` fields.\n    Users can define arbitrary function to select expected foreground from the whole source image or specified\n    channels. And it can also add margin to every dim of the bounding box of foreground object.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        source_key: str,\n        select_fn: Callable = is_positive,\n        channel_indices: IndexSelection | None = None,\n        margin: Sequence[int] | int = 0,\n        allow_smaller: bool = False,\n        k_divisible: Sequence[int] | int = 1,\n        mode: SequenceStr = PytorchPadMode.CONSTANT,\n        start_coord_key: str | None = \"foreground_start_coord\",\n        end_coord_key: str | None = \"foreground_end_coord\",\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n        **pad_kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            source_key: data source to generate the bounding box of foreground, can be image or label, etc.\n            select_fn: function to select expected foreground, default is to select values > 0.\n            channel_indices: if defined, select foreground only on the specified channels\n                of image. if None, select foreground on the whole image.\n            margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.\n            allow_smaller: when computing box size with `margin`, whether to allow the image edges to be smaller than the\n                final box edges. If `False`, part of a padded output box might be outside of the original image, if `True`,\n                the image edges will be used as the box edges. Default to `False`.\n                The default value is changed from `True` to `False` in v1.5.0.\n            k_divisible: make each spatial dimension to be divisible by k, default to 1.\n                if `k_divisible` is an int, the same `k` be applied to all the input spatial dimensions.\n            mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n                it also can be a sequence of string, each element corresponds to a key in ``keys``.\n            start_coord_key: key to record the start coordinate of spatial bounding box for foreground.\n            end_coord_key: key to record the end coordinate of spatial bounding box for foreground.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n            pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n                note that `np.pad` treats channel dimension as the first dimension.\n\n        \"\"\"\n        self.source_key = source_key\n        self.start_coord_key = start_coord_key\n        self.end_coord_key = end_coord_key\n        cropper = CropForeground(\n            select_fn=select_fn,\n            channel_indices=channel_indices,\n            margin=margin,\n            allow_smaller=allow_smaller,\n            k_divisible=k_divisible,\n            lazy=lazy,\n            **pad_kwargs,\n        )\n        super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        self.cropper.lazy = value\n\n    @property\n    def requires_current_data(self):\n        return True\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        self.cropper: CropForeground\n        box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key])\n        if self.start_coord_key is not None:\n            d[self.start_coord_key] = box_start  # type: ignore\n        if self.end_coord_key is not None:\n            d[self.end_coord_key] = box_end  # type: ignore\n\n        lazy_ = self.lazy if lazy is None else lazy\n        for key, m in self.key_iterator(d, self.mode):\n            d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m, lazy=lazy_)\n        return d\n\n\nclass RandWeightedCropd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Samples a list of `num_samples` image patches according to the provided `weight_map`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        w_key: key for the weight map. The corresponding value will be used as the sampling weights,\n            it should be a single-channel array in size, for example, `(1, spatial_dim_0, spatial_dim_1, ...)`\n        spatial_size: the spatial size of the image patch e.g. [224, 224, 128].\n            If its components have non-positive values, the corresponding size of `img` will be used.\n        num_samples: number of samples (image patches) to take in the returned list.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n    See Also:\n        :py:class:`monai.transforms.RandWeightedCrop`\n    \"\"\"\n\n    backend = SpatialCrop.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        w_key: str,\n        spatial_size: Sequence[int] | int,\n        num_samples: int = 1,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy)\n        self.w_key = w_key\n        self.cropper = RandWeightedCrop(spatial_size, num_samples, lazy=lazy)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandWeightedCropd:\n        super().set_random_state(seed, state)\n        self.cropper.set_random_state(seed, state)\n        return self\n\n    def randomize(self, weight_map: NdarrayOrTensor) -> None:\n        self.cropper.randomize(weight_map)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        self.cropper.lazy = value\n\n    def __call__(\n        self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None\n    ) -> list[dict[Hashable, torch.Tensor]]:\n        # output starts as empty list of dictionaries\n        ret: list = [dict(data) for _ in range(self.cropper.num_samples)]\n        # deep copy all the unmodified data\n        for i in range(self.cropper.num_samples):\n            for key in set(data.keys()).difference(set(self.keys)):\n                ret[i][key] = deepcopy(data[key])\n\n        self.randomize(weight_map=data[self.w_key])\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(data):\n            for i, im in enumerate(self.cropper(data[key], randomize=False, lazy=lazy_)):\n                ret[i][key] = im\n        return ret\n\n\nclass RandCropByPosNegLabeld(Randomizable, MapTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandCropByPosNegLabel`.\n    Crop random fixed sized regions with the center being a foreground or background voxel\n    based on the Pos Neg Ratio.\n    Suppose all the expected fields specified by `keys` have same shape,\n    and add `patch_index` to the corresponding metadata.\n    And will return a list of dictionaries for all the cropped images.\n\n    If a dimension of the expected spatial size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than the expected size,\n    and the cropped results of several images may not have exactly the same shape.\n    And if the crop ROI is partly out of the image, will automatically adjust the crop center\n    to ensure the valid crop ROI.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        label_key: name of key for label image, this will be used for finding foreground/background.\n        spatial_size: the spatial size of the crop region e.g. [224, 224, 128].\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            if its components have non-positive values, the corresponding size of `data[label_key]` will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        pos: used with `neg` together to calculate the ratio ``pos / (pos + neg)`` for the probability\n            to pick a foreground voxel as a center rather than a background voxel.\n        neg: used with `pos` together to calculate the ratio ``pos / (pos + neg)`` for the probability\n            to pick a foreground voxel as a center rather than a background voxel.\n        num_samples: number of samples (crop regions) to take in each list.\n        image_key: if image_key is not None, use ``label == 0 & image > image_threshold`` to select\n            the negative sample(background) center. so the crop center will only exist on valid image area.\n        image_threshold: if enabled image_key, use ``image > image_threshold`` to determine\n            the valid image content area.\n        fg_indices_key: if provided pre-computed foreground indices of `label`, will ignore above `image_key` and\n            `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key`\n            and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening.\n            a typical usage is to call `FgBgToIndicesd` transform first and cache the results.\n        bg_indices_key: if provided pre-computed background indices of `label`, will ignore above `image_key` and\n            `image_threshold`, and randomly select crop centers based on them, need to provide `fg_indices_key`\n            and `bg_indices_key` together, expect to be 1 dim array of spatial indices after flattening.\n            a typical usage is to call `FgBgToIndicesd` transform first and cache the results.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n\n    Raises:\n        ValueError: When ``pos`` or ``neg`` are negative.\n        ValueError: When ``pos=0`` and ``neg=0``. Incompatible values.\n\n    \"\"\"\n\n    backend = RandCropByPosNegLabel.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        label_key: str,\n        spatial_size: Sequence[int] | int,\n        pos: float = 1.0,\n        neg: float = 1.0,\n        num_samples: int = 1,\n        image_key: str | None = None,\n        image_threshold: float = 0.0,\n        fg_indices_key: str | None = None,\n        bg_indices_key: str | None = None,\n        allow_smaller: bool = False,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy)\n        self.label_key = label_key\n        self.image_key = image_key\n        self.fg_indices_key = fg_indices_key\n        self.bg_indices_key = bg_indices_key\n        self.cropper = RandCropByPosNegLabel(\n            spatial_size=spatial_size,\n            pos=pos,\n            neg=neg,\n            num_samples=num_samples,\n            image_threshold=image_threshold,\n            allow_smaller=allow_smaller,\n            lazy=lazy,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandCropByPosNegLabeld:\n        super().set_random_state(seed, state)\n        self.cropper.set_random_state(seed, state)\n        return self\n\n    def randomize(\n        self,\n        label: torch.Tensor | None = None,\n        fg_indices: NdarrayOrTensor | None = None,\n        bg_indices: NdarrayOrTensor | None = None,\n        image: torch.Tensor | None = None,\n    ) -> None:\n        self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        self.cropper.lazy = value\n\n    @property\n    def requires_current_data(self):\n        return True\n\n    def __call__(\n        self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None\n    ) -> list[dict[Hashable, torch.Tensor]]:\n        d = dict(data)\n        fg_indices = d.pop(self.fg_indices_key, None)\n        bg_indices = d.pop(self.bg_indices_key, None)\n\n        self.randomize(d.get(self.label_key), fg_indices, bg_indices, d.get(self.image_key))\n\n        # initialize returned list with shallow copy to preserve key ordering\n        ret: list = [dict(d) for _ in range(self.cropper.num_samples)]\n        # deep copy all the unmodified data\n        for i in range(self.cropper.num_samples):\n            for key in set(d.keys()).difference(set(self.keys)):\n                ret[i][key] = deepcopy(d[key])\n\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            for i, im in enumerate(self.cropper(d[key], randomize=False, lazy=lazy_)):\n                ret[i][key] = im\n        return ret\n\n\nclass RandCropByLabelClassesd(Randomizable, MapTransform, LazyTransform, MultiSampleTrait):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandCropByLabelClasses`.\n    Crop random fixed sized regions with the center being a class based on the specified ratios of every class.\n    The label data can be One-Hot format array or Argmax data. And will return a list of arrays for all the\n    cropped images. For example, crop two (3 x 3) arrays from (5 x 5) array with `ratios=[1, 2, 3, 1]`::\n\n        cropper = RandCropByLabelClassesd(\n            keys=[\"image\", \"label\"],\n            label_key=\"label\",\n            spatial_size=[3, 3],\n            ratios=[1, 2, 3, 1],\n            num_classes=4,\n            num_samples=2,\n        )\n        data = {\n            \"image\": np.array([\n                [[0.0, 0.3, 0.4, 0.2, 0.0],\n                [0.0, 0.1, 0.2, 0.1, 0.4],\n                [0.0, 0.3, 0.5, 0.2, 0.0],\n                [0.1, 0.2, 0.1, 0.1, 0.0],\n                [0.0, 0.1, 0.2, 0.1, 0.0]]\n            ]),\n            \"label\": np.array([\n                [[0, 0, 0, 0, 0],\n                [0, 1, 2, 1, 0],\n                [0, 1, 3, 0, 0],\n                [0, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0]]\n            ]),\n        }\n        result = cropper(data)\n\n        The 2 randomly cropped samples of `label` can be:\n        [[0, 1, 2],     [[0, 0, 0],\n         [0, 1, 3],      [1, 2, 1],\n         [0, 0, 0]]      [1, 3, 0]]\n\n    If a dimension of the expected spatial size is larger than the input image size,\n    will not crop that dimension. So the cropped result may be smaller than expected size, and the cropped\n    results of several images may not have exactly same shape.\n    And if the crop ROI is partly out of the image, will automatically adjust the crop center to ensure the\n    valid crop ROI.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        label_key: name of key for label image, this will be used for finding indices of every class.\n        spatial_size: the spatial size of the crop region e.g. [224, 224, 128].\n            if a dimension of ROI size is larger than image size, will not crop that dimension of the image.\n            if its components have non-positive values, the corresponding size of `label` will be used.\n            for example: if the spatial size of input data is [40, 40, 40] and `spatial_size=[32, 64, -1]`,\n            the spatial size of output data will be [32, 40, 40].\n        ratios: specified ratios of every class in the label to generate crop centers, including background class.\n            if None, every class will have the same ratio to generate crop centers.\n        num_classes: number of classes for argmax label, not necessary for One-Hot label.\n        num_samples: number of samples (crop regions) to take in each list.\n        image_key: if image_key is not None, only return the indices of every class that are within the valid\n            region of the image (``image > image_threshold``).\n        image_threshold: if enabled `image_key`, use ``image > image_threshold`` to\n            determine the valid image content area and select class indices only in this area.\n        indices_key: if provided pre-computed indices of every class, will ignore above `image` and\n            `image_threshold`, and randomly select crop centers based on them, expect to be 1 dim array\n            of spatial indices after flattening. a typical usage is to call `ClassesToIndices` transform first\n            and cache the results for better performance.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will remain\n            unchanged.\n        allow_missing_keys: don't raise exception if key is missing.\n        warn: if `True` prints a warning if a class is not present in the label.\n        max_samples_per_class: maximum length of indices in each class to reduce memory consumption.\n            Default is None, no subsampling.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n    \"\"\"\n\n    backend = RandCropByLabelClasses.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        label_key: str,\n        spatial_size: Sequence[int] | int,\n        ratios: list[float | int] | None = None,\n        num_classes: int | None = None,\n        num_samples: int = 1,\n        image_key: str | None = None,\n        image_threshold: float = 0.0,\n        indices_key: str | None = None,\n        allow_smaller: bool = False,\n        allow_missing_keys: bool = False,\n        warn: bool = True,\n        max_samples_per_class: int | None = None,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy)\n        self.label_key = label_key\n        self.image_key = image_key\n        self.indices_key = indices_key\n        self.cropper = RandCropByLabelClasses(\n            spatial_size=spatial_size,\n            ratios=ratios,\n            num_classes=num_classes,\n            num_samples=num_samples,\n            image_threshold=image_threshold,\n            allow_smaller=allow_smaller,\n            warn=warn,\n            max_samples_per_class=max_samples_per_class,\n            lazy=lazy,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandCropByLabelClassesd:\n        super().set_random_state(seed, state)\n        self.cropper.set_random_state(seed, state)\n        return self\n\n    def randomize(\n        self, label: torch.Tensor, indices: list[NdarrayOrTensor] | None = None, image: torch.Tensor | None = None\n    ) -> None:\n        self.cropper.randomize(label=label, indices=indices, image=image)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, value: bool) -> None:\n        self._lazy = value\n        self.cropper.lazy = value\n\n    @property\n    def requires_current_data(self):\n        return True\n\n    def __call__(self, data: Mapping[Hashable, Any], lazy: bool | None = None) -> list[dict[Hashable, torch.Tensor]]:\n        d = dict(data)\n        self.randomize(d.get(self.label_key), d.pop(self.indices_key, None), d.get(self.image_key))  # type: ignore\n\n        # initialize returned list with shallow copy to preserve key ordering\n        ret: list = [dict(d) for _ in range(self.cropper.num_samples)]\n        # deep copy all the unmodified data\n        for i in range(self.cropper.num_samples):\n            for key in set(d.keys()).difference(set(self.keys)):\n                ret[i][key] = deepcopy(d[key])\n\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            for i, im in enumerate(self.cropper(d[key], randomize=False, lazy=lazy_)):\n                ret[i][key] = im\n        return ret\n\n\nclass ResizeWithPadOrCropd(Padd):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ResizeWithPadOrCrop`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        spatial_size: the spatial size of output data after padding or crop.\n            If has non-positive values, the corresponding size of input image will be used (no padding).\n        mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        allow_missing_keys: don't raise exception if key is missing.\n        method: {``\"symmetric\"``, ``\"end\"``}\n            Pad image symmetrically on every side or only pad at the end sides. Defaults to ``\"symmetric\"``.\n        lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.\n        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int] | int,\n        mode: SequenceStr = PytorchPadMode.CONSTANT,\n        allow_missing_keys: bool = False,\n        method: str = Method.SYMMETRIC,\n        lazy: bool = False,\n        **pad_kwargs,\n    ) -> None:\n        padcropper = ResizeWithPadOrCrop(spatial_size=spatial_size, method=method, **pad_kwargs, lazy=lazy)\n        super().__init__(\n            keys, padder=padcropper, mode=mode, allow_missing_keys=allow_missing_keys, lazy=lazy  # type: ignore\n        )\n\n\nclass BoundingRectd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.BoundingRect`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        bbox_key_postfix: the output bounding box coordinates will be\n            written to the value of `{key}_{bbox_key_postfix}`.\n        select_fn: function to select expected foreground, default is to select values > 0.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = BoundingRect.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        bbox_key_postfix: str = \"bbox\",\n        select_fn: Callable = is_positive,\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.bbox = BoundingRect(select_fn=select_fn)\n        self.bbox_key_postfix = bbox_key_postfix\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        See also: :py:class:`monai.transforms.utils.generate_spatial_bounding_box`.\n        \"\"\"\n        d = dict(data)\n        for key in self.key_iterator(d):\n            bbox = self.bbox(d[key])\n            key_to_add = f\"{key}_{self.bbox_key_postfix}\"\n            if key_to_add in d:\n                raise KeyError(f\"Bounding box data with key {key_to_add} already exists.\")\n            d[key_to_add] = bbox\n        return d\n\n\nPadD = PadDict = Padd\nSpatialPadD = SpatialPadDict = SpatialPadd\nBorderPadD = BorderPadDict = BorderPadd\nDivisiblePadD = DivisiblePadDict = DivisiblePadd\nCropD = CropDict = Cropd\nRandCropD = RandCropDict = RandCropd\nSpatialCropD = SpatialCropDict = SpatialCropd\nCenterSpatialCropD = CenterSpatialCropDict = CenterSpatialCropd\nCenterScaleCropD = CenterScaleCropDict = CenterScaleCropd\nRandSpatialCropD = RandSpatialCropDict = RandSpatialCropd\nRandScaleCropD = RandScaleCropDict = RandScaleCropd\nRandSpatialCropSamplesD = RandSpatialCropSamplesDict = RandSpatialCropSamplesd\nCropForegroundD = CropForegroundDict = CropForegroundd\nRandWeightedCropD = RandWeightedCropDict = RandWeightedCropd\nRandCropByPosNegLabelD = RandCropByPosNegLabelDict = RandCropByPosNegLabeld\nRandCropByLabelClassesD = RandCropByLabelClassesDict = RandCropByLabelClassesd\nResizeWithPadOrCropD = ResizeWithPadOrCropDict = ResizeWithPadOrCropd\nBoundingRectD = BoundingRectDict = BoundingRectd\n"
  },
  {
    "path": "monai/transforms/croppad/functional.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"functional\" transforms for spatial operations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport numpy as np\nimport torch\nfrom torch.nn.functional import pad as pad_pt\n\nfrom monai.config.type_definitions import NdarrayTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import to_affine_nd\nfrom monai.transforms.inverse import TraceableTransform\nfrom monai.transforms.utils import convert_pad_mode, create_translate\nfrom monai.utils import PytorchPadMode, convert_to_dst_type, convert_to_numpy, convert_to_tensor, ensure_tuple\n\n__all__ = [\"pad_nd\", \"pad_func\", \"crop_func\", \"crop_or_pad_nd\"]\n\n\ndef _convert_pt_pad_mode(padding_mode):\n    \"\"\"get the most similar mode of `pad` from ``padding_mode`` of the spatial resampling.\"\"\"\n    if padding_mode is None or padding_mode in (\"zeros\", \"constant\", \"grid-constant\"):\n        return PytorchPadMode.CONSTANT\n    elif padding_mode in (\"reflection\", \"reflect\", \"mirror\", \"grid-mirror\"):\n        return PytorchPadMode.REFLECT\n    elif padding_mode in (\"wrap\", \"grid-wrap\"):\n        return PytorchPadMode.CIRCULAR\n    return PytorchPadMode.REPLICATE  # \"nearest\", \"border\", and others\n\n\ndef _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor:\n    if isinstance(img, torch.Tensor):\n        if img.is_cuda:\n            warnings.warn(f\"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.\")\n        img_np = img.detach().cpu().numpy()\n    else:\n        img_np = np.asarray(img)\n    mode = convert_pad_mode(dst=img_np, mode=mode).value\n    if mode == \"constant\" and \"value\" in kwargs:\n        kwargs[\"constant_values\"] = kwargs.pop(\"value\")\n    img_np = np.pad(img_np, pad_width, mode=mode, **kwargs)  # type: ignore\n    return convert_to_dst_type(img_np, dst=img)[0]\n\n\ndef _pt_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> NdarrayTensor:\n    img_pt = torch.as_tensor(img)\n    mode = convert_pad_mode(dst=img_pt, mode=mode).value\n    if mode == \"constant\" and \"constant_values\" in kwargs:\n        _kwargs = kwargs.copy()\n        _kwargs[\"value\"] = _kwargs.pop(\"constant_values\")\n    else:\n        _kwargs = kwargs\n    pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]\n    # torch.pad expects `[B, C, H, W, [D]]` shape\n    img_pt = pad_pt(img_pt.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)\n    return convert_to_dst_type(img_pt, dst=img)[0]\n\n\ndef pad_nd(\n    img: NdarrayTensor, to_pad: list[tuple[int, int]], mode: str = PytorchPadMode.CONSTANT, **kwargs\n) -> NdarrayTensor:\n    \"\"\"\n    Pad `img` for a given an amount of padding in each dimension.\n\n    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,\n    in which case `np.pad` will be used.\n\n    Args:\n        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.\n        to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].\n            default to `self.to_pad`.\n        mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    Raises:\n        ValueError: If `value` is provided when `mode` is not ``\"constant\"``.\n    \"\"\"\n    if mode != \"constant\" and \"value\" in kwargs:\n        raise ValueError(\"'value' argument is only valid when mode='constant'\")\n    if mode in {\"linear_ramp\", \"maximum\", \"mean\", \"median\", \"minimum\", \"symmetric\", \"empty\"}:\n        return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)\n    try:\n        _pad = _np_pad\n        if mode in {\"constant\", \"reflect\", \"edge\", \"replicate\", \"wrap\", \"circular\"}:\n            # Try PyTorch pad for these modes; fallback to NumPy on error.\n            _pad = _pt_pad\n        return _pad(img, pad_width=to_pad, mode=mode, **kwargs)\n    except NotImplementedError:\n        # PyTorch does not support this combination, fall back to NumPy\n        return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)\n    except (ValueError, TypeError, RuntimeError) as err:\n        # PyTorch may raise generic errors for unsupported modes/dtypes or kwargs.\n        # Since there are no stable exception types for these cases, we fall back\n        # to NumPy by matching known error message patterns.\n        if any(k in str(err) for k in (\"supported\", \"unexpected keyword\", \"implemented\", \"value\")):\n            return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)\n        raise ValueError(\n            f\"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}\"\n        ) from err\n\n\ndef crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, ...], mode: str, **kwargs):\n    \"\"\"\n    Crop or pad using the translation matrix and spatial size. The translation coefficients are rounded\n    to the nearest integers. For a more generic implementation, please see :py:class:`monai.transforms.SpatialResample`.\n\n    Args:\n        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.\n        translation_mat: the translation matrix to be applied to the image. A translation matrix generated by,\n            for example, :py:func:`monai.transforms.utils.create_translate`. The translation coefficients are rounded\n            to the nearest integers.\n        spatial_size: the spatial size of the output image.\n        mode: the padding mode.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n    \"\"\"\n    ndim = len(img.shape) - 1\n    matrix_np = np.round(to_affine_nd(ndim, convert_to_numpy(translation_mat, wrap_sequence=True).copy()))\n    matrix_np = to_affine_nd(len(spatial_size), matrix_np)\n    cc = np.asarray(np.meshgrid(*[[0.5, x - 0.5] for x in spatial_size], indexing=\"ij\"))\n    cc = cc.reshape((len(spatial_size), -1))\n    src_cc = np.floor(matrix_np @ np.concatenate((cc, np.ones_like(cc[:1]))))\n    src_start, src_end = src_cc.min(axis=1), src_cc.max(axis=1)\n    to_pad, to_crop, do_pad, do_crop = [(0, 0)], [slice(None)], False, False\n    for s, e, sp in zip(src_start, src_end, img.shape[1:]):\n        do_pad, do_crop = do_pad or s < 0 or e > sp - 1, do_crop or s > 0 or e < sp - 1\n        to_pad += [(0 if s >= 0 else int(-s), 0 if e < sp - 1 else int(e - sp + 1))]\n        to_crop += [slice(int(max(s, 0)), int(e + 1 + to_pad[-1][0]))]\n    if do_pad:\n        _mode = _convert_pt_pad_mode(mode)\n        img = pad_nd(img, to_pad, mode=_mode, **kwargs)\n    if do_crop:\n        img = img[tuple(to_crop)]\n    return img\n\n\ndef pad_func(\n    img: torch.Tensor,\n    to_pad: tuple[tuple[int, int]],\n    transform_info: dict,\n    mode: str = PytorchPadMode.CONSTANT,\n    lazy: bool = False,\n    **kwargs,\n) -> torch.Tensor:\n    \"\"\"\n    Functional implementation of padding a MetaTensor. This function operates eagerly or lazily according\n    to ``lazy`` (default ``False``).\n\n    `torch.nn.functional.pad` is used unless the mode or kwargs are not available in torch,\n    in which case `np.pad` will be used.\n\n    Args:\n        img: data to be transformed, assuming `img` is channel-first and padding doesn't apply to the channel dim.\n        to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].\n            note that it including channel dimension.\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n        mode: available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        lazy: a flag indicating whether the operation should be performed in a lazy fashion or not.\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    \"\"\"\n    extra_info = {\"padded\": to_pad, \"mode\": f\"{mode}\"}\n    img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3\n    do_pad = np.asarray(to_pad).any()\n    if do_pad:\n        to_pad_list = [(int(p[0]), int(p[1])) for p in to_pad]\n        if len(to_pad_list) < len(img.shape):\n            to_pad_list += [(0, 0)] * (len(img.shape) - len(to_pad_list))\n        to_shift = [-s[0] for s in to_pad_list[1:]]  # skipping the channel pad\n        xform = create_translate(spatial_rank, to_shift)\n        shape = [d + s + e for d, (s, e) in zip(img_size, to_pad_list[1:])]\n    else:\n        shape = img_size\n        xform = torch.eye(int(spatial_rank) + 1, device=torch.device(\"cpu\"), dtype=torch.float64)\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=shape,\n        affine=xform,\n        extra_info=extra_info,\n        orig_size=img_size,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info  # type: ignore\n    out = pad_nd(out, to_pad_list, mode, **kwargs) if do_pad else out\n    out = convert_to_tensor(out, track_meta=get_track_meta())\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out  # type: ignore\n\n\ndef crop_func(img: torch.Tensor, slices: tuple[slice, ...], lazy: bool, transform_info: dict) -> torch.Tensor:\n    \"\"\"\n    Functional implementation of cropping a MetaTensor. This function operates eagerly or lazily according\n    to ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be transformed, assuming `img` is channel-first and cropping doesn't apply to the channel dim.\n        slices: the crop slices computed based on specified `center & size` or `start & end` or `slices`.\n        lazy: a flag indicating whether the operation should be performed in a lazy fashion or not.\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n    \"\"\"\n    img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    spatial_rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else 3\n    cropped = np.asarray([[s.indices(o)[0], o - s.indices(o)[1]] for s, o in zip(slices[1:], img_size)])\n    extra_info = {\"cropped\": cropped.flatten().tolist()}\n    to_shift = []\n    for i, s in enumerate(ensure_tuple(slices)[1:]):\n        if s.start is not None:\n            to_shift.append(img_size[i] + s.start if s.start < 0 else s.start)\n        else:\n            to_shift.append(0)\n    shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], img_size)]\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=shape,\n        affine=create_translate(spatial_rank, to_shift),\n        extra_info=extra_info,\n        orig_size=img_size,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    out = convert_to_tensor(img.as_tensor() if isinstance(img, MetaTensor) else img, track_meta=get_track_meta())\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info  # type: ignore\n    out = out[slices]\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out  # type: ignore\n"
  },
  {
    "path": "monai/transforms/intensity/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/intensity/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for intensity adjustment.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom abc import abstractmethod\nfrom collections.abc import Callable, Iterable, Sequence\nfrom functools import partial\nfrom typing import Any\nfrom warnings import warn\n\nimport numpy as np\nimport torch\n\nfrom monai.config import DtypeLike\nfrom monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap\nfrom monai.data.utils import get_random_patch, get_valid_patch_size\nfrom monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter\nfrom monai.transforms.transform import RandomizableTransform, Transform\nfrom monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip\nfrom monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple\nfrom monai.utils.module import min_version, optional_import\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype\n\nskimage, _ = optional_import(\"skimage\", \"0.19.0\", min_version)\n\n__all__ = [\n    \"RandGaussianNoise\",\n    \"RandRicianNoise\",\n    \"ShiftIntensity\",\n    \"RandShiftIntensity\",\n    \"StdShiftIntensity\",\n    \"RandStdShiftIntensity\",\n    \"RandBiasField\",\n    \"ScaleIntensity\",\n    \"RandScaleIntensity\",\n    \"ScaleIntensityFixedMean\",\n    \"RandScaleIntensityFixedMean\",\n    \"NormalizeIntensity\",\n    \"ThresholdIntensity\",\n    \"ScaleIntensityRange\",\n    \"ClipIntensityPercentiles\",\n    \"AdjustContrast\",\n    \"RandAdjustContrast\",\n    \"ScaleIntensityRangePercentiles\",\n    \"MaskIntensity\",\n    \"DetectEnvelope\",\n    \"SavitzkyGolaySmooth\",\n    \"MedianSmooth\",\n    \"GaussianSmooth\",\n    \"RandGaussianSmooth\",\n    \"GaussianSharpen\",\n    \"RandGaussianSharpen\",\n    \"RandHistogramShift\",\n    \"GibbsNoise\",\n    \"RandGibbsNoise\",\n    \"KSpaceSpikeNoise\",\n    \"RandKSpaceSpikeNoise\",\n    \"RandCoarseTransform\",\n    \"RandCoarseDropout\",\n    \"RandCoarseShuffle\",\n    \"HistogramNormalize\",\n    \"IntensityRemap\",\n    \"RandIntensityRemap\",\n    \"ForegroundMask\",\n    \"ComputeHoVerMaps\",\n    \"UltrasoundConfidenceMapTransform\",\n]\n\n\nclass RandGaussianNoise(RandomizableTransform):\n    \"\"\"\n    Add Gaussian noise to image.\n\n    Args:\n        prob: Probability to add Gaussian noise.\n        mean: Mean or “centre” of the distribution.\n        std: Standard deviation (spread) of distribution.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        mean: float = 0.0,\n        std: float = 0.1,\n        dtype: DtypeLike = np.float32,\n        sample_std: bool = True,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        self.mean = mean\n        self.std = std\n        self.dtype = dtype\n        self.noise: np.ndarray | None = None\n        self.sample_std = sample_std\n\n    def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        std = self.R.uniform(0, self.std) if self.sample_std else self.std\n        noise = self.R.normal(self.mean if mean is None else mean, std, size=img.shape)\n        # noise is float64 array, convert to the output dtype to save memory\n        self.noise, *_ = convert_data_type(noise, dtype=self.dtype)\n\n    def __call__(self, img: NdarrayOrTensor, mean: float | None = None, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(img=img, mean=self.mean if mean is None else mean)\n\n        if not self._do_transform:\n            return img\n\n        if self.noise is None:\n            raise RuntimeError(\"please call the `randomize()` function first.\")\n        img, *_ = convert_data_type(img, dtype=self.dtype)\n        noise, *_ = convert_to_dst_type(self.noise, img)\n        return img + noise\n\n\nclass RandRicianNoise(RandomizableTransform):\n    \"\"\"\n    Add Rician noise to image.\n    Rician noise in MRI is the result of performing a magnitude operation on complex\n    data with Gaussian noise of the same variance in both channels, as described in\n    `Noise in Magnitude Magnetic Resonance Images <https://doi.org/10.1002/cmr.a.20124>`_.\n    This transform is adapted from `DIPY <https://github.com/dipy/dipy>`_.\n    See also: `The rician distribution of noisy mri data <https://doi.org/10.1002/mrm.1910340618>`_.\n\n    Args:\n        prob: Probability to add Rician noise.\n        mean: Mean or \"centre\" of the Gaussian distributions sampled to make up\n            the Rician noise.\n        std: Standard deviation (spread) of the Gaussian distributions sampled\n            to make up the Rician noise.\n        channel_wise: If True, treats each channel of the image separately.\n        relative: If True, the spread of the sampled Gaussian distributions will\n            be std times the standard deviation of the image or channel's intensity\n            histogram.\n        sample_std: If True, sample the spread of the Gaussian distributions\n            uniformly from 0 to std.\n        dtype: output data type, if None, same as input image. defaults to float32.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        mean: Sequence[float] | float = 0.0,\n        std: Sequence[float] | float = 1.0,\n        channel_wise: bool = False,\n        relative: bool = False,\n        sample_std: bool = True,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        self.prob = prob\n        self.mean = mean\n        self.std = std\n        self.channel_wise = channel_wise\n        self.relative = relative\n        self.sample_std = sample_std\n        self.dtype = dtype\n        self._noise1: NdarrayOrTensor\n        self._noise2: NdarrayOrTensor\n\n    def _add_noise(self, img: NdarrayOrTensor, mean: float, std: float):\n        dtype_np = get_equivalent_dtype(img.dtype, np.ndarray)\n        im_shape = img.shape\n        _std = self.R.uniform(0, std) if self.sample_std else std\n        self._noise1 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np, copy=False)\n        self._noise2 = self.R.normal(mean, _std, size=im_shape).astype(dtype_np, copy=False)\n        if isinstance(img, torch.Tensor):\n            n1 = torch.tensor(self._noise1, device=img.device)\n            n2 = torch.tensor(self._noise2, device=img.device)\n            return torch.sqrt((img + n1) ** 2 + n2**2)\n\n        return np.sqrt((img + self._noise1) ** 2 + self._noise2**2)\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)\n        if randomize:\n            super().randomize(None)\n\n        if not self._do_transform:\n            return img\n\n        if self.channel_wise:\n            _mean = ensure_tuple_rep(self.mean, len(img))\n            _std = ensure_tuple_rep(self.std, len(img))\n            for i, d in enumerate(img):\n                img[i] = self._add_noise(d, mean=_mean[i], std=_std[i] * d.std() if self.relative else _std[i])\n        else:\n            if not isinstance(self.mean, (int, float)):\n                raise RuntimeError(f\"If channel_wise is False, mean must be a float or int, got {type(self.mean)}.\")\n            if not isinstance(self.std, (int, float)):\n                raise RuntimeError(f\"If channel_wise is False, std must be a float or int, got {type(self.std)}.\")\n            std = self.std * img.std().item() if self.relative else self.std\n            if not isinstance(std, (int, float)):\n                raise RuntimeError(f\"std must be a float or int number, got {type(std)}.\")\n            img = self._add_noise(img, mean=self.mean, std=std)\n        return img\n\n\nclass ShiftIntensity(Transform):\n    \"\"\"\n    Shift intensity uniformly for the entire image with specified `offset`.\n\n    Args:\n        offset: offset value to shift the intensity of image.\n        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n            E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, offset: float, safe: bool = False) -> None:\n        self.offset = offset\n        self.safe = safe\n\n    def __call__(self, img: NdarrayOrTensor, offset: float | None = None) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        offset = self.offset if offset is None else offset\n        out = img + offset\n        out, *_ = convert_data_type(data=out, dtype=img.dtype, safe=self.safe)\n\n        return out\n\n\nclass RandShiftIntensity(RandomizableTransform):\n    \"\"\"\n    Randomly shift intensity with randomly picked offset.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1, channel_wise: bool = False\n    ) -> None:\n        \"\"\"\n        Args:\n            offsets: offset range to randomly shift.\n                if single number, offset value is picked from (-offsets, offsets).\n            safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n                E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n            prob: probability of shift.\n            channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen.\n                Please ensure that the first dimension represents the channel of the image if True.\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        if isinstance(offsets, (int, float)):\n            self.offsets = (min(-offsets, offsets), max(-offsets, offsets))\n        elif len(offsets) != 2:\n            raise ValueError(f\"offsets should be a number or pair of numbers, got {offsets}.\")\n        else:\n            self.offsets = (min(offsets), max(offsets))\n        self._offset = self.offsets[0]\n        self.channel_wise = channel_wise\n        self._shifter = ShiftIntensity(self._offset, safe)\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        if self.channel_wise:\n            self._offset = [self.R.uniform(low=self.offsets[0], high=self.offsets[1]) for _ in range(data.shape[0])]  # type: ignore\n        else:\n            self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])\n\n    def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n\n        Args:\n            img: input image to shift intensity.\n            factor: a factor to multiply the random offset, then shift.\n                can be some image specific value at runtime, like: max(img), etc.\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(img)\n\n        if not self._do_transform:\n            return img\n\n        ret: NdarrayOrTensor\n        if self.channel_wise:\n            out = []\n            for i, d in enumerate(img):\n                out_channel = self._shifter(d, self._offset[i] if factor is None else self._offset[i] * factor)  # type: ignore\n                out.append(out_channel)\n            ret = torch.stack(out)  # type: ignore\n        else:\n            ret = self._shifter(img, self._offset if factor is None else self._offset * factor)\n        return ret\n\n\nclass StdShiftIntensity(Transform):\n    \"\"\"\n    Shift intensity for the image with a factor and the standard deviation of the image\n    by: ``v = v + factor * std(v)``.\n    This transform can focus on only non-zero values or the entire image,\n    and can also calculate the std on each channel separately.\n\n    Args:\n        factor: factor shift by ``v = v + factor * std(v)``.\n        nonzero: whether only count non-zero values.\n        channel_wise: if True, calculate on each channel separately. Please ensure\n            that the first dimension represents the channel of the image if True.\n        dtype: output data type, if None, same as input image. defaults to float32.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32\n    ) -> None:\n        self.factor = factor\n        self.nonzero = nonzero\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    def _stdshift(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        ones: Callable\n        std: Callable\n        if isinstance(img, torch.Tensor):\n            ones = torch.ones\n            std = partial(torch.std, unbiased=False)\n        else:\n            ones = np.ones\n            std = np.std\n\n        slices = (img != 0) if self.nonzero else ones(img.shape, dtype=bool)\n        if slices.any():\n            offset = self.factor * std(img[slices])\n            img[slices] = img[slices] + offset\n        return img\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)\n        if self.channel_wise:\n            for i, d in enumerate(img):\n                img[i] = self._stdshift(d)  # type: ignore\n        else:\n            img = self._stdshift(img)\n        return img\n\n\nclass RandStdShiftIntensity(RandomizableTransform):\n    \"\"\"\n    Shift intensity for the image with a factor and the standard deviation of the image\n    by: ``v = v + factor * std(v)`` where the `factor` is randomly picked.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        factors: tuple[float, float] | float,\n        prob: float = 0.1,\n        nonzero: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        \"\"\"\n        Args:\n            factors: if tuple, the randomly picked range is (min(factors), max(factors)).\n                If single number, the range is (-factors, factors).\n            prob: probability of std shift.\n            nonzero: whether only count non-zero values.\n            channel_wise: if True, calculate on each channel separately.\n            dtype: output data type, if None, same as input image. defaults to float32.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        if isinstance(factors, (int, float)):\n            self.factors = (min(-factors, factors), max(-factors, factors))\n        elif len(factors) != 2:\n            raise ValueError(f\"factors should be a number or pair of numbers, got {factors}.\")\n        else:\n            self.factors = (min(factors), max(factors))\n        self.factor = self.factors[0]\n        self.nonzero = nonzero\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        shifter = StdShiftIntensity(\n            factor=self.factor, nonzero=self.nonzero, channel_wise=self.channel_wise, dtype=self.dtype\n        )\n        return shifter(img=img)\n\n\nclass ScaleIntensity(Transform):\n    \"\"\"\n    Scale the intensity of input image to the given value range (minv, maxv).\n    If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        minv: float | None = 0.0,\n        maxv: float | None = 1.0,\n        factor: float | None = None,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        \"\"\"\n        Args:\n            minv: minimum value of output data.\n            maxv: maximum value of output data.\n            factor: factor scale by ``v = v * (1 + factor)``. In order to use\n                this parameter, please set both `minv` and `maxv` into None.\n            channel_wise: if True, scale on each channel separately. Please ensure\n                that the first dimension represents the channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n        \"\"\"\n        self.minv = minv\n        self.maxv = maxv\n        self.factor = factor\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n\n        Raises:\n            ValueError: When ``self.minv=None`` or ``self.maxv=None`` and ``self.factor=None``. Incompatible values.\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t = convert_to_tensor(img, track_meta=False)\n        ret: NdarrayOrTensor\n        if self.minv is not None or self.maxv is not None:\n            if self.channel_wise:\n                out = [rescale_array(d, self.minv, self.maxv, dtype=self.dtype) for d in img_t]\n                ret = torch.stack(out)  # type: ignore\n            else:\n                ret = rescale_array(img_t, self.minv, self.maxv, dtype=self.dtype)\n        else:\n            ret = (img_t * (1 + self.factor)) if self.factor is not None else img_t\n        ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img_t.dtype)[0]\n        return ret\n\n\nclass ScaleIntensityFixedMean(Transform):\n    \"\"\"\n    Scale the intensity of input image by ``v = v * (1 + factor)``, then shift the output so that the output image has the\n    same mean as the input.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        factor: float = 0,\n        preserve_range: bool = False,\n        fixed_mean: bool = True,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        \"\"\"\n        Args:\n            factor: factor scale by ``v = v * (1 + factor)``.\n            preserve_range: clips the output array/tensor to the range of the input array/tensor\n            fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling\n                to ensure that the output has the same mean as the input.\n            channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied\n                on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the\n                channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n        \"\"\"\n        self.factor = factor\n        self.preserve_range = preserve_range\n        self.fixed_mean = fixed_mean\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        Args:\n            img: the input tensor/array\n            factor: factor scale by ``v = v * (1 + factor)``\n\n        \"\"\"\n\n        factor = factor if factor is not None else self.factor\n\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t = convert_to_tensor(img, track_meta=False)\n        ret: NdarrayOrTensor\n        if self.channel_wise:\n            out = []\n            for d in img_t:\n                if self.preserve_range:\n                    clip_min = d.min()\n                    clip_max = d.max()\n\n                if self.fixed_mean:\n                    mn = d.mean()\n                    d = d - mn\n\n                out_channel = d * (1 + factor)\n\n                if self.fixed_mean:\n                    out_channel = out_channel + mn\n\n                if self.preserve_range:\n                    out_channel = clip(out_channel, clip_min, clip_max)\n\n                out.append(out_channel)\n            ret = torch.stack(out)\n        else:\n            if self.preserve_range:\n                clip_min = img_t.min()\n                clip_max = img_t.max()\n\n            if self.fixed_mean:\n                mn = img_t.mean()\n                img_t = img_t - mn\n\n            ret = img_t * (1 + factor)\n\n            if self.fixed_mean:\n                ret = ret + mn\n\n            if self.preserve_range:\n                ret = clip(ret, clip_min, clip_max)\n\n        ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img_t.dtype)[0]\n        return ret\n\n\nclass RandScaleIntensityFixedMean(RandomizableTransform):\n    \"\"\"\n    Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`\n    is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling\n    to ensure that the output has the same mean as the input.\n    \"\"\"\n\n    backend = ScaleIntensityFixedMean.backend\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        factors: Sequence[float] | float = 0,\n        fixed_mean: bool = True,\n        preserve_range: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        \"\"\"\n        Args:\n            factors: factor range to randomly scale by ``v = v * (1 + factor)``.\n                if single number, factor value is picked from (-factors, factors).\n            preserve_range: clips the output array/tensor to the range of the input array/tensor\n            fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling\n                to ensure that the output has the same mean as the input.\n            channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied\n            on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the\n            channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        if isinstance(factors, (int, float)):\n            self.factors = (min(-factors, factors), max(-factors, factors))\n        elif len(factors) != 2:\n            raise ValueError(\"factors should be a number or pair of numbers.\")\n        else:\n            self.factors = (min(factors), max(factors))\n        self.factor = self.factors[0]\n        self.fixed_mean = fixed_mean\n        self.preserve_range = preserve_range\n        self.dtype = dtype\n\n        self.scaler = ScaleIntensityFixedMean(\n            factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype\n        )\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return convert_data_type(img, dtype=self.dtype)[0]\n\n        return self.scaler(img, self.factor)\n\n\nclass RandScaleIntensity(RandomizableTransform):\n    \"\"\"\n    Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`\n    is randomly picked.\n    \"\"\"\n\n    backend = ScaleIntensity.backend\n\n    def __init__(\n        self,\n        factors: tuple[float, float] | float,\n        prob: float = 0.1,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        \"\"\"\n        Args:\n            factors: factor range to randomly scale by ``v = v * (1 + factor)``.\n                if single number, factor value is picked from (-factors, factors).\n            prob: probability of scale.\n            channel_wise: if True, scale on each channel separately. Please ensure\n                that the first dimension represents the channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        if isinstance(factors, (int, float)):\n            self.factors = (min(-factors, factors), max(-factors, factors))\n        elif len(factors) != 2:\n            raise ValueError(f\"factors should be a number or pair of numbers, got {factors}.\")\n        else:\n            self.factors = (min(factors), max(factors))\n        self.factor = self.factors[0]\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        if self.channel_wise:\n            self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])]  # type: ignore\n        else:\n            self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(img)\n\n        if not self._do_transform:\n            return convert_data_type(img, dtype=self.dtype)[0]\n\n        ret: NdarrayOrTensor\n        if self.channel_wise:\n            out = []\n            for i, d in enumerate(img):\n                out_channel = ScaleIntensity(minv=None, maxv=None, factor=self.factor[i], dtype=self.dtype)(d)  # type: ignore\n                out.append(out_channel)\n            ret = torch.stack(out)  # type: ignore\n        else:\n            ret = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)\n        return ret\n\n\nclass RandBiasField(RandomizableTransform):\n    \"\"\"\n    Random bias field augmentation for MR images.\n    The bias field is considered as a linear combination of smoothly varying basis (polynomial)\n    functions, as described in `Automated Model-Based Tissue Classification of MR Images of the Brain\n    <https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=811270>`_.\n    This implementation adapted from `NiftyNet\n    <https://github.com/NifTK/NiftyNet>`_.\n    Referred to `Longitudinal segmentation of age-related white matter hyperintensities\n    <https://www.sciencedirect.com/science/article/pii/S1361841517300257?via%3Dihub>`_.\n\n    Args:\n        degree: degree of freedom of the polynomials. The value should be no less than 1.\n            Defaults to 3.\n        coeff_range: range of the random coefficients. Defaults to (0.0, 0.1).\n        dtype: output data type, if None, same as input image. defaults to float32.\n        prob: probability to do random bias field.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        degree: int = 3,\n        coeff_range: tuple[float, float] = (0.0, 0.1),\n        dtype: DtypeLike = np.float32,\n        prob: float = 0.1,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        if degree < 1:\n            raise ValueError(f\"degree should be no less than 1, got {degree}.\")\n        self.degree = degree\n        self.coeff_range = coeff_range\n        self.dtype = dtype\n\n        self._coeff = [1.0]\n\n    def _generate_random_field(self, spatial_shape: Sequence[int], degree: int, coeff: Sequence[float]):\n        \"\"\"\n        products of polynomials as bias field estimations\n        \"\"\"\n        rank = len(spatial_shape)\n        coeff_mat = np.zeros((degree + 1,) * rank)\n        coords = [np.linspace(-1.0, 1.0, dim, dtype=np.float32) for dim in spatial_shape]\n        if rank == 2:\n            coeff_mat[np.tril_indices(degree + 1)] = coeff\n            return np.polynomial.legendre.leggrid2d(coords[0], coords[1], coeff_mat)\n        if rank == 3:\n            pts: list[list[int]] = [[0, 0, 0]]\n            for i in range(degree + 1):\n                for j in range(degree + 1 - i):\n                    for k in range(degree + 1 - i - j):\n                        pts.append([i, j, k])\n            if len(pts) > 1:\n                pts = pts[1:]\n            np_pts = np.stack(pts)\n            coeff_mat[np_pts[:, 0], np_pts[:, 1], np_pts[:, 2]] = coeff\n            return np.polynomial.legendre.leggrid3d(coords[0], coords[1], coords[2], coeff_mat)\n        raise NotImplementedError(\"only supports 2D or 3D fields\")\n\n    def randomize(self, img_size: Sequence[int]) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        n_coeff = int(np.prod([(self.degree + k) / k for k in range(1, len(img_size) + 1)]))\n        self._coeff = self.R.uniform(*self.coeff_range, n_coeff).tolist()\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(img_size=img.shape[1:])\n\n        if not self._do_transform:\n            return img\n\n        num_channels, *spatial_shape = img.shape\n        _bias_fields = np.stack(\n            [\n                self._generate_random_field(spatial_shape=spatial_shape, degree=self.degree, coeff=self._coeff)\n                for _ in range(num_channels)\n            ],\n            axis=0,\n        )\n        img_np, *_ = convert_data_type(img, np.ndarray)\n        out: NdarrayOrTensor = img_np * np.exp(_bias_fields)\n        out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype or img.dtype)\n        return out\n\n\nclass NormalizeIntensity(Transform):\n    \"\"\"\n    Normalize input based on the `subtrahend` and `divisor`: `(img - subtrahend) / divisor`.\n    Use calculated mean or std value of the input image if no `subtrahend` or `divisor` provided.\n    This transform can normalize only non-zero values or entire image, and can also calculate\n    mean and std on each channel separately.\n    When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should\n    be the number of image channels if they are not None.\n    If the input is not of floating point type, it will be converted to float32\n\n    Args:\n        subtrahend: the amount to subtract by (usually the mean).\n        divisor: the amount to divide by (usually the standard deviation).\n        nonzero: whether only normalize non-zero values.\n        channel_wise: if True, calculate on each channel separately, otherwise, calculate on\n            the entire image directly. default to False.\n        dtype: output data type, if None, same as input image. defaults to float32.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        subtrahend: Sequence | NdarrayOrTensor | None = None,\n        divisor: Sequence | NdarrayOrTensor | None = None,\n        nonzero: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        self.subtrahend = subtrahend\n        self.divisor = divisor\n        self.nonzero = nonzero\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    @staticmethod\n    def _mean(x):\n        if isinstance(x, np.ndarray):\n            return np.mean(x)\n        x = torch.mean(x.float())\n        return x.item() if x.numel() == 1 else x\n\n    @staticmethod\n    def _std(x):\n        if isinstance(x, np.ndarray):\n            return np.std(x)\n        x = torch.std(x.float(), unbiased=False)\n        return x.item() if x.numel() == 1 else x\n\n    def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:\n        img, *_ = convert_data_type(img, dtype=torch.float32)\n\n        if self.nonzero:\n            slices = img != 0\n            masked_img = img[slices]\n            if not slices.any():\n                return img\n        else:\n            slices = None\n            masked_img = img\n\n        _sub = sub if sub is not None else self._mean(masked_img)\n        if isinstance(_sub, (torch.Tensor, np.ndarray)):\n            _sub, *_ = convert_to_dst_type(_sub, img)\n            if slices is not None:\n                _sub = _sub[slices]\n\n        _div = div if div is not None else self._std(masked_img)\n        if np.isscalar(_div):\n            if _div == 0.0:\n                _div = 1.0\n        elif isinstance(_div, (torch.Tensor, np.ndarray)):\n            _div, *_ = convert_to_dst_type(_div, img)\n            if slices is not None:\n                _div = _div[slices]\n            _div[_div == 0.0] = 1.0\n\n        if slices is not None:\n            img[slices] = (masked_img - _sub) / _div\n        else:\n            img = (img - _sub) / _div\n        return img\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,\n        \"\"\"\n        img_t: torch.Tensor = convert_to_tensor(img, track_meta=get_track_meta())  # type: ignore[assignment]\n        dtype = self.dtype or img.dtype\n        img_len = len(img_t)\n        if self.channel_wise:\n            if self.subtrahend is not None and len(self.subtrahend) != img_len:\n                raise ValueError(f\"img has {img_len} channels, but subtrahend has {len(self.subtrahend)} components.\")\n            if self.divisor is not None and len(self.divisor) != img_len:\n                raise ValueError(f\"img has {img_len} channels, but divisor has {len(self.divisor)} components.\")\n\n            if not img_t.dtype.is_floating_point:\n                img_t, *_ = convert_data_type(img_t, dtype=torch.float32)\n\n            for i, d in enumerate(img_t):\n                img_t[i] = self._normalize(  # type: ignore\n                    d,\n                    sub=self.subtrahend[i] if self.subtrahend is not None else None,\n                    div=self.divisor[i] if self.divisor is not None else None,\n                )\n        else:\n            img_t = self._normalize(img_t, self.subtrahend, self.divisor)  # type: ignore[assignment]\n\n        out = convert_to_dst_type(img_t, img_t, dtype=dtype)[0]\n        return out\n\n\nclass ThresholdIntensity(Transform):\n    \"\"\"\n    Filter the intensity values of whole image to below threshold or above threshold.\n    And fill the remaining parts of the image to the `cval` value.\n\n    Args:\n        threshold: the threshold to filter intensity values.\n        above: filter values above the threshold or below the threshold, default is True.\n        cval: value to fill the remaining parts of the image, default is 0.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None:\n        if not isinstance(threshold, (int, float)):\n            raise ValueError(f\"threshold must be a float or int number, got {type(threshold)} {threshold}.\")\n        self.threshold = threshold\n        self.above = above\n        self.cval = cval\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        mask = img > self.threshold if self.above else img < self.threshold\n        res = where(mask, img, self.cval)\n        res, *_ = convert_data_type(res, dtype=img.dtype)\n        return res\n\n\nclass ScaleIntensityRange(Transform):\n    \"\"\"\n    Apply specific intensity scaling to the whole numpy array.\n    Scaling from [a_min, a_max] to [b_min, b_max] with clip option.\n\n    When `b_min` or `b_max` are `None`, `scaled_array * (b_max - b_min) + b_min` will be skipped.\n    If `clip=True`, when `b_min`/`b_max` is None, the clipping is not performed on the corresponding edge.\n\n    Args:\n        a_min: intensity original range min.\n        a_max: intensity original range max.\n        b_min: intensity target range min.\n        b_max: intensity target range max.\n        clip: whether to perform clip after scaling.\n        dtype: output data type, if None, same as input image. defaults to float32.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        a_min: float,\n        a_max: float,\n        b_min: float | None = None,\n        b_max: float | None = None,\n        clip: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        self.a_min = a_min\n        self.a_max = a_max\n        self.b_min = b_min\n        self.b_max = b_max\n        self.clip = clip\n        self.dtype = dtype\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        dtype = self.dtype or img.dtype\n        if self.a_max - self.a_min == 0.0:\n            warn(\"Divide by zero (a_min == a_max)\", Warning)\n            if self.b_min is None:\n                return img - self.a_min\n            return img - self.a_min + self.b_min\n\n        img = (img - self.a_min) / (self.a_max - self.a_min)\n        if (self.b_min is not None) and (self.b_max is not None):\n            img = img * (self.b_max - self.b_min) + self.b_min\n        if self.clip:\n            img = clip(img, self.b_min, self.b_max)\n        ret: NdarrayOrTensor = convert_data_type(img, dtype=dtype)[0]\n\n        return ret\n\n\nclass ClipIntensityPercentiles(Transform):\n    \"\"\"\n    Apply clip based on the intensity distribution of input image.\n    If `sharpness_factor` is provided, the intensity values will be soft clipped according to\n    f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))\n    From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291\n\n    Soft clipping preserves the order of the values and maintains the gradient everywhere.\n    For example:\n\n    .. code-block:: python\n        :emphasize-lines: 11, 22\n\n        image = torch.Tensor(\n            [[[1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5]]])\n\n        # Hard clipping from lower and upper image intensity percentiles\n        hard_clipper = ClipIntensityPercentiles(30, 70)\n        print(hard_clipper(image))\n        metatensor([[[2., 2., 3., 4., 4.],\n                [2., 2., 3., 4., 4.],\n                [2., 2., 3., 4., 4.],\n                [2., 2., 3., 4., 4.],\n                [2., 2., 3., 4., 4.],\n                [2., 2., 3., 4., 4.]]])\n\n\n        # Soft clipping from lower and upper image intensity percentiles\n        soft_clipper = ClipIntensityPercentiles(30, 70, 10.)\n        print(soft_clipper(image))\n        metatensor([[[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],\n         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],\n         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],\n         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],\n         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],\n         [2.0000, 2.0693, 3.0000, 3.9307, 4.0000]]])\n\n    See Also:\n\n        - :py:class:`monai.transforms.ScaleIntensityRangePercentiles`\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        lower: float | None,\n        upper: float | None,\n        sharpness_factor: float | None = None,\n        channel_wise: bool = False,\n        return_clipping_values: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        \"\"\"\n        Args:\n            lower: lower intensity percentile. In the case of hard clipping, None will have the same effect as 0 by\n                not clipping the lowest input values. However, in the case of soft clipping, None and zero will have\n                two different effects: None will not apply clipping to low values, whereas zero will still transform\n                the lower values according to the soft clipping transformation. Please check for more details:\n                https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291.\n            upper: upper intensity percentile.  The same as for lower, but this time with the highest values. If we\n                are looking to perform soft clipping, if None then there will be no effect on this side whereas if set\n                to 100, the values will be passed via the corresponding clipping equation.\n            sharpness_factor: if not None, the intensity values will be soft clipped according to\n                f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)).\n                defaults to None.\n            channel_wise: if True, compute intensity percentile and normalize every channel separately.\n                default to False.\n            return_clipping_values: whether to return the calculated percentiles in tensor meta information.\n                If soft clipping and requested percentile is None, return None as the corresponding clipping\n                values in meta information. Clipping values are stored in a list with each element corresponding\n                to a channel if channel_wise is set to True. defaults to False.\n            dtype: output data type, if None, same as input image. defaults to float32.\n        \"\"\"\n        if lower is None and upper is None:\n            raise ValueError(\"lower or upper percentiles must be provided\")\n        if lower is not None and (lower < 0.0 or lower > 100.0):\n            raise ValueError(\"Percentiles must be in the range [0, 100]\")\n        if upper is not None and (upper < 0.0 or upper > 100.0):\n            raise ValueError(\"Percentiles must be in the range [0, 100]\")\n        if upper is not None and lower is not None and upper < lower:\n            raise ValueError(\"upper must be greater than or equal to lower\")\n        if sharpness_factor is not None and sharpness_factor <= 0:\n            raise ValueError(\"sharpness_factor must be greater than 0\")\n\n        self.lower = lower\n        self.upper = upper\n        self.sharpness_factor = sharpness_factor\n        self.channel_wise = channel_wise\n        if return_clipping_values:\n            self.clipping_values: list[tuple[float | None, float | None]] = []\n        self.return_clipping_values = return_clipping_values\n        self.dtype = dtype\n\n    def _clip(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        if self.sharpness_factor is not None:\n            lower_percentile = percentile(img, self.lower) if self.lower is not None else None\n            upper_percentile = percentile(img, self.upper) if self.upper is not None else None\n            img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)\n        else:\n            lower_percentile = percentile(img, self.lower) if self.lower is not None else percentile(img, 0)\n            upper_percentile = percentile(img, self.upper) if self.upper is not None else percentile(img, 100)\n            img = clip(img, lower_percentile, upper_percentile)\n\n        if self.return_clipping_values:\n            self.clipping_values.append(\n                (\n                    (\n                        lower_percentile\n                        if lower_percentile is None\n                        else lower_percentile.item() if hasattr(lower_percentile, \"item\") else lower_percentile\n                    ),\n                    (\n                        upper_percentile\n                        if upper_percentile is None\n                        else upper_percentile.item() if hasattr(upper_percentile, \"item\") else upper_percentile\n                    ),\n                )\n            )\n        img = convert_to_tensor(img, track_meta=False)\n        return img\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t = convert_to_tensor(img, track_meta=False)\n        if self.channel_wise:\n            img_t = torch.stack([self._clip(img=d) for d in img_t])  # type: ignore\n        else:\n            img_t = self._clip(img=img_t)\n\n        img = convert_to_dst_type(img_t, dst=img)[0]\n        if self.return_clipping_values:\n            img.meta[\"clipping_values\"] = self.clipping_values  # type: ignore\n\n        return img\n\n\nclass AdjustContrast(Transform):\n    \"\"\"\n    Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::\n\n        x = ((x - min) / intensity_range) ^ gamma * intensity_range + min\n\n    Args:\n        gamma: gamma value to adjust the contrast as function.\n        invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity\n            values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked\n            from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n        retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to\n            ensure that the output intensity distribution has the same mean and standard deviation as the intensity\n            distribution of the input. This behaviour is mimicked from `nnU-Net\n            <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, gamma: float, invert_image: bool = False, retain_stats: bool = False) -> None:\n        if not isinstance(gamma, (int, float)):\n            raise ValueError(f\"gamma must be a float or int number, got {type(gamma)} {gamma}.\")\n        self.gamma = gamma\n        self.invert_image = invert_image\n        self.retain_stats = retain_stats\n\n    def __call__(self, img: NdarrayOrTensor, gamma=None) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        gamma: gamma value to adjust the contrast as function.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        gamma = gamma if gamma is not None else self.gamma\n\n        if self.invert_image:\n            img = -img\n\n        if self.retain_stats:\n            mn = img.mean()\n            sd = img.std()\n\n        epsilon = 1e-7\n        img_min = img.min()\n        img_range = img.max() - img_min\n        ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** gamma * img_range + img_min\n\n        if self.retain_stats:\n            # zero mean and normalize\n            ret = ret - ret.mean()\n            ret = ret / (ret.std() + 1e-8)\n            # restore old mean and standard deviation\n            ret = sd * ret + mn\n\n        if self.invert_image:\n            ret = -ret\n\n        return ret\n\n\nclass RandAdjustContrast(RandomizableTransform):\n    \"\"\"\n    Randomly changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:\n\n        x = ((x - min) / intensity_range) ^ gamma * intensity_range + min\n\n    Args:\n        prob: Probability of adjustment.\n        gamma: Range of gamma values.\n            If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).\n        invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity\n            values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked\n            from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n        retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to\n            ensure that the output intensity distribution has the same mean and standard deviation as the intensity\n            distribution of the input. This behaviour is mimicked from `nnU-Net\n            <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n    \"\"\"\n\n    backend = AdjustContrast.backend\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        gamma: Sequence[float] | float = (0.5, 4.5),\n        invert_image: bool = False,\n        retain_stats: bool = False,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n\n        if isinstance(gamma, (int, float)):\n            if gamma <= 0.5:\n                raise ValueError(\n                    f\"if gamma is a number, must greater than 0.5 and value is picked from (0.5, gamma), got {gamma}\"\n                )\n            self.gamma = (0.5, gamma)\n        elif len(gamma) != 2:\n            raise ValueError(\"gamma should be a number or pair of numbers.\")\n        else:\n            self.gamma = (min(gamma), max(gamma))\n\n        self.gamma_value: float = 1.0\n        self.invert_image: bool = invert_image\n        self.retain_stats: bool = retain_stats\n\n        self.adjust_contrast = AdjustContrast(\n            self.gamma_value, invert_image=self.invert_image, retain_stats=self.retain_stats\n        )\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.gamma_value = self.R.uniform(low=self.gamma[0], high=self.gamma[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        if self.gamma_value is None:\n            raise RuntimeError(\"gamma_value is not set, please call `randomize` function first.\")\n\n        return self.adjust_contrast(img, self.gamma_value)\n\n\nclass ScaleIntensityRangePercentiles(Transform):\n    \"\"\"\n    Apply range scaling to a numpy array based on the intensity distribution of the input.\n\n    By default this transform will scale from [lower_intensity_percentile, upper_intensity_percentile] to\n    `[b_min, b_max]`, where {lower,upper}_intensity_percentile are the intensity values at the corresponding\n    percentiles of ``img``.\n\n    The ``relative`` parameter can also be set to scale from [lower_intensity_percentile, upper_intensity_percentile]\n    to the lower and upper percentiles of the output range [b_min, b_max].\n\n    For example:\n\n    .. code-block:: python\n        :emphasize-lines: 11, 22\n\n        image = torch.Tensor(\n            [[[1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5],\n              [1, 2, 3, 4, 5]]])\n\n        # Scale from lower and upper image intensity percentiles\n        # to output range [b_min, b_max]\n        scaler = ScaleIntensityRangePercentiles(10, 90, 0, 200, False, False)\n        print(scaler(image))\n        metatensor([[[  0.,  50., 100., 150., 200.],\n             [  0.,  50., 100., 150., 200.],\n             [  0.,  50., 100., 150., 200.],\n             [  0.,  50., 100., 150., 200.],\n             [  0.,  50., 100., 150., 200.],\n             [  0.,  50., 100., 150., 200.]]])\n\n\n        # Scale from lower and upper image intensity percentiles\n        # to lower and upper percentiles of the output range [b_min, b_max]\n        rel_scaler = ScaleIntensityRangePercentiles(10, 90, 0, 200, False, True)\n        print(rel_scaler(image))\n        metatensor([[[ 20.,  60., 100., 140., 180.],\n             [ 20.,  60., 100., 140., 180.],\n             [ 20.,  60., 100., 140., 180.],\n             [ 20.,  60., 100., 140., 180.],\n             [ 20.,  60., 100., 140., 180.],\n             [ 20.,  60., 100., 140., 180.]]])\n\n    See Also:\n\n        - :py:class:`monai.transforms.ScaleIntensityRange`\n\n    Args:\n        lower: lower intensity percentile.\n        upper: upper intensity percentile.\n        b_min: intensity target range min.\n        b_max: intensity target range max.\n        clip: whether to perform clip after scaling.\n        relative: whether to scale to the corresponding percentiles of [b_min, b_max].\n        channel_wise: if True, compute intensity percentile and normalize every channel separately.\n            default to False.\n        dtype: output data type, if None, same as input image. defaults to float32.\n    \"\"\"\n\n    backend = ScaleIntensityRange.backend\n\n    def __init__(\n        self,\n        lower: float,\n        upper: float,\n        b_min: float | None,\n        b_max: float | None,\n        clip: bool = False,\n        relative: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        if lower < 0.0 or lower > 100.0:\n            raise ValueError(\"Percentiles must be in the range [0, 100]\")\n        if upper < 0.0 or upper > 100.0:\n            raise ValueError(\"Percentiles must be in the range [0, 100]\")\n        self.lower = lower\n        self.upper = upper\n        self.b_min = b_min\n        self.b_max = b_max\n        self.clip = clip\n        self.relative = relative\n        self.channel_wise = channel_wise\n        self.dtype = dtype\n\n    def _normalize(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        a_min: float = percentile(img, self.lower)  # type: ignore\n        a_max: float = percentile(img, self.upper)  # type: ignore\n        b_min = self.b_min\n        b_max = self.b_max\n\n        if self.relative:\n            if (self.b_min is None) or (self.b_max is None):\n                raise ValueError(\"If it is relative, b_min and b_max should not be None.\")\n            b_min = ((self.b_max - self.b_min) * (self.lower / 100.0)) + self.b_min\n            b_max = ((self.b_max - self.b_min) * (self.upper / 100.0)) + self.b_min\n\n        scalar = ScaleIntensityRange(\n            a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=self.clip, dtype=self.dtype\n        )\n        img = scalar(img)\n        img = convert_to_tensor(img, track_meta=False)\n        return img\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t = convert_to_tensor(img, track_meta=False)\n        if self.channel_wise:\n            img_t = torch.stack([self._normalize(img=d) for d in img_t])  # type: ignore\n        else:\n            img_t = self._normalize(img=img_t)\n\n        return convert_to_dst_type(img_t, dst=img, dtype=self.dtype)[0]\n\n\nclass MaskIntensity(Transform):\n    \"\"\"\n    Mask the intensity values of input image with the specified mask data.\n    Mask data must have the same spatial size as the input image, and all\n    the intensity values of input image corresponding to the selected values\n    in the mask data will keep the original value, others will be set to `0`.\n\n    Args:\n        mask_data: if `mask_data` is single channel, apply to every channel\n            of input image. if multiple channels, the number of channels must\n            match the input data. the intensity values of input image corresponding\n            to the selected values in the mask data will keep the original value,\n            others will be set to `0`. if None, must specify the `mask_data` at runtime.\n        select_fn: function to select valid values of the `mask_data`, default is\n            to select `values > 0`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, mask_data: NdarrayOrTensor | None = None, select_fn: Callable = is_positive) -> None:\n        self.mask_data = mask_data\n        self.select_fn = select_fn\n\n    def __call__(self, img: NdarrayOrTensor, mask_data: NdarrayOrTensor | None = None) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            mask_data: if mask data is single channel, apply to every channel\n                of input image. if multiple channels, the channel number must\n                match input data. mask_data will be converted to `bool` values\n                by `mask_data > 0` before applying transform to input image.\n\n        Raises:\n            - ValueError: When both ``mask_data`` and ``self.mask_data`` are None.\n            - ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel.\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        mask_data = self.mask_data if mask_data is None else mask_data\n        if mask_data is None:\n            raise ValueError(\"must provide the mask_data when initializing the transform or at runtime.\")\n\n        mask_data_, *_ = convert_to_dst_type(src=mask_data, dst=img)\n\n        mask_data_ = self.select_fn(mask_data_)\n        if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]:\n            raise ValueError(\n                \"When mask_data is not single channel, mask_data channels must match img, \"\n                f\"got img channels={img.shape[0]} mask_data channels={mask_data_.shape[0]}.\"\n            )\n\n        return convert_to_dst_type(img * mask_data_, dst=img)[0]\n\n\nclass SavitzkyGolaySmooth(Transform):\n    \"\"\"\n    Smooth the input data along the given axis using a Savitzky-Golay filter.\n\n    Args:\n        window_length: Length of the filter window, must be a positive odd integer.\n        order: Order of the polynomial to fit to each window, must be less than ``window_length``.\n        axis: Optional axis along which to apply the filter kernel. Default 1 (first spatial dimension).\n        mode: Optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``\n            or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, window_length: int, order: int, axis: int = 1, mode: str = \"zeros\"):\n        if axis < 0:\n            raise ValueError(\"axis must be zero or positive.\")\n\n        self.window_length = window_length\n        self.order = order\n        self.axis = axis\n        self.mode = mode\n        self.img_t: torch.Tensor = torch.tensor(0.0)\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: array containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].\n\n        Returns:\n            array containing smoothed result.\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        self.img_t = convert_to_tensor(img, track_meta=False)\n\n        # add one to transform axis because a batch axis will be added at dimension 0\n        savgol_filter = SavitzkyGolayFilter(self.window_length, self.order, self.axis + 1, self.mode)\n        # convert to Tensor and add Batch axis expected by HilbertTransform\n        smoothed = savgol_filter(self.img_t.unsqueeze(0)).squeeze(0)\n        out, *_ = convert_to_dst_type(smoothed, dst=img)\n\n        return out\n\n\nclass DetectEnvelope(Transform):\n    \"\"\"\n    Find the envelope of the input data along the requested axis using a Hilbert transform.\n\n    Args:\n        axis: Axis along which to detect the envelope. Default 1, i.e. the first spatial dimension.\n        n: FFT size. Default img.shape[axis]. Input will be zero-padded or truncated to this size along dimension\n        ``axis``.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, axis: int = 1, n: int | None = None) -> None:\n        if axis < 0:\n            raise ValueError(\"axis must be zero or positive.\")\n\n        self.axis = axis\n        self.n = n\n\n    def __call__(self, img: NdarrayOrTensor):\n        \"\"\"\n\n        Args:\n            img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].\n\n        Returns:\n            np.ndarray containing envelope of data in img along the specified axis.\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t, *_ = convert_data_type(img, torch.Tensor)\n        # add one to transform axis because a batch axis will be added at dimension 0\n        hilbert_transform = HilbertTransform(self.axis + 1, self.n)\n        # convert to Tensor and add Batch axis expected by HilbertTransform\n        out = hilbert_transform(img_t.unsqueeze(0)).squeeze(0).abs()\n        out, *_ = convert_to_dst_type(src=out, dst=img)\n\n        return out\n\n\nclass MedianSmooth(Transform):\n    \"\"\"\n    Apply median filter to the input data based on specified `radius` parameter.\n    A default value `radius=1` is provided for reference.\n\n    See also: :py:func:`monai.networks.layers.median_filter`\n\n    Args:\n        radius: if a list of values, must match the count of spatial dimensions of input data,\n            and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n            use it for all spatial dimensions.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, radius: Sequence[int] | int = 1) -> None:\n        self.radius = radius\n\n    def __call__(self, img: NdarrayTensor) -> NdarrayTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)\n        spatial_dims = img_t.ndim - 1\n        r = ensure_tuple_rep(self.radius, spatial_dims)\n        median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims)\n        out_t: torch.Tensor = median_filter_instance(img_t)\n        out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)\n        return out\n\n\nclass GaussianSmooth(Transform):\n    \"\"\"\n    Apply Gaussian smooth to the input data based on specified `sigma` parameter.\n    A default value `sigma=1.0` is provided for reference.\n\n    Args:\n        sigma: if a list of values, must match the count of spatial dimensions of input data,\n            and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n            use it for all spatial dimensions.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, sigma: Sequence[float] | float = 1.0, approx: str = \"erf\") -> None:\n        self.sigma = sigma\n        self.approx = approx\n\n    def __call__(self, img: NdarrayTensor) -> NdarrayTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)\n        sigma: Sequence[torch.Tensor] | torch.Tensor\n        if isinstance(self.sigma, Sequence):\n            sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma]\n        else:\n            sigma = torch.as_tensor(self.sigma, device=img_t.device)\n        gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx)\n        out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0)\n        out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)\n\n        return out\n\n\nclass RandGaussianSmooth(RandomizableTransform):\n    \"\"\"\n    Apply Gaussian smooth to the input data based on randomly selected `sigma` parameters.\n\n    Args:\n        sigma_x: randomly select sigma value for the first spatial dimension.\n        sigma_y: randomly select sigma value for the second spatial dimension if have.\n        sigma_z: randomly select sigma value for the third spatial dimension if have.\n        prob: probability of Gaussian smooth.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n\n    \"\"\"\n\n    backend = GaussianSmooth.backend\n\n    def __init__(\n        self,\n        sigma_x: tuple[float, float] = (0.25, 1.5),\n        sigma_y: tuple[float, float] = (0.25, 1.5),\n        sigma_z: tuple[float, float] = (0.25, 1.5),\n        prob: float = 0.1,\n        approx: str = \"erf\",\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        self.sigma_x = sigma_x\n        self.sigma_y = sigma_y\n        self.sigma_z = sigma_z\n        self.approx = approx\n\n        self.x = self.sigma_x[0]\n        self.y = self.sigma_y[0]\n        self.z = self.sigma_z[0]\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.x = self.R.uniform(low=self.sigma_x[0], high=self.sigma_x[1])\n        self.y = self.R.uniform(low=self.sigma_y[0], high=self.sigma_y[1])\n        self.z = self.R.uniform(low=self.sigma_z[0], high=self.sigma_z[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=img.ndim - 1)\n        return GaussianSmooth(sigma=sigma, approx=self.approx)(img)\n\n\nclass GaussianSharpen(Transform):\n    \"\"\"\n    Sharpen images using the Gaussian Blur filter.\n    Referring to: http://scipy-lectures.org/advanced/image_processing/auto_examples/plot_sharpen.html.\n    The algorithm is shown as below\n\n    .. code-block:: python\n\n        blurred_f = gaussian_filter(img, sigma1)\n        filter_blurred_f = gaussian_filter(blurred_f, sigma2)\n        img = blurred_f + alpha * (blurred_f - filter_blurred_f)\n\n    A set of default values `sigma1=3.0`, `sigma2=1.0` and `alpha=30.0` is provide for reference.\n\n    Args:\n        sigma1: sigma parameter for the first gaussian kernel. if a list of values, must match the count\n            of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension.\n            if only 1 value provided, use it for all spatial dimensions.\n        sigma2: sigma parameter for the second gaussian kernel. if a list of values, must match the count\n            of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension.\n            if only 1 value provided, use it for all spatial dimensions.\n        alpha: weight parameter to compute the final result.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        sigma1: Sequence[float] | float = 3.0,\n        sigma2: Sequence[float] | float = 1.0,\n        alpha: float = 30.0,\n        approx: str = \"erf\",\n    ) -> None:\n        self.sigma1 = sigma1\n        self.sigma2 = sigma2\n        self.alpha = alpha\n        self.approx = approx\n\n    def __call__(self, img: NdarrayTensor) -> NdarrayTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32)\n\n        gf1, gf2 = (\n            GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device)\n            for sigma in (self.sigma1, self.sigma2)\n        )\n        blurred_f = gf1(img_t.unsqueeze(0))\n        filter_blurred_f = gf2(blurred_f)\n        out_t: torch.Tensor = (blurred_f + self.alpha * (blurred_f - filter_blurred_f)).squeeze(0)\n        out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)\n        return out\n\n\nclass RandGaussianSharpen(RandomizableTransform):\n    \"\"\"\n    Sharpen images using the Gaussian Blur filter based on randomly selected `sigma1`, `sigma2` and `alpha`.\n    The algorithm is :py:class:`monai.transforms.GaussianSharpen`.\n\n    Args:\n        sigma1_x: randomly select sigma value for the first spatial dimension of first gaussian kernel.\n        sigma1_y: randomly select sigma value for the second spatial dimension(if have) of first gaussian kernel.\n        sigma1_z: randomly select sigma value for the third spatial dimension(if have) of first gaussian kernel.\n        sigma2_x: randomly select sigma value for the first spatial dimension of second gaussian kernel.\n            if only 1 value `X` provided, it must be smaller than `sigma1_x` and randomly select from [X, sigma1_x].\n        sigma2_y: randomly select sigma value for the second spatial dimension(if have) of second gaussian kernel.\n            if only 1 value `Y` provided, it must be smaller than `sigma1_y` and randomly select from [Y, sigma1_y].\n        sigma2_z: randomly select sigma value for the third spatial dimension(if have) of second gaussian kernel.\n            if only 1 value `Z` provided, it must be smaller than `sigma1_z` and randomly select from [Z, sigma1_z].\n        alpha: randomly select weight parameter to compute the final result.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n        prob: probability of Gaussian sharpen.\n\n    \"\"\"\n\n    backend = GaussianSharpen.backend\n\n    def __init__(\n        self,\n        sigma1_x: tuple[float, float] = (0.5, 1.0),\n        sigma1_y: tuple[float, float] = (0.5, 1.0),\n        sigma1_z: tuple[float, float] = (0.5, 1.0),\n        sigma2_x: tuple[float, float] | float = 0.5,\n        sigma2_y: tuple[float, float] | float = 0.5,\n        sigma2_z: tuple[float, float] | float = 0.5,\n        alpha: tuple[float, float] = (10.0, 30.0),\n        approx: str = \"erf\",\n        prob: float = 0.1,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        self.sigma1_x = sigma1_x\n        self.sigma1_y = sigma1_y\n        self.sigma1_z = sigma1_z\n        self.sigma2_x = sigma2_x\n        self.sigma2_y = sigma2_y\n        self.sigma2_z = sigma2_z\n        self.alpha = alpha\n        self.approx = approx\n        self.x1: float | None = None\n        self.y1: float | None = None\n        self.z1: float | None = None\n        self.x2: float | None = None\n        self.y2: float | None = None\n        self.z2: float | None = None\n        self.a: float | None = None\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.x1 = self.R.uniform(low=self.sigma1_x[0], high=self.sigma1_x[1])\n        self.y1 = self.R.uniform(low=self.sigma1_y[0], high=self.sigma1_y[1])\n        self.z1 = self.R.uniform(low=self.sigma1_z[0], high=self.sigma1_z[1])\n        sigma2_x = (self.sigma2_x, self.x1) if not isinstance(self.sigma2_x, Iterable) else self.sigma2_x\n        sigma2_y = (self.sigma2_y, self.y1) if not isinstance(self.sigma2_y, Iterable) else self.sigma2_y\n        sigma2_z = (self.sigma2_z, self.z1) if not isinstance(self.sigma2_z, Iterable) else self.sigma2_z\n        self.x2 = self.R.uniform(low=sigma2_x[0], high=sigma2_x[1])\n        self.y2 = self.R.uniform(low=sigma2_y[0], high=sigma2_y[1])\n        self.z2 = self.R.uniform(low=sigma2_z[0], high=sigma2_z[1])\n        self.a = self.R.uniform(low=self.alpha[0], high=self.alpha[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None:\n            raise RuntimeError(\"please call the `randomize()` function first.\")\n        sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=img.ndim - 1)\n        sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=img.ndim - 1)\n        return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img)\n\n\nclass RandHistogramShift(RandomizableTransform):\n    \"\"\"\n    Apply random nonlinear transform to the image's intensity histogram.\n\n    Args:\n        num_control_points: number of control points governing the nonlinear intensity mapping.\n            a smaller number of control points allows for larger intensity shifts. if two values provided, number of\n            control points selecting from range (min_value, max_value).\n        prob: probability of histogram shift.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, num_control_points: tuple[int, int] | int = 10, prob: float = 0.1) -> None:\n        RandomizableTransform.__init__(self, prob)\n\n        if isinstance(num_control_points, int):\n            if num_control_points <= 2:\n                raise ValueError(\"num_control_points should be greater than or equal to 3\")\n            self.num_control_points = (num_control_points, num_control_points)\n        else:\n            if len(num_control_points) != 2:\n                raise ValueError(\"num_control points should be a number or a pair of numbers\")\n            if min(num_control_points) <= 2:\n                raise ValueError(\"num_control_points should be greater than or equal to 3\")\n            self.num_control_points = (min(num_control_points), max(num_control_points))\n        self.reference_control_points: NdarrayOrTensor\n        self.floating_control_points: NdarrayOrTensor\n\n    def interp(self, x: NdarrayOrTensor, xp: NdarrayOrTensor, fp: NdarrayOrTensor) -> NdarrayOrTensor:\n        ns = torch if isinstance(x, torch.Tensor) else np\n        if isinstance(x, np.ndarray):\n            # approx 2x faster than code below for ndarray\n            return np.interp(x, xp, fp)\n\n        m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])\n        b = fp[:-1] - (m * xp[:-1])\n\n        indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1\n        indices = ns.clip(indices, 0, len(m) - 1)\n\n        f: NdarrayOrTensor = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape)\n        f[x < xp[0]] = fp[0]\n        f[x > xp[-1]] = fp[-1]\n        return f\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        num_control_point = self.R.randint(self.num_control_points[0], self.num_control_points[1] + 1)\n        self.reference_control_points = np.linspace(0, 1, num_control_point)\n        self.floating_control_points = np.copy(self.reference_control_points)\n        for i in range(1, num_control_point - 1):\n            self.floating_control_points[i] = self.R.uniform(\n                self.floating_control_points[i - 1], self.floating_control_points[i + 1]\n            )\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        if self.reference_control_points is None or self.floating_control_points is None:\n            raise RuntimeError(\"please call the `randomize()` function first.\")\n        img_t = convert_to_tensor(img, track_meta=False)\n        img_min, img_max = img_t.min(), img_t.max()\n        if img_min == img_max:\n            warn(\n                f\"The image's intensity is a single value {img_min}. \"\n                \"The original image is simply returned, no histogram shift is done.\"\n            )\n            return img\n        xp, *_ = convert_to_dst_type(self.reference_control_points, dst=img_t)\n        yp, *_ = convert_to_dst_type(self.floating_control_points, dst=img_t)\n        reference_control_points_scaled = xp * (img_max - img_min) + img_min\n        floating_control_points_scaled = yp * (img_max - img_min) + img_min\n        img_t = self.interp(img_t, reference_control_points_scaled, floating_control_points_scaled)\n        return convert_to_dst_type(img_t, dst=img)[0]\n\n\nclass GibbsNoise(Transform, Fourier):\n    \"\"\"\n    The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts\n    are one of the common type of type artifacts appearing in MRI scans.\n\n    The transform is applied to all the channels in the data.\n\n    For general information on Gibbs artifacts, please refer to:\n\n    `An Image-based Approach to Understanding the Physics of MR Artifacts\n    <https://pubs.rsna.org/doi/full/10.1148/rg.313105115>`_.\n\n    `The AAPM/RSNA Physics Tutorial for Residents\n    <https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949>`_\n\n    Args:\n        alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes\n            values in the interval [0,1] with alpha = 0 acting as the identity mapping.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, alpha: float = 0.1) -> None:\n        if alpha > 1 or alpha < 0:\n            raise ValueError(\"alpha must take values in the interval [0, 1].\")\n        self.alpha = alpha\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t = convert_to_tensor(img, track_meta=False)\n        n_dims = len(img_t.shape[1:])\n\n        # FT\n        k = self.shift_fourier(img_t, n_dims)\n        # build and apply mask\n        k = self._apply_mask(k)\n        # map back\n        out = self.inv_shift_fourier(k, n_dims)\n        img, *_ = convert_to_dst_type(out, dst=img, dtype=out.dtype)\n\n        return img\n\n    def _apply_mask(self, k: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"Builds and applies a mask on the spatial dimensions.\n\n        Args:\n            k: k-space version of the image.\n        Returns:\n            masked version of the k-space image.\n        \"\"\"\n        shape = k.shape[1:]\n\n        # compute masking radius and center\n        r = (1 - self.alpha) * np.max(shape) * np.sqrt(2) / 2.0\n        center = (np.array(shape) - 1) / 2\n\n        # gives list w/ len==self.dim. Each dim gives coordinate in that dimension\n        coords = np.ogrid[tuple(slice(0, i) for i in shape)]\n\n        # need to subtract center coord and then square for Euc distance\n        coords_from_center_sq = [(coord - c) ** 2 for coord, c in zip(coords, center)]\n        dist_from_center = np.sqrt(sum(coords_from_center_sq))\n        mask = dist_from_center <= r\n\n        # add channel dimension into mask\n        mask = np.repeat(mask[None], k.shape[0], axis=0)\n\n        if isinstance(k, torch.Tensor):\n            mask, *_ = convert_data_type(mask, torch.Tensor, device=k.device)\n\n        # apply binary mask\n        k_masked: NdarrayOrTensor\n        k_masked = k * mask\n        return k_masked\n\n\nclass RandGibbsNoise(RandomizableTransform):\n    \"\"\"\n    Naturalistic image augmentation via Gibbs artifacts. The transform\n    randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts\n    are one of the common type of type artifacts appearing in MRI scans.\n\n    The transform is applied to all the channels in the data.\n\n    For general information on Gibbs artifacts, please refer to:\n    https://pubs.rsna.org/doi/full/10.1148/rg.313105115\n    https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949\n\n\n    Args:\n        prob (float): probability of applying the transform.\n        alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes\n            values in the interval [0,1] with alpha = 0 acting as the identity mapping.\n            If a length-2 list is given as [a,b] then the value of alpha will be\n            sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.\n            If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].\n    \"\"\"\n\n    backend = GibbsNoise.backend\n\n    def __init__(self, prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0)) -> None:\n        if isinstance(alpha, float):\n            alpha = (0, alpha)\n        alpha = ensure_tuple(alpha)\n        if len(alpha) != 2:\n            raise ValueError(\"alpha length must be 2.\")\n        if alpha[1] > 1 or alpha[0] < 0:\n            raise ValueError(\"alpha must take values in the interval [0, 1]\")\n        if alpha[0] > alpha[1]:\n            raise ValueError(\"When alpha = [a,b] we need a < b.\")\n\n        self.alpha = alpha\n        self.sampled_alpha = -1.0  # stores last alpha sampled by randomize()\n\n        RandomizableTransform.__init__(self, prob=prob)\n\n    def randomize(self, data: Any) -> None:\n        \"\"\"\n        (1) Set random variable to apply the transform.\n        (2) Get alpha from uniform distribution.\n        \"\"\"\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1])\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True):\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            # randomize application and possibly alpha\n            self.randomize(None)\n\n        if not self._do_transform:\n            return img\n\n        return GibbsNoise(self.sampled_alpha)(img)\n\n\nclass KSpaceSpikeNoise(Transform, Fourier):\n    \"\"\"\n    Apply localized spikes in `k`-space at the given locations and intensities.\n    Spike (Herringbone) artifact is a type of data acquisition artifact which\n    may occur during MRI scans.\n\n    For general information on spike artifacts, please refer to:\n\n    `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging\n    <https://pubmed.ncbi.nlm.nih.gov/16009826>`_.\n\n    `Body MRI artifacts in clinical practice: A physicist's and radiologist's\n    perspective <https://doi.org/10.1002/jmri.24288>`_.\n\n    Args:\n        loc: spatial location for the spikes. For\n            images with 3D spatial dimensions, the user can provide (C, X, Y, Z)\n            to fix which channel C is affected, or (X, Y, Z) to place the same\n            spike in all channels. For 2D cases, the user can provide (C, X, Y)\n            or (X, Y).\n        k_intensity: value for the log-intensity of the\n            `k`-space version of the image. If one location is passed to ``loc`` or the\n            channel is not specified, then this argument should receive a float. If\n            ``loc`` is given a sequence of locations, then this argument should\n            receive a sequence of intensities. This value should be tested as it is\n            data-dependent. The default values are the 2.5 the mean of the\n            log-intensity for each channel.\n\n    Example:\n        When working with 4D data, ``KSpaceSpikeNoise(loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))``\n        will place a spike at `[3, 60, 64, 32]` with `log-intensity = 13`, and\n        one spike per channel located respectively at `[: , 64, 60, 32]`\n        with `log-intensity = 14`.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, loc: tuple | Sequence[tuple], k_intensity: Sequence[float] | float | None = None):\n        self.loc = ensure_tuple(loc)\n        self.k_intensity = k_intensity\n\n        # assert one-to-one relationship between factors and locations\n        if isinstance(k_intensity, Sequence):\n            if not isinstance(loc[0], Sequence):\n                raise ValueError(\n                    \"If a sequence is passed to k_intensity, then a sequence of locations must be passed to loc\"\n                )\n            if len(k_intensity) != len(loc):\n                raise ValueError(\"There must be one intensity_factor value for each tuple of indices in loc.\")\n        if isinstance(self.loc[0], Sequence) and k_intensity is not None and not isinstance(self.k_intensity, Sequence):\n            raise ValueError(\"There must be one intensity_factor value for each tuple of indices in loc.\")\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: image with dimensions (C, H, W) or (C, H, W, D)\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        # checking that tuples in loc are consistent with img size\n        self._check_indices(img)\n\n        if len(img.shape) < 3:\n            raise RuntimeError(\"Image needs a channel direction.\")\n        if isinstance(self.loc[0], int) and len(img.shape) == 4 and len(self.loc) == 2:\n            raise RuntimeError(\"Input images of dimension 4 need location tuple to be length 3 or 4\")\n        if isinstance(self.loc[0], Sequence) and len(img.shape) == 4 and min(map(len, self.loc)) == 2:\n            raise RuntimeError(\"Input images of dimension 4 need location tuple to be length 3 or 4\")\n\n        n_dims = len(img.shape[1:])\n\n        # FT\n        k = self.shift_fourier(img, n_dims)\n        lib = np if isinstance(k, np.ndarray) else torch\n        log_abs = lib.log(lib.abs(k) + 1e-10)\n        phase = lib.angle(k)\n\n        k_intensity = self.k_intensity\n        # default log intensity\n        if k_intensity is None:\n            k_intensity = tuple(lib.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5)\n\n        # highlight\n        if isinstance(self.loc[0], Sequence):\n            for idx, val in zip(self.loc, ensure_tuple(k_intensity)):\n                self._set_spike(log_abs, idx, val)\n        else:\n            self._set_spike(log_abs, self.loc, k_intensity)\n        # map back\n        k = lib.exp(log_abs) * lib.exp(1j * phase)\n        img, *_ = convert_to_dst_type(self.inv_shift_fourier(k, n_dims), dst=img)\n\n        return img\n\n    def _check_indices(self, img) -> None:\n        \"\"\"Helper method to check consistency of self.loc and input image.\n\n        Raises assertion error if any index in loc is out of bounds.\"\"\"\n\n        loc = list(self.loc)\n        if not isinstance(loc[0], Sequence):\n            loc = [loc]\n        for i in range(len(loc)):\n            if len(loc[i]) < len(img.shape):\n                loc[i] = [0] + list(loc[i])\n\n        for i in range(len(img.shape)):\n            if img.shape[i] <= max(x[i] for x in loc):\n                raise ValueError(\n                    f\"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image.\"\n                )\n\n    def _set_spike(self, k: NdarrayOrTensor, idx: tuple, val: Sequence[float] | float):\n        \"\"\"\n        Helper function to introduce a given intensity at given location.\n\n        Args:\n            k: intensity array to alter.\n            idx: index of location where to apply change.\n            val: value of intensity to write in.\n        \"\"\"\n        if len(k.shape) == len(idx):\n            k[idx] = val[idx[0]] if isinstance(val, Sequence) else val\n        elif len(k.shape) == 4 and len(idx) == 3:\n            k[:, idx[0], idx[1], idx[2]] = val  # type: ignore\n        elif len(k.shape) == 3 and len(idx) == 2:\n            k[:, idx[0], idx[1]] = val  # type: ignore\n\n\nclass RandKSpaceSpikeNoise(RandomizableTransform, Fourier):\n    \"\"\"\n    Naturalistic data augmentation via spike artifacts. The transform applies\n    localized spikes in `k`-space, and it is the random version of\n    :py:class:`monai.transforms.KSpaceSpikeNoise`.\n\n    Spike (Herringbone) artifact is a type of data acquisition artifact which\n    may occur during MRI scans. For general information on spike artifacts,\n    please refer to:\n\n    `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging\n    <https://pubmed.ncbi.nlm.nih.gov/16009826>`_.\n\n    `Body MRI artifacts in clinical practice: A physicist's and radiologist's\n    perspective <https://doi.org/10.1002/jmri.24288>`_.\n\n    Args:\n        prob: probability of applying the transform, either on all\n            channels at once, or channel-wise if ``channel_wise = True``.\n        intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b)\n            uniformly for all channels. Or pass sequence of intervals\n            ((a0, b0), (a1, b1), ...) to sample for each respective channel.\n            In the second case, the number of 2-tuples must match the number of channels.\n            Default ranges is `(0.95x, 1.10x)` where `x` is the mean\n            log-intensity for each channel.\n        channel_wise: treat each channel independently. True by\n            default.\n\n    Example:\n        To apply `k`-space spikes randomly with probability 0.5, and\n        log-intensity sampled from the interval [11, 12] for each channel\n        independently, one uses\n        ``RandKSpaceSpikeNoise(prob=0.5, intensity_range=(11, 12), channel_wise=True)``\n    \"\"\"\n\n    backend = KSpaceSpikeNoise.backend\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        intensity_range: Sequence[Sequence[float] | float] | None = None,\n        channel_wise: bool = True,\n    ):\n        self.intensity_range = intensity_range\n        self.channel_wise = channel_wise\n        self.sampled_k_intensity: list = []\n        self.sampled_locs: list[tuple] = []\n\n        if intensity_range is not None and isinstance(intensity_range[0], Sequence) and not channel_wise:\n            raise ValueError(\"When channel_wise = False, intensity_range should be a 2-tuple (low, high) or None.\")\n\n        super().__init__(prob)\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True):\n        \"\"\"\n        Apply transform to `img`. Assumes data is in channel-first form.\n\n        Args:\n            img: image with dimensions (C, H, W) or (C, H, W, D)\n        \"\"\"\n\n        if (\n            self.intensity_range is not None\n            and isinstance(self.intensity_range[0], Sequence)\n            and len(self.intensity_range) != img.shape[0]\n        ):\n            raise RuntimeError(\n                \"If intensity_range is a sequence of sequences, then there must be one (low, high) tuple for each channel.\"\n            )\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        self.sampled_k_intensity = []\n        self.sampled_locs = []\n\n        if randomize:\n            intensity_range = self._make_sequence(img)\n            self.randomize(img, intensity_range)\n\n        if not self._do_transform:\n            return img\n\n        return KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity)(img)\n\n    def randomize(self, img: NdarrayOrTensor, intensity_range: Sequence[Sequence[float]]) -> None:  # type: ignore\n        \"\"\"\n        Helper method to sample both the location and intensity of the spikes.\n        When not working channel wise (channel_wise=False) it use the random\n        variable ``self._do_transform`` to decide whether to sample a location\n        and intensity.\n\n        When working channel wise, the method randomly samples a location and\n        intensity for each channel depending on ``self._do_transform``.\n        \"\"\"\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        if self.channel_wise:\n            # randomizing per channel\n            for i, chan in enumerate(img):\n                self.sampled_locs.append((i,) + tuple(self.R.randint(0, k) for k in chan.shape))\n                self.sampled_k_intensity.append(self.R.uniform(intensity_range[i][0], intensity_range[i][1]))\n        else:\n            # working with all channels together\n            spatial = tuple(self.R.randint(0, k) for k in img.shape[1:])\n            self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])]\n            if isinstance(intensity_range[0], Sequence):\n                self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range]\n            else:\n                self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img)\n\n    def _make_sequence(self, x: NdarrayOrTensor) -> Sequence[Sequence[float]]:\n        \"\"\"\n        Formats the sequence of intensities ranges to Sequence[Sequence[float]].\n        \"\"\"\n        if self.intensity_range is None:\n            # set default range if one not provided\n            return self._set_default_range(x)\n\n        if not isinstance(self.intensity_range[0], Sequence):\n            return (ensure_tuple(self.intensity_range),) * x.shape[0]\n        return ensure_tuple(self.intensity_range)\n\n    def _set_default_range(self, img: NdarrayOrTensor) -> Sequence[Sequence[float]]:\n        \"\"\"\n        Sets default intensity ranges to be sampled.\n\n        Args:\n            img: image to transform.\n        \"\"\"\n        n_dims = len(img.shape[1:])\n\n        k = self.shift_fourier(img, n_dims)\n        mod = torch if isinstance(k, torch.Tensor) else np\n        log_abs = mod.log(mod.absolute(k) + 1e-10)\n        shifted_means = mod.mean(log_abs, tuple(range(-n_dims, 0))) * 2.5\n        if isinstance(shifted_means, torch.Tensor):\n            shifted_means = shifted_means.to(\"cpu\")\n        return tuple((i * 0.95, i * 1.1) for i in shifted_means)\n\n\nclass RandCoarseTransform(RandomizableTransform):\n    \"\"\"\n    Randomly select coarse regions in the image, then execute transform operations for the regions.\n    It's the base class of all kinds of region transforms.\n    Refer to papers: https://arxiv.org/abs/1708.04552\n\n    Args:\n        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to\n            randomly select the expected number of regions.\n        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg\n            as the minimum spatial size to randomly select size for every region.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        max_holes: if not None, define the maximum number to randomly select the expected number of regions.\n        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.\n            if some components of the `max_spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        prob: probability of applying the transform.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        holes: int,\n        spatial_size: Sequence[int] | int,\n        max_holes: int | None = None,\n        max_spatial_size: Sequence[int] | int | None = None,\n        prob: float = 0.1,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        if holes < 1:\n            raise ValueError(\"number of holes must be greater than 0.\")\n        self.holes = holes\n        self.spatial_size = spatial_size\n        self.max_holes = max_holes\n        self.max_spatial_size = max_spatial_size\n        self.hole_coords: list = []\n\n    def randomize(self, img_size: Sequence[int]) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        size = fall_back_tuple(self.spatial_size, img_size)\n        self.hole_coords = []  # clear previously computed coords\n        num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1)\n        for _ in range(num_holes):\n            if self.max_spatial_size is not None:\n                max_size = fall_back_tuple(self.max_spatial_size, img_size)\n                size = tuple(self.R.randint(low=size[i], high=max_size[i] + 1) for i in range(len(img_size)))\n            valid_size = get_valid_patch_size(img_size, size)\n            self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R))\n\n    @abstractmethod\n    def _transform_holes(self, img: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Transform the randomly selected `self.hole_coords` in input images.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(img.shape[1:])\n\n        if not self._do_transform:\n            return img\n\n        img_np, *_ = convert_data_type(img, np.ndarray)\n        out = self._transform_holes(img=img_np)\n        ret, *_ = convert_to_dst_type(src=out, dst=img)\n        return ret\n\n\nclass RandCoarseDropout(RandCoarseTransform):\n    \"\"\"\n    Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value.\n    Or keep the rectangular regions and fill in the other areas with specified value.\n    Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379\n    And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/\n    #albumentations.augmentations.transforms.CoarseDropout.\n\n    Args:\n        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to\n            randomly select the expected number of regions.\n        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg\n            as the minimum spatial size to randomly select size for every region.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and\n            dropout the outside and fill value. default to `True`.\n        fill_value: target value to fill the dropout regions, if providing a number, will use it as constant\n            value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select\n            value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max`\n            value of input image then randomly select value to fill, default to None.\n        max_holes: if not None, define the maximum number to randomly select the expected number of regions.\n        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.\n            if some components of the `max_spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        prob: probability of applying the transform.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        holes: int,\n        spatial_size: Sequence[int] | int,\n        dropout_holes: bool = True,\n        fill_value: tuple[float, float] | float | None = None,\n        max_holes: int | None = None,\n        max_spatial_size: Sequence[int] | int | None = None,\n        prob: float = 0.1,\n    ) -> None:\n        super().__init__(\n            holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=prob\n        )\n        self.dropout_holes = dropout_holes\n        if isinstance(fill_value, (tuple, list)):\n            if len(fill_value) != 2:\n                raise ValueError(\"fill value should contain 2 numbers if providing the `min` and `max`.\")\n        self.fill_value = fill_value\n\n    def _transform_holes(self, img: np.ndarray):\n        \"\"\"\n        Fill the randomly selected `self.hole_coords` in input images.\n        Please note that we usually only use `self.R` in `randomize()` method, here is a special case.\n\n        \"\"\"\n        fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value\n\n        if self.dropout_holes:\n            for h in self.hole_coords:\n                if isinstance(fill_value, (tuple, list)):\n                    img[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape)\n                else:\n                    img[h] = fill_value\n            ret = img\n        else:\n            if isinstance(fill_value, (tuple, list)):\n                ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype, copy=False)\n            else:\n                ret = np.full_like(img, fill_value)\n            for h in self.hole_coords:\n                ret[h] = img[h]\n        return ret\n\n\nclass RandCoarseShuffle(RandCoarseTransform):\n    \"\"\"\n    Randomly select regions in the image, then shuffle the pixels within every region.\n    It shuffles every channel separately.\n    Refer to paper:\n    Kang, Guoliang, et al. \"Patchshuffle regularization.\" arXiv preprint arXiv:1707.07103 (2017).\n    https://arxiv.org/abs/1707.07103\n\n    Args:\n        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to\n            randomly select the expected number of regions.\n        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg\n            as the minimum spatial size to randomly select size for every region.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        max_holes: if not None, define the maximum number to randomly select the expected number of regions.\n        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.\n            if some components of the `max_spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        prob: probability of applying the transform.\n\n    \"\"\"\n\n    def _transform_holes(self, img: np.ndarray):\n        \"\"\"\n        Shuffle the content of randomly selected `self.hole_coords` in input images.\n        Please note that we usually only use `self.R` in `randomize()` method, here is a special case.\n\n        \"\"\"\n        for h in self.hole_coords:\n            # shuffle every channel separately\n            for i, c in enumerate(img[h]):\n                patch_channel = c.flatten()\n                self.R.shuffle(patch_channel)\n                img[h][i] = patch_channel.reshape(c.shape)\n        return img\n\n\nclass HistogramNormalize(Transform):\n    \"\"\"\n    Apply the histogram normalization to input image.\n    Refer to: https://github.com/facebookresearch/CovidPrognosis/blob/master/covidprognosis/data/transforms.py#L83.\n\n    Args:\n        num_bins: number of the bins to use in histogram, default to `256`. for more details:\n            https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.\n        min: the min value to normalize input image, default to `0`.\n        max: the max value to normalize input image, default to `255`.\n        mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.\n            only points at which `mask==True` are used for the equalization.\n            can also provide the mask along with img at runtime.\n        dtype: data type of the output, if None, same as input image. default to `float32`.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        num_bins: int = 256,\n        min: int = 0,\n        max: int = 255,\n        mask: NdarrayOrTensor | None = None,\n        dtype: DtypeLike = np.float32,\n    ) -> None:\n        self.num_bins = num_bins\n        self.min = min\n        self.max = max\n        self.mask = mask\n        self.dtype = dtype\n\n    def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_np, *_ = convert_data_type(img, np.ndarray)\n        mask = mask if mask is not None else self.mask\n        mask_np: np.ndarray | None = None\n        if mask is not None:\n            mask_np, *_ = convert_data_type(mask, np.ndarray)\n\n        ret = equalize_hist(img=img_np, mask=mask_np, num_bins=self.num_bins, min=self.min, max=self.max)\n        out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype)\n\n        return out\n\n\nclass IntensityRemap(RandomizableTransform):\n    \"\"\"\n    Transform for intensity remapping of images. The intensity at each\n    pixel is replaced by a new values coming from an intensity remappping\n    curve.\n\n    The remapping curve is created by uniformly sampling values from the\n    possible intensities for the input image and then adding a linear\n    component. The curve is the rescaled to the input image intensity range.\n\n    Intended to be used as a means to data augmentation via:\n    :py:class:`monai.transforms.RandIntensityRemap`.\n\n    Implementation is described in the work:\n    `Intensity augmentation for domain transfer of whole breast segmentation\n    in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.\n\n    Args:\n        kernel_size: window size for averaging operation for the remapping\n            curve.\n        slope: slope of the linear component. Easiest to leave default value\n            and tune the kernel_size parameter instead.\n    \"\"\"\n\n    def __init__(self, kernel_size: int = 30, slope: float = 0.7):\n        super().__init__()\n\n        self.kernel_size = kernel_size\n        self.slope = slope\n\n    def __call__(self, img: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: image to remap.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_ = convert_to_tensor(img, track_meta=False)\n        # sample noise\n        vals_to_sample = torch.unique(img_).tolist()\n        noise = torch.from_numpy(self.R.choice(vals_to_sample, len(vals_to_sample) - 1 + self.kernel_size))\n        # smooth\n        noise = torch.nn.AvgPool1d(self.kernel_size, stride=1)(noise.unsqueeze(0)).squeeze()\n        # add linear component\n        grid = torch.arange(len(noise)) / len(noise)\n        noise += self.slope * grid\n        # rescale\n        noise = (noise - noise.min()) / (noise.max() - noise.min()) * img_.max() + img_.min()\n\n        # intensity remapping function\n        index_img = torch.bucketize(img_, torch.tensor(vals_to_sample))\n        img, *_ = convert_to_dst_type(noise[index_img], dst=img)\n\n        return img\n\n\nclass RandIntensityRemap(RandomizableTransform):\n    \"\"\"\n    Transform for intensity remapping of images. The intensity at each\n    pixel is replaced by a new values coming from an intensity remappping\n    curve.\n\n    The remapping curve is created by uniformly sampling values from the\n    possible intensities for the input image and then adding a linear\n    component. The curve is the rescaled to the input image intensity range.\n\n    Implementation is described in the work:\n    `Intensity augmentation for domain transfer of whole breast segmentation\n    in MRI <https://ieeexplore.ieee.org/abstract/document/9166708>`_.\n\n    Args:\n        prob: probability of applying the transform.\n        kernel_size: window size for averaging operation for the remapping\n            curve.\n        slope: slope of the linear component. Easiest to leave default value\n            and tune the kernel_size parameter instead.\n        channel_wise: set to True to treat each channel independently.\n    \"\"\"\n\n    def __init__(self, prob: float = 0.1, kernel_size: int = 30, slope: float = 0.7, channel_wise: bool = True):\n        RandomizableTransform.__init__(self, prob=prob)\n        self.kernel_size = kernel_size\n        self.slope = slope\n        self.channel_wise = channel_wise\n\n    def __call__(self, img: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: image to remap.\n        \"\"\"\n        super().randomize(None)\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if self._do_transform:\n            if self.channel_wise:\n                img = torch.stack(\n                    [\n                        IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img[i])\n                        for i in range(len(img))\n                    ]\n                )\n            else:\n                img = IntensityRemap(self.kernel_size, self.R.choice([-self.slope, self.slope]))(img)\n\n        return img\n\n\nclass ForegroundMask(Transform):\n    \"\"\"\n    Creates a binary mask that defines the foreground based on thresholds in RGB or HSV color space.\n    This transform receives an RGB (or grayscale) image where by default it is assumed that the foreground has\n    low values (dark) while the background has high values (white). Otherwise, set `invert` argument to `True`.\n\n    Args:\n        threshold: an int or a float number that defines the threshold that values less than that are foreground.\n            It also can be a callable that receives each dimension of the image and calculate the threshold,\n            or a string that defines such callable from `skimage.filter.threshold_...`. For the list of available\n            threshold functions, please refer to https://scikit-image.org/docs/stable/api/skimage.filters.html\n            Moreover, a dictionary can be passed that defines such thresholds for each channel, like\n            {\"R\": 100, \"G\": \"otsu\", \"B\": skimage.filter.threshold_mean}\n        hsv_threshold: similar to threshold but HSV color space (\"H\", \"S\", and \"V\").\n            Unlike RBG, in HSV, value greater than `hsv_threshold` are considered foreground.\n        invert: invert the intensity range of the input image, so that the dtype maximum is now the dtype minimum,\n            and vice-versa.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        threshold: dict | Callable | str | float | int = \"otsu\",\n        hsv_threshold: dict | Callable | str | float | int | None = None,\n        invert: bool = False,\n    ) -> None:\n        self.thresholds: dict[str, Callable | float] = {}\n        if threshold is not None:\n            if isinstance(threshold, dict):\n                for mode, th in threshold.items():\n                    self._set_threshold(th, mode.upper())\n            else:\n                self._set_threshold(threshold, \"R\")\n                self._set_threshold(threshold, \"G\")\n                self._set_threshold(threshold, \"B\")\n        if hsv_threshold is not None:\n            if isinstance(hsv_threshold, dict):\n                for mode, th in hsv_threshold.items():\n                    self._set_threshold(th, mode.upper())\n            else:\n                self._set_threshold(hsv_threshold, \"H\")\n                self._set_threshold(hsv_threshold, \"S\")\n                self._set_threshold(hsv_threshold, \"V\")\n\n        self.thresholds = {k: v for k, v in self.thresholds.items() if v is not None}\n        if self.thresholds.keys().isdisjoint(set(\"RGBHSV\")):\n            raise ValueError(\n                f\"Threshold for at least one channel of RGB or HSV needs to be set. {self.thresholds} is provided.\"\n            )\n        self.invert = invert\n\n    def _set_threshold(self, threshold, mode):\n        if callable(threshold):\n            self.thresholds[mode] = threshold\n        elif isinstance(threshold, str):\n            self.thresholds[mode] = getattr(skimage.filters, \"threshold_\" + threshold.lower())\n        elif isinstance(threshold, (float, int)):\n            self.thresholds[mode] = float(threshold)\n        else:\n            raise ValueError(\n                f\"`threshold` should be either a callable, string, or float number, {type(threshold)} was given.\"\n            )\n\n    def _get_threshold(self, image, mode):\n        threshold = self.thresholds.get(mode)\n        if callable(threshold):\n            return threshold(image)\n        return threshold\n\n    def __call__(self, image: NdarrayOrTensor):\n        image = convert_to_tensor(image, track_meta=get_track_meta())\n        img_rgb, *_ = convert_data_type(image, np.ndarray)\n        if self.invert:\n            img_rgb = skimage.util.invert(img_rgb)\n        foregrounds = []\n        if not self.thresholds.keys().isdisjoint(set(\"RGB\")):\n            rgb_foreground = np.zeros_like(img_rgb[:1])\n            for img, mode in zip(img_rgb, \"RGB\"):\n                threshold = self._get_threshold(img, mode)\n                if threshold:\n                    rgb_foreground = np.logical_or(rgb_foreground, img <= threshold)\n            foregrounds.append(rgb_foreground)\n        if not self.thresholds.keys().isdisjoint(set(\"HSV\")):\n            img_hsv = skimage.color.rgb2hsv(img_rgb, channel_axis=0)\n            hsv_foreground = np.zeros_like(img_rgb[:1])\n            for img, mode in zip(img_hsv, \"HSV\"):\n                threshold = self._get_threshold(img, mode)\n                if threshold:\n                    hsv_foreground = np.logical_or(hsv_foreground, img > threshold)\n            foregrounds.append(hsv_foreground)\n\n        mask = np.stack(foregrounds).all(axis=0)\n        return convert_to_dst_type(src=mask, dst=image)[0]\n\n\nclass ComputeHoVerMaps(Transform):\n    \"\"\"Compute horizontal and vertical maps from an instance mask\n    It generates normalized horizontal and vertical distances to the center of mass of each region.\n    Input data with the size of [1xHxW[xD]], which channel dim will temporarily removed for calculating coordinates.\n\n    Args:\n        dtype: the data type of output Tensor. Defaults to `\"float32\"`.\n\n    Return:\n        A torch.Tensor with the size of [2xHxW[xD]], which is stack horizontal and vertical maps\n\n    \"\"\"\n\n    def __init__(self, dtype: DtypeLike = \"float32\") -> None:\n        super().__init__()\n        self.dtype = dtype\n\n    def __call__(self, mask: NdarrayOrTensor):\n        instance_mask: np.ndarray = convert_data_type(mask, np.ndarray)[0]  # type: ignore[assignment]\n\n        h_map = instance_mask.astype(self.dtype, copy=True)\n        v_map = instance_mask.astype(self.dtype, copy=True)\n        instance_mask = instance_mask.squeeze(0)  # remove channel dim\n\n        for region in skimage.measure.regionprops(instance_mask):\n            v_dist = region.coords[:, 0] - region.centroid[0]\n            h_dist = region.coords[:, 1] - region.centroid[1]\n\n            h_dist[h_dist < 0] /= -np.amin(h_dist)\n            h_dist[h_dist > 0] /= np.amax(h_dist)\n\n            v_dist[v_dist < 0] /= -np.amin(v_dist)\n            v_dist[v_dist > 0] /= np.amax(v_dist)\n\n            h_map[h_map == region.label] = h_dist\n            v_map[v_map == region.label] = v_dist\n\n        hv_maps = convert_to_tensor(np.concatenate([h_map, v_map]), track_meta=get_track_meta())\n        return hv_maps\n\n\nclass UltrasoundConfidenceMapTransform(Transform):\n    \"\"\"Compute confidence map from an ultrasound image.\n    This transform uses the method introduced by Karamalis et al. in https://doi.org/10.1016/j.media.2012.07.005.\n    It generates a confidence map by setting source and sink points in the image and computing the probability\n    for random walks to reach the source for each pixel.\n\n    The official code is available at:\n    https://campar.in.tum.de/Main/AthanasiosKaramalisCode\n\n    Args:\n        alpha (float, optional): Alpha parameter. Defaults to 2.0.\n        beta (float, optional): Beta parameter. Defaults to 90.0.\n        gamma (float, optional): Gamma parameter. Defaults to 0.05.\n        mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.\n        sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when\n            calling the transform. Can be one of 'all', 'mid', 'min', 'mask'.\n        use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.\n        cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.\n            Will be used only if `use_cg` is True.\n        cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.\n            Will be used only if `use_cg` is True.\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha: float = 2.0,\n        beta: float = 90.0,\n        gamma: float = 0.05,\n        mode=\"B\",\n        sink_mode=\"all\",\n        use_cg=False,\n        cg_tol: float = 1.0e-6,\n        cg_maxiter: int = 200,\n    ):\n        self.alpha = alpha\n        self.beta = beta\n        self.gamma = gamma\n        self.mode = mode\n        self.sink_mode = sink_mode\n        self.use_cg = use_cg\n        self.cg_tol = cg_tol\n        self.cg_maxiter = cg_maxiter\n\n        if self.mode not in [\"B\", \"RF\"]:\n            raise ValueError(f\"Unknown mode: {self.mode}. Supported modes are 'B' and 'RF'.\")\n\n        if self.sink_mode not in [\"all\", \"mid\", \"min\", \"mask\"]:\n            raise ValueError(\n                f\"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'.\"\n            )\n\n        self._compute_conf_map = UltrasoundConfidenceMap(\n            self.alpha, self.beta, self.gamma, self.mode, self.sink_mode, self.use_cg, self.cg_tol, self.cg_maxiter\n        )\n\n    def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor:\n        \"\"\"Compute confidence map from an ultrasound image.\n\n        Args:\n            img (ndarray or Tensor): Ultrasound image of shape [1, H, W] or [1, D, H, W]. If the image has channels,\n                they will be averaged before computing the confidence map.\n            mask (ndarray or Tensor, optional): Mask of shape [1, H, W]. Defaults to None. Must be\n                provided when sink mode is 'mask'. The non-zero values of the mask are used as sink points.\n\n        Returns:\n            ndarray or Tensor: Confidence map of shape [1, H, W].\n        \"\"\"\n\n        if self.sink_mode == \"mask\" and mask is None:\n            raise ValueError(\"A mask must be provided when sink mode is 'mask'.\")\n\n        if img.shape[0] != 1:\n            raise ValueError(\"The correct shape of the image is [1, H, W] or [1, D, H, W].\")\n\n        _img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_np, *_ = convert_data_type(_img, np.ndarray)\n        img_np = img_np[0]  # Remove the first dimension\n\n        mask_np = None\n        if mask is not None:\n            mask = convert_to_tensor(mask, dtype=torch.bool, track_meta=get_track_meta())\n            mask_np, *_ = convert_data_type(mask, np.ndarray)\n            mask_np = mask_np[0]  # Remove the first dimension\n\n        # If the image is RGB, convert it to grayscale\n        if len(img_np.shape) == 3:\n            img_np = np.mean(img_np, axis=0)\n\n        if mask_np is not None and mask_np.shape != img_np.shape:\n            raise ValueError(\"The mask must have the same shape as the image.\")\n\n        # Compute confidence map\n        conf_map: NdarrayOrTensor = self._compute_conf_map(img_np, mask_np)\n\n        if type(img) is torch.Tensor:\n            conf_map = torch.from_numpy(conf_map)\n\n        return conf_map\n"
  },
  {
    "path": "monai/transforms/intensity/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for intensity adjustment\ndefined in :py:class:`monai.transforms.intensity.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping, Sequence\nfrom typing import Callable\n\nimport numpy as np\n\nfrom monai.config import DtypeLike, KeysCollection\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.transforms.intensity.array import (\n    AdjustContrast,\n    ClipIntensityPercentiles,\n    ComputeHoVerMaps,\n    ForegroundMask,\n    GaussianSharpen,\n    GaussianSmooth,\n    GibbsNoise,\n    HistogramNormalize,\n    KSpaceSpikeNoise,\n    MaskIntensity,\n    MedianSmooth,\n    NormalizeIntensity,\n    RandAdjustContrast,\n    RandBiasField,\n    RandCoarseDropout,\n    RandCoarseShuffle,\n    RandGaussianNoise,\n    RandGaussianSharpen,\n    RandGaussianSmooth,\n    RandGibbsNoise,\n    RandHistogramShift,\n    RandKSpaceSpikeNoise,\n    RandRicianNoise,\n    RandScaleIntensity,\n    RandScaleIntensityFixedMean,\n    RandShiftIntensity,\n    RandStdShiftIntensity,\n    SavitzkyGolaySmooth,\n    ScaleIntensity,\n    ScaleIntensityRange,\n    ScaleIntensityRangePercentiles,\n    ShiftIntensity,\n    StdShiftIntensity,\n    ThresholdIntensity,\n)\nfrom monai.transforms.transform import MapTransform, RandomizableTransform\nfrom monai.transforms.utils import is_positive\nfrom monai.utils import convert_to_tensor, ensure_tuple, ensure_tuple_rep\nfrom monai.utils.enums import PostFix\n\n__all__ = [\n    \"RandGaussianNoised\",\n    \"RandRicianNoised\",\n    \"ShiftIntensityd\",\n    \"RandShiftIntensityd\",\n    \"ScaleIntensityd\",\n    \"RandScaleIntensityd\",\n    \"StdShiftIntensityd\",\n    \"RandStdShiftIntensityd\",\n    \"RandBiasFieldd\",\n    \"NormalizeIntensityd\",\n    \"ThresholdIntensityd\",\n    \"ScaleIntensityRanged\",\n    \"ClipIntensityPercentilesd\",\n    \"AdjustContrastd\",\n    \"RandAdjustContrastd\",\n    \"ScaleIntensityRangePercentilesd\",\n    \"MaskIntensityd\",\n    \"SavitzkyGolaySmoothd\",\n    \"MedianSmoothd\",\n    \"GaussianSmoothd\",\n    \"RandGaussianSmoothd\",\n    \"GaussianSharpend\",\n    \"RandGaussianSharpend\",\n    \"GibbsNoised\",\n    \"RandGibbsNoised\",\n    \"KSpaceSpikeNoised\",\n    \"RandKSpaceSpikeNoised\",\n    \"RandHistogramShiftd\",\n    \"RandCoarseDropoutd\",\n    \"RandCoarseShuffled\",\n    \"HistogramNormalized\",\n    \"ForegroundMaskd\",\n    \"ComputeHoVerMapsd\",\n    \"RandGaussianNoiseD\",\n    \"RandGaussianNoiseDict\",\n    \"ShiftIntensityD\",\n    \"ShiftIntensityDict\",\n    \"RandShiftIntensityD\",\n    \"RandShiftIntensityDict\",\n    \"ScaleIntensityD\",\n    \"ScaleIntensityDict\",\n    \"StdShiftIntensityD\",\n    \"StdShiftIntensityDict\",\n    \"RandScaleIntensityD\",\n    \"RandScaleIntensityDict\",\n    \"RandScaleIntensityFixedMeand\",\n    \"RandScaleIntensityFixedMeanDict\",\n    \"RandScaleIntensityFixedMeanD\",\n    \"RandStdShiftIntensityD\",\n    \"RandStdShiftIntensityDict\",\n    \"RandBiasFieldD\",\n    \"RandBiasFieldDict\",\n    \"NormalizeIntensityD\",\n    \"NormalizeIntensityDict\",\n    \"ThresholdIntensityD\",\n    \"ThresholdIntensityDict\",\n    \"ScaleIntensityRangeD\",\n    \"ScaleIntensityRangeDict\",\n    \"ClipIntensityPercentilesD\",\n    \"ClipIntensityPercentilesDict\",\n    \"AdjustContrastD\",\n    \"AdjustContrastDict\",\n    \"RandAdjustContrastD\",\n    \"RandAdjustContrastDict\",\n    \"ScaleIntensityRangePercentilesD\",\n    \"ScaleIntensityRangePercentilesDict\",\n    \"MaskIntensityD\",\n    \"MaskIntensityDict\",\n    \"SavitzkyGolaySmoothD\",\n    \"SavitzkyGolaySmoothDict\",\n    \"MedianSmoothD\",\n    \"MedianSmoothDict\",\n    \"GaussianSmoothD\",\n    \"GaussianSmoothDict\",\n    \"RandGaussianSmoothD\",\n    \"RandGaussianSmoothDict\",\n    \"GaussianSharpenD\",\n    \"GaussianSharpenDict\",\n    \"RandGaussianSharpenD\",\n    \"RandGaussianSharpenDict\",\n    \"GibbsNoiseD\",\n    \"GibbsNoiseDict\",\n    \"RandGibbsNoiseD\",\n    \"RandGibbsNoiseDict\",\n    \"KSpaceSpikeNoiseD\",\n    \"KSpaceSpikeNoiseDict\",\n    \"RandHistogramShiftD\",\n    \"RandHistogramShiftDict\",\n    \"RandRicianNoiseD\",\n    \"RandRicianNoiseDict\",\n    \"RandCoarseDropoutD\",\n    \"RandCoarseDropoutDict\",\n    \"RandCoarseShuffleD\",\n    \"RandCoarseShuffleDict\",\n    \"HistogramNormalizeD\",\n    \"HistogramNormalizeDict\",\n    \"RandKSpaceSpikeNoiseD\",\n    \"RandKSpaceSpikeNoiseDict\",\n    \"ForegroundMaskD\",\n    \"ForegroundMaskDict\",\n    \"ComputeHoVerMapsD\",\n    \"ComputeHoVerMapsDict\",\n]\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\nclass RandGaussianNoised(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`.\n    Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if you want to add\n    different noise for every field, please use this transform separately.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        prob: Probability to add Gaussian noise.\n        mean: Mean or “centre” of the distribution.\n        std: Standard deviation (spread) of distribution.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        allow_missing_keys: don't raise exception if key is missing.\n        sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.\n    \"\"\"\n\n    backend = RandGaussianNoise.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        mean: float = 0.0,\n        std: float = 0.1,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n        sample_std: bool = True,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandGaussianNoised:\n        super().set_random_state(seed, state)\n        self.rand_gaussian_noise.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random noise\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        self.rand_gaussian_noise.randomize(d[first_key])\n\n        for key in self.key_iterator(d):\n            d[key] = self.rand_gaussian_noise(img=d[key], randomize=False)\n        return d\n\n\nclass RandRicianNoised(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`.\n    Add Rician noise to image. This transform assumes all the expected fields have same shape, if want to add\n    different noise for every field, please use this transform separately.\n\n    Args:\n        keys: Keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        prob: Probability to add Rician noise to the dictionary.\n        mean: Mean or \"centre\" of the Gaussian distributions sampled to make up\n            the Rician noise.\n        std: Standard deviation (spread) of the Gaussian distributions sampled\n            to make up the Rician noise.\n        channel_wise: If True, treats each channel of the image separately.\n        relative: If True, the spread of the sampled Gaussian distributions will\n            be std times the standard deviation of the image or channel's intensity\n            histogram.\n        sample_std: If True, sample the spread of the Gaussian distributions\n            uniformly from 0 to std.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        allow_missing_keys: Don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RandRicianNoise.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        mean: Sequence[float] | float = 0.0,\n        std: Sequence[float] | float = 1.0,\n        channel_wise: bool = False,\n        relative: bool = False,\n        sample_std: bool = True,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_rician_noise = RandRicianNoise(\n            prob=1.0,\n            mean=mean,\n            std=std,\n            channel_wise=channel_wise,\n            relative=relative,\n            sample_std=sample_std,\n            dtype=dtype,\n        )\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRicianNoised:\n        super().set_random_state(seed, state)\n        self.rand_rician_noise.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        for key in self.key_iterator(d):\n            d[key] = self.rand_rician_noise(d[key], randomize=True)\n        return d\n\n\nclass ShiftIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`.\n    \"\"\"\n\n    backend = ShiftIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        offset: float,\n        safe: bool = False,\n        factor_key: str | None = None,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            offset: offset value to shift the intensity of image.\n            safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n                E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n            factor_key: if not None, use it as the key to extract a value from the corresponding\n                metadata dictionary of `key` at runtime, and multiply the `offset` to shift intensity.\n                Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values\n                and store in the metadata.\n                it also can be a sequence of strings, map to `keys`.\n            meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n                used to extract the factor value is `factor_key` is not None.\n                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n                the metadata is a dictionary object which contains: filename, original_shape, etc.\n                it can be a sequence of string, map to the `keys`.\n                if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n            meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according\n                to the key data, default is `meta_dict`, the metadata is a dictionary object.\n                used to extract the factor value is `factor_key` is not None.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.factor_key = ensure_tuple_rep(factor_key, len(self.keys))\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.shifter = ShiftIntensity(offset, safe)\n\n    def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, factor_key, meta_key, meta_key_postfix in self.key_iterator(\n            d, self.factor_key, self.meta_keys, self.meta_key_postfix\n        ):\n            meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n            factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None\n            offset = None if factor is None else self.shifter.offset * factor\n            d[key] = self.shifter(d[key], offset=offset)\n        return d\n\n\nclass RandShiftIntensityd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandShiftIntensity`.\n    \"\"\"\n\n    backend = RandShiftIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        offsets: tuple[float, float] | float,\n        safe: bool = False,\n        factor_key: str | None = None,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        prob: float = 0.1,\n        channel_wise: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            offsets: offset range to randomly shift.\n                if single number, offset value is picked from (-offsets, offsets).\n            safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n                E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n            factor_key: if not None, use it as the key to extract a value from the corresponding\n                metadata dictionary of `key` at runtime, and multiply the random `offset` to shift intensity.\n                Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values\n                and store in the metadata.\n                it also can be a sequence of strings, map to `keys`.\n            meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n                used to extract the factor value is `factor_key` is not None.\n                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n                the metadata is a dictionary object which contains: filename, original_shape, etc.\n                it can be a sequence of string, map to the `keys`.\n                if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n            meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according\n                to the key data, default is `meta_dict`, the metadata is a dictionary object.\n                used to extract the factor value is `factor_key` is not None.\n            prob: probability of shift.\n                (Default 0.1, with 10% probability it returns an array shifted intensity.)\n            channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen.\n                Please ensure that the first dimension represents the channel of the image if True.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n\n        self.factor_key = ensure_tuple_rep(factor_key, len(self.keys))\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0, channel_wise=channel_wise)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandShiftIntensityd:\n        super().set_random_state(seed, state)\n        self.shifter.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # expect all the specified keys have same spatial shape and share same random holes\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random shift factor\n        self.shifter.randomize(d[first_key])\n        for key, factor_key, meta_key, meta_key_postfix in self.key_iterator(\n            d, self.factor_key, self.meta_keys, self.meta_key_postfix\n        ):\n            meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n            factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None\n            d[key] = self.shifter(d[key], factor=factor, randomize=False)\n        return d\n\n\nclass StdShiftIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.StdShiftIntensity`.\n    \"\"\"\n\n    backend = StdShiftIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        factor: float,\n        nonzero: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            factor: factor shift by ``v = v + factor * std(v)``.\n            nonzero: whether only count non-zero values.\n            channel_wise: if True, calculate on each channel separately. Please ensure\n                that the first dimension represents the channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.shifter = StdShiftIntensity(factor, nonzero, channel_wise, dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.shifter(d[key])\n        return d\n\n\nclass RandStdShiftIntensityd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandStdShiftIntensity`.\n    \"\"\"\n\n    backend = RandStdShiftIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        factors: tuple[float, float] | float,\n        prob: float = 0.1,\n        nonzero: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            factors: if tuple, the randomly picked range is (min(factors), max(factors)).\n                If single number, the range is (-factors, factors).\n            prob: probability of std shift.\n            nonzero: whether only count non-zero values.\n            channel_wise: if True, calculate on each channel separately.\n            dtype: output data type, if None, same as input image. defaults to float32.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.shifter = RandStdShiftIntensity(\n            factors=factors, nonzero=nonzero, channel_wise=channel_wise, dtype=dtype, prob=1.0\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandStdShiftIntensityd:\n        super().set_random_state(seed, state)\n        self.shifter.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random shift factor\n        self.shifter.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.shifter(d[key], randomize=False)\n        return d\n\n\nclass ScaleIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensity`.\n    Scale the intensity of input image to the given value range (minv, maxv).\n    If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.\n    \"\"\"\n\n    backend = ScaleIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        minv: float | None = 0.0,\n        maxv: float | None = 1.0,\n        factor: float | None = None,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            minv: minimum value of output data.\n            maxv: maximum value of output data.\n            factor: factor scale by ``v = v * (1 + factor)``. In order to use\n                this parameter, please set both `minv` and `maxv` into None.\n            channel_wise: if True, scale on each channel separately. Please ensure\n                that the first dimension represents the channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.scaler = ScaleIntensity(minv, maxv, factor, channel_wise, dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.scaler(d[key])\n        return d\n\n\nclass RandScaleIntensityd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`.\n    \"\"\"\n\n    backend = RandScaleIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        factors: tuple[float, float] | float,\n        prob: float = 0.1,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            factors: factor range to randomly scale by ``v = v * (1 + factor)``.\n                if single number, factor value is picked from (-factors, factors).\n            prob: probability of scale.\n                (Default 0.1, with 10% probability it returns a scaled array.)\n            channel_wise: if True, scale on each channel separately. Please ensure\n                that the first dimension represents the channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.scaler = RandScaleIntensity(factors=factors, dtype=dtype, prob=1.0, channel_wise=channel_wise)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandScaleIntensityd:\n        super().set_random_state(seed, state)\n        self.scaler.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # expect all the specified keys have same spatial shape and share same random holes\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random scale factor\n        self.scaler.randomize(d[first_key])\n        for key in self.key_iterator(d):\n            d[key] = self.scaler(d[key], randomize=False)\n        return d\n\n\nclass RandScaleIntensityFixedMeand(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`.\n    Subtract the mean intensity before scaling with `factor`, then add the same value after scaling\n    to ensure that the output has the same mean as the input.\n    \"\"\"\n\n    backend = RandScaleIntensityFixedMean.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        factors: Sequence[float] | float,\n        fixed_mean: bool = True,\n        preserve_range: bool = False,\n        prob: float = 0.1,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            factors: factor range to randomly scale by ``v = v * (1 + factor)``.\n                if single number, factor value is picked from (-factors, factors).\n            preserve_range: clips the output array/tensor to the range of the input array/tensor\n            fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling\n                to ensure that the output has the same mean as the input.\n            channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied\n            on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the\n            channel of the image if True.\n            dtype: output data type, if None, same as input image. defaults to float32.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.fixed_mean = fixed_mean\n        self.preserve_range = preserve_range\n        self.scaler = RandScaleIntensityFixedMean(\n            factors=factors, fixed_mean=self.fixed_mean, preserve_range=preserve_range, dtype=dtype, prob=1.0\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandScaleIntensityFixedMeand:\n        super().set_random_state(seed, state)\n        self.scaler.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random scale factor\n        self.scaler.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.scaler(d[key], randomize=False)\n        return d\n\n\nclass RandBiasFieldd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandBiasField`.\n    \"\"\"\n\n    backend = RandBiasField.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        degree: int = 3,\n        coeff_range: tuple[float, float] = (0.0, 0.1),\n        dtype: DtypeLike = np.float32,\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            degree: degree of freedom of the polynomials. The value should be no less than 1.\n                Defaults to 3.\n            coeff_range: range of the random coefficients. Defaults to (0.0, 0.1).\n            dtype: output data type, if None, same as input image. defaults to float32.\n            prob: probability to do random bias field.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n\n        self.rand_bias_field = RandBiasField(degree=degree, coeff_range=coeff_range, dtype=dtype, prob=1.0)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandBiasFieldd:\n        super().set_random_state(seed, state)\n        self.rand_bias_field.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random bias factor\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        self.rand_bias_field.randomize(img_size=d[first_key].shape[1:])\n\n        for key in self.key_iterator(d):\n            d[key] = self.rand_bias_field(d[key], randomize=False)\n        return d\n\n\nclass NormalizeIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.NormalizeIntensity`.\n    This transform can normalize only non-zero values or entire image, and can also calculate\n    mean and std on each channel separately.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        subtrahend: the amount to subtract by (usually the mean)\n        divisor: the amount to divide by (usually the standard deviation)\n        nonzero: whether only normalize non-zero values.\n        channel_wise: if True, calculate on each channel separately, otherwise, calculate on\n            the entire image directly. default to False.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = NormalizeIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        subtrahend: NdarrayOrTensor | None = None,\n        divisor: NdarrayOrTensor | None = None,\n        nonzero: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.normalizer(d[key])\n        return d\n\n\nclass ThresholdIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ThresholdIntensity`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        threshold: the threshold to filter intensity values.\n        above: filter values above the threshold or below the threshold, default is True.\n        cval: value to fill the remaining parts of the image, default is 0.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = ThresholdIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        threshold: float,\n        above: bool = True,\n        cval: float = 0.0,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.filter = ThresholdIntensity(threshold, above, cval)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.filter(d[key])\n        return d\n\n\nclass ScaleIntensityRanged(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensityRange`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        a_min: intensity original range min.\n        a_max: intensity original range max.\n        b_min: intensity target range min.\n        b_max: intensity target range max.\n        clip: whether to perform clip after scaling.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = ScaleIntensityRange.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        a_min: float,\n        a_max: float,\n        b_min: float | None = None,\n        b_max: float | None = None,\n        clip: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip, dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.scaler(d[key])\n        return d\n\n\nclass ClipIntensityPercentilesd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ClipIntensityPercentiles`.\n    Clip the intensity values of input image to a specific range based on the intensity distribution of the input.\n    If `sharpness_factor` is provided, the intensity values will be soft clipped according to\n    f(x) = x + (1/sharpness_factor) * softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        lower: float | None,\n        upper: float | None,\n        sharpness_factor: float | None = None,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.scaler = ClipIntensityPercentiles(\n            lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype\n        )\n\n    def __call__(self, data: dict) -> dict:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.scaler(d[key])\n        return d\n\n\nclass AdjustContrastd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.AdjustContrast`.\n    Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:\n\n        `x = ((x - min) / intensity_range) ^ gamma * intensity_range + min`\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        gamma: gamma value to adjust the contrast as function.\n        invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity\n            values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked\n            from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n        retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to\n            ensure that the output intensity distribution has the same mean and standard deviation as the intensity\n            distribution of the input. This behaviour is mimicked from `nnU-Net\n            <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = AdjustContrast.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        gamma: float,\n        invert_image: bool = False,\n        retain_stats: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.adjuster = AdjustContrast(gamma, invert_image, retain_stats)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.adjuster(d[key])\n        return d\n\n\nclass RandAdjustContrastd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandAdjustContrast`.\n    Randomly changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:\n\n        `x = ((x - min) / intensity_range) ^ gamma * intensity_range + min`\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        prob: Probability of adjustment.\n        gamma: Range of gamma values.\n            If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).\n        invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity\n            values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked\n            from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n        retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to\n            ensure that the output intensity distribution has the same mean and standard deviation as the intensity\n            distribution of the input. This behaviour is mimicked from `nnU-Net\n            <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this\n            <https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_\n            function.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RandAdjustContrast.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        gamma: tuple[float, float] | float = (0.5, 4.5),\n        invert_image: bool = False,\n        retain_stats: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.adjuster = RandAdjustContrast(gamma=gamma, prob=1.0, invert_image=invert_image, retain_stats=retain_stats)\n        self.invert_image = invert_image\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandAdjustContrastd:\n        super().set_random_state(seed, state)\n        self.adjuster.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random gamma value\n        self.adjuster.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.adjuster(d[key], randomize=False)\n        return d\n\n\nclass ScaleIntensityRangePercentilesd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ScaleIntensityRangePercentiles`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        lower: lower percentile.\n        upper: upper percentile.\n        b_min: intensity target range min.\n        b_max: intensity target range max.\n        clip: whether to perform clip after scaling.\n        relative: whether to scale to the corresponding percentiles of [b_min, b_max]\n        channel_wise: if True, compute intensity percentile and normalize every channel separately.\n            default to False.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = ScaleIntensityRangePercentiles.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        lower: float,\n        upper: float,\n        b_min: float | None,\n        b_max: float | None,\n        clip: bool = False,\n        relative: bool = False,\n        channel_wise: bool = False,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, channel_wise, dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.scaler(d[key])\n        return d\n\n\nclass MaskIntensityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.MaskIntensity`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        mask_data: if mask data is single channel, apply to every channel\n            of input image. if multiple channels, the channel number must\n            match input data. the intensity values of input image corresponding\n            to the selected values in the mask data will keep the original value,\n            others will be set to `0`. if None, will extract the mask data from\n            input data based on `mask_key`.\n        mask_key: the key to extract mask data from input dictionary, only works\n            when `mask_data` is None.\n        select_fn: function to select valid values of the `mask_data`, default is\n            to select `values > 0`.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = MaskIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        mask_data: NdarrayOrTensor | None = None,\n        mask_key: str | None = None,\n        select_fn: Callable = is_positive,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = MaskIntensity(mask_data=mask_data, select_fn=select_fn)\n        self.mask_key = mask_key if mask_data is None else None\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key], d[self.mask_key]) if self.mask_key is not None else self.converter(d[key])\n        return d\n\n\nclass SavitzkyGolaySmoothd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SavitzkyGolaySmooth`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        window_length: length of the filter window, must be a positive odd integer.\n        order: order of the polynomial to fit to each window, must be less than ``window_length``.\n        axis: optional axis along which to apply the filter kernel. Default 1 (first spatial dimension).\n        mode: optional padding mode, passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``\n            or ``'circular'``. default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = SavitzkyGolaySmooth.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        window_length: int,\n        order: int,\n        axis: int = 1,\n        mode: str = \"zeros\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = SavitzkyGolaySmooth(window_length=window_length, order=order, axis=axis, mode=mode)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass MedianSmoothd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.MedianSmooth`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        radius: if a list of values, must match the count of spatial dimensions of input data,\n            and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n            use it for all spatial dimensions.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = MedianSmooth.backend\n\n    def __init__(self, keys: KeysCollection, radius: Sequence[int] | int, allow_missing_keys: bool = False) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = MedianSmooth(radius)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass GaussianSmoothd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        sigma: if a list of values, must match the count of spatial dimensions of input data,\n            and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n            use it for all spatial dimensions.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GaussianSmooth.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigma: Sequence[float] | float,\n        approx: str = \"erf\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = GaussianSmooth(sigma, approx=approx)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass RandGaussianSmoothd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        sigma_x: randomly select sigma value for the first spatial dimension.\n        sigma_y: randomly select sigma value for the second spatial dimension if have.\n        sigma_z: randomly select sigma value for the third spatial dimension if have.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n        prob: probability of Gaussian smooth.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = RandGaussianSmooth.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigma_x: tuple[float, float] = (0.25, 1.5),\n        sigma_y: tuple[float, float] = (0.25, 1.5),\n        sigma_z: tuple[float, float] = (0.25, 1.5),\n        approx: str = \"erf\",\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_smooth = RandGaussianSmooth(\n            sigma_x=sigma_x, sigma_y=sigma_y, sigma_z=sigma_z, approx=approx, prob=1.0\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandGaussianSmoothd:\n        super().set_random_state(seed, state)\n        self.rand_smooth.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random sigma\n        self.rand_smooth.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.rand_smooth(d[key], randomize=False)\n        return d\n\n\nclass GaussianSharpend(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSharpen`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        sigma1: sigma parameter for the first gaussian kernel. if a list of values, must match the count\n            of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension.\n            if only 1 value provided, use it for all spatial dimensions.\n        sigma2: sigma parameter for the second gaussian kernel. if a list of values, must match the count\n            of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension.\n            if only 1 value provided, use it for all spatial dimensions.\n        alpha: weight parameter to compute the final result.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = GaussianSharpen.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigma1: Sequence[float] | float = 3.0,\n        sigma2: Sequence[float] | float = 1.0,\n        alpha: float = 30.0,\n        approx: str = \"erf\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = GaussianSharpen(sigma1, sigma2, alpha, approx=approx)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass RandGaussianSharpend(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSharpen`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        sigma1_x: randomly select sigma value for the first spatial dimension of first gaussian kernel.\n        sigma1_y: randomly select sigma value for the second spatial dimension(if have) of first gaussian kernel.\n        sigma1_z: randomly select sigma value for the third spatial dimension(if have) of first gaussian kernel.\n        sigma2_x: randomly select sigma value for the first spatial dimension of second gaussian kernel.\n            if only 1 value `X` provided, it must be smaller than `sigma1_x` and randomly select from [X, sigma1_x].\n        sigma2_y: randomly select sigma value for the second spatial dimension(if have) of second gaussian kernel.\n            if only 1 value `Y` provided, it must be smaller than `sigma1_y` and randomly select from [Y, sigma1_y].\n        sigma2_z: randomly select sigma value for the third spatial dimension(if have) of second gaussian kernel.\n            if only 1 value `Z` provided, it must be smaller than `sigma1_z` and randomly select from [Z, sigma1_z].\n        alpha: randomly select weight parameter to compute the final result.\n        approx: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n            see also :py:meth:`monai.networks.layers.GaussianFilter`.\n        prob: probability of Gaussian sharpen.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = RandGaussianSharpen.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigma1_x: tuple[float, float] = (0.5, 1.0),\n        sigma1_y: tuple[float, float] = (0.5, 1.0),\n        sigma1_z: tuple[float, float] = (0.5, 1.0),\n        sigma2_x: tuple[float, float] | float = 0.5,\n        sigma2_y: tuple[float, float] | float = 0.5,\n        sigma2_z: tuple[float, float] | float = 0.5,\n        alpha: tuple[float, float] = (10.0, 30.0),\n        approx: str = \"erf\",\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_sharpen = RandGaussianSharpen(\n            sigma1_x=sigma1_x,\n            sigma1_y=sigma1_y,\n            sigma1_z=sigma1_z,\n            sigma2_x=sigma2_x,\n            sigma2_y=sigma2_y,\n            sigma2_z=sigma2_z,\n            alpha=alpha,\n            approx=approx,\n            prob=1.0,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandGaussianSharpend:\n        super().set_random_state(seed, state)\n        self.rand_sharpen.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random sigma1, sigma2, etc.\n        self.rand_sharpen.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.rand_sharpen(d[key], randomize=False)\n        return d\n\n\nclass RandHistogramShiftd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandHistogramShift`.\n    Apply random nonlinear transform the image's intensity histogram.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        num_control_points: number of control points governing the nonlinear intensity mapping.\n            a smaller number of control points allows for larger intensity shifts. if two values provided, number of\n            control points selecting from range (min_value, max_value).\n        prob: probability of histogram shift.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = RandHistogramShift.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        num_control_points: tuple[int, int] | int = 10,\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.shifter = RandHistogramShift(num_control_points=num_control_points, prob=1.0)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandHistogramShiftd:\n        super().set_random_state(seed, state)\n        self.shifter.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random shift params\n        self.shifter.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.shifter(d[key], randomize=False)\n        return d\n\n\nclass RandGibbsNoised(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version of RandGibbsNoise.\n\n    Naturalistic image augmentation via Gibbs artifacts. The transform\n    randomly applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts\n    are one of the common type of type artifacts appearing in MRI scans.\n\n    The transform is applied to all the channels in the data.\n\n    For general information on Gibbs artifacts, please refer to:\n    https://pubs.rsna.org/doi/full/10.1148/rg.313105115\n    https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949\n\n    Args:\n        keys: 'image', 'label', or ['image', 'label'] depending on which data\n                you need to transform.\n        prob (float): probability of applying the transform.\n        alpha (float, Sequence[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes\n            values in the interval [0,1] with alpha = 0 acting as the identity mapping.\n            If a length-2 list is given as [a,b] then the value of alpha will be sampled\n            uniformly from the interval [a,b].\n            If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].\n        allow_missing_keys: do not raise exception if key is missing.\n    \"\"\"\n\n    backend = RandGibbsNoise.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        alpha: float | Sequence[float] = (0.0, 1.0),\n        allow_missing_keys: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob=prob)\n        self.rand_gibbs_noise = RandGibbsNoise(alpha=alpha, prob=1.0)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandGibbsNoised:\n        super().set_random_state(seed, state)\n        self.rand_gibbs_noise.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # all the keys share the same random noise params\n        self.rand_gibbs_noise.randomize(None)\n        for key in self.key_iterator(d):\n            d[key] = self.rand_gibbs_noise(d[key], randomize=False)\n        return d\n\n\nclass GibbsNoised(MapTransform):\n    \"\"\"\n    Dictionary-based version of GibbsNoise.\n\n    The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts\n    are one of the common type of type artifacts appearing in MRI scans.\n\n    For general information on Gibbs artifacts, please refer to:\n    https://pubs.rsna.org/doi/full/10.1148/rg.313105115\n    https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949\n\n    Args:\n        keys: 'image', 'label', or ['image', 'label'] depending on which data\n                you need to transform.\n        alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes\n            values in the interval [0,1] with alpha = 0 acting as the identity mapping.\n        allow_missing_keys: do not raise exception if key is missing.\n    \"\"\"\n\n    backend = GibbsNoise.backend\n\n    def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.transform = GibbsNoise(alpha)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.transform(d[key])\n        return d\n\n\nclass KSpaceSpikeNoised(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.KSpaceSpikeNoise`.\n\n    Applies localized spikes in `k`-space at the given locations and intensities.\n    Spike (Herringbone) artifact is a type of data acquisition artifact which\n    may occur during MRI scans.\n\n    For general information on spike artifacts, please refer to:\n\n    `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging\n    <https://pubmed.ncbi.nlm.nih.gov/16009826>`_.\n\n    `Body MRI artifacts in clinical practice: A physicist's and radiologist's\n    perspective <https://doi.org/10.1002/jmri.24288>`_.\n\n    Args:\n        keys: \"image\", \"label\", or [\"image\", \"label\"] depending\n             on which data you need to transform.\n        loc: spatial location for the spikes. For\n            images with 3D spatial dimensions, the user can provide (C, X, Y, Z)\n            to fix which channel C is affected, or (X, Y, Z) to place the same\n            spike in all channels. For 2D cases, the user can provide (C, X, Y)\n            or (X, Y).\n        k_intensity: value for the log-intensity of the\n            `k`-space version of the image. If one location is passed to ``loc`` or the\n            channel is not specified, then this argument should receive a float. If\n            ``loc`` is given a sequence of locations, then this argument should\n            receive a sequence of intensities. This value should be tested as it is\n            data-dependent. The default values are the 2.5 the mean of the\n            log-intensity for each channel.\n        allow_missing_keys: do not raise exception if key is missing.\n\n    Example:\n        When working with 4D data,\n        ``KSpaceSpikeNoised(\"image\", loc = ((3,60,64,32), (64,60,32)), k_intensity = (13,14))``\n        will place a spike at `[3, 60, 64, 32]` with `log-intensity = 13`, and\n        one spike per channel located respectively at `[: , 64, 60, 32]`\n        with `log-intensity = 14`.\n    \"\"\"\n\n    backend = KSpaceSpikeNoise.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        loc: tuple | Sequence[tuple],\n        k_intensity: Sequence[float] | float | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.transform = KSpaceSpikeNoise(loc, k_intensity)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            data: Expects image/label to have dimensions (C, H, W) or\n                (C, H, W, D), where C is the channel.\n        \"\"\"\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.transform(d[key])\n        return d\n\n\nclass RandKSpaceSpikeNoised(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based version of :py:class:`monai.transforms.RandKSpaceSpikeNoise`.\n\n    Naturalistic data augmentation via spike artifacts. The transform applies\n    localized spikes in `k`-space.\n\n    For general information on spike artifacts, please refer to:\n\n    `AAPM/RSNA physics tutorial for residents: fundamental physics of MR imaging\n    <https://pubmed.ncbi.nlm.nih.gov/16009826>`_.\n\n    `Body MRI artifacts in clinical practice: A physicist's and radiologist's\n    perspective <https://doi.org/10.1002/jmri.24288>`_.\n\n    Args:\n        keys: \"image\", \"label\", or [\"image\", \"label\"] depending\n             on which data you need to transform.\n        prob: probability to add spike artifact to each item in the\n            dictionary provided it is realized that the noise will be applied\n            to the dictionary.\n        intensity_range: pass a tuple (a, b) to sample the log-intensity from the interval (a, b)\n            uniformly for all channels. Or pass sequence of intervals\n            ((a0, b0), (a1, b1), ...) to sample for each respective channel.\n            In the second case, the number of 2-tuples must match the number of channels.\n            Default ranges is `(0.95x, 1.10x)` where `x` is the mean\n            log-intensity for each channel.\n        channel_wise: treat each channel independently. True by default.\n        allow_missing_keys: do not raise exception if key is missing.\n\n    Example:\n        To apply `k`-space spikes randomly on the image only, with probability\n        0.5, and log-intensity sampled from the interval [13, 15] for each\n        channel independently, one uses\n        ``RandKSpaceSpikeNoised(\"image\", prob=0.5, intensity_ranges=(13, 15), channel_wise=True)``.\n    \"\"\"\n\n    backend = RandKSpaceSpikeNoise.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        intensity_range: Sequence[Sequence[float] | float] | None = None,\n        channel_wise: bool = True,\n        allow_missing_keys: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob=prob)\n        self.rand_noise = RandKSpaceSpikeNoise(prob=1.0, intensity_range=intensity_range, channel_wise=channel_wise)\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandKSpaceSpikeNoised:\n        super().set_random_state(seed, state)\n        self.rand_noise.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: dict[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        for key in self.key_iterator(d):\n            d[key] = self.rand_noise(d[key], randomize=True)\n        return d\n\n\nclass RandCoarseDropoutd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseDropout`.\n    Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions\n    for every key, if want to dropout differently for every key, please use this transform separately.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to\n            randomly select the expected number of regions.\n        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg\n            as the minimum spatial size to randomly select size for every region.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and\n            dropout the outside and fill value. default to `True`.\n        fill_value: target value to fill the dropout regions, if providing a number, will use it as constant\n            value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select\n            value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max`\n            value of input image then randomly select value to fill, default to None.\n        max_holes: if not None, define the maximum number to randomly select the expected number of regions.\n        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.\n            if some components of the `max_spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        prob: probability of applying the transform.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = RandCoarseDropout.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        holes: int,\n        spatial_size: Sequence[int] | int,\n        dropout_holes: bool = True,\n        fill_value: tuple[float, float] | float | None = None,\n        max_holes: int | None = None,\n        max_spatial_size: Sequence[int] | int | None = None,\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob=prob)\n        self.dropper = RandCoarseDropout(\n            holes=holes,\n            spatial_size=spatial_size,\n            dropout_holes=dropout_holes,\n            fill_value=fill_value,\n            max_holes=max_holes,\n            max_spatial_size=max_spatial_size,\n            prob=1.0,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandCoarseDropoutd:\n        super().set_random_state(seed, state)\n        self.dropper.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # expect all the specified keys have same spatial shape and share same random holes\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        self.dropper.randomize(d[first_key].shape[1:])\n        for key in self.key_iterator(d):\n            d[key] = self.dropper(img=d[key], randomize=False)\n\n        return d\n\n\nclass RandCoarseShuffled(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseShuffle`.\n    Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions\n    for every key, if want to shuffle different regions for every key, please use this transform separately.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to\n            randomly select the expected number of regions.\n        spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg\n            as the minimum spatial size to randomly select size for every region.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        max_holes: if not None, define the maximum number to randomly select the expected number of regions.\n        max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.\n            if some components of the `max_spatial_size` are non-positive values, the transform will use the\n            corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        prob: probability of applying the transform.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = RandCoarseShuffle.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        holes: int,\n        spatial_size: Sequence[int] | int,\n        max_holes: int | None = None,\n        max_spatial_size: Sequence[int] | int | None = None,\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob=prob)\n        self.shuffle = RandCoarseShuffle(\n            holes=holes, spatial_size=spatial_size, max_holes=max_holes, max_spatial_size=max_spatial_size, prob=1.0\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandCoarseShuffled:\n        super().set_random_state(seed, state)\n        self.shuffle.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        # expect all the specified keys have same spatial shape and share same random holes\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        self.shuffle.randomize(d[first_key].shape[1:])\n        for key in self.key_iterator(d):\n            d[key] = self.shuffle(img=d[key], randomize=False)\n\n        return d\n\n\nclass HistogramNormalized(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        num_bins: number of the bins to use in histogram, default to `256`. for more details:\n            https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.\n        min: the min value to normalize input image, default to `0`.\n        max: the max value to normalize input image, default to `255`.\n        mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.\n            only points at which `mask==True` are used for the equalization.\n            can also provide the mask by `mask_key` at runtime.\n        mask_key: if mask is None, will try to get the mask with `mask_key`.\n        dtype: output data type, if None, same as input image. defaults to float32.\n        allow_missing_keys: do not raise exception if key is missing.\n\n    \"\"\"\n\n    backend = HistogramNormalize.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        num_bins: int = 256,\n        min: int = 0,\n        max: int = 255,\n        mask: NdarrayOrTensor | None = None,\n        mask_key: str | None = None,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.transform = HistogramNormalize(num_bins=num_bins, min=min, max=max, mask=mask, dtype=dtype)\n        self.mask_key = mask_key if mask is None else None\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.transform(d[key], d[self.mask_key]) if self.mask_key is not None else self.transform(d[key])\n\n        return d\n\n\nclass ForegroundMaskd(MapTransform):\n    \"\"\"\n    Creates a binary mask that defines the foreground based on thresholds in RGB or HSV color space.\n    This transform receives an RGB (or grayscale) image where by default it is assumed that the foreground has\n    low values (dark) while the background is white.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        threshold: an int or a float number that defines the threshold that values less than that are foreground.\n            It also can be a callable that receives each dimension of the image and calculate the threshold,\n            or a string that defines such callable from `skimage.filter.threshold_...`. For the list of available\n            threshold functions, please refer to https://scikit-image.org/docs/stable/api/skimage.filters.html\n            Moreover, a dictionary can be passed that defines such thresholds for each channel, like\n            {\"R\": 100, \"G\": \"otsu\", \"B\": skimage.filter.threshold_mean}\n        hsv_threshold: similar to threshold but HSV color space (\"H\", \"S\", and \"V\").\n            Unlike RBG, in HSV, value greater than `hsv_threshold` are considered foreground.\n        invert: invert the intensity range of the input image, so that the dtype maximum is now the dtype minimum,\n            and vice-versa.\n        new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of\n            key intact. By default not prefix is set and the corresponding array to the key will be replaced.\n        allow_missing_keys: do not raise exception if key is missing.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        threshold: dict | Callable | str | float = \"otsu\",\n        hsv_threshold: dict | Callable | str | float | int | None = None,\n        invert: bool = False,\n        new_key_prefix: str | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.transform = ForegroundMask(threshold=threshold, hsv_threshold=hsv_threshold, invert=invert)\n        self.new_key_prefix = new_key_prefix\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            new_key = key if self.new_key_prefix is None else self.new_key_prefix + key\n            d[new_key] = self.transform(d[key])\n\n        return d\n\n\nclass ComputeHoVerMapsd(MapTransform):\n    \"\"\"Compute horizontal and vertical maps from an instance mask\n    It generates normalized horizontal and vertical distances to the center of mass of each region.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        dtype: the type of output Tensor. Defaults to `\"float32\"`.\n        new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of\n            key intact. Defaults to '\"_hover\", so if the input key is \"mask\" the output will be \"hover_mask\".\n        allow_missing_keys: do not raise exception if key is missing.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        dtype: DtypeLike = \"float32\",\n        new_key_prefix: str = \"hover_\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.transform = ComputeHoVerMaps(dtype=dtype)\n        self.new_key_prefix = new_key_prefix\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            new_key = key if self.new_key_prefix is None else self.new_key_prefix + key\n            d[new_key] = self.transform(d[key])\n\n        return d\n\n\nRandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised\nRandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised\nShiftIntensityD = ShiftIntensityDict = ShiftIntensityd\nRandShiftIntensityD = RandShiftIntensityDict = RandShiftIntensityd\nStdShiftIntensityD = StdShiftIntensityDict = StdShiftIntensityd\nRandStdShiftIntensityD = RandStdShiftIntensityDict = RandStdShiftIntensityd\nRandBiasFieldD = RandBiasFieldDict = RandBiasFieldd\nScaleIntensityD = ScaleIntensityDict = ScaleIntensityd\nRandScaleIntensityD = RandScaleIntensityDict = RandScaleIntensityd\nRandScaleIntensityFixedMeanD = RandScaleIntensityFixedMeanDict = RandScaleIntensityFixedMeand\nNormalizeIntensityD = NormalizeIntensityDict = NormalizeIntensityd\nThresholdIntensityD = ThresholdIntensityDict = ThresholdIntensityd\nScaleIntensityRangeD = ScaleIntensityRangeDict = ScaleIntensityRanged\nClipIntensityPercentilesD = ClipIntensityPercentilesDict = ClipIntensityPercentilesd\nAdjustContrastD = AdjustContrastDict = AdjustContrastd\nRandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd\nScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd\nMaskIntensityD = MaskIntensityDict = MaskIntensityd\nSavitzkyGolaySmoothD = SavitzkyGolaySmoothDict = SavitzkyGolaySmoothd\nMedianSmoothD = MedianSmoothDict = MedianSmoothd\nGaussianSmoothD = GaussianSmoothDict = GaussianSmoothd\nRandGaussianSmoothD = RandGaussianSmoothDict = RandGaussianSmoothd\nGaussianSharpenD = GaussianSharpenDict = GaussianSharpend\nRandGaussianSharpenD = RandGaussianSharpenDict = RandGaussianSharpend\nRandHistogramShiftD = RandHistogramShiftDict = RandHistogramShiftd\nRandGibbsNoiseD = RandGibbsNoiseDict = RandGibbsNoised\nGibbsNoiseD = GibbsNoiseDict = GibbsNoised\nKSpaceSpikeNoiseD = KSpaceSpikeNoiseDict = KSpaceSpikeNoised\nRandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised\nRandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd\nHistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized\nRandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled\nForegroundMaskD = ForegroundMaskDict = ForegroundMaskd\nComputeHoVerMapsD = ComputeHoVerMapsDict = ComputeHoVerMapsd\n"
  },
  {
    "path": "monai/transforms/inverse.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport threading\nimport warnings\nfrom collections.abc import Hashable, Mapping\nfrom contextlib import contextmanager\nfrom typing import Any\n\nimport torch\n\nfrom monai import transforms\nfrom monai.data.meta_obj import MetaObj, get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import affine_to_spacing, to_affine_nd\nfrom monai.transforms.traits import InvertibleTrait\nfrom monai.transforms.transform import Transform\nfrom monai.utils import (\n    LazyAttr,\n    MetaKeys,\n    TraceKeys,\n    TraceStatusKeys,\n    convert_to_dst_type,\n    convert_to_numpy,\n    convert_to_tensor,\n)\nfrom monai.utils.misc import MONAIEnvVars\n\n__all__ = [\"TraceableTransform\", \"InvertibleTransform\"]\n\n\nclass TraceableTransform(Transform):\n    \"\"\"\n    Maintains a stack of applied transforms to data.\n\n    Data can be one of two types:\n        1. A `MetaTensor` (this is the preferred data type).\n        2. A dictionary of data containing arrays/tensors and auxiliary metadata. In\n            this case, a key must be supplied (this dictionary-based approach is deprecated).\n\n    If `data` is of type `MetaTensor`, then the applied transform will be added to ``data.applied_operations``.\n\n    If `data` is a dictionary, then one of two things can happen:\n        1. If data[key] is a `MetaTensor`, the applied transform will be added to ``data[key].applied_operations``.\n        2. Else, the applied transform will be appended to an adjacent list using\n            `trace_key`. If, for example, the key is `image`, then the transform\n            will be appended to `image_transforms` (this dictionary-based approach is deprecated).\n\n    Hopefully it is clear that there are three total possibilities:\n        1. data is `MetaTensor`\n        2. data is dictionary, data[key] is `MetaTensor`\n        3. data is dictionary, data[key] is not `MetaTensor` (this is a deprecated approach).\n\n    The ``__call__`` method of this transform class must be implemented so\n    that the transformation information is stored during the data transformation.\n\n    The information in the stack of applied transforms must be compatible with the\n    default collate, by only storing strings, numbers and arrays.\n\n    `tracing` could be enabled by assigning to `self.tracing` or setting\n    `MONAI_TRACE_TRANSFORM` when initializing the class.\n    \"\"\"\n\n    def _init_trace_threadlocal(self):\n        \"\"\"Create a `_tracing` instance member to store the thread-local tracing state value.\"\"\"\n        # needed since this class is meant to be a trait with no constructor\n        if not hasattr(self, \"_tracing\"):\n            self._tracing = threading.local()\n\n        # This is True while the above initialising _tracing is False when this is\n        # called from a different thread than the one initialising _tracing.\n        if not hasattr(self._tracing, \"value\"):\n            self._tracing.value = MONAIEnvVars.trace_transform() != \"0\"\n\n    def __getstate__(self):\n        \"\"\"When pickling, remove the `_tracing` member from the output, if present, since it's not picklable.\"\"\"\n        _dict = dict(getattr(self, \"__dict__\", {}))  # this makes __dict__ always present in the unpickled object\n        _slots = {k: getattr(self, k) for k in getattr(self, \"__slots__\", [])}\n        _dict.pop(\"_tracing\", None)  # remove tracing\n        return _dict if len(_slots) == 0 else (_dict, _slots)\n\n    @property\n    def tracing(self) -> bool:\n        \"\"\"\n        Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != \"0\"`.\n        \"\"\"\n        self._init_trace_threadlocal()\n        return bool(self._tracing.value)\n\n    @tracing.setter\n    def tracing(self, val: bool):\n        \"\"\"Sets the thread-local tracing state to `val`.\"\"\"\n        self._init_trace_threadlocal()\n        self._tracing.value = val\n\n    @staticmethod\n    def trace_key(key: Hashable = None):\n        \"\"\"The key to store the stack of applied transforms.\"\"\"\n        if key is None:\n            return f\"{TraceKeys.KEY_SUFFIX}\"\n        return f\"{key}{TraceKeys.KEY_SUFFIX}\"\n\n    @staticmethod\n    def transform_info_keys():\n        \"\"\"The keys to store necessary info of an applied transform.\"\"\"\n        return (TraceKeys.CLASS_NAME, TraceKeys.ID, TraceKeys.TRACING, TraceKeys.DO_TRANSFORM)\n\n    def get_transform_info(self) -> dict:\n        \"\"\"\n        Return a dictionary with the relevant information pertaining to an applied transform.\n        \"\"\"\n        vals = (\n            self.__class__.__name__,\n            id(self),\n            self.tracing,\n            self._do_transform if hasattr(self, \"_do_transform\") else True,\n        )\n        return dict(zip(self.transform_info_keys(), vals))\n\n    def push_transform(self, data, *args, **kwargs):\n        \"\"\"\n        Push to a stack of applied transforms of ``data``.\n\n        Args:\n            data: dictionary of data or `MetaTensor`.\n            args: additional positional arguments to track_transform_meta.\n            kwargs: additional keyword arguments to track_transform_meta,\n                set ``replace=True`` (default False) to rewrite the last transform infor in\n                applied_operation/pending_operation based on ``self.get_transform_info()``.\n        \"\"\"\n        lazy_eval = kwargs.get(\"lazy\", False)\n        transform_info = self.get_transform_info()\n        do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, True)\n        kwargs = kwargs or {}\n        replace = kwargs.pop(\"replace\", False)  # whether to rewrite the most recently pushed transform info\n        if replace and get_track_meta() and isinstance(data, MetaTensor):\n            if not lazy_eval:\n                xform = self.pop_transform(data, check=False) if do_transform else {}\n                meta_obj = self.push_transform(data, orig_size=xform.get(TraceKeys.ORIG_SIZE), extra_info=xform)\n                return data.copy_meta_from(meta_obj)\n            if do_transform:\n                xform = data.pending_operations.pop()\n                extra = xform.copy()\n                xform.update(transform_info)\n            else:  # lazy, replace=True, do_transform=False\n                xform, extra = transform_info, {}\n            meta_obj = self.push_transform(data, transform_info=xform, lazy=True, extra_info=extra)\n            return data.copy_meta_from(meta_obj)\n        kwargs[\"lazy\"] = lazy_eval\n        if \"transform_info\" in kwargs and isinstance(kwargs[\"transform_info\"], dict):\n            kwargs[\"transform_info\"].update(transform_info)\n        else:\n            kwargs[\"transform_info\"] = transform_info\n        meta_obj = TraceableTransform.track_transform_meta(data, *args, **kwargs)\n        return data.copy_meta_from(meta_obj) if isinstance(data, MetaTensor) else data\n\n    @classmethod\n    def track_transform_meta(\n        cls,\n        data,\n        key: Hashable = None,\n        sp_size=None,\n        affine=None,\n        extra_info: dict | None = None,\n        orig_size: tuple | None = None,\n        transform_info=None,\n        lazy=False,\n    ):\n        \"\"\"\n        Update a stack of applied/pending transforms metadata of ``data``.\n\n        Args:\n            data: dictionary of data or `MetaTensor`.\n            key: if data is a dictionary, data[key] will be modified.\n            sp_size: the expected output spatial size when the transform is applied.\n                it can be tensor or numpy, but will be converted to a list of integers.\n            affine: the affine representation of the (spatial) transform in the image space.\n                When the transform is applied, meta_tensor.affine will be updated to ``meta_tensor.affine @ affine``.\n            extra_info: if desired, any extra information pertaining to the applied\n                transform can be stored in this dictionary. These are often needed for\n                computing the inverse transformation.\n            orig_size: sometimes during the inverse it is useful to know what the size\n                of the original image was, in which case it can be supplied here.\n            transform_info: info from self.get_transform_info().\n            lazy: whether to push the transform to pending_operations or applied_operations.\n\n        Returns:\n\n            For backward compatibility, if ``data`` is a dictionary, it returns the dictionary with\n            updated ``data[key]``. Otherwise, this function returns a MetaObj with updated transform metadata.\n        \"\"\"\n        data_t = data[key] if key is not None else data  # compatible with the dict data representation\n        out_obj = MetaObj()\n        # after deprecating metadict, we should always convert data_t to metatensor here\n        if isinstance(data_t, MetaTensor):\n            out_obj.copy_meta_from(data_t, keys=out_obj.__dict__.keys())\n\n        if lazy and (not get_track_meta()):\n            warnings.warn(\"metadata is not tracked, please call 'set_track_meta(True)' if doing lazy evaluation.\")\n\n        if not lazy and affine is not None and isinstance(data_t, MetaTensor):\n            # not lazy evaluation, directly update the metatensor affine (don't push to the stack)\n            orig_affine = data_t.peek_pending_affine()\n            orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]\n            try:\n                affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)\n            except RuntimeError as e:\n                if orig_affine.ndim > 2:\n                    if data_t.is_batch:\n                        msg = \"Transform applied to batched tensor, should be applied to instances only\"\n                    else:\n                        msg = \"Mismatch affine matrix, ensured that the batch dimension is not included in the calculation.\"\n                    raise RuntimeError(msg) from e\n                else:\n                    raise\n            out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device(\"cpu\"), dtype=torch.float64)\n            if MetaKeys.PIXDIM in out_obj.meta:\n                spacing = affine_to_spacing(out_obj.meta[MetaKeys.AFFINE])\n                out_obj.meta[MetaKeys.PIXDIM][1 : 1 + len(spacing)] = spacing\n\n        if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):\n            if isinstance(data, Mapping):\n                if not isinstance(data, dict):\n                    data = dict(data)\n                data[key] = data_t.copy_meta_from(out_obj) if isinstance(data_t, MetaTensor) else data_t\n                return data\n            return out_obj  # return with data_t as tensor if get_track_meta() is False\n\n        info = transform_info.copy()\n        # track the current spatial shape\n        if orig_size is not None:\n            info[TraceKeys.ORIG_SIZE] = orig_size\n        elif isinstance(data_t, MetaTensor):\n            info[TraceKeys.ORIG_SIZE] = data_t.peek_pending_shape()\n        elif hasattr(data_t, \"shape\"):\n            info[TraceKeys.ORIG_SIZE] = data_t.shape[1:]\n\n        # add lazy status to the transform info\n        info[TraceKeys.LAZY] = lazy\n\n        # include extra_info\n        if extra_info is not None:\n            extra_info.pop(LazyAttr.SHAPE, None)\n            extra_info.pop(LazyAttr.AFFINE, None)\n            info[TraceKeys.EXTRA_INFO] = extra_info\n\n        # push the transform info to the applied_operation or pending_operation stack\n        if lazy:\n            if sp_size is None:\n                if LazyAttr.SHAPE not in info:\n                    info[LazyAttr.SHAPE] = info.get(TraceKeys.ORIG_SIZE, [])\n            else:\n                info[LazyAttr.SHAPE] = sp_size\n            info[LazyAttr.SHAPE] = tuple(convert_to_numpy(info[LazyAttr.SHAPE], wrap_sequence=True).tolist())\n            if affine is None:\n                if LazyAttr.AFFINE not in info:\n                    info[LazyAttr.AFFINE] = MetaTensor.get_default_affine()\n            else:\n                info[LazyAttr.AFFINE] = affine\n            info[LazyAttr.AFFINE] = convert_to_tensor(info[LazyAttr.AFFINE], device=torch.device(\"cpu\"))\n            out_obj.push_pending_operation(info)\n        else:\n            if out_obj.pending_operations:\n                transform_name = info.get(TraceKeys.CLASS_NAME, \"\") if isinstance(info, dict) else \"\"\n                msg = (\n                    f\"Transform {transform_name} has been applied to a MetaTensor with pending operations: \"\n                    f\"{[x.get(TraceKeys.CLASS_NAME) for x in out_obj.pending_operations]}\"\n                )\n                if key is not None:\n                    msg += f\" for key {key}\"\n\n                pend = out_obj.pending_operations[-1]\n                statuses = pend.get(TraceKeys.STATUSES, dict())\n                messages = statuses.get(TraceStatusKeys.PENDING_DURING_APPLY, list())\n                messages.append(msg)\n                statuses[TraceStatusKeys.PENDING_DURING_APPLY] = messages\n                info[TraceKeys.STATUSES] = statuses\n            out_obj.push_applied_operation(info)\n        if isinstance(data, Mapping):\n            if not isinstance(data, dict):\n                data = dict(data)\n            if isinstance(data_t, MetaTensor):\n                data[key] = data_t.copy_meta_from(out_obj)\n            else:\n                x_k = TraceableTransform.trace_key(key)\n                if x_k not in data:\n                    data[x_k] = []  # If this is the first, create list\n                data[x_k].append(info)\n            return data\n        return out_obj\n\n    def check_transforms_match(self, transform: Mapping) -> None:\n        \"\"\"Check transforms are of same instance.\"\"\"\n        xform_id = transform.get(TraceKeys.ID, \"\")\n        if xform_id == id(self):\n            return\n        # TraceKeys.NONE to skip the id check\n        if xform_id == TraceKeys.NONE:\n            return\n        xform_name = transform.get(TraceKeys.CLASS_NAME, \"\")\n        warning_msg = transform.get(TraceKeys.EXTRA_INFO, {}).get(\"warn\")\n        if warning_msg:\n            warnings.warn(warning_msg)\n        # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)\n        if torch.multiprocessing.get_start_method() in (\"spawn\", None) and xform_name == self.__class__.__name__:\n            return\n        raise RuntimeError(\n            f\"Error {self.__class__.__name__} getting the most recently \"\n            f\"applied invertible transform {xform_name} {xform_id} != {id(self)}.\"\n        )\n\n    def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):\n        \"\"\"\n        Get most recent matching transform for the current class from the sequence of applied operations.\n\n        Args:\n            data: dictionary of data or `MetaTensor`.\n            key: if data is a dictionary, data[key] will be modified.\n            check: if true, check that `self` is the same type as the most recently-applied transform.\n            pop: if true, remove the transform as it is returned.\n\n        Returns:\n            Dictionary of most recently applied transform\n\n        Raises:\n            - RuntimeError: data is neither `MetaTensor` nor dictionary\n        \"\"\"\n        if not self.tracing:\n            raise RuntimeError(\"Transform Tracing must be enabled to get the most recent transform.\")\n        if isinstance(data, MetaTensor):\n            all_transforms = data.applied_operations\n        elif isinstance(data, Mapping):\n            if key in data and isinstance(data[key], MetaTensor):\n                all_transforms = data[key].applied_operations\n            else:\n                all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations())\n        else:\n            raise ValueError(f\"`data` should be either `MetaTensor` or dictionary, got {type(data)}.\")\n\n        if not all_transforms:\n            raise ValueError(f\"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'\")\n\n        if check:\n            self.check_transforms_match(all_transforms[-1])\n\n        return all_transforms.pop(-1) if pop else all_transforms[-1]\n\n    def pop_transform(self, data, key: Hashable = None, check: bool = True):\n        \"\"\"\n        Return and pop the most recent transform.\n\n        Args:\n            data: dictionary of data or `MetaTensor`\n            key: if data is a dictionary, data[key] will be modified\n            check: if true, check that `self` is the same type as the most recently-applied transform.\n\n        Returns:\n            Dictionary of most recently applied transform\n\n        Raises:\n            - RuntimeError: data is neither `MetaTensor` nor dictionary\n        \"\"\"\n        return self.get_most_recent_transform(data, key, check, pop=True)\n\n    @contextmanager\n    def trace_transform(self, to_trace: bool):\n        \"\"\"Temporarily set the tracing status of a transform with a context manager.\"\"\"\n        prev = self.tracing\n        self.tracing = to_trace\n        yield\n        self.tracing = prev\n\n\nclass InvertibleTransform(TraceableTransform, InvertibleTrait):\n    \"\"\"Classes for invertible transforms.\n\n    This class exists so that an ``invert`` method can be implemented. This allows, for\n    example, images to be cropped, rotated, padded, etc., during training and inference,\n    and after be returned to their original size before saving to file for comparison in\n    an external viewer.\n\n    When the ``inverse`` method is called:\n\n        - the inverse is called on each key individually, which allows for\n          different parameters being passed to each label (e.g., different\n          interpolation for image and label).\n\n        - the inverse transforms are applied in a last-in-first-out order. As\n          the inverse is applied, its entry is removed from the list detailing\n          the applied transformations. That is to say that during the forward\n          pass, the list of applied transforms grows, and then during the\n          inverse it shrinks back down to an empty list.\n\n    We currently check that the ``id()`` of the transform is the same in the forward and\n    inverse directions. This is a useful check to ensure that the inverses are being\n    processed in the correct order.\n\n    Note to developers: When converting a transform to an invertible transform, you need to:\n\n        #. Inherit from this class.\n        #. In ``__call__``, add a call to ``push_transform``.\n        #. Any extra information that might be needed for the inverse can be included with the\n           dictionary ``extra_info``. This dictionary should have the same keys regardless of\n           whether ``do_transform`` was `True` or `False` and can only contain objects that are\n           accepted in pytorch data loader's collate function (e.g., `None` is not allowed).\n        #. Implement an ``inverse`` method. Make sure that after performing the inverse,\n           ``pop_transform`` is called.\n\n    \"\"\"\n\n    def inverse_update(self, data):\n        \"\"\"\n        This function is to be called before every `self.inverse(data)`,\n        update each MetaTensor `data[key]` using `data[key_transforms]` and `data[key_meta_dict]`,\n        for MetaTensor backward compatibility 0.9.0.\n        \"\"\"\n        if not isinstance(data, dict) or not isinstance(self, transforms.MapTransform):\n            return data\n        d = dict(data)\n        for k in self.key_iterator(data):\n            transform_key = transforms.TraceableTransform.trace_key(k)\n            if transform_key not in data or not data[transform_key]:\n                continue\n            d = transforms.sync_meta_info(k, data, t=False)\n        return d\n\n    def inverse(self, data: Any) -> Any:\n        \"\"\"\n        Inverse of ``__call__``.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n"
  },
  {
    "path": "monai/transforms/inverse_batch_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import Any\n\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.dataloader import DataLoader as TorchDataLoader\n\nfrom monai.config import KeysCollection\nfrom monai.data.dataloader import DataLoader\nfrom monai.data.utils import decollate_batch, no_collation, pad_list_data_collate\nfrom monai.transforms.croppad.batch import PadListDataCollate\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.transform import MapTransform, Transform\nfrom monai.utils import first\n\n__all__ = [\"BatchInverseTransform\", \"Decollated\", \"DecollateD\", \"DecollateDict\"]\n\n\nclass _BatchInverseDataset(Dataset):\n\n    def __init__(self, data: Sequence[Any], transform: InvertibleTransform, pad_collation_used: bool) -> None:\n        self.data = data\n        self.invertible_transform = transform\n        self.pad_collation_used = pad_collation_used\n\n    def __getitem__(self, index: int):\n        data = dict(self.data[index])\n        # If pad collation was used, then we need to undo this first\n        if self.pad_collation_used:\n            data = PadListDataCollate.inverse(data)\n\n        if not isinstance(self.invertible_transform, InvertibleTransform):\n            warnings.warn(\"transform is not invertible, can't invert transform for the input data.\")\n            return data\n        return self.invertible_transform.inverse(data)\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n\nclass BatchInverseTransform(Transform):\n    \"\"\"\n    Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert\n    them all.\n    \"\"\"\n\n    def __init__(\n        self,\n        transform: InvertibleTransform,\n        loader: TorchDataLoader,\n        collate_fn: Callable | None = no_collation,\n        num_workers: int | None = 0,\n        detach: bool = True,\n        pad_batch: bool = True,\n        fill_value=None,\n    ) -> None:\n        \"\"\"\n        Args:\n            transform: a callable data transform on input data.\n            loader: data loader used to run `transforms` and generate the batch of data.\n            collate_fn: how to collate data after inverse transformations.\n                default won't do any collation, so the output will be a list of size batch size.\n            num_workers: number of workers when run data loader for inverse transforms,\n                default to 0 as only run 1 iteration and multi-processing may be even slower.\n                if the transforms are really slow, set num_workers for multi-processing.\n                if set to `None`, use the `num_workers` of the transform data loader.\n            detach: whether to detach the tensors. Scalars tensors will be detached into number types\n                instead of torch tensors.\n            pad_batch: when the items in a batch indicate different batch size,\n                whether to pad all the sequences to the longest.\n                If False, the batch size will be the length of the shortest sequence.\n            fill_value: the value to fill the padded sequences when `pad_batch=True`.\n\n        \"\"\"\n        self.transform = transform\n        self.batch_size = loader.batch_size\n        self.num_workers = loader.num_workers if num_workers is None else num_workers\n        self.collate_fn = collate_fn\n        self.detach = detach\n        self.pad_batch = pad_batch\n        self.fill_value = fill_value\n        self.pad_collation_used = loader.collate_fn.__doc__ == pad_list_data_collate.__doc__ or isinstance(\n            loader.collate_fn, PadListDataCollate\n        )\n\n    def __call__(self, data: dict[str, Any]) -> Any:\n        decollated_data = decollate_batch(data, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)\n        inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used)\n        inv_loader = DataLoader(\n            inv_ds, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=self.collate_fn\n        )\n        try:\n            return first(inv_loader)\n        except RuntimeError as re:\n            re_str = str(re)\n            if \"equal size\" in re_str:\n                re_str += \"\\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`.\"\n            raise RuntimeError(re_str) from re\n\n\nclass Decollated(MapTransform):\n    \"\"\"\n    Decollate a batch of data. If input is a dictionary, it also supports to only decollate specified keys.\n    Note that unlike most MapTransforms, it will delete the other keys that are not specified.\n    if `keys=None`, it will decollate all the data in the input.\n    It replicates the scalar values to every item of the decollated list.\n\n    Args:\n        keys: keys of the corresponding items to decollate, note that it will delete other keys not specified.\n            if None, will decollate all the keys. see also: :py:class:`monai.transforms.compose.MapTransform`.\n        detach: whether to detach the tensors. Scalars tensors will be detached into number types\n            instead of torch tensors.\n        pad_batch: when the items in a batch indicate different batch size,\n            whether to pad all the sequences to the longest.\n            If False, the batch size will be the length of the shortest sequence.\n        fill_value: the value to fill the padded sequences when `pad_batch=True`.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection | None = None,\n        detach: bool = True,\n        pad_batch: bool = True,\n        fill_value=None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.detach = detach\n        self.pad_batch = pad_batch\n        self.fill_value = fill_value\n\n    def __call__(self, data: dict | list):\n        d: dict | list\n        if len(self.keys) == 1 and self.keys[0] is None:\n            # it doesn't support `None` as the key\n            d = data\n        else:\n            if not isinstance(data, dict):\n                raise TypeError(\"input data is not a dictionary, but specified keys to decollate.\")\n            d = {}\n            for key in self.key_iterator(data):\n                d[key] = data[key]\n\n        return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value)\n\n\nDecollateD = DecollateDict = Decollated\n"
  },
  {
    "path": "monai/transforms/io/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/io/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for IO functions.\n\"\"\"\nfrom __future__ import annotations\n\nimport inspect\nimport json\nimport logging\nimport sys\nimport traceback\nimport warnings\nfrom collections.abc import Sequence\nfrom pathlib import Path\nfrom pydoc import locate\nfrom typing import Callable\n\nimport numpy as np\nimport torch\n\nfrom monai.config import DtypeLike, NdarrayOrTensor, PathLike\nfrom monai.data import image_writer\nfrom monai.data.folder_layout import FolderLayout, FolderLayoutBase, default_name_formatter\nfrom monai.data.image_reader import (\n    ImageReader,\n    ITKReader,\n    NibabelReader,\n    NrrdReader,\n    NumpyReader,\n    PILReader,\n    PydicomReader,\n)\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import is_no_channel\nfrom monai.transforms.transform import Transform\nfrom monai.transforms.utility.array import EnsureChannelFirst\nfrom monai.utils import (\n    GridSamplePadMode,\n    ImageMetaKey,\n    MetaKeys,\n    OptionalImportError,\n    convert_to_dst_type,\n    ensure_tuple,\n    look_up_option,\n    optional_import,\n)\n\nnib, _ = optional_import(\"nibabel\")\nImage, _ = optional_import(\"PIL.Image\")\nnrrd, _ = optional_import(\"nrrd\")\nFileLock, has_filelock = optional_import(\"filelock\", name=\"FileLock\")\n\n__all__ = [\"LoadImage\", \"SaveImage\", \"SUPPORTED_READERS\"]\n\nSUPPORTED_READERS = {\n    \"pydicomreader\": PydicomReader,\n    \"itkreader\": ITKReader,\n    \"nrrdreader\": NrrdReader,\n    \"numpyreader\": NumpyReader,\n    \"pilreader\": PILReader,\n    \"nibabelreader\": NibabelReader,\n}\n\n\ndef switch_endianness(data, new=\"<\"):\n    \"\"\"\n    Convert the input `data` endianness to `new`.\n\n    Args:\n        data: input to be converted.\n        new: the target endianness, currently support \"<\" or \">\".\n    \"\"\"\n    if isinstance(data, torch.Tensor):\n        device = data.device\n        requires_grad: bool = data.requires_grad\n        data = (\n            torch.from_numpy(switch_endianness(data.cpu().detach().numpy(), new))\n            .to(device)\n            .requires_grad_(requires_grad=requires_grad)  # type: ignore\n        )\n    elif isinstance(data, np.ndarray):\n        # default to system endian\n        sys_native = \"<\" if (sys.byteorder == \"little\") else \">\"\n        current_ = sys_native if data.dtype.byteorder not in (\"<\", \">\") else data.dtype.byteorder\n        if new not in (\"<\", \">\"):\n            raise NotImplementedError(f\"Not implemented option new={new}.\")\n        if current_ != new:\n            data = data.byteswap().view(data.dtype.newbyteorder(new))\n    elif isinstance(data, tuple):\n        data = tuple(switch_endianness(x, new) for x in data)\n    elif isinstance(data, list):\n        data = [switch_endianness(x, new) for x in data]\n    elif isinstance(data, dict):\n        data = {k: switch_endianness(v, new) for k, v in data.items()}\n    elif not isinstance(data, (bool, str, float, int, type(None))):\n        raise RuntimeError(f\"Unknown type: {type(data).__name__}\")\n    return data\n\n\nclass LoadImage(Transform):\n    \"\"\"\n    Load image file or files from provided path based on reader.\n    If reader is not specified, this class automatically chooses readers\n    based on the supported suffixes and in the following order:\n\n        - User-specified reader at runtime when calling this loader.\n        - User-specified reader in the constructor of `LoadImage`.\n        - Readers from the last to the first in the registered list.\n        - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),\n          (npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).\n\n    Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after\n    loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition\n    for non-medical specific file formats is different from other common medical packages.\n\n    See also:\n\n        - tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb\n\n    \"\"\"\n\n    def __init__(\n        self,\n        reader=None,\n        image_only: bool = True,\n        dtype: DtypeLike | None = np.float32,\n        ensure_channel_first: bool = False,\n        simple_keys: bool = False,\n        prune_meta_pattern: str | None = None,\n        prune_meta_sep: str = \".\",\n        expanduser: bool = True,\n        *args,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            reader: reader to load image file and metadata\n                - if `reader` is None, a default set of `SUPPORTED_READERS` will be used.\n                - if `reader` is a string, it's treated as a class name or dotted path\n                (such as ``\"monai.data.ITKReader\"``), the supported built-in reader classes are\n                ``\"ITKReader\"``, ``\"NibabelReader\"``, ``\"NumpyReader\"``, ``\"PydicomReader\"``.\n                a reader instance will be constructed with the `*args` and `**kwargs` parameters.\n                - if `reader` is a reader class/instance, it will be registered to this loader accordingly.\n            image_only: if True return only the image MetaTensor, otherwise return image and header dict.\n            dtype: if not None convert the loaded image to this data type.\n            ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert\n                the image array shape to `channel first`. default to `False`.\n            simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.\n            prune_meta_pattern: combined with `prune_meta_sep`, a regular expression used to match and prune keys\n                in the metadata (nested dictionary), default to None, no key deletion.\n            prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys\n                in the metadata (nested dictionary). default is \".\", see also :py:class:`monai.transforms.DeleteItemsd`.\n                e.g. ``prune_meta_pattern=\".*_code$\", prune_meta_sep=\" \"`` removes meta keys that ends with ``\"_code\"``.\n            expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.\n            args: additional parameters for reader if providing a reader name.\n            kwargs: additional parameters for reader if providing a reader name.\n\n        Note:\n\n            - The transform returns a MetaTensor, unless `set_track_meta(False)` has been used, in which case, a\n              `torch.Tensor` will be returned.\n            - If `reader` is specified, the loader will attempt to use the specified readers and the default supported\n              readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders.\n              In this case, it is therefore recommended setting the most appropriate reader as\n              the last item of the `reader` parameter.\n\n        \"\"\"\n\n        self.auto_select = reader is None\n        self.image_only = image_only\n        self.dtype = dtype\n        self.ensure_channel_first = ensure_channel_first\n        self.simple_keys = simple_keys\n        self.pattern = prune_meta_pattern\n        self.sep = prune_meta_sep\n        self.expanduser = expanduser\n\n        self.readers: list[ImageReader] = []\n        for r in SUPPORTED_READERS:  # set predefined readers as default\n            try:\n                self.register(SUPPORTED_READERS[r](*args, **kwargs))\n            except OptionalImportError:\n                logging.getLogger(self.__class__.__name__).debug(\n                    f\"required package for reader {r} is not installed, or the version doesn't match requirement.\"\n                )\n            except TypeError:  # the reader doesn't have the corresponding args/kwargs\n                logging.getLogger(self.__class__.__name__).debug(\n                    f\"{r} is not supported with the given parameters {args} {kwargs}.\"\n                )\n                self.register(SUPPORTED_READERS[r]())\n        if reader is None:\n            return  # no user-specified reader, no need to register\n\n        for _r in ensure_tuple(reader):\n            if isinstance(_r, str):\n                the_reader, has_built_in = optional_import(\"monai.data\", name=f\"{_r}\")  # search built-in\n                if not has_built_in:\n                    the_reader = locate(f\"{_r}\")  # search dotted path\n                if the_reader is None:\n                    the_reader = look_up_option(_r.lower(), SUPPORTED_READERS)\n                try:\n                    self.register(the_reader(*args, **kwargs))\n                except OptionalImportError:\n                    warnings.warn(\n                        f\"required package for reader {_r} is not installed, or the version doesn't match requirement.\"\n                    )\n                except TypeError:  # the reader doesn't have the corresponding args/kwargs\n                    warnings.warn(f\"{_r} is not supported with the given parameters {args} {kwargs}.\")\n                    self.register(the_reader())\n            elif inspect.isclass(_r):\n                self.register(_r(*args, **kwargs))\n            else:\n                self.register(_r)  # reader instance, ignoring the constructor args/kwargs\n        return\n\n    def register(self, reader: ImageReader):\n        \"\"\"\n        Register image reader to load image file and metadata.\n\n        Args:\n            reader: reader instance to be registered with this loader.\n\n        \"\"\"\n        if not isinstance(reader, ImageReader):\n            warnings.warn(f\"Preferably the reader should inherit ImageReader, but got {type(reader)}.\")\n        self.readers.append(reader)\n\n    def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader | None = None):\n        \"\"\"\n        Load image file and metadata from the given filename(s).\n        If `reader` is not specified, this class automatically chooses readers based on the\n        reversed order of registered readers `self.readers`.\n\n        Args:\n            filename: path file or file-like object or a list of files.\n                will save the filename to meta_data with key `filename_or_obj`.\n                if provided a list of files, use the filename of first file to save,\n                and will stack them together as multi-channels data.\n                if provided directory path instead of file path, will treat it as\n                DICOM images series and read.\n            reader: runtime reader to load image file and metadata.\n\n        \"\"\"\n        filename = tuple(\n            f\"{Path(s).expanduser()}\" if self.expanduser else s for s in ensure_tuple(filename)  # allow Path objects\n        )\n        img, err = None, []\n        if reader is not None:\n            img = reader.read(filename)  # runtime specified reader\n        else:\n            for reader in self.readers[::-1]:\n                if self.auto_select:  # rely on the filename extension to choose the reader\n                    if reader.verify_suffix(filename):\n                        img = reader.read(filename)\n                        break\n                else:  # try the user designated readers\n                    try:\n                        img = reader.read(filename)\n                    except Exception as e:\n                        err.append(traceback.format_exc())\n                        logging.getLogger(self.__class__.__name__).debug(e, exc_info=True)\n                        logging.getLogger(self.__class__.__name__).info(\n                            f\"{reader.__class__.__name__}: unable to load {filename}.\\n\"\n                        )\n                    else:\n                        err = []\n                        break\n\n        if img is None or reader is None:\n            if isinstance(filename, Sequence) and len(filename) == 1:\n                filename = filename[0]\n            msg = \"\\n\".join([f\"{e}\" for e in err])\n            raise RuntimeError(\n                f\"{self.__class__.__name__} cannot find a suitable reader for file: {filename}.\\n\"\n                \"    Please install the reader libraries, see also the installation instructions:\\n\"\n                \"    https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies.\\n\"\n                f\"   The current registered: {self.readers}.\\n{msg}\"\n            )\n        img_array: NdarrayOrTensor\n        img_array, meta_data = reader.get_data(img)\n        img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]\n        if not isinstance(meta_data, dict):\n            raise ValueError(f\"`meta_data` must be a dict, got type {type(meta_data)}.\")\n        # make sure all elements in metadata are little endian\n        meta_data = switch_endianness(meta_data, \"<\")\n\n        # Path obj should be strings for data loader\n        meta_data[ImageMetaKey.FILENAME_OR_OBJ] = f\"{ensure_tuple(filename)[0]}\"\n        img = MetaTensor.ensure_torch_and_prune_meta(\n            img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep\n        )\n        if self.ensure_channel_first:\n            img = EnsureChannelFirst()(img)\n        if self.image_only:\n            return img\n        return img, img.meta if isinstance(img, MetaTensor) else meta_data\n\n\nclass SaveImage(Transform):\n    \"\"\"\n    Save the image (in the form of torch tensor or numpy ndarray) and metadata dictionary into files.\n\n    The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`,\n    where the `input_image_name` is extracted from the provided metadata dictionary.\n    If no metadata provided, a running index starting from 0 will be used as the filename prefix.\n\n    Args:\n        output_dir: output image directory.\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        output_postfix: a string appended to all output file names, default to `trans`.\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        output_ext: output file extension name.\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        output_dtype: data type (if not None) for saving data. Defaults to ``np.float32``.\n        resample: whether to resample image (if needed) before saving the data array,\n            based on the ``\"spatial_shape\"`` (and ``\"original_affine\"``) from metadata.\n        mode: This option is used when ``resample=True``. Defaults to ``\"nearest\"``.\n            Depending on the writers, the possible options are\n\n            - {``\"bilinear\"``, ``\"nearest\"``, ``\"bicubic\"``}.\n              See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n            - {``\"nearest\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}.\n              See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate\n\n        padding_mode: This option is used when ``resample = True``. Defaults to ``\"border\"``.\n            Possible options are {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n        scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling\n            [0, 255] (``uint8``) or [0, 65535] (``uint16``). Default is ``None`` (no scaling).\n        dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision.\n            if ``None``, use the data type of input data. To set the output data type, use ``output_dtype``.\n        squeeze_end_dims: if ``True``, any trailing singleton dimensions will be removed (after the channel\n            has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and\n            then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If ``False``,\n            image will always be saved as (H,W,D,C).\n        data_root_dir: if not empty, it specifies the beginning parts of the input file's\n            absolute path. It's used to compute ``input_file_rel_path``, the relative path to the file from\n            ``data_root_dir`` to preserve folder structure when saving in case there are files in different\n            folders with the same file names. For example, with the following inputs:\n\n            - input_file_name: ``/foo/bar/test1/image.nii``\n            - output_postfix: ``seg``\n            - output_ext: ``.nii.gz``\n            - output_dir: ``/output``\n            - data_root_dir: ``/foo/bar``\n\n            The output will be: ``/output/test1/image/image_seg.nii.gz``\n\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        separate_folder: whether to save every file in a separate folder. For example: for the input filename\n            ``image.nii``, postfix ``seg`` and ``folder_path`` ``output``, if ``separate_folder=True``, it will be\n            saved as: ``output/image/image_seg.nii``, if ``False``, saving as ``output/image_seg.nii``.\n            Default to ``True``.\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        print_log: whether to print logs when saving. Default to ``True``.\n        output_format: an optional string of filename extension to specify the output image writer.\n            see also: ``monai.data.image_writer.SUPPORTED_WRITERS``.\n        writer: a customised ``monai.data.ImageWriter`` subclass to save data arrays.\n            if ``None``, use the default writer from ``monai.data.image_writer`` according to ``output_ext``.\n            if it's a string, it's treated as a class name or dotted path (such as ``\"monai.data.ITKWriter\"``);\n            the supported built-in writer classes are ``\"NibabelWriter\"``, ``\"ITKWriter\"``, ``\"PILWriter\"``.\n        channel_dim: the index of the channel dimension. Default to ``0``.\n            ``None`` to indicate no channel dimension.\n        output_name_formatter: a callable function (returning a kwargs dict) to format the output file name.\n            If using a custom ``monai.data.FolderLayoutBase`` class in ``folder_layout``, consider providing\n            your own formatter.\n            see also: :py:func:`monai.data.folder_layout.default_name_formatter`.\n        folder_layout: A customized ``monai.data.FolderLayoutBase`` subclass to define file naming schemes.\n            if ``None``, uses the default ``FolderLayout``.\n        savepath_in_metadict: if ``True``, adds a key ``\"saved_to\"`` to the metadata, which contains the path\n            to where the input image has been saved.\n    \"\"\"\n\n    def __init__(\n        self,\n        output_dir: PathLike = \"./\",\n        output_postfix: str = \"trans\",\n        output_ext: str = \".nii.gz\",\n        output_dtype: DtypeLike | None = np.float32,\n        resample: bool = False,\n        mode: str = \"nearest\",\n        padding_mode: str = GridSamplePadMode.BORDER,\n        scale: int | None = None,\n        dtype: DtypeLike = np.float64,\n        squeeze_end_dims: bool = True,\n        data_root_dir: PathLike = \"\",\n        separate_folder: bool = True,\n        print_log: bool = True,\n        output_format: str = \"\",\n        writer: type[image_writer.ImageWriter] | str | None = None,\n        channel_dim: int | None = 0,\n        output_name_formatter: Callable[[dict, Transform], dict] | None = None,\n        folder_layout: FolderLayoutBase | None = None,\n        savepath_in_metadict: bool = False,\n    ) -> None:\n        self.folder_layout: FolderLayoutBase\n        if folder_layout is None:\n            self.folder_layout = FolderLayout(\n                output_dir=output_dir,\n                postfix=output_postfix,\n                extension=output_ext,\n                parent=separate_folder,\n                makedirs=True,\n                data_root_dir=data_root_dir,\n            )\n        else:\n            self.folder_layout = folder_layout\n\n        self.fname_formatter: Callable\n        if output_name_formatter is None:\n            self.fname_formatter = default_name_formatter\n        else:\n            self.fname_formatter = output_name_formatter\n\n        self.output_ext = output_ext.lower() or output_format.lower()\n        self.output_ext = (\n            f\".{self.output_ext}\" if self.output_ext and not self.output_ext.startswith(\".\") else self.output_ext\n        )\n        if isinstance(writer, str):\n            writer_, has_built_in = optional_import(\"monai.data\", name=f\"{writer}\")  # search built-in\n            if not has_built_in:\n                writer_ = locate(f\"{writer}\")  # search dotted path\n            if writer_ is None:\n                raise ValueError(f\"writer {writer} not found\")\n            writer = writer_\n        self.writers = image_writer.resolve_writer(self.output_ext) if writer is None else (writer,)\n        self.writer_obj = None\n\n        _output_dtype = output_dtype\n        if self.output_ext == \".png\" and _output_dtype not in (np.uint8, np.uint16, None):\n            _output_dtype = np.uint8\n        if self.output_ext == \".dcm\" and _output_dtype not in (np.uint8, np.uint16, None):\n            _output_dtype = np.uint8\n        self.init_kwargs = {\"output_dtype\": _output_dtype, \"scale\": scale}\n        self.data_kwargs = {\"squeeze_end_dims\": squeeze_end_dims, \"channel_dim\": channel_dim}\n        self.meta_kwargs = {\"resample\": resample, \"mode\": mode, \"padding_mode\": padding_mode, \"dtype\": dtype}\n        self.write_kwargs = {\"verbose\": print_log}\n        self._data_index = 0\n        self.savepath_in_metadict = savepath_in_metadict\n\n    def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):\n        \"\"\"\n        Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries.\n\n        The arguments correspond to the following usage:\n\n            - `writer = ImageWriter(**init_kwargs)`\n            - `writer.set_data_array(array, **data_kwargs)`\n            - `writer.set_metadata(meta_data, **meta_kwargs)`\n            - `writer.write(filename, **write_kwargs)`\n\n        \"\"\"\n        if init_kwargs is not None:\n            self.init_kwargs.update(init_kwargs)\n        if data_kwargs is not None:\n            self.data_kwargs.update(data_kwargs)\n        if meta_kwargs is not None:\n            self.meta_kwargs.update(meta_kwargs)\n        if write_kwargs is not None:\n            self.write_kwargs.update(write_kwargs)\n        return self\n\n    def __call__(\n        self, img: torch.Tensor | np.ndarray, meta_data: dict | None = None, filename: str | PathLike | None = None\n    ):\n        \"\"\"\n        Args:\n            img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`.\n            meta_data: key-value pairs of metadata corresponding to the data.\n            filename: str or file-like object which to save img.\n                If specified, will ignore `self.output_name_formatter` and `self.folder_layout`.\n        \"\"\"\n        meta_data = img.meta if isinstance(img, MetaTensor) else meta_data\n        if filename is not None:\n            filename = f\"{filename}{self.output_ext}\"\n        else:\n            kw = self.fname_formatter(meta_data, self)\n            filename = self.folder_layout.filename(**kw)\n\n        if meta_data:\n            meta_spatial_shape = ensure_tuple(meta_data.get(\"spatial_shape\", ()))\n            if len(meta_spatial_shape) >= len(img.shape):\n                self.data_kwargs[\"channel_dim\"] = None\n            elif is_no_channel(self.data_kwargs.get(\"channel_dim\")):\n                warnings.warn(\n                    f\"data shape {img.shape} (with spatial shape {meta_spatial_shape}) \"\n                    f\"but SaveImage `channel_dim` is set to {self.data_kwargs.get('channel_dim')} no channel.\"\n                )\n\n        err = []\n        for writer_cls in self.writers:\n            try:\n                writer_obj = writer_cls(**self.init_kwargs)\n                writer_obj.set_data_array(data_array=img, **self.data_kwargs)\n                writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs)\n                writer_obj.write(filename, **self.write_kwargs)\n                self.writer_obj = writer_obj\n            except Exception as e:\n                err.append(traceback.format_exc())\n                logging.getLogger(self.__class__.__name__).debug(e, exc_info=True)\n                logging.getLogger(self.__class__.__name__).info(\n                    f\"{writer_cls.__class__.__name__}: unable to write {filename}.\\n\"\n                )\n            else:\n                self._data_index += 1\n                if self.savepath_in_metadict and meta_data is not None:\n                    meta_data[MetaKeys.SAVED_TO] = filename\n                return img\n        msg = \"\\n\".join([f\"{e}\" for e in err])\n        raise RuntimeError(\n            f\"{self.__class__.__name__} cannot find a suitable writer for {filename}.\\n\"\n            \"    Please install the writer libraries, see also the installation instructions:\\n\"\n            \"    https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies.\\n\"\n            f\"   The current registered writers for {self.output_ext}: {self.writers}.\\n{msg}\"\n        )\n\n\nclass WriteFileMapping(Transform):\n    \"\"\"\n    Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.\n    This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.\n\n    Args:\n        mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.\n    \"\"\"\n\n    def __init__(self, mapping_file_path: Path | str = \"mapping.json\"):\n        self.mapping_file_path = Path(mapping_file_path)\n\n    def __call__(self, img: NdarrayOrTensor):\n        \"\"\"\n        Args:\n            img: The input image with metadata.\n        \"\"\"\n        if isinstance(img, MetaTensor):\n            meta_data = img.meta\n\n        if MetaKeys.SAVED_TO not in meta_data:\n            raise KeyError(\n                \"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.\"\n            )\n\n        input_path = meta_data[ImageMetaKey.FILENAME_OR_OBJ]\n        output_path = meta_data[MetaKeys.SAVED_TO]\n        log_data = {\"input\": input_path, \"output\": output_path}\n\n        if has_filelock:\n            with FileLock(str(self.mapping_file_path) + \".lock\"):\n                self._write_to_file(log_data)\n        else:\n            self._write_to_file(log_data)\n        return img\n\n    def _write_to_file(self, log_data):\n        try:\n            with self.mapping_file_path.open(\"r\") as f:\n                existing_log_data = json.load(f)\n        except (FileNotFoundError, json.JSONDecodeError):\n            existing_log_data = []\n        existing_log_data.append(log_data)\n        with self.mapping_file_path.open(\"w\") as f:\n            json.dump(existing_log_data, f, indent=4)\n"
  },
  {
    "path": "monai/transforms/io/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for IO functions\ndefined in :py:class:`monai.transforms.io.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping\nfrom pathlib import Path\nfrom typing import Callable\n\nimport numpy as np\n\nimport monai\nfrom monai.config import DtypeLike, KeysCollection, NdarrayOrTensor\nfrom monai.data import image_writer\nfrom monai.data.image_reader import ImageReader\nfrom monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping\nfrom monai.transforms.transform import MapTransform, Transform\nfrom monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep\nfrom monai.utils.enums import PostFix\n\n__all__ = [\"LoadImaged\", \"LoadImageD\", \"LoadImageDict\", \"SaveImaged\", \"SaveImageD\", \"SaveImageDict\"]\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\nclass LoadImaged(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.LoadImage`,\n    It can load both image data and metadata. When loading a list of files in one key,\n    the arrays will be stacked and a new dimension will be added as the first dimension\n    In this case, the metadata of the first image will be used to represent the stacked result.\n    The affine transform of all the stacked images should be same.\n    The output metadata field will be created as ``meta_keys`` or ``key_{meta_key_postfix}``.\n\n    If reader is not specified, this class automatically chooses readers\n    based on the supported suffixes and in the following order:\n\n        - User-specified reader at runtime when calling this loader.\n        - User-specified reader in the constructor of `LoadImage`.\n        - Readers from the last to the first in the registered list.\n        - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),\n          (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader).\n\n    Please note that for png, jpg, bmp, and other 2D formats, readers by default swap axis 0 and 1 after\n    loading the array with ``reverse_indexing`` set to ``True`` because the spatial axes definition\n    for non-medical specific file formats is different from other common medical packages.\n\n    Note:\n\n        - If `reader` is specified, the loader will attempt to use the specified readers and the default supported\n          readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders.\n          In this case, it is therefore recommended setting the most appropriate reader as\n          the last item of the `reader` parameter.\n\n    See also:\n\n        - tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        reader: type[ImageReader] | str | None = None,\n        dtype: DtypeLike = np.float32,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        overwriting: bool = False,\n        image_only: bool = True,\n        ensure_channel_first: bool = False,\n        simple_keys: bool = False,\n        prune_meta_pattern: str | None = None,\n        prune_meta_sep: str = \".\",\n        allow_missing_keys: bool = False,\n        expanduser: bool = True,\n        *args,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            reader: reader to load image file and metadata\n                - if `reader` is None, a default set of `SUPPORTED_READERS` will be used.\n                - if `reader` is a string, it's treated as a class name or dotted path\n                (such as ``\"monai.data.ITKReader\"``), the supported built-in reader classes are\n                ``\"ITKReader\"``, ``\"NibabelReader\"``, ``\"NumpyReader\"``.\n                a reader instance will be constructed with the `*args` and `**kwargs` parameters.\n                - if `reader` is a reader class/instance, it will be registered to this loader accordingly.\n            dtype: if not None, convert the loaded image data to this data type.\n            meta_keys: explicitly indicate the key to store the corresponding metadata dictionary.\n                the metadata is a dictionary object which contains: filename, original_shape, etc.\n                it can be a sequence of string, map to the `keys`.\n                if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n            meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image,\n                default is `meta_dict`. The metadata is a dictionary object.\n                For example, load nifti file for `image`, store the metadata into `image_meta_dict`.\n            overwriting: whether allow overwriting existing metadata of same key.\n                default is False, which will raise exception if encountering existing key.\n            image_only: if True return dictionary containing just only the image volumes, otherwise return\n                dictionary containing image data array and header dict per input key.\n            ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert\n                the image array shape to `channel first`. default to `False`.\n            simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.\n            prune_meta_pattern: combined with `prune_meta_sep`, a regular expression used to match and prune keys\n                in the metadata (nested dictionary), default to None, no key deletion.\n            prune_meta_sep: combined with `prune_meta_pattern`, used to match and prune keys\n                in the metadata (nested dictionary). default is \".\", see also :py:class:`monai.transforms.DeleteItemsd`.\n                e.g. ``prune_meta_pattern=\".*_code$\", prune_meta_sep=\" \"`` removes meta keys that ends with ``\"_code\"``.\n            allow_missing_keys: don't raise exception if key is missing.\n            expanduser: if True cast filename to Path and call .expanduser on it, otherwise keep filename as is.\n            args: additional parameters for reader if providing a reader name.\n            kwargs: additional parameters for reader if providing a reader name.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self._loader = LoadImage(\n            reader,\n            image_only,\n            dtype,\n            ensure_channel_first,\n            simple_keys,\n            prune_meta_pattern,\n            prune_meta_sep,\n            expanduser,\n            *args,\n            **kwargs,\n        )\n        if not isinstance(meta_key_postfix, str):\n            raise TypeError(f\"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.\")\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\n                f\"meta_keys should have the same length as keys, got {len(self.keys)} and {len(self.meta_keys)}.\"\n            )\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.overwriting = overwriting\n\n    def register(self, reader: ImageReader):\n        self._loader.register(reader)\n\n    def __call__(self, data, reader: ImageReader | None = None):\n        \"\"\"\n        Raises:\n            KeyError: When not ``self.overwriting`` and key already exists in ``data``.\n\n        \"\"\"\n        d = dict(data)\n        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):\n            data = self._loader(d[key], reader)\n            if self._loader.image_only:\n                d[key] = data\n            else:\n                if not isinstance(data, (tuple, list)):\n                    raise ValueError(\n                        f\"loader must return a tuple or list (because image_only=False was used), got {type(data)}.\"\n                    )\n                d[key] = data[0]\n                if not isinstance(data[1], dict):\n                    raise ValueError(f\"metadata must be a dict, got {type(data[1])}.\")\n                meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n                if meta_key in d and not self.overwriting:\n                    raise KeyError(f\"Metadata with key {meta_key} already exists and overwriting=False.\")\n                d[meta_key] = data[1]\n        return d\n\n\nclass SaveImaged(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`.\n\n    Note:\n        Image should be channel-first shape: [C,H,W,[D]].\n        If the data is a patch of an image, the patch index will be appended to the filename.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            For example, for data with key ``image``, the metadata by default is in ``image_meta_dict``.\n            The metadata is a dictionary contains values such as ``filename``, ``original_shape``.\n            This argument can be a sequence of strings, mapped to the ``keys``.\n            If ``None``, will try to construct ``meta_keys`` by ``key_{meta_key_postfix}``.\n        meta_key_postfix: if ``meta_keys`` is ``None``, use ``key_{meta_key_postfix}`` to retrieve the metadict.\n        output_dir: output image directory.\n                    Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        output_postfix: a string appended to all output file names, default to ``trans``.\n                        Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        output_ext: output file extension name, available extensions: ``.nii.gz``, ``.nii``, ``.png``, ``.dcm``.\n                    Handled by ``folder_layout`` instead, if ``folder_layout`` not ``None``.\n        resample: whether to resample image (if needed) before saving the data array,\n            based on the ``spatial_shape`` (and ``original_affine``) from metadata.\n        mode: This option is used when ``resample=True``. Defaults to ``\"nearest\"``.\n            Depending on the writers, the possible options are:\n\n            - {``\"bilinear\"``, ``\"nearest\"``, ``\"bicubic\"``}.\n              See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n            - {``\"nearest\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}.\n              See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate\n\n        padding_mode: This option is used when ``resample = True``. Defaults to ``\"border\"``.\n            Possible options are {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n        scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling\n            [0, 255] (``uint8``) or [0, 65535] (``uint16``). Default is ``None`` (no scaling).\n        dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision.\n            if None, use the data type of input data. To set the output data type, use ``output_dtype``.\n        output_dtype: data type for saving data. Defaults to ``np.float32``.\n        allow_missing_keys: don't raise exception if key is missing.\n        squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel\n            has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and\n            then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`,\n            image will always be saved as (H,W,D,C).\n        data_root_dir: if not empty, it specifies the beginning parts of the input file's\n            absolute path. It's used to compute ``input_file_rel_path``, the relative path to the file from\n            ``data_root_dir`` to preserve folder structure when saving in case there are files in different\n            folders with the same file names. For example, with the following inputs:\n\n            - input_file_name: ``/foo/bar/test1/image.nii``\n            - output_postfix: ``seg``\n            - output_ext: ``.nii.gz``\n            - output_dir: ``/output``\n            - data_root_dir: ``/foo/bar``\n\n            The output will be: ``/output/test1/image/image_seg.nii.gz``\n\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        separate_folder: whether to save every file in a separate folder. For example: for the input filename\n            ``image.nii``, postfix ``seg`` and folder_path ``output``, if ``separate_folder=True``, it will be saved as:\n            ``output/image/image_seg.nii``, if ``False``, saving as ``output/image_seg.nii``. Default to ``True``.\n            Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.\n        print_log: whether to print logs when saving. Default to ``True``.\n        output_format: an optional string to specify the output image writer.\n            see also: ``monai.data.image_writer.SUPPORTED_WRITERS``.\n        writer: a customised ``monai.data.ImageWriter`` subclass to save data arrays.\n            if ``None``, use the default writer from ``monai.data.image_writer`` according to ``output_ext``.\n            if it's a string, it's treated as a class name or dotted path;\n            the supported built-in writer classes are ``\"NibabelWriter\"``, ``\"ITKWriter\"``, ``\"PILWriter\"``.\n        output_name_formatter: a callable function (returning a kwargs dict) to format the output file name.\n            see also: :py:func:`monai.data.folder_layout.default_name_formatter`.\n            If using a custom ``folder_layout``, consider providing your own formatter.\n        folder_layout: A customized ``monai.data.FolderLayoutBase`` subclass to define file naming schemes.\n            if ``None``, uses the default ``FolderLayout``.\n        savepath_in_metadict: if ``True``, adds a key ``saved_to`` to the metadata, which contains the path\n            to where the input image has been saved.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        output_dir: Path | str = \"./\",\n        output_postfix: str = \"trans\",\n        output_ext: str = \".nii.gz\",\n        resample: bool = False,\n        mode: str = \"nearest\",\n        padding_mode: str = GridSamplePadMode.BORDER,\n        scale: int | None = None,\n        dtype: DtypeLike = np.float64,\n        output_dtype: DtypeLike | None = np.float32,\n        allow_missing_keys: bool = False,\n        squeeze_end_dims: bool = True,\n        data_root_dir: str = \"\",\n        separate_folder: bool = True,\n        print_log: bool = True,\n        output_format: str = \"\",\n        writer: type[image_writer.ImageWriter] | str | None = None,\n        output_name_formatter: Callable[[dict, Transform], dict] | None = None,\n        folder_layout: monai.data.FolderLayoutBase | None = None,\n        savepath_in_metadict: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.saver = SaveImage(\n            output_dir=output_dir,\n            output_postfix=output_postfix,\n            output_ext=output_ext,\n            resample=resample,\n            mode=mode,\n            padding_mode=padding_mode,\n            scale=scale,\n            dtype=dtype,\n            output_dtype=output_dtype,\n            squeeze_end_dims=squeeze_end_dims,\n            data_root_dir=data_root_dir,\n            separate_folder=separate_folder,\n            print_log=print_log,\n            output_format=output_format,\n            writer=writer,\n            output_name_formatter=output_name_formatter,\n            folder_layout=folder_layout,\n            savepath_in_metadict=savepath_in_metadict,\n        )\n\n    def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):\n        self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs)\n        return self\n\n    def __call__(self, data):\n        d = dict(data)\n        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):\n            if meta_key is None and meta_key_postfix is not None:\n                meta_key = f\"{key}_{meta_key_postfix}\"\n            meta_data = d.get(meta_key) if meta_key is not None else None\n            self.saver(img=d[key], meta_data=meta_data)\n        return d\n\n\nclass WriteFileMappingd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.\n\n    Args:\n          keys: keys of the corresponding items to be transformed.\n              See also: :py:class:`monai.transforms.compose.MapTransform`\n          mapping_file_path: Path to the JSON file where the mappings will be saved.\n              Defaults to \"mapping.json\".\n          allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(\n        self, keys: KeysCollection, mapping_file_path: Path | str = \"mapping.json\", allow_missing_keys: bool = False\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.mapping = WriteFileMapping(mapping_file_path)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.mapping(d[key])\n        return d\n\n\nLoadImageD = LoadImageDict = LoadImaged\nSaveImageD = SaveImageDict = SaveImaged\nWriteFileMappingD = WriteFileMappingDict = WriteFileMappingd\n"
  },
  {
    "path": "monai/transforms/lazy/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/lazy/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.transforms.traits import InvertibleTrait\n\n__all__ = [\"ApplyPending\"]\n\n\nclass ApplyPending(InvertibleTrait):\n    \"\"\"\n    ApplyPending can be inserted into a pipeline that is being executed lazily in order to ensure\n    resampling happens before the next transform. It doesn't do anything itself, but its presence\n    causes the pipeline to be executed as ApplyPending doesn't implement ```LazyTrait``.\n\n    See ``Compose`` for a detailed explanation of the lazy resampling feature.\n    \"\"\"\n\n    def __call__(self, data):\n        return data\n\n    def inverse(self, data):\n        return data\n"
  },
  {
    "path": "monai/transforms/lazy/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom monai.config import KeysCollection\nfrom monai.transforms.traits import InvertibleTrait\nfrom monai.transforms.transform import MapTransform\n\n__all__ = [\"ApplyPendingd\", \"ApplyPendingD\", \"ApplyPendingDict\"]\n\n\nclass ApplyPendingd(InvertibleTrait, MapTransform):\n    \"\"\"\n    ApplyPendingd can be inserted into a pipeline that is being executed lazily in order\n    to ensure resampling happens before the next transform. It doesn't do anything itself,\n    but its presence causes the pipeline to be executed as it doesn't implement ``LazyTrait``\n\n    See ``Compose`` for a detailed explanation of the lazy resampling feature.\n\n    Args:\n        keys: the keys for tensors that should have their pending transforms executed\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection):\n        super().__init__(keys)\n\n    def __call__(self, data):\n        return data\n\n    def inverse(self, data):\n        return data\n\n\nApplyPendingD = ApplyPendingDict = ApplyPendingd\n"
  },
  {
    "path": "monai/transforms/lazy/functional.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Mapping, Sequence\nfrom typing import Any\n\nimport torch\n\nfrom monai.apps.utils import get_logger\nfrom monai.config import NdarrayOrTensor\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import to_affine_nd\nfrom monai.transforms.lazy.utils import (\n    affine_from_pending,\n    combine_transforms,\n    is_compatible_apply_kwargs,\n    kwargs_from_pending,\n    resample,\n)\nfrom monai.transforms.traits import LazyTrait\nfrom monai.transforms.transform import MapTransform\nfrom monai.utils import LazyAttr, look_up_option\n\n__all__ = [\"apply_pending_transforms\", \"apply_pending_transforms_in_order\", \"apply_pending\"]\n\n__override_keywords = {\"mode\", \"padding_mode\", \"dtype\", \"align_corners\", \"resample_mode\", \"device\"}\n\n\ndef _log_pending_info(\n    transform: Any,\n    data: Any,\n    activity: str,\n    *,\n    lazy: bool | None = None,\n    key: str | None = None,\n    logger_name: bool | str = False,\n):\n    if logger_name is False:\n        return\n    logger_name = logger_name if isinstance(logger_name, str) else \"apply_pending_transforms\"\n    logger = get_logger(logger_name)\n\n    tcname = type(transform).__name__\n    if isinstance(transform, LazyTrait):\n        tlazy = f\", transform.lazy: {transform.lazy}\"\n        if lazy is not None and lazy != transform.lazy:\n            tlazy += \" (overridden)\"\n    else:\n        tlazy = \", transform is not lazy\"\n\n    msg = f\"{activity} - lazy: {lazy}, {{key_msg}}pending: {{pcount}}, upcoming '{tcname}'{tlazy}\"\n\n    if isinstance(transform, MapTransform):\n        transform_keys = transform.keys if key is None else (key,)\n        for k in transform_keys:\n            if k in data:\n                pcount = len(data[k].pending_operations) if isinstance(data[k], MetaTensor) else 0\n                logger.info(msg.format(pcount=pcount, key_msg=f\"key: '{k}', \"))\n    else:\n        pcount = len(data.pending_operations) if isinstance(data, MetaTensor) else 0\n        logger.info(msg.format(pcount=pcount, key_msg=\"\" if key is None else f\"key: '{key}', \"))\n\n\ndef _log_applied_info(data: Any, key=None, logger_name: bool | str = False):\n    if logger_name is False:\n        return\n    logger_name = logger_name if isinstance(logger_name, str) else \"apply_pending_transforms\"\n    logger = get_logger(logger_name)\n\n    key_str = \"\" if key is None else f\"key: '{key}', \"\n    logger.info(f\"Pending transforms applied: {key_str}applied_operations: {len(data.applied_operations)}\")\n\n\ndef apply_pending_transforms(\n    data: NdarrayOrTensor | Sequence[Any | NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],\n    keys: tuple | None,\n    overrides: dict | None = None,\n    logger_name: bool | str = False,\n):\n    \"\"\"\n    apply_pending_transforms is called with either a tensor or a dictionary, some entries of which contain\n    tensors.\n\n    When operating on a dictionary of tensors, the 'keys' parameter determines what tensors should be checked.\n    If 'keys' is not set, all keys of 'data' are considered.\n\n    This method optionally takes a set of overrides that can be used to change specific parameters on the\n    transform pipeline. See ``Compose`` for more details. This method takes a logger_name that can be used\n    to override the default logger, to provide telemetry during the execution of pending transforms.\n\n    This method is intended primarily for use by ``execute_compose`` and other methods that handle the\n    underlying execution of transform pipelines. You should not need to use it in the general case, unless\n    you are developing functionality to perform such operations.\n\n    Args:\n        data: a ``torch.Tensor`` or ``MetaTensor``, or dictionary of tensors.\n        keys: an optional tuple of keys that filters the keys on 'data' if it is a dict\n        overrides: An optional dictionary that specifies parameters that can be used to override transform\n            arguments when they are called. When 'data' is a dict, this dictionary should contain a dictionary\n            of overrides for each key that needs them\n        logger_name: An optional name for a logger to be used when applying pending transforms. If None,\n            logging is suppressed.\n    Returns:\n        an object of the same type as data if pending transforms were applied, or 'data' if they were not\n    \"\"\"\n    if isinstance(data, list):\n        return [apply_pending_transforms(d, keys, overrides, logger_name) for d in data]\n    if isinstance(data, tuple):\n        return tuple(apply_pending_transforms(d, keys, overrides, logger_name) for d in data)\n\n    if isinstance(data, dict):\n        # get the keys from 'data' for metatensors with pending operations. If 'keys' is set, select\n        # only data keys that are in 'keys'\n        active_keys = [k for k in data.keys() if keys is None or k in keys]\n        keys_to_update = [k for k in active_keys if isinstance(data[k], MetaTensor) and data[k].has_pending_operations]\n\n        if len(keys_to_update) > 0:\n            rdata = dict(data)\n\n            for k in keys_to_update:\n                overrides_ = None if overrides is None else overrides.get(k, None)\n                rdata[k], _ = apply_pending(data[k], overrides=overrides_)\n                _log_applied_info(rdata[k], key=k, logger_name=logger_name)\n\n            return rdata\n    else:\n        if isinstance(data, MetaTensor) and data.has_pending_operations:\n            rdata, _ = apply_pending(data, overrides=overrides)\n            _log_applied_info(rdata, logger_name=logger_name)\n            return rdata\n\n    return data\n\n\ndef apply_pending_transforms_in_order(\n    transform, data, lazy: bool | None = None, overrides: dict | None = None, logger_name: bool | str = False\n):\n    \"\"\"\n    This method causes \"in order\" processing of pending transforms to occur.\n    \"in order\" processing of pending transforms ensures that all pending transforms have been applied to the\n    tensor before a non-lazy transform (or lazy transform that is executing non-lazily) is carried out.\n    It ensures that no operations will be added to a metatensor's apply_operations while there are outstanding\n    pending_operations. Note that there is only one mechanism for executing lazy resampling at present but this\n    is expected to change in future releases.\n\n    Evaluation of pending transforms is performed under the following circumstances:\n    * If the transform is a lazy transform and:\n      * The transform checks data as part of its execution, or\n      * the transform is not executing lazily\n    * If the transform is an ApplyPending[d] transform\n    * If the transform is not a lazy transform\n\n    This method is designed to be used only in the context of implementing lazy resampling functionality. In general\n    you should not need to interact with or use this method directly, and its API may change without warning between\n    releases. See the :ref:`Lazy Resampling topic<lazy_resampling> for more information about lazy resampling.\n\n    Args:\n        transform: a transform that should be evaluated to determine whether pending transforms should be applied\n        data: a tensor / MetaTensor, or dictionary containing tensors / MetaTensors whose pending transforms may\n            need to be applied\n        lazy: The lazy mode that is being applied (this can be False, True or None)\n        overrides: An optional dictionary containing overrides to be applied to the pending transforms when they\n            are lazily executed. If data is a dict, it should contain a dictionary of overrides for each key that\n            needs them\n        logger_name: An optional name for a logger to be used when applying pending transforms. If None,\n            logging is suppressed.\n    Returns:\n        an object of the same type as data if pending transforms were applied, or 'data' if they were not\n\n    \"\"\"\n    from monai.transforms.lazy.dictionary import ApplyPendingd\n\n    must_apply_pending = True\n    keys = transform.keys if isinstance(transform, ApplyPendingd) else None\n    if isinstance(transform, LazyTrait) and not transform.requires_current_data:\n        must_apply_pending = not (transform.lazy if lazy is None else lazy)\n\n    if must_apply_pending is True:\n        _log_pending_info(transform, data, \"Apply pending transforms\", lazy=lazy, logger_name=logger_name)\n        return apply_pending_transforms(data, keys, overrides, logger_name)\n\n    _log_pending_info(transform, data, \"Accumulate pending transforms\", lazy=lazy, logger_name=logger_name)\n    return data\n\n\ndef apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, overrides: dict | None = None):\n    \"\"\"\n    This method applies pending transforms to `data` tensors.\n    Currently, only 2d and 3d inputs are supported.\n\n    This method is designed to be called by ``apply_pending_transforms`` and other methods / classes\n    that are part of the implementation of lazy resampling. In general, you should not need to call\n    this method unless you are directly developing custom lazy execution strategies.\n\n    It works by calculating the overall effect of the accumulated pending transforms. When it runs\n    out of pending transforms or when it finds incompatibilities between the accumulated pending\n    transform and the next pending transform, it then applies the accumulated transform in a call to\n    ``resample``.\n\n    Pending transforms are incompatible with each other if one or more of the arguments in the pending\n    transforms differ. These are parameters such as 'mode', 'padding_mode', 'dtype' and so forth. If\n    a pending transform doesn't have a given parameter, it is considered compatible with the\n    accumulated transform. If a subsequent transform has a parameter that is incompatible with\n    the accumulated transform (e.g. 'mode' of 'bilinear' vs. 'mode' of 'nearest'), an intermediate\n    resample will be performed and the accumulated transform reset to its starting state.\n\n    After resampling, the pending transforms are pushed to the ``applied_transforms`` field of the\n    resulting MetaTensor. Note, if a torch.tensor is passed to this method along with a list of\n    pending transforms, the resampled tensor will be wrapped in a MetaTensor before being returned.\n\n    Args:\n        data: A torch Tensor or a monai MetaTensor.\n        pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor.\n        overrides: a dictionary of overrides for the transform arguments. The keys must be one of:\n\n            - mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order ``0-5`` (integers).\n                Interpolation mode to calculate output values. Defaults to None.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's `an integer`, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            - padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to None.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            - dtype: data type for resampling computation. Defaults to ``float64``.\n                If ``None``, use the data type of input data, this option may not be compatible the resampling backend.\n            - align_corners: Geometrically, we consider the pixels of the input as squares rather than points, when using\n                the PyTorch resampling backend. Defaults to ``False``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            - device: device for resampling computation. Defaults to ``None``.\n            - resample_mode: the mode of resampling, currently support ``\"auto\"``. Setting to other values will use the\n                :py:class:`monai.transforms.SpatialResample` for resampling (instead of potentially crop/pad).\n    \"\"\"\n    overrides = (overrides or {}).copy()\n    for k in overrides:\n        look_up_option(k, __override_keywords)  # check existence of the key\n\n    if isinstance(data, MetaTensor) and pending is None:\n        pending = data.pending_operations.copy()\n        data.clear_pending_operations()\n    pending = [] if pending is None else pending\n\n    if not pending:\n        return data, []\n\n    cumulative_xform = affine_from_pending(pending[0])\n    if cumulative_xform.shape[0] == 3:\n        cumulative_xform = to_affine_nd(3, cumulative_xform)\n\n    cur_kwargs = kwargs_from_pending(pending[0])\n    override_kwargs: dict[str, Any] = {}\n    if \"mode\" in overrides:\n        override_kwargs[LazyAttr.INTERP_MODE] = overrides[\"mode\"]\n    if \"padding_mode\" in overrides:\n        override_kwargs[LazyAttr.PADDING_MODE] = overrides[\"padding_mode\"]\n    if \"align_corners\" in overrides:\n        override_kwargs[LazyAttr.ALIGN_CORNERS] = overrides[\"align_corners\"]\n    if \"resample_mode\" in overrides:\n        override_kwargs[LazyAttr.RESAMPLE_MODE] = overrides[\"resample_mode\"]\n    override_dtype = overrides.get(\"dtype\", torch.float64)\n    override_kwargs[LazyAttr.DTYPE] = data.dtype if override_dtype is None else override_dtype\n    device = overrides.get(\"device\")\n\n    for p in pending[1:]:\n        new_kwargs = kwargs_from_pending(p)\n        if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs):\n            # carry out an intermediate resample here due to incompatibility between arguments\n            _cur_kwargs = cur_kwargs.copy()\n            _cur_kwargs.update(override_kwargs)\n            data = resample(data.to(device), cumulative_xform, _cur_kwargs)\n\n        next_matrix = affine_from_pending(p)\n        if next_matrix.shape[0] == 3:\n            next_matrix = to_affine_nd(3, next_matrix)\n\n        cumulative_xform = combine_transforms(cumulative_xform, next_matrix)\n        cur_kwargs.update(new_kwargs)\n    cur_kwargs.update(override_kwargs)\n    data = resample(data.to(device), cumulative_xform, cur_kwargs)\n    if isinstance(data, MetaTensor):\n        for p in pending:\n            data.push_applied_operation(p)\n    return data, pending\n"
  },
  {
    "path": "monai/transforms/lazy/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.config import NdarrayOrTensor\nfrom monai.data.utils import AFFINE_TOL\nfrom monai.transforms.utils_pytorch_numpy_unification import allclose\nfrom monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor, look_up_option\n\n__all__ = [\"resample\", \"combine_transforms\"]\n\n\nclass Affine:\n    \"\"\"A class to represent an affine transform matrix.\"\"\"\n\n    __slots__ = (\"data\",)\n\n    def __init__(self, data):\n        self.data = data\n\n    @staticmethod\n    def is_affine_shaped(data):\n        \"\"\"Check if the data is an affine matrix.\"\"\"\n        if isinstance(data, Affine):\n            return True\n        if isinstance(data, DisplacementField):\n            return False\n        if not hasattr(data, \"shape\") or len(data.shape) < 2:\n            return False\n        return data.shape[-1] in (3, 4) and data.shape[-1] == data.shape[-2]\n\n\nclass DisplacementField:\n    \"\"\"A class to represent a dense displacement field.\"\"\"\n\n    __slots__ = (\"data\",)\n\n    def __init__(self, data):\n        self.data = data\n\n    @staticmethod\n    def is_ddf_shaped(data):\n        \"\"\"Check if the data is a DDF.\"\"\"\n        if isinstance(data, DisplacementField):\n            return True\n        if isinstance(data, Affine):\n            return False\n        if not hasattr(data, \"shape\") or len(data.shape) < 3:\n            return False\n        return not Affine.is_affine_shaped(data)\n\n\ndef combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:\n    \"\"\"Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)\"\"\"\n    if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right):  # linear transforms\n        left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True)\n        right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True)\n        return torch.matmul(left, right)\n    if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped(\n        right\n    ):  # adds DDFs, do we need metadata if metatensor input?\n        left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True)\n        right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True)\n        return left + right\n    raise NotImplementedError\n\n\ndef affine_from_pending(pending_item):\n    \"\"\"Extract the affine matrix from a pending transform item.\"\"\"\n    if isinstance(pending_item, (torch.Tensor, np.ndarray)):\n        return pending_item\n    if isinstance(pending_item, dict):\n        return pending_item[LazyAttr.AFFINE]\n    return pending_item\n\n\ndef kwargs_from_pending(pending_item):\n    \"\"\"Extract kwargs from a pending transform item.\"\"\"\n    if not isinstance(pending_item, dict):\n        return {}\n    ret = {\n        LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None),  # interpolation mode\n        LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None),  # padding mode\n    }\n    if LazyAttr.SHAPE in pending_item:\n        ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]\n    if LazyAttr.DTYPE in pending_item:\n        ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]\n    return ret  # adding support of pending_item['extra_info']??\n\n\ndef is_compatible_apply_kwargs(kwargs_1, kwargs_2):\n    \"\"\"Check if two sets of kwargs are compatible (to be combined in `apply`).\"\"\"\n    return True\n\n\ndef requires_interp(matrix, atol=AFFINE_TOL):\n    \"\"\"\n    Check whether the transformation matrix suggests voxel-wise interpolation.\n\n    Returns None if the affine matrix suggests interpolation.\n    Otherwise, the matrix suggests that the resampling could be achieved by simple array operations\n    such as flip/permute/pad_nd/slice; in this case this function returns axes information about simple axes\n    operations.\n\n    Args:\n        matrix: the affine matrix to check.\n        atol: absolute tolerance for checking if the matrix is close to an integer.\n    \"\"\"\n    matrix = convert_to_numpy(matrix, wrap_sequence=True)\n    s = matrix[:, -1]\n    if not np.allclose(s, np.round(s), atol=atol):\n        return None\n\n    ndim = len(matrix) - 1\n    ox, oy = [], [0]\n    for x, r in enumerate(matrix[:ndim, :ndim]):\n        for y, c in enumerate(r):\n            if np.isclose(c, -1, atol=atol) or np.isclose(c, 1, atol=atol):\n                y_channel = y + 1  # the returned axis index starting with channel dim\n                if x in ox or y_channel in oy:\n                    return None\n                ox.append(x)\n                oy.append(y_channel)\n            elif not np.isclose(c, 0.0, atol=atol):\n                return None\n    return oy\n\n\n__override_lazy_keywords = {*list(LazyAttr), \"atol\"}\n\n\ndef resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):\n    \"\"\"\n    Resample `data` using the affine transformation defined by ``matrix``.\n\n    Args:\n        data: input data to be resampled.\n        matrix: affine transformation matrix.\n        kwargs: currently supports (see also: ``monai.utils.enums.LazyAttr``)\n\n            - \"lazy_shape\" for output spatial shape\n            - \"lazy_padding_mode\"\n            - \"lazy_interpolation_mode\" (this option might be ignored when ``mode=\"auto\"``.)\n            - \"lazy_align_corners\"\n            - \"lazy_dtype\" (dtype for resampling computation; this might be ignored when ``mode=\"auto\"``.)\n            - \"atol\" for tolerance for matrix floating point comparison.\n            - \"lazy_resample_mode\" for resampling backend, default to `\"auto\"`. Setting to other values will use the\n              `monai.transforms.SpatialResample` for resampling.\n\n    See Also:\n        :py:class:`monai.transforms.SpatialResample`\n    \"\"\"\n    if not Affine.is_affine_shaped(matrix):\n        raise NotImplementedError(f\"Calling the dense grid resample API directly not implemented, {matrix.shape}.\")\n    if isinstance(data, monai.data.MetaTensor) and data.pending_operations:\n        warnings.warn(\"data.pending_operations is not empty, the resampling output may be incorrect.\")\n    kwargs = kwargs or {}\n    for k in kwargs:\n        look_up_option(k, __override_lazy_keywords)\n    atol = kwargs.get(\"atol\", AFFINE_TOL)\n    mode = kwargs.get(LazyAttr.RESAMPLE_MODE, \"auto\")\n\n    init_kwargs = {\n        \"dtype\": kwargs.get(LazyAttr.DTYPE, data.dtype),\n        \"align_corners\": kwargs.get(LazyAttr.ALIGN_CORNERS, False),\n    }\n    ndim = len(matrix) - 1\n    img = convert_to_tensor(data=data, track_meta=monai.data.get_track_meta())\n    init_affine = monai.data.to_affine_nd(ndim, img.affine)\n    spatial_size = kwargs.get(LazyAttr.SHAPE, None)\n    out_spatial_size = img.peek_pending_shape() if spatial_size is None else spatial_size\n    out_spatial_size = convert_to_numpy(out_spatial_size, wrap_sequence=True)\n    call_kwargs = {\n        \"spatial_size\": out_spatial_size,\n        \"dst_affine\": init_affine @ monai.utils.convert_to_dst_type(matrix, init_affine)[0],\n        \"mode\": kwargs.get(LazyAttr.INTERP_MODE),\n        \"padding_mode\": kwargs.get(LazyAttr.PADDING_MODE),\n    }\n\n    axes = requires_interp(matrix, atol=atol)\n    if axes is not None and mode == \"auto\" and not init_kwargs[\"align_corners\"]:\n        matrix_np = np.round(convert_to_numpy(matrix, wrap_sequence=True))\n        full_transpose = np.argsort(axes).tolist()\n        if not np.allclose(full_transpose, np.arange(len(full_transpose))):\n            img = img.permute(full_transpose[: len(img.shape)])\n        in_shape = img.shape[1 : ndim + 1]  # requires that ``img`` has empty pending operations\n        matrix_np[:ndim] = matrix_np[[x - 1 for x in full_transpose[1:]]]\n        flip = [idx + 1 for idx, val in enumerate(matrix_np[:ndim]) if val[idx] == -1]\n        if flip:\n            img = torch.flip(img, dims=flip)  # todo: if on cpu, using the np.flip is faster?\n            for f in flip:\n                ind_f = f - 1\n                matrix_np[ind_f, ind_f] = 1\n                matrix_np[ind_f, -1] = in_shape[ind_f] - 1 - matrix_np[ind_f, -1]\n        if not np.all(out_spatial_size > 0):\n            raise ValueError(f\"Resampling out_spatial_size should be positive, got {out_spatial_size}.\")\n        if (\n            allclose(matrix_np, np.eye(len(matrix_np)), atol=atol)\n            and len(in_shape) == len(out_spatial_size)\n            and allclose(convert_to_numpy(in_shape, wrap_sequence=True), out_spatial_size)\n        ):\n            img.affine = call_kwargs[\"dst_affine\"]\n            img = img.to(torch.float32)  # consistent with monai.transforms.spatial.functional.spatial_resample\n            return img\n        img = monai.transforms.crop_or_pad_nd(img, matrix_np, out_spatial_size, mode=call_kwargs[\"padding_mode\"])\n        img = img.to(torch.float32)  # consistent with monai.transforms.spatial.functional.spatial_resample\n        img.affine = call_kwargs[\"dst_affine\"]\n        return img\n\n    resampler = monai.transforms.SpatialResample(**init_kwargs)\n    resampler.lazy = False  # resampler is a lazytransform\n    with resampler.trace_transform(False):  # don't track this transform in `img`\n        return resampler(img=img, **call_kwargs)\n"
  },
  {
    "path": "monai/transforms/meta_utility/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/meta_utility/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers for moving between MetaTensor types and dictionaries of data.\nThese can be used to make backwards compatible code.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping, Sequence\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import KeysCollection, NdarrayOrTensor\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.transform import MapTransform\nfrom monai.utils.enums import PostFix, TransformBackends\nfrom monai.utils.misc import ensure_tuple_rep\n\n__all__ = [\n    \"FromMetaTensord\",\n    \"FromMetaTensorD\",\n    \"FromMetaTensorDict\",\n    \"ToMetaTensord\",\n    \"ToMetaTensorD\",\n    \"ToMetaTensorDict\",\n]\n\n\nclass FromMetaTensord(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform to convert MetaTensor to a dictionary.\n\n    If input is `{\"a\": MetaTensor, \"b\": MetaTensor}`, then output will\n    have the form `{\"a\": torch.Tensor, \"a_meta_dict\": dict, \"a_transforms\": list, \"b\": ...}`.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY]\n\n    def __init__(\n        self, keys: KeysCollection, data_type: Sequence[str] | str = \"tensor\", allow_missing_keys: bool = False\n    ):\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            data_type: target data type to convert, should be \"tensor\" or \"numpy\".\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.as_tensor_output = tuple(d == \"tensor\" for d in ensure_tuple_rep(data_type, len(self.keys)))\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, t in self.key_iterator(d, self.as_tensor_output):\n            im: MetaTensor = d[key]  # type: ignore\n            d.update(im.as_dict(key, output_type=torch.Tensor if t else np.ndarray))\n            self.push_transform(d, key)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            # check transform\n            _ = self.get_most_recent_transform(d, key)\n            # do the inverse\n            im = d[key]\n            meta = d.pop(PostFix.meta(key), None)\n            transforms = d.pop(PostFix.transforms(key), None)\n            im = MetaTensor(im, meta=meta, applied_operations=transforms)  # type: ignore\n            d[key] = im\n            # Remove the applied transform\n            self.pop_transform(d, key)\n        return d\n\n\nclass ToMetaTensord(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based transform to convert a dictionary to MetaTensor.\n\n    If input is `{\"a\": torch.Tensor, \"a_meta_dict\": dict, \"b\": ...}`, then output will\n    have the form `{\"a\": MetaTensor, \"b\": MetaTensor}`.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY]\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            self.push_transform(d, key)\n            im = d[key]\n            meta = d.pop(PostFix.meta(key), None)\n            transforms = d.pop(PostFix.transforms(key), None)\n            im = MetaTensor(im, meta=meta, applied_operations=transforms)  # type: ignore\n            d[key] = im\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            # check transform\n            _ = self.get_most_recent_transform(d, key)\n            # do the inverse\n            im: MetaTensor = d[key]  # type: ignore\n            d.update(im.as_dict(key))\n            # Remove the applied transform\n            self.pop_transform(d, key)\n        return d\n\n\nFromMetaTensorD = FromMetaTensorDict = FromMetaTensord\nToMetaTensorD = ToMetaTensorDict = ToMetaTensord\n"
  },
  {
    "path": "monai/transforms/nvtx.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nWrapper around NVIDIA Tools Extension for profiling MONAI transformations\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom monai.transforms.traits import RandomizableTrait\nfrom monai.transforms.transform import Transform\nfrom monai.utils import optional_import\n\n_nvtx, _ = optional_import(\"torch._C._nvtx\", descriptor=\"NVTX is not installed. Are you sure you have a CUDA build?\")\n\n__all__ = [\n    \"Mark\",\n    \"Markd\",\n    \"MarkD\",\n    \"MarkDict\",\n    \"RandMark\",\n    \"RandMarkd\",\n    \"RandMarkD\",\n    \"RandMarkDict\",\n    \"RandRangePop\",\n    \"RandRangePopd\",\n    \"RandRangePopD\",\n    \"RandRangePopDict\",\n    \"RandRangePush\",\n    \"RandRangePushd\",\n    \"RandRangePushD\",\n    \"RandRangePushDict\",\n    \"RangePop\",\n    \"RangePopd\",\n    \"RangePopD\",\n    \"RangePopDict\",\n    \"RangePush\",\n    \"RangePushd\",\n    \"RangePushD\",\n    \"RangePushDict\",\n]\n\n\nclass RangePush(Transform):\n    \"\"\"\n    Pushes a range onto a stack of nested range span.\n    Stores zero-based depth of the range that is started.\n\n    Args:\n        msg: ASCII message to associate with range\n    \"\"\"\n\n    def __init__(self, msg: str) -> None:\n        self.msg = msg\n        self.depth = None\n\n    def __call__(self, data):\n        self.depth = _nvtx.rangePushA(self.msg)\n        return data\n\n\nclass RandRangePush(RangePush, RandomizableTrait):\n    \"\"\"\n    Pushes a range onto a stack of nested range span (for randomizable transforms).\n    Stores zero-based depth of the range that is started.\n\n    Args:\n        msg: ASCII message to associate with range\n    \"\"\"\n\n\nclass RangePop(Transform):\n    \"\"\"\n    Pops a range off of a stack of nested range spans.\n    Stores zero-based depth of the range that is ended.\n    \"\"\"\n\n    def __call__(self, data):\n        _nvtx.rangePop()\n        return data\n\n\nclass RandRangePop(RangePop, RandomizableTrait):\n    \"\"\"\n    Pops a range off of a stack of nested range spans (for randomizable transforms).\n    Stores zero-based depth of the range that is ended.\n    \"\"\"\n\n\nclass Mark(Transform):\n    \"\"\"\n    Mark an instantaneous event that occurred at some point.\n\n    Args:\n        msg: ASCII message to associate with the event.\n    \"\"\"\n\n    def __init__(self, msg: str) -> None:\n        self.msg = msg\n\n    def __call__(self, data):\n        _nvtx.markA(self.msg)\n        return data\n\n\nclass RandMark(Mark, RandomizableTrait):\n    \"\"\"\n    Mark an instantaneous event that occurred at some point (for randomizable transforms).\n\n    Args:\n        msg: ASCII message to associate with the event.\n    \"\"\"\n\n\nRangePushDict = RangePushD = RangePushd = RangePush\nRandRangePushDict = RandRangePushD = RandRangePushd = RandRangePush\n\nRangePopDict = RangePopD = RangePopd = RangePop\nRandRangePopDict = RandRangePopD = RandRangePopd = RandRangePop\n\nMarkDict = MarkD = Markd = Mark\nRandMarkDict = RandMarkD = RandMarkd = RandMark\n"
  },
  {
    "path": "monai/transforms/post/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/post/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for the model output tensors.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Iterable, Sequence\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.networks import one_hot\nfrom monai.networks.layers import GaussianFilter, apply_filter, separable_filtering\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.transform import Transform\nfrom monai.transforms.utility.array import ToTensor\nfrom monai.transforms.utils import (\n    convert_applied_interp_mode,\n    distance_transform_edt,\n    fill_holes,\n    get_largest_connected_component_mask,\n    get_unique_labels,\n    remove_small_objects,\n)\nfrom monai.transforms.utils_pytorch_numpy_unification import unravel_index\nfrom monai.utils import (\n    TransformBackends,\n    convert_data_type,\n    convert_to_tensor,\n    ensure_tuple,\n    get_equivalent_dtype,\n    look_up_option,\n)\nfrom monai.utils.type_conversion import convert_to_dst_type\n\n__all__ = [\n    \"Activations\",\n    \"AsDiscrete\",\n    \"FillHoles\",\n    \"KeepLargestConnectedComponent\",\n    \"RemoveSmallObjects\",\n    \"LabelFilter\",\n    \"LabelToContour\",\n    \"MeanEnsemble\",\n    \"ProbNMS\",\n    \"SobelGradients\",\n    \"VoteEnsemble\",\n    \"Invert\",\n    \"GenerateHeatmap\",\n    \"DistanceTransformEDT\",\n]\n\n\nclass Activations(Transform):\n    \"\"\"\n    Activation operations, typically `Sigmoid` or `Softmax`.\n\n    Args:\n        sigmoid: whether to execute sigmoid function on model output before transform.\n            Defaults to ``False``.\n        softmax: whether to execute softmax function on model output before transform.\n            Defaults to ``False``.\n        other: callable function to execute other activation layers, for example:\n            `other = lambda x: torch.tanh(x)`. Defaults to ``None``.\n        kwargs: additional parameters to `torch.softmax` (used when ``softmax=True``).\n            Defaults to ``dim=0``, unrecognized parameters will be ignored.\n\n    Raises:\n        TypeError: When ``other`` is not an ``Optional[Callable]``.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Callable | None = None, **kwargs) -> None:\n        self.sigmoid = sigmoid\n        self.softmax = softmax\n        self.kwargs = kwargs\n        if other is not None and not callable(other):\n            raise TypeError(f\"other must be None or callable but is {type(other).__name__}.\")\n        self.other = other\n\n    def __call__(\n        self,\n        img: NdarrayOrTensor,\n        sigmoid: bool | None = None,\n        softmax: bool | None = None,\n        other: Callable | None = None,\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            sigmoid: whether to execute sigmoid function on model output before transform.\n                Defaults to ``self.sigmoid``.\n            softmax: whether to execute softmax function on model output before transform.\n                Defaults to ``self.softmax``.\n            other: callable function to execute other activation layers, for example:\n                `other = torch.tanh`. Defaults to ``self.other``.\n\n        Raises:\n            ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.\n            TypeError: When ``other`` is not an ``Optional[Callable]``.\n            ValueError: When ``self.other=None`` and ``other=None``. Incompatible values.\n\n        \"\"\"\n        if sigmoid and softmax:\n            raise ValueError(\"Incompatible values: sigmoid=True and softmax=True.\")\n        if other is not None and not callable(other):\n            raise TypeError(f\"other must be None or callable but is {type(other).__name__}.\")\n\n        # convert to float as activation must operate on float tensor\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)\n        if sigmoid or self.sigmoid:\n            img_t = torch.sigmoid(img_t)\n        if softmax or self.softmax:\n            img_t = torch.softmax(img_t, dim=self.kwargs.get(\"dim\", 0))\n\n        act_func = self.other if other is None else other\n        if act_func is not None:\n            img_t = act_func(img_t)\n        out, *_ = convert_to_dst_type(img_t, img)\n        return out\n\n\nclass AsDiscrete(Transform):\n    \"\"\"\n    Convert the input tensor/array into discrete values, possible operations are:\n\n        -  `argmax`.\n        -  threshold input value to binary values.\n        -  convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes).\n        -  round the value to the closest integer.\n\n    Args:\n        argmax: whether to execute argmax function on input data before transform.\n            Defaults to ``False``.\n        to_onehot: if not None, convert input data into the one-hot format with specified number of classes.\n            Defaults to ``None``.\n        threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold.\n            Defaults to ``None``.\n        rounding: if not None, round the data according to the specified option,\n            available options: [\"torchrounding\"].\n        kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`.\n            currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.\n            These default to ``0``, ``True``, ``torch.float`` respectively.\n\n    Example:\n\n        >>> transform = AsDiscrete(argmax=True)\n        >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))\n        # [[[1.0, 1.0]]]\n\n        >>> transform = AsDiscrete(threshold=0.6)\n        >>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]])))\n        # [[[0.0, 0.0], [1.0, 1.0]]]\n\n        >>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5)\n        >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))\n        # [[[0.0, 0.0]], [[1.0, 1.0]]]\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        argmax: bool = False,\n        to_onehot: int | None = None,\n        threshold: float | None = None,\n        rounding: str | None = None,\n        **kwargs,\n    ) -> None:\n        self.argmax = argmax\n        if isinstance(to_onehot, bool):  # for backward compatibility\n            raise ValueError(\"`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.\")\n        self.to_onehot = to_onehot\n        self.threshold = threshold\n        self.rounding = rounding\n        self.kwargs = kwargs\n\n    def __call__(\n        self,\n        img: NdarrayOrTensor,\n        argmax: bool | None = None,\n        to_onehot: int | None = None,\n        threshold: float | None = None,\n        rounding: str | None = None,\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,\n                will automatically add it.\n            argmax: whether to execute argmax function on input data before transform.\n                Defaults to ``self.argmax``.\n            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.\n                Defaults to ``self.to_onehot``.\n            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.\n                Defaults to ``self.threshold``.\n            rounding: if not None, round the data according to the specified option,\n                available options: [\"torchrounding\"].\n\n        \"\"\"\n        if isinstance(to_onehot, bool):\n            raise ValueError(\"`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.\")\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_t, *_ = convert_data_type(img, torch.Tensor)\n        argmax = self.argmax if argmax is None else argmax\n        if argmax:\n            img_t = torch.argmax(img_t, dim=self.kwargs.get(\"dim\", 0), keepdim=self.kwargs.get(\"keepdim\", True))\n\n        to_onehot = self.to_onehot if to_onehot is None else to_onehot\n        if to_onehot is not None:\n            if not isinstance(to_onehot, int):\n                raise ValueError(f\"the number of classes for One-Hot must be an integer, got {type(to_onehot)}.\")\n            img_t = one_hot(\n                img_t, num_classes=to_onehot, dim=self.kwargs.get(\"dim\", 0), dtype=self.kwargs.get(\"dtype\", torch.float)\n            )\n\n        threshold = self.threshold if threshold is None else threshold\n        if threshold is not None:\n            img_t = img_t >= threshold\n\n        rounding = self.rounding if rounding is None else rounding\n        if rounding is not None:\n            look_up_option(rounding, [\"torchrounding\"])\n            img_t = torch.round(img_t)\n\n        img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get(\"dtype\", torch.float))\n        return img\n\n\nclass KeepLargestConnectedComponent(Transform):\n    \"\"\"\n    Keeps only the largest connected component in the image.\n    This transform can be used as a post-processing step to clean up over-segment areas in model output.\n\n    The input is assumed to be a channel-first PyTorch Tensor:\n      1) For not OneHot format data, the values correspond to expected labels,\n      0 will be treated as background and the over-segment pixels will be set to 0.\n      2) For OneHot format data, the values should be 0, 1 on each labels,\n      the over-segment pixels will be set to 0 in its channel.\n\n    For example:\n    Use with applied_labels=[1], is_onehot=False, connectivity=1::\n\n       [1, 0, 0]         [0, 0, 0]\n       [0, 1, 1]    =>   [0, 1 ,1]\n       [0, 1, 1]         [0, 1, 1]\n\n    Use with applied_labels=[1, 2], is_onehot=False, independent=False, connectivity=1::\n\n      [0, 0, 1, 0 ,0]           [0, 0, 1, 0 ,0]\n      [0, 2, 1, 1 ,1]           [0, 2, 1, 1 ,1]\n      [1, 2, 1, 0 ,0]    =>     [1, 2, 1, 0 ,0]\n      [1, 2, 0, 1 ,0]           [1, 2, 0, 0 ,0]\n      [2, 2, 0, 0 ,2]           [2, 2, 0, 0 ,0]\n\n    Use with applied_labels=[1, 2], is_onehot=False, independent=True, connectivity=1::\n\n      [0, 0, 1, 0 ,0]           [0, 0, 1, 0 ,0]\n      [0, 2, 1, 1 ,1]           [0, 2, 1, 1 ,1]\n      [1, 2, 1, 0 ,0]    =>     [0, 2, 1, 0 ,0]\n      [1, 2, 0, 1 ,0]           [0, 2, 0, 0 ,0]\n      [2, 2, 0, 0 ,2]           [2, 2, 0, 0 ,0]\n\n    Use with applied_labels=[1, 2], is_onehot=False, independent=False, connectivity=2::\n\n      [0, 0, 1, 0 ,0]           [0, 0, 1, 0 ,0]\n      [0, 2, 1, 1 ,1]           [0, 2, 1, 1 ,1]\n      [1, 2, 1, 0 ,0]    =>     [1, 2, 1, 0 ,0]\n      [1, 2, 0, 1 ,0]           [1, 2, 0, 1 ,0]\n      [2, 2, 0, 0 ,2]           [2, 2, 0, 0 ,2]\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.CUPY]\n\n    def __init__(\n        self,\n        applied_labels: Sequence[int] | int | None = None,\n        is_onehot: bool | None = None,\n        independent: bool = True,\n        connectivity: int | None = None,\n        num_components: int = 1,\n    ) -> None:\n        \"\"\"\n        Args:\n            applied_labels: Labels for applying the connected component analysis on.\n                If given, voxels whose value is in this list will be analyzed.\n                If `None`, all non-zero values will be analyzed.\n            is_onehot: if `True`, treat the input data as OneHot format data, otherwise, not OneHot format data.\n                default to None, which treats multi-channel data as OneHot and single channel data as not OneHot.\n            independent: whether to treat ``applied_labels`` as a union of foreground labels.\n                If ``True``, the connected component analysis will be performed on each foreground label independently\n                and return the intersection of the largest components.\n                If ``False``, the analysis will be performed on the union of foreground labels.\n                default is `True`.\n            connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n                Accepted values are ranging from  1 to input.ndim. If ``None``, a full\n                connectivity of ``input.ndim`` is used. for more details:\n                https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.\n            num_components: The number of largest components to preserve.\n\n        \"\"\"\n        super().__init__()\n        self.applied_labels = ensure_tuple(applied_labels) if applied_labels is not None else None\n        self.is_onehot = is_onehot\n        self.independent = independent\n        self.connectivity = connectivity\n        self.num_components = num_components\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: shape must be (C, spatial_dim1[, spatial_dim2, ...]).\n\n        Returns:\n            An array with shape (C, spatial_dim1[, spatial_dim2, ...]).\n        \"\"\"\n        is_onehot = img.shape[0] > 1 if self.is_onehot is None else self.is_onehot\n        if self.applied_labels is not None:\n            applied_labels = self.applied_labels\n        else:\n            applied_labels = tuple(get_unique_labels(img, is_onehot, discard=0))\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_: torch.Tensor = convert_to_tensor(img, track_meta=False)\n        if self.independent:\n            for i in applied_labels:\n                foreground = img_[i] > 0 if is_onehot else img_[0] == i\n                mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)\n                if is_onehot:\n                    img_[i][foreground != mask] = 0\n                else:\n                    img_[0][foreground != mask] = 0\n            return convert_to_dst_type(img_, dst=img)[0]\n        if not is_onehot:  # not one-hot, union of labels\n            labels, *_ = convert_to_dst_type(applied_labels, dst=img_, wrap_sequence=True)\n            foreground = (img_[..., None] == labels).any(-1)[0]\n            mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)\n            img_[0][foreground != mask] = 0\n            return convert_to_dst_type(img_, dst=img)[0]\n        # one-hot, union of labels\n        foreground = (img_[applied_labels, ...] == 1).any(0)\n        mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)\n        for i in applied_labels:\n            img_[i][foreground != mask] = 0\n        return convert_to_dst_type(img_, dst=img)[0]\n\n\nclass RemoveSmallObjects(Transform):\n    \"\"\"\n    Use `skimage.morphology.remove_small_objects` to remove small objects from images.\n    See: https://scikit-image.org/docs/dev/api/skimage.morphology.html#remove-small-objects.\n\n    Data should be one-hotted.\n\n    Args:\n        min_size: objects smaller than this size (in number of voxels; or surface area/volume value\n            in whatever units your image is if by_measure is True) are removed.\n        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n            Accepted values are ranging from  1 to input.ndim. If ``None``, a full\n            connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image\n            documentation.\n        independent_channels: Whether or not to consider channels as independent. If true, then\n            conjoining islands from different labels will be removed if they are below the threshold.\n            If false, the overall size islands made from all non-background voxels will be used.\n        by_measure: Whether the specified min_size is in number of voxels. if this is True then min_size\n            represents a surface area or volume value of whatever units your image is in (mm^3, cm^2, etc.)\n            default is False. e.g. if min_size is 3, by_measure is True and the units of your data is mm,\n            objects smaller than 3mm^3 are removed.\n        pixdim: the pixdim of the input image. if a single number, this is used for all axes.\n            If a sequence of numbers, the length of the sequence must be equal to the image dimensions.\n\n    Example::\n\n        .. code-block:: python\n\n            from monai.transforms import RemoveSmallObjects, Spacing, Compose\n            from monai.data import MetaTensor\n\n            data1 = torch.tensor([[[0, 0, 0, 0, 0], [0, 1, 1, 0, 1], [0, 0, 0, 1, 1]]])\n            affine = torch.as_tensor([[2,0,0,0],\n                                      [0,1,0,0],\n                                      [0,0,1,0],\n                                      [0,0,0,1]], dtype=torch.float64)\n            data2 = MetaTensor(data1, affine=affine)\n\n            # remove objects smaller than 3mm^3, input is MetaTensor\n            trans = RemoveSmallObjects(min_size=3, by_measure=True)\n            out = trans(data2)\n            # remove objects smaller than 3mm^3, input is not MetaTensor\n            trans = RemoveSmallObjects(min_size=3, by_measure=True, pixdim=(2, 1, 1))\n            out = trans(data1)\n\n            # remove objects smaller than 3 (in pixel)\n            trans = RemoveSmallObjects(min_size=3)\n            out = trans(data2)\n\n            # If the affine of the data is not identity, you can also add Spacing before.\n            trans = Compose([\n                Spacing(pixdim=(1, 1, 1)),\n                RemoveSmallObjects(min_size=3)\n            ])\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        min_size: int = 64,\n        connectivity: int = 1,\n        independent_channels: bool = True,\n        by_measure: bool = False,\n        pixdim: Sequence[float] | float | np.ndarray | None = None,\n    ) -> None:\n        self.min_size = min_size\n        self.connectivity = connectivity\n        self.independent_channels = independent_channels\n        self.by_measure = by_measure\n        self.pixdim = pixdim\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Data\n                should be one-hotted.\n\n        Returns:\n            An array with shape (C, spatial_dim1[, spatial_dim2, ...]).\n        \"\"\"\n\n        return remove_small_objects(\n            img, self.min_size, self.connectivity, self.independent_channels, self.by_measure, self.pixdim\n        )\n\n\nclass LabelFilter(Transform):\n    \"\"\"\n    This transform filters out labels and can be used as a processing step to view only certain labels.\n\n    The list of applied labels defines which labels will be kept.\n\n    Note:\n        All labels which do not match the `applied_labels` are set to the background label (0).\n\n    For example:\n\n    Use LabelFilter with applied_labels=[1, 5, 9]::\n\n        [1, 2, 3]         [1, 0, 0]\n        [4, 5, 6]    =>   [0, 5 ,0]\n        [7, 8, 9]         [0, 0, 9]\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, applied_labels: Iterable[int] | int) -> None:\n        \"\"\"\n        Initialize the LabelFilter class with the labels to filter on.\n\n        Args:\n            applied_labels: Label(s) to filter on.\n        \"\"\"\n        self.applied_labels = ensure_tuple(applied_labels)\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Filter the image on the `applied_labels`.\n\n        Args:\n            img: Pytorch tensor or numpy array of any shape.\n\n        Raises:\n            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.\n\n        Returns:\n            Pytorch tensor or numpy array of the same shape as the input.\n        \"\"\"\n        if not isinstance(img, (np.ndarray, torch.Tensor)):\n            raise NotImplementedError(f\"{self.__class__} can not handle data of type {type(img)}.\")\n\n        if isinstance(img, torch.Tensor):\n            img = convert_to_tensor(img, track_meta=get_track_meta())\n            img_ = convert_to_tensor(img, track_meta=False)\n            if hasattr(torch, \"isin\"):  # `isin` is new in torch 1.10.0\n                appl_lbls = torch.as_tensor(self.applied_labels, device=img_.device)\n                out = torch.where(torch.isin(img_, appl_lbls), img_, torch.tensor(0.0).to(img_))\n                return convert_to_dst_type(out, dst=img)[0]\n            out: NdarrayOrTensor = self(img_.detach().cpu().numpy())  # type: ignore\n            out = convert_to_dst_type(out, img)[0]  # type: ignore\n            return out\n        return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))\n\n\nclass FillHoles(Transform):\n    r\"\"\"\n    This transform fills holes in the image and can be used to remove artifacts inside segments.\n\n    An enclosed hole is defined as a background pixel/voxel which is only enclosed by a single class.\n    The definition of enclosed can be defined with the connectivity parameter::\n\n        1-connectivity     2-connectivity     diagonal connection close-up\n\n             [ ]           [ ]  [ ]  [ ]             [ ]\n              |               \\  |  /                 |  <- hop 2\n        [ ]--[x]--[ ]      [ ]--[x]--[ ]        [x]--[ ]\n              |               /  |  \\             hop 1\n             [ ]           [ ]  [ ]  [ ]\n\n    It is possible to define for which labels the hole filling should be applied.\n    The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]].\n    If C = 1, then the values correspond to expected labels.\n    If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing.\n\n    Note:\n\n        The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label.\n\n        The performance of this method heavily depends on the number of labels.\n        It is a bit faster if the list of `applied_labels` is provided.\n        Limiting the number of `applied_labels` results in a big decrease in processing time.\n\n    For example:\n\n        Use FillHoles with default parameters::\n\n            [1, 1, 1, 2, 2, 2, 3, 3]         [1, 1, 1, 2, 2, 2, 3, 3]\n            [1, 0, 1, 2, 0, 0, 3, 0]    =>   [1, 1 ,1, 2, 0, 0, 3, 0]\n            [1, 1, 1, 2, 2, 2, 3, 3]         [1, 1, 1, 2, 2, 2, 3, 3]\n\n        The hole in label 1 is fully enclosed and therefore filled with label 1.\n        The background label near label 2 and 3 is not fully enclosed and therefore not filled.\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, applied_labels: Iterable[int] | int | None = None, connectivity: int | None = None) -> None:\n        \"\"\"\n        Initialize the connectivity and limit the labels for which holes are filled.\n\n        Args:\n            applied_labels: Labels for which to fill holes. Defaults to None, that is filling holes for all labels.\n            connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n                Accepted values are ranging from  1 to input.ndim. Defaults to a full connectivity of ``input.ndim``.\n        \"\"\"\n        super().__init__()\n        self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None\n        self.connectivity = connectivity\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Fill the holes in the provided image.\n\n        Note:\n            The value 0 is assumed as background label.\n\n        Args:\n            img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].\n\n        Raises:\n            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.\n\n        Returns:\n            Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_np, *_ = convert_data_type(img, np.ndarray)\n        out_np: np.ndarray = fill_holes(img_np, self.applied_labels, self.connectivity)\n        out, *_ = convert_to_dst_type(out_np, img)\n        return out\n\n\nclass LabelToContour(Transform):\n    \"\"\"\n    Return the contour of binary input images that only compose of 0 and 1, with Laplacian kernel\n    set as default for edge detection. Typical usage is to plot the edge of label or segmentation output.\n\n    Args:\n        kernel_type: the method applied to do edge detection, default is \"Laplace\".\n\n    Raises:\n        NotImplementedError: When ``kernel_type`` is not \"Laplace\".\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, kernel_type: str = \"Laplace\") -> None:\n        if kernel_type != \"Laplace\":\n            raise NotImplementedError('Currently only kernel_type=\"Laplace\" is supported.')\n        self.kernel_type = kernel_type\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]]\n\n        Raises:\n            ValueError: When ``image`` ndim is not one of [3, 4].\n\n        Returns:\n            A torch tensor with the same shape as img, note:\n                1. it's the binary classification result of whether a pixel is edge or not.\n                2. in order to keep the original shape of mask image, we use padding as default.\n                3. the edge detection is just approximate because it defects inherent to Laplace kernel,\n                   ideally the edge should be thin enough, but now it has a thickness.\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_: torch.Tensor = convert_to_tensor(img, track_meta=False)\n        spatial_dims = len(img_.shape) - 1\n        img_ = img_.unsqueeze(0)  # adds a batch dim\n        if spatial_dims == 2:\n            kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32)\n        elif spatial_dims == 3:\n            kernel = -1.0 * torch.ones(3, 3, 3, dtype=torch.float32)\n            kernel[1, 1, 1] = 26.0\n        else:\n            raise ValueError(f\"{self.__class__} can only handle 2D or 3D images.\")\n        contour_img = apply_filter(img_, kernel)\n        contour_img.clamp_(min=0.0, max=1.0)\n        output, *_ = convert_to_dst_type(contour_img.squeeze(0), img)\n        return output\n\n\nclass Ensemble:\n\n    @staticmethod\n    def get_stacked_torch(img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> torch.Tensor:\n        \"\"\"Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor.\"\"\"\n        if isinstance(img, Sequence) and isinstance(img[0], np.ndarray):\n            img = [torch.as_tensor(i) for i in img]\n        elif isinstance(img, np.ndarray):\n            img = torch.as_tensor(img)\n        out: torch.Tensor = torch.stack(img) if isinstance(img, Sequence) else img  # type: ignore\n        return out\n\n    @staticmethod\n    def post_convert(img: torch.Tensor, orig_img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor:\n        orig_img_ = orig_img[0] if isinstance(orig_img, Sequence) else orig_img\n        out, *_ = convert_to_dst_type(img, orig_img_)\n        return out\n\n\nclass MeanEnsemble(Ensemble, Transform):\n    \"\"\"\n    Execute mean ensemble on the input data.\n    The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]],\n    Or a single PyTorch Tensor with shape: [E, C[, H, W, D]], the `E` dimension represents\n    the output data from different models.\n    Typically, the input data is model output of segmentation task or classification task.\n    And it also can support to add `weights` for the input data.\n\n    Args:\n        weights: can be a list or tuple of numbers for input data with shape: [E, C, H, W[, D]].\n            or a Numpy ndarray or a PyTorch Tensor data.\n            the `weights` will be added to input data from highest dimension, for example:\n            1. if the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.\n            2. if the `weights` has 2 dimensions, it will be added to `E` and `C` dimensions.\n            it's a typical practice to add weights for different classes:\n            to ensemble 3 segmentation model outputs, every output has 4 channels(classes),\n            so the input data shape can be: [3, 4, H, W, D].\n            and add different `weights` for different classes, so the `weights` shape can be: [3, 4].\n            for example: `weights = [[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 1, 1]]`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, weights: Sequence[float] | NdarrayOrTensor | None = None) -> None:\n        self.weights = torch.as_tensor(weights, dtype=torch.float) if weights is not None else None\n\n    def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor:\n        img_ = self.get_stacked_torch(img)\n        if self.weights is not None:\n            self.weights = self.weights.to(img_.device)\n            shape = tuple(self.weights.shape)\n            for _ in range(img_.ndimension() - self.weights.ndimension()):\n                shape += (1,)\n            weights = self.weights.reshape(*shape)\n\n            img_ = img_ * weights / weights.mean(dim=0, keepdim=True)\n\n        out_pt = torch.mean(img_, dim=0)\n        return self.post_convert(out_pt, img)\n\n\nclass VoteEnsemble(Ensemble, Transform):\n    \"\"\"\n    Execute vote ensemble on the input data.\n    The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]],\n    Or a single PyTorch Tensor with shape: [E[, C, H, W, D]], the `E` dimension represents\n    the output data from different models.\n    Typically, the input data is model output of segmentation task or classification task.\n\n    Note:\n        This vote transform expects the input data is discrete values. It can be multiple channels\n        data in One-Hot format or single channel data. It will vote to select the most common data\n        between items.\n        The output data has the same shape as every item of the input data.\n\n    Args:\n        num_classes: if the input is single channel data instead of One-Hot, we can't get class number\n            from channel, need to explicitly specify the number of classes to vote.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, num_classes: int | None = None) -> None:\n        self.num_classes = num_classes\n\n    def __call__(self, img: Sequence[NdarrayOrTensor] | NdarrayOrTensor) -> NdarrayOrTensor:\n        img_ = self.get_stacked_torch(img)\n\n        if self.num_classes is not None:\n            has_ch_dim = True\n            if img_.ndimension() > 1 and img_.shape[1] > 1:\n                warnings.warn(\"no need to specify num_classes for One-Hot format data.\")\n            else:\n                if img_.ndimension() == 1:\n                    # if no channel dim, need to remove channel dim after voting\n                    has_ch_dim = False\n                img_ = one_hot(img_, self.num_classes, dim=1)\n\n        img_ = torch.mean(img_.float(), dim=0)\n\n        if self.num_classes is not None:\n            # if not One-Hot, use \"argmax\" to vote the most common class\n            out_pt = torch.argmax(img_, dim=0, keepdim=has_ch_dim)\n        else:\n            # for One-Hot data, round the float number to 0 or 1\n            out_pt = torch.round(img_)\n        return self.post_convert(out_pt, img)\n\n\nclass GenerateHeatmap(Transform):\n    \"\"\"\n    Generate per-landmark Gaussian heatmaps for 2D or 3D coordinates.\n\n    Notes:\n        - Coordinates are interpreted in voxel units and expected in (Y, X) for 2D or (Z, Y, X) for 3D.\n        - Target spatial_shape is (Y, X) for 2D and (Z, Y, X) for 3D.\n        - Output layout uses channel-first convention with one channel per landmark.\n        - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions (2 or 3).\n        - Output heatmap shape: (N, Y, X) for 2D or (N, Z, Y, X) for 3D.\n        - Each channel index corresponds to one landmark.\n\n    Args:\n        sigma: gaussian standard deviation. A single value is broadcast across all spatial dimensions.\n        spatial_shape: optional fallback spatial shape. If ``None`` it must be provided when calling the transform.\n        truncated: extent, in multiples of ``sigma``, used to crop the gaussian support window.\n        normalize: normalize every heatmap channel to ``[0, 1]`` when ``True``.\n        dtype: target dtype for the generated heatmaps (accepts numpy or torch dtypes).\n\n    Raises:\n        ValueError: when ``sigma`` is non-positive or ``spatial_shape`` cannot be resolved.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        sigma: Sequence[float] | float = 5.0,\n        spatial_shape: Sequence[int] | None = None,\n        truncated: float = 4.0,\n        normalize: bool = True,\n        dtype: np.dtype | torch.dtype | type = np.float32,\n    ) -> None:\n        if isinstance(sigma, Sequence) and not isinstance(sigma, (str, bytes)):\n            if any(s <= 0 for s in sigma):\n                raise ValueError(\"Argument `sigma` values must be positive.\")\n            self._sigma = tuple(float(s) for s in sigma)\n        else:\n            if float(sigma) <= 0:\n                raise ValueError(\"Argument `sigma` must be positive.\")\n            self._sigma = (float(sigma),)\n        if truncated <= 0:\n            raise ValueError(\"Argument `truncated` must be positive.\")\n        self.truncated = float(truncated)\n        self.normalize = normalize\n        self.torch_dtype = get_equivalent_dtype(dtype, torch.Tensor)\n        self.numpy_dtype = get_equivalent_dtype(dtype, np.ndarray)\n        # Validate that dtype is floating-point for meaningful Gaussian values\n        if not self.torch_dtype.is_floating_point:\n            raise ValueError(f\"Argument `dtype` must be a floating-point type, got {self.torch_dtype}\")\n        self.spatial_shape = None if spatial_shape is None else tuple(int(s) for s in spatial_shape)\n\n    def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None = None) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            points: landmark coordinates as ndarray/Tensor with shape (N, D),\n                ordered as (Y, X) for 2D or (Z, Y, X) for 3D, where N is the number\n                of landmarks and D is the spatial dimensionality.\n            spatial_shape: spatial size as a sequence. If None, uses the value provided at construction.\n\n        Returns:\n            Heatmaps with shape (N, *spatial), one channel per landmark.\n\n        Raises:\n            ValueError: if points shape/dimension or spatial_shape is invalid.\n        \"\"\"\n        original_points = points\n        points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)\n\n        if points_t.ndim != 2:\n            raise ValueError(\n                f\"Argument `points` must be a 2D array with shape (num_points, spatial_dims), got shape {points_t.shape}.\"\n            )\n\n        if points_t.shape[-1] not in (2, 3):\n            raise ValueError(\"GenerateHeatmap only supports 2D or 3D landmarks.\")\n\n        device = points_t.device\n        num_points, spatial_dims = points_t.shape\n\n        target_shape = self._resolve_spatial_shape(spatial_shape, spatial_dims)\n        sigma = self._resolve_sigma(spatial_dims)\n\n        # Create sparse image with impulses at landmark locations\n        heatmap = torch.zeros((num_points, *target_shape), dtype=self.torch_dtype, device=device)\n        bounds_t = torch.as_tensor(target_shape, device=device, dtype=points_t.dtype)\n\n        for idx, center in enumerate(points_t):\n            if not torch.isfinite(center).all():\n                continue\n            if not ((center >= 0).all() and (center < bounds_t).all()):\n                continue\n            # Round to nearest integer for impulse placement, then clamp to valid index range\n            center_int = center.round().long()\n            # Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array)\n            bounds_max = (bounds_t - 1).long()\n            center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max)\n            # Place impulse (use maximum in case of overlapping landmarks)\n            current_val = heatmap[idx][tuple(center_int)]\n            heatmap[idx][tuple(center_int)] = torch.maximum(\n                current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device)\n            )\n\n        # Apply Gaussian blur using GaussianFilter\n        # Reshape to (num_points, 1, *spatial) for per-channel filtering\n        heatmap_input = heatmap.unsqueeze(1)  # Add channel dimension\n\n        gaussian_filter = GaussianFilter(\n            spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx=\"erf\", requires_grad=False\n        ).to(device=device, dtype=self.torch_dtype)\n\n        heatmap_blurred = gaussian_filter(heatmap_input)\n        heatmap = heatmap_blurred.squeeze(1)  # Remove channel dimension\n\n        # Normalize per channel if requested\n        if self.normalize:\n            for idx in range(num_points):\n                peak = heatmap[idx].amax()\n                if peak > 0:\n                    heatmap[idx].div_(peak)\n\n        target_dtype = self.torch_dtype if isinstance(original_points, (torch.Tensor, MetaTensor)) else self.numpy_dtype\n        converted, _, _ = convert_to_dst_type(heatmap, original_points, dtype=target_dtype)\n        return converted\n\n    def _resolve_spatial_shape(self, call_shape: Sequence[int] | None, spatial_dims: int) -> tuple[int, ...]:\n        shape = call_shape if call_shape is not None else self.spatial_shape\n        if shape is None:\n            raise ValueError(\"Argument `spatial_shape` must be provided either at construction time or call time.\")\n        shape_tuple = ensure_tuple(shape)\n        if len(shape_tuple) != spatial_dims:\n            if len(shape_tuple) == 1:\n                shape_tuple = shape_tuple * spatial_dims  # type: ignore\n            else:\n                raise ValueError(\n                    \"Argument `spatial_shape` length must match the landmarks' spatial dims (or pass a single int to broadcast).\"\n                )\n        return tuple(int(s) for s in shape_tuple)\n\n    def _resolve_sigma(self, spatial_dims: int) -> tuple[float, ...]:\n        if len(self._sigma) == spatial_dims:\n            return self._sigma\n        if len(self._sigma) == 1:\n            return self._sigma * spatial_dims\n        raise ValueError(\"Argument `sigma` sequence length must equal the number of spatial dimensions.\")\n\n\nclass ProbNMS(Transform):\n    \"\"\"\n    Performs probability based non-maximum suppression (NMS) on the probabilities map via\n    iteratively selecting the coordinate with highest probability and then move it as well\n    as its surrounding values. The remove range is determined by the parameter `box_size`.\n    If multiple coordinates have the same highest probability, only one of them will be\n    selected.\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input probabilities map.\n            Defaults to 2.\n        sigma: the standard deviation for gaussian filter.\n            It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.\n        prob_threshold: the probability threshold, the function will stop searching if\n            the highest probability is no larger than the threshold. The value should be\n            no less than 0.0. Defaults to 0.5.\n        box_size: the box size (in pixel) to be removed around the pixel with the maximum probability.\n            It can be an integer that defines the size of a square or cube,\n            or a list containing different values for each dimensions. Defaults to 48.\n\n    Return:\n        a list of selected lists, where inner lists contain probability and coordinates.\n        For example, for 3D input, the inner lists are in the form of [probability, x, y, z].\n\n    Raises:\n        ValueError: When ``prob_threshold`` is less than 0.0.\n        ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.\n        ValueError: When ``box_size`` has a less than 1 value.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        spatial_dims: int = 2,\n        sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 0.0,\n        prob_threshold: float = 0.5,\n        box_size: int | Sequence[int] = 48,\n    ) -> None:\n        self.sigma = sigma\n        self.spatial_dims = spatial_dims\n        if self.sigma != 0:\n            self.filter = GaussianFilter(spatial_dims=spatial_dims, sigma=sigma)\n        if prob_threshold < 0:\n            raise ValueError(\"prob_threshold should be no less than 0.0.\")\n        self.prob_threshold = prob_threshold\n        if isinstance(box_size, int):\n            self.box_size = np.asarray([box_size] * spatial_dims)\n        elif len(box_size) != spatial_dims:\n            raise ValueError(\"the sequence length of box_size should be the same as spatial_dims.\")\n        else:\n            self.box_size = np.asarray(box_size)\n        if self.box_size.min() <= 0:\n            raise ValueError(\"box_size should be larger than 0.\")\n\n        self.box_lower_bd = self.box_size // 2\n        self.box_upper_bd = self.box_size - self.box_lower_bd\n\n    def __call__(self, prob_map: NdarrayOrTensor):\n        \"\"\"\n        prob_map: the input probabilities map, it must have shape (H[, W, ...]).\n        \"\"\"\n        if self.sigma != 0:\n            if not isinstance(prob_map, torch.Tensor):\n                prob_map = torch.as_tensor(prob_map, dtype=torch.float)\n            self.filter.to(prob_map.device)\n            prob_map = self.filter(prob_map)\n\n        prob_map_shape = prob_map.shape\n\n        outputs = []\n        while prob_map.max() > self.prob_threshold:\n            max_idx = unravel_index(prob_map.argmax(), prob_map_shape)\n            prob_max = prob_map[tuple(max_idx)]\n            max_idx = max_idx.cpu().numpy() if isinstance(max_idx, torch.Tensor) else max_idx\n            prob_max = prob_max.item() if isinstance(prob_max, torch.Tensor) else prob_max\n            outputs.append([prob_max] + list(max_idx))\n\n            idx_min_range = (max_idx - self.box_lower_bd).clip(0, None)\n            idx_max_range = (max_idx + self.box_upper_bd).clip(None, prob_map_shape)\n            # for each dimension, set values during index ranges to 0\n            slices = tuple(slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims))\n            prob_map[slices] = 0\n\n        return outputs\n\n\nclass Invert(Transform):\n    \"\"\"\n    Utility transform to automatically invert the previously applied transforms.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        transform: InvertibleTransform | None = None,\n        nearest_interp: bool | Sequence[bool] = True,\n        device: str | torch.device | None = None,\n        post_func: Callable | None = None,\n        to_tensor: bool | Sequence[bool] = True,\n    ) -> None:\n        \"\"\"\n        Args:\n            transform: the previously applied transform.\n            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,\n                default to `True`. If `False`, use the same interpolation mode as the original transform.\n            device: move the inverted results to a target device before `post_func`, default to `None`.\n            post_func: postprocessing for the inverted result, should be a callable function.\n            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.\n        \"\"\"\n        if not isinstance(transform, InvertibleTransform):\n            raise ValueError(\"transform is not invertible, can't invert transform for the data.\")\n        self.transform = transform\n        self.nearest_interp = nearest_interp\n        self.device = device\n        self.post_func = post_func\n        self.to_tensor = to_tensor\n        self._totensor = ToTensor()\n\n    def __call__(self, data):\n        if not isinstance(data, MetaTensor):\n            return data\n\n        if self.nearest_interp:\n            data.applied_operations = convert_applied_interp_mode(\n                trans_info=data.applied_operations, mode=\"nearest\", align_corners=None\n            )\n\n        data = data.detach()\n        inverted = self.transform.inverse(data)\n        if self.to_tensor and not isinstance(inverted, MetaTensor):\n            inverted = self._totensor(inverted)\n        if isinstance(inverted, torch.Tensor):\n            inverted = inverted.to(device=self.device)\n        if callable(self.post_func):\n            inverted = self.post_func(inverted)\n        return inverted\n\n\nclass SobelGradients(Transform):\n    \"\"\"Calculate Sobel gradients of a grayscale image with the shape of CxH[xWxDx...] or BxH[xWxDx...].\n\n    Args:\n        kernel_size: the size of the Sobel kernel. Defaults to 3.\n        spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient\n            along each of the provide axis. By default it calculate the gradient for all spatial axes.\n        normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.\n        normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.\n        padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `\"reflect\"`.\n            Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.\n            See ``torch.nn.Conv1d()`` for more information.\n        dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        kernel_size: int = 3,\n        spatial_axes: Sequence[int] | int | None = None,\n        normalize_kernels: bool = True,\n        normalize_gradients: bool = False,\n        padding_mode: str = \"reflect\",\n        dtype: torch.dtype = torch.float32,\n    ) -> None:\n        super().__init__()\n        self.padding = padding_mode\n        self.spatial_axes = spatial_axes\n        self.normalize_kernels = normalize_kernels\n        self.normalize_gradients = normalize_gradients\n        self.kernel_diff, self.kernel_smooth = self._get_kernel(kernel_size, dtype)\n\n    def _get_kernel(self, size, dtype) -> tuple[torch.Tensor, torch.Tensor]:\n        if size < 3:\n            raise ValueError(f\"Sobel kernel size should be at least three. {size} was given.\")\n        if size % 2 == 0:\n            raise ValueError(f\"Sobel kernel size should be an odd number. {size} was given.\")\n\n        kernel_diff = torch.tensor([[[-1, 0, 1]]], dtype=dtype)\n        kernel_smooth = torch.tensor([[[1, 2, 1]]], dtype=dtype)\n        kernel_expansion = torch.tensor([[[1, 2, 1]]], dtype=dtype)\n\n        if self.normalize_kernels:\n            if not dtype.is_floating_point:\n                raise ValueError(\n                    f\"`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. {dtype} was given.\"\n                )\n            kernel_diff /= 2.0\n            kernel_smooth /= 4.0\n            kernel_expansion /= 4.0\n\n        # Expand the kernel to larger size than 3\n        expand = (size - 3) // 2\n        for _ in range(expand):\n            kernel_diff = F.conv1d(kernel_diff, kernel_expansion, padding=2)\n            kernel_smooth = F.conv1d(kernel_smooth, kernel_expansion, padding=2)\n\n        return kernel_diff.squeeze(), kernel_smooth.squeeze()\n\n    def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:\n        image_tensor = convert_to_tensor(image, track_meta=get_track_meta())\n\n        # Check/set spatial axes\n        n_spatial_dims = image_tensor.ndim - 1  # excluding the channel dimension\n        valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0))\n\n        # Check gradient axes to be valid\n        if self.spatial_axes is None:\n            spatial_axes = list(range(n_spatial_dims))\n        else:\n            invalid_axis = set(ensure_tuple(self.spatial_axes)) - set(valid_spatial_axes)\n            if invalid_axis:\n                raise ValueError(\n                    f\"The provide axes to calculate gradient is not valid: {invalid_axis}. \"\n                    f\"The image has {n_spatial_dims} spatial dimensions so it should be: {valid_spatial_axes}.\"\n                )\n            spatial_axes = [ax % n_spatial_dims if ax < 0 else ax for ax in ensure_tuple(self.spatial_axes)]\n\n        # Add batch dimension for separable_filtering\n        image_tensor = image_tensor.unsqueeze(0)\n\n        # Get the Sobel kernels\n        kernel_diff = self.kernel_diff.to(image_tensor.device)\n        kernel_smooth = self.kernel_smooth.to(image_tensor.device)\n\n        # Calculate gradient\n        grad_list = []\n        for ax in spatial_axes:\n            kernels = [kernel_smooth] * n_spatial_dims\n            kernels[ax] = kernel_diff\n            grad = separable_filtering(image_tensor, kernels, mode=self.padding)\n            if self.normalize_gradients:\n                grad_min = grad.min()\n                if grad_min != grad.max():\n                    grad -= grad_min\n                grad_max = grad.max()\n                if grad_max > 0:\n                    grad /= grad_max\n            grad_list.append(grad)\n\n        grads = torch.cat(grad_list, dim=1)\n\n        # Remove batch dimension and convert the gradient type to be the same as input image\n        grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]\n\n        return grads\n\n\nclass DistanceTransformEDT(Transform):\n    \"\"\"\n    Applies the Euclidean distance transform on the input.\n    Either GPU based with CuPy / cuCIM or CPU based with scipy.\n    To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.\n\n    Note that the results of the libraries can differ, so stick to one if possible.\n    For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.\n\n    .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html\n    .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.CUPY]\n\n    def __init__(self, sampling: None | float | list[float] = None) -> None:\n        super().__init__()\n        self.sampling = sampling\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: Input image on which the distance transform shall be run.\n                Has to be a channel first array, must have shape: (num_channels, H, W [,D]).\n                Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.\n                Input gets passed channel-wise to the distance-transform, thus results from this function will differ\n                from directly calling ``distance_transform_edt()`` in CuPy or SciPy.\n            sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;\n                if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.\n\n        Returns:\n            An array with the same shape and data type as img\n        \"\"\"\n        return distance_transform_edt(img=img, sampling=self.sampling)  # type: ignore\n"
  },
  {
    "path": "monai/transforms/post/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for model output tensors\ndefined in :py:class:`monai.transforms.utility.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Hashable, Iterable, Mapping, Sequence\nfrom copy import deepcopy\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai import config\nfrom monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike\nfrom monai.data.csv_saver import CSVSaver\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.post.array import (\n    Activations,\n    AsDiscrete,\n    DistanceTransformEDT,\n    FillHoles,\n    GenerateHeatmap,\n    KeepLargestConnectedComponent,\n    LabelFilter,\n    LabelToContour,\n    MeanEnsemble,\n    ProbNMS,\n    RemoveSmallObjects,\n    SobelGradients,\n    VoteEnsemble,\n)\nfrom monai.transforms.transform import MapTransform\nfrom monai.transforms.utility.array import ToTensor\nfrom monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode\nfrom monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep\nfrom monai.utils.type_conversion import convert_to_dst_type\n\n__all__ = [\n    \"ActivationsD\",\n    \"ActivationsDict\",\n    \"Activationsd\",\n    \"AsDiscreteD\",\n    \"AsDiscreteDict\",\n    \"AsDiscreted\",\n    \"Ensembled\",\n    \"EnsembleD\",\n    \"EnsembleDict\",\n    \"FillHolesD\",\n    \"FillHolesDict\",\n    \"FillHolesd\",\n    \"InvertD\",\n    \"InvertDict\",\n    \"Invertd\",\n    \"KeepLargestConnectedComponentD\",\n    \"KeepLargestConnectedComponentDict\",\n    \"KeepLargestConnectedComponentd\",\n    \"RemoveSmallObjectsD\",\n    \"RemoveSmallObjectsDict\",\n    \"RemoveSmallObjectsd\",\n    \"LabelFilterD\",\n    \"LabelFilterDict\",\n    \"LabelFilterd\",\n    \"LabelToContourD\",\n    \"LabelToContourDict\",\n    \"LabelToContourd\",\n    \"MeanEnsembleD\",\n    \"MeanEnsembleDict\",\n    \"MeanEnsembled\",\n    \"ProbNMSD\",\n    \"ProbNMSDict\",\n    \"ProbNMSd\",\n    \"SaveClassificationD\",\n    \"SaveClassificationDict\",\n    \"SaveClassificationd\",\n    \"SobelGradientsD\",\n    \"SobelGradientsDict\",\n    \"SobelGradientsd\",\n    \"VoteEnsembleD\",\n    \"VoteEnsembleDict\",\n    \"VoteEnsembled\",\n    \"DistanceTransformEDTd\",\n    \"DistanceTransformEDTD\",\n    \"DistanceTransformEDTDict\",\n    \"GenerateHeatmapd\",\n    \"GenerateHeatmapD\",\n    \"GenerateHeatmapDict\",\n]\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\nclass Activationsd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.AddActivations`.\n    Add activation layers to the input data specified by `keys`.\n    \"\"\"\n\n    backend = Activations.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigmoid: Sequence[bool] | bool = False,\n        softmax: Sequence[bool] | bool = False,\n        other: Sequence[Callable] | Callable | None = None,\n        allow_missing_keys: bool = False,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to model output and label.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            sigmoid: whether to execute sigmoid function on model output before transform.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            softmax: whether to execute softmax function on model output before transform.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            other: callable function to execute other activation layers,\n                for example: `other = torch.tanh`. it also can be a sequence of Callable, each\n                element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n            kwargs: additional parameters to `torch.softmax` (used when ``softmax=True``).\n                Defaults to ``dim=0``, unrecognized parameters will be ignored.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.sigmoid = ensure_tuple_rep(sigmoid, len(self.keys))\n        self.softmax = ensure_tuple_rep(softmax, len(self.keys))\n        self.other = ensure_tuple_rep(other, len(self.keys))\n        self.converter = Activations()\n        self.converter.kwargs = kwargs\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, sigmoid, softmax, other in self.key_iterator(d, self.sigmoid, self.softmax, self.other):\n            d[key] = self.converter(d[key], sigmoid, softmax, other)\n        return d\n\n\nclass AsDiscreted(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`.\n    \"\"\"\n\n    backend = AsDiscrete.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        argmax: Sequence[bool] | bool = False,\n        to_onehot: Sequence[int | None] | int | None = None,\n        threshold: Sequence[float | None] | float | None = None,\n        rounding: Sequence[str | None] | str | None = None,\n        allow_missing_keys: bool = False,\n        **kwargs,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to model output and label.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            argmax: whether to execute argmax function on input data before transform.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.\n                defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.\n            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.\n                defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.\n            rounding: if not None, round the data according to the specified option,\n                available options: [\"torchrounding\"]. it also can be a sequence of str or None,\n                each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n            kwargs: additional parameters to ``AsDiscrete``.\n                ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.\n                These default to ``0``, ``True``, ``torch.float`` respectively.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.argmax = ensure_tuple_rep(argmax, len(self.keys))\n        self.to_onehot = []\n        for flag in ensure_tuple_rep(to_onehot, len(self.keys)):\n            if isinstance(flag, bool):\n                raise ValueError(\"`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.\")\n            self.to_onehot.append(flag)\n\n        self.threshold = []\n        for flag in ensure_tuple_rep(threshold, len(self.keys)):\n            if isinstance(flag, bool):\n                raise ValueError(\"`threshold_values=True/False` is deprecated, please use `threshold=value` instead.\")\n            self.threshold.append(flag)\n\n        self.rounding = ensure_tuple_rep(rounding, len(self.keys))\n        self.converter = AsDiscrete()\n        self.converter.kwargs = kwargs\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, argmax, to_onehot, threshold, rounding in self.key_iterator(\n            d, self.argmax, self.to_onehot, self.threshold, self.rounding\n        ):\n            d[key] = self.converter(d[key], argmax, to_onehot, threshold, rounding)\n        return d\n\n\nclass KeepLargestConnectedComponentd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.KeepLargestConnectedComponent`.\n    \"\"\"\n\n    backend = KeepLargestConnectedComponent.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        applied_labels: Sequence[int] | int | None = None,\n        is_onehot: bool | None = None,\n        independent: bool = True,\n        connectivity: int | None = None,\n        num_components: int = 1,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            applied_labels: Labels for applying the connected component analysis on.\n                If given, voxels whose value is in this list will be analyzed.\n                If `None`, all non-zero values will be analyzed.\n            is_onehot: if `True`, treat the input data as OneHot format data, otherwise, not OneHot format data.\n                default to None, which treats multi-channel data as OneHot and single channel data as not OneHot.\n            independent: whether to treat ``applied_labels`` as a union of foreground labels.\n                If ``True``, the connected component analysis will be performed on each foreground label independently\n                and return the intersection of the largest components.\n                If ``False``, the analysis will be performed on the union of foreground labels.\n                default is `True`.\n            connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n                Accepted values are ranging from  1 to input.ndim. If ``None``, a full\n                connectivity of ``input.ndim`` is used. for more details:\n                https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.\n            num_components: The number of largest components to preserve.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = KeepLargestConnectedComponent(\n            applied_labels=applied_labels,\n            is_onehot=is_onehot,\n            independent=independent,\n            connectivity=connectivity,\n            num_components=num_components,\n        )\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass RemoveSmallObjectsd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RemoveSmallObjectsd`.\n\n    Args:\n        min_size: objects smaller than this size (in number of voxels; or surface area/volume value\n            in whatever units your image is if by_measure is True) are removed.\n        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n            Accepted values are ranging from  1 to input.ndim. If ``None``, a full\n            connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image\n            documentation.\n        independent_channels: Whether or not to consider channels as independent. If true, then\n            conjoining islands from different labels will be removed if they are below the threshold.\n            If false, the overall size islands made from all non-background voxels will be used.\n        by_measure: Whether the specified min_size is in number of voxels. if this is True then min_size\n            represents a surface area or volume value of whatever units your image is in (mm^3, cm^2, etc.)\n            default is False. e.g. if min_size is 3, by_measure is True and the units of your data is mm,\n            objects smaller than 3mm^3 are removed.\n        pixdim: the pixdim of the input image. if a single number, this is used for all axes.\n            If a sequence of numbers, the length of the sequence must be equal to the image dimensions.\n    \"\"\"\n\n    backend = RemoveSmallObjects.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        min_size: int = 64,\n        connectivity: int = 1,\n        independent_channels: bool = True,\n        by_measure: bool = False,\n        pixdim: Sequence[float] | float | np.ndarray | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = RemoveSmallObjects(min_size, connectivity, independent_channels, by_measure, pixdim)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass LabelFilterd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.LabelFilter`.\n    \"\"\"\n\n    backend = LabelFilter.backend\n\n    def __init__(\n        self, keys: KeysCollection, applied_labels: Sequence[int] | int, allow_missing_keys: bool = False\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            applied_labels: Label(s) to filter on.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = LabelFilter(applied_labels)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass FillHolesd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.FillHoles`.\n    \"\"\"\n\n    backend = FillHoles.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        applied_labels: Iterable[int] | int | None = None,\n        connectivity: int | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Initialize the connectivity and limit the labels for which holes are filled.\n\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            applied_labels (Optional[Union[Iterable[int], int]], optional): Labels for which to fill holes. Defaults to None,\n                that is filling holes for all labels.\n            connectivity (int, optional): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n                Accepted values are ranging from  1 to input.ndim. Defaults to a full\n                connectivity of ``input.ndim``.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = FillHoles(applied_labels=applied_labels, connectivity=connectivity)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass LabelToContourd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`.\n    \"\"\"\n\n    backend = LabelToContour.backend\n\n    def __init__(self, keys: KeysCollection, kernel_type: str = \"Laplace\", allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            kernel_type: the method applied to do edge detection, default is \"Laplace\".\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = LabelToContour(kernel_type=kernel_type)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass Ensembled(MapTransform):\n    \"\"\"\n    Base class of dictionary-based ensemble transforms.\n\n    \"\"\"\n\n    backend = list(set(VoteEnsemble.backend) & set(MeanEnsemble.backend))\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        ensemble: Callable[[Sequence[NdarrayOrTensor] | NdarrayOrTensor], NdarrayOrTensor],\n        output_key: str | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be stack and execute ensemble.\n                if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`.\n            output_key: the key to store ensemble result in the dictionary.\n            ensemble: callable method to execute ensemble on specified data.\n                if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        Raises:\n            TypeError: When ``ensemble`` is not ``callable``.\n            ValueError: When ``len(keys) > 1`` and ``output_key=None``. Incompatible values.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        if not callable(ensemble):\n            raise TypeError(f\"ensemble must be callable but is {type(ensemble).__name__}.\")\n        self.ensemble = ensemble\n        if len(self.keys) > 1 and output_key is None:\n            raise ValueError(\"Incompatible values: len(self.keys) > 1 and output_key=None.\")\n        self.output_key = output_key if output_key is not None else self.keys[0]\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        items: list[NdarrayOrTensor] | NdarrayOrTensor\n        if len(self.keys) == 1 and self.keys[0] in d:\n            items = d[self.keys[0]]\n        else:\n            items = [d[key] for key in self.key_iterator(d)]\n\n        if len(items) > 0:\n            d[self.output_key] = self.ensemble(items)\n\n        return d\n\n\nclass MeanEnsembled(Ensembled):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.MeanEnsemble`.\n    \"\"\"\n\n    backend = MeanEnsemble.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        output_key: str | None = None,\n        weights: Sequence[float] | NdarrayOrTensor | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be stack and execute ensemble.\n                if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`.\n            output_key: the key to store ensemble result in the dictionary.\n                if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default.\n            weights: can be a list or tuple of numbers for input data with shape: [E, C, H, W[, D]].\n                or a Numpy ndarray or a PyTorch Tensor data.\n                the `weights` will be added to input data from highest dimension, for example:\n                1. if the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.\n                2. if the `weights` has 2 dimensions, it will be added to `E` and `C` dimensions.\n                it's a typical practice to add weights for different classes:\n                to ensemble 3 segmentation model outputs, every output has 4 channels(classes),\n                so the input data shape can be: [3, 4, H, W, D].\n                and add different `weights` for different classes, so the `weights` shape can be: [3, 4].\n                for example: `weights = [[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 1, 1]]`.\n\n        \"\"\"\n        ensemble = MeanEnsemble(weights=weights)\n        super().__init__(keys, ensemble, output_key)\n\n\nclass VoteEnsembled(Ensembled):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.VoteEnsemble`.\n    \"\"\"\n\n    backend = VoteEnsemble.backend\n\n    def __init__(self, keys: KeysCollection, output_key: str | None = None, num_classes: int | None = None) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be stack and execute ensemble.\n                if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`.\n            output_key: the key to store ensemble result in the dictionary.\n                if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default.\n            num_classes: if the input is single channel data instead of One-Hot, we can't get class number\n                from channel, need to explicitly specify the number of classes to vote.\n\n        \"\"\"\n        ensemble = VoteEnsemble(num_classes=num_classes)\n        super().__init__(keys, ensemble, output_key)\n\n\nclass GenerateHeatmapd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.GenerateHeatmap`.\n    Converts landmark coordinates into gaussian heatmaps and optionally copies metadata from a reference image.\n\n    Args:\n        keys: keys of the corresponding items in the dictionary, where each key references a tensor\n            of landmark point coordinates with shape (N, D), where N is the number of landmarks\n            and D is the spatial dimensionality (2 or 3).\n        sigma: standard deviation for the Gaussian kernel. Can be a single value or a sequence matching the number\n            of spatial dimensions.\n        heatmap_keys: keys to store output heatmaps. Default: \"{key}_heatmap\" for each key.\n        ref_image_keys: keys of reference images to inherit spatial metadata from. When provided, heatmaps will\n            have the same shape, affine, and spatial metadata as the reference images.\n        spatial_shape: spatial dimensions of output heatmaps. Can be:\n            - Single shape (tuple): applied to all keys\n            - List of shapes: one per key (must match keys length)\n        truncated: truncation distance for Gaussian kernel computation (in sigmas).\n        normalize: if True, normalize each heatmap's peak value to 1.0.\n        dtype: output data type for heatmaps. Defaults to np.float32.\n        allow_missing_keys: if True, don't raise error if some keys are missing in data.\n\n    Returns:\n        Dictionary with original data plus generated heatmaps at specified keys.\n\n    Raises:\n        ValueError: If heatmap_keys/ref_image_keys length doesn't match keys length.\n        ValueError: If no spatial shape can be determined (need spatial_shape or ref_image_keys).\n        ValueError: If input points have invalid shape (must be 2D array with shape (N, D)).\n\n    Example:\n        .. code-block:: python\n\n            import numpy as np\n            from monai.transforms import GenerateHeatmapd\n\n            # Create sample data with landmark points and a reference image\n            data = {\n                \"landmarks\": np.array([[10.0, 15.0], [20.0, 25.0]]),  # 2 points in 2D\n                \"image\": np.zeros((32, 32))  # reference image\n            }\n\n            # Transform with reference image\n            transform = GenerateHeatmapd(\n                keys=\"landmarks\",\n                sigma=2.0,\n                ref_image_keys=\"image\"\n            )\n            result = transform(data)\n            # result[\"landmarks_heatmap\"] has shape (2, 32, 32) - one channel per landmark\n\n            # Or with explicit spatial_shape\n            transform = GenerateHeatmapd(\n                keys=\"landmarks\",\n                sigma=2.0,\n                spatial_shape=(64, 64)\n            )\n            result = transform(data)\n            # result[\"landmarks_heatmap\"] has shape (2, 64, 64)\n\n    Notes:\n        - Default heatmap_keys are generated as \"{key}_heatmap\" for each input key\n        - Shape inference precedence: static spatial_shape > ref_image\n        - Input points shape: (N, D) where N is number of landmarks, D is spatial dimensions\n        - Output heatmap shape: (N, H, W) for 2D or (N, H, W, D) for 3D\n        - When using ref_image_keys, heatmaps inherit affine and spatial metadata from reference\n    \"\"\"\n\n    backend = GenerateHeatmap.backend\n\n    # Error messages\n    _ERR_HEATMAP_KEYS_LEN = \"Argument `heatmap_keys` length must match keys length.\"\n    _ERR_REF_KEYS_LEN = \"Argument `ref_image_keys` length must match keys length when provided.\"\n    _ERR_SHAPE_LEN = \"Argument `spatial_shape` length must match keys length when providing per-key shapes.\"\n    _ERR_NO_SHAPE = \"Unable to determine spatial shape for GenerateHeatmapd. Provide spatial_shape or ref_image_keys.\"\n    _ERR_INVALID_POINTS = \"Landmark arrays must be 2D with shape (N, D).\"\n    _ERR_REF_NO_SHAPE = \"Reference data must define a shape attribute.\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigma: Sequence[float] | float = 5.0,\n        heatmap_keys: KeysCollection | None = None,\n        ref_image_keys: KeysCollection | None = None,\n        spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None = None,\n        truncated: float = 4.0,\n        normalize: bool = True,\n        dtype: np.dtype | torch.dtype | type = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.heatmap_keys = self._prepare_heatmap_keys(heatmap_keys)\n        self.ref_image_keys = self._prepare_optional_keys(ref_image_keys)\n        self.static_shapes = self._prepare_shapes(spatial_shape)\n        self.generator = GenerateHeatmap(\n            sigma=sigma, spatial_shape=None, truncated=truncated, normalize=normalize, dtype=dtype\n        )\n\n    def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:\n        d = dict(data)\n        for key, out_key, ref_key, static_shape in self.key_iterator(\n            d, self.heatmap_keys, self.ref_image_keys, self.static_shapes\n        ):\n            points = d[key]\n            shape = self._determine_shape(points, static_shape, d, ref_key)\n            # The GenerateHeatmap transform will handle type conversion based on input points\n            heatmap = self.generator(points, spatial_shape=shape)\n            # If there's a reference image and we need to match its type/device\n            reference = d.get(ref_key) if ref_key is not None and ref_key in d else None\n            if reference is not None and isinstance(reference, (torch.Tensor, np.ndarray)):\n                # Convert to match reference type and device while preserving heatmap's dtype\n                heatmap, _, _ = convert_to_dst_type(\n                    heatmap, reference, dtype=heatmap.dtype, device=getattr(reference, \"device\", None)\n                )\n                # Copy metadata if reference is MetaTensor\n                if isinstance(reference, MetaTensor) and isinstance(heatmap, MetaTensor):\n                    heatmap.affine = reference.affine\n                    self._update_spatial_metadata(heatmap, shape)\n            d[out_key] = heatmap\n        return d\n\n    def _prepare_heatmap_keys(self, heatmap_keys: KeysCollection | None) -> tuple[Hashable, ...]:\n        if heatmap_keys is None:\n            return tuple(f\"{key}_heatmap\" for key in self.keys)\n        keys_tuple = ensure_tuple(heatmap_keys)\n        if len(keys_tuple) == 1 and len(self.keys) > 1:\n            keys_tuple = keys_tuple * len(self.keys)\n        if len(keys_tuple) != len(self.keys):\n            raise ValueError(self._ERR_HEATMAP_KEYS_LEN)\n        return keys_tuple\n\n    def _prepare_optional_keys(self, maybe_keys: KeysCollection | None) -> tuple[Hashable | None, ...]:\n        if maybe_keys is None:\n            return (None,) * len(self.keys)\n        keys_tuple = ensure_tuple(maybe_keys)\n        if len(keys_tuple) == 1 and len(self.keys) > 1:\n            keys_tuple = keys_tuple * len(self.keys)\n        if len(keys_tuple) != len(self.keys):\n            raise ValueError(self._ERR_REF_KEYS_LEN)\n        return tuple(keys_tuple)\n\n    def _prepare_shapes(\n        self, spatial_shape: Sequence[int] | Sequence[Sequence[int]] | None\n    ) -> tuple[tuple[int, ...] | None, ...]:\n        if spatial_shape is None:\n            return (None,) * len(self.keys)\n        shape_tuple = ensure_tuple(spatial_shape)\n        if shape_tuple and all(isinstance(v, (int, np.integer)) for v in shape_tuple):\n            shape = tuple(int(v) for v in shape_tuple)\n            return (shape,) * len(self.keys)\n        if len(shape_tuple) == 1 and len(self.keys) > 1:\n            shape_tuple = shape_tuple * len(self.keys)\n        if len(shape_tuple) != len(self.keys):\n            raise ValueError(self._ERR_SHAPE_LEN)\n        prepared: list[tuple[int, ...] | None] = []\n        for item in shape_tuple:\n            if item is None:\n                prepared.append(None)\n            else:\n                dims = ensure_tuple(item)\n                prepared.append(tuple(int(v) for v in dims))\n        return tuple(prepared)\n\n    def _determine_shape(\n        self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None\n    ) -> tuple[int, ...]:\n        points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)\n        if points_t.ndim != 2:\n            raise ValueError(f\"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.\")\n        spatial_dims = int(points_t.shape[-1])\n        if static_shape is not None:\n            if len(static_shape) == 1 and spatial_dims > 1:\n                static_shape = tuple([static_shape[0]] * spatial_dims)\n            if len(static_shape) != spatial_dims:\n                raise ValueError(\n                    f\"Provided static spatial_shape has {len(static_shape)} dims; expected {spatial_dims}.\"\n                )\n            return static_shape\n        if ref_key is not None and ref_key in data:\n            return self._shape_from_reference(data[ref_key], spatial_dims)\n        raise ValueError(self._ERR_NO_SHAPE)\n\n    def _shape_from_reference(self, reference: Any, spatial_dims: int) -> tuple[int, ...]:\n        if isinstance(reference, MetaTensor):\n            meta_shape = reference.meta.get(\"spatial_shape\")\n            if meta_shape is not None:\n                dims = ensure_tuple(meta_shape)\n                if len(dims) == spatial_dims:\n                    return tuple(int(v) for v in dims)\n            return tuple(int(v) for v in reference.shape[-spatial_dims:])\n        if hasattr(reference, \"shape\"):\n            return tuple(int(v) for v in reference.shape[-spatial_dims:])\n        raise ValueError(self._ERR_REF_NO_SHAPE)\n\n    def _update_spatial_metadata(self, heatmap: MetaTensor, spatial_shape: tuple[int, ...]) -> None:\n        \"\"\"Set spatial_shape explicitly from resolved shape.\"\"\"\n        heatmap.meta[\"spatial_shape\"] = tuple(int(v) for v in spatial_shape)\n\n\nGenerateHeatmapD = GenerateHeatmapDict = GenerateHeatmapd\n\n\nclass ProbNMSd(MapTransform):\n    \"\"\"\n    Performs probability based non-maximum suppression (NMS) on the probabilities map via\n    iteratively selecting the coordinate with highest probability and then move it as well\n    as its surrounding values. The remove range is determined by the parameter `box_size`.\n    If multiple coordinates have the same highest probability, only one of them will be\n    selected.\n\n    Args:\n        spatial_dims: number of spatial dimensions of the input probabilities map.\n            Defaults to 2.\n        sigma: the standard deviation for gaussian filter.\n            It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.\n        prob_threshold: the probability threshold, the function will stop searching if\n            the highest probability is no larger than the threshold. The value should be\n            no less than 0.0. Defaults to 0.5.\n        box_size: the box size (in pixel) to be removed around the pixel with the maximum probability.\n            It can be an integer that defines the size of a square or cube,\n            or a list containing different values for each dimensions. Defaults to 48.\n\n    Return:\n        a list of selected lists, where inner lists contain probability and coordinates.\n        For example, for 3D input, the inner lists are in the form of [probability, x, y, z].\n\n    Raises:\n        ValueError: When ``prob_threshold`` is less than 0.0.\n        ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.\n        ValueError: When ``box_size`` has a less than 1 value.\n\n    \"\"\"\n\n    backend = ProbNMS.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_dims: int = 2,\n        sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 0.0,\n        prob_threshold: float = 0.5,\n        box_size: int | Sequence[int] = 48,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.prob_nms = ProbNMS(\n            spatial_dims=spatial_dims, sigma=sigma, prob_threshold=prob_threshold, box_size=box_size\n        )\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.prob_nms(d[key])\n        return d\n\n\nclass Invertd(MapTransform):\n    \"\"\"\n    Utility transform to invert the previously applied transforms.\n\n    Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it\n    to the data stored at ``keys``.\n\n    ``Invertd``'s output will also include a copy of the metadata\n    dictionary (originally from  ``orig_meta_keys`` or the metadata of ``orig_keys``),\n    with the relevant fields inverted and stored at ``meta_keys``.\n\n    A typical usage is to apply the inverse of the preprocessing (``transform=preprocessings``) on\n    input ``orig_keys=image`` to the model predictions ``keys=pred``.\n\n    A detailed usage example is available in the tutorial:\n    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py\n\n    Note:\n\n        - The output of the inverted data and metadata will be stored at ``keys`` and ``meta_keys`` respectively.\n        - To correctly invert the transforms, the information of the previously applied transforms should be\n          available at ``{orig_keys}_transforms``, and the original metadata at ``orig_meta_keys``.\n          (``meta_key_postfix`` is an optional string to conveniently construct \"meta_keys\" and/or \"orig_meta_keys\".)\n          see also: :py:class:`monai.transforms.TraceableTransform`.\n        - The transform will not change the content in ``orig_keys`` and ``orig_meta_key``.\n          These keys are only used to represent the data status of ``key`` before inverting.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        transform: InvertibleTransform,\n        orig_keys: KeysCollection | None = None,\n        meta_keys: KeysCollection | None = None,\n        orig_meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        nearest_interp: bool | Sequence[bool] = True,\n        to_tensor: bool | Sequence[bool] = True,\n        device: str | torch.device | Sequence[str | torch.device] | None = None,\n        post_func: Callable | Sequence[Callable] | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place.\n                It also can be a list of keys, will apply the inverse transform respectively.\n            transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``.\n            orig_keys: the key of the original input data in the dict. These keys default to `self.keys` if not set.\n                the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``.\n                It can also be a list of keys, each matches the ``keys``.\n            meta_keys: The key to output the inverted metadata dictionary.\n                The metadata is a dictionary optionally containing: filename, original_shape.\n                It can be a sequence of strings, maps to ``keys``.\n                If None, will try to create a metadata dict with the default key: `{key}_{meta_key_postfix}`.\n            orig_meta_keys: the key of the metadata of original input data.\n                The metadata is a dictionary optionally containing: filename, original_shape.\n                It can be a sequence of strings, maps to the `keys`.\n                If None, will try to create a metadata dict with the default key: `{orig_key}_{meta_key_postfix}`.\n                This metadata dict will also be included in the inverted dict, stored in `meta_keys`.\n            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the\n                metadata from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``\"meta_dict\"``.\n            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,\n                default to `True`. If `False`, use the same interpolation mode as the original transform.\n                It also can be a list of bool, each matches to the `keys` data.\n            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.\n                It also can be a list of bool, each matches to the `keys` data.\n            device: if converted to Tensor, move the inverted results to target device before `post_func`,\n                default to None, it also can be a list of string or `torch.device`, each matches to the `keys` data.\n            post_func: post processing for the inverted data, should be a callable function.\n                It also can be a list of callable, each matches to the `keys` data.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        if not isinstance(transform, InvertibleTransform):\n            raise ValueError(\"transform is not invertible, can't invert transform for the data.\")\n        self.transform = transform\n        self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys)) if orig_keys is not None else self.keys\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.orig_meta_keys = ensure_tuple_rep(orig_meta_keys, len(self.keys))\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n        self.nearest_interp = ensure_tuple_rep(nearest_interp, len(self.keys))\n        self.to_tensor = ensure_tuple_rep(to_tensor, len(self.keys))\n        self.device = ensure_tuple_rep(device, len(self.keys))\n        self.post_func = ensure_tuple_rep(post_func, len(self.keys))\n        self._totensor = ToTensor()\n\n    def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:\n        d = dict(data)\n        for (\n            key,\n            orig_key,\n            meta_key,\n            orig_meta_key,\n            meta_key_postfix,\n            nearest_interp,\n            to_tensor,\n            device,\n            post_func,\n        ) in self.key_iterator(\n            d,\n            self.orig_keys,\n            self.meta_keys,\n            self.orig_meta_keys,\n            self.meta_key_postfix,\n            self.nearest_interp,\n            self.to_tensor,\n            self.device,\n            self.post_func,\n        ):\n            if isinstance(d[key], MetaTensor):\n                if orig_key not in d:\n                    warnings.warn(f\"transform info of `{orig_key}` is not available in MetaTensor {key}.\")\n                    continue\n            else:\n                transform_key = InvertibleTransform.trace_key(orig_key)\n                if transform_key not in d:\n                    warnings.warn(f\"transform info of `{orig_key}` is not available or no InvertibleTransform applied.\")\n                    continue\n\n            orig_meta_key = orig_meta_key or f\"{orig_key}_{meta_key_postfix}\"\n            if orig_key in d and isinstance(d[orig_key], MetaTensor):\n                transform_info = d[orig_key].applied_operations\n                meta_info = d[orig_key].meta\n            else:\n                transform_info = d[InvertibleTransform.trace_key(orig_key)]\n                meta_info = d.get(orig_meta_key, {})\n            if nearest_interp:\n                transform_info = convert_applied_interp_mode(\n                    trans_info=transform_info, mode=\"nearest\", align_corners=None\n                )\n\n            inputs = d[key]\n            if isinstance(inputs, torch.Tensor):\n                inputs = inputs.detach()\n\n            if not isinstance(inputs, MetaTensor):\n                inputs = convert_to_tensor(inputs, track_meta=True)\n            inputs.applied_operations = deepcopy(transform_info)\n            inputs.meta = deepcopy(meta_info)\n\n            # construct the input dict data\n            input_dict = {orig_key: inputs}\n            if config.USE_META_DICT:\n                input_dict[InvertibleTransform.trace_key(orig_key)] = transform_info\n                input_dict[PostFix.meta(orig_key)] = meta_info\n            with allow_missing_keys_mode(self.transform):  # type: ignore\n                inverted = self.transform.inverse(input_dict)\n\n            # save the inverted data\n            inverted_data = inverted[orig_key]\n            if to_tensor and not isinstance(inverted_data, MetaTensor):\n                inverted_data = self._totensor(inverted_data)\n            if isinstance(inverted_data, np.ndarray) and device is not None and torch.device(device).type != \"cpu\":\n                raise ValueError(f\"Inverted data with type of 'numpy.ndarray' support device='cpu', got {device}.\")\n            if isinstance(inverted_data, torch.Tensor):\n                inverted_data = inverted_data.to(device=device)\n            d[key] = post_func(inverted_data) if callable(post_func) else inverted_data\n            # save the invertd applied_operations if it's in the source dict\n            if InvertibleTransform.trace_key(orig_key) in d:\n                d[InvertibleTransform.trace_key(orig_key)] = inverted_data.applied_operations\n            # save the inverted meta dict if it's in the source dict\n            if orig_meta_key in d:\n                meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n                d[meta_key] = inverted.get(orig_meta_key)\n        return d\n\n\nclass SaveClassificationd(MapTransform):\n    \"\"\"\n    Save the classification results and metadata into CSV file or other storage.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        saver: CSVSaver | None = None,\n        output_dir: PathLike = \"./\",\n        filename: str = \"predictions.csv\",\n        delimiter: str = \",\",\n        overwrite: bool = True,\n        flush: bool = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to model output, this transform only supports 1 key.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n                the metadata is a dictionary object which contains: filename, original_shape, etc.\n                it can be a sequence of string, map to the `keys`.\n                if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n                will extract the filename of input image to save classification results.\n            meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`.\n                so need the key to extract the metadata of input image, like filename, etc. default is `meta_dict`.\n                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n                the metadata is a dictionary object which contains: filename, original_shape, etc.\n                this arg only works when `meta_keys=None`. if no corresponding metadata, set to `None`.\n            saver: the saver instance to save classification results, if None, create a CSVSaver internally.\n                the saver must provide `save(data, meta_data)` and `finalize()` APIs.\n            output_dir: if `saver=None`, specify the directory to save the CSV file.\n            filename: if `saver=None`, specify the name of the saved CSV file.\n            delimiter: the delimiter character in the saved file, default to \",\" as the default output type is `csv`.\n                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.\n            overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True,\n                will clear the file before saving. otherwise, will append new content to the CSV file.\n            flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately\n                in this transform and clear the cache. default to True.\n                If False, may need user to call `saver.finalize()` manually or use `ClassificationSaver` handler.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        if len(self.keys) != 1:\n            raise ValueError(\"only 1 key is allowed when saving the classification result.\")\n        self.saver = saver or CSVSaver(\n            output_dir=output_dir, filename=filename, overwrite=overwrite, flush=flush, delimiter=delimiter\n        )\n        self.flush = flush\n        self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n\n    def __call__(self, data):\n        d = dict(data)\n        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):\n            if meta_key is None and meta_key_postfix is not None:\n                meta_key = f\"{key}_{meta_key_postfix}\"\n            meta_data = d[meta_key] if meta_key is not None else None\n            self.saver.save(data=d[key], meta_data=meta_data)\n            if self.flush:\n                self.saver.finalize()\n\n        return d\n\n    def get_saver(self):\n        \"\"\"\n        If want to write content into file, may need to call `finalize` of saver when epoch completed.\n        Or users can also get the cache content from `saver` instead of writing into file.\n\n        \"\"\"\n        return self.saver\n\n\nclass SobelGradientsd(MapTransform):\n    \"\"\"Calculate Sobel horizontal and vertical gradients of a grayscale image.\n\n    Args:\n        keys: keys of the corresponding items to model output.\n        kernel_size: the size of the Sobel kernel. Defaults to 3.\n        spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient\n            along each of the provide axis. By default it calculate the gradient for all spatial axes.\n        normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.\n        normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.\n        padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `\"reflect\"`.\n            Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.\n            See ``torch.nn.Conv1d()`` for more information.\n        dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.\n        new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of\n            key intact. By default not prefix is set and the corresponding array to the key will be replaced.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = SobelGradients.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        kernel_size: int = 3,\n        spatial_axes: Sequence[int] | int | None = None,\n        normalize_kernels: bool = True,\n        normalize_gradients: bool = False,\n        padding_mode: str = \"reflect\",\n        dtype: torch.dtype = torch.float32,\n        new_key_prefix: str | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.transform = SobelGradients(\n            kernel_size=kernel_size,\n            spatial_axes=spatial_axes,\n            normalize_kernels=normalize_kernels,\n            normalize_gradients=normalize_gradients,\n            padding_mode=padding_mode,\n            dtype=dtype,\n        )\n        self.new_key_prefix = new_key_prefix\n        self.kernel_diff = self.transform.kernel_diff\n        self.kernel_smooth = self.transform.kernel_smooth\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            new_key = key if self.new_key_prefix is None else self.new_key_prefix + key\n            d[new_key] = self.transform(d[key])\n\n        return d\n\n\nclass DistanceTransformEDTd(MapTransform):\n    \"\"\"\n    Applies the Euclidean distance transform on the input.\n    Either GPU based with CuPy / cuCIM or CPU based with scipy.\n    To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.\n\n    Note that the results of the libraries can differ, so stick to one if possible.\n    For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.\n\n\n    Note on the input shape:\n        Has to be a channel first array, must have shape: (num_channels, H, W [,D]).\n        Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.\n        Input gets passed channel-wise to the distance-transform, thus results from this function will differ\n        from directly calling ``distance_transform_edt()`` in CuPy or SciPy.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        allow_missing_keys: don't raise exception if key is missing.\n        sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;\n            if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.\n\n    .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html\n    .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt\n\n\n    \"\"\"\n\n    backend = DistanceTransformEDT.backend\n\n    def __init__(\n        self, keys: KeysCollection, allow_missing_keys: bool = False, sampling: None | float | list[float] = None\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.sampling = sampling\n        self.distance_transform = DistanceTransformEDT(sampling=self.sampling)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.distance_transform(img=d[key])\n\n        return d\n\n\nActivationsD = ActivationsDict = Activationsd\nAsDiscreteD = AsDiscreteDict = AsDiscreted\nFillHolesD = FillHolesDict = FillHolesd\nInvertD = InvertDict = Invertd\nKeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd\nRemoveSmallObjectsD = RemoveSmallObjectsDict = RemoveSmallObjectsd\nLabelFilterD = LabelFilterDict = LabelFilterd\nLabelToContourD = LabelToContourDict = LabelToContourd\nMeanEnsembleD = MeanEnsembleDict = MeanEnsembled\nProbNMSD = ProbNMSDict = ProbNMSd\nSaveClassificationD = SaveClassificationDict = SaveClassificationd\nVoteEnsembleD = VoteEnsembleDict = VoteEnsembled\nEnsembleD = EnsembleDict = Ensembled\nSobelGradientsD = SobelGradientsDict = SobelGradientsd\nDistanceTransformEDTD = DistanceTransformEDTDict = DistanceTransformEDTd\n"
  },
  {
    "path": "monai/transforms/regularization/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/regularization/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom abc import abstractmethod\nfrom math import ceil, sqrt\n\nimport torch\n\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.transforms.transform import RandomizableTransform\nfrom monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor\n\n__all__ = [\"MixUp\", \"CutMix\", \"CutOut\", \"Mixer\"]\n\n\nclass Mixer(RandomizableTransform):\n\n    def __init__(self, batch_size: int, alpha: float = 1.0) -> None:\n        \"\"\"\n        Mixer is a base class providing the basic logic for the mixup-class of\n        augmentations. In all cases, we need to sample the mixing weights for each\n        sample (lambda in the notation used in the papers). Also, pairs of samples\n        being mixed are picked by randomly shuffling the batch samples.\n\n        Args:\n            batch_size (int): number of samples per batch. That is, samples are expected tp\n                be of size batchsize x channels [x depth] x height x width.\n            alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha)\n                distribution. Defaults to 1.0, the uniform distribution.\n        \"\"\"\n        super().__init__()\n        if alpha <= 0:\n            raise ValueError(f\"Expected positive number, but got {alpha=}\")\n        self.alpha = alpha\n        self.batch_size = batch_size\n\n    @abstractmethod\n    def apply(self, data: torch.Tensor):\n        raise NotImplementedError()\n\n    def randomize(self, data=None) -> None:\n        \"\"\"\n        Sometimes you need may to apply the same transform to different tensors.\n        The idea is to get a sample and then apply it with apply() as often\n        as needed. You need to call this method everytime you apply the transform to a new\n        batch.\n        \"\"\"\n        super().randomize(None)\n        self._params = (\n            torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),\n            self.R.permutation(self.batch_size),\n            [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else [],\n        )\n\n\nclass MixUp(Mixer):\n    \"\"\"MixUp as described in:\n    Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.\n    mixup: Beyond Empirical Risk Minimization, ICLR 2018\n\n    Class derived from :py:class:`monai.transforms.Mixer`. See corresponding\n    documentation for details on the constructor parameters.\n    \"\"\"\n\n    def apply(self, data: torch.Tensor):\n        weight, perm, _ = self._params\n        nsamples, *dims = data.shape\n        if len(weight) != nsamples:\n            raise ValueError(f\"Expected batch of size: {len(weight)}, but got {nsamples}\")\n\n        if len(dims) not in [3, 4]:\n            raise ValueError(\"Unexpected number of dimensions\")\n\n        mixweight = weight[(Ellipsis,) + (None,) * len(dims)]\n        return mixweight * data + (1 - mixweight) * data[perm, ...]\n\n    def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):\n        data_t = convert_to_tensor(data, track_meta=get_track_meta())\n        labels_t = data_t  # will not stay this value, needed to satisfy pylint/mypy\n        if labels is not None:\n            labels_t = convert_to_tensor(labels, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n        if labels is None:\n            return convert_to_dst_type(self.apply(data_t), dst=data)[0]\n\n        return (\n            convert_to_dst_type(self.apply(data_t), dst=data)[0],\n            convert_to_dst_type(self.apply(labels_t), dst=labels)[0],\n        )\n\n\nclass CutMix(Mixer):\n    \"\"\"CutMix augmentation as described in:\n        Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.\n        CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,\n        ICCV 2019\n\n        Class derived from :py:class:`monai.transforms.Mixer`. See corresponding\n        documentation for details on the constructor parameters. Here, alpha not only determines\n        the mixing weight but also the size of the random rectangles used during for mixing.\n        Please refer to the paper for details.\n\n        Please note that there is a change in behavior starting from version 1.4.0. In the previous\n        implementation, the transform would generate a different label each time it was called.\n        To ensure determinism, the new implementation will now generate the same label for\n        the same input image when using the same operation.\n\n        The most common use case is something close to:\n\n    .. code-block:: python\n\n        cm = CutMix(batch_size=8, alpha=0.5)\n        for batch in loader:\n            images, labels = batch\n            augimg, auglabels = cm(images, labels)\n            output = model(augimg)\n            loss = loss_function(output, auglabels)\n            ...\n\n    \"\"\"\n\n    def apply(self, data: torch.Tensor):\n        weights, perm, coords = self._params\n        nsamples, _, *dims = data.shape\n        if len(weights) != nsamples:\n            raise ValueError(f\"Expected batch of size: {len(weights)}, but got {nsamples}\")\n\n        mask = torch.ones_like(data)\n        for s, weight in enumerate(weights):\n            lengths = [d * sqrt(1 - weight) for d in dims]\n            idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]\n            mask[s][idx] = 0\n\n        return mask * data + (1 - mask) * data[perm, ...]\n\n    def apply_on_labels(self, labels: torch.Tensor):\n        weights, perm, _ = self._params\n        nsamples, *dims = labels.shape\n        if len(weights) != nsamples:\n            raise ValueError(f\"Expected batch of size: {len(weights)}, but got {nsamples}\")\n\n        mixweight = weights[(Ellipsis,) + (None,) * len(dims)]\n        return mixweight * labels + (1 - mixweight) * labels[perm, ...]\n\n    def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):\n        data_t = convert_to_tensor(data, track_meta=get_track_meta())\n        augmented_label = None\n        if labels is not None:\n            labels_t = convert_to_tensor(labels, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(data)\n        augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0]\n\n        if labels is not None:\n            augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0]\n        return (augmented, augmented_label) if labels is not None else augmented\n\n\nclass CutOut(Mixer):\n    \"\"\"Cutout as described in the paper:\n    Terrance DeVries, Graham W. Taylor.\n    Improved Regularization of Convolutional Neural Networks with Cutout,\n    arXiv:1708.04552\n\n    Class derived from :py:class:`monai.transforms.Mixer`. See corresponding\n    documentation for details on the constructor parameters. Here, alpha not only determines\n    the mixing weight but also the size of the random rectangles being cut put.\n    Please refer to the paper for details.\n    \"\"\"\n\n    def apply(self, data: torch.Tensor):\n        weights, _, coords = self._params\n        nsamples, _, *dims = data.shape\n        if len(weights) != nsamples:\n            raise ValueError(f\"Expected batch of size: {len(weights)}, but got {nsamples}\")\n\n        mask = torch.ones_like(data)\n        for s, weight in enumerate(weights):\n            lengths = [d * sqrt(1 - weight) for d in dims]\n            idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]\n            mask[s][idx] = 0\n\n        return mask * data\n\n    def __call__(self, data: torch.Tensor, randomize=True):\n        data_t = convert_to_tensor(data, track_meta=get_track_meta())\n        if randomize:\n            self.randomize(data)\n        return convert_to_dst_type(self.apply(data_t), dst=data)[0]\n"
  },
  {
    "path": "monai/transforms/regularization/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable\n\nimport numpy as np\n\nfrom monai.config import KeysCollection\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.transforms.transform import MapTransform, RandomizableTransform\nfrom monai.utils import convert_to_tensor\nfrom monai.utils.misc import ensure_tuple\n\nfrom .array import CutMix, CutOut, MixUp\n\n__all__ = [\"MixUpd\", \"MixUpD\", \"MixUpDict\", \"CutMixd\", \"CutMixD\", \"CutMixDict\", \"CutOutd\", \"CutOutD\", \"CutOutDict\"]\n\n\nclass MixUpd(MapTransform, RandomizableTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.MixUp`.\n\n    Notice that the mixup transformation will be the same for all entries\n    for consistency, i.e. images and labels must be applied the same augmenation.\n    \"\"\"\n\n    def __init__(\n        self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.mixup = MixUp(batch_size, alpha)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd:\n        super().set_random_state(seed, state)\n        self.mixup.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data):\n        d = dict(data)\n        # all the keys share the same random state\n        self.mixup.randomize(None)\n        for k in self.key_iterator(d):\n            d[k] = self.mixup(data[k], randomize=False)\n        return d\n\n\nclass CutMixd(MapTransform, RandomizableTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.CutMix`.\n\n    Notice that the mixture weights will be the same for all entries\n    for consistency, i.e. images and labels must be aggregated with the same weights,\n    but the random crops are not.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        batch_size: int,\n        label_keys: KeysCollection | None = None,\n        alpha: float = 1.0,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.mixer = CutMix(batch_size, alpha)\n        self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutMixd:\n        super().set_random_state(seed, state)\n        self.mixer.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data):\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n        self.mixer.randomize(d[first_key])\n        for key, label_key in self.key_iterator(d, self.label_keys):\n            ret = self.mixer(data[key], data.get(label_key, None), randomize=False)\n            d[key] = ret[0]\n            if label_key in d:\n                d[label_key] = ret[1]\n        return d\n\n\nclass CutOutd(MapTransform, RandomizableTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.CutOut`.\n\n    Notice that the cutout is different for every entry in the dictionary.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.cutout = CutOut(batch_size)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutOutd:\n        super().set_random_state(seed, state)\n        self.cutout.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data):\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n        self.cutout.randomize(d[first_key])\n        for k in self.key_iterator(d):\n            d[k] = self.cutout(data[k], randomize=False)\n        return d\n\n\nMixUpD = MixUpDict = MixUpd\nCutMixD = CutMixDict = CutMixd\nCutOutD = CutOutDict = CutOutd\n"
  },
  {
    "path": "monai/transforms/signal/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/signal/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of transforms for signal operations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.transforms.transform import RandomizableTransform, Transform\nfrom monai.transforms.utils import check_boundaries, paste, squarepulse\nfrom monai.utils import optional_import\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.type_conversion import convert_data_type, convert_to_tensor\n\nshift, has_shift = optional_import(\"scipy.ndimage\", name=\"shift\")\niirnotch, has_iirnotch = optional_import(\"scipy.signal\", name=\"iirnotch\")\nwith warnings.catch_warnings():\n    warnings.simplefilter(\"ignore\", UserWarning)  # project-monai/monai#5204\n    filtfilt, has_filtfilt = optional_import(\"torchaudio.functional\", name=\"filtfilt\")\ncentral_frequency, has_central_frequency = optional_import(\"pywt\", name=\"central_frequency\")\ncwt, has_cwt = optional_import(\"pywt\", name=\"cwt\")\n\n__all__ = [\n    \"SignalRandDrop\",\n    \"SignalRandScale\",\n    \"SignalRandShift\",\n    \"SignalRandAddSine\",\n    \"SignalRandAddSquarePulse\",\n    \"SignalRandAddGaussianNoise\",\n    \"SignalRandAddSinePartial\",\n    \"SignalRandAddSquarePulsePartial\",\n    \"SignalFillEmpty\",\n    \"SignalRemoveFrequency\",\n    \"SignalContinuousWavelet\",\n]\n\n\nclass SignalRandShift(RandomizableTransform):\n    \"\"\"\n    Apply a random shift on a signal\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]\n\n    def __init__(\n        self, mode: str | None = \"wrap\", filling: float | None = 0.0, boundaries: Sequence[float] = (-1.0, 1.0)\n    ) -> None:\n        \"\"\"\n        Args:\n            mode: define how the extension of the input array is done beyond its boundaries, see for more details :\n                https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.shift.html.\n            filling: value to fill past edges of input if mode is ‘constant’. Default is 0.0. see for mode details :\n                https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.shift.html.\n            boundaries: list defining lower and upper boundaries for the signal shift, default : ``[-1.0, 1.0]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.filling = filling\n        self.mode = mode\n        self.boundaries = boundaries\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to be shifted\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        length = signal.shape[1]\n        shift_idx = round(self.magnitude * length)\n        sig = convert_data_type(signal, np.ndarray)[0]\n        signal = convert_to_tensor(shift(input=sig, mode=self.mode, shift=shift_idx, cval=self.filling))\n        return signal\n\n\nclass SignalRandScale(RandomizableTransform):\n    \"\"\"\n    Apply a random rescaling on a signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, boundaries: Sequence[float] = (-1.0, 1.0)) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the signal scaling, default : ``[-1.0, 1.0]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to be scaled\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        signal = convert_to_tensor(self.magnitude * signal)\n\n        return signal\n\n\nclass SignalRandDrop(RandomizableTransform):\n    \"\"\"\n    Randomly drop a portion of a signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, boundaries: Sequence[float] = (0.0, 1.0)) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the signal drop,\n            lower and upper values need to be positive default : ``[0.0, 1.0]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to be dropped\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n\n        length = signal.shape[-1]\n        mask = torch.zeros(round(self.magnitude * length))\n        trange = torch.arange(length)\n        loc = trange[torch.randint(0, trange.size(0), (1,))]\n        signal = convert_to_tensor(paste(signal, mask, (loc,)))\n\n        return signal\n\n\nclass SignalRandAddSine(RandomizableTransform):\n    \"\"\"\n    Add a random sinusoidal signal to the input signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, boundaries: Sequence[float] = (0.1, 0.3), frequencies: Sequence[float] = (0.001, 0.02)) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the sinusoidal magnitude,\n                lower and upper values need to be positive ,default : ``[0.1, 0.3]``\n            frequencies: list defining lower and upper frequencies for sinusoidal\n                signal generation ,default : ``[0.001, 0.02]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n        self.frequencies = frequencies\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to which sinusoidal signal will be added\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])\n\n        length = signal.shape[1]\n\n        time = np.arange(0, length, 1)\n        data = convert_to_tensor(self.freqs * time)\n        sine = self.magnitude * torch.sin(data)\n        signal = convert_to_tensor(signal) + sine\n\n        return signal\n\n\nclass SignalRandAddSquarePulse(RandomizableTransform):\n    \"\"\"\n    Add a random square pulse signal to the input signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, boundaries: Sequence[float] = (0.01, 0.2), frequencies: Sequence[float] = (0.001, 0.02)) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the square pulse magnitude,\n                lower and upper values need to be positive , default : ``[0.01, 0.2]``\n            frequencies: list defining lower and upper frequencies for the square pulse\n                signal generation , default : ``[0.001, 0.02]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n        self.frequencies = frequencies\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to which square pulse will be added\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])\n\n        length = signal.shape[1]\n\n        time = np.arange(0, length, 1)\n        squaredpulse = self.magnitude * squarepulse(self.freqs * time)\n        signal = convert_to_tensor(signal) + squaredpulse\n\n        return signal\n\n\nclass SignalRandAddSinePartial(RandomizableTransform):\n    \"\"\"\n    Add a random partial sinusoidal signal to the input signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        boundaries: Sequence[float] = (0.1, 0.3),\n        frequencies: Sequence[float] = (0.001, 0.02),\n        fraction: Sequence[float] = (0.01, 0.2),\n    ) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the sinusoidal magnitude,\n                lower and upper values need to be positive , default : ``[0.1, 0.3]``\n            frequencies: list defining lower and upper frequencies for sinusoidal\n                signal generation , default : ``[0.001, 0.02]``\n            fraction: list defining lower and upper boundaries for partial signal generation\n                default : ``[0.01, 0.2]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n        self.frequencies = frequencies\n        self.fraction = fraction\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to which a partial sinusoidal signal\n            will be added\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        self.fracs = self.R.uniform(low=self.fraction[0], high=self.fraction[1])\n        self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])\n\n        length = signal.shape[-1]\n\n        time_partial = np.arange(0, round(self.fracs * length), 1)\n        data = convert_to_tensor(self.freqs * time_partial)\n        sine_partial = self.magnitude * torch.sin(data)\n\n        loc = np.random.choice(range(length))\n        signal = paste(signal, sine_partial, (loc,))\n\n        return signal\n\n\nclass SignalRandAddGaussianNoise(RandomizableTransform):\n    \"\"\"\n    Add a random gaussian noise to the input signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, boundaries: Sequence[float] = (0.001, 0.02)) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the signal magnitude,\n                default : ``[0.001,0.02]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to which gaussian noise will be added\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        length = signal.shape[1]\n        gaussiannoise = self.magnitude * torch.randn(length)\n\n        signal = convert_to_tensor(signal) + gaussiannoise\n\n        return signal\n\n\nclass SignalRandAddSquarePulsePartial(RandomizableTransform):\n    \"\"\"\n    Add a random partial square pulse to a signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        boundaries: Sequence[float] = (0.01, 0.2),\n        frequencies: Sequence[float] = (0.001, 0.02),\n        fraction: Sequence[float] = (0.01, 0.2),\n    ) -> None:\n        \"\"\"\n        Args:\n            boundaries: list defining lower and upper boundaries for the square pulse magnitude,\n                lower and upper values need to be positive , default : ``[0.01, 0.2]``\n            frequencies: list defining lower and upper frequencies for square pulse\n                signal generation example : ``[0.001, 0.02]``\n            fraction: list defining lower and upper boundaries for partial square pulse generation\n                default: ``[0.01, 0.2]``\n        \"\"\"\n        super().__init__()\n        check_boundaries(boundaries)\n        self.boundaries = boundaries\n        self.frequencies = frequencies\n        self.fraction = fraction\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: input 1 dimension signal to which a partial square pulse will be added\n        \"\"\"\n        self.randomize(None)\n        self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])\n        self.fracs = self.R.uniform(low=self.fraction[0], high=self.fraction[1])\n        self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])\n\n        length = signal.shape[-1]\n\n        time_partial = np.arange(0, round(self.fracs * length), 1)\n        squaredpulse_partial = self.magnitude * squarepulse(self.freqs * time_partial)\n\n        loc = np.random.choice(range(length))\n        signal = paste(signal, squaredpulse_partial, (loc,))\n\n        return signal\n\n\nclass SignalFillEmpty(Transform):\n    \"\"\"\n    replace empty part of a signal (NaN)\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, replacement: float = 0.0) -> None:\n        \"\"\"\n        Args:\n            replacement: value to replace nan items in signal\n        \"\"\"\n        super().__init__()\n        self.replacement = replacement\n\n    def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            signal: signal to be filled\n        \"\"\"\n        signal = torch.nan_to_num(convert_to_tensor(signal, track_meta=True), nan=self.replacement)\n        return signal\n\n\nclass SignalRemoveFrequency(Transform):\n    \"\"\"\n    Remove a frequency from a signal\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self, frequency: float | None = None, quality_factor: float | None = None, sampling_freq: float | None = None\n    ) -> None:\n        \"\"\"\n        Args:\n            frequency: frequency to be removed from the signal\n            quality_factor: quality factor for notch filter\n                see : https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirnotch.html\n            sampling_freq: sampling frequency of the input signal\n        \"\"\"\n        super().__init__()\n        self.frequency = frequency\n        self.quality_factor = quality_factor\n        self.sampling_freq = sampling_freq\n\n    def __call__(self, signal: np.ndarray) -> Any:\n        \"\"\"\n        Args:\n            signal: signal to be frequency removed\n        \"\"\"\n        b_notch, a_notch = convert_to_tensor(\n            iirnotch(self.frequency, self.quality_factor, self.sampling_freq), dtype=torch.float\n        )\n        y_notched = filtfilt(convert_to_tensor(signal, dtype=torch.float), a_notch, b_notch)\n\n        return y_notched\n\n\nclass SignalContinuousWavelet(Transform):\n    \"\"\"\n    Generate continuous wavelet transform of a signal\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, type: str = \"mexh\", length: float = 125.0, frequency: float = 500.0) -> None:\n        \"\"\"\n        Args:\n            type: mother wavelet type.\n                Available options are: {``\"mexh\"``, ``\"morl\"``, ``\"cmorB-C\"``, , ``\"gausP\"``}\n            see : https://pywavelets.readthedocs.io/en/latest/ref/cwt.html\n            length: expected length, default ``125.0``\n            frequency: signal frequency, default ``500.0``\n        \"\"\"\n        super().__init__()\n        self.frequency = frequency\n        self.length = length\n        self.type = type\n\n    def __call__(self, signal: np.ndarray) -> Any:\n        \"\"\"\n        Args:\n            signal: signal for which to generate continuous wavelet transform\n        \"\"\"\n        mother_wavelet = self.type\n        spread = np.arange(1, self.length + 1, 1)\n        scales = central_frequency(mother_wavelet) * self.frequency / spread\n\n        coeffs, _ = cwt(signal, scales, mother_wavelet, 1.0 / self.frequency)\n\n        coeffs = np.transpose(coeffs, [1, 0, 2])\n\n        return coeffs\n"
  },
  {
    "path": "monai/transforms/signal/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the signal operations defined in :py:class:`monai.transforms.signal.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping\n\nfrom monai.config.type_definitions import KeysCollection, NdarrayOrTensor\nfrom monai.transforms.signal.array import SignalFillEmpty\nfrom monai.transforms.transform import MapTransform\n\n__all__ = [\"SignalFillEmptyd\", \"SignalFillEmptyD\", \"SignalFillEmptyDict\"]\n\n\nclass SignalFillEmptyd(MapTransform):\n    \"\"\"\n    Applies the SignalFillEmptyd transform on the input. All NaN values will be replaced with the\n    replacement value.\n\n    Args:\n        keys: keys of the corresponding items to model output.\n        allow_missing_keys: don't raise exception if key is missing.\n        replacement: The value that the NaN entries shall be mapped to.\n    \"\"\"\n\n    backend = SignalFillEmpty.backend\n\n    def __init__(self, keys: KeysCollection = None, allow_missing_keys: bool = False, replacement=0.0):\n        super().__init__(keys, allow_missing_keys)\n        self.signal_fill_empty = SignalFillEmpty(replacement=replacement)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        for key in self.key_iterator(data):\n            data[key] = self.signal_fill_empty(data[key])  # type: ignore\n\n        return data\n\n\nSignalFillEmptyD = SignalFillEmptyDict = SignalFillEmptyd\n"
  },
  {
    "path": "monai/transforms/smooth_field/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/smooth_field/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Transforms using a smooth spatial field generated by interpolating from smaller randomized fields.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom torch.nn.functional import grid_sample, interpolate\n\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.networks.utils import meshgrid_ij\nfrom monai.transforms.transform import Randomizable, RandomizableTransform\nfrom monai.transforms.utils_pytorch_numpy_unification import moveaxis\nfrom monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.module import look_up_option\nfrom monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor\n\n__all__ = [\"SmoothField\", \"RandSmoothFieldAdjustContrast\", \"RandSmoothFieldAdjustIntensity\", \"RandSmoothDeform\"]\n\n\nclass SmoothField(Randomizable):\n    \"\"\"\n    Generate a smooth field array by defining a smaller randomized field and then reinterpolating to the desired size.\n\n    This exploits interpolation to create a smoothly varying field used for other applications. An initial randomized\n    field is defined with `rand_size` dimensions with `pad` number of values padding it along each dimension using\n    `pad_val` as the value. If `spatial_size` is given this is interpolated to that size, otherwise if None the random\n    array is produced uninterpolated. The output is always a Pytorch tensor allocated on the specified device.\n\n    Args:\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with `pad_val`\n        pad_val: value with which to pad field edges\n        low: low value for randomized field\n        high: high value for randomized field\n        channels: number of channels of final output\n        spatial_size: final output size of the array, None to produce original uninterpolated field\n        mode: interpolation mode for resizing the field\n        align_corners: if True align the corners when upsampling field\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        rand_size: Sequence[int],\n        pad: int = 0,\n        pad_val: float = 0,\n        low: float = -1.0,\n        high: float = 1.0,\n        channels: int = 1,\n        spatial_size: Sequence[int] | None = None,\n        mode: str = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        device: torch.device | None = None,\n    ):\n        self.rand_size = tuple(rand_size)\n        self.pad = pad\n        self.low = low\n        self.high = high\n        self.channels = channels\n        self.mode = mode\n        self.align_corners = align_corners\n        self.device = device\n\n        self.spatial_size: Sequence[int] | None = None\n        self.spatial_zoom: Sequence[float] | None = None\n\n        if low >= high:\n            raise ValueError(\"Value for `low` must be less than `high` otherwise field will be zeros\")\n\n        self.total_rand_size = tuple(rs + self.pad * 2 for rs in self.rand_size)\n\n        self.field = torch.ones((1, self.channels) + self.total_rand_size, device=self.device) * pad_val\n\n        self.crand_size = (self.channels,) + self.rand_size\n\n        pad_slice = slice(None) if self.pad == 0 else slice(self.pad, -self.pad)\n        self.rand_slices = (0, slice(None)) + (pad_slice,) * len(self.rand_size)\n\n        self.set_spatial_size(spatial_size)\n\n    def randomize(self, data: Any | None = None) -> None:\n        self.field[self.rand_slices] = torch.from_numpy(self.R.uniform(self.low, self.high, self.crand_size))  # type: ignore[index]\n\n    def set_spatial_size(self, spatial_size: Sequence[int] | None) -> None:\n        \"\"\"\n        Set the `spatial_size` and `spatial_zoom` attributes used for interpolating the field to the given\n        dimension, or not interpolate at all if None.\n\n        Args:\n            spatial_size: new size to interpolate to, or None to not interpolate\n        \"\"\"\n        if spatial_size is None:\n            self.spatial_size = None\n            self.spatial_zoom = None\n        else:\n            self.spatial_size = tuple(spatial_size)\n            self.spatial_zoom = tuple(s / f for s, f in zip(self.spatial_size, self.total_rand_size))\n\n    def set_mode(self, mode: str) -> None:\n        self.mode = mode\n\n    def __call__(self, randomize=False) -> torch.Tensor:\n        if randomize:\n            self.randomize()\n\n        field = self.field.clone()\n\n        if self.spatial_zoom is not None:\n            resized_field = interpolate(\n                input=field,\n                scale_factor=self.spatial_zoom,\n                mode=look_up_option(self.mode, InterpolateMode),\n                align_corners=self.align_corners,\n                recompute_scale_factor=False,\n            )\n\n            mina = resized_field.min()\n            maxa = resized_field.max()\n            minv = self.field.min()\n            maxv = self.field.max()\n\n            # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks\n            norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina)\n            field = norm_field.mul_(maxv - minv).add_(minv)\n\n        return field\n\n\nclass RandSmoothFieldAdjustContrast(RandomizableTransform):\n    \"\"\"\n    Randomly adjust the contrast of input images by calculating a randomized smooth field for each invocation.\n\n    This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the\n    edges of the input volume of that width will be mostly unchanged. Contrast is changed by raising input\n    values by the power of the smooth field so the range of values given by `gamma` should be chosen with this\n    in mind. For example, a minimum value of 0 in `gamma` will produce white areas so this should be avoided.\n    After the contrast is adjusted the values of the result are rescaled to the range of the original input.\n\n    Args:\n        spatial_size: size of input array's spatial dimensions\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with 1\n        mode: interpolation mode to use when upsampling\n        align_corners: if True align the corners when upsampling field\n        prob: probability transform is applied\n        gamma: (min, max) range for exponential field\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int],\n        rand_size: Sequence[int],\n        pad: int = 0,\n        mode: str = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        prob: float = 0.1,\n        gamma: Sequence[float] | float = (0.5, 4.5),\n        device: torch.device | None = None,\n    ):\n        super().__init__(prob)\n\n        if isinstance(gamma, (int, float)):\n            self.gamma = (0.5, gamma)\n        else:\n            if len(gamma) != 2:\n                raise ValueError(\"Argument `gamma` should be a number or pair of numbers.\")\n\n            self.gamma = (min(gamma), max(gamma))\n\n        self.sfield = SmoothField(\n            rand_size=rand_size,\n            pad=pad,\n            pad_val=1,\n            low=self.gamma[0],\n            high=self.gamma[1],\n            channels=1,\n            spatial_size=spatial_size,\n            mode=mode,\n            align_corners=align_corners,\n            device=device,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSmoothFieldAdjustContrast:\n        super().set_random_state(seed, state)\n        self.sfield.set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n\n        if self._do_transform:\n            self.sfield.randomize()\n\n    def set_mode(self, mode: str) -> None:\n        self.sfield.set_mode(mode)\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        img_min = img.min()\n        img_max = img.max()\n        img_rng = img_max - img_min\n\n        field = self.sfield()\n        rfield, *_ = convert_to_dst_type(field, img)\n\n        # everything below here is to be computed using the destination type (numpy, tensor, etc.)\n\n        img = (img - img_min) / (img_rng + 1e-10)  # rescale to unit values\n        img = img**rfield  # contrast is changed by raising image data to a power, in this case the field\n\n        out = (img * img_rng) + img_min  # rescale back to the original image value range\n\n        return out\n\n\nclass RandSmoothFieldAdjustIntensity(RandomizableTransform):\n    \"\"\"\n    Randomly adjust the intensity of input images by calculating a randomized smooth field for each invocation.\n\n    This uses SmoothField internally to define the adjustment over the image. If `pad` is greater than 0 the\n    edges of the input volume of that width will be mostly unchanged. Intensity is changed by multiplying the\n    inputs by the smooth field, so the values of `gamma` should be chosen with this in mind. The default values\n    of `(0.1, 1.0)` are sensible in that values will not be zeroed out by the field nor multiplied greater than\n    the original value range.\n\n    Args:\n        spatial_size: size of input array\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with 1\n        mode: interpolation mode to use when upsampling\n        align_corners: if True align the corners when upsampling field\n        prob: probability transform is applied\n        gamma: (min, max) range of intensity multipliers\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int],\n        rand_size: Sequence[int],\n        pad: int = 0,\n        mode: str = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        prob: float = 0.1,\n        gamma: Sequence[float] | float = (0.1, 1.0),\n        device: torch.device | None = None,\n    ):\n        super().__init__(prob)\n\n        if isinstance(gamma, (int, float)):\n            self.gamma = (0.5, gamma)\n        else:\n            if len(gamma) != 2:\n                raise ValueError(\"Argument `gamma` should be a number or pair of numbers.\")\n\n            self.gamma = (min(gamma), max(gamma))\n\n        self.sfield = SmoothField(\n            rand_size=rand_size,\n            pad=pad,\n            pad_val=1,\n            low=self.gamma[0],\n            high=self.gamma[1],\n            channels=1,\n            spatial_size=spatial_size,\n            mode=mode,\n            align_corners=align_corners,\n            device=device,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSmoothFieldAdjustIntensity:\n        super().set_random_state(seed, state)\n        self.sfield.set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n\n        if self._do_transform:\n            self.sfield.randomize()\n\n    def set_mode(self, mode: str) -> None:\n        self.sfield.set_mode(mode)\n\n    def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        field = self.sfield()\n        rfield, *_ = convert_to_dst_type(field, img)\n\n        # everything below here is to be computed using the destination type (numpy, tensor, etc.)\n\n        out = img * rfield\n\n        return out\n\n\nclass RandSmoothDeform(RandomizableTransform):\n    \"\"\"\n    Deform an image using a random smooth field and Pytorch's grid_sample.\n\n    The amount of deformation is given by `def_range` in fractions of the size of the image. The size of each dimension\n    of the input image is always defined as 2 regardless of actual image voxel dimensions, that is the coordinates in\n    every dimension range from -1 to 1. A value of 0.1 means pixels/voxels can be moved by up to 5% of the image's size.\n\n    Args:\n        spatial_size: input array size to which deformation grid is interpolated\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with 0\n        field_mode: interpolation mode to use when upsampling the deformation field\n        align_corners: if True align the corners when upsampling field\n        prob: probability transform is applied\n        def_range: value of the deformation range in image size fractions, single min/max value  or min/max pair\n        grid_dtype: type for the deformation grid calculated from the field\n        grid_mode: interpolation mode used for sampling input using deformation grid\n        grid_padding_mode: padding mode used for sampling input using deformation grid\n        grid_align_corners: if True align the corners when sampling the deformation grid\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int],\n        rand_size: Sequence[int],\n        pad: int = 0,\n        field_mode: str = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        prob: float = 0.1,\n        def_range: Sequence[float] | float = 1.0,\n        grid_dtype=torch.float32,\n        grid_mode: str = GridSampleMode.NEAREST,\n        grid_padding_mode: str = GridSamplePadMode.BORDER,\n        grid_align_corners: bool | None = False,\n        device: torch.device | None = None,\n    ):\n        super().__init__(prob)\n\n        self.grid_dtype = grid_dtype\n        self.grid_mode = grid_mode\n        self.def_range = def_range\n        self.device = device\n        self.grid_align_corners = grid_align_corners\n        self.grid_padding_mode = grid_padding_mode\n\n        if isinstance(def_range, (int, float)):\n            self.def_range = (-def_range, def_range)\n        else:\n            if len(def_range) != 2:\n                raise ValueError(\"Argument `def_range` should be a number or pair of numbers.\")\n\n            self.def_range = (min(def_range), max(def_range))\n\n        self.sfield = SmoothField(\n            spatial_size=spatial_size,\n            rand_size=rand_size,\n            pad=pad,\n            low=self.def_range[0],\n            high=self.def_range[1],\n            channels=len(rand_size),\n            mode=field_mode,\n            align_corners=align_corners,\n            device=device,\n        )\n\n        grid_space = tuple(spatial_size) if spatial_size is not None else self.sfield.field.shape[2:]\n        grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space]\n\n        grid = meshgrid_ij(*grid_ranges)\n\n        self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype)\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:\n        super().set_random_state(seed, state)\n        self.sfield.set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n\n        if self._do_transform:\n            self.sfield.randomize()\n\n    def set_field_mode(self, mode: str) -> None:\n        self.sfield.set_mode(mode)\n\n    def set_grid_mode(self, mode: str) -> None:\n        self.grid_mode = mode\n\n    def __call__(\n        self, img: NdarrayOrTensor, randomize: bool = True, device: torch.device | None = None\n    ) -> NdarrayOrTensor:\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if randomize:\n            self.randomize()\n\n        if not self._do_transform:\n            return img\n\n        device = device if device is not None else self.device\n\n        field = self.sfield()\n\n        dgrid = self.grid + field.to(self.grid_dtype)\n        dgrid = moveaxis(dgrid, 1, -1)  # type: ignore\n        dgrid = dgrid[..., list(range(dgrid.shape[-1] - 1, -1, -1))]  # invert order of coordinates\n\n        img_t = convert_to_tensor(img[None], torch.float32, device)\n\n        out = grid_sample(\n            input=img_t,\n            grid=dgrid,\n            mode=look_up_option(self.grid_mode, GridSampleMode),\n            align_corners=self.grid_align_corners,\n            padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode),\n        )\n\n        out_t, *_ = convert_to_dst_type(out.squeeze(0), img)\n\n        return out_t\n"
  },
  {
    "path": "monai/transforms/smooth_field/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Hashable, Mapping, Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import KeysCollection, SequenceStr\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.transforms.smooth_field.array import (\n    RandSmoothDeform,\n    RandSmoothFieldAdjustContrast,\n    RandSmoothFieldAdjustIntensity,\n)\nfrom monai.transforms.transform import MapTransform, RandomizableTransform\nfrom monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, convert_to_tensor, ensure_tuple_rep\n\n__all__ = [\n    \"RandSmoothFieldAdjustContrastd\",\n    \"RandSmoothFieldAdjustIntensityd\",\n    \"RandSmoothDeformd\",\n    \"RandSmoothFieldAdjustContrastD\",\n    \"RandSmoothFieldAdjustIntensityD\",\n    \"RandSmoothDeformD\",\n    \"RandSmoothFieldAdjustContrastDict\",\n    \"RandSmoothFieldAdjustIntensityDict\",\n    \"RandSmoothDeformDict\",\n]\n\n\nclass RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary version of RandSmoothFieldAdjustContrast.\n\n    The field is randomized once per invocation by default so the same field is applied to every selected key. The\n    `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with\n    one for each key in `keys`.\n\n    Args:\n        keys: key names to apply the augment to\n        spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with 0\n        mode: interpolation mode to use when upsampling\n        align_corners: if True align the corners when upsampling field\n        prob: probability transform is applied\n        gamma: (min, max) range for exponential field\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = RandSmoothFieldAdjustContrast.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int],\n        rand_size: Sequence[int],\n        pad: int = 0,\n        mode: SequenceStr = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        prob: float = 0.1,\n        gamma: Sequence[float] | float = (0.5, 4.5),\n        device: torch.device | None = None,\n    ):\n        RandomizableTransform.__init__(self, prob)\n        MapTransform.__init__(self, keys)\n\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n\n        self.trans = RandSmoothFieldAdjustContrast(\n            spatial_size=spatial_size,\n            rand_size=rand_size,\n            pad=pad,\n            mode=self.mode[0],\n            align_corners=align_corners,\n            prob=1.0,\n            gamma=gamma,\n            device=device,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSmoothFieldAdjustContrastd:\n        super().set_random_state(seed, state)\n        self.trans.set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n\n        if self._do_transform:\n            self.trans.randomize()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        self.randomize()\n        d = dict(data)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        for idx, key in enumerate(self.key_iterator(d)):\n            self.trans.set_mode(self.mode[idx % len(self.mode)])\n            d[key] = self.trans(d[key], False)\n\n        return d\n\n\nclass RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary version of RandSmoothFieldAdjustIntensity.\n\n    The field is randomized once per invocation by default so the same field is applied to every selected key. The\n    `mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values with\n    one for each key in `keys`.\n\n    Args:\n        keys: key names to apply the augment to\n        spatial_size: size of input arrays, all arrays stated in `keys` must have same dimensions\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with 0\n        mode: interpolation mode to use when upsampling\n        align_corners: if True align the corners when upsampling field\n        prob: probability transform is applied\n        gamma: (min, max) range of intensity multipliers\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = RandSmoothFieldAdjustIntensity.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int],\n        rand_size: Sequence[int],\n        pad: int = 0,\n        mode: SequenceStr = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        prob: float = 0.1,\n        gamma: Sequence[float] | float = (0.1, 1.0),\n        device: torch.device | None = None,\n    ):\n        RandomizableTransform.__init__(self, prob)\n        MapTransform.__init__(self, keys)\n\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n\n        self.trans = RandSmoothFieldAdjustIntensity(\n            spatial_size=spatial_size,\n            rand_size=rand_size,\n            pad=pad,\n            mode=self.mode[0],\n            align_corners=align_corners,\n            prob=1.0,\n            gamma=gamma,\n            device=device,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSmoothFieldAdjustIntensityd:\n        super().set_random_state(seed, state)\n        self.trans.set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        self.trans.randomize()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        self.randomize()\n\n        d = dict(data)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        for idx, key in enumerate(self.key_iterator(d)):\n            self.trans.set_mode(self.mode[idx % len(self.mode)])\n            d[key] = self.trans(d[key], False)\n\n        return d\n\n\nclass RandSmoothDeformd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary version of RandSmoothDeform.\n\n    The field is randomized once per invocation by default so the same field is applied to every selected key. The\n    `field_mode` parameter specifying interpolation mode for the field can be a single value or a sequence of values\n    with one for each key in `keys`. Similarly the `grid_mode` parameter can be one value or one per key.\n\n    Args:\n        keys: key names to apply the augment to\n        spatial_size: input array size to which deformation grid is interpolated\n        rand_size: size of the randomized field to start from\n        pad: number of pixels/voxels along the edges of the field to pad with 0\n        field_mode: interpolation mode to use when upsampling the deformation field\n        align_corners: if True align the corners when upsampling field\n        prob: probability transform is applied\n        def_range: value of the deformation range in image size fractions\n        grid_dtype: type for the deformation grid calculated from the field\n        grid_mode: interpolation mode used for sampling input using deformation grid\n        grid_padding_mode: padding mode used for sampling input using deformation grid\n        grid_align_corners: if True align the corners when sampling the deformation grid\n        device: Pytorch device to define field on\n    \"\"\"\n\n    backend = RandSmoothDeform.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int],\n        rand_size: Sequence[int],\n        pad: int = 0,\n        field_mode: SequenceStr = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        prob: float = 0.1,\n        def_range: Sequence[float] | float = 1.0,\n        grid_dtype=torch.float32,\n        grid_mode: SequenceStr = GridSampleMode.NEAREST,\n        grid_padding_mode: str = GridSamplePadMode.BORDER,\n        grid_align_corners: bool | None = False,\n        device: torch.device | None = None,\n    ):\n        RandomizableTransform.__init__(self, prob)\n        MapTransform.__init__(self, keys)\n\n        self.field_mode = ensure_tuple_rep(field_mode, len(self.keys))\n        self.grid_mode = ensure_tuple_rep(grid_mode, len(self.keys))\n\n        self.trans = RandSmoothDeform(\n            rand_size=rand_size,\n            spatial_size=spatial_size,\n            pad=pad,\n            field_mode=self.field_mode[0],\n            align_corners=align_corners,\n            prob=1.0,\n            def_range=def_range,\n            grid_dtype=grid_dtype,\n            grid_mode=self.grid_mode[0],\n            grid_padding_mode=grid_padding_mode,\n            grid_align_corners=grid_align_corners,\n            device=device,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSmoothDeformd:\n        super().set_random_state(seed, state)\n        self.trans.set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        self.trans.randomize()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        self.randomize()\n\n        d = dict(data)\n        if not self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            return d\n\n        for idx, key in enumerate(self.key_iterator(d)):\n            self.trans.set_field_mode(self.field_mode[idx % len(self.field_mode)])\n            self.trans.set_grid_mode(self.grid_mode[idx % len(self.grid_mode)])\n\n            d[key] = self.trans(d[key], False, self.trans.device)\n\n        return d\n\n\nRandSmoothDeformD = RandSmoothDeformDict = RandSmoothDeformd\nRandSmoothFieldAdjustIntensityD = RandSmoothFieldAdjustIntensityDict = RandSmoothFieldAdjustIntensityd\nRandSmoothFieldAdjustContrastD = RandSmoothFieldAdjustContrastDict = RandSmoothFieldAdjustContrastd\n"
  },
  {
    "path": "monai/transforms/spatial/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/spatial/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for spatial operations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom copy import deepcopy\nfrom itertools import zip_longest\nfrom typing import Any, Optional, Union, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.config import USE_COMPILED, DtypeLike\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.box_utils import BoxMode, StandardMode\nfrom monai.data.meta_obj import get_track_meta, set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine\nfrom monai.networks.layers import AffineTransform, GaussianFilter, grid_pull\nfrom monai.networks.utils import meshgrid_ij\nfrom monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.spatial.functional import (\n    affine_func,\n    convert_box_to_points,\n    convert_points_to_box,\n    flip,\n    orientation,\n    resize,\n    rotate,\n    rotate90,\n    spatial_resample,\n    zoom,\n)\nfrom monai.transforms.traits import MultiSampleTrait\nfrom monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform\nfrom monai.transforms.utils import (\n    create_control_grid,\n    create_grid,\n    create_rotate,\n    create_scale,\n    create_shear,\n    create_translate,\n    map_spatial_axes,\n    resolves_modes,\n    scale_affine,\n)\nfrom monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis\nfrom monai.utils import (\n    GridSampleMode,\n    GridSamplePadMode,\n    InterpolateMode,\n    NumpyPadMode,\n    SpaceKeys,\n    convert_to_cupy,\n    convert_to_dst_type,\n    convert_to_numpy,\n    convert_to_tensor,\n    ensure_tuple,\n    ensure_tuple_rep,\n    ensure_tuple_size,\n    fall_back_tuple,\n    issequenceiterable,\n    optional_import,\n)\nfrom monai.utils.deprecate_utils import deprecated_arg_default\nfrom monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends\nfrom monai.utils.misc import ImageMetaKey as Key\nfrom monai.utils.module import look_up_option\nfrom monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string\n\nnib, has_nib = optional_import(\"nibabel\")\ncupy, _ = optional_import(\"cupy\")\ncupy_ndi, _ = optional_import(\"cupyx.scipy.ndimage\")\nnp_ndi, _ = optional_import(\"scipy.ndimage\")\n\n__all__ = [\n    \"SpatialResample\",\n    \"ResampleToMatch\",\n    \"Spacing\",\n    \"Orientation\",\n    \"Flip\",\n    \"GridDistortion\",\n    \"GridSplit\",\n    \"GridPatch\",\n    \"RandGridPatch\",\n    \"Resize\",\n    \"Rotate\",\n    \"Zoom\",\n    \"Rotate90\",\n    \"RandRotate90\",\n    \"RandRotate\",\n    \"RandFlip\",\n    \"RandGridDistortion\",\n    \"RandAxisFlip\",\n    \"RandZoom\",\n    \"AffineGrid\",\n    \"RandAffineGrid\",\n    \"RandDeformGrid\",\n    \"Resample\",\n    \"Affine\",\n    \"RandAffine\",\n    \"Rand2DElastic\",\n    \"Rand3DElastic\",\n    \"RandSimulateLowResolution\",\n]\n\nRandRange = Optional[Union[Sequence[Union[tuple[float, float], float]], float]]\n\n\nclass SpatialResample(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into\n    the ones specified by ``dst_affine`` affine matrix.\n\n    Internally this transform computes the affine transform matrix from ``src_affine`` to ``dst_affine``,\n    by ``xform = linalg.solve(src_affine, dst_affine)``, and call ``monai.transforms.Affine`` with ``xform``.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY]\n\n    def __init__(\n        self,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        align_corners: bool = False,\n        dtype: DtypeLike = np.float64,\n        lazy: bool = False,\n    ):\n        \"\"\"\n        Args:\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n                If ``None``, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        LazyTransform.__init__(self, lazy=lazy)\n        self.mode = mode\n        self.padding_mode = padding_mode\n        self.align_corners = align_corners\n        self.dtype = dtype\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        dst_affine: torch.Tensor | None = None,\n        spatial_size: Sequence[int] | torch.Tensor | int | None = None,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: input image to be resampled. It currently supports channel-first arrays with\n                at most three spatial dimensions.\n            dst_affine: destination affine matrix. Defaults to ``None``, which means the same as `img.affine`.\n                the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``.\n                when `dst_affine` and `spatial_size` are None, the input will be returned without resampling,\n                but the data type will be `float32`.\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined,\n                the transform will compute a spatial size automatically containing the previous field of view.\n                if `spatial_size` is ``-1`` are the transform will use the corresponding input img size.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                Defaults to ``None``, effectively using the value of `self.align_corners`.\n            dtype: data type for resampling computation. Defaults to ``self.dtype`` or\n                ``np.float64`` (for best precision). If ``None``, use the data type of input data.\n                To be compatible with other modules, the output data type is always `float32`.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``.\n\n        When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``,\n        MONAI's resampling implementation will be used.\n        Set `dst_affine` and `spatial_size` to `None` to turn off the resampling step.\n        \"\"\"\n        # get dtype as torch (e.g., torch.float64)\n        dtype_pt = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)\n        align_corners = align_corners if align_corners is not None else self.align_corners\n        mode = mode if mode is not None else self.mode\n        padding_mode = padding_mode if padding_mode is not None else self.padding_mode\n        lazy_ = self.lazy if lazy is None else lazy\n        return spatial_resample(\n            img,\n            dst_affine,\n            spatial_size,\n            mode,\n            padding_mode,\n            align_corners,\n            dtype_pt,\n            lazy=lazy_,\n            transform_info=self.get_transform_info(),\n        )\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        # Create inverse transform\n        kw_args = transform[TraceKeys.EXTRA_INFO]\n        # need to convert dtype from string back to torch.dtype\n        kw_args[\"dtype\"] = get_torch_dtype_from_string(kw_args[\"dtype\"])\n        # source becomes destination\n        kw_args[\"dst_affine\"] = kw_args.pop(\"src_affine\")\n        kw_args[\"spatial_size\"] = transform[TraceKeys.ORIG_SIZE]\n        if kw_args.get(\"align_corners\") == TraceKeys.NONE:\n            kw_args[\"align_corners\"] = False\n        with self.trace_transform(False):\n            # we can't use `self.__call__` in case a child class calls this inverse.\n            out: torch.Tensor = SpatialResample.__call__(self, data, **kw_args)\n        kw_args[\"src_affine\"] = kw_args.get(\"dst_affine\")\n        return out\n\n\nclass ResampleToMatch(SpatialResample):\n    \"\"\"\n    Resample an image to match given metadata. The affine matrix will be aligned,\n    and the size of the output image will match.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    def __call__(  # type: ignore\n        self,\n        img: torch.Tensor,\n        img_dst: torch.Tensor,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: input image to be resampled to match ``img_dst``. It currently supports channel-first arrays with\n                at most three spatial dimensions.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                Defaults to ``None``, effectively using the value of `self.align_corners`.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            dtype: data type for resampling computation. Defaults to ``self.dtype`` or\n                ``np.float64`` (for best precision). If ``None``, use the data type of input data.\n                To be compatible with other modules, the output data type is always `float32`.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Raises:\n            ValueError: When the affine matrix of the source image is not invertible.\n        Returns:\n            Resampled input tensor or MetaTensor.\n        \"\"\"\n        if img_dst is None:\n            raise RuntimeError(\"`img_dst` is missing.\")\n        dst_affine = img_dst.peek_pending_affine() if isinstance(img_dst, MetaTensor) else torch.eye(4)\n        lazy_ = self.lazy if lazy is None else lazy\n        img = super().__call__(\n            img=img,\n            dst_affine=dst_affine,\n            spatial_size=img_dst.peek_pending_shape() if isinstance(img_dst, MetaTensor) else img_dst.shape[1:],\n            mode=mode,\n            padding_mode=padding_mode,\n            align_corners=align_corners,\n            dtype=dtype,\n            lazy=lazy_,\n        )\n        if not lazy_:\n            if isinstance(img, MetaTensor):\n                img.affine = dst_affine\n                if isinstance(img_dst, MetaTensor):\n                    original_fname = img.meta.get(Key.FILENAME_OR_OBJ, \"resample_to_match_source\")\n                    img.meta = deepcopy(img_dst.meta)\n                    img.meta[Key.FILENAME_OR_OBJ] = original_fname  # keep the original name, the others are overwritten\n        else:\n            if isinstance(img, MetaTensor) and isinstance(img_dst, MetaTensor):\n                original_fname = img.meta.get(Key.FILENAME_OR_OBJ, \"resample_to_match_source\")\n                meta_dict = deepcopy(img_dst.meta)\n                for k in (\"affine\", \"spatial_shape\"):  # keys that don't copy from img_dst in lazy evaluation\n                    meta_dict.pop(k, None)\n                img.meta.update(meta_dict)\n                img.meta[Key.FILENAME_OR_OBJ] = original_fname  # keep the original name, the others are overwritten\n        return img\n\n\nclass Spacing(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Resample input image into the specified `pixdim`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = SpatialResample.backend\n\n    def __init__(\n        self,\n        pixdim: Sequence[float] | float | np.ndarray,\n        diagonal: bool = False,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        align_corners: bool = False,\n        dtype: DtypeLike = np.float64,\n        scale_extent: bool = False,\n        recompute_affine: bool = False,\n        min_pixdim: Sequence[float] | float | np.ndarray | None = None,\n        max_pixdim: Sequence[float] | float | np.ndarray | None = None,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            pixdim: output voxel spacing. if providing a single number, will use it for the first dimension.\n                items of the pixdim sequence map to the spatial dimensions of input image, if length\n                of pixdim sequence is longer than image spatial dimensions, will ignore the longer part,\n                if shorter, will pad with the last value. For example, for 3D image if pixdim is [1.0, 2.0] it\n                will be padded to [1.0, 2.0, 2.0]\n                if the components of the `pixdim` are non-positive values, the transform will use the\n                corresponding components of the original pixdim, which is computed from the `affine`\n                matrix of input image.\n            diagonal: whether to resample the input to have a diagonal affine matrix.\n                If True, the input data is resampled to the following affine::\n\n                    np.diag((pixdim_0, pixdim_1, ..., pixdim_n, 1))\n\n                This effectively resets the volume to the world coordinate system (RAS+ in nibabel).\n                The original orientation, rotation, shearing are not preserved.\n\n                If False, this transform preserves the axes orientation, orthogonal rotation and\n                translation components from the original affine. This option will not flip/swap axes\n                of the original data.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n            scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,\n                default False. The option is ignored if output spatial size is specified when calling this transform.\n                See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`\n                should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.\n            recompute_affine: whether to recompute affine based on the output shape. The affine computed\n                analytically does not reflect the potential quantization errors in terms of the output shape.\n                Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.\n            min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this\n                value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the\n                value of `pixdim`. Default to `None`.\n            max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this\n                value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the\n                value of `pixdim`. Default to `None`.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        LazyTransform.__init__(self, lazy=lazy)\n        self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)\n        self.min_pixdim = np.array(ensure_tuple(min_pixdim), dtype=np.float64)\n        self.max_pixdim = np.array(ensure_tuple(max_pixdim), dtype=np.float64)\n        self.diagonal = diagonal\n        self.scale_extent = scale_extent\n        self.recompute_affine = recompute_affine\n\n        for mn, mx in zip(self.min_pixdim, self.max_pixdim):\n            if (not np.isnan(mn)) and (not np.isnan(mx)) and ((mx < mn) or (mn < 0)):\n                raise ValueError(f\"min_pixdim {self.min_pixdim} must be positive, smaller than max {self.max_pixdim}.\")\n\n        self.sp_resample = SpatialResample(\n            mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy\n        )\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.sp_resample.lazy = val\n\n    def __call__(\n        self,\n        data_array: torch.Tensor,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike = None,\n        scale_extent: bool | None = None,\n        output_spatial_shape: Sequence[int] | np.ndarray | int | None = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            data_array: in shape (num_channels, H[, W, ...]).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"self.mode\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"self.padding_mode\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                Defaults to ``None``, effectively using the value of `self.align_corners`.\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n            scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,\n                The option is ignored if output spatial size is specified when calling this transform.\n                See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`\n                should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.\n            output_spatial_shape: specify the shape of the output data_array. This is typically useful for\n                the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization\n                error with the affine.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Raises:\n            ValueError: When ``data_array`` has no spatial dimensions.\n            ValueError: When ``pixdim`` is nonpositive.\n\n        Returns:\n            data tensor or MetaTensor (resampled into `self.pixdim`).\n\n        \"\"\"\n        original_spatial_shape = (\n            data_array.peek_pending_shape() if isinstance(data_array, MetaTensor) else data_array.shape[1:]\n        )\n        sr = len(original_spatial_shape)\n        if sr <= 0:\n            raise ValueError(f\"data_array must have at least one spatial dimension, got {original_spatial_shape}.\")\n        affine_: np.ndarray\n        input_affine = data_array.peek_pending_affine() if isinstance(data_array, MetaTensor) else None\n        if input_affine is None:\n            warnings.warn(\"`data_array` is not of type MetaTensor, assuming affine to be identity.\")\n            # default to identity\n            input_affine = np.eye(sr + 1, dtype=np.float64)\n        affine_ = to_affine_nd(sr, convert_data_type(input_affine, np.ndarray)[0])\n\n        out_d = self.pixdim[:sr].copy()\n        if out_d.size < sr:\n            out_d = np.append(out_d, [out_d[-1]] * (sr - out_d.size))\n\n        orig_d = affine_to_spacing(affine_, sr, out_d.dtype)\n        for idx, (_d, mn, mx) in enumerate(\n            zip_longest(orig_d, self.min_pixdim[:sr], self.max_pixdim[:sr], fillvalue=np.nan)\n        ):\n            target = out_d[idx]\n            mn = target if np.isnan(mn) else min(mn, target)\n            mx = target if np.isnan(mx) else max(mx, target)\n            if mn > mx:\n                raise ValueError(f\"min_pixdim is larger than max_pixdim at dim {idx}: min {mn} max {mx} out {target}.\")\n            out_d[idx] = _d if (mn - AFFINE_TOL) <= _d <= (mx + AFFINE_TOL) else target\n\n        if not align_corners and scale_extent:\n            warnings.warn(\"align_corners=False is not compatible with scale_extent=True.\")\n\n        # compute output affine, shape and offset\n        new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal)\n        scale_extent = self.scale_extent if scale_extent is None else scale_extent\n        output_shape, offset = compute_shape_offset(original_spatial_shape, affine_, new_affine, scale_extent)\n        new_affine[:sr, -1] = offset[:sr]\n\n        actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape\n        lazy_ = self.lazy if lazy is None else lazy\n        data_array = self.sp_resample(\n            data_array,\n            dst_affine=torch.as_tensor(new_affine),\n            spatial_size=actual_shape,  # type: ignore\n            mode=mode,\n            padding_mode=padding_mode,\n            align_corners=align_corners,\n            dtype=dtype,\n            lazy=lazy_,\n        )\n        if self.recompute_affine and isinstance(data_array, MetaTensor):\n            if lazy_:\n                raise NotImplementedError(\"recompute_affine is not supported with lazy evaluation.\")\n            a = scale_affine(original_spatial_shape, actual_shape)\n            data_array.affine = convert_to_dst_type(a, affine_)[0]  # type: ignore\n        return data_array\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        return self.sp_resample.inverse(data)\n\n\nclass Orientation(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Change the input image's orientation into the specified based on `axcodes`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]\n\n    @deprecated_arg_default(\n        name=\"labels\",\n        old_default=((\"L\", \"R\"), (\"P\", \"A\"), (\"I\", \"S\")),\n        new_default=None,\n        msg_suffix=(\n            \"Default value changed to None meaning that the transform now uses the 'space' of a \"\n            \"meta-tensor, if applicable, to determine appropriate axis labels.\"\n        ),\n    )\n    def __init__(\n        self,\n        axcodes: str | None = None,\n        as_closest_canonical: bool = False,\n        labels: Sequence[tuple[str, str]] | None = None,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            axcodes: N elements sequence for spatial ND input's orientation.\n                e.g. axcodes='RAS' represents 3D orientation:\n                (Left, Right), (Posterior, Anterior), (Inferior, Superior).\n                default orientation labels options are: 'L' and 'R' for the first dimension,\n                'P' and 'A' for the second, 'I' and 'S' for the third.\n            as_closest_canonical: if True, load the image as closest to canonical axis format.\n            labels: optional, None or sequence of (2,) sequences\n                (2,) sequences are labels for (beginning, end) of output axis.\n                If ``None``, an appropriate value is chosen depending on the\n                value of the ``\"space\"`` metadata item of a metatensor: if\n                ``\"space\"`` is ``\"LPS\"``, the value used is ``(('R', 'L'),\n                ('A', 'P'), ('I', 'S'))``, if ``\"space\"`` is ``\"RPS\"`` or the\n                input is not a meta-tensor or has no ``\"space\"`` item, the\n                value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not\n                ``None``, the provided value is always used and the ``\"space\"``\n                metadata item (if any) of the input is ignored.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n\n        Raises:\n            ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values.\n\n        See Also: `nibabel.orientations.ornt2axcodes`.\n\n        \"\"\"\n        LazyTransform.__init__(self, lazy=lazy)\n        if axcodes is None and not as_closest_canonical:\n            raise ValueError(\"Incompatible values: axcodes=None and as_closest_canonical=True.\")\n        if axcodes is not None and as_closest_canonical:\n            warnings.warn(\"using as_closest_canonical=True, axcodes ignored.\")\n        self.axcodes = axcodes\n        self.as_closest_canonical = as_closest_canonical\n        self.labels = labels\n\n    def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:\n        \"\"\"\n        If input type is `MetaTensor`, original affine is extracted with `data_array.affine`.\n        If input type is `torch.Tensor`, original affine is assumed to be identity.\n\n        Args:\n            data_array: in shape (num_channels, H[, W, ...]).\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Raises:\n            ValueError: When ``data_array`` has no spatial dimensions.\n            ValueError: When ``axcodes`` spatiality differs from ``data_array``.\n\n        Returns:\n            data_array [reoriented in `self.axcodes`]. Output type will be `MetaTensor`\n                unless `get_track_meta() == False`, in which case it will be\n                `torch.Tensor`.\n\n        \"\"\"\n        spatial_shape = data_array.peek_pending_shape() if isinstance(data_array, MetaTensor) else data_array.shape[1:]\n        sr = len(spatial_shape)\n        if sr <= 0:\n            raise ValueError(f\"data_array must have at least one spatial dimension, got {spatial_shape}.\")\n        affine_: np.ndarray\n        affine_np: np.ndarray\n        labels = self.labels\n        if isinstance(data_array, MetaTensor):\n            affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)\n            affine_ = to_affine_nd(sr, affine_np)\n\n            # Set up \"labels\" such that LPS tensors are handled correctly by default\n            if (\n                self.labels is None\n                and \"space\" in data_array.meta\n                and SpaceKeys(data_array.meta[\"space\"]) == SpaceKeys.LPS\n            ):\n                labels = ((\"R\", \"L\"), (\"A\", \"P\"), (\"I\", \"S\"))  # value for LPS\n\n        else:\n            warnings.warn(\"`data_array` is not of type `MetaTensor, assuming affine to be identity.\")\n            # default to identity\n            affine_np = np.eye(sr + 1, dtype=np.float64)\n            affine_ = np.eye(sr + 1, dtype=np.float64)\n\n        src = nib.io_orientation(affine_)\n        if self.as_closest_canonical:\n            spatial_ornt = src\n        else:\n            if self.axcodes is None:\n                raise ValueError(\"Incompatible values: axcodes=None and as_closest_canonical=True.\")\n            if sr < len(self.axcodes):\n                warnings.warn(\n                    f\"axcodes ('{self.axcodes}') length is smaller than number of input spatial dimensions D={sr}.\\n\"\n                    f\"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},\"\n                    \"please make sure the input is in the channel-first format.\"\n                )\n            dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)\n            if len(dst) < sr:\n                raise ValueError(\n                    f\"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D\"\n                )\n            spatial_ornt = nib.orientations.ornt_transform(src, dst)\n        lazy_ = self.lazy if lazy is None else lazy\n        return orientation(data_array, affine_np, spatial_ornt, lazy=lazy_, transform_info=self.get_transform_info())\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        # Create inverse transform\n        orig_affine = transform[TraceKeys.EXTRA_INFO][\"original_affine\"]\n        labels = self.labels\n\n        # Set up \"labels\" such that LPS tensors are handled correctly by default\n        if (\n            isinstance(data, MetaTensor)\n            and self.labels is None\n            and \"space\" in data.meta\n            and SpaceKeys(data.meta[\"space\"]) == SpaceKeys.LPS\n        ):\n            labels = ((\"R\", \"L\"), (\"A\", \"P\"), (\"I\", \"S\"))  # value for LPS\n\n        orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)\n        inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)\n        # Apply inverse\n        with inverse_transform.trace_transform(False):\n            data = inverse_transform(data)\n\n        return data\n\n\nclass Flip(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Reverses the order of elements along the given spatial axis. Preserves shape.\n    See `torch.flip` documentation for additional details:\n    https://pytorch.org/docs/stable/generated/torch.flip.html\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_axis: spatial axes along which to flip over. Default is None.\n            The default `axis=None` will flip over all of the axes of the input array.\n            If axis is negative it counts from the last to the first axis.\n            If axis is a tuple of ints, flipping is performed on all of the axes\n            specified in the tuple.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None:\n        LazyTransform.__init__(self, lazy=lazy)\n        self.spatial_axis = spatial_axis\n\n    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ])\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        lazy_ = self.lazy if lazy is None else lazy\n        return flip(img, self.spatial_axis, lazy=lazy_, transform_info=self.get_transform_info())  # type: ignore\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        self.pop_transform(data)\n        flipper = Flip(spatial_axis=self.spatial_axis)\n        with flipper.trace_transform(False):\n            return flipper(data)\n\n\nclass Resize(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Resize the input image to given spatial size (with scaling, not cropping/padding).\n    Implemented using :py:class:`torch.nn.functional.interpolate`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        spatial_size: expected shape of spatial dimensions after resize operation.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        size_mode: should be \"all\" or \"longest\", if \"all\", will use `spatial_size` for all the spatial dims,\n            if \"longest\", rescale the image so that only the longest side is equal to specified `spatial_size`,\n            which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:\n            https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/\n            #albumentations.augmentations.geometric.resize.LongestMaxSize.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        anti_aliasing: bool\n            Whether to apply a Gaussian filter to smooth the image prior\n            to downsampling. It is crucial to filter when downsampling\n            the image to avoid aliasing artifacts. See also ``skimage.transform.resize``\n        anti_aliasing_sigma: {float, tuple of floats}, optional\n            Standard deviation for Gaussian filtering used when anti-aliasing.\n            By default, this value is chosen as (s - 1) / 2 where s is the\n            downsampling factor, where s > 1. For the up-size case, s < 1, no\n            anti-aliasing is performed prior to rescaling.\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        spatial_size: Sequence[int] | int,\n        size_mode: str = \"all\",\n        mode: str = InterpolateMode.AREA,\n        align_corners: bool | None = None,\n        anti_aliasing: bool = False,\n        anti_aliasing_sigma: Sequence[float] | float | None = None,\n        dtype: DtypeLike | torch.dtype = torch.float32,\n        lazy: bool = False,\n    ) -> None:\n        LazyTransform.__init__(self, lazy=lazy)\n        self.size_mode = look_up_option(size_mode, [\"all\", \"longest\"])\n        self.spatial_size = spatial_size\n        self.mode = mode\n        self.align_corners = align_corners\n        self.anti_aliasing = anti_aliasing\n        self.anti_aliasing_sigma = anti_aliasing_sigma\n        self.dtype = dtype\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        mode: str | None = None,\n        align_corners: bool | None = None,\n        anti_aliasing: bool | None = None,\n        anti_aliasing_sigma: Sequence[float] | float | None = None,\n        dtype: DtypeLike | torch.dtype = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ]).\n            mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``,\n                ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n                The interpolation mode. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            align_corners: This only has an effect when mode is\n                'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            anti_aliasing: bool, optional\n                Whether to apply a Gaussian filter to smooth the image prior\n                to downsampling. It is crucial to filter when downsampling\n                the image to avoid aliasing artifacts. See also ``skimage.transform.resize``\n            anti_aliasing_sigma: {float, tuple of floats}, optional\n                Standard deviation for Gaussian filtering used when anti-aliasing.\n                By default, this value is chosen as (s - 1) / 2 where s is the\n                downsampling factor, where s > 1. For the up-size case, s < 1, no\n                anti-aliasing is performed prior to rescaling.\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                If None, use the data type of input data.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        Raises:\n            ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions.\n\n        \"\"\"\n        anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing\n        anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma\n\n        input_ndim = img.ndim - 1  # spatial ndim\n        if self.size_mode == \"all\":\n            output_ndim = len(ensure_tuple(self.spatial_size))\n            if output_ndim > input_ndim:\n                input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)\n                img = img.reshape(input_shape)\n            elif output_ndim < input_ndim:\n                raise ValueError(\n                    \"len(spatial_size) must be greater or equal to img spatial dimensions, \"\n                    f\"got spatial_size={output_ndim} img={input_ndim}.\"\n                )\n            _sp = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n            sp_size = fall_back_tuple(self.spatial_size, _sp)\n        else:  # for the \"longest\" mode\n            img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n            if not isinstance(self.spatial_size, int):\n                raise ValueError(\"spatial_size must be an int number if size_mode is 'longest'.\")\n            scale = self.spatial_size / max(img_size)\n            sp_size = tuple(int(round(s * scale)) for s in img_size)\n\n        _mode = self.mode if mode is None else mode\n        _align_corners = self.align_corners if align_corners is None else align_corners\n        _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)\n        lazy_ = self.lazy if lazy is None else lazy\n        return resize(  # type: ignore\n            img,\n            tuple(int(_s) for _s in sp_size),\n            _mode,\n            _align_corners,\n            _dtype,\n            input_ndim,\n            anti_aliasing,\n            anti_aliasing_sigma,\n            lazy_,\n            self.get_transform_info(),\n        )\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        return self.inverse_transform(data, transform)\n\n    def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:\n        orig_size = transform[TraceKeys.ORIG_SIZE]\n        mode = transform[TraceKeys.EXTRA_INFO][\"mode\"]\n        align_corners = transform[TraceKeys.EXTRA_INFO][\"align_corners\"]\n        dtype = transform[TraceKeys.EXTRA_INFO][\"dtype\"]\n        xform = Resize(\n            spatial_size=orig_size,\n            mode=mode,\n            align_corners=None if align_corners == TraceKeys.NONE else align_corners,\n            dtype=dtype,\n        )\n        with xform.trace_transform(False):\n            data = xform(data)\n        for _ in range(transform[TraceKeys.EXTRA_INFO][\"new_dim\"]):\n            data = data.squeeze(-1)  # remove the additional dims\n        return data\n\n\nclass Rotate(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D.\n        keep_size: If it is True, the output shape is kept the same as the input.\n            If it is False, the output shape is adapted so that the\n            input array is contained completely in the output. Default is True.\n        mode: {``\"bilinear\"``, ``\"nearest\"``}\n            Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values. Defaults to ``\"border\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        align_corners: Defaults to False.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data. To be compatible with other modules,\n            the output data type is always ``float32``.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        angle: Sequence[float] | float,\n        keep_size: bool = True,\n        mode: str = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        align_corners: bool = False,\n        dtype: DtypeLike | torch.dtype = torch.float32,\n        lazy: bool = False,\n    ) -> None:\n        LazyTransform.__init__(self, lazy=lazy)\n        self.angle = angle\n        self.keep_size = keep_size\n        self.mode: str = mode\n        self.padding_mode: str = padding_mode\n        self.align_corners = align_corners\n        self.dtype = dtype\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        mode: str | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike | torch.dtype = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].\n            mode: {``\"bilinear\"``, ``\"nearest\"``}\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                align_corners: Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            align_corners: Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Raises:\n            ValueError: When ``img`` spatially is not one of [2D, 3D].\n\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)\n        _mode = mode or self.mode\n        _padding_mode = padding_mode or self.padding_mode\n        _align_corners = self.align_corners if align_corners is None else align_corners\n        im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n        output_shape = im_shape if self.keep_size else None\n        lazy_ = self.lazy if lazy is None else lazy\n        return rotate(  # type: ignore\n            img,\n            self.angle,\n            output_shape,\n            _mode,\n            _padding_mode,\n            _align_corners,\n            _dtype,\n            lazy=lazy_,\n            transform_info=self.get_transform_info(),\n        )\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        return self.inverse_transform(data, transform)\n\n    def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:\n        fwd_rot_mat = transform[TraceKeys.EXTRA_INFO][\"rot_mat\"]\n        mode = transform[TraceKeys.EXTRA_INFO][\"mode\"]\n        padding_mode = transform[TraceKeys.EXTRA_INFO][\"padding_mode\"]\n        align_corners = transform[TraceKeys.EXTRA_INFO][\"align_corners\"]\n        dtype = transform[TraceKeys.EXTRA_INFO][\"dtype\"]\n        inv_rot_mat = linalg_inv(convert_to_numpy(fwd_rot_mat))\n\n        _, _m, _p, _ = resolves_modes(mode, padding_mode)\n        xform = AffineTransform(\n            normalized=False,\n            mode=_m,\n            padding_mode=_p,\n            align_corners=False if align_corners == TraceKeys.NONE else align_corners,\n            reverse_indexing=True,\n        )\n        img_t: torch.Tensor = convert_data_type(data, MetaTensor, dtype=dtype)[0]\n        transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t)\n        sp_size = transform[TraceKeys.ORIG_SIZE]\n        out: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=sp_size).float().squeeze(0)\n        out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0]\n        if isinstance(out, MetaTensor):\n            affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False)\n            mat = to_affine_nd(len(affine) - 1, transform_t)\n            out.affine @= convert_to_dst_type(mat, affine)[0]\n        return out\n\n\nclass Zoom(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Zooms an ND image using :py:class:`torch.nn.functional.interpolate`.\n    For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html.\n\n    Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors\n    as input, and provides an option of preserving the input spatial size.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        zoom: The zoom factor along the spatial axes.\n            If a float, zoom is the same for each spatial axis.\n            If a sequence, zoom should contain one value for each spatial axis.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"edge\"``.\n            The mode to pad data after zooming.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data.\n        keep_size: Should keep original size (padding/slicing if needed), default is True.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        zoom: Sequence[float] | float,\n        mode: str = InterpolateMode.AREA,\n        padding_mode: str = NumpyPadMode.EDGE,\n        align_corners: bool | None = None,\n        dtype: DtypeLike | torch.dtype = torch.float32,\n        keep_size: bool = True,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        LazyTransform.__init__(self, lazy=lazy)\n        self.zoom = zoom\n        self.mode = mode\n        self.padding_mode = padding_mode\n        self.align_corners = align_corners\n        self.dtype = dtype\n        self.keep_size = keep_size\n        self.kwargs = kwargs\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        mode: str | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike | torch.dtype = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ]).\n            mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``,\n                ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n                The interpolation mode. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"edge\"``.\n                The mode to pad data after zooming.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            align_corners: This only has an effect when mode is\n                'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                If None, use the data type of input data.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1)  # match the spatial image dim\n        _mode = self.mode if mode is None else mode\n        _padding_mode = padding_mode or self.padding_mode\n        _align_corners = self.align_corners if align_corners is None else align_corners\n        _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)\n        lazy_ = self.lazy if lazy is None else lazy\n        return zoom(  # type: ignore\n            img,\n            _zoom,\n            self.keep_size,\n            _mode,\n            _padding_mode,\n            _align_corners,\n            _dtype,\n            lazy=lazy_,\n            transform_info=self.get_transform_info(),\n            **self.kwargs,\n        )\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        return self.inverse_transform(data, transform)\n\n    def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:\n        if transform[TraceKeys.EXTRA_INFO][\"do_padcrop\"]:\n            orig_size = transform[TraceKeys.ORIG_SIZE]\n            pad_or_crop = ResizeWithPadOrCrop(spatial_size=orig_size, mode=\"edge\")\n            padcrop_xform = transform[TraceKeys.EXTRA_INFO][\"padcrop\"]\n            padcrop_xform[TraceKeys.EXTRA_INFO][\"pad_info\"][TraceKeys.ID] = TraceKeys.NONE\n            padcrop_xform[TraceKeys.EXTRA_INFO][\"crop_info\"][TraceKeys.ID] = TraceKeys.NONE\n            # this uses inverse because spatial_size // 2 in the forward pass of center crop may cause issues\n            data = pad_or_crop.inverse_transform(data, padcrop_xform)  # type: ignore\n        # Create inverse transform\n        mode = transform[TraceKeys.EXTRA_INFO][\"mode\"]\n        align_corners = transform[TraceKeys.EXTRA_INFO][\"align_corners\"]\n        dtype = transform[TraceKeys.EXTRA_INFO][\"dtype\"]\n        inverse_transform = Resize(spatial_size=transform[TraceKeys.ORIG_SIZE])\n        # Apply inverse\n        with inverse_transform.trace_transform(False):\n            out = inverse_transform(\n                data, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners, dtype=dtype\n            )\n        return out\n\n\nclass Rotate90(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Rotate an array by 90 degrees in the plane specified by `axes`.\n    See `torch.rot90` for additional details:\n    https://pytorch.org/docs/stable/generated/torch.rot90.html#torch-rot90.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, k: int = 1, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False) -> None:\n        \"\"\"\n        Args:\n            k: number of times to rotate by 90 degrees.\n            spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n                Default: (0, 1), this is the first two axis in spatial dimensions.\n                If axis is negative it counts from the last to the first axis.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        LazyTransform.__init__(self, lazy=lazy)\n        self.k = (4 + (k % 4)) % 4  # 0, 1, 2, 3\n        spatial_axes_: tuple[int, int] = ensure_tuple(spatial_axes)\n        if len(spatial_axes_) != 2:\n            raise ValueError(f\"spatial_axes must be 2 numbers to define the plane to rotate, got {spatial_axes_}.\")\n        self.spatial_axes = spatial_axes_\n\n    def __call__(self, img: torch.Tensor, lazy: bool | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ]),\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        axes = map_spatial_axes(img.ndim, self.spatial_axes)\n        lazy_ = self.lazy if lazy is None else lazy\n        return rotate90(img, axes, self.k, lazy=lazy_, transform_info=self.get_transform_info())  # type: ignore\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        return self.inverse_transform(data, transform)\n\n    def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:\n        axes = transform[TraceKeys.EXTRA_INFO][\"axes\"]\n        k = transform[TraceKeys.EXTRA_INFO][\"k\"]\n        inv_k = 4 - k % 4\n        xform = Rotate90(k=inv_k, spatial_axes=axes)\n        with xform.trace_transform(False):\n            return xform(data)\n\n\nclass RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    With probability `prob`, input arrays are rotated by 90 degrees\n    in the plane specified by `spatial_axes`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Rotate90.backend\n\n    def __init__(\n        self, prob: float = 0.1, max_k: int = 3, spatial_axes: tuple[int, int] = (0, 1), lazy: bool = False\n    ) -> None:\n        \"\"\"\n        Args:\n            prob: probability of rotating.\n                (Default 0.1, with 10% probability it returns a rotated array)\n            max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`, (Default 3).\n            spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n                Default: (0, 1), this is the first two axis in spatial dimensions.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.max_k = max_k\n        self.spatial_axes = spatial_axes\n\n        self._rand_k = 0\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self._rand_k = self.R.randint(self.max_k) + 1\n\n    def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ]),\n            randomize: whether to execute `randomize()` function first, default to True.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n\n        if randomize:\n            self.randomize()\n\n        lazy_ = self.lazy if lazy is None else lazy\n        if self._do_transform:\n            xform = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_)\n            out = xform(img)\n        else:\n            out = convert_to_tensor(img, track_meta=get_track_meta())\n\n        self.push_transform(out, replace=True, lazy=lazy_)\n        return out\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        xform_info = self.pop_transform(data)\n        if not xform_info[TraceKeys.DO_TRANSFORM]:\n            return data\n        rotate_xform = xform_info[TraceKeys.EXTRA_INFO]\n        return Rotate90().inverse_transform(data, rotate_xform)\n\n\nclass RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Randomly rotate the input arrays.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        range_x: Range of rotation angle in radians in the plane defined by the first and second axes.\n            If single number, angle is uniformly sampled from (-range_x, range_x).\n        range_y: Range of rotation angle in radians in the plane defined by the first and third axes.\n            If single number, angle is uniformly sampled from (-range_y, range_y). only work for 3D data.\n        range_z: Range of rotation angle in radians in the plane defined by the second and third axes.\n            If single number, angle is uniformly sampled from (-range_z, range_z). only work for 3D data.\n        prob: Probability of rotation.\n        keep_size: If it is False, the output shape is adapted so that the\n            input array is contained completely in the output.\n            If it is True, the output shape is the same as the input. Default is True.\n        mode: {``\"bilinear\"``, ``\"nearest\"``}\n            Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values. Defaults to ``\"border\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        align_corners: Defaults to False.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data. To be compatible with other modules,\n            the output data type is always ``float32``.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Rotate.backend\n\n    def __init__(\n        self,\n        range_x: tuple[float, float] | float = 0.0,\n        range_y: tuple[float, float] | float = 0.0,\n        range_z: tuple[float, float] | float = 0.0,\n        prob: float = 0.1,\n        keep_size: bool = True,\n        mode: str = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        align_corners: bool = False,\n        dtype: DtypeLike | torch.dtype = np.float32,\n        lazy: bool = False,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.range_x = ensure_tuple(range_x)\n        if len(self.range_x) == 1:\n            self.range_x = tuple(sorted([-self.range_x[0], self.range_x[0]]))\n        self.range_y = ensure_tuple(range_y)\n        if len(self.range_y) == 1:\n            self.range_y = tuple(sorted([-self.range_y[0], self.range_y[0]]))\n        self.range_z = ensure_tuple(range_z)\n        if len(self.range_z) == 1:\n            self.range_z = tuple(sorted([-self.range_z[0], self.range_z[0]]))\n\n        self.keep_size = keep_size\n        self.mode: str = mode\n        self.padding_mode: str = padding_mode\n        self.align_corners = align_corners\n        self.dtype = dtype\n\n        self.x = 0.0\n        self.y = 0.0\n        self.z = 0.0\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1])\n        self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1])\n        self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1])\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        mode: str | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike | torch.dtype = None,\n        randomize: bool = True,\n        lazy: bool | None = None,\n    ):\n        \"\"\"\n        Args:\n            img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).\n            mode: {``\"bilinear\"``, ``\"nearest\"``}\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            align_corners: Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n            randomize: whether to execute `randomize()` function first, default to True.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        if randomize:\n            self.randomize()\n\n        lazy_ = self.lazy if lazy is None else lazy\n        if self._do_transform:\n            ndim = len(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:])\n            rotator = Rotate(\n                angle=self.x if ndim == 2 else (self.x, self.y, self.z),\n                keep_size=self.keep_size,\n                mode=mode or self.mode,\n                padding_mode=padding_mode or self.padding_mode,\n                align_corners=self.align_corners if align_corners is None else align_corners,\n                dtype=dtype or self.dtype or img.dtype,\n                lazy=lazy_,\n            )\n            out = rotator(img)\n        else:\n            out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)\n        self.push_transform(out, replace=True, lazy=lazy_)\n        return out\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        xform_info = self.pop_transform(data)\n        if not xform_info[TraceKeys.DO_TRANSFORM]:\n            return data\n        return Rotate(0).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO])\n\n\nclass RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Randomly flips the image along axes. Preserves shape.\n    See numpy.flip for additional details.\n    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        prob: Probability of flipping.\n        spatial_axis: Spatial axes along which to flip over. Default is None.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Flip.backend\n\n    def __init__(self, prob: float = 0.1, spatial_axis: Sequence[int] | int | None = None, lazy: bool = False) -> None:\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.flipper.lazy = val\n        self._lazy = val\n\n    def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ]),\n            randomize: whether to execute `randomize()` function first, default to True.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        if randomize:\n            self.randomize(None)\n        lazy_ = self.lazy if lazy is None else lazy\n        out = self.flipper(img, lazy=lazy_) if self._do_transform else img\n        out = convert_to_tensor(out, track_meta=get_track_meta())\n        self.push_transform(out, replace=True, lazy=lazy_)\n        return out\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        if not transform[TraceKeys.DO_TRANSFORM]:\n            return data\n        data.applied_operations.append(transform[TraceKeys.EXTRA_INFO])  # type: ignore\n        return self.flipper.inverse(data)\n\n\nclass RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Randomly select a spatial axis and flip along it.\n    See numpy.flip for additional details.\n    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        prob: Probability of flipping.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Flip.backend\n\n    def __init__(self, prob: float = 0.1, lazy: bool = False) -> None:\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self._axis: int | None = None\n        self.flipper = Flip(spatial_axis=self._axis)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.flipper.lazy = val\n        self._lazy = val\n\n    def randomize(self, data: NdarrayOrTensor) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self._axis = self.R.randint(data.ndim - 1)\n\n    def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape: (num_channels, H[, W, ..., ])\n            randomize: whether to execute `randomize()` function first, default to True.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        if randomize:\n            self.randomize(data=img)\n\n        lazy_ = self.lazy if lazy is None else lazy\n        if self._do_transform:\n            self.flipper.spatial_axis = self._axis\n            out = self.flipper(img, lazy=lazy_)\n        else:\n            out = convert_to_tensor(img, track_meta=get_track_meta())\n        self.push_transform(out, replace=True, lazy=lazy_)\n        return out\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        if not transform[TraceKeys.DO_TRANSFORM]:\n            return data\n        flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO][TraceKeys.EXTRA_INFO][\"axes\"])\n        with flipper.trace_transform(False):\n            return flipper(data)\n\n\nclass RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Randomly zooms input arrays with given probability within given zoom range.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        prob: Probability of zooming.\n        min_zoom: Min zoom factor. Can be float or sequence same size as image.\n            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims\n            to keep the original spatial shape ratio.\n            If a sequence, min_zoom should contain one value for each spatial axis.\n            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.\n        max_zoom: Max zoom factor. Can be float or sequence same size as image.\n            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims\n            to keep the original spatial shape ratio.\n            If a sequence, max_zoom should contain one value for each spatial axis.\n            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n            The mode to pad data after zooming.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data.\n        keep_size: Should keep original size (pad if needed), default is True.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    backend = Zoom.backend\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        min_zoom: Sequence[float] | float = 0.9,\n        max_zoom: Sequence[float] | float = 1.1,\n        mode: str = InterpolateMode.AREA,\n        padding_mode: str = NumpyPadMode.EDGE,\n        align_corners: bool | None = None,\n        dtype: DtypeLike | torch.dtype = torch.float32,\n        keep_size: bool = True,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.min_zoom = ensure_tuple(min_zoom)\n        self.max_zoom = ensure_tuple(max_zoom)\n        if len(self.min_zoom) != len(self.max_zoom):\n            raise ValueError(\n                f\"min_zoom and max_zoom must have same length, got {len(self.min_zoom)} and {len(self.max_zoom)}.\"\n            )\n        self.mode = mode\n        self.padding_mode = padding_mode\n        self.align_corners = align_corners\n        self.dtype = dtype\n        self.keep_size = keep_size\n        self.kwargs = kwargs\n\n        self._zoom: Sequence[float] = [1.0]\n\n    def randomize(self, img: NdarrayOrTensor) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)]\n        if len(self._zoom) == 1:\n            # to keep the spatial shape ratio, use same random zoom factor for all dims\n            self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1)\n        elif len(self._zoom) == 2 and img.ndim > 3:\n            # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim\n            self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1])\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        mode: str | None = None,\n        padding_mode: str | None = None,\n        align_corners: bool | None = None,\n        dtype: DtypeLike | torch.dtype = None,\n        randomize: bool = True,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).\n            mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``,\n                ``\"area\"``}, the interpolation mode. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n                ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n                available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n                One of the listed string values or a user supplied function. Defaults to ``\"constant\"``.\n                The mode to pad data after zooming.\n                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            align_corners: This only has an effect when mode is\n                'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                If None, use the data type of input data.\n            randomize: whether to execute `randomize()` function first, default to True.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        # match the spatial image dim\n        if randomize:\n            self.randomize(img=img)\n\n        lazy_ = self.lazy if lazy is None else lazy\n        if not self._do_transform:\n            out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)\n        else:\n            xform = Zoom(\n                self._zoom,\n                keep_size=self.keep_size,\n                mode=mode or self.mode,\n                padding_mode=padding_mode or self.padding_mode,\n                align_corners=self.align_corners if align_corners is None else align_corners,\n                dtype=dtype or self.dtype,\n                lazy=lazy_,\n                **self.kwargs,\n            )\n            out = xform(img)\n        self.push_transform(out, replace=True, lazy=lazy_)\n        return out  # type: ignore\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        xform_info = self.pop_transform(data)\n        if not xform_info[TraceKeys.DO_TRANSFORM]:\n            return data\n        return Zoom(self._zoom).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO])\n\n\nclass AffineGrid(LazyTransform):\n    \"\"\"\n    Affine transforms on the coordinates.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        rotate_params: a rotation angle in radians, a scalar for 2D image, a tuple of 3 floats for 3D.\n            Defaults to no rotation.\n        shear_params: shearing factors for affine matrix, take a 3D affine as example::\n\n            [\n                [1.0, params[0], params[1], 0.0],\n                [params[2], 1.0, params[3], 0.0],\n                [params[4], params[5], 1.0, 0.0],\n                [0.0, 0.0, 0.0, 1.0],\n            ]\n\n            a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing.\n        translate_params: a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Translation is in\n            pixel/voxel relative to the center of the input image. Defaults to no translation.\n        scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,\n            a tuple of 3 floats for 3D. Defaults to `1.0`.\n        dtype: data type for the grid computation. Defaults to ``float32``.\n            If ``None``, use the data type of input data (if `grid` is provided).\n        device: device on which the tensor will be allocated, if a new grid is generated.\n        align_corners: Defaults to False.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        affine: If applied, ignore the params (`rotate_params`, etc.) and use the\n            supplied matrix. Should be square with each side = num of image spatial\n            dimensions + 1.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        rotate_params: Sequence[float] | float | None = None,\n        shear_params: Sequence[float] | float | None = None,\n        translate_params: Sequence[float] | float | None = None,\n        scale_params: Sequence[float] | float | None = None,\n        device: torch.device | None = None,\n        dtype: DtypeLike = np.float32,\n        align_corners: bool = False,\n        affine: NdarrayOrTensor | None = None,\n        lazy: bool = False,\n    ) -> None:\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rotate_params = rotate_params\n        self.shear_params = shear_params\n        self.translate_params = translate_params\n        self.scale_params = scale_params\n        self.device = device\n        _dtype = get_equivalent_dtype(dtype, torch.Tensor)\n        self.dtype = _dtype if _dtype in (torch.float16, torch.float64, None) else torch.float32\n        self.align_corners = align_corners\n        self.affine = affine\n\n    def __call__(\n        self, spatial_size: Sequence[int] | None = None, grid: torch.Tensor | None = None, lazy: bool | None = None\n    ) -> tuple[torch.Tensor | None, torch.Tensor]:\n        \"\"\"\n        The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`.\n        Therefore, either `spatial_size` or `grid` must be provided.\n        When initialising from `spatial_size`, the backend \"torch\" will be used.\n\n        Args:\n            spatial_size: output grid size.\n            grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        Raises:\n            ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values.\n\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        _device: torch.device | None\n\n        if not lazy_:\n            if grid is None:  # create grid from spatial_size\n                if spatial_size is None:\n                    raise ValueError(\"Incompatible values: grid=None and spatial_size=None.\")\n                grid_ = create_grid(spatial_size, device=self.device, backend=\"torch\", dtype=self.dtype)\n            else:\n                grid_ = grid\n            _dtype = self.dtype or grid_.dtype\n            grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta())  # type: ignore\n            _device = torch.device(grid_.device)  # type: ignore\n            spatial_dims = len(grid_.shape) - 1\n        else:\n            _device = self.device  # type: ignore[assignment]\n            spatial_dims = len(spatial_size)  # type: ignore\n        _b = TransformBackends.TORCH\n        affine: torch.Tensor\n        if self.affine is None:\n            affine = torch.eye(spatial_dims + 1, device=_device)\n            if self.rotate_params:\n                affine @= create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b)  # type: ignore[assignment]\n            if self.shear_params:\n                affine @= create_shear(spatial_dims, self.shear_params, device=_device, backend=_b)  # type: ignore[assignment]\n            if self.translate_params:\n                affine @= create_translate(spatial_dims, self.translate_params, device=_device, backend=_b)  # type: ignore[assignment]\n            if self.scale_params:\n                affine @= create_scale(spatial_dims, self.scale_params, device=_device, backend=_b)  # type: ignore[assignment]\n        else:\n            affine = self.affine  # type: ignore\n        affine = to_affine_nd(spatial_dims, affine)\n        if lazy_:\n            return None, affine\n\n        affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False)  # type: ignore\n        if self.align_corners:\n            sc = create_scale(\n                spatial_dims, [max(d, 2) / (max(d, 2) - 1) for d in grid_.shape[1:]], device=_device, backend=_b\n            )\n            sc = convert_to_dst_type(sc, affine)[0]\n            grid_ = ((affine @ sc) @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:]))\n        else:\n            grid_ = (affine @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:]))\n        return grid_, affine  # type: ignore[return-value]\n\n\nclass RandAffineGrid(Randomizable, LazyTransform):\n    \"\"\"\n    Generate randomised affine grid.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = AffineGrid.backend\n\n    def __init__(\n        self,\n        rotate_range: RandRange = None,\n        shear_range: RandRange = None,\n        translate_range: RandRange = None,\n        scale_range: RandRange = None,\n        device: torch.device | None = None,\n        dtype: DtypeLike = np.float32,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 2 floats for 2D, a tuple of 6 floats for 3D) for affine matrix,\n                take a 3D affine as example::\n\n                    [\n                        [1.0, params[0], params[1], 0.0],\n                        [params[2], 1.0, params[3], 0.0],\n                        [params[4], params[5], 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select voxels to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            device: device to store the output grid data.\n            dtype: data type for the grid computation. Defaults to ``np.float32``.\n                If ``None``, use the data type of input data (if `grid` is provided).\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n\n        See also:\n            - :py:meth:`monai.transforms.utils.create_rotate`\n            - :py:meth:`monai.transforms.utils.create_shear`\n            - :py:meth:`monai.transforms.utils.create_translate`\n            - :py:meth:`monai.transforms.utils.create_scale`\n\n        \"\"\"\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rotate_range = ensure_tuple(rotate_range)\n        self.shear_range = ensure_tuple(shear_range)\n        self.translate_range = ensure_tuple(translate_range)\n        self.scale_range = ensure_tuple(scale_range)\n\n        self.rotate_params: list[float] | None = None\n        self.shear_params: list[float] | None = None\n        self.translate_params: list[float] | None = None\n        self.scale_params: list[float] | None = None\n\n        self.device = device\n        self.dtype = dtype\n        self.affine: torch.Tensor | None = torch.eye(4, dtype=torch.float64)\n\n    def _get_rand_param(self, param_range, add_scalar: float = 0.0):\n        out_param = []\n        for f in param_range:\n            if issequenceiterable(f):\n                if len(f) != 2:\n                    raise ValueError(f\"If giving range as [min,max], should have 2 elements per dim, got {f}.\")\n                out_param.append(self.R.uniform(f[0], f[1]) + add_scalar)\n            elif f is not None:\n                out_param.append(self.R.uniform(-f, f) + add_scalar)\n        return out_param\n\n    def randomize(self, data: Any | None = None) -> None:\n        self.rotate_params = self._get_rand_param(self.rotate_range)\n        self.shear_params = self._get_rand_param(self.shear_range)\n        self.translate_params = self._get_rand_param(self.translate_range)\n        self.scale_params = self._get_rand_param(self.scale_range, 1.0)\n\n    def __call__(\n        self,\n        spatial_size: Sequence[int] | None = None,\n        grid: NdarrayOrTensor | None = None,\n        randomize: bool = True,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            spatial_size: output grid size.\n            grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D.\n            randomize: boolean as to whether the grid parameters governing the grid should be randomized.\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a 2D (3xHxW) or 3D (4xHxWxD) grid.\n        \"\"\"\n        if randomize:\n            self.randomize()\n        lazy_ = self.lazy if lazy is None else lazy\n        affine_grid = AffineGrid(\n            rotate_params=self.rotate_params,\n            shear_params=self.shear_params,\n            translate_params=self.translate_params,\n            scale_params=self.scale_params,\n            device=self.device,\n            dtype=self.dtype,\n            lazy=lazy_,\n        )\n        if lazy_:  # return the affine only, don't construct the grid\n            self.affine = affine_grid(spatial_size, grid)[1]  # type: ignore\n            return None  # type: ignore\n        _grid: torch.Tensor\n        _grid, self.affine = affine_grid(spatial_size, grid)  # type: ignore\n        return _grid\n\n    def get_transformation_matrix(self) -> torch.Tensor | None:\n        \"\"\"Get the most recently applied transformation matrix\"\"\"\n        return self.affine\n\n\nclass RandDeformGrid(Randomizable, Transform):\n    \"\"\"\n    Generate random deformation grid.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self, spacing: Sequence[float] | float, magnitude_range: tuple[float, float], device: torch.device | None = None\n    ) -> None:\n        \"\"\"\n        Args:\n            spacing: spacing of the grid in 2D or 3D.\n                e.g., spacing=(1, 1) indicates pixel-wise deformation in 2D,\n                spacing=(1, 1, 1) indicates voxel-wise deformation in 3D,\n                spacing=(2, 2) indicates deformation field defined on every other pixel in 2D.\n            magnitude_range: the random offsets will be generated from\n                `uniform[magnitude[0], magnitude[1])`.\n            device: device to store the output grid data.\n        \"\"\"\n        self.spacing = spacing\n        self.magnitude = magnitude_range\n\n        self.rand_mag = 1.0\n        self.random_offset: np.ndarray\n        self.device = device\n\n    def randomize(self, grid_size: Sequence[int]) -> None:\n        self.random_offset = self.R.normal(size=([len(grid_size)] + list(grid_size))).astype(np.float32, copy=False)\n        self.rand_mag = self.R.uniform(self.magnitude[0], self.magnitude[1])\n\n    def __call__(self, spatial_size: Sequence[int]) -> torch.Tensor:\n        \"\"\"\n        Args:\n            spatial_size: spatial size of the grid.\n        \"\"\"\n        self.spacing = fall_back_tuple(self.spacing, (1.0,) * len(spatial_size))\n        control_grid = create_control_grid(spatial_size, self.spacing, device=self.device, backend=\"torch\")\n        self.randomize(control_grid.shape[1:])\n        _offset, *_ = convert_to_dst_type(self.rand_mag * self.random_offset, control_grid)\n        control_grid[: len(spatial_size)] += _offset\n        return control_grid  # type: ignore\n\n\nclass Resample(Transform):\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        norm_coords: bool = True,\n        device: torch.device | None = None,\n        align_corners: bool = False,\n        dtype: DtypeLike = np.float64,\n    ) -> None:\n        \"\"\"\n        computes output image using values from `img`, locations from `grid` using pytorch.\n        supports spatially 2D or 3D (num_channels, H, W[, D]).\n\n        Args:\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `USE_COMPILED` is `True`, this argument uses\n                ``\"nearest\"``, ``\"bilinear\"``, ``\"bicubic\"`` to indicate 0, 1, 3 order interpolations.\n                See also: https://monai.readthedocs.io/en/stable/networks.html#grid-pull (experimental).\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `USE_COMPILED` is `True`, this argument uses an integer to represent the padding mode.\n                See also: https://monai.readthedocs.io/en/stable/networks.html#grid-pull (experimental).\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to\n                `[0, size - 1]` (for ``monai/csrc`` implementation) or\n                `[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying\n                resampling API.\n            device: device on which the tensor will be allocated.\n            align_corners: Defaults to False.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n                If ``None``, use the data type of input data. To be compatible with other modules,\n                the output data type is always `float32`.\n\n        \"\"\"\n        self.mode = mode\n        self.padding_mode = padding_mode\n        self.norm_coords = norm_coords\n        self.device = device\n        self.align_corners = align_corners\n        self.dtype = dtype\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        grid: torch.Tensor | None = None,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        dtype: DtypeLike = None,\n        align_corners: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W[, D]).\n            grid: shape must be (3, H, W) for 2D or (4, H, W, D) for 3D.\n                if ``norm_coords`` is True, the grid values must be in `[-(size-1)/2, (size-1)/2]`.\n                if ``USE_COMPILED=True`` and ``norm_coords=False``, grid values must be in `[0, size-1]`.\n                if ``USE_COMPILED=False`` and ``norm_coords=False``, grid values must be in `[-1, 1]`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `USE_COMPILED` is `True`, this argument uses\n                ``\"nearest\"``, ``\"bilinear\"``, ``\"bicubic\"`` to indicate 0, 1, 3 order interpolations.\n                See also: https://monai.readthedocs.io/en/stable/networks.html#grid-pull (experimental).\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `USE_COMPILED` is `True`, this argument uses an integer to represent the padding mode.\n                See also: https://monai.readthedocs.io/en/stable/networks.html#grid-pull (experimental).\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            dtype: data type for resampling computation. Defaults to ``self.dtype``.\n                To be compatible with other modules, the output data type is always `float32`.\n            align_corners: Defaults to ``self.align_corners``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n\n        See also:\n            :py:const:`monai.config.USE_COMPILED`\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if grid is None:\n            return img\n\n        _device = img.device if isinstance(img, torch.Tensor) else self.device\n        _dtype = dtype or self.dtype or img.dtype\n        _align_corners = self.align_corners if align_corners is None else align_corners\n        img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)\n        sr = min(len(img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]), 3)\n        backend, _interp_mode, _padding_mode, _ = resolves_modes(\n            self.mode if mode is None else mode,\n            self.padding_mode if padding_mode is None else padding_mode,\n            backend=None,\n            use_compiled=USE_COMPILED,\n        )\n\n        if USE_COMPILED or backend == TransformBackends.NUMPY:\n            grid_t, *_ = convert_to_dst_type(grid[:sr], img_t, dtype=grid.dtype, wrap_sequence=True)\n            if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr():\n                grid_t = grid_t.clone(memory_format=torch.contiguous_format)\n            for i, dim in enumerate(img_t.shape[1 : 1 + sr]):\n                _dim = max(2, dim)\n                t = (_dim - 1) / 2.0\n                if self.norm_coords:\n                    grid_t[i] = ((_dim - 1) / _dim) * grid_t[i] + t if _align_corners else grid_t[i] + t\n                elif _align_corners:\n                    grid_t[i] = ((_dim - 1) / _dim) * (grid_t[i] + 0.5)\n            if USE_COMPILED and backend == TransformBackends.TORCH:  # compiled is using torch backend param name\n                grid_t = moveaxis(grid_t, 0, -1)  # type: ignore\n                out = grid_pull(\n                    img_t.unsqueeze(0),\n                    grid_t.unsqueeze(0).to(img_t),\n                    bound=_padding_mode,\n                    extrapolate=True,\n                    interpolation=_interp_mode,\n                )[0]\n            elif backend == TransformBackends.NUMPY:\n                is_cuda = img_t.is_cuda\n                img_np = (convert_to_cupy if is_cuda else convert_to_numpy)(img_t, wrap_sequence=True)\n                grid_np, *_ = convert_to_dst_type(grid_t, img_np, dtype=grid_t.dtype, wrap_sequence=True)\n                _map_coord = (cupy_ndi if is_cuda else np_ndi).map_coordinates\n                out = (cupy if is_cuda else np).stack(\n                    [_map_coord(c, grid_np, order=_interp_mode, mode=_padding_mode) for c in img_np]\n                )\n                out = convert_to_dst_type(out, img_t)[0]\n        else:\n            grid_t = moveaxis(grid[list(range(sr - 1, -1, -1))], 0, -1)  # type: ignore\n            grid_t = convert_to_dst_type(grid_t, img_t, wrap_sequence=True)[0].unsqueeze(0)\n            if isinstance(grid, torch.Tensor) and grid_t.data_ptr() == grid.data_ptr():\n                grid_t = grid_t.clone(memory_format=torch.contiguous_format)\n            if self.norm_coords:\n                for i, dim in enumerate(img_t.shape[sr + 1 : 0 : -1]):\n                    grid_t[0, ..., i] *= 2.0 / max(2, dim)\n            out = torch.nn.functional.grid_sample(\n                img_t.unsqueeze(0),\n                grid_t,\n                mode=_interp_mode,\n                padding_mode=_padding_mode,\n                align_corners=None if _align_corners == TraceKeys.NONE else _align_corners,  # type: ignore\n            )[0]\n        out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32)\n        return out_val\n\n\nclass Affine(InvertibleTransform, LazyTransform):\n    \"\"\"\n    Transform ``img`` given the affine parameters.\n    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = list(set(AffineGrid.backend) & set(Resample.backend))\n\n    def __init__(\n        self,\n        rotate_params: Sequence[float] | float | None = None,\n        shear_params: Sequence[float] | float | None = None,\n        translate_params: Sequence[float] | float | None = None,\n        scale_params: Sequence[float] | float | None = None,\n        affine: NdarrayOrTensor | None = None,\n        spatial_size: Sequence[int] | int | None = None,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.REFLECTION,\n        normalized: bool = False,\n        device: torch.device | None = None,\n        dtype: DtypeLike = np.float32,\n        align_corners: bool = False,\n        image_only: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        The affine transformations are applied in rotate, shear, translate, scale order.\n\n        Args:\n            rotate_params: a rotation angle in radians, a scalar for 2D image, a tuple of 3 floats for 3D.\n                Defaults to no rotation.\n            shear_params: shearing factors for affine matrix, take a 3D affine as example::\n\n                [\n                    [1.0, params[0], params[1], 0.0],\n                    [params[2], 1.0, params[3], 0.0],\n                    [params[4], params[5], 1.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n\n                a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing.\n            translate_params: a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Translation is in\n                pixel/voxel relative to the center of the input image. Defaults to no translation.\n            scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,\n                a tuple of 3 floats for 3D. Defaults to `1.0`.\n            affine: If applied, ignore the params (`rotate_params`, etc.) and use the\n                supplied matrix. Should be square with each side = num of image spatial\n                dimensions + 1.\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n                to `(32, 64)` if the second spatial dimension size of img is `64`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            normalized: indicating whether the provided `affine` is defined to include a normalization\n                transform converting the coordinates from `[-(size-1)/2, (size-1)/2]` (defined in ``create_grid``) to\n                `[0, size - 1]` or `[-1, 1]` in order to be compatible with the underlying resampling API.\n                If `normalized=False`, additional coordinate normalization will be applied before resampling.\n                See also: :py:func:`monai.networks.utils.normalize_transform`.\n            device: device on which the tensor will be allocated.\n            dtype: data type for resampling computation. Defaults to ``float32``.\n                If ``None``, use the data type of input data. To be compatible with other modules,\n                the output data type is always `float32`.\n            align_corners: Defaults to False.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            image_only: if True return only the image volume, otherwise return (image, affine).\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        LazyTransform.__init__(self, lazy=lazy)\n        self.affine_grid = AffineGrid(\n            rotate_params=rotate_params,\n            shear_params=shear_params,\n            translate_params=translate_params,\n            scale_params=scale_params,\n            affine=affine,\n            dtype=dtype,\n            align_corners=align_corners,\n            device=device,\n            lazy=lazy,\n        )\n        self.image_only = image_only\n        self.norm_coord = not normalized\n        self.resampler = Resample(norm_coords=self.norm_coord, device=device, dtype=dtype, align_corners=align_corners)\n        self.spatial_size = spatial_size\n        self.mode = mode\n        self.padding_mode: str = padding_mode\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self.affine_grid.lazy = val\n        self._lazy = val\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        spatial_size: Sequence[int] | int | None = None,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W[, D]),\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if `img` has two spatial dimensions, `spatial_size` should have 2 elements [h, w].\n                if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d].\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n        sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size)\n        lazy_ = self.lazy if lazy is None else lazy\n        _mode = mode if mode is not None else self.mode\n        _padding_mode = padding_mode if padding_mode is not None else self.padding_mode\n        grid, affine = self.affine_grid(spatial_size=sp_size, lazy=lazy_)\n\n        return affine_func(  # type: ignore\n            img,\n            affine,\n            grid,\n            self.resampler,\n            sp_size,\n            _mode,\n            _padding_mode,\n            True,\n            self.image_only,\n            lazy=lazy_,\n            transform_info=self.get_transform_info(),\n        )\n\n    @classmethod\n    def compute_w_affine(cls, spatial_rank, mat, img_size, sp_size):\n        r = int(spatial_rank)\n        mat = to_affine_nd(r, mat)\n        shift_1 = create_translate(r, [float(d - 1) / 2 for d in img_size[:r]])\n        shift_2 = create_translate(r, [-float(d - 1) / 2 for d in sp_size[:r]])\n        mat = shift_1 @ convert_data_type(mat, np.ndarray)[0] @ shift_2\n        return mat\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        orig_size = transform[TraceKeys.ORIG_SIZE]\n        # Create inverse transform\n        fwd_affine = transform[TraceKeys.EXTRA_INFO][\"affine\"]\n        mode = transform[TraceKeys.EXTRA_INFO][\"mode\"]\n        padding_mode = transform[TraceKeys.EXTRA_INFO][\"padding_mode\"]\n        align_corners = transform[TraceKeys.EXTRA_INFO][\"align_corners\"]\n        inv_affine = linalg_inv(convert_to_numpy(fwd_affine))\n        inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0]\n\n        affine_grid = AffineGrid(affine=inv_affine, align_corners=align_corners)\n        grid, _ = affine_grid(orig_size)\n        # Apply inverse transform\n        out = self.resampler(data, grid, mode, padding_mode, align_corners=align_corners)\n        if not isinstance(out, MetaTensor):\n            out = MetaTensor(out)\n        out.meta = data.meta  # type: ignore\n        affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]\n        xform, *_ = convert_to_dst_type(\n            Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine\n        )\n        out.affine @= xform\n        return out\n\n\nclass RandAffine(RandomizableTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Random affine transform.\n    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Affine.backend\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        rotate_range: RandRange = None,\n        shear_range: RandRange = None,\n        translate_range: RandRange = None,\n        scale_range: RandRange = None,\n        spatial_size: Sequence[int] | int | None = None,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.REFLECTION,\n        cache_grid: bool = False,\n        device: torch.device | None = None,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            prob: probability of returning a randomized affine grid.\n                defaults to 0.1, with 10% chance returns a randomized grid.\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 2 floats for 2D, a tuple of 6 floats for 3D) for affine matrix,\n                take a 3D affine as example::\n\n                    [\n                        [1.0, params[0], params[1], 0.0],\n                        [params[2], 1.0, params[3], 0.0],\n                        [params[4], params[5], 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select pixel/voxel to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n                to `(32, 64)` if the second spatial dimension size of img is `64`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``bilinear``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``reflection``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            cache_grid: whether to cache the identity sampling grid.\n                If the spatial size is not dynamically defined by input image, enabling this option could\n                accelerate the transform.\n            device: device on which the tensor will be allocated.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n\n        See also:\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n            - :py:class:`Affine` for the affine transformation parameters configurations.\n\n        Note:\n            The affine transformations in MONAI use a 'backward mapping' (image-to-grid) logic.\n            This can be counter-intuitive:\n            - Translation: A positive value shifts the image in the negative direction.\n            - Scaling: Positive scale_range values decrease the image size; values in [-1, 0) increase it.\n            - Rotation: The direction (CW/CCW) may vary depending on the axis.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rand_affine_grid = RandAffineGrid(\n            rotate_range=rotate_range,\n            shear_range=shear_range,\n            translate_range=translate_range,\n            scale_range=scale_range,\n            device=device,\n            lazy=lazy,\n        )\n        self.resampler = Resample(device=device)\n\n        self.spatial_size = spatial_size\n        self.cache_grid = cache_grid\n        self._cached_grid = self._init_identity_cache(lazy)\n        self.mode = mode\n        self.padding_mode: str = padding_mode\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.rand_affine_grid.lazy = val\n\n    def _init_identity_cache(self, lazy: bool):\n        \"\"\"\n        Create cache of the identity grid if cache_grid=True and spatial_size is known.\n        \"\"\"\n        if lazy:\n            return None\n        if self.spatial_size is None:\n            if self.cache_grid:\n                warnings.warn(\n                    \"cache_grid=True is not compatible with the dynamic spatial_size, please specify 'spatial_size'.\"\n                )\n            return None\n        _sp_size = ensure_tuple(self.spatial_size)\n        _ndim = len(_sp_size)\n        if _sp_size != fall_back_tuple(_sp_size, [1] * _ndim) or _sp_size != fall_back_tuple(_sp_size, [2] * _ndim):\n            # dynamic shape because it falls back to different outcomes\n            if self.cache_grid:\n                warnings.warn(\n                    \"cache_grid=True is not compatible with the dynamic spatial_size \"\n                    f\"'spatial_size={self.spatial_size}', please specify 'spatial_size'.\"\n                )\n            return None\n        return create_grid(spatial_size=_sp_size, device=self.rand_affine_grid.device, backend=\"torch\")\n\n    def get_identity_grid(self, spatial_size: Sequence[int], lazy: bool):\n        \"\"\"\n        Return a cached or new identity grid depends on the availability.\n\n        Args:\n            spatial_size: non-dynamic spatial size\n        \"\"\"\n        if lazy:\n            return None\n        ndim = len(spatial_size)\n        if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple(\n            spatial_size, [2] * ndim\n        ):\n            raise RuntimeError(f\"spatial_size should not be dynamic, got {spatial_size}.\")\n        return (\n            create_grid(spatial_size=spatial_size, device=self.rand_affine_grid.device, backend=\"torch\")\n            if self._cached_grid is None\n            else self._cached_grid\n        )\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffine:\n        self.rand_affine_grid.set_random_state(seed, state)\n        super().set_random_state(seed, state)\n        return self\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.rand_affine_grid.randomize()\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        spatial_size: Sequence[int] | int | None = None,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        randomize: bool = True,\n        grid=None,\n        lazy: bool | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W[, D]),\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if `img` has two spatial dimensions, `spatial_size` should have 2 elements [h, w].\n                if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d].\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            randomize: whether to execute `randomize()` function first, default to True.\n            grid: precomputed grid to be used (mainly to accelerate `RandAffined`).\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n        \"\"\"\n        if randomize:\n            self.randomize()\n        # if not doing transform and spatial size doesn't change, nothing to do\n        # except convert to float and device\n        ori_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n        sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, ori_size)\n        do_resampling = self._do_transform or (sp_size != ensure_tuple(ori_size))\n        _mode = mode if mode is not None else self.mode\n        _padding_mode = padding_mode if padding_mode is not None else self.padding_mode\n        lazy_ = self.lazy if lazy is None else lazy\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if lazy_:\n            if self._do_transform:\n                if grid is None:\n                    self.rand_affine_grid(sp_size, randomize=randomize, lazy=True)\n                affine = self.rand_affine_grid.get_transformation_matrix()\n            else:\n                affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0]\n        else:\n            if grid is None:\n                grid = self.get_identity_grid(sp_size, lazy_)\n                if self._do_transform:\n                    grid = self.rand_affine_grid(grid=grid, randomize=randomize, lazy=lazy_)\n            affine = self.rand_affine_grid.get_transformation_matrix()\n        return affine_func(  # type: ignore\n            img,\n            affine,\n            grid,\n            self.resampler,\n            sp_size,\n            _mode,\n            _padding_mode,\n            do_resampling,\n            True,\n            lazy=lazy_,\n            transform_info=self.get_transform_info(),\n        )\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        # if transform was not performed nothing to do.\n        if not transform[TraceKeys.EXTRA_INFO][\"do_resampling\"]:\n            return data\n        orig_size = transform[TraceKeys.ORIG_SIZE]\n        orig_size = fall_back_tuple(orig_size, data.shape[1:])\n        # Create inverse transform\n        fwd_affine = transform[TraceKeys.EXTRA_INFO][\"affine\"]\n        mode = transform[TraceKeys.EXTRA_INFO][\"mode\"]\n        padding_mode = transform[TraceKeys.EXTRA_INFO][\"padding_mode\"]\n        inv_affine = linalg_inv(convert_to_numpy(fwd_affine))\n        inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0]\n        affine_grid = AffineGrid(affine=inv_affine)\n        grid, _ = affine_grid(orig_size)\n\n        # Apply inverse transform\n        out = self.resampler(data, grid, mode, padding_mode)\n        if not isinstance(out, MetaTensor):\n            out = MetaTensor(out)\n        out.meta = data.meta  # type: ignore\n        affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]\n        xform, *_ = convert_to_dst_type(\n            Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine\n        )\n        out.affine @= xform\n        return out\n\n\nclass Rand2DElastic(RandomizableTransform):\n    \"\"\"\n    Random elastic deformation and affine in 2D.\n    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.\n\n    \"\"\"\n\n    backend = Resample.backend\n\n    def __init__(\n        self,\n        spacing: tuple[float, float] | float,\n        magnitude_range: tuple[float, float],\n        prob: float = 0.1,\n        rotate_range: RandRange = None,\n        shear_range: RandRange = None,\n        translate_range: RandRange = None,\n        scale_range: RandRange = None,\n        spatial_size: tuple[int, int] | int | None = None,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.REFLECTION,\n        device: torch.device | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            spacing : distance in between the control points.\n            magnitude_range: the random offsets will be generated from ``uniform[magnitude[0], magnitude[1])``.\n            prob: probability of returning a randomized elastic transform.\n                defaults to 0.1, with 10% chance returns a randomized elastic transform,\n                otherwise returns a ``spatial_size`` centered area extracted from the input image.\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 2 floats for 2D) for affine matrix, take a 2D affine as example::\n\n                    [\n                        [1.0, params[0], 0.0],\n                        [params[1], 1.0, 0.0],\n                        [0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select pixel to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            spatial_size: specifying output image spatial size [h, w].\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n                to `(32, 64)` if the second spatial dimension size of img is `64`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            device: device on which the tensor will be allocated.\n\n        See also:\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n            - :py:class:`Affine` for the affine transformation parameters configurations.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        self.deform_grid = RandDeformGrid(spacing=spacing, magnitude_range=magnitude_range, device=device)\n        self.rand_affine_grid = RandAffineGrid(\n            rotate_range=rotate_range,\n            shear_range=shear_range,\n            translate_range=translate_range,\n            scale_range=scale_range,\n            device=device,\n            lazy=False,\n        )\n        self.resampler = Resample(device=device)\n\n        self.device = device\n        self.spatial_size = spatial_size\n        self.mode = mode\n        self.padding_mode: str = padding_mode\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand2DElastic:\n        self.deform_grid.set_random_state(seed, state)\n        self.rand_affine_grid.set_random_state(seed, state)\n        super().set_random_state(seed, state)\n        return self\n\n    def set_device(self, device):\n        self.deform_grid.device = device\n        self.rand_affine_grid.device = device\n        self.resampler.device = device\n        self.device = device\n\n    def randomize(self, spatial_size: Sequence[int]) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.deform_grid.randomize(spatial_size)\n        self.rand_affine_grid.randomize()\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        spatial_size: tuple[int, int] | int | None = None,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        randomize: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W),\n            spatial_size: specifying output image spatial size [h, w].\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            randomize: whether to execute `randomize()` function first, default to True.\n        \"\"\"\n        sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:])\n        if randomize:\n            self.randomize(spatial_size=sp_size)\n\n        if self._do_transform:\n            grid = self.deform_grid(spatial_size=sp_size)\n            grid = self.rand_affine_grid(grid=grid)\n            grid = torch.nn.functional.interpolate(\n                recompute_scale_factor=True,\n                input=grid.unsqueeze(0),\n                scale_factor=list(ensure_tuple(self.deform_grid.spacing)),\n                mode=InterpolateMode.BICUBIC.value,\n                align_corners=False,\n            )\n            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])\n        else:\n            _device = img.device if isinstance(img, torch.Tensor) else self.device\n            grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=_device, backend=\"torch\"))\n        out: torch.Tensor = self.resampler(\n            img,\n            grid,\n            mode=mode if mode is not None else self.mode,\n            padding_mode=padding_mode if padding_mode is not None else self.padding_mode,\n        )\n        return out\n\n\nclass Rand3DElastic(RandomizableTransform):\n    \"\"\"\n    Random elastic deformation and affine in 3D.\n    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.\n\n    \"\"\"\n\n    backend = Resample.backend\n\n    def __init__(\n        self,\n        sigma_range: tuple[float, float],\n        magnitude_range: tuple[float, float],\n        prob: float = 0.1,\n        rotate_range: RandRange = None,\n        shear_range: RandRange = None,\n        translate_range: RandRange = None,\n        scale_range: RandRange = None,\n        spatial_size: tuple[int, int, int] | int | None = None,\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.REFLECTION,\n        device: torch.device | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            sigma_range: a Gaussian kernel with standard deviation sampled from\n                ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid.\n            magnitude_range: the random offsets on the grid will be generated from\n                ``uniform[magnitude[0], magnitude[1])``.\n            prob: probability of returning a randomized elastic transform.\n                defaults to 0.1, with 10% chance returns a randomized elastic transform,\n                otherwise returns a ``spatial_size`` centered area extracted from the input image.\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 6 floats for 3D) for affine matrix, take a 3D affine as example::\n\n                    [\n                        [1.0, params[0], params[1], 0.0],\n                        [params[2], 1.0, params[3], 0.0],\n                        [params[4], params[5], 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select voxel to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            spatial_size: specifying output image spatial size [h, w, d].\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted\n                to `(32, 32, 64)` if the third spatial dimension size of img is `64`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            device: device on which the tensor will be allocated.\n\n        See also:\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n            - :py:class:`Affine` for the affine transformation parameters configurations.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        self.rand_affine_grid = RandAffineGrid(\n            rotate_range=rotate_range,\n            shear_range=shear_range,\n            translate_range=translate_range,\n            scale_range=scale_range,\n            device=device,\n            lazy=False,\n        )\n        self.resampler = Resample(device=device)\n\n        self.sigma_range = sigma_range\n        self.magnitude_range = magnitude_range\n        self.spatial_size = spatial_size\n        self.mode = mode\n        self.padding_mode: str = padding_mode\n        self.device = device\n\n        self.rand_offset: np.ndarray\n        self.magnitude = 1.0\n        self.sigma = 1.0\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand3DElastic:\n        self.rand_affine_grid.set_random_state(seed, state)\n        super().set_random_state(seed, state)\n        return self\n\n    def set_device(self, device):\n        self.rand_affine_grid.device = device\n        self.resampler.device = device\n        self.device = device\n\n    def randomize(self, grid_size: Sequence[int]) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return None\n        self.rand_offset = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32, copy=False)\n        self.magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1])\n        self.sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1])\n        self.rand_affine_grid.randomize()\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        spatial_size: tuple[int, int, int] | int | None = None,\n        mode: str | int | None = None,\n        padding_mode: str | None = None,\n        randomize: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W, D),\n            spatial_size: specifying spatial 3D output image spatial size [h, w, d].\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            randomize: whether to execute `randomize()` function first, default to True.\n        \"\"\"\n        sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:])\n        if randomize:\n            self.randomize(grid_size=sp_size)\n\n        _device = img.device if isinstance(img, torch.Tensor) else self.device\n        grid = create_grid(spatial_size=sp_size, device=_device, backend=\"torch\")\n        if self._do_transform:\n            if self.rand_offset is None:\n                raise RuntimeError(\"rand_offset is not initialized.\")\n            gaussian = GaussianFilter(3, self.sigma, 3.0).to(device=_device)\n            offset = torch.as_tensor(self.rand_offset, device=_device).unsqueeze(0)\n            grid[:3] += gaussian(offset)[0] * self.magnitude\n            grid = self.rand_affine_grid(grid=grid)\n        out: torch.Tensor = self.resampler(\n            img,\n            grid,  # type: ignore\n            mode=mode if mode is not None else self.mode,\n            padding_mode=padding_mode if padding_mode is not None else self.padding_mode,\n        )\n        return out\n\n\nclass GridDistortion(Transform):\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        num_cells: tuple[int] | int,\n        distort_steps: Sequence[Sequence[float]],\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        device: torch.device | None = None,\n    ) -> None:\n        \"\"\"\n        Grid distortion transform. Refer to:\n        https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py\n\n        Args:\n            num_cells: number of grid cells on each dimension.\n            distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the\n                corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.\n                Each value in the tuple represents the distort step of the related cell.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            device: device on which the tensor will be allocated.\n\n        \"\"\"\n        self.resampler = Resample(mode=mode, padding_mode=padding_mode, device=device)\n        self.num_cells = num_cells\n        self.distort_steps = distort_steps\n        self.device = device\n\n    def __call__(\n        self,\n        img: torch.Tensor,\n        distort_steps: Sequence[Sequence] | None = None,\n        mode: str | None = None,\n        padding_mode: str | None = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W[, D]).\n            distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the\n                corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.\n                Each value in the tuple represents the distort step of the related cell.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n\n        \"\"\"\n        distort_steps = self.distort_steps if distort_steps is None else distort_steps\n        if len(img.shape) != len(distort_steps) + 1:\n            raise ValueError(\"the spatial size of `img` does not match with the length of `distort_steps`\")\n\n        all_ranges = []\n        num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1)\n        if isinstance(img, MetaTensor) and img.pending_operations:\n            warnings.warn(\"MetaTensor img has pending operations, transform may return incorrect results.\")\n        for dim_idx, dim_size in enumerate(img.shape[1:]):\n            dim_distort_steps = distort_steps[dim_idx]\n            ranges = torch.zeros(dim_size, dtype=torch.float32)\n            cell_size = dim_size // num_cells[dim_idx]\n            prev = 0\n            for idx in range(num_cells[dim_idx] + 1):\n                start = int(idx * cell_size)\n                end = start + cell_size\n                if end > dim_size:\n                    end = dim_size\n                    cur = dim_size\n                else:\n                    cur = prev + cell_size * dim_distort_steps[idx]\n                ranges[start:end] = torch.linspace(prev, cur, end - start)\n                prev = cur\n            ranges = ranges - (dim_size - 1.0) / 2.0\n            all_ranges.append(ranges)\n\n        coords = meshgrid_ij(*all_ranges)\n        grid = torch.stack([*coords, torch.ones_like(coords[0])])\n\n        return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode)\n\n\nclass RandGridDistortion(RandomizableTransform):\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        num_cells: tuple[int] | int = 5,\n        prob: float = 0.1,\n        distort_limit: tuple[float, float] | float = (-0.03, 0.03),\n        mode: str | int = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        device: torch.device | None = None,\n    ) -> None:\n        \"\"\"\n        Random grid distortion transform. Refer to:\n        https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py\n\n        Args:\n            num_cells: number of grid cells on each dimension.\n            prob: probability of returning a randomized grid distortion transform. Defaults to 0.1.\n            distort_limit: range to randomly distort.\n                If single number, distort_limit is picked from (-distort_limit, distort_limit).\n                Defaults to (-0.03, 0.03).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            device: device on which the tensor will be allocated.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n        self.num_cells = num_cells\n        if isinstance(distort_limit, (int, float)):\n            self.distort_limit = (min(-distort_limit, distort_limit), max(-distort_limit, distort_limit))\n        else:\n            self.distort_limit = (min(distort_limit), max(distort_limit))\n        self.distort_steps: Sequence[Sequence[float]] = ((1.0,),)\n        self.grid_distortion = GridDistortion(\n            num_cells=num_cells, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode, device=device\n        )\n\n    def randomize(self, spatial_shape: Sequence[int]) -> None:\n        super().randomize(None)\n        if not self._do_transform:\n            return\n        self.distort_steps = tuple(\n            tuple(1.0 + self.R.uniform(low=self.distort_limit[0], high=self.distort_limit[1], size=n_cells + 1))\n            for n_cells in ensure_tuple_rep(self.num_cells, len(spatial_shape))\n        )\n\n    def __call__(\n        self, img: torch.Tensor, mode: str | None = None, padding_mode: str | None = None, randomize: bool = True\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W[, D]).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``self.mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``self.padding_mode``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n            randomize: whether to shuffle the random factors using `randomize()`, default to True.\n        \"\"\"\n        if randomize:\n            if isinstance(img, MetaTensor) and img.pending_operations:\n                warnings.warn(\"MetaTensor img has pending operations, transform may return incorrect results.\")\n            self.randomize(img.shape[1:])\n        if not self._do_transform:\n            return convert_to_tensor(img, track_meta=get_track_meta())  # type: ignore\n        return self.grid_distortion(img, distort_steps=self.distort_steps, mode=mode, padding_mode=padding_mode)\n\n\nclass GridSplit(Transform, MultiSampleTrait):\n    \"\"\"\n    Split the image into patches based on the provided grid in 2D.\n\n    Args:\n        grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)\n        size: a tuple or an integer that defines the output patch sizes.\n            If it's an integer, the value will be repeated for each dimension.\n            The default is None, where the patch size will be inferred from the grid shape.\n\n    Example:\n        Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2),\n        it will return a Tensor or array with the size of (4, 3, 5, 5).\n        Here, if the `size` is provided, the returned shape will be (4, 3, size, size)\n\n    Note: This transform currently support only image with two spatial dimensions.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, grid: tuple[int, int] = (2, 2), size: int | tuple[int, int] | None = None):\n        # Grid size\n        self.grid = grid\n\n        # Patch size\n        self.size = None if size is None else ensure_tuple_rep(size, len(self.grid))\n\n    def __call__(\n        self, image: NdarrayOrTensor, size: int | tuple[int, int] | np.ndarray | None = None\n    ) -> list[NdarrayOrTensor]:\n        input_size = self.size if size is None else ensure_tuple_rep(size, len(self.grid))\n\n        if self.grid == (1, 1) and input_size is None:\n            return [image]\n        if isinstance(image, MetaTensor) and image.pending_operations:\n            warnings.warn(\"MetaTensor img has pending operations, transform may return incorrect results.\")\n        split_size, steps = self._get_params(image.shape[1:], input_size)\n        patches: list[NdarrayOrTensor]\n        as_strided_func: Callable\n        if isinstance(image, torch.Tensor):\n            as_strided_func = torch.as_strided\n            c_stride, x_stride, y_stride = image.stride()\n        elif isinstance(image, np.ndarray):\n            as_strided_func = np.lib.stride_tricks.as_strided\n            c_stride, x_stride, y_stride = image.strides\n        else:\n            raise ValueError(f\"Input type [{type(image)}] is not supported.\")\n\n        x_step, y_step = steps\n        n_channels = image.shape[0]\n        strided_image = as_strided_func(\n            image,\n            (*self.grid, n_channels, split_size[0], split_size[1]),\n            (x_stride * x_step, y_stride * y_step, c_stride, x_stride, y_stride),\n        )\n        # Flatten the first two dimensions\n        strided_image = strided_image.reshape(-1, *strided_image.shape[2:])\n        # Make a list of contiguous patches\n        if isinstance(image, torch.Tensor):\n            patches = [p.contiguous() for p in strided_image]\n        elif isinstance(image, np.ndarray):\n            patches = [np.ascontiguousarray(p) for p in strided_image]\n\n        return patches\n\n    def _get_params(self, image_size: Sequence[int] | np.ndarray, size: Sequence[int] | np.ndarray | None = None):\n        \"\"\"\n        Calculate the size and step required for splitting the image\n        Args:\n            The size of the input image\n        \"\"\"\n        if size is None:\n            # infer each sub-image size from the image size and the grid\n            size = tuple(image_size[i] // self.grid[i] for i in range(len(self.grid)))\n\n        if any(size[i] > image_size[i] for i in range(len(self.grid))):\n            raise ValueError(f\"The image size ({image_size})is smaller than the requested split size ({size})\")\n\n        steps = tuple(\n            (image_size[i] - size[i]) // (self.grid[i] - 1) if self.grid[i] > 1 else image_size[i]\n            for i in range(len(self.grid))\n        )\n\n        return size, steps\n\n\nclass GridPatch(Transform, MultiSampleTrait):\n    \"\"\"\n    Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps.\n    It can sort the patches and return all or a subset of them.\n\n    Args:\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension\n        offset: offset of starting position in the array, default is 0 for each dimension.\n        num_patches: number of patches (or maximum number of patches) to return.\n            If the requested number of patches is greater than the number of available patches,\n            padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.\n            When `threshold` is set, this value is treated as the maximum number of patches.\n            Defaults to None, which does not limit number of the patches.\n        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).\n            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.\n        sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`\"max\"`),\n            lowest values (`\"min\"`), or in their default order (`None`). Default to None.\n        threshold: a value to keep only the patches whose sum of intensities are less than the threshold.\n            Defaults to no filtering.\n        pad_mode: the  mode for padding the input image by `patch_size` to include patches that cross boundaries.\n            Available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function.\n            Defaults to `None`, which means no padding will be applied.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n    Returns:\n        MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),\n            with following metadata:\n\n            - `PatchKeys.LOCATION`: the starting location of the patch in the image,\n            - `PatchKeys.COUNT`: total number of patches in the image,\n            - \"spatial_shape\": spatial size of the extracted patch, and\n            - \"offset\": the amount of offset for the patches in the image (starting position of the first patch)\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        patch_size: Sequence[int],\n        offset: Sequence[int] | None = None,\n        num_patches: int | None = None,\n        overlap: Sequence[float] | float = 0.0,\n        sort_fn: str | None = None,\n        threshold: float | None = None,\n        pad_mode: str | None = None,\n        **pad_kwargs,\n    ):\n        self.patch_size = ensure_tuple(patch_size)\n        self.offset = ensure_tuple(offset) if offset else (0,) * len(self.patch_size)\n        self.pad_mode = pad_mode\n        self.pad_kwargs = pad_kwargs\n        self.overlap = overlap\n        self.num_patches = num_patches\n        self.sort_fn = sort_fn.lower() if sort_fn else None\n        self.threshold = threshold\n\n    def filter_threshold(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]:\n        \"\"\"\n        Filter the patches and their locations according to a threshold.\n\n        Args:\n            image_np: a numpy.ndarray or torch.Tensor representing a stack of patches.\n            locations: a numpy.ndarray representing the stack of location of each patch.\n\n        Returns:\n            tuple[NdarrayOrTensor, numpy.ndarray]:  tuple of filtered patches and locations.\n        \"\"\"\n        n_dims = len(image_np.shape)\n        idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1)  # type: ignore[operator]\n        idx_np = convert_data_type(idx, np.ndarray)[0]\n        return image_np[idx], locations[idx_np]\n\n    def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]:\n        \"\"\"\n        Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them.\n\n        Args:\n            image_np: a numpy.ndarray or torch.Tensor representing a stack of patches.\n            locations: a numpy.ndarray representing the stack of location of each patch.\n        \"\"\"\n        if self.sort_fn is None:\n            image_np = image_np[: self.num_patches]\n            locations = locations[: self.num_patches]\n        elif self.num_patches is not None:\n            n_dims = len(image_np.shape)\n            if self.sort_fn == GridPatchSort.MIN:\n                idx = argsort(image_np.sum(tuple(range(1, n_dims))))\n            elif self.sort_fn == GridPatchSort.MAX:\n                idx = argsort(-image_np.sum(tuple(range(1, n_dims))))\n            else:\n                raise ValueError(f'`sort_fn` should be either \"min\", \"max\", or None! {self.sort_fn} provided!')\n            idx = idx[: self.num_patches]\n            idx_np = convert_data_type(idx, np.ndarray)[0]\n            image_np = image_np[idx]\n            locations = locations[idx_np]\n        return image_np, locations\n\n    def __call__(self, array: NdarrayOrTensor) -> MetaTensor:\n        \"\"\"\n        Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps).\n\n        Args:\n            array: a input image as `numpy.ndarray` or `torch.Tensor`\n\n        Return:\n            MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),\n                with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata.\n        \"\"\"\n        # create the patch iterator which sweeps the image row-by-row\n        patch_iterator = iter_patch(\n            array,\n            patch_size=(None,) + self.patch_size,  # expand to have the channel dim\n            start_pos=(0,) + self.offset,  # expand to have the channel dim\n            overlap=self.overlap,\n            copy_back=False,\n            mode=self.pad_mode,\n            **self.pad_kwargs,\n        )\n        patches = list(zip(*patch_iterator))\n        patched_image: NdarrayOrTensor\n        patched_image = np.stack(patches[0]) if isinstance(array, np.ndarray) else torch.stack(patches[0])\n        locations = np.stack(patches[1])[:, 1:, 0]  # only keep the starting location\n\n        # Apply threshold filtering\n        if self.threshold is not None:\n            patched_image, locations = self.filter_threshold(patched_image, locations)\n\n        # Apply count filtering\n        if self.num_patches:\n            # Limit number of patches\n            patched_image, locations = self.filter_count(patched_image, locations)\n            # Pad the patch list to have the requested number of patches\n            if self.threshold is None:\n                padding = self.num_patches - len(patched_image)\n                if padding > 0:\n                    # pad constant patches to the end of the first dim\n                    constant_values = self.pad_kwargs.get(\"constant_values\", 0)\n                    padding_shape = (padding, *list(patched_image.shape)[1:])\n                    constant_padding: NdarrayOrTensor\n                    if isinstance(patched_image, np.ndarray):\n                        constant_padding = np.full(padding_shape, constant_values, dtype=patched_image.dtype)\n                        patched_image = np.concatenate([patched_image, constant_padding], axis=0)\n                    else:\n                        constant_padding = torch.full(\n                            padding_shape,\n                            constant_values,\n                            dtype=patched_image.dtype,\n                            layout=patched_image.layout,\n                            device=patched_image.device,\n                        )\n                        patched_image = torch.cat([patched_image, constant_padding], dim=0)\n                    locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)\n\n        # Convert to MetaTensor\n        metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()\n        metadata[PatchKeys.LOCATION] = locations.T\n        metadata[PatchKeys.COUNT] = len(locations)\n        metadata[\"spatial_shape\"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T\n        metadata[\"offset\"] = self.offset\n        output = MetaTensor(x=patched_image, meta=metadata)\n        output.is_batch = True\n\n        return output\n\n\nclass RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):\n    \"\"\"\n    Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps,\n    and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D.\n    It can sort the patches and return all or a subset of them.\n\n    Args:\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension\n        min_offset: the minimum range of offset to be selected randomly. Defaults to 0.\n        max_offset: the maximum range of offset to be selected randomly.\n            Defaults to image size modulo patch size.\n        num_patches: number of patches (or maximum number of patches) to return.\n            If the requested number of patches is greater than the number of available patches,\n            padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.\n            When `threshold` is set, this value is treated as the maximum number of patches.\n            Defaults to None, which does not limit number of the patches.\n        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).\n            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.\n        sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`\"max\"`),\n            lowest values (`\"min\"`), in random (\"random\"), or in their default order (`None`). Default to None.\n        threshold: a value to keep only the patches whose sum of intensities are less than the threshold.\n            Defaults to no filtering.\n        pad_mode: the  mode for padding the input image by `patch_size` to include patches that cross boundaries.\n            Available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function.\n            Defaults to `None`, which means no padding will be applied.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    Returns:\n        MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),\n            with following metadata:\n\n            - `PatchKeys.LOCATION`: the starting location of the patch in the image,\n            - `PatchKeys.COUNT`: total number of patches in the image,\n            - \"spatial_shape\": spatial size of the extracted patch, and\n            - \"offset\": the amount of offset for the patches in the image (starting position of the first patch)\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        patch_size: Sequence[int],\n        min_offset: Sequence[int] | int | None = None,\n        max_offset: Sequence[int] | int | None = None,\n        num_patches: int | None = None,\n        overlap: Sequence[float] | float = 0.0,\n        sort_fn: str | None = None,\n        threshold: float | None = None,\n        pad_mode: str | None = None,\n        **pad_kwargs,\n    ):\n        super().__init__(\n            patch_size=patch_size,\n            offset=(),\n            num_patches=num_patches,\n            overlap=overlap,\n            sort_fn=sort_fn,\n            threshold=threshold,\n            pad_mode=pad_mode,\n            **pad_kwargs,\n        )\n        self.min_offset = min_offset\n        self.max_offset = max_offset\n        self.num_patches = num_patches\n        self.sort_fn = sort_fn\n\n    def randomize(self, array):\n        if self.min_offset is None:\n            min_offset = (0,) * len(self.patch_size)\n        else:\n            min_offset = ensure_tuple_rep(self.min_offset, len(self.patch_size))\n        if self.max_offset is None:\n            max_offset = tuple(s % p for s, p in zip(array.shape[1:], self.patch_size))\n        else:\n            max_offset = ensure_tuple_rep(self.max_offset, len(self.patch_size))\n\n        self.offset = tuple(self.R.randint(low=low, high=high + 1) for low, high in zip(min_offset, max_offset))\n\n    def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, np.ndarray]:\n        if self.sort_fn == GridPatchSort.RANDOM:\n            idx = self.R.permutation(image_np.shape[0])\n            idx = idx[: self.num_patches]\n            idx_np = convert_data_type(idx, np.ndarray)[0]\n            image_np = image_np[idx]  # type: ignore[index]\n            locations = locations[idx_np]\n            return image_np, locations\n        elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX):\n            raise ValueError(f'`sort_fn` should be either \"min\", \"max\", \"random\" or None! {self.sort_fn} provided!')\n        return super().filter_count(image_np, locations)\n\n    def __call__(self, array: NdarrayOrTensor, randomize: bool = True):\n        if randomize:\n            self.randomize(array)\n        return super().__call__(array)\n\n\nclass RandSimulateLowResolution(RandomizableTransform):\n    \"\"\"\n    Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform\n    (https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)\n    First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled\n    from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.\n    \"\"\"\n\n    backend = Affine.backend\n\n    def __init__(\n        self,\n        prob: float = 0.1,\n        downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,\n        upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,\n        zoom_range: Sequence[float] = (0.5, 1.0),\n        align_corners=False,\n        device: torch.device | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            prob: probability of performing this augmentation\n            downsample_mode: interpolation mode for downsampling operation\n            upsample_mode: interpolation mode for upsampling operation\n            zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is\n            sampled. It determines the shape of the downsampled tensor.\n            align_corners: This only has an effect when downsample_mode or upsample_mode  is 'linear', 'bilinear',\n                'bicubic' or 'trilinear'. Default: False\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            device: device on which the tensor will be allocated.\n\n        \"\"\"\n        RandomizableTransform.__init__(self, prob)\n\n        self.downsample_mode = downsample_mode\n        self.upsample_mode = upsample_mode\n        self.zoom_range = zoom_range\n        self.align_corners = align_corners\n        self.device = device\n        self.zoom_factor = 1.0\n\n    def randomize(self, data: Any | None = None) -> None:\n        super().randomize(None)\n        self.zoom_factor = self.R.uniform(self.zoom_range[0], self.zoom_range[1])\n        if not self._do_transform:\n            return None\n\n    def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: shape must be (num_channels, H, W[, D]),\n            randomize: whether to execute `randomize()` function first, defaults to True.\n        \"\"\"\n        if randomize:\n            self.randomize()\n\n        if self._do_transform:\n            input_shape = img.shape[1:]\n            target_shape = tuple(np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_).tolist())\n\n            resize_tfm_downsample = Resize(\n                spatial_size=target_shape, size_mode=\"all\", mode=self.downsample_mode, anti_aliasing=False\n            )\n\n            resize_tfm_upsample = Resize(\n                spatial_size=input_shape,\n                size_mode=\"all\",\n                mode=self.upsample_mode,\n                anti_aliasing=False,\n                align_corners=self.align_corners,\n            )\n            # temporarily disable metadata tracking, since we do not want to invert the two Resize functions during\n            # post-processing\n            original_tack_meta_value = get_track_meta()\n            set_track_meta(False)\n\n            img_downsampled = resize_tfm_downsample(img)\n            img_upsampled = resize_tfm_upsample(img_downsampled)\n\n            # reset metadata tracking to original value\n            set_track_meta(original_tack_meta_value)\n\n            # copy metadata from original image to down-and-upsampled image\n            img_upsampled = MetaTensor(img_upsampled)\n            img_upsampled.copy_meta_from(img)\n\n            return img_upsampled\n\n        else:\n            return img\n\n\nclass ConvertBoxToPoints(Transform):\n    \"\"\"\n    Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.\n    Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.\n    Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:\n        \"\"\"\n        Args:\n            mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.\n        \"\"\"\n        super().__init__()\n        self.mode = StandardMode if mode is None else mode\n\n    def __call__(self, data: Any):\n        data = convert_to_tensor(data, track_meta=get_track_meta())\n        points = convert_box_to_points(data, mode=self.mode)\n        return convert_to_dst_type(points, data)[0]\n\n\nclass ConvertPointsToBoxes(Transform):\n    \"\"\"\n    Converts points to an axis-aligned bounding box.\n    Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or\n    (N, 4, 2) for the 4 corners of a 2D rectangle.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self) -> None:\n        super().__init__()\n\n    def __call__(self, data: Any):\n        data = convert_to_tensor(data, track_meta=get_track_meta())\n        box = convert_points_to_box(data)\n        return convert_to_dst_type(box, data)[0]\n"
  },
  {
    "path": "monai/transforms/spatial/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for spatial operations\ndefined in :py:class:`monai.transforms.spatial.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Hashable, Mapping, Sequence\nfrom typing import Any, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.config import DtypeLike, KeysCollection, SequenceStr\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.box_utils import BoxMode, StandardMode\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import is_supported_format\nfrom monai.networks.layers.simplelayers import GaussianFilter\nfrom monai.transforms.croppad.array import CenterSpatialCrop\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.spatial.array import (\n    Affine,\n    ConvertBoxToPoints,\n    ConvertPointsToBoxes,\n    Flip,\n    GridDistortion,\n    GridPatch,\n    GridSplit,\n    Orientation,\n    Rand2DElastic,\n    Rand3DElastic,\n    RandAffine,\n    RandAxisFlip,\n    RandGridDistortion,\n    RandGridPatch,\n    RandRotate,\n    RandSimulateLowResolution,\n    RandZoom,\n    ResampleToMatch,\n    Resize,\n    Rotate,\n    Rotate90,\n    Spacing,\n    SpatialResample,\n    Zoom,\n)\nfrom monai.transforms.traits import MultiSampleTrait\nfrom monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform\nfrom monai.transforms.utils import create_grid\nfrom monai.utils import (\n    GridSampleMode,\n    GridSamplePadMode,\n    InterpolateMode,\n    NumpyPadMode,\n    convert_to_tensor,\n    ensure_tuple,\n    ensure_tuple_rep,\n    fall_back_tuple,\n)\nfrom monai.utils.deprecate_utils import deprecated_arg_default\nfrom monai.utils.enums import TraceKeys\nfrom monai.utils.module import optional_import\n\nnib, _ = optional_import(\"nibabel\")\n\n__all__ = [\n    \"SpatialResampled\",\n    \"ResampleToMatchd\",\n    \"Spacingd\",\n    \"Orientationd\",\n    \"Rotate90d\",\n    \"RandRotate90d\",\n    \"Resized\",\n    \"Affined\",\n    \"RandAffined\",\n    \"Rand2DElasticd\",\n    \"Rand3DElasticd\",\n    \"Flipd\",\n    \"RandFlipd\",\n    \"GridDistortiond\",\n    \"RandGridDistortiond\",\n    \"RandAxisFlipd\",\n    \"Rotated\",\n    \"RandRotated\",\n    \"Zoomd\",\n    \"RandZoomd\",\n    \"SpatialResampleD\",\n    \"SpatialResampleDict\",\n    \"SpacingD\",\n    \"SpacingDict\",\n    \"OrientationD\",\n    \"OrientationDict\",\n    \"Rotate90D\",\n    \"Rotate90Dict\",\n    \"RandRotate90D\",\n    \"RandRotate90Dict\",\n    \"ResizeD\",\n    \"ResizeDict\",\n    \"AffineD\",\n    \"AffineDict\",\n    \"RandAffineD\",\n    \"RandAffineDict\",\n    \"Rand2DElasticD\",\n    \"Rand2DElasticDict\",\n    \"Rand3DElasticD\",\n    \"Rand3DElasticDict\",\n    \"FlipD\",\n    \"FlipDict\",\n    \"RandFlipD\",\n    \"RandFlipDict\",\n    \"GridDistortionD\",\n    \"GridDistortionDict\",\n    \"RandGridDistortionD\",\n    \"RandGridDistortionDict\",\n    \"RandAxisFlipD\",\n    \"RandAxisFlipDict\",\n    \"RotateD\",\n    \"RotateDict\",\n    \"RandRotateD\",\n    \"RandRotateDict\",\n    \"ZoomD\",\n    \"ZoomDict\",\n    \"RandZoomD\",\n    \"RandZoomDict\",\n    \"GridSplitd\",\n    \"GridSplitD\",\n    \"GridSplitDict\",\n    \"GridPatchd\",\n    \"GridPatchD\",\n    \"GridPatchDict\",\n    \"RandGridPatchd\",\n    \"RandGridPatchD\",\n    \"RandGridPatchDict\",\n    \"RandSimulateLowResolutiond\",\n    \"RandSimulateLowResolutionD\",\n    \"RandSimulateLowResolutionDict\",\n]\n\n\nclass SpatialResampled(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`.\n\n    This transform assumes the ``data`` dictionary has a key for the input\n    data's metadata and contains ``src_affine`` and ``dst_affine`` required by\n    `SpatialResample`. The key is formed by ``key_{meta_key_postfix}``.  The\n    transform will swap ``src_affine`` and ``dst_affine`` affine (with potential data type\n    changes) in the dictionary so that ``src_affine`` always refers to the current\n    status of affine.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    See also:\n        :py:class:`monai.transforms.SpatialResample`\n    \"\"\"\n\n    backend = SpatialResample.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.BORDER,\n        align_corners: Sequence[bool] | bool = False,\n        dtype: Sequence[DtypeLike] | DtypeLike = np.float64,\n        dst_keys: KeysCollection | None = \"dst_affine\",\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n                It also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n                It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.\n            dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False.\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.sp_transform = SpatialResample(lazy=lazy)\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.sp_transform.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        d: dict = dict(data)\n        for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.dst_keys\n        ):\n            d[key] = self.sp_transform(\n                img=d[key],\n                dst_affine=d[dst_key],\n                spatial_size=None,  # None means shape auto inferred\n                mode=mode,\n                padding_mode=padding_mode,\n                align_corners=align_corners,\n                dtype=dtype,\n                lazy=lazy_,\n            )\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.sp_transform.inverse(d[key])\n        return d\n\n\nclass ResampleToMatchd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    \"\"\"\n\n    backend = ResampleToMatch.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        key_dst: str,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.BORDER,\n        align_corners: Sequence[bool] | bool = False,\n        dtype: Sequence[DtypeLike] | DtypeLike = np.float64,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ):\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            key_dst: key of image to resample to match.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample\n                It also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n                It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.key_dst = key_dst\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.resampler = ResampleToMatch(lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.resampler.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        d = dict(data)\n        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype\n        ):\n            d[key] = self.resampler(\n                img=d[key],\n                img_dst=d[self.key_dst],\n                mode=mode,\n                padding_mode=padding_mode,\n                align_corners=align_corners,\n                dtype=dtype,\n                lazy=lazy_,\n            )\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.resampler.inverse(d[key])\n        return d\n\n\nclass Spacingd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`.\n\n    This transform assumes the ``data`` dictionary has a key for the input\n    data's metadata and contains `affine` field.  The key is formed by ``key_{meta_key_postfix}``.\n\n    After resampling the input array, this transform will write the new affine\n    to the `affine` field of metadata which is formed by ``key_{meta_key_postfix}``.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    see also:\n        :py:class:`monai.transforms.Spacing`\n    \"\"\"\n\n    backend = Spacing.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        pixdim: Sequence[float] | float,\n        diagonal: bool = False,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.BORDER,\n        align_corners: Sequence[bool] | bool = False,\n        dtype: Sequence[DtypeLike] | DtypeLike = np.float64,\n        scale_extent: bool = False,\n        recompute_affine: bool = False,\n        min_pixdim: Sequence[float] | float | None = None,\n        max_pixdim: Sequence[float] | float | None = None,\n        ensure_same_shape: bool = True,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            pixdim: output voxel spacing. if providing a single number, will use it for the first dimension.\n                items of the pixdim sequence map to the spatial dimensions of input image, if length\n                of pixdim sequence is longer than image spatial dimensions, will ignore the longer part,\n                if shorter, will pad with `1.0`.\n                if the components of the `pixdim` are non-positive values, the transform will use the\n                corresponding components of the original pixdim, which is computed from the `affine`\n                matrix of input image.\n            diagonal: whether to resample the input to have a diagonal affine matrix.\n                If True, the input data is resampled to the following affine::\n\n                    np.diag((pixdim_0, pixdim_1, pixdim_2, 1))\n\n                This effectively resets the volume to the world coordinate system (RAS+ in nibabel).\n                The original orientation, rotation, shearing are not preserved.\n\n                If False, the axes orientation, orthogonal rotation and\n                translations components from the original affine will be\n                preserved in the target affine. This option will not flip/swap\n                axes against the original ones.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                It also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n                If None, use the data type of input data. To be compatible with other modules,\n                the output data type is always ``float32``.\n                It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.\n            scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,\n                default False. The option is ignored if output spatial size is specified when calling this transform.\n                See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`\n                should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.\n            recompute_affine: whether to recompute affine based on the output shape. The affine computed\n                analytically does not reflect the potential quantization errors in terms of the output shape.\n                Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.\n            min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this\n                value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the\n                value of `pixdim`. Default to `None`.\n            max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this\n                value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the\n                value of `pixdim`. Default to `None`.\n            ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim,\n                whether to ensure exactly the same output spatial shape.  Default to True.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.spacing_transform = Spacing(\n            pixdim,\n            diagonal=diagonal,\n            recompute_affine=recompute_affine,\n            min_pixdim=min_pixdim,\n            max_pixdim=max_pixdim,\n            lazy=lazy,\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys))\n        self.ensure_same_shape = ensure_same_shape\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.spacing_transform.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d: dict = dict(data)\n\n        _init_shape, _pixdim, should_match = None, None, False\n        output_shape_k = None  # tracking output shape\n        lazy_ = self.lazy if lazy is None else lazy\n\n        for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent\n        ):\n            if self.ensure_same_shape and isinstance(d[key], MetaTensor):\n                if _init_shape is None and _pixdim is None:\n                    _init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim\n                else:\n                    should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose(\n                        _pixdim, d[key].pixdim, atol=1e-3\n                    )\n            d[key] = self.spacing_transform(\n                data_array=d[key],\n                mode=mode,\n                padding_mode=padding_mode,\n                align_corners=align_corners,\n                dtype=dtype,\n                scale_extent=scale_extent,\n                output_spatial_shape=output_shape_k if should_match else None,\n                lazy=lazy_,\n            )\n            if isinstance(d[key], MetaTensor):\n                meta_keys = [k for k in d.keys() if k is not None and k.startswith(f\"{key}_\")]\n                for meta_key in meta_keys:\n                    if \"filename_or_obj\" in d[key].meta and is_supported_format(\n                        d[key].meta[\"filename_or_obj\"], [\"nii\", \"nii.gz\"]\n                    ):\n                        d[meta_key].update(d[key].meta)\n            if output_shape_k is None:\n                output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:]\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.spacing_transform.inverse(cast(torch.Tensor, d[key]))\n        return d\n\n\nclass Orientationd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`.\n\n    This transform assumes the channel-first input format.\n    In the case of using this transform for normalizing the orientations of images,\n    it should be used before any anisotropic spatial transforms.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Orientation.backend\n\n    @deprecated_arg_default(\n        name=\"labels\",\n        old_default=((\"L\", \"R\"), (\"P\", \"A\"), (\"I\", \"S\")),\n        new_default=None,\n        msg_suffix=(\n            \"Default value changed to None meaning that the transform now uses the 'space' of a \"\n            \"meta-tensor, if applicable, to determine appropriate axis labels.\"\n        ),\n    )\n    def __init__(\n        self,\n        keys: KeysCollection,\n        axcodes: str | None = None,\n        as_closest_canonical: bool = False,\n        labels: Sequence[tuple[str, str]] | None = None,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            axcodes: N elements sequence for spatial ND input's orientation.\n                e.g. axcodes='RAS' represents 3D orientation:\n                (Left, Right), (Posterior, Anterior), (Inferior, Superior).\n                default orientation labels options are: 'L' and 'R' for the first dimension,\n                'P' and 'A' for the second, 'I' and 'S' for the third.\n            as_closest_canonical: if True, load the image as closest to canonical axis format.\n            labels: optional, None or sequence of (2,) sequences\n                (2,) sequences are labels for (beginning, end) of output axis.\n                If ``None``, an appropriate value is chosen depending on the\n                value of the ``\"space\"`` metadata item of a metatensor: if\n                ``\"space\"`` is ``\"LPS\"``, the value used is ``(('R', 'L'),\n                ('A', 'P'), ('I', 'S'))``, if ``\"space\"`` is ``\"RPS\"`` or the\n                input is not a meta-tensor or has no ``\"space\"`` item, the\n                value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not\n                ``None``, the provided value is always used and the ``\"space\"``\n                metadata item (if any) of the input is ignored.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n\n        See Also:\n            `nibabel.orientations.ornt2axcodes`.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.ornt_transform = Orientation(\n            axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels, lazy=lazy\n        )\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.ornt_transform.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d: dict = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            d[key] = self.ornt_transform(d[key], lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.ornt_transform.inverse(d[key])\n        return d\n\n\nclass Rotate90d(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Rotate90.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        k: int = 1,\n        spatial_axes: tuple[int, int] = (0, 1),\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            k: number of times to rotate by 90 degrees.\n            spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n                Default: (0, 1), this is the first two axis in spatial dimensions.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rotator = Rotate90(k, spatial_axes, lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.rotator.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            d[key] = self.rotator(d[key], lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.rotator.inverse(d[key])\n        return d\n\n\nclass RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandRotate90`.\n    With probability `prob`, input arrays are rotated by 90 degrees\n    in the plane specified by `spatial_axes`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Rotate90.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        max_k: int = 3,\n        spatial_axes: tuple[int, int] = (0, 1),\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            prob: probability of rotating.\n                (Default 0.1, with 10% probability it returns a rotated array.)\n            max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`.\n                (Default 3)\n            spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n                Default: (0, 1), this is the first two axis in spatial dimensions.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n\n        self.max_k = max_k\n        self.spatial_axes = spatial_axes\n\n        self._rand_k = 0\n\n    def randomize(self, data: Any | None = None) -> None:\n        self._rand_k = self.R.randint(self.max_k) + 1\n        super().randomize(None)\n\n    def __call__(\n        self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None\n    ) -> Mapping[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        self.randomize()\n        d = dict(data)\n\n        # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need\n        # to be compatible with the random status of some previous integration tests\n        lazy_ = self.lazy if lazy is None else lazy\n        rotator = Rotate90(self._rand_k, self.spatial_axes, lazy=lazy_)\n        for key in self.key_iterator(d):\n            d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta())\n            self.push_transform(d[key], replace=True, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            if not isinstance(d[key], MetaTensor):\n                continue\n            xform = self.pop_transform(d[key])\n            if xform[TraceKeys.DO_TRANSFORM]:\n                d[key] = Rotate90().inverse_transform(d[key], xform[TraceKeys.EXTRA_INFO])\n        return d\n\n\nclass Resized(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Resize`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        spatial_size: expected shape of spatial dimensions after resize operation.\n            if some components of the `spatial_size` are non-positive values, the transform will use the\n            corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n            to `(32, 64)` if the second spatial dimension size of img is `64`.\n        size_mode: should be \"all\" or \"longest\", if \"all\", will use `spatial_size` for all the spatial dims,\n            if \"longest\", rescale the image so that only the longest side is equal to specified `spatial_size`,\n            which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:\n            https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/\n            #albumentations.augmentations.geometric.resize.LongestMaxSize.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of bool or None, each element corresponds to a key in ``keys``.\n        anti_aliasing: bool\n            Whether to apply a Gaussian filter to smooth the image prior\n            to downsampling. It is crucial to filter when downsampling\n            the image to avoid aliasing artifacts. See also ``skimage.transform.resize``\n        anti_aliasing_sigma: {float, tuple of floats}, optional\n            Standard deviation for Gaussian filtering used when anti-aliasing.\n            By default, this value is chosen as (s - 1) / 2 where s is the\n            downsampling factor, where s > 1. For the up-size case, s < 1, no\n            anti-aliasing is performed prior to rescaling.\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Resize.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int] | int,\n        size_mode: str = \"all\",\n        mode: SequenceStr = InterpolateMode.AREA,\n        align_corners: Sequence[bool | None] | bool | None = None,\n        anti_aliasing: Sequence[bool] | bool = False,\n        anti_aliasing_sigma: Sequence[Sequence[float] | float | None] | Sequence[float] | float | None = None,\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.anti_aliasing = ensure_tuple_rep(anti_aliasing, len(self.keys))\n        self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys))\n        self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode, lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.resizer.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma, dtype in self.key_iterator(\n            d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma, self.dtype\n        ):\n            d[key] = self.resizer(\n                d[key],\n                mode=mode,\n                align_corners=align_corners,\n                anti_aliasing=anti_aliasing,\n                anti_aliasing_sigma=anti_aliasing_sigma,\n                dtype=dtype,\n                lazy=lazy_,\n            )\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.resizer.inverse(d[key])\n        return d\n\n\nclass Affined(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Affine`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = Affine.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        rotate_params: Sequence[float] | float | None = None,\n        shear_params: Sequence[float] | float | None = None,\n        translate_params: Sequence[float] | float | None = None,\n        scale_params: Sequence[float] | float | None = None,\n        affine: NdarrayOrTensor | None = None,\n        spatial_size: Sequence[int] | int | None = None,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,\n        device: torch.device | None = None,\n        dtype: DtypeLike | torch.dtype = np.float32,\n        align_corners: bool = False,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            rotate_params: a rotation angle in radians, a scalar for 2D image, a tuple of 3 floats for 3D.\n                Defaults to no rotation.\n            shear_params: shearing factors for affine matrix, take a 3D affine as example::\n\n                [\n                    [1.0, params[0], params[1], 0.0],\n                    [params[2], 1.0, params[3], 0.0],\n                    [params[4], params[5], 1.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n\n                a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing.\n            translate_params: a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Translation is in\n                pixel/voxel relative to the center of the input image. Defaults to no translation.\n            scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,\n                a tuple of 3 floats for 3D. Defaults to `1.0`.\n            affine: if applied, ignore the params (`rotate_params`, etc.) and use the\n                supplied matrix. Should be square with each side = num of image spatial\n                dimensions + 1.\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n                to `(32, 64)` if the second spatial dimension size of img is `64`.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            device: device on which the tensor will be allocated.\n            dtype: data type for resampling computation. Defaults to ``float32``.\n                If ``None``, use the data type of input data. To be compatible with other modules,\n                the output data type is always `float32`.\n            align_corners: Defaults to False.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n\n        See also:\n            - :py:class:`monai.transforms.compose.MapTransform`\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.affine = Affine(\n            rotate_params=rotate_params,\n            shear_params=shear_params,\n            translate_params=translate_params,\n            scale_params=scale_params,\n            affine=affine,\n            spatial_size=spatial_size,\n            device=device,\n            dtype=dtype,  # type: ignore\n            align_corners=align_corners,\n            lazy=lazy,\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.affine.lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        lazy_ = self.lazy if lazy is None else lazy\n        d = dict(data)\n        for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):\n            d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.affine.inverse(d[key])\n        return d\n\n\nclass RandAffined(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n    \"\"\"\n\n    backend = RandAffine.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_size: Sequence[int] | int | None = None,\n        prob: float = 0.1,\n        rotate_range: Sequence[tuple[float, float] | float] | float | None = None,\n        shear_range: Sequence[tuple[float, float] | float] | float | None = None,\n        translate_range: Sequence[tuple[float, float] | float] | float | None = None,\n        scale_range: Sequence[tuple[float, float] | float] | float | None = None,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,\n        cache_grid: bool = False,\n        device: torch.device | None = None,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            spatial_size: output image spatial size.\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n                to `(32, 64)` if the second spatial dimension size of img is `64`.\n            prob: probability of returning a randomized affine grid.\n                defaults to 0.1, with 10% chance returns a randomized grid.\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 2 floats for 2D, a tuple of 6 floats for 3D) for affine matrix,\n                take a 3D affine as example::\n\n                    [\n                        [1.0, params[0], params[1], 0.0],\n                        [params[2], 1.0, params[3], 0.0],\n                        [params[4], params[5], 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select pixel/voxel to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            cache_grid: whether to cache the identity sampling grid.\n                If the spatial size is not dynamically defined by input image, enabling this option could\n                accelerate the transform.\n            device: device on which the tensor will be allocated.\n            allow_missing_keys: don't raise exception if key is missing.\n            lazy: a flag to indicate whether this transform should execute lazily or not.\n                Defaults to False\n\n        See also:\n            - :py:class:`monai.transforms.compose.MapTransform`\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n\n        Note:\n            The affine transformations in MONAI use a 'backward mapping' (image-to-grid) logic.\n            This can be counter-intuitive:\n            - Translation: A positive value shifts the image in the negative direction.\n            - Scaling: Positive scale_range values decrease the image size; values in [-1, 0) increase it.\n            - Rotation: The direction (CW/CCW) may vary depending on the axis.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rand_affine = RandAffine(\n            prob=1.0,  # because probability handled in this class\n            rotate_range=rotate_range,\n            shear_range=shear_range,\n            translate_range=translate_range,\n            scale_range=scale_range,\n            spatial_size=spatial_size,\n            cache_grid=cache_grid,\n            device=device,\n            lazy=lazy,\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool) -> None:\n        self._lazy = val\n        self.rand_affine.lazy = val\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAffined:\n        self.rand_affine.set_random_state(seed, state)\n        super().set_random_state(seed, state)\n        return self\n\n    def __call__(\n        self, data: Mapping[Hashable, NdarrayOrTensor], lazy: bool | None = None\n    ) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n\n        self.randomize(None)\n        # all the keys share the same random Affine factor\n        self.rand_affine.randomize()\n\n        item = d[first_key]\n        spatial_size = item.peek_pending_shape() if isinstance(item, MetaTensor) else item.shape[1:]\n        lazy_ = self.lazy if lazy is None else lazy\n\n        sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size)\n        # change image size or do random transform\n        do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size))\n        # converting affine to tensor because the resampler currently only support torch backend\n        grid = None\n        if do_resampling:  # need to prepare grid\n            grid = self.rand_affine.get_identity_grid(sp_size, lazy=lazy_)\n            if self._do_transform:  # add some random factors\n                grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid, lazy=lazy_)\n        grid = 0 if grid is None else grid  # always provide a grid to self.rand_affine\n\n        for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):\n            # do the transform\n            if do_resampling:\n                d[key] = self.rand_affine(d[key], None, mode, padding_mode, True, grid, lazy=lazy_)  # type: ignore\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)\n            self._do_transform = do_resampling  # TODO: unify self._do_transform and do_resampling\n            self.push_transform(d[key], replace=True, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            tr = self.pop_transform(d[key])\n            if TraceKeys.EXTRA_INFO not in tr[TraceKeys.EXTRA_INFO]:\n                continue\n            do_resampling = tr[TraceKeys.EXTRA_INFO][TraceKeys.EXTRA_INFO][\"do_resampling\"]\n            if do_resampling:\n                d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO])  # type: ignore\n                d[key] = self.rand_affine.inverse(d[key])  # type: ignore\n\n        return d\n\n\nclass Rand2DElasticd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Rand2DElastic`.\n    \"\"\"\n\n    backend = Rand2DElastic.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spacing: tuple[float, float] | float,\n        magnitude_range: tuple[float, float],\n        spatial_size: tuple[int, int] | int | None = None,\n        prob: float = 0.1,\n        rotate_range: Sequence[tuple[float, float] | float] | float | None = None,\n        shear_range: Sequence[tuple[float, float] | float] | float | None = None,\n        translate_range: Sequence[tuple[float, float] | float] | float | None = None,\n        scale_range: Sequence[tuple[float, float] | float] | float | None = None,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,\n        device: torch.device | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            spacing: distance in between the control points.\n            magnitude_range: 2 int numbers, the random offsets will be generated from\n                ``uniform[magnitude[0], magnitude[1])``.\n            spatial_size: specifying output image spatial size [h, w].\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted\n                to `(32, 64)` if the second spatial dimension size of img is `64`.\n            prob: probability of returning a randomized affine grid.\n                defaults to 0.1, with 10% chance returns a randomized grid,\n                otherwise returns a ``spatial_size`` centered area extracted from the input image.\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 2 floats for 2D) for affine matrix, take a 2D affine as example::\n\n                    [\n                        [1.0, params[0], 0.0],\n                        [params[1], 1.0, 0.0],\n                        [0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select pixel to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            device: device on which the tensor will be allocated.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        See also:\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n            - :py:class:`Affine` for the affine transformation parameters configurations.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_2d_elastic = Rand2DElastic(\n            spacing=spacing,\n            magnitude_range=magnitude_range,\n            prob=1.0,  # because probability controlled by this class\n            rotate_range=rotate_range,\n            shear_range=shear_range,\n            translate_range=translate_range,\n            scale_range=scale_range,\n            spatial_size=spatial_size,\n            device=device,\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand2DElasticd:\n        self.rand_2d_elastic.set_random_state(seed, state)\n        super().set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n\n        if first_key == ():\n            out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n\n        self.randomize(None)\n        device = self.rand_2d_elastic.device\n        if device is None and isinstance(d[first_key], torch.Tensor):\n            device = d[first_key].device  # type: ignore\n            self.rand_2d_elastic.set_device(device)\n        if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations:  # type: ignore\n            warnings.warn(f\"data['{first_key}'] has pending operations, transform may return incorrect results.\")\n        sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:])\n\n        # all the keys share the same random elastic factor\n        self.rand_2d_elastic.randomize(sp_size)\n\n        if self._do_transform:\n            grid = self.rand_2d_elastic.deform_grid(spatial_size=sp_size)\n            grid = self.rand_2d_elastic.rand_affine_grid(grid=grid)\n            grid = torch.nn.functional.interpolate(\n                recompute_scale_factor=True,\n                input=grid.unsqueeze(0),\n                scale_factor=ensure_tuple_rep(self.rand_2d_elastic.deform_grid.spacing, 2),\n                mode=InterpolateMode.BICUBIC.value,\n                align_corners=False,\n            )\n            grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])\n        else:\n            grid = cast(torch.Tensor, create_grid(spatial_size=sp_size, device=device, backend=\"torch\"))\n\n        for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):\n            d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)  # type: ignore\n        return d\n\n\nclass Rand3DElasticd(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Rand3DElastic`.\n    \"\"\"\n\n    backend = Rand3DElastic.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sigma_range: tuple[float, float],\n        magnitude_range: tuple[float, float],\n        spatial_size: tuple[int, int, int] | int | None = None,\n        prob: float = 0.1,\n        rotate_range: Sequence[tuple[float, float] | float] | float | None = None,\n        shear_range: Sequence[tuple[float, float] | float] | float | None = None,\n        translate_range: Sequence[tuple[float, float] | float] | float | None = None,\n        scale_range: Sequence[tuple[float, float] | float] | float | None = None,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,\n        device: torch.device | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            sigma_range: a Gaussian kernel with standard deviation sampled from\n                ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid.\n            magnitude_range: the random offsets on the grid will be generated from\n                ``uniform[magnitude[0], magnitude[1])``.\n            spatial_size: specifying output image spatial size [h, w, d].\n                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,\n                the transform will use the spatial size of `img`.\n                if some components of the `spatial_size` are non-positive values, the transform will use the\n                corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted\n                to `(32, 32, 64)` if the third spatial dimension size of img is `64`.\n            prob: probability of returning a randomized affine grid.\n                defaults to 0.1, with 10% chance returns a randomized grid,\n                otherwise returns a ``spatial_size`` centered area extracted from the input image.\n            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then\n                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter\n                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.\n                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be\n                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`\n                for dim0 and nothing for the remaining dimensions.\n            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select\n                shearing factors(a tuple of 6 floats for 3D) for affine matrix, take a 3D affine as example::\n\n                    [\n                        [1.0, params[0], params[1], 0.0],\n                        [params[2], 1.0, params[3], 0.0],\n                        [params[4], params[5], 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0],\n                    ]\n\n            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly\n                select voxel to translate for every spatial dims.\n            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select\n                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.\n                This allows 0 to correspond to no change (i.e., a scaling of 1.0).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"reflection\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            device: device on which the tensor will be allocated.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        See also:\n            - :py:class:`RandAffineGrid` for the random affine parameters configurations.\n            - :py:class:`Affine` for the affine transformation parameters configurations.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_3d_elastic = Rand3DElastic(\n            sigma_range=sigma_range,\n            magnitude_range=magnitude_range,\n            prob=1.0,  # because probability controlled by this class\n            rotate_range=rotate_range,\n            shear_range=shear_range,\n            translate_range=translate_range,\n            scale_range=scale_range,\n            spatial_size=spatial_size,\n            device=device,\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Rand3DElasticd:\n        self.rand_3d_elastic.set_random_state(seed, state)\n        super().set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n\n        if first_key == ():\n            out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n\n        self.randomize(None)\n        if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations:  # type: ignore\n            warnings.warn(f\"data['{first_key}'] has pending operations, transform may return incorrect results.\")\n        sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:])\n\n        # all the keys share the same random elastic factor\n        self.rand_3d_elastic.randomize(sp_size)\n\n        device = self.rand_3d_elastic.device\n        if device is None and isinstance(d[first_key], torch.Tensor):\n            device = d[first_key].device\n            self.rand_3d_elastic.set_device(device)\n        grid = create_grid(spatial_size=sp_size, device=device, backend=\"torch\")\n        if self._do_transform:\n            gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device)\n            offset = torch.as_tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0)\n            grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude\n            grid = self.rand_3d_elastic.rand_affine_grid(grid=grid)\n\n        for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):\n            d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)  # type: ignore\n        return d\n\n\nclass Flipd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Flip`.\n\n    See `numpy.flip` for additional details.\n    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        spatial_axis: Spatial axes along which to flip over. Default is None.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Flip.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        spatial_axis: Sequence[int] | int | None = None,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.flipper = Flip(spatial_axis=spatial_axis)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.flipper.lazy = val\n        self._lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            d[key] = self.flipper(d[key], lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.flipper.inverse(d[key])\n        return d\n\n\nclass RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandFlip`.\n\n    See `numpy.flip` for additional details.\n    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        prob: Probability of flipping.\n        spatial_axis: Spatial axes along which to flip over. Default is None.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Flip.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        spatial_axis: Sequence[int] | int | None = None,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.flipper = Flip(spatial_axis=spatial_axis, lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.flipper.lazy = val\n        self._lazy = val\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipd:\n        super().set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        self.randomize(None)\n\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            if self._do_transform:\n                d[key] = self.flipper(d[key], lazy=lazy_)\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            self.push_transform(d[key], replace=True, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            xform = self.pop_transform(d[key])\n            if not xform[TraceKeys.DO_TRANSFORM]:\n                continue\n            with self.flipper.trace_transform(False):\n                d[key] = self.flipper(d[key])\n        return d\n\n\nclass RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`.\n\n    See `numpy.flip` for additional details.\n    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        prob: Probability of flipping.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = RandAxisFlip.backend\n\n    def __init__(\n        self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False, lazy: bool = False\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.flipper = RandAxisFlip(prob=1.0, lazy=lazy)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.flipper.lazy = val\n        self._lazy = val\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandAxisFlipd:\n        super().set_random_state(seed, state)\n        self.flipper.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            return d\n\n        self.randomize(None)\n\n        # all the keys share the same random selected axis\n        self.flipper.randomize(d[first_key])\n\n        lazy_ = self.lazy if lazy is None else lazy\n        for key in self.key_iterator(d):\n            if self._do_transform:\n                d[key] = self.flipper(d[key], randomize=False, lazy=lazy_)\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())\n            self.push_transform(d[key], replace=True, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            xform = self.pop_transform(d[key])\n            if xform[TraceKeys.DO_TRANSFORM]:\n                d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO])  # type: ignore\n                d[key] = self.flipper.inverse(d[key])\n        return d\n\n\nclass Rotated(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        angle: Rotation angle(s) in radians.\n        keep_size: If it is False, the output shape is adapted so that the\n            input array is contained completely in the output.\n            If it is True, the output shape is the same as the input. Default is True.\n        mode: {``\"bilinear\"``, ``\"nearest\"``}\n            Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values. Defaults to ``\"border\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        align_corners: Defaults to False.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            It also can be a sequence of bool, each element corresponds to a key in ``keys``.\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data. To be compatible with other modules,\n            the output data type is always ``float32``.\n            It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = Rotate.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        angle: Sequence[float] | float,\n        keep_size: bool = True,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.BORDER,\n        align_corners: Sequence[bool] | bool = False,\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rotator = Rotate(angle=angle, keep_size=keep_size, lazy=lazy)\n\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.rotator.lazy = val\n        self._lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype\n        ):\n            d[key] = self.rotator(\n                d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_\n            )\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.rotator.inverse(d[key])\n        return d\n\n\nclass RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based version :py:class:`monai.transforms.RandRotate`\n    Randomly rotates the input arrays.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        range_x: Range of rotation angle in radians in the plane defined by the first and second axes.\n            If single number, angle is uniformly sampled from (-range_x, range_x).\n        range_y: Range of rotation angle in radians in the plane defined by the first and third axes.\n            If single number, angle is uniformly sampled from (-range_y, range_y). only work for 3D data.\n        range_z: Range of rotation angle in radians in the plane defined by the second and third axes.\n            If single number, angle is uniformly sampled from (-range_z, range_z). only work for 3D data.\n        prob: Probability of rotation.\n        keep_size: If it is False, the output shape is adapted so that the\n            input array is contained completely in the output.\n            If it is True, the output shape is the same as the input. Default is True.\n        mode: {``\"bilinear\"``, ``\"nearest\"``}\n            Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values. Defaults to ``\"border\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        align_corners: Defaults to False.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of bool, each element corresponds to a key in ``keys``.\n        dtype: data type for resampling computation. Defaults to ``float64`` for best precision.\n            If None, use the data type of input data. To be compatible with other modules,\n            the output data type is always ``float32``.\n            It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n    \"\"\"\n\n    backend = RandRotate.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        range_x: tuple[float, float] | float = 0.0,\n        range_y: tuple[float, float] | float = 0.0,\n        range_z: tuple[float, float] | float = 0.0,\n        prob: float = 0.1,\n        keep_size: bool = True,\n        mode: SequenceStr = GridSampleMode.BILINEAR,\n        padding_mode: SequenceStr = GridSamplePadMode.BORDER,\n        align_corners: Sequence[bool] | bool = False,\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rand_rotate = RandRotate(\n            range_x=range_x, range_y=range_y, range_z=range_z, prob=1.0, keep_size=keep_size, lazy=lazy\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.rand_rotate.lazy = val\n        self._lazy = val\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandRotated:\n        super().set_random_state(seed, state)\n        self.rand_rotate.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        self.randomize(None)\n\n        # all the keys share the same random rotate angle\n        self.rand_rotate.randomize()\n        lazy_ = self.lazy if lazy is None else lazy\n\n        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype\n        ):\n            if self._do_transform:\n                d[key] = self.rand_rotate(\n                    d[key],\n                    mode=mode,\n                    padding_mode=padding_mode,\n                    align_corners=align_corners,\n                    dtype=dtype,\n                    randomize=False,\n                    lazy=lazy_,\n                )\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)\n            self.push_transform(d[key], replace=True, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            xform = self.pop_transform(d[key])\n            if xform[TraceKeys.DO_TRANSFORM]:\n                d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO])  # type: ignore\n                d[key] = self.rand_rotate.inverse(d[key])\n        return d\n\n\nclass Zoomd(MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        zoom: The zoom factor along the spatial axes.\n            If a float, zoom is the same for each spatial axis.\n            If a sequence, zoom should contain one value for each spatial axis.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"edge\"``.\n            The mode to pad data after zooming.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of bool or None, each element corresponds to a key in ``keys``.\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data.\n        keep_size: Should keep original size (pad if needed), default is True.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n        kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    \"\"\"\n\n    backend = Zoom.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        zoom: Sequence[float] | float,\n        mode: SequenceStr = InterpolateMode.AREA,\n        padding_mode: SequenceStr = NumpyPadMode.EDGE,\n        align_corners: Sequence[bool | None] | bool | None = None,\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32,\n        keep_size: bool = True,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        LazyTransform.__init__(self, lazy=lazy)\n\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, lazy=lazy, **kwargs)\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.zoomer.lazy = val\n        self._lazy = val\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        lazy_ = self.lazy if lazy is None else lazy\n        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype\n        ):\n            d[key] = self.zoomer(\n                d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, lazy=lazy_\n            )\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.zoomer.inverse(d[key])\n        return d\n\n\nclass RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform):\n    \"\"\"\n    Dict-based version :py:class:`monai.transforms.RandZoom`.\n\n    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`\n    for more information.\n\n    Args:\n        keys: Keys to pick data for transformation.\n        prob: Probability of zooming.\n        min_zoom: Min zoom factor. Can be float or sequence same size as image.\n            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims\n            to keep the original spatial shape ratio.\n            If a sequence, min_zoom should contain one value for each spatial axis.\n            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.\n        max_zoom: Max zoom factor. Can be float or sequence same size as image.\n            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims\n            to keep the original spatial shape ratio.\n            If a sequence, max_zoom should contain one value for each spatial axis.\n            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``, ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode. Defaults to ``\"area\"``.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of string, each element corresponds to a key in ``keys``.\n        padding_mode: available modes for numpy array:{``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            available modes for PyTorch Tensor: {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function. Defaults to ``\"edge\"``.\n            The mode to pad data after zooming.\n            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            It also can be a sequence of bool or None, each element corresponds to a key in ``keys``.\n        dtype: data type for resampling computation. Defaults to ``float32``.\n            If None, use the data type of input data.\n        keep_size: Should keep original size (pad if needed), default is True.\n        allow_missing_keys: don't raise exception if key is missing.\n        lazy: a flag to indicate whether this transform should execute lazily or not.\n            Defaults to False\n        kwargs: other args for `np.pad` API, note that `np.pad` treats channel dimension as the first dimension.\n            more details: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n    \"\"\"\n\n    backend = RandZoom.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        min_zoom: Sequence[float] | float = 0.9,\n        max_zoom: Sequence[float] | float = 1.1,\n        mode: SequenceStr = InterpolateMode.AREA,\n        padding_mode: SequenceStr = NumpyPadMode.EDGE,\n        align_corners: Sequence[bool | None] | bool | None = None,\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32,\n        keep_size: bool = True,\n        allow_missing_keys: bool = False,\n        lazy: bool = False,\n        **kwargs,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        LazyTransform.__init__(self, lazy=lazy)\n        self.rand_zoom = RandZoom(\n            prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, keep_size=keep_size, lazy=lazy, **kwargs\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n\n    @LazyTransform.lazy.setter  # type: ignore\n    def lazy(self, val: bool):\n        self.rand_zoom.lazy = val\n        self._lazy = val\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomd:\n        super().set_random_state(seed, state)\n        self.rand_zoom.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n            lazy: a flag to indicate whether this transform should execute lazily or not\n                during this call. Setting this to False or True overrides the ``lazy`` flag set\n                during initialization for this call. Defaults to None.\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n\n        self.randomize(None)\n\n        # all the keys share the same random zoom factor\n        self.rand_zoom.randomize(d[first_key])\n        lazy_ = self.lazy if lazy is None else lazy\n\n        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(\n            d, self.mode, self.padding_mode, self.align_corners, self.dtype\n        ):\n            if self._do_transform:\n                d[key] = self.rand_zoom(\n                    d[key],\n                    mode=mode,\n                    padding_mode=padding_mode,\n                    align_corners=align_corners,\n                    dtype=dtype,\n                    randomize=False,\n                    lazy=lazy_,\n                )\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)\n            self.push_transform(d[key], replace=True, lazy=lazy_)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            xform = self.pop_transform(d[key])\n            if xform[TraceKeys.DO_TRANSFORM]:\n                d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO])  # type: ignore\n                d[key] = self.rand_zoom.inverse(d[key])\n        return d\n\n\nclass GridDistortiond(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.GridDistortion`.\n    \"\"\"\n\n    backend = GridDistortion.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        num_cells: tuple[int] | int,\n        distort_steps: list[tuple],\n        mode: str = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        device: torch.device | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            num_cells: number of grid cells on each dimension.\n            distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the\n                corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.\n                Each value in the tuple represents the distort step of the related cell.\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            device: device on which the tensor will be allocated.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.grid_distortion = GridDistortion(num_cells=num_cells, distort_steps=distort_steps, device=device)\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):\n            d[key] = self.grid_distortion(d[key], mode=mode, padding_mode=padding_mode)\n        return d\n\n\nclass RandGridDistortiond(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RandGridDistortion`.\n    \"\"\"\n\n    backend = RandGridDistortion.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        num_cells: tuple[int] | int = 5,\n        prob: float = 0.1,\n        distort_limit: tuple[float, float] | float = (-0.03, 0.03),\n        mode: str = GridSampleMode.BILINEAR,\n        padding_mode: str = GridSamplePadMode.BORDER,\n        device: torch.device | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            num_cells: number of grid cells on each dimension.\n            prob: probability of returning a randomized grid distortion transform. Defaults to 0.1.\n            distort_limit: range to randomly distort.\n                If single number, distort_limit is picked from (-distort_limit, distort_limit).\n                Defaults to (-0.03, 0.03).\n            mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n                Interpolation mode to calculate output values. Defaults to ``\"bilinear\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n                and the value represents the order of the spline interpolation.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n                Padding mode for outside grid values. Defaults to ``\"border\"``.\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n                When `mode` is an integer, using numpy/cupy backends, this argument accepts\n                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n                It also can be a sequence, each element corresponds to a key in ``keys``.\n            device: device on which the tensor will be allocated.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.rand_grid_distortion = RandGridDistortion(\n            num_cells=num_cells, prob=1.0, distort_limit=distort_limit, device=device\n        )\n        self.mode = ensure_tuple_rep(mode, len(self.keys))\n        self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandGridDistortiond:\n        super().set_random_state(seed, state)\n        self.rand_grid_distortion.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        self.randomize(None)\n        if not self._do_transform:\n            out: dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            out = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n        if isinstance(d[first_key], MetaTensor) and d[first_key].pending_operations:  # type: ignore\n            warnings.warn(f\"data['{first_key}'] has pending operations, transform may return incorrect results.\")\n        self.rand_grid_distortion.randomize(d[first_key].shape[1:])\n\n        for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):\n            d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False)\n        return d\n\n\nclass GridSplitd(MapTransform, MultiSampleTrait):\n    \"\"\"\n    Split the image into patches based on the provided grid in 2D.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)\n        size: a tuple or an integer that defines the output patch sizes,\n            or a dictionary that define it separately for each key, like {\"image\": 3, \"mask\", (2, 2)}.\n            If it's an integer, the value will be repeated for each dimension.\n            The default is None, where the patch size will be inferred from the grid shape.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Note: This transform currently support only image with two spatial dimensions.\n    \"\"\"\n\n    backend = GridSplit.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        grid: tuple[int, int] = (2, 2),\n        size: int | tuple[int, int] | dict[Hashable, int | tuple[int, int] | None] | None = None,\n        allow_missing_keys: bool = False,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.grid = grid\n        self.size = size if isinstance(size, dict) else {key: size for key in self.keys}\n        self.splitter = GridSplit(grid=grid)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> list[dict[Hashable, NdarrayOrTensor]]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        n_outputs = np.prod(self.grid)\n        output: list[dict[Hashable, NdarrayOrTensor]] = [dict(d) for _ in range(n_outputs)]\n        for key in self.key_iterator(d):\n            result = self.splitter(d[key], self.size[key])\n            for i in range(n_outputs):\n                output[i][key] = result[i]\n        return output\n\n\nclass GridPatchd(MapTransform, MultiSampleTrait):\n    \"\"\"\n    Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps.\n    It can sort the patches and return all or a subset of them.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension\n        offset: starting position in the array, default is 0 for each dimension.\n            np.random.randint(0, patch_size, 2) creates random start between 0 and `patch_size` for a 2D image.\n        num_patches: number of patches (or maximum number of patches) to return.\n            If the requested number of patches is greater than the number of available patches,\n            padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.\n            When `threshold` is set, this value is treated as the maximum number of patches.\n            Defaults to None, which does not limit number of the patches.\n        overlap: amount of overlap between patches in each dimension. Default to 0.0.\n        sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`\"max\"`),\n            lowest values (`\"min\"`), or in their default order (`None`). Default to None.\n        threshold: a value to keep only the patches whose sum of intensities are less than the threshold.\n            Defaults to no filtering.\n        pad_mode: the  mode for padding the input image by `patch_size` to include patches that cross boundaries.\n            Available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function.\n            Defaults to `None`, which means no padding will be applied.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        allow_missing_keys: don't raise exception if key is missing.\n        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    Returns:\n        dictionary, contains the all the original key/value with the values for `keys`\n            replaced by the patches, a MetaTensor with following metadata:\n\n            - `PatchKeys.LOCATION`: the starting location of the patch in the image,\n            - `PatchKeys.COUNT`: total number of patches in the image,\n            - \"spatial_shape\": spatial size of the extracted patch, and\n            - \"offset\": the amount of offset for the patches in the image (starting position of the first patch)\n    \"\"\"\n\n    backend = GridPatch.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        patch_size: Sequence[int],\n        offset: Sequence[int] | None = None,\n        num_patches: int | None = None,\n        overlap: float = 0.0,\n        sort_fn: str | None = None,\n        threshold: float | None = None,\n        pad_mode: str | None = None,\n        allow_missing_keys: bool = False,\n        **pad_kwargs,\n    ):\n        super().__init__(keys, allow_missing_keys)\n        self.patcher = GridPatch(\n            patch_size=patch_size,\n            offset=offset,\n            num_patches=num_patches,\n            overlap=overlap,\n            sort_fn=sort_fn,\n            threshold=threshold,\n            pad_mode=pad_mode,\n            **pad_kwargs,\n        )\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.patcher(d[key])\n        return d\n\n\nclass RandGridPatchd(RandomizableTransform, MapTransform, MultiSampleTrait):\n    \"\"\"\n    Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps,\n    and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D.\n    It can sort the patches and return all or a subset of them.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n        patch_size: size of patches to generate slices for, 0 or None selects whole dimension\n        min_offset: the minimum range of starting position to be selected randomly. Defaults to 0.\n        max_offset: the maximum range of starting position to be selected randomly.\n            Defaults to image size modulo patch size.\n        num_patches: number of patches (or maximum number of patches) to return.\n            If the requested number of patches is greater than the number of available patches,\n            padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.\n            When `threshold` is set, this value is treated as the maximum number of patches.\n            Defaults to None, which does not limit number of the patches.\n        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).\n            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.\n        sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`\"max\"`),\n            lowest values (`\"min\"`), in random (\"random\"), or in their default order (`None`). Default to None.\n        threshold: a value to keep only the patches whose sum of intensities are less than the threshold.\n            Defaults to no filtering.\n        pad_mode: the  mode for padding the input image by `patch_size` to include patches that cross boundaries.\n            Available modes: (Numpy) {``\"constant\"``, ``\"edge\"``, ``\"linear_ramp\"``, ``\"maximum\"``,\n            ``\"mean\"``, ``\"median\"``, ``\"minimum\"``, ``\"reflect\"``, ``\"symmetric\"``, ``\"wrap\"``, ``\"empty\"``}\n            (PyTorch) {``\"constant\"``, ``\"reflect\"``, ``\"replicate\"``, ``\"circular\"``}.\n            One of the listed string values or a user supplied function.\n            Defaults to `None`, which means no padding will be applied.\n            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html\n            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n            requires pytorch >= 1.10 for best compatibility.\n        allow_missing_keys: don't raise exception if key is missing.\n        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.\n            note that `np.pad` treats channel dimension as the first dimension.\n\n    Returns:\n        dictionary, contains the all the original key/value with the values for `keys`\n            replaced by the patches, a MetaTensor with following metadata:\n\n            - `PatchKeys.LOCATION`: the starting location of the patch in the image,\n            - `PatchKeys.COUNT`: total number of patches in the image,\n            - \"spatial_shape\": spatial size of the extracted patch, and\n            - \"offset\": the amount of offset for the patches in the image (starting position of the first patch)\n\n    \"\"\"\n\n    backend = RandGridPatch.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        patch_size: Sequence[int],\n        min_offset: Sequence[int] | int | None = None,\n        max_offset: Sequence[int] | int | None = None,\n        num_patches: int | None = None,\n        overlap: float = 0.0,\n        sort_fn: str | None = None,\n        threshold: float | None = None,\n        pad_mode: str | None = None,\n        allow_missing_keys: bool = False,\n        **pad_kwargs,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.patcher = RandGridPatch(\n            patch_size=patch_size,\n            min_offset=min_offset,\n            max_offset=max_offset,\n            num_patches=num_patches,\n            overlap=overlap,\n            sort_fn=sort_fn,\n            threshold=threshold,\n            pad_mode=pad_mode,\n            **pad_kwargs,\n        )\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandGridPatchd:\n        super().set_random_state(seed, state)\n        self.patcher.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be processed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n\n        Returns:\n            a dictionary containing the transformed data, as well as any other data present in the dictionary\n        \"\"\"\n        d = dict(data)\n        # All the keys share the same random noise\n        for key in self.key_iterator(d):\n            self.patcher.randomize(d[key])\n            break\n        for key in self.key_iterator(d):\n            d[key] = self.patcher(d[key], randomize=False)\n        return d\n\n\nclass RandSimulateLowResolutiond(RandomizableTransform, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RandSimulateLowResolution`.\n    Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform\n    (https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)\n    First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled\n    from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.\n    \"\"\"\n\n    backend = RandAffine.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prob: float = 0.1,\n        downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,\n        upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,\n        zoom_range=(0.5, 1.0),\n        align_corners=False,\n        allow_missing_keys: bool = False,\n        device: torch.device | None = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            prob: probability of performing this augmentation\n            downsample_mode: interpolation mode for downsampling operation\n            upsample_mode: interpolation mode for upsampling operation\n            zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is\n            sampled. It determines the shape of the downsampled tensor.\n            align_corners: This only has an effect when downsample_mode or upsample_mode  is 'linear', 'bilinear',\n                'bicubic' or 'trilinear'. Default: False\n                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n            allow_missing_keys: don't raise exception if key is missing.\n            device: device on which the tensor will be allocated.\n\n        See also:\n            - :py:class:`monai.transforms.compose.MapTransform`\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n\n        self.downsample_mode = downsample_mode\n        self.upsample_mode = upsample_mode\n        self.zoom_range = zoom_range\n        self.align_corners = align_corners\n        self.device = device\n\n        self.sim_lowres_tfm = RandSimulateLowResolution(\n            prob=1.0,  # probability is handled by dictionary class\n            downsample_mode=self.downsample_mode,\n            upsample_mode=self.upsample_mode,\n            zoom_range=self.zoom_range,\n            align_corners=self.align_corners,\n            device=self.device,\n        )\n\n    def set_random_state(\n        self, seed: int | None = None, state: np.random.RandomState | None = None\n    ) -> RandSimulateLowResolutiond:\n        super().set_random_state(seed, state)\n        self.sim_lowres_tfm.set_random_state(seed, state)\n        return self\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            data: a dictionary containing the tensor-like data to be transformed. The ``keys`` specified\n                in this dictionary must be tensor like arrays that are channel first and have at most\n                three spatial dimensions\n        \"\"\"\n        d = dict(data)\n        first_key: Hashable = self.first_key(d)\n        if first_key == ():\n            out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())\n            return out\n\n        self.randomize(None)\n\n        for key in self.key_iterator(d):\n            # do the transform\n            if self._do_transform:\n                d[key] = self.sim_lowres_tfm(d[key])  # type: ignore\n            else:\n                d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)\n        return d\n\n\nclass ConvertBoxToPointsd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.\n    \"\"\"\n\n    backend = ConvertBoxToPoints.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        point_key=\"points\",\n        mode: str | BoxMode | type[BoxMode] | None = StandardMode,\n        allow_missing_keys: bool = False,\n    ):\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            point_key: key to store the point data.\n            mode: the mode of the input boxes. Defaults to StandardMode.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.point_key = point_key\n        self.converter = ConvertBoxToPoints(mode=mode)\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            data[self.point_key] = self.converter(d[key])\n        return data\n\n\nclass ConvertPointsToBoxesd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, box_key=\"box\", allow_missing_keys: bool = False):\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n            box_key: key to store the box data.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.box_key = box_key\n        self.converter = ConvertPointsToBoxes()\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            data[self.box_key] = self.converter(d[key])\n        return data\n\n\nSpatialResampleD = SpatialResampleDict = SpatialResampled\nResampleToMatchD = ResampleToMatchDict = ResampleToMatchd\nSpacingD = SpacingDict = Spacingd\nOrientationD = OrientationDict = Orientationd\nRotate90D = Rotate90Dict = Rotate90d\nRandRotate90D = RandRotate90Dict = RandRotate90d\nResizeD = ResizeDict = Resized\nAffineD = AffineDict = Affined\nRandAffineD = RandAffineDict = RandAffined\nRand2DElasticD = Rand2DElasticDict = Rand2DElasticd\nRand3DElasticD = Rand3DElasticDict = Rand3DElasticd\nFlipD = FlipDict = Flipd\nRandFlipD = RandFlipDict = RandFlipd\nGridDistortionD = GridDistortionDict = GridDistortiond\nRandGridDistortionD = RandGridDistortionDict = RandGridDistortiond\nRandAxisFlipD = RandAxisFlipDict = RandAxisFlipd\nRotateD = RotateDict = Rotated\nRandRotateD = RandRotateDict = RandRotated\nZoomD = ZoomDict = Zoomd\nRandZoomD = RandZoomDict = RandZoomd\nGridSplitD = GridSplitDict = GridSplitd\nGridPatchD = GridPatchDict = GridPatchd\nRandGridPatchD = RandGridPatchDict = RandGridPatchd\nRandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond\nConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd\nConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd\n"
  },
  {
    "path": "monai/transforms/spatial/functional.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"functional\" transforms for spatial operations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport math\nimport warnings\nfrom enum import Enum\n\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.config import USE_COMPILED\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.box_utils import get_boxmode\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd\nfrom monai.networks.layers import AffineTransform\nfrom monai.transforms.croppad.array import ResizeWithPadOrCrop\nfrom monai.transforms.intensity.array import GaussianSmooth\nfrom monai.transforms.inverse import TraceableTransform\nfrom monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine\nfrom monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack\nfrom monai.utils import (\n    LazyAttr,\n    TraceKeys,\n    convert_to_dst_type,\n    convert_to_numpy,\n    convert_to_tensor,\n    ensure_tuple,\n    ensure_tuple_rep,\n    fall_back_tuple,\n    optional_import,\n)\n\nnib, has_nib = optional_import(\"nibabel\")\ncupy, _ = optional_import(\"cupy\")\ncupy_ndi, _ = optional_import(\"cupyx.scipy.ndimage\")\nnp_ndi, _ = optional_import(\"scipy.ndimage\")\n\n__all__ = [\"spatial_resample\", \"orientation\", \"flip\", \"resize\", \"rotate\", \"zoom\", \"rotate90\", \"affine_func\"]\n\n\ndef _maybe_new_metatensor(img, dtype=None, device=None):\n    \"\"\"create a metatensor with fresh metadata if track_meta is True otherwise convert img into a torch tensor\"\"\"\n    return convert_to_tensor(\n        img.as_tensor() if isinstance(img, MetaTensor) else img,\n        dtype=dtype,\n        device=device,\n        track_meta=get_track_meta(),\n        wrap_sequence=True,\n    )\n\n\ndef spatial_resample(\n    img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, lazy, transform_info\n) -> torch.Tensor:\n    \"\"\"\n    Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be resampled, assuming `img` is channel-first.\n        dst_affine: target affine matrix, if None, use the input affine matrix, effectively no resampling.\n        spatial_size: output spatial size, if the component is ``-1``, use the corresponding input spatial size.\n        mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n            Interpolation mode to calculate output values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n            and the value represents the order of the spline interpolation.\n            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            When `mode` is an integer, using numpy/cupy backends, this argument accepts\n            {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n        align_corners: Geometrically, we consider the pixels of the input as squares rather than points.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        dtype_pt: data `dtype` for resampling computation.\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n    \"\"\"\n    original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4)\n    img = convert_to_tensor(data=img, track_meta=get_track_meta())\n    # ensure spatial rank is <= 3\n    spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3)\n    if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None:\n        spatial_rank = min(len(ensure_tuple(spatial_size)), 3)  # infer spatial rank based on spatial_size\n    src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64)\n    dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine\n    dst_affine = convert_to_dst_type(dst_affine, src_affine)[0]\n    if not isinstance(dst_affine, torch.Tensor):\n        raise ValueError(f\"dst_affine should be a torch.Tensor, got {type(dst_affine)}\")\n\n    in_spatial_size = torch.tensor(original_spatial_shape[:spatial_rank])\n    if isinstance(spatial_size, int) and (spatial_size == -1):  # using the input spatial size\n        spatial_size = in_spatial_size\n    elif spatial_size is None and spatial_rank > 1:  # auto spatial size\n        spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine)  # type: ignore\n    spatial_size = torch.tensor(\n        fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size, lambda x: x >= 0)\n    )\n    extra_info = {\n        \"dtype\": str(dtype_pt)[6:],  # remove \"torch\": torch.float32 -> float32\n        \"mode\": mode.value if isinstance(mode, Enum) else mode,\n        \"padding_mode\": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode,\n        \"align_corners\": align_corners if align_corners is not None else TraceKeys.NONE,\n        \"src_affine\": src_affine,\n    }\n    try:\n        _s = convert_to_numpy(src_affine)\n        _d = convert_to_numpy(dst_affine)\n        xform = np.eye(spatial_rank + 1) if spatial_rank < 2 else np.linalg.solve(_s, _d)\n    except (np.linalg.LinAlgError, RuntimeError) as e:\n        raise ValueError(f\"src affine is not invertible {_s}, {_d}.\") from e\n    xform = convert_to_tensor(to_affine_nd(spatial_rank, xform)).to(device=img.device, dtype=torch.float64)\n    affine_unchanged = (\n        allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size)\n    ) or (allclose(xform, np.eye(len(xform)), atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size))\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=spatial_size,\n        affine=None if affine_unchanged and not lazy else xform,\n        extra_info=extra_info,\n        orig_size=original_spatial_shape,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    if lazy:\n        out = _maybe_new_metatensor(img)\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info  # type: ignore\n    if affine_unchanged:\n        # no significant change or lazy change, return original image\n        out = _maybe_new_metatensor(img, dtype=torch.float32)\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out  # type: ignore\n    # drop current meta first\n    img = img.as_tensor() if isinstance(img, MetaTensor) else img\n    im_size = list(img.shape)\n    chns, in_sp_size, additional_dims = im_size[0], im_size[1 : spatial_rank + 1], im_size[spatial_rank + 1 :]\n\n    if additional_dims:\n        xform_shape = [-1] + in_sp_size\n        img = img.reshape(xform_shape)\n    img = img.to(dtype_pt)\n    if isinstance(mode, int) or USE_COMPILED:\n        dst_xform = create_translate(spatial_rank, [float(d - 1) / 2 for d in spatial_size])\n        xform = xform @ convert_to_dst_type(dst_xform, xform)[0]\n        affine_xform = monai.transforms.Affine(\n            affine=xform,\n            spatial_size=spatial_size,\n            normalized=True,\n            image_only=True,\n            dtype=dtype_pt,\n            align_corners=align_corners,\n        )\n        with affine_xform.trace_transform(False):\n            img = affine_xform(img, mode=mode, padding_mode=padding_mode)\n    else:\n        _, _m, _p, _ = resolves_modes(mode, padding_mode)\n        affine_xform = AffineTransform(  # type: ignore\n            normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True\n        )\n        img = affine_xform(img.unsqueeze(0), theta=xform.to(img), spatial_size=spatial_size).squeeze(0)  # type: ignore\n    if additional_dims:\n        full_shape = (chns, *spatial_size, *additional_dims)\n        img = img.reshape(full_shape)\n    out = _maybe_new_metatensor(img, dtype=torch.float32)\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out  # type: ignore\n\n\ndef orientation(img, original_affine, spatial_ornt, lazy, transform_info) -> torch.Tensor:\n    \"\"\"\n    Functional implementation of changing the input image's orientation into the specified based on `spatial_ornt`.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        original_affine: original affine of the input image.\n        spatial_ornt: orientations of the spatial axes,\n            see also https://nipy.org/nibabel/reference/nibabel.orientations.html\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n    \"\"\"\n    spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    xform = nib.orientations.inv_ornt_aff(spatial_ornt, spatial_shape)\n    img = convert_to_tensor(img, track_meta=get_track_meta())\n\n    spatial_ornt[:, 0] += 1  # skip channel dim\n    spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt])\n    axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1]\n    full_transpose = np.arange(len(spatial_shape) + 1)  # channel-first array\n    full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0])\n    extra_info = {\"original_affine\": original_affine}\n\n    shape_np = convert_to_numpy(spatial_shape, wrap_sequence=True)\n    shape_np = shape_np[[i - 1 for i in full_transpose if i > 0]]\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=shape_np,\n        affine=xform,\n        extra_info=extra_info,\n        orig_size=spatial_shape,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    out = _maybe_new_metatensor(img)\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info  # type: ignore\n    if axes:\n        out = torch.flip(out, dims=axes)\n    if not np.all(full_transpose == np.arange(len(out.shape))):\n        out = out.permute(full_transpose.tolist())\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out  # type: ignore\n\n\ndef flip(img, sp_axes, lazy, transform_info):\n    \"\"\"\n    Functional implementation of flip.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        sp_axes: spatial axes along which to flip over.\n            If None, will flip over all of the axes of the input array.\n            If axis is negative it counts from the last to the first axis.\n            If axis is a tuple of ints, flipping is performed on all of the axes\n            specified in the tuple.\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n    \"\"\"\n    sp_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    sp_size = convert_to_numpy(sp_size, wrap_sequence=True).tolist()\n    extra_info = {\"axes\": sp_axes}  # track the spatial axes\n    axes = monai.transforms.utils.map_spatial_axes(img.ndim, sp_axes)  # use the axes with channel dim\n    rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)\n    # axes include the channel dim\n    xform = torch.eye(int(rank) + 1, dtype=torch.double)\n    for axis in axes:\n        sp = axis - 1\n        xform[sp, sp], xform[sp, -1] = xform[sp, sp] * -1, sp_size[sp] - 1\n    meta_info = TraceableTransform.track_transform_meta(\n        img, sp_size=sp_size, affine=xform, extra_info=extra_info, transform_info=transform_info, lazy=lazy\n    )\n    out = _maybe_new_metatensor(img)\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info\n    out = torch.flip(out, axes)\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n\n\ndef resize(\n    img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info\n):\n    \"\"\"\n    Functional implementation of resize.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        out_size: expected shape of spatial dimensions after resize operation.\n        mode: {``\"nearest\"``, ``\"nearest-exact\"``, ``\"linear\"``,\n            ``\"bilinear\"``, ``\"bicubic\"``, ``\"trilinear\"``, ``\"area\"``}\n            The interpolation mode.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n        align_corners: This only has an effect when mode is\n            'linear', 'bilinear', 'bicubic' or 'trilinear'.\n        dtype: data type for resampling computation. If None, use the data type of input data.\n        input_ndim: number of spatial dimensions.\n        anti_aliasing: whether to apply a Gaussian filter to smooth the image prior\n            to downsampling. It is crucial to filter when downsampling\n            the image to avoid aliasing artifacts. See also ``skimage.transform.resize``\n        anti_aliasing_sigma: {float, tuple of floats}, optional\n            Standard deviation for Gaussian filtering used when anti-aliasing.\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n    \"\"\"\n    img = convert_to_tensor(img, track_meta=get_track_meta())\n    orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    extra_info = {\n        \"mode\": mode,\n        \"align_corners\": align_corners if align_corners is not None else TraceKeys.NONE,\n        \"dtype\": str(dtype)[6:],  # dtype as string; remove \"torch\": torch.float32 -> float32\n        \"new_dim\": len(orig_size) - input_ndim,\n    }\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=out_size,\n        affine=scale_affine(orig_size, out_size),\n        extra_info=extra_info,\n        orig_size=orig_size,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    if lazy:\n        if anti_aliasing and lazy:\n            warnings.warn(\"anti-aliasing is not compatible with lazy evaluation.\")\n        out = _maybe_new_metatensor(img)\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info\n    if tuple(convert_to_numpy(orig_size)) == out_size:\n        out = _maybe_new_metatensor(img, dtype=torch.float32)\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n    out = _maybe_new_metatensor(img)\n    img_ = convert_to_tensor(out, dtype=dtype, track_meta=False)  # convert to a regular tensor\n    if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])):\n        factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(out_size))\n        if anti_aliasing_sigma is None:\n            # if sigma is not given, use the default sigma in skimage.transform.resize\n            anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist()\n        else:\n            # if sigma is given, use the given value for downsampling axis\n            anti_aliasing_sigma = list(ensure_tuple_rep(anti_aliasing_sigma, len(out_size)))\n            for axis in range(len(out_size)):\n                anti_aliasing_sigma[axis] = anti_aliasing_sigma[axis] * int(factors[axis] > 1)\n        anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma)\n        img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False)\n    _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_.shape) - 1)\n    resized = torch.nn.functional.interpolate(\n        input=img_.unsqueeze(0), size=out_size, mode=_m, align_corners=align_corners\n    )\n    out, *_ = convert_to_dst_type(resized.squeeze(0), out, dtype=torch.float32)\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n\n\ndef rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):\n    \"\"\"\n    Functional implementation of rotate.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D.\n        output_shape: output shape of the rotated data.\n        mode: {``\"bilinear\"``, ``\"nearest\"``}\n            Interpolation mode to calculate output values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        dtype: data type for resampling computation.\n            If None, use the data type of input data. To be compatible with other modules,\n            the output data type is always ``float32``.\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n\n    \"\"\"\n\n    im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    input_ndim = len(im_shape)\n    if input_ndim not in (2, 3):\n        raise ValueError(f\"Unsupported image dimension: {input_ndim}, available options are [2, 3].\")\n    _angle = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3)\n    transform = create_rotate(input_ndim, _angle)\n    if output_shape is None:\n        corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing=\"ij\")).reshape((len(im_shape), -1))\n        corners = transform[:-1, :-1] @ corners  # type: ignore\n        output_shape = np.asarray(np.ptp(corners, axis=1) + 0.5, dtype=int)\n    else:\n        output_shape = np.asarray(output_shape, dtype=int)\n    shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist())\n    shift_1 = create_translate(input_ndim, (-(np.asarray(output_shape, dtype=int) - 1) / 2).tolist())\n    transform = shift @ transform @ shift_1\n    extra_info = {\n        \"rot_mat\": transform,\n        \"mode\": mode,\n        \"padding_mode\": padding_mode,\n        \"align_corners\": align_corners if align_corners is not None else TraceKeys.NONE,\n        \"dtype\": str(dtype)[6:],  # dtype as string; remove \"torch\": torch.float32 -> float32\n    }\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=output_shape,\n        affine=transform,\n        extra_info=extra_info,\n        orig_size=im_shape,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    out = _maybe_new_metatensor(img)\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info\n    _, _m, _p, _ = resolves_modes(mode, padding_mode)\n    xform = AffineTransform(\n        normalized=False, mode=_m, padding_mode=_p, align_corners=align_corners, reverse_indexing=True\n    )\n    img_t = out.to(dtype)\n    transform_t, *_ = convert_to_dst_type(transform, img_t)\n    output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=tuple(int(i) for i in output_shape))\n    output = output.float().squeeze(0)\n    out, *_ = convert_to_dst_type(output, dst=out, dtype=torch.float32)\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n\n\ndef zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info, **kwargs):\n    \"\"\"\n    Functional implementation of zoom.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        scale_factor: The zoom factor along the spatial axes.\n            If a float, zoom is the same for each spatial axis.\n            If a sequence, zoom should contain one value for each spatial axis.\n        keep_size: Whether keep original size (padding/slicing if needed).\n        mode: {``\"bilinear\"``, ``\"nearest\"``}\n            Interpolation mode to calculate output values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n        dtype: data type for resampling computation.\n            If None, use the data type of input data. To be compatible with other modules,\n            the output data type is always ``float32``.\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n\n    \"\"\"\n    im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    output_size = [int(math.floor(float(i) * z)) for i, z in zip(im_shape, scale_factor)]\n    xform = scale_affine(im_shape, output_size)\n    extra_info = {\n        \"mode\": mode,\n        \"align_corners\": align_corners if align_corners is not None else TraceKeys.NONE,\n        \"dtype\": str(dtype)[6:],  # dtype as string; remove \"torch\": torch.float32 -> float32\n        \"do_padcrop\": False,\n        \"padcrop\": {},\n    }\n    if keep_size:\n        do_pad_crop = not np.allclose(output_size, im_shape)\n        if do_pad_crop and lazy:  # update for lazy evaluation\n            _pad_crop = ResizeWithPadOrCrop(spatial_size=im_shape, mode=padding_mode, **kwargs)\n            _pad_crop.lazy = True\n            _tmp_img = MetaTensor([], affine=torch.eye(len(output_size) + 1))\n            _tmp_img.push_pending_operation({LazyAttr.SHAPE: list(output_size), LazyAttr.AFFINE: xform})\n            lazy_cropped = _pad_crop(_tmp_img)\n            if isinstance(lazy_cropped, MetaTensor):\n                xform = lazy_cropped.peek_pending_affine()\n                extra_info[\"padcrop\"] = lazy_cropped.pending_operations[-1]\n            extra_info[\"do_padcrop\"] = do_pad_crop\n        output_size = [int(i) for i in im_shape]\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=output_size,\n        affine=xform,\n        extra_info=extra_info,\n        orig_size=im_shape,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    out = _maybe_new_metatensor(img)\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info\n    img_t = out.to(dtype)\n    _, _m, _, _ = resolves_modes(mode, torch_interpolate_spatial_nd=len(img_t.shape) - 1)\n    zoomed: NdarrayOrTensor = torch.nn.functional.interpolate(\n        recompute_scale_factor=True,\n        input=img_t.unsqueeze(0),\n        scale_factor=list(scale_factor),\n        mode=_m,\n        align_corners=align_corners,\n    ).squeeze(0)\n    out, *_ = convert_to_dst_type(zoomed, dst=out, dtype=torch.float32)\n    if isinstance(out, MetaTensor):\n        out = out.copy_meta_from(meta_info)\n    do_pad_crop = not np.allclose(output_size, zoomed.shape[1:])\n    if do_pad_crop:\n        _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=padding_mode, **kwargs)\n        out = _pad_crop(out)\n    if get_track_meta() and do_pad_crop:\n        padcrop_xform = out.applied_operations.pop()\n        out.applied_operations[-1][\"extra_info\"][\"do_padcrop\"] = True\n        out.applied_operations[-1][\"extra_info\"][\"padcrop\"] = padcrop_xform\n    return out\n\n\ndef rotate90(img, axes, k, lazy, transform_info):\n    \"\"\"\n    Functional implementation of rotate90.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.\n            If axis is negative it counts from the last to the first axis.\n        k: number of times to rotate by 90 degrees.\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n    \"\"\"\n    extra_info = {\"axes\": [d - 1 for d in axes], \"k\": k}\n    ori_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    sp_shape = list(ori_shape)\n    if k in (1, 3):\n        a_0, a_1 = axes[0] - 1, axes[1] - 1\n        sp_shape[a_0], sp_shape[a_1] = ori_shape[a_1], ori_shape[a_0]\n    rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)\n    r, sp_r = int(rank), len(ori_shape)\n    xform = to_affine_nd(r, create_translate(sp_r, [-float(d - 1) / 2 for d in sp_shape]))\n    s = -1.0 if int(axes[0]) - int(axes[1]) in (-1, 2) else 1.0\n    if sp_r == 2:\n        rot90 = to_affine_nd(r, create_rotate(sp_r, [s * np.pi / 2]))\n    else:\n        idx = {1, 2, 3} - set(axes)\n        angle: list[float] = [0, 0, 0]\n        angle[idx.pop() - 1] = s * np.pi / 2\n        rot90 = to_affine_nd(r, create_rotate(sp_r, angle))\n    for _ in range(k):\n        xform = rot90 @ xform\n    xform = to_affine_nd(r, create_translate(sp_r, [float(d - 1) / 2 for d in ori_shape])) @ xform\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=sp_shape,\n        affine=xform,\n        extra_info=extra_info,\n        orig_size=ori_shape,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    out = _maybe_new_metatensor(img)\n    if lazy:\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info\n    out = torch.rot90(out, k, axes)\n    return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n\n\ndef affine_func(\n    img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, lazy, transform_info\n):\n    \"\"\"\n    Functional implementation of affine.\n    This function operates eagerly or lazily according to\n    ``lazy`` (default ``False``).\n\n    Args:\n        img: data to be changed, assuming `img` is channel-first.\n        affine: the affine transformation to be applied, it can be a 3x3 or 4x4 matrix. This should be defined\n            for the voxel space spatial centers (``float(size - 1)/2``).\n        grid: used in non-lazy mode to pre-compute the grid to do the resampling.\n        resampler: the resampler function, see also: :py:class:`monai.transforms.Resample`.\n        sp_size: output image spatial size.\n        mode: {``\"bilinear\"``, ``\"nearest\"``} or spline interpolation order 0-5 (integers).\n            Interpolation mode to calculate output values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used\n            and the value represents the order of the spline interpolation.\n            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n        padding_mode: {``\"zeros\"``, ``\"border\"``, ``\"reflection\"``}\n            Padding mode for outside grid values.\n            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n            When `mode` is an integer, using numpy/cupy backends, this argument accepts\n            {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.\n            See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n        do_resampling: whether to do the resampling, this is a flag for the use case of updating metadata but\n            skipping the actual (potentially heavy) resampling operation.\n        image_only: if True return only the image volume, otherwise return (image, affine).\n        lazy: a flag that indicates whether the operation should be performed lazily or not\n        transform_info: a dictionary with the relevant information pertaining to an applied transform.\n\n    \"\"\"\n\n    # resampler should carry the align_corners and type info\n    img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]\n    rank = img.peek_pending_rank() if isinstance(img, MetaTensor) else torch.tensor(3.0, dtype=torch.double)\n    extra_info = {\n        \"affine\": affine,\n        \"mode\": mode,\n        \"padding_mode\": padding_mode,\n        \"do_resampling\": do_resampling,\n        \"align_corners\": resampler.align_corners,\n    }\n    affine = monai.transforms.Affine.compute_w_affine(rank, affine, img_size, sp_size)\n    meta_info = TraceableTransform.track_transform_meta(\n        img,\n        sp_size=sp_size,\n        affine=affine,\n        extra_info=extra_info,\n        orig_size=img_size,\n        transform_info=transform_info,\n        lazy=lazy,\n    )\n    if lazy:\n        out = _maybe_new_metatensor(img)\n        out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info\n        return out if image_only else (out, affine)\n    if do_resampling:\n        out = resampler(img=img, grid=grid, mode=mode, padding_mode=padding_mode)\n        out = _maybe_new_metatensor(out)\n    else:\n        out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)\n    out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n    return out if image_only else (out, affine)\n\n\ndef convert_box_to_points(bbox, mode):\n    \"\"\"\n    Converts an axis-aligned bounding box to points.\n\n    Args:\n        mode: The mode specifying how to interpret the bounding box.\n        bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]\n            for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.\n\n    Returns:\n        sequence of points representing the corners of the bounding box.\n    \"\"\"\n\n    mode = get_boxmode(mode)\n\n    points_list = []\n    for _num in range(bbox.shape[0]):\n        corners = mode.boxes_to_corners(bbox[_num : _num + 1])\n        if len(corners) == 4:\n            points_list.append(\n                concatenate(\n                    [\n                        concatenate([corners[0], corners[1]], axis=1),\n                        concatenate([corners[2], corners[1]], axis=1),\n                        concatenate([corners[2], corners[3]], axis=1),\n                        concatenate([corners[0], corners[3]], axis=1),\n                    ],\n                    axis=0,\n                )\n            )\n        else:\n            points_list.append(\n                concatenate(\n                    [\n                        concatenate([corners[0], corners[1], corners[2]], axis=1),\n                        concatenate([corners[3], corners[1], corners[2]], axis=1),\n                        concatenate([corners[3], corners[4], corners[2]], axis=1),\n                        concatenate([corners[0], corners[4], corners[2]], axis=1),\n                        concatenate([corners[0], corners[1], corners[5]], axis=1),\n                        concatenate([corners[3], corners[1], corners[5]], axis=1),\n                        concatenate([corners[3], corners[4], corners[5]], axis=1),\n                        concatenate([corners[0], corners[4], corners[5]], axis=1),\n                    ],\n                    axis=0,\n                )\n            )\n\n    return stack(points_list, dim=0)\n\n\ndef convert_points_to_box(points):\n    \"\"\"\n    Converts points to an axis-aligned bounding box.\n\n    Args:\n        points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of\n            a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.\n    \"\"\"\n    from monai.transforms.utils_pytorch_numpy_unification import max, min\n\n    mins = min(points, dim=1)\n    maxs = max(points, dim=1)\n    # Concatenate the min and max values to get the bounding boxes\n    bboxes = concatenate([mins, maxs], axis=1)\n\n    return bboxes\n"
  },
  {
    "path": "monai/transforms/traits.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of generic traits for MONAI transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\n__all__ = [\"LazyTrait\", \"InvertibleTrait\", \"RandomizableTrait\", \"MultiSampleTrait\", \"ThreadUnsafe\", \"ReduceTrait\"]\n\nfrom typing import Any\n\n\nclass LazyTrait:\n    \"\"\"\n    An interface to indicate that the transform has the capability to execute using\n    MONAI's lazy resampling feature. In order to do this, the implementing class needs\n    to be able to describe its operation as an affine matrix or grid with accompanying metadata.\n    This interface can be extended from by people adapting transforms to the MONAI framework as\n    well as by implementors of MONAI transforms.\n    \"\"\"\n\n    @property\n    def lazy(self):\n        \"\"\"\n        Get whether lazy evaluation is enabled for this transform instance.\n        Returns:\n            True if the transform is operating in a lazy fashion, False if not.\n        \"\"\"\n        raise NotImplementedError()\n\n    @lazy.setter\n    def lazy(self, enabled: bool):\n        \"\"\"\n        Set whether lazy evaluation is enabled for this transform instance.\n        Args:\n            enabled: True if the transform should operate in a lazy fashion, False if not.\n        \"\"\"\n        raise NotImplementedError()\n\n    @property\n    def requires_current_data(self):\n        \"\"\"\n        Get whether the transform requires the input data to be up to date before the transform executes.\n        Such transforms can still execute lazily by adding pending operations to the output tensors.\n        Returns:\n            True if the transform requires its inputs to be up to date and False if it does not\n        \"\"\"\n\n\nclass InvertibleTrait:\n    \"\"\"\n    An interface to indicate that the transform can be inverted, i.e. undone by performing\n    the inverse of the operation performed during `__call__`.\n    \"\"\"\n\n    def inverse(self, data: Any) -> Any:\n        raise NotImplementedError()\n\n\nclass RandomizableTrait:\n    \"\"\"\n    An interface to indicate that the transform has the capability to perform\n    randomized transforms to the data that it is called upon. This interface\n    can be extended from by people adapting transforms to the MONAI framework as well as by\n    implementors of MONAI transforms.\n    \"\"\"\n\n\nclass MultiSampleTrait:\n    \"\"\"\n    An interface to indicate that the transform has the capability to return multiple samples\n    given an input, such as when performing random crops of a sample. This interface can be\n    extended from by people adapting transforms to the MONAI framework as well as by implementors\n    of MONAI transforms.\n    \"\"\"\n\n\nclass ThreadUnsafe:\n    \"\"\"\n    A class to denote that the transform will mutate its member variables,\n    when being applied. Transforms inheriting this class should be used\n    cautiously in a multi-thread context.\n\n    This type is typically used by :py:class:`monai.data.CacheDataset` and\n    its extensions, where the transform cache is built with multiple threads.\n    \"\"\"\n\n\nclass ReduceTrait:\n    \"\"\"\n    An interface to indicate that the transform has the capability to reduce multiple samples\n    into a single sample.\n    This interface can be extended from by people adapting transforms to the MONAI framework as well\n    as by implementors of MONAI transforms.\n    \"\"\"\n"
  },
  {
    "path": "monai/transforms/transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of generic interfaces for MONAI transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Generator, Hashable, Iterable, Mapping\nfrom typing import Any, TypeVar\n\nimport numpy as np\nimport torch\n\nfrom monai import config, transforms\nfrom monai.config import KeysCollection\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.traits import LazyTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe\nfrom monai.utils import MAX_SEED, ensure_tuple, first\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.misc import MONAIEnvVars\n\n__all__ = [\n    \"ThreadUnsafe\",\n    \"apply_transform\",\n    \"Randomizable\",\n    \"LazyTransform\",\n    \"RandomizableTransform\",\n    \"Transform\",\n    \"MapTransform\",\n]\n\nReturnType = TypeVar(\"ReturnType\")\n\n\ndef _apply_transform(\n    transform: Callable[..., ReturnType],\n    data: Any,\n    unpack_parameters: bool = False,\n    lazy: bool | None = False,\n    overrides: dict | None = None,\n    logger_name: bool | str = False,\n) -> ReturnType:\n    \"\"\"\n    Perform a transform 'transform' on 'data', according to the other parameters specified.\n\n    If `data` is a tuple and `unpack_parameters` is True, each parameter of `data` is unpacked\n    as arguments to `transform`. Otherwise `data` is considered as single argument to `transform`.\n\n    If 'lazy' is True, this method first checks whether it can execute this method lazily. If it\n    can't, it will ensure that all pending lazy transforms on 'data' are applied before applying\n    this 'transform' to it. If 'lazy' is True, and 'overrides' are provided, those overrides will\n    be applied to the pending operations on 'data'. See ``Compose`` for more details on lazy\n    resampling, which is an experimental feature for 1.2.\n\n    Please note, this class is function is designed to be called by ``apply_transform``.\n    In general, you should not need to make specific use of it unless you are implementing\n    pipeline execution mechanisms.\n\n    Args:\n        transform: a callable to be used to transform `data`.\n        data: the tensorlike or dictionary of tensorlikes to be executed on\n        unpack_parameters: whether to unpack parameters for `transform`. Defaults to False.\n        lazy: whether to enable lazy evaluation for lazy transforms. If False, transforms will be\n            carried out on a transform by transform basis. If True, all lazy transforms will\n            be executed by accumulating changes and resampling as few times as possible.\n            See the :ref:`Lazy Resampling topic<lazy_resampling> for more information about lazy resampling.\n        overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden\n            when executing a pipeline. These each parameter that is compatible with a given transform is then applied\n            to that transform before it is executed. Note that overrides are currently only applied when\n            :ref:`Lazy Resampling<lazy_resampling>` is enabled for the pipeline or a given transform. If lazy is False\n            they are ignored. Currently supported args are:\n            {``\"mode\"``, ``\"padding_mode\"``, ``\"dtype\"``, ``\"align_corners\"``, ``\"resample_mode\"``, ``device``}.\n        logger_name: this optional parameter allows you to specify a logger by name for logging of pipeline execution.\n            Setting this to False disables logging. Setting it to True enables logging to the default loggers.\n            Setting a string overrides the logger name to which logging is performed.\n\n    Returns:\n        ReturnType: The return type of `transform`.\n    \"\"\"\n    from monai.transforms.lazy.functional import apply_pending_transforms_in_order\n\n    data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name)\n\n    if isinstance(data, tuple) and unpack_parameters:\n        return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)\n\n    return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)\n\n\ndef apply_transform(\n    transform: Callable[..., ReturnType],\n    data: Any,\n    map_items: bool | int = True,\n    unpack_items: bool = False,\n    log_stats: bool | str = False,\n    lazy: bool | None = None,\n    overrides: dict | None = None,\n) -> list[Any] | ReturnType:\n    \"\"\"\n    Transform `data` with `transform`.\n\n    If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed\n    and this method returns a list of outcomes.\n    otherwise transform will be applied once with `data` as the argument.\n\n    Args:\n        transform: a callable to be used to transform `data`.\n        data: an object to be transformed.\n        map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,\n            it can behave as follows:\n            - Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied\n              to the first level of items in `data`.\n            - If an integer is provided, it specifies the maximum level of nesting to which the transformation\n              should be recursively applied. This allows treating multi-sample transforms applied after another\n              multi-sample transform while controlling how deep the mapping goes.\n        unpack_items: whether to unpack parameters using `*`. Defaults to False.\n        log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which\n            disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the\n            default logger name. Setting it to a string specifies the logger to which errors should be logged.\n        lazy: whether to execute in lazy mode or not. See the :ref:`Lazy Resampling topic<lazy_resampling> for more\n            information about lazy resampling. Defaults to None.\n        overrides: optional overrides to apply to transform parameters. This parameter is ignored unless transforms\n            are being executed lazily. See the :ref:`Lazy Resampling topic<lazy_resampling> for more details and\n            examples of its usage.\n\n    Raises:\n        Exception: When ``transform`` raises an exception.\n\n    Returns:\n        Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.\n    \"\"\"\n    try:\n        map_items_ = int(map_items) if isinstance(map_items, bool) else map_items\n        if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):\n            return [\n                apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)\n                for item in data\n            ]\n        return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)\n    except Exception as e:\n        # if in debug mode, don't swallow exception so that the breakpoint\n        # appears where the exception was raised.\n        if MONAIEnvVars.debug():\n            raise\n        if log_stats is not False and not isinstance(transform, transforms.compose.Compose):\n            # log the input data information of exact transform in the transform chain\n            if isinstance(log_stats, str):\n                datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats)\n            else:\n                datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False)\n            logger = logging.getLogger(datastats._logger_name)\n            logger.error(f\"\\n=== Transform input info -- {type(transform).__name__} ===\")\n            if isinstance(data, (list, tuple)):\n                data = data[0]\n\n            def _log_stats(data, prefix: str | None = \"Data\"):\n                if isinstance(data, (np.ndarray, torch.Tensor)):\n                    # log data type, shape, range for array\n                    datastats(img=data, data_shape=True, value_range=True, prefix=prefix)\n                else:\n                    # log data type and value for other metadata\n                    datastats(img=data, data_value=True, prefix=prefix)\n\n            if isinstance(data, dict):\n                for k, v in data.items():\n                    _log_stats(data=v, prefix=k)\n            else:\n                _log_stats(data=data)\n        raise RuntimeError(f\"applying transform {transform}\") from e\n\n\nclass Randomizable(ThreadUnsafe, RandomizableTrait):\n    \"\"\"\n    An interface for handling random state locally, currently based on a class\n    variable `R`, which is an instance of `np.random.RandomState`.  This\n    provides the flexibility of component-specific determinism without\n    affecting the global states.  It is recommended to use this API with\n    :py:class:`monai.data.DataLoader` for deterministic behaviour of the\n    preprocessing pipelines. This API is not thread-safe. Additionally,\n    deepcopying instance of this class often causes insufficient randomness as\n    the random states will be duplicated.\n    \"\"\"\n\n    R: np.random.RandomState = np.random.RandomState()\n\n    def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:\n        \"\"\"\n        Set the random state locally, to control the randomness, the derived\n        classes should use :py:attr:`self.R` instead of `np.random` to introduce random\n        factors.\n\n        Args:\n            seed: set the random state with an integer seed.\n            state: set the random state with a `np.random.RandomState` object.\n\n        Raises:\n            TypeError: When ``state`` is not an ``Optional[np.random.RandomState]``.\n\n        Returns:\n            a Randomizable instance.\n\n        \"\"\"\n        if seed is not None:\n            _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)\n            _seed = _seed % MAX_SEED  # need to account for Numpy2.0 which doesn't silently convert to int64\n            self.R = np.random.RandomState(_seed)\n            return self\n\n        if state is not None:\n            if not isinstance(state, np.random.RandomState):\n                raise TypeError(f\"state must be None or a np.random.RandomState but is {type(state).__name__}.\")\n            self.R = state\n            return self\n\n        self.R = np.random.RandomState()\n        return self\n\n    def randomize(self, data: Any) -> None:\n        \"\"\"\n        Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors.\n\n        all :py:attr:`self.R` calls happen here so that we have a better chance to\n        identify errors of sync the random state.\n\n        This method can generate the random factors based on properties of the input data.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass Transform(ABC):\n    \"\"\"\n    An abstract class of a ``Transform``.\n    A transform is callable that processes ``data``.\n\n    It could be stateful and may modify ``data`` in place,\n    the implementation should be aware of:\n\n        #. thread safety when mutating its own states.\n           When used from a multi-process context, transform's instance variables are read-only.\n           thread-unsafe transforms should inherit :py:class:`monai.transforms.ThreadUnsafe`.\n        #. ``data`` content unused by this transform may still be used in the\n           subsequent transforms in a composed transform.\n        #. storing too much information in ``data`` may cause some memory issue or IPC sync issue,\n           especially in the multi-processing environment of PyTorch DataLoader.\n\n    See Also\n\n        :py:class:`monai.transforms.Compose`\n    \"\"\"\n\n    # Transforms should add `monai.transforms.utils.TransformBackends` to this list if they are performing\n    # the data processing using the corresponding backend APIs.\n    # Most of MONAI transform's inputs and outputs will be converted into torch.Tensor or monai.data.MetaTensor.\n    # This variable provides information about whether the input will be converted\n    # to other data types during the transformation. Note that not all `dtype` (such as float32, uint8) are supported\n    # by all the data types, the `dtype` during the conversion is determined automatically by each transform,\n    # please refer to the transform's docstring.\n    backend: list[TransformBackends] = []\n\n    @abstractmethod\n    def __call__(self, data: Any):\n        \"\"\"\n        ``data`` is an element which often comes from an iteration over an\n        iterable, such as :py:class:`torch.utils.data.Dataset`. This method should\n        return an updated version of ``data``.\n        To simplify the input validations, most of the transforms assume that\n\n        - ``data`` is a Numpy ndarray, PyTorch Tensor or string,\n        - the data shape can be:\n\n          #. string data without shape, `LoadImage` transform expects file paths,\n          #. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,\n             except for example: `AddChannel` expects (spatial_dim_1[, spatial_dim_2, ...])\n\n        - the channel dimension is often not omitted even if number of channels is one.\n\n        This method can optionally take additional arguments to help execute transformation operation.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n\nclass LazyTransform(Transform, LazyTrait):\n    \"\"\"\n    An implementation of functionality for lazy transforms that can be subclassed by array and\n    dictionary transforms to simplify implementation of new lazy transforms.\n    \"\"\"\n\n    def __init__(self, lazy: bool | None = False):\n        if lazy is not None:\n            if not isinstance(lazy, bool):\n                raise TypeError(f\"lazy must be a bool but is of type {type(lazy)}\")\n        self._lazy = lazy\n\n    @property\n    def lazy(self):\n        return self._lazy\n\n    @lazy.setter\n    def lazy(self, lazy: bool | None):\n        if lazy is not None:\n            if not isinstance(lazy, bool):\n                raise TypeError(f\"lazy must be a bool but is of type {type(lazy)}\")\n        self._lazy = lazy\n\n    @property\n    def requires_current_data(self):\n        return False\n\n\nclass RandomizableTransform(Randomizable, Transform):\n    \"\"\"\n    An interface for handling random state locally, currently based on a class variable `R`,\n    which is an instance of `np.random.RandomState`.\n    This class introduces a randomized flag `_do_transform`, is mainly for randomized data augmentation transforms.\n    For example:\n\n    .. code-block:: python\n\n        from monai.transforms import RandomizableTransform\n\n        class RandShiftIntensity100(RandomizableTransform):\n            def randomize(self):\n                super().randomize(None)\n                self._offset = self.R.uniform(low=0, high=100)\n\n            def __call__(self, img):\n                self.randomize()\n                if not self._do_transform:\n                    return img\n                return img + self._offset\n\n        transform = RandShiftIntensity()\n        transform.set_random_state(seed=0)\n        print(transform(10))\n\n    \"\"\"\n\n    def __init__(self, prob: float = 1.0, do_transform: bool = True):\n        self._do_transform = do_transform\n        self.prob = min(max(prob, 0.0), 1.0)\n\n    def randomize(self, data: Any) -> None:\n        \"\"\"\n        Within this method, :py:attr:`self.R` should be used, instead of `np.random`, to introduce random factors.\n\n        all :py:attr:`self.R` calls happen here so that we have a better chance to\n        identify errors of sync the random state.\n\n        This method can generate the random factors based on properties of the input data.\n        \"\"\"\n        self._do_transform = self.R.rand() < self.prob\n\n\nclass MapTransform(Transform):\n    \"\"\"\n    A subclass of :py:class:`monai.transforms.Transform` with an assumption\n    that the ``data`` input of ``self.__call__`` is a MutableMapping such as ``dict``.\n\n    The ``keys`` parameter will be used to get and set the actual data\n    item to transform.  That is, the callable of this transform should\n    follow the pattern:\n\n        .. code-block:: python\n\n            def __call__(self, data):\n                for key in self.keys:\n                    if key in data:\n                        # update output data with some_transform_function(data[key]).\n                    else:\n                        # raise exception unless allow_missing_keys==True.\n                return data\n\n    Raises:\n        ValueError: When ``keys`` is an empty iterable.\n        TypeError: When ``keys`` type is not in ``Union[Hashable, Iterable[Hashable]]``.\n\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        if config.USE_META_DICT:\n            # call_update after MapTransform.__call__\n            cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, \"post\")  # type: ignore\n\n            if hasattr(cls, \"inverse\"):\n                # inverse_update before InvertibleTransform.inverse\n                cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update)\n        return Transform.__new__(cls)\n\n    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:\n        super().__init__()\n        self.keys: tuple[Hashable, ...] = ensure_tuple(keys)\n        self.allow_missing_keys = allow_missing_keys\n        if not self.keys:\n            raise ValueError(\"keys must be non empty.\")\n        for key in self.keys:\n            if not isinstance(key, Hashable):\n                raise TypeError(f\"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.\")\n\n    def call_update(self, data):\n        \"\"\"\n        This function is to be called after every `self.__call__(data)`,\n        update `data[key_transforms]` and `data[key_meta_dict]` using the content from MetaTensor `data[key]`,\n        for MetaTensor backward compatibility 0.9.0.\n        \"\"\"\n        if not isinstance(data, (list, tuple, Mapping)):\n            return data\n        is_dict = False\n        if isinstance(data, Mapping):\n            data, is_dict = [data], True\n        if not data or not isinstance(data[0], Mapping):\n            return data[0] if is_dict else data\n        list_d = [dict(x) for x in data]  # list of dict for crop samples\n        for idx, dict_i in enumerate(list_d):\n            for k in dict_i:\n                if not isinstance(dict_i[k], MetaTensor):\n                    continue\n                list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD))\n        return list_d[0] if is_dict else list_d\n\n    @abstractmethod\n    def __call__(self, data):\n        \"\"\"\n        ``data`` often comes from an iteration over an iterable,\n        such as :py:class:`torch.utils.data.Dataset`.\n\n        To simplify the input validations, this method assumes:\n\n        - ``data`` is a Python dictionary,\n        - ``data[key]`` is a Numpy ndarray, PyTorch Tensor or string, where ``key`` is an element\n          of ``self.keys``, the data shape can be:\n\n          #. string data without shape, `LoadImaged` transform expects file paths,\n          #. most of the pre-/post-processing transforms expect: ``(num_channels, spatial_dim_1[, spatial_dim_2, ...])``,\n             except for example: `AddChanneld` expects (spatial_dim_1[, spatial_dim_2, ...])\n\n        - the channel dimension is often not omitted even if number of channels is one.\n\n        Raises:\n            NotImplementedError: When the subclass does not override this method.\n\n        returns:\n            An updated dictionary version of ``data`` by applying the transform.\n\n        \"\"\"\n        raise NotImplementedError(f\"Subclass {self.__class__.__name__} must implement this method.\")\n\n    def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator:\n        \"\"\"\n        Iterate across keys and optionally extra iterables. If key is missing, exception is raised if\n        `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.\n\n        Args:\n            data: data that the transform will be applied to\n            extra_iterables: anything else to be iterated through\n        \"\"\"\n        # if no extra iterables given, create a dummy list of Nones\n        ex_iters = extra_iterables or [[None] * len(self.keys)]\n\n        # loop over keys and any extra iterables\n        _ex_iters: list[Any]\n        for key, *_ex_iters in zip(self.keys, *ex_iters):\n            # all normal, yield (what we yield depends on whether extra iterables were given)\n            if key in data:\n                yield (key,) + tuple(_ex_iters) if extra_iterables else key\n            elif not self.allow_missing_keys:\n                raise KeyError(\n                    f\"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False.\"\n                )\n\n    def first_key(self, data: dict[Hashable, Any]):\n        \"\"\"\n        Get the first available key of `self.keys` in the input `data` dictionary.\n        If no available key, return an empty tuple `()`.\n\n        Args:\n            data: data that the transform will be applied to.\n\n        \"\"\"\n        return first(self.key_iterator(data), ())\n"
  },
  {
    "path": "monai/transforms/utility/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "monai/transforms/utility/array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of \"vanilla\" transforms for utility functions.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport sys\nimport time\nimport warnings\nfrom collections.abc import Hashable, Mapping, Sequence\nfrom copy import deepcopy\nfrom functools import partial\nfrom typing import Any, Callable\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.config import DtypeLike\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_obj import get_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import is_no_channel, no_collation, orientation_ras_lps\nfrom monai.networks.layers.simplelayers import (\n    ApplyFilter,\n    EllipticalFilter,\n    GaussianFilter,\n    LaplaceFilter,\n    MeanFilter,\n    SavitzkyGolayFilter,\n    SharpenFilter,\n    median_filter,\n)\nfrom monai.transforms.inverse import InvertibleTransform, TraceableTransform\nfrom monai.transforms.traits import MultiSampleTrait, ReduceTrait\nfrom monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform\nfrom monai.transforms.utils import (\n    apply_affine_to_points,\n    extreme_points_to_image,\n    get_extreme_points,\n    map_binary_to_indices,\n    map_classes_to_indices,\n)\nfrom monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices\nfrom monai.utils import (\n    MetaKeys,\n    TraceKeys,\n    convert_data_type,\n    convert_to_cupy,\n    convert_to_numpy,\n    convert_to_tensor,\n    ensure_tuple,\n    look_up_option,\n    min_version,\n    optional_import,\n)\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype\n\nPILImageImage, has_pil = optional_import(\"PIL.Image\", name=\"Image\")\npil_image_fromarray, _ = optional_import(\"PIL.Image\", name=\"fromarray\")\ncp, has_cp = optional_import(\"cupy\")\n\n__all__ = [\n    \"Identity\",\n    \"RandIdentity\",\n    \"AsChannelLast\",\n    \"AddCoordinateChannels\",\n    \"EnsureChannelFirst\",\n    \"EnsureType\",\n    \"RepeatChannel\",\n    \"RemoveRepeatedChannel\",\n    \"SplitDim\",\n    \"CastToType\",\n    \"ToTensor\",\n    \"ToNumpy\",\n    \"ToPIL\",\n    \"Transpose\",\n    \"SqueezeDim\",\n    \"DataStats\",\n    \"SimulateDelay\",\n    \"Lambda\",\n    \"RandLambda\",\n    \"LabelToMask\",\n    \"FgBgToIndices\",\n    \"ClassesToIndices\",\n    \"ConvertToMultiChannelBasedOnBratsClasses\",\n    \"AddExtremePointsChannel\",\n    \"TorchVision\",\n    \"TorchIO\",\n    \"MapLabelValue\",\n    \"IntensityStats\",\n    \"ToDevice\",\n    \"CuCIM\",\n    \"RandCuCIM\",\n    \"RandTorchIO\",\n    \"RandTorchVision\",\n    \"ToCupy\",\n    \"ImageFilter\",\n    \"RandImageFilter\",\n    \"ApplyTransformToPoints\",\n    \"FlattenSequence\",\n]\n\n\nclass Identity(Transform):\n    \"\"\"\n    Do nothing to the data.\n    As the output value is same as input, it can be used as a testing tool to verify the transform chain,\n    Compose or transform adaptor, etc.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        return img\n\n\nclass RandIdentity(RandomizableTrait):\n    \"\"\"\n    Do nothing to the data. This transform is random, so can be used to stop the caching of any\n    subsequent transforms.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __call__(self, data: Any) -> Any:\n        return data\n\n\nclass AsChannelLast(Transform):\n    \"\"\"\n    Change the channel dimension of the image to the last dimension.\n\n    Some of other 3rd party transforms assume the input image is in the channel-last format with shape\n    (spatial_dim_1[, spatial_dim_2, ...], num_channels).\n\n    This transform could be used to convert, for example, a channel-first image array in shape\n    (num_channels, spatial_dim_1[, spatial_dim_2, ...]) into the channel-last format,\n    so that MONAI transforms can construct a chain with other 3rd party transforms together.\n\n    Args:\n        channel_dim: which dimension of input image is the channel, default is the first dimension.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, channel_dim: int = 0) -> None:\n        if not (isinstance(channel_dim, int) and channel_dim >= -1):\n            raise ValueError(f\"invalid channel dimension ({channel_dim}).\")\n        self.channel_dim = channel_dim\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        out: NdarrayOrTensor = convert_to_tensor(moveaxis(img, self.channel_dim, -1), track_meta=get_track_meta())\n        return out\n\n\nclass EnsureChannelFirst(Transform):\n    \"\"\"\n    Adjust or add the channel dimension of input data to ensure `channel_first` shape.\n\n    This extracts the `original_channel_dim` info from provided meta_data dictionary or MetaTensor input. This value\n    should state which dimension is the channel dimension so that it can be moved forward, or contain \"no_channel\" to\n    state no dimension is the channel and so a 1-size first dimension is to be added.\n\n    Args:\n        strict_check: whether to raise an error when the meta information is insufficient.\n        channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.\n            It overrides the `original_channel_dim` from provided MetaTensor input.\n            If the input array doesn't have a channel dim, this value should be ``'no_channel'``.\n            If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, strict_check: bool = True, channel_dim: None | str | int = None):\n        self.strict_check = strict_check\n        self.input_channel_dim = channel_dim\n\n    def __call__(self, img: torch.Tensor, meta_dict: Mapping | None = None) -> torch.Tensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        if not isinstance(img, MetaTensor) and not isinstance(meta_dict, Mapping):\n            if self.input_channel_dim is None:\n                msg = \"Metadata not available and channel_dim=None, EnsureChannelFirst is not in use.\"\n                if self.strict_check:\n                    raise ValueError(msg)\n                warnings.warn(msg)\n                return img\n            else:\n                img = MetaTensor(img)\n\n        if isinstance(img, MetaTensor):\n            meta_dict = img.meta\n\n        channel_dim = meta_dict.get(MetaKeys.ORIGINAL_CHANNEL_DIM, None) if isinstance(meta_dict, Mapping) else None\n        if self.input_channel_dim is not None:\n            channel_dim = float(\"nan\") if self.input_channel_dim == \"no_channel\" else self.input_channel_dim\n\n        if channel_dim is None:\n            msg = \"Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`.\"\n            if self.strict_check:\n                raise ValueError(msg)\n            warnings.warn(msg)\n            return img\n\n        # track the original channel dim\n        if isinstance(meta_dict, dict):\n            meta_dict[MetaKeys.ORIGINAL_CHANNEL_DIM] = channel_dim\n\n        if is_no_channel(channel_dim):\n            result = img[None]\n        else:\n            result = moveaxis(img, int(channel_dim), 0)  # type: ignore\n\n        return convert_to_tensor(result, track_meta=get_track_meta())  # type: ignore\n\n\nclass RepeatChannel(Transform):\n    \"\"\"\n    Repeat channel data to construct expected input shape for models.\n    The `repeats` count includes the origin data, for example:\n    ``RepeatChannel(repeats=2)([[1, 2], [3, 4]])`` generates: ``[[1, 2], [1, 2], [3, 4], [3, 4]]``\n\n    Args:\n        repeats: the number of repetitions for each element.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, repeats: int) -> None:\n        if repeats <= 0:\n            raise ValueError(f\"repeats count must be greater than 0, got {repeats}.\")\n        self.repeats = repeats\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is a \"channel-first\" array.\n        \"\"\"\n        repeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat\n        return convert_to_tensor(repeat_fn(img, self.repeats, 0), track_meta=get_track_meta())  # type: ignore\n\n\nclass RemoveRepeatedChannel(Transform):\n    \"\"\"\n    RemoveRepeatedChannel data to undo RepeatChannel\n    The `repeats` count specifies the deletion of the origin data, for example:\n    ``RemoveRepeatedChannel(repeats=2)([[1, 2], [1, 2], [3, 4], [3, 4]])`` generates: ``[[1, 2], [3, 4]]``\n\n    Args:\n        repeats: the number of repetitions to be deleted for each element.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, repeats: int) -> None:\n        if repeats <= 0:\n            raise ValueError(f\"repeats count must be greater than 0, got {repeats}.\")\n\n        self.repeats = repeats\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is a \"channel-first\" array.\n        \"\"\"\n        if img.shape[0] < 2:\n            raise ValueError(f\"Image must have more than one channel, got {img.shape[0]} channels.\")\n\n        out: NdarrayOrTensor = convert_to_tensor(img[:: self.repeats, :], track_meta=get_track_meta())\n        return out\n\n\nclass SplitDim(Transform, MultiSampleTrait):\n    \"\"\"\n    Given an image of size X along a certain dimension, return a list of length X containing\n    images. Useful for converting 3D images into a stack of 2D images, splitting multichannel inputs into\n    single channels, for example.\n\n    Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy).\n\n    Args:\n        dim: dimension on which to split\n        keepdim: if `True`, output will have singleton in the split dimension. If `False`, this\n            dimension will be squeezed.\n        update_meta: whether to update the MetaObj in each split result.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, dim: int = -1, keepdim: bool = True, update_meta=True) -> None:\n        self.dim = dim\n        self.keepdim = keepdim\n        self.update_meta = update_meta\n\n    def __call__(self, img: torch.Tensor) -> list[torch.Tensor]:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        n_out = img.shape[self.dim]\n        if isinstance(img, torch.Tensor):\n            outputs = list(torch.split(img, 1, self.dim))\n        else:\n            outputs = np.split(img, n_out, self.dim)\n        for idx, item in enumerate(outputs):\n            if not self.keepdim:\n                outputs[idx] = item.squeeze(self.dim)\n            if self.update_meta and isinstance(img, MetaTensor):\n                if not isinstance(item, MetaTensor):\n                    item = MetaTensor(item, meta=img.meta)\n                if self.dim == 0:  # don't update affine if channel dim\n                    continue\n                ndim = len(item.affine)\n                shift = torch.eye(ndim, device=item.affine.device, dtype=item.affine.dtype)\n                shift[self.dim - 1, -1] = idx\n                item.affine = item.affine @ shift\n        return outputs\n\n\nclass CastToType(Transform):\n    \"\"\"\n    Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to\n    specified PyTorch data type.\n\n    Example:\n        >>> import numpy as np\n        >>> import torch\n        >>> transform = CastToType(dtype=np.float32)\n\n        >>> # Example with a numpy array\n        >>> img_np = np.array([0, 127, 255], dtype=np.uint8)\n        >>> img_np_casted = transform(img_np)\n        >>> img_np_casted\n        array([  0. , 127. , 255. ], dtype=float32)\n\n        >>> # Example with a PyTorch tensor\n        >>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8)\n        >>> img_tensor_casted = transform(img_tensor)\n        >>> img_tensor_casted\n        tensor([  0., 127., 255.])  # dtype is float32\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, dtype=np.float32) -> None:\n        \"\"\"\n        Args:\n            dtype: convert image to this data type, default is `np.float32`.\n        \"\"\"\n        self.dtype = dtype\n\n    def __call__(self, img: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor.\n\n        Args:\n            dtype: convert image to this data type, default is `self.dtype`.\n\n        Raises:\n            TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``.\n\n        \"\"\"\n        return convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)[0]\n\n\nclass ToTensor(Transform):\n    \"\"\"\n    Converts the input image to a tensor without applying any other transformations.\n    Input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.\n    Will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original.\n    For dictionary, list or tuple, convert every item to a Tensor if applicable and `wrap_sequence=False`.\n\n    Args:\n        dtype: target data type to when converting to Tensor.\n        device: target device to put the converted Tensor data.\n        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n            E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.\n        track_meta: whether to convert to `MetaTensor` or regular tensor, default to `None`,\n            use the return value of ``get_track_meta``.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        dtype: torch.dtype | None = None,\n        device: torch.device | str | None = None,\n        wrap_sequence: bool = True,\n        track_meta: bool | None = None,\n    ) -> None:\n        super().__init__()\n        self.dtype = dtype\n        self.device = device\n        self.wrap_sequence = wrap_sequence\n        self.track_meta = get_track_meta() if track_meta is None else bool(track_meta)\n\n    def __call__(self, img: NdarrayOrTensor):\n        \"\"\"\n        Apply the transform to `img` and make it contiguous.\n        \"\"\"\n        if isinstance(img, MetaTensor):\n            img.applied_operations = []  # drops tracking info\n        return convert_to_tensor(\n            img, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence, track_meta=self.track_meta\n        )\n\n\nclass EnsureType(Transform):\n    \"\"\"\n    Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`,\n    `float`, `int`, `bool`, `string` and `object` keep the original.\n    If passing a dictionary, list or tuple, still return dictionary, list or tuple will recursively convert\n    every item to the expected data type if `wrap_sequence=False`.\n\n    Args:\n        data_type: target data type to convert, should be \"tensor\" or \"numpy\".\n        dtype: target data content type to convert, for example: np.float32, torch.float, etc.\n        device: for Tensor data type, specify the target device.\n        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n        track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,\n            if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.\n\n    Example with wrap_sequence=True:\n        >>> import numpy as np\n        >>> import torch\n        >>> transform = EnsureType(data_type=\"tensor\", wrap_sequence=True)\n        >>> # Converting a list to a tensor\n        >>> data_list = [1, 2., 3]\n        >>> tensor_data = transform(data_list)\n        >>> tensor_data\n        tensor([1., 2., 3.])    # All elements have dtype float32\n\n    Example with wrap_sequence=False:\n        >>> transform = EnsureType(data_type=\"tensor\", wrap_sequence=False)\n        >>> # Converting each element in a list to individual tensors\n        >>> data_list = [1, 2, 3]\n        >>> tensors_list = transform(data_list)\n        >>> tensors_list\n        [tensor(1), tensor(2.), tensor(3)]  # Only second element is float32 rest are int64\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        data_type: str = \"tensor\",\n        dtype: DtypeLike | torch.dtype = None,\n        device: torch.device | None = None,\n        wrap_sequence: bool = True,\n        track_meta: bool | None = None,\n    ) -> None:\n        self.data_type = look_up_option(data_type.lower(), {\"tensor\", \"numpy\"})\n        self.dtype = dtype\n        self.device = device\n        self.wrap_sequence = wrap_sequence\n        self.track_meta = get_track_meta() if track_meta is None else bool(track_meta)\n\n    def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None):\n        \"\"\"\n        Args:\n            data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.\n                will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and\n                objects keep the original. for dictionary, list or tuple, ensure every item as expected type\n                if applicable and `wrap_sequence=False`.\n            dtype: target data content type to convert, for example: np.float32, torch.float, etc.\n\n        \"\"\"\n        if self.data_type == \"tensor\":\n            output_type = MetaTensor if self.track_meta else torch.Tensor\n        else:\n            output_type = np.ndarray  # type: ignore\n        out: NdarrayOrTensor\n        out, *_ = convert_data_type(\n            data=data,\n            output_type=output_type,  # type: ignore\n            dtype=self.dtype if dtype is None else dtype,\n            device=self.device,\n            wrap_sequence=self.wrap_sequence,\n        )\n        return out\n\n\nclass ToNumpy(Transform):\n    \"\"\"\n    Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor.\n\n    Args:\n        dtype: target data type when converting to numpy array.\n        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n            E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, dtype: DtypeLike = None, wrap_sequence: bool = True) -> None:\n        super().__init__()\n        self.dtype = dtype\n        self.wrap_sequence = wrap_sequence\n\n    def __call__(self, img: NdarrayOrTensor):\n        \"\"\"\n        Apply the transform to `img` and make it contiguous.\n        \"\"\"\n        return convert_to_numpy(img, dtype=self.dtype, wrap_sequence=self.wrap_sequence)\n\n\nclass ToCupy(Transform):\n    \"\"\"\n    Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor.\n\n    Args:\n        dtype: data type specifier. It is inferred from the input by default.\n            if not None, must be an argument of `numpy.dtype`, for more details:\n            https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.\n        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n            E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.\n\n    \"\"\"\n\n    backend = [TransformBackends.CUPY]\n\n    def __init__(self, dtype: np.dtype | None = None, wrap_sequence: bool = True) -> None:\n        super().__init__()\n        self.dtype = dtype\n        self.wrap_sequence = wrap_sequence\n\n    def __call__(self, data: NdarrayOrTensor):\n        \"\"\"\n        Create a CuPy array from `data` and make it contiguous\n        \"\"\"\n        return convert_to_cupy(data, dtype=self.dtype, wrap_sequence=self.wrap_sequence)\n\n\nclass ToPIL(Transform):\n    \"\"\"\n    Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __call__(self, img):\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        if isinstance(img, PILImageImage):\n            return img\n        if isinstance(img, torch.Tensor):\n            img = img.detach().cpu().numpy()\n        return pil_image_fromarray(img)\n\n\nclass Transpose(Transform):\n    \"\"\"\n    Transposes the input image based on the given `indices` dimension ordering.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, indices: Sequence[int] | None) -> None:\n        self.indices = None if indices is None else tuple(indices)\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        return img.permute(self.indices or tuple(range(img.ndim)[::-1]))  # type: ignore\n\n\nclass SqueezeDim(Transform):\n    \"\"\"\n    Squeeze a unitary dimension.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, dim: int | None = 0, update_meta=True) -> None:\n        \"\"\"\n        Args:\n            dim: dimension to be squeezed. Default = 0\n                \"None\" works when the input is numpy array.\n            update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.\n\n        Raises:\n            TypeError: When ``dim`` is not an ``Optional[int]``.\n\n        \"\"\"\n        if dim is not None and not isinstance(dim, int):\n            raise TypeError(f\"dim must be None or a int but is {type(dim).__name__}.\")\n        self.dim = dim\n        self.update_meta = update_meta\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: numpy arrays with required dimension `dim` removed\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if self.dim is None:\n            if self.update_meta:\n                warnings.warn(\"update_meta=True is ignored when dim=None.\")\n            return img.squeeze()\n        dim = (self.dim + len(img.shape)) if self.dim < 0 else self.dim\n        # for pytorch/numpy unification\n        if img.shape[dim] != 1:\n            raise ValueError(f\"Can only squeeze singleton dimension, got shape {img.shape[dim]} of {img.shape}.\")\n        img = img.squeeze(dim)\n        if self.update_meta and isinstance(img, MetaTensor) and dim > 0 and len(img.affine.shape) == 2:\n            h, w = img.affine.shape\n            affine, device = img.affine, img.affine.device if isinstance(img.affine, torch.Tensor) else None\n            if h > dim:\n                affine = affine[torch.arange(0, h, device=device) != dim - 1]\n            if w > dim:\n                affine = affine[:, torch.arange(0, w, device=device) != dim - 1]\n            if (affine.shape[0] == affine.shape[1]) and not np.linalg.det(convert_to_numpy(affine, wrap_sequence=True)):\n                warnings.warn(f\"After SqueezeDim, img.affine is ill-posed: \\n{img.affine}.\")\n            img.affine = affine\n        return img\n\n\nclass DataStats(Transform):\n    \"\"\"\n    Utility transform to show the statistics of data for debug or analysis.\n    It can be inserted into any place of a transform chain and check results of previous transforms.\n    It support both `numpy.ndarray` and `torch.tensor` as input data,\n    so it can be used in pre-processing and post-processing.\n\n    It gets logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`.\n    If the log level of `logging.RootLogger` is higher than `INFO`, will add a separate `StreamHandler`\n    log handler with `INFO` level and record to `stdout`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        prefix: str = \"Data\",\n        data_type: bool = True,\n        data_shape: bool = True,\n        value_range: bool = True,\n        data_value: bool = False,\n        meta_info: bool = False,\n        additional_info: Callable | None = None,\n        name: str = \"DataStats\",\n    ) -> None:\n        \"\"\"\n        Args:\n            prefix: will be printed in format: \"{prefix} statistics\".\n            data_type: whether to show the type of input data.\n            data_shape: whether to show the shape of input data.\n            value_range: whether to show the value range of input data.\n            data_value: whether to show the raw value of input data.\n                a typical example is to print some properties of Nifti image: affine, pixdim, etc.\n            meta_info: whether to show the data of MetaTensor.\n            additional_info: user can define callable function to extract additional info from input data.\n            name: identifier of `logging.logger` to use, defaulting to \"DataStats\".\n\n        Raises:\n            TypeError: When ``additional_info`` is not an ``Optional[Callable]``.\n\n        \"\"\"\n        if not isinstance(prefix, str):\n            raise ValueError(f\"prefix must be a string, got {type(prefix)}.\")\n        self.prefix = prefix\n        self.data_type = data_type\n        self.data_shape = data_shape\n        self.value_range = value_range\n        self.data_value = data_value\n        self.meta_info = meta_info\n        if additional_info is not None and not callable(additional_info):\n            raise TypeError(f\"additional_info must be None or callable but is {type(additional_info).__name__}.\")\n        self.additional_info = additional_info\n        self._logger_name = name\n        _logger = logging.getLogger(self._logger_name)\n        _logger.setLevel(logging.INFO)\n        if logging.root.getEffectiveLevel() > logging.INFO:\n            # Avoid duplicate stream handlers to be added when multiple DataStats are used in a chain.\n            has_console_handler = any(\n                hasattr(h, \"is_data_stats_handler\") and h.is_data_stats_handler for h in _logger.handlers\n            )\n            if not has_console_handler:\n                # if the root log level is higher than INFO, set a separate stream handler to record\n                console = logging.StreamHandler(sys.stdout)\n                console.setLevel(logging.INFO)\n                console.is_data_stats_handler = True  # type:ignore[attr-defined]\n                _logger.addHandler(console)\n\n    def __call__(\n        self,\n        img: NdarrayOrTensor,\n        prefix: str | None = None,\n        data_type: bool | None = None,\n        data_shape: bool | None = None,\n        value_range: bool | None = None,\n        data_value: bool | None = None,\n        meta_info: bool | None = None,\n        additional_info: Callable | None = None,\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Apply the transform to `img`, optionally take arguments similar to the class constructor.\n        \"\"\"\n        lines = [f\"{prefix or self.prefix} statistics:\"]\n\n        if self.data_type if data_type is None else data_type:\n            lines.append(f\"Type: {type(img)} {img.dtype if hasattr(img, 'dtype') else None}\")\n        if self.data_shape if data_shape is None else data_shape:\n            lines.append(f\"Shape: {img.shape if hasattr(img, 'shape') else None}\")\n        if self.value_range if value_range is None else value_range:\n            if isinstance(img, np.ndarray):\n                lines.append(f\"Value range: ({np.min(img)}, {np.max(img)})\")\n            elif isinstance(img, torch.Tensor):\n                lines.append(f\"Value range: ({torch.min(img)}, {torch.max(img)})\")\n            else:\n                lines.append(f\"Value range: (not a PyTorch or Numpy array, type: {type(img)})\")\n        if self.data_value if data_value is None else data_value:\n            lines.append(f\"Value: {img}\")\n        if self.meta_info if meta_info is None else meta_info:\n            metadata = getattr(img, \"meta\", \"(input is not a MetaTensor)\")\n            lines.append(f\"Meta info: {repr(metadata)}\")\n        additional_info = self.additional_info if additional_info is None else additional_info\n        if additional_info is not None:\n            lines.append(f\"Additional info: {additional_info(img)}\")\n        separator = \"\\n\"\n        output = f\"{separator.join(lines)}\"\n        logging.getLogger(self._logger_name).info(output)\n        return img\n\n\nclass SimulateDelay(Transform):\n    \"\"\"\n    This is a pass through transform to be used for testing purposes. It allows\n    adding fake behaviors that are useful for testing purposes to simulate\n    how large datasets behave without needing to test on large data sets.\n\n    For example, simulating slow NFS data transfers, or slow network transfers\n    in testing by adding explicit timing delays. Testing of small test data\n    can lead to incomplete understanding of real world issues, and may lead\n    to sub-optimal design choices.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, delay_time: float = 0.0) -> None:\n        \"\"\"\n        Args:\n            delay_time: The minimum amount of time, in fractions of seconds,\n                to accomplish this delay task.\n        \"\"\"\n        super().__init__()\n        self.delay_time: float = delay_time\n\n    def __call__(self, img: NdarrayOrTensor, delay_time: float | None = None) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: data remain unchanged throughout this transform.\n            delay_time: The minimum amount of time, in fractions of seconds,\n                to accomplish this delay task.\n        \"\"\"\n        time.sleep(self.delay_time if delay_time is None else delay_time)\n        return img\n\n\nclass Lambda(InvertibleTransform):\n    \"\"\"\n    Apply a user-defined lambda as a transform.\n\n    For example:\n\n    .. code-block:: python\n        :emphasize-lines: 2\n\n        image = np.ones((10, 2, 2))\n        lambd = Lambda(func=lambda x: x[:4, :, :])\n        print(lambd(image).shape)\n        (4, 2, 2)\n\n    Args:\n        func: Lambda/function to be applied.\n        inv_func: Lambda/function of inverse operation, default to `lambda x: x`.\n        track_meta:  If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)\n            as opposed to MONAI's enhanced objects. By default, this is `True`.\n\n    Raises:\n        TypeError: When ``func`` is not an ``Optional[Callable]``.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True\n    ) -> None:\n        if func is not None and not callable(func):\n            raise TypeError(f\"func must be None or callable but is {type(func).__name__}.\")\n        self.func = func\n        self.inv_func = inv_func\n        self.track_meta = track_meta\n\n    def __call__(self, img: NdarrayOrTensor, func: Callable | None = None):\n        \"\"\"\n        Apply `self.func` to `img`.\n\n        Args:\n            func: Lambda/function to be applied. Defaults to `self.func`.\n\n        Raises:\n            TypeError: When ``func`` is not an ``Optional[Callable]``.\n\n        \"\"\"\n        fn = func if func is not None else self.func\n        if not callable(fn):\n            raise TypeError(f\"func must be None or callable but is {type(fn).__name__}.\")\n        out = fn(img)\n        # convert to MetaTensor if necessary\n        if isinstance(out, (np.ndarray, torch.Tensor)) and not isinstance(out, MetaTensor) and self.track_meta:\n            out = MetaTensor(out)\n        if isinstance(out, MetaTensor):\n            self.push_transform(out)\n        return out\n\n    def inverse(self, data: torch.Tensor):\n        if isinstance(data, MetaTensor):\n            self.pop_transform(data)\n        return self.inv_func(data)\n\n\nclass RandLambda(Lambda, RandomizableTransform):\n    \"\"\"\n    Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` may contain random logic,\n    or randomly execute the function based on `prob`.\n\n    Args:\n        func: Lambda/function to be applied.\n        prob: probability of executing the random function, default to 1.0, with 100% probability to execute.\n        inv_func: Lambda/function of inverse operation, default to `lambda x: x`.\n        track_meta:  If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)\n            as opposed to MONAI's enhanced objects. By default, this is `True`.\n\n    For more details, please check :py:class:`monai.transforms.Lambda`.\n    \"\"\"\n\n    backend = Lambda.backend\n\n    def __init__(\n        self,\n        func: Callable | None = None,\n        prob: float = 1.0,\n        inv_func: Callable = no_collation,\n        track_meta: bool = True,\n    ) -> None:\n        Lambda.__init__(self=self, func=func, inv_func=inv_func, track_meta=track_meta)\n        RandomizableTransform.__init__(self=self, prob=prob)\n\n    def __call__(self, img: NdarrayOrTensor, func: Callable | None = None):\n        self.randomize(img)\n        out = deepcopy(super().__call__(img, func) if self._do_transform else img)\n        # convert to MetaTensor if necessary\n        if not isinstance(out, MetaTensor) and self.track_meta:\n            out = MetaTensor(out)\n        if isinstance(out, MetaTensor):\n            lambda_info = self.pop_transform(out) if self._do_transform else {}\n            self.push_transform(out, extra_info=lambda_info)\n        return out\n\n    def inverse(self, data: torch.Tensor):\n        do_transform = self.get_most_recent_transform(data).pop(TraceKeys.DO_TRANSFORM)\n        if do_transform:\n            data = super().inverse(data)\n        else:\n            self.pop_transform(data)\n        return data\n\n\nclass LabelToMask(Transform):\n    \"\"\"\n    Convert labels to mask for other tasks. A typical usage is to convert segmentation labels\n    to mask data to pre-process images and then feed the images into classification network.\n    It can support single channel labels or One-Hot labels with specified `select_labels`.\n    For example, users can select `label value = [2, 3]` to construct mask data, or select the\n    second and the third channels of labels to construct mask data.\n    The output mask data can be a multiple channels binary data or a single channel binary\n    data that merges all the channels.\n\n    Args:\n        select_labels: labels to generate mask from. for 1 channel label, the `select_labels`\n            is the expected label values, like: [1, 2, 3]. for One-Hot format label, the\n            `select_labels` is the expected channel indices.\n        merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes,\n            will return a single channel mask with binary data.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(  # pytype: disable=annotation-type-mismatch\n        self, select_labels: Sequence[int] | int, merge_channels: bool = False\n    ) -> None:  # pytype: disable=annotation-type-mismatch\n        self.select_labels = ensure_tuple(select_labels)\n        self.merge_channels = merge_channels\n\n    def __call__(\n        self, img: NdarrayOrTensor, select_labels: Sequence[int] | int | None = None, merge_channels: bool = False\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            select_labels: labels to generate mask from. for 1 channel label, the `select_labels`\n                is the expected label values, like: [1, 2, 3]. for One-Hot format label, the\n                `select_labels` is the expected channel indices.\n            merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes,\n                will return a single channel mask with binary data.\n        \"\"\"\n        img = convert_to_tensor(img, track_meta=get_track_meta())\n        if select_labels is None:\n            select_labels = self.select_labels\n        else:\n            select_labels = ensure_tuple(select_labels)\n\n        if img.shape[0] > 1:\n            data = img[[*select_labels]]\n        else:\n            where: Callable = np.where if isinstance(img, np.ndarray) else torch.where  # type: ignore\n            data = where(in1d(img, select_labels), True, False).reshape(img.shape)\n\n        if merge_channels or self.merge_channels:\n            return data.any(0)[None]  # type: ignore\n\n        return data\n\n\nclass FgBgToIndices(Transform, MultiSampleTrait):\n    \"\"\"\n    Compute foreground and background of the input label data, return the indices.\n    If no output_shape specified, output data will be 1 dim indices after flattening.\n    This transform can help pre-compute foreground and background regions for other transforms.\n    A typical usage is to randomly select foreground and background to crop.\n    The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`.\n\n    Args:\n        image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to\n            determine the valid image content area and select background only in this area.\n        output_shape: expected shape of output indices. if not None, unravel indices to specified shape.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]\n\n    def __init__(self, image_threshold: float = 0.0, output_shape: Sequence[int] | None = None) -> None:\n        self.image_threshold = image_threshold\n        self.output_shape = output_shape\n\n    def __call__(\n        self, label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, output_shape: Sequence[int] | None = None\n    ) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            label: input data to compute foreground and background indices.\n            image: if image is not None, use ``label = 0 & image > image_threshold``\n                to define background. so the output items will not map to all the voxels in the label.\n            output_shape: expected shape of output indices. if None, use `self.output_shape` instead.\n\n        \"\"\"\n        if output_shape is None:\n            output_shape = self.output_shape\n        fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)\n        if output_shape is not None:\n            fg_indices = unravel_indices(fg_indices, output_shape)\n            bg_indices = unravel_indices(bg_indices, output_shape)\n        return fg_indices, bg_indices\n\n\nclass ClassesToIndices(Transform, MultiSampleTrait):\n    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]\n\n    def __init__(\n        self,\n        num_classes: int | None = None,\n        image_threshold: float = 0.0,\n        output_shape: Sequence[int] | None = None,\n        max_samples_per_class: int | None = None,\n    ) -> None:\n        \"\"\"\n        Compute indices of every class of the input label data, return a list of indices.\n        If no output_shape specified, output data will be 1 dim indices after flattening.\n        This transform can help pre-compute indices of the class regions for other transforms.\n        A typical usage is to randomly select indices of classes to crop.\n        The main logic is based on :py:class:`monai.transforms.utils.map_classes_to_indices`.\n\n        Args:\n            num_classes: number of classes for argmax label, not necessary for One-Hot label.\n            image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to\n                determine the valid image content area and select only the indices of classes in this area.\n            output_shape: expected shape of output indices. if not None, unravel indices to specified shape.\n            max_samples_per_class: maximum length of indices to sample in each class to reduce memory consumption.\n                Default is None, no subsampling.\n\n        \"\"\"\n        self.num_classes = num_classes\n        self.image_threshold = image_threshold\n        self.output_shape = output_shape\n        self.max_samples_per_class = max_samples_per_class\n\n    def __call__(\n        self, label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, output_shape: Sequence[int] | None = None\n    ) -> list[NdarrayOrTensor]:\n        \"\"\"\n        Args:\n            label: input data to compute the indices of every class.\n            image: if image is not None, use ``image > image_threshold`` to define valid region, and only select\n                the indices within the valid region.\n            output_shape: expected shape of output indices. if None, use `self.output_shape` instead.\n\n        \"\"\"\n\n        if output_shape is None:\n            output_shape = self.output_shape\n        indices: list[NdarrayOrTensor]\n        indices = map_classes_to_indices(\n            label, self.num_classes, image, self.image_threshold, self.max_samples_per_class\n        )\n        if output_shape is not None:\n            indices = [unravel_indices(cls_indices, output_shape) for cls_indices in indices]\n\n        return indices\n\n\nclass ConvertToMultiChannelBasedOnBratsClasses(Transform):\n    \"\"\"\n    Convert labels to multi channels based on `brats18 <https://www.med.upenn.edu/sbia/brats2018/data.html>`_ classes,\n    which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):\n    label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,\n    label 2 is the peritumoral edema, which is counted only under WT subregion,\n    label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        # if img has channel dim, squeeze it\n        if img.ndim == 4 and img.shape[0] == 1:\n            img = img.squeeze(0)\n\n        result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4]\n        # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT\n        # label 4 is ET\n        return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)\n\n\nclass AddExtremePointsChannel(Randomizable, Transform):\n    \"\"\"\n    Add extreme points of label to the image as a new channel. This transform generates extreme\n    point from label and applies a gaussian filter. The pixel values in points image are rescaled\n    to range [rescale_min, rescale_max] and added as a new channel to input image. The algorithm is\n    described in Roth et al., Going to Extremes: Weakly Supervised Medical Image Segmentation\n    https://arxiv.org/abs/2009.11988.\n\n    This transform only supports single channel labels (1, spatial_dim1, [spatial_dim2, ...]). The\n    background ``index`` is ignored when calculating extreme points.\n\n    Args:\n        background: Class index of background label, defaults to 0.\n        pert: Random perturbation amount to add to the points, defaults to 0.0.\n\n    Raises:\n        ValueError: When no label image provided.\n        ValueError: When label image is not single channel.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, background: int = 0, pert: float = 0.0) -> None:\n        self._background = background\n        self._pert = pert\n        self._points: list[tuple[int, ...]] = []\n\n    def randomize(self, label: NdarrayOrTensor) -> None:\n        self._points = get_extreme_points(label, rand_state=self.R, background=self._background, pert=self._pert)\n\n    def __call__(\n        self,\n        img: NdarrayOrTensor,\n        label: NdarrayOrTensor | None = None,\n        sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 3.0,\n        rescale_min: float = -1.0,\n        rescale_max: float = 1.0,\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: the image that we want to add new channel to.\n            label: label image to get extreme points from. Shape must be\n                (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels.\n            sigma: if a list of values, must match the count of spatial dimensions of input data,\n                and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n                use it for all spatial dimensions.\n            rescale_min: minimum value of output data.\n            rescale_max: maximum value of output data.\n        \"\"\"\n        if label is None:\n            raise ValueError(\"This transform requires a label array!\")\n        if label.shape[0] != 1:\n            raise ValueError(\"Only supports single channel labels!\")\n\n        # Generate extreme points\n        self.randomize(label[0, :])\n\n        points_image = extreme_points_to_image(\n            points=self._points, label=label, sigma=sigma, rescale_min=rescale_min, rescale_max=rescale_max\n        )\n        points_image, *_ = convert_to_dst_type(points_image, img)  # type: ignore\n        return concatenate((img, points_image), axis=0)\n\n\nclass TorchVision(Transform):\n    \"\"\"\n    This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.\n    Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, name: str, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            name: The transform name in TorchVision package.\n            args: parameters for the TorchVision transform.\n            kwargs: parameters for the TorchVision transform.\n\n        \"\"\"\n        super().__init__()\n        self.name = name\n        transform, _ = optional_import(\"torchvision.transforms\", \"0.8.0\", min_version, name=name)\n        self.trans = transform(*args, **kwargs)\n\n    def __call__(self, img: NdarrayOrTensor):\n        \"\"\"\n        Args:\n            img: PyTorch Tensor data for the TorchVision transform.\n\n        \"\"\"\n        img_t, *_ = convert_data_type(img, torch.Tensor)\n\n        out = self.trans(img_t)\n        out, *_ = convert_to_dst_type(src=out, dst=img)\n        return out\n\n\nclass RandTorchVision(Transform, RandomizableTrait):\n    \"\"\"\n    This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.\n    Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, name: str, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            name: The transform name in TorchVision package.\n            args: parameters for the TorchVision transform.\n            kwargs: parameters for the TorchVision transform.\n\n        \"\"\"\n        super().__init__()\n        self.name = name\n        transform, _ = optional_import(\"torchvision.transforms\", \"0.8.0\", min_version, name=name)\n        self.trans = transform(*args, **kwargs)\n\n    def __call__(self, img: NdarrayOrTensor):\n        \"\"\"\n        Args:\n            img: PyTorch Tensor data for the TorchVision transform.\n\n        \"\"\"\n        img_t, *_ = convert_data_type(img, torch.Tensor)\n\n        out = self.trans(img_t)\n        out, *_ = convert_to_dst_type(src=out, dst=img)\n        return out\n\n\nclass TorchIO(Transform):\n    \"\"\"\n    This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.\n    See https://torchio.readthedocs.io/transforms/transforms.html for more details.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, name: str, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            name: The transform name in TorchIO package.\n            args: parameters for the TorchIO transform.\n            kwargs: parameters for the TorchIO transform.\n        \"\"\"\n        super().__init__()\n        self.name = name\n        transform, _ = optional_import(\"torchio.transforms\", \"0.18.0\", min_version, name=name)\n        self.trans = transform(*args, **kwargs)\n\n    def __call__(self, img: NdarrayOrTensor | Mapping[Hashable, NdarrayOrTensor]):\n        \"\"\"\n        Args:\n            img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,\n                 or dict containing 4D tensors as values\n\n        \"\"\"\n        return self.trans(img)\n\n\nclass RandTorchIO(Transform, RandomizableTrait):\n    \"\"\"\n    This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.\n    See https://torchio.readthedocs.io/transforms/transforms.html for more details.\n    Use this wrapper for all TorchIO transform inheriting from RandomTransform:\n    https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, name: str, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            name: The transform name in TorchIO package.\n            args: parameters for the TorchIO transform.\n            kwargs: parameters for the TorchIO transform.\n        \"\"\"\n        super().__init__()\n        self.name = name\n        transform, _ = optional_import(\"torchio.transforms\", \"0.18.0\", min_version, name=name)\n        self.trans = transform(*args, **kwargs)\n\n    def __call__(self, img: NdarrayOrTensor | Mapping[Hashable, NdarrayOrTensor]):\n        \"\"\"\n        Args:\n            img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,\n                 or dict containing 4D tensors as values\n\n        \"\"\"\n        return self.trans(img)\n\n\nclass MapLabelValue:\n    \"\"\"\n    Utility to map label values to another set of values.\n    For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], [\"label3\", \"label2\", \"label1\"] -> [0, 1, 2],\n    [3.5, 2.5, 1.5] -> [\"label0\", \"label1\", \"label2\"], etc.\n    The label data must be numpy array or array-like data and the output data will be numpy array.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY, TransformBackends.TORCH]\n\n    def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None:\n        \"\"\"\n        Args:\n            orig_labels: original labels that map to others.\n            target_labels: expected label values, 1: 1 map to the `orig_labels`.\n            dtype: convert the output data to dtype, default to float32.\n                if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend.\n\n        \"\"\"\n        if len(orig_labels) != len(target_labels):\n            raise ValueError(\"orig_labels and target_labels must have the same length.\")\n\n        self.orig_labels = orig_labels\n        self.target_labels = target_labels\n        self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t)\n        type_dtype = type(dtype)\n        if getattr(type_dtype, \"__module__\", \"\") == \"torch\":\n            self.use_numpy = False\n            self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor)\n        else:\n            self.use_numpy = True\n            self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray)\n\n    def __call__(self, img: NdarrayOrTensor):\n        if self.use_numpy:\n            img_np, *_ = convert_data_type(img, np.ndarray)\n            _out_shape = img_np.shape\n            img_flat = img_np.flatten()\n            try:\n                out_flat = img_flat.astype(self.dtype)\n            except ValueError:\n                # can't copy unchanged labels as the expected dtype is not supported, must map all the label values\n                out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype)\n            for o, t in self.pair:\n                out_flat[img_flat == o] = t\n            out_t = out_flat.reshape(_out_shape)\n        else:\n            img_t, *_ = convert_data_type(img, torch.Tensor)\n            out_t = img_t.detach().clone().to(self.dtype)  # type: ignore\n            for o, t in self.pair:\n                out_t[img_t == o] = t\n        out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype)\n        return out\n\n\nclass IntensityStats(Transform):\n    \"\"\"\n    Compute statistics for the intensity values of input image and store into the metadata dictionary.\n    For example: if `ops=[lambda x: np.mean(x), \"max\"]` and `key_prefix=\"orig\"`, may generate below stats:\n    `{\"orig_custom_0\": 1.5, \"orig_max\": 3.0}`.\n\n    Args:\n        ops: expected operations to compute statistics for the intensity.\n            if a string, will map to the predefined operations, supported: [\"mean\", \"median\", \"max\", \"min\", \"std\"]\n            mapping to `np.nanmean`, `np.nanmedian`, `np.nanmax`, `np.nanmin`, `np.nanstd`.\n            if a callable function, will execute the function on input image.\n        key_prefix: the prefix to combine with `ops` name to generate the key to store the results in the\n            metadata dictionary. if some `ops` are callable functions, will use \"{key_prefix}_custom_{index}\"\n            as the key, where index counts from 0.\n        channel_wise: whether to compute statistics for every channel of input image separately.\n            if True, return a list of values for every operation, default to False.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, ops: Sequence[str | Callable], key_prefix: str, channel_wise: bool = False) -> None:\n        self.ops = ensure_tuple(ops)\n        self.key_prefix = key_prefix\n        self.channel_wise = channel_wise\n\n    def __call__(\n        self, img: NdarrayOrTensor, meta_data: dict | None = None, mask: np.ndarray | None = None\n    ) -> tuple[NdarrayOrTensor, dict]:\n        \"\"\"\n        Compute statistics for the intensity of input image.\n\n        Args:\n            img: input image to compute intensity stats.\n            meta_data: metadata dictionary to store the statistics data, if None, will create an empty dictionary.\n            mask: if not None, mask the image to extract only the interested area to compute statistics.\n                mask must have the same shape as input `img`.\n\n        \"\"\"\n        img_np, *_ = convert_data_type(img, np.ndarray)\n        if meta_data is None:\n            meta_data = {}\n\n        if mask is not None:\n            if mask.shape != img_np.shape:\n                raise ValueError(f\"mask must have the same shape as input `img`, got {mask.shape} and {img_np.shape}.\")\n            if mask.dtype != bool:\n                raise TypeError(f\"mask must be bool array, got type {mask.dtype}.\")\n            img_np = img_np[mask]\n\n        supported_ops = {\n            \"mean\": np.nanmean,\n            \"median\": np.nanmedian,\n            \"max\": np.nanmax,\n            \"min\": np.nanmin,\n            \"std\": np.nanstd,\n        }\n\n        def _compute(op: Callable, data: np.ndarray):\n            if self.channel_wise:\n                return [op(c) for c in data]\n            return op(data)\n\n        custom_index = 0\n        for o in self.ops:\n            if isinstance(o, str):\n                o = look_up_option(o, supported_ops.keys())\n                meta_data[self.key_prefix + \"_\" + o] = _compute(supported_ops[o], img_np)  # type: ignore\n            elif callable(o):\n                meta_data[self.key_prefix + \"_custom_\" + str(custom_index)] = _compute(o, img_np)\n                custom_index += 1\n            else:\n                raise ValueError(\"ops must be key string for predefined operations or callable function.\")\n\n        return img, meta_data\n\n\nclass ToDevice(Transform):\n    \"\"\"\n    Move PyTorch Tensor to the specified device.\n    It can help cache data into GPU and execute following logic on GPU directly.\n\n    Note:\n        If moving data to GPU device in the multi-processing workers of DataLoader, may got below CUDA error:\n        \"RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing,\n        you must use the 'spawn' start method.\"\n        So usually suggest to set `num_workers=0` in the `DataLoader` or `ThreadDataLoader`.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH]\n\n    def __init__(self, device: torch.device | str, **kwargs) -> None:\n        \"\"\"\n        Args:\n            device: target device to move the Tensor, for example: \"cuda:1\".\n            kwargs: other args for the PyTorch `Tensor.to()` API, for more details:\n                https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.\n\n        \"\"\"\n        self.device = device\n        self.kwargs = kwargs\n\n    def __call__(self, img: torch.Tensor):\n        if not isinstance(img, torch.Tensor):\n            raise ValueError(\"img must be PyTorch Tensor, consider converting img by `EnsureType` transform first.\")\n\n        return img.to(self.device, **self.kwargs)\n\n\nclass CuCIM(Transform):\n    \"\"\"\n    Wrap a non-randomized cuCIM transform, defined based on the transform name and args.\n    For randomized transforms use :py:class:`monai.transforms.RandCuCIM`.\n\n    Args:\n        name: the transform name in CuCIM package\n        args: parameters for the CuCIM transform\n        kwargs: parameters for the CuCIM transform\n\n    Note:\n        CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.\n        Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.\n    \"\"\"\n\n    def __init__(self, name: str, *args, **kwargs) -> None:\n        super().__init__()\n        self.name = name\n        self.transform, _ = optional_import(\"cucim.core.operations.expose.transform\", name=name)\n        self.args = args\n        self.kwargs = kwargs\n\n    def __call__(self, data):\n        \"\"\"\n        Args:\n            data: a CuPy array (`cupy.ndarray`) for the cuCIM transform\n\n        Returns:\n            `cupy.ndarray`\n\n        \"\"\"\n        return self.transform(data, *self.args, **self.kwargs)\n\n\nclass RandCuCIM(CuCIM, RandomizableTrait):\n    \"\"\"\n    Wrap a randomized cuCIM transform, defined based on the transform name and args\n    For deterministic non-randomized transforms use :py:class:`monai.transforms.CuCIM`.\n\n    Args:\n        name: the transform name in CuCIM package.\n        args: parameters for the CuCIM transform.\n        kwargs: parameters for the CuCIM transform.\n\n    Note:\n        - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.\n          Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.\n        - If the random factor of the underlying cuCIM transform is not derived from `self.R`,\n          the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.\n    \"\"\"\n\n    def __init__(self, name: str, *args, **kwargs) -> None:\n        CuCIM.__init__(self, name, *args, **kwargs)\n\n\nclass AddCoordinateChannels(Transform):\n    \"\"\"\n    Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling,\n    to allow feeding of the patch's location into the network.\n\n    This can be seen as a input-only version of CoordConv:\n\n    Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018.\n\n    Args:\n        spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and\n            appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels\n            to the input image, encoding the coordinates of the input's three spatial dimensions.\n\n    \"\"\"\n\n    backend = [TransformBackends.NUMPY]\n\n    def __init__(self, spatial_dims: Sequence[int]) -> None:\n        self.spatial_dims = spatial_dims\n\n    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: data to be transformed, assuming `img` is channel first.\n        \"\"\"\n        if max(self.spatial_dims) > img.ndim - 2 or min(self.spatial_dims) < 0:\n            raise ValueError(f\"`spatial_dims` values must be within [0, {img.ndim - 2}]\")\n\n        spatial_size = img.shape[1:]\n        coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_size), indexing=\"ij\"))\n        coord_channels, *_ = convert_to_dst_type(coord_channels, img)  # type: ignore\n        coord_channels = coord_channels[list(self.spatial_dims)]\n        return concatenate((img, coord_channels), axis=0)\n\n\nclass ImageFilter(Transform):\n    \"\"\"\n    Applies a convolution filter to the input image.\n\n    Args:\n        filter:\n            A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``.\n            Available options for string are: ``mean``, ``laplace``, ``elliptical``, ``sobel``, ``sharpen``, ``median``, ``gauss``\n            See below for short explanations on every filter.\n        filter_size:\n            A single integer value specifying the size of the quadratic or cubic filter.\n            Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which\n            should be considered when choosing filter size.\n        kwargs:\n            Additional arguments passed to filter function, required by ``sobel`` and ``gauss``.\n            See below for details.\n\n    Raises:\n        ValueError: When ``filter_size`` is not an uneven integer\n        ValueError: When ``filter`` is an array and ``ndim`` is not in [1,2,3]\n        ValueError: When ``filter`` is an array and any dimension has an even shape\n        NotImplementedError: When ``filter`` is a string and not in ``self.supported_filters``\n        KeyError: When necessary ``kwargs`` are not passed to a filter that requires additional arguments.\n\n\n    **Mean Filtering:** ``filter='mean'``\n\n    Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image.\n    See also py:func:`monai.networks.layers.simplelayers.MeanFilter`\n    Example 2D filter (5 x 5)::\n\n        [[1, 1, 1, 1, 1],\n         [1, 1, 1, 1, 1],\n         [1, 1, 1, 1, 1],\n         [1, 1, 1, 1, 1],\n         [1, 1, 1, 1, 1]]\n\n    If smoothing labels with this filter, ensure they are in one-hot format.\n\n    **Outline Detection:** ``filter='laplace'``\n\n    Laplacian filtering for outline detection in images. Can be used to transform labels to contours.\n    See also py:func:`monai.networks.layers.simplelayers.LaplaceFilter`\n\n    Example 2D filter (5x5)::\n\n        [[-1., -1., -1., -1., -1.],\n         [-1., -1., -1., -1., -1.],\n         [-1., -1., 24., -1., -1.],\n         [-1., -1., -1., -1., -1.],\n         [-1., -1., -1., -1., -1.]]\n\n\n    **Dilation:** ``filter='elliptical'``\n\n    An elliptical filter can be used to dilate labels or label-contours.\n    Example 2D filter (5x5)::\n\n        [[0., 0., 1., 0., 0.],\n         [1., 1., 1., 1., 1.],\n         [1., 1., 1., 1., 1.],\n         [1., 1., 1., 1., 1.],\n         [0., 0., 1., 0., 0.]]\n\n\n    **Edge Detection:** ``filter='sobel'``\n\n    This filter allows for additional arguments passed as ``kwargs`` during initialization.\n    See also py:func:`monai.transforms.post.SobelGradients`\n\n    *kwargs*\n\n    * ``spatial_axes``: the axes that define the direction of the gradient to be calculated.\n      It calculates the gradient along each of the provide axis.\n      By default it calculate the gradient for all spatial axes.\n    * ``normalize_kernels``: if normalize the Sobel kernel to provide proper gradients. Defaults to True.\n    * ``normalize_gradients``: if normalize the output gradient to 0 and 1. Defaults to False.\n    * ``padding_mode``: the padding mode of the image when convolving with Sobel kernels. Defaults to ``\"reflect\"``.\n      Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.\n      See ``torch.nn.Conv1d()`` for more information.\n    * ``dtype``: kernel data type (torch.dtype). Defaults to ``torch.float32``.\n\n\n    **Sharpening:** ``filter='sharpen'``\n\n    Sharpen an image with a 2D or 3D filter.\n    Example 2D filter (5x5)::\n\n        [[ 0.,  0., -1.,  0.,  0.],\n         [-1., -1., -1., -1., -1.],\n         [-1., -1., 17., -1., -1.],\n         [-1., -1., -1., -1., -1.],\n         [ 0.,  0., -1.,  0.,  0.]]\n\n\n    **Gaussian Smooth:** ``filter='gauss'``\n\n    Blur/smooth an image with 2D or 3D gaussian filter.\n    This filter requires additional arguments passed as ``kwargs`` during initialization.\n    See also py:func:`monai.networks.layers.simplelayers.GaussianFilter`\n\n    *kwargs*\n\n    * ``sigma``: std. could be a single value, or spatial_dims number of values.\n    * ``truncated``: spreads how many stds.\n    * ``approx``: discrete Gaussian kernel type, available options are \"erf\", \"sampled\", and \"scalespace\".\n\n\n    **Median Filter:** ``filter='median'``\n\n    Blur an image with 2D or 3D median filter to remove noise.\n    Useful in image preprocessing to improve results of later processing.\n    See also py:func:`monai.networks.layers.simplelayers.MedianFilter`\n\n\n    **Savitzky Golay Filter:** ``filter = 'savitzky_golay'``\n\n    Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.\n    This filter requires additional arguments passed as ``kwargs`` during initialization.\n    See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter`\n\n    *kwargs*\n\n    * ``order``: Order of the polynomial to fit to each window, must be less than ``window_length``.\n    * ``axis``: (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).\n    * ``mode``: (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or\n      ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.\n\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n    supported_filters = sorted(\n        [\"mean\", \"laplace\", \"elliptical\", \"sobel\", \"sharpen\", \"median\", \"gauss\", \"savitzky_golay\"]\n    )\n\n    def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None, **kwargs) -> None:\n        self._check_filter_format(filter, filter_size)\n        self._check_kwargs_are_present(filter, **kwargs)\n        self.filter = filter\n        self.filter_size = filter_size\n        self.additional_args_for_filter = kwargs\n\n    def __call__(\n        self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None\n    ) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]\n            meta_dict: An optional dictionary with metadata\n            applied_operations: An optional list of operations that have been applied to the data\n\n        Returns:\n            A MetaTensor with the same shape as `img` and identical metadata\n        \"\"\"\n        if isinstance(img, MetaTensor):\n            meta_dict = img.meta\n            applied_operations = img.applied_operations\n\n        img_, prev_type, device = convert_data_type(img, torch.Tensor)\n        ndim = img_.ndim - 1  # assumes channel first format\n\n        if isinstance(self.filter, str):\n            self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim)  # type: ignore\n        elif isinstance(self.filter, (torch.Tensor, np.ndarray)):\n            self.filter = ApplyFilter(self.filter)\n\n        img_ = self._apply_filter(img_)\n        if meta_dict is not None or applied_operations is not None:\n            img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations)\n        else:\n            img_, *_ = convert_data_type(img_, prev_type, device)\n        return img_\n\n    def _check_all_values_uneven(self, x: tuple) -> None:\n        for value in x:\n            if value % 2 == 0:\n                raise ValueError(f\"Only uneven filters are supported, but filter size is {x}\")\n\n    def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None:\n        if isinstance(filter, str):\n            if filter != \"gauss\" and not filter_size:  # Gauss is the only filter that does not require `filter_size`\n                raise ValueError(\"`filter_size` must be specified when specifying filters by string.\")\n            if filter_size and filter_size % 2 == 0:\n                raise ValueError(\"`filter_size` should be a single uneven integer.\")\n            if filter not in self.supported_filters:\n                raise NotImplementedError(f\"{filter}. Supported filters are {self.supported_filters}.\")\n        elif isinstance(filter, (torch.Tensor, np.ndarray)):\n            if filter.ndim not in [1, 2, 3]:\n                raise ValueError(\"Only 1D, 2D, and 3D filters are supported.\")\n            self._check_all_values_uneven(filter.shape)\n        elif not isinstance(filter, (nn.Module, Transform)):\n            raise TypeError(\n                f\"{type(filter)} is not supported.\"\n                \"Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, \"\n                \"`class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`\"\n            )\n\n    def _check_kwargs_are_present(self, filter: str | NdarrayOrTensor | nn.Module, **kwargs: Any) -> None:\n        \"\"\"\n        Perform sanity checks on the kwargs if the filter contains the required keys.\n        If the filter is ``gauss``, kwargs should contain ``sigma``.\n        If the filter is ``savitzky_golay``, kwargs should contain ``order``.\n\n        Args:\n            filter: A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``.\n            kwargs: additional arguments defining the filter.\n\n        Raises:\n            KeyError if the filter doesn't contain the requirement key.\n        \"\"\"\n\n        if not isinstance(filter, str):\n            return\n        if filter == \"gauss\" and \"sigma\" not in kwargs.keys():\n            raise KeyError(\"`filter='gauss', requires the additional keyword argument `sigma`\")\n        if filter == \"savitzky_golay\" and \"order\" not in kwargs.keys():\n            raise KeyError(\"`filter='savitzky_golay', requires the additional keyword argument `order`\")\n\n    def _get_filter_from_string(self, filter: str, size: int, ndim: int) -> nn.Module | Callable:\n        if filter == \"mean\":\n            return MeanFilter(ndim, size)\n        elif filter == \"laplace\":\n            return LaplaceFilter(ndim, size)\n        elif filter == \"elliptical\":\n            return EllipticalFilter(ndim, size)\n        elif filter == \"sobel\":\n            from monai.transforms.post.array import SobelGradients  # cannot import on top because of circular imports\n\n            allowed_keys = SobelGradients.__init__.__annotations__.keys()\n            kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys}\n            return SobelGradients(size, **kwargs)\n        elif filter == \"sharpen\":\n            return SharpenFilter(ndim, size)\n        elif filter == \"gauss\":\n            allowed_keys = GaussianFilter.__init__.__annotations__.keys()\n            kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys}\n            return GaussianFilter(ndim, **kwargs)\n        elif filter == \"median\":\n            return partial(median_filter, kernel_size=size, spatial_dims=ndim)\n        elif filter == \"savitzky_golay\":\n            allowed_keys = SavitzkyGolayFilter.__init__.__annotations__.keys()\n            kwargs = {k: v for k, v in self.additional_args_for_filter.items() if k in allowed_keys}\n            return SavitzkyGolayFilter(size, **kwargs)\n        else:\n            raise NotImplementedError(f\"Filter {filter} not implemented\")\n\n    def _apply_filter(self, img: torch.Tensor) -> torch.Tensor:\n        if isinstance(self.filter, Transform):\n            img = self.filter(img)\n        else:\n            img = self.filter(img.unsqueeze(0))  # type: ignore\n            img = img[0]  # add and remove batch dim\n        return img\n\n\nclass RandImageFilter(RandomizableTransform):\n    \"\"\"\n    Randomly apply a convolutional filter to the input data.\n\n    Args:\n        filter:\n            A string specifying the filter or a custom filter as `torch.Tenor` or `np.ndarray`.\n            Available options are: `mean`, `laplace`, `elliptical`, `gaussian``\n            See below for short explanations on every filter.\n        filter_size:\n            A single integer value specifying the size of the quadratic or cubic filter.\n            Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which\n            should be considered when choosing filter size.\n        prob:\n            Probability the transform is applied to the data\n    \"\"\"\n\n    backend = ImageFilter.backend\n\n    def __init__(\n        self, filter: str | NdarrayOrTensor, filter_size: int | None = None, prob: float = 0.1, **kwargs\n    ) -> None:\n        super().__init__(prob)\n        self.filter = ImageFilter(filter, filter_size, **kwargs)\n\n    def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> NdarrayOrTensor:\n        \"\"\"\n        Args:\n            img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]\n            meta_dict: An optional dictionary with metadata\n            kwargs: optional arguments required by specific filters. E.g. `sigma`if filter is `gauss`.\n                see py:func:`monai.transforms.utility.array.ImageFilter` for more details\n\n        Returns:\n            A MetaTensor with the same shape as `img` and identical metadata\n        \"\"\"\n        self.randomize(None)\n        if self._do_transform:\n            img = self.filter(img)\n        return img\n\n\nclass ApplyTransformToPoints(InvertibleTransform, Transform):\n    \"\"\"\n    Transform points between image coordinates and world coordinates.\n    The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels\n    and N denotes the number of points. It will return a tensor with the same shape as the input.\n\n    Args:\n        dtype: The desired data type for the output.\n        affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates\n            from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary\n            Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when\n            applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.\n            The matrix is always converted to float64 for computation, which can be computationally\n            expensive when applied to a large number of points.\n            If None, will try to use the affine matrix from the input data.\n        invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.\n            Typically, the affine matrix is derived from an image and represents its location in world space,\n            while the points are in world coordinates. A value of ``True`` represents transforming these\n            world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation.\n        affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system\n            or you're using `ITKReader` with `affine_lps_to_ras=True`.\n            This ensures the correct application of the affine transformation between LPS (left-posterior-superior)\n            and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine\n            matrix are in the same coordinate system.\n\n    Use Cases:\n        - Transforming points between world space and image space, and vice versa.\n        - Automatically handling inverse transformations between image space and world space.\n        - If points have an existing affine transformation, the class computes and\n          applies the required delta affine transformation.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dtype: DtypeLike | torch.dtype | None = None,\n        affine: torch.Tensor | None = None,\n        invert_affine: bool = True,\n        affine_lps_to_ras: bool = False,\n    ) -> None:\n        self.dtype = dtype\n        self.affine = affine\n        self.invert_affine = invert_affine\n        self.affine_lps_to_ras = affine_lps_to_ras\n\n    def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:\n        \"\"\"\n        Compute the final affine transformation matrix to apply to the point data.\n\n        Args:\n            data: Input coordinates assumed to be in the shape (C, N, 2 or 3).\n            affine: 3x3 or 4x4 affine transformation matrix.\n\n        Returns:\n            Final affine transformation matrix.\n        \"\"\"\n\n        affine = convert_data_type(affine, dtype=torch.float64)[0]\n\n        if self.affine_lps_to_ras:\n            affine = orientation_ras_lps(affine)\n\n        if self.invert_affine:\n            affine = linalg_inv(affine)\n            if applied_affine is not None:\n                affine = affine @ applied_affine\n\n        return affine\n\n    def transform_coordinates(\n        self, data: torch.Tensor, affine: torch.Tensor | None = None\n    ) -> tuple[torch.Tensor, dict]:\n        \"\"\"\n        Transform coordinates using an affine transformation matrix.\n\n        Args:\n            data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),\n                where C represents the number of channels and N denotes the number of points.\n            affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation,\n                which can be computationally expensive when applied to a large number of points.\n\n        Returns:\n            Transformed coordinates.\n        \"\"\"\n        data = convert_to_tensor(data, track_meta=get_track_meta())\n        if affine is None and self.invert_affine:\n            raise ValueError(\"affine must be provided when invert_affine is True.\")\n        # applied_affine is the affine transformation matrix that has already been applied to the point data\n        applied_affine: torch.Tensor | None = getattr(data, \"affine\", None)\n        affine = applied_affine if affine is None else affine\n        if affine is None:\n            raise ValueError(\"affine must be provided if data does not have an affine matrix.\")\n\n        final_affine = self._compute_final_affine(affine, applied_affine)\n        out = apply_affine_to_points(data, final_affine, dtype=self.dtype)\n\n        extra_info = {\n            \"invert_affine\": self.invert_affine,\n            \"dtype\": get_dtype_string(self.dtype),\n            \"image_affine\": affine,\n            \"affine_lps_to_ras\": self.affine_lps_to_ras,\n        }\n\n        xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)\n        meta_info = TraceableTransform.track_transform_meta(\n            data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()\n        )\n\n        return out, meta_info\n\n    def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):\n        \"\"\"\n        Args:\n            data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),\n                where C represents the number of channels and N denotes the number of points.\n            affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.\n        \"\"\"\n        if data.ndim != 3 or data.shape[-1] not in (2, 3):\n            raise ValueError(f\"data should be in shape (C, N, 2 or 3), got {data.shape}.\")\n        affine = self.affine if affine is None else affine\n        if affine is not None and affine.shape not in ((3, 3), (4, 4)):\n            raise ValueError(f\"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.\")\n\n        out, meta_info = self.transform_coordinates(data, affine)\n\n        return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out\n\n    def inverse(self, data: torch.Tensor) -> torch.Tensor:\n        transform = self.pop_transform(data)\n        inverse_transform = ApplyTransformToPoints(\n            dtype=transform[TraceKeys.EXTRA_INFO][\"dtype\"],\n            invert_affine=not transform[TraceKeys.EXTRA_INFO][\"invert_affine\"],\n            affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO][\"affine_lps_to_ras\"],\n        )\n        with inverse_transform.trace_transform(False):\n            data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO][\"image_affine\"])\n\n        return data\n\n\nclass FlattenSequence(Transform, ReduceTrait):\n    \"\"\"\n    Flatten a nested sequence (list or tuple) by one level.\n    If the input is a sequence of sequences, it will flatten them into a single sequence.\n    Non-nested sequences and other data types are returned unchanged.\n\n    For example:\n\n    .. code-block:: python\n\n        flatten = FlattenSequence()\n        data = [[1, 2], [3, 4], [5, 6]]\n        print(flatten(data))\n        [1, 2, 3, 4, 5, 6]\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n    def __call__(self, data: list | tuple | Any) -> list | tuple | Any:\n        \"\"\"\n        Flatten a nested sequence by one level.\n        Args:\n            data: Input data, can be a nested sequence.\n        Returns:\n            Flattened list if input is a nested sequence, otherwise returns data unchanged.\n        \"\"\"\n        if isinstance(data, (list, tuple)):\n            if len(data) == 0:\n                return data\n            if all(isinstance(item, (list, tuple)) for item in data):\n                return [item for sublist in data for item in sublist]\n        return data\n"
  },
  {
    "path": "monai/transforms/utility/dictionary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nA collection of dictionary-based wrappers around the \"vanilla\" transforms for utility functions\ndefined in :py:class:`monai.transforms.utility.array`.\n\nClass names are ended with 'd' to denote dictionary-based transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport re\nfrom collections.abc import Callable, Hashable, Mapping, Sequence\nfrom copy import deepcopy\nfrom typing import Any, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.config import DtypeLike, KeysCollection\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data.meta_tensor import MetaObj, MetaTensor\nfrom monai.data.utils import no_collation\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.transforms.traits import MultiSampleTrait, RandomizableTrait, ReduceTrait\nfrom monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform\nfrom monai.transforms.utility.array import (\n    AddCoordinateChannels,\n    AddExtremePointsChannel,\n    ApplyTransformToPoints,\n    AsChannelLast,\n    CastToType,\n    ClassesToIndices,\n    ConvertToMultiChannelBasedOnBratsClasses,\n    CuCIM,\n    DataStats,\n    EnsureChannelFirst,\n    EnsureType,\n    FgBgToIndices,\n    FlattenSequence,\n    Identity,\n    ImageFilter,\n    IntensityStats,\n    LabelToMask,\n    Lambda,\n    MapLabelValue,\n    RemoveRepeatedChannel,\n    RepeatChannel,\n    SimulateDelay,\n    SplitDim,\n    SqueezeDim,\n    ToCupy,\n    ToDevice,\n    ToNumpy,\n    ToPIL,\n    TorchIO,\n    TorchVision,\n    ToTensor,\n    Transpose,\n)\nfrom monai.transforms.utils import extreme_points_to_image, get_extreme_points\nfrom monai.transforms.utils_pytorch_numpy_unification import concatenate\nfrom monai.utils import ensure_tuple, ensure_tuple_rep\nfrom monai.utils.enums import PostFix, TraceKeys, TransformBackends\nfrom monai.utils.type_conversion import convert_to_dst_type\n\n__all__ = [\n    \"AddCoordinateChannelsD\",\n    \"AddCoordinateChannelsDict\",\n    \"AddCoordinateChannelsd\",\n    \"AddExtremePointsChannelD\",\n    \"AddExtremePointsChannelDict\",\n    \"AddExtremePointsChanneld\",\n    \"AsChannelLastD\",\n    \"AsChannelLastDict\",\n    \"AsChannelLastd\",\n    \"CastToTypeD\",\n    \"CastToTypeDict\",\n    \"CastToTyped\",\n    \"ConcatItemsD\",\n    \"ConcatItemsDict\",\n    \"ConcatItemsd\",\n    \"ConvertToMultiChannelBasedOnBratsClassesD\",\n    \"ConvertToMultiChannelBasedOnBratsClassesDict\",\n    \"ConvertToMultiChannelBasedOnBratsClassesd\",\n    \"CopyItemsD\",\n    \"CopyItemsDict\",\n    \"CopyItemsd\",\n    \"CuCIMd\",\n    \"CuCIMD\",\n    \"CuCIMDict\",\n    \"DataStatsD\",\n    \"DataStatsDict\",\n    \"DataStatsd\",\n    \"DeleteItemsD\",\n    \"DeleteItemsDict\",\n    \"DeleteItemsd\",\n    \"EnsureChannelFirstD\",\n    \"EnsureChannelFirstDict\",\n    \"EnsureChannelFirstd\",\n    \"EnsureTypeD\",\n    \"EnsureTypeDict\",\n    \"EnsureTyped\",\n    \"FgBgToIndicesD\",\n    \"FgBgToIndicesDict\",\n    \"FgBgToIndicesd\",\n    \"IdentityD\",\n    \"IdentityDict\",\n    \"Identityd\",\n    \"IntensityStatsd\",\n    \"IntensityStatsD\",\n    \"IntensityStatsDict\",\n    \"ImageFilterd\",\n    \"LabelToMaskD\",\n    \"LabelToMaskDict\",\n    \"LabelToMaskd\",\n    \"LambdaD\",\n    \"LambdaDict\",\n    \"Lambdad\",\n    \"MapLabelValueD\",\n    \"MapLabelValueDict\",\n    \"MapLabelValued\",\n    \"FlattenSubKeysd\",\n    \"FlattenSubKeysD\",\n    \"FlattenSubKeysDict\",\n    \"RandCuCIMd\",\n    \"RandCuCIMD\",\n    \"RandCuCIMDict\",\n    \"RandImageFilterd\",\n    \"RandLambdaD\",\n    \"RandLambdaDict\",\n    \"RandLambdad\",\n    \"RandTorchIOd\",\n    \"RandTorchIOD\",\n    \"RandTorchIODict\",\n    \"RandTorchVisionD\",\n    \"RandTorchVisionDict\",\n    \"RandTorchVisiond\",\n    \"RemoveRepeatedChannelD\",\n    \"RemoveRepeatedChannelDict\",\n    \"RemoveRepeatedChanneld\",\n    \"RepeatChannelD\",\n    \"RepeatChannelDict\",\n    \"RepeatChanneld\",\n    \"SelectItemsD\",\n    \"SelectItemsDict\",\n    \"SelectItemsd\",\n    \"SimulateDelayD\",\n    \"SimulateDelayDict\",\n    \"SimulateDelayd\",\n    \"SplitDimD\",\n    \"SplitDimDict\",\n    \"SplitDimd\",\n    \"SqueezeDimD\",\n    \"SqueezeDimDict\",\n    \"SqueezeDimd\",\n    \"ToCupyD\",\n    \"ToCupyDict\",\n    \"ToCupyd\",\n    \"ToDeviced\",\n    \"ToDeviceD\",\n    \"ToDeviceDict\",\n    \"ToNumpyD\",\n    \"ToNumpyDict\",\n    \"ToNumpyd\",\n    \"ToPILD\",\n    \"ToPILDict\",\n    \"ToPILd\",\n    \"ToTensorD\",\n    \"ToTensorDict\",\n    \"ToTensord\",\n    \"TorchIOD\",\n    \"TorchIODict\",\n    \"TorchIOd\",\n    \"TorchVisionD\",\n    \"TorchVisionDict\",\n    \"TorchVisiond\",\n    \"Transposed\",\n    \"TransposeDict\",\n    \"TransposeD\",\n    \"ClassesToIndicesd\",\n    \"ClassesToIndicesD\",\n    \"ClassesToIndicesDict\",\n    \"ApplyTransformToPointsd\",\n    \"ApplyTransformToPointsD\",\n    \"ApplyTransformToPointsDict\",\n    \"FlattenSequenced\",\n    \"FlattenSequenceD\",\n    \"FlattenSequenceDict\",\n]\n\nDEFAULT_POST_FIX = PostFix.meta()\n\n\nclass Identityd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Identity`.\n    \"\"\"\n\n    backend = Identity.backend\n\n    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.identity = Identity()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.identity(d[key])\n        return d\n\n\nclass AsChannelLastd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`.\n    \"\"\"\n\n    backend = AsChannelLast.backend\n\n    def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            channel_dim: which dimension of input image is the channel, default is the first dimension.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = AsChannelLast(channel_dim=channel_dim)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass EnsureChannelFirstd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`.\n    \"\"\"\n\n    backend = EnsureChannelFirst.backend\n\n    def __init__(\n        self, keys: KeysCollection, strict_check: bool = True, allow_missing_keys: bool = False, channel_dim=None\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            strict_check: whether to raise an error when the meta information is insufficient.\n            allow_missing_keys: don't raise exception if key is missing.\n            channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.\n                It overrides the `original_channel_dim` from provided MetaTensor input.\n                If the input array doesn't have a channel dim, this value should be ``'no_channel'``.\n                If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.adjuster = EnsureChannelFirst(strict_check=strict_check, channel_dim=channel_dim)\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            meta_dict = d[key].meta if isinstance(d[key], MetaTensor) else None  # type: ignore[attr-defined]\n            d[key] = self.adjuster(d[key], meta_dict)\n        return d\n\n\nclass RepeatChanneld(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`.\n    \"\"\"\n\n    backend = RepeatChannel.backend\n\n    def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            repeats: the number of repetitions for each element.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.repeater = RepeatChannel(repeats)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.repeater(d[key])\n        return d\n\n\nclass RemoveRepeatedChanneld(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`.\n    \"\"\"\n\n    backend = RemoveRepeatedChannel.backend\n\n    def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            repeats: the number of repetitions for each element.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.repeater = RemoveRepeatedChannel(repeats)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.repeater(d[key])\n        return d\n\n\nclass SplitDimd(MapTransform, MultiSampleTrait):\n    backend = SplitDim.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        output_postfixes: Sequence[str] | None = None,\n        dim: int = 0,\n        keepdim: bool = True,\n        update_meta: bool = True,\n        list_output: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            output_postfixes: the postfixes to construct keys to store split data.\n                for example: if the key of input data is `pred` and split 2 classes, the output\n                data keys will be: pred_(output_postfixes[0]), pred_(output_postfixes[1])\n                if None, using the index number: `pred_0`, `pred_1`, ... `pred_N`.\n            dim: which dimension of input image is the channel, default to 0.\n            keepdim: if `True`, output will have singleton in the split dimension. If `False`, this\n                dimension will be squeezed.\n            update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to\n                reflect the cropped image\n            list_output: it `True`, the output will be a list of dictionaries with the same keys as original.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.output_postfixes = output_postfixes\n        self.splitter = SplitDim(dim, keepdim, update_meta)\n        self.list_output = list_output\n        if self.list_output is None and self.output_postfixes is not None:\n            raise ValueError(\"`output_postfixes` should not be provided when `list_output` is `True`.\")\n\n    def __call__(\n        self, data: Mapping[Hashable, torch.Tensor]\n    ) -> dict[Hashable, torch.Tensor] | list[dict[Hashable, torch.Tensor]]:\n        d = dict(data)\n        all_keys = list(set(self.key_iterator(d)))\n\n        if self.list_output:\n            output = []\n            results = [self.splitter(d[key]) for key in all_keys]\n            for row in zip(*results):\n                new_dict = dict(zip(all_keys, row))\n                # fill in the extra keys with unmodified data\n                for k in set(d.keys()).difference(set(all_keys)):\n                    new_dict[k] = deepcopy(d[k])\n                output.append(new_dict)\n            return output\n\n        for key in all_keys:\n            rets = self.splitter(d[key])\n            postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes\n            if len(postfixes) != len(rets):\n                raise ValueError(f\"count of splits must match output_postfixes, {len(postfixes)} != {len(rets)}.\")\n            for i, r in enumerate(rets):\n                split_key = f\"{key}_{postfixes[i]}\"\n                if split_key in d:\n                    raise RuntimeError(f\"input data already contains key {split_key}.\")\n                d[split_key] = r\n        return d\n\n\nclass CastToTyped(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.CastToType`.\n    \"\"\"\n\n    backend = CastToType.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            dtype: convert image to this data type, default is `np.float32`.\n                it also can be a sequence of dtypes or torch.dtype,\n                each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.converter = CastToType()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, dtype in self.key_iterator(d, self.dtype):\n            d[key] = self.converter(d[key], dtype=dtype)\n\n        return d\n\n\nclass ToTensord(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`.\n    \"\"\"\n\n    backend = ToTensor.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        dtype: torch.dtype | None = None,\n        device: torch.device | str | None = None,\n        wrap_sequence: bool = True,\n        track_meta: bool | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            dtype: target data content type to convert, for example: torch.float, etc.\n            device: specify the target device to put the Tensor data.\n            wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n                E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.\n            track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,\n                if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = ToTensor(dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n            self.push_transform(d, key)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            # Remove the applied transform\n            self.pop_transform(d, key)\n            # Create inverse transform\n            inverse_transform = ToNumpy()\n            # Apply inverse\n            d[key] = inverse_transform(d[key])\n        return d\n\n\nclass EnsureTyped(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.EnsureType`.\n\n    Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`,\n    `float`, `int`, `bool`, `string` and `object` keep the original.\n    If passing a dictionary, list or tuple, still return dictionary, list or tuple and recursively convert\n    every item to the expected data type if `wrap_sequence=False`.\n\n    Note: Currently, we only convert tensor data to numpy array or scalar number in the inverse operation.\n\n    \"\"\"\n\n    backend = EnsureType.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        data_type: str = \"tensor\",\n        dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = None,\n        device: torch.device | None = None,\n        wrap_sequence: bool = True,\n        track_meta: bool | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            data_type: target data type to convert, should be \"tensor\" or \"numpy\".\n            dtype: target data content type to convert, for example: np.float32, torch.float, etc.\n                It also can be a sequence of dtype, each element corresponds to a key in ``keys``.\n            device: for Tensor data type, specify the target device.\n            wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n                E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.\n            track_meta: whether to convert to `MetaTensor` when `data_type` is \"tensor\".\n                If False, the output data type will be `torch.Tensor`. Default to the return value of `get_track_meta`.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.dtype = ensure_tuple_rep(dtype, len(self.keys))\n        self.converter = EnsureType(\n            data_type=data_type, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta\n        )\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, dtype in self.key_iterator(d, self.dtype):\n            d[key] = self.converter(d[key], dtype)\n        return d\n\n\nclass ToNumpyd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.\n    \"\"\"\n\n    backend = ToNumpy.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        dtype: DtypeLike = None,\n        wrap_sequence: bool = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            dtype: target data type when converting to numpy array.\n            wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n                E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = ToNumpy(dtype=dtype, wrap_sequence=wrap_sequence)\n\n    def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass ToCupyd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        dtype: data type specifier. It is inferred from the input by default.\n            if not None, must be an argument of `numpy.dtype`, for more details:\n            https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.\n        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.\n            E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.\n        allow_missing_keys: don't raise exception if key is missing.\n    \"\"\"\n\n    backend = ToCupy.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        dtype: np.dtype | None = None,\n        wrap_sequence: bool = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.converter = ToCupy(dtype=dtype, wrap_sequence=wrap_sequence)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass ToPILd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.\n    \"\"\"\n\n    backend = ToPIL.backend\n\n    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = ToPIL()\n\n    def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass Transposed(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Transpose`.\n    \"\"\"\n\n    backend = Transpose.backend\n\n    def __init__(self, keys: KeysCollection, indices: Sequence[int] | None, allow_missing_keys: bool = False) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.transform = Transpose(indices)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.transform(d[key])\n            # if None was supplied then numpy uses range(a.ndim)[::-1]\n            indices = self.transform.indices or range(d[key].ndim)[::-1]\n            self.push_transform(d, key, extra_info={\"indices\": indices})\n        return d\n\n    def inverse(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            transform = self.get_most_recent_transform(d, key)\n            # Create inverse transform\n            fwd_indices = np.array(transform[TraceKeys.EXTRA_INFO][\"indices\"])\n            inv_indices = np.argsort(fwd_indices)\n            inverse_transform = Transpose(inv_indices.tolist())\n            # Apply inverse\n            d[key] = inverse_transform(d[key])\n            # Remove the applied transform\n            self.pop_transform(d, key)\n        return d\n\n\nclass DeleteItemsd(MapTransform):\n    \"\"\"\n    Delete specified items from data dictionary to release memory.\n    It will remove the key-values and copy the others to construct a new dictionary.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, keys: KeysCollection, sep: str = \".\", use_re: Sequence[bool] | bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to delete, can be \"A{sep}B{sep}C\"\n                to delete key `C` in nested dictionary, `C` can be regular expression.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            sep: the separator tag to define nested dictionary keys, default to \".\".\n            use_re: whether the specified key is a regular expression, it also can be\n                a list of bool values, mapping them to `keys`.\n        \"\"\"\n        super().__init__(keys)\n        self.sep = sep\n        self.use_re = ensure_tuple_rep(use_re, len(self.keys))\n\n    def __call__(self, data):\n\n        def _delete_item(keys, d, use_re: bool = False):\n            key = keys[0]\n            if len(keys) > 1:\n                d[key] = _delete_item(keys[1:], d[key], use_re)\n                return d\n            return {k: v for k, v in d.items() if (use_re and not re.search(key, f\"{k}\")) or (not use_re and k != key)}\n\n        d = dict(data)\n        for key, use_re in zip(cast(Sequence[str], self.keys), self.use_re):\n            d = _delete_item(key.split(self.sep), d, use_re)\n\n        return d\n\n\nclass SelectItemsd(MapTransform):\n    \"\"\"\n    Select only specified items from data dictionary to release memory.\n    It will copy the selected key-values and construct a new dictionary.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __call__(self, data):\n        return {key: data[key] for key in self.key_iterator(data)}\n\n\nclass FlattenSubKeysd(MapTransform):\n    \"\"\"\n    If an item is dictionary, it flatten the item by moving the sub-items (defined by sub-keys) to the top level.\n    {\"pred\": {\"a\": ..., \"b\", ... }} --> {\"a\": ..., \"b\", ... }\n\n    Args:\n        keys: keys of the corresponding items to be flatten\n        sub_keys: the sub-keys of items to be flatten. If not provided all the sub-keys are flattened.\n        delete_keys: whether to delete the key of the items that their sub-keys are flattened. Default to True.\n        prefix: optional prefix to be added to the sub-keys when moving to the top level.\n            By default no prefix will be added.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        sub_keys: KeysCollection | None = None,\n        delete_keys: bool = True,\n        prefix: str | None = None,\n    ) -> None:\n        super().__init__(keys)\n        self.sub_keys = sub_keys\n        self.delete_keys = delete_keys\n        self.prefix = prefix\n\n    def __call__(self, data):\n        d = dict(data)\n        for key in self.key_iterator(d):\n            # set the sub-keys for the specified key\n            sub_keys = d[key].keys() if self.sub_keys is None else self.sub_keys\n\n            # move all the sub-keys to the top level\n            for sk in sub_keys:\n                # set the top-level key for the sub-key\n                sk_top = f\"{self.prefix}_{sk}\" if self.prefix else sk\n                if sk_top in d:\n                    raise ValueError(\n                        f\"'{sk_top}' already exists in the top-level keys. Please change `prefix` to avoid duplicity.\"\n                    )\n                d[sk_top] = d[key][sk]\n\n            # delete top level key that is flattened\n            if self.delete_keys:\n                del d[key]\n        return d\n\n\nclass SqueezeDimd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SqueezeDim`.\n    \"\"\"\n\n    backend = SqueezeDim.backend\n\n    def __init__(\n        self, keys: KeysCollection, dim: int = 0, update_meta: bool = True, allow_missing_keys: bool = False\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            dim: dimension to be squeezed. Default: 0 (the first dimension)\n            update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = SqueezeDim(dim=dim, update_meta=update_meta)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass DataStatsd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.DataStats`.\n    \"\"\"\n\n    backend = DataStats.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        prefix: Sequence[str] | str = \"Data\",\n        data_type: Sequence[bool] | bool = True,\n        data_shape: Sequence[bool] | bool = True,\n        value_range: Sequence[bool] | bool = True,\n        data_value: Sequence[bool] | bool = False,\n        meta_info: Sequence[bool] | bool = False,\n        additional_info: Sequence[Callable] | Callable | None = None,\n        name: str = \"DataStats\",\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            prefix: will be printed in format: \"{prefix} statistics\".\n                it also can be a sequence of string, each element corresponds to a key in ``keys``.\n            data_type: whether to show the type of input data.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            data_shape: whether to show the shape of input data.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            value_range: whether to show the value range of input data.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            data_value: whether to show the raw value of input data.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n                a typical example is to print some properties of Nifti image: affine, pixdim, etc.\n            meta_info: whether to show the data of MetaTensor.\n                it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n            additional_info: user can define callable function to extract\n                additional info from input data. it also can be a sequence of string, each element\n                corresponds to a key in ``keys``.\n            name: identifier of `logging.logger` to use, defaulting to \"DataStats\".\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.prefix = ensure_tuple_rep(prefix, len(self.keys))\n        self.data_type = ensure_tuple_rep(data_type, len(self.keys))\n        self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))\n        self.value_range = ensure_tuple_rep(value_range, len(self.keys))\n        self.data_value = ensure_tuple_rep(data_value, len(self.keys))\n        self.meta_info = ensure_tuple_rep(meta_info, len(self.keys))\n        self.additional_info = ensure_tuple_rep(additional_info, len(self.keys))\n        self.printer = DataStats(name=name)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for (\n            key,\n            prefix,\n            data_type,\n            data_shape,\n            value_range,\n            data_value,\n            meta_info,\n            additional_info,\n        ) in self.key_iterator(\n            d,\n            self.prefix,\n            self.data_type,\n            self.data_shape,\n            self.value_range,\n            self.data_value,\n            self.meta_info,\n            self.additional_info,\n        ):\n            d[key] = self.printer(\n                d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info\n            )\n        return d\n\n\nclass SimulateDelayd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.SimulateDelay`.\n    \"\"\"\n\n    backend = SimulateDelay.backend\n\n    def __init__(\n        self, keys: KeysCollection, delay_time: Sequence[float] | float = 0.0, allow_missing_keys: bool = False\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            delay_time: The minimum amount of time, in fractions of seconds, to accomplish this identity task.\n                It also can be a sequence of string, each element corresponds to a key in ``keys``.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.delay_time = ensure_tuple_rep(delay_time, len(self.keys))\n        self.delayer = SimulateDelay()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, delay_time in self.key_iterator(d, self.delay_time):\n            d[key] = self.delayer(d[key], delay_time=delay_time)\n        return d\n\n\nclass CopyItemsd(MapTransform):\n    \"\"\"\n    Copy specified items from data dictionary and save with different key names.\n    It can copy several items together and copy several times.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        times: int = 1,\n        names: KeysCollection | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            times: expected copy times, for example, if keys is \"img\", times is 3,\n                it will add 3 copies of \"img\" data to the dictionary, default to 1.\n            names: the names corresponding to the newly copied data,\n                the length should match `len(keys) x times`. for example, if keys is [\"img\", \"seg\"]\n                and times is 2, names can be: [\"img_1\", \"seg_1\", \"img_2\", \"seg_2\"].\n                if None, use \"{key}_{index}\" as key for copy times `N`, index from `0` to `N-1`.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        Raises:\n            ValueError: When ``times`` is nonpositive.\n            ValueError: When ``len(names)`` is not ``len(keys) * times``. Incompatible values.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        if times < 1:\n            raise ValueError(f\"times must be positive, got {times}.\")\n        self.times = times\n        names = [f\"{k}_{i}\" for k in self.keys for i in range(self.times)] if names is None else ensure_tuple(names)\n        if len(names) != (len(self.keys) * times):\n            raise ValueError(\n                \"len(names) must match len(keys) * times, \"\n                f\"got len(names)={len(names)} len(keys) * times={len(self.keys) * times}.\"\n            )\n        self.names = names\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Raises:\n            KeyError: When a key in ``self.names`` already exists in ``data``.\n\n        \"\"\"\n        d = dict(data)\n        key_len = len(self.keys)\n        for i in range(self.times):\n            for key, new_key in self.key_iterator(d, self.names[i * key_len : (i + 1) * key_len]):\n                if new_key in d:\n                    raise KeyError(f\"Key {new_key} already exists in data.\")\n                val = d[key]\n                d[new_key] = MetaObj.copy_items(val) if isinstance(val, (torch.Tensor, np.ndarray)) else deepcopy(val)\n        return d\n\n\nclass ConcatItemsd(MapTransform):\n    \"\"\"\n    Concatenate specified items from data dictionary together on the first dim to construct a big array.\n    Expect all the items are numpy array or PyTorch Tensor or MetaTensor.\n    Return the first input's meta information when items are MetaTensor.\n    \"\"\"\n\n    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]\n\n    def __init__(self, keys: KeysCollection, name: str, dim: int = 0, allow_missing_keys: bool = False) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be concatenated together.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            name: the name corresponding to the key to store the concatenated data.\n            dim: on which dimension to concatenate the items, default is 0.\n            allow_missing_keys: don't raise exception if key is missing.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.name = name\n        self.dim = dim\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        \"\"\"\n        Raises:\n            TypeError: When items in ``data`` differ in type.\n            TypeError: When the item type is not in ``Union[numpy.ndarray, torch.Tensor, MetaTensor]``.\n\n        \"\"\"\n        d = dict(data)\n        output = []\n        data_type = None\n        for key in self.key_iterator(d):\n            if data_type is None:\n                data_type = type(d[key])\n            elif not isinstance(d[key], data_type):\n                raise TypeError(\"All items in data must have the same type.\")\n            output.append(d[key])\n\n        if len(output) == 0:\n            return d\n\n        if data_type is np.ndarray:\n            d[self.name] = np.concatenate(output, axis=self.dim)\n        elif issubclass(data_type, torch.Tensor):  # type: ignore\n            d[self.name] = torch.cat(output, dim=self.dim)  # type: ignore\n        else:\n            raise TypeError(\n                f\"Unsupported data type: {data_type}, available options are (numpy.ndarray, torch.Tensor, MetaTensor).\"\n            )\n        return d\n\n\nclass Lambdad(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`.\n\n    For example:\n\n    .. code-block:: python\n        :emphasize-lines: 2\n\n        input_data={'image': np.zeros((10, 2, 2)), 'label': np.ones((10, 2, 2))}\n        lambd = Lambdad(keys='label', func=lambda x: x[:4, :, :])\n        print(lambd(input_data)['label'].shape)\n        (4, 2, 2)\n\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        func: Lambda/function to be applied. It also can be a sequence of Callable,\n            each element corresponds to a key in ``keys``.\n        inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.\n            It also can be a sequence of Callable, each element corresponds to a key in ``keys``.\n        track_meta:  If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)\n            as opposed to MONAI's enhanced objects. By default, this is `True`.\n        overwrite: whether to overwrite the original data in the input dictionary with lambda function output. it\n            can be bool or str, when setting to str, it will create a new key for the output and keep the value of\n            key intact. default to True. it also can be a sequence of bool or str, each element corresponds to a key\n            in ``keys``.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the\n        image's original size. If need these complicated information, please write a new InvertibleTransform directly.\n\n    \"\"\"\n\n    backend = Lambda.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        func: Sequence[Callable] | Callable,\n        inv_func: Sequence[Callable] | Callable = no_collation,\n        track_meta: bool = True,\n        overwrite: Sequence[bool] | bool | Sequence[str] | str = True,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.func = ensure_tuple_rep(func, len(self.keys))\n        self.inv_func = ensure_tuple_rep(inv_func, len(self.keys))\n        self.overwrite = ensure_tuple_rep(overwrite, len(self.keys))\n        self._lambd = Lambda(track_meta=track_meta)\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite):\n            ret = self._lambd(img=d[key], func=func)\n            if overwrite and isinstance(overwrite, bool):\n                d[key] = ret\n            elif isinstance(overwrite, str):\n                d[overwrite] = ret\n        return d\n\n    def inverse(self, data):\n        d = dict(data)\n        for key, overwrite in self.key_iterator(d, self.overwrite):\n            ret = self._lambd.inverse(data=d[key])\n            if overwrite:\n                d[key] = ret\n        return d\n\n\nclass RandLambdad(Lambdad, RandomizableTransform):\n    \"\"\"\n    Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic,\n    or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        func: Lambda/function to be applied. It also can be a sequence of Callable,\n            each element corresponds to a key in ``keys``.\n        inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.\n            It also can be a sequence of Callable, each element corresponds to a key in ``keys``.\n        track_meta:  If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)\n            as opposed to MONAI's enhanced objects. By default, this is `True`.\n        overwrite: whether to overwrite the original data in the input dictionary with lambda function output.\n            default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.\n        prob: probability of executing the random function, default to 1.0, with 100% probability to execute.\n            note that all the data specified by `keys` will share the same random probability to execute or not.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    For more details, please check :py:class:`monai.transforms.Lambdad`.\n\n    Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the\n        image's original size. If need these complicated information, please write a new InvertibleTransform directly.\n    \"\"\"\n\n    backend = Lambda.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        func: Sequence[Callable] | Callable,\n        inv_func: Sequence[Callable] | Callable = no_collation,\n        track_meta: bool = True,\n        overwrite: Sequence[bool] | bool = True,\n        prob: float = 1.0,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        Lambdad.__init__(\n            self=self,\n            keys=keys,\n            func=func,\n            inv_func=inv_func,\n            track_meta=track_meta,\n            overwrite=overwrite,\n            allow_missing_keys=allow_missing_keys,\n        )\n        RandomizableTransform.__init__(self=self, prob=prob, do_transform=True)\n\n    def __call__(self, data):\n        self.randomize(data)\n        d = dict(data)\n        for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite):\n            ret = d[key]\n            if not isinstance(ret, MetaTensor):\n                ret = MetaTensor(ret)\n            if self._do_transform:\n                ret = self._lambd(ret, func=func)\n                self.push_transform(ret, extra_info={\"lambda_info\": self._lambd.pop_transform(ret)})\n            else:\n                self.push_transform(ret)\n            if overwrite:\n                d[key] = ret\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key, overwrite in self.key_iterator(d, self.overwrite):\n            if isinstance(d[key], MetaTensor):\n                tr = self.pop_transform(d[key])\n                if tr[TraceKeys.DO_TRANSFORM]:\n                    d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO][\"lambda_info\"])  # type: ignore\n                    ret = self._lambd.inverse(d[key])\n                    if overwrite:\n                        d[key] = ret\n        return d\n\n\nclass LabelToMaskd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.LabelToMask`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        select_labels: labels to generate mask from. for 1 channel label, the `select_labels`\n            is the expected label values, like: [1, 2, 3]. for One-Hot format label, the\n            `select_labels` is the expected channel indices.\n        merge_channels: whether to use `np.any()` to merge the result on channel dim.\n            if yes, will return a single channel mask with binary data.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = LabelToMask.backend\n\n    def __init__(  # pytype: disable=annotation-type-mismatch\n        self,\n        keys: KeysCollection,\n        select_labels: Sequence[int] | int,\n        merge_channels: bool = False,\n        allow_missing_keys: bool = False,\n    ) -> None:  # pytype: disable=annotation-type-mismatch\n        super().__init__(keys, allow_missing_keys)\n        self.converter = LabelToMask(select_labels=select_labels, merge_channels=merge_channels)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n\n        return d\n\n\nclass FgBgToIndicesd(MapTransform, MultiSampleTrait):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.FgBgToIndices`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        fg_postfix: postfix to save the computed foreground indices in dict.\n            for example, if computed on `label` and `postfix = \"_fg_indices\"`, the key will be `label_fg_indices`.\n        bg_postfix: postfix to save the computed background indices in dict.\n            for example, if computed on `label` and `postfix = \"_bg_indices\"`, the key will be `label_bg_indices`.\n        image_key: if image_key is not None, use ``label == 0 & image > image_threshold`` to determine\n            the negative sample(background). so the output items will not map to all the voxels in the label.\n        image_threshold: if enabled image_key, use ``image > image_threshold`` to determine\n            the valid image content area and select background only in this area.\n        output_shape: expected shape of output indices. if not None, unravel indices to specified shape.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = FgBgToIndices.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        fg_postfix: str = \"_fg_indices\",\n        bg_postfix: str = \"_bg_indices\",\n        image_key: str | None = None,\n        image_threshold: float = 0.0,\n        output_shape: Sequence[int] | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.fg_postfix = fg_postfix\n        self.bg_postfix = bg_postfix\n        self.image_key = image_key\n        self.converter = FgBgToIndices(image_threshold, output_shape)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        image = d[self.image_key] if self.image_key else None\n        for key in self.key_iterator(d):\n            d[str(key) + self.fg_postfix], d[str(key) + self.bg_postfix] = self.converter(d[key], image)\n\n        return d\n\n\nclass ClassesToIndicesd(MapTransform, MultiSampleTrait):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ClassesToIndices`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        indices_postfix: postfix to save the computed indices of all classes in dict.\n            for example, if computed on `label` and `postfix = \"_cls_indices\"`, the key will be `label_cls_indices`.\n        num_classes: number of classes for argmax label, not necessary for One-Hot label.\n        image_key: if image_key is not None, use ``image > image_threshold`` to define valid region, and only select\n            the indices within the valid region.\n        image_threshold: if enabled image_key, use ``image > image_threshold`` to determine the valid image content\n            area and select only the indices of classes in this area.\n        output_shape: expected shape of output indices. if not None, unravel indices to specified shape.\n        max_samples_per_class: maximum length of indices to sample in each class to reduce memory consumption.\n            Default is None, no subsampling.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = ClassesToIndices.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        indices_postfix: str = \"_cls_indices\",\n        num_classes: int | None = None,\n        image_key: str | None = None,\n        image_threshold: float = 0.0,\n        output_shape: Sequence[int] | None = None,\n        max_samples_per_class: int | None = None,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.indices_postfix = indices_postfix\n        self.image_key = image_key\n        self.converter = ClassesToIndices(num_classes, image_threshold, output_shape, max_samples_per_class)\n\n    def __call__(self, data: Mapping[Hashable, Any]):\n        d = dict(data)\n        image = d[self.image_key] if self.image_key else None\n        for key in self.key_iterator(d):\n            d[str(key) + self.indices_postfix] = self.converter(d[key], image)\n\n        return d\n\n\nclass ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.\n    Convert labels to multi channels based on brats18 classes:\n    label 1 is the necrotic and non-enhancing tumor core\n    label 2 is the peritumoral edema\n    label 4 is the GD-enhancing tumor\n    The possible classes are TC (Tumor core), WT (Whole tumor)\n    and ET (Enhancing tumor).\n    \"\"\"\n\n    backend = ConvertToMultiChannelBasedOnBratsClasses.backend\n\n    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):\n        super().__init__(keys, allow_missing_keys)\n        self.converter = ConvertToMultiChannelBasedOnBratsClasses()\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass AddExtremePointsChanneld(Randomizable, MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.AddExtremePointsChannel`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        label_key: key to label source to get the extreme points.\n        background: Class index of background label, defaults to 0.\n        pert: Random perturbation amount to add to the points, defaults to 0.0.\n        sigma: if a list of values, must match the count of spatial dimensions of input data,\n            and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n            use it for all spatial dimensions.\n        rescale_min: minimum value of output data.\n        rescale_max: maximum value of output data.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = AddExtremePointsChannel.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        label_key: str,\n        background: int = 0,\n        pert: float = 0.0,\n        sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 3.0,\n        rescale_min: float = -1.0,\n        rescale_max: float = 1.0,\n        allow_missing_keys: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.background = background\n        self.pert = pert\n        self.points: list[tuple[int, ...]] = []\n        self.label_key = label_key\n        self.sigma = sigma\n        self.rescale_min = rescale_min\n        self.rescale_max = rescale_max\n\n    def randomize(self, label: NdarrayOrTensor) -> None:\n        self.points = get_extreme_points(label, rand_state=self.R, background=self.background, pert=self.pert)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        label = d[self.label_key]\n        if label.shape[0] != 1:\n            raise ValueError(\"Only supports single channel labels!\")\n\n        # Generate extreme points\n        self.randomize(label[0, :])\n\n        for key in self.key_iterator(d):\n            img = d[key]\n            points_image = extreme_points_to_image(\n                points=self.points,\n                label=label,\n                sigma=self.sigma,\n                rescale_min=self.rescale_min,\n                rescale_max=self.rescale_max,\n            )\n            points_image, *_ = convert_to_dst_type(points_image, img)  # type: ignore\n            d[key] = concatenate([img, points_image], axis=0)\n        return d\n\n\nclass TorchVisiond(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision` for non-randomized transforms.\n    For randomized transforms of TorchVision use :py:class:`monai.transforms.RandTorchVisiond`.\n\n    Note:\n        As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input\n        data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.\n    \"\"\"\n\n    backend = TorchVision.backend\n\n    def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            name: The transform name in TorchVision package.\n            allow_missing_keys: don't raise exception if key is missing.\n            args: parameters for the TorchVision transform.\n            kwargs: parameters for the TorchVision transform.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.name = name\n        self.trans = TorchVision(name, *args, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.trans(d[key])\n        return d\n\n\nclass RandTorchVisiond(MapTransform, RandomizableTrait):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision` for randomized transforms.\n    For deterministic non-randomized transforms of TorchVision use :py:class:`monai.transforms.TorchVisiond`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        name: The transform name in TorchVision package.\n        allow_missing_keys: don't raise exception if key is missing.\n        args: parameters for the TorchVision transform.\n        kwargs: parameters for the TorchVision transform.\n\n    Note:\n\n        - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input\n          data to be dict of PyTorch Tensors. Users should call `ToTensord` transform first to convert Numpy to Tensor.\n        - This class inherits the ``Randomizable`` purely to prevent any dataset caching to skip the transform\n          computation. If the random factor of the underlying torchvision transform is not derived from `self.R`,\n          the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.\n\n    \"\"\"\n\n    backend = TorchVision.backend\n\n    def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.name = name\n        self.trans = TorchVision(name, *args, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.trans(d[key])\n        return d\n\n\nclass TorchIOd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.\n    For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.\n    \"\"\"\n\n    backend = TorchIO.backend\n\n    def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            name: The transform name in TorchIO package.\n            allow_missing_keys: don't raise exception if key is missing.\n            args: parameters for the TorchIO transform.\n            kwargs: parameters for the TorchIO transform.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.name = name\n        kwargs[\"include\"] = self.keys\n\n        self.trans = TorchIO(name, *args, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        return dict(self.trans(data))\n\n\nclass RandTorchIOd(MapTransform, RandomizableTrait):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.\n    For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.\n    \"\"\"\n\n    backend = TorchIO.backend\n\n    def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            name: The transform name in TorchIO package.\n            allow_missing_keys: don't raise exception if key is missing.\n            args: parameters for the TorchIO transform.\n            kwargs: parameters for the TorchIO transform.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.name = name\n        kwargs[\"include\"] = self.keys\n\n        self.trans = TorchIO(name, *args, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:\n        return dict(self.trans(data))\n\n\nclass MapLabelValued(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.\n    \"\"\"\n\n    backend = MapLabelValue.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        orig_labels: Sequence,\n        target_labels: Sequence,\n        dtype: DtypeLike = np.float32,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            orig_labels: original labels that map to others.\n            target_labels: expected label values, 1: 1 map to the `orig_labels`.\n            dtype: convert the output data to dtype, default to float32.\n                if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend.\n            allow_missing_keys: don't raise exception if key is missing.\n\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.mapper(d[key])\n        return d\n\n\nclass IntensityStatsd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.IntensityStats`.\n    Compute statistics for the intensity values of input image and store into the metadata dictionary.\n    For example: if `ops=[lambda x: np.mean(x), \"max\"]` and `key_prefix=\"orig\"`, may generate below stats:\n    `{\"orig_custom_0\": 1.5, \"orig_max\": 3.0}`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        ops: expected operations to compute statistics for the intensity.\n            if a string, will map to the predefined operations, supported: [\"mean\", \"median\", \"max\", \"min\", \"std\"]\n            mapping to `np.nanmean`, `np.nanmedian`, `np.nanmax`, `np.nanmin`, `np.nanstd`.\n            if a callable function, will execute the function on input image.\n        key_prefix: the prefix to combine with `ops` name to generate the key to store the results in the\n            metadata dictionary. if some `ops` are callable functions, will use \"{key_prefix}_custom_{index}\"\n            as the key, where index counts from 0.\n        mask_keys: if not None, specify the mask array for the image to extract only the interested area to compute\n            statistics, mask must have the same shape as the image.\n            it should be a sequence of strings or None, map to the `keys`.\n        channel_wise: whether to compute statistics for every channel of input image separately.\n            if True, return a list of values for every operation, default to False.\n        meta_keys: explicitly indicate the key of the corresponding metadata dictionary.\n            used to store the computed statistics to the meta dict.\n            for example, for data with key `image`, the metadata by default is in `image_meta_dict`.\n            the metadata is a dictionary object which contains: filename, original_shape, etc.\n            it can be a sequence of string, map to the `keys`.\n            if None, will try to construct meta_keys by `key_{meta_key_postfix}`.\n        meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according\n            to the key data, default is `meta_dict`, the metadata is a dictionary object.\n            used to store the computed statistics to the meta dict.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = IntensityStats.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        ops: Sequence[str | Callable],\n        key_prefix: str,\n        mask_keys: KeysCollection | None = None,\n        channel_wise: bool = False,\n        meta_keys: KeysCollection | None = None,\n        meta_key_postfix: str = DEFAULT_POST_FIX,\n        allow_missing_keys: bool = False,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.stats = IntensityStats(ops=ops, key_prefix=key_prefix, channel_wise=channel_wise)\n        self.mask_keys = ensure_tuple_rep(None, len(self.keys)) if mask_keys is None else ensure_tuple(mask_keys)\n        self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)\n        if len(self.keys) != len(self.meta_keys):\n            raise ValueError(\"meta_keys should have the same length as keys.\")\n        self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))\n\n    def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key, mask_key, meta_key, meta_key_postfix in self.key_iterator(\n            d, self.mask_keys, self.meta_keys, self.meta_key_postfix\n        ):\n            meta_key = meta_key or f\"{key}_{meta_key_postfix}\"\n            d[key], d[meta_key] = self.stats(\n                img=d[key], meta_data=d.get(meta_key), mask=d.get(mask_key) if mask_key is not None else None\n            )\n        return d\n\n\nclass ToDeviced(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`.\n    \"\"\"\n\n    backend = ToDevice.backend\n\n    def __init__(\n        self, keys: KeysCollection, device: torch.device | str, allow_missing_keys: bool = False, **kwargs\n    ) -> None:\n        \"\"\"\n        Args:\n            keys: keys of the corresponding items to be transformed.\n                See also: :py:class:`monai.transforms.compose.MapTransform`\n            device: target device to move the Tensor, for example: \"cuda:1\".\n            allow_missing_keys: don't raise exception if key is missing.\n            kwargs: other args for the PyTorch `Tensor.to()` API, for more details:\n                https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.\n        \"\"\"\n        super().__init__(keys, allow_missing_keys)\n        self.converter = ToDevice(device=device, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter(d[key])\n        return d\n\n\nclass CuCIMd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for non-randomized transforms.\n    For randomized transforms of CuCIM use :py:class:`monai.transforms.RandCuCIMd`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        name: The transform name in CuCIM package.\n        allow_missing_keys: don't raise exception if key is missing.\n        args: parameters for the CuCIM transform.\n        kwargs: parameters for the CuCIM transform.\n\n    Note:\n        CuCIM transforms only work with CuPy arrays, this transform expects input data to be `cupy.ndarray`.\n        Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:\n        super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)\n        self.name = name\n        self.trans = CuCIM(name, *args, **kwargs)\n\n    def __call__(self, data):\n        \"\"\"\n        Args:\n            data: Dict[Hashable, `cupy.ndarray`]\n\n        Returns:\n            Dict[Hashable, `cupy.ndarray`]\n\n        \"\"\"\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.trans(d[key])\n        return d\n\n\nclass RandCuCIMd(MapTransform, RandomizableTrait):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for randomized transforms.\n    For deterministic non-randomized transforms of CuCIM use :py:class:`monai.transforms.CuCIMd`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        name: The transform name in CuCIM package.\n        allow_missing_keys: don't raise exception if key is missing.\n        args: parameters for the CuCIM transform.\n        kwargs: parameters for the CuCIM transform.\n\n    Note:\n        - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.\n          Users should call `ToCuPy` transform first to convert a numpy array or torch tensor to cupy array.\n        - This class inherits the ``Randomizable`` purely to prevent any dataset caching to skip the transform\n          computation. If the random factor of the underlying cuCIM transform is not derived from `self.R`,\n          the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.name = name\n        self.trans = CuCIM(name, *args, **kwargs)\n\n    def __call__(self, data):\n        \"\"\"\n        Args:\n            data: Dict[Hashable, `cupy.ndarray`]\n\n        Returns:\n            Dict[Hashable, `cupy.ndarray`]\n\n        \"\"\"\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.trans(d[key])\n        return d\n\n\nclass AddCoordinateChannelsd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.AddCoordinateChannels`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: :py:class:`monai.transforms.compose.MapTransform`\n        spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and\n            appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels\n            to the input image, encoding the coordinates of the input's three spatial dimensions.\n        allow_missing_keys: don't raise exception if key is missing.\n\n    \"\"\"\n\n    backend = AddCoordinateChannels.backend\n\n    def __init__(self, keys: KeysCollection, spatial_dims: Sequence[int], allow_missing_keys: bool = False) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.add_coordinate_channels = AddCoordinateChannels(spatial_dims=spatial_dims)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.add_coordinate_channels(d[key])\n        return d\n\n\nclass ImageFilterd(MapTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ImageFilter`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        kernel:\n            A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`.\n            Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}``\n        kernel_size:\n            A single integer value specifying the size of the quadratic or cubic kernel.\n            Computational complexity increases exponentially with kernel_size, which\n            should be considered when choosing the kernel size.\n        allow_missing_keys:\n            Don't raise exception if key is missing.\n    \"\"\"\n\n    backend = ImageFilter.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        kernel: str | NdarrayOrTensor,\n        kernel_size: int | None = None,\n        allow_missing_keys: bool = False,\n        **kwargs,\n    ) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.filter = ImageFilter(kernel, kernel_size, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.filter(d[key])\n        return d\n\n\nclass RandImageFilterd(MapTransform, RandomizableTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.RandomFilterKernel`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        kernel:\n            A string specifying the kernel or a custom kernel as `torch.Tenor` or `np.ndarray`.\n            Available options are: `mean`, `laplacian`, `elliptical`, `sobel_{w,h,d}``\n        kernel_size:\n            A single integer value specifying the size of the quadratic or cubic kernel.\n            Computational complexity increases exponentially with kernel_size, which\n            should be considered when choosing the kernel size.\n        prob:\n            Probability the transform is applied to the data\n        allow_missing_keys:\n            Don't raise exception if key is missing.\n\n    Note:\n        - This transform does not scale output image values automatically to match the range of the input.\n          The output should be scaled by later transforms to match the input if this is desired.\n    \"\"\"\n\n    backend = ImageFilter.backend\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        kernel: str | NdarrayOrTensor,\n        kernel_size: int | None = None,\n        prob: float = 0.1,\n        allow_missing_keys: bool = False,\n        **kwargs,\n    ) -> None:\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        RandomizableTransform.__init__(self, prob)\n        self.filter = ImageFilter(kernel, kernel_size, **kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        self.randomize(None)\n        if self._do_transform:\n            for key in self.key_iterator(d):\n                d[key] = self.filter(d[key])\n        return d\n\n\nclass ApplyTransformToPointsd(MapTransform, InvertibleTransform):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.\n    The input coordinates are assumed to be in the shape (C, N, 2 or 3),\n    where C represents the number of channels and N denotes the number of points.\n    The output has the same shape as the input.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        refer_keys: The key of the reference item used for transformation.\n            It can directly refer to an affine or an image from which the affine can be derived. It can also be a\n            sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.\n        dtype: The desired data type for the output.\n        affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates\n            from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary\n            Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when\n            applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.\n            The matrix is always converted to float64 for computation, which can be computationally\n            expensive when applied to a large number of points.\n            If None, will try to use the affine matrix from the refer data.\n        invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.\n            Typically, the affine matrix is derived from the image, while the points are in world coordinates.\n            If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.\n        affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system\n            or you're using `ITKReader` with `affine_lps_to_ras=True`.\n            This ensures the correct application of the affine transformation between LPS (left-posterior-superior)\n            and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine\n            matrix are in the same coordinate system.\n        allow_missing_keys: Don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(\n        self,\n        keys: KeysCollection,\n        refer_keys: KeysCollection | None = None,\n        dtype: DtypeLike | torch.dtype = torch.float64,\n        affine: torch.Tensor | None = None,\n        invert_affine: bool = True,\n        affine_lps_to_ras: bool = False,\n        allow_missing_keys: bool = False,\n    ):\n        MapTransform.__init__(self, keys, allow_missing_keys)\n        self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))\n        self.converter = ApplyTransformToPoints(\n            dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras\n        )\n\n    def __call__(self, data: Mapping[Hashable, torch.Tensor]):\n        d = dict(data)\n        for key, refer_key in self.key_iterator(d, self.refer_keys):\n            coords = d[key]\n            affine = None  # represents using affine given in constructor\n            if refer_key is not None:\n                if refer_key in d:\n                    refer_data = d[refer_key]\n                else:\n                    raise KeyError(f\"The refer_key '{refer_key}' is not found in the data.\")\n\n                # use the \"affine\" member of refer_data, or refer_data itself, as the affine matrix\n                affine = getattr(refer_data, \"affine\", refer_data)\n            d[key] = self.converter(coords, affine)\n        return d\n\n    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.converter.inverse(d[key])\n        return d\n\n\nclass FlattenSequenced(MapTransform, ReduceTrait):\n    \"\"\"\n    Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.\n\n    Args:\n        keys: keys of the corresponding items to be transformed.\n            See also: monai.transforms.MapTransform\n        allow_missing_keys:\n            Don't raise exception if key is missing.\n    \"\"\"\n\n    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, **kwargs) -> None:\n        super().__init__(keys, allow_missing_keys)\n        self.flatten_sequence = FlattenSequence(**kwargs)\n\n    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:\n        d = dict(data)\n        for key in self.key_iterator(d):\n            d[key] = self.flatten_sequence(d[key])  # type: ignore[assignment]\n        return d\n\n\nRandImageFilterD = RandImageFilterDict = RandImageFilterd\nImageFilterD = ImageFilterDict = ImageFilterd\nIdentityD = IdentityDict = Identityd\nAsChannelLastD = AsChannelLastDict = AsChannelLastd\nEnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd\nRemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld\nRepeatChannelD = RepeatChannelDict = RepeatChanneld\nSplitDimD = SplitDimDict = SplitDimd\nCastToTypeD = CastToTypeDict = CastToTyped\nToTensorD = ToTensorDict = ToTensord\nEnsureTypeD = EnsureTypeDict = EnsureTyped\nToNumpyD = ToNumpyDict = ToNumpyd\nToCupyD = ToCupyDict = ToCupyd\nToPILD = ToPILDict = ToPILd\nTransposeD = TransposeDict = Transposed\nDeleteItemsD = DeleteItemsDict = DeleteItemsd\nSelectItemsD = SelectItemsDict = SelectItemsd\nSqueezeDimD = SqueezeDimDict = SqueezeDimd\nDataStatsD = DataStatsDict = DataStatsd\nSimulateDelayD = SimulateDelayDict = SimulateDelayd\nCopyItemsD = CopyItemsDict = CopyItemsd\nConcatItemsD = ConcatItemsDict = ConcatItemsd\nLambdaD = LambdaDict = Lambdad\nLabelToMaskD = LabelToMaskDict = LabelToMaskd\nFgBgToIndicesD = FgBgToIndicesDict = FgBgToIndicesd\nClassesToIndicesD = ClassesToIndicesDict = ClassesToIndicesd\nConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsClassesDict = (\n    ConvertToMultiChannelBasedOnBratsClassesd\n)\nAddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld\nTorchIOD = TorchIODict = TorchIOd\nTorchVisionD = TorchVisionDict = TorchVisiond\nRandTorchVisionD = RandTorchVisionDict = RandTorchVisiond\nRandTorchIOD = RandTorchIODict = RandTorchIOd\nRandLambdaD = RandLambdaDict = RandLambdad\nMapLabelValueD = MapLabelValueDict = MapLabelValued\nIntensityStatsD = IntensityStatsDict = IntensityStatsd\nToDeviceD = ToDeviceDict = ToDeviced\nCuCIMD = CuCIMDict = CuCIMd\nRandCuCIMD = RandCuCIMDict = RandCuCIMd\nAddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd\nFlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd\nApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd\nFlattenSequenceD = FlattenSequenceDict = FlattenSequenced\n"
  },
  {
    "path": "monai/transforms/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport random\nimport warnings\nfrom collections.abc import Callable, Hashable, Iterable, Mapping, Sequence\nfrom contextlib import contextmanager\nfrom functools import lru_cache, wraps\nfrom inspect import getmembers, isclass\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nimport monai\nfrom monai.config import DtypeLike, IndexSelection\nfrom monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor\nfrom monai.data.utils import to_affine_nd\nfrom monai.networks.layers import GaussianFilter\nfrom monai.networks.utils import meshgrid_ij\nfrom monai.transforms.compose import Compose\nfrom monai.transforms.transform import MapTransform, Transform, apply_transform\nfrom monai.transforms.utils_morphological_ops import erode\nfrom monai.transforms.utils_pytorch_numpy_unification import (\n    any_np_pt,\n    ascontiguousarray,\n    concatenate,\n    cumsum,\n    isfinite,\n    nonzero,\n    ravel,\n    searchsorted,\n    softplus,\n    unique,\n    unravel_index,\n    where,\n)\nfrom monai.utils import (\n    GridSampleMode,\n    GridSamplePadMode,\n    InterpolateMode,\n    NdimageMode,\n    NumpyPadMode,\n    PostFix,\n    PytorchPadMode,\n    SplineMode,\n    TraceKeys,\n    TraceStatusKeys,\n    ensure_tuple,\n    ensure_tuple_rep,\n    ensure_tuple_size,\n    fall_back_tuple,\n    get_equivalent_dtype,\n    issequenceiterable,\n    look_up_option,\n    min_version,\n    optional_import,\n    unsqueeze_left,\n    unsqueeze_right,\n)\nfrom monai.utils.enums import TransformBackends\nfrom monai.utils.type_conversion import (\n    convert_data_type,\n    convert_to_cupy,\n    convert_to_dst_type,\n    convert_to_numpy,\n    convert_to_tensor,\n)\n\nmeasure, has_measure = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\nmorphology, has_morphology = optional_import(\"skimage.morphology\")\nndimage, has_ndimage = optional_import(\"scipy.ndimage\")\ncp, has_cp = optional_import(\"cupy\")\ncp_ndarray, _ = optional_import(\"cupy\", name=\"ndarray\")\nexposure, has_skimage = optional_import(\"skimage.exposure\")\n\n__all__ = [\n    \"allow_missing_keys_mode\",\n    \"check_boundaries\",\n    \"compute_divisible_spatial_size\",\n    \"convert_applied_interp_mode\",\n    \"copypaste_arrays\",\n    \"check_non_lazy_pending_ops\",\n    \"create_control_grid\",\n    \"create_grid\",\n    \"create_rotate\",\n    \"create_scale\",\n    \"create_shear\",\n    \"create_translate\",\n    \"extreme_points_to_image\",\n    \"fill_holes\",\n    \"Fourier\",\n    \"generate_label_classes_crop_centers\",\n    \"generate_pos_neg_label_crop_centers\",\n    \"generate_spatial_bounding_box\",\n    \"get_extreme_points\",\n    \"get_largest_connected_component_mask\",\n    \"keep_merge_components_with_points\",\n    \"keep_components_with_positive_points\",\n    \"convert_points_to_disc\",\n    \"remove_small_objects\",\n    \"img_bounds\",\n    \"in_bounds\",\n    \"is_empty\",\n    \"is_positive\",\n    \"map_and_generate_sampling_centers\",\n    \"map_binary_to_indices\",\n    \"map_classes_to_indices\",\n    \"map_spatial_axes\",\n    \"rand_choice\",\n    \"rescale_array\",\n    \"rescale_array_int_max\",\n    \"rescale_instance_array\",\n    \"resize_center\",\n    \"weighted_patch_samples\",\n    \"zero_margins\",\n    \"equalize_hist\",\n    \"get_number_image_type_conversions\",\n    \"get_transform_backends\",\n    \"print_transform_backends\",\n    \"convert_pad_mode\",\n    \"convert_to_contiguous\",\n    \"get_unique_labels\",\n    \"scale_affine\",\n    \"attach_hook\",\n    \"sync_meta_info\",\n    \"reset_ops_id\",\n    \"resolves_modes\",\n    \"has_status_keys\",\n    \"distance_transform_edt\",\n    \"soft_clip\",\n]\n\n\ndef soft_clip(\n    arr: NdarrayOrTensor,\n    sharpness_factor: float = 1.0,\n    minv: NdarrayOrTensor | float | int | None = None,\n    maxv: NdarrayOrTensor | float | int | None = None,\n    dtype: DtypeLike | torch.dtype = np.float32,\n) -> NdarrayOrTensor:\n    \"\"\"\n    Apply soft clip to the input array or tensor.\n    The intensity values will be soft clipped according to\n    f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))\n    From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291\n\n    To perform one-sided clipping, set either minv or maxv to None.\n    Args:\n        arr: input array to clip.\n        sharpness_factor: the sharpness of the soft clip function, default to 1.\n        minv: minimum value of target clipped array.\n        maxv: maximum value of target clipped array.\n        dtype: if not None, convert input array to dtype before computation.\n\n    \"\"\"\n\n    if dtype is not None:\n        arr, *_ = convert_data_type(arr, dtype=dtype)\n\n    v = arr\n    if minv is not None:\n        v = v + softplus(-sharpness_factor * (arr - minv)) / sharpness_factor\n    if maxv is not None:\n        v = v - softplus(sharpness_factor * (arr - maxv)) / sharpness_factor\n\n    return v\n\n\ndef rand_choice(prob: float = 0.5) -> bool:\n    \"\"\"\n    Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.\n    \"\"\"\n    return bool(random.random() <= prob)\n\n\ndef img_bounds(img: np.ndarray):\n    \"\"\"\n    Returns the minimum and maximum indices of non-zero lines in axis 0 of `img`, followed by that for axis 1.\n    \"\"\"\n    ax0 = np.any(img, axis=0)\n    ax1 = np.any(img, axis=1)\n    return np.concatenate((np.where(ax0)[0][[0, -1]], np.where(ax1)[0][[0, -1]]))\n\n\ndef in_bounds(x: float, y: float, margin: float, maxx: float, maxy: float) -> bool:\n    \"\"\"\n    Returns True if (x,y) is within the rectangle (margin, margin, maxx-margin, maxy-margin).\n    \"\"\"\n    return bool(margin <= x < (maxx - margin) and margin <= y < (maxy - margin))\n\n\ndef is_empty(img: np.ndarray | torch.Tensor) -> bool:\n    \"\"\"\n    Returns True if `img` is empty, that is its maximum value is not greater than its minimum.\n    \"\"\"\n    return not (img.max() > img.min())  # use > instead of <= so that an image full of NaNs will result in True\n\n\ndef is_positive(img):\n    \"\"\"\n    Returns a boolean version of `img` where the positive values are converted into True, the other values are False.\n    \"\"\"\n    return img > 0\n\n\ndef zero_margins(img: np.ndarray, margin: int) -> bool:\n    \"\"\"\n    Returns True if the values within `margin` indices of the edges of `img` in dimensions 1 and 2 are 0.\n    \"\"\"\n    if np.any(img[:, :, :margin]) or np.any(img[:, :, -margin:]):\n        return False\n\n    return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :])\n\n\ndef rescale_array(\n    arr: NdarrayOrTensor,\n    minv: float | None = 0.0,\n    maxv: float | None = 1.0,\n    dtype: DtypeLike | torch.dtype = np.float32,\n) -> NdarrayOrTensor:\n    \"\"\"\n    Rescale the values of numpy array `arr` to be from `minv` to `maxv`.\n    If either `minv` or `maxv` is None, it returns `(a - min_a) / (max_a - min_a)`.\n\n    Args:\n        arr: input array to rescale.\n        minv: minimum value of target rescaled array.\n        maxv: maximum value of target rescaled array.\n        dtype: if not None, convert input array to dtype before computation.\n\n    \"\"\"\n    if dtype is not None:\n        arr, *_ = convert_data_type(arr, dtype=dtype)\n    mina = arr.min()\n    maxa = arr.max()\n\n    if mina == maxa:\n        return arr * minv if minv is not None else arr\n\n    norm = (arr - mina) / (maxa - mina)  # normalize the array first\n    if (minv is None) or (maxv is None):\n        return norm\n    return (norm * (maxv - minv)) + minv  # rescale by minv and maxv, which is the normalized array by default\n\n\ndef rescale_instance_array(\n    arr: np.ndarray, minv: float | None = 0.0, maxv: float | None = 1.0, dtype: DtypeLike = np.float32\n) -> np.ndarray:\n    \"\"\"\n    Rescale each array slice along the first dimension of `arr` independently.\n    \"\"\"\n    out: np.ndarray = np.zeros(arr.shape, dtype or arr.dtype)\n    for i in range(arr.shape[0]):\n        out[i] = rescale_array(arr[i], minv, maxv, dtype)\n\n    return out\n\n\ndef rescale_array_int_max(arr: np.ndarray, dtype: DtypeLike = np.uint16) -> np.ndarray:\n    \"\"\"\n    Rescale the array `arr` to be between the minimum and maximum values of the type `dtype`.\n    \"\"\"\n    info: np.iinfo = np.iinfo(dtype or arr.dtype)\n    return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype or arr.dtype)\n\n\ndef copypaste_arrays(\n    src_shape, dest_shape, srccenter: Sequence[int], destcenter: Sequence[int], dims: Sequence[int | None]\n) -> tuple[tuple[slice, ...], tuple[slice, ...]]:\n    \"\"\"\n    Calculate the slices to copy a sliced area of array in `src_shape` into array in `dest_shape`.\n\n    The area has dimensions `dims` (use 0 or None to copy everything in that dimension),\n    the source area is centered at `srccenter` index in `src` and copied into area centered at `destcenter` in `dest`.\n    The dimensions of the copied area will be clipped to fit within the\n    source and destination arrays so a smaller area may be copied than expected. Return value is the tuples of slice\n    objects indexing the copied area in `src`, and those indexing the copy area in `dest`.\n\n    Example\n\n    .. code-block:: python\n\n        src_shape = (6,6)\n        src = np.random.randint(0,10,src_shape)\n        dest = np.zeros_like(src)\n        srcslices, destslices = copypaste_arrays(src_shape, dest.shape, (3, 2),(2, 1),(3, 4))\n        dest[destslices] = src[srcslices]\n        print(src)\n        print(dest)\n\n        >>> [[9 5 6 6 9 6]\n             [4 3 5 6 1 2]\n             [0 7 3 2 4 1]\n             [3 0 0 1 5 1]\n             [9 4 7 1 8 2]\n             [6 6 5 8 6 7]]\n            [[0 0 0 0 0 0]\n             [7 3 2 4 0 0]\n             [0 0 1 5 0 0]\n             [4 7 1 8 0 0]\n             [0 0 0 0 0 0]\n             [0 0 0 0 0 0]]\n\n    \"\"\"\n    s_ndim = len(src_shape)\n    d_ndim = len(dest_shape)\n    srcslices = [slice(None)] * s_ndim\n    destslices = [slice(None)] * d_ndim\n\n    for i, ss, ds, sc, dc, dim in zip(range(s_ndim), src_shape, dest_shape, srccenter, destcenter, dims):\n        if dim:\n            # dimension before midpoint, clip to size fitting in both arrays\n            d1 = np.clip(dim // 2, 0, min(sc, dc))\n            # dimension after midpoint, clip to size fitting in both arrays\n            d2 = np.clip(dim // 2 + 1, 0, min(ss - sc, ds - dc))\n\n            srcslices[i] = slice(sc - d1, sc + d2)\n            destslices[i] = slice(dc - d1, dc + d2)\n\n    return tuple(srcslices), tuple(destslices)\n\n\ndef resize_center(img: np.ndarray, *resize_dims: int | None, fill_value: float = 0.0, inplace: bool = True):\n    \"\"\"\n    Resize `img` by cropping or expanding the image from the center. The `resize_dims` values are the output dimensions\n    (or None to use original dimension of `img`). If a dimension is smaller than that of `img` then the result will be\n    cropped and if larger padded with zeros, in both cases this is done relative to the center of `img`. The result is\n    a new image with the specified dimensions and values from `img` copied into its center.\n    \"\"\"\n\n    resize_dims = fall_back_tuple(resize_dims, img.shape)\n\n    half_img_shape = (np.asarray(img.shape) // 2).tolist()\n    half_dest_shape = (np.asarray(resize_dims) // 2).tolist()\n    srcslices, destslices = copypaste_arrays(img.shape, resize_dims, half_img_shape, half_dest_shape, resize_dims)\n\n    if not inplace:\n        dest = np.full(resize_dims, fill_value, img.dtype)  # type: ignore\n        dest[destslices] = img[srcslices]\n        return dest\n    return img[srcslices]\n\n\ndef check_non_lazy_pending_ops(\n    input_array: NdarrayOrTensor, name: None | str = None, raise_error: bool = False\n) -> None:\n    \"\"\"\n    Check whether the input array has pending operations, raise an error or warn when it has.\n\n    Args:\n        input_array: input array to be checked.\n        name: an optional name to be included in the error message.\n        raise_error: whether to raise an error, default to False, a warning message will be issued instead.\n    \"\"\"\n    if isinstance(input_array, monai.data.MetaTensor) and input_array.pending_operations:\n        msg = (\n            \"The input image is a MetaTensor and has pending operations,\\n\"\n            f\"but the function {name or ''} assumes non-lazy input, result may be incorrect.\"\n        )\n        if raise_error:\n            raise ValueError(msg)\n        warnings.warn(msg)\n\n\ndef map_and_generate_sampling_centers(\n    label: NdarrayOrTensor,\n    spatial_size: Sequence[int] | int,\n    num_samples: int,\n    label_spatial_shape: Sequence[int] | None = None,\n    num_classes: int | None = None,\n    image: NdarrayOrTensor | None = None,\n    image_threshold: float = 0.0,\n    max_samples_per_class: int | None = None,\n    ratios: list[float | int] | None = None,\n    rand_state: np.random.RandomState | None = None,\n    allow_smaller: bool = False,\n    warn: bool = True,\n) -> tuple[tuple]:\n    \"\"\"\n    Combine \"map_classes_to_indices\" and \"generate_label_classes_crop_centers\" functions, return crop center coordinates.\n    This calls `map_classes_to_indices` to get indices from `label`, gets the shape from `label_spatial_shape`\n    is given otherwise from the labels, calls `generate_label_classes_crop_centers`, and returns its results.\n\n    Args:\n        label: use the label data to get the indices of every class.\n        spatial_size: spatial size of the ROIs to be sampled.\n        num_samples: total sample centers to be generated.\n        label_spatial_shape: spatial shape of the original label data to unravel selected centers.\n        indices: sequence of pre-computed foreground indices of every class in 1 dimension.\n        num_classes: number of classes for argmax label, not necessary for One-Hot label.\n        image: if image is not None, only return the indices of every class that are within the valid\n            region of the image (``image > image_threshold``).\n        image_threshold: if enabled `image`, use ``image > image_threshold`` to\n            determine the valid image content area and select class indices only in this area.\n        max_samples_per_class: maximum length of indices in each class to reduce memory consumption.\n            Default is None, no subsampling.\n        ratios: ratios of every class in the label to generate crop centers, including background class.\n            if None, every class will have the same ratio to generate crop centers.\n        rand_state: numpy randomState object to align with other modules.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n        warn: if `True` prints a warning if a class is not present in the label.\n    Returns:\n        Tuple of crop centres\n    \"\"\"\n    if label is None:\n        raise ValueError(\"label must not be None.\")\n    indices = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class)\n\n    if label_spatial_shape is not None:\n        _shape = label_spatial_shape\n    elif isinstance(label, monai.data.MetaTensor):\n        _shape = label.peek_pending_shape()\n    else:\n        _shape = label.shape[1:]\n\n    if _shape is None:\n        raise ValueError(\n            \"label_spatial_shape or label with a known shape must be provided to infer the output spatial shape.\"\n        )\n    centers = generate_label_classes_crop_centers(\n        spatial_size, num_samples, _shape, indices, ratios, rand_state, allow_smaller, warn\n    )\n\n    return ensure_tuple(centers)\n\n\ndef map_binary_to_indices(\n    label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0\n) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:\n    \"\"\"\n    Compute the foreground and background of input label data, return the indices after fattening.\n    For example:\n    ``label = np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])``\n    ``foreground indices = np.array([1, 2, 3, 5, 6, 7])`` and ``background indices = np.array([0, 4, 8])``\n\n    Args:\n        label: use the label data to get the foreground/background information.\n        image: if image is not None, use ``label = 0 & image > image_threshold``\n            to define background. so the output items will not map to all the voxels in the label.\n        image_threshold: if enabled `image`, use ``image > image_threshold`` to\n            determine the valid image content area and select background only in this area.\n    \"\"\"\n    check_non_lazy_pending_ops(label, name=\"map_binary_to_indices\")\n    # Prepare fg/bg indices\n    if label.shape[0] > 1:\n        label = label[1:]  # for One-Hot format data, remove the background channel\n    label_flat = ravel(any_np_pt(label, 0))  # in case label has multiple dimensions\n    fg_indices = nonzero(label_flat)\n    if image is not None:\n        check_non_lazy_pending_ops(image, name=\"map_binary_to_indices\")\n        img_flat = ravel(any_np_pt(image > image_threshold, 0))\n        img_flat, *_ = convert_to_dst_type(img_flat, label, dtype=bool)\n        bg_indices = nonzero(img_flat & ~label_flat)\n    else:\n        bg_indices = nonzero(~label_flat)\n\n    # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices\n    fg_indices, *_ = convert_data_type(fg_indices, device=torch.device(\"cpu\"))\n    bg_indices, *_ = convert_data_type(bg_indices, device=torch.device(\"cpu\"))\n    return fg_indices, bg_indices\n\n\ndef map_classes_to_indices(\n    label: NdarrayOrTensor,\n    num_classes: int | None = None,\n    image: NdarrayOrTensor | None = None,\n    image_threshold: float = 0.0,\n    max_samples_per_class: int | None = None,\n) -> list[NdarrayOrTensor]:\n    \"\"\"\n    Filter out indices of every class of the input label data, return the indices after fattening.\n    It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for\n    Argmax label.\n\n    For example:\n    ``label = np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])`` and `num_classes=3`, will return a list\n    which contains the indices of the 3 classes:\n    ``[np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])]``\n\n    Args:\n        label: use the label data to get the indices of every class.\n        num_classes: number of classes for argmax label, not necessary for One-Hot label.\n        image: if image is not None, only return the indices of every class that are within the valid\n            region of the image (``image > image_threshold``).\n        image_threshold: if enabled `image`, use ``image > image_threshold`` to\n            determine the valid image content area and select class indices only in this area.\n        max_samples_per_class: maximum length of indices in each class to reduce memory consumption.\n            Default is None, no subsampling.\n\n    \"\"\"\n    check_non_lazy_pending_ops(label, name=\"map_classes_to_indices\")\n    img_flat: NdarrayOrTensor | None = None\n    if image is not None:\n        check_non_lazy_pending_ops(image, name=\"map_classes_to_indices\")\n        img_flat = ravel((image > image_threshold).any(0))  # type: ignore\n\n    # assuming the first dimension is channel\n    channels = len(label)\n\n    num_classes_: int = channels\n    if channels == 1:\n        if num_classes is None:\n            raise ValueError(\"channels==1 indicates not using One-Hot format label, must provide ``num_classes``.\")\n        num_classes_ = num_classes\n\n    indices: list[NdarrayOrTensor] = []\n    for c in range(num_classes_):\n        if channels > 1:\n            label_flat = ravel(convert_data_type(label[c], dtype=bool)[0])\n        else:\n            label_flat = ravel(label == c)\n        if img_flat is not None:\n            label_flat = img_flat & label_flat\n        # no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices\n        output_type = torch.Tensor if isinstance(label, monai.data.MetaTensor) else None\n        cls_indices: NdarrayOrTensor = convert_data_type(\n            nonzero(label_flat), output_type=output_type, device=torch.device(\"cpu\")\n        )[0]\n        if max_samples_per_class and len(cls_indices) > max_samples_per_class and len(cls_indices) > 1:\n            sample_id = np.round(np.linspace(0, len(cls_indices) - 1, max_samples_per_class)).astype(int)\n            indices.append(cls_indices[sample_id])\n        else:\n            indices.append(cls_indices)\n\n    return indices\n\n\ndef weighted_patch_samples(\n    spatial_size: int | Sequence[int],\n    w: NdarrayOrTensor,\n    n_samples: int = 1,\n    r_state: np.random.RandomState | None = None,\n) -> list:\n    \"\"\"\n    Computes `n_samples` of random patch sampling locations, given the sampling weight map `w` and patch `spatial_size`.\n\n    Args:\n        spatial_size: length of each spatial dimension of the patch.\n        w: weight map, the weights must be non-negative. each element denotes a sampling weight of the spatial location.\n            0 indicates no sampling.\n            The weight map shape is assumed ``(spatial_dim_0, spatial_dim_1, ..., spatial_dim_n)``.\n        n_samples: number of patch samples\n        r_state: a random state container\n\n    Returns:\n        a list of `n_samples` N-D integers representing the spatial sampling location of patches.\n\n    \"\"\"\n    check_non_lazy_pending_ops(w, name=\"weighted_patch_samples\")\n    if w is None:\n        raise ValueError(\"w must be an ND array, got None.\")\n    if r_state is None:\n        r_state = np.random.RandomState()\n    img_size = np.asarray(w.shape, dtype=int)\n    win_size = np.asarray(fall_back_tuple(spatial_size, img_size), dtype=int)\n\n    s = tuple(slice(w // 2, m - w + w // 2) if m > w else slice(m // 2, m // 2 + 1) for w, m in zip(win_size, img_size))\n    v = w[s]  # weight map in the 'valid' mode\n    v_size = v.shape\n    v = ravel(v)  # always copy\n    if (v < 0).any():\n        v -= v.min()  # shifting to non-negative\n    v = cumsum(v)\n    if not v[-1] or not isfinite(v[-1]) or v[-1] < 0:  # uniform sampling\n        idx = r_state.randint(0, len(v), size=n_samples)\n    else:\n        r_samples = r_state.random(n_samples)\n        r, *_ = convert_to_dst_type(r_samples, v, dtype=r_samples.dtype)\n        idx = searchsorted(v, r * v[-1], right=True)  # type: ignore\n    idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int)  # type: ignore\n    # compensate 'valid' mode\n    diff = np.minimum(win_size, img_size) // 2\n    diff, *_ = convert_to_dst_type(diff, v)  # type: ignore\n    return [unravel_index(i, v_size) + diff for i in idx]\n\n\ndef correct_crop_centers(\n    centers: list[int],\n    spatial_size: Sequence[int] | int,\n    label_spatial_shape: Sequence[int],\n    allow_smaller: bool = False,\n) -> tuple[Any]:\n    \"\"\"\n    Utility to correct the crop center if the crop size and centers are not compatible with the image size.\n\n    Args:\n        centers: pre-computed crop centers of every dim, will correct based on the valid region.\n        spatial_size: spatial size of the ROIs to be sampled.\n        label_spatial_shape: spatial shape of the original label data to compare with ROI.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n\n    \"\"\"\n    spatial_size = fall_back_tuple(spatial_size, default=label_spatial_shape)\n    if any(np.subtract(label_spatial_shape, spatial_size) < 0):\n        if not allow_smaller:\n            raise ValueError(\n                \"The size of the proposed random crop ROI is larger than the image size, \"\n                f\"got ROI size {spatial_size} and label image size {label_spatial_shape} respectively.\"\n            )\n        spatial_size = tuple(min(l, s) for l, s in zip(label_spatial_shape, spatial_size))\n\n    # Select subregion to assure valid roi\n    valid_start = np.floor_divide(spatial_size, 2)\n    # add 1 for random\n    valid_end = np.subtract(label_spatial_shape + np.array(1), spatial_size / np.array(2)).astype(np.uint16)\n    # int generation to have full range on upper side, but subtract unfloored size/2 to prevent rounded range\n    # from being too high\n    for i, valid_s in enumerate(valid_start):\n        # need this because np.random.randint does not work with same start and end\n        if valid_s == valid_end[i]:\n            valid_end[i] += 1\n    valid_centers = []\n    for c, v_s, v_e in zip(centers, valid_start, valid_end):\n        center_i = min(max(c, v_s), v_e - 1)\n        valid_centers.append(int(center_i))\n    return ensure_tuple(valid_centers)\n\n\ndef generate_pos_neg_label_crop_centers(\n    spatial_size: Sequence[int] | int,\n    num_samples: int,\n    pos_ratio: float,\n    label_spatial_shape: Sequence[int],\n    fg_indices: NdarrayOrTensor,\n    bg_indices: NdarrayOrTensor,\n    rand_state: np.random.RandomState | None = None,\n    allow_smaller: bool = False,\n) -> tuple[tuple]:\n    \"\"\"\n    Generate valid sample locations based on the label with option for specifying foreground ratio\n    Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]\n\n    Args:\n        spatial_size: spatial size of the ROIs to be sampled.\n        num_samples: total sample centers to be generated.\n        pos_ratio: ratio of total locations generated that have center being foreground.\n        label_spatial_shape: spatial shape of the original label data to unravel selected centers.\n        fg_indices: pre-computed foreground indices in 1 dimension.\n        bg_indices: pre-computed background indices in 1 dimension.\n        rand_state: numpy randomState object to align with other modules.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n\n    Raises:\n        ValueError: When the proposed roi is larger than the image.\n        ValueError: When the foreground and background indices lengths are 0.\n\n    \"\"\"\n    if rand_state is None:\n        rand_state = np.random.random.__self__  # type: ignore\n\n    centers = []\n    fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices\n    bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices\n    if len(fg_indices) == 0 and len(bg_indices) == 0:\n        raise ValueError(\"No sampling location available.\")\n\n    if len(fg_indices) == 0 or len(bg_indices) == 0:\n        pos_ratio = 0 if len(fg_indices) == 0 else 1\n        warnings.warn(\n            f\"Num foregrounds {len(fg_indices)}, Num backgrounds {len(bg_indices)}, \"\n            f\"unable to generate class balanced samples, setting `pos_ratio` to {pos_ratio}.\"\n        )\n\n    for _ in range(num_samples):\n        indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices\n        random_int = rand_state.randint(len(indices_to_use))\n        idx = indices_to_use[random_int]\n        center = unravel_index(idx, label_spatial_shape).tolist()\n        # shift center to range of valid centers\n        centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))\n\n    return ensure_tuple(centers)\n\n\ndef generate_label_classes_crop_centers(\n    spatial_size: Sequence[int] | int,\n    num_samples: int,\n    label_spatial_shape: Sequence[int],\n    indices: Sequence[NdarrayOrTensor],\n    ratios: list[float | int] | None = None,\n    rand_state: np.random.RandomState | None = None,\n    allow_smaller: bool = False,\n    warn: bool = True,\n) -> tuple[tuple]:\n    \"\"\"\n    Generate valid sample locations based on the specified ratios of label classes.\n    Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]\n\n    Args:\n        spatial_size: spatial size of the ROIs to be sampled.\n        num_samples: total sample centers to be generated.\n        label_spatial_shape: spatial shape of the original label data to unravel selected centers.\n        indices: sequence of pre-computed foreground indices of every class in 1 dimension.\n        ratios: ratios of every class in the label to generate crop centers, including background class.\n            if None, every class will have the same ratio to generate crop centers.\n        rand_state: numpy randomState object to align with other modules.\n        allow_smaller: if `False`, an exception will be raised if the image is smaller than\n            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to\n            match the cropped size (i.e., no cropping in that dimension).\n        warn: if `True` prints a warning if a class is not present in the label.\n\n    \"\"\"\n    if rand_state is None:\n        rand_state = np.random.random.__self__  # type: ignore\n\n    if num_samples < 1:\n        raise ValueError(f\"num_samples must be an int number and greater than 0, got {num_samples}.\")\n    ratios_: list[float | int] = list(ensure_tuple([1] * len(indices) if ratios is None else ratios))\n    if len(ratios_) != len(indices):\n        raise ValueError(\n            f\"random crop ratios must match the number of indices of classes, got {len(ratios_)} and {len(indices)}.\"\n        )\n    if any(i < 0 for i in ratios_):\n        raise ValueError(f\"ratios should not contain negative number, got {ratios_}.\")\n\n    for i, array in enumerate(indices):\n        if len(array) == 0:\n            if ratios_[i] != 0:\n                ratios_[i] = 0\n                if warn:\n                    warnings.warn(\n                        f\"no available indices of class {i} to crop, setting the crop ratio of this class to zero.\"\n                    )\n\n    centers = []\n    classes = rand_state.choice(len(ratios_), size=num_samples, p=np.asarray(ratios_) / np.sum(ratios_))\n    for i in classes:\n        # randomly select the indices of a class based on the ratios\n        indices_to_use = indices[i]\n        random_int = rand_state.randint(len(indices_to_use))\n        center = unravel_index(indices_to_use[random_int], label_spatial_shape).tolist()\n        # shift center to range of valid centers\n        centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))\n\n    return ensure_tuple(centers)\n\n\ndef create_grid(\n    spatial_size: Sequence[int],\n    spacing: Sequence[float] | None = None,\n    homogeneous: bool = True,\n    dtype: DtypeLike | torch.dtype = float,\n    device: torch.device | None = None,\n    backend=TransformBackends.NUMPY,\n) -> NdarrayOrTensor:\n    \"\"\"\n    compute a `spatial_size` mesh.\n\n        - when ``homogeneous=True``, the output shape is (N+1, dim_size_1, dim_size_2, ..., dim_size_N)\n        - when ``homogeneous=False``, the output shape is (N, dim_size_1, dim_size_2, ..., dim_size_N)\n\n    Args:\n        spatial_size: spatial size of the grid.\n        spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid).\n        homogeneous: whether to make homogeneous coordinates.\n        dtype: output grid data type, defaults to `float`.\n        device: device to compute and store the output (when the backend is \"torch\").\n        backend: APIs to use, ``numpy`` or ``torch``.\n\n    \"\"\"\n    _backend = look_up_option(backend, TransformBackends)\n    _dtype = dtype or float\n    if _backend == TransformBackends.NUMPY:\n        return _create_grid_numpy(spatial_size, spacing, homogeneous, _dtype)  # type: ignore\n    if _backend == TransformBackends.TORCH:\n        return _create_grid_torch(spatial_size, spacing, homogeneous, _dtype, device)  # type: ignore\n    raise ValueError(f\"backend {backend} is not supported\")\n\n\ndef _create_grid_numpy(\n    spatial_size: Sequence[int],\n    spacing: Sequence[float] | None = None,\n    homogeneous: bool = True,\n    dtype: DtypeLike | torch.dtype = float,\n):\n    \"\"\"\n    compute a `spatial_size` mesh with the numpy API.\n    \"\"\"\n    spacing = spacing or tuple(1.0 for _ in spatial_size)\n    ranges = [np.linspace(-(d - 1.0) / 2.0 * s, (d - 1.0) / 2.0 * s, int(d)) for d, s in zip(spatial_size, spacing)]\n    coords = np.asarray(np.meshgrid(*ranges, indexing=\"ij\"), dtype=get_equivalent_dtype(dtype, np.ndarray))\n    if not homogeneous:\n        return coords\n    return np.concatenate([coords, np.ones_like(coords[:1])])\n\n\ndef _create_grid_torch(\n    spatial_size: Sequence[int],\n    spacing: Sequence[float] | None = None,\n    homogeneous: bool = True,\n    dtype=torch.float32,\n    device: torch.device | None = None,\n):\n    \"\"\"\n    compute a `spatial_size` mesh with the torch API.\n    \"\"\"\n    spacing = spacing or tuple(1.0 for _ in spatial_size)\n    ranges = [\n        torch.linspace(\n            -(d - 1.0) / 2.0 * s,\n            (d - 1.0) / 2.0 * s,\n            int(d),\n            device=device,\n            dtype=get_equivalent_dtype(dtype, torch.Tensor),\n        )\n        for d, s in zip(spatial_size, spacing)\n    ]\n    coords = meshgrid_ij(*ranges)\n    if not homogeneous:\n        return torch.stack(coords)\n    return torch.stack([*coords, torch.ones_like(coords[0])])\n\n\ndef create_control_grid(\n    spatial_shape: Sequence[int],\n    spacing: Sequence[float],\n    homogeneous: bool = True,\n    dtype: DtypeLike = float,\n    device: torch.device | None = None,\n    backend=TransformBackends.NUMPY,\n):\n    \"\"\"\n    control grid with two additional point in each direction\n    \"\"\"\n    torch_backend = look_up_option(backend, TransformBackends) == TransformBackends.TORCH\n    ceil_func: Callable = torch.ceil if torch_backend else np.ceil  # type: ignore\n    grid_shape = []\n    for d, s in zip(spatial_shape, spacing):\n        d = torch.as_tensor(d, device=device) if torch_backend else int(d)  # type: ignore\n        if d % 2 == 0:\n            grid_shape.append(ceil_func((d - 1.0) / (2.0 * s) + 0.5) * 2.0 + 2.0)\n        else:\n            grid_shape.append(ceil_func((d - 1.0) / (2.0 * s)) * 2.0 + 3.0)\n    return create_grid(\n        spatial_size=grid_shape, spacing=spacing, homogeneous=homogeneous, dtype=dtype, device=device, backend=backend\n    )\n\n\ndef create_rotate(\n    spatial_dims: int,\n    radians: Sequence[float] | float,\n    device: torch.device | None = None,\n    backend: str = TransformBackends.NUMPY,\n) -> NdarrayOrTensor:\n    \"\"\"\n    create a 2D or 3D rotation matrix\n\n    Args:\n        spatial_dims: {``2``, ``3``} spatial rank\n        radians: rotation radians\n            when spatial_dims == 3, the `radians` sequence corresponds to\n            rotation in the 1st, 2nd, and 3rd dim respectively.\n        device: device to compute and store the output (when the backend is \"torch\").\n        backend: APIs to use, ``numpy`` or ``torch``.\n\n    Raises:\n        ValueError: When ``radians`` is empty.\n        ValueError: When ``spatial_dims`` is not one of [2, 3].\n\n    \"\"\"\n    _backend = look_up_option(backend, TransformBackends)\n    if _backend == TransformBackends.NUMPY:\n        return _create_rotate(\n            spatial_dims=spatial_dims, radians=radians, sin_func=np.sin, cos_func=np.cos, eye_func=np.eye\n        )\n    if _backend == TransformBackends.TORCH:\n        return _create_rotate(\n            spatial_dims=spatial_dims,\n            radians=radians,\n            sin_func=lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32, device=device)),\n            cos_func=lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32, device=device)),\n            eye_func=lambda rank: torch.eye(rank, device=device),\n        )\n    raise ValueError(f\"backend {backend} is not supported\")\n\n\ndef _create_rotate(\n    spatial_dims: int,\n    radians: Sequence[float] | float,\n    sin_func: Callable = np.sin,\n    cos_func: Callable = np.cos,\n    eye_func: Callable = np.eye,\n) -> NdarrayOrTensor:\n    radians = ensure_tuple(radians)\n    if spatial_dims == 2:\n        if len(radians) >= 1:\n            sin_, cos_ = sin_func(radians[0]), cos_func(radians[0])\n            out = eye_func(3)\n            out[0, 0], out[0, 1] = cos_, -sin_\n            out[1, 0], out[1, 1] = sin_, cos_\n            return out  # type: ignore\n        raise ValueError(\"radians must be non empty.\")\n\n    if spatial_dims == 3:\n        affine = None\n        if len(radians) >= 1:\n            sin_, cos_ = sin_func(radians[0]), cos_func(radians[0])\n            affine = eye_func(4)\n            affine[1, 1], affine[1, 2] = cos_, -sin_\n            affine[2, 1], affine[2, 2] = sin_, cos_\n        if len(radians) >= 2:\n            sin_, cos_ = sin_func(radians[1]), cos_func(radians[1])\n            if affine is None:\n                raise ValueError(\"Affine should be a matrix.\")\n            _affine = eye_func(4)\n            _affine[0, 0], _affine[0, 2] = cos_, sin_\n            _affine[2, 0], _affine[2, 2] = -sin_, cos_\n            affine = affine @ _affine\n        if len(radians) >= 3:\n            sin_, cos_ = sin_func(radians[2]), cos_func(radians[2])\n            if affine is None:\n                raise ValueError(\"Affine should be a matrix.\")\n            _affine = eye_func(4)\n            _affine[0, 0], _affine[0, 1] = cos_, -sin_\n            _affine[1, 0], _affine[1, 1] = sin_, cos_\n            affine = affine @ _affine\n        if affine is None:\n            raise ValueError(\"radians must be non empty.\")\n        return affine  # type: ignore\n\n    raise ValueError(f\"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].\")\n\n\ndef create_shear(\n    spatial_dims: int,\n    coefs: Sequence[float] | float,\n    device: torch.device | None = None,\n    backend=TransformBackends.NUMPY,\n) -> NdarrayOrTensor:\n    \"\"\"\n    create a shearing matrix\n\n    Args:\n        spatial_dims: spatial rank\n        coefs: shearing factors, a tuple of 2 floats for 2D, a tuple of 6 floats for 3D),\n            take a 3D affine as example::\n\n                [\n                    [1.0, coefs[0], coefs[1], 0.0],\n                    [coefs[2], 1.0, coefs[3], 0.0],\n                    [coefs[4], coefs[5], 1.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n\n        device: device to compute and store the output (when the backend is \"torch\").\n        backend: APIs to use, ``numpy`` or ``torch``.\n\n    Raises:\n        NotImplementedError: When ``spatial_dims`` is not one of [2, 3].\n\n    \"\"\"\n    _backend = look_up_option(backend, TransformBackends)\n    if _backend == TransformBackends.NUMPY:\n        return _create_shear(spatial_dims=spatial_dims, coefs=coefs, eye_func=np.eye)\n    if _backend == TransformBackends.TORCH:\n        return _create_shear(\n            spatial_dims=spatial_dims, coefs=coefs, eye_func=lambda rank: torch.eye(rank, device=device)\n        )\n    raise ValueError(f\"backend {backend} is not supported\")\n\n\ndef _create_shear(spatial_dims: int, coefs: Sequence[float] | float, eye_func=np.eye) -> NdarrayOrTensor:\n    if spatial_dims == 2:\n        coefs = ensure_tuple_size(coefs, dim=2, pad_val=0.0)\n        out = eye_func(3)\n        out[0, 1], out[1, 0] = coefs[0], coefs[1]\n        return out  # type: ignore\n    if spatial_dims == 3:\n        coefs = ensure_tuple_size(coefs, dim=6, pad_val=0.0)\n        out = eye_func(4)\n        out[0, 1], out[0, 2] = coefs[0], coefs[1]\n        out[1, 0], out[1, 2] = coefs[2], coefs[3]\n        out[2, 0], out[2, 1] = coefs[4], coefs[5]\n        return out  # type: ignore\n    raise NotImplementedError(\"Currently only spatial_dims in [2, 3] are supported.\")\n\n\ndef create_scale(\n    spatial_dims: int,\n    scaling_factor: Sequence[float] | float,\n    device: torch.device | str | None = None,\n    backend=TransformBackends.NUMPY,\n) -> NdarrayOrTensor:\n    \"\"\"\n    create a scaling matrix\n\n    Args:\n        spatial_dims: spatial rank\n        scaling_factor: scaling factors for every spatial dim, defaults to 1.\n        device: device to compute and store the output (when the backend is \"torch\").\n        backend: APIs to use, ``numpy`` or ``torch``.\n    \"\"\"\n    _backend = look_up_option(backend, TransformBackends)\n    if _backend == TransformBackends.NUMPY:\n        return _create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor, array_func=np.diag)\n    if _backend == TransformBackends.TORCH:\n        return _create_scale(\n            spatial_dims=spatial_dims,\n            scaling_factor=scaling_factor,\n            array_func=lambda x: torch.diag(torch.as_tensor(x, device=device)),\n        )\n    raise ValueError(f\"backend {backend} is not supported\")\n\n\ndef _create_scale(spatial_dims: int, scaling_factor: Sequence[float] | float, array_func=np.diag) -> NdarrayOrTensor:\n    scaling_factor = ensure_tuple_size(scaling_factor, dim=spatial_dims, pad_val=1.0)\n    return array_func(scaling_factor[:spatial_dims] + (1.0,))  # type: ignore\n\n\ndef create_translate(\n    spatial_dims: int,\n    shift: Sequence[float] | float,\n    device: torch.device | None = None,\n    backend=TransformBackends.NUMPY,\n) -> NdarrayOrTensor:\n    \"\"\"\n    create a translation matrix\n\n    Args:\n        spatial_dims: spatial rank\n        shift: translate pixel/voxel for every spatial dim, defaults to 0.\n        device: device to compute and store the output (when the backend is \"torch\").\n        backend: APIs to use, ``numpy`` or ``torch``.\n    \"\"\"\n    _backend = look_up_option(backend, TransformBackends)\n    spatial_dims = int(spatial_dims)\n    if _backend == TransformBackends.NUMPY:\n        return _create_translate(spatial_dims=spatial_dims, shift=shift, eye_func=np.eye, array_func=np.asarray)\n    if _backend == TransformBackends.TORCH:\n        return _create_translate(\n            spatial_dims=spatial_dims,\n            shift=shift,\n            eye_func=lambda x: torch.eye(torch.as_tensor(x), device=device),  # type: ignore\n            array_func=lambda x: torch.as_tensor(x, device=device),\n        )\n    raise ValueError(f\"backend {backend} is not supported\")\n\n\ndef _create_translate(\n    spatial_dims: int, shift: Sequence[float] | float, eye_func=np.eye, array_func=np.asarray\n) -> NdarrayOrTensor:\n    shift = ensure_tuple(shift)\n    affine = eye_func(spatial_dims + 1)\n    for i, a in enumerate(shift[:spatial_dims]):\n        affine[i, spatial_dims] = a\n    return array_func(affine)  # type: ignore\n\n\ndef generate_spatial_bounding_box(\n    img: NdarrayOrTensor,\n    select_fn: Callable = is_positive,\n    channel_indices: IndexSelection | None = None,\n    margin: Sequence[int] | int = 0,\n    allow_smaller: bool = False,\n) -> tuple[list[int], list[int]]:\n    \"\"\"\n    Generate the spatial bounding box of foreground in the image with start-end positions (inclusive).\n    Users can define arbitrary function to select expected foreground from the whole image or specified channels.\n    And it can also add margin to every dim of the bounding box.\n    The output format of the coordinates is:\n\n        [1st_spatial_dim_start, 2nd_spatial_dim_start, ..., Nth_spatial_dim_start],\n        [1st_spatial_dim_end, 2nd_spatial_dim_end, ..., Nth_spatial_dim_end]\n\n    This function returns [0, 0, ...], [0, 0, ...] if there's no positive intensity.\n\n    Args:\n        img: a \"channel-first\" image of shape (C, spatial_dim1[, spatial_dim2, ...]) to generate bounding box from.\n        select_fn: function to select expected foreground, default is to select values > 0.\n        channel_indices: if defined, select foreground only on the specified channels\n            of image. if None, select foreground on the whole image.\n        margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.\n        allow_smaller: when computing box size with `margin`, whether to allow the image edges to be smaller than the\n            final box edges. If `True`, the bounding boxes edges are aligned with the input image edges, if `False`,\n            the bounding boxes edges are aligned with the final box edges. Default to `False`.\n            The default value is changed from `True` to `False` in v1.5.0.\n\n    \"\"\"\n    check_non_lazy_pending_ops(img, name=\"generate_spatial_bounding_box\")\n    spatial_size = img.shape[1:]\n    data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img\n    data = select_fn(data).any(0)\n    ndim = len(data.shape)\n    margin = ensure_tuple_rep(margin, ndim)\n    for m in margin:\n        if m < 0:\n            raise ValueError(f\"margin value should not be negative number, got {margin}.\")\n\n    box_start = [0] * ndim\n    box_end = [0] * ndim\n\n    for di, ax in enumerate(itertools.combinations(reversed(range(ndim)), ndim - 1)):\n        dt = data\n        if len(ax) != 0:\n            dt = any_np_pt(dt, ax)\n\n        if not dt.any():\n            # if no foreground, return all zero bounding box coords\n            return [0] * ndim, [0] * ndim\n\n        arg_max = where(dt == dt.max())[0]\n        min_d = arg_max[0] - margin[di]\n        max_d = arg_max[-1] + margin[di] + 1\n        if allow_smaller:\n            min_d = max(min_d, 0)\n            max_d = min(max_d, spatial_size[di])\n\n        box_start[di] = min_d.detach().cpu().item() if isinstance(min_d, torch.Tensor) else min_d\n        box_end[di] = max_d.detach().cpu().item() if isinstance(max_d, torch.Tensor) else max_d\n\n    return box_start, box_end\n\n\ndef get_largest_connected_component_mask(\n    img: NdarrayTensor, connectivity: int | None = None, num_components: int = 1\n) -> NdarrayTensor:\n    \"\"\"\n    Gets the largest connected component mask of an image.\n\n    Args:\n        img: Image to get largest connected component from. Shape is (spatial_dim1 [, spatial_dim2, ...])\n        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n            Accepted values are ranging from  1 to input.ndim. If ``None``, a full\n            connectivity of ``input.ndim`` is used. for more details:\n            https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.\n        num_components: The number of largest components to preserve.\n    \"\"\"\n    # use skimage/cucim.skimage and np/cp depending on whether packages are\n    # available and input is non-cpu torch.tensor\n    skimage, has_cucim = optional_import(\"cucim.skimage\")\n    use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device(\"cpu\")\n    if use_cp:\n        img_ = convert_to_cupy(img.short())  # type: ignore\n        label = skimage.measure.label\n        lib = cp\n    else:\n        if not has_measure:\n            raise RuntimeError(\"Skimage.measure required.\")\n        img_, *_ = convert_data_type(img, np.ndarray)\n        label = measure.label\n        lib = np\n\n    # features will be an image -- 0 for background and then each different\n    # feature will have its own index.\n    features, num_features = label(img_, connectivity=connectivity, return_num=True)\n    # if num features less than max desired, nothing to do.\n    if num_features <= num_components:\n        out = img_.astype(bool)\n    else:\n        # ignore background\n        nonzeros = features[lib.nonzero(features)]\n        # get number voxels per feature (bincount). argsort[::-1] to get indices\n        # of largest components.\n        features_to_keep = lib.argsort(lib.bincount(nonzeros))[::-1]\n        # only keep the first n non-background indices\n        features_to_keep = features_to_keep[:num_components]\n        # generate labelfield. True if in list of features to keep\n        out = lib.isin(features, features_to_keep)\n\n    return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]\n\n\ndef keep_merge_components_with_points(\n    img_pos: NdarrayTensor,\n    img_neg: NdarrayTensor,\n    point_coords: NdarrayTensor,\n    point_labels: NdarrayTensor,\n    pos_val: Sequence[int] = (1, 3),\n    neg_val: Sequence[int] = (0, 2),\n    margins: int = 3,\n) -> NdarrayTensor:\n    \"\"\"\n    Keep connected regions of img_pos and img_neg that include the positive points and\n    negative points separately. The function is used for merging automatic results with interactive\n    results in VISTA3D.\n\n    Args:\n        img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image.\n        img_neg: same format as img_pos but corresponds to negative points.\n        pos_val: positive point label values.\n        neg_val: negative point label values.\n        point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.\n        point_labels: the label of each point, shape [B, N].\n        margins: include points outside of the region but within the margin.\n    \"\"\"\n\n    cucim_skimage, has_cucim = optional_import(\"cucim.skimage\")\n\n    use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device(\"cpu\")\n    if use_cp:\n        img_pos_ = convert_to_cupy(img_pos.short())  # type: ignore\n        img_neg_ = convert_to_cupy(img_neg.short())  # type: ignore\n        label = cucim_skimage.measure.label\n        lib = cp\n    else:\n        if not has_measure:\n            raise RuntimeError(\"skimage.measure required.\")\n        img_pos_, *_ = convert_data_type(img_pos, np.ndarray)\n        img_neg_, *_ = convert_data_type(img_neg, np.ndarray)\n        # for skimage.measure.label, the input must be bool type\n        if img_pos_.dtype != bool or img_neg_.dtype != bool:\n            raise ValueError(\"img_pos and img_neg must be bool type.\")\n        label = measure.label\n        lib = np\n\n    features_pos, _ = label(img_pos_, connectivity=3, return_num=True)\n    features_neg, _ = label(img_neg_, connectivity=3, return_num=True)\n\n    outs = np.zeros_like(img_pos_)\n    for bs in range(point_coords.shape[0]):\n        for i, p in enumerate(point_coords[bs]):\n            if point_labels[bs, i] in pos_val:\n                features = features_pos\n            elif point_labels[bs, i] in neg_val:\n                features = features_neg\n            else:\n                # if -1 padding point, skip\n                continue\n            for margin in range(margins):\n                if isinstance(p, np.ndarray):\n                    x, y, z = np.round(p).astype(int).tolist()\n                else:\n                    x, y, z = p.float().round().int().tolist()\n                l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3])\n                t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2])\n                f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])\n                if (features[bs, 0, l:r, t:d, f:b] > 0).any():\n                    index = features[bs, 0, l:r, t:d, f:b].max()\n                    outs[[bs]] += lib.isin(features[[bs]], index)\n                    break\n    outs[outs > 1] = 1\n    return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]\n\n\ndef keep_components_with_positive_points(\n    img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor\n) -> torch.Tensor:\n    \"\"\"\n    Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove\n    regions without positive points.\n    Args:\n        img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value.\n        point_coords: [B, N, 3]. Point click coordinates\n        point_labels: [B, N]. Point click labels.\n    \"\"\"\n    if not has_measure:\n        raise RuntimeError(\"skimage.measure required.\")\n    outs = torch.zeros_like(img)\n    for c in range(len(point_coords)):\n        if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()):\n            # skip if no positive points.\n            continue\n        coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist()\n        not_nan_mask = ~torch.isnan(img[0, c])\n        img_ = torch.nan_to_num(img[0, c] > 0, 0)\n        img_, *_ = convert_data_type(img_, np.ndarray)  # type: ignore\n        label = measure.label\n        features = label(img_, connectivity=3)\n        pos_mask = torch.from_numpy(img_).to(img.device) > 0\n        # if num features less than max desired, nothing to do.\n        features = torch.from_numpy(features).to(img.device)\n        # generate a map with all pos points\n        idx = []\n        for p in coords:\n            idx.append(features[round(p[0]), round(p[1]), round(p[2])].item())\n        idx = list(set(idx))\n        for i in idx:\n            if i == 0:\n                continue\n            outs[0, c] += features == i\n        outs = outs > 0\n        # find negative mean value\n        fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean()\n        img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in\n    return img\n\n\ndef convert_points_to_disc(\n    image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False\n):\n    \"\"\"\n    Convert a 3D point coordinates into image mask. The returned mask has the same spatial\n    size as `image_size` while the batch dimension is the same as 'point' batch dimension.\n    The point is converted to a mask ball with radius defined by `radius`. The output\n    contains two channels each for negative (first channel) and positive points.\n\n    Args:\n        image_size: The output size of the converted mask. It should be a 3D tuple.\n        point: [B, N, 3], 3D point coordinates.\n        point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points.\n        radius: disc ball radius size.\n        disc: If true, use regular disc, other use gaussian.\n    \"\"\"\n    masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device)\n    _array = [\n        torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)\n    ]\n    coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2])\n    # [1, 3, h, w, d] -> [b, 2, 3, h, w, d]\n    coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)\n    coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)\n    for b, n in np.ndindex(*point.shape[:2]):\n        point_bn = unsqueeze_right(point[b, n], 4)\n        if point_label[b, n] > -1:\n            channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1\n            pow_diff = torch.pow(coords[b, channel] - point_bn, 2)\n            if disc:\n                masks[b, channel] += pow_diff.sum(0) < radius**2\n            else:\n                masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2))\n    return masks\n\n\ndef sample_points_from_label(\n    labels: Tensor,\n    label_set: Sequence[int],\n    max_ppoint: int = 1,\n    max_npoint: int = 0,\n    device: torch.device | str | None = \"cpu\",\n    use_center: bool = False,\n):\n    \"\"\"Sample points from labels.\n\n    Args:\n        labels: [1, 1, H, W, D]\n        label_set: local index, must match values in labels.\n        max_ppoint: maximum positive point samples.\n        max_npoint: maximum negative point samples.\n        device: returned tensor device.\n        use_center: whether to sample points from center.\n\n    Returns:\n        point: point coordinates of [B, N, 3]. B equals to the length of label_set.\n        point_label: [B, N], always 0 for negative, 1 for positive.\n    \"\"\"\n    if not labels.shape[0] == 1:\n        raise ValueError(\"labels must have batch size 1.\")\n\n    if device is None:\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    labels = labels[0, 0]\n    unique_labels = labels.unique().cpu().numpy().tolist()\n    _point = []\n    _point_label = []\n    for id in label_set:\n        if id in unique_labels:\n            plabels = labels == int(id)\n            nlabels = ~plabels\n            _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0])\n            plabelpoints = torch.nonzero(_plabels).to(device)\n            if len(plabelpoints) == 0:\n                plabelpoints = torch.nonzero(plabels).to(device)\n            nlabelpoints = torch.nonzero(nlabels).to(device)\n            num_p = min(len(plabelpoints), max_ppoint)\n            num_n = min(len(nlabelpoints), max_npoint)\n            pad = max_ppoint + max_npoint - num_p - num_n\n            if use_center:\n                pmean = plabelpoints.float().mean(0)\n                pdis = ((plabelpoints - pmean) ** 2).sum(-1)\n                _, sorted_indices_tensor = torch.sort(pdis)\n                sorted_indices = sorted_indices_tensor.cpu().tolist()\n            else:\n                sorted_indices = list(range(len(plabelpoints)))\n                random.shuffle(sorted_indices)\n            _point.append(\n                torch.stack(\n                    [plabelpoints[sorted_indices[i]] for i in range(num_p)]\n                    + random.choices(nlabelpoints, k=num_n)\n                    + [torch.tensor([0, 0, 0], device=device)] * pad\n                )\n            )\n            _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device))\n        else:\n            # pad the background labels\n            _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device))\n            _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1)\n    point = torch.stack(_point)\n    point_label = torch.stack(_point_label)\n\n    return point, point_label\n\n\ndef remove_small_objects(\n    img: NdarrayTensor,\n    min_size: int = 64,\n    connectivity: int = 1,\n    independent_channels: bool = True,\n    by_measure: bool = False,\n    pixdim: Sequence[float] | float | np.ndarray | None = None,\n) -> NdarrayTensor:\n    \"\"\"\n    Use `skimage.morphology.remove_small_objects` to remove small objects from images.\n    See: https://scikit-image.org/docs/dev/api/skimage.morphology.html#remove-small-objects.\n\n    Data should be one-hotted.\n\n    Args:\n        img: image to process. Expected shape: C, H,W,[D]. Expected to only have singleton channel dimension,\n            i.e., not be one-hotted. Converted to type int.\n        min_size: objects smaller than this size are removed.\n        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.\n            Accepted values are ranging from  1 to input.ndim. If ``None``, a full\n            connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image\n            documentation.\n        independent_channels: Whether to consider each channel independently.\n        by_measure: Whether the specified min_size is in number of voxels. if this is True then min_size\n            represents a surface area or volume value of whatever units your image is in (mm^3, cm^2, etc.)\n            default is False.\n        pixdim: the pixdim of the input image. if a single number, this is used for all axes.\n            If a sequence of numbers, the length of the sequence must be equal to the image dimensions.\n    \"\"\"\n    # if all equal to one value, no need to call skimage\n    if len(unique(img)) == 1:\n        return img\n\n    if not has_morphology:\n        raise RuntimeError(\"Skimage required.\")\n\n    if by_measure:\n        sr = len(img.shape[1:])\n        if isinstance(img, monai.data.MetaTensor):\n            _pixdim = img.pixdim\n        elif pixdim is not None:\n            _pixdim = ensure_tuple_rep(pixdim, sr)\n        else:\n            warnings.warn(\"`img` is not of type MetaTensor and `pixdim` is None, assuming affine to be identity.\")\n            _pixdim = (1.0,) * sr\n        voxel_volume = np.prod(np.array(_pixdim))\n        if voxel_volume == 0:\n            warnings.warn(\"Invalid `pixdim` value detected, set it to 1. Please verify the pixdim settings.\")\n            voxel_volume = 1\n        min_size = np.ceil(min_size / voxel_volume)\n    elif pixdim is not None:\n        warnings.warn(\"`pixdim` is specified but not in use when computing the volume.\")\n\n    img_np: np.ndarray\n    img_np, *_ = convert_data_type(img, np.ndarray)\n\n    # morphology.remove_small_objects assumes them to be independent by default\n    # else, convert to foreground vs background, remove small objects, then convert\n    # back by multiplying the output by the input\n    if not independent_channels:\n        img_np = img_np > 0\n    else:\n        # if binary, convert to boolean, else int\n        img_np = img_np.astype(bool if img_np.max() <= 1 else np.int32)\n\n    out_np = morphology.remove_small_objects(img_np, min_size, connectivity)\n    out, *_ = convert_to_dst_type(out_np, img)\n\n    # convert back by multiplying\n    if not independent_channels:\n        out = img * out  # type: ignore\n    return out\n\n\ndef get_unique_labels(img: NdarrayOrTensor, is_onehot: bool, discard: int | Iterable[int] | None = None) -> set[int]:\n    \"\"\"Get list of non-background labels in an image.\n\n    Args:\n        img: Image to be processed. Shape should be [C, W, H, [D]] with C=1 if not onehot else `num_classes`.\n        is_onehot: Boolean as to whether input image is one-hotted. If one-hotted, only return channels with\n        discard: Can be used to remove labels (e.g., background). Can be any value, sequence of values, or\n            `None` (nothing is discarded).\n\n    Returns:\n        Set of labels\n    \"\"\"\n    applied_labels: set[int]\n    n_channels = img.shape[0]\n    if is_onehot:\n        applied_labels = {i for i, s in enumerate(img) if s.sum() > 0}\n    else:\n        if n_channels != 1:\n            raise ValueError(f\"If input not one-hotted, should only be 1 channel, got {n_channels}.\")\n        applied_labels = set(unique(img).tolist())\n    if discard is not None:\n        for i in ensure_tuple(discard):\n            applied_labels.discard(i)\n    return applied_labels\n\n\ndef fill_holes(\n    img_arr: np.ndarray, applied_labels: Iterable[int] | None = None, connectivity: int | None = None\n) -> np.ndarray:\n    \"\"\"\n    Fill the holes in the provided image.\n\n    The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label.\n    What is considered to be an enclosed hole is defined by the connectivity.\n    Holes on the edge are always considered to be open (not enclosed).\n\n    Note:\n\n        The performance of this method heavily depends on the number of labels.\n        It is a bit faster if the list of `applied_labels` is provided.\n        Limiting the number of `applied_labels` results in a big decrease in processing time.\n\n        If the image is one-hot-encoded, then the `applied_labels` need to match the channel index.\n\n    Args:\n        img_arr: numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].\n        applied_labels: Labels for which to fill holes. Defaults to None,\n            that is filling holes for all labels.\n        connectivity: Maximum number of orthogonal hops to\n            consider a pixel/voxel as a neighbor. Accepted values are ranging from  1 to input.ndim.\n            Defaults to a full connectivity of ``input.ndim``.\n\n    Returns:\n        numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].\n    \"\"\"\n    channel_axis = 0\n    num_channels = img_arr.shape[channel_axis]\n    is_one_hot = num_channels > 1\n    spatial_dims = img_arr.ndim - 1\n    structure = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims)\n\n    # Get labels if not provided. Exclude background label.\n    applied_labels = set(applied_labels) if applied_labels is not None else get_unique_labels(img_arr, is_one_hot)\n    background_label = 0\n    applied_labels.discard(background_label)\n\n    for label in applied_labels:\n        tmp = np.zeros(img_arr.shape[1:], dtype=bool)\n        ndimage.binary_dilation(\n            tmp,\n            structure=structure,\n            iterations=-1,\n            mask=np.logical_not(img_arr[label]) if is_one_hot else img_arr[0] != label,\n            origin=0,\n            border_value=1,\n            output=tmp,\n        )\n        if is_one_hot:\n            img_arr[label] = np.logical_not(tmp)\n        else:\n            img_arr[0, np.logical_not(tmp)] = label\n\n    return img_arr\n\n\ndef get_extreme_points(\n    img: NdarrayOrTensor, rand_state: np.random.RandomState | None = None, background: int = 0, pert: float = 0.0\n) -> list[tuple[int, ...]]:\n    \"\"\"\n    Generate extreme points from an image. These are used to generate initial segmentation\n    for annotation models. An optional perturbation can be passed to simulate user clicks.\n\n    Args:\n        img:\n            Image to generate extreme points from. Expected Shape is ``(spatial_dim1, [, spatial_dim2, ...])``.\n        rand_state: `np.random.RandomState` object used to select random indices.\n        background: Value to be consider as background, defaults to 0.\n        pert: Random perturbation amount to add to the points, defaults to 0.0.\n\n    Returns:\n        A list of extreme points, its length is equal to 2 * spatial dimension of input image.\n        The output format of the coordinates is:\n\n        [1st_spatial_dim_min, 1st_spatial_dim_max, 2nd_spatial_dim_min, ..., Nth_spatial_dim_max]\n\n    Raises:\n        ValueError: When the input image does not have any foreground pixel.\n    \"\"\"\n    check_non_lazy_pending_ops(img, name=\"get_extreme_points\")\n    if rand_state is None:\n        rand_state = np.random.random.__self__  # type: ignore\n    indices = where(img != background)\n    if np.size(indices[0]) == 0:\n        raise ValueError(\"get_extreme_points: no foreground object in mask!\")\n\n    def _get_point(val, dim):\n        \"\"\"\n        Select one of the indices within slice containing val.\n\n        Args:\n            val : value for comparison\n            dim : dimension in which to look for value\n        \"\"\"\n        idx = where(indices[dim] == val)[0]\n        idx = idx.cpu() if isinstance(idx, torch.Tensor) else idx\n        idx = rand_state.choice(idx) if rand_state is not None else idx\n        pt = []\n        for j in range(img.ndim):\n            # add +- pert to each dimension\n            val = int(indices[j][idx] + 2.0 * pert * (rand_state.rand() if rand_state is not None else 0.5 - 0.5))\n            val = max(val, 0)\n            val = min(val, img.shape[j] - 1)\n            pt.append(val)\n        return pt\n\n    points = []\n    for i in range(img.ndim):\n        points.append(tuple(_get_point(indices[i].min(), i)))\n        points.append(tuple(_get_point(indices[i].max(), i)))\n\n    return points\n\n\ndef extreme_points_to_image(\n    points: list[tuple[int, ...]],\n    label: NdarrayOrTensor,\n    sigma: Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor = 0.0,\n    rescale_min: float = -1.0,\n    rescale_max: float = 1.0,\n) -> torch.Tensor:\n    \"\"\"\n    Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage.\n\n    Applies a gaussian filter to the extreme points image. Then the pixel values in points image are rescaled\n    to range [rescale_min, rescale_max].\n\n    Args:\n        points: Extreme points of the object/organ.\n        label: label image to get extreme points from. Shape must be\n            (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels.\n        sigma: if a list of values, must match the count of spatial dimensions of input data,\n            and apply every value in the list to 1 spatial dimension. if only 1 value provided,\n            use it for all spatial dimensions.\n        rescale_min: minimum value of output data.\n        rescale_max: maximum value of output data.\n    \"\"\"\n    # points to image\n    # points_image = torch.zeros(label.shape[1:], dtype=torch.float)\n    points_image = torch.zeros_like(torch.as_tensor(label[0]), dtype=torch.float)\n    for p in points:\n        points_image[p] = 1.0\n\n    if isinstance(sigma, Sequence):\n        sigma = [torch.as_tensor(s, device=points_image.device) for s in sigma]\n    else:\n        sigma = torch.as_tensor(sigma, device=points_image.device)\n\n    # add channel and add batch\n    points_image = points_image.unsqueeze(0).unsqueeze(0)\n    gaussian_filter = GaussianFilter(label.ndim - 1, sigma=sigma)\n    points_image = gaussian_filter(points_image).squeeze(0).detach()\n\n    # rescale the points image to [rescale_min, rescale_max]\n    min_intensity = points_image.min()\n    max_intensity = points_image.max()\n    points_image = (points_image - min_intensity) / (max_intensity - min_intensity)\n    return points_image * (rescale_max - rescale_min) + rescale_min\n\n\ndef map_spatial_axes(\n    img_ndim: int, spatial_axes: Sequence[int] | int | None = None, channel_first: bool = True\n) -> list[int]:\n    \"\"\"\n    Utility to map the spatial axes to real axes in channel first/last shape.\n    For example:\n    If `channel_first` is True, and `img` has 3 spatial dims, map spatial axes to real axes as below:\n    None -> [1, 2, 3]\n    [0, 1] -> [1, 2]\n    [0, -1] -> [1, -1]\n    If `channel_first` is False, and `img` has 3 spatial dims, map spatial axes to real axes as below:\n    None -> [0, 1, 2]\n    [0, 1] -> [0, 1]\n    [0, -1] -> [0, -2]\n\n    Args:\n        img_ndim: dimension number of the target image.\n        spatial_axes: spatial axes to be converted, default is None.\n            The default `None` will convert to all the spatial axes of the image.\n            If axis is negative it counts from the last to the first axis.\n            If axis is a tuple of ints.\n        channel_first: the image data is channel first or channel last, default to channel first.\n\n    \"\"\"\n    if spatial_axes is None:\n        return list(range(1, img_ndim) if channel_first else range(img_ndim - 1))\n    spatial_axes_ = []\n    for a in ensure_tuple(spatial_axes):\n        if channel_first:\n            spatial_axes_.append(a % img_ndim if a < 0 else a + 1)\n        else:\n            spatial_axes_.append((a - 1) % (img_ndim - 1) if a < 0 else a)\n    return spatial_axes_\n\n\n@contextmanager\ndef allow_missing_keys_mode(transform: MapTransform | Compose | tuple[MapTransform] | tuple[Compose]):\n    \"\"\"Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states.\n\n    Args:\n        transform: either MapTransform or a Compose\n\n    Example:\n\n    .. code-block:: python\n\n        data = {\"image\": np.arange(16, dtype=float).reshape(1, 4, 4)}\n        t = SpatialPadd([\"image\", \"label\"], 10, allow_missing_keys=False)\n        _ = t(data)  # would raise exception\n        with allow_missing_keys_mode(t):\n            _ = t(data)  # OK!\n    \"\"\"\n    # If given a sequence of transforms, Compose them to get a single list\n    if issequenceiterable(transform):\n        transform = Compose(transform)\n\n    # Get list of MapTransforms\n    transforms = []\n    if isinstance(transform, MapTransform):\n        transforms = [transform]\n    elif isinstance(transform, Compose):\n        # Only keep contained MapTransforms\n        transforms = [t for t in transform.flatten().transforms if isinstance(t, MapTransform)]\n    if len(transforms) == 0:\n        raise TypeError(\n            \"allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)\"\n        )\n\n    # Get the state of each `allow_missing_keys`\n    orig_states = [t.allow_missing_keys for t in transforms]\n\n    try:\n        # Set all to True\n        for t in transforms:\n            t.allow_missing_keys = True\n        yield\n    finally:\n        # Revert\n        for t, o_s in zip(transforms, orig_states):\n            t.allow_missing_keys = o_s\n\n\n_interp_modes = list(InterpolateMode) + list(GridSampleMode)\n\n\ndef convert_applied_interp_mode(trans_info, mode: str = \"nearest\", align_corners: bool | None = None):\n    \"\"\"\n    Recursively change the interpolation mode in the applied operation stacks, default to \"nearest\".\n\n    See also: :py:class:`monai.transform.inverse.InvertibleTransform`\n\n    Args:\n        trans_info: applied operation stack, tracking the previously applied invertible transform.\n        mode: target interpolation mode to convert, default to \"nearest\" as it's usually used to save the mode output.\n        align_corners: target align corner value in PyTorch interpolation API, need to align with the `mode`.\n\n    \"\"\"\n    if isinstance(trans_info, (list, tuple)):\n        return [convert_applied_interp_mode(x, mode=mode, align_corners=align_corners) for x in trans_info]\n    if not isinstance(trans_info, Mapping):\n        return trans_info\n    trans_info = dict(trans_info)\n    if \"mode\" in trans_info:\n        current_mode = trans_info[\"mode\"]\n        if isinstance(current_mode, int) or current_mode in _interp_modes:\n            trans_info[\"mode\"] = mode\n        elif isinstance(current_mode[0], int) or current_mode[0] in _interp_modes:\n            trans_info[\"mode\"] = [mode for _ in range(len(mode))]\n    if \"align_corners\" in trans_info:\n        _align_corners = TraceKeys.NONE if align_corners is None else align_corners\n        current_value = trans_info[\"align_corners\"]\n        trans_info[\"align_corners\"] = (\n            [_align_corners for _ in mode] if issequenceiterable(current_value) else _align_corners\n        )\n    if (\"mode\" not in trans_info) and (\"align_corners\" not in trans_info):\n        return {\n            k: convert_applied_interp_mode(trans_info[k], mode=mode, align_corners=align_corners) for k in trans_info\n        }\n    return trans_info\n\n\ndef reset_ops_id(data):\n    \"\"\"find MetaTensors in list or dict `data` and (in-place) set ``TraceKeys.ID`` to ``Tracekeys.NONE``.\"\"\"\n    if isinstance(data, (list, tuple)):\n        return [reset_ops_id(d) for d in data]\n    if isinstance(data, monai.data.MetaTensor):\n        data.applied_operations = reset_ops_id(data.applied_operations)\n        return data\n    if not isinstance(data, Mapping):\n        return data\n    data = dict(data)\n    if TraceKeys.ID in data:\n        data[TraceKeys.ID] = TraceKeys.NONE\n    return {k: reset_ops_id(v) for k, v in data.items()}\n\n\ndef compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Sequence[int] | int):\n    \"\"\"\n    Compute the target spatial size which should be divisible by `k`.\n\n    Args:\n        spatial_shape: original spatial shape.\n        k: the target k for each spatial dimension.\n            if `k` is negative or 0, the original size is preserved.\n            if `k` is an int, the same `k` be applied to all the input spatial dimensions.\n\n    \"\"\"\n    k = fall_back_tuple(k, (1,) * len(spatial_shape))\n    new_size = []\n    for k_d, dim in zip(k, spatial_shape):\n        new_dim = int(np.ceil(dim / k_d) * k_d) if k_d > 0 else dim\n        new_size.append(new_dim)\n\n    return tuple(new_size)\n\n\ndef equalize_hist(\n    img: np.ndarray, mask: np.ndarray | None = None, num_bins: int = 256, min: int = 0, max: int = 255\n) -> np.ndarray:\n    \"\"\"\n    Utility to equalize input image based on the histogram.\n    If `skimage` installed, will leverage `skimage.exposure.histogram`, otherwise, use\n    `np.histogram` instead.\n\n    Args:\n        img: input image to equalize.\n        mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.\n            only points at which `mask==True` are used for the equalization.\n        num_bins: number of the bins to use in histogram, default to `256`. for more details:\n            https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.\n        min: the min value to normalize input image, default to `0`.\n        max: the max value to normalize input image, default to `255`.\n\n    \"\"\"\n\n    orig_shape = img.shape\n    hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img\n    if has_skimage:\n        hist, bins = exposure.histogram(hist_img.flatten(), num_bins)\n    else:\n        hist, bins = np.histogram(hist_img.flatten(), num_bins)\n        bins = (bins[:-1] + bins[1:]) / 2\n\n    cum = hist.cumsum()\n    # normalize the cumulative result\n    cum = rescale_array(arr=cum, minv=min, maxv=max)\n\n    # apply linear interpolation\n    img = np.interp(img.flatten(), bins, cum)\n    return img.reshape(orig_shape)\n\n\nclass Fourier:\n    \"\"\"\n    Helper class storing Fourier mappings\n    \"\"\"\n\n    @staticmethod\n    def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:\n        \"\"\"\n        Applies fourier transform and shifts the zero-frequency component to the\n        center of the spectrum. Only the spatial dimensions get transformed.\n\n        Args:\n            x: Image to transform.\n            spatial_dims: Number of spatial dimensions.\n            as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n\n        Returns\n            k: K-space data.\n        \"\"\"\n        dims = tuple(range(-spatial_dims, 0))\n        k: NdarrayOrTensor\n        if isinstance(x, torch.Tensor):\n            k = torch.fft.fftshift(torch.fft.fftn(x, dim=dims), dim=dims)\n        else:\n            k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)\n        return ascontiguousarray(k) if as_contiguous else k\n\n    @staticmethod\n    def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:\n        \"\"\"\n        Applies inverse shift and fourier transform. Only the spatial\n        dimensions are transformed.\n\n        Args:\n            k: K-space data.\n            spatial_dims: Number of spatial dimensions.\n            as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.\n\n        Returns:\n            x: Tensor in image space.\n        \"\"\"\n        dims = tuple(range(-spatial_dims, 0))\n        out: NdarrayOrTensor\n        if isinstance(k, torch.Tensor):\n            out = torch.fft.ifftn(torch.fft.ifftshift(k, dim=dims), dim=dims, norm=\"backward\").real\n        else:\n            out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real\n        return ascontiguousarray(out) if as_contiguous else out\n\n\ndef get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int:\n    \"\"\"\n    Get the number of times that the data need to be converted (e.g., numpy to torch).\n    Conversions between different devices are also counted (e.g., CPU to GPU).\n\n    Args:\n        transform: composed transforms to be tested\n        test_data: data to be used to count the number of conversions\n        key: if using dictionary transforms, this key will be used to check the number of conversions.\n    \"\"\"\n    from monai.transforms.compose import OneOf\n\n    def _get_data(obj, key):\n        return obj if key is None else obj[key]\n\n    # if the starting point is a string (e.g., input to LoadImage), start\n    # at -1 since we don't want to count the string -> image conversion.\n    num_conversions = 0 if not isinstance(_get_data(test_data, key), str) else -1\n\n    tr = transform.flatten().transforms\n\n    if isinstance(transform, OneOf) or any(isinstance(i, OneOf) for i in tr):\n        raise RuntimeError(\"Not compatible with `OneOf`, as the applied transform is deterministically chosen.\")\n\n    for _transform in tr:\n        prev_data = _get_data(test_data, key)\n        prev_type = type(prev_data)\n        prev_device = prev_data.device if isinstance(prev_data, torch.Tensor) else None\n        test_data = apply_transform(_transform, test_data, transform.map_items, transform.unpack_items)\n        # every time the type or device changes, increment the counter\n        curr_data = _get_data(test_data, key)\n        curr_device = curr_data.device if isinstance(curr_data, torch.Tensor) else None\n        if not isinstance(curr_data, prev_type) or curr_device != prev_device:\n            num_conversions += 1\n    return num_conversions\n\n\ndef get_transform_backends():\n    \"\"\"Get the backends of all MONAI transforms.\n\n    Returns:\n        Dictionary, where each key is a transform, and its\n        corresponding values are a boolean list, stating\n        whether that transform supports (1) `torch.Tensor`,\n        and (2) `np.ndarray` as input without needing to\n        convert.\n    \"\"\"\n    backends = {}\n    unique_transforms = []\n    for n, obj in getmembers(monai.transforms):\n        # skip aliases\n        if obj in unique_transforms:\n            continue\n        unique_transforms.append(obj)\n\n        if (\n            isclass(obj)\n            and issubclass(obj, Transform)\n            and n\n            not in [\n                \"BatchInverseTransform\",\n                \"Compose\",\n                \"CuCIM\",\n                \"CuCIMD\",\n                \"Decollated\",\n                \"InvertD\",\n                \"InvertibleTransform\",\n                \"Lambda\",\n                \"LambdaD\",\n                \"MapTransform\",\n                \"OneOf\",\n                \"RandCuCIM\",\n                \"RandCuCIMD\",\n                \"RandomOrder\",\n                \"PadListDataCollate\",\n                \"RandLambda\",\n                \"RandLambdaD\",\n                \"RandTorchVisionD\",\n                \"RandomizableTransform\",\n                \"TorchVisionD\",\n                \"Transform\",\n            ]\n        ):\n            backends[n] = [TransformBackends.TORCH in obj.backend, TransformBackends.NUMPY in obj.backend]\n    return backends\n\n\ndef print_transform_backends():\n    \"\"\"Prints a list of backends of all MONAI transforms.\"\"\"\n\n    class Colors:\n        none = \"\"\n        red = \"91\"\n        green = \"92\"\n        yellow = \"93\"\n\n    def print_color(t, color):\n        print(f\"\\033[{color}m{t}\\033[00m\")\n\n    def print_table_column(name, torch, numpy, color=Colors.none):\n        print_color(f\"{name:<50} {torch:<8} {numpy:<8}\", color)\n\n    backends = get_transform_backends()\n    n_total = len(backends)\n    n_t_or_np, n_t, n_np, n_uncategorized = 0, 0, 0, 0\n    print_table_column(\"Transform\", \"Torch?\", \"Numpy?\")\n    for k, v in backends.items():\n        if all(v):\n            color = Colors.green\n            n_t_or_np += 1\n        elif v[0]:\n            color = Colors.green\n            n_t += 1\n        elif v[1]:\n            color = Colors.yellow\n            n_np += 1\n        else:\n            color = Colors.red\n            n_uncategorized += 1\n        print_table_column(k, v[0], v[1], color=color)\n\n    print(\"Total number of transforms:\", n_total)\n    print_color(f\"Number transforms allowing both torch and numpy: {n_t_or_np}\", Colors.green)\n    print_color(f\"Number of TorchTransform: {n_t}\", Colors.green)\n    print_color(f\"Number of NumpyTransform: {n_np}\", Colors.yellow)\n    print_color(f\"Number of uncategorized: {n_uncategorized}\", Colors.red)\n\n\ndef convert_pad_mode(dst: NdarrayOrTensor, mode: str | None):\n    \"\"\"\n    Utility to convert padding mode between numpy array and PyTorch Tensor.\n\n    Args:\n        dst: target data to convert padding mode for, should be numpy array or PyTorch Tensor.\n        mode: current padding mode.\n\n    \"\"\"\n    if isinstance(dst, torch.Tensor):\n        if mode == \"wrap\":\n            mode = \"circular\"\n        elif mode == \"edge\":\n            mode = \"replicate\"\n        return look_up_option(mode, PytorchPadMode)\n    if isinstance(dst, np.ndarray):\n        if mode == \"circular\":\n            mode = \"wrap\"\n        elif mode == \"replicate\":\n            mode = \"edge\"\n        return look_up_option(mode, NumpyPadMode)\n    raise ValueError(f\"unsupported data type: {type(dst)}.\")\n\n\ndef convert_to_contiguous(\n    data: NdarrayOrTensor | str | bytes | Mapping | Sequence[Any], **kwargs\n) -> NdarrayOrTensor | Mapping | Sequence[Any]:\n    \"\"\"\n    Check and ensure the numpy array or PyTorch Tensor in data to be contiguous in memory.\n\n    Args:\n        data: input data to convert, will recursively convert the numpy array or PyTorch Tensor in dict and sequence.\n        kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:\n            https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.\n\n    \"\"\"\n    if isinstance(data, (np.ndarray, torch.Tensor, str, bytes)):\n        return ascontiguousarray(data, **kwargs)\n    elif isinstance(data, Mapping):\n        return {k: convert_to_contiguous(v, **kwargs) for k, v in data.items()}\n    elif isinstance(data, Sequence):\n        return type(data)(convert_to_contiguous(i, **kwargs) for i in data)  # type: ignore\n    else:\n        return data\n\n\ndef scale_affine(spatial_size, new_spatial_size, centered: bool = True):\n    \"\"\"\n    Compute the scaling matrix according to the new spatial size\n\n    Args:\n        spatial_size: original spatial size.\n        new_spatial_size: new spatial size.\n        centered: whether the scaling is with respect to the image center (True, default) or corner (False).\n\n    Returns:\n        the scaling matrix.\n\n    \"\"\"\n    r = max(len(new_spatial_size), len(spatial_size))\n    if spatial_size == new_spatial_size:\n        return np.eye(r + 1)\n    s = np.array([float(o) / float(max(n, 1)) for o, n in zip(spatial_size, new_spatial_size)], dtype=float)\n    scale = create_scale(r, s.tolist())\n    if centered:\n        scale[:r, -1] = (np.diag(scale)[:r] - 1) / 2.0  # type: ignore\n    return scale\n\n\ndef attach_hook(func, hook, mode=\"pre\"):\n    \"\"\"\n    Adds `hook` before or after a `func` call. If mode is \"pre\", the wrapper will call hook then func.\n    If the mode is \"post\", the wrapper will call func then hook.\n    \"\"\"\n    supported = {\"pre\", \"post\"}\n    if look_up_option(mode, supported) == \"pre\":\n        _hook, _func = hook, func\n    else:\n        _hook, _func = func, hook\n\n    @wraps(func)\n    def wrapper(inst, data):\n        data = _hook(inst, data)\n        return _func(inst, data)\n\n    return wrapper\n\n\ndef sync_meta_info(key, data_dict, t: bool = True):\n    \"\"\"\n    Given the key, sync up between metatensor `data_dict[key]` and meta_dict `data_dict[key_transforms/meta_dict]`.\n    t=True: the one with more applied_operations in metatensor vs meta_dict is the output, False: less is the output.\n    \"\"\"\n    if not isinstance(data_dict, Mapping):\n        return data_dict\n    d = dict(data_dict)\n\n    # update meta dicts\n    meta_dict_key = PostFix.meta(key)\n    if meta_dict_key not in d:\n        d[meta_dict_key] = monai.data.MetaTensor.get_default_meta()\n    if not isinstance(d[key], monai.data.MetaTensor):\n        d[key] = monai.data.MetaTensor(data_dict[key])\n        d[key].meta = d[meta_dict_key]\n    d[meta_dict_key].update(d[key].meta)  # prefer metatensor's data\n\n    # update xform info\n    xform_key = monai.transforms.TraceableTransform.trace_key(key)\n    if xform_key not in d:\n        d[xform_key] = monai.data.MetaTensor.get_default_applied_operations()\n    from_meta, from_dict = d[key].applied_operations, d[xform_key]\n    if not from_meta:  # avoid []\n        d[key].applied_operations = d[xform_key] = from_dict\n        return d\n    if not from_dict:\n        d[key].applied_operations = d[xform_key] = from_meta\n        return d\n    if t:  # larger transform info stack is used as the result\n        ref = from_meta if len(from_meta) > len(from_dict) else from_dict\n    else:  # smaller transform info stack is used as the result\n        ref = from_dict if len(from_meta) > len(from_dict) else from_meta\n    d[key].applied_operations = d[xform_key] = ref\n    return d\n\n\ndef check_boundaries(boundaries) -> None:\n    \"\"\"\n    Check boundaries for Signal transforms\n    \"\"\"\n    if not (\n        isinstance(boundaries, Sequence) and len(boundaries) == 2 and all(isinstance(i, float) for i in boundaries)\n    ):\n        raise ValueError(\"Incompatible values: boundaries needs to be a list of float.\")\n\n\ndef paste_slices(tup):\n    \"\"\"\n    given a tuple (pos,w,max_w), return a tuple of slices\n    \"\"\"\n    pos, w, max_w = tup\n    max_w = max_w.shape[-1]\n    orig_min = max(pos, 0)\n    orig_max = min(pos + w, max_w)\n    block_min = -min(pos, 0)\n    block_max = max_w - max(pos + w, max_w)\n    block_max = block_max if block_max != 0 else None\n    return slice(orig_min, orig_max), slice(block_min, block_max)\n\n\ndef paste(orig, block, loc):\n    \"\"\"\n    given a location (loc) and an original array (orig), paste a block array into it\n    \"\"\"\n    loc_zip = zip(loc, block.shape, orig)\n    orig_slices, block_slices = zip(*map(paste_slices, loc_zip))\n\n    orig[:, orig_slices[0]] = block[block_slices[0]]\n\n    if orig.shape[0] == 1:\n        orig = orig.squeeze()\n    return orig\n\n\ndef squarepulse(sig, duty: float = 0.5):\n    \"\"\"\n    compute squarepulse using pytorch\n    equivalent to numpy implementation from\n    https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.square.html\n    \"\"\"\n    t, w = convert_to_tensor(sig), convert_to_tensor(duty)\n    w = convert_to_tensor(w)\n    t = convert_to_tensor(t)\n\n    y = torch.zeros(t.shape)\n\n    mask1 = (w > 1) | (w < 0)\n\n    tmod = torch.remainder(t, 2 * torch.pi)\n    mask2 = (~mask1) & (tmod < w * 2 * torch.pi)\n    y[mask2] = 1\n    mask3 = (~mask1) & (~mask2)\n    y[mask3] = -1\n    return y\n\n\ndef _to_numpy_resample_interp_mode(interp_mode):\n    ret = look_up_option(str(interp_mode), SplineMode, default=None)\n    if ret is not None:\n        return int(ret)\n    _mapping = {\n        InterpolateMode.NEAREST: SplineMode.ZERO,\n        InterpolateMode.NEAREST_EXACT: SplineMode.ZERO,\n        InterpolateMode.LINEAR: SplineMode.ONE,\n        InterpolateMode.BILINEAR: SplineMode.ONE,\n        InterpolateMode.TRILINEAR: SplineMode.ONE,\n        InterpolateMode.BICUBIC: SplineMode.THREE,\n        InterpolateMode.AREA: SplineMode.ZERO,\n    }\n    ret = look_up_option(str(interp_mode), _mapping, default=None)\n    if ret is not None:\n        return ret\n    return look_up_option(str(interp_mode), list(_mapping) + list(SplineMode))  # for better error msg\n\n\ndef _to_torch_resample_interp_mode(interp_mode):\n    ret = look_up_option(str(interp_mode), InterpolateMode, default=None)\n    if ret is not None:\n        return ret\n    _mapping = {\n        SplineMode.ZERO: InterpolateMode.NEAREST_EXACT,\n        SplineMode.ONE: InterpolateMode.LINEAR,\n        SplineMode.THREE: InterpolateMode.BICUBIC,\n    }\n    ret = look_up_option(str(interp_mode), _mapping, default=None)\n    if ret is not None:\n        return ret\n    return look_up_option(str(interp_mode), list(_mapping) + list(InterpolateMode))\n\n\ndef _to_numpy_resample_padding_mode(m):\n    ret = look_up_option(str(m), NdimageMode, default=None)\n    if ret is not None:\n        return ret\n    _mapping = {\n        GridSamplePadMode.ZEROS: NdimageMode.CONSTANT,\n        GridSamplePadMode.BORDER: NdimageMode.NEAREST,\n        GridSamplePadMode.REFLECTION: NdimageMode.REFLECT,\n    }\n    ret = look_up_option(str(m), _mapping, default=None)\n    if ret is not None:\n        return ret\n    return look_up_option(str(m), list(_mapping) + list(NdimageMode))\n\n\ndef _to_torch_resample_padding_mode(m):\n    ret = look_up_option(str(m), GridSamplePadMode, default=None)\n    if ret is not None:\n        return ret\n    _mapping = {\n        NdimageMode.CONSTANT: GridSamplePadMode.ZEROS,\n        NdimageMode.GRID_CONSTANT: GridSamplePadMode.ZEROS,\n        NdimageMode.NEAREST: GridSamplePadMode.BORDER,\n        NdimageMode.REFLECT: GridSamplePadMode.REFLECTION,\n        NdimageMode.WRAP: GridSamplePadMode.REFLECTION,\n        NdimageMode.GRID_WRAP: GridSamplePadMode.REFLECTION,\n        NdimageMode.GRID_MIRROR: GridSamplePadMode.REFLECTION,\n    }\n    ret = look_up_option(str(m), _mapping, default=None)\n    if ret is not None:\n        return ret\n    return look_up_option(str(m), list(_mapping) + list(GridSamplePadMode))\n\n\n@lru_cache(None)\ndef resolves_modes(\n    interp_mode: str | None = \"constant\", padding_mode=\"zeros\", backend=TransformBackends.TORCH, **kwargs\n):\n    \"\"\"\n    Automatically adjust the resampling interpolation mode and padding mode,\n    so that they are compatible with the corresponding API of the `backend`.\n    Depending on the availability of the backends, when there's no exact\n    equivalent, a similar mode is returned.\n\n    Args:\n        interp_mode: interpolation mode.\n        padding_mode: padding mode.\n        backend: optional backend of `TransformBackends`. If None, the backend will be decided from `interp_mode`.\n        kwargs: additional keyword arguments. currently support ``torch_interpolate_spatial_nd``, to provide\n            additional information to determine ``linear``, ``bilinear`` and ``trilinear``;\n            ``use_compiled`` to use MONAI's precompiled backend (pytorch c++ extensions), default to ``False``.\n    \"\"\"\n    _interp_mode, _padding_mode, _kwargs = None, None, (kwargs or {}).copy()\n    if backend is None:  # infer backend\n        backend = (\n            TransformBackends.NUMPY\n            if look_up_option(str(interp_mode), SplineMode, default=None) is not None\n            else TransformBackends.TORCH\n        )\n    if backend == TransformBackends.NUMPY:\n        _interp_mode = _to_numpy_resample_interp_mode(interp_mode)\n        _padding_mode = _to_numpy_resample_padding_mode(padding_mode)\n        return backend, _interp_mode, _padding_mode, _kwargs\n    _interp_mode = _to_torch_resample_interp_mode(interp_mode)\n    _padding_mode = _to_torch_resample_padding_mode(padding_mode)\n    if str(_interp_mode).endswith(\"linear\"):\n        nd = _kwargs.pop(\"torch_interpolate_spatial_nd\", 2)\n        if nd == 1:\n            _interp_mode = InterpolateMode.LINEAR\n        elif nd == 3:\n            _interp_mode = InterpolateMode.TRILINEAR\n        else:\n            _interp_mode = InterpolateMode.BILINEAR  # torch grid_sample bilinear is trilinear in 3D\n    if not _kwargs.pop(\"use_compiled\", False):\n        return backend, _interp_mode, _padding_mode, _kwargs\n    _padding_mode = 1 if _padding_mode == \"reflection\" else _padding_mode\n    if _interp_mode == \"bicubic\":\n        _interp_mode = 3\n    elif str(_interp_mode).endswith(\"linear\"):\n        _interp_mode = 1\n    else:\n        _interp_mode = GridSampleMode(_interp_mode)\n    return backend, _interp_mode, _padding_mode, _kwargs\n\n\ndef check_applied_operations(entry: list | dict, status_key: str, default_message: str = \"No message provided\"):\n    \"\"\"\n    Check the operations of a MetaTensor to determine whether there are any statuses\n    Args:\n        entry: a dictionary that may contain TraceKey.STATUS entries, or a list of such dictionaries\n        status_key: the status key to search for. This must be an entry in `TraceStatusKeys`_\n        default_message: The message to provide if no messages are provided for the given status key entry\n\n    Returns:\n        A list of status messages matching the providing status key\n\n    \"\"\"\n    if isinstance(entry, list):\n        results = list()\n        for sub_entry in entry:\n            results.extend(check_applied_operations(sub_entry, status_key, default_message))\n        return results\n    else:\n        status_key_ = TraceStatusKeys(status_key)\n        if TraceKeys.STATUSES in entry:\n            if status_key_ in entry[TraceKeys.STATUSES]:\n                reason = entry[TraceKeys.STATUSES][status_key_]\n                if reason is None:\n                    return [default_message]\n                return reason if isinstance(reason, list) else [reason]\n        return []\n\n\ndef has_status_keys(data: torch.Tensor, status_key: Any, default_message: str = \"No message provided\"):\n    \"\"\"\n    Checks whether a given tensor is has a particular status key message on any of its\n    applied operations. If it doesn't, it returns the tuple `(False, None)`. If it does\n    it returns a tuple of True and a list of status messages for that status key.\n\n    Status keys are defined in :class:`TraceStatusKeys<monai.utils.enums.TraceStatusKeys>`.\n\n    This function also accepts:\n\n    * dictionaries of tensors\n    * lists or tuples of tensors\n    * list or tuples of dictionaries of tensors\n\n    In any of the above scenarios, it iterates through the collections and executes itself recursively until it is\n    operating on tensors.\n\n    Args:\n        data: a `torch.Tensor` or `MetaTensor` or collections of torch.Tensor or MetaTensor, as described above\n        status_key: the status key to look for, from `TraceStatusKeys`\n        default_message: a default message to use if the status key entry doesn't have a message set\n\n    Returns:\n        A tuple. The first entry is `False` or `True`. The second entry is the status messages that can be used for the\n        user to help debug their pipelines.\n\n    \"\"\"\n    status_key_occurrences = list()\n    if isinstance(data, (list, tuple)):\n        for d in data:\n            _, reasons = has_status_keys(d, status_key, default_message)\n            if reasons is not None:\n                status_key_occurrences.extend(reasons)\n    elif isinstance(data, monai.data.MetaTensor):\n        for op in data.applied_operations:\n            status_key_occurrences.extend(check_applied_operations(op, status_key, default_message))\n    elif isinstance(data, dict):\n        for d in data.values():\n            _, reasons = has_status_keys(d, status_key, default_message)\n            if reasons is not None:\n                status_key_occurrences.extend(reasons)\n\n    if len(status_key_occurrences) > 0:\n        return False, status_key_occurrences\n    return True, None\n\n\ndef distance_transform_edt(\n    img: NdarrayOrTensor,\n    sampling: None | float | list[float] = None,\n    return_distances: bool = True,\n    return_indices: bool = False,\n    distances: NdarrayOrTensor | None = None,\n    indices: NdarrayOrTensor | None = None,\n    *,\n    block_params: tuple[int, int, int] | None = None,\n    float64_distances: bool = False,\n) -> None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor]:\n    \"\"\"\n    Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy.\n    To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.\n\n    Note that the results of the libraries can differ, so stick to one if possible.\n    For details, check out the `SciPy`_ and `cuCIM`_ documentation.\n\n    .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html\n    .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt\n\n    Args:\n        img: Input image on which the distance transform shall be run.\n            Has to be a channel first array, must have shape: (num_channels, H, W [,D]).\n            Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.\n            Input gets passed channel-wise to the distance-transform, thus results from this function will differ\n            from directly calling ``distance_transform_edt()`` in CuPy or SciPy.\n        sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;\n            if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.\n        return_distances: Whether to calculate the distance transform.\n        return_indices: Whether to calculate the feature transform.\n        distances: An output array to store the calculated distance transform, instead of returning it.\n            `return_distances` must be True.\n        indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True.\n        block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_.\n        float64_distances: This parameter is specific to cuCIM and does not exist in SciPy.\n            If True, use double precision in the distance computation (to match SciPy behavior).\n            Otherwise, single precision will be used for efficiency.\n\n    Returns:\n        distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied.\n            It will have the same shape and type as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True,\n            otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64.\n        indices: The calculated feature transform. It has an image-shaped array for each dimension of the image.\n            The type will be equal to the type of the image.\n            Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64.\n\n    \"\"\"\n    distance_transform_edt, has_cucim = optional_import(\n        \"cucim.core.operations.morphology\", name=\"distance_transform_edt\"\n    )\n    use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == \"cuda\"\n    if not return_distances and not return_indices:\n        raise RuntimeError(\"Neither return_distances nor return_indices True\")\n\n    if not (img.ndim >= 3 and img.ndim <= 4):\n        raise RuntimeError(\"Wrong input dimensionality. Use (num_channels, H, W [,D])\")\n\n    distances_original, indices_original = distances, indices\n    distances, indices = None, None\n    if use_cp:\n        distances_, indices_ = None, None\n        if return_distances:\n            dtype = torch.float64 if float64_distances else torch.float32\n            if distances is None:\n                distances = torch.zeros_like(img, memory_format=torch.contiguous_format, dtype=dtype)  # type: ignore\n            else:\n                if not isinstance(distances, torch.Tensor) and distances.device != img.device:\n                    raise TypeError(\"distances must be a torch.Tensor on the same device as img\")\n                if not distances.dtype == dtype:\n                    raise TypeError(\"distances must be a torch.Tensor of dtype float32 or float64\")\n            distances_ = convert_to_cupy(distances)\n        if return_indices:\n            dtype = torch.int32\n            if indices is None:\n                indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype)  # type: ignore\n            else:\n                if not isinstance(indices, torch.Tensor) and indices.device != img.device:\n                    raise TypeError(\"indices must be a torch.Tensor on the same device as img\")\n                if not indices.dtype == dtype:\n                    raise TypeError(\"indices must be a torch.Tensor of dtype int32\")\n            indices_ = convert_to_cupy(indices)\n        img_ = convert_to_cupy(img)\n        for channel_idx in range(img_.shape[0]):\n            distance_transform_edt(\n                img_[channel_idx],\n                sampling=sampling,\n                return_distances=return_distances,\n                return_indices=return_indices,\n                distances=distances_[channel_idx] if distances_ is not None else None,\n                indices=indices_[channel_idx] if indices_ is not None else None,\n                block_params=block_params,\n                float64_distances=float64_distances,\n            )\n        torch.cuda.synchronize()\n    else:\n        if not has_ndimage:\n            raise RuntimeError(\"scipy.ndimage required if cupy is not available\")\n        img_ = convert_to_numpy(img)\n        if return_distances:\n            if distances is None:\n                distances = np.zeros_like(img_, dtype=np.float64)\n            else:\n                if not isinstance(distances, np.ndarray):\n                    raise TypeError(\"distances must be a numpy.ndarray\")\n                if not distances.dtype == np.float64:\n                    raise TypeError(\"distances must be a numpy.ndarray of dtype float64\")\n        if return_indices:\n            if indices is None:\n                indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32)\n            else:\n                if not isinstance(indices, np.ndarray):\n                    raise TypeError(\"indices must be a numpy.ndarray\")\n                if not indices.dtype == np.int32:\n                    raise TypeError(\"indices must be a numpy.ndarray of dtype int32\")\n\n        for channel_idx in range(img_.shape[0]):\n            ndimage.distance_transform_edt(\n                img_[channel_idx],\n                sampling=sampling,\n                return_distances=return_distances,\n                return_indices=return_indices,\n                distances=distances[channel_idx] if distances is not None else None,\n                indices=indices[channel_idx] if indices is not None else None,\n            )\n\n    r_vals = []\n    if return_distances and distances_original is None:\n        r_vals.append(distances_ if use_cp else distances)\n    if return_indices and indices_original is None:\n        r_vals.append(indices)\n    if not r_vals:\n        return None\n    device = img.device if isinstance(img, torch.Tensor) else None\n    return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]\n\n\ndef apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None):\n    \"\"\"\n    apply affine transformation to a set of points.\n\n    Args:\n        data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3),\n            where C represents the number of channels and N denotes the number of points.\n        affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4).\n        dtype: output data dtype.\n    \"\"\"\n    data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64)\n    affine = to_affine_nd(data_.shape[-1], affine)\n\n    homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2)  # type: ignore\n    transformed_homogeneous = torch.matmul(homogeneous, affine.T)\n    transformed_coordinates = transformed_homogeneous[:, :, :-1]\n    out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)\n\n    return out\n\n\nif __name__ == \"__main__\":\n    print_transform_backends()\n"
  },
  {
    "path": "monai/transforms/utils_create_transform_ims.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport pathlib\nimport tempfile\nimport textwrap\nfrom copy import deepcopy\nfrom glob import glob\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\nimport torch\n\nfrom monai.apps import download_and_extract\nfrom monai.transforms import (\n    Affine,\n    Affined,\n    AsDiscrete,\n    Compose,\n    EnsureChannelFirstd,\n    Flip,\n    Flipd,\n    LoadImaged,\n    MapTransform,\n    Orientation,\n    Orientationd,\n    Rand3DElastic,\n    Rand3DElasticd,\n    RandFlip,\n    RandFlipd,\n    Randomizable,\n    RandRotate,\n    RandRotated,\n    RandZoom,\n    RandZoomd,\n    Rotate,\n    Rotate90,\n    Rotate90d,\n    Rotated,\n    ScaleIntensity,\n    ScaleIntensityd,\n    SpatialPadd,\n    Zoom,\n    Zoomd,\n)\nfrom monai.transforms.croppad.array import (\n    BorderPad,\n    CenterScaleCrop,\n    CenterSpatialCrop,\n    CropForeground,\n    DivisiblePad,\n    RandCropByLabelClasses,\n    RandCropByPosNegLabel,\n    RandScaleCrop,\n    RandSpatialCrop,\n    RandSpatialCropSamples,\n    RandWeightedCrop,\n    ResizeWithPadOrCrop,\n    SpatialCrop,\n    SpatialPad,\n)\nfrom monai.transforms.croppad.dictionary import (\n    BorderPadd,\n    CenterScaleCropd,\n    CenterSpatialCropd,\n    CropForegroundd,\n    DivisiblePadd,\n    RandCropByLabelClassesd,\n    RandCropByPosNegLabeld,\n    RandScaleCropd,\n    RandSpatialCropd,\n    RandSpatialCropSamplesd,\n    RandWeightedCropd,\n    ResizeWithPadOrCropd,\n    SpatialCropd,\n)\nfrom monai.transforms.intensity.array import (\n    AdjustContrast,\n    ForegroundMask,\n    GaussianSharpen,\n    GaussianSmooth,\n    GibbsNoise,\n    HistogramNormalize,\n    KSpaceSpikeNoise,\n    MaskIntensity,\n    MedianSmooth,\n    NormalizeIntensity,\n    RandAdjustContrast,\n    RandBiasField,\n    RandCoarseDropout,\n    RandCoarseShuffle,\n    RandGaussianNoise,\n    RandGaussianSharpen,\n    RandGaussianSmooth,\n    RandGibbsNoise,\n    RandHistogramShift,\n    RandKSpaceSpikeNoise,\n    RandRicianNoise,\n    RandScaleIntensity,\n    RandShiftIntensity,\n    RandStdShiftIntensity,\n    SavitzkyGolaySmooth,\n    ScaleIntensityRange,\n    ScaleIntensityRangePercentiles,\n    ShiftIntensity,\n    StdShiftIntensity,\n    ThresholdIntensity,\n)\nfrom monai.transforms.intensity.dictionary import (\n    AdjustContrastd,\n    ForegroundMaskd,\n    GaussianSharpend,\n    GaussianSmoothd,\n    GibbsNoised,\n    HistogramNormalized,\n    KSpaceSpikeNoised,\n    MaskIntensityd,\n    MedianSmoothD,\n    NormalizeIntensityd,\n    RandAdjustContrastd,\n    RandBiasFieldd,\n    RandCoarseDropoutd,\n    RandCoarseShuffled,\n    RandGaussianNoised,\n    RandGaussianSharpend,\n    RandGaussianSmoothd,\n    RandGibbsNoised,\n    RandHistogramShiftd,\n    RandKSpaceSpikeNoised,\n    RandRicianNoised,\n    RandScaleIntensityd,\n    RandShiftIntensityd,\n    RandStdShiftIntensityd,\n    SavitzkyGolaySmoothd,\n    ScaleIntensityRanged,\n    ScaleIntensityRangePercentilesd,\n    ShiftIntensityd,\n    StdShiftIntensityd,\n    ThresholdIntensityd,\n)\nfrom monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour, RemoveSmallObjects\nfrom monai.transforms.post.dictionary import (\n    AsDiscreted,\n    KeepLargestConnectedComponentd,\n    LabelFilterd,\n    LabelToContourd,\n    RemoveSmallObjectsd,\n)\nfrom monai.transforms.smooth_field.array import (\n    RandSmoothDeform,\n    RandSmoothFieldAdjustContrast,\n    RandSmoothFieldAdjustIntensity,\n)\nfrom monai.transforms.smooth_field.dictionary import (\n    RandSmoothDeformd,\n    RandSmoothFieldAdjustContrastd,\n    RandSmoothFieldAdjustIntensityd,\n)\nfrom monai.transforms.spatial.array import (\n    GridDistortion,\n    Rand2DElastic,\n    RandAffine,\n    RandAxisFlip,\n    RandGridDistortion,\n    RandRotate90,\n    Resize,\n    Spacing,\n)\nfrom monai.transforms.spatial.dictionary import (\n    GridDistortiond,\n    Rand2DElasticd,\n    RandAffined,\n    RandAxisFlipd,\n    RandGridDistortiond,\n    RandRotate90d,\n    Resized,\n    Spacingd,\n)\nfrom monai.utils.enums import CommonKeys\nfrom monai.utils.misc import MONAIEnvVars\nfrom monai.utils.module import optional_import\n\nif TYPE_CHECKING:\n    import matplotlib.pyplot as plt\n\n    has_matplotlib = True\n\nelse:\n    plt, has_matplotlib = optional_import(\"matplotlib.pyplot\")\n\n\ndef get_data(keys):\n    \"\"\"Get the example data to be used.\n\n    Use MarsAtlas as it only contains 1 image for quick download and\n    that image is parcellated.\n    \"\"\"\n    cache_dir = MONAIEnvVars.data_dir() or tempfile.mkdtemp()\n    fname = \"MarsAtlas-MNI-Colin27.zip\"\n    url = \"https://www.dropbox.com/s/ndz8qtqblkciole/\" + fname + \"?dl=1\"\n    out_path = os.path.join(cache_dir, \"MarsAtlas-MNI-Colin27\")\n    zip_path = os.path.join(cache_dir, fname)\n\n    download_and_extract(url, zip_path, out_path)\n\n    image, label = sorted(glob(os.path.join(out_path, \"*.nii\")))\n\n    transforms = Compose(\n        [\n            LoadImaged(keys),\n            EnsureChannelFirstd(keys),\n            ScaleIntensityd(CommonKeys.IMAGE),\n            Rotate90d(keys, spatial_axes=(0, 2)),\n        ]\n    )\n    data = transforms({CommonKeys.IMAGE: image, CommonKeys.LABEL: label})\n    max_size = max(data[keys[0]].shape)\n    padder = SpatialPadd(keys, (max_size, max_size, max_size))\n    return padder(data)\n\n\ndef update_docstring(code_path, transform_name):\n    \"\"\"\n    Find the documentation for a given transform and if it's missing,\n    add a pointer to the transform's example image.\n    \"\"\"\n    with open(code_path) as f:\n        contents = f.readlines()\n    doc_start = None\n    for i, line in enumerate(contents):\n        # find the line containing start of the transform documentation\n        if \"`\" + transform_name + \"`\" in line:\n            doc_start = i\n            break\n    if doc_start is None:\n        raise RuntimeError(\"Couldn't find transform documentation\")\n\n    # if image is already in docs, nothing to do\n    image_line = doc_start + 2\n    if \".. image\" in contents[image_line]:\n        return\n\n    # add the line for the image and the alt text\n    contents_orig = deepcopy(contents)\n    contents.insert(\n        image_line,\n        \".. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/\" + transform_name + \".png\\n\",\n    )\n    contents.insert(image_line + 1, \"    :alt: example of \" + transform_name + \"\\n\")\n\n    # check that we've only added two lines\n    if len(contents) != len(contents_orig) + 2:\n        raise AssertionError\n\n    # write the updated doc to overwrite the original\n    with open(code_path, \"w\") as f:\n        f.writelines(contents)\n\n\ndef pre_process_data(data, ndim, is_map, is_post):\n    \"\"\"If transform requires 2D data, then convert to 2D by selecting the middle of the last dimension.\"\"\"\n    if ndim == 2:\n        data = {k: v[..., v.shape[-1] // 2] for k, v in data.items()}\n    if is_map:\n        return data\n    return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE]\n\n\ndef get_2d_slice(image, view: int, is_label):\n    \"\"\"If image is 3d, get the central slice. If is already 2d, return as-is.\n    If image is label, set 0 to np.nan.\n    \"\"\"\n    if image.ndim == 2:\n        out = image\n    else:\n        shape = image.shape\n        slices = [slice(0, s) for s in shape]\n        _slice = shape[view] // 2\n        slices[view] = slice(_slice, _slice + 1)\n        out = np.squeeze(image[tuple(slices)], view)\n    if is_label:\n        out[out == 0] = np.nan\n    return out\n\n\ndef get_stacked_2d_ims(im, is_label):\n    \"\"\"Get the 3 orthogonal views and stack them into 1 image.\n    Requires that all images be same size, but this is taken care\n    of by the `SpatialPadd` earlier.\n    \"\"\"\n    return [get_2d_slice(im, i, is_label) for i in range(3)]\n\n\ndef get_stacked_before_after(before, after, is_label=False):\n    \"\"\"Stack before and after images into 1 image if 3d.\n    Requires that before and after images be the same size.\n    \"\"\"\n    return [get_stacked_2d_ims(d, is_label) for d in (before, after)]\n\n\ndef save_image(images, labels, filename, transform_name, transform_args, shapes, colorbar=False):\n    \"\"\"Save image to file, ensuring there's no whitespace around the edge.\"\"\"\n    plt.rcParams.update({\"font.family\": \"monospace\"})\n    plt.style.use(\"dark_background\")\n    nrow = len(images)  # before and after (should always be 2)\n    ncol = len(images[0])  # num orthogonal views (either 1 or 3)\n    # roughly estimate the height_ratios of the first:second row\n    hs = [float(r[0].shape[0]) for r in images]\n    fig = plt.figure(tight_layout=True)\n    spec = fig.add_gridspec(nrow, ncol, hspace=0, wspace=0, height_ratios=hs)\n    for row in range(nrow):\n        vmin = min(i.min() for i in images[row])\n        vmax = max(i.max() for i in images[row])\n        for col in range(ncol):\n            ax = fig.add_subplot(spec[row, col])\n            imshow = ax.imshow(images[row][col], cmap=\"gray\", vmin=vmin, vmax=vmax)\n            ax.set_aspect(\"equal\")\n            if colorbar and col == ncol - 1:\n                plt.colorbar(imshow, ax=ax)\n            if col == 0:\n                y_label = \"After\" if row else \"Before\"\n                y_label += (\"\\n\" + shapes[row]) if shapes[0] != shapes[1] else \"\"\n                ax.set_ylabel(y_label)\n            # print yticks for the right most column\n            if col != ncol - 1 or colorbar:\n                ax.set_yticks([])\n            else:\n                ax.yaxis.tick_right()\n                for n, label in enumerate(ax.yaxis.get_ticklabels()):\n                    if n > 2:\n                        label.set_visible(False)\n            ax.set_xticks([])\n            ax.set_frame_on(False)\n            if labels is not None:\n                ax.imshow(labels[row][col], cmap=\"hsv\", alpha=0.9, interpolation=\"nearest\")\n    # title is e.g., Flipd(keys=keys, spatial_axis=0)\n    title = transform_name + \"(\"\n    for k, v in transform_args.items():\n        title += k + \"=\"\n        if isinstance(v, str):\n            title += \"'\" + v + \"'\"\n        elif isinstance(v, (np.ndarray, torch.Tensor)):\n            title += \"[array]\"\n        elif callable(v):\n            title += \"[callable]\"\n        else:\n            title += str(v)\n        title += \", \"\n    if len(transform_args) > 0:\n        title = title[:-2]\n    title += \")\"\n    # shorten the lines\n    title = textwrap.fill(title, 50, break_long_words=False, subsequent_indent=\" \" * (len(transform_name) + 1))\n    fig.suptitle(title, x=0.1, horizontalalignment=\"left\")\n    fig.savefig(filename)\n    plt.close(fig)\n\n\ndef get_images(data, is_label=False):\n    \"\"\"Get image. If is dictionary, extract key. If is list, stack. If both dictionary and list, do both.\n    Also return the image size as string to be used im the imshow. If it's a list, return `N x (H,W,D)`.\n    \"\"\"\n    # If not a list, convert\n    if not isinstance(data, list):\n        data = [data]\n    key = CommonKeys.LABEL if is_label else CommonKeys.IMAGE\n    is_map = isinstance(data[0], dict)\n    # length of the list will be equal to number of samples produced. This will be 1 except for transforms that\n    # produce `num_samples`.\n    data = [d[key] if is_map else d for d in data]\n    data = [d[0] for d in data]  # remove channel component\n\n    # for each sample, create a list of the orthogonal views. If image is 2d, length will be 1. If 3d, there\n    # will be three orthogonal views\n    num_samples = len(data)\n    num_orthog_views = 3 if data[0].ndim == 3 else 1\n    shape_str = (f\"{num_samples} x \" if num_samples > 1 else \"\") + str(data[0].shape)\n    for i in range(num_samples):\n        data[i] = [get_2d_slice(data[i], view, is_label) for view in range(num_orthog_views)]\n\n    out = []\n    if num_samples == 1:\n        out = data[0]\n    else:\n        # we might need to panel the images. this happens if a transform produces e.g. 4 output images.\n        # In this case, we create a 2-by-2 grid from them. Output will be a list containing n_orthog_views,\n        # each element being either the image (if num_samples is 1) or the panelled image.\n        nrows = int(np.floor(num_samples**0.5))\n        for view in range(num_orthog_views):\n            result = np.asarray([d[view] for d in data])\n            nindex, height, width = result.shape\n            ncols = nindex // nrows\n            # only implemented for square number of images (e.g. 4 images goes to a 2-by-2 panel)\n            if nindex != nrows * ncols:\n                raise NotImplementedError\n            # want result.shape = (height*nrows, width*ncols), have to be careful about striding\n            result = result.reshape(nrows, ncols, height, width).swapaxes(1, 2).reshape(height * nrows, width * ncols)\n            out.append(result)\n    return out, shape_str\n\n\ndef create_transform_im(\n    transform, transform_args, data, ndim=3, colorbar=False, update_doc=True, seed=0, is_post=False\n):\n    \"\"\"Create an image with the before and after of the transform.\n    Also update the transform's documentation to point to this image.\"\"\"\n\n    transform = transform(**transform_args)\n\n    if not has_matplotlib:\n        raise RuntimeError\n\n    if isinstance(transform, Randomizable):\n        # increment the seed for map transforms so they're different to the array versions.\n        seed = seed + 1 if isinstance(transform, MapTransform) else seed\n        transform.set_random_state(seed)\n\n    out_dir = MONAIEnvVars.doc_images()\n    if out_dir is None:\n        raise RuntimeError(\n            \"Please git clone https://github.com/Project-MONAI/DocImages\"\n            + \" and then set the environment variable `MONAI_DOC_IMAGES`\"\n        )\n    out_dir = os.path.join(out_dir, \"transforms\")\n\n    # Path is transform name\n    transform_name = transform.__class__.__name__\n    out_fname = transform_name + \".png\"\n    out_file = os.path.join(out_dir, out_fname)\n\n    is_map = isinstance(transform, MapTransform)\n    data_in = pre_process_data(deepcopy(data), ndim, is_map, is_post)\n\n    data_tr = transform(deepcopy(data_in))\n\n    images_before, before_shape = get_images(data_in)\n    images_after, after_shape = get_images(data_tr)\n    images = (images_before, images_after)\n    shapes = (before_shape, after_shape)\n\n    labels = None\n    if is_map:\n        labels_before, *_ = get_images(data_in, is_label=True)\n        labels_after, *_ = get_images(data_tr, is_label=True)\n        labels = (labels_before, labels_after)\n\n    save_image(images, labels, out_file, transform_name, transform_args, shapes, colorbar)\n\n    if update_doc:\n        base_dir = pathlib.Path(__file__).parent.parent.parent\n        rst_path = os.path.join(base_dir, \"docs\", \"source\", \"transforms.rst\")\n        update_docstring(rst_path, transform_name)\n\n\nif __name__ == \"__main__\":\n    keys = [CommonKeys.IMAGE, CommonKeys.LABEL]\n    data = get_data(keys)\n    create_transform_im(RandFlip, dict(prob=1, spatial_axis=1), data)\n    create_transform_im(RandFlipd, dict(keys=keys, prob=1, spatial_axis=2), data)\n    create_transform_im(Flip, dict(spatial_axis=1), data)\n    create_transform_im(Flipd, dict(keys=keys, spatial_axis=2), data)\n    create_transform_im(Orientation, dict(axcodes=\"RPI\"), data)\n    create_transform_im(Orientationd, dict(keys=keys, axcodes=\"RPI\"), data)\n    create_transform_im(\n        Rand3DElastic, dict(prob=1.0, sigma_range=(1, 2), magnitude_range=(0.5, 0.5), shear_range=(1, 1, 1)), data\n    )\n    create_transform_im(Affine, dict(shear_params=(0, 0.5, 0), image_only=True, padding_mode=\"zeros\"), data)\n    create_transform_im(\n        Affined, dict(keys=keys, shear_params=(0, 0.5, 0), mode=[\"bilinear\", \"nearest\"], padding_mode=\"zeros\"), data\n    )\n    create_transform_im(RandAffine, dict(prob=1, shear_range=(0.5, 0.5), padding_mode=\"zeros\"), data)\n    create_transform_im(\n        RandAffined,\n        dict(keys=keys, prob=1, shear_range=(0.5, 0.5), mode=[\"bilinear\", \"nearest\"], padding_mode=\"zeros\"),\n        data,\n    )\n    create_transform_im(\n        Rand3DElastic, dict(sigma_range=(5, 7), magnitude_range=(50, 150), prob=1, padding_mode=\"zeros\"), data\n    )\n    create_transform_im(\n        Rand2DElastic, dict(prob=1, spacing=(20, 20), magnitude_range=(1, 2), padding_mode=\"zeros\"), data, 2\n    )\n    create_transform_im(\n        Rand2DElasticd,\n        dict(\n            keys=keys,\n            prob=1,\n            spacing=(20, 20),\n            magnitude_range=(1, 2),\n            padding_mode=\"zeros\",\n            mode=[\"bilinear\", \"nearest\"],\n        ),\n        data,\n        2,\n    )\n    create_transform_im(\n        Rand3DElasticd,\n        dict(\n            keys=keys,\n            sigma_range=(5, 7),\n            magnitude_range=(50, 150),\n            prob=1,\n            padding_mode=\"zeros\",\n            mode=[\"bilinear\", \"nearest\"],\n        ),\n        data,\n    )\n    create_transform_im(Rotate90, dict(spatial_axes=(1, 2)), data)\n    create_transform_im(Rotate90d, dict(keys=keys, spatial_axes=(1, 2)), data)\n    create_transform_im(RandRotate90, dict(prob=1), data)\n    create_transform_im(RandRotate90d, dict(keys=keys, prob=1), data)\n    create_transform_im(Rotate, dict(angle=0.1), data)\n    create_transform_im(Rotated, dict(keys=keys, angle=0.1, mode=[\"bilinear\", \"nearest\"]), data)\n    create_transform_im(RandRotate, dict(prob=1, range_x=[0.4, 0.4]), data)\n    create_transform_im(RandRotated, dict(keys=keys, prob=1, range_x=[0.4, 0.4], mode=[\"bilinear\", \"nearest\"]), data)\n    create_transform_im(Zoom, dict(zoom=0.6), data)\n    create_transform_im(Zoomd, dict(keys=keys, zoom=1.3, mode=[\"area\", \"nearest\"]), data)\n    create_transform_im(RandZoom, dict(prob=1, min_zoom=0.6, max_zoom=0.8), data)\n    create_transform_im(RandZoomd, dict(keys=keys, prob=1, min_zoom=1.3, max_zoom=1.5, mode=[\"area\", \"nearest\"]), data)\n    create_transform_im(ScaleIntensity, dict(minv=0, maxv=10), data, colorbar=True)\n    create_transform_im(ScaleIntensityd, dict(keys=CommonKeys.IMAGE, minv=0, maxv=10), data, colorbar=True)\n    create_transform_im(RandScaleIntensity, dict(prob=1.0, factors=(5, 10)), data, colorbar=True)\n    create_transform_im(\n        RandScaleIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, factors=(5, 10)), data, colorbar=True\n    )\n    create_transform_im(DivisiblePad, dict(k=64), data)\n    create_transform_im(DivisiblePadd, dict(keys=keys, k=64), data)\n    create_transform_im(CropForeground, dict(allow_smaller=False), data)\n    create_transform_im(CropForegroundd, dict(keys=keys, source_key=CommonKeys.IMAGE, allow_smaller=False), data)\n    create_transform_im(RandGaussianNoise, dict(prob=1, mean=0, std=0.1), data)\n    create_transform_im(RandGaussianNoised, dict(keys=CommonKeys.IMAGE, prob=1, mean=0, std=0.1), data)\n    create_transform_im(KSpaceSpikeNoise, dict(loc=(100, 100, 100), k_intensity=13), data)\n    create_transform_im(KSpaceSpikeNoised, dict(keys=CommonKeys.IMAGE, loc=(100, 100, 100), k_intensity=13), data)\n    create_transform_im(RandKSpaceSpikeNoise, dict(prob=1, intensity_range=(10, 13)), data)\n    create_transform_im(RandKSpaceSpikeNoised, dict(keys=CommonKeys.IMAGE, prob=1, intensity_range=(13, 15)), data)\n    create_transform_im(RandRicianNoise, dict(prob=1.0, mean=1, std=0.5), data)\n    create_transform_im(RandRicianNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, mean=1, std=0.5), data)\n    create_transform_im(SavitzkyGolaySmooth, dict(window_length=5, order=1), data)\n    create_transform_im(SavitzkyGolaySmoothd, dict(keys=CommonKeys.IMAGE, window_length=5, order=1), data)\n    create_transform_im(GibbsNoise, dict(alpha=0.8), data)\n    create_transform_im(GibbsNoised, dict(keys=CommonKeys.IMAGE, alpha=0.8), data)\n    create_transform_im(RandGibbsNoise, dict(prob=1.0, alpha=(0.6, 0.8)), data)\n    create_transform_im(RandGibbsNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, alpha=(0.6, 0.8)), data)\n    create_transform_im(ShiftIntensity, dict(offset=1), data, colorbar=True)\n    create_transform_im(ShiftIntensityd, dict(keys=CommonKeys.IMAGE, offset=1), data, colorbar=True)\n    create_transform_im(RandShiftIntensity, dict(prob=1.0, offsets=(10, 20)), data, colorbar=True)\n    create_transform_im(\n        RandShiftIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, offsets=(10, 20)), data, colorbar=True\n    )\n    create_transform_im(StdShiftIntensity, dict(factor=10), data, colorbar=True)\n    create_transform_im(StdShiftIntensityd, dict(keys=CommonKeys.IMAGE, factor=10), data, colorbar=True)\n    create_transform_im(RandStdShiftIntensity, dict(prob=1.0, factors=(5, 10)), data, colorbar=True)\n    create_transform_im(\n        RandStdShiftIntensityd, dict(keys=CommonKeys.IMAGE, prob=1.0, factors=(5, 10)), data, colorbar=True\n    )\n    create_transform_im(RandBiasField, dict(prob=1, coeff_range=(0.2, 0.3)), data)\n    create_transform_im(RandBiasFieldd, dict(keys=CommonKeys.IMAGE, prob=1, coeff_range=(0.2, 0.3)), data)\n    create_transform_im(NormalizeIntensity, dict(subtrahend=0, divisor=10), data, colorbar=True)\n    create_transform_im(NormalizeIntensityd, dict(keys=CommonKeys.IMAGE, subtrahend=0, divisor=10), data, colorbar=True)\n    create_transform_im(ThresholdIntensity, dict(threshold=0.4, above=False, cval=0.9), data, colorbar=True)\n    create_transform_im(\n        ThresholdIntensityd, dict(keys=CommonKeys.IMAGE, threshold=0.4, above=False, cval=0.9), data, colorbar=True\n    )\n    create_transform_im(ScaleIntensityRange, dict(a_min=0, a_max=1, b_min=1, b_max=10), data, colorbar=True)\n    create_transform_im(\n        ScaleIntensityRanged, dict(keys=CommonKeys.IMAGE, a_min=0, a_max=1, b_min=1, b_max=10), data, colorbar=True\n    )\n    create_transform_im(ScaleIntensityRangePercentiles, dict(lower=5, upper=95, b_min=1, b_max=10), data, colorbar=True)\n    create_transform_im(\n        ScaleIntensityRangePercentilesd,\n        dict(keys=CommonKeys.IMAGE, lower=5, upper=95, b_min=1, b_max=10),\n        data,\n        colorbar=True,\n    )\n    create_transform_im(AdjustContrast, dict(gamma=2), data, colorbar=True)\n    create_transform_im(AdjustContrastd, dict(keys=CommonKeys.IMAGE, gamma=2), data, colorbar=True)\n    create_transform_im(RandAdjustContrast, dict(prob=1, gamma=(1.5, 2)), data, colorbar=True)\n    create_transform_im(RandAdjustContrastd, dict(keys=CommonKeys.IMAGE, prob=1, gamma=(1.5, 2)), data, colorbar=True)\n    create_transform_im(MaskIntensity, dict(mask_data=data[CommonKeys.IMAGE], select_fn=lambda x: x > 0.3), data)\n    create_transform_im(\n        MaskIntensityd, dict(keys=CommonKeys.IMAGE, mask_key=CommonKeys.IMAGE, select_fn=lambda x: x > 0.3), data\n    )\n    create_transform_im(ForegroundMask, dict(invert=True), data)\n    create_transform_im(ForegroundMaskd, dict(keys=CommonKeys.IMAGE, invert=True), data)\n    create_transform_im(GaussianSmooth, dict(sigma=2), data)\n    create_transform_im(GaussianSmoothd, dict(keys=CommonKeys.IMAGE, sigma=2), data)\n    create_transform_im(MedianSmooth, dict(radius=3), data)\n    create_transform_im(MedianSmoothD, dict(keys=keys, radius=1), data)\n    create_transform_im(RandGaussianSmooth, dict(prob=1.0, sigma_x=(1, 2)), data)\n    create_transform_im(RandGaussianSmoothd, dict(keys=CommonKeys.IMAGE, prob=1.0, sigma_x=(1, 2)), data)\n    create_transform_im(GaussianSharpen, dict(), GaussianSmoothd(CommonKeys.IMAGE, 2)(data))\n    create_transform_im(GaussianSharpend, dict(keys=CommonKeys.IMAGE), GaussianSmoothd(CommonKeys.IMAGE, 2)(data))\n    create_transform_im(RandGaussianSharpen, dict(prob=1), GaussianSmoothd(CommonKeys.IMAGE, 2)(data))\n    create_transform_im(\n        RandGaussianSharpend, dict(keys=CommonKeys.IMAGE, prob=1), GaussianSmoothd(CommonKeys.IMAGE, 2)(data)\n    )\n    create_transform_im(RandHistogramShift, dict(prob=1, num_control_points=3), data, colorbar=True)\n    create_transform_im(\n        RandHistogramShiftd, dict(keys=CommonKeys.IMAGE, prob=1, num_control_points=3), data, colorbar=True\n    )\n    create_transform_im(RandCoarseDropout, dict(prob=1, holes=200, spatial_size=20, fill_value=0), data)\n    create_transform_im(\n        RandCoarseDropoutd, dict(keys=CommonKeys.IMAGE, prob=1, holes=200, spatial_size=20, fill_value=0), data\n    )\n    create_transform_im(RandCoarseShuffle, dict(prob=1, holes=200, spatial_size=20), data)\n    create_transform_im(RandCoarseShuffled, dict(keys=CommonKeys.IMAGE, prob=1, holes=200, spatial_size=20), data)\n    create_transform_im(HistogramNormalize, dict(num_bins=10), data)\n    create_transform_im(HistogramNormalized, dict(keys=CommonKeys.IMAGE, num_bins=10), data)\n    create_transform_im(SpatialPad, dict(spatial_size=(300, 300, 300)), data)\n    create_transform_im(SpatialPadd, dict(keys=keys, spatial_size=(300, 300, 300)), data)\n    create_transform_im(BorderPad, dict(spatial_border=10), data)\n    create_transform_im(BorderPadd, dict(keys=keys, spatial_border=10), data)\n    create_transform_im(SpatialCrop, dict(roi_center=(75, 75, 75), roi_size=(100, 100, 100)), data)\n    create_transform_im(SpatialCropd, dict(keys=keys, roi_center=(75, 75, 75), roi_size=(100, 100, 100)), data)\n    create_transform_im(CenterSpatialCrop, dict(roi_size=(100, 100, 100)), data)\n    create_transform_im(CenterSpatialCropd, dict(keys=keys, roi_size=(100, 100, 100)), data)\n    create_transform_im(RandSpatialCrop, dict(roi_size=(100, 100, 100), random_size=False), data)\n    create_transform_im(RandSpatialCropd, dict(keys=keys, roi_size=(100, 100, 100), random_size=False), data)\n    create_transform_im(RandSpatialCropSamples, dict(num_samples=4, roi_size=(100, 100, 100), random_size=False), data)\n    create_transform_im(\n        RandSpatialCropSamplesd, dict(keys=keys, num_samples=4, roi_size=(100, 100, 100), random_size=False), data\n    )\n    create_transform_im(\n        RandWeightedCrop, dict(spatial_size=(100, 100, 100), num_samples=4, weight_map=data[CommonKeys.IMAGE] > 0), data\n    )\n    create_transform_im(\n        RandWeightedCropd, dict(keys=keys, spatial_size=(100, 100, 100), num_samples=4, w_key=CommonKeys.IMAGE), data\n    )\n    create_transform_im(\n        RandCropByPosNegLabel,\n        dict(spatial_size=(100, 100, 100), label=data[CommonKeys.LABEL], neg=0, num_samples=4),\n        data,\n    )\n    create_transform_im(\n        RandCropByPosNegLabeld,\n        dict(keys=keys, spatial_size=(100, 100, 100), label_key=CommonKeys.LABEL, neg=0, num_samples=4),\n        data,\n    )\n    create_transform_im(\n        RandCropByLabelClasses,\n        dict(\n            spatial_size=(100, 100, 100), label=data[CommonKeys.LABEL] > 0, num_classes=2, ratios=[0, 1], num_samples=4\n        ),\n        data,\n    )\n    create_transform_im(\n        RandCropByLabelClassesd,\n        dict(\n            keys=keys,\n            spatial_size=(100, 100, 100),\n            label_key=CommonKeys.LABEL,\n            num_classes=2,\n            ratios=[0, 1],\n            num_samples=4,\n        ),\n        data,\n    )\n    create_transform_im(ResizeWithPadOrCrop, dict(spatial_size=(100, 100, 100)), data)\n    create_transform_im(ResizeWithPadOrCropd, dict(keys=keys, spatial_size=(100, 100, 100)), data)\n    create_transform_im(RandScaleCrop, dict(roi_scale=0.4), data)\n    create_transform_im(RandScaleCropd, dict(keys=keys, roi_scale=0.4), data)\n    create_transform_im(CenterScaleCrop, dict(roi_scale=0.4), data)\n    create_transform_im(CenterScaleCropd, dict(keys=keys, roi_scale=0.4), data)\n    create_transform_im(AsDiscrete, dict(to_onehot=None, threshold=10), data, is_post=True, colorbar=True)\n    create_transform_im(AsDiscreted, dict(keys=CommonKeys.LABEL, to_onehot=None, threshold=10), data, is_post=True)\n    create_transform_im(LabelFilter, dict(applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True)\n    create_transform_im(\n        LabelFilterd, dict(keys=CommonKeys.LABEL, applied_labels=(1, 2, 3, 4, 5, 6)), data, is_post=True\n    )\n    create_transform_im(LabelToContour, dict(), data, is_post=True)\n    create_transform_im(LabelToContourd, dict(keys=CommonKeys.LABEL), data, is_post=True)\n    create_transform_im(Spacing, dict(pixdim=(5, 5, 5)), data)\n    create_transform_im(Spacingd, dict(keys=keys, pixdim=(5, 5, 5), mode=[\"bilinear\", \"nearest\"]), data)\n    create_transform_im(RandAxisFlip, dict(prob=1), data)\n    create_transform_im(RandAxisFlipd, dict(keys=keys, prob=1), data)\n    create_transform_im(Resize, dict(spatial_size=(100, 100, 100)), data)\n    create_transform_im(Resized, dict(keys=keys, spatial_size=(100, 100, 100), mode=[\"area\", \"nearest\"]), data)\n    data_binary = deepcopy(data)\n    data_binary[CommonKeys.LABEL] = (data_binary[CommonKeys.LABEL] > 0).astype(np.float32)\n    create_transform_im(KeepLargestConnectedComponent, dict(applied_labels=1), data_binary, is_post=True, ndim=2)\n    create_transform_im(\n        KeepLargestConnectedComponentd, dict(keys=CommonKeys.LABEL, applied_labels=1), data_binary, is_post=True, ndim=2\n    )\n    create_transform_im(RemoveSmallObjects, dict(min_size=100), data_binary, is_post=True, ndim=2)\n    create_transform_im(\n        RemoveSmallObjectsd, dict(keys=CommonKeys.LABEL, min_size=100), data_binary, is_post=True, ndim=2\n    )\n    create_transform_im(\n        GridDistortion, dict(num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode=\"nearest\", padding_mode=\"zeros\"), data\n    )\n    create_transform_im(\n        GridDistortiond,\n        dict(\n            keys=keys, num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode=[\"bilinear\", \"nearest\"], padding_mode=\"zeros\"\n        ),\n        data,\n    )\n    create_transform_im(RandGridDistortion, dict(num_cells=3, prob=1.0, distort_limit=(-0.1, 0.1)), data)\n    create_transform_im(\n        RandGridDistortiond,\n        dict(keys=keys, num_cells=4, prob=1.0, distort_limit=(-0.2, 0.2), mode=[\"bilinear\", \"nearest\"]),\n        data,\n    )\n    create_transform_im(\n        RandSmoothFieldAdjustContrast, dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0), data\n    )\n    create_transform_im(\n        RandSmoothFieldAdjustContrastd,\n        dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0),\n        data,\n    )\n    create_transform_im(\n        RandSmoothFieldAdjustIntensity,\n        dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)),\n        data,\n    )\n    create_transform_im(\n        RandSmoothFieldAdjustIntensityd,\n        dict(keys=keys, spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, gamma=(0.5, 4.5)),\n        data,\n    )\n\n    create_transform_im(\n        RandSmoothDeform,\n        dict(spatial_size=(217, 217, 217), rand_size=(10, 10, 10), prob=1.0, def_range=0.05, grid_mode=\"bilinear\"),\n        data,\n    )\n    create_transform_im(\n        RandSmoothDeformd,\n        dict(\n            keys=keys,\n            spatial_size=(217, 217, 217),\n            rand_size=(10, 10, 10),\n            prob=1.0,\n            def_range=0.05,\n            grid_mode=\"bilinear\",\n        ),\n        data,\n    )\n"
  },
  {
    "path": "monai/transforms/utils_morphological_ops.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom monai.config import NdarrayOrTensor\nfrom monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep\n\n__all__ = [\"erode\", \"dilate\"]\n\n\ndef erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:\n    \"\"\"\n    Erode 2D/3D binary mask.\n\n    Args:\n        mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.\n        filter_size: erosion filter size, has to be odd numbers, default to be 3.\n        pad_value: the filled value for padding. We need to pad the input before filtering\n                   to keep the output with the same size as input. Usually use default value\n                   and not changed.\n\n    Return:\n        eroded mask, same shape and data type as input.\n\n    Example:\n\n        .. code-block:: python\n\n            # define a naive mask\n            mask = torch.zeros(3,2,3,3,3)\n            mask[:,:,1,1,1] = 1.0\n            filter_size = 3\n            erode_result = erode(mask, filter_size)  # expect torch.zeros(3,2,3,3,3)\n            dilate_result = dilate(mask, filter_size)  # expect torch.ones(3,2,3,3,3)\n    \"\"\"\n    mask_t, *_ = convert_data_type(mask, torch.Tensor)\n    res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value)\n    res_mask: NdarrayOrTensor\n    res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)\n    return res_mask\n\n\ndef dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor:\n    \"\"\"\n    Dilate 2D/3D binary mask.\n\n    Args:\n        mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.\n        filter_size: dilation filter size, has to be odd numbers, default to be 3.\n        pad_value: the filled value for padding. We need to pad the input before filtering\n                   to keep the output with the same size as input. Usually use default value\n                   and not changed.\n\n    Return:\n        dilated mask, same shape and data type as input.\n\n    Example:\n\n        .. code-block:: python\n\n            # define a naive mask\n            mask = torch.zeros(3,2,3,3,3)\n            mask[:,:,1,1,1] = 1.0\n            filter_size = 3\n            erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)\n            dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)\n    \"\"\"\n    mask_t, *_ = convert_data_type(mask, torch.Tensor)\n    res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value)\n    res_mask: NdarrayOrTensor\n    res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)\n    return res_mask\n\n\ndef get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor:\n    \"\"\"\n    Apply a morphological filter to a 2D/3D binary mask tensor.\n\n    Args:\n        mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.\n        filter_size: morphological filter size, has to be odd numbers.\n        pad_value: the filled value for padding. We need to pad the input before filtering\n                   to keep the output with the same size as input.\n\n    Return:\n        Tensor: Morphological filter result mask, same shape as input.\n    \"\"\"\n    spatial_dims = len(mask_t.shape) - 2\n    if spatial_dims not in [2, 3]:\n        raise ValueError(\n            f\"spatial_dims must be either 2 or 3, \"\n            f\"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}.\"\n        )\n\n    # Define the structuring element\n    filter_size = ensure_tuple_rep(filter_size, spatial_dims)\n    if any(size % 2 == 0 for size in filter_size):\n        raise ValueError(f\"All dimensions in filter_size must be odd numbers, got {filter_size}.\")\n\n    structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device)\n\n    # Pad the input tensor to handle border pixels\n    # Calculate padding size\n    pad_size = [size // 2 for size in filter_size for _ in range(2)]\n\n    input_padded = F.pad(mask_t.float(), pad_size, mode=\"constant\", value=pad_value)\n\n    # Apply filter operation\n    conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d\n    output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])\n\n    return output\n\n\ndef erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:\n    \"\"\"\n    Erode 2D/3D binary mask with data type as torch tensor.\n\n    Args:\n        mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.\n        filter_size: erosion filter size, has to be odd numbers, default to be 3.\n        pad_value: the filled value for padding. We need to pad the input before filtering\n                   to keep the output with the same size as input. Usually use default value\n                   and not changed.\n\n    Return:\n        Tensor: eroded mask, same shape as input.\n    \"\"\"\n\n    output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)\n\n    # Set output values based on the minimum value within the structuring element\n    output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0)\n\n    return output\n\n\ndef dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:\n    \"\"\"\n    Dilate 2D/3D binary mask with data type as torch tensor.\n\n    Args:\n        mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.\n        filter_size: dilation filter size, has to be odd numbers, default to be 3.\n        pad_value: the filled value for padding. We need to pad the input before filtering\n                   to keep the output with the same size as input. Usually use default value\n                   and not changed.\n\n    Return:\n        Tensor: dilated mask, same shape as input.\n    \"\"\"\n    output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)\n\n    # Set output values based on the minimum value within the structuring element\n    output = torch.where(output > 0, 1.0, 0.0)\n\n    return output\n"
  },
  {
    "path": "monai/transforms/utils_pytorch_numpy_unification.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom typing import TypeVar\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type\n\n__all__ = [\n    \"allclose\",\n    \"moveaxis\",\n    \"in1d\",\n    \"clip\",\n    \"percentile\",\n    \"where\",\n    \"argwhere\",\n    \"argsort\",\n    \"nonzero\",\n    \"floor_divide\",\n    \"unravel_index\",\n    \"unravel_indices\",\n    \"ravel\",\n    \"any_np_pt\",\n    \"maximum\",\n    \"concatenate\",\n    \"cumsum\",\n    \"isfinite\",\n    \"searchsorted\",\n    \"repeat\",\n    \"isnan\",\n    \"ascontiguousarray\",\n    \"stack\",\n    \"mode\",\n    \"unique\",\n    \"max\",\n    \"min\",\n    \"median\",\n    \"mean\",\n    \"std\",\n    \"softplus\",\n]\n\n\ndef softplus(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"stable softplus through `np.logaddexp` with equivalent implementation for torch.\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        Softplus of the input.\n    \"\"\"\n    if isinstance(x, np.ndarray):\n        return np.logaddexp(np.zeros_like(x), x)\n    return torch.logaddexp(torch.zeros_like(x), x)\n\n\ndef allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool:\n    \"\"\"`np.allclose` with equivalent implementation for torch.\"\"\"\n    b, *_ = convert_to_dst_type(b, a, wrap_sequence=True)\n    if isinstance(a, np.ndarray):\n        return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)\n    return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)  # type: ignore\n\n\ndef moveaxis(x: NdarrayOrTensor, src: int | Sequence[int], dst: int | Sequence[int]) -> NdarrayOrTensor:\n    \"\"\"`moveaxis` for pytorch and numpy\"\"\"\n    if isinstance(x, torch.Tensor):\n        return torch.movedim(x, src, dst)  # type: ignore\n    return np.moveaxis(x, src, dst)\n\n\ndef in1d(x, y):\n    \"\"\"`np.in1d` with equivalent implementation for torch.\"\"\"\n    if isinstance(x, np.ndarray):\n        return np.isin(x, y)\n    return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)\n\n\ndef clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:\n    \"\"\"`np.clip` with equivalent implementation for torch.\"\"\"\n    result: NdarrayOrTensor\n    if isinstance(a, np.ndarray):\n        result = np.clip(a, a_min, a_max)\n    else:\n        result = torch.clamp(a, a_min, a_max)\n    return result\n\n\ndef percentile(\n    x: NdarrayOrTensor, q, dim: int | None = None, keepdim: bool = False, **kwargs\n) -> NdarrayOrTensor | float | int:\n    \"\"\"`np.percentile` with equivalent implementation for torch.\n\n    Pytorch uses `quantile`. For more details please refer to:\n    https://pytorch.org/docs/stable/generated/torch.quantile.html.\n    https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.\n\n    Args:\n        x: input data.\n        q: percentile to compute (should in range 0 <= q <= 100).\n        dim: the dim along which the percentiles are computed. default is to compute the percentile\n            along a flattened version of the array.\n        keepdim: whether the output data has dim retained or not.\n        kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:\n            https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.\n\n    Returns:\n        Resulting value (scalar)\n    \"\"\"\n    q_np = convert_data_type(q, output_type=np.ndarray, wrap_sequence=True)[0]\n    if ((q_np < 0) | (q_np > 100)).any():\n        raise ValueError(f\"q values must be in [0, 100], got values: {q}.\")\n    result: NdarrayOrTensor | float | int\n    if isinstance(x, np.ndarray) or (isinstance(x, torch.Tensor) and torch.numel(x) > 1_000_000):  # pytorch#64947\n        _x = convert_data_type(x, output_type=np.ndarray)[0]\n        result = np.percentile(_x, q_np, axis=dim, keepdims=keepdim, **kwargs)\n        result = convert_to_dst_type(result, x)[0]\n    else:\n        q = convert_to_dst_type(q_np / 100.0, x)[0]\n        result = torch.quantile(x, q, dim=dim, keepdim=keepdim)\n    return result\n\n\ndef where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor:\n    \"\"\"\n    Note that `torch.where` may convert y.dtype to x.dtype.\n    \"\"\"\n    result: NdarrayOrTensor\n    if isinstance(condition, np.ndarray):\n        if x is not None:\n            result = np.where(condition, x, y)\n        else:\n            result = np.where(condition)  # type: ignore\n    else:\n        if x is not None:\n            x = torch.as_tensor(x, device=condition.device)\n            y = torch.as_tensor(y, device=condition.device, dtype=x.dtype)\n            result = torch.where(condition, x, y)\n        else:\n            result = torch.where(condition)  # type: ignore\n    return result\n\n\ndef argwhere(a: NdarrayTensor) -> NdarrayTensor:\n    \"\"\"`np.argwhere` with equivalent implementation for torch.\n\n    Args:\n        a: input data.\n\n    Returns:\n        Indices of elements that are non-zero. Indices are grouped by element.\n        This array will have shape (N, a.ndim) where N is the number of non-zero items.\n    \"\"\"\n    if isinstance(a, np.ndarray):\n        return np.argwhere(a)  # type: ignore\n    return torch.argwhere(a)  # type: ignore\n\n\ndef argsort(a: NdarrayTensor, axis: int | None = -1) -> NdarrayTensor:\n    \"\"\"`np.argsort` with equivalent implementation for torch.\n\n    Args:\n        a: the array/tensor to sort.\n        axis: axis along which to sort.\n\n    Returns:\n        Array/Tensor of indices that sort a along the specified axis.\n    \"\"\"\n    if isinstance(a, np.ndarray):\n        return np.argsort(a, axis=axis)  # type: ignore\n    return torch.argsort(a, dim=axis)  # type: ignore\n\n\ndef nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"`np.nonzero` with equivalent implementation for torch.\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        Index unravelled for given shape\n    \"\"\"\n    if isinstance(x, np.ndarray):\n        return np.nonzero(x)[0]\n    return torch.nonzero(x).flatten()\n\n\ndef floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:\n    \"\"\"`np.floor_divide` with equivalent implementation for torch.\n\n    As of pt1.8, use `torch.div(..., rounding_mode=\"floor\")`, and\n    before that, use `torch.floor_divide`.\n\n    Args:\n        a: first array/tensor\n        b: scalar to divide by\n\n    Returns:\n        Element-wise floor division between two arrays/tensors.\n    \"\"\"\n    if isinstance(a, torch.Tensor):\n        return torch.floor_divide(a, b)\n    else:\n        return np.asarray(np.floor_divide(a, b))\n\n\ndef unravel_index(idx, shape) -> NdarrayOrTensor:\n    \"\"\"`np.unravel_index` with equivalent implementation for torch.\n\n    Args:\n        idx: index to unravel.\n        shape: shape of array/tensor.\n\n    Returns:\n        Index unravelled for given shape\n    \"\"\"\n    if isinstance(idx, torch.Tensor):\n        coord = []\n        for dim in reversed(shape):\n            coord.append(idx % dim)\n            idx = floor_divide(idx, dim)\n        return torch.stack(coord[::-1])\n    return np.asarray(np.unravel_index(idx, shape))\n\n\ndef unravel_indices(idx, shape) -> NdarrayOrTensor:\n    \"\"\"Computing unravel coordinates from indices.\n\n    Args:\n        idx: a sequence of indices to unravel.\n        shape: shape of array/tensor.\n\n    Returns:\n        Stacked indices unravelled for given shape\n    \"\"\"\n    lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack\n    return lib_stack([unravel_index(i, shape) for i in idx])  # type: ignore\n\n\ndef ravel(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"`np.ravel` with equivalent implementation for torch.\n\n    Args:\n        x: array/tensor to ravel.\n\n    Returns:\n        Return a contiguous flattened array/tensor.\n    \"\"\"\n    if isinstance(x, torch.Tensor):\n        if hasattr(torch, \"ravel\"):  # `ravel` is new in torch 1.8.0\n            return x.ravel()\n        return x.flatten().contiguous()\n    return np.ravel(x)\n\n\ndef any_np_pt(x: NdarrayOrTensor, axis: int | Sequence[int]) -> NdarrayOrTensor:\n    \"\"\"`np.any` with equivalent implementation for torch.\n\n    For pytorch, convert to boolean for compatibility with older versions.\n\n    Args:\n        x: input array/tensor.\n        axis: axis to perform `any` over.\n\n    Returns:\n        Return a contiguous flattened array/tensor.\n    \"\"\"\n    if isinstance(x, np.ndarray):\n        return np.any(x, axis)  # type: ignore\n\n    # pytorch can't handle multiple dimensions to `any` so loop across them\n    axis = [axis] if not isinstance(axis, Sequence) else axis\n    for ax in axis:\n        try:\n            x = torch.any(x, ax)\n        except RuntimeError:\n            # older versions of pytorch require the input to be cast to boolean\n            x = torch.any(x.bool(), ax)\n    return x\n\n\ndef maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"`np.maximum` with equivalent implementation for torch.\n\n    Args:\n        a: first array/tensor.\n        b: second array/tensor.\n\n    Returns:\n        Element-wise maximum between two arrays/tensors.\n    \"\"\"\n    if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):\n        return torch.maximum(a, b)\n    return np.maximum(a, b)\n\n\ndef concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> NdarrayOrTensor:\n    \"\"\"`np.concatenate` with equivalent implementation for torch (`torch.cat`).\"\"\"\n    if isinstance(to_cat[0], np.ndarray):\n        return np.concatenate(to_cat, axis, out)  # type: ignore\n    return torch.cat(to_cat, dim=axis, out=out)  # type: ignore\n\n\ndef cumsum(a: NdarrayOrTensor, axis=None, **kwargs) -> NdarrayOrTensor:\n    \"\"\"\n    `np.cumsum` with equivalent implementation for torch.\n\n    Args:\n        a: input data to compute cumsum.\n        axis: expected axis to compute cumsum.\n        kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:\n            https://pytorch.org/docs/stable/generated/torch.cumsum.html.\n\n    \"\"\"\n\n    if isinstance(a, np.ndarray):\n        return np.cumsum(a, axis)  # type: ignore\n    if axis is None:\n        return torch.cumsum(a[:], 0, **kwargs)\n    return torch.cumsum(a, dim=axis, **kwargs)\n\n\ndef isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"`np.isfinite` with equivalent implementation for torch.\"\"\"\n    if not isinstance(x, torch.Tensor):\n        return np.isfinite(x)  # type: ignore\n    return torch.isfinite(x)\n\n\ndef searchsorted(a: NdarrayTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs) -> NdarrayTensor:\n    \"\"\"\n    `np.searchsorted` with equivalent implementation for torch.\n\n    Args:\n        a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.\n        v: containing the search values.\n        right: if False, return the first suitable location that is found, if True, return the last such index.\n        sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.\n        kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:\n            https://pytorch.org/docs/stable/generated/torch.searchsorted.html.\n\n    \"\"\"\n    side = \"right\" if right else \"left\"\n    if isinstance(a, np.ndarray):\n        return np.searchsorted(a, v, side, sorter)  # type: ignore\n    return torch.searchsorted(a, v, right=right, **kwargs)  # type: ignore\n\n\ndef repeat(a: NdarrayOrTensor, repeats: int, axis: int | None = None, **kwargs) -> NdarrayOrTensor:\n    \"\"\"\n    `np.repeat` with equivalent implementation for torch (`repeat_interleave`).\n\n    Args:\n        a: input data to repeat.\n        repeats: number of repetitions for each element, repeats is broadcast to fit the shape of the given axis.\n        axis: axis along which to repeat values.\n        kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:\n            https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.\n\n    \"\"\"\n    if isinstance(a, np.ndarray):\n        return np.repeat(a, repeats, axis)\n    return torch.repeat_interleave(a, repeats, dim=axis, **kwargs)\n\n\ndef isnan(x: NdarrayOrTensor) -> NdarrayOrTensor:\n    \"\"\"`np.isnan` with equivalent implementation for torch.\n\n    Args:\n        x: array/tensor.\n\n    \"\"\"\n    if isinstance(x, np.ndarray):\n        return np.isnan(x)  # type: ignore\n    return torch.isnan(x)\n\n\nT = TypeVar(\"T\")\n\n\ndef ascontiguousarray(x: NdarrayTensor | T, **kwargs) -> NdarrayOrTensor | T:\n    \"\"\"`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`).\n\n    Args:\n        x: array/tensor.\n        kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:\n            https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.\n\n    \"\"\"\n    if isinstance(x, np.ndarray):\n        if x.ndim == 0:\n            return x\n        return np.ascontiguousarray(x)\n    if isinstance(x, torch.Tensor):\n        return x.contiguous(**kwargs)\n    return x\n\n\ndef stack(x: Sequence[NdarrayTensor], dim: int) -> NdarrayTensor:\n    \"\"\"`np.stack` with equivalent implementation for torch.\n\n    Args:\n        x: array/tensor.\n        dim: dimension along which to perform the stack (referred to as `axis` by numpy).\n    \"\"\"\n    if isinstance(x[0], np.ndarray):\n        return np.stack(x, dim)  # type: ignore\n    return torch.stack(x, dim)  # type: ignore\n\n\ndef mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor:\n    \"\"\"`torch.mode` with equivalent implementation for numpy.\n\n    Args:\n        x: array/tensor.\n        dim: dimension along which to perform `mode` (referred to as `axis` by numpy).\n        to_long: convert input to long before performing mode.\n    \"\"\"\n    dtype = torch.int64 if to_long else None\n    x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype)\n    o_t = torch.mode(x_t, dim).values\n    o, *_ = convert_to_dst_type(o_t, x)\n    return o\n\n\ndef unique(x: NdarrayTensor, **kwargs) -> NdarrayTensor:\n    \"\"\"`torch.unique` with equivalent implementation for numpy.\n\n    Args:\n        x: array/tensor.\n    \"\"\"\n    return np.unique(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.unique(x, **kwargs)  # type: ignore\n\n\ndef linalg_inv(x: NdarrayTensor) -> NdarrayTensor:\n    \"\"\"`torch.linalg.inv` with equivalent implementation for numpy.\n\n    Args:\n        x: array/tensor.\n    \"\"\"\n    if isinstance(x, torch.Tensor) and hasattr(torch, \"inverse\"):  # pytorch 1.7.0\n        return torch.inverse(x)  # type: ignore\n    return torch.linalg.inv(x) if isinstance(x, torch.Tensor) else np.linalg.inv(x)  # type: ignore\n\n\ndef max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:\n    \"\"\"`torch.max` with equivalent implementation for numpy\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        the maximum of x.\n\n    \"\"\"\n\n    ret: NdarrayTensor\n    if dim is None:\n        ret = np.max(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.max(x, **kwargs)  # type: ignore\n    else:\n        if isinstance(x, (np.ndarray, list)):\n            ret = np.max(x, axis=dim, **kwargs)\n        else:\n            ret = torch.max(x, int(dim), **kwargs)  # type: ignore\n\n    return ret[0] if isinstance(ret, tuple) else ret\n\n\ndef mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:\n    \"\"\"`torch.mean` with equivalent implementation for numpy\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        the mean of x\n    \"\"\"\n\n    ret: NdarrayTensor\n    if dim is None:\n        ret = np.mean(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.mean(x, **kwargs)  # type: ignore\n    else:\n        if isinstance(x, (np.ndarray, list)):\n            ret = np.mean(x, axis=dim, **kwargs)\n        else:\n            ret = torch.mean(x, int(dim), **kwargs)  # type: ignore\n\n    return ret\n\n\ndef median(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:\n    \"\"\"`torch.median` with equivalent implementation for numpy\n\n    Args:\n        x: array/tensor.\n\n    Returns\n        the median of x.\n    \"\"\"\n\n    ret: NdarrayTensor\n    if dim is None:\n        ret = np.median(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.median(x, **kwargs)  # type: ignore\n    else:\n        if isinstance(x, (np.ndarray, list)):\n            ret = np.median(x, axis=dim, **kwargs)\n        else:\n            ret = torch.median(x, int(dim), **kwargs)  # type: ignore\n\n    return ret\n\n\ndef min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:\n    \"\"\"`torch.min` with equivalent implementation for numpy\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        the minimum of x.\n    \"\"\"\n\n    ret: NdarrayTensor\n    if dim is None:\n        ret = np.min(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.min(x, **kwargs)  # type: ignore\n    else:\n        if isinstance(x, (np.ndarray, list)):\n            ret = np.min(x, axis=dim, **kwargs)\n        else:\n            ret = torch.min(x, int(dim), **kwargs)  # type: ignore\n\n    return ret[0] if isinstance(ret, tuple) else ret\n\n\ndef std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor:\n    \"\"\"`torch.std` with equivalent implementation for numpy\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        the standard deviation of x.\n    \"\"\"\n\n    ret: NdarrayTensor\n    if dim is None:\n        ret = np.std(x) if isinstance(x, (np.ndarray, list)) else torch.std(x, unbiased)  # type: ignore\n    else:\n        if isinstance(x, (np.ndarray, list)):\n            ret = np.std(x, axis=dim)\n        else:\n            ret = torch.std(x, int(dim), unbiased)  # type: ignore\n\n    return ret\n\n\ndef sum(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:\n    \"\"\"`torch.sum` with equivalent implementation for numpy\n\n    Args:\n        x: array/tensor.\n\n    Returns:\n        the sum of x.\n    \"\"\"\n\n    ret: NdarrayTensor\n    if dim is None:\n        ret = np.sum(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.sum(x, **kwargs)  # type: ignore\n    else:\n        if isinstance(x, (np.ndarray, list)):\n            ret = np.sum(x, axis=dim, **kwargs)\n        else:\n            ret = torch.sum(x, int(dim), **kwargs)  # type: ignore\n\n    return ret\n"
  },
  {
    "path": "monai/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .component_store import ComponentStore\nfrom .decorators import MethodReplacer, RestartGenerator\nfrom .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default\nfrom .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather\nfrom .enums import (\n    AdversarialIterationEvents,\n    AdversarialKeys,\n    AlgoKeys,\n    Average,\n    BlendMode,\n    BoxModeName,\n    BundleProperty,\n    BundlePropertyConfig,\n    ChannelMatching,\n    ColorOrder,\n    CommonKeys,\n    CompInitMode,\n    DiceCEReduction,\n    DownsampleMode,\n    EngineStatsKeys,\n    FastMRIKeys,\n    ForwardMode,\n    GanKeys,\n    GridPatchSort,\n    GridSampleMode,\n    GridSamplePadMode,\n    HoVerNetBranch,\n    HoVerNetMode,\n    IgniteInfo,\n    InterpolateMode,\n    JITMetadataKeys,\n    LazyAttr,\n    LossReduction,\n    MetaKeys,\n    Method,\n    MetricReduction,\n    NdimageMode,\n    NumpyPadMode,\n    OrderingTransformations,\n    OrderingType,\n    PatchKeys,\n    PostFix,\n    ProbMapKeys,\n    PytorchPadMode,\n    SkipMode,\n    SpaceKeys,\n    SplineMode,\n    StrEnum,\n    TraceKeys,\n    TraceStatusKeys,\n    TransformBackends,\n    UpsampleMode,\n    Weight,\n    WSIPatchKeys,\n)\nfrom .jupyter_utils import StatusMembers, ThreadContainer\nfrom .misc import (\n    MAX_SEED,\n    ImageMetaKey,\n    MONAIEnvVars,\n    check_kwargs_exist_in_class_init,\n    check_parent_dir,\n    copy_to_device,\n    ensure_tuple,\n    ensure_tuple_rep,\n    ensure_tuple_size,\n    fall_back_tuple,\n    first,\n    flatten_dict,\n    get_seed,\n    has_option,\n    is_immutable,\n    is_module_ver_at_least,\n    is_scalar,\n    is_scalar_tensor,\n    is_sqrt,\n    issequenceiterable,\n    list_to_dict,\n    path_to_uri,\n    pprint_edges,\n    progress_bar,\n    run_cmd,\n    sample_slices,\n    save_obj,\n    set_determinism,\n    star_zip_with,\n    str2bool,\n    str2list,\n    to_tuple_of_dictionaries,\n    unsqueeze_left,\n    unsqueeze_right,\n    zip_with,\n)\nfrom .module import (\n    InvalidPyTorchVersionError,\n    OptionalImportError,\n    allow_missing_reference,\n    compute_capabilities_after,\n    damerau_levenshtein_distance,\n    exact_version,\n    get_full_type_name,\n    get_package_version,\n    get_torch_version_tuple,\n    instantiate,\n    load_submodules,\n    look_up_option,\n    min_version,\n    optional_import,\n    pytorch_after,\n    require_pkg,\n    run_debug,\n    run_eval,\n    version_geq,\n    version_leq,\n)\nfrom .nvtx import Range\nfrom .ordering import Ordering\nfrom .profiling import (\n    PerfContext,\n    ProfileHandler,\n    WorkflowProfiler,\n    select_transform_call,\n    torch_profiler_full,\n    torch_profiler_time_cpu_gpu,\n    torch_profiler_time_end_to_end,\n)\nfrom .state_cacher import StateCacher\nfrom .tf32 import detect_default_tf32, has_ampere_or_later\nfrom .type_conversion import (\n    convert_data_type,\n    convert_to_cupy,\n    convert_to_dst_type,\n    convert_to_list,\n    convert_to_numpy,\n    convert_to_tensor,\n    dtype_numpy_to_torch,\n    dtype_torch_to_numpy,\n    get_dtype,\n    get_dtype_string,\n    get_equivalent_dtype,\n    get_numpy_dtype_from_string,\n    get_torch_dtype_from_string,\n)\n\n# have to explicitly bring these in here to resolve circular import issues\n"
  },
  {
    "path": "monai/utils/component_store.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections import namedtuple\nfrom collections.abc import Iterable\nfrom keyword import iskeyword\nfrom textwrap import dedent, indent\nfrom typing import Any, Callable, TypeVar\n\nT = TypeVar(\"T\")\n\n\ndef is_variable(name):\n    \"\"\"Returns True if `name` is a valid Python variable name and also not a keyword.\"\"\"\n    return name.isidentifier() and not iskeyword(name)\n\n\nclass ComponentStore:\n    \"\"\"\n    Represents a storage object for other objects (specifically functions) keyed to a name with a description.\n\n    These objects act as global named places for storing components for objects parameterised by component names.\n    Typically this is functions although other objects can be added. Printing a component store will produce a\n    list of members along with their docstring information if present.\n\n    Example:\n\n    .. code-block:: python\n\n        TestStore = ComponentStore(\"Test Store\", \"A test store for demo purposes\")\n\n        @TestStore.add_def(\"my_func_name\", \"Some description of your function\")\n        def _my_func(a, b):\n            '''A description of your function here.'''\n            return a * b\n\n        print(TestStore)  # will print out name, description, and 'my_func_name' with the docstring\n\n        func = TestStore[\"my_func_name\"]\n        result = func(7, 6)\n\n    \"\"\"\n\n    _Component = namedtuple(\"_Component\", (\"description\", \"value\"))  # internal value pair\n\n    def __init__(self, name: str, description: str) -> None:\n        self.components: dict[str, ComponentStore._Component] = {}\n        self.name: str = name\n        self.description: str = description\n\n        self.__doc__ = f\"Component Store '{name}': {description}\\n{self.__doc__ or ''}\".strip()\n\n    def add(self, name: str, desc: str, value: T) -> T:\n        \"\"\"Store the object `value` under the name `name` with description `desc`.\"\"\"\n        if not is_variable(name):\n            raise ValueError(\"Name of component must be valid Python identifier\")\n\n        self.components[name] = self._Component(desc, value)\n        return value\n\n    def add_def(self, name: str, desc: str) -> Callable:\n        \"\"\"Returns a decorator which stores the decorated function under `name` with description `desc`.\"\"\"\n\n        def deco(func):\n            \"\"\"Decorator to add a function to a store.\"\"\"\n            return self.add(name, desc, func)\n\n        return deco\n\n    @property\n    def names(self) -> tuple[str, ...]:\n        \"\"\"\n        Produces all factory names.\n        \"\"\"\n\n        return tuple(self.components)\n\n    def __contains__(self, name: str) -> bool:\n        \"\"\"Returns True if the given name is stored.\"\"\"\n        return name in self.components\n\n    def __len__(self) -> int:\n        \"\"\"Returns the number of stored components.\"\"\"\n        return len(self.components)\n\n    def __iter__(self) -> Iterable:\n        \"\"\"Yields name/component pairs.\"\"\"\n        for k, v in self.components.items():\n            yield k, v.value\n\n    def __str__(self):\n        result = f\"Component Store '{self.name}': {self.description}\\nAvailable components:\"\n        for k, v in self.components.items():\n            result += f\"\\n* {k}:\"\n\n            if hasattr(v.value, \"__doc__\") and v.value.__doc__:\n                doc = indent(dedent(v.value.__doc__.lstrip(\"\\n\").rstrip()), \"    \")\n                result += f\"\\n{doc}\\n\"\n            else:\n                result += f\" {v.description}\"\n\n        return result\n\n    def __getattr__(self, name: str) -> Any:\n        \"\"\"Returns the stored object under the given name.\"\"\"\n        if name in self.components:\n            return self.components[name].value\n        else:\n            return self.__getattribute__(name)\n\n    def __getitem__(self, name: str) -> Any:\n        \"\"\"Returns the stored object under the given name.\"\"\"\n        if name in self.components:\n            return self.components[name].value\n        else:\n            raise ValueError(f\"Component '{name}' not found\")\n"
  },
  {
    "path": "monai/utils/decorators.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom functools import wraps\n\n__all__ = [\"RestartGenerator\", \"MethodReplacer\"]\n\nfrom collections.abc import Generator\nfrom typing import Callable\n\n\nclass RestartGenerator:\n    \"\"\"\n    Wraps a generator callable which will be called whenever this class is iterated and its result returned. This is\n    used to create an iterator which can start iteration over the given generator multiple times.\n    \"\"\"\n\n    def __init__(self, create_gen: Callable[[], Generator]) -> None:\n        self.create_gen = create_gen\n\n    def __iter__(self) -> Generator:\n        return self.create_gen()\n\n\nclass MethodReplacer:\n    \"\"\"\n    Base class for method decorators which can be used to replace methods pass to replace_method() with wrapped versions.\n    \"\"\"\n\n    replace_list_name = \"__replacemethods__\"\n\n    def __init__(self, meth: Callable) -> None:\n        self.meth = meth\n\n    def replace_method(self, meth):\n        \"\"\"\n        Return a new method to replace `meth` in the instantiated object, or `meth` to do nothing.\n        \"\"\"\n        return meth\n\n    def __set_name__(self, owner, name):\n        \"\"\"\n        Add the (name,self.replace_method) pair to the list named by replace_list_name in `owner`, creating the list and\n        replacing the constructor of `owner` if necessary. The replaced constructor will call the old one then do the\n        replacing operation of substituting, for each (name,self.replace_method) pair, the named method with the returned\n        value from self.replace_method.\n        \"\"\"\n        entry = (name, owner, self.replace_method)\n\n        if not hasattr(owner, self.replace_list_name):\n            oldinit = owner.__init__\n\n            # replace the constructor with a new one which calls the old then replaces methods\n            @wraps(oldinit)\n            def newinit(_self, *args, **kwargs):\n                oldinit(_self, *args, **kwargs)\n\n                # replace each listed method of this newly constructed object\n                for m, owner, replacer in getattr(_self, self.replace_list_name):\n                    if isinstance(_self, owner):\n                        meth = getattr(_self, m)\n                        newmeth = replacer(meth)\n                        setattr(_self, m, newmeth)\n\n            owner.__init__ = newinit\n            setattr(owner, self.replace_list_name, [entry])\n        else:\n            namelist = getattr(owner, self.replace_list_name)\n\n            if not any(nl[0] == name for nl in namelist):\n                namelist.append(entry)\n\n        setattr(owner, name, self.meth)\n"
  },
  {
    "path": "monai/utils/deprecate_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport inspect\nimport sys\nimport warnings\nfrom collections.abc import Callable\nfrom functools import wraps\nfrom types import FunctionType\nfrom typing import Any, TypeVar\n\nfrom monai import __version__\nfrom monai.utils.module import version_leq\n\n__all__ = [\"deprecated\", \"deprecated_arg\", \"DeprecatedError\", \"deprecated_arg_default\"]\nT = TypeVar(\"T\", type, Callable)\n\n\nclass DeprecatedError(Exception):\n    pass\n\n\ndef warn_deprecated(obj, msg, warning_category=FutureWarning):\n    \"\"\"\n    Issue the warning message `msg`.\n    \"\"\"\n    warnings.warn(f\"{obj}: {msg}\", category=warning_category, stacklevel=2)\n\n\ndef deprecated(\n    since: str | None = None,\n    removed: str | None = None,\n    msg_suffix: str = \"\",\n    version_val: str = __version__,\n    warning_category: type[FutureWarning] = FutureWarning,\n) -> Callable[[T], T]:\n    \"\"\"\n    Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the\n    current version and states at what version of the definition was marked as deprecated. If `removed` is given\n    this can be any version and marks when the definition was removed.\n\n    When the decorated definition is called, that is when the function is called or the class instantiated,\n    a `warning_category` is issued if `since` is given and the current version is at or later than that given.\n    a `DeprecatedError` exception is instead raised if `removed` is given and the current version is at or later\n    than that, or if neither `since` nor `removed` is provided.\n\n    The relevant docstring of the deprecating function should also be updated accordingly,\n    using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.\n    https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded\n\n    Args:\n        since: version at which the definition was marked deprecated but not removed.\n        removed: version at which the definition was/will be removed and no longer usable.\n        msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.\n        version_val: (used for testing) version to compare since and removed against, default is MONAI version.\n        warning_category: a warning category class, defaults to `FutureWarning`.\n\n    Returns:\n        Decorated definition which warns or raises exception when used\n    \"\"\"\n\n    if since is not None and removed is not None and not version_leq(since, removed):\n        raise ValueError(f\"since must be less or equal to removed, got since={since}, removed={removed}.\")\n    is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)\n    if is_not_yet_deprecated:\n        # smaller than `since`, do nothing\n        return lambda obj: obj\n\n    if since is None and removed is None:\n        # raise a DeprecatedError directly\n        is_removed = True\n        is_deprecated = True\n    else:\n        # compare the numbers\n        is_deprecated = since is not None and version_leq(since, version_val)\n        is_removed = removed is not None and version_leq(removed, version_val)\n\n    def _decorator(obj):\n        is_func = isinstance(obj, FunctionType)\n        call_obj = obj if is_func else obj.__init__\n\n        msg_prefix = f\"{'Function' if is_func else 'Class'} `{obj.__qualname__}`\"\n\n        if is_removed:\n            msg_infix = f\"was removed in version {removed}.\"\n        elif is_deprecated:\n            msg_infix = f\"has been deprecated since version {since}.\"\n            if removed is not None:\n                msg_infix += f\" It will be removed in version {removed}.\"\n        else:\n            msg_infix = \"has been deprecated.\"\n\n        msg = f\"{msg_prefix} {msg_infix} {msg_suffix}\".strip()\n\n        @wraps(call_obj)\n        def _wrapper(*args, **kwargs):\n            if is_removed:\n                raise DeprecatedError(msg)\n            if is_deprecated:\n                warn_deprecated(obj, msg, warning_category)\n\n            return call_obj(*args, **kwargs)\n\n        if is_func:\n            return _wrapper\n        obj.__init__ = _wrapper\n        return obj\n\n    return _decorator\n\n\ndef deprecated_arg(\n    name: str,\n    since: str | None = None,\n    removed: str | None = None,\n    msg_suffix: str = \"\",\n    version_val: str = __version__,\n    new_name: str | None = None,\n    warning_category: type[FutureWarning] = FutureWarning,\n) -> Callable[[T], T]:\n    \"\"\"\n    Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as\n    described in the `deprecated` decorator.\n\n    When the decorated definition is called, that is when the function is called or the class instantiated with args,\n    a `warning_category` is issued if `since` is given and the current version is at or later than that given.\n    a `DeprecatedError` exception is instead raised if `removed` is given and the current version is at or later\n    than that, or if neither `since` nor `removed` is provided.\n\n    The relevant docstring of the deprecating function should also be updated accordingly,\n    using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.\n    https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded\n\n\n    Args:\n        name: name of position or keyword argument to mark as deprecated.\n        since: version at which the argument was marked deprecated but not removed.\n        removed: version at which the argument was/will be removed and no longer usable.\n        msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.\n        version_val: (used for testing) version to compare since and removed against, default is MONAI version.\n        new_name: name of position or keyword argument to replace the deprecated argument.\n            if it is specified and the signature of the decorated function has a `kwargs`, the value to the\n            deprecated argument `name` will be removed.\n        warning_category: a warning category class, defaults to `FutureWarning`.\n\n    Returns:\n        Decorated callable which warns or raises exception when deprecated argument used.\n    \"\"\"\n\n    if version_val.startswith(\"0+\") or not f\"{version_val}\".strip()[0].isdigit():\n        # version unknown, set version_val to a large value (assuming the latest version)\n        version_val = f\"{sys.maxsize}\"\n    if since is not None and removed is not None and not version_leq(since, removed):\n        raise ValueError(f\"since must be less or equal to removed, got since={since}, removed={removed}.\")\n    is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)\n    if is_not_yet_deprecated:\n        # smaller than `since`, do nothing\n        return lambda obj: obj\n    if since is None and removed is None:\n        # raise a DeprecatedError directly\n        is_removed = True\n        is_deprecated = True\n    else:\n        # compare the numbers\n        is_deprecated = since is not None and version_leq(since, version_val)\n        is_removed = removed is not None and version_val != f\"{sys.maxsize}\" and version_leq(removed, version_val)\n\n    def _decorator(func):\n        argname = f\"{func.__module__} {func.__qualname__}:{name}\"\n\n        msg_prefix = f\"Argument `{name}`\"\n\n        if is_removed:\n            msg_infix = f\"was removed in version {removed}.\"\n        elif is_deprecated:\n            msg_infix = f\"has been deprecated since version {since}.\"\n            if removed is not None:\n                msg_infix += f\" It will be removed in version {removed}.\"\n        else:\n            msg_infix = \"has been deprecated.\"\n\n        msg = f\"{msg_prefix} {msg_infix} {msg_suffix}\".strip()\n\n        sig = inspect.signature(func)\n\n        @wraps(func)\n        def _wrapper(*args, **kwargs):\n            if new_name is not None and name in kwargs and new_name not in kwargs:\n                # replace the deprecated arg \"name\" with \"new_name\"\n                # if name is specified and new_name is not specified\n                kwargs[new_name] = kwargs[name]\n                try:\n                    _ = sig.bind(*args, **kwargs).arguments\n                except TypeError:\n                    # multiple values for new_name using both args and kwargs\n                    kwargs.pop(new_name, None)\n            binding = sig.bind(*args, **kwargs).arguments\n            positional_found = name in binding\n            kw_found = False\n            for k, param in sig.parameters.items():\n                if param.kind == inspect.Parameter.VAR_KEYWORD and k in binding and name in binding[k]:\n                    kw_found = True\n                    # if the deprecated arg is found in the **kwargs, it should be removed\n                    kwargs.pop(name, None)\n\n            if positional_found or kw_found:\n                if is_removed:\n                    raise DeprecatedError(msg)\n                if is_deprecated:\n                    warn_deprecated(argname, msg, warning_category)\n\n            return func(*args, **kwargs)\n\n        return _wrapper\n\n    return _decorator\n\n\ndef deprecated_arg_default(\n    name: str,\n    old_default: Any,\n    new_default: Any,\n    since: str | None = None,\n    replaced: str | None = None,\n    msg_suffix: str = \"\",\n    version_val: str = __version__,\n    warning_category: type[FutureWarning] = FutureWarning,\n) -> Callable[[T], T]:\n    \"\"\"\n    Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default`\n    in version `changed`.\n\n    When the decorated definition is called, a `warning_category` is issued if `since` is given,\n    the default is not explicitly set by the caller and the current version is at or later than that given.\n    Another warning with the same category is issued if `changed` is given and the current version is at or later.\n\n    The relevant docstring of the deprecating function should also be updated accordingly,\n    using the Sphinx directives such as `.. versionchanged:: version` and `.. deprecated:: version`.\n    https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-versionadded\n\n\n    Args:\n        name: name of position or keyword argument where the default is deprecated/changed.\n        old_default: name of the old default. This is only for the warning message, it will not be validated.\n        new_default: name of the new default.\n            It is validated that this value is not present as the default before version `replaced`.\n            This means, that you can also use this if the actual default value is `None` and set later in the function.\n            You can also set this to any string representation, e.g. `\"calculate_default_value()\"`\n            if the default is calculated from another function.\n        since: version at which the argument default was marked deprecated but not replaced.\n        replaced: version at which the argument default was/will be replaced.\n        msg_suffix: message appended to warning/exception detailing reasons for deprecation.\n        version_val: (used for testing) version to compare since and removed against, default is MONAI version.\n        warning_category: a warning category class, defaults to `FutureWarning`.\n\n    Returns:\n        Decorated callable which warns when deprecated default argument is not explicitly specified.\n    \"\"\"\n\n    if version_val.startswith(\"0+\") or not f\"{version_val}\".strip()[0].isdigit():\n        # version unknown, set version_val to a large value (assuming the latest version)\n        version_val = f\"{sys.maxsize}\"\n    if since is not None and replaced is not None and not version_leq(since, replaced):\n        raise ValueError(f\"since must be less or equal to replaced, got since={since}, replaced={replaced}.\")\n    is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)\n    if is_not_yet_deprecated:\n        # smaller than `since`, do nothing\n        return lambda obj: obj\n    if since is None and replaced is None:\n        # raise a DeprecatedError directly\n        is_replaced = True\n        is_deprecated = True\n    else:\n        # compare the numbers\n        is_deprecated = since is not None and version_leq(since, version_val)\n        is_replaced = replaced is not None and version_val != f\"{sys.maxsize}\" and version_leq(replaced, version_val)\n\n    def _decorator(func):\n        argname = f\"{func.__module__} {func.__qualname__}:{name}\"\n\n        msg_prefix = f\" Current default value of argument `{name}={old_default}`\"\n\n        if is_replaced:\n            msg_infix = f\"was changed in version {replaced} from `{name}={old_default}` to `{name}={new_default}`.\"\n        elif is_deprecated:\n            msg_infix = f\"has been deprecated since version {since}.\"\n            if replaced is not None:\n                msg_infix += f\" It will be changed to `{name}={new_default}` in version {replaced}.\"\n        else:\n            msg_infix = f\"has been deprecated from `{name}={old_default}` to `{name}={new_default}`.\"\n\n        msg = f\"{msg_prefix} {msg_infix} {msg_suffix}\".strip()\n\n        sig = inspect.signature(func)\n        if name not in sig.parameters:\n            raise ValueError(f\"Argument `{name}` not found in signature of {func.__qualname__}.\")\n        param = sig.parameters[name]\n        if param.default is inspect.Parameter.empty:\n            raise ValueError(f\"Argument `{name}` has no default value.\")\n\n        if param.default == new_default and not is_replaced:\n            raise ValueError(\n                f\"Argument `{name}` was replaced to the new default value `{new_default}` before the specified version {replaced}.\"\n            )\n\n        @wraps(func)\n        def _wrapper(*args, **kwargs):\n            if name not in sig.bind(*args, **kwargs).arguments and is_deprecated:\n                # arg was not found so the default value is used\n                warn_deprecated(argname, msg, warning_category)\n\n            return func(*args, **kwargs)\n\n        return _wrapper\n\n    return _decorator\n"
  },
  {
    "path": "monai/utils/dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable\nfrom logging import Filter\nfrom typing import Literal, overload\n\nimport torch\nimport torch.distributed as dist\n\nfrom monai.utils.enums import IgniteInfo\nfrom monai.utils.module import min_version, optional_import\n\nidist, has_ignite = optional_import(\"ignite\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"distributed\")\n\n__all__ = [\"get_dist_device\", \"evenly_divisible_all_gather\", \"string_list_all_gather\", \"RankFilter\"]\n\n\ndef get_dist_device():\n    \"\"\"\n    Get the expected target device in the native PyTorch distributed data parallel.\n    For NCCL backend, return GPU device of current process.\n    For GLOO backend, return CPU.\n    For any other backends, return None as the default, tensor.to(None) will not change the device.\n\n    \"\"\"\n    if dist.is_initialized():\n        backend = dist.get_backend()\n        if backend == \"nccl\" and torch.cuda.is_available():\n            return torch.device(f\"cuda:{torch.cuda.current_device()}\")\n        if backend == \"gloo\":\n            return torch.device(\"cpu\")\n    return None\n\n\n@overload\ndef evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor: ...\n\n\n@overload\ndef evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]: ...\n\n\n@overload\ndef evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]: ...\n\n\ndef evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]:\n    \"\"\"\n    Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather.\n    The input data of every rank should have the same number of dimensions, only the first dim can be different.\n\n    Note: If has ignite installed, will execute based on ignite distributed APIs, otherwise, if the native\n    PyTorch distributed group initialized, will execute based on native PyTorch distributed APIs.\n\n    Args:\n        data: source tensor to pad and execute all_gather in distributed data parallel.\n        concat: whether to concat the gathered list to be a Tensor, if False, return a list\n            of Tensors, similar behavior as torch.distributed.all_gather(). default to True.\n\n    Note:\n        The input data on different ranks must have exactly same `dtype`.\n\n    \"\"\"\n    if not isinstance(data, torch.Tensor):\n        raise ValueError(\"input data must be PyTorch Tensor.\")\n    # data of all the ranks must have same number of dimensions\n    ndims = data.ndimension()\n    length: int = data.shape[0] if ndims > 0 else 1\n\n    def _torch_all_gather(data: torch.Tensor) -> list[torch.Tensor]:\n        \"\"\"\n        Implementation based on native PyTorch distributed data parallel APIs.\n\n        \"\"\"\n        device = get_dist_device()\n        orig_device = data.device\n        data = data.to(device)\n        data = data.unsqueeze(0) if ndims == 0 else data\n\n        # make sure the data is evenly-divisible on multi-GPUs\n        length_tensor = torch.as_tensor([length], device=device)\n        all_lens = [torch.zeros_like(length_tensor) for _ in range(dist.get_world_size())]\n        dist.all_gather(all_lens, length_tensor)\n        all_lens_: list[int] = [int(i.item()) for i in all_lens]\n\n        max_len: int = max(all_lens_)\n        if length < max_len:\n            size = [max_len - length] + list(data.shape[1:])\n            data = torch.cat([data, data.new_full(size, 0)], dim=0)\n        # all gather across all processes\n        output = [torch.zeros_like(data) for _ in range(dist.get_world_size())]\n        dist.all_gather(output, data)\n        # remove the padding items, if all the input data doesn't have batch dim, squeeze the first dim\n        return [(o.squeeze(0) if ndims == 0 else o[:l, ...]).to(orig_device) for o, l in zip(output, all_lens_)]\n\n    def _ignite_all_gather(data: torch.Tensor) -> list[torch.Tensor]:\n        \"\"\"\n        Implementation based on PyTorch ignite package, it can support more kinds of backends.\n\n        \"\"\"\n        data = data.unsqueeze(0) if ndims == 0 else data\n        # make sure the data is evenly-divisible on multi-GPUs\n        all_lens: list[int] = idist.all_gather(length)\n        max_len: int = max(all_lens)\n        if length < max_len:\n            size = [max_len - length] + list(data.shape[1:])\n            data = torch.cat([data, data.new_full(size, 0)], dim=0)\n        # all gather across all processes\n        output = idist.all_gather(data)\n        # delete the padding NaN items\n        if ndims == 0:\n            # if all the input data doesn't have batch dim, unbind to a list of 0-dim Tensors\n            return list(torch.unbind(output, dim=0))\n        return [output[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)]\n\n    output: list[torch.Tensor]\n    if has_ignite:\n        if idist.get_world_size() <= 1:\n            return data\n        output = _ignite_all_gather(data=data)\n    elif dist.is_available() and dist.is_initialized():\n        if dist.get_world_size() <= 1:\n            return data\n        output = _torch_all_gather(data=data)\n    else:\n        return data\n\n    return torch.cat(output, dim=0) if concat else output\n\n\ndef string_list_all_gather(strings: list[str], delimiter: str = \"\\t\") -> list[str]:\n    \"\"\"\n    Utility function for distributed data parallel to all gather a list of strings.\n    Refer to the idea of ignite `all_gather(string)`:\n    https://pytorch.org/ignite/v0.4.5/distributed.html#ignite.distributed.utils.all_gather.\n\n    Note: If has ignite installed, will execute based on ignite distributed APIs, otherwise, if the native\n    PyTorch distributed group initialized, will execute based on native PyTorch distributed APIs.\n\n    Args:\n        strings: a list of strings to all gather.\n        delimiter: use the delimiter to join the string list to be a long string,\n            then all gather across ranks and split to a list. default to \"\\t\".\n\n    \"\"\"\n    world_size: int = 1\n    if has_ignite:\n        world_size = idist.get_world_size()\n    elif dist.is_available() and dist.is_initialized():\n        world_size = dist.get_world_size()\n\n    if world_size <= 1:\n        return strings\n\n    joined = delimiter.join(strings)\n    gathered = evenly_divisible_all_gather(torch.tensor(bytearray(joined, \"utf-8\"), dtype=torch.long), concat=False)\n    _gathered = [bytearray(g.tolist()).decode(\"utf-8\").split(delimiter) for g in gathered]\n\n    return [i for k in _gathered for i in k]\n\n\nclass RankFilter(Filter):\n    \"\"\"\n    The RankFilter class is a convenient filter that extends the Filter class in the Python logging module.\n    The purpose is to control which log records are processed based on the rank in a distributed environment.\n\n    Args:\n        rank: the rank of the process in the torch.distributed. Default is None and then it will use dist.get_rank().\n        filter_fn: an optional lambda function used as the filtering criteria.\n            The default function logs only if the rank of the process is 0,\n            but the user can define their own function to implement custom filtering logic.\n    \"\"\"\n\n    def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: rank == 0):\n        super().__init__()\n        self.filter_fn: Callable = filter_fn\n        if dist.is_available() and dist.is_initialized():\n            self.rank: int = rank if rank is not None else dist.get_rank()\n        else:\n            if torch.cuda.is_available() and torch.cuda.device_count() > 1:\n                warnings.warn(\n                    \"The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.\\n\"\n                    \"If torch.distributed is used, please ensure that the RankFilter() is called\\n\"\n                    \"after torch.distributed.init_process_group() in the script.\\n\"\n                )\n            self.rank = 0\n\n    def filter(self, *_args):\n        return self.filter_fn(self.rank)\n"
  },
  {
    "path": "monai/utils/enums.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nfrom enum import Enum\nfrom typing import TYPE_CHECKING\n\nfrom monai.utils.module import min_version, optional_import\n\n__all__ = [\n    \"StrEnum\",\n    \"NumpyPadMode\",\n    \"GridSampleMode\",\n    \"SplineMode\",\n    \"InterpolateMode\",\n    \"UpsampleMode\",\n    \"DownsampleMode\",\n    \"BlendMode\",\n    \"PytorchPadMode\",\n    \"NdimageMode\",\n    \"GridSamplePadMode\",\n    \"Average\",\n    \"MetricReduction\",\n    \"LossReduction\",\n    \"DiceCEReduction\",\n    \"Weight\",\n    \"ChannelMatching\",\n    \"SkipMode\",\n    \"Method\",\n    \"TraceKeys\",\n    \"TraceStatusKeys\",\n    \"CommonKeys\",\n    \"GanKeys\",\n    \"PostFix\",\n    \"ForwardMode\",\n    \"TransformBackends\",\n    \"CompInitMode\",\n    \"BoxModeName\",\n    \"GridPatchSort\",\n    \"FastMRIKeys\",\n    \"SpaceKeys\",\n    \"MetaKeys\",\n    \"ColorOrder\",\n    \"EngineStatsKeys\",\n    \"DataStatsKeys\",\n    \"ImageStatsKeys\",\n    \"LabelStatsKeys\",\n    \"HoVerNetMode\",\n    \"HoVerNetBranch\",\n    \"LazyAttr\",\n    \"BundleProperty\",\n    \"BundlePropertyConfig\",\n    \"AlgoKeys\",\n    \"IgniteInfo\",\n]\n\n\nclass StrEnum(str, Enum):\n    \"\"\"\n    Enum subclass that converts its value to a string.\n\n    .. code-block:: python\n\n        from monai.utils import StrEnum\n\n        class Example(StrEnum):\n            MODE_A = \"A\"\n            MODE_B = \"B\"\n\n        assert (list(Example) == [\"A\", \"B\"])\n        assert Example.MODE_A == \"A\"\n        assert str(Example.MODE_A) == \"A\"\n        assert monai.utils.look_up_option(\"A\", Example) == \"A\"\n    \"\"\"\n\n    def __str__(self):\n        return self.value\n\n    def __repr__(self):\n        return self.value\n\n\nclass NumpyPadMode(StrEnum):\n    \"\"\"\n    See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html\n    \"\"\"\n\n    CONSTANT = \"constant\"\n    EDGE = \"edge\"\n    LINEAR_RAMP = \"linear_ramp\"\n    MAXIMUM = \"maximum\"\n    MEAN = \"mean\"\n    MEDIAN = \"median\"\n    MINIMUM = \"minimum\"\n    REFLECT = \"reflect\"\n    SYMMETRIC = \"symmetric\"\n    WRAP = \"wrap\"\n    EMPTY = \"empty\"\n\n\nclass NdimageMode(StrEnum):\n    \"\"\"\n    The available options determine how the input array is extended beyond its boundaries when interpolating.\n    See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n    \"\"\"\n\n    REFLECT = \"reflect\"\n    GRID_MIRROR = \"grid-mirror\"\n    CONSTANT = \"constant\"\n    GRID_CONSTANT = \"grid-constant\"\n    NEAREST = \"nearest\"\n    MIRROR = \"mirror\"\n    GRID_WRAP = \"grid-wrap\"\n    WRAP = \"wrap\"\n\n\nclass GridSampleMode(StrEnum):\n    \"\"\"\n    See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n\n    interpolation mode of `torch.nn.functional.grid_sample`\n\n    Note:\n        (documentation from `torch.nn.functional.grid_sample`)\n        `mode='bicubic'` supports only 4-D input.\n        When `mode='bilinear'` and the input is 5-D, the interpolation mode used internally will actually be trilinear.\n        However, when the input is 4-D, the interpolation mode will legitimately be bilinear.\n    \"\"\"\n\n    NEAREST = \"nearest\"\n    BILINEAR = \"bilinear\"\n    BICUBIC = \"bicubic\"\n\n\nclass SplineMode(StrEnum):\n    \"\"\"\n    Order of spline interpolation.\n\n    See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html\n    \"\"\"\n\n    ZERO = 0\n    ONE = 1\n    TWO = 2\n    THREE = 3\n    FOUR = 4\n    FIVE = 5\n\n\nclass InterpolateMode(StrEnum):\n    \"\"\"\n    See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html\n    \"\"\"\n\n    NEAREST = \"nearest\"\n    NEAREST_EXACT = \"nearest-exact\"\n    LINEAR = \"linear\"\n    BILINEAR = \"bilinear\"\n    BICUBIC = \"bicubic\"\n    TRILINEAR = \"trilinear\"\n    AREA = \"area\"\n\n\nclass UpsampleMode(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.networks.blocks.UpSample`\n    \"\"\"\n\n    DECONV = \"deconv\"\n    DECONVGROUP = \"deconvgroup\"\n    NONTRAINABLE = \"nontrainable\"  # e.g. using torch.nn.Upsample\n    PIXELSHUFFLE = \"pixelshuffle\"\n\n\nclass DownsampleMode(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.networks.blocks.UpSample`\n    \"\"\"\n\n    CONV = \"conv\"  # e.g. using strided convolution\n    CONVGROUP = \"convgroup\"  # e.g. using grouped strided convolution\n    PIXELUNSHUFFLE = \"pixelunshuffle\"\n    MAXPOOL = \"maxpool\"\n    AVGPOOL = \"avgpool\"\n\n\nclass BlendMode(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.data.utils.compute_importance_map`\n    \"\"\"\n\n    CONSTANT = \"constant\"\n    GAUSSIAN = \"gaussian\"\n\n\nclass PytorchPadMode(StrEnum):\n    \"\"\"\n    See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n    \"\"\"\n\n    CONSTANT = \"constant\"\n    REFLECT = \"reflect\"\n    REPLICATE = \"replicate\"\n    CIRCULAR = \"circular\"\n\n\nclass GridSamplePadMode(StrEnum):\n    \"\"\"\n    See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html\n    \"\"\"\n\n    ZEROS = \"zeros\"\n    BORDER = \"border\"\n    REFLECTION = \"reflection\"\n\n\nclass Average(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or\n    :py:class:`monai.metrics.average_precision.compute_average_precision`\n    \"\"\"\n\n    MACRO = \"macro\"\n    WEIGHTED = \"weighted\"\n    MICRO = \"micro\"\n    NONE = \"none\"\n\n\nclass MetricReduction(StrEnum):\n    \"\"\"\n    See also: :py:func:`monai.metrics.utils.do_metric_reduction`\n    \"\"\"\n\n    NONE = \"none\"\n    MEAN = \"mean\"\n    SUM = \"sum\"\n    MEAN_BATCH = \"mean_batch\"\n    SUM_BATCH = \"sum_batch\"\n    MEAN_CHANNEL = \"mean_channel\"\n    SUM_CHANNEL = \"sum_channel\"\n\n\nclass LossReduction(StrEnum):\n    \"\"\"\n    See also:\n        - :py:class:`monai.losses.dice.DiceLoss`\n        - :py:class:`monai.losses.dice.GeneralizedDiceLoss`\n        - :py:class:`monai.losses.focal_loss.FocalLoss`\n        - :py:class:`monai.losses.tversky.TverskyLoss`\n    \"\"\"\n\n    NONE = \"none\"\n    MEAN = \"mean\"\n    SUM = \"sum\"\n\n\nclass DiceCEReduction(StrEnum):\n    \"\"\"\n    See also:\n        - :py:class:`monai.losses.dice.DiceCELoss`\n    \"\"\"\n\n    MEAN = \"mean\"\n    SUM = \"sum\"\n\n\nclass Weight(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.losses.dice.GeneralizedDiceLoss`\n    \"\"\"\n\n    SQUARE = \"square\"\n    SIMPLE = \"simple\"\n    UNIFORM = \"uniform\"\n\n\nclass ChannelMatching(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.networks.nets.HighResBlock`\n    \"\"\"\n\n    PAD = \"pad\"\n    PROJECT = \"project\"\n\n\nclass SkipMode(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.networks.layers.SkipConnection`\n    \"\"\"\n\n    CAT = \"cat\"\n    ADD = \"add\"\n    MUL = \"mul\"\n\n\nclass Method(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.transforms.croppad.array.SpatialPad`\n    \"\"\"\n\n    SYMMETRIC = \"symmetric\"\n    END = \"end\"\n\n\nclass ForwardMode(StrEnum):\n    \"\"\"\n    See also: :py:class:`monai.transforms.engines.evaluator.Evaluator`\n    \"\"\"\n\n    TRAIN = \"train\"\n    EVAL = \"eval\"\n\n\nclass TraceKeys(StrEnum):\n    \"\"\"Extra metadata keys used for traceable transforms.\"\"\"\n\n    CLASS_NAME: str = \"class\"\n    ID: str = \"id\"\n    ORIG_SIZE: str = \"orig_size\"\n    EXTRA_INFO: str = \"extra_info\"\n    DO_TRANSFORM: str = \"do_transforms\"\n    KEY_SUFFIX: str = \"_transforms\"\n    NONE: str = \"none\"\n    TRACING: str = \"tracing\"\n    STATUSES: str = \"statuses\"\n    LAZY: str = \"lazy\"\n\n\nclass TraceStatusKeys(StrEnum):\n    \"\"\"Enumerable status keys for the TraceKeys.STATUS flag\"\"\"\n\n    PENDING_DURING_APPLY = \"pending_during_apply\"\n\n\nclass CommonKeys(StrEnum):\n    \"\"\"\n    A set of common keys for dictionary based supervised training process.\n    `IMAGE` is the input image data.\n    `LABEL` is the training or evaluation label of segmentation or classification task.\n    `PRED` is the prediction data of model output.\n    `LOSS` is the loss value of current iteration.\n    `METADATA` is some useful information during training or evaluation, like loss value, etc.\n\n    \"\"\"\n\n    IMAGE = \"image\"\n    LABEL = \"label\"\n    PRED = \"pred\"\n    LOSS = \"loss\"\n    METADATA = \"metadata\"\n\n\nclass GanKeys(StrEnum):\n    \"\"\"\n    A set of common keys for generative adversarial networks.\n\n    \"\"\"\n\n    REALS = \"reals\"\n    FAKES = \"fakes\"\n    LATENTS = \"latents\"\n    GLOSS = \"g_loss\"\n    DLOSS = \"d_loss\"\n\n\nclass PostFix(StrEnum):\n    \"\"\"Post-fixes.\"\"\"\n\n    @staticmethod\n    def _get_str(prefix: str | None, suffix: str) -> str:\n        return suffix if prefix is None else f\"{prefix}_{suffix}\"\n\n    @staticmethod\n    def meta(key: str | None = None) -> str:\n        return PostFix._get_str(key, \"meta_dict\")\n\n    @staticmethod\n    def orig_meta(key: str | None = None) -> str:\n        return PostFix._get_str(key, \"orig_meta_dict\")\n\n    @staticmethod\n    def transforms(key: str | None = None) -> str:\n        return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:])\n\n\nclass TransformBackends(StrEnum):\n    \"\"\"\n    Transform backends. Most of `monai.transforms` components first converts the input data into ``torch.Tensor`` or\n    ``monai.data.MetaTensor``. Internally, some transforms are made by converting the data into ``numpy.array`` or\n    ``cupy.array`` and use the underlying transform backend API to achieve the actual output array and\n    converting back to ``Tensor``/``MetaTensor``. Transforms with more than one backend indicate the that they may\n    convert the input data types to accommodate the underlying API.\n    \"\"\"\n\n    TORCH = \"torch\"\n    NUMPY = \"numpy\"\n    CUPY = \"cupy\"\n\n\nclass CompInitMode(StrEnum):\n    \"\"\"\n    Mode names for instantiating a class or calling a callable.\n\n    See also: :py:func:`monai.utils.module.instantiate`\n    \"\"\"\n\n    DEFAULT = \"default\"\n    CALLABLE = \"callable\"\n    DEBUG = \"debug\"\n\n\nclass JITMetadataKeys(StrEnum):\n    \"\"\"\n    Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines\n    and others are optionally provided by users.\n    \"\"\"\n\n    NAME = \"name\"\n    TIMESTAMP = \"timestamp\"\n    VERSION = \"version\"\n    DESCRIPTION = \"description\"\n\n\nclass BoxModeName(StrEnum):\n    \"\"\"\n    Box mode names.\n    \"\"\"\n\n    XYXY = \"xyxy\"  # [xmin, ymin, xmax, ymax]\n    XYZXYZ = \"xyzxyz\"  # [xmin, ymin, zmin, xmax, ymax, zmax]\n    XXYY = \"xxyy\"  # [xmin, xmax, ymin, ymax]\n    XXYYZZ = \"xxyyzz\"  # [xmin, xmax, ymin, ymax, zmin, zmax]\n    XYXYZZ = \"xyxyzz\"  # [xmin, ymin, xmax, ymax, zmin, zmax]\n    XYWH = \"xywh\"  # [xmin, ymin, xsize, ysize]\n    XYZWHD = \"xyzwhd\"  # [xmin, ymin, zmin, xsize, ysize, zsize]\n    CCWH = \"ccwh\"  # [xcenter, ycenter, xsize, ysize]\n    CCCWHD = \"cccwhd\"  # [xcenter, ycenter, zcenter, xsize, ysize, zsize]\n\n\nclass ProbMapKeys(StrEnum):\n    \"\"\"\n    The keys to be used for generating the probability maps from patches\n    \"\"\"\n\n    LOCATION = \"mask_location\"\n    SIZE = \"mask_size\"\n    COUNT = \"num_patches\"\n    NAME = \"name\"\n\n\nclass GridPatchSort(StrEnum):\n    \"\"\"\n    The sorting method for the generated patches in `GridPatch`\n    \"\"\"\n\n    RANDOM = \"random\"\n    MIN = \"min\"\n    MAX = \"max\"\n\n    @staticmethod\n    def min_fn(x):\n        return x[0].sum()\n\n    @staticmethod\n    def max_fn(x):\n        return -x[0].sum()\n\n    @staticmethod\n    def get_sort_fn(sort_fn):\n        if sort_fn == GridPatchSort.RANDOM:\n            return random.random\n        elif sort_fn == GridPatchSort.MIN:\n            return GridPatchSort.min_fn\n        elif sort_fn == GridPatchSort.MAX:\n            return GridPatchSort.max_fn\n        else:\n            raise ValueError(\n                f'sort_fn should be one of the following values, \"{sort_fn}\" was given:',\n                [e.value for e in GridPatchSort],\n            )\n\n\nclass PatchKeys(StrEnum):\n    \"\"\"\n    The keys to be used for metadata of patches extracted from any kind of image\n    \"\"\"\n\n    LOCATION = \"location\"\n    SIZE = \"size\"\n    COUNT = \"count\"\n\n\nclass WSIPatchKeys(StrEnum):\n    \"\"\"\n    The keys to be used for metadata of patches extracted from whole slide images\n    \"\"\"\n\n    LOCATION = PatchKeys.LOCATION\n    SIZE = PatchKeys.SIZE\n    COUNT = PatchKeys.COUNT\n    LEVEL = \"level\"\n    PATH = \"path\"\n\n\nclass FastMRIKeys(StrEnum):\n    \"\"\"\n    The keys to be used for extracting data from the fastMRI dataset\n    \"\"\"\n\n    KSPACE = \"kspace\"\n    MASK = \"mask\"\n    FILENAME = \"filename\"\n    RECON = \"reconstruction_rss\"\n    ACQUISITION = \"acquisition\"\n    MAX = \"max\"\n    NORM = \"norm\"\n    PID = \"patient_id\"\n\n\nclass SpaceKeys(StrEnum):\n    \"\"\"\n    The coordinate system keys, for example, Nifti1 uses Right-Anterior-Superior or \"RAS\",\n    DICOM (0020,0032) uses Left-Posterior-Superior or \"LPS\". This type does not distinguish spatial 1/2/3D.\n    \"\"\"\n\n    RAS = \"RAS\"\n    LPS = \"LPS\"\n\n\nclass MetaKeys(StrEnum):\n    \"\"\"\n    Typical keys for MetaObj.meta\n    \"\"\"\n\n    PIXDIM = \"pixdim\"  # MetaTensor.pixdim\n    ORIGINAL_PIXDIM = \"original_pixdim\"  # the pixdim after image loading before any data processing\n    AFFINE = \"affine\"  # MetaTensor.affine\n    ORIGINAL_AFFINE = \"original_affine\"  # the affine after image loading before any data processing\n    SPATIAL_SHAPE = \"spatial_shape\"  # optional key for the length in each spatial dimension\n    SPACE = \"space\"  # possible values of space type are defined in `SpaceKeys`\n    ORIGINAL_CHANNEL_DIM = \"original_channel_dim\"  # an integer or float(\"nan\")\n    SAVED_TO = \"saved_to\"\n\n\nclass ColorOrder(StrEnum):\n    \"\"\"\n    Enums for color order. Expand as necessary.\n    \"\"\"\n\n    RGB = \"RGB\"\n    BGR = \"BGR\"\n\n\nclass EngineStatsKeys(StrEnum):\n    \"\"\"\n    Default keys for the statistics of trainer and evaluator engines.\n\n    \"\"\"\n\n    RANK = \"rank\"\n    CURRENT_ITERATION = \"current_iteration\"\n    CURRENT_EPOCH = \"current_epoch\"\n    TOTAL_EPOCHS = \"total_epochs\"\n    TOTAL_ITERATIONS = \"total_iterations\"\n    BEST_VALIDATION_EPOCH = \"best_validation_epoch\"\n    BEST_VALIDATION_METRIC = \"best_validation_metric\"\n\n\nclass DataStatsKeys(StrEnum):\n    \"\"\"\n    Defaults keys for dataset statistical analysis modules\n\n    \"\"\"\n\n    SUMMARY = \"stats_summary\"\n    BY_CASE = \"stats_by_cases\"\n    BY_CASE_IMAGE_PATH = \"image_filepath\"\n    BY_CASE_LABEL_PATH = \"label_filepath\"\n    IMAGE_STATS = \"image_stats\"\n    FG_IMAGE_STATS = \"image_foreground_stats\"\n    LABEL_STATS = \"label_stats\"\n    IMAGE_HISTOGRAM = \"image_histogram\"\n\n\nclass ImageStatsKeys(StrEnum):\n    \"\"\"\n    Defaults keys for dataset statistical analysis image modules\n\n    \"\"\"\n\n    SHAPE = \"shape\"\n    CHANNELS = \"channels\"\n    CROPPED_SHAPE = \"cropped_shape\"\n    SPACING = \"spacing\"\n    SIZEMM = \"sizemm\"\n    INTENSITY = \"intensity\"\n    HISTOGRAM = \"histogram\"\n\n\nclass LabelStatsKeys(StrEnum):\n    \"\"\"\n    Defaults keys for dataset statistical analysis label modules\n\n    \"\"\"\n\n    LABEL_UID = \"labels\"\n    PIXEL_PCT = \"foreground_percentage\"\n    IMAGE_INTST = \"image_intensity\"\n    LABEL = \"label\"\n    LABEL_SHAPE = \"shape\"\n    LABEL_NCOMP = \"ncomponents\"\n\n\nclass HoVerNetMode(StrEnum):\n    \"\"\"\n    Modes for HoVerNet model:\n    `FAST`: a faster implementation (than original)\n    `ORIGINAL`: the original implementation\n    \"\"\"\n\n    FAST = \"FAST\"\n    ORIGINAL = \"ORIGINAL\"\n\n\nclass HoVerNetBranch(StrEnum):\n    \"\"\"\n    Three branches of HoVerNet model, which results in three outputs:\n    `HV` is horizontal and vertical gradient map of each nucleus (regression),\n    `NP` is the pixel prediction of all nuclei (segmentation), and\n    `NC` is the type of each nucleus (classification).\n    \"\"\"\n\n    HV = \"horizontal_vertical\"\n    NP = \"nucleus_prediction\"\n    NC = \"type_prediction\"\n\n\nclass LazyAttr(StrEnum):\n    \"\"\"\n    MetaTensor with pending operations requires some key attributes tracked especially when the primary array\n    is not up-to-date due to lazy evaluation.\n    This class specifies the set of key attributes to be tracked for each MetaTensor.\n    See also: :py:func:`monai.transforms.lazy.utils.resample` for more details.\n    \"\"\"\n\n    SHAPE = \"lazy_shape\"  # spatial shape\n    AFFINE = \"lazy_affine\"\n    PADDING_MODE = \"lazy_padding_mode\"\n    INTERP_MODE = \"lazy_interpolation_mode\"\n    DTYPE = \"lazy_dtype\"\n    ALIGN_CORNERS = \"lazy_align_corners\"\n    RESAMPLE_MODE = \"lazy_resample_mode\"\n\n\nclass BundleProperty(StrEnum):\n    \"\"\"\n    Bundle property fields:\n    `DESC` is the description of the property.\n    `REQUIRED` is flag to indicate whether the property is required or optional.\n    \"\"\"\n\n    DESC = \"description\"\n    REQUIRED = \"required\"\n\n\nclass BundlePropertyConfig(StrEnum):\n    \"\"\"\n    additional bundle property fields for config based bundle workflow:\n    `ID` is the config item ID of the property.\n    `REF_ID` is the ID of config item which is supposed to refer to this property.\n    For properties that do not have `REF_ID`, `None` should be set.\n    this field is only useful to check the optional property ID.\n    \"\"\"\n\n    ID = \"id\"\n    REF_ID = \"refer_id\"\n\n\nclass AlgoKeys(StrEnum):\n    \"\"\"\n    Default keys for templated Auto3DSeg Algo.\n    `ID` is the identifier of the algorithm. The string has the format of <name>_<idx>_<other>.\n    `ALGO` is the Auto3DSeg Algo instance.\n    `IS_TRAINED` is the status that shows if the Algo has been trained.\n    `SCORE` is the score the Algo has achieved after training.\n    \"\"\"\n\n    ID = \"identifier\"\n    ALGO = \"algo_instance\"\n    IS_TRAINED = \"is_trained\"\n    SCORE = \"best_metric\"\n\n\nclass AdversarialKeys(StrEnum):\n    \"\"\"\n    Keys used by the AdversarialTrainer.\n    `REALS` are real images from the batch.\n    `FAKES` are fake images generated by the generator. Are the same as PRED.\n    `REAL_LOGITS` are logits of the discriminator for the real images.\n    `FAKE_LOGIT` are logits of the discriminator for the fake images.\n    `RECONSTRUCTION_LOSS` is the loss value computed by the reconstruction loss function.\n    `GENERATOR_LOSS` is the loss value computed by the generator loss function. It is the\n                discriminator loss for the fake images. That is backpropagated through the generator only.\n    `DISCRIMINATOR_LOSS` is the loss value computed by the discriminator loss function. It is the\n                discriminator loss for the real images and the fake images. That is backpropagated through the\n                discriminator only.\n    \"\"\"\n\n    REALS = \"reals\"\n    REAL_LOGITS = \"real_logits\"\n    FAKES = \"fakes\"\n    FAKE_LOGITS = \"fake_logits\"\n    RECONSTRUCTION_LOSS = \"reconstruction_loss\"\n    GENERATOR_LOSS = \"generator_loss\"\n    DISCRIMINATOR_LOSS = \"discriminator_loss\"\n\n\nclass OrderingType(StrEnum):\n    RASTER_SCAN = \"raster_scan\"\n    S_CURVE = \"s_curve\"\n    RANDOM = \"random\"\n\n\nclass OrderingTransformations(StrEnum):\n    ROTATE_90 = \"rotate_90\"\n    TRANSPOSE = \"transpose\"\n    REFLECT = \"reflect\"\n\n\nclass IgniteInfo(StrEnum):\n    \"\"\"\n    Config information of the PyTorch ignite package.\n\n    \"\"\"\n\n    OPT_IMPORT_VERSION = \"0.4.11\"\n\n\nif TYPE_CHECKING:\n    from ignite.engine import EventEnum\nelse:\n    EventEnum, _ = optional_import(\n        \"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"EventEnum\", as_type=\"base\"\n    )\n\n\nclass AdversarialIterationEvents(EventEnum):\n    \"\"\"\n    Keys used to define events as used in the AdversarialTrainer.\n    \"\"\"\n\n    RECONSTRUCTION_LOSS_COMPLETED = \"reconstruction_loss_completed\"\n    GENERATOR_FORWARD_COMPLETED = \"generator_forward_completed\"\n    GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = \"generator_discriminator_forward_completed\"\n    GENERATOR_LOSS_COMPLETED = \"generator_loss_completed\"\n    GENERATOR_BACKWARD_COMPLETED = \"generator_backward_completed\"\n    GENERATOR_MODEL_COMPLETED = \"generator_model_completed\"\n    DISCRIMINATOR_REALS_FORWARD_COMPLETED = \"discriminator_reals_forward_completed\"\n    DISCRIMINATOR_FAKES_FORWARD_COMPLETED = \"discriminator_fakes_forward_completed\"\n    DISCRIMINATOR_LOSS_COMPLETED = \"discriminator_loss_completed\"\n    DISCRIMINATOR_BACKWARD_COMPLETED = \"discriminator_backward_completed\"\n    DISCRIMINATOR_MODEL_COMPLETED = \"discriminator_model_completed\"\n"
  },
  {
    "path": "monai/utils/jupyter_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nThis set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using\nMatplotlib produce common plots for metrics and images.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport copy\nfrom collections.abc import Callable, Mapping\nfrom enum import Enum\nfrom threading import RLock, Thread\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import IgniteInfo\nfrom monai.utils.module import min_version, optional_import\n\ntry:\n    import matplotlib.pyplot as plt\n\n    has_matplotlib = True\nexcept ImportError:\n    has_matplotlib = False\n\nif TYPE_CHECKING:\n    from ignite.engine import Engine, Events\nelse:\n    Engine, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n    Events, _ = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\nLOSS_NAME = \"loss\"\n\n\ndef plot_metric_graph(\n    ax: plt.Axes,\n    title: str,\n    graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]],\n    yscale: str = \"log\",\n    avg_keys: tuple[str] = (LOSS_NAME,),\n    window_fraction: int = 20,\n) -> None:\n    \"\"\"\n    Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap`\n    should be lists of (timepoint, value) pairs as stored in MetricLogger objects.\n\n    Args:\n        ax: Axes object to plot into\n        title: graph title\n        graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs\n        yscale: scale for y-axis compatible with `Axes.set_yscale`\n        avg_keys: tuple of keys in `graphmap` to provide running average plots for\n        window_fraction: what fraction of the graph value length to use as the running average window\n    \"\"\"\n    from matplotlib.ticker import MaxNLocator\n\n    for n, v in graphmap.items():\n        if len(v) > 0:\n            if isinstance(v[0], (tuple, list)):  # values are (x,y) pairs\n                inds, vals = zip(*v)  # separate values into list of indices in X dimension and values\n            else:\n                inds, vals = tuple(range(len(v))), tuple(v)  # values are without indices, make indices for them\n\n            ax.plot(inds, vals, label=f\"{n} = {vals[-1]:.5g}\")\n\n            # if requested compute and plot a running average for the values using a fractional window size\n            if n in avg_keys and len(v) > window_fraction:\n                window = len(v) // window_fraction\n                kernel = np.ones((window,)) / window\n                ra = np.convolve((vals[0],) * (window - 1) + vals, kernel, mode=\"valid\")\n\n                ax.plot(inds, ra, label=f\"{n} Avg = {ra[-1]:.5g}\")\n\n    ax.set_title(title)\n    ax.set_yscale(yscale)\n    ax.axis(\"on\")\n    ax.legend(bbox_to_anchor=(1, 1), loc=1, borderaxespad=0.0)\n    ax.grid(True, \"both\", \"both\")\n    ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n\n\ndef plot_metric_images(\n    fig: plt.Figure,\n    title: str,\n    graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]],\n    imagemap: dict[str, np.ndarray],\n    yscale: str = \"log\",\n    avg_keys: tuple[str] = (LOSS_NAME,),\n    window_fraction: int = 20,\n) -> list:\n    \"\"\"\n    Plot metric graph data with images below into figure `fig`. The intended use is for the graph data to be\n    metrics from a training run and the images to be the batch and output from the last iteration. This uses\n    `plot_metric_graph` to plot the metric graph.\n\n    Args:\n        fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing\n        title: graph title\n        graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs\n        imagemap: dictionary of named images to show with metric plot\n        yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale`\n        avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for\n        window_fraction: for metric plot, what fraction of the graph value length to use as the running average window\n\n    Returns:\n        list of Axes objects for graph followed by images\n    \"\"\"\n    gridshape = (4, max(1, len(imagemap)))\n\n    graph = plt.subplot2grid(gridshape, (0, 0), colspan=gridshape[1], fig=fig)\n\n    plot_metric_graph(graph, title, graphmap, yscale, avg_keys, window_fraction)\n\n    axes = [graph]\n    for i, n in enumerate(imagemap):\n        im = plt.subplot2grid(gridshape, (1, i), rowspan=2, fig=fig)\n\n        if imagemap[n].shape[0] == 3:\n            im.imshow(imagemap[n].transpose([1, 2, 0]))\n        else:\n            im.imshow(np.squeeze(imagemap[n]), cmap=\"gray\")\n\n        im.set_title(f\"{n}\\n{imagemap[n].min():.3g} -> {imagemap[n].max():.3g}\")\n        im.axis(\"off\")\n        axes.append(im)\n\n    return axes\n\n\ndef tensor_to_images(name: str, tensor: torch.Tensor) -> np.ndarray | None:\n    \"\"\"\n    Return an tuple of images derived from the given tensor. The `name` value indices which key from the\n    output or batch value the tensor was stored as, or is \"Batch\" or \"Output\" if these were single tensors\n    instead of dictionaries. Returns a tuple of 2D images of shape HW, or 3D images of shape CHW where C is\n    color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show\n    each channel separately.\n    \"\"\"\n    if tensor.ndim == 3 and tensor.shape[1] > 2 and tensor.shape[2] > 2:\n        return tensor.cpu().data.numpy()  # type: ignore[no-any-return]\n    if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2:\n        dmid = tensor.shape[1] // 2\n        return tensor[:, dmid].cpu().data.numpy()  # type: ignore[no-any-return]\n\n    return None\n\n\ndef plot_engine_status(\n    engine: Engine,\n    logger: Any,\n    title: str = \"Training Log\",\n    yscale: str = \"log\",\n    avg_keys: tuple[str] = (LOSS_NAME,),\n    window_fraction: int = 20,\n    image_fn: Callable[[str, torch.Tensor], Any] | None = tensor_to_images,\n    fig: plt.Figure | None = None,\n    selected_inst: int = 0,\n) -> tuple[plt.Figure, list]:\n    \"\"\"\n    Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics\n    taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are\n    converted to Numpy arrays suitable for input to `Axes.imshow` using `image_fn`, if this is None then no image\n    plotting is done.\n\n    Args:\n        engine: Engine to extract images from\n        logger: MetricLogger to extract loss and metric data from\n        title: graph title\n        yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale`\n        avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for\n        window_fraction: for metric plot, what fraction of the graph value length to use as the running average window\n        image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot\n        fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing\n        selected_inst: index of the instance to show in the image plot\n\n    Returns:\n        Figure object (or `fig` if given), list of Axes objects for graph and images\n    \"\"\"\n    if fig is not None:\n        fig.clf()\n    else:\n        fig = plt.Figure(figsize=(20, 10), tight_layout=True, facecolor=\"white\")\n\n    graphmap: dict[str, list[float]] = {LOSS_NAME: logger.loss}\n    graphmap.update(logger.metrics)\n\n    imagemap: dict = {}\n    if image_fn is not None and engine.state is not None and engine.state.batch is not None:\n        for src in (engine.state.batch, engine.state.output):\n            label = \"Batch\" if src is engine.state.batch else \"Output\"\n            batch_selected_inst = selected_inst  # selected batch index, set to 0 when src is decollated\n\n            # if the src object is a list of elements, ie. a decollated batch, select an element and keep it as\n            # a dictionary of tensors with a batch dimension added\n            if isinstance(src, list):\n                selected_dict = src[selected_inst]  # select this element\n                batch_selected_inst = 0  # set the selection to be the single index in the batch dimension\n                # store each tensor that is interpretable as an image with an added batch dimension\n                src = {k: v[None] for k, v in selected_dict.items() if isinstance(v, torch.Tensor) and v.ndim >= 3}\n\n            # images will be generated from the batch item selected above only, or from the single item given as `src`\n\n            if isinstance(src, dict):\n                for k, v in src.items():\n                    if isinstance(v, torch.Tensor) and v.ndim >= 4:\n                        image = image_fn(k, v[batch_selected_inst])\n\n                        # if we have images add each one separately to the map\n                        if image is not None:\n                            for i, im in enumerate(image):\n                                imagemap[f\"{k}_{i}\"] = im\n\n            elif isinstance(src, torch.Tensor):\n                image = image_fn(label, src)\n                if image is not None:\n                    imagemap[f\"{label}_{i}\"] = image\n\n    axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction)\n\n    if logger.loss:\n        axes[0].axhline(logger.loss[-1][1], c=\"k\", ls=\":\")  # draw dotted horizontal line at last loss value\n\n    return fig, axes\n\n\ndef _get_loss_from_output(\n    output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,\n) -> torch.Tensor:\n    \"\"\"Returns a single value from the network output, which is a dict or tensor.\"\"\"\n\n    def _get_loss(data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor:\n        if isinstance(data, dict):\n            return data[\"loss\"]\n        return data\n\n    if isinstance(output, list):\n        return _get_loss(output[0])\n    return _get_loss(output)\n\n\nclass StatusMembers(Enum):\n    \"\"\"\n    Named members of the status dictionary, others may be present for named metric values.\n    \"\"\"\n\n    STATUS = \"Status\"\n    EPOCHS = \"Epochs\"\n    ITERS = \"Iters\"\n    LOSS = \"Loss\"\n\n\nclass ThreadContainer(Thread):\n    \"\"\"\n    Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This\n    allows an engine to begin a run in the background and allow the starting notebook cell to complete. A\n    user can thus start a run and then navigate away from the notebook without concern for loosing connection\n    with the running cell. All output is acquired through methods which synchronize with the running engine\n    using an internal `lock` member, acquiring this lock allows the engine to be inspected while it's prevented\n    from starting the next iteration.\n\n    Args:\n        engine: wrapped `Engine` object, when the container is started its `run` method is called\n        loss_transform: callable to convert an output dict into a single numeric value\n        metric_transform: callable to convert a named metric value into a single numeric value\n        status_format: format string for status key-value pairs.\n    \"\"\"\n\n    def __init__(\n        self,\n        engine: Engine,\n        loss_transform: Callable = _get_loss_from_output,\n        metric_transform: Callable = lambda name, value: value,\n        status_format: str = \"{}: {:.4}\",\n    ):\n        super().__init__()\n        self.lock = RLock()\n        self.engine = engine\n        self._status_dict: dict[str, Any] = {}\n        self.loss_transform = loss_transform\n        self.metric_transform = metric_transform\n        self.fig: plt.Figure | None = None\n        self.status_format = status_format\n\n        self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status)\n\n    def run(self):\n        \"\"\"Calls the `run` method of the wrapped engine.\"\"\"\n        self.engine.run()\n\n    def stop(self):\n        \"\"\"Stop the engine and join the thread.\"\"\"\n        self.engine.terminate()\n        self.join()\n\n    def _update_status(self):\n        \"\"\"Called as an event, updates the internal status dict at the end of iterations.\"\"\"\n        with self.lock:\n            state = self.engine.state\n            stats: dict[str, Any] = {\n                StatusMembers.EPOCHS.value: 0,\n                StatusMembers.ITERS.value: 0,\n                StatusMembers.LOSS.value: float(\"nan\"),\n            }\n\n            if state is not None:\n                if state.max_epochs is not None and state.max_epochs >= 1:\n                    epoch = f\"{state.epoch}/{state.max_epochs}\"\n                else:\n                    epoch = str(state.epoch)\n\n                if state.epoch_length is not None:\n                    iters = f\"{state.iteration % state.epoch_length}/{state.epoch_length}\"\n                else:\n                    iters = str(state.iteration)\n\n                stats[StatusMembers.EPOCHS.value] = epoch\n                stats[StatusMembers.ITERS.value] = iters\n                stats[StatusMembers.LOSS.value] = self.loss_transform(state.output)\n\n                metrics = state.metrics or {}\n                for m, v in metrics.items():\n                    v = self.metric_transform(m, v)\n                    if v is not None:\n                        stats[m].append(v)\n\n            self._status_dict.update(stats)\n\n    @property\n    def status_dict(self) -> dict[str, str]:\n        \"\"\"A dictionary containing status information, current loss, and current metric values.\"\"\"\n        with self.lock:\n            stats = {StatusMembers.STATUS.value: \"Running\" if self.is_alive() else \"Stopped\"}\n            stats.update(self._status_dict)\n            return stats\n\n    def status(self) -> str:\n        \"\"\"Returns a status string for the current state of the engine.\"\"\"\n        stats = copy.deepcopy(self.status_dict)\n\n        msgs = [stats.pop(StatusMembers.STATUS.value), \"Iters: \" + str(stats.pop(StatusMembers.ITERS.value, 0))]\n\n        for key, val in stats.items():\n            if isinstance(val, float):\n                msg = self.status_format.format(key, val)\n            else:\n                msg = f\"{key}: {val}\"\n\n            msgs.append(msg)\n\n        return \", \".join(msgs)\n\n    def plot_status(self, logger: Any, plot_func: Callable = plot_engine_status) -> plt.Figure | None:\n        \"\"\"\n        Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`.\n        The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title,\n        `self.engine`, `logger`, and `self.fig` respectively. The return value must be a figure object (stored in\n        `self.fig`) and a list of Axes objects for the plots in the figure. Only the figure is returned by this method,\n        which holds the internal lock during the plot generation.\n        \"\"\"\n        with self.lock:\n            self.fig, _ = plot_func(title=self.status(), engine=self.engine, logger=logger, fig=self.fig)\n            return self.fig\n"
  },
  {
    "path": "monai/utils/misc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport inspect\nimport itertools\nimport math\nimport os\nimport pprint\nimport random\nimport shutil\nimport subprocess\nimport tempfile\nimport types\nimport warnings\nfrom ast import literal_eval\nfrom collections.abc import Callable, Iterable, Sequence\nfrom math import log10\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, TypeVar, cast, overload\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike\nfrom monai.utils.module import optional_import, version_leq\n\nif TYPE_CHECKING:\n    from yaml import SafeLoader\nelse:\n    SafeLoader, _ = optional_import(\"yaml\", name=\"SafeLoader\", as_type=\"base\")\n\n__all__ = [\n    \"zip_with\",\n    \"star_zip_with\",\n    \"first\",\n    \"issequenceiterable\",\n    \"is_immutable\",\n    \"ensure_tuple\",\n    \"ensure_tuple_size\",\n    \"ensure_tuple_rep\",\n    \"to_tuple_of_dictionaries\",\n    \"fall_back_tuple\",\n    \"is_scalar_tensor\",\n    \"is_scalar\",\n    \"progress_bar\",\n    \"get_seed\",\n    \"set_determinism\",\n    \"list_to_dict\",\n    \"MAX_SEED\",\n    \"copy_to_device\",\n    \"str2bool\",\n    \"str2list\",\n    \"MONAIEnvVars\",\n    \"ImageMetaKey\",\n    \"is_module_ver_at_least\",\n    \"has_option\",\n    \"sample_slices\",\n    \"check_parent_dir\",\n    \"save_obj\",\n    \"label_union\",\n    \"path_to_uri\",\n    \"pprint_edges\",\n    \"check_key_duplicates\",\n    \"CheckKeyDuplicatesYamlLoader\",\n    \"ConvertUnits\",\n    \"check_kwargs_exist_in_class_init\",\n    \"run_cmd\",\n]\n\n\ndef _strtobool(val: str) -> bool:\n    \"\"\"\n    Replaces deprecated (pre python 3.12)\n    distutils strtobool function.\n\n    True values are y, yes, t, true, on and 1;\n    False values are n, no, f, false, off and 0.\n    Raises ValueError if val is anything else.\n    \"\"\"\n    val = val.lower()\n    if val in (\"y\", \"yes\", \"t\", \"true\", \"on\", \"1\"):\n        return True\n    elif val in (\"n\", \"no\", \"f\", \"false\", \"off\", \"0\"):\n        return False\n    else:\n        raise ValueError(f\"invalid truth value {val}\")\n\n\n_seed = None\n_flag_deterministic = torch.backends.cudnn.deterministic\n_flag_cudnn_benchmark = torch.backends.cudnn.benchmark\nNP_MAX = np.iinfo(np.uint32).max\nMAX_SEED = NP_MAX + 1  # 2**32, the actual seed should be in [0, MAX_SEED - 1] for uint32\n\n# Environment variable must be set to enable determinism for algorithms (alternative value is \":16:8\").\n# This needs to be here to ensure it's set before deterministic algorithms are used/initialised.\nos.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = os.environ.get(\"CUBLAS_WORKSPACE_CONFIG\", \":4096:8\")\n\n\ndef zip_with(op, *vals, mapfunc=map):\n    \"\"\"\n    Map `op`, using `mapfunc`, to each tuple derived from zipping the iterables in `vals`.\n    \"\"\"\n    return mapfunc(op, zip(*vals))\n\n\ndef star_zip_with(op, *vals):\n    \"\"\"\n    Use starmap as the mapping function in zipWith.\n    \"\"\"\n    return zip_with(op, *vals, mapfunc=itertools.starmap)\n\n\nT = TypeVar(\"T\")\nNT = TypeVar(\"NT\", np.ndarray, torch.Tensor)\n\n\n@overload\ndef first(iterable: Iterable[T], default: T) -> T: ...\n\n\n@overload\ndef first(iterable: Iterable[T]) -> T | None: ...\n\n\ndef first(iterable: Iterable[T], default: T | None = None) -> T | None:\n    \"\"\"\n    Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.\n    \"\"\"\n    for i in iterable:\n        return i\n    return default\n\n\ndef issequenceiterable(obj: Any) -> bool:\n    \"\"\"\n    Determine if the object is an iterable sequence and is not a string.\n    \"\"\"\n    try:\n        if hasattr(obj, \"ndim\") and obj.ndim == 0:\n            return False  # a 0-d tensor is not iterable\n    except Exception:\n        return False\n    return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes))\n\n\ndef is_immutable(obj: Any) -> bool:\n    \"\"\"\n    Determine if the object is an immutable object.\n\n    see also https://github.com/python/cpython/blob/3.11/Lib/copy.py#L109\n    \"\"\"\n    return isinstance(obj, (type(None), int, float, bool, complex, str, tuple, bytes, type, range, slice))\n\n\ndef ensure_tuple(vals: Any, wrap_array: bool = False) -> tuple:\n    \"\"\"\n    Returns a tuple of `vals`.\n\n    Args:\n        vals: input data to convert to a tuple.\n        wrap_array: if `True`, treat the input numerical array (ndarray/tensor) as one item of the tuple.\n            if `False`, try to convert the array with `tuple(vals)`, default to `False`.\n\n    \"\"\"\n    if wrap_array and isinstance(vals, (np.ndarray, torch.Tensor)):\n        return (vals,)\n    return tuple(vals) if issequenceiterable(vals) else (vals,)\n\n\ndef ensure_tuple_size(vals: Any, dim: int, pad_val: Any = 0, pad_from_start: bool = False) -> tuple:\n    \"\"\"\n    Returns a copy of `tup` with `dim` values by either shortened or padded with `pad_val` as necessary.\n    \"\"\"\n    tup = ensure_tuple(vals)\n    pad_dim = dim - len(tup)\n    if pad_dim <= 0:\n        return tup[:dim]\n    if pad_from_start:\n        return (pad_val,) * pad_dim + tup\n    return tup + (pad_val,) * pad_dim\n\n\ndef ensure_tuple_rep(tup: Any, dim: int) -> tuple[Any, ...]:\n    \"\"\"\n    Returns a copy of `tup` with `dim` values by either shortened or duplicated input.\n\n    Raises:\n        ValueError: When ``tup`` is a sequence and ``tup`` length is not ``dim``.\n\n    Examples::\n\n        >>> ensure_tuple_rep(1, 3)\n        (1, 1, 1)\n        >>> ensure_tuple_rep(None, 3)\n        (None, None, None)\n        >>> ensure_tuple_rep('test', 3)\n        ('test', 'test', 'test')\n        >>> ensure_tuple_rep([1, 2, 3], 3)\n        (1, 2, 3)\n        >>> ensure_tuple_rep(range(3), 3)\n        (0, 1, 2)\n        >>> ensure_tuple_rep([1, 2], 3)\n        ValueError: Sequence must have length 3, got length 2.\n\n    \"\"\"\n    if isinstance(tup, torch.Tensor):\n        tup = tup.detach().cpu().numpy()\n    if isinstance(tup, np.ndarray):\n        tup = tup.tolist()\n    if not issequenceiterable(tup):\n        return (tup,) * dim\n    if len(tup) == dim:\n        return tuple(tup)\n\n    raise ValueError(f\"Sequence must have length {dim}, got {len(tup)}.\")\n\n\ndef to_tuple_of_dictionaries(dictionary_of_tuples: dict, keys: Any) -> tuple[dict[Any, Any], ...]:\n    \"\"\"\n    Given a dictionary whose values contain scalars or tuples (with the same length as ``keys``),\n    Create a dictionary for each key containing the scalar values mapping to that key.\n\n    Args:\n        dictionary_of_tuples: a dictionary whose values are scalars or tuples whose length is\n            the length of ``keys``\n        keys: a tuple of string values representing the keys in question\n\n    Returns:\n        a tuple of dictionaries that contain scalar values, one dictionary for each key\n\n    Raises:\n        ValueError: when values in the dictionary are tuples but not the same length as the length\n        of ``keys``\n\n    Examples:\n        >>> to_tuple_of_dictionaries({'a': 1 'b': (2, 3), 'c': (4, 4)}, (\"x\", \"y\"))\n        ({'a':1, 'b':2, 'c':4}, {'a':1, 'b':3, 'c':4})\n\n    \"\"\"\n\n    keys = ensure_tuple(keys)\n    if len(keys) == 0:\n        return tuple({})\n\n    dict_overrides = {k: ensure_tuple_rep(v, len(keys)) for k, v in dictionary_of_tuples.items()}\n    return tuple({k: v[ik] for (k, v) in dict_overrides.items()} for ik in range(len(keys)))\n\n\ndef fall_back_tuple(\n    user_provided: Any, default: Sequence | NdarrayTensor, func: Callable = lambda x: x and x > 0\n) -> tuple[Any, ...]:\n    \"\"\"\n    Refine `user_provided` according to the `default`, and returns as a validated tuple.\n\n    The validation is done for each element in `user_provided` using `func`.\n    If `func(user_provided[idx])` returns False, the corresponding `default[idx]` will be used\n    as the fallback.\n\n    Typically used when `user_provided` is a tuple of window size provided by the user,\n    `default` is defined by data, this function returns an updated `user_provided` with its non-positive\n    components replaced by the corresponding components from `default`.\n\n    Args:\n        user_provided: item to be validated.\n        default: a sequence used to provided the fallbacks.\n        func: a Callable to validate every components of `user_provided`.\n\n    Examples::\n\n        >>> fall_back_tuple((1, 2), (32, 32))\n        (1, 2)\n        >>> fall_back_tuple(None, (32, 32))\n        (32, 32)\n        >>> fall_back_tuple((-1, 10), (32, 32))\n        (32, 10)\n        >>> fall_back_tuple((-1, None), (32, 32))\n        (32, 32)\n        >>> fall_back_tuple((1, None), (32, 32))\n        (1, 32)\n        >>> fall_back_tuple(0, (32, 32))\n        (32, 32)\n        >>> fall_back_tuple(range(3), (32, 64, 48))\n        (32, 1, 2)\n        >>> fall_back_tuple([0], (32, 32))\n        ValueError: Sequence must have length 2, got length 1.\n\n    \"\"\"\n    ndim = len(default)\n    user = ensure_tuple_rep(user_provided, ndim)\n    return tuple(  # use the default values if user provided is not valid\n        user_c if func(user_c) else default_c for default_c, user_c in zip(default, user)\n    )\n\n\ndef is_scalar_tensor(val: Any) -> bool:\n    return isinstance(val, torch.Tensor) and val.ndim == 0\n\n\ndef is_scalar(val: Any) -> bool:\n    if isinstance(val, torch.Tensor) and val.ndim == 0:\n        return True\n    return bool(np.isscalar(val))\n\n\ndef progress_bar(index: int, count: int, desc: str | None = None, bar_len: int = 30, newline: bool = False) -> None:\n    \"\"\"print a progress bar to track some time consuming task.\n\n    Args:\n        index: current status in progress.\n        count: total steps of the progress.\n        desc: description of the progress bar, if not None, show before the progress bar.\n        bar_len: the total length of the bar on screen, default is 30 char.\n        newline: whether to print in a new line for every index.\n    \"\"\"\n    end = \"\\r\" if not newline else \"\\r\\n\"\n    filled_len = int(bar_len * index // count)\n    bar = f\"{desc} \" if desc is not None else \"\"\n    bar += \"[\" + \"=\" * filled_len + \" \" * (bar_len - filled_len) + \"]\"\n    print(f\"{index}/{count} {bar}\", end=end)\n    if index == count:\n        print(\"\")\n\n\ndef get_seed() -> int | None:\n    return _seed\n\n\ndef set_determinism(\n    seed: int | None = NP_MAX,\n    use_deterministic_algorithms: bool | None = None,\n    additional_settings: Sequence[Callable[[int], Any]] | Callable[[int], Any] | None = None,\n) -> None:\n    \"\"\"\n    Set random seed for modules to enable or disable deterministic training.\n\n    Args:\n        seed: the random seed to use, default is np.iinfo(np.int32).max.\n            It is recommended to set a large seed, i.e. a number that has a good balance\n            of 0 and 1 bits. Avoid having many 0 bits in the seed.\n            if set to None, will disable deterministic training.\n        use_deterministic_algorithms: Set whether PyTorch operations must use \"deterministic\" algorithms.\n        additional_settings: additional settings that need to set random seed.\n\n    Note:\n\n        This function will not affect the randomizable objects in :py:class:`monai.transforms.Randomizable`, which\n        have independent random states. For those objects, the ``set_random_state()`` method should be used to\n        ensure the deterministic behavior (alternatively, :py:class:`monai.data.DataLoader` by default sets the seeds\n        according to the global random state, please see also: :py:class:`monai.data.utils.worker_init_fn` and\n        :py:class:`monai.data.utils.set_rnd`).\n    \"\"\"\n    if seed is None:\n        # cast to 32 bit seed for CUDA\n        seed_ = torch.default_generator.seed() % MAX_SEED\n        torch.manual_seed(seed_)\n    else:\n        seed = int(seed) % MAX_SEED\n        torch.manual_seed(seed)\n\n    global _seed\n    _seed = seed\n    random.seed(seed)\n    np.random.seed(seed)\n\n    if additional_settings is not None:\n        additional_settings = ensure_tuple(additional_settings)\n        for func in additional_settings:\n            func(seed)\n\n    with torch.backends.__allow_nonbracketed_mutation():  # FIXME: better method without accessing private member\n        if seed is not None:\n            torch.backends.cudnn.deterministic = True\n            torch.backends.cudnn.benchmark = False\n        else:  # restore the original flags\n            torch.backends.cudnn.deterministic = _flag_deterministic\n            torch.backends.cudnn.benchmark = _flag_cudnn_benchmark\n\n    if use_deterministic_algorithms is not None:\n        torch.use_deterministic_algorithms(use_deterministic_algorithms)\n\n\ndef list_to_dict(items):\n    \"\"\"\n    To convert a list of \"key=value\" pairs into a dictionary.\n    For examples: items: `[\"a=1\", \"b=2\", \"c=3\"]`, return: {\"a\": \"1\", \"b\": \"2\", \"c\": \"3\"}.\n    If no \"=\" in the pair, use None as the value, for example: [\"a\"], return: {\"a\": None}.\n    Note that it will remove the blanks around keys and values.\n\n    \"\"\"\n\n    def _parse_var(s):\n        items = s.split(\"=\", maxsplit=1)\n        key = items[0].strip(\" \\n\\r\\t'\")\n        value = items[1].strip(\" \\n\\r\\t'\") if len(items) > 1 else None\n        return key, value\n\n    d = {}\n    if items:\n        for item in items:\n            key, value = _parse_var(item)\n\n            try:\n                if key in d:\n                    raise KeyError(f\"encounter duplicated key {key}.\")\n                d[key] = literal_eval(value)\n            except ValueError:\n                try:\n                    d[key] = bool(_strtobool(str(value)))\n                except ValueError:\n                    d[key] = value\n    return d\n\n\ndef copy_to_device(\n    obj: Any, device: str | torch.device | None, non_blocking: bool = True, verbose: bool = False\n) -> Any:\n    \"\"\"\n    Copy object or tuple/list/dictionary of objects to ``device``.\n\n    Args:\n        obj: object or tuple/list/dictionary of objects to move to ``device``.\n        device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``,\n            ``cuda:0``, etc.) or of type ``torch.device``.\n        non_blocking: when `True`, moves data to device asynchronously if\n            possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.\n        verbose: when `True`, will print a warning for any elements of incompatible type\n            not copied to ``device``.\n    Returns:\n        Same as input, copied to ``device`` where possible. Original input will be\n            unchanged.\n    \"\"\"\n\n    if hasattr(obj, \"to\"):\n        return obj.to(device, non_blocking=non_blocking)\n    if isinstance(obj, tuple):\n        return tuple(copy_to_device(o, device, non_blocking) for o in obj)\n    if isinstance(obj, list):\n        return [copy_to_device(o, device, non_blocking) for o in obj]\n    if isinstance(obj, dict):\n        return {k: copy_to_device(o, device, non_blocking) for k, o in obj.items()}\n    if verbose:\n        fn_name = cast(types.FrameType, inspect.currentframe()).f_code.co_name\n        warnings.warn(f\"{fn_name} called with incompatible type: \" + f\"{type(obj)}. Data will be returned unchanged.\")\n\n    return obj\n\n\ndef str2bool(value: str | bool, default: bool = False, raise_exc: bool = True) -> bool:\n    \"\"\"\n    Convert a string to a boolean. Case insensitive.\n    True: yes, true, t, y, 1. False: no, false, f, n, 0.\n\n    Args:\n        value: string to be converted to a boolean. If value is a bool already, simply return it.\n        raise_exc: if value not in tuples of expected true or false inputs,\n            should we raise an exception? If not, return `default`.\n    Raises\n        ValueError: value not in tuples of expected true or false inputs and\n            `raise_exc` is `True`.\n    Useful with argparse, for example:\n        parser.add_argument(\"--convert\", default=False, type=str2bool)\n        python mycode.py --convert=True\n    \"\"\"\n\n    if isinstance(value, bool):\n        return value\n\n    true_set = (\"yes\", \"true\", \"t\", \"y\", \"1\")\n    false_set = (\"no\", \"false\", \"f\", \"n\", \"0\")\n\n    if isinstance(value, str):\n        value = value.lower()\n        if value in true_set:\n            return True\n        if value in false_set:\n            return False\n\n    if raise_exc:\n        raise ValueError(f\"Got \\\"{value}\\\", expected a value from: {', '.join(true_set + false_set)}\")\n    return default\n\n\ndef str2list(value: str | list | None, raise_exc: bool = True) -> list | None:\n    \"\"\"\n    Convert a string to a list.  Useful with argparse commandline arguments:\n        parser.add_argument(\"--blocks\", default=[1,2,3], type=str2list)\n        python mycode.py --blocks=1,2,2,4\n\n    Args:\n        value: string (comma separated) to be converted to a list\n        raise_exc: if not possible to convert to a list, raise an exception\n    Raises\n        ValueError: value not a string or list or not possible to convert\n    \"\"\"\n\n    if value is None:\n        return None\n    elif isinstance(value, list):\n        return value\n    elif isinstance(value, str):\n        v = value.split(\",\")\n        for i in range(len(v)):\n            try:\n                a = literal_eval(v[i].strip())  # attempt to convert\n                v[i] = a\n            except Exception:\n                pass\n        return v\n    elif raise_exc:\n        raise ValueError(f'Unable to convert \"{value}\", expected a comma-separated str, e.g. 1,2,3')\n\n    return None\n\n\nclass MONAIEnvVars:\n    \"\"\"\n    Environment variables used by MONAI.\n    \"\"\"\n\n    @staticmethod\n    def data_dir() -> str | None:\n        return os.environ.get(\"MONAI_DATA_DIRECTORY\")\n\n    @staticmethod\n    def debug() -> bool:\n        val = os.environ.get(\"MONAI_DEBUG\", False)\n        return val if isinstance(val, bool) else str2bool(val)\n\n    @staticmethod\n    def doc_images() -> str | None:\n        return os.environ.get(\"MONAI_DOC_IMAGES\")\n\n    @staticmethod\n    def algo_hash() -> str | None:\n        return os.environ.get(\"MONAI_ALGO_HASH\", \"21ed8e5\")\n\n    @staticmethod\n    def trace_transform() -> str | None:\n        return os.environ.get(\"MONAI_TRACE_TRANSFORM\", \"1\")\n\n    @staticmethod\n    def eval_expr() -> str | None:\n        return os.environ.get(\"MONAI_EVAL_EXPR\", \"1\")\n\n    @staticmethod\n    def allow_missing_reference() -> str | None:\n        return os.environ.get(\"MONAI_ALLOW_MISSING_REFERENCE\", \"1\")\n\n    @staticmethod\n    def extra_test_data() -> str | None:\n        return os.environ.get(\"MONAI_EXTRA_TEST_DATA\", \"1\")\n\n    @staticmethod\n    def testing_algo_template() -> str | None:\n        return os.environ.get(\"MONAI_TESTING_ALGO_TEMPLATE\", None)\n\n\nclass ImageMetaKey:\n    \"\"\"\n    Common key names in the metadata header of images\n    \"\"\"\n\n    FILENAME_OR_OBJ = \"filename_or_obj\"\n    PATCH_INDEX = \"patch_index\"\n    SPATIAL_SHAPE = \"spatial_shape\"\n\n\ndef has_option(obj: Callable, keywords: str | Sequence[str]) -> bool:\n    \"\"\"\n    Return a boolean indicating whether the given callable `obj` has the `keywords` in its signature.\n    \"\"\"\n    if not callable(obj):\n        return False\n    sig = inspect.signature(obj)\n    return all(key in sig.parameters for key in ensure_tuple(keywords))\n\n\ndef is_module_ver_at_least(module, version):\n    \"\"\"Determine if a module's version is at least equal to the given value.\n\n    Args:\n        module: imported module's name, e.g., `np` or `torch`.\n        version: required version, given as a tuple, e.g., `(1, 8, 0)`.\n    Returns:\n        `True` if module is the given version or newer.\n    \"\"\"\n    test_ver = \".\".join(map(str, version))\n    return module.__version__ != test_ver and version_leq(test_ver, module.__version__)\n\n\ndef sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True, *slicevals: int) -> NdarrayOrTensor:\n    \"\"\"sample several slices of input numpy array or Tensor on specified `dim`.\n\n    Args:\n        data: input data to sample slices, can be numpy array or PyTorch Tensor.\n        dim: expected dimension index to sample slices, default to `1`.\n        as_indices: if `True`, `slicevals` arg will be treated as the expected indices of slice, like: `1, 3, 5`\n            means `data[..., [1, 3, 5], ...]`, if `False`, `slicevals` arg will be treated as args for `slice` func,\n            like: `1, None` means `data[..., [1:], ...]`, `1, 5` means `data[..., [1: 5], ...]`.\n        slicevals: indices of slices or start and end indices of expected slices, depends on `as_indices` flag.\n\n    \"\"\"\n    slices = [slice(None)] * len(data.shape)\n    slices[dim] = slicevals if as_indices else slice(*slicevals)  # type: ignore\n\n    return data[tuple(slices)]\n\n\ndef check_parent_dir(path: PathLike, create_dir: bool = True) -> None:\n    \"\"\"\n    Utility to check whether the parent directory of the `path` exists.\n\n    Args:\n        path: input path to check the parent directory.\n        create_dir: if True, when the parent directory doesn't exist, create the directory,\n            otherwise, raise exception.\n\n    \"\"\"\n    path = Path(path)\n    path_dir = path.parent\n    if not path_dir.exists():\n        if create_dir:\n            path_dir.mkdir(parents=True)\n        else:\n            raise ValueError(f\"the directory of specified path does not exist: `{path_dir}`.\")\n\n\ndef save_obj(\n    obj: object,\n    path: PathLike,\n    create_dir: bool = True,\n    atomic: bool = True,\n    func: Callable | None = None,\n    **kwargs: Any,\n) -> None:\n    \"\"\"\n    Save an object to file with specified path.\n    Support to serialize to a temporary file first, then move to final destination,\n    so that files are guaranteed to not be damaged if exception occurs.\n\n    Args:\n        obj: input object data to save.\n        path: target file path to save the input object.\n        create_dir: whether to create dictionary of the path if not existing, default to `True`.\n        atomic: if `True`, state is serialized to a temporary file first, then move to final destination.\n            so that files are guaranteed to not be damaged if exception occurs. default to `True`.\n        func: the function to save file, if None, default to `torch.save`.\n        kwargs: other args for the save `func` except for the checkpoint and filename.\n            default `func` is `torch.save()`, details of other args:\n            https://pytorch.org/docs/stable/generated/torch.save.html.\n\n    \"\"\"\n    path = Path(path)\n    check_parent_dir(path=path, create_dir=create_dir)\n    if path.exists():\n        # remove the existing file\n        os.remove(path)\n\n    if func is None:\n        func = torch.save\n\n    if not atomic:\n        func(obj=obj, f=path, **kwargs)\n        return\n    try:\n        # writing to a temporary directory and then using a nearly atomic rename operation\n        with tempfile.TemporaryDirectory() as tempdir:\n            temp_path: Path = Path(tempdir) / path.name\n            func(obj=obj, f=temp_path, **kwargs)\n            if temp_path.is_file():\n                shutil.move(str(temp_path), path)\n    except PermissionError:  # project-monai/monai issue #3613\n        pass\n\n\ndef label_union(x: list | np.ndarray) -> list:\n    \"\"\"\n    Compute the union of class IDs in label and generate a list to include all class IDs\n    Args:\n        x: a list of numbers (for example, class_IDs)\n\n    Returns\n        a list showing the union (the union the class IDs)\n    \"\"\"\n    return list(set.union(set(np.array(x).tolist())))\n\n\ndef prob2class(x: torch.Tensor, sigmoid: bool = False, threshold: float = 0.5, **kwargs: Any) -> torch.Tensor:\n    \"\"\"\n    Compute the lab from the probability of predicted feature maps\n\n    Args:\n        sigmoid: If the sigmoid function should be used.\n        threshold: threshold value to activate the sigmoid function.\n    \"\"\"\n    return torch.argmax(x, **kwargs) if not sigmoid else (x > threshold).int()\n\n\ndef path_to_uri(path: PathLike) -> str:\n    \"\"\"\n    Convert a file path to URI. if not absolute path, will convert to absolute path first.\n\n    Args:\n        path: input file path to convert, can be a string or `Path` object.\n\n    \"\"\"\n    return Path(path).absolute().as_uri()\n\n\ndef pprint_edges(val: Any, n_lines: int = 20) -> str:\n    \"\"\"\n    Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines.\n\n    Returns: the formatted string.\n    \"\"\"\n    val_str = pprint.pformat(val).splitlines(True)\n    n_lines = max(n_lines, 1)\n    if len(val_str) > n_lines * 2 + 3:\n        hidden_n = len(val_str) - n_lines * 2\n        val_str = val_str[:n_lines] + [f\"\\n ... omitted {hidden_n} line(s)\\n\\n\"] + val_str[-n_lines:]\n    return \"\".join(val_str)\n\n\ndef check_key_duplicates(ordered_pairs: Sequence[tuple[Any, Any]]) -> dict[Any, Any]:\n    \"\"\"\n    Checks if there is a duplicated key in the sequence of `ordered_pairs`.\n    If there is - it will log a warning or raise ValueError\n    (if configured by environmental var `MONAI_FAIL_ON_DUPLICATE_CONFIG==1`)\n\n    Otherwise, it returns the dict made from this sequence.\n\n    Satisfies a format for an `object_pairs_hook` in `json.load`\n\n    Args:\n        ordered_pairs: sequence of (key, value)\n    \"\"\"\n    keys = set()\n    for k, _ in ordered_pairs:\n        if k in keys:\n            if os.environ.get(\"MONAI_FAIL_ON_DUPLICATE_CONFIG\", \"0\") == \"1\":\n                raise ValueError(f\"Duplicate key: `{k}`\")\n            else:\n                warnings.warn(f\"Duplicate key: `{k}`\")\n        else:\n            keys.add(k)\n    return dict(ordered_pairs)\n\n\nclass CheckKeyDuplicatesYamlLoader(SafeLoader):\n\n    def construct_mapping(self, node, deep=False):\n        mapping = set()\n        for key_node, _ in node.value:\n            key = self.construct_object(key_node, deep=deep)\n            if key in mapping:\n                if os.environ.get(\"MONAI_FAIL_ON_DUPLICATE_CONFIG\", \"0\") == \"1\":\n                    raise ValueError(f\"Duplicate key: `{key}`\")\n                else:\n                    warnings.warn(f\"Duplicate key: `{key}`\")\n            mapping.add(key)\n        return super().construct_mapping(node, deep)\n\n\nclass ConvertUnits:\n    \"\"\"\n    Convert the values from input unit to the target unit\n\n    Args:\n        input_unit: the unit of the input quantity\n        target_unit: the unit of the target quantity\n\n    \"\"\"\n\n    imperial_unit_of_length = {\"inch\": 0.0254, \"foot\": 0.3048, \"yard\": 0.9144, \"mile\": 1609.344}\n\n    unit_prefix = {\n        \"peta\": 15,\n        \"tera\": 12,\n        \"giga\": 9,\n        \"mega\": 6,\n        \"kilo\": 3,\n        \"hecto\": 2,\n        \"deca\": 1,\n        \"deci\": -1,\n        \"centi\": -2,\n        \"milli\": -3,\n        \"micro\": -6,\n        \"nano\": -9,\n        \"pico\": -12,\n        \"femto\": -15,\n    }\n    base_units = [\"meter\", \"byte\", \"bit\"]\n\n    def __init__(self, input_unit: str, target_unit: str) -> None:\n        self.input_unit, input_base = self._get_valid_unit_and_base(input_unit)\n        self.target_unit, target_base = self._get_valid_unit_and_base(target_unit)\n        if input_base == target_base:\n            self.unit_base = input_base\n        else:\n            raise ValueError(\n                \"Both input and target units should be from the same quantity. \"\n                f\"Input quantity is {input_base} while target quantity is {target_base}\"\n            )\n        self.conversion_factor = self._calculate_conversion_factor()\n\n    def _get_valid_unit_and_base(self, unit):\n        unit = str(unit).lower()\n        if unit in self.imperial_unit_of_length:\n            return unit, \"meter\"\n        for base_unit in self.base_units:\n            if unit.endswith(base_unit):\n                return unit, base_unit\n        raise ValueError(f\"Currently, it only supports length conversion but `{unit}` is given.\")\n\n    def _get_unit_power(self, unit):\n        \"\"\"Calculate the power of the unit factor with respect to the base unit\"\"\"\n        if unit in self.imperial_unit_of_length:\n            return log10(self.imperial_unit_of_length[unit])\n\n        prefix = unit[: len(self.unit_base)]\n        if prefix == \"\":\n            return 1.0\n        return self.unit_prefix[prefix]\n\n    def _calculate_conversion_factor(self):\n        \"\"\"Calculate unit conversion factor with respect to the input unit\"\"\"\n        if self.input_unit == self.target_unit:\n            return 1.0\n        input_power = self._get_unit_power(self.input_unit)\n        target_power = self._get_unit_power(self.target_unit)\n        return 10 ** (input_power - target_power)\n\n    def __call__(self, value: int | float) -> Any:\n        return float(value) * self.conversion_factor\n\n\ndef check_kwargs_exist_in_class_init(cls, kwargs):\n    \"\"\"\n    Check if the all keys in kwargs exist in the __init__ method of the class.\n\n    Args:\n        cls: the class to check.\n        kwargs: kwargs to examine.\n\n    Returns:\n        a boolean indicating if all keys exist.\n        a set of extra keys that are not used in the __init__.\n    \"\"\"\n    init_signature = inspect.signature(cls.__init__)\n    init_params = set(init_signature.parameters) - {\"self\"}  # Exclude 'self' from the parameter list\n    input_kwargs = set(kwargs)\n    extra_kwargs = input_kwargs - init_params\n\n    return extra_kwargs == set(), extra_kwargs\n\n\ndef run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess:\n    \"\"\"\n    Run a command by using ``subprocess.run`` with capture_output=True and stderr=subprocess.STDOUT\n    so that the raise exception will have that information. The argument `capture_output` can be set explicitly\n    if desired, but will be overriden with the debug status from the variable.\n\n    Args:\n        cmd_list: a list of strings describing the command to run.\n        kwargs: keyword arguments supported by the ``subprocess.run`` method.\n\n    Returns:\n        a CompletedProcess instance after the command completes.\n    \"\"\"\n    debug = MONAIEnvVars.debug()\n    # Always capture output when check=True so that error details are available\n    # in the CalledProcessError exception for debugging subprocess failures.\n    if kwargs.get(\"check\", False):\n        kwargs.setdefault(\"capture_output\", True)\n    else:\n        kwargs[\"capture_output\"] = kwargs.get(\"capture_output\", debug)\n\n    if kwargs.pop(\"run_cmd_verbose\", False):\n        import monai\n\n        monai.apps.utils.get_logger(\"monai.utils.run_cmd\").info(f\"{cmd_list}\")  # type: ignore[attr-defined]\n    try:\n        return subprocess.run(cmd_list, **kwargs)\n    except subprocess.CalledProcessError as e:\n        output = str(e.stdout.decode(errors=\"replace\")) if e.stdout else \"\"\n        errors = str(e.stderr.decode(errors=\"replace\")) if e.stderr else \"\"\n        raise RuntimeError(f\"subprocess call error {e.returncode}: {errors}, {output}\") from e\n\n\ndef is_sqrt(num: Sequence[int] | int) -> bool:\n    \"\"\"\n    Determine if the input is a square number or a squence of square numbers.\n    \"\"\"\n    num = ensure_tuple(num)\n    sqrt_num = [int(math.sqrt(_num)) for _num in num]\n    ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]\n    return ensure_tuple(ret) == num\n\n\ndef unsqueeze_right(arr: NT, ndim: int) -> NT:\n    \"\"\"Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.\"\"\"\n    return arr[(...,) + (None,) * (ndim - arr.ndim)]\n\n\ndef unsqueeze_left(arr: NT, ndim: int) -> NT:\n    \"\"\"Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.\"\"\"\n    return arr[(None,) * (ndim - arr.ndim)]\n\n\ndef flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]:\n    \"\"\"\n    Flatten the nested dictionary to a flat dictionary.\n    \"\"\"\n    result = {}\n    for key, value in metrics.items():\n        if isinstance(value, dict):\n            result.update(flatten_dict(value))\n        else:\n            result[key] = value\n    return result\n"
  },
  {
    "path": "monai/utils/module.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport enum\nimport functools\nimport os\nimport pdb\nimport re\nimport sys\nimport warnings\nfrom collections.abc import Callable, Collection, Hashable, Iterable, Mapping\nfrom functools import partial, wraps\nfrom importlib import import_module\nfrom pkgutil import walk_packages\nfrom pydoc import locate\nfrom re import match\nfrom types import FunctionType, ModuleType\nfrom typing import Any, cast\n\nimport torch\n\n# bundle config system flags\n# set MONAI_EVAL_EXPR=1 to use 'eval', default value: run_eval=True\nrun_eval = os.environ.get(\"MONAI_EVAL_EXPR\", \"1\") != \"0\"\n# set MONAI_DEBUG_CONFIG=1 to run in debug mode, default value: run_debug=False\nrun_debug = os.environ.get(\"MONAI_DEBUG_CONFIG\", \"0\") != \"0\"\n# set MONAI_ALLOW_MISSING_REFERENCE=1 to allow missing references, default value: allow_missing_reference=False\nallow_missing_reference = os.environ.get(\"MONAI_ALLOW_MISSING_REFERENCE\", \"0\") != \"0\"\n\nOPTIONAL_IMPORT_MSG_FMT = \"{}\"\n\n__all__ = [\n    \"InvalidPyTorchVersionError\",\n    \"OptionalImportError\",\n    \"exact_version\",\n    \"damerau_levenshtein_distance\",\n    \"look_up_option\",\n    \"min_version\",\n    \"optional_import\",\n    \"require_pkg\",\n    \"instantiate\",\n    \"get_full_type_name\",\n    \"get_package_version\",\n    \"get_torch_version_tuple\",\n    \"version_leq\",\n    \"version_geq\",\n    \"pytorch_after\",\n]\n\n\ndef look_up_option(\n    opt_str: Hashable,\n    supported: Collection | enum.EnumMeta,\n    default: Any = \"no_default\",\n    print_all_options: bool = True,\n) -> Any:\n    \"\"\"\n    Look up the option in the supported collection and return the matched item.\n    Raise a value error possibly with a guess of the closest match.\n\n    Args:\n        opt_str: The option string or Enum to look up.\n        supported: The collection of supported options, it can be list, tuple, set, dict, or Enum.\n        default: If it is given, this method will return `default` when `opt_str` is not found,\n            instead of raising a `ValueError`. Otherwise, it defaults to `\"no_default\"`,\n            so that the method may raise a `ValueError`.\n        print_all_options: whether to print all available options when `opt_str` is not found. Defaults to True\n\n    Examples:\n\n    .. code-block:: python\n\n        from enum import Enum\n        from monai.utils import look_up_option\n        class Color(Enum):\n            RED = \"red\"\n            BLUE = \"blue\"\n        look_up_option(\"red\", Color)  # <Color.RED: 'red'>\n        look_up_option(Color.RED, Color)  # <Color.RED: 'red'>\n        look_up_option(\"read\", Color)\n        # ValueError: By 'read', did you mean 'red'?\n        # 'read' is not a valid option.\n        # Available options are {'blue', 'red'}.\n        look_up_option(\"red\", {\"red\", \"blue\"})  # \"red\"\n\n    Adapted from https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/utilities/util_common.py#L249\n    \"\"\"\n    if not isinstance(opt_str, Hashable):\n        raise ValueError(f\"Unrecognized option type: {type(opt_str)}:{opt_str}.\")\n    if isinstance(opt_str, str):\n        opt_str = opt_str.strip()\n    if isinstance(supported, enum.EnumMeta):\n        if isinstance(opt_str, str) and opt_str in {item.value for item in supported}:  # type: ignore\n            # such as: \"example\" in MyEnum\n            return supported(opt_str)\n        if isinstance(opt_str, enum.Enum) and opt_str in supported:\n            # such as: MyEnum.EXAMPLE in MyEnum\n            return opt_str\n    elif isinstance(supported, Mapping) and opt_str in supported:\n        # such as: MyDict[key]\n        return supported[opt_str]\n    elif isinstance(supported, Collection) and opt_str in supported:\n        return opt_str\n\n    if default != \"no_default\":\n        return default\n\n    # find a close match\n    set_to_check: set\n    if isinstance(supported, enum.EnumMeta):\n        set_to_check = {item.value for item in supported}  # type: ignore\n    else:\n        set_to_check = set(supported) if supported is not None else set()\n    if not set_to_check:\n        raise ValueError(f\"No options available: {supported}.\")\n    edit_dists = {}\n    opt_str = f\"{opt_str}\"\n    for key in set_to_check:\n        edit_dist = damerau_levenshtein_distance(f\"{key}\", opt_str)\n        if edit_dist <= 3:\n            edit_dists[key] = edit_dist\n\n    supported_msg = f\"Available options are {set_to_check}.\\n\" if print_all_options else \"\"\n    if edit_dists:\n        guess_at_spelling = min(edit_dists, key=edit_dists.get)  # type: ignore\n        raise ValueError(\n            f\"By '{opt_str}', did you mean '{guess_at_spelling}'?\\n\"\n            + f\"'{opt_str}' is not a valid value.\\n\"\n            + supported_msg\n        )\n    raise ValueError(f\"Unsupported option '{opt_str}', \" + supported_msg)\n\n\ndef damerau_levenshtein_distance(s1: str, s2: str) -> int:\n    \"\"\"\n    Calculates the Damerau–Levenshtein distance between two strings for spelling correction.\n    https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance\n    \"\"\"\n    if s1 == s2:\n        return 0\n    string_1_length = len(s1)\n    string_2_length = len(s2)\n    if not s1:\n        return string_2_length\n    if not s2:\n        return string_1_length\n    d = {(i, -1): i + 1 for i in range(-1, string_1_length + 1)}\n    for j in range(-1, string_2_length + 1):\n        d[(-1, j)] = j + 1\n\n    for i, s1i in enumerate(s1):\n        for j, s2j in enumerate(s2):\n            cost = 0 if s1i == s2j else 1\n            d[(i, j)] = min(\n                d[(i - 1, j)] + 1, d[(i, j - 1)] + 1, d[(i - 1, j - 1)] + cost  # deletion  # insertion  # substitution\n            )\n            if i and j and s1i == s2[j - 1] and s1[i - 1] == s2j:\n                d[(i, j)] = min(d[(i, j)], d[i - 2, j - 2] + cost)  # transposition\n\n    return d[string_1_length - 1, string_2_length - 1]\n\n\ndef load_submodules(\n    basemod: ModuleType, load_all: bool = True, exclude_pattern: str = \"(.*[tT]est.*)|(_.*)\"\n) -> tuple[list[ModuleType], list[str]]:\n    \"\"\"\n    Traverse the source of the module structure starting with module `basemod`, loading all packages plus all files if\n    `load_all` is True, excluding anything whose name matches `exclude_pattern`.\n    \"\"\"\n    submodules = []\n    err_mod: list[str] = []\n    for importer, name, is_pkg in walk_packages(\n        basemod.__path__, prefix=basemod.__name__ + \".\", onerror=err_mod.append\n    ):\n        if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None:\n            try:\n                mod = import_module(name)\n                mod_spec = importer.find_spec(name)  # type: ignore\n                if mod_spec and mod_spec.loader:\n                    loader = mod_spec.loader\n                    loader.exec_module(mod)\n                    submodules.append(mod)\n            except OptionalImportError:\n                pass  # could not import the optional deps., they are ignored\n            except ImportError as e:\n                msg = (\n                    \"\\nMultiple versions of MONAI may have been installed?\\n\"\n                    \"Please see the installation guide: https://monai.readthedocs.io/en/stable/installation.html\\n\"\n                )  # issue project-monai/monai#5193\n                raise type(e)(f\"{e}\\n{msg}\").with_traceback(e.__traceback__) from e  # raise with modified message\n\n    return submodules, err_mod\n\n\ndef instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:\n    \"\"\"\n    Create an object instance or call a callable object from a class or function represented by ``_path``.\n    `kwargs` will be part of the input arguments to the class constructor or function.\n    The target component must be a class or a function, if not, return the component directly.\n\n    Args:\n        __path: if a string is provided, it's interpreted as the full path of the target class or function component.\n            If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode=\"default\"``.\n            For ``__mode=\"callable\"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided,\n            as ``functools.partial(__path, **kwargs)`` for future invoking.\n\n        __mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``:\n\n            - ``\"default\"``: returns ``component(**kwargs)``\n            - ``\"callable\"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``\n            - ``\"debug\"``: returns ``pdb.runcall(component, **kwargs)``\n\n        kwargs: keyword arguments to the callable represented by ``__path``.\n\n    \"\"\"\n    from monai.utils.enums import CompInitMode\n\n    component = locate(__path) if isinstance(__path, str) else __path\n    if component is None:\n        raise ModuleNotFoundError(f\"Cannot locate class or function path: '{__path}'.\")\n    m = look_up_option(__mode, CompInitMode)\n    try:\n        if kwargs.pop(\"_debug_\", False) or run_debug:\n            warnings.warn(\n                f\"\\n\\npdb: instantiating component={component}, mode={m}\\n\"\n                f\"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\\n\"\n            )\n            breakpoint()\n        if not callable(component):\n            warnings.warn(f\"Component {component} is not callable when mode={m}.\")\n            return component\n        if m == CompInitMode.DEFAULT:\n            return component(**kwargs)\n        if m == CompInitMode.CALLABLE:\n            return partial(component, **kwargs) if kwargs else component\n        if m == CompInitMode.DEBUG:\n            warnings.warn(\n                f\"\\n\\npdb: instantiating component={component}, mode={m}\\n\"\n                f\"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\\n\"\n            )\n            return pdb.runcall(component, **kwargs)\n    except Exception as e:\n        raise RuntimeError(\n            f\"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}\"\n            f\"\\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode.\"\n        ) from e\n\n    warnings.warn(f\"Component to instantiate must represent a valid class or function, but got {__path}.\")\n    return component\n\n\ndef get_full_type_name(typeobj):\n    \"\"\"\n    Utility to get the full path name of a class or object type.\n\n    \"\"\"\n    module = typeobj.__module__\n    if module is None or module == str.__class__.__module__:\n        return typeobj.__name__  # Avoid reporting __builtin__\n    return module + \".\" + typeobj.__name__\n\n\ndef min_version(the_module: Any, min_version_str: str = \"\", *_args: Any) -> bool:\n    \"\"\"\n    Convert version strings into tuples of int and compare them.\n\n    Returns True if the module's version is greater or equal to the 'min_version'.\n    When min_version_str is not provided, it always returns True.\n    \"\"\"\n    if not min_version_str or not hasattr(the_module, \"__version__\"):\n        return True  # always valid version\n\n    mod_version = tuple(int(x) for x in the_module.__version__.split(\".\")[:2])\n    required = tuple(int(x) for x in min_version_str.split(\".\")[:2])\n    return mod_version >= required\n\n\ndef exact_version(the_module: Any, version_str: str = \"\", *_args: Any) -> bool:\n    \"\"\"\n    Returns True if the module's __version__ matches version_str\n    \"\"\"\n    if not hasattr(the_module, \"__version__\"):\n        warnings.warn(f\"{the_module} has no attribute __version__ in exact_version check.\")\n        return False\n    return bool(the_module.__version__ == version_str)\n\n\nclass InvalidPyTorchVersionError(Exception):\n    \"\"\"\n    Raised when called function or method requires a more recent\n    PyTorch version than that installed.\n    \"\"\"\n\n    def __init__(self, required_version, name):\n        message = f\"{name} requires PyTorch version {required_version} or later\"\n        super().__init__(message)\n\n\nclass OptionalImportError(ImportError):\n    \"\"\"\n    Could not import APIs from an optional dependency.\n    \"\"\"\n\n\ndef optional_import(\n    module: str,\n    version: str = \"\",\n    version_checker: Callable[..., bool] = min_version,\n    name: str = \"\",\n    descriptor: str = OPTIONAL_IMPORT_MSG_FMT,\n    version_args: Any = None,\n    allow_namespace_pkg: bool = False,\n    as_type: str = \"default\",\n) -> tuple[Any, bool]:\n    \"\"\"\n    Imports an optional module specified by `module` string.\n    Any importing related exceptions will be stored, and exceptions raise lazily\n    when attempting to use the failed-to-import module.\n\n    Args:\n        module: name of the module to be imported.\n        version: version string used by the version_checker.\n        version_checker: a callable to check the module version, Defaults to monai.utils.min_version.\n        name: a non-module attribute (such as method/class) to import from the imported module.\n        descriptor: a format string for the final error message when using a not imported module.\n        version_args: additional parameters to the version checker.\n        allow_namespace_pkg: whether importing a namespace package is allowed. Defaults to False.\n        as_type: there are cases where the optionally imported object is used as\n            a base class, or a decorator, the exceptions should raise accordingly. The current supported values\n            are \"default\" (call once to raise), \"decorator\" (call the constructor and the second call to raise),\n            and anything else will return a lazy class that can be used as a base class (call the constructor to raise).\n\n    Returns:\n        The imported module and a boolean flag indicating whether the import is successful.\n\n    Examples::\n\n        >>> torch, flag = optional_import('torch', '1.1')\n        >>> print(torch, flag)\n        <module 'torch' from 'python/lib/python3.6/site-packages/torch/__init__.py'> True\n\n        >>> the_module, flag = optional_import('unknown_module')\n        >>> print(flag)\n        False\n        >>> the_module.method  # trying to access a module which is not imported\n        OptionalImportError: import unknown_module (No module named 'unknown_module').\n\n        >>> torch, flag = optional_import('torch', '42', exact_version)\n        >>> torch.nn  # trying to access a module for which there isn't a proper version imported\n        OptionalImportError: import torch (requires version '42' by 'exact_version').\n\n        >>> conv, flag = optional_import('torch.nn.functional', '1.0', name='conv1d')\n        >>> print(conv)\n        <built-in method conv1d of type object at 0x11a49eac0>\n\n        >>> conv, flag = optional_import('torch.nn.functional', '42', name='conv1d')\n        >>> conv()  # trying to use a function from the not successfully imported module (due to unmatched version)\n        OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version').\n    \"\"\"\n\n    tb = None\n    exception_str = \"\"\n    if name:\n        actual_cmd = f\"from {module} import {name}\"\n    else:\n        actual_cmd = f\"import {module}\"\n    try:\n        pkg = __import__(module)  # top level module\n        the_module = import_module(module)\n        if not allow_namespace_pkg:\n            is_namespace = getattr(the_module, \"__file__\", None) is None and hasattr(the_module, \"__path__\")\n            if is_namespace:\n                raise AssertionError\n        if name:  # user specified to load class/function/... from the module\n            the_module = getattr(the_module, name)\n    except Exception as import_exception:  # any exceptions during import\n        tb = import_exception.__traceback__\n        exception_str = f\"{import_exception}\"\n    else:  # found the module\n        if version_args and version_checker(pkg, f\"{version}\", version_args):\n            return the_module, True\n        if not version_args and version_checker(pkg, f\"{version}\"):\n            return the_module, True\n\n    # preparing lazy error message\n    msg = descriptor.format(actual_cmd)\n    if version and tb is None:  # a pure version issue\n        msg += f\" (requires '{module} {version}' by '{version_checker.__name__}')\"\n    if exception_str:\n        msg += f\" ({exception_str})\"\n\n    class _LazyRaise:\n\n        def __init__(self, *_args, **_kwargs):\n            _default_msg = (\n                f\"{msg}.\"\n                + \"\\n\\nFor details about installing the optional dependencies, please visit:\"\n                + \"\\n    https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies\"\n            )\n            if tb is None:\n                self._exception = OptionalImportError(_default_msg)\n            else:\n                self._exception = OptionalImportError(_default_msg).with_traceback(tb)\n\n        def __getattr__(self, name):\n            \"\"\"\n            Raises:\n                OptionalImportError: When you call this method.\n            \"\"\"\n            raise self._exception\n\n        def __call__(self, *_args, **_kwargs):\n            \"\"\"\n            Raises:\n                OptionalImportError: When you call this method.\n            \"\"\"\n            raise self._exception\n\n        def __getitem__(self, item):\n            raise self._exception\n\n        def __iter__(self):\n            raise self._exception\n\n    if as_type == \"default\":\n        return _LazyRaise(), False\n\n    class _LazyCls(_LazyRaise):\n\n        def __init__(self, *_args, **kwargs):\n            super().__init__()\n            if not as_type.startswith(\"decorator\"):\n                raise self._exception\n\n    return _LazyCls, False\n\n\ndef require_pkg(\n    pkg_name: str, version: str = \"\", version_checker: Callable[..., bool] = min_version, raise_error: bool = True\n) -> Callable:\n    \"\"\"\n    Decorator function to check the required package installation.\n\n    Args:\n        pkg_name: required package name, like: \"itk\", \"nibabel\", etc.\n        version: required version string used by the version_checker.\n        version_checker: a callable to check the module version, defaults to `monai.utils.min_version`.\n        raise_error: if True, raise `OptionalImportError` error if the required package is not installed\n            or the version doesn't match requirement, if False, print the error in a warning.\n\n    \"\"\"\n\n    def _decorator(obj):\n        is_func = isinstance(obj, FunctionType)\n        call_obj = obj if is_func else obj.__init__\n\n        @wraps(call_obj)\n        def _wrapper(*args, **kwargs):\n            _, has = optional_import(module=pkg_name, version=version, version_checker=version_checker)\n            if not has:\n                err_msg = f\"required package `{pkg_name}` is not installed or the version doesn't match requirement.\"\n                if raise_error:\n                    raise OptionalImportError(err_msg)\n                else:\n                    warnings.warn(err_msg)\n\n            return call_obj(*args, **kwargs)\n\n        if is_func:\n            return _wrapper\n        obj.__init__ = _wrapper\n        return obj\n\n    return _decorator\n\n\ndef get_package_version(dep_name, default=\"NOT INSTALLED or UNKNOWN VERSION.\"):\n    \"\"\"\n    Try to load package and get version. If not found, return `default`.\n    \"\"\"\n    dep, has_dep = optional_import(dep_name)\n    if has_dep and hasattr(dep, \"__version__\"):\n        return dep.__version__\n    return default\n\n\n@functools.lru_cache(None)\ndef get_torch_version_tuple():\n    \"\"\"\n    Returns:\n        tuple of ints represents the pytorch major/minor version.\n    \"\"\"\n    return tuple(int(x) for x in torch.__version__.split(\".\")[:2])\n\n\ndef parse_version_strs(lhs: str, rhs: str) -> tuple[Iterable[int | str], Iterable[int | str]]:\n    \"\"\"\n    Parse the version strings.\n    \"\"\"\n\n    def _try_cast(val: str) -> int | str:\n        val = val.strip()\n        try:\n            m = match(\"(\\\\d+)(.*)\", val)\n            if m is not None:\n                val = m.groups()[0]\n                return int(val)\n            return val\n        except ValueError:\n            return val\n\n    # remove git version suffixes if present\n    lhs = lhs.split(\"+\", 1)[0]\n    rhs = rhs.split(\"+\", 1)[0]\n\n    # parse the version strings in this basic way without `packaging` package\n    lhs_ = map(_try_cast, lhs.split(\".\"))\n    rhs_ = map(_try_cast, rhs.split(\".\"))\n    return lhs_, rhs_\n\n\ndef version_leq(lhs: str, rhs: str) -> bool:\n    \"\"\"\n    Returns True if version `lhs` is earlier or equal to `rhs`.\n\n    Args:\n        lhs: version name to compare with `rhs`, return True if earlier or equal to `rhs`.\n        rhs: version name to compare with `lhs`, return True if later or equal to `lhs`.\n\n    \"\"\"\n\n    lhs, rhs = str(lhs), str(rhs)\n    pkging, has_ver = optional_import(\"packaging.version\")\n    if has_ver:\n        try:\n            return cast(bool, pkging.Version(lhs) <= pkging.Version(rhs))\n        except pkging.InvalidVersion:\n            return True\n\n    lhs_, rhs_ = parse_version_strs(lhs, rhs)\n    for l, r in zip(lhs_, rhs_):\n        if l != r:\n            if isinstance(l, int) and isinstance(r, int):\n                return l < r\n            return f\"{l}\" < f\"{r}\"\n\n    return True\n\n\ndef version_geq(lhs: str, rhs: str) -> bool:\n    \"\"\"\n    Returns True if version `lhs` is later or equal to `rhs`.\n\n    Args:\n        lhs: version name to compare with `rhs`, return True if later or equal to `rhs`.\n        rhs: version name to compare with `lhs`, return True if earlier or equal to `lhs`.\n\n    \"\"\"\n    lhs, rhs = str(lhs), str(rhs)\n    pkging, has_ver = optional_import(\"packaging.version\")\n\n    if has_ver:\n        try:\n            return cast(bool, pkging.Version(lhs) >= pkging.Version(rhs))\n        except pkging.InvalidVersion:\n            return True\n\n    lhs_, rhs_ = parse_version_strs(lhs, rhs)\n    for l, r in zip(lhs_, rhs_):\n        if l != r:\n            if isinstance(l, int) and isinstance(r, int):\n                return l > r\n            return f\"{l}\" > f\"{r}\"\n\n    return True\n\n\n@functools.lru_cache(None)\ndef pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: str | None = None) -> bool:\n    \"\"\"\n    Compute whether the current pytorch version is after or equal to the specified version.\n    The current system pytorch version is determined by `torch.__version__` or\n    via system environment variable `PYTORCH_VER`.\n\n    Args:\n        major: major version number to be compared with\n        minor: minor version number to be compared with\n        patch: patch version number to be compared with\n        current_ver_string: if None, `torch.__version__` will be used.\n\n    Returns:\n        True if the current pytorch version is greater than or equal to the specified version.\n    \"\"\"\n\n    try:\n        if current_ver_string is None:\n            _env_var = os.environ.get(\"PYTORCH_VER\", \"\")\n            current_ver_string = _env_var if _env_var else torch.__version__\n        ver, has_ver = optional_import(\"packaging.version\", name=\"parse\")\n        if has_ver:\n            return ver(\".\".join((f\"{major}\", f\"{minor}\", f\"{patch}\"))) <= ver(f\"{current_ver_string}\")  # type: ignore\n        parts = f\"{current_ver_string}\".split(\"+\", 1)[0].split(\".\", 3)\n        while len(parts) < 3:\n            parts += [\"0\"]\n        c_major, c_minor, c_patch = parts[:3]\n    except (AttributeError, ValueError, TypeError):\n        c_major, c_minor = get_torch_version_tuple()\n        c_patch = \"0\"\n    c_mn = int(c_major), int(c_minor)\n    mn = int(major), int(minor)\n    if c_mn != mn:\n        return c_mn > mn\n    is_prerelease = (\"a\" in f\"{c_patch}\".lower()) or (\"rc\" in f\"{c_patch}\".lower())\n    c_p = 0\n    try:\n        p_reg = re.search(r\"\\d+\", f\"{c_patch}\")\n        if p_reg:\n            c_p = int(p_reg.group())\n    except (AttributeError, TypeError, ValueError):\n        is_prerelease = True\n    patch = int(patch)\n    if c_p != patch:\n        return c_p > patch\n    if is_prerelease:\n        return False\n    return True\n\n\n@functools.lru_cache(None)\ndef compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:\n    \"\"\"\n    Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.\n    The current system GPU CUDA compute capability is determined by the first GPU in the system.\n    The compared version is a string in the form of \"major.minor\".\n\n    Args:\n        major: major version number to be compared with.\n        minor: minor version number to be compared with. Defaults to 0.\n        current_ver_string: if None, the current system GPU CUDA compute capability will be used.\n\n    Returns:\n        True if the current system GPU CUDA compute capability is greater than or equal to the specified version.\n    \"\"\"\n    if current_ver_string is None:\n        cuda_available = torch.cuda.is_available()\n        pynvml, has_pynvml = optional_import(\"pynvml\")\n        if not has_pynvml:  # assuming that the user has Ampere and later GPU\n            return True\n        if not cuda_available:\n            return False\n        else:\n            pynvml.nvmlInit()\n            handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # get the first GPU\n            major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)\n            current_ver_string = f\"{major_c}.{minor_c}\"\n            pynvml.nvmlShutdown()\n\n    ver, has_ver = optional_import(\"packaging.version\", name=\"parse\")\n    if has_ver:\n        return ver(\".\".join((f\"{major}\", f\"{minor}\"))) <= ver(f\"{current_ver_string}\")  # type: ignore\n    parts = f\"{current_ver_string}\".split(\"+\", 1)[0].split(\".\", 2)\n    while len(parts) < 2:\n        parts += [\"0\"]\n    c_major, c_minor = parts[:2]\n    c_mn = int(c_major), int(c_minor)\n    mn = int(major), int(minor)\n    return c_mn > mn\n"
  },
  {
    "path": "monai/utils/nvtx.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nDecorators and context managers for NVIDIA Tools Extension to profile MONAI components\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom functools import wraps\nfrom typing import Any\n\nfrom torch.autograd import Function\nfrom torch.nn import Module\nfrom torch.optim import Optimizer\nfrom torch.utils.data import Dataset\n\nfrom monai.utils import ensure_tuple, optional_import\n\n_nvtx, _ = optional_import(\"torch._C._nvtx\", descriptor=\"NVTX is not installed. Are you sure you have a CUDA build?\")\n\n__all__ = [\"Range\"]\n\n\nclass Range:\n    \"\"\"\n    A decorator and context manager for NVIDIA Tools Extension (NVTX) Range for profiling.\n    When used as a decorator it encloses a specific method of the object with an NVTX Range.\n    When used as a context manager, it encloses the runtime context (created by with statement) with an NVTX Range.\n\n    Args:\n        name: the name to be associated to the range\n        methods: (only when used as decorator) the name of a method (or a list of the name of the methods)\n            to be wrapped by NVTX range.\n            If None (default), the method(s) will be inferred based on the object's type for various MONAI components,\n            such as Networks, Losses, Functions, Transforms, and Datasets.\n            Otherwise, it look up predefined methods: \"forward\", \"__call__\", \"__next__\", \"__getitem__\"\n        append_method_name: if append the name of the methods to be decorated to the range's name\n            If None (default), it appends the method's name only if we are annotating more than one method.\n        recursive: if set to True, it will recursively annotate every individual module in a list\n            or in a chain of modules (chained using Compose). Default to False.\n\n    \"\"\"\n\n    name_counter: dict = defaultdict(int)\n\n    def __init__(\n        self,\n        name: str | None = None,\n        methods: str | tuple[str, ...] | None = None,\n        append_method_name: bool | None = None,\n        recursive: bool = False,\n    ) -> None:\n        self.name = name\n        self.methods = methods\n        self.append_method_name = append_method_name\n        self.recursive = recursive\n\n    def __call__(self, obj: Any) -> Any:\n        if self.recursive is True:\n            if isinstance(obj, (list, tuple)):\n                return type(obj)(Range(recursive=True)(t) for t in obj)\n\n            from monai.transforms.compose import Compose\n\n            if isinstance(obj, Compose):\n                obj.transforms = Range(recursive=True)(obj.transforms)\n\n            self.recursive = False\n\n        # Define the name to be associated to the range if not provided\n        if self.name is None:\n            name = type(obj).__name__\n            # If CuCIM or TorchVision transform wrappers are being used,\n            # append the underlying transform to the name for more clarity\n            if \"CuCIM\" in name or \"TorchVision\" in name:\n                name = f\"{name}_{obj.name}\"\n            self.name_counter[name] += 1\n            if self.name_counter[name] > 1:\n                self.name = f\"{name}_{self.name_counter[name]}\"\n            else:\n                self.name = name\n\n        # Define the methods to be wrapped if not provided\n        if self.methods is None:\n            self.methods = self._get_method(obj)\n        else:\n            self.methods = ensure_tuple(self.methods)\n\n        # Check if to append method's name to the range's name\n        if self.append_method_name is None:\n            if len(self.methods) > 1:\n                self.append_method_name = True\n            else:\n                self.append_method_name = False\n\n        # Decorate the methods\n        for method in self.methods:\n            self._decorate_method(obj, method, self.append_method_name)\n\n        return obj\n\n    def _decorate_method(self, obj, method, append_method_name):\n        # Append the method's name to the range's name\n        name = f\"{self.name}.{method}\" if append_method_name else self.name\n\n        # Get the class for special functions\n        if method.startswith(\"__\"):\n            owner = type(obj)\n        else:\n            owner = obj\n\n        # Get the method to be wrapped\n        _temp_func = getattr(owner, method)\n\n        # Wrap the method with NVTX range (range push/pop)\n        @wraps(_temp_func)\n        def range_wrapper(*args, **kwargs):\n            _nvtx.rangePushA(name)\n            output = _temp_func(*args, **kwargs)\n            _nvtx.rangePop()\n            return output\n\n        # Replace the method with the wrapped version\n        if method.startswith(\"__\"):\n            # If it is a special method, it requires special attention\n            class NVTXRangeDecoratedClass(owner):  # type: ignore\n                ...\n\n            setattr(NVTXRangeDecoratedClass, method, range_wrapper)\n            obj.__class__ = NVTXRangeDecoratedClass\n\n        else:\n            setattr(owner, method, range_wrapper)\n\n    def _get_method(self, obj: Any) -> tuple:\n        if isinstance(obj, Module):\n            method_list = [\"forward\"]\n        elif isinstance(obj, Optimizer):\n            method_list = [\"step\"]\n        elif isinstance(obj, Function):\n            method_list = [\"forward\", \"backward\"]\n        elif isinstance(obj, Dataset):\n            method_list = [\"__getitem__\"]\n        else:\n            default_methods = [\"forward\", \"__call__\", \"__next__\", \"__getitem__\"]\n            method_list = []\n            for method in default_methods:\n                if hasattr(obj, method):\n                    method_list.append(method)\n            if len(method_list) < 1:\n                raise ValueError(\n                    f\"The method to be wrapped for this object [{type(obj)}] is not recognized.\"\n                    \"The name of the method should be provided or the object should have one of these methods:\"\n                    f\"{default_methods}\"\n                )\n        return ensure_tuple(method_list)\n\n    def __enter__(self):\n        if self.name is None:\n            # Number the range with class variable counter to avoid duplicate names.\n            self.name_counter[\"context\"] += 1\n            self.name = f\"context_{self.name_counter['context']}\"\n\n        _nvtx.rangePushA(self.name)\n\n    def __exit__(self, type, value, traceback):\n        _nvtx.rangePop()\n"
  },
  {
    "path": "monai/utils/ordering.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport numpy as np\n\nfrom monai.utils.enums import OrderingTransformations, OrderingType\n\n\nclass Ordering:\n    \"\"\"\n    Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with\n    one of the following transformations:\n    Reflection (see np.flip for more details).\n    Transposition (see np.transpose for more details).\n    90-degree rotation (see np.rot90 for more details).\n\n    The transformations are applied in the order specified by the transformation_order parameter.\n\n    Args:\n        ordering_type: The ordering type. One of the following:\n            - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from\n            top to bottom. Also called a row major ordering.\n            - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like\n            pattern from top left towards right gowing in a spiral towards the center.\n            - random': The image is projected into a 1D sequence by randomly shuffling the image.\n        spatial_dims: The number of spatial dimensions of the image.\n        dimensions: The dimensions of the image.\n        reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension.\n        transpositions_axes: A tuple of tuples indicating the axes to transpose the image along.\n        rot90_axes: A tuple of tuples indicating the axes to rotate the image along.\n        transformation_order: The order in which to apply the transformations.\n    \"\"\"\n\n    def __init__(\n        self,\n        ordering_type: str,\n        spatial_dims: int,\n        dimensions: tuple[int, int, int] | tuple[int, int, int, int],\n        reflected_spatial_dims: tuple[bool, bool] | None = None,\n        transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None,\n        rot90_axes: tuple[tuple[int, int], ...] | None = None,\n        transformation_order: tuple[str, ...] = (\n            OrderingTransformations.TRANSPOSE.value,\n            OrderingTransformations.ROTATE_90.value,\n            OrderingTransformations.REFLECT.value,\n        ),\n    ) -> None:\n        super().__init__()\n        self.ordering_type = ordering_type\n\n        if self.ordering_type not in list(OrderingType):\n            raise ValueError(\n                f\"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}.\"\n            )\n\n        self.spatial_dims = spatial_dims\n        self.dimensions = dimensions\n\n        if len(dimensions) != self.spatial_dims + 1:\n            raise ValueError(f\"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.\")\n\n        self.reflected_spatial_dims = reflected_spatial_dims\n        self.transpositions_axes = transpositions_axes\n        self.rot90_axes = rot90_axes\n        if len(set(transformation_order)) != len(transformation_order):\n            raise ValueError(f\"No duplicates are allowed. Received {transformation_order}.\")\n\n        for transformation in transformation_order:\n            if transformation not in list(OrderingTransformations):\n                raise ValueError(\n                    f\"Valid transformations are {list(OrderingTransformations)} but received {transformation}.\"\n                )\n        self.transformation_order = transformation_order\n\n        self.template = self._create_template()\n        self._sequence_ordering = self._create_ordering()\n        self._revert_sequence_ordering = np.argsort(self._sequence_ordering)\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        x = x[self._sequence_ordering]\n\n        return x\n\n    def get_sequence_ordering(self) -> np.ndarray:\n        return self._sequence_ordering\n\n    def get_revert_sequence_ordering(self) -> np.ndarray:\n        return self._revert_sequence_ordering\n\n    def _create_ordering(self) -> np.ndarray:\n        self.template = self._transform_template()\n        order = self._order_template(template=self.template)\n\n        return order\n\n    def _create_template(self) -> np.ndarray:\n        spatial_dimensions = self.dimensions[1:]\n        template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions)\n\n        return template\n\n    def _transform_template(self) -> np.ndarray:\n        for transformation in self.transformation_order:\n            if transformation == OrderingTransformations.TRANSPOSE.value:\n                self.template = self._transpose_template(template=self.template)\n            elif transformation == OrderingTransformations.ROTATE_90.value:\n                self.template = self._rot90_template(template=self.template)\n            elif transformation == OrderingTransformations.REFLECT.value:\n                self.template = self._flip_template(template=self.template)\n\n        return self.template\n\n    def _transpose_template(self, template: np.ndarray) -> np.ndarray:\n        if self.transpositions_axes is not None:\n            for axes in self.transpositions_axes:\n                template = np.transpose(template, axes=axes)\n\n        return template\n\n    def _flip_template(self, template: np.ndarray) -> np.ndarray:\n        if self.reflected_spatial_dims is not None:\n            for axis, to_reflect in enumerate(self.reflected_spatial_dims):\n                template = np.flip(template, axis=axis) if to_reflect else template\n\n        return template\n\n    def _rot90_template(self, template: np.ndarray) -> np.ndarray:\n        if self.rot90_axes is not None:\n            for axes in self.rot90_axes:\n                template = np.rot90(template, axes=axes)\n\n        return template\n\n    def _order_template(self, template: np.ndarray) -> np.ndarray:\n        depths = None\n        if self.spatial_dims == 2:\n            rows, columns = template.shape[0], template.shape[1]\n        else:\n            rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2])\n\n        sequence = eval(f\"self.{self.ordering_type}_idx\")(rows, columns, depths)\n\n        ordering = np.array([template[tuple(e)] for e in sequence])\n\n        return ordering\n\n    @staticmethod\n    def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:\n        idx: list[tuple] = []\n\n        for r in range(rows):\n            for c in range(cols):\n                if depths is not None:\n                    for d in range(depths):\n                        idx.append((r, c, d))\n                else:\n                    idx.append((r, c))\n\n        idx_np = np.array(idx)\n\n        return idx_np\n\n    @staticmethod\n    def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:\n        idx: list[tuple] = []\n\n        for r in range(rows):\n            col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1)\n            for c in col_idx:\n                if depths:\n                    depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1)\n\n                    for d in depth_idx:\n                        idx.append((r, c, d))\n                else:\n                    idx.append((r, c))\n\n        idx_np = np.array(idx)\n\n        return idx_np\n\n    @staticmethod\n    def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:\n        idx: list[tuple] = []\n\n        for r in range(rows):\n            for c in range(cols):\n                if depths:\n                    for d in range(depths):\n                        idx.append((r, c, d))\n                else:\n                    idx.append((r, c))\n\n        idx_np = np.array(idx)\n        np.random.shuffle(idx_np)\n\n        return idx_np\n"
  },
  {
    "path": "monai/utils/profiling.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport datetime\nimport multiprocessing\nimport os\nimport sys\nimport threading\nfrom collections import defaultdict, namedtuple\nfrom contextlib import contextmanager\nfrom functools import wraps\nfrom inspect import getframeinfo, stack\nfrom queue import Empty\nfrom time import perf_counter, perf_counter_ns\nfrom typing import TYPE_CHECKING, Any, cast\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import optional_import\n\nif TYPE_CHECKING:\n    from ignite.engine import Events\nelse:\n    Events = optional_import(\"ignite.engine\", name=\"Events\")\n\npd, has_pandas = optional_import(\"pandas\")\n\n__all__ = [\n    \"torch_profiler_full\",\n    \"torch_profiler_time_cpu_gpu\",\n    \"torch_profiler_time_end_to_end\",\n    \"PerfContext\",\n    \"WorkflowProfiler\",\n    \"ProfileHandler\",\n    \"select_transform_call\",\n]\n\n\ndef torch_profiler_full(func):\n    \"\"\"\n    A decorator which will run the torch profiler for the decorated function,\n    printing the results in full.\n    Note: Enforces a gpu sync point which could slow down pipelines.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        with torch.autograd.profiler.profile(use_cuda=True) as prof:\n            result = func(*args, **kwargs)\n\n        print(prof, flush=True)\n\n        return result\n\n    return wrapper\n\n\ndef torch_profiler_time_cpu_gpu(func):\n    \"\"\"\n    A decorator which measures the execution time of both the CPU and GPU components\n    of the decorated function, printing both results.\n    Note: Enforces a gpu sync point which could slow down pipelines.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        with torch.autograd.profiler.profile(use_cuda=True) as prof:\n            result = func(*args, **kwargs)\n\n        cpu_time = prof.self_cpu_time_total\n        gpu_time = sum(evt.self_cuda_time_total for evt in prof.function_events)\n\n        cpu_time = torch.autograd.profiler.format_time(cpu_time)  # type: ignore\n        gpu_time = torch.autograd.profiler.format_time(gpu_time)  # type: ignore\n\n        print(f\"cpu time: {cpu_time}, gpu time: {gpu_time}\", flush=True)\n\n        return result\n\n    return wrapper\n\n\ndef torch_profiler_time_end_to_end(func):\n    \"\"\"\n    A decorator which measures the total execution time from when the decorated\n    function is called to when the last cuda operation finishes, printing the result.\n    Note: Enforces a gpu sync point which could slow down pipelines.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        torch.cuda.synchronize()\n        start = perf_counter()\n\n        result = func(*args, **kwargs)\n\n        torch.cuda.synchronize()\n        end = perf_counter()\n\n        total_time = (end - start) * 1e6\n        total_time_str = torch.autograd.profiler.format_time(total_time)  # type: ignore\n        print(f\"End-to-end time: {total_time_str}\", flush=True)\n\n        return result\n\n    return wrapper\n\n\nclass PerfContext:\n    \"\"\"\n    Context manager for tracking how much time is spent within context blocks. This uses `time.perf_counter` to\n    accumulate the total amount of time in seconds in the attribute `total_time` over however many context blocks\n    the object is used in.\n    \"\"\"\n\n    def __init__(self):\n        self.total_time: float = 0\n        self.start_time: float | None = None\n\n    def __enter__(self):\n        self.start_time = perf_counter()\n        return self\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        if self.start_time is not None:\n            self.total_time += perf_counter() - self.start_time\n        self.start_time = None\n\n\n# stores the results from profiling with trace or with other helper methods\nProfileResult = namedtuple(\"ProfileResult\", [\"name\", \"time\", \"filename\", \"lineno\", \"pid\", \"timestamp\"])\n\n\ndef select_transform_call(frame):\n    \"\"\"Returns True if `frame` is a call to a `Transform` object's `_call__` method.\"\"\"\n    from monai.transforms import Transform  # prevents circular import\n\n    self_obj = frame.f_locals.get(\"self\", None)\n    return frame.f_code.co_name == \"__call__\" and isinstance(self_obj, Transform)\n\n\nclass WorkflowProfiler:\n    \"\"\"\n    Profiler for timing all aspects of a workflow. This includes using stack tracing to capture call times for\n    all selected calls (by default calls to `Transform.__call__` methods), times within context blocks, times\n    to generate items from iterables, and times to execute decorated functions.\n\n    This profiler must be used only within its context because it uses an internal thread to read results from a\n    multiprocessing queue. This allows the profiler to function across multiple threads and processes, though the\n    multiprocess tracing is at times unreliable and not available in Windows at all.\n\n    The profiler uses `sys.settrace` and `threading.settrace` to find all calls to profile, this will be set when\n    the context enters and cleared when it exits so proper use of the context is essential to prevent excessive\n    tracing. Note that tracing has a high overhead so times will not accurately reflect real world performance\n    but give an idea of relative share of time spent.\n\n    The tracing functionality uses a selector to choose which calls to trace, since tracing all calls induces\n    infinite loops and would be terribly slow even if not. This selector is a callable accepting a `call` trace\n    frame and returns True if the call should be traced. The default is `select_transform_call` which will return\n    True for `Transform.__call__` calls only.\n\n    Example showing use of all profiling functions:\n\n    .. code-block:: python\n\n        import monai.transform as mt\n        from monai.utils import WorkflowProfiler\n        import torch\n\n        comp=mt.Compose([mt.ScaleIntensity(),mt.RandAxisFlip(0.5)])\n\n        with WorkflowProfiler() as wp:\n            for _ in wp.profile_iter(\"range\",range(5)):\n                with wp.profile_ctx(\"Loop\"):\n                    for i in range(10):\n                        comp(torch.rand(1,16,16))\n\n            @wp.profile_callable()\n            def foo(): pass\n\n            foo()\n            foo()\n\n        print(wp.get_times_summary_pd())  # print results\n\n    Args:\n        call_selector: selector to determine which calls to trace, use None to disable tracing\n    \"\"\"\n\n    def __init__(self, call_selector=select_transform_call):\n        self.results = defaultdict(list)\n        self.parent_pid = os.getpid()\n        self.read_thread: threading.Thread | None = None\n        self.lock = threading.RLock()\n        self.queue: multiprocessing.SimpleQueue = multiprocessing.SimpleQueue()\n        self.queue_timeout = 0.1\n        self.call_selector = call_selector\n\n    def _is_parent(self):\n        \"\"\"Return True if this is the parent process.\"\"\"\n        return os.getpid() == self.parent_pid\n\n    def _is_thread_active(self):\n        \"\"\"Return True if the read thread should be still active.\"\"\"\n        return self.read_thread is not None or not self.queue.empty()\n\n    def _read_thread_func(self):\n        \"\"\"Read results from the queue and add to self.results in a thread stared by `__enter__`.\"\"\"\n        while self._is_parent() and self._is_thread_active():\n            try:\n                result = self.queue.get()\n\n                if result is None:\n                    break\n\n                self.add_result(result)\n            except Empty:\n                pass\n\n        if not (not self._is_parent() or self.queue.empty()):\n            raise AssertionError\n\n    def _put_result(self, name, timedelta, filename, lineno):\n        \"\"\"Add a ProfileResult object to the queue.\"\"\"\n        ts = str(datetime.datetime.now())\n        self.queue.put(ProfileResult(name, timedelta, filename, lineno, os.getpid(), ts))\n\n    def _trace_call(self, frame, why, arg):\n        \"\"\"\n        Trace calls, when a call is encountered that is accepted by self.call_selector, create a new function to\n        trace that call and measure the time from the call to a \"return\" frame.\n        \"\"\"\n        if why == \"call\":\n            if self.call_selector(frame):\n                calling_frame = frame\n                start = perf_counter_ns()\n\n                def _call_profiler(frame, why, arg):\n                    \"\"\"Defines a new inner trace function just for this call.\"\"\"\n                    if why == \"return\":\n                        diff = perf_counter_ns() - start\n                        f_code = calling_frame.f_code\n                        self_obj = calling_frame.f_locals.get(\"self\", None)\n                        name = f_code.co_name\n                        if self_obj is not None:\n                            name = f\"{type(self_obj).__name__}.{name}\"\n\n                        self._put_result(name, diff, f_code.co_filename, f_code.co_firstlineno)\n\n                # This function will be used to trace this specific call now, however any new functions calls\n                # within will cause a \"call\" frame to be sent to `_trace_call` rather than to it, ie. it's not\n                # actually recursively tracing everything below as the documentation suggests and so cannot\n                # control whether subsequence calls are traced (see https://bugs.python.org/issue11992).\n                return _call_profiler\n        else:\n            return self._trace_call\n\n    def __enter__(self):\n        \"\"\"Enter the context, creating the read thread and setting up tracing if needed.\"\"\"\n        self.read_thread = threading.Thread(target=self._read_thread_func)\n        self.read_thread.start()\n\n        if self.call_selector is not None:\n            threading.settrace(self._trace_call)\n            sys.settrace(self._trace_call)\n\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        \"\"\"Terminate the read thread cleanly and reset tracing if needed.\"\"\"\n        if not self._is_parent():\n            raise AssertionError\n\n        self.queue.put(None)\n\n        read_thread = cast(threading.Thread, self.read_thread)\n        self.read_thread = None\n\n        read_thread.join()\n\n        if self.call_selector is not None:\n            threading.settrace(None)  # type: ignore\n            sys.settrace(None)\n\n    def add_result(self, result: ProfileResult) -> None:\n        \"\"\"Add a result in a thread-safe manner to the internal results dictionary.\"\"\"\n        with self.lock:\n            self.results[result.name].append(result)\n\n    def get_results(self):\n        \"\"\"Get a fresh results dictionary containing fresh tuples of ProfileResult objects.\"\"\"\n        if not self._is_parent():\n            raise RuntimeError(\"Only parent process can collect results\")\n\n        with self.lock:\n            return {k: tuple(v) for k, v in self.results.items()}\n\n    @contextmanager\n    def profile_ctx(self, name, caller=None):\n        \"\"\"Creates a context to profile, placing a timing result onto the queue when it exits.\"\"\"\n        if caller is None:\n            caller = getframeinfo(stack()[2][0])  # caller of context, not something in contextlib\n\n        start = perf_counter_ns()\n        try:\n            yield\n        finally:\n            diff = perf_counter_ns() - start\n            self._put_result(name, diff, caller.filename, caller.lineno)\n\n    def profile_callable(self, name=None):\n        \"\"\"\n        Decorator which can be applied to a function which profiles any calls to it. All calls to decorated\n        callables must be done within the context of the profiler.\n        \"\"\"\n\n        def _outer(func):\n            _name = func.__name__ if name is None else name\n            return self.profile_ctx(_name)(func)\n\n        return _outer\n\n    def profile_iter(self, name, iterable):\n        \"\"\"Wrapper around anything iterable to profile how long it takes to generate items.\"\"\"\n\n        class _Iterable:\n\n            def __iter__(_self):  # noqa: N805 pylint: disable=E0213\n                do_iter = True\n                orig_iter = iter(iterable)\n                caller = getframeinfo(stack()[1][0])\n\n                while do_iter:\n                    try:\n                        start = perf_counter_ns()\n                        item = next(orig_iter)\n                        diff = perf_counter_ns() - start\n                        # don't put result when StopIteration is hit\n                        self._put_result(name, diff, caller.filename, caller.lineno)\n                        yield item\n                    except StopIteration:\n                        do_iter = False\n\n        return _Iterable()\n\n    def get_times_summary(self, times_in_s=True):\n        \"\"\"\n        Returns a dictionary mapping results entries to tuples containing the number of items, time sum, time average,\n        time std dev, time min, and time max.\n        \"\"\"\n        result = {}\n        for k, v in self.get_results().items():\n            timemult = 1e-9 if times_in_s else 1.0\n            all_times = [res.time * timemult for res in v]\n\n            timesum = sum(all_times)\n            timeavg = timesum / len(all_times)\n            timestd = np.std(all_times)\n            timemin = min(all_times)\n            timemax = max(all_times)\n\n            result[k] = (len(v), timesum, timeavg, timestd, timemin, timemax)\n\n        return result\n\n    def get_times_summary_pd(self, times_in_s=True):\n        \"\"\"Returns the same information as `get_times_summary` but in a Pandas DataFrame.\"\"\"\n        import pandas as pd\n\n        summ = self.get_times_summary(times_in_s)\n        suffix = \"s\" if times_in_s else \"ns\"\n        columns = [\"Count\", f\"Total Time ({suffix})\", \"Avg\", \"Std\", \"Min\", \"Max\"]\n\n        df = pd.DataFrame.from_dict(summ, orient=\"index\", columns=columns)\n        df = df.sort_values(columns[1], ascending=False)\n        return df\n\n    def dump_csv(self, stream=sys.stdout):\n        \"\"\"Save all results to a csv file.\"\"\"\n        all_results = list(self.get_results().values())\n        writer = csv.DictWriter(stream, fieldnames=all_results[0][0]._asdict().keys())\n        writer.writeheader()\n\n        for rlist in all_results:\n            for r in rlist:\n                writer.writerow(r._asdict())\n\n\nclass ProfileHandler:\n    \"\"\"\n    Handler for Ignite Engine classes which measures the time from a start event ton an end event. This can be used to\n    profile epoch, iteration, and other events as defined in `ignite.engine.Events`. This class should be used only\n    within the context of a profiler object.\n\n    Args:\n        name: name of event to profile\n        profiler: instance of WorkflowProfiler used by the handler, should be within the context of this object\n        start_event: item in `ignite.engine.Events` stating event at which to start timing\n        end_event: item in `ignite.engine.Events` stating event at which to stop timing\n    \"\"\"\n\n    def __init__(self, name: str, profiler: WorkflowProfiler, start_event: Events, end_event: Events):\n        self.name = name\n        self.profiler = profiler\n        self.start_event = start_event\n        self.end_event = end_event\n        self.ctx: Any = None\n\n    def attach(self, engine):\n        engine.add_event_handler(self.start_event, self.start)\n        engine.add_event_handler(self.end_event, self.end)\n        return self\n\n    def start(self, engine):\n        self.ctx = self.profiler.profile_ctx(self.name)\n        self.ctx.__enter__()\n\n    def end(self, engine):\n        self.ctx.__exit__(None, None, None)\n        self.ctx = None\n"
  },
  {
    "path": "monai/utils/state_cacher.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nimport os\nimport pickle\nimport tempfile\nfrom collections.abc import Hashable\nfrom types import ModuleType\nfrom typing import Any\n\nimport torch\nfrom torch.serialization import DEFAULT_PROTOCOL\n\nfrom monai.config.type_definitions import PathLike\n\n__all__ = [\"StateCacher\"]\n\n\nclass StateCacher:\n    \"\"\"Class to cache and retrieve the state of an object.\n\n    Objects can either be stored in memory or on disk. If stored on disk, they can be\n    stored in a given directory, or alternatively a temporary location will be used.\n\n    If necessary/possible, restored objects will be returned to their original device.\n\n    Example:\n\n    >>> state_cacher = StateCacher(memory_cache, cache_dir=cache_dir)\n    >>> state_cacher.store(\"model\", model.state_dict())\n    >>> model.load_state_dict(state_cacher.retrieve(\"model\"))\n    \"\"\"\n\n    def __init__(\n        self,\n        in_memory: bool,\n        cache_dir: PathLike | None = None,\n        allow_overwrite: bool = True,\n        pickle_module: ModuleType = pickle,\n        pickle_protocol: int = DEFAULT_PROTOCOL,\n    ) -> None:\n        \"\"\"Constructor.\n\n        Args:\n            in_memory: boolean to determine if the object will be cached in memory or on\n                disk.\n            cache_dir: directory for data to be cached if `in_memory==False`. Defaults\n                to using a temporary directory. Any created files will be deleted during\n                the `StateCacher`'s destructor.\n            allow_overwrite: allow the cache to be overwritten. If set to `False`, an\n                error will be thrown if a matching already exists in the list of cached\n                objects.\n            pickle_module: module used for pickling metadata and objects, default to `pickle`.\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.\n                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n\n        \"\"\"\n        self.in_memory = in_memory\n        self.cache_dir = tempfile.gettempdir() if cache_dir is None else cache_dir\n        if not os.path.isdir(self.cache_dir):\n            raise ValueError(\"Given `cache_dir` is not a valid directory.\")\n\n        self.allow_overwrite = allow_overwrite\n        self.pickle_module = pickle_module\n        self.pickle_protocol = pickle_protocol\n        self.cached: dict = {}\n\n    def store(\n        self, key: Hashable, data_obj: Any, pickle_module: ModuleType | None = None, pickle_protocol: int | None = None\n    ) -> None:\n        \"\"\"\n        Store a given object with the given key name.\n\n        Args:\n            key: key of the data object to store.\n            data_obj: data object to store.\n            pickle_module: module used for pickling metadata and objects, default to `self.pickle_module`.\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n            pickle_protocol: can be specified to override the default protocol, default to `self.pickle_protocol`.\n                this arg is used by `torch.save`, for more details, please check:\n                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.\n\n        \"\"\"\n        if key in self.cached and not self.allow_overwrite:\n            raise RuntimeError(\"Cached key already exists and overwriting is disabled.\")\n        if self.in_memory:\n            self.cached.update({key: {\"obj\": copy.deepcopy(data_obj)}})\n        else:\n            fn = os.path.join(self.cache_dir, f\"state_{key}_{id(self)}.pt\")\n            self.cached.update({key: {\"obj\": fn}})\n            torch.save(\n                obj=data_obj,\n                f=fn,\n                pickle_module=self.pickle_module if pickle_module is None else pickle_module,\n                pickle_protocol=self.pickle_protocol if pickle_protocol is None else pickle_protocol,\n            )\n            # store object's device if relevant\n            if hasattr(data_obj, \"device\"):\n                self.cached[key][\"device\"] = data_obj.device\n\n    def retrieve(self, key: Hashable) -> Any:\n        \"\"\"Retrieve the object stored under a given key name.\"\"\"\n        if key not in self.cached:\n            raise KeyError(f\"Target {key} was not cached.\")\n\n        if self.in_memory:\n            return self.cached[key][\"obj\"]\n\n        fn = self.cached[key][\"obj\"]  # pytype: disable=attribute-error\n        if not os.path.exists(fn):  # pytype: disable=wrong-arg-types\n            raise RuntimeError(f\"Failed to load state in {fn}. File doesn't exist anymore.\")\n        data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True)\n        # copy back to device if necessary\n        if \"device\" in self.cached[key]:\n            data_obj = data_obj.to(self.cached[key][\"device\"])\n        return data_obj\n\n    def __del__(self):\n        \"\"\"If necessary, delete any cached files existing in `cache_dir`.\"\"\"\n        if not self.in_memory:\n            for k in self.cached:\n                if os.path.exists(self.cached[k][\"obj\"]):\n                    os.remove(self.cached[k][\"obj\"])\n"
  },
  {
    "path": "monai/utils/tf32.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport functools\nimport os\nimport warnings\n\n__all__ = [\"has_ampere_or_later\", \"detect_default_tf32\"]\n\n\n@functools.lru_cache(None)\ndef has_ampere_or_later() -> bool:\n    \"\"\"\n    Check if there is any Ampere and later GPU.\n    \"\"\"\n    import torch\n\n    from monai.utils.module import optional_import, version_geq\n\n    if not (torch.version.cuda and version_geq(f\"{torch.version.cuda}\", \"11.0\")):\n        return False\n\n    pynvml, has_pynvml = optional_import(\"pynvml\")\n    if not has_pynvml:  # assuming that the user has Ampere and later GPU\n        return True\n\n    try:\n        pynvml.nvmlInit()\n        for i in range(pynvml.nvmlDeviceGetCount()):\n            handle = pynvml.nvmlDeviceGetHandleByIndex(i)\n            major, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle)\n            if major >= 8:\n                return True\n    except BaseException:\n        pass\n    finally:\n        pynvml.nvmlShutdown()\n\n    return False\n\n\n@functools.lru_cache(None)\ndef detect_default_tf32() -> bool:\n    \"\"\"\n    Detect if there is anything that may enable TF32 mode by default.\n    If any, show a warning message.\n    \"\"\"\n    may_enable_tf32 = False\n    try:\n        if not has_ampere_or_later():\n            return False\n\n        override_tf32_env_vars = {\"NVIDIA_TF32_OVERRIDE\": \"1\"}  # TORCH_ALLOW_TF32_CUBLAS_OVERRIDE not checked #6907\n        for name, override_val in override_tf32_env_vars.items():\n            if os.environ.get(name) == override_val:\n                warnings.warn(\n                    f\"Environment variable `{name} = {override_val}` is set.\\n\"\n                    f\"  This environment variable may enable TF32 mode accidentally and affect precision.\\n\"\n                    f\"  See https://monai.readthedocs.io/en/latest/precision_accelerating.html#precision-and-accelerating\"\n                )\n                may_enable_tf32 = True\n\n        return may_enable_tf32\n    except BaseException:\n        from monai.utils.misc import MONAIEnvVars\n\n        if MONAIEnvVars.debug():\n            raise\n        return False\n"
  },
  {
    "path": "monai/utils/type_conversion.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport re\nfrom collections.abc import Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.config.type_definitions import DtypeLike, NdarrayTensor\nfrom monai.utils import optional_import\n\ncp, has_cp = optional_import(\"cupy\")\ncp_ndarray, _ = optional_import(\"cupy\", name=\"ndarray\")\n\n__all__ = [\n    \"get_numpy_dtype_from_string\",\n    \"get_torch_dtype_from_string\",\n    \"dtype_torch_to_numpy\",\n    \"dtype_numpy_to_torch\",\n    \"get_equivalent_dtype\",\n    \"convert_data_type\",\n    \"get_dtype\",\n    \"get_dtype_string\",\n    \"convert_to_cupy\",\n    \"convert_to_numpy\",\n    \"convert_to_tensor\",\n    \"convert_to_dst_type\",\n]\n\n# conversion map for types unsupported by torch.as_tensor\nUNSUPPORTED_TYPES = {np.dtype(\"uint16\"): np.int32, np.dtype(\"uint32\"): np.int64, np.dtype(\"uint64\"): np.int64}\n\n\ndef get_numpy_dtype_from_string(dtype: str) -> np.dtype:\n    \"\"\"Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `\"float32\"`).\"\"\"\n    return np.empty([], dtype=str(dtype).split(\".\")[-1]).dtype\n\n\ndef get_torch_dtype_from_string(dtype: str) -> torch.dtype:\n    \"\"\"Get a torch dtype (e.g., `torch.float32`) from its string (e.g., `\"float32\"`).\"\"\"\n    return dtype_numpy_to_torch(get_numpy_dtype_from_string(dtype))\n\n\ndef dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype:\n    \"\"\"Convert a torch dtype to its numpy equivalent.\"\"\"\n    return torch.empty([], dtype=dtype).numpy().dtype  # type: ignore\n\n\ndef dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype:\n    \"\"\"Convert a numpy dtype to its torch equivalent.\"\"\"\n    return torch.from_numpy(np.empty([], dtype=dtype)).dtype\n\n\ndef get_equivalent_dtype(dtype, data_type):\n    \"\"\"Convert to the `dtype` that corresponds to `data_type`.\n\n    The input dtype can also be a string. e.g., `\"float32\"` becomes `torch.float32` or\n    `np.float32` as necessary.\n\n    Example::\n\n        im = torch.tensor(1)\n        dtype = get_equivalent_dtype(np.float32, type(im))\n\n    \"\"\"\n    if dtype is None:\n        return None\n    if data_type is torch.Tensor or data_type.__name__ == \"MetaTensor\":\n        if isinstance(dtype, torch.dtype):\n            # already a torch dtype and target `data_type` is torch.Tensor\n            return dtype\n        return dtype_numpy_to_torch(dtype)\n    if not isinstance(dtype, torch.dtype):\n        # assuming the dtype is ok if it is not a torch dtype and target `data_type` is not torch.Tensor\n        return dtype\n    return dtype_torch_to_numpy(dtype)\n\n\ndef get_dtype(data: Any) -> DtypeLike | torch.dtype:\n    \"\"\"Get the dtype of an image, or if there is a sequence, recursively call the method on the 0th element.\n\n    This therefore assumes that in a `Sequence`, all types are the same.\n    \"\"\"\n    if hasattr(data, \"dtype\"):\n        return data.dtype  # type: ignore\n    # need recursion\n    if isinstance(data, Sequence):\n        return get_dtype(data[0])\n    # objects like float don't have dtype, so return their type\n    return type(data)\n\n\ndef get_dtype_string(dtype: DtypeLike | torch.dtype) -> str:\n    \"\"\"Get a string representation of the dtype.\"\"\"\n    if isinstance(dtype, torch.dtype):\n        return str(dtype)[6:]\n    return str(dtype)[3:]\n\n\ndef convert_to_tensor(\n    data: Any,\n    dtype: DtypeLike | torch.dtype = None,\n    device: None | str | torch.device = None,\n    wrap_sequence: bool = False,\n    track_meta: bool = False,\n    safe: bool = False,\n    convert_numeric: bool = True,\n) -> Any:\n    \"\"\"\n    Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,\n    otherwise, the output will be a regular torch Tensor.\n    If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor.\n\n    Args:\n        data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.\n            will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original.\n            for dictionary, list or tuple, convert every item to a Tensor if applicable.\n        dtype: target data type to when converting to Tensor.\n        device: target device to put the converted Tensor data.\n        wrap_sequence: if `False`, then lists will recursively call this function.\n            E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.\n        track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.\n            default to `False`.\n        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n            E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`.\n            If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`.\n        convert_numeric: if `True`, convert numeric Python values to tensors.\n\n    \"\"\"\n\n    def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:\n        if not isinstance(tensor, torch.Tensor):\n            # certain numpy types are not supported as being directly convertible to Pytorch tensors\n            if isinstance(tensor, np.ndarray) and tensor.dtype in UNSUPPORTED_TYPES:\n                tensor = tensor.astype(UNSUPPORTED_TYPES[tensor.dtype])\n\n            # if input data is not Tensor, convert it to Tensor first\n            tensor = torch.as_tensor(tensor, **kwargs)\n        if track_meta and not isinstance(tensor, monai.data.MetaTensor):\n            return monai.data.MetaTensor(tensor)\n        if not track_meta and isinstance(tensor, monai.data.MetaTensor):\n            return tensor.as_tensor()\n        return tensor\n\n    if safe:\n        data = safe_dtype_range(data, dtype)\n    dtype = get_equivalent_dtype(dtype, torch.Tensor)\n\n    if isinstance(data, torch.Tensor):\n        return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format)\n    if isinstance(data, np.ndarray):\n        # skip array of string classes and object, refer to:\n        # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13\n        if re.search(r\"[SaUO]\", data.dtype.str) is None:\n            # numpy array with 0 dims is also sequence iterable,\n            # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims\n            if data.ndim > 0:\n                data = np.ascontiguousarray(data)\n            return _convert_tensor(data, dtype=dtype, device=device)\n    elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))):\n        return _convert_tensor(data, dtype=dtype, device=device)\n    elif isinstance(data, list):\n        list_ret = [\n            convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)\n            for i in data\n        ]\n        return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret\n    elif isinstance(data, tuple):\n        tuple_ret = tuple(\n            convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)\n            for i in data\n        )\n        return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret\n    elif isinstance(data, dict):\n        return {\n            k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)\n            for k, v in data.items()\n        }\n\n    return data\n\n\ndef convert_to_numpy(data: Any, dtype: DtypeLike = None, wrap_sequence: bool = False, safe: bool = False) -> Any:\n    \"\"\"\n    Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,\n    recursively check every item and convert it to numpy array.\n\n    Args:\n        data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.\n            will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original.\n            for dictionary, list or tuple, convert every item to a numpy array if applicable.\n        dtype: target data type when converting to numpy array.\n        wrap_sequence: if `False`, then lists will recursively call this function.\n            E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.\n        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n            E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n    \"\"\"\n    if safe:\n        data = safe_dtype_range(data, dtype)\n    if isinstance(data, torch.Tensor):\n        data = np.asarray(data.detach().to(device=\"cpu\").numpy(), dtype=get_equivalent_dtype(dtype, np.ndarray))\n    elif has_cp and isinstance(data, cp_ndarray):\n        data = cp.asnumpy(data).astype(dtype, copy=False)\n    elif isinstance(data, (np.ndarray, float, int, bool)):\n        # Convert into a contiguous array first if the current dtype's size is smaller than the target dtype's size.\n        # This help improve the performance because (convert to contiguous array) -> (convert dtype) is faster\n        # than (convert dtype) -> (convert to contiguous array) when src dtype (e.g., uint8) is smaller than\n        # target dtype(e.g., float32) and we are going to convert it to contiguous array anyway later in this\n        # method.\n        if isinstance(data, np.ndarray) and data.ndim > 0 and data.dtype.itemsize < np.dtype(dtype).itemsize:\n            data = np.ascontiguousarray(data)\n        data = np.asarray(data, dtype=dtype)\n    elif isinstance(data, list):\n        list_ret = [convert_to_numpy(i, dtype=dtype) for i in data]\n        return np.asarray(list_ret) if wrap_sequence else list_ret\n    elif isinstance(data, tuple):\n        tuple_ret = tuple(convert_to_numpy(i, dtype=dtype) for i in data)\n        return np.asarray(tuple_ret) if wrap_sequence else tuple_ret\n    elif isinstance(data, dict):\n        return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()}\n\n    if isinstance(data, np.ndarray) and data.ndim > 0:\n        data = np.ascontiguousarray(data)\n\n    return data\n\n\ndef convert_to_cupy(data: Any, dtype: np.dtype | None = None, wrap_sequence: bool = False, safe: bool = False) -> Any:\n    \"\"\"\n    Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple,\n    recursively check every item and convert it to cupy array.\n\n    Args:\n        data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc.\n            Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays,\n            for dictionary, list or tuple, convert every item to a numpy array if applicable.\n        dtype: target data type when converting to Cupy array, tt must be an argument of `numpy.dtype`,\n            for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.\n        wrap_sequence: if `False`, then lists will recursively call this function.\n            E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.\n        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n            E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n    \"\"\"\n    if safe:\n        data = safe_dtype_range(data, dtype)\n    # direct calls\n    if isinstance(data, torch.Tensor) and data.device.type == \"cuda\":\n        # This is needed because of https://github.com/cupy/cupy/issues/7874#issuecomment-1727511030\n        if data.dtype == torch.bool:\n            data = data.detach().to(torch.uint8)\n            if dtype is None:\n                dtype = bool  # type: ignore\n        data = cp.asarray(data, dtype)\n    elif isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)):\n        data = cp.asarray(data, dtype)\n    elif isinstance(data, list):\n        list_ret = [convert_to_cupy(i, dtype) for i in data]\n        return cp.asarray(list_ret) if wrap_sequence else list_ret\n    elif isinstance(data, tuple):\n        tuple_ret = tuple(convert_to_cupy(i, dtype) for i in data)\n        return cp.asarray(tuple_ret) if wrap_sequence else tuple_ret\n    elif isinstance(data, dict):\n        return {k: convert_to_cupy(v, dtype) for k, v in data.items()}\n    # make it contiguous\n    if not isinstance(data, cp.ndarray):\n        raise ValueError(f\"The input data type [{type(data)}] cannot be converted into cupy arrays!\")\n\n    if data.ndim > 0:\n        data = cp.ascontiguousarray(data)\n    return data\n\n\ndef convert_data_type(\n    data: Any,\n    output_type: type[NdarrayTensor] | None = None,\n    device: None | str | torch.device = None,\n    dtype: DtypeLike | torch.dtype = None,\n    wrap_sequence: bool = False,\n    safe: bool = False,\n) -> tuple[NdarrayTensor, type, torch.device | None]:\n    \"\"\"\n    Convert to `MetaTensor`, `torch.Tensor` or `np.ndarray` from `MetaTensor`, `torch.Tensor`,\n    `np.ndarray`, `float`, `int`, etc.\n\n    Args:\n        data: data to be converted\n        output_type: `monai.data.MetaTensor`, `torch.Tensor`, or `np.ndarray` (if `None`, unchanged)\n        device: if output is `MetaTensor` or `torch.Tensor`, select device (if `None`, unchanged)\n        dtype: dtype of output data. Converted to correct library type (e.g.,\n            `np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).\n            If left blank, it remains unchanged.\n        wrap_sequence: if `False`, then lists will recursively call this function.\n            E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.\n        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n            E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n\n    Returns:\n        modified data, orig_type, orig_device\n\n    Note:\n        When both `output_type` and `dtype` are specified with different backend\n        (e.g., `torch.Tensor` and `np.float32`), the `output_type` will be used as the primary type,\n        for example::\n\n            >>> convert_data_type(1, torch.Tensor, dtype=np.float32)\n            (1.0, <class 'torch.Tensor'>, None)\n\n    \"\"\"\n    orig_type: type\n    if isinstance(data, monai.data.MetaTensor):\n        orig_type = monai.data.MetaTensor\n    elif isinstance(data, torch.Tensor):\n        orig_type = torch.Tensor\n    elif isinstance(data, np.ndarray):\n        orig_type = np.ndarray\n    elif has_cp and isinstance(data, cp.ndarray):\n        orig_type = cp.ndarray\n    else:\n        orig_type = type(data)\n\n    orig_device = data.device if isinstance(data, torch.Tensor) else None\n\n    output_type = output_type or orig_type\n    dtype_ = get_equivalent_dtype(dtype, output_type)\n\n    data_: NdarrayTensor\n    if issubclass(output_type, torch.Tensor):\n        track_meta = issubclass(output_type, monai.data.MetaTensor)\n        data_ = convert_to_tensor(\n            data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta, safe=safe\n        )\n        return data_, orig_type, orig_device\n    if issubclass(output_type, np.ndarray):\n        data_ = convert_to_numpy(data, dtype=dtype_, wrap_sequence=wrap_sequence, safe=safe)\n        return data_, orig_type, orig_device\n    elif has_cp and issubclass(output_type, cp.ndarray):\n        data_ = convert_to_cupy(data, dtype=dtype_, wrap_sequence=wrap_sequence, safe=safe)\n        return data_, orig_type, orig_device\n    raise ValueError(f\"Unsupported output type: {output_type}\")\n\n\ndef convert_to_dst_type(\n    src: Any,\n    dst: NdarrayTensor,\n    dtype: DtypeLike | torch.dtype | None = None,\n    wrap_sequence: bool = False,\n    device: None | str | torch.device = None,\n    safe: bool = False,\n) -> tuple[NdarrayTensor, type, torch.device | None]:\n    \"\"\"\n    Convert source data to the same data type and device as the destination data.\n    If `dst` is an instance of `torch.Tensor` or its subclass, convert `src` to `torch.Tensor` with the same data type as `dst`,\n    if `dst` is an instance of `numpy.ndarray` or its subclass, convert to `numpy.ndarray` with the same data type as `dst`,\n    otherwise, convert to the type of `dst` directly.\n\n    Args:\n        src: source data to convert type.\n        dst: destination data that convert to the same data type as it.\n        dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type.\n        wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.\n            If `True`, then `[1, 2]` -> `array([1, 2])`.\n        device: target device to put the converted Tensor data. If unspecified, `dst.device` will be used if possible.\n        safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.\n            E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.\n\n    See Also:\n        :func:`convert_data_type`\n    \"\"\"\n\n    device = dst.device if device is None and isinstance(dst, torch.Tensor) else device\n    if dtype is None:\n        dtype = getattr(dst, \"dtype\", None)  # sequence has no dtype\n\n    copy_meta = False\n    output_type: Any\n    if isinstance(dst, monai.data.MetaTensor):\n        output_type = monai.data.MetaTensor\n        if not isinstance(src, monai.data.MetaTensor):\n            copy_meta = True  # converting a non-meta tensor to a meta tensor, probably take the metadata as well.\n    elif isinstance(dst, torch.Tensor):\n        output_type = torch.Tensor\n    elif isinstance(dst, np.ndarray):\n        output_type = np.ndarray\n    else:\n        output_type = type(dst)\n    output: NdarrayTensor\n    output, _type, _device = convert_data_type(\n        data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence, safe=safe\n    )\n    if copy_meta and isinstance(output, monai.data.MetaTensor):\n        output.copy_meta_from(dst)\n    return output, _type, _device\n\n\ndef convert_to_list(data: Sequence | torch.Tensor | np.ndarray) -> list:\n    \"\"\"\n    Convert to list from `torch.Tensor`/`np.ndarray`/`list`/`tuple` etc.\n    Args:\n        data: data to be converted\n    Returns:\n        a list\n\n    \"\"\"\n    return data.tolist() if isinstance(data, (torch.Tensor, np.ndarray)) else list(data)\n\n\ndef get_dtype_bound_value(dtype: DtypeLike | torch.dtype) -> tuple[float, float]:\n    \"\"\"\n    Get dtype bound value\n    Args:\n        dtype: dtype to get bound value\n    Returns:\n        (bound_min_value, bound_max_value)\n    \"\"\"\n    if dtype in UNSUPPORTED_TYPES:\n        is_floating_point = False\n    else:\n        is_floating_point = get_equivalent_dtype(dtype, torch.Tensor).is_floating_point\n    dtype = get_equivalent_dtype(dtype, np.array)\n    if is_floating_point:\n        return (np.finfo(dtype).min, np.finfo(dtype).max)  # type: ignore\n    else:\n        return (np.iinfo(dtype).min, np.iinfo(dtype).max)\n\n\ndef safe_dtype_range(data: Any, dtype: DtypeLike | torch.dtype = None) -> Any:\n    \"\"\"\n    Utility to safely convert the input data to target dtype.\n\n    Args:\n        data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.\n            will convert to target dtype and keep the original type.\n            for dictionary, list or tuple, convert every item.\n        dtype: target data type to convert.\n    \"\"\"\n\n    def _safe_dtype_range(data, dtype):\n        output_dtype = dtype if dtype is not None else data.dtype\n        dtype_bound_value = get_dtype_bound_value(output_dtype)\n        if data.ndim == 0:\n            data_bound = (data, data)\n        else:\n            if isinstance(data, torch.Tensor):\n                data_bound = (torch.min(data), torch.max(data))\n            else:\n                data_bound = (np.min(data), np.max(data))\n        if (data_bound[1] > dtype_bound_value[1]) or (data_bound[0] < dtype_bound_value[0]):\n            if isinstance(data, torch.Tensor):\n                return torch.clamp(data, dtype_bound_value[0], dtype_bound_value[1])\n            elif isinstance(data, np.ndarray):\n                return np.clip(data, dtype_bound_value[0], dtype_bound_value[1])\n            elif has_cp and isinstance(data, cp_ndarray):\n                return cp.clip(data, dtype_bound_value[0], dtype_bound_value[1])\n        else:\n            return data\n\n    if has_cp and isinstance(data, cp_ndarray):\n        return cp.asarray(_safe_dtype_range(data, dtype))\n    elif isinstance(data, np.ndarray):\n        return np.asarray(_safe_dtype_range(data, dtype))\n    elif isinstance(data, torch.Tensor):\n        return _safe_dtype_range(data, dtype)\n    elif isinstance(data, (float, int, bool)) and dtype is None:\n        return data\n    elif isinstance(data, (float, int, bool)) and dtype is not None:\n        output_dtype = dtype\n        dtype_bound_value = get_dtype_bound_value(output_dtype)\n        data = dtype_bound_value[1] if data > dtype_bound_value[1] else data\n        data = dtype_bound_value[0] if data < dtype_bound_value[0] else data\n        return data\n\n    elif isinstance(data, list):\n        return [safe_dtype_range(i, dtype=dtype) for i in data]\n    elif isinstance(data, tuple):\n        return tuple(safe_dtype_range(i, dtype=dtype) for i in data)\n    elif isinstance(data, dict):\n        return {k: safe_dtype_range(v, dtype=dtype) for k, v in data.items()}\n    return data\n"
  },
  {
    "path": "monai/visualize/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer\nfrom .gradient_based import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad\nfrom .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image\nfrom .occlusion_sensitivity import OcclusionSensitivity\nfrom .utils import blend_images, matshow3d\nfrom .visualizer import default_upsampler\n"
  },
  {
    "path": "monai/visualize/class_activation_maps.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport warnings\nfrom collections.abc import Callable, Sequence\nfrom typing import cast\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom monai.config import NdarrayTensor\nfrom monai.transforms import ScaleIntensity\nfrom monai.utils import ensure_tuple\nfrom monai.visualize.visualizer import default_upsampler\n\n__all__ = [\"CAM\", \"GradCAM\", \"GradCAMpp\", \"ModelWithHooks\", \"default_normalizer\"]\n\n\ndef default_normalizer(x: NdarrayTensor) -> NdarrayTensor:\n    \"\"\"\n    A linear intensity scaling by mapping the (min, max) to (1, 0).\n    If the input data is PyTorch Tensor, the output data will be Tensor on the same device,\n    otherwise, output data will be numpy array.\n\n    Note: This will flip magnitudes (i.e., smallest will become biggest and vice versa).\n    \"\"\"\n\n    def _compute(data: np.ndarray) -> np.ndarray:\n        scaler = ScaleIntensity(minv=1.0, maxv=0.0)\n        return np.stack([scaler(i) for i in data], axis=0)\n\n    if isinstance(x, torch.Tensor):\n        return torch.as_tensor(_compute(x.detach().cpu().numpy()), device=x.device)  # type: ignore\n\n    return _compute(x)  # type: ignore\n\n\nclass ModelWithHooks:\n    \"\"\"\n    A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information.\n    \"\"\"\n\n    def __init__(\n        self,\n        nn_module: nn.Module,\n        target_layer_names: str | Sequence[str],\n        register_forward: bool = False,\n        register_backward: bool = False,\n    ):\n        \"\"\"\n\n        Args:\n            nn_module: the model to be wrapped.\n            target_layer_names: the names of the layer to cache.\n            register_forward: whether to cache the forward pass output corresponding to `target_layer_names`.\n            register_backward: whether to cache the backward pass output corresponding to `target_layer_names`.\n        \"\"\"\n        self.model = nn_module\n        self.target_layers = ensure_tuple(target_layer_names)\n\n        self.gradients: dict[str, torch.Tensor] = {}\n        self.activations: dict[str, torch.Tensor] = {}\n        self.score: torch.Tensor | None = None\n        self.class_idx: int | None = None\n        self.register_backward = register_backward\n        self.register_forward = register_forward\n\n        _registered = []\n        for name, mod in nn_module.named_modules():\n            if name not in self.target_layers:\n                continue\n            _registered.append(name)\n            if self.register_backward:\n                if \"inplace\" in mod.__dict__ and mod.__dict__[\"inplace\"]:\n                    # inplace=True causes errors for register_full_backward_hook\n                    mod.__dict__[\"inplace\"] = False\n                mod.register_full_backward_hook(self.backward_hook(name))\n            if self.register_forward:\n                mod.register_forward_hook(self.forward_hook(name))\n        if self.target_layers and (len(_registered) != len(self.target_layers)):\n            warnings.warn(f\"Not all target_layers exist in the network module: targets: {self.target_layers}.\")\n\n    def backward_hook(self, name):\n\n        def _hook(_module, _grad_input, grad_output):\n            self.gradients[name] = grad_output[0]\n\n        return _hook\n\n    def forward_hook(self, name):\n\n        def _hook(_module, _input, output):\n            self.activations[name] = output\n\n        return _hook\n\n    def get_layer(self, layer_id: str | Callable[[nn.Module], nn.Module]) -> nn.Module:\n        \"\"\"\n\n        Args:\n            layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`,\n                this method will return the module `self.model.fc`.\n\n        Returns:\n            a submodule from self.model.\n        \"\"\"\n        if callable(layer_id):\n            return layer_id(self.model)\n        if isinstance(layer_id, str):\n            for name, mod in self.model.named_modules():\n                if name == layer_id:\n                    return cast(nn.Module, mod)\n        raise NotImplementedError(f\"Could not find {layer_id}.\")\n\n    def class_score(self, logits: torch.Tensor, class_idx: int) -> torch.Tensor:\n        return logits[:, class_idx].squeeze()\n\n    def __call__(self, x, class_idx=None, retain_graph=False, **kwargs):\n        train = self.model.training\n        self.model.eval()\n        logits = self.model(x, **kwargs)\n        self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx\n        acti, grad = None, None\n        if self.register_forward:\n            acti = tuple(self.activations[layer] for layer in self.target_layers)\n        if self.register_backward:\n            self.score = self.class_score(logits, cast(int, self.class_idx))\n            self.model.zero_grad()\n            self.score.sum().backward(retain_graph=retain_graph)\n            for layer in self.target_layers:\n                if layer not in self.gradients:\n                    warnings.warn(\n                        f\"Backward hook for {layer} is not triggered; `requires_grad` of {layer} should be `True`.\"\n                    )\n            grad = tuple(self.gradients[layer] for layer in self.target_layers if layer in self.gradients)\n        if train:\n            self.model.train()\n        return logits, acti, grad\n\n    def get_wrapped_net(self):\n        return self.model\n\n\nclass CAMBase:\n    \"\"\"\n    Base class for CAM methods.\n    \"\"\"\n\n    def __init__(\n        self,\n        nn_module: nn.Module,\n        target_layers: str,\n        upsampler: Callable = default_upsampler,\n        postprocessing: Callable = default_normalizer,\n        register_backward: bool = True,\n    ) -> None:\n        self.nn_module: ModelWithHooks\n        # Convert to model with hooks if necessary\n        if not isinstance(nn_module, ModelWithHooks):\n            self.nn_module = ModelWithHooks(\n                nn_module, target_layers, register_forward=True, register_backward=register_backward\n            )\n        else:\n            self.nn_module = nn_module\n\n        self.upsampler = upsampler\n        self.postprocessing = postprocessing\n\n    def feature_map_size(self, input_size, device=\"cpu\", layer_idx=-1, **kwargs):\n        \"\"\"\n        Computes the actual feature map size given `nn_module` and the target_layer name.\n        Args:\n            input_size: shape of the input tensor\n            device: the device used to initialise the input tensor\n            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.\n            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.\n        Returns:\n            shape of the actual feature map.\n        \"\"\"\n        return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx, **kwargs).shape\n\n    def compute_map(self, x, class_idx=None, layer_idx=-1):\n        \"\"\"\n        Compute the actual feature map with input tensor `x`.\n\n        Args:\n            x: input to `nn_module`.\n            class_idx: index of the class to be visualized. Default to `None` (computing `class_idx` from `argmax`)\n            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.\n\n        Returns:\n            activation maps (raw outputs without upsampling/post-processing.)\n        \"\"\"\n        raise NotImplementedError()\n\n    def _upsample_and_post_process(self, acti_map, x):\n        # upsampling and postprocessing\n        img_spatial = x.shape[2:]\n        acti_map = self.upsampler(img_spatial)(acti_map)\n        return self.postprocessing(acti_map)\n\n    def __call__(self):\n        raise NotImplementedError()\n\n\nclass CAM(CAMBase):\n    \"\"\"\n    Compute class activation map from the last fully-connected layers before the spatial pooling.\n    This implementation is based on:\n\n        Zhou et al., Learning Deep Features for Discriminative Localization. CVPR '16,\n        https://arxiv.org/abs/1512.04150\n\n    Examples\n\n    .. code-block:: python\n\n        import torch\n\n        # densenet 2d\n        from monai.networks.nets import DenseNet121\n        from monai.visualize import CAM\n\n        model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        cam = CAM(nn_module=model_2d, target_layers=\"class_layers.relu\", fc_layers=\"class_layers.out\")\n        result = cam(x=torch.rand((1, 1, 48, 64)))\n\n        # resnet 2d\n        from monai.networks.nets import seresnet50\n        from monai.visualize import CAM\n\n        model_2d = seresnet50(spatial_dims=2, in_channels=3, num_classes=4)\n        cam = CAM(nn_module=model_2d, target_layers=\"layer4\", fc_layers=\"last_linear\")\n        result = cam(x=torch.rand((2, 3, 48, 64)))\n\n    N.B.: To help select the target layer, it may be useful to list all layers:\n\n    .. code-block:: python\n\n        for name, _ in model.named_modules(): print(name)\n\n    See Also:\n\n        - :py:class:`monai.visualize.class_activation_maps.GradCAM`\n\n    \"\"\"\n\n    def __init__(\n        self,\n        nn_module: nn.Module,\n        target_layers: str,\n        fc_layers: str | Callable = \"fc\",\n        upsampler: Callable = default_upsampler,\n        postprocessing: Callable = default_normalizer,\n    ) -> None:\n        \"\"\"\n        Args:\n            nn_module: the model to be visualized\n            target_layers: name of the model layer to generate the feature map.\n            fc_layers: a string or a callable used to get fully-connected weights to compute activation map\n                from the target_layers (without pooling).  and evaluate it at every spatial location.\n            upsampler: An upsampling method to upsample the output image. Default is\n                N dimensional linear (bilinear, trilinear, etc.) depending on num spatial\n                dimensions of input.\n            postprocessing: a callable that applies on the upsampled output image.\n                Default is normalizing between min=1 and max=0 (i.e., largest input will become 0 and\n                smallest input will become 1).\n        \"\"\"\n        super().__init__(\n            nn_module=nn_module,\n            target_layers=target_layers,\n            upsampler=upsampler,\n            postprocessing=postprocessing,\n            register_backward=False,\n        )\n        self.fc_layers = fc_layers\n\n    def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs):  # type: ignore[override]\n        logits, acti, _ = self.nn_module(x, **kwargs)\n        acti = acti[layer_idx]\n        if class_idx is None:\n            class_idx = logits.max(1)[-1]\n        b, c, *spatial = acti.shape\n        acti = torch.split(acti.reshape(b, c, -1), 1, dim=2)  # make the spatial dims 1D\n        fc_layers = self.nn_module.get_layer(self.fc_layers)\n        output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2)\n        output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0)\n        return output.reshape(b, 1, *spatial)  # resume the spatial dims on the selected class\n\n    def __call__(self, x, class_idx=None, layer_idx=-1, **kwargs):  # type: ignore[override]\n        \"\"\"\n        Compute the activation map with upsampling and postprocessing.\n\n        Args:\n            x: input tensor, shape must be compatible with `nn_module`.\n            class_idx: index of the class to be visualized. Default to argmax(logits)\n            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.\n            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.\n\n        Returns:\n            activation maps\n        \"\"\"\n        acti_map = self.compute_map(x, class_idx, layer_idx, **kwargs)\n        return self._upsample_and_post_process(acti_map, x)\n\n\nclass GradCAM(CAMBase):\n    \"\"\"\n    Computes Gradient-weighted Class Activation Mapping (Grad-CAM).\n    This implementation is based on:\n\n        Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization,\n        https://arxiv.org/abs/1610.02391\n\n    Examples\n\n    .. code-block:: python\n\n        import torch\n\n        # densenet 2d\n        from monai.networks.nets import DenseNet121\n        from monai.visualize import GradCAM\n\n        model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        cam = GradCAM(nn_module=model_2d, target_layers=\"class_layers.relu\")\n        result = cam(x=torch.rand((1, 1, 48, 64)))\n\n        # resnet 2d\n        from monai.networks.nets import seresnet50\n        from monai.visualize import GradCAM\n\n        model_2d = seresnet50(spatial_dims=2, in_channels=3, num_classes=4)\n        cam = GradCAM(nn_module=model_2d, target_layers=\"layer4\")\n        result = cam(x=torch.rand((2, 3, 48, 64)))\n\n    N.B.: To help select the target layer, it may be useful to list all layers:\n\n    .. code-block:: python\n\n        for name, _ in model.named_modules(): print(name)\n\n    See Also:\n\n        - :py:class:`monai.visualize.class_activation_maps.CAM`\n\n    \"\"\"\n\n    def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):  # type: ignore[override]\n        _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)\n        acti, grad = acti[layer_idx], grad[layer_idx]\n        b, c, *spatial = grad.shape\n        weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial))\n        acti_map = (weights * acti).sum(1, keepdim=True)\n        return F.relu(acti_map)\n\n    def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs):  # type: ignore[override]\n        \"\"\"\n        Compute the activation map with upsampling and postprocessing.\n\n        Args:\n            x: input tensor, shape must be compatible with `nn_module`.\n            class_idx: index of the class to be visualized. Default to argmax(logits)\n            layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.\n            retain_graph: whether to retain_graph for torch module backward call.\n            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.\n\n        Returns:\n            activation maps\n        \"\"\"\n        acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx, **kwargs)\n        return self._upsample_and_post_process(acti_map, x)\n\n\nclass GradCAMpp(GradCAM):\n    \"\"\"\n    Computes Gradient-weighted Class Activation Mapping (Grad-CAM++).\n    This implementation is based on:\n\n        Chattopadhyay et al., Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks,\n        https://arxiv.org/abs/1710.11063\n\n    See Also:\n\n        - :py:class:`monai.visualize.class_activation_maps.GradCAM`\n\n    \"\"\"\n\n    def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):  # type: ignore[override]\n        _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)\n        acti, grad = acti[layer_idx], grad[layer_idx]\n        b, c, *spatial = grad.shape\n        alpha_nr = grad.pow(2)\n        alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial))\n        alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr))\n        alpha = alpha_nr.div(alpha_dr + 1e-7)\n        relu_grad = F.relu(cast(torch.Tensor, self.nn_module.score).exp() * grad)\n        weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial))\n        acti_map = (weights * acti).sum(1, keepdim=True)\n        return F.relu(acti_map)\n"
  },
  {
    "path": "monai/visualize/gradient_based.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom functools import partial\nfrom typing import Any, Callable\n\nimport torch\n\nfrom monai.networks.utils import replace_modules_temp\nfrom monai.utils.module import optional_import\nfrom monai.visualize.class_activation_maps import ModelWithHooks\n\ntrange, has_trange = optional_import(\"tqdm\", name=\"trange\")\n\n__all__ = [\"VanillaGrad\", \"SmoothGrad\", \"GuidedBackpropGrad\", \"GuidedBackpropSmoothGrad\"]\n\n\nclass _AutoGradReLU(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, x):\n        pos_mask = (x > 0).type_as(x)\n        output = torch.mul(x, pos_mask)\n        ctx.save_for_backward(x, output)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        x, _ = ctx.saved_tensors\n        pos_mask_1 = (x > 0).type_as(grad_output)\n        pos_mask_2 = (grad_output > 0).type_as(grad_output)\n        y = torch.mul(grad_output, pos_mask_1)\n        grad_input = torch.mul(y, pos_mask_2)\n        return grad_input\n\n\nclass _GradReLU(torch.nn.Module):\n    \"\"\"\n    A customized ReLU with the backward pass imputed for guided backpropagation (https://arxiv.org/abs/1412.6806).\n    \"\"\"\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        out: torch.Tensor = _AutoGradReLU.apply(x)\n        return out\n\n\nclass VanillaGrad:\n    \"\"\"\n    Given an input image ``x``, calling this class will perform the forward pass, then set to zero\n    all activations except one (defined by ``index``) and propagate back to the image to achieve a gradient-based\n    saliency map.\n\n    If ``index`` is None, argmax of the output logits will be used.\n\n    See also:\n\n        - Simonyan et al. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps\n          (https://arxiv.org/abs/1312.6034)\n    \"\"\"\n\n    def __init__(self, model: torch.nn.Module) -> None:\n        if not isinstance(model, ModelWithHooks):  # Convert to model with hooks if necessary\n            self._model = ModelWithHooks(model, target_layer_names=(), register_backward=True)\n        else:\n            self._model = model\n\n    @property\n    def model(self):\n        return self._model.model\n\n    @model.setter\n    def model(self, m):\n        if not isinstance(m, ModelWithHooks):  # regular model as ModelWithHooks\n            self._model.model = m\n        else:\n            self._model = m  # replace the ModelWithHooks\n\n    def get_grad(\n        self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph: bool = True, **kwargs: Any\n    ) -> torch.Tensor:\n        if x.shape[0] != 1:\n            raise ValueError(\"expect batch size of 1\")\n        x.requires_grad = True\n\n        self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)\n        grad: torch.Tensor = x.grad.detach()  # type: ignore\n        return grad\n\n    def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor:\n        return self.get_grad(x, index, **kwargs)\n\n\nclass SmoothGrad(VanillaGrad):\n    \"\"\"\n    Compute averaged sensitivity map based on ``n_samples`` (Gaussian additive) of noisy versions\n    of the input image ``x``.\n\n    See also:\n\n        - Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825\n    \"\"\"\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        stdev_spread: float = 0.15,\n        n_samples: int = 25,\n        magnitude: bool = True,\n        verbose: bool = True,\n    ) -> None:\n        super().__init__(model)\n        self.stdev_spread = stdev_spread\n        self.n_samples = n_samples\n        self.magnitude = magnitude\n        self.range: Callable\n        if verbose and has_trange:\n            self.range = partial(trange, desc=f\"Computing {self.__class__.__name__}\")\n        else:\n            self.range = range\n\n    def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor:\n        stdev = (self.stdev_spread * (x.max() - x.min())).item()\n        total_gradients = torch.zeros_like(x)\n        for _ in self.range(self.n_samples):\n            # create noisy image\n            noise = torch.normal(0, stdev, size=x.shape, dtype=torch.float32, device=x.device)\n            x_plus_noise = x + noise\n            x_plus_noise = x_plus_noise.detach()\n\n            # get gradient and accumulate\n            grad = self.get_grad(x_plus_noise, index, **kwargs)\n            total_gradients += (grad * grad) if self.magnitude else grad\n\n        # average\n        if self.magnitude:\n            total_gradients = total_gradients**0.5\n\n        return total_gradients / self.n_samples\n\n\nclass GuidedBackpropGrad(VanillaGrad):\n    \"\"\"\n    Based on Springenberg and Dosovitskiy et al. https://arxiv.org/abs/1412.6806,\n    compute gradient-based saliency maps by backpropagating positive gradients and inputs (see ``_AutoGradReLU``).\n\n    See also:\n\n        - Springenberg and Dosovitskiy et al. Striving for Simplicity: The All Convolutional Net\n          (https://arxiv.org/abs/1412.6806)\n    \"\"\"\n\n    def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor:\n        with replace_modules_temp(self.model, \"relu\", _GradReLU(), strict_match=False):\n            return super().__call__(x, index, **kwargs)\n\n\nclass GuidedBackpropSmoothGrad(SmoothGrad):\n    \"\"\"\n    Compute gradient-based saliency maps based on both ``GuidedBackpropGrad`` and ``SmoothGrad``.\n    \"\"\"\n\n    def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs: Any) -> torch.Tensor:\n        with replace_modules_temp(self.model, \"relu\", _GradReLU(), strict_match=False):\n            return super().__call__(x, index, **kwargs)\n"
  },
  {
    "path": "monai/visualize/img2tensorboard.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config import NdarrayTensor\nfrom monai.transforms import rescale_array\nfrom monai.utils import convert_data_type, optional_import\n\nPIL, _ = optional_import(\"PIL\")\nGifImage, _ = optional_import(\"PIL.GifImagePlugin\", name=\"Image\")\n\nif TYPE_CHECKING:\n    from tensorboard.compat.proto.summary_pb2 import Summary\n    from tensorboardX import SummaryWriter as SummaryWriterX\n    from tensorboardX.proto.summary_pb2 import Summary as SummaryX\n    from torch.utils.tensorboard import SummaryWriter\n\n    has_tensorboardx = True\nelse:\n    Summary, _ = optional_import(\"tensorboard.compat.proto.summary_pb2\", name=\"Summary\")\n    SummaryX, _ = optional_import(\"tensorboardX.proto.summary_pb2\", name=\"Summary\")\n    SummaryWriter, _ = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n    SummaryWriterX, has_tensorboardx = optional_import(\"tensorboardX\", name=\"SummaryWriter\")\n\n__all__ = [\"make_animated_gif_summary\", \"add_animated_gif\", \"plot_2d_or_3d_image\"]\n\n\ndef _image3_animated_gif(\n    tag: str,\n    image: np.ndarray | torch.Tensor,\n    writer: SummaryWriter | SummaryWriterX | None,\n    frame_dim: int = 0,\n    scale_factor: float = 1.0,\n) -> Any:\n    \"\"\"Function to actually create the animated gif.\n\n    Args:\n        tag: Data identifier\n        image: 3D image tensors expected to be in `HWD` format\n        writer: the tensorboard writer to plot image\n        frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`.\n        scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will\n            scale it to displayable range\n    \"\"\"\n    if len(image.shape) != 3:\n        raise AssertionError(\"3D image tensors expected to be in `HWD` format, len(image.shape) != 3\")\n\n    image_np, *_ = convert_data_type(image, output_type=np.ndarray)\n    ims = [(i * scale_factor).astype(np.uint8, copy=False) for i in np.moveaxis(image_np, frame_dim, 0)]\n    ims = [GifImage.fromarray(im) for im in ims]\n    img_str = b\"\"\n    for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:\n        img_str += b_data\n    img_str += b\"\\x21\\xff\\x0b\\x4e\\x45\\x54\\x53\\x43\\x41\\x50\" b\"\\x45\\x32\\x2e\\x30\\x03\\x01\\x00\\x00\\x00\"\n    for i in ims:\n        for b_data in PIL.GifImagePlugin.getdata(i):\n            img_str += b_data\n    img_str += b\"\\x3b\"\n\n    summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary\n    summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)\n    image_summary = summary.Value(tag=tag, image=summary_image_str)\n    return summary(value=[image_summary])\n\n\ndef make_animated_gif_summary(\n    tag: str,\n    image: np.ndarray | torch.Tensor,\n    writer: SummaryWriter | SummaryWriterX | None = None,\n    max_out: int = 3,\n    frame_dim: int = -3,\n    scale_factor: float = 1.0,\n) -> Summary:\n    \"\"\"Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary.\n\n    Args:\n        tag: Data identifier\n        image: The image, expected to be in `CHWD` format\n        writer: the tensorboard writer to plot image\n        max_out: maximum number of image channels to animate through\n        frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,\n            default to `-3` (the first spatial dim)\n        scale_factor: amount to multiply values by.\n            if the image data is between 0 and 1, using 255 for this value will scale it to displayable range\n    \"\"\"\n\n    suffix = \"/image\" if max_out == 1 else \"/image/{}\"\n    # GIF image has no channel dim, reduce the spatial dim index if positive\n    frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim\n\n    summary_op = []\n    for it_i in range(min(max_out, list(image.shape)[0])):\n        one_channel_img: torch.Tensor | np.ndarray = (\n            image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :]\n        )\n        summary_op.append(\n            _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, frame_dim, scale_factor)\n        )\n    return summary_op\n\n\ndef add_animated_gif(\n    writer: SummaryWriter | SummaryWriterX,\n    tag: str,\n    image_tensor: np.ndarray | torch.Tensor,\n    max_out: int = 3,\n    frame_dim: int = -3,\n    scale_factor: float = 1.0,\n    global_step: int | None = None,\n) -> None:\n    \"\"\"Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.\n\n    Args:\n        writer: Tensorboard SummaryWriter to write to\n        tag: Data identifier\n        image_tensor: tensor for the image to add, expected to be in `CHWD` format\n        max_out: maximum number of image channels to animate through\n        frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,\n            default to `-3` (the first spatial dim)\n        scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will\n            scale it to displayable range\n        global_step: Global step value to record\n    \"\"\"\n    summary = make_animated_gif_summary(\n        tag=tag, image=image_tensor, writer=writer, max_out=max_out, frame_dim=frame_dim, scale_factor=scale_factor\n    )\n    for s in summary:\n        # add GIF for every channel separately\n        writer._get_file_writer().add_summary(s, global_step)\n\n\ndef plot_2d_or_3d_image(\n    data: NdarrayTensor | list[NdarrayTensor],\n    step: int,\n    writer: SummaryWriter | SummaryWriterX,\n    index: int = 0,\n    max_channels: int = 1,\n    frame_dim: int = -3,\n    max_frames: int = 24,\n    tag: str = \"output\",\n) -> None:\n    \"\"\"Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.\n\n    Note:\n        Plot 3D or 2D image(with more than 3 channels) as separate images.\n        And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.\n\n    Args:\n        data: target data to be plotted as image on the TensorBoard.\n            The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions,\n            and only plot the first in the batch.\n        step: current step to plot in a chart.\n        writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.\n        index: plot which element in the input data batch, default is the first element.\n        max_channels: number of channels to plot.\n        frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,\n            expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)\n        max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.\n        tag: tag of the plotted image on TensorBoard.\n    \"\"\"\n    data_index = data[index]\n    # as the `d` data has no batch dim, reduce the spatial dim index if positive\n    frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim\n\n    d: np.ndarray = (\n        data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else np.asarray(data_index)\n    )\n\n    if d.ndim == 2:\n        d = rescale_array(d, 0, 1)  # type: ignore\n        dataformats = \"HW\"\n        writer.add_image(f\"{tag}_{dataformats}\", d, step, dataformats=dataformats)\n        return\n\n    if d.ndim == 3:\n        if d.shape[0] == 3 and max_channels == 3:  # RGB\n            dataformats = \"CHW\"\n            writer.add_image(f\"{tag}_{dataformats}\", d, step, dataformats=dataformats)\n            return\n        dataformats = \"HW\"\n        for j, d2 in enumerate(d[:max_channels]):\n            d2 = rescale_array(d2, 0, 1)\n            writer.add_image(f\"{tag}_{dataformats}_{j}\", d2, step, dataformats=dataformats)\n        return\n\n    if d.ndim >= 4:\n        spatial = d.shape[-3:]\n        d = d.reshape([-1] + list(spatial))\n        d_chans = d.shape[0]  # type: ignore\n        if d_chans == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX):  # RGB\n            # move the expected frame dim to the end as `T` dim for video\n            d = np.moveaxis(d, frame_dim, -1)\n            writer.add_video(tag, d[None], step, fps=max_frames, dataformats=\"NCHWT\")\n            return\n        # scale data to 0 - 255 for visualization\n        max_channels = min(max_channels, d_chans)\n        d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0)\n        # will plot every channel as a separate GIF image\n        add_animated_gif(writer, f\"{tag}_HWD\", d, max_out=max_channels, frame_dim=frame_dim, global_step=step)\n        return\n"
  },
  {
    "path": "monai/visualize/occlusion_sensitivity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Mapping, Sequence\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.networks.utils import eval_mode\nfrom monai.transforms import Compose, GaussianSmooth, Lambda, ScaleIntensity, SpatialCrop\nfrom monai.utils import ensure_tuple_rep\n\n\nclass OcclusionSensitivity:\n    \"\"\"\n    This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity,\n    we mean how the probability of a given prediction changes as the occluded section of an image changes. This can be\n    useful to understand why a network is making certain decisions.\n\n    As important parts of the image are occluded, the probability of classifying the image correctly will decrease.\n    Hence, more negative values imply the corresponding occluded volume was more important in the decision process.\n\n    Two ``torch.Tensor`` will be returned by the ``__call__`` method: an occlusion map and an image of the most probable\n    class. Both images will be cropped if a bounding box used, but voxel sizes will always match the input.\n\n    The occlusion map shows the inference probabilities when the corresponding part of the image is occluded. Hence,\n    more -ve values imply that region was important in the decision process. The map will have shape ``BCHW(D)N``,\n    where ``N`` is the number of classes to be inferred by the network. Hence, the occlusion for class ``i`` can\n    be seen with ``map[...,i]``.\n\n    The most probable class is an image of the probable class when the corresponding part of the image is occluded\n    (equivalent to ``occ_map.argmax(dim=-1)``).\n\n    See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via\n    Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74.\n\n    Examples:\n\n    .. code-block:: python\n\n        # densenet 2d\n        from monai.networks.nets import DenseNet121\n        from monai.visualize import OcclusionSensitivity\n        import torch\n\n        model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        occ_sens = OcclusionSensitivity(nn_module=model_2d)\n        occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), b_box=[2, 40, 1, 62])\n\n        # densenet 3d\n        from monai.networks.nets import DenseNet\n        from monai.visualize import OcclusionSensitivity\n\n        model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))\n        occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10)\n        occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), b_box=[1, 3, -1, -1, -1, -1])\n\n    See Also:\n\n        - :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.`\n    \"\"\"\n\n    def __init__(\n        self,\n        nn_module: nn.Module,\n        mask_size: int | Sequence = 16,\n        n_batch: int = 16,\n        verbose: bool = True,\n        mode: str | float | Callable = \"gaussian\",\n        overlap: float = 0.6,\n        activate: bool | Callable = True,\n    ) -> None:\n        \"\"\"\n        Occlusion sensitivity constructor.\n\n        Args:\n            nn_module: Classification model to use for inference\n            mask_size: Size of box to be occluded, centred on the central voxel. If a single number\n                is given, this is used for all dimensions. If a sequence is given, this is used for each dimension\n                individually.\n            n_batch: Number of images in a batch for inference.\n            verbose: Use progress bar (if ``tqdm`` available).\n            mode: what should the occluded region be replaced with? If a float is given, that value will be used\n                throughout the occlusion. Else, ``gaussian``, ``mean_img`` and ``mean_patch`` can be supplied:\n\n                * ``gaussian``: occluded region is multiplied by 1 - gaussian kernel. In this fashion, the occlusion\n                  will be 0 at the center and will be unchanged towards the edges, varying smoothly between. When\n                  gaussian is used, a weighted average will be used to combine overlapping regions. This will be\n                  done using the gaussian (not 1-gaussian) as occluded regions count more.\n                * ``mean_patch``: occluded region will be replaced with the mean of occluded region.\n                * ``mean_img``: occluded region will be replaced with the mean of the whole image.\n\n            overlap: overlap between inferred regions. Should be in range 0<=x<1.\n            activate: if ``True``, do softmax activation if num_channels > 1 else do ``sigmoid``. If ``False``, don't do any\n                activation. If ``callable``, use callable on inferred outputs.\n\n        \"\"\"\n        self.nn_module = nn_module\n        self.mask_size = mask_size\n        self.n_batch = n_batch\n        self.verbose = verbose\n        self.overlap = overlap\n        self.activate = activate\n        # mode\n        if isinstance(mode, str) and mode not in (\"gaussian\", \"mean_patch\", \"mean_img\"):\n            raise NotImplementedError\n        self.mode = mode\n\n    @staticmethod\n    def constant_occlusion(x: torch.Tensor, val: float, mask_size: Sequence) -> tuple[float, torch.Tensor]:\n        \"\"\"Occlude with a constant occlusion. Multiplicative is zero, additive is constant value.\"\"\"\n        ones = torch.ones((*x.shape[:2], *mask_size), device=x.device, dtype=x.dtype)\n        return 0, ones * val\n\n    @staticmethod\n    def gaussian_occlusion(x: torch.Tensor, mask_size: Sequence, sigma: float = 0.25) -> tuple[torch.Tensor, float]:\n        \"\"\"\n        For Gaussian occlusion, Multiplicative is 1-Gaussian, additive is zero.\n        Default sigma of 0.25 empirically shown to give reasonable kernel, see here:\n        https://github.com/Project-MONAI/MONAI/pull/5230#discussion_r984520714.\n        \"\"\"\n        kernel = torch.zeros((x.shape[1], *mask_size), device=x.device, dtype=x.dtype)\n        spatial_shape = kernel.shape[1:]\n        # all channels (as occluded shape already takes into account per_channel), center in spatial dimensions\n        center = [slice(None)] + [slice(s // 2, s // 2 + 1) for s in spatial_shape]\n        # place value of 1 at center\n        kernel[center] = 1.0\n        # Smooth with sigma equal to quarter of image, flip +ve/-ve so largest values are at edge\n        # and smallest at center. Scale to [0, 1].\n        gaussian = Compose(\n            [GaussianSmooth(sigma=[b * sigma for b in spatial_shape]), Lambda(lambda x: -x), ScaleIntensity()]\n        )\n        # transform and add batch\n        mul: torch.Tensor = gaussian(kernel)[None]\n\n        return mul, 0\n\n    @staticmethod\n    def predictor(\n        cropped_grid: torch.Tensor,\n        nn_module: nn.Module,\n        x: torch.Tensor,\n        mul: torch.Tensor | float,\n        add: torch.Tensor | float,\n        mask_size: Sequence,\n        occ_mode: str,\n        activate: bool | Callable,\n        module_kwargs: Mapping[str, Any],\n    ) -> torch.Tensor:\n        \"\"\"\n        Predictor function to be passed to the sliding window inferer. Takes a cropped meshgrid,\n        referring to the coordinates in the input image. We use the index of the top-left corner\n        in combination ``mask_size`` to figure out which region of the image is to be occluded. The\n        occlusion is performed on the original image, ``x``, using ``cropped_region * mul + add``. ``mul``\n        and ``add`` are sometimes pre-computed (e.g., a constant Gaussian blur), or they are\n        sometimes calculated on the fly (e.g., the mean of the occluded patch). For this reason\n        ``occ_mode`` is given. Lastly, ``activate`` is used to activate after each call of the model.\n\n        Args:\n            cropped_grid: subsection of the meshgrid, where each voxel refers to the coordinate of\n                the input image. The meshgrid is created by the ``OcclusionSensitivity`` class, and\n                the generation of the subset is determined by ``sliding_window_inference``.\n            nn_module: module to call on data.\n            x: the image that was originally passed into ``OcclusionSensitivity.__call__``.\n            mul: occluded region will be multiplied by this. Can be ``torch.Tensor`` or ``float``.\n            add: after multiplication, this is added to the occluded region. Can be ``torch.Tensor`` or ``float``.\n            mask_size: Size of box to be occluded, centred on the central voxel. Should be\n                a sequence, one value for each spatial dimension.\n            occ_mode: might be used to calculate ``mul`` and ``add`` on the fly.\n            activate: if ``True``, do softmax activation if num_channels > 1 else do ``sigmoid``. If ``False``, don't do any\n                activation. If ``callable``, use callable on inferred outputs.\n            module_kwargs: kwargs to be passed onto module when inferring\n        \"\"\"\n        n_batch = cropped_grid.shape[0]\n        sd = cropped_grid.ndim - 2\n        # start with copies of x to infer\n        im = torch.repeat_interleave(x, n_batch, 0)\n        # get coordinates of top left corner of occluded region (possible because we use meshgrid)\n        corner_coord_slices = [slice(None)] * 2 + [slice(1)] * sd\n        top_corners = cropped_grid[corner_coord_slices]\n\n        # replace occluded regions\n        for b, t in enumerate(top_corners):\n            # starting from corner, get the slices to extract the occluded region from the image\n            slices = [slice(b, b + 1), slice(None)] + [slice(int(j), int(j) + m) for j, m in zip(t, mask_size)]\n            to_occlude = im[slices]\n            if occ_mode == \"mean_patch\":\n                add, mul = OcclusionSensitivity.constant_occlusion(x, to_occlude.mean().item(), mask_size)\n\n            if callable(occ_mode):\n                to_occlude = occ_mode(x, to_occlude)\n            else:\n                to_occlude = to_occlude * mul + add\n            if add is None or mul is None:\n                raise RuntimeError(\"Shouldn't be here, something's gone wrong...\")\n            im[slices] = to_occlude\n        # infer\n        out: torch.Tensor = nn_module(im, **module_kwargs)\n\n        # if activation is callable, call it\n        if callable(activate):\n            out = activate(out)\n        # else if True (should be boolean), sigmoid if n_chan == 1 else softmax\n        elif activate:\n            out = out.sigmoid() if x.shape[1] == 1 else out.softmax(1)\n\n        # the output will have shape [B,C] where C is number of channels output by model (inference classes)\n        # we need to return it to sliding window inference with shape [B,C,H,W,[D]], so add dims and repeat values\n        for m in mask_size:\n            out = torch.repeat_interleave(out.unsqueeze(-1), m, dim=-1)\n\n        return out\n\n    @staticmethod\n    def crop_meshgrid(\n        grid: MetaTensor, b_box: Sequence, mask_size: Sequence\n    ) -> tuple[MetaTensor, SpatialCrop, Sequence]:\n        \"\"\"Crop the meshgrid so we only perform occlusion sensitivity on a subsection of the image.\"\"\"\n        # distance from center of mask to edge is -1 // 2.\n        mask_edge = [(m - 1) // 2 for m in mask_size]\n        bbox_min = [max(b - m, 0) for b, m in zip(b_box[::2], mask_edge)]\n        bbox_max = []\n        for b, m, s in zip(b_box[1::2], mask_edge, grid.shape[2:]):\n            # if bbox is -ve for that dimension, no cropping so use current image size\n            if b == -1:\n                bbox_max.append(s)\n            # else bounding box plus distance to mask edge. Make sure it's not bigger than the size of the image\n            else:\n                bbox_max.append(min(b + m, s))\n        # bbox_max = [min(b + m, s) if b >= 0 else s for b, m, s in zip(b_box[1::2], mask_edge, grid.shape[2:])]\n        # No need for batch and channel slices. Batch will be removed and added back in, and\n        # SpatialCrop doesn't act on the first dimension anyway.\n        slices = [slice(s, e) for s, e in zip(bbox_min, bbox_max)]\n        cropper = SpatialCrop(roi_slices=slices)\n        cropped: MetaTensor = cropper(grid[0])[None]  # type: ignore\n        mask_size = list(mask_size)\n        for i, s in enumerate(cropped.shape[2:]):\n            mask_size[i] = min(s, mask_size[i])\n        return cropped, cropper, mask_size\n\n    def __call__(\n        self, x: torch.Tensor, b_box: Sequence | None = None, **kwargs: Any\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            x: Image to use for inference. Should be a tensor consisting of 1 batch.\n            b_box: Bounding box on which to perform the analysis. The output image will be limited to this size.\n                There should be a minimum and maximum for all spatial dimensions: ``[min1, max1, min2, max2,...]``.\n                * By default, the whole image will be used. Decreasing the size will speed the analysis up, which might\n                    be useful for larger images.\n                * Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``.\n                * Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension.\n                * N.B.: we add half of the mask size to the bounding box to ensure that the region of interest has a\n                    sufficiently large area surrounding it.\n            kwargs: any extra arguments to be passed on to the module as part of its `__call__`.\n\n        Returns:\n            * Occlusion map:\n                * Shows the inference probabilities when the corresponding part of the image is occluded.\n                    Hence, more -ve values imply that region was important in the decision process.\n                * The map will have shape ``BCHW(D)N``, where N is the number of classes to be inferred by the\n                    network. Hence, the occlusion for class ``i`` can be seen with ``map[...,i]``.\n                * If `per_channel==False`, output ``C`` will equal 1: ``B1HW(D)N``\n            * Most probable class:\n                * The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``).\n            Both images will be cropped if a bounding box used, but voxel sizes will always match the input.\n        \"\"\"\n        if x.shape[0] > 1:\n            raise ValueError(\"Expected batch size of 1.\")\n\n        sd = x.ndim - 2\n        mask_size: Sequence = ensure_tuple_rep(self.mask_size, sd)\n\n        # get the meshgrid (so that sliding_window_inference can tell us which bit to occlude)\n        grid: MetaTensor = MetaTensor(\n            np.stack(np.meshgrid(*[np.arange(0, i) for i in x.shape[2:]], indexing=\"ij\"))[None],\n            device=x.device,\n            dtype=x.dtype,\n        )\n        # if bounding box given, crop the grid to only infer subsections of the image\n        if b_box is not None:\n            grid, cropper, mask_size = self.crop_meshgrid(grid, b_box, mask_size)\n\n        # check that the grid is bigger than the mask size\n        if any(m > g for g, m in zip(grid.shape[2:], mask_size)):\n            raise ValueError(f\"Image (spatial shape) {grid.shape[2:]} should be bigger than mask {mask_size}.\")\n\n        # get additive and multiplicative factors if they are unchanged for all patches (i.e., not mean_patch)\n        add: float | torch.Tensor | None\n        mul: float | torch.Tensor | None\n        # multiply by 0, add value\n        if isinstance(self.mode, float):\n            mul, add = self.constant_occlusion(x, self.mode, mask_size)\n        # multiply by 0, add mean of image\n        elif self.mode == \"mean_img\":\n            mul, add = self.constant_occlusion(x, x.mean().item(), mask_size)\n        # for gaussian, additive = 0, multiplicative = gaussian\n        elif self.mode == \"gaussian\":\n            mul, add = self.gaussian_occlusion(x, mask_size)\n        # else will be determined on each patch individually so calculated later\n        else:\n            add, mul = None, None\n\n        with eval_mode(self.nn_module):\n            # needs to go here to avoid circular import\n            from monai.inferers import sliding_window_inference\n\n            sensitivity_im: MetaTensor = sliding_window_inference(  # type: ignore\n                grid,\n                roi_size=mask_size,\n                sw_batch_size=self.n_batch,\n                predictor=OcclusionSensitivity.predictor,\n                overlap=self.overlap,\n                mode=\"gaussian\" if self.mode == \"gaussian\" else \"constant\",\n                progress=self.verbose,\n                nn_module=self.nn_module,\n                x=x,\n                add=add,\n                mul=mul,\n                mask_size=mask_size,\n                occ_mode=self.mode,\n                activate=self.activate,\n                module_kwargs=kwargs,\n            )\n\n        if b_box is not None:\n            # undo the cropping that was applied to the meshgrid\n            sensitivity_im = cropper.inverse(sensitivity_im[0])[None]  # type: ignore\n            # crop using the bounding box (ignoring the mask size this time)\n            bbox_min = [max(b, 0) for b in b_box[::2]]\n            bbox_max = [b if b > 0 else s for b, s in zip(b_box[1::2], x.shape[2:])]\n            cropper = SpatialCrop(roi_start=bbox_min, roi_end=bbox_max)\n            sensitivity_im = cropper(sensitivity_im[0])[None]  # type: ignore\n\n        # The most probable class is the max in the classification dimension (1)\n        most_probable_class = sensitivity_im.argmax(dim=1, keepdim=True)\n        return sensitivity_im, most_probable_class\n"
  },
  {
    "path": "monai/visualize/utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom typing import TYPE_CHECKING, Any\n\nimport numpy as np\nimport torch\n\nfrom monai.config.type_definitions import DtypeLike, NdarrayOrTensor\nfrom monai.transforms.croppad.array import SpatialPad\nfrom monai.transforms.utils import rescale_array\nfrom monai.transforms.utils_pytorch_numpy_unification import repeat\nfrom monai.utils.module import optional_import\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type\n\nif TYPE_CHECKING:\n    from matplotlib import pyplot as plt\nelse:\n    plt, _ = optional_import(\"matplotlib\", name=\"pyplot\")\n\n__all__ = [\"matshow3d\", \"blend_images\"]\n\n\ndef matshow3d(\n    volume: NdarrayOrTensor,\n    fig: Any = None,\n    title: str | None = None,\n    figsize: tuple[int, int] = (10, 10),\n    frames_per_row: int | None = None,\n    frame_dim: int = -3,\n    channel_dim: int | None = None,\n    vmin: float | None = None,\n    vmax: float | None = None,\n    every_n: int = 1,\n    interpolation: str = \"none\",\n    show: bool = False,\n    fill_value: Any = np.nan,\n    margin: int = 1,\n    dtype: DtypeLike = np.float32,\n    **kwargs: Any,\n) -> tuple[Any, np.ndarray]:\n    \"\"\"\n    Create a 3D volume figure as a grid of images.\n\n    Args:\n        volume: 3D volume to display. data shape can be `BCHWD`, `CHWD` or `HWD`.\n            Higher dimensional arrays will be reshaped into (-1, H, W, [C]), `C` depends on `channel_dim` arg.\n            A list of channel-first (C, H[, W, D]) arrays can also be passed in,\n            in which case they will be displayed as a padded and stacked volume.\n        fig: matplotlib figure or Axes to use. If None, a new figure will be created.\n        title: title of the figure.\n        figsize: size of the figure.\n        frames_per_row: number of frames to display in each row. If None, sqrt(firstdim) will be used.\n        frame_dim: for higher dimensional arrays, which dimension from (`-1`, `-2`, `-3`) is moved to\n            the `-3` dimension. dim and reshape to (-1, H, W) shape to construct frames, default to `-3`.\n        channel_dim: if not None, explicitly specify the channel dimension to be transposed to the\n            last dimensionas shape (-1, H, W, C). this can be used to plot RGB color image.\n            if None, the channel dimension will be flattened with `frame_dim` and `batch_dim` as shape (-1, H, W).\n            note that it can only support 3D input image. default is None.\n        vmin: `vmin` for the matplotlib `imshow`.\n        vmax: `vmax` for the matplotlib `imshow`.\n        every_n: factor to subsample the frames so that only every n-th frame is displayed.\n        interpolation: interpolation to use for the matplotlib `matshow`.\n        show: if True, show the figure.\n        fill_value: value to use for the empty part of the grid.\n        margin: margin to use for the grid.\n        dtype: data type of the output stacked frames.\n        kwargs: additional keyword arguments to matplotlib `matshow` and `imshow`.\n\n    See Also:\n        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html\n        - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.matshow.html\n\n    Example:\n\n        >>> import numpy as np\n        >>> import matplotlib.pyplot as plt\n        >>> from monai.visualize import matshow3d\n        # create a figure of a 3D volume\n        >>> volume = np.random.rand(10, 10, 10)\n        >>> fig = plt.figure()\n        >>> matshow3d(volume, fig=fig, title=\"3D Volume\")\n        >>> plt.show()\n        # create a figure of a list of channel-first 3D volumes\n        >>> volumes = [np.random.rand(1, 10, 10, 10), np.random.rand(1, 10, 10, 10)]\n        >>> fig = plt.figure()\n        >>> matshow3d(volumes, fig=fig, title=\"List of Volumes\")\n        >>> plt.show()\n\n    \"\"\"\n    vol = convert_data_type(data=volume, output_type=np.ndarray)[0]\n    if channel_dim is not None:\n        if channel_dim not in [0, 1] or vol.shape[channel_dim] not in [1, 3, 4]:\n            raise ValueError(\"channel_dim must be: None, 0 or 1, and channels of image must be 1, 3 or 4.\")\n\n    if isinstance(vol, (list, tuple)):\n        # a sequence of channel-first volumes\n        if not isinstance(vol[0], np.ndarray):\n            raise ValueError(\"volume must be a list of arrays.\")\n        pad_size = np.max(np.asarray([v.shape for v in vol]), axis=0)\n        pad = SpatialPad(pad_size[1:])  # assuming channel-first for item in vol\n        vol = np.concatenate([pad(v) for v in vol], axis=0)\n    else:  # ndarray\n        while len(vol.shape) < 3:\n            vol = np.expand_dims(vol, 0)  # type: ignore  # so that we display 2d as well\n\n    if channel_dim is not None:  # move the expected dim to construct frames with `B` dim\n        vol = np.moveaxis(vol, frame_dim, -4)  # type: ignore\n        vol = vol.reshape((-1, vol.shape[-3], vol.shape[-2], vol.shape[-1]))  # type: ignore[assignment]\n    else:\n        vol = np.moveaxis(vol, frame_dim, -3)  # type: ignore\n        vol = vol.reshape((-1, vol.shape[-2], vol.shape[-1]))  # type: ignore[assignment]\n    vmin = np.nanmin(vol) if vmin is None else vmin\n    vmax = np.nanmax(vol) if vmax is None else vmax\n\n    # subsample every_n-th frame of the 3D volume\n    vol = vol[:: max(every_n, 1)]  # type: ignore[assignment]\n    if not frames_per_row:\n        frames_per_row = int(np.ceil(np.sqrt(len(vol))))\n    # create the grid of frames\n    cols = max(min(len(vol), frames_per_row), 1)\n    rows = int(np.ceil(len(vol) / cols))\n    width = [[0, cols * rows - len(vol)]]\n    if channel_dim is not None:\n        width += [[0, 0]]  # add pad width for the channel dim\n    width += [[margin, margin]] * 2\n    vol = np.pad(vol.astype(dtype, copy=False), width, mode=\"constant\", constant_values=fill_value)  # type: ignore\n    im = np.block([[vol[i * cols + j] for j in range(cols)] for i in range(rows)])\n    if channel_dim is not None:\n        # move channel dim to the end\n        im = np.moveaxis(im, 0, -1)\n\n    # figure related configurations\n    if isinstance(fig, plt.Axes):\n        ax = fig\n    else:\n        if fig is None:\n            fig = plt.figure(tight_layout=True)\n        if not fig.axes:\n            fig.add_subplot(111)\n        ax = fig.axes[0]\n    ax.matshow(im, vmin=vmin, vmax=vmax, interpolation=interpolation, **kwargs)\n    ax.axis(\"off\")\n\n    if title is not None:\n        ax.set_title(title)\n    if figsize is not None and hasattr(fig, \"set_size_inches\"):\n        fig.set_size_inches(figsize)\n    if show:\n        plt.show()\n    return fig, im\n\n\ndef blend_images(\n    image: NdarrayOrTensor,\n    label: NdarrayOrTensor,\n    alpha: float | NdarrayOrTensor = 0.5,\n    cmap: str = \"hsv\",\n    rescale_arrays: bool = True,\n    transparent_background: bool = True,\n) -> NdarrayOrTensor:\n    \"\"\"\n    Blend an image and a label. Both should have the shape CHW[D].\n    The image may have C==1 or 3 channels (greyscale or RGB).\n    The label is expected to have C==1.\n\n    Args:\n        image: the input image to blend with label data.\n        label: the input label to blend with image data.\n        alpha: this specifies the weighting given to the label, where 0 is completely\n            transparent and 1 is completely opaque. This can be given as either a\n            single value or an array/tensor that is the same size as the input image.\n        cmap: specify colormap in the matplotlib, default to `hsv`, for more details, please refer to:\n            https://matplotlib.org/2.0.2/users/colormaps.html.\n        rescale_arrays: whether to rescale the array to [0, 1] first, default to `True`.\n        transparent_background: if true, any zeros in the label field will not be colored.\n\n    .. image:: ../../docs/images/blend_images.png\n\n    \"\"\"\n\n    if label.shape[0] != 1:\n        raise ValueError(\"Label should have 1 channel.\")\n    if image.shape[0] not in (1, 3):\n        raise ValueError(\"Image should have 1 or 3 channels.\")\n    if image.shape[1:] != label.shape[1:]:\n        raise ValueError(\"image and label should have matching spatial sizes.\")\n    if isinstance(alpha, (np.ndarray, torch.Tensor)):\n        if image.shape[1:] != alpha.shape[1:]:  # pytype: disable=attribute-error,invalid-directive\n            raise ValueError(\"if alpha is image, size should match input image and label.\")\n\n    # rescale arrays to [0, 1] if desired\n    if rescale_arrays:\n        image = rescale_array(image)\n        label = rescale_array(label)\n    # convert image to rgb (if necessary) and then rgba\n    if image.shape[0] == 1:\n        image = repeat(image, 3, axis=0)\n\n    def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor:\n        _cmap = plt.colormaps.get_cmap(cmap)\n        label_np, *_ = convert_data_type(label, np.ndarray)\n        label_rgb_np = _cmap(label_np[0])\n        label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3]\n        label_rgb, *_ = convert_to_dst_type(label_rgb_np, label)\n        return label_rgb\n\n    label_rgb = get_label_rgb(cmap, label)\n    if isinstance(alpha, (torch.Tensor, np.ndarray)):\n        w_label = alpha\n    elif isinstance(label, torch.Tensor):\n        w_label = torch.full_like(label, alpha)\n    else:\n        w_label = np.full_like(label, alpha)\n    if transparent_background:\n        # where label == 0 (background), set label alpha to 0\n        w_label[label == 0] = 0  # pytype: disable=unsupported-operands\n\n    w_image = 1 - w_label\n    return w_image * image + w_label * label_rgb\n"
  },
  {
    "path": "monai/visualize/visualizer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom collections.abc import Callable, Sized\n\nimport torch\nimport torch.nn.functional as F\n\nfrom monai.utils import InterpolateMode\n\n__all__ = [\"default_upsampler\"]\n\n\ndef default_upsampler(spatial_size: Sized, align_corners: bool = False) -> Callable[[torch.Tensor], torch.Tensor]:\n    \"\"\"\n    A linear interpolation method for upsampling the feature map.\n    The output of this function is a callable `func`,\n    such that `func(x)` returns an upsampled tensor.\n    \"\"\"\n\n    def up(x):\n        linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]\n        interp_mode = linear_mode[len(spatial_size) - 1]\n        return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=align_corners)\n\n    return up\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\n  \"wheel\",\n  \"setuptools\",\n  \"more-itertools>=8.0\",\n  \"torch>=2.4.1\",\n  \"ninja\",\n  \"packaging\"\n]\n\n[tool.black]\nline-length = 120\ntarget-version = ['py39', 'py310', 'py311', 'py312']\ninclude = '\\.pyi?$'\nexclude = '''\n(\n  /(\n    # exclude a few common directories in the root of the project\n      \\.eggs\n    | \\.git\n    | \\.hg\n    | \\.mypy_cache\n    | \\.tox\n    | \\.venv\n    | venv\n    | \\.pytype\n    | _build\n    | buck-out\n    | build\n    | dist\n  )/\n  # also separately exclude a file named versioneer.py\n  | monai/_version.py\n)\n'''\n\n[tool.pycln]\nall = true\nexclude = \"monai/bundle/__main__.py\"\n\n[tool.ruff]\nline-length = 120\ntarget-version = \"py39\"\n\n[tool.ruff.lint]\nselect = [\n  \"B\",  # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b\n  \"C90\",  # mccabe (complexity) - https://docs.astral.sh/ruff/rules/#mccabe-c90\n  \"E\",  # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e\n  \"F\",  # pyflakes - https://docs.astral.sh/ruff/rules/#pyflakes-f\n  \"N\",  # pep8-naming - https://docs.astral.sh/ruff/rules/#pep8-naming-n\n  \"PIE\",  # flake8-pie - https://docs.astral.sh/ruff/rules/#flake8-pie-pie\n  \"TID\", # flake8-tidy-imports - https://docs.astral.sh/ruff/rules/#flake8-tidy-imports-tid\n  \"W\",  # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w\n  \"NPY\",         # NumPy specific rules\n  \"UP\",          # pyupgrade\n  \"RUF100\",      # aka yesqa\n]\nextend-ignore = [\n  \"E741\", # ambiguous variable name\n  \"F401\", # unused import\n  \"NPY002\", # numpy-legacy-random\n  \"E203\", # whitespace before ':' (pycodestyle)\n  \"E501\", # line too long (pycodestyle)\n  \"C408\", # unnecessary collection call (flake8-comprehensions)\n  \"N812\", # lowercase imported as non lowercase (pep8-naming)\n  \"B023\", # function uses loop variable (flake8-bugbear)\n  \"B905\", # zip() without an explicit strict= parameter (flake8-bugbear)\n  \"B028\", # no explicit stacklevel keyword argument found (flake8-bugbear)\n]\n\n[tool.ruff.lint.per-file-ignores]\n\"tests/**\" = [\n  \"B018\",\n  \"C901\",\n  \"N999\",\n  \"N801\"\n]\n\"monai/apps/detection/utils/ATSS_matcher.py\" = [\n  \"N999\"\n]\n\n[tool.ruff.lint.mccabe]\nmax-complexity = 50  # todo lower this treshold when yesqa id replaced with Ruff's RUF100\n\n[tool.pytype]\n# Space-separated list of files or directories to exclude.\nexclude = [\"versioneer.py\", \"_version.py\"]\n# Space-separated list of files or directories to process.\ninputs = [\"monai\"]\n# Keep going past errors to analyze as many files as possible.\nkeep_going = true\n# Run N jobs in parallel.\njobs = 8\n# All pytype output goes here.\noutput = \".pytype\"\n# Paths to source code directories, separated by ':'.\npythonpath = \".\"\n# Check attribute values against their annotations.\ncheck_attribute_types = true\n# Check container mutations against their annotations.\ncheck_container_types = true\n# Check parameter defaults and assignments against their annotations.\ncheck_parameter_types = true\n# Check variable values against their annotations.\ncheck_variable_types = true\n# Comma or space separated list of error names to ignore.\ndisable = [\"pyi-error\"]\n# Report errors.\nreport_errors = true\n# Experimental: Infer precise return types even for invalid function calls.\nprecise_return = true\n# Experimental: solve unknown types to label with structural types.\nprotocols = true\n# Experimental: Only load submodules that are explicitly imported.\nstrict_import = false\n"
  },
  {
    "path": "requirements-dev.txt",
    "content": "# Full requirements for developments\n-r requirements-min.txt\npytorch-ignite\ngdown>=4.7.3\nscipy>=1.12.0; python_version >= '3.9'\nitk>=5.2\nnibabel\npillow!=8.3.0  # https://github.com/python-pillow/Pillow/issues/5571\ntensorboard>=2.12.0  # https://github.com/Project-MONAI/MONAI/issues/7434\nscikit-image>=0.19.0\ntqdm>=4.47.0\nlmdb\nmccabe\npep8-naming\npycodestyle\npyflakes\nblack>=25.1.0\nisort>=5.1, !=6.0.0\nruff\npytype>=2020.6.1, <=2024.4.11; platform_system != \"Windows\"\ntypes-setuptools\nmypy>=1.5.0, <1.12.0\nninja\ntorchio\ntorchvision\npsutil\ncucim-cu12; platform_system == \"Linux\" and python_version >= \"3.9\" and python_version <= \"3.10\"\nopenslide-python\nopenslide-bin\nimagecodecs; platform_system == \"Linux\" or platform_system == \"Darwin\"\ntifffile; platform_system == \"Linux\" or platform_system == \"Darwin\"\npandas\nrequests\neinops\ntransformers>=4.53.0\nmlflow>=2.12.2\nclearml>=1.10.0rc0\nmatplotlib>=3.6.3\ntensorboardX\ntypes-PyYAML\npyyaml\nfire\njsonschema\npynrrd\npre-commit\npydicom\nh5py\nnni==2.10.1; platform_system == \"Linux\" and \"arm\" not in platform_machine and \"aarch\" not in platform_machine\noptuna\ngit+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded\nonnx>=1.13.0\nonnxscript\nonnxruntime\ntypeguard<3  # https://github.com/microsoft/nni/issues/5457\nfilelock<3.12.0  # https://github.com/microsoft/nni/issues/5523\nzarr\nlpips==0.1.4\nnvidia-ml-py\nhuggingface_hub\npyamg>=5.0.0, <5.3.0\ngit+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588\nonnx_graphsurgeon\npolygraphy\n"
  },
  {
    "path": "requirements-min.txt",
    "content": "# Requirements for minimal tests\n-r requirements.txt\nsetuptools>=50.3.0,<66.0.0,!=60.6.0 ; python_version < \"3.12\"\nsetuptools>=70.2.0,<=79.0.1; python_version >= \"3.12\"\ncoverage>=5.5\nparameterized\npackaging\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch>=2.4.1; platform_system != \"Windows\"\ntorch>=2.4.1, !=2.7.0; platform_system == \"Windows\"\nnumpy>=1.24,<3.0\n"
  },
  {
    "path": "runtests.sh",
    "content": "#! /bin/bash\n\n# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# script for running all tests\nset -e\n\n# output formatting\nseparator=\"\"\nblue=\"\"\ngreen=\"\"\nred=\"\"\nnoColor=\"\"\n\nif [[ -t 1 ]] # stdout is a terminal\nthen\n    separator=$'--------------------------------------------------------------------------------\\n'\n    blue=\"$(tput bold; tput setaf 4)\"\n    green=\"$(tput bold; tput setaf 2)\"\n    red=\"$(tput bold; tput setaf 1)\"\n    noColor=\"$(tput sgr0)\"\nfi\n\n# configuration values\ndoCoverage=false\ndoQuickTests=false\ndoMinTests=false\ndoNetTests=false\ndoDryRun=false\ndoZooTests=false\ndoUnitTests=false\ndoBuild=false\ndoBlackFormat=false\ndoBlackFix=false\ndoIsortFormat=false\ndoIsortFix=false\ndoPylintFormat=false\ndoRuffFormat=false\ndoRuffFix=false\ndoClangFormat=false\ndoCopyRight=false\ndoPytypeFormat=false\ndoMypyFormat=false\ndoCleanup=false\ndoDistTests=false\ndoPrecommit=false\n\nNUM_PARALLEL=1\n\nPY_EXE=${MONAI_PY_EXE:-$(which python)}\n\nfunction print_usage {\n    echo \"runtests.sh [--codeformat] [--autofix] [--black] [--isort] [--pylint] [--ruff]\"\n    echo \"            [--clangformat] [--precommit] [--pytype] [-j number] [--mypy]\"\n    echo \"            [--unittests] [--disttests] [--coverage] [--quick] [--min] [--net] [--build] [--list_tests]\"\n    echo \"            [--dryrun] [--copyright] [--clean] [--help] [--version] [--path] [--formatfix]\"\n    echo \"\"\n    echo \"MONAI unit testing utilities.\"\n    echo \"\"\n    echo \"Examples:\"\n    echo \"./runtests.sh -f -u --net --coverage  # run style checks, full tests, print code coverage (${green}recommended for pull requests${noColor}).\"\n    echo \"./runtests.sh -f -u                   # run style checks and unit tests.\"\n    echo \"./runtests.sh -f                      # run coding style and static type checking.\"\n    echo \"./runtests.sh --quick --unittests     # run minimal unit tests, for quick verification during code developments.\"\n    echo \"./runtests.sh --autofix               # run automatic code formatting using \\\"isort\\\" and \\\"black\\\".\"\n    echo \"./runtests.sh --clean                 # clean up temporary files and run \\\"${PY_EXE} -m pip uninstall -y monai\\\".\"\n    echo \"./runtests.sh --formatfix -p /my/code # run automatic code formatting using \\\"isort\\\" and \\\"black\\\" in specified path.\"\n    echo \"\"\n    echo \"Code style check options:\"\n    echo \"    --autofix         : format code using \\\"isort\\\" and \\\"black\\\"\"\n    echo \"    --black           : perform \\\"black\\\" code format checks\"\n    echo \"    --isort           : perform \\\"isort\\\" import sort checks\"\n    echo \"    --pylint          : perform \\\"pylint\\\" code format checks\"\n    echo \"    --ruff            : perform \\\"ruff\\\" code format checks\"\n    echo \"    --flake8          : perform \\\"ruff\\\" code format checks (deprecated alias for --ruff)\"\n    echo \"    --clangformat     : format csrc code using \\\"clang-format\\\"\"\n    echo \"    --precommit       : perform source code format check and fix using \\\"pre-commit\\\"\"\n    echo \"\"\n    echo \"Python type check options:\"\n    echo \"    --pytype          : perform \\\"pytype\\\" static type checks\"\n    echo \"    -j, --jobs        : number of parallel jobs to run \\\"pytype\\\" (default $NUM_PARALLEL)\"\n    echo \"    --mypy            : perform \\\"mypy\\\" static type checks\"\n    echo \"\"\n    echo \"MONAI unit testing options:\"\n    echo \"    -u, --unittests   : perform unit testing\"\n    echo \"    --disttests       : perform distributed unit testing\"\n    echo \"    --coverage        : report testing code coverage, to be used with \\\"--net\\\", \\\"--unittests\\\"\"\n    echo \"    -q, --quick       : skip long running unit tests and integration tests\"\n    echo \"    -m, --min         : only run minimal unit tests which do not require optional packages\"\n    echo \"    --net             : perform integration testing\"\n    echo \"    -b, --build       : compile and install the source code folder an editable release.\"\n    echo \"    --list_tests      : list unit tests and exit\"\n    echo \"\"\n    echo \"Misc. options:\"\n    echo \"    --dryrun          : display the commands to the screen without running\"\n    echo \"    --copyright       : check whether every source code has a copyright header\"\n    echo \"    -f, --codeformat  : shorthand to run all code style and static analysis tests\"\n    echo \"    -c, --clean       : clean temporary files from tests and exit\"\n    echo \"    -h, --help        : show this help message and exit\"\n    echo \"    -v, --version     : show MONAI and system version information and exit\"\n    echo \"    -p, --path        : specify the path used for formatting, default is the current dir if unspecified\"\n    echo \"    --formatfix       : format code using \\\"isort\\\" and \\\"black\\\" for user specified directories\"\n    echo \"\"\n    echo \"${separator}For bug reports and feature requests, please file an issue at:\"\n    echo \"    https://github.com/Project-MONAI/MONAI/issues/new/choose\"\n    echo \"\"\n    echo \"To choose an alternative python executable, set the environmental variable, \\\"MONAI_PY_EXE\\\".\"\n    exit 1\n}\n\n# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354\nprotobuf_major_version=$(\"${PY_EXE}\" -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)\nif [ ! -z \"$protobuf_major_version\" ] && [ \"$protobuf_major_version\" -ge \"4\" ]\nthen\n    export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\nfi\n\nfunction check_import {\n    echo \"Python: \"${PY_EXE}\"\"\n    ${cmdPrefix}\"${PY_EXE}\" -W error -W ignore::DeprecationWarning -W ignore::ResourceWarning -c \"import monai\"\n}\n\nfunction print_version {\n    ${cmdPrefix}\"${PY_EXE}\" -c 'import monai; monai.config.print_config()'  # project-monai/monai#6167\n}\n\nfunction install_deps {\n    echo \"Pip installing MONAI development dependencies and compile MONAI cpp extensions...\"\n    ${cmdPrefix}\"${PY_EXE}\" -m pip install --no-build-isolation -r requirements-dev.txt\n}\n\nfunction compile_cpp {\n    echo \"Compiling and installing MONAI cpp extensions...\"\n    # depends on setup.py behaviour for building\n    # currently setup.py uses environment variables: BUILD_MONAI and FORCE_CUDA\n    ${cmdPrefix}\"${PY_EXE}\" -m pip uninstall -y monai\n    if [[ \"$OSTYPE\" == \"darwin\"* ]];\n    then  # clang for mac os\n        BUILD_MONAI=1 CC=clang CXX=clang++ ${cmdPrefix}\"${PY_EXE}\" -m pip install -e .\n    else\n        BUILD_MONAI=1 ${cmdPrefix}\"${PY_EXE}\" -m pip install -e .\n    fi\n}\n\nfunction clang_format {\n    echo \"Running clang-format...\"\n    ${cmdPrefix}\"${PY_EXE}\" -m tests.clang_format_utils\n    clang_format_tool='.clang-format-bin/clang-format'\n    # Verify .\n    if ! type -p \"$clang_format_tool\" >/dev/null; then\n        echo \"'clang-format' not found, skipping the formatting.\"\n        exit 1\n    fi\n    find monai/csrc -type f | while read i; do $clang_format_tool -style=file -i $i; done\n    find monai/_extensions -type f -name \"*.cpp\" -o -name \"*.h\" -o -name \"*.cuh\" -o -name \"*.cu\" |\\\n        while read i; do $clang_format_tool -style=file -i $i; done\n}\n\nfunction is_pip_installed() {\n\treturn $(\"${PY_EXE}\" -c \"import sys, importlib.util; sys.exit(0 if importlib.util.find_spec(sys.argv[1]) else 1)\" $1)\n}\n\nfunction clean_py {\n    if is_pip_installed coverage\n    then\n      # remove coverage history\n      ${cmdPrefix}\"${PY_EXE}\" -m coverage erase\n    fi\n\n    # uninstall the development package\n    echo \"Uninstalling MONAI development files...\"\n    ${cmdPrefix}\"${PY_EXE}\" -m pip uninstall -y monai\n\n    # remove temporary files (in the directory of this script)\n    TO_CLEAN=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" >/dev/null 2>&1 && pwd )\"\n    echo \"Removing temporary files in ${TO_CLEAN}\"\n\n    find ${TO_CLEAN}/monai -type f -name \"*.py[co]\" -delete\n    find ${TO_CLEAN}/monai -type f -name \"*.so\" -delete\n    find ${TO_CLEAN}/monai -type d -name \"__pycache__\" -delete\n    find ${TO_CLEAN} -maxdepth 1 -type f -name \".coverage.*\" -delete\n\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \".eggs\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \"monai.egg-info\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \"build\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \"dist\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \".mypy_cache\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \".pytype\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \".coverage\" -exec rm -r \"{}\" +\n    find ${TO_CLEAN} -depth -maxdepth 1 -type d -name \"__pycache__\" -exec rm -r \"{}\" +\n}\n\nfunction torch_validate {\n    ${cmdPrefix}\"${PY_EXE}\" -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'\n}\n\nfunction print_error_msg() {\n    echo \"${red}Error: $1.${noColor}\"\n    echo \"\"\n}\n\nfunction print_style_fail_msg() {\n    echo \"${red}Check failed!${noColor}\"\n    if [ \"$homedir\" = \"$currentdir\" ]\n    then\n        echo \"Please run auto style fixes: ${green}./runtests.sh --autofix${noColor}\"\n    else :\n    fi\n}\n\nfunction list_unittests() {\n    \"${PY_EXE}\" - << END\nimport unittest\ndef print_suite(suite):\n    if hasattr(suite, \"__iter__\"):\n        for x in suite:\n            print_suite(x)\n    else:\n        print(suite)\nprint_suite(unittest.defaultTestLoader.discover('./tests'))\nEND\n    exit 0\n}\n\nif [ -z \"$1\" ]\nthen\n    print_error_msg \"Too few arguments to $0\"\n    print_usage\nfi\n\n# parse arguments\nwhile [[ $# -gt 0 ]]\ndo\n    key=\"$1\"\n    case $key in\n        --coverage)\n            doCoverage=true\n        ;;\n        -q|--quick)\n            doQuickTests=true\n        ;;\n        -m|--min)\n            doMinTests=true\n        ;;\n        --net)\n            doNetTests=true\n        ;;\n        --list_tests)\n            list_unittests\n        ;;\n        --dryrun)\n            doDryRun=true\n        ;;\n        -u|--u*)  # allow --unittest | --unittests | --unittesting  etc.\n            doUnitTests=true\n        ;;\n        -f|--codeformat)\n            doBlackFormat=true\n            doIsortFormat=true\n            # doPylintFormat=true  # https://github.com/Project-MONAI/MONAI/issues/7094\n            doRuffFormat=true\n            doCopyRight=true\n        ;;\n        --disttests)\n            doDistTests=true\n        ;;\n        --black)\n            doBlackFormat=true\n        ;;\n        --autofix)\n            doIsortFix=true\n            doBlackFix=true\n            doRuffFix=true\n            doIsortFormat=true\n            doBlackFormat=true\n            doRuffFormat=true\n            doCopyRight=true\n        ;;\n        --formatfix)\n            doIsortFix=true\n            doBlackFix=true\n            doIsortFormat=true\n            doBlackFormat=true\n        ;;\n        --clangformat)\n            doClangFormat=true\n        ;;\n        --isort)\n            doIsortFormat=true\n        ;;\n        --pylint)\n            doPylintFormat=true\n        ;;\n        --ruff)\n            doRuffFormat=true\n        ;;\n        --flake8)\n            echo \"${red}warning: --flake8 is deprecated, please use --ruff instead.${noColor}\"\n            doRuffFormat=true\n        ;;\n        --precommit)\n            doPrecommit=true\n        ;;\n        --pytype)\n            doPytypeFormat=true\n        ;;\n        --mypy)\n            doMypyFormat=true\n        ;;\n        -j|--jobs)\n            NUM_PARALLEL=$2\n            shift\n        ;;\n        --copyright)\n            doCopyRight=true\n        ;;\n        -b|--build)\n            doBuild=true\n        ;;\n        -c|--clean)\n            doCleanup=true\n        ;;\n        -h|--help)\n            print_usage\n        ;;\n        -v|--version)\n            print_version\n            exit 1\n        ;;\n        --nou*)  # allow --nounittest | --nounittests | --nounittesting  etc.\n            print_error_msg \"nounittest option is deprecated, no unit tests is the default setting\"\n            print_usage\n        ;;\n        -p|--path)\n            testdir=$2\n            shift\n        ;;\n        *)\n            print_error_msg \"Incorrect commandline provided, invalid key: $key\"\n            print_usage\n        ;;\n    esac\n    shift\ndone\n\n# home directory\ncurrentdir=\"$( cd -P \"$( dirname \"${BASH_SOURCE[0]}\" )\" && pwd )\"\nif [ -e \"$testdir\" ]\nthen\n    homedir=$testdir\nelse\n    homedir=$currentdir\nfi\necho \"Run tests under $homedir\"\ncd \"$homedir\"\n\n# python path\nexport PYTHONPATH=\"$homedir:$PYTHONPATH\"\necho \"PYTHONPATH: $PYTHONPATH\"\n\n# by default do nothing\ncmdPrefix=\"\"\n\nif [ $doDryRun = true ]\nthen\n    echo \"${separator}${blue}dryrun${noColor}\"\n\n    # commands are echoed instead of ran\n    cmdPrefix=\"dryrun \"\n    function dryrun { echo \"    \" \"$@\"; }\nelse\n    check_import\nfi\n\nif [ $doBuild = true ]\nthen\n    echo \"${separator}${blue}compile and install${noColor}\"\n    # try to compile MONAI cpp\n    compile_cpp\n\n    echo \"${green}done! (to uninstall and clean up, please use \\\"./runtests.sh --clean\\\")${noColor}\"\nfi\n\nif [ $doCleanup = true ]\nthen\n    echo \"${separator}${blue}clean${noColor}\"\n\n    clean_py\n\n    echo \"${green}done!${noColor}\"\n    exit\nfi\n\nif [ $doClangFormat = true ]\nthen\n    echo \"${separator}${blue}clang-formatting${noColor}\"\n\n    clang_format\n\n    echo \"${green}done!${noColor}\"\nfi\n\n# unconditionally report on the state of monai\nprint_version\n\nif [ $doCopyRight = true ]\nthen\n    # check copyright headers\n    copyright_bad=0\n    copyright_all=0\n    while read -r fname; do\n        copyright_all=$((copyright_all + 1))\n        if ! grep \"http://www.apache.org/licenses/LICENSE-2.0\" \"$fname\" > /dev/null; then\n            print_error_msg \"Missing the license header in file: $fname\"\n            copyright_bad=$((copyright_bad + 1))\n        fi\n    done <<< \"$(find \"$(pwd)/monai\" \"$(pwd)/tests\" -type f \\\n        ! -wholename \"*_version.py\" -and -name \"*.py\" -or -name \"*.cpp\" -or -name \"*.cu\" -or -name \"*.h\")\"\n    if [[ ${copyright_bad} -eq 0 ]];\n    then\n        echo \"${green}Source code copyright headers checked ($copyright_all).${noColor}\"\n    else\n        echo \"Please add the licensing header to the file ($copyright_bad of $copyright_all files).\"\n        echo \"  See also: https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md#checking-the-coding-style\"\n        echo \"\"\n        exit 1\n    fi\nfi\n\n\nif [ $doPrecommit = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    echo \"${separator}${blue}pre-commit${noColor}\"\n\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed pre_commit\n    then\n        install_deps\n    fi\n    ${cmdPrefix}\"${PY_EXE}\" -m pre_commit run --all-files\n\n    pre_commit_status=$?\n    if [ ${pre_commit_status} -ne 0 ]\n    then\n        print_style_fail_msg\n        exit ${pre_commit_status}\n    else\n        echo \"${green}passed!${noColor}\"\n    fi\n    set -e # enable exit on failure\nfi\n\n\nif [ $doIsortFormat = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    if [ $doIsortFix = true ]\n    then\n        echo \"${separator}${blue}isort-fix${noColor}\"\n    else\n        echo \"${separator}${blue}isort${noColor}\"\n    fi\n\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed isort\n    then\n        install_deps\n    fi\n    ${cmdPrefix}\"${PY_EXE}\" -m isort --version\n\n    if [ $doIsortFix = true ]\n    then\n        ${cmdPrefix}\"${PY_EXE}\" -m isort \"$homedir\"\n    else\n        ${cmdPrefix}\"${PY_EXE}\" -m isort --check \"$homedir\"\n    fi\n\n    isort_status=$?\n    if [ ${isort_status} -ne 0 ]\n    then\n        print_style_fail_msg\n        exit ${isort_status}\n    else\n        echo \"${green}passed!${noColor}\"\n    fi\n    set -e # enable exit on failure\nfi\n\n\nif [ $doBlackFormat = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    if [ $doBlackFix = true ]\n    then\n        echo \"${separator}${blue}black-fix${noColor}\"\n    else\n        echo \"${separator}${blue}black${noColor}\"\n    fi\n\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed black\n    then\n        install_deps\n    fi\n    ${cmdPrefix}\"${PY_EXE}\" -m black --version\n\n    if [ $doBlackFix = true ]\n    then\n        ${cmdPrefix}\"${PY_EXE}\" -m black --skip-magic-trailing-comma \"$homedir\"\n    else\n        ${cmdPrefix}\"${PY_EXE}\" -m black --skip-magic-trailing-comma --check \"$homedir\"\n    fi\n\n    black_status=$?\n    if [ ${black_status} -ne 0 ]\n    then\n        print_style_fail_msg\n        exit ${black_status}\n    else\n        echo \"${green}passed!${noColor}\"\n    fi\n    set -e # enable exit on failure\nfi\n\nif [ $doPylintFormat = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    echo \"${separator}${blue}pylint${noColor}\"\n\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed pylint\n    then\n        echo \"Pip installing pylint ...\"\n        ${cmdPrefix}\"${PY_EXE}\" -m pip install \"pylint>2.16,!=3.0.0\"\n    fi\n    ${cmdPrefix}\"${PY_EXE}\" -m pylint --version\n\n    ignore_codes=\"C,R,W,E1101,E1102,E0601,E1130,E1123,E0102,E1120,E1137,E1136\"\n    ${cmdPrefix}\"${PY_EXE}\" -m pylint monai tests --disable=$ignore_codes -j $NUM_PARALLEL\n    pylint_status=$?\n\n    if [ ${pylint_status} -ne 0 ]\n    then\n        print_style_fail_msg\n        exit ${pylint_status}\n    else\n        echo \"${green}passed!${noColor}\"\n    fi\n    set -e # enable exit on failure\nfi\n\n\nif [ $doRuffFormat = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    if [ $doRuffFix = true ]\n    then\n        echo \"${separator}${blue}ruff-fix${noColor}\"\n    else\n        echo \"${separator}${blue}ruff${noColor}\"\n    fi\n\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed ruff\n    then\n        install_deps\n    fi\n    ruff --version\n\n    if [ $doRuffFix = true ]\n    then\n        ruff check --fix --unsafe-fixes --exclude versioneer.py --exclude \"monai/_version.py\" \"$homedir\"\n    else\n        ruff check --exclude versioneer.py --exclude \"monai/_version.py\" \"$homedir\"\n    fi\n\n    ruff_status=$?\n    if [ ${ruff_status} -ne 0 ]\n    then\n        print_style_fail_msg\n        exit ${ruff_status}\n    else\n        echo \"${green}passed!${noColor}\"\n    fi\n    set -e # enable exit on failure\nfi\n\n\nif [ $doPytypeFormat = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    echo \"${separator}${blue}pytype${noColor}\"\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed pytype\n    then\n        install_deps\n    fi\n    pytype_ver=$(${cmdPrefix}\"${PY_EXE}\" -m pytype --version)\n    if [[ \"$OSTYPE\" == \"darwin\"* && \"$pytype_ver\" == \"2021.\"* ]]; then\n        echo \"${red}pytype not working on macOS 2021 (https://github.com/Project-MONAI/MONAI/issues/2391). Please upgrade to 2022*.${noColor}\"\n        exit 1\n    else\n        ${cmdPrefix}\"${PY_EXE}\" -m pytype --version\n\n        ${cmdPrefix}\"${PY_EXE}\" -m pytype -j ${NUM_PARALLEL} --python-version=\"$(${PY_EXE} -c \"import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')\")\" \"$homedir\"\n\n        pytype_status=$?\n        if [ ${pytype_status} -ne 0 ]\n        then\n            echo \"${red}failed!${noColor}\"\n            exit ${pytype_status}\n        else\n            echo \"${green}passed!${noColor}\"\n        fi\n    fi\n    set -e # enable exit on failure\nfi\n\n\nif [ $doMypyFormat = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    echo \"${separator}${blue}mypy${noColor}\"\n\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed mypy\n    then\n        install_deps\n    fi\n    ${cmdPrefix}\"${PY_EXE}\" -m mypy --version\n    ${cmdPrefix}\"${PY_EXE}\" -m mypy \"$homedir\"\n\n    mypy_status=$?\n    if [ ${mypy_status} -ne 0 ]\n    then\n        : # mypy output already follows format\n        exit ${mypy_status}\n    else\n        : # mypy output already follows format\n    fi\n    set -e # enable exit on failure\nfi\n\n\n# testing command to run\ncmd=\"${PY_EXE}\"\n\n# When running --quick, require doCoverage as well and set QUICKTEST environmental\n# variable to disable slow unit tests from running.\nif [ $doQuickTests = true ]\nthen\n    echo \"${separator}${blue}quick${noColor}\"\n    doCoverage=true\n    export QUICKTEST=True\nfi\n\nif [ $doMinTests = true ]\nthen\n    echo \"${separator}${blue}min${noColor}\"\n    doCoverage=false\n    ${cmdPrefix}\"${PY_EXE}\" -m tests.min_tests\nfi\n\n# set coverage command\nif [ $doCoverage = true ]\nthen\n    echo \"${separator}${blue}coverage${noColor}\"\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed coverage\n    then\n        install_deps\n    fi\n    cmd=\"\"${PY_EXE}\" -m coverage run --append\"\nfi\n\n# # download test data if needed\n# if [ ! -d testing_data ] && [ \"$doDryRun\" != 'true' ]\n# then\n# fi\n\n# unit tests\n# TODO: temp skip test_perceptual_loss, revert after #8652 merged\n# TODO: temp skip test_auto3dseg_ensemble, revert after #8737 resolved\nif [ $doUnitTests = true ]\nthen\n    echo \"${separator}${blue}unittests${noColor}\"\n    torch_validate\n    ${cmdPrefix}${cmd} ./tests/runner.py -p \"^(?!test_integration|test_perceptual_loss|test_auto3dseg_ensemble).*(?<!_dist)$\"  # excluding integration/dist/perceptual_loss tests\nfi\n\n# distributed test only\nif [ $doDistTests = true ]\nthen\n    echo \"${separator}${blue}run distributed unit test cases${noColor}\"\n    torch_validate\n    for i in  $(find ./tests/ -name \"*_dist.py\")\n    do\n        echo \"$i\"\n        ${cmdPrefix}${cmd} \"$i\"\n    done\nfi\n\n# network training/inference/eval integration tests\nif [ $doNetTests = true ]\nthen\n    set +e  # disable exit on failure so that diagnostics can be given on failure\n    echo \"${separator}${blue}integration${noColor}\"\n    for i in tests/integration/*.py\n    do\n        echo \"$i\"\n        ${cmdPrefix}${cmd} \"$i\"\n    done\n    set -e # enable exit on failure\nfi\n\n# run model zoo tests\nif [ $doZooTests = true ]\nthen\n    echo \"${separator}${blue}zoo${noColor}\"\n    print_error_msg \"--zoo option not yet implemented\"\n    exit 255\nfi\n\n# report on coverage\nif [ $doCoverage = true ]\nthen\n    echo \"${separator}${blue}coverage${noColor}\"\n    # ensure that the necessary packages for code format testing are installed\n    if ! is_pip_installed coverage\n    then\n        install_deps\n    fi\n    ${cmdPrefix}\"${PY_EXE}\" -m coverage combine --append .coverage/\n    ${cmdPrefix}\"${PY_EXE}\" -m coverage report --ignore-errors\nfi\n"
  },
  {
    "path": "setup.cfg",
    "content": "[metadata]\nname = monai\nauthor = MONAI Consortium\nauthor_email = monai.contact@gmail.com\nurl = https://project-monai.github.io/\ndescription = AI Toolkit for Healthcare Imaging\nlong_description = file:README.md\nlong_description_content_type = text/markdown; charset=UTF-8\nplatforms = OS Independent\nlicense = Apache License 2.0\nlicense_files =\n    LICENSE\nproject_urls =\n    Documentation=https://monai.readthedocs.io/\n    Bug Tracker=https://github.com/Project-MONAI/MONAI/issues\n    Source Code=https://github.com/Project-MONAI/MONAI\nclassifiers =\n    Intended Audience :: Developers\n    Intended Audience :: Education\n    Intended Audience :: Science/Research\n    Intended Audience :: Healthcare Industry\n    Programming Language :: C++\n    Programming Language :: Python :: 3\n    Programming Language :: Python :: 3.9\n    Programming Language :: Python :: 3.10\n    Programming Language :: Python :: 3.11\n    Topic :: Scientific/Engineering\n    Topic :: Scientific/Engineering :: Artificial Intelligence\n    Topic :: Scientific/Engineering :: Medical Science Apps.\n    Topic :: Scientific/Engineering :: Information Analysis\n    Topic :: Software Development\n    Topic :: Software Development :: Libraries\n    Typing :: Typed\n\n[options]\npython_requires = >= 3.9\n# for compiling and develop setup only\n# no need to specify the versions so that we could\n# compile for multiple targeted versions.\nsetup_requires =\n    torch\n    ninja\n    packaging\ninstall_requires =\n    torch>=2.4.1\n    numpy>=1.24,<3.0\n\n[options.extras_require]\nall =\n    nibabel\n    ninja\n    scikit-image>=0.14.2\n    scipy>=1.12.0; python_version >= '3.9'\n    pillow\n    tensorboard\n    gdown>=4.7.3\n    pytorch-ignite==0.4.11\n    torchio\n    torchvision\n    itk>=5.2\n    tqdm>=4.47.0\n    lmdb\n    psutil\n    cucim-cu12; platform_system == \"Linux\" and python_version >= '3.9' and python_version <= '3.10'\n    cucim-cu13; platform_system == \"Linux\" and python_version >= '3.11'\n    openslide-python\n    openslide-bin\n    tifffile; platform_system == \"Linux\" or platform_system == \"Darwin\"\n    imagecodecs; platform_system == \"Linux\" or platform_system == \"Darwin\"\n    pandas\n    einops\n    transformers>=4.36.0, <4.41.0; python_version <= '3.10'\n    mlflow>=2.12.2\n    clearml>=1.10.0rc0\n    matplotlib>=3.6.3\n    tensorboardX\n    pyyaml\n    fire\n    jsonschema\n    pynrrd\n    pydicom\n    h5py\n    nni; platform_system == \"Linux\" and \"arm\" not in platform_machine and \"aarch\" not in platform_machine\n    optuna\n    onnx>=1.13.0\n    onnxruntime; python_version <= '3.10'\n    zarr\n    lpips==0.1.4\n    nvidia-ml-py\n    huggingface_hub\n    pyamg>=5.0.0, <5.3.0\nnibabel =\n    nibabel\nninja =\n    ninja\nskimage =\n    scikit-image>=0.14.2\nscipy =\n    scipy>=1.12.0; python_version >= '3.9'\npillow =\n    pillow!=8.3.0\ntensorboard =\n    tensorboard\ngdown =\n    gdown>=4.7.3\nignite =\n    pytorch-ignite==0.4.11\ntorchio =\n    torchio\ntorchvision =\n    torchvision\nitk =\n    itk>=5.2\ntqdm =\n    tqdm>=4.47.0\nlmdb =\n    lmdb\npsutil =\n    psutil\ncucim =\n    cucim-cu12; platform_system == \"Linux\" and python_version >= '3.9' and python_version <= '3.10'\n    cucim-cu13; platform_system == \"Linux\" and python_version >= '3.11'\nopenslide =\n    openslide-python\n    openslide-bin\ntifffile =\n    tifffile; platform_system == \"Linux\" or platform_system == \"Darwin\"\nimagecodecs =\n    imagecodecs; platform_system == \"Linux\" or platform_system == \"Darwin\"\npandas =\n    pandas\neinops =\n    einops\ntransformers =\n    transformers>=4.36.0, <4.41.0; python_version <= '3.10'\nmlflow =\n    mlflow>=2.12.2\nmatplotlib =\n    matplotlib>=3.6.3\nclearml =\n    clearml\ntensorboardX =\n    tensorboardX\npyyaml =\n    pyyaml\nfire =\n    fire\npackaging =\n    packaging\njsonschema =\n    jsonschema\npynrrd =\n    pynrrd\npydicom =\n    pydicom\nh5py =\n    h5py\nnni =\n    nni; platform_system == \"Linux\" and \"arm\" not in platform_machine and \"aarch\" not in platform_machine\noptuna =\n    optuna\nonnx =\n    onnx>=1.13.0\n    onnxruntime; python_version <= '3.10'\nzarr =\n    zarr\nlpips =\n    lpips==0.1.4\npynvml =\n    nvidia-ml-py\npolygraphy =\n    polygraphy\n\n# # workaround https://github.com/Project-MONAI/MONAI/issues/5882\n# MetricsReloaded =\n    # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded\nhuggingface_hub =\n    huggingface_hub\npyamg =\n    pyamg>=5.0.0, <5.3.0\n# segment-anything =\n#     segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything\n\n[isort]\nknown_first_party = monai\nprofile = black\nline_length = 120\nskip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py\nskip_glob = *.pyi\nadd_imports = from __future__ import annotations\nappend_only = true\n\n[versioneer]\nVCS = git\nstyle = pep440\nversionfile_source = monai/_version.py\nversionfile_build = monai/_version.py\ntag_prefix =\nparentdir_prefix =\n\n[mypy]\n# Suppresses error messages about imports that cannot be resolved.\nignore_missing_imports = True\n# Changes the treatment of arguments with a default value of None by not implicitly making their type Optional.\nno_implicit_optional = True\n# Warns about casting an expression to its inferred type.\nwarn_redundant_casts = True\n# No error on unneeded # type: ignore comments.\nwarn_unused_ignores = False\n# Shows a warning when returning a value with type Any from a function declared with a non-Any return type.\nwarn_return_any = True\n# Prohibit equality checks, identity checks, and container checks between non-overlapping types.\nstrict_equality = True\n# Shows column numbers in error messages.\nshow_column_numbers = True\n# Shows error codes in error messages.\nshow_error_codes = True\n# Use visually nicer output in error messages: use soft word wrap, show source code snippets, and show error location markers.\npretty = False\n# Warns about per-module sections in the config file that do not match any files processed when invoking mypy.\nwarn_unused_configs = True\n# Make arguments prepended via Concatenate be truly positional-only.\nextra_checks = True\n# Allows variables to be redefined with an arbitrary type,\n# as long as the redefinition is in the same block and nesting level as the original definition.\n# allow_redefinition = True\n\nexclude = venv/\n\n[mypy-versioneer]\n# Ignores all non-fatal errors.\nignore_errors = True\n\n[mypy-monai._version]\n# Ignores all non-fatal errors.\nignore_errors = True\n\n[mypy-monai.eggs]\n# Ignores all non-fatal errors.\nignore_errors = True\n\n[mypy-monai.*]\n# Also check the body of functions with no types in their type signature.\ncheck_untyped_defs = True\n# Warns about usage of untyped decorators.\ndisallow_untyped_decorators = True\n\n[mypy-monai.visualize.*,monai.utils.*,monai.optimizers.*,monai.losses.*,monai.inferers.*,monai.config.*,monai._extensions.*,monai.fl.*,monai.engines.*,monai.handlers.*,monai.auto3dseg.*,monai.bundle.*,monai.metrics.*,monai.apps.*]\ndisallow_incomplete_defs = True\n\n[coverage:run]\nconcurrency = multiprocessing\nsource = .\ndata_file = .coverage/.coverage\nomit = setup.py\n\n[coverage:report]\nexclude_lines =\n    pragma: no cover\n    if TYPE_CHECKING:\n    # Don't complain if tests don't hit code:\n    raise NotImplementedError\n    if __name__ == .__main__.:\nshow_missing = True\nskip_covered = True\n\n[coverage:xml]\noutput = coverage.xml\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport os\nimport re\nimport sys\nimport warnings\n\nfrom packaging import version\nfrom setuptools import find_packages, setup\n\nimport versioneer\n\n# TODO: debug mode -g -O0, compile test cases\n\nRUN_BUILD = os.getenv(\"BUILD_MONAI\", \"0\") == \"1\"\nFORCE_CUDA = os.getenv(\"FORCE_CUDA\", \"0\") == \"1\"  # flag ignored if BUILD_MONAI is False\n\nBUILD_CPP = BUILD_CUDA = False\nTORCH_VERSION = 0\ntry:\n    import torch\n\n    print(f\"setup.py with torch {torch.__version__}\")\n    from torch.utils.cpp_extension import BuildExtension, CppExtension\n\n    BUILD_CPP = True\n    from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension\n\n    BUILD_CUDA = FORCE_CUDA or (torch.cuda.is_available() and (CUDA_HOME is not None))\n\n    _pt_version = version.parse(torch.__version__).release\n    if _pt_version is None or len(_pt_version) < 3:\n        raise AssertionError(\"unknown torch version\")\n    TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2])\nexcept (ImportError, TypeError, AssertionError, AttributeError) as e:\n    warnings.warn(f\"extension build skipped: {e}\")\nfinally:\n    if not RUN_BUILD:\n        BUILD_CPP = BUILD_CUDA = False\n        print(\"Please set environment variable `BUILD_MONAI=1` to enable Cpp/CUDA extension build.\")\n    print(f\"BUILD_MONAI_CPP={BUILD_CPP}, BUILD_MONAI_CUDA={BUILD_CUDA}, TORCH_VERSION={TORCH_VERSION}.\")\n\n\ndef torch_parallel_backend():\n    try:\n        match = re.search(\"^ATen parallel backend: (?P<backend>.*)$\", torch._C._parallel_info(), re.MULTILINE)\n        if match is None:\n            return None\n        backend = match.group(\"backend\")\n        if backend == \"OpenMP\":\n            return \"AT_PARALLEL_OPENMP\"\n        if backend == \"native thread pool\":\n            return \"AT_PARALLEL_NATIVE\"\n        if backend == \"native thread pool and TBB\":\n            return \"AT_PARALLEL_NATIVE_TBB\"\n    except (NameError, AttributeError):  # no torch or no binaries\n        warnings.warn(\"Could not determine torch parallel_info.\")\n    return None\n\n\ndef omp_flags():\n    if sys.platform == \"win32\":\n        return [\"/openmp\"]\n    if sys.platform == \"darwin\":\n        # https://stackoverflow.com/questions/37362414/\n        # return [\"-fopenmp=libiomp5\"]\n        return []\n    return [\"-fopenmp\"]\n\n\ndef get_extensions():\n    this_dir = os.path.dirname(os.path.abspath(__file__))\n    ext_dir = os.path.join(this_dir, \"monai\", \"csrc\")\n    include_dirs = [ext_dir]\n\n    source_cpu = glob.glob(os.path.join(ext_dir, \"**\", \"*.cpp\"), recursive=True)\n    source_cuda = glob.glob(os.path.join(ext_dir, \"**\", \"*.cu\"), recursive=True)\n\n    extension = None\n    define_macros = [(f\"{torch_parallel_backend()}\", 1), (\"MONAI_TORCH_VERSION\", TORCH_VERSION)]\n    extra_compile_args = {}\n    extra_link_args = []\n    sources = source_cpu\n    if BUILD_CPP:\n        extension = CppExtension\n        extra_compile_args.setdefault(\"cxx\", [])\n        if torch_parallel_backend() == \"AT_PARALLEL_OPENMP\":\n            extra_compile_args[\"cxx\"] += omp_flags()\n        extra_link_args = omp_flags()\n    if BUILD_CUDA:\n        extension = CUDAExtension\n        sources += source_cuda\n        define_macros += [(\"WITH_CUDA\", None)]\n        extra_compile_args = {\"cxx\": [], \"nvcc\": []}\n        if torch_parallel_backend() == \"AT_PARALLEL_OPENMP\":\n            extra_compile_args[\"cxx\"] += omp_flags()\n    if extension is None or not sources:\n        return []  # compile nothing\n\n    ext_modules = [\n        extension(\n            name=\"monai._C\",\n            sources=sources,\n            include_dirs=include_dirs,\n            define_macros=define_macros,\n            extra_compile_args=extra_compile_args,\n            extra_link_args=extra_link_args,\n        )\n    ]\n    return ext_modules\n\n\ndef get_cmds():\n    cmds = versioneer.get_cmdclass()\n\n    if not (BUILD_CPP or BUILD_CUDA):\n        return cmds\n\n    cmds.update({\"build_ext\": BuildExtension.with_options(no_python_abi_suffix=True)})\n    return cmds\n\n\n# Gathering source used for JIT extensions to include in package_data.\njit_extension_source = []\n\nfor ext in [\"cpp\", \"cu\", \"h\", \"cuh\"]:\n    glob_path = os.path.join(\"monai\", \"_extensions\", \"**\", f\"*.{ext}\")\n    jit_extension_source += glob.glob(glob_path, recursive=True)\n\njit_extension_source = [os.path.join(\"..\", path) for path in jit_extension_source]\n\nsetup(\n    version=versioneer.get_version(),\n    cmdclass=get_cmds(),\n    packages=find_packages(exclude=(\"docs\", \"examples\", \"tests\")),\n    zip_safe=False,\n    package_data={\"monai\": [\"py.typed\", *jit_extension_source]},  # type: ignore[arg-type]\n    ext_modules=get_extensions(),\n)\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/deepedit/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/deepedit/test_deepedit_transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.deepedit.transforms import (\n    AddGuidanceFromPointsDeepEditd,\n    AddGuidanceSignalDeepEditd,\n    AddInitialSeedPointMissingLabelsd,\n    AddRandomGuidanceDeepEditd,\n    DiscardAddGuidanced,\n    FindAllValidSlicesMissingLabelsd,\n    FindDiscrepancyRegionsDeepEditd,\n    NormalizeLabelsInDatasetd,\n    RemapLabelsToSequentiald,\n    ResizeGuidanceMultipleLabelDeepEditd,\n    SingleLabelSelectiond,\n    SplitPredsLabeld,\n)\nfrom monai.utils import min_version, optional_import, set_determinism\nfrom monai.utils.enums import PostFix\n\nmeasure, _ = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\n\nset_determinism(seed=0)\nIMAGE = np.random.randint(0, 256, size=(1, 10, 10, 10))\nTHREE_CHAN_IMAGE = np.random.randint(0, 255, size=(3, 10, 10, 10))\nLABEL = np.random.randint(0, 2, size=(10, 10, 10))\nPRED = np.random.randint(0, 2, size=(10, 10, 10))\nLABEL_NAMES = {\"spleen\": 1, \"background\": 0}\nDISCREPANCY = {\n    \"spleen\": np.random.randint(0, 2, size=(10, 10, 10)),\n    \"background\": np.random.randint(0, 2, size=(10, 10, 10)),\n}\nset_determinism(None)\n\nDATA_1 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    PostFix.meta(\"label\"): {},\n}\n\nDATA_2 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    \"label_names\": LABEL_NAMES,\n    \"guidance\": {\"spleen\": [[3, 5, 4, 6], [-1, -1, -1, -1]], \"background\": [[-1, -1, -1, -1], [-1, -1, -1, -1]]},\n    \"discrepancy\": DISCREPANCY,\n    \"probability\": 1.0,\n}\n\nDATA_3 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    \"guidance\": {\n        \"spleen\": np.array([[1, 0, 2, 2], [-1, -1, -1, -1]]),\n        \"background\": np.array([[1, 0, 2, 2], [-1, -1, -1, -1]]),\n    },\n    \"probability\": 1.0,\n    \"label_names\": LABEL_NAMES,\n    \"pred\": PRED,\n}\n\nDATA_4 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"current_label\": \"spleen\",\n    \"probability\": 1.0,\n    \"label_names\": LABEL_NAMES,\n    \"spleen\": [[0, 4, 3], [0, 0, 3], [0, 1, 3]],\n    \"sids\": {\"spleen\": []},\n    \"pred\": PRED,\n}\n\nDATA_5 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"current_label\": \"spleen\",\n    \"probability\": 1.0,\n    \"label_names\": LABEL_NAMES,\n    \"sids\": {\"spleen\": [2, 3, 4], \"background\": [0, 1, 5]},\n}\n\nDATA_6 = {\n    \"image\": IMAGE,\n    \"label\": LABEL[None],\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"current_label\": \"spleen\",\n    \"label_names\": LABEL_NAMES,\n}\n\nDATA_7 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    \"pred\": PRED,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"current_label\": \"spleen\",\n    \"probability\": 1.0,\n    \"label_names\": LABEL_NAMES,\n    \"guidance\": {\n        \"spleen\": np.array([[1, 0, 2, 2], [-1, -1, -1, -1]]),\n        \"background\": np.array([[1, 0, 2, 2], [-1, -1, -1, -1]]),\n    },\n}\n\nDATA_8 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"label_names\": LABEL_NAMES,\n}\n\nDATA_9 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"label_names\": LABEL_NAMES,\n    \"guidance\": {\"spleen\": np.array([0, 2, 2]), \"background\": np.array([-1, -1, -1])},\n}\n\nDATA_10 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    PostFix.meta(\"image\"): {\"dim\": IMAGE.shape, \"spatial_shape\": IMAGE[0, ...].shape},\n    \"current_label\": \"spleen\",\n}\n\nDATA_11 = {\"image\": IMAGE, \"label\": LABEL, \"label_names\": LABEL_NAMES, \"pred\": PRED}\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE = [\n    {\"ref_image\": \"image\", \"guidance\": \"guidance\", \"label_names\": LABEL_NAMES},  # arguments\n    DATA_4,  # input_data\n    [0, 4, 3],  # expected_result\n]\n\nADD_GUIDANCE_CUSTOM_TEST_CASE = [\n    {\"keys\": \"image\", \"guidance\": \"guidance\"},  # arguments\n    DATA_3,  # input_data\n    3,  # expected_result\n]\n\nADD_INITIAL_POINT_TEST_CASE = [\n    {\"keys\": \"label\", \"guidance\": \"guidance\", \"sids\": \"sids\"},  # arguments\n    DATA_5,  # input_data\n    {\n        \"spleen\": \"[[1, 0, 7], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1]]\",\n        \"background\": \"[[1, 5, 3], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1]]\",\n    },  # expected_result\n]\n\nADD_RANDOM_GUIDANCE_TEST_CASE = [\n    {\"keys\": \"NA\", \"guidance\": \"guidance\", \"discrepancy\": \"discrepancy\", \"probability\": \"probability\"},  # arguments\n    DATA_2,  # input_data\n    0,  # expected_result\n]\n\nDISCARD_ADD_GUIDANCE_TEST_CASE = [\n    {\"keys\": \"image\", \"label_names\": LABEL_NAMES},  # arguments\n    DATA_1,  # input_data\n    (3, 10, 10, 10),  # expected_result\n]\n\nFIND_DISCREPANCY_TEST_CASE = [\n    {\"keys\": \"label\", \"pred\": \"pred\", \"discrepancy\": \"discrepancy\"},  # arguments\n    DATA_7,  # input_data\n    240,  # expected_result\n]\n\nFIND_SLICE_TEST_CASE = [\n    {\"keys\": \"label\", \"sids\": \"sids\"},  # arguments\n    DATA_6,  # input_data\n    {\"spleen\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], \"background\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]},  # expected_result\n]\n\nNormalizeLabelsDatasetd_TEST_CASE = [\n    {\"keys\": \"label\", \"label_names\": LABEL_NAMES},  # arguments\n    DATA_8,  # input_data\n    len(LABEL_NAMES),  # expected_result\n]\n\nRESIZE_GUIDANCE_TEST_CASE = [\n    {\"guidance\": \"guidance\", \"ref_image\": \"image\"},  # arguments\n    DATA_9,  # input_data\n    {\"spleen\": [0, 2, 2], \"background\": [-1, -1, -1]},  # expected_result\n]\n\nSingleLabelSelectiond_TEST_CASE = [\n    {\"keys\": \"label\", \"label_names\": [\"spleen\"]},  # arguments\n    DATA_10,  # input_data\n    \"spleen\",  # expected_result\n]\n\nSplitPredsLabeld_TEST_CASE = [{\"keys\": \"pred\"}, DATA_11, (1, 10, 10)]  # arguments  # input_data  # expected_result\n\n\nclass TestAddGuidanceFromPointsCustomd(unittest.TestCase):\n\n    @parameterized.expand([ADD_GUIDANCE_FROM_POINTS_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = AddGuidanceFromPointsDeepEditd(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]][\"spleen\"][0], expected_result)\n\n\nclass TestAddGuidanceSignalCustomd(unittest.TestCase):\n\n    @parameterized.expand([ADD_GUIDANCE_CUSTOM_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = AddGuidanceSignalDeepEditd(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[\"image\"].shape[0], expected_result)\n\n\nclass TestAddInitialSeedPointMissingLabelsd(unittest.TestCase):\n\n    @parameterized.expand([ADD_INITIAL_POINT_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        seed = 0\n        add_fn = AddInitialSeedPointMissingLabelsd(**arguments)\n        add_fn.set_random_state(seed)\n        result = add_fn(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]], expected_result)\n\n\nclass TestAddRandomGuidanceCustomd(unittest.TestCase):\n\n    @parameterized.expand([ADD_RANDOM_GUIDANCE_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = AddRandomGuidanceDeepEditd(**arguments)\n        result = add_fn(input_data)\n        label_key = list(result[arguments[\"guidance\"]].keys())[0]\n        self.assertGreaterEqual(len(result[arguments[\"guidance\"]][label_key]), expected_result)\n\n\nclass TestDiscardAddGuidanced(unittest.TestCase):\n\n    @parameterized.expand([DISCARD_ADD_GUIDANCE_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = DiscardAddGuidanced(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[\"image\"].shape, expected_result)\n\n\nclass TestFindAllValidSlicesMissingLabelsd(unittest.TestCase):\n\n    @parameterized.expand([FIND_SLICE_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = FindAllValidSlicesMissingLabelsd(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[arguments[\"sids\"]], expected_result)\n\n\nclass TestFindDiscrepancyRegionsCustomd(unittest.TestCase):\n\n    @parameterized.expand([FIND_DISCREPANCY_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = FindDiscrepancyRegionsDeepEditd(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(np.sum(result[arguments[\"discrepancy\"]][\"spleen\"][0]), expected_result)\n\n\nclass TestNormalizeLabelsDatasetd(unittest.TestCase):\n\n    @parameterized.expand([NormalizeLabelsDatasetd_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = NormalizeLabelsInDatasetd(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(len(np.unique(result[\"label\"])), expected_result)\n\n    def test_ordering_determinism(self):\n        \"\"\"Test that different input ordering produces the same output (alphabetical)\"\"\"\n        # Create a label array with different label values\n        label = np.array([[[0, 1, 6, 3]]])  # background=0, spleen=1, liver=6, kidney=3\n\n        # Test case 1: liver first, then kidney, then spleen\n        data1 = {\"label\": label.copy()}\n        transform1 = RemapLabelsToSequentiald(\n            keys=\"label\", label_names={\"liver\": 6, \"kidney\": 3, \"spleen\": 1, \"background\": 0}\n        )\n        result1 = transform1(data1)\n\n        # Test case 2: spleen first, then kidney, then liver (different order)\n        data2 = {\"label\": label.copy()}\n        transform2 = RemapLabelsToSequentiald(\n            keys=\"label\", label_names={\"spleen\": 1, \"kidney\": 3, \"liver\": 6, \"background\": 0}\n        )\n        result2 = transform2(data2)\n\n        # Both should produce the same output (alphabetically sorted)\n        # Expected mapping: background=0, kidney=1, liver=2, spleen=3\n        np.testing.assert_array_equal(result1[\"label\"], result2[\"label\"])\n\n        # Verify the actual mapping is alphabetical\n        expected_output = np.array([[[0, 3, 2, 1]]])  # kidney=1, liver=2, spleen=3, background=0\n        np.testing.assert_array_equal(result1[\"label\"], expected_output)\n\n        # Verify label_names is correct\n        self.assertEqual(result1[\"label_names\"], {\"background\": 0, \"kidney\": 1, \"liver\": 2, \"spleen\": 3})\n        self.assertEqual(result2[\"label_names\"], {\"background\": 0, \"kidney\": 1, \"liver\": 2, \"spleen\": 3})\n\n    def test_multiple_labels(self):\n        \"\"\"Test with multiple non-background labels\"\"\"\n        label = np.array([[[0, 1, 2, 5]]])  # background, spleen, kidney, liver\n        data = {\"label\": label.copy()}\n        transform = RemapLabelsToSequentiald(\n            keys=\"label\", label_names={\"spleen\": 1, \"kidney\": 2, \"liver\": 5, \"background\": 0}\n        )\n        result = transform(data)\n\n        # Expected: background=0, kidney=1, liver=2, spleen=3 (alphabetical)\n        expected = np.array([[[0, 3, 1, 2]]])\n        np.testing.assert_array_equal(result[\"label\"], expected)\n        self.assertEqual(result[\"label_names\"], {\"background\": 0, \"kidney\": 1, \"liver\": 2, \"spleen\": 3})\n\n    def test_deprecated_name_warning(self):\n        \"\"\"Test that NormalizeLabelsInDatasetd is properly deprecated.\n\n        The deprecation warning only triggers when MONAI version >= 1.6 (since=\"1.6\").\n        This test verifies:\n        1. The actual NormalizeLabelsInDatasetd class is marked as deprecated in docstring\n        2. The class is a subclass of RemapLabelsToSequentiald\n        3. The deprecation mechanism works correctly (tested via version_val simulation)\n        4. The actual class functions correctly\n        \"\"\"\n        import warnings\n\n        from monai.utils import deprecated\n\n        # Verify NormalizeLabelsInDatasetd docstring indicates deprecation\n        self.assertIn(\"deprecated\", NormalizeLabelsInDatasetd.__doc__.lower())\n        self.assertIn(\"RemapLabelsToSequentiald\", NormalizeLabelsInDatasetd.__doc__)\n\n        # Verify NormalizeLabelsInDatasetd is a subclass of RemapLabelsToSequentiald\n        self.assertTrue(issubclass(NormalizeLabelsInDatasetd, RemapLabelsToSequentiald))\n\n        # Test the deprecation mechanism using version_val to simulate version 1.6\n        # This verifies the @deprecated decorator behavior that NormalizeLabelsInDatasetd uses\n        @deprecated(\n            since=\"1.6\",\n            removed=\"1.8\",\n            msg_suffix=\"Use `RemapLabelsToSequentiald` instead.\",\n            version_val=\"1.6\",  # Simulate version 1.6 to trigger warning\n        )\n        class DeprecatedNormalizeLabels(RemapLabelsToSequentiald):\n            pass\n\n        data = {\"label\": np.array([[[0, 1]]])}\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            transform = DeprecatedNormalizeLabels(keys=\"label\", label_names={\"spleen\": 1, \"background\": 0})\n            _ = transform(data)\n\n            # Check that a deprecation warning was raised\n            self.assertEqual(len(w), 1)\n            self.assertTrue(issubclass(w[0].category, FutureWarning))\n            self.assertIn(\"RemapLabelsToSequentiald\", str(w[0].message))\n\n        # Verify the actual NormalizeLabelsInDatasetd class works correctly\n        transform_actual = NormalizeLabelsInDatasetd(keys=\"label\", label_names={\"spleen\": 1, \"background\": 0})\n        result = transform_actual({\"label\": np.array([[[0, 1]]])})\n        self.assertIn(\"label\", result)\n\n\nclass TestResizeGuidanceMultipleLabelCustomd(unittest.TestCase):\n\n    @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = ResizeGuidanceMultipleLabelDeepEditd(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]], expected_result)\n\n\nclass TestSingleLabelSelectiond(unittest.TestCase):\n\n    @parameterized.expand([SingleLabelSelectiond_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = SingleLabelSelectiond(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[\"current_label\"], expected_result)\n\n\nclass TestSplitPredsLabeld(unittest.TestCase):\n\n    @parameterized.expand([SplitPredsLabeld_TEST_CASE])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        add_fn = SplitPredsLabeld(**arguments)\n        result = add_fn(input_data)\n        self.assertEqual(result[\"pred_spleen\"].shape, expected_result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/deepgrow/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/deepgrow/test_deepgrow_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.deepgrow.dataset import create_dataset\nfrom monai.utils import set_determinism\n\nTEST_CASE_1 = [{\"dimension\": 2, \"pixdim\": (1, 1)}, {\"length\": 3}, 9, 1]\n\nTEST_CASE_2 = [{\"dimension\": 2, \"pixdim\": (1, 1), \"limit\": 1}, {\"length\": 3}, 3, 1]\n\nTEST_CASE_3 = [{\"dimension\": 2, \"pixdim\": (1, 1)}, {\"length\": 1}, 3, 1]\n\nTEST_CASE_4 = [{\"dimension\": 3, \"pixdim\": (1, 1, 1)}, {\"length\": 1}, 1, 1]\n\nTEST_CASE_5 = [{\"dimension\": 3, \"pixdim\": (1, 1, 1)}, {\"length\": 1, \"image_channel\": 1}, 1, 1]\n\nTEST_CASE_6 = [{\"dimension\": 2, \"pixdim\": (1, 1)}, {\"length\": 1, \"image_channel\": 1}, 3, 1]\n\nTEST_CASE_7 = [\n    {\"dimension\": 2, \"pixdim\": (1, 1), \"label_key\": None},\n    {\"length\": 1, \"image_channel\": 1, \"with_label\": False},\n    40,\n    None,\n]\n\nTEST_CASE_8 = [\n    {\"dimension\": 3, \"pixdim\": (1, 1, 1), \"label_key\": None},\n    {\"length\": 1, \"image_channel\": 1, \"with_label\": False},\n    1,\n    None,\n]\n\n\nclass TestCreateDataset(unittest.TestCase):\n\n    def setUp(self):\n        set_determinism(1)\n        self.tempdir = tempfile.mkdtemp()\n\n    def _create_data(self, length=1, image_channel=1, with_label=True):\n        affine = np.eye(4)\n        datalist = []\n        for i in range(length):\n            if image_channel == 1:\n                image = np.random.randint(0, 2, size=(128, 128, 40))\n            else:\n                image = np.random.randint(0, 2, size=(128, 128, 40, image_channel))\n            image_file = os.path.join(self.tempdir, f\"image{i}.nii.gz\")\n            nib.save(nib.Nifti1Image(image.astype(float), affine), image_file)\n\n            if with_label:\n                # 3 slices has label\n                label = np.zeros((128, 128, 40))\n                label[0][1][0] = 1\n                label[0][1][1] = 1\n                label[0][0][2] = 1\n                label[0][1][2] = 1\n                label_file = os.path.join(self.tempdir, f\"label{i}.nii.gz\")\n                nib.save(nib.Nifti1Image(label.astype(float), affine), label_file)\n                datalist.append({\"image\": image_file, \"label\": label_file})\n            else:\n                datalist.append({\"image\": image_file})\n\n        return datalist\n\n    @parameterized.expand(\n        [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]\n    )\n    def test_create_dataset(self, args, data_args, expected_length, expected_region):\n        datalist = self._create_data(**data_args)\n        deepgrow_datalist = create_dataset(datalist=datalist, output_dir=self.tempdir, **args)\n        self.assertEqual(len(deepgrow_datalist), expected_length)\n        if expected_region is not None:\n            self.assertEqual(deepgrow_datalist[0][\"region\"], expected_region)\n\n    def test_invalid_dim(self):\n        with self.assertRaises(ValueError):\n            create_dataset(datalist=self._create_data(), output_dir=self.tempdir, dimension=4, pixdim=(1, 1, 1, 1))\n\n    def test_empty_datalist(self):\n        with self.assertRaises(ValueError):\n            create_dataset(datalist=[], output_dir=self.tempdir, dimension=3, pixdim=(1, 1, 1))\n\n    def tearDown(self):\n        shutil.rmtree(self.tempdir)\n        set_determinism(None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/deepgrow/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/deepgrow/transforms/test_deepgrow_interaction.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.apps.deepgrow.interaction import Interaction\nfrom monai.apps.deepgrow.transforms import (\n    AddGuidanceSignald,\n    AddInitialSeedPointd,\n    AddRandomGuidanced,\n    FindAllValidSlicesd,\n    FindDiscrepancyRegionsd,\n)\nfrom monai.data import Dataset\nfrom monai.engines import SupervisedTrainer\nfrom monai.engines.utils import IterationEvents\nfrom monai.transforms import Activationsd, Compose, ToNumpyd, ToTensord\n\n\ndef add_one(engine):\n    if engine.state.best_metric == -1:\n        engine.state.best_metric = 0\n    else:\n        engine.state.best_metric = engine.state.best_metric + 1\n\n\nclass TestInteractions(unittest.TestCase):\n\n    def run_interaction(self, train, compose):\n        data = [{\"image\": np.ones((1, 2, 2, 2)).astype(np.float32), \"label\": np.ones((1, 2, 2, 2))} for _ in range(5)]\n        network = torch.nn.Linear(2, 2)\n        lr = 1e-3\n        opt = torch.optim.SGD(network.parameters(), lr)\n        loss = torch.nn.L1Loss()\n        train_transforms = Compose(\n            [\n                FindAllValidSlicesd(label=\"label\", sids=\"sids\"),\n                AddInitialSeedPointd(label=\"label\", guidance=\"guidance\", sids=\"sids\"),\n                AddGuidanceSignald(image=\"image\", guidance=\"guidance\"),\n                ToTensord(keys=(\"image\", \"label\")),\n            ]\n        )\n        dataset = Dataset(data, transform=train_transforms)\n        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)\n\n        iteration_transforms = [\n            Activationsd(keys=\"pred\", sigmoid=True),\n            ToNumpyd(keys=[\"image\", \"label\", \"pred\"]),\n            FindDiscrepancyRegionsd(label=\"label\", pred=\"pred\", discrepancy=\"discrepancy\"),\n            AddRandomGuidanced(guidance=\"guidance\", discrepancy=\"discrepancy\", probability=\"probability\"),\n            AddGuidanceSignald(image=\"image\", guidance=\"guidance\"),\n            ToTensord(keys=(\"image\", \"label\")),\n        ]\n        iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms\n\n        i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5)\n        self.assertEqual(len(i.transforms.transforms), 6, \"Mismatch in expected transforms\")\n\n        # set up engine\n        engine = SupervisedTrainer(\n            device=torch.device(\"cpu\"),\n            max_epochs=1,\n            train_data_loader=data_loader,\n            network=network,\n            optimizer=opt,\n            loss_function=loss,\n            iteration_update=i,\n        )\n        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)\n        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)\n\n        engine.run()\n        self.assertIsNotNone(engine.state.batch[0].get(\"guidance\"), \"guidance is missing\")\n        self.assertEqual(engine.state.best_metric, 9)\n\n    def test_train_interaction(self):\n        self.run_interaction(train=True, compose=True)\n\n    def test_val_interaction(self):\n        self.run_interaction(train=False, compose=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/deepgrow/transforms/test_deepgrow_transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.deepgrow.transforms import (\n    AddGuidanceFromPointsd,\n    AddGuidanceSignald,\n    AddInitialSeedPointd,\n    AddRandomGuidanced,\n    Fetch2DSliced,\n    FindAllValidSlicesd,\n    FindDiscrepancyRegionsd,\n    ResizeGuidanced,\n    RestoreLabeld,\n    SpatialCropForegroundd,\n    SpatialCropGuidanced,\n)\nfrom monai.utils.enums import PostFix\n\nIMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]])\nLABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]])\n\nDATA_1 = {\"image\": IMAGE, \"label\": LABEL, PostFix.meta(\"image\"): {}, PostFix.meta(\"label\"): {}}\n\nDATA_2 = {\n    \"image\": np.array(\n        [\n            [\n                [[1, 2, 3, 2, 1], [1, 1, 3, 2, 1], [0, 0, 0, 0, 0], [1, 1, 1, 2, 1], [0, 2, 2, 2, 1]],\n                [[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]],\n            ]\n        ]\n    ),\n    \"label\": np.array(\n        [\n            [\n                [[0, 0, 1, 0, 0], [0, 1, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 0, 1, 0, 0]],\n                [[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]],\n            ]\n        ]\n    ),\n    \"guidance\": np.array([[[1, 0, 2, 2], [1, 1, 2, 2]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]),\n}\n\nDATA_3 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    \"pred\": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]),\n}\n\nDATA_4 = {\n    \"image\": IMAGE,\n    \"label\": LABEL,\n    \"guidance\": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),\n    \"discrepancy\": np.array(\n        [\n            [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],\n            [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],\n        ]\n    ),\n    \"probability\": 1.0,\n}\n\nDATA_5 = {\n    \"image\": np.arange(25).reshape((1, 5, 5)),\n    PostFix.meta(\"image\"): {\"spatial_shape\": [5, 5, 1]},\n    \"foreground\": [[2, 2, 0]],\n    \"background\": [],\n}\n\nDATA_6 = {\n    \"image\": np.arange(25).reshape((1, 5, 5)),\n    PostFix.meta(\"image\"): {\"spatial_shape\": [5, 2, 1]},\n    \"foreground\": [[2, 1, 0]],\n    \"background\": [[1, 0, 0]],\n}\n\nDATA_7 = {\n    \"image\": np.arange(500).reshape((5, 10, 10)),\n    PostFix.meta(\"image\"): {\"spatial_shape\": [20, 20, 10]},\n    \"foreground\": [[10, 14, 6], [10, 14, 8]],\n    \"background\": [[10, 16, 8]],\n    \"slice\": 6,\n}\n\nDATA_8 = {\n    \"image\": np.arange(500).reshape((1, 5, 10, 10)),\n    PostFix.meta(\"image\"): {\"spatial_shape\": [20, 20, 10]},\n    \"guidance\": [[[3, 5, 7], [4, 5, 7]], [[4, 5, 8]]],\n}\n\nDATA_9 = {\n    \"image\": np.arange(1000).reshape((1, 5, 10, 20)),\n    PostFix.meta(\"image\"): {\"foreground_cropped_shape\": (1, 10, 20, 40)},\n    \"guidance\": [[[6, 10, 14], [8, 10, 14]], [[8, 10, 16]]],\n}\n\nDATA_10 = {\n    \"image\": np.arange(9).reshape((1, 1, 3, 3)),\n    PostFix.meta(\"image\"): {\n        \"spatial_shape\": [3, 3, 1],\n        \"foreground_start_coord\": np.array([0, 0, 0]),\n        \"foreground_end_coord\": np.array([1, 3, 3]),\n        \"foreground_original_shape\": (1, 1, 3, 3),\n        \"foreground_cropped_shape\": (1, 1, 3, 3),\n        \"original_affine\": np.array(\n            [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]\n        ),\n    },\n    \"pred\": np.array([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]]),\n}\n\nDATA_11 = {\n    \"image\": np.arange(500).reshape((1, 5, 10, 10)),\n    PostFix.meta(\"image\"): {\n        \"spatial_shape\": [20, 20, 10],\n        \"foreground_start_coord\": np.array([2, 2, 2]),\n        \"foreground_end_coord\": np.array([4, 4, 4]),\n        \"foreground_original_shape\": (1, 5, 10, 10),\n        \"foreground_cropped_shape\": (1, 2, 2, 2),\n        \"original_affine\": np.array(\n            [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]\n        ),\n    },\n    \"pred\": np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]),\n}\n\nDATA_12 = {\"image\": np.arange(27).reshape(3, 3, 3), PostFix.meta(\"image\"): {}, \"guidance\": [[0, 0, 0], [0, 1, 1], 1]}\n\nDATA_13 = {\n    \"image\": np.arange(64).reshape((1, 4, 4, 4)),\n    PostFix.meta(\"image\"): {\n        \"spatial_shape\": [8, 8, 4],\n        \"foreground_start_coord\": np.array([1, 1, 1]),\n        \"foreground_end_coord\": np.array([3, 3, 3]),\n        \"foreground_original_shape\": (1, 4, 4, 4),\n        \"foreground_cropped_shape\": (1, 2, 2, 2),\n        \"original_affine\": np.array(\n            [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]\n        ),\n    },\n    \"pred\": np.array([[[[10, 20], [30, 40]], [[50, 60], [70, 80]]]]),\n}\n\nFIND_SLICE_TEST_CASE_1 = [{\"label\": \"label\", \"sids\": \"sids\"}, DATA_1, [0]]\n\nFIND_SLICE_TEST_CASE_2 = [{\"label\": \"label\", \"sids\": \"sids\"}, DATA_2, [0, 1]]\n\nCROP_TEST_CASE_1 = [\n    {\n        \"keys\": [\"image\", \"label\"],\n        \"source_key\": \"label\",\n        \"select_fn\": lambda x: x > 0,\n        \"channel_indices\": None,\n        \"margin\": 0,\n        \"spatial_size\": [1, 4, 4],\n    },\n    DATA_1,\n    np.array([[[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]]),\n]\n\nCROP_TEST_CASE_2 = [\n    {\n        \"keys\": [\"image\", \"label\"],\n        \"source_key\": \"label\",\n        \"select_fn\": lambda x: x > 0,\n        \"channel_indices\": None,\n        \"margin\": 0,\n        \"spatial_size\": [2, 4, 4],\n    },\n    DATA_1,\n    np.array([1, 1, 4, 4]),\n]\n\nADD_INITIAL_POINT_TEST_CASE_1 = [\n    {\"label\": \"label\", \"guidance\": \"guidance\", \"sids\": \"sids\"},\n    DATA_1,\n    \"[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]\",\n]\n\nADD_GUIDANCE_TEST_CASE_1 = [\n    {\"image\": \"image\", \"guidance\": \"guidance\"},\n    DATA_2,\n    np.array(\n        [\n            [\n                [[1, 2, 3, 2, 1], [1, 1, 3, 2, 1], [0, 0, 0, 0, 0], [1, 1, 1, 2, 1], [0, 2, 2, 2, 1]],\n                [[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]],\n            ],\n            [\n                [\n                    [0.0, 0.26689214, 0.37996644, 0.26689214, 0.0],\n                    [0.26689214, 0.65222847, 0.81548417, 0.65222847, 0.26689214],\n                    [0.37996635, 0.81548399, 1.0, 0.81548399, 0.37996635],\n                    [0.26689214, 0.65222847, 0.81548417, 0.65222847, 0.26689214],\n                    [0.0, 0.26689214, 0.37996644, 0.26689214, 0.0],\n                ],\n                [\n                    [0.0, 0.26689214, 0.37996644, 0.26689214, 0.0],\n                    [0.26689214, 0.65222847, 0.81548417, 0.65222847, 0.26689214],\n                    [0.37996635, 0.81548399, 1.0, 0.81548399, 0.37996635],\n                    [0.26689214, 0.65222847, 0.81548417, 0.65222847, 0.26689214],\n                    [0.0, 0.26689214, 0.37996644, 0.26689214, 0.0],\n                ],\n            ],\n            [\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n            ],\n        ]\n    ),\n]\n\nFIND_DISCREPANCY_TEST_CASE_1 = [\n    {\"label\": \"label\", \"pred\": \"pred\", \"discrepancy\": \"discrepancy\"},\n    DATA_3,\n    np.array(\n        [\n            [[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],\n            [[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],\n        ]\n    ),\n]\n\nADD_RANDOM_GUIDANCE_TEST_CASE_1 = [\n    {\"guidance\": \"guidance\", \"discrepancy\": \"discrepancy\", \"probability\": \"probability\"},\n    DATA_4,\n    \"[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]\",\n]\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE_1 = [\n    {\"ref_image\": \"image\", \"spatial_dims\": 3, \"guidance\": \"guidance\", \"depth_first\": True},\n    DATA_5,\n    [[0, 2, 2]],\n    [],\n]\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE_2 = [\n    {\"ref_image\": \"image\", \"spatial_dims\": 3, \"guidance\": \"guidance\", \"depth_first\": True},\n    DATA_6,\n    [[0, 2, 2]],\n    [[0, 1, 0]],\n]\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE_3 = [\n    {\"ref_image\": \"image\", \"spatial_dims\": 3, \"guidance\": \"guidance\", \"depth_first\": True},\n    DATA_7,\n    [[3, 5, 7], [4, 5, 7]],\n    [[4, 5, 8]],\n]\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE_4 = [\n    {\"ref_image\": \"image\", \"spatial_dims\": 2, \"guidance\": \"guidance\", \"depth_first\": True},\n    DATA_6,\n    [[2, 2]],\n    [[1, 0]],\n]\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE_5 = [\n    {\"ref_image\": \"image\", \"spatial_dims\": 2, \"guidance\": \"guidance\", \"depth_first\": True, \"slice_key\": \"slice\"},\n    DATA_7,\n    [[5, 7]],\n    [],\n]\n\nADD_GUIDANCE_FROM_POINTS_TEST_CASE_6 = [\n    {\"ref_image\": \"image\", \"spatial_dims\": 2, \"guidance\": \"guidance\", \"depth_first\": True},\n    DATA_5,\n    [[2, 2]],\n    [],\n]\n\nSPATIAL_CROP_GUIDANCE_TEST_CASE_1 = [\n    {\"keys\": [\"image\"], \"guidance\": \"guidance\", \"spatial_size\": [1, 4, 4], \"margin\": 0},\n    DATA_8,\n    np.array([[[[357, 358]], [[457, 458]]]]),\n]\n\nSPATIAL_CROP_GUIDANCE_TEST_CASE_2 = [\n    {\"keys\": [\"image\"], \"guidance\": \"guidance\", \"spatial_size\": [2, 2], \"margin\": 1},\n    DATA_8,\n    np.array(\n        [\n            [\n                [[246, 247, 248, 249], [256, 257, 258, 259], [266, 267, 268, 269]],\n                [[346, 347, 348, 349], [356, 357, 358, 359], [366, 367, 368, 369]],\n                [[446, 447, 448, 449], [456, 457, 458, 459], [466, 467, 468, 469]],\n            ]\n        ]\n    ),\n]\n\nSPATIAL_CROP_GUIDANCE_TEST_CASE_3 = [\n    {\"keys\": [\"image\"], \"guidance\": \"guidance\", \"spatial_size\": [3, 3], \"margin\": 0},\n    DATA_8,\n    np.array(\n        [\n            [\n                [[47, 48, 49], [57, 58, 59], [67, 68, 69]],\n                [[147, 148, 149], [157, 158, 159], [167, 168, 169]],\n                [[247, 248, 249], [257, 258, 259], [267, 268, 269]],\n                [[347, 348, 349], [357, 358, 359], [367, 368, 369]],\n                [[447, 448, 449], [457, 458, 459], [467, 468, 469]],\n            ]\n        ]\n    ),\n]\n\nRESIZE_GUIDANCE_TEST_CASE_1 = [\n    {\"ref_image\": \"image\", \"guidance\": \"guidance\"},\n    DATA_9,\n    [[[3, 5, 7], [4, 5, 7]], [[4, 5, 8]]],\n]\n\nRESTORE_LABEL_TEST_CASE_1 = [\n    {\"keys\": [\"pred\"], \"ref_image\": \"image\", \"mode\": \"nearest\"},\n    DATA_10,\n    np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),\n]\n\nRESULT = np.zeros((10, 20, 20))\nRESULT[4:8, 4:8, 4:8] = np.array(\n    [\n        [[1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0]],\n        [[1.0, 1.0, 2.0, 2.0], [1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0], [3.0, 3.0, 4.0, 4.0]],\n        [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]],\n        [[5.0, 5.0, 6.0, 6.0], [5.0, 5.0, 6.0, 6.0], [7.0, 7.0, 8.0, 8.0], [7.0, 7.0, 8.0, 8.0]],\n    ]\n)\n\nRESTORE_LABEL_TEST_CASE_2 = [{\"keys\": [\"pred\"], \"ref_image\": \"image\", \"mode\": \"nearest\"}, DATA_11, RESULT]\n\nRESTORE_LABEL_TEST_CASE_3_RESULT = np.zeros((10, 20, 20))\nRESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 0:10] = 1\nRESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 10:20] = 2\nRESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 0:10] = 3\nRESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 10:20] = 4\nRESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 0:10] = 5\nRESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 10:20] = 6\nRESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 0:10] = 7\nRESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 10:20] = 8\n\nRESTORE_LABEL_TEST_CASE_3 = [\n    {\"keys\": [\"pred\"], \"ref_image\": \"image\", \"mode\": \"nearest\", \"restore_cropping\": False},\n    DATA_11,\n    RESTORE_LABEL_TEST_CASE_3_RESULT,\n]\n\nRESTORE_LABEL_TEST_CASE_4_RESULT = np.zeros((4, 8, 8))\nRESTORE_LABEL_TEST_CASE_4_RESULT[1, 2:6, 2:6] = np.array(\n    [[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]\n)\nRESTORE_LABEL_TEST_CASE_4_RESULT[2, 2:6, 2:6] = np.array(\n    [[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]\n)\n\nRESTORE_LABEL_TEST_CASE_4 = [\n    {\"keys\": [\"pred\"], \"ref_image\": \"image\", \"mode\": \"nearest\", \"restore_resizing\": False},\n    DATA_13,\n    RESTORE_LABEL_TEST_CASE_4_RESULT,\n]\n\nRESTORE_LABEL_TEST_CASE_5_RESULT = np.zeros((4, 4, 4))\nRESTORE_LABEL_TEST_CASE_5_RESULT[1, 1:3, 1:3] = np.array([[10.0, 20.0], [30.0, 40.0]])\nRESTORE_LABEL_TEST_CASE_5_RESULT[2, 1:3, 1:3] = np.array([[50.0, 60.0], [70.0, 80.0]])\n\nRESTORE_LABEL_TEST_CASE_5 = [\n    {\"keys\": [\"pred\"], \"ref_image\": \"image\", \"mode\": \"nearest\", \"restore_spacing\": False},\n    DATA_13,\n    RESTORE_LABEL_TEST_CASE_5_RESULT,\n]\n\nRESTORE_LABEL_TEST_CASE_6_RESULT = np.zeros((1, 4, 8, 8))\nRESTORE_LABEL_TEST_CASE_6_RESULT[-1, 1, 2:6, 2:6] = np.array(\n    [[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]\n)\nRESTORE_LABEL_TEST_CASE_6_RESULT[-1, 2, 2:6, 2:6] = np.array(\n    [[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]\n)\n\nRESTORE_LABEL_TEST_CASE_6 = [\n    {\"keys\": [\"pred\"], \"ref_image\": \"image\", \"mode\": \"nearest\", \"restore_slicing\": False},\n    DATA_13,\n    RESTORE_LABEL_TEST_CASE_6_RESULT,\n]\n\nRESTORE_LABEL_TEST_CASE_7 = [\n    {\n        \"keys\": [\"pred\"],\n        \"ref_image\": \"image\",\n        \"mode\": \"nearest\",\n        \"restore_resizing\": False,\n        \"restore_cropping\": False,\n        \"restore_spacing\": False,\n        \"restore_slicing\": False,\n    },\n    DATA_11,\n    np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]),\n]\n\nFETCH_2D_SLICE_TEST_CASE_1 = [\n    {\"keys\": [\"image\"], \"guidance\": \"guidance\"},\n    DATA_12,\n    np.array([[9, 10, 11], [12, 13, 14], [15, 16, 17]]),\n]\n\n\nclass TestFindAllValidSlicesd(unittest.TestCase):\n\n    @parameterized.expand([FIND_SLICE_TEST_CASE_1, FIND_SLICE_TEST_CASE_2])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = FindAllValidSlicesd(**arguments)(input_data)\n        np.testing.assert_allclose(result[arguments[\"sids\"]], expected_result)\n\n\nclass TestSpatialCropForegroundd(unittest.TestCase):\n\n    @parameterized.expand([CROP_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = SpatialCropForegroundd(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"image\"], expected_result)\n\n    @parameterized.expand([CROP_TEST_CASE_2])\n    def test_correct_shape(self, arguments, input_data, expected_shape):\n        result = SpatialCropForegroundd(**arguments)(input_data)\n        np.testing.assert_equal(result[\"image\"].shape, expected_shape)\n\n    @parameterized.expand([CROP_TEST_CASE_1])\n    def test_foreground_position(self, arguments, input_data, _):\n        result = SpatialCropForegroundd(**arguments)(input_data)\n        np.testing.assert_allclose(result[PostFix.meta(\"image\")][\"foreground_start_coord\"], np.array([0, 1, 1]))\n        np.testing.assert_allclose(result[PostFix.meta(\"image\")][\"foreground_end_coord\"], np.array([1, 4, 4]))\n\n        arguments[\"start_coord_key\"] = \"test_start_coord\"\n        arguments[\"end_coord_key\"] = \"test_end_coord\"\n        result = SpatialCropForegroundd(**arguments)(input_data)\n        np.testing.assert_allclose(result[PostFix.meta(\"image\")][\"test_start_coord\"], np.array([0, 1, 1]))\n        np.testing.assert_allclose(result[PostFix.meta(\"image\")][\"test_end_coord\"], np.array([1, 4, 4]))\n\n\nclass TestAddInitialSeedPointd(unittest.TestCase):\n\n    @parameterized.expand([ADD_INITIAL_POINT_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        seed = 0\n        add_fn = AddInitialSeedPointd(**arguments)\n        add_fn.set_random_state(seed)\n        result = add_fn(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]], expected_result)\n\n\nclass TestAddGuidanceSignald(unittest.TestCase):\n\n    @parameterized.expand([ADD_GUIDANCE_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = AddGuidanceSignald(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"image\"], expected_result, rtol=1e-5)\n\n\nclass TestFindDiscrepancyRegionsd(unittest.TestCase):\n\n    @parameterized.expand([FIND_DISCREPANCY_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = FindDiscrepancyRegionsd(**arguments)(input_data)\n        np.testing.assert_allclose(result[arguments[\"discrepancy\"]], expected_result)\n\n\nclass TestAddRandomGuidanced(unittest.TestCase):\n\n    @parameterized.expand([ADD_RANDOM_GUIDANCE_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        seed = 0\n        add_fn = AddRandomGuidanced(**arguments)\n        add_fn.set_random_state(seed)\n        result = add_fn(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]], expected_result)\n\n\nclass TestAddGuidanceFromPointsd(unittest.TestCase):\n\n    @parameterized.expand(\n        [\n            ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1,\n            ADD_GUIDANCE_FROM_POINTS_TEST_CASE_2,\n            ADD_GUIDANCE_FROM_POINTS_TEST_CASE_3,\n            ADD_GUIDANCE_FROM_POINTS_TEST_CASE_4,\n            ADD_GUIDANCE_FROM_POINTS_TEST_CASE_5,\n            ADD_GUIDANCE_FROM_POINTS_TEST_CASE_6,\n        ]\n    )\n    def test_correct_results(self, arguments, input_data, expected_pos, expected_neg):\n        result = AddGuidanceFromPointsd(**arguments)(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]][0], expected_pos)\n        self.assertEqual(result[arguments[\"guidance\"]][1], expected_neg)\n\n\nclass TestSpatialCropGuidanced(unittest.TestCase):\n\n    @parameterized.expand(\n        [SPATIAL_CROP_GUIDANCE_TEST_CASE_1, SPATIAL_CROP_GUIDANCE_TEST_CASE_2, SPATIAL_CROP_GUIDANCE_TEST_CASE_3]\n    )\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = SpatialCropGuidanced(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"image\"], expected_result)\n\n\nclass TestResizeGuidanced(unittest.TestCase):\n\n    @parameterized.expand([RESIZE_GUIDANCE_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = ResizeGuidanced(**arguments)(input_data)\n        self.assertEqual(result[arguments[\"guidance\"]], expected_result)\n\n\nclass TestRestoreLabeld(unittest.TestCase):\n\n    @parameterized.expand(\n        [\n            RESTORE_LABEL_TEST_CASE_1,\n            RESTORE_LABEL_TEST_CASE_2,\n            RESTORE_LABEL_TEST_CASE_3,\n            RESTORE_LABEL_TEST_CASE_4,\n            RESTORE_LABEL_TEST_CASE_5,\n            RESTORE_LABEL_TEST_CASE_6,\n            RESTORE_LABEL_TEST_CASE_7,\n        ]\n    )\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = RestoreLabeld(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"pred\"], expected_result)\n\n\nclass TestFetch2DSliced(unittest.TestCase):\n\n    @parameterized.expand([FETCH_2D_SLICE_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = Fetch2DSliced(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"image\"], expected_result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/detection/metrics/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/detection/metrics/test_detection_coco_metrics.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.apps.detection.metrics.coco import COCOMetric\nfrom monai.apps.detection.metrics.matching import matching_batch\nfrom monai.data.box_utils import box_iou\n\n\nclass TestCOCOMetrics(unittest.TestCase):\n\n    def test_coco_run(self):\n        coco_metric = COCOMetric(classes=[\"c0\", \"c1\", \"c2\"], iou_list=[0.1], max_detection=[10])\n\n        num_images = 10\n\n        val_outputs_all = []\n        val_targets_all = []\n        for _ in range(num_images):\n            # randomly generate gt boxes and pred boxes\n            num_gt_boxes = random.randint(1, 3)\n            num_pred_boxes = random.randint(0, 3)\n\n            box_start = torch.randint(3, (num_pred_boxes, 3))\n            box_stop = box_start + torch.randint(1, 32, (num_pred_boxes, 3))\n            boxes = torch.cat((box_start, box_stop), dim=1).to(torch.float16)\n            val_outputs_all.append(\n                {\n                    \"boxes\": boxes,\n                    \"labels\": torch.randint(3, (num_pred_boxes,)),\n                    \"scores\": torch.randn((num_pred_boxes,)).absolute(),\n                }\n            )\n\n            box_start = torch.randint(3, (num_gt_boxes, 3))\n            box_stop = box_start + torch.randint(1, 32, (num_gt_boxes, 3))\n            boxes = torch.cat((box_start, box_stop), dim=1).to(torch.float16)\n            val_targets_all.append({\"boxes\": boxes, \"labels\": torch.randint(3, (num_gt_boxes,))})\n\n        results_metric = matching_batch(\n            iou_fn=box_iou,\n            iou_thresholds=coco_metric.iou_thresholds,\n            pred_boxes=[val_data_i[\"boxes\"].numpy() for val_data_i in val_outputs_all],\n            pred_classes=[val_data_i[\"labels\"].numpy() for val_data_i in val_outputs_all],\n            pred_scores=[val_data_i[\"scores\"].numpy() for val_data_i in val_outputs_all],\n            gt_boxes=[val_data_i[\"boxes\"].numpy() for val_data_i in val_targets_all],\n            gt_classes=[val_data_i[\"labels\"].numpy() for val_data_i in val_targets_all],\n        )\n        val_epoch_metric_dict = coco_metric(results_metric)[0]\n        np.testing.assert_array_less([-16.01], [sum(val_epoch_metric_dict.values())])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/detection/networks/test_retinanet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.networks.retinanet_network import RetinaNet, resnet_fpn_feature_extractor\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200\nfrom monai.utils import ensure_tuple, optional_import\nfrom tests.test_utils import dict_product, skip_if_quick, test_onnx_save, test_script_save\n\n_, has_torchvision = optional_import(\"torchvision\")\n_, has_onnxruntime = optional_import(\"onnxruntime\")\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nnum_anchors = 7\n\nTEST_CASE_1 = [  # 3D, batch 3, 2 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 3,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": 7,\n        \"conv1_t_stride\": (2, 2, 2),\n    },\n    (3, 2, 32, 64, 48),\n]\n\nTEST_CASE_2 = [  # 2D, batch 2, 1 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [7, 7],\n        \"conv1_t_stride\": [2, 2],\n    },\n    (2, 1, 32, 64),\n]\n\nTEST_CASE_2_A = [  # 2D, batch 2, 1 input channel, shortcut type A\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"shortcut_type\": \"A\",\n        \"conv1_t_size\": (7, 7),\n        \"conv1_t_stride\": 2,\n    },\n    (2, 1, 32, 64),\n]\n\nTEST_CASE_3 = [  # 1D, batch 1, 2 input channels\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n]\n\nTEST_CASE_3_A = [  # 1D, batch 1, 2 input channels\n    {\"pretrained\": False, \"spatial_dims\": 1, \"n_input_channels\": 2, \"num_classes\": 3, \"shortcut_type\": \"A\"},\n    (1, 2, 32),\n]\n\nTEST_CASE_4 = [  # 2D, batch 2, 1 input channel\n    {\"pretrained\": False, \"spatial_dims\": 2, \"n_input_channels\": 1, \"num_classes\": 3, \"feed_forward\": False},\n    (2, 1, 32, 64),\n]\n\n# Create all test case combinations using dict_product\nCASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]\nMODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]\n\nTEST_CASES = [[params[\"model\"], *params[\"case\"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)]\nTEST_CASES_TS = [[params[\"model\"], *params[\"case\"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]\n\n\n@unittest.skipUnless(has_torchvision, \"Requires torchvision\")\n@skip_if_quick\nclass TestRetinaNet(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_retina_shape(self, model, input_param, input_shape):\n        backbone = model(**input_param)\n        feature_extractor = resnet_fpn_feature_extractor(\n            backbone=backbone,\n            spatial_dims=input_param[\"spatial_dims\"],\n            pretrained_backbone=input_param[\"pretrained\"],\n            trainable_backbone_layers=None,\n            returned_layers=[1, 2],\n        )\n        net = RetinaNet(\n            spatial_dims=input_param[\"spatial_dims\"],\n            num_classes=input_param[\"num_classes\"],\n            num_anchors=num_anchors,\n            feature_extractor=feature_extractor,\n            size_divisible=32,\n        ).to(device)\n\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n\n            base_stride = ensure_tuple(input_param[\"conv1_t_stride\"])[0] if \"conv1_t_stride\" in input_param else 1\n            expected_cls_channel = input_param[\"num_classes\"] * num_anchors\n            expected_cls_shape = tuple(\n                (input_shape[0], expected_cls_channel)\n                + tuple(input_shape[2 + a] // s // base_stride for a in range(input_param[\"spatial_dims\"]))\n                for s in [2, 4, 8]\n            )\n            expected_box_channel = 2 * input_param[\"spatial_dims\"] * num_anchors\n            expected_box_shape = tuple(\n                (input_shape[0], expected_box_channel)\n                + tuple(input_shape[2 + a] // s // base_stride for a in range(input_param[\"spatial_dims\"]))\n                for s in [2, 4, 8]\n            )\n\n            self.assertEqual(tuple(cc.shape for cc in result[net.cls_key]), expected_cls_shape)\n            self.assertEqual(tuple(cc.shape for cc in result[net.box_reg_key]), expected_box_shape)\n\n    @parameterized.expand(TEST_CASES_TS)\n    def test_script(self, model, input_param, input_shape):\n        try:\n            idx = int(self.id().split(\"test_script_\")[-1])\n        except BaseException:\n            idx = 0\n        idx %= 3\n        # test whether support torchscript\n        data = torch.randn(input_shape)\n        backbone = model(**input_param)\n        if idx == 0:\n            test_script_save(backbone, data)\n            return\n        feature_extractor = resnet_fpn_feature_extractor(\n            backbone=backbone,\n            spatial_dims=input_param[\"spatial_dims\"],\n            pretrained_backbone=input_param[\"pretrained\"],\n            trainable_backbone_layers=None,\n            returned_layers=[1, 2],\n        )\n        if idx == 1:\n            test_script_save(feature_extractor, data)\n            return\n        net = RetinaNet(\n            spatial_dims=input_param[\"spatial_dims\"],\n            num_classes=input_param[\"num_classes\"],\n            num_anchors=num_anchors,\n            feature_extractor=feature_extractor,\n            size_divisible=32,\n        )\n        if idx == 2:\n            test_script_save(net, data)\n\n    @parameterized.expand(TEST_CASES_TS)\n    @unittest.skipUnless(has_onnxruntime, \"onnxruntime not installed\")\n    def test_onnx(self, model, input_param, input_shape):\n        try:\n            idx = int(self.id().split(\"test_onnx_\")[-1])\n        except BaseException:\n            idx = 0\n        idx %= 3\n        # test whether support torchscript\n        data = torch.randn(input_shape)\n        backbone = model(**input_param)\n        if idx == 0:\n            test_onnx_save(backbone, data, rtol=2e-2, atol=1e-5)\n            return\n        feature_extractor = resnet_fpn_feature_extractor(\n            backbone=backbone,\n            spatial_dims=input_param[\"spatial_dims\"],\n            pretrained_backbone=input_param[\"pretrained\"],\n            trainable_backbone_layers=None,\n            returned_layers=[1, 2],\n        )\n        if idx == 1:\n            test_onnx_save(feature_extractor, data, rtol=2e-2, atol=1e-5)\n            return\n        net = RetinaNet(\n            spatial_dims=input_param[\"spatial_dims\"],\n            num_classes=input_param[\"num_classes\"],\n            num_anchors=num_anchors,\n            feature_extractor=feature_extractor,\n            size_divisible=32,\n        )\n        if idx == 2:\n            test_onnx_save(net, data, rtol=2e-2, atol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/networks/test_retinanet_detector.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.networks.retinanet_detector import RetinaNetDetector, retinanet_resnet50_fpn_detector\nfrom monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape\nfrom monai.networks import eval_mode, train_mode\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_quick, test_script_save\n\n_, has_torchvision = optional_import(\"torchvision\")\n\nnum_anchors = 7\n\nTEST_CASE_1 = [  # 3D, batch 3, 2 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 3,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": 7,\n        \"conv1_t_stride\": (2, 2, 2),\n    },\n    (3, 2, 32, 64, 48),\n]\n\nTEST_CASE_2 = [  # 2D, batch 2, 1 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [7, 7],\n        \"conv1_t_stride\": [2, 2],\n    },\n    (2, 1, 32, 64),\n]\n\nTEST_CASE_2_A = [  # 2D, batch 2, 1 input channel, shortcut type A\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"shortcut_type\": \"A\",\n        \"conv1_t_size\": (7, 7),\n        \"conv1_t_stride\": 2,\n    },\n    (2, 1, 32, 64),\n]\n\nTEST_CASE_3 = [  # 1D, batch 1, 2 input channels\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n]\n\nTEST_CASE_3_A = [  # 1D, batch 1, 2 input channels\n    {\"pretrained\": False, \"spatial_dims\": 1, \"n_input_channels\": 2, \"num_classes\": 3, \"shortcut_type\": \"A\"},\n    (1, 2, 32),\n]\n\nTEST_CASE_4 = [  # 2D, batch 2, 1 input channel\n    {\"pretrained\": False, \"spatial_dims\": 2, \"n_input_channels\": 1, \"num_classes\": 3, \"feed_forward\": False},\n    (2, 1, 32, 64),\n]\n\nTEST_CASES = []\nTEST_CASES = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_A]\n\nTEST_CASES_TS = [TEST_CASE_1]\n\n\nclass NaiveNetwork(torch.nn.Module):\n    def __init__(self, spatial_dims, num_classes, **kwargs):\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.num_classes = num_classes\n        self.num_anchors = 1\n        self.cls_key = \"cls\"\n        self.box_reg_key = \"box_reg\"\n        self.size_divisible = 1\n\n    def forward(self, images):\n        out_cls_shape = (images.shape[0], self.num_classes * self.num_anchors) + images.shape[-self.spatial_dims :]\n        out_box_reg_shape = (images.shape[0], 2 * self.spatial_dims * self.num_anchors) + images.shape[\n            -self.spatial_dims :\n        ]\n        return {self.cls_key: [torch.randn(out_cls_shape)], self.box_reg_key: [torch.randn(out_box_reg_shape)]}\n\n\n@unittest.skipUnless(has_torchvision, \"Requires torchvision\")\n@skip_if_quick\nclass TestRetinaNetDetector(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_retina_detector_resnet_backbone_shape(self, input_param, input_shape):\n        returned_layers = [1]\n        anchor_generator = AnchorGeneratorWithAnchorShape(\n            feature_map_scales=(1, 2), base_anchor_shapes=((8,) * input_param[\"spatial_dims\"],)\n        )\n        detector = retinanet_resnet50_fpn_detector(\n            **input_param, anchor_generator=anchor_generator, returned_layers=returned_layers\n        )\n\n        with eval_mode(detector):\n            input_data = torch.randn(input_shape)\n            result = detector.forward(input_data)\n            assert len(result) == len(result)\n\n            input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]\n            result = detector.forward(input_data)\n            assert len(result) == len(result)\n\n        detector.set_atss_matcher()\n        detector.set_hard_negative_sampler(10, 0.5)\n        for num_gt_box in [0, 3]:  # test for both empty and non-empty boxes\n            gt_box_start = torch.randint(2, (num_gt_box, input_param[\"spatial_dims\"])).to(torch.float16)\n            gt_box_end = gt_box_start + torch.randint(1, 10, (num_gt_box, input_param[\"spatial_dims\"]))\n            one_target = {\n                \"boxes\": torch.cat((gt_box_start, gt_box_end), dim=1),\n                \"labels\": torch.randint(input_param[\"num_classes\"], (num_gt_box,)),\n            }\n            with train_mode(detector):\n                input_data = torch.randn(input_shape)\n                targets = [one_target] * len(input_data)\n                result = detector.forward(input_data, targets)\n\n                input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]\n                targets = [one_target] * len(input_data)\n                result = detector.forward(input_data, targets)\n\n    @parameterized.expand(TEST_CASES)\n    def test_naive_retina_detector_shape(self, input_param, input_shape):\n        anchor_generator = AnchorGeneratorWithAnchorShape(\n            feature_map_scales=(1,), base_anchor_shapes=((8,) * input_param[\"spatial_dims\"],)\n        )\n        detector = RetinaNetDetector(network=NaiveNetwork(**input_param), anchor_generator=anchor_generator)\n\n        with eval_mode(detector):\n            input_data = torch.randn(input_shape)\n            result = detector.forward(input_data)\n            assert len(result) == len(result)\n\n            input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]\n            result = detector.forward(input_data)\n            assert len(result) == len(result)\n\n        detector.set_atss_matcher()\n        detector.set_hard_negative_sampler(10, 0.5)\n        gt_box_start = torch.randint(2, (3, input_param[\"spatial_dims\"])).to(torch.float16)\n        gt_box_end = gt_box_start + torch.randint(1, 10, (3, input_param[\"spatial_dims\"]))\n        one_target = {\n            \"boxes\": torch.cat((gt_box_start, gt_box_end), dim=1),\n            \"labels\": torch.randint(input_param[\"num_classes\"], (3,)),\n        }\n        with train_mode(detector):\n            input_data = torch.randn(input_shape)\n            targets = [one_target] * len(input_data)\n            result = detector.forward(input_data, targets)\n\n            input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]\n            targets = [one_target] * len(input_data)\n            result = detector.forward(input_data, targets)\n\n    @parameterized.expand(TEST_CASES_TS)\n    def test_script(self, input_param, input_shape):\n        # test whether support torchscript\n        returned_layers = [1]\n        anchor_generator = AnchorGeneratorWithAnchorShape(\n            feature_map_scales=(1, 2), base_anchor_shapes=((8,) * input_param[\"spatial_dims\"],)\n        )\n        detector = retinanet_resnet50_fpn_detector(\n            **input_param, anchor_generator=anchor_generator, returned_layers=returned_layers\n        )\n        with eval_mode(detector):\n            input_data = torch.randn(input_shape)\n            test_script_save(detector.network, input_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/test_box_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.transforms.box_ops import convert_mask_to_box\nfrom monai.apps.detection.transforms.dictionary import (\n    AffineBoxToImageCoordinated,\n    AffineBoxToWorldCoordinated,\n    BoxToMaskd,\n    ClipBoxToImaged,\n    ConvertBoxModed,\n    FlipBoxd,\n    MaskToBoxd,\n    RandCropBoxByPosNegLabeld,\n    RandFlipBoxd,\n    RandRotateBox90d,\n    RandZoomBoxd,\n    RotateBox90d,\n    ZoomBoxd,\n)\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import CastToTyped, Invertd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS_3D = []\nboxes = [[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 3, 3], [0, 1, 1, 2, 3, 4]]\nlabels = [1, 1, 0]\nscores = [[0.2, 0.8], [0.3, 0.7], [0.6, 0.4]]\nimage_size = [1, 4, 6, 4]\nimage = np.zeros(image_size)\n\nfor p in TEST_NDARRAYS:\n    TESTS_3D.append(\n        [\n            {\"box_keys\": \"boxes\", \"dst_mode\": \"xyzwhd\"},\n            {\"boxes\": p(boxes), \"image\": p(image), \"labels\": p(labels), \"scores\": p(scores)},\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 2, 3], [0, 1, 1, 2, 2, 3]]),\n            p([[0, 0, 0, 0, 0, 0], [0, 3, 0, 1, 9, 4.5], [0, 3, 1.5, 1, 9, 6]]),\n            p([[1, -6, -1, 1, -6, -1], [1, -3, -1, 2, 3, 3.5], [1, -3, 0.5, 2, 3, 5]]),\n            p([[4, 6, 4, 4, 6, 4], [2, 3, 1, 4, 5, 4], [2, 3, 0, 4, 5, 3]]),\n            p([[0, 1, 0, 2, 3, 3], [0, 1, 1, 2, 3, 4]]),\n            p([[6, 0, 0, 6, 0, 0], [3, 0, 0, 5, 2, 3], [3, 0, 1, 5, 2, 4]]),\n        ]\n    )\n\nTESTS_2D = []\nboxes = [[0, 1, 2, 2], [0, 0, 1, 1]]\nlabels = [1, 0]\nimage_size = [1, 2, 2]\nimage = np.zeros(image_size)\nfor p in TEST_NDARRAYS:\n    TESTS_2D.append(\n        [{\"boxes\": p(boxes), \"image\": p(image), \"labels\": p(labels)}, p([[[0, 2], [0, 2]], [[1, 0], [0, 0]]])]\n    )\n\nTESTS_2D_mask = []\nboxes_mask = [[[-1, 0], [0, -1]]]\nfor p in TEST_NDARRAYS:\n    TESTS_2D_mask.append([p(boxes_mask), (p([[0.0, 0.0, 2.0, 2.0]]), p([0]))])\nboxes_mask = [[[-1, 0], [0, -1]], [[-1, 1], [1, -1]]]\nfor p in TEST_NDARRAYS:\n    TESTS_2D_mask.append([p(boxes_mask), (p([[0.0, 0.0, 2.0, 2.0], [0.0, 0.0, 2.0, 2.0]]), p([0, 1]))])\n\n\nclass TestBoxTransform(unittest.TestCase):\n    @parameterized.expand(TESTS_2D_mask)\n    def test_value_2d_mask(self, mask, expected_box_label):\n        box_label = convert_mask_to_box(mask)\n        assert_allclose(box_label[0], expected_box_label[0], type_test=True, device_test=True, atol=1e-3)\n        assert_allclose(box_label[1], expected_box_label[1], type_test=True, device_test=True, atol=1e-3)\n\n    @parameterized.expand(TESTS_2D)\n    def test_value_2d(self, data, expected_mask):\n        test_dtype = [torch.float32, torch.float16]\n        for dtype in test_dtype:\n            data = CastToTyped(keys=[\"image\", \"boxes\"], dtype=dtype)(data)\n            transform_to_mask = BoxToMaskd(\n                box_keys=\"boxes\",\n                box_mask_keys=\"box_mask\",\n                box_ref_image_keys=\"image\",\n                label_keys=\"labels\",\n                min_fg_label=0,\n                ellipse_mask=False,\n            )\n            transform_to_box = MaskToBoxd(\n                box_keys=\"boxes\", box_mask_keys=\"box_mask\", label_keys=\"labels\", min_fg_label=0\n            )\n            data_mask = transform_to_mask(data)\n            assert_allclose(data_mask[\"box_mask\"], expected_mask, type_test=True, device_test=True, atol=1e-3)\n            data_back = transform_to_box(data_mask)\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n            assert_allclose(data_back[\"labels\"], data[\"labels\"], type_test=False, device_test=False, atol=1e-3)\n\n    def test_value_3d_mask(self):\n        test_dtype = [torch.float32, torch.float16]\n        image = np.zeros((1, 32, 33, 34))\n        boxes = np.array([[7, 8, 9, 10, 12, 13], [1, 3, 5, 2, 5, 9], [0, 0, 0, 1, 1, 1]])\n        data = {\"image\": image, \"boxes\": boxes, \"labels\": np.array((1, 0, 3))}\n        for dtype in test_dtype:\n            data = CastToTyped(keys=[\"image\", \"boxes\"], dtype=dtype)(data)\n            transform_to_mask = BoxToMaskd(\n                box_keys=\"boxes\",\n                box_mask_keys=\"box_mask\",\n                box_ref_image_keys=\"image\",\n                label_keys=\"labels\",\n                min_fg_label=0,\n                ellipse_mask=False,\n            )\n            transform_to_box = MaskToBoxd(\n                box_keys=\"boxes\", box_mask_keys=\"box_mask\", label_keys=\"labels\", min_fg_label=0\n            )\n            data_mask = transform_to_mask(data)\n            assert_allclose(data_mask[\"box_mask\"].shape, (3, 32, 33, 34), type_test=True, device_test=True, atol=1e-3)\n            data_back = transform_to_box(data_mask)\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n            assert_allclose(data_back[\"labels\"], data[\"labels\"], type_test=False, device_test=False, atol=1e-3)\n\n    def test_shape_assertion(self):\n        test_dtype = torch.float32\n        image = np.zeros((1, 10, 10, 10))\n        boxes = np.array([[7, 8, 9, 10, 12, 13]])\n        data = {\"image\": image, \"boxes\": boxes, \"labels\": np.array((1,))}\n        data = CastToTyped(keys=[\"image\", \"boxes\"], dtype=test_dtype)(data)\n        transform_to_mask = BoxToMaskd(\n            box_keys=\"boxes\",\n            box_mask_keys=\"box_mask\",\n            box_ref_image_keys=\"image\",\n            label_keys=\"labels\",\n            min_fg_label=0,\n            ellipse_mask=False,\n        )\n        with self.assertRaises(ValueError) as context:\n            transform_to_mask(data)\n        self.assertTrue(\"Some boxes are larger than the image.\" in str(context.exception))\n\n    @parameterized.expand(TESTS_3D)\n    def test_value_3d(\n        self,\n        keys,\n        data,\n        expected_convert_result,\n        expected_zoom_result,\n        expected_zoom_keepsize_result,\n        expected_flip_result,\n        expected_clip_result,\n        expected_rotate_result,\n    ):\n        test_dtype = [torch.float32]\n        for dtype in test_dtype:\n            data = CastToTyped(keys=[\"image\", \"boxes\"], dtype=dtype)(data)\n            # test ConvertBoxToStandardModed\n            transform_convert_mode = ConvertBoxModed(**keys)\n            convert_result = transform_convert_mode(data)\n            assert_allclose(\n                convert_result[\"boxes\"], expected_convert_result, type_test=True, device_test=True, atol=1e-3\n            )\n\n            invert_transform_convert_mode = Invertd(\n                keys=[\"boxes\"], transform=transform_convert_mode, orig_keys=[\"boxes\"]\n            )\n            data_back = invert_transform_convert_mode(convert_result)\n            if \"boxes_transforms\" in data_back:  # if the transform is tracked in dict:\n                self.assertEqual(data_back[\"boxes_transforms\"], [])  # it should be updated\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n\n            # test ZoomBoxd\n            transform_zoom = ZoomBoxd(\n                image_keys=\"image\", box_keys=\"boxes\", box_ref_image_keys=\"image\", zoom=[0.5, 3, 1.5], keep_size=False\n            )\n            zoom_result = transform_zoom(data)\n            self.assertEqual(len(zoom_result[\"image\"].applied_operations), 1)\n            assert_allclose(zoom_result[\"boxes\"], expected_zoom_result, type_test=True, device_test=True, atol=1e-3)\n            invert_transform_zoom = Invertd(\n                keys=[\"image\", \"boxes\"], transform=transform_zoom, orig_keys=[\"image\", \"boxes\"]\n            )\n            data_back = invert_transform_zoom(zoom_result)\n            self.assertEqual(data_back[\"image\"].applied_operations, [])\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n            assert_allclose(data_back[\"image\"], data[\"image\"], type_test=False, device_test=False, atol=1e-3)\n\n            transform_zoom = ZoomBoxd(\n                image_keys=\"image\", box_keys=\"boxes\", box_ref_image_keys=\"image\", zoom=[0.5, 3, 1.5], keep_size=True\n            )\n            zoom_result = transform_zoom(data)\n            self.assertEqual(len(zoom_result[\"image\"].applied_operations), 1)\n            assert_allclose(\n                zoom_result[\"boxes\"], expected_zoom_keepsize_result, type_test=True, device_test=True, atol=1e-3\n            )\n\n            # test RandZoomBoxd\n            transform_zoom = RandZoomBoxd(\n                image_keys=\"image\",\n                box_keys=\"boxes\",\n                box_ref_image_keys=\"image\",\n                prob=1.0,\n                min_zoom=(0.3,) * 3,\n                max_zoom=(3.0,) * 3,\n                keep_size=False,\n            )\n            zoom_result = transform_zoom(data)\n            self.assertEqual(len(zoom_result[\"image\"].applied_operations), 1)\n            invert_transform_zoom = Invertd(\n                keys=[\"image\", \"boxes\"], transform=transform_zoom, orig_keys=[\"image\", \"boxes\"]\n            )\n            data_back = invert_transform_zoom(zoom_result)\n            self.assertEqual(data_back[\"image\"].applied_operations, [])\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=0.01)\n            assert_allclose(data_back[\"image\"], data[\"image\"], type_test=False, device_test=False, atol=1e-3)\n\n            # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated\n            transform_affine = AffineBoxToImageCoordinated(box_keys=\"boxes\", box_ref_image_keys=\"image\")\n            if not isinstance(data[\"image\"], MetaTensor):  # metadict should be undefined and it's an exception\n                with self.assertRaises(Exception) as context:\n                    transform_affine(deepcopy(data))\n                self.assertTrue(\"Please check whether it is the correct the image meta key.\" in str(context.exception))\n\n            data[\"image\"] = MetaTensor(data[\"image\"], meta={\"affine\": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1]))})\n            affine_result = transform_affine(data)\n            if \"boxes_transforms\" in affine_result:\n                self.assertEqual(len(affine_result[\"boxes_transforms\"]), 1)\n            assert_allclose(affine_result[\"boxes\"], expected_zoom_result, type_test=True, device_test=True, atol=0.01)\n            invert_transform_affine = Invertd(keys=[\"boxes\"], transform=transform_affine, orig_keys=[\"boxes\"])\n            data_back = invert_transform_affine(affine_result)\n            if \"boxes_transforms\" in data_back:\n                self.assertEqual(data_back[\"boxes_transforms\"], [])\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=0.01)\n            invert_transform_affine = AffineBoxToWorldCoordinated(box_keys=\"boxes\", box_ref_image_keys=\"image\")\n            data_back = invert_transform_affine(affine_result)\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=0.01)\n\n            # test FlipBoxd\n            transform_flip = FlipBoxd(\n                image_keys=\"image\", box_keys=\"boxes\", box_ref_image_keys=\"image\", spatial_axis=[0, 1, 2]\n            )\n            flip_result = transform_flip(data)\n            if \"boxes_transforms\" in flip_result:\n                self.assertEqual(len(flip_result[\"boxes_transforms\"]), 1)\n            assert_allclose(flip_result[\"boxes\"], expected_flip_result, type_test=True, device_test=True, atol=1e-3)\n            invert_transform_flip = Invertd(\n                keys=[\"image\", \"boxes\"], transform=transform_flip, orig_keys=[\"image\", \"boxes\"]\n            )\n            data_back = invert_transform_flip(flip_result)\n            if \"boxes_transforms\" in data_back:\n                self.assertEqual(data_back[\"boxes_transforms\"], [])\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n            assert_allclose(data_back[\"image\"], data[\"image\"], type_test=False, device_test=False, atol=1e-3)\n\n            # test RandFlipBoxd\n            for spatial_axis in [(0,), (1,), (2,), (0, 1), (1, 2)]:\n                transform_flip = RandFlipBoxd(\n                    image_keys=\"image\",\n                    box_keys=\"boxes\",\n                    box_ref_image_keys=\"image\",\n                    prob=1.0,\n                    spatial_axis=spatial_axis,\n                )\n                flip_result = transform_flip(data)\n                if \"boxes_transforms\" in flip_result:\n                    self.assertEqual(len(flip_result[\"boxes_transforms\"]), 1)\n                invert_transform_flip = Invertd(\n                    keys=[\"image\", \"boxes\"], transform=transform_flip, orig_keys=[\"image\", \"boxes\"]\n                )\n                data_back = invert_transform_flip(flip_result)\n                if \"boxes_transforms\" in data_back:\n                    self.assertEqual(data_back[\"boxes_transforms\"], [])\n                assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n                assert_allclose(data_back[\"image\"], data[\"image\"], type_test=False, device_test=False, atol=1e-3)\n\n            # test ClipBoxToImaged\n            transform_clip = ClipBoxToImaged(\n                box_keys=\"boxes\", box_ref_image_keys=\"image\", label_keys=[\"labels\", \"scores\"], remove_empty=True\n            )\n            clip_result = transform_clip(data)\n            assert_allclose(clip_result[\"boxes\"], expected_clip_result, type_test=True, device_test=True, atol=1e-3)\n            assert_allclose(clip_result[\"labels\"], data[\"labels\"][1:], type_test=True, device_test=True, atol=1e-3)\n            assert_allclose(clip_result[\"scores\"], data[\"scores\"][1:], type_test=True, device_test=True, atol=1e-3)\n\n            transform_clip = ClipBoxToImaged(\n                box_keys=\"boxes\", box_ref_image_keys=\"image\", label_keys=[], remove_empty=True\n            )  # corner case when label_keys is empty\n            clip_result = transform_clip(data)\n            assert_allclose(clip_result[\"boxes\"], expected_clip_result, type_test=True, device_test=True, atol=1e-3)\n\n            # test RandCropBoxByPosNegLabeld\n            transform_crop = RandCropBoxByPosNegLabeld(\n                image_keys=\"image\", box_keys=\"boxes\", label_keys=[\"labels\", \"scores\"], spatial_size=2, num_samples=3\n            )\n            crop_result = transform_crop(data)\n            assert len(crop_result) == 3\n            for ll in range(3):\n                assert_allclose(\n                    crop_result[ll][\"boxes\"].shape[0],\n                    crop_result[ll][\"labels\"].shape[0],\n                    type_test=True,\n                    device_test=True,\n                    atol=1e-3,\n                )\n                assert_allclose(\n                    crop_result[ll][\"boxes\"].shape[0],\n                    crop_result[ll][\"scores\"].shape[0],\n                    type_test=True,\n                    device_test=True,\n                    atol=1e-3,\n                )\n\n            # test RotateBox90d\n            transform_rotate = RotateBox90d(\n                image_keys=\"image\", box_keys=\"boxes\", box_ref_image_keys=\"image\", k=1, spatial_axes=[0, 1]\n            )\n            rotate_result = transform_rotate(data)\n            self.assertEqual(len(rotate_result[\"image\"].applied_operations), 1)\n            assert_allclose(rotate_result[\"boxes\"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3)\n            invert_transform_rotate = Invertd(\n                keys=[\"image\", \"boxes\"], transform=transform_rotate, orig_keys=[\"image\", \"boxes\"]\n            )\n            data_back = invert_transform_rotate(rotate_result)\n            self.assertEqual(data_back[\"image\"].applied_operations, [])\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n            assert_allclose(data_back[\"image\"], data[\"image\"], type_test=False, device_test=False, atol=1e-3)\n\n            transform_rotate = RandRotateBox90d(\n                image_keys=\"image\", box_keys=\"boxes\", box_ref_image_keys=\"image\", prob=1.0, max_k=3, spatial_axes=[0, 1]\n            )\n            rotate_result = transform_rotate(data)\n            self.assertEqual(len(rotate_result[\"image\"].applied_operations), 1)\n            invert_transform_rotate = Invertd(\n                keys=[\"image\", \"boxes\"], transform=transform_rotate, orig_keys=[\"image\", \"boxes\"]\n            )\n            data_back = invert_transform_rotate(rotate_result)\n            self.assertEqual(data_back[\"image\"].applied_operations, [])\n            assert_allclose(data_back[\"boxes\"], data[\"boxes\"], type_test=False, device_test=False, atol=1e-3)\n            assert_allclose(data_back[\"image\"], data[\"image\"], type_test=False, device_test=False, atol=1e-3)\n\n    def test_crop_shape(self):\n        tt = RandCropBoxByPosNegLabeld(\n            image_keys=[\"image\"],\n            box_keys=\"box\",\n            label_keys=\"label\",\n            spatial_size=[10, 7, -1],\n            whole_box=True,\n            num_samples=1,\n            pos=1,\n            neg=0,\n        )\n        iii = {\n            \"image\": torch.rand(1, 10, 8, 7),\n            \"box\": torch.tensor(((1.0, 2.0, 3.0, 4.0, 5.0, 6.0),)),\n            \"label\": torch.tensor((1,)).long(),\n        }\n        self.assertEqual(tt(iii)[0][\"image\"].shape, (1, 10, 7, 7))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/detection/utils/test_anchor_box.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.utils.anchor_utils import AnchorGenerator, AnchorGeneratorWithAnchorShape\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose, test_script_save\n\n_, has_torchvision = optional_import(\"torchvision\")\n\nTEST_CASES_2D = [\n    [\n        {\"sizes\": ((10, 12, 14, 16), (20, 24, 28, 32)), \"aspect_ratios\": ((1.0, 0.5, 2.0), (1.0, 0.5, 2.0))},\n        (5, 3, 128, 128),\n        ((5, 7, 64, 32), (5, 7, 32, 16)),\n    ]\n]\n\nTEST_CASES_SHAPE_3D = [\n    [\n        {\"feature_map_scales\": (1, 2), \"base_anchor_shapes\": ((4, 3, 6), (8, 2, 4))},\n        (5, 3, 128, 128, 128),\n        ((5, 7, 64, 32, 32), (5, 7, 32, 16, 16)),\n    ]\n]\n\n\n@unittest.skipUnless(has_torchvision, \"Requires torchvision\")\nclass TestAnchorGenerator(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):\n        torch_anchor_utils, _ = optional_import(\"torchvision.models.detection.anchor_utils\")\n        image_list, _ = optional_import(\"torchvision.models.detection.image_list\")\n\n        # test it behaves the same with torchvision for 2d\n        anchor = AnchorGenerator(**input_param, indexing=\"xy\")\n        anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param)\n        for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors):\n            assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)\n        for a, a_f in zip(anchor.num_anchors_per_location(), anchor_ref.num_anchors_per_location()):\n            assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)\n\n        grid_sizes = [[2, 2], [1, 1]]\n        strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]]\n        for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)):\n            assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)\n\n        images = torch.rand(image_shape)\n        feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)\n        result = anchor(images, feature_maps)\n        result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps)\n        for a, a_f in zip(result, result_ref):\n            assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1)\n\n    @parameterized.expand(TEST_CASES_2D)\n    def test_script_2d(self, input_param, image_shape, feature_maps_shapes):\n        # test whether support torchscript\n        anchor = AnchorGenerator(**input_param, indexing=\"xy\")\n        images = torch.rand(image_shape)\n        feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)\n        test_script_save(anchor, images, feature_maps)\n\n    @parameterized.expand(TEST_CASES_SHAPE_3D)\n    def test_script_3d(self, input_param, image_shape, feature_maps_shapes):\n        # test whether support torchscript\n        anchor = AnchorGeneratorWithAnchorShape(**input_param, indexing=\"ij\")\n        images = torch.rand(image_shape)\n        feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)\n        test_script_save(anchor, images, feature_maps)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/utils/test_atss_box_matcher.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.utils.ATSS_matcher import ATSSMatcher\nfrom monai.data.box_utils import box_iou\nfrom tests.test_utils import assert_allclose\n\nTEST_CASES = [\n    [\n        {\"num_candidates\": 2, \"similarity_fn\": box_iou, \"center_in_gt\": False},\n        torch.tensor([[0, 1, 2, 3, 2, 5]], dtype=torch.float16),\n        torch.tensor([[0, 1, 2, 3, 2, 5], [0, 1, 1, 3, 2, 5], [0, 1, 2, 3, 2, 4]], dtype=torch.float16),\n        [3],\n        3,\n        torch.tensor([0, -1, -1]),\n    ]\n]\n\n\nclass TestATSS(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_atss(self, input_param, boxes, anchors, num_anchors_per_level, num_anchors_per_loc, expected_matches):\n        matcher = ATSSMatcher(**input_param, debug=True)\n        match_quality_matrix, matches = matcher.compute_matches(\n            boxes, anchors, num_anchors_per_level, num_anchors_per_loc\n        )\n        assert_allclose(matches, expected_matches, type_test=True, device_test=True, atol=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/utils/test_box_coder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.apps.detection.utils.box_coder import BoxCoder\nfrom monai.transforms import CastToType\nfrom tests.test_utils import assert_allclose\n\n\nclass TestBoxTransform(unittest.TestCase):\n    def test_value(self):\n        box_coder = BoxCoder(weights=[1, 1, 1, 1, 1, 1])\n        test_dtype = [torch.float32, torch.float16]\n        for dtype in test_dtype:\n            gt_boxes_0 = torch.rand((10, 3)).abs()\n            gt_boxes_1 = gt_boxes_0 + torch.rand((10, 3)).abs() + 10\n            gt_boxes = torch.cat((gt_boxes_0, gt_boxes_1), dim=1)\n            gt_boxes = CastToType(dtype=dtype)(gt_boxes)\n\n            proposals_0 = (gt_boxes_0 + torch.rand(gt_boxes_0.shape)).abs()\n            proposals_1 = proposals_0 + torch.rand(gt_boxes_0.shape).abs() + 10\n            proposals = torch.cat((proposals_0, proposals_1), dim=1)\n\n            rel_gt_boxes = box_coder.encode_single(gt_boxes, proposals)\n            gt_back = box_coder.decode_single(rel_gt_boxes, proposals)\n            assert_allclose(gt_back, gt_boxes, type_test=True, device_test=True, atol=0.1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/utils/test_detector_boxselector.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.utils.box_selector import BoxSelector\nfrom tests.test_utils import assert_allclose\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nnum_anchors = 7\n\nTEST_CASE = []\nTEST_CASE.append(\n    [  # 2D\n        {\n            \"apply_sigmoid\": False,\n            \"score_thresh\": 0.1,\n            \"topk_candidates_per_level\": 2,\n            \"nms_thresh\": 0.1,\n            \"detections_per_img\": 5,\n        },\n        [torch.tensor([[1, 2, 3, 2, 3, 4], [5, 6, 7, 8, 9, 10]]).to(torch.float16)],\n        [torch.tensor([[0.1, 0.6], [0.2, 0.2]])],\n        (20, 20, 20),\n        torch.tensor([[1, 2, 3, 2, 3, 4], [5, 6, 7, 8, 9, 10]]),\n    ]\n)\nTEST_CASE.append(\n    [\n        {\n            \"apply_sigmoid\": False,\n            \"score_thresh\": 0.1,\n            \"topk_candidates_per_level\": 1,\n            \"nms_thresh\": 0.1,\n            \"detections_per_img\": 5,\n        },\n        [torch.tensor([[1, 2, 3, 2, 3, 4]]).to(torch.float32), torch.tensor([[5, 6, 7, 8, 9, 10]]).to(torch.float32)],\n        [torch.tensor([[0.3, 0.6]]), torch.tensor([[0.2, 0.2]])],\n        (20, 20, 8),\n        torch.tensor([[1, 2, 3, 2, 3, 4], [5, 6, 7, 8, 9, 8]]),\n    ]\n)\n\n\nclass TestBoxSelector(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_box_selector(self, input_param, boxes, logits, image_shape, expected_results):\n        box_selector = BoxSelector(**input_param)\n        result = box_selector.select_boxes_per_image(boxes, logits, image_shape)\n        assert_allclose(result[0], expected_results, type_test=True, device_test=False, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/utils/test_detector_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.utils.detector_utils import preprocess_images\nfrom monai.utils import ensure_tuple\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [  # 3D, batch 3, 2 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 3,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": 7,\n        \"conv1_t_stride\": (2, 2, 2),\n    },\n    (3, 2, 32, 64, 48),\n    (3, 2, 64, 64, 64),\n]\n\nTEST_CASE_2 = [  # 2D, batch 2, 1 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [7, 7],\n        \"conv1_t_stride\": [2, 2],\n    },\n    (2, 1, 32, 64),\n    (2, 1, 64, 64),\n]\n\nTEST_CASE_2_A = [  # 2D, batch 2, 1 input channel, shortcut type A\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"shortcut_type\": \"A\",\n        \"conv1_t_size\": (7, 7),\n        \"conv1_t_stride\": 2,\n    },\n    (2, 1, 32, 64),\n    (2, 1, 64, 64),\n]\n\nTEST_CASE_3 = [  # 1D, batch 1, 2 input channels\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n    (1, 2, 32),\n]\n\nTEST_CASES = []\nTEST_CASES = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]\n\n\nclass TestDetectorUtils(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_detector_utils(self, input_param, input_shape, expected_shape):\n        size_divisible = 32 * ensure_tuple(input_param[\"conv1_t_stride\"])[0]\n        input_data = torch.randn(input_shape)\n        result, _ = preprocess_images(input_data, input_param[\"spatial_dims\"], size_divisible, mode=\"constant\", value=1)\n        assert_allclose(expected_shape, result.shape, type_test=True, device_test=False, atol=0.1)\n\n        input_data = [torch.randn(input_shape[1:]) for _ in range(random.randint(1, 9))]\n        result, _ = preprocess_images(input_data, input_param[\"spatial_dims\"], size_divisible, mode=\"edge\")\n        expected_shape = (len(input_data),) + expected_shape[1:]\n        assert_allclose(expected_shape, result.shape, type_test=True, device_test=False, atol=0.1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/detection/utils/test_hardnegsampler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.utils.hard_negative_sampler import HardNegativeSampler\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE = [\n    [[], [], [], [torch.tensor([]), torch.tensor([])], [torch.tensor([]), torch.tensor([])]],\n    [\n        [0, 1],\n        [1, 0, 2, 3],\n        [0.1, 0.9, 0.4, 0.3, 0.3, 0.5],\n        [torch.tensor([0, 1]), torch.tensor([1, 0, 1, 1])],\n        [torch.tensor([1, 0]), torch.tensor([0, 1, 0, 0])],\n    ],\n]\n\nselect_sample_size_per_image = 6\npositive_fraction = 0.5\nmin_neg = 1\npool_size = 2\n\n\nclass TestSampleSlices(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, target_label0, target_label1, concat_fg_probs, expected_result_pos, expected_result_neg):\n        compute_dtypes = [torch.float16, torch.float32]\n        for compute_dtype in compute_dtypes:\n            sampler = HardNegativeSampler(select_sample_size_per_image, positive_fraction, min_neg, pool_size)\n            target_labels = [torch.tensor(target_label0), torch.tensor(target_label1)]\n            result_pos, result_neg = sampler(target_labels, torch.tensor(concat_fg_probs, dtype=compute_dtype))\n            for r, er in zip(result_pos, expected_result_pos):\n                assert_allclose(r, er)\n            for r, er in zip(result_neg, expected_result_neg):\n                assert_allclose(r, er)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/maisi/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/maisi/networks/test_autoencoderkl_maisi.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi\nfrom monai.networks import eval_mode\nfrom monai.utils import optional_import\n\ntqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\n_, has_einops = optional_import(\"einops\")\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nCASES_NO_ATTENTION = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": (1, 1, 1),\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"num_splits\": 2,\n            \"print_info\": False,\n        },\n        (1, 1, 32, 32, 32),\n        (1, 1, 32, 32, 32),\n        (1, 4, 8, 8, 8),\n    ]\n]\n\nCASES_ATTENTION = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, True),\n            \"num_res_blocks\": (1, 1, 1),\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": True,\n            \"with_decoder_nonlocal_attn\": True,\n            \"num_splits\": 2,\n            \"print_info\": False,\n        },\n        (1, 1, 32, 32, 32),\n        (1, 1, 32, 32, 32),\n        (1, 4, 8, 8, 8),\n    ]\n]\n\nif has_einops:\n    CASES = CASES_NO_ATTENTION + CASES_ATTENTION\nelse:\n    CASES = CASES_NO_ATTENTION\n\n\nclass TestAutoencoderKlMaisi(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n            self.assertEqual(result[2].shape, expected_latent_shape)\n\n    @parameterized.expand(CASES)\n    def test_shape_with_convtranspose_and_checkpointing(\n        self, input_param, input_shape, expected_shape, expected_latent_shape\n    ):\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpointing\": True, \"use_convtranspose\": True})\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n            self.assertEqual(result[2].shape, expected_latent_shape)\n\n    def test_model_channels_not_multiple_of_norm_num_group(self):\n        with self.assertRaises(ValueError):\n            AutoencoderKlMaisi(\n                spatial_dims=3,\n                in_channels=1,\n                out_channels=1,\n                num_channels=(24, 24, 24),\n                attention_levels=(False, False, False),\n                latent_channels=8,\n                num_res_blocks=(1, 1, 1),\n                norm_num_groups=16,\n                num_splits=2,\n                print_info=False,\n            )\n\n    def test_model_num_channels_not_same_size_of_attention_levels(self):\n        with self.assertRaises(ValueError):\n            AutoencoderKlMaisi(\n                spatial_dims=3,\n                in_channels=1,\n                out_channels=1,\n                num_channels=(24, 24, 24),\n                attention_levels=(False, False),\n                latent_channels=8,\n                num_res_blocks=(1, 1, 1),\n                norm_num_groups=16,\n                num_splits=2,\n                print_info=False,\n            )\n\n    def test_model_num_channels_not_same_size_of_num_res_blocks(self):\n        with self.assertRaises(ValueError):\n            AutoencoderKlMaisi(\n                spatial_dims=3,\n                in_channels=1,\n                out_channels=1,\n                num_channels=(24, 24),\n                attention_levels=(False, False, False),\n                latent_channels=8,\n                num_res_blocks=(8, 8, 8),\n                norm_num_groups=16,\n                num_splits=2,\n                print_info=False,\n            )\n\n    def test_shape_reconstruction(self):\n        input_param, input_shape, expected_shape, _ = CASES[0]\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.reconstruct(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):\n        input_param, input_shape, expected_shape, _ = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpointing\": True, \"use_convtranspose\": True})\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.reconstruct(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_shape_encode(self):\n        input_param, input_shape, _, expected_latent_shape = CASES[0]\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.encode(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_latent_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n\n    def test_shape_encode_with_convtranspose_and_checkpointing(self):\n        input_param, input_shape, _, expected_latent_shape = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpointing\": True, \"use_convtranspose\": True})\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.encode(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_latent_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n\n    def test_shape_sampling(self):\n        input_param, _, _, expected_latent_shape = CASES[0]\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.sampling(\n                torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)\n            )\n            self.assertEqual(result.shape, expected_latent_shape)\n\n    def test_shape_sampling_convtranspose_and_checkpointing(self):\n        input_param, _, _, expected_latent_shape = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpointing\": True, \"use_convtranspose\": True})\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.sampling(\n                torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)\n            )\n            self.assertEqual(result.shape, expected_latent_shape)\n\n    def test_shape_decode(self):\n        input_param, expected_input_shape, _, latent_shape = CASES[0]\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.decode(torch.randn(latent_shape).to(device))\n            self.assertEqual(result.shape, expected_input_shape)\n\n    def test_shape_decode_convtranspose_and_checkpointing(self):\n        input_param, expected_input_shape, _, latent_shape = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpointing\": True, \"use_convtranspose\": True})\n        net = AutoencoderKlMaisi(**input_param).to(device)\n        with eval_mode(net):\n            result = net.decode(torch.randn(latent_shape).to(device))\n            self.assertEqual(result.shape, expected_input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/maisi/networks/test_controlnet_maisi.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi\nfrom monai.networks import eval_mode\nfrom monai.utils import optional_import\n\n_, has_einops = optional_import(\"einops\")\n\nTEST_CASES = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"conditioning_embedding_in_channels\": 1,\n            \"conditioning_embedding_num_channels\": (8, 8),\n            \"use_checkpointing\": False,\n        },\n        6,\n        (1, 8, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"conditioning_embedding_in_channels\": 1,\n            \"conditioning_embedding_num_channels\": (8, 8),\n            \"use_checkpointing\": True,\n        },\n        6,\n        (1, 8, 4, 4, 4),\n    ],\n]\n\nTEST_CASES_CONDITIONAL = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"conditioning_embedding_in_channels\": 1,\n            \"conditioning_embedding_num_channels\": (8, 8),\n            \"use_checkpointing\": False,\n            \"with_conditioning\": True,\n            \"cross_attention_dim\": 2,\n        },\n        6,\n        (1, 8, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"conditioning_embedding_in_channels\": 1,\n            \"conditioning_embedding_num_channels\": (8, 8),\n            \"use_checkpointing\": True,\n            \"with_conditioning\": True,\n            \"cross_attention_dim\": 2,\n        },\n        6,\n        (1, 8, 4, 4, 4),\n    ],\n]\n\nTEST_CASES_ERROR = [\n    [\n        {\"spatial_dims\": 2, \"in_channels\": 1, \"with_conditioning\": True, \"cross_attention_dim\": None},\n        \"ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) \"\n        \"to be specified when with_conditioning=True.\",\n    ],\n    [\n        {\"spatial_dims\": 2, \"in_channels\": 1, \"with_conditioning\": False, \"cross_attention_dim\": 2},\n        \"ControlNet expects with_conditioning=True when specifying the cross_attention_dim.\",\n    ],\n    [\n        {\"spatial_dims\": 2, \"in_channels\": 1, \"num_channels\": (8, 16), \"norm_num_groups\": 16},\n        f\"ControlNet expects all channels to be a multiple of norm_num_groups, but got\"\n        f\" channels={(8, 16)} and norm_num_groups={16}\",\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_channels\": (8, 16),\n            \"attention_levels\": (True,),\n            \"norm_num_groups\": 8,\n        },\n        f\"ControlNet expects channels to have the same length as attention_levels, but got \"\n        f\"channels={(8, 16)} and attention_levels={(True,)}\",\n    ],\n]\n\n\nclass TestControlNet(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):\n        net = ControlNetMaisi(**input_param)\n        with eval_mode(net):\n            x = torch.rand((1, 1, 16, 16)) if input_param[\"spatial_dims\"] == 2 else torch.rand((1, 1, 16, 16, 16))\n            timesteps = torch.randint(0, 1000, (1,)).long()\n            controlnet_cond = (\n                torch.rand((1, 1, 32, 32)) if input_param[\"spatial_dims\"] == 2 else torch.rand((1, 1, 32, 32, 32))\n            )\n            result = net.forward(x, timesteps, controlnet_cond)\n            self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)\n            self.assertEqual(result[1].shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES_CONDITIONAL)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):\n        net = ControlNetMaisi(**input_param)\n        with eval_mode(net):\n            x = torch.rand((1, 1, 16, 16)) if input_param[\"spatial_dims\"] == 2 else torch.rand((1, 1, 16, 16, 16))\n            timesteps = torch.randint(0, 1000, (1,)).long()\n            controlnet_cond = (\n                torch.rand((1, 1, 32, 32)) if input_param[\"spatial_dims\"] == 2 else torch.rand((1, 1, 32, 32, 32))\n            )\n            context = torch.randn((1, 1, input_param[\"cross_attention_dim\"]))\n            result = net.forward(x, timesteps, controlnet_cond, context=context)\n            self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)\n            self.assertEqual(result[1].shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES_ERROR)\n    def test_error_input(self, input_param, expected_error):\n        with self.assertRaises(ValueError) as context:  # output shape too small\n            _ = ControlNetMaisi(**input_param)\n        runtime_error = context.exception\n        self.assertEqual(str(runtime_error), expected_error)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/maisi/networks/test_diffusion_model_unet_maisi.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi\nfrom monai.networks import eval_mode\nfrom monai.utils import optional_import\n\n_, has_einops = optional_import(\"einops\")\n\nUNCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": (1, 1, 2),\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, True, True),\n            \"num_head_channels\": (0, 2, 4),\n            \"norm_num_groups\": 8,\n        }\n    ],\n]\n\nUNCOND_CASES_3D = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": (0, 0, 4),\n            \"norm_num_groups\": 8,\n        }\n    ],\n]\n\nCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"upcast_attention\": True,\n        }\n    ],\n]\n\nDROPOUT_OK = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"dropout_cattn\": 0.25,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n        }\n    ],\n]\n\nDROPOUT_WRONG = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"num_channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"dropout_cattn\": 3.0,\n        }\n    ]\n]\n\n\nclass TestDiffusionModelUNetMaisi2D(unittest.TestCase):\n\n    @parameterized.expand(UNCOND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param):\n        net = DiffusionModelUNetMaisi(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_timestep_with_wrong_shape(self):\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_different_in_channel_out_channel(self):\n        in_channels = 6\n        out_channels = 3\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=2,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, out_channels, 16, 16))\n\n    def test_model_channels_not_multiple_of_norm_num_group(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNetMaisi(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                num_channels=(8, 8, 12),\n                attention_levels=(False, False, False),\n                norm_num_groups=8,\n            )\n\n    def test_attention_levels_with_different_length_num_head_channels(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNetMaisi(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                num_channels=(8, 8, 8),\n                attention_levels=(False, False, False),\n                num_head_channels=(0, 2),\n                norm_num_groups=8,\n            )\n\n    def test_num_res_blocks_with_different_length_channels(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNetMaisi(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=(1, 1),\n                num_channels=(8, 8, 8),\n                attention_levels=(False, False, False),\n                norm_num_groups=8,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self):\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            with_conditioning=True,\n            transformer_num_layers=1,\n            cross_attention_dim=3,\n            norm_num_groups=8,\n            num_head_channels=8,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                context=torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 32))\n\n    def test_with_conditioning_cross_attention_dim_none(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNetMaisi(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                num_channels=(8, 8, 8),\n                attention_levels=(False, False, True),\n                with_conditioning=True,\n                transformer_num_layers=1,\n                cross_attention_dim=None,\n                norm_num_groups=8,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_context_with_conditioning_none(self):\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            with_conditioning=False,\n            transformer_num_layers=1,\n            norm_num_groups=8,\n        )\n\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net.forward(\n                    x=torch.rand((1, 1, 16, 32)),\n                    timesteps=torch.randint(0, 1000, (1,)).long(),\n                    context=torch.rand((1, 1, 3)),\n                )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models_class_conditioning(self):\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=8,\n            num_head_channels=8,\n            num_class_embeds=2,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                class_labels=torch.randint(0, 2, (1,)).long(),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 32))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_models_no_class_labels(self):\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=8,\n            num_head_channels=8,\n            num_class_embeds=2,\n        )\n\n        with self.assertRaises(ValueError):\n            net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long())\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_model_channels_not_same_size_of_attention_levels(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNetMaisi(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                num_channels=(8, 8, 8),\n                attention_levels=(False, False),\n                norm_num_groups=8,\n                num_head_channels=8,\n                num_class_embeds=2,\n            )\n\n    @parameterized.expand(COND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_2d_models_shape(self, input_param):\n        net = DiffusionModelUNetMaisi(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)))\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n    @parameterized.expand(UNCOND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_additional_inputs(self, input_param):\n        input_param[\"include_top_region_index_input\"] = True\n        input_param[\"include_bottom_region_index_input\"] = True\n        input_param[\"include_spacing_input\"] = True\n        net = DiffusionModelUNetMaisi(**input_param)\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 16)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                top_region_index_tensor=torch.rand((1, 4)),\n                bottom_region_index_tensor=torch.rand((1, 4)),\n                spacing_tensor=torch.rand((1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n\nclass TestDiffusionModelUNetMaisi3D(unittest.TestCase):\n\n    @parameterized.expand(UNCOND_CASES_3D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param):\n        net = DiffusionModelUNetMaisi(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_different_in_channel_out_channel(self):\n        in_channels = 6\n        out_channels = 3\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=3,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_res_blocks=1,\n            num_channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=4,\n        )\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, out_channels, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self):\n        net = DiffusionModelUNetMaisi(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            num_channels=(16, 16, 16),\n            attention_levels=(False, False, True),\n            norm_num_groups=16,\n            with_conditioning=True,\n            transformer_num_layers=1,\n            cross_attention_dim=3,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 16, 16)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                context=torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n    # Test dropout specification for cross-attention blocks\n    @parameterized.expand(DROPOUT_WRONG)\n    def test_wrong_dropout(self, input_param):\n        with self.assertRaises(ValueError):\n            _ = DiffusionModelUNetMaisi(**input_param)\n\n    @parameterized.expand(DROPOUT_OK)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_right_dropout(self, input_param):\n        _ = DiffusionModelUNetMaisi(**input_param)\n\n    @parameterized.expand(UNCOND_CASES_3D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_additional_inputs(self, input_param):\n        input_param[\"include_top_region_index_input\"] = True\n        input_param[\"include_bottom_region_index_input\"] = True\n        input_param[\"include_spacing_input\"] = True\n        net = DiffusionModelUNetMaisi(**input_param)\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 16, 16)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                top_region_index_tensor=torch.rand((1, 4)),\n                bottom_region_index_tensor=torch.rand((1, 4)),\n                spacing_tensor=torch.rand((1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/nuclick/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/nuclick/test_nuclick_transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.nuclick.transforms import (\n    AddClickSignalsd,\n    AddLabelAsGuidanced,\n    AddPointGuidanceSignald,\n    ExtractPatchd,\n    FilterImaged,\n    FlattenLabeld,\n    PostFilterLabeld,\n    SetLabelClassd,\n    SplitLabeld,\n)\n\n# Data Definitions\nRGB_IMAGE_1 = np.array(\n    [[[0, 0, 0], [0, 1, 0], [0, 0, 1]], [[2, 0, 2], [0, 1, 0], [1, 0, 1]], [[3, 0, 2], [0, 1, 0], [1, 3, 1]]]\n)\n\nLABEL_1 = np.array(\n    [\n        [1, 1, 1, 0, 0, 0, 0],\n        [1, 1, 1, 0, 0, 0, 0],\n        [1, 1, 1, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0],\n        [1, 1, 1, 0, 1, 1, 1],\n        [1, 1, 1, 0, 1, 1, 1],\n        [1, 1, 1, 0, 1, 1, 1],\n    ],\n    dtype=np.uint8,\n)\n\nLABEL_1_1 = np.array(\n    [\n        [1, 1, 1, 0, 0, 1, 1],\n        [1, 1, 1, 0, 0, 1, 1],\n        [1, 1, 1, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0],\n        [1, 1, 1, 0, 2, 2, 2],\n        [1, 1, 1, 0, 2, 2, 2],\n        [1, 1, 1, 0, 2, 2, 2],\n    ],\n    dtype=np.uint8,\n)\n\nLABEL_2 = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.uint8)\n\nLABEL_3 = np.array([[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]], dtype=np.uint8)\n\nLABEL_4 = np.array([[[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]]], dtype=np.uint8)\n\nIL_IMAGE_1 = np.array(\n    [\n        [[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1]],\n        [[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1]],\n        [[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1]],\n    ]\n)\n\nIL_FG_IMAGE_1 = np.array([[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 1]])\n\nIL_LABEL_1 = np.array(\n    [[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=np.uint8\n)\n\nIL_OTHERS_1 = np.array(\n    [[[1, 1, 1, 1, 1], [2, 0, 0, 0, 2], [3, 0, 0, 0, 3], [4, 0, 0, 0, 4], [5, 5, 5, 5, 5]]], dtype=np.uint8\n)\n\nIL_IMAGE_2 = np.array(\n    [[[0, 0, 0], [0, 1, 0], [0, 0, 1]], [[0, 0, 0], [0, 1, 0], [0, 0, 1]], [[0, 0, 0], [0, 1, 0], [0, 0, 1]]]\n)\n\nIL_LABEL_2 = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]], dtype=np.uint8)\n\nPRED_1 = np.array(\n    [[[1, 1, 1, 1, 1], [2, 0, 0, 0, 2], [3, 0, 0, 0, 3], [4, 0, 0, 0, 4], [5, 5, 5, 5, 5]]], dtype=np.float32\n)\n\nNUC_POINTS_1 = np.array(\n    [\n        [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]]],\n        [[[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]],\n    ],\n    dtype=np.float32,\n)\nBB_1 = np.array([[0, 0, 5, 5], [1, 1, 6, 6]], dtype=np.uint8)\n\nDATA_FILTER_1 = {\"image\": RGB_IMAGE_1}\n\nDATA_FLATTEN_1 = {\"label\": LABEL_1}\nDATA_FLATTEN_2 = {\"label\": LABEL_2}\n\nDATA_EXTRACT_1 = {\"image\": IL_IMAGE_1, \"label\": IL_LABEL_1, \"centroid\": (2, 2)}\nDATA_EXTRACT_2 = {\"image\": IL_IMAGE_2, \"label\": IL_LABEL_2, \"centroid\": (1, 1)}\n\nDATA_SPLIT_1 = {\"label\": LABEL_3, \"mask_value\": 1}\nDATA_SPLIT_2 = {\"label\": LABEL_4, \"mask_value\": 4}\n\nDATA_GUIDANCE_1 = {\"image\": IL_IMAGE_1, \"label\": IL_LABEL_1, \"others\": IL_OTHERS_1, \"centroid\": (2, 2)}\n\nDATA_CLICK_1 = {\"image\": IL_IMAGE_1, \"foreground\": [[2, 2], [1, 1]]}\n\nDATA_LABEL_FILTER_1 = {\n    \"pred\": PRED_1,\n    \"nuc_points\": NUC_POINTS_1,\n    \"bounding_boxes\": BB_1,\n    \"img_height\": 6,\n    \"img_width\": 6,\n}\n\n# Result Definitions\nEXTRACT_RESULT_TC1 = np.array([[[0, 0, 0], [0, 0, 0], [0, 0, 1]]], dtype=np.uint8)\nEXTRACT_RESULT_TC2 = np.array([[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], dtype=np.uint8)\n\nSPLIT_RESULT_TC1 = np.array([[[1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], dtype=np.uint8)\nSPLIT_RESULT_TC2 = np.array([[[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=np.uint8)\n\n# Test Case Definitions\nFILTER_IMAGE_TEST_CASE_1 = [{\"keys\": \"image\", \"min_size\": 1}, DATA_FILTER_1, [3, 3, 3]]\n\nFLATTEN_LABEL_TEST_CASE_1 = [{\"keys\": \"label\"}, DATA_FLATTEN_1, [0, 1, 2, 3]]\nFLATTEN_LABEL_TEST_CASE_2 = [{\"keys\": \"label\"}, DATA_FLATTEN_2, [0]]\nFLATTEN_LABEL_TEST_CASE_3 = [{\"keys\": \"label\"}, {\"label\": LABEL_1_1}, [0, 1, 2, 3, 4]]\n\nEXTRACT_TEST_CASE_1 = [{\"keys\": [\"image\", \"label\"], \"patch_size\": 3}, DATA_EXTRACT_1, [1, 3, 3]]\nEXTRACT_TEST_CASE_2 = [{\"keys\": [\"image\", \"label\"], \"patch_size\": 5}, DATA_EXTRACT_1, [1, 5, 5]]\nEXTRACT_TEST_CASE_3 = [{\"keys\": [\"image\", \"label\"], \"patch_size\": 1}, DATA_EXTRACT_2, [1, 1, 1]]\n\nEXTRACT_RESULT_TEST_CASE_1 = [{\"keys\": [\"image\", \"label\"], \"patch_size\": 3}, DATA_EXTRACT_1, EXTRACT_RESULT_TC1]\nEXTRACT_RESULT_TEST_CASE_2 = [{\"keys\": [\"image\", \"label\"], \"patch_size\": 4}, DATA_EXTRACT_2, EXTRACT_RESULT_TC2]\n\nEXTRACT_KW_TEST_CASE_1 = [\n    {\"keys\": [\"image\", \"label\"], \"patch_size\": 3, \"mode\": \"constant\"},\n    DATA_EXTRACT_1,\n    EXTRACT_RESULT_TC1,\n]\n\nSPLIT_TEST_CASE_1 = [{\"keys\": [\"label\"], \"mask_value\": \"mask_value\", \"min_area\": 1}, DATA_SPLIT_1, SPLIT_RESULT_TC1]\nSPLIT_TEST_CASE_2 = [{\"keys\": [\"label\"], \"mask_value\": \"mask_value\", \"min_area\": 3}, DATA_SPLIT_2, SPLIT_RESULT_TC2]\n\nGUIDANCE_TEST_CASE_1 = [{\"image\": \"image\", \"label\": \"label\", \"others\": \"others\"}, DATA_GUIDANCE_1, [5, 5, 5]]\nGUIDANCE_TEST_CASE_2 = [\n    {\"image\": \"image\", \"label\": \"label\", \"others\": \"others\", \"gaussian\": True, \"use_distance\": True},\n    DATA_GUIDANCE_1,\n    [5, 5, 5],\n]\n\nCLICK_TEST_CASE_1 = [{\"image\": \"image\", \"foreground\": \"foreground\", \"bb_size\": 4}, DATA_CLICK_1, [2, 5, 4, 4]]\nCLICK_TEST_CASE_2 = [\n    {\"image\": \"image\", \"foreground\": \"foreground\", \"bb_size\": 4, \"gaussian\": True},\n    DATA_CLICK_1,\n    [2, 5, 4, 4],\n]\n\nLABEL_FILTER_TEST_CASE_1 = [{\"keys\": [\"pred\"]}, DATA_LABEL_FILTER_1, [6, 6]]\n\nLABEL_GUIDANCE_TEST_CASE_1 = [{\"keys\": [\"image\"], \"source\": \"label\"}, DATA_GUIDANCE_1, [4, 5, 5]]\n\nLABEL_CLASS_TEST_CASE_1 = [{\"keys\": [\"label\"], \"offset\": 2}, DATA_GUIDANCE_1, 3]\n\n# Test Case Classes\n\n\nclass TestFilterImaged(unittest.TestCase):\n\n    @parameterized.expand([FILTER_IMAGE_TEST_CASE_1])\n    def test_correct_shape(self, arguments, input_data, expected_shape):\n        result = FilterImaged(**arguments)(input_data)\n        np.testing.assert_equal(result[\"image\"].shape, expected_shape)\n\n\nclass TestFlattenLabeld(unittest.TestCase):\n\n    @parameterized.expand([FLATTEN_LABEL_TEST_CASE_1, FLATTEN_LABEL_TEST_CASE_2, FLATTEN_LABEL_TEST_CASE_3])\n    def test_correct_num_labels(self, arguments, input_data, expected_result):\n        result = FlattenLabeld(**arguments)(input_data)\n        np.testing.assert_equal(np.unique(result[\"label\"]), expected_result)\n\n\nclass TestExtractPatchd(unittest.TestCase):\n\n    @parameterized.expand([EXTRACT_TEST_CASE_1, EXTRACT_TEST_CASE_2, EXTRACT_TEST_CASE_3])\n    def test_correct_patch_size(self, arguments, input_data, expected_shape):\n        result = ExtractPatchd(**arguments)(input_data)\n        np.testing.assert_equal(result[\"label\"].shape, expected_shape)\n\n    @parameterized.expand([EXTRACT_RESULT_TEST_CASE_1, EXTRACT_RESULT_TEST_CASE_2, EXTRACT_KW_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = ExtractPatchd(**arguments)(input_data)\n        np.testing.assert_equal(result[\"label\"], expected_result)\n\n\nclass TestSplitLabelsd(unittest.TestCase):\n\n    @parameterized.expand([SPLIT_TEST_CASE_1, SPLIT_TEST_CASE_2])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = SplitLabeld(**arguments)(input_data)\n        np.testing.assert_equal(result[\"label\"], expected_result)\n\n\nclass TestGuidanceSignal(unittest.TestCase):\n\n    @parameterized.expand([GUIDANCE_TEST_CASE_1, GUIDANCE_TEST_CASE_2])\n    def test_correct_shape(self, arguments, input_data, expected_shape):\n        result = AddPointGuidanceSignald(**arguments)(input_data)\n        np.testing.assert_equal(result[\"image\"].shape, expected_shape)\n\n\nclass TestClickSignal(unittest.TestCase):\n\n    @parameterized.expand([CLICK_TEST_CASE_1, CLICK_TEST_CASE_2])\n    def test_correct_shape(self, arguments, input_data, expected_shape):\n        result = AddClickSignalsd(**arguments)(input_data)\n        np.testing.assert_equal(result[\"image\"].shape, expected_shape)\n\n\nclass TestPostFilterLabel(unittest.TestCase):\n\n    @parameterized.expand([LABEL_FILTER_TEST_CASE_1])\n    def test_correct_shape(self, arguments, input_data, expected_shape):\n        result = PostFilterLabeld(**arguments)(input_data)\n        np.testing.assert_equal(result[\"pred\"].shape, expected_shape)\n\n\nclass TestAddLabelAsGuidance(unittest.TestCase):\n\n    @parameterized.expand([LABEL_GUIDANCE_TEST_CASE_1])\n    def test_correct_shape(self, arguments, input_data, expected_shape):\n        result = AddLabelAsGuidanced(**arguments)(input_data)\n        np.testing.assert_equal(result[\"image\"].shape, expected_shape)\n\n\nclass TestSetLabelClass(unittest.TestCase):\n\n    @parameterized.expand([LABEL_CLASS_TEST_CASE_1])\n    def test_correct_results(self, arguments, input_data, expected_result):\n        result = SetLabelClassd(**arguments)(input_data)\n        assert result[\"label\"] == expected_result\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/pathology/handlers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/pathology/handlers/test_from_engine_hovernet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.handlers.utils import from_engine_hovernet\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_0 = [\n    [{\"A\": {\"C\": 1, \"D\": 2}, \"B\": {\"C\": 2, \"D\": 2}}, {\"A\": {\"C\": 3, \"D\": 2}, \"B\": {\"C\": 4, \"D\": 2}}],\n    ([1, 3], [2, 4]),\n]\nTEST_CASE_1 = [{\"A\": {\"C\": 1, \"D\": 2}, \"B\": {\"C\": 2, \"D\": 2}}, (1, 2)]\n\nCASES = [TEST_CASE_0, TEST_CASE_1]\n\n\nclass TestFromEngineHovernet(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_results(self, input, expected):\n        output = from_engine_hovernet(keys=[\"A\", \"B\"], nested_key=\"C\")(input)\n        assert_allclose(output, expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/test_lesion_froc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.metrics import LesionFROC\nfrom monai.utils import optional_import\n\n_cucim, has_cucim = optional_import(\"cucim\")\nhas_cucim = has_cucim and hasattr(_cucim, \"CuImage\")\n_, has_skimage = optional_import(\"skimage.measure\")\n_, has_sp = optional_import(\"scipy.ndimage\")\nimwrite, has_tif = optional_import(\"tifffile\", name=\"imwrite\")\n\n\ndef save_as_tif(filename, array):\n    array = array[::-1, ...]  # Upside-down\n    if not filename.endswith(\".tif\"):\n        filename += \".tif\"\n    file_path = os.path.join(\"tests\", \"testing_data\", filename)\n    imwrite(file_path, array, compression=\"jpeg\", tile=(16, 16))\n\n\ndef around(val, interval=3):\n    return slice(val - interval, val + interval)\n\n\n# mask and prediction image size\nHEIGHT = 101\nWIDTH = 800\n\n\ndef prepare_test_data():\n    # -------------------------------------\n    # Ground Truth - Binary Masks\n    # -------------------------------------\n    # ground truth with no tumor\n    ground_truth = np.zeros((HEIGHT, WIDTH), dtype=np.uint8)\n    save_as_tif(\"temp_ground_truth_0\", ground_truth)\n\n    # ground truth with one tumor\n    ground_truth[around(HEIGHT // 2), around(1 * WIDTH // 7)] = 1\n    save_as_tif(\"temp_ground_truth_1\", ground_truth)\n\n    # ground truth with two tumors\n    ground_truth[around(HEIGHT // 2), around(2 * WIDTH // 7)] = 1\n    save_as_tif(\"temp_ground_truth_2\", ground_truth)\n\n    # ground truth with three tumors\n    ground_truth[around(HEIGHT // 2), around(3 * WIDTH // 7)] = 1\n    save_as_tif(\"temp_ground_truth_3\", ground_truth)\n\n    # ground truth with four tumors\n    ground_truth[around(HEIGHT // 2), around(4 * WIDTH // 7)] = 1\n    save_as_tif(\"temp_ground_truth_4\", ground_truth)\n\n    # -------------------------------------\n    # predictions - Probability Maps\n    # -------------------------------------\n\n    # prediction with no tumor\n    prob_map = np.zeros((HEIGHT, WIDTH))\n    np.save(\"./tests/testing_data/temp_prob_map_0_0.npy\", prob_map)\n\n    # prediction with one incorrect tumor\n    prob_map[HEIGHT // 2, 5 * WIDTH // 7] = 0.6\n    np.save(\"./tests/testing_data/temp_prob_map_0_1.npy\", prob_map)\n\n    # prediction with correct first tumors and an incorrect tumor\n    prob_map[HEIGHT // 2, 1 * WIDTH // 7] = 0.8\n    np.save(\"./tests/testing_data/temp_prob_map_1_1.npy\", prob_map)\n\n    # prediction with correct firt two tumors and an incorrect tumor\n    prob_map[HEIGHT // 2, 2 * WIDTH // 7] = 0.8\n    np.save(\"./tests/testing_data/temp_prob_map_2_1.npy\", prob_map)\n\n    # prediction with two incorrect tumors\n    prob_map = np.zeros((HEIGHT, WIDTH))\n    prob_map[HEIGHT // 2, 5 * WIDTH // 7] = 0.6\n    prob_map[HEIGHT // 2, 6 * WIDTH // 7] = 0.4\n    np.save(\"./tests/testing_data/temp_prob_map_0_2.npy\", prob_map)\n\n    # prediction with correct first tumors and two incorrect tumors\n    prob_map[HEIGHT // 2, 1 * WIDTH // 7] = 0.8\n    np.save(\"./tests/testing_data/temp_prob_map_1_2.npy\", prob_map)\n\n\nTEST_CASE_0 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_0_0.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_0.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            }\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    np.nan,\n]\n\nTEST_CASE_1 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_0_0.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_1.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            }\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    0.0,\n]\n\nTEST_CASE_2 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_1.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            }\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    1.0,\n]\n\nTEST_CASE_3 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_2_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_1.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            }\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    1.0,\n]\n\nTEST_CASE_4 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_2_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_2.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            }\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    1.0,\n]\n\nTEST_CASE_5 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_2.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_2.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            }\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    0.5,\n]\n\nTEST_CASE_6 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_1.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_2.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_2.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    2.0 / 3.0,\n]\n\nTEST_CASE_7 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_3.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_2.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_2.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    0.4,\n]\n\nTEST_CASE_8 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_0_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_1.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_3.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_2.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_2.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    1.0 / 3.0,\n]\n\nTEST_CASE_9 = [\n    {\n        \"data\": [\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_0_2.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_4.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_1.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_3.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n            {\n                \"prob_map\": \"./tests/testing_data/temp_prob_map_1_2.npy\",\n                \"tumor_mask\": \"./tests/testing_data/temp_ground_truth_2.tif\",\n                \"level\": 0,\n                \"pixel_spacing\": 1,\n            },\n        ],\n        \"grow_distance\": 2,\n        \"itc_diameter\": 0,\n    },\n    2.0 / 9.0,\n]\n\n\nclass TestEvaluateTumorFROC(unittest.TestCase):\n\n    @skipUnless(has_cucim, \"Requires cucim\")\n    @skipUnless(has_skimage, \"Requires skimage\")\n    @skipUnless(has_sp, \"Requires scipy\")\n    @skipUnless(has_tif, \"Requires tifffile\")\n    def setUp(self):\n        prepare_test_data()\n\n    @parameterized.expand(\n        [\n            TEST_CASE_0,\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n        ]\n    )\n    def test_read_patches_cucim(self, input_parameters, expected):\n        froc = LesionFROC(**input_parameters)\n        froc_score = froc.evaluate()\n        if np.isnan(expected):\n            self.assertTrue(np.isnan(froc_score))\n        else:\n            self.assertAlmostEqual(froc_score, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/test_pathology_prob_nms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.utils import PathologyProbNMS\n\nprobs_map_2d = np.random.rand(100, 100).clip(0, 0.5)\nprobs_map_2d[33, 33] = 0.7\nprobs_map_2d[66, 66] = 0.9\nexpected_2d = [[0.9, 133, 133], [0.7, 67, 67]]\nTEST_CASES_2D = [\n    {\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": [10, 10]},\n    {\"resolution_level\": 1},\n    probs_map_2d,\n    expected_2d,\n]\n\nprobs_map_3d = torch.rand([50, 50, 50]).uniform_(0, 0.5)\nprobs_map_3d[25, 25, 25] = 0.7\nprobs_map_3d[45, 45, 45] = 0.9\nexpected_3d = [[0.9, 91, 91, 91], [0.7, 51, 51, 51]]\nTEST_CASES_3D = [\n    {\"spatial_dims\": 3, \"prob_threshold\": 0.5, \"box_size\": (10, 10, 10)},\n    {\"resolution_level\": 1},\n    probs_map_3d,\n    expected_3d,\n]\n\n\nclass TestPathologyProbNMS(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D])\n    def test_output(self, class_args, call_args, probs_map, expected):\n        nms = PathologyProbNMS(**class_args)\n        output = nms(probs_map, **call_args)\n        np.testing.assert_allclose(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/test_prepare_batch_hovernet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.engines import PrepareBatchHoVerNet\nfrom monai.engines import SupervisedEvaluator\nfrom monai.utils.enums import HoVerNetBranch\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_0 = [\n    {\"extra_keys\": [\"extra_label1\", \"extra_label2\"]},\n    {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16},\n]\n\n\nclass TestNet(torch.nn.Module):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def forward(self, x: torch.Tensor):\n        return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16}\n\n\nclass TestPrepareBatchHoVerNet(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_0])\n    def test_content(self, input_args, expected_value):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        dataloader = [\n            {\n                \"image\": torch.tensor([1, 2]),\n                \"label\": torch.tensor([1, 2]),\n                \"extra_label1\": torch.tensor([3, 4]),\n                \"extra_label2\": 16,\n            }\n        ]\n        # set up engine\n        evaluator = SupervisedEvaluator(\n            device=device,\n            val_data_loader=dataloader,\n            epoch_length=1,\n            network=TestNet(),\n            non_blocking=True,\n            prepare_batch=PrepareBatchHoVerNet(**input_args),\n            decollate=False,\n        )\n        evaluator.run()\n        output = evaluator.state.output\n        assert_allclose(output[\"image\"], torch.tensor([1, 2], device=device))\n        for k, v in output[\"pred\"].items():\n            if isinstance(v, torch.Tensor):\n                assert_allclose(v, expected_value[k].to(device))\n            else:\n                self.assertEqual(v, expected_value[k])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/test_sliding_window_hovernet_inference.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.inferers import SlidingWindowHoVerNetInferer\nfrom monai.data import MetaTensor\nfrom monai.inferers import sliding_window_inference\nfrom monai.utils import optional_import\nfrom tests.inferers.test_sliding_window_inference import TEST_CASES\n\n_, has_tqdm = optional_import(\"tqdm\")\n\nTEST_CASES_PADDING = [\n    [None, (1, 3, 16, 8), (4, 4), 7, 0.5, \"constant\", torch.device(\"cpu:0\"), None],\n    [\"hover\", (1, 3, 16, 8), (4, 4), 7, 0.5, \"constant\", torch.device(\"cpu:0\"), None],\n    [None, (1, 3, 16, 8), (4, 4), 7, 0.5, \"constant\", torch.device(\"cpu:0\"), (1,) * 4],\n    [\"hover\", (1, 3, 16, 8), (4, 4), 7, 0.5, \"constant\", torch.device(\"cpu:0\"), (1,) * 4],\n]\n\nTEST_CASES_MULTIOUTPUT = [[torch.ones((1, 6, 20, 20))], [MetaTensor(torch.ones((1, 6, 20, 20)))]]\n\n\nclass TestSlidingWindowHoVerNetInference(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_PADDING)\n    def test_sliding_window_with_padding(\n        self, key, image_shape, roi_shape, sw_batch_size, overlap, mode, device, extra_input_padding\n    ):\n        n_total = np.prod(image_shape)\n        if mode == \"constant\":\n            inputs = torch.arange(n_total, dtype=torch.float).reshape(*image_shape)\n        else:\n            inputs = torch.ones(*image_shape, dtype=torch.float)\n        if device.type == \"cuda\" and not torch.cuda.is_available():\n            device = torch.device(\"cpu:0\")\n\n        def compute(data):\n            if key:\n                return {key: torch.clone(data[..., 1:-1, 1:-1]) + 1}\n            else:\n                return torch.clone(data[..., 1:-1, 1:-1]) + 1\n\n        if mode == \"constant\":\n            expected_val = np.arange(n_total, dtype=np.float32).reshape(*image_shape) + 1.0\n        else:\n            expected_val = np.ones(image_shape, dtype=np.float32) + 1.0\n\n        if extra_input_padding is None:\n            expected_val[..., 0, :] = expected_val[..., -1, :] = None\n            expected_val[..., 0] = expected_val[..., -1] = None\n\n        sliding_inference = SlidingWindowHoVerNetInferer(\n            roi_shape, sw_batch_size, overlap, mode, extra_input_padding=extra_input_padding\n        )\n        result = sliding_inference(inputs.to(device), compute)\n        result = result[key] if key else result\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n    @parameterized.expand(TEST_CASES)\n    def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device):\n        n_total = np.prod(image_shape)\n        if mode == \"constant\":\n            inputs = torch.arange(n_total, dtype=torch.float).reshape(*image_shape)\n        else:\n            inputs = torch.ones(*image_shape, dtype=torch.float)\n        if device.type == \"cuda\" and not torch.cuda.is_available():\n            device = torch.device(\"cpu:0\")\n\n        def compute(data):\n            return data + 1\n\n        if mode == \"constant\":\n            expected_val = np.arange(n_total, dtype=np.float32).reshape(*image_shape) + 1.0\n        else:\n            expected_val = np.ones(image_shape, dtype=np.float32) + 1.0\n        result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode)\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n        result = SlidingWindowHoVerNetInferer(roi_shape, sw_batch_size, overlap, mode)(inputs.to(device), compute)\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n    def test_sigma(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 7, 7)).to(device=device)\n        roi_shape = (3, 3)\n        sw_batch_size = 10\n\n        class _Pred:\n            add = 1\n\n            def compute(self, data):\n                self.add += 1\n                return data + self.add\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            _Pred().compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"constant\",\n            sigma_scale=1.0,\n        )\n\n        expected = np.array(\n            [\n                [\n                    [\n                        [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],\n                        [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],\n                        [3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333],\n                        [3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667],\n                        [4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333],\n                        [4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000],\n                        [5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000],\n                    ]\n                ]\n            ]\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            _Pred().compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"gaussian\",\n            sigma_scale=1.0,\n            progress=has_tqdm,\n        )\n        expected = np.array(\n            [\n                [\n                    [\n                        [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],\n                        [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],\n                        [3.3271625, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271625],\n                        [3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377],\n                        [4.3271623, 4.3271623, 4.3271627, 4.3271627, 4.3271627, 4.3271623, 4.3271623],\n                        [4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757],\n                        [4.9999995, 5.0, 5.0, 5.0, 5.0, 5.0, 4.9999995],\n                    ]\n                ]\n            ]\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowHoVerNetInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=1.0)(\n            inputs, _Pred().compute\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowHoVerNetInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=[1.0, 1.0]\n        )(inputs, _Pred().compute)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowHoVerNetInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=[1.0, 1.0], cache_roi_weight_map=True\n        )(inputs, _Pred().compute)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_cval(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 3, 3)).to(device=device)\n        roi_shape = (5, 5)\n        sw_batch_size = 10\n\n        def compute(data):\n            return data + data.sum()\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"constant\",\n            sigma_scale=1.0,\n        )\n        expected = np.ones((1, 1, 3, 3)) * -6.0\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowHoVerNetInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1)(\n            inputs, compute\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_args_kwargs(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 3, 3)).to(device=device)\n        t1 = torch.ones(1).to(device=device)\n        t2 = torch.ones(1).to(device=device)\n        roi_shape = (5, 5)\n        sw_batch_size = 10\n\n        def compute(data, test1, test2):\n            return data + test1 + test2\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n            None,\n            None,\n            0,\n            False,\n            t1,\n            test2=t2,\n        )\n        expected = np.ones((1, 1, 3, 3)) + 2.0\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowHoVerNetInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute, t1, test2=t2)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    @parameterized.expand(TEST_CASES_MULTIOUTPUT)\n    def test_multioutput(self, inputs):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = inputs.to(device=device)\n        roi_shape = (8, 8)\n        sw_batch_size = 10\n\n        def compute(data):\n            return data + 1, data[:, ::3, ::2, ::2] + 2, data[:, ::2, ::4, ::4] + 3\n\n        def compute_dict(data):\n            return {1: data + 1, 2: data[:, ::3, ::2, ::2] + 2, 3: data[:, ::2, ::4, ::4] + 3}\n\n        result = sliding_window_inference(\n            inputs, roi_shape, sw_batch_size, compute, 0.5, \"constant\", 1.0, \"constant\", 0.0, device, device, has_tqdm\n        )\n        result_dict = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute_dict,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n        )\n        expected = (np.ones((1, 6, 20, 20)) + 1, np.ones((1, 2, 10, 10)) + 2, np.ones((1, 3, 5, 5)) + 3)\n        expected_dict = {1: np.ones((1, 6, 20, 20)) + 1, 2: np.ones((1, 2, 10, 10)) + 2, 3: np.ones((1, 3, 5, 5)) + 3}\n        for rr, ee in zip(result, expected):\n            np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4)\n        for rr, _ in zip(result_dict, expected_dict):\n            np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)\n\n        result = SlidingWindowHoVerNetInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute)\n        for rr, ee in zip(result, expected):\n            np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4)\n\n        result_dict = SlidingWindowHoVerNetInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute_dict)\n        for rr, _ in zip(result_dict, expected_dict):\n            np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_distance_map.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateDistanceMap\nfrom monai.transforms.intensity.array import GaussianSmooth\nfrom tests.test_utils import TEST_NDARRAYS\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append([{}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError])\n    EXCEPTION_TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError])\n\n    TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)])\n    TESTS.append(\n        [{\"smooth_fn\": GaussianSmooth(sigma=0.4)}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)]\n    )\n\n\nclass TestGenerateDistanceMap(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, mask, probmap, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateDistanceMap(**arguments)(mask, probmap)\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, mask, probmap, expected_shape):\n        result = GenerateDistanceMap(**arguments)(mask, probmap)\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_distance_mapd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateDistanceMapd\nfrom monai.transforms.intensity.array import GaussianSmooth\nfrom tests.test_utils import TEST_NDARRAYS\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append(\n        [\n            {\"mask_key\": \"mask\", \"border_key\": \"border\"},\n            p(np.random.rand(2, 5, 5)),\n            p(np.random.rand(1, 5, 5)),\n            ValueError,\n        ]\n    )\n    EXCEPTION_TESTS.append(\n        [\n            {\"mask_key\": \"mask\", \"border_key\": \"border\"},\n            p(np.random.rand(1, 5, 5)),\n            p(np.random.rand(2, 5, 5)),\n            ValueError,\n        ]\n    )\n\n    TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)])\n    TESTS.append(\n        [\n            {\"mask_key\": \"mask\", \"border_key\": \"border\", \"smooth_fn\": GaussianSmooth(sigma=0.4)},\n            p(np.random.rand(1, 5, 5)),\n            p(np.random.rand(1, 5, 5)),\n            (1, 5, 5),\n        ]\n    )\n\n\nclass TestGenerateDistanceMapd(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, mask, border_map, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateDistanceMapd(**arguments)({\"mask\": mask, \"border\": border_map})\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, mask, border_map, expected_shape):\n        result = GenerateDistanceMapd(**arguments)({\"mask\": mask, \"border\": border_map})\n        self.assertEqual(result[\"dist_map\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_border.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateInstanceBorder\nfrom tests.test_utils import TEST_NDARRAYS\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append([{\"kernel_size\": 3}, p(np.random.rand(1, 5, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError])\n    EXCEPTION_TESTS.append([{\"kernel_size\": 3}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError])\n    EXCEPTION_TESTS.append([{\"kernel_size\": 3}, p(np.random.rand(2, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError])\n\n    TESTS.append([{\"kernel_size\": 3}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), (1, 5, 5)])\n    TESTS.append([{\"kernel_size\": 3}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), (1, 5, 5)])\n\n\nclass TestGenerateInstanceBorder(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, mask, hover_map, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateInstanceBorder(**arguments)(mask, hover_map)\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, mask, hover_map, expected_shape):\n        result = GenerateInstanceBorder(**arguments)(mask, hover_map)\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_borderd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateInstanceBorderd\nfrom tests.test_utils import TEST_NDARRAYS\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append(\n        [{\"mask_key\": \"mask\", \"kernel_size\": 3}, p(np.random.rand(1, 5, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError]\n    )\n    EXCEPTION_TESTS.append(\n        [{\"mask_key\": \"mask\", \"kernel_size\": 3}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError]\n    )\n    EXCEPTION_TESTS.append(\n        [{\"mask_key\": \"mask\", \"kernel_size\": 3}, p(np.random.rand(2, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError]\n    )\n\n    TESTS.append(\n        [{\"mask_key\": \"mask\", \"kernel_size\": 3}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), (1, 5, 5)]\n    )\n    TESTS.append(\n        [{\"mask_key\": \"mask\", \"kernel_size\": 3}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), (1, 5, 5)]\n    )\n\n\nclass TestGenerateInstanceBorderd(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, mask, hover_map, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateInstanceBorderd(**arguments)({\"mask\": mask, \"hover_map\": hover_map})\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, mask, hover_map, expected_shape):\n        result = GenerateInstanceBorderd(**arguments)({\"mask\": mask, \"hover_map\": hover_map})\n        self.assertEqual(result[\"border\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_centroid.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateInstanceCentroid\nfrom monai.transforms import BoundingRect\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nget_bbox = BoundingRect()\n\nTEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, [0, 0], [2, 2]]\n\nTEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, [6, 6], [8, 8]]\n\nTEST_CASE_3 = [(x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, [2, 3], [4, 6]]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, *TEST_CASE_1])\n    TEST_CASE.append([p, *TEST_CASE_2])\n    TEST_CASE.append([p, *TEST_CASE_3])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestGenerateInstanceCentroid(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, in_type, test_data, offset, expected):\n        inst_bbox = get_bbox(test_data[None])\n        inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]\n        result = GenerateInstanceCentroid()(in_type(inst_map[None]), offset=offset)\n        assert_allclose(result, expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_centroidd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateInstanceCentroidd\nfrom monai.transforms import BoundingRect\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nget_bbox = BoundingRect()\n\nTEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, [0, 0], [2, 2]]\n\nTEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, [6, 6], [8, 8]]\n\nTEST_CASE_3 = [(x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, [2, 3], [4, 6]]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, *TEST_CASE_1])\n    TEST_CASE.append([p, *TEST_CASE_2])\n    TEST_CASE.append([p, *TEST_CASE_3])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestGenerateInstanceCentroidd(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, in_type, test_data, offset, expected):\n        inst_bbox = get_bbox(test_data[None])\n        inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]\n        test_case = {\"image\": in_type(inst_map[None]), \"offset\": offset}\n        result = GenerateInstanceCentroidd(keys=\"image\", centroid_key_postfix=\"centroid\", offset_key=\"offset\")(\n            test_case\n        )\n        assert_allclose(result[\"image_centroid\"], expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_contour.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateInstanceContour\nfrom monai.transforms import BoundingRect\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nget_bbox = BoundingRect()\n\nTEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, 3, [0, 0], [[2, 0], [0, 2], [2, 4], [4, 2]]]\n\nTEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, 3, [8, 8], [[10, 8], [8, 10], [10, 12], [12, 10]]]\n\nTEST_CASE_3 = [\n    (x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1,\n    3,\n    [2, 3],\n    [[5, 3], [4, 4], [3, 4], [2, 5], [3, 6], [4, 6], [5, 7], [6, 6], [7, 6], [8, 5], [7, 4], [6, 4]],\n]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, *TEST_CASE_1])\n    TEST_CASE.append([p, *TEST_CASE_2])\n    TEST_CASE.append([p, *TEST_CASE_3])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestGenerateInstanceContour(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, in_type, test_data, min_num_points, offset, expected):\n        inst_bbox = get_bbox(test_data[None])\n        inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]\n        result = GenerateInstanceContour(min_num_points=min_num_points)(in_type(inst_map[None]), offset=offset)\n        assert_allclose(result, expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_contourd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateInstanceContourd\nfrom monai.transforms import BoundingRect\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nget_bbox = BoundingRect()\n\nTEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, 3, [0, 0], [[2, 0], [0, 2], [2, 4], [4, 2]]]\n\nTEST_CASE_2 = [(x - 10) ** 2 + (y - 10) ** 2 <= 2**2, 3, [8, 8], [[10, 8], [8, 10], [10, 12], [12, 10]]]\n\nTEST_CASE_3 = [\n    (x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1,\n    3,\n    [2, 3],\n    [[5, 3], [4, 4], [3, 4], [2, 5], [3, 6], [4, 6], [5, 7], [6, 6], [7, 6], [8, 5], [7, 4], [6, 4]],\n]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, *TEST_CASE_1])\n    TEST_CASE.append([p, *TEST_CASE_2])\n    TEST_CASE.append([p, *TEST_CASE_3])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestGenerateInstanceContourd(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, in_type, test_data, min_num_points, offset, expected):\n        inst_bbox = get_bbox(test_data[None])\n        inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]\n        test_data = {\"image\": in_type(inst_map[None]), \"offset\": offset}\n        result = GenerateInstanceContourd(\n            keys=\"image\", contour_key_postfix=\"contour\", offset_key=\"offset\", min_num_points=min_num_points\n        )(test_data)\n        assert_allclose(result[\"image_contour\"], expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_type.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateInstanceType\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\ny, x = np.ogrid[0:30, 0:30]\n\nTEST_CASE_1 = [\n    (x - 2) ** 2 + (y - 2) ** 2 <= 2**2,\n    (x - 2) ** 2 + (y - 3) ** 2 <= 2**2,\n    np.array([[0, 5, 0, 5]]),\n    [1, 0.6666666111111158],\n]\n\nTEST_CASE_2 = [\n    (x - 8) ** 2 / 3**2 + (y - 8) ** 2 / 2**2 <= 1,\n    (x - 7) ** 2 / 3**2 + (y - 7) ** 2 / 2**2 <= 1,\n    np.array([[6, 11, 5, 12]]),\n    [1, 0.7058823114186875],\n]\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, *TEST_CASE_1])\n    TEST_CASE.append([p, *TEST_CASE_2])\n\n\nclass TestGenerateInstanceType(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, in_type, type_pred, seg_pred, bbox, expected):\n        result = GenerateInstanceType()(in_type(type_pred[None]), in_type(seg_pred[None]), bbox, 1)\n        assert_allclose(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_instance_typed.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateInstanceTyped\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\ny, x = np.ogrid[0:30, 0:30]\n\nTEST_CASE_1 = [\n    (x - 2) ** 2 + (y - 2) ** 2 <= 2**2,\n    (x - 2) ** 2 + (y - 3) ** 2 <= 2**2,\n    np.array([[0, 5, 0, 5]]),\n    [1, 0.6666666111111158],\n]\n\nTEST_CASE_2 = [\n    (x - 8) ** 2 / 3**2 + (y - 8) ** 2 / 2**2 <= 1,\n    (x - 7) ** 2 / 3**2 + (y - 7) ** 2 / 2**2 <= 1,\n    np.array([[6, 11, 5, 12]]),\n    [1, 0.7058823114186875],\n]\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, *TEST_CASE_1])\n    TEST_CASE.append([p, *TEST_CASE_2])\n\n\nclass TestGenerateInstanceTyped(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_shape(self, in_type, type_pred, seg_pred, bbox, expected):\n        test_data = {\"type_pred\": in_type(type_pred[None]), \"seg\": in_type(seg_pred[None]), \"bbox\": bbox, \"id\": 1}\n        result = GenerateInstanceTyped(keys=\"type_pred\")(test_data)\n        assert_allclose(result[\"type_info\"][\"inst_type\"], expected[0])\n        assert_allclose(result[\"type_info\"][\"type_prob\"], expected[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_succinct_contour.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateSuccinctContour\n\nTEST_CASE_1 = [\n    [\n        np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.0, 1.5]]),\n        np.array([[0.0, 2.5], [0.5, 3.0], [1.0, 3.5], [1.5, 4.0]]),\n        np.array([[4.0, 1.5], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),\n        np.array([[2.5, 4.0], [3.0, 3.5], [3.5, 3.0], [4.0, 2.5]]),\n    ],\n    5,\n    5,\n    [[2, 0], [0, 2], [2, 4], [4, 2]],\n]\n\nTEST_CASE_2 = [\n    [\n        np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.5, 2.0], [0.0, 2.5]]),\n        np.array([[0.0, 3.5], [0.5, 4.0], [0.5, 5.0], [1.0, 5.5], [1.5, 6.0]]),\n        np.array([[4.0, 2.5], [3.5, 2.0], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),\n        np.array([[2.5, 6.0], [3.0, 5.5], [3.5, 5.0], [3.5, 4.0], [4.0, 3.5]]),\n    ],\n    5,\n    7,\n    [[3, 0], [2, 1], [1, 1], [0, 2], [1, 3], [2, 3], [3, 4], [4, 3], [5, 3], [6, 2], [5, 1], [4, 1]],\n]\n\n\nclass TestGenerateSuccinctContour(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape(self, test_data, height, width, expected):\n        result = GenerateSuccinctContour(height=height, width=width)(test_data)\n        np.testing.assert_allclose(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_succinct_contourd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateSuccinctContourd\n\ny, x = np.ogrid[0:5, 0:5]\nTEST_CASE_1 = [\n    [\n        np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.0, 1.5]]),\n        np.array([[0.0, 2.5], [0.5, 3.0], [1.0, 3.5], [1.5, 4.0]]),\n        np.array([[4.0, 1.5], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),\n        np.array([[2.5, 4.0], [3.0, 3.5], [3.5, 3.0], [4.0, 2.5]]),\n    ],\n    5,\n    5,\n    [[2, 0], [0, 2], [2, 4], [4, 2]],\n]\n\nTEST_CASE_2 = [\n    [\n        np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.5, 2.0], [0.0, 2.5]]),\n        np.array([[0.0, 3.5], [0.5, 4.0], [0.5, 5.0], [1.0, 5.5], [1.5, 6.0]]),\n        np.array([[4.0, 2.5], [3.5, 2.0], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),\n        np.array([[2.5, 6.0], [3.0, 5.5], [3.5, 5.0], [3.5, 4.0], [4.0, 3.5]]),\n    ],\n    5,\n    7,\n    [[3, 0], [2, 1], [1, 1], [0, 2], [1, 3], [2, 3], [3, 4], [4, 3], [5, 3], [6, 2], [5, 1], [4, 1]],\n]\n\n\nclass TestGenerateSuccinctContour(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape(self, data, height, width, expected):\n        test_data = {\"contour\": data}\n        result = GenerateSuccinctContourd(keys=\"contour\", height=height, width=width)(test_data)\n        np.testing.assert_allclose(result[\"contour\"], expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_watershed_markers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateWatershedMarkers\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append([{}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError])\n    EXCEPTION_TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError])\n\n    TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\nclass TestGenerateWatershedMarkers(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, mask, probmap, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateWatershedMarkers(**arguments)(mask, probmap)\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, mask, probmap, expected_shape):\n        result = GenerateWatershedMarkers(**arguments)(mask, probmap)\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_watershed_markersd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateWatershedMarkersd\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append(\n        [\n            {\"mask_key\": \"mask\", \"border_key\": \"border\"},\n            p(np.random.rand(2, 5, 5)),\n            p(np.random.rand(1, 5, 5)),\n            ValueError,\n        ]\n    )\n    EXCEPTION_TESTS.append(\n        [\n            {\"mask_key\": \"mask\", \"border_key\": \"border\"},\n            p(np.random.rand(1, 5, 5)),\n            p(np.random.rand(2, 5, 5)),\n            ValueError,\n        ]\n    )\n    EXCEPTION_TESTS.append(\n        [\n            {\"mask_key\": \"mask\", \"border_key\": \"border\", \"markers_key\": \"old_markers\"},\n            p(np.random.rand(1, 5, 5)),\n            p(np.random.rand(1, 5, 5)),\n            KeyError,\n        ]\n    )\n\n    TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)])\n    TESTS.append(\n        [\n            {\"threshold\": 0.4, \"radius\": 1, \"min_object_size\": 0},\n            p(np.random.rand(1, 5, 5)),\n            p(np.random.rand(1, 5, 5)),\n            (1, 5, 5),\n        ]\n    )\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\nclass TestGenerateWatershedMarkersd(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, mask, border_map, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateWatershedMarkersd(**arguments)({\"mask\": mask, \"border\": border_map, \"old_markers\": 1})\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, mask, border_map, expected_shape):\n        result = GenerateWatershedMarkersd(**arguments)({\"mask\": mask, \"border\": border_map})\n        self.assertEqual(result[\"markers\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_watershed_mask.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import GenerateWatershedMask\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS\n\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append([{\"activation\": \"incorrect\"}, ValueError])\n    EXCEPTION_TESTS.append([{\"activation\": 1}, ValueError])\n\n    TESTS.append(\n        [\n            {\"activation\": \"softmax\", \"min_object_size\": 0},\n            p(\n                [\n                    [[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [0.6134, 0.6389, 0.0680]],\n                    [[0.5000, 0.3400, 0.9900], [0.8900, 0.5600, 0.2700], [0.6100, 0.6300, 0.0600]],\n                ]\n            ),\n            (1, 3, 3),\n            [0, 1],\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"activation\": \"sigmoid\", \"threshold\": 0.5, \"min_object_size\": 0},\n            p([[[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [-0.1134, -0.0389, -0.0680]]]),\n            (1, 3, 3),\n            [0, 1],\n        ]\n    )\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\nclass TestGenerateWatershedMask(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateWatershedMask(**arguments)\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, image, expected_shape, expected_value):\n        result = GenerateWatershedMask(**arguments)(image)\n        self.assertEqual(result.shape, expected_shape)\n\n        if isinstance(result, torch.Tensor):\n            result = result.cpu().numpy()\n        self.assertEqual(np.unique(result).tolist(), expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_generate_watershed_maskd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import GenerateWatershedMaskd\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS\n\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n\nEXCEPTION_TESTS = []\nTESTS = []\n\nnp.random.RandomState(123)\n\nfor p in TEST_NDARRAYS:\n    EXCEPTION_TESTS.append([{\"keys\": \"img\", \"activation\": \"incorrect\"}, ValueError])\n    EXCEPTION_TESTS.append([{\"keys\": \"img\", \"activation\": 1}, ValueError])\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"mask_key\": \"mask\", \"activation\": \"softmax\", \"min_object_size\": 0},\n            p(\n                [\n                    [[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [0.6134, 0.6389, 0.0680]],\n                    [[0.5000, 0.3400, 0.9900], [0.8900, 0.5600, 0.2700], [0.6100, 0.6300, 0.0600]],\n                ]\n            ),\n            (1, 3, 3),\n            [0, 1],\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"mask_key\": \"mask\", \"activation\": \"sigmoid\", \"threshold\": 0.5, \"min_object_size\": 0},\n            p([[[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [-0.1134, -0.0389, -0.0680]]]),\n            (1, 3, 3),\n            [0, 1],\n        ]\n    )\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\nclass TestGenerateWatershedMaskd(unittest.TestCase):\n    @parameterized.expand(EXCEPTION_TESTS)\n    def test_value(self, arguments, exception_type):\n        with self.assertRaises(exception_type):\n            GenerateWatershedMaskd(**arguments)\n\n    @parameterized.expand(TESTS)\n    def test_value2(self, arguments, image, expected_shape, expected_value):\n        result = GenerateWatershedMaskd(**arguments)({\"img\": image})\n        self.assertEqual(result[\"mask\"].shape, expected_shape)\n\n        if isinstance(result[\"mask\"], torch.Tensor):\n            result[\"mask\"] = result[\"mask\"].cpu().numpy()\n        self.assertEqual(np.unique(result[\"mask\"]).tolist(), expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import HoVerNetInstanceMapPostProcessing\nfrom monai.transforms import ComputeHoVerMaps, FillHoles, GaussianSmooth\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nimage = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2\nimage = image[None, ...].astype(\"uint8\")\n\nTEST_CASE_1 = [{}, {\"1\": {\"centroid\": 1, \"bbox\": 1.0}}, np.zeros_like(image)]\nTEST_CASE_2 = [{\"distance_smooth_fn\": GaussianSmooth()}, {\"1\": {\"type\": 1, \"type_prob\": 1.0}}, np.zeros_like(image)]\nTEST_CASE_3 = [{\"marker_postprocess_fn\": FillHoles()}, {\"1\": {\"type\": 1, \"type_prob\": 1.0}}, np.zeros_like(image)]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, image] + TEST_CASE_1)\n    TEST_CASE.append([p, image] + TEST_CASE_2)\n    TEST_CASE.append([p, image] + TEST_CASE_3)\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestHoVerNetInstanceMapPostProcessing(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_value(self, in_type, test_data, kwargs, expected_info, expected_map):\n        nuclear_prediction = in_type(test_data.astype(float))\n        hover_map = in_type(ComputeHoVerMaps()(test_data.astype(int)))\n\n        inst_info, inst_map = HoVerNetInstanceMapPostProcessing(**kwargs)(nuclear_prediction, hover_map)\n\n        # instance info\n        for key in inst_info:\n            assert_allclose(inst_info[key][\"centroid\"], expected_info[key][\"centroid\"], type_test=False)\n\n        # instance map\n        assert_allclose(inst_map, expected_map, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_hovernet_instance_map_post_processingd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import HoVerNetInstanceMapPostProcessingd\nfrom monai.transforms import ComputeHoVerMaps, FillHoles, GaussianSmooth\nfrom monai.utils import min_version, optional_import\nfrom monai.utils.enums import HoVerNetBranch\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nimage = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2\nimage = image[None, ...].astype(\"uint8\")\n\nTEST_CASE_1 = [{}, {\"1\": {\"centroid\": 1, \"bbox\": 1.0}}, np.zeros_like(image)]\nTEST_CASE_2 = [{\"distance_smooth_fn\": GaussianSmooth()}, {\"1\": {\"type\": 1, \"type_prob\": 1.0}}, np.zeros_like(image)]\nTEST_CASE_3 = [{\"marker_postprocess_fn\": FillHoles()}, {\"1\": {\"type\": 1, \"type_prob\": 1.0}}, np.zeros_like(image)]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, image] + TEST_CASE_1)\n    TEST_CASE.append([p, image] + TEST_CASE_2)\n    TEST_CASE.append([p, image] + TEST_CASE_3)\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestHoVerNetInstanceMapPostProcessingd(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_value(self, in_type, test_data, kwargs, expected_info, expected_map):\n        input = {\n            HoVerNetBranch.NP.value: in_type(test_data.astype(float)),\n            HoVerNetBranch.HV.value: in_type(ComputeHoVerMaps()(test_data.astype(int))),\n        }\n\n        outputs = HoVerNetInstanceMapPostProcessingd(**kwargs)(input)\n        inst_info_key = kwargs.get(\"instance_info_key\", \"instance_info\")\n        inst_map_key = kwargs.get(\"instance_map_key\", \"instance_map\")\n\n        # instance info\n        for key in outputs[inst_info_key]:\n            assert_allclose(outputs[inst_info_key][\"centroid\"], expected_info[key][\"centroid\"], type_test=False)\n\n        # instance map\n        assert_allclose(outputs[inst_map_key], expected_map, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_hovernet_nuclear_type_post_processing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import (\n    HoVerNetInstanceMapPostProcessing,\n    HoVerNetNuclearTypePostProcessing,\n)\nfrom monai.transforms import ComputeHoVerMaps\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nimage = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2\nimage = image[None, ...].astype(\"uint8\")\n\nTEST_CASE_1 = [{}, {\"1\": {\"type\": 1, \"type_prob\": 1.0}}, np.zeros_like(image)]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, image] + TEST_CASE_1)\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestHoVerNetNuclearTypePostProcessing(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_value(self, in_type, test_data, kwargs, expected_info, expected_map):\n        nuclear_prediction = in_type(test_data.astype(float))\n        hover_map = in_type(ComputeHoVerMaps()(test_data.astype(int)))\n        nuclear_type = in_type(test_data)\n\n        inst_info, inst_map = HoVerNetInstanceMapPostProcessing()(nuclear_prediction, hover_map)\n        inst_info, type_map = HoVerNetNuclearTypePostProcessing(**kwargs)(nuclear_type, inst_info, inst_map)\n\n        # instance prediction info\n        for key in inst_info:\n            self.assertEqual(inst_info[key][\"type\"], expected_info[key][\"type\"])\n            self.assertEqual(inst_info[key][\"type_prob\"], expected_info[key][\"type_prob\"])\n\n        # type map\n        if expected_map is None:\n            self.assertIsNone(type_map)\n        else:\n            assert_allclose(type_map, expected_map, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_watershed.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.array import (\n    GenerateDistanceMap,\n    GenerateInstanceBorder,\n    GenerateWatershedMarkers,\n    GenerateWatershedMask,\n    Watershed,\n)\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n\nnp.random.RandomState(123)\n\nTESTS = []\nparams = {\"connectivity\": 1}\nfor p in TEST_NDARRAYS:\n    image = p(np.random.rand(1, 10, 10))\n    hover_map = p(np.random.rand(2, 10, 10))\n\n    TESTS.append([params, image, hover_map, (1, 10, 10)])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\nclass TestWatershed(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_output(self, args, image, hover_map, expected_shape):\n        mask = GenerateWatershedMask()(image)\n        border_map = GenerateInstanceBorder(kernel_size=3)(mask, hover_map)\n        distance_map = GenerateDistanceMap()(mask, border_map)\n        markers = GenerateWatershedMarkers()(mask, border_map)\n\n        calculate_instance_seg = Watershed(**args)\n        output = calculate_instance_seg(distance_map, mask, markers)\n\n        self.assertTupleEqual(output.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/post/test_watershedd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import (\n    GenerateDistanceMapd,\n    GenerateInstanceBorderd,\n    GenerateWatershedMarkersd,\n    GenerateWatershedMaskd,\n    Watershedd,\n)\nfrom monai.transforms import Compose\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n\nTESTS = []\nparams = {\"keys\": \"dist_map\", \"mask_key\": \"mask\", \"markers_key\": \"markers\", \"connectivity\": 1}\nfor p in TEST_NDARRAYS:\n    image = p(np.random.rand(1, 10, 10))\n    hover_map = p(np.random.rand(2, 10, 10))\n\n    TESTS.append([params, image, hover_map, (1, 10, 10)])\n\n    params.update({\"markers_key\": None})\n    TESTS.append([params, image, hover_map, (1, 10, 10)])\n\n    params.update({\"mask_key\": None, \"markers_key\": None})\n    TESTS.append([params, image, hover_map, (1, 10, 10)])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\nclass TestWatershedd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_output(self, args, image, hover_map, expected_shape):\n        data = {\"output\": image, \"hover_map\": hover_map}\n\n        trans = Compose(\n            [\n                GenerateWatershedMaskd(keys=\"output\"),\n                GenerateInstanceBorderd(mask_key=\"mask\", hover_map_key=\"hover_map\", kernel_size=3),\n                GenerateDistanceMapd(mask_key=\"mask\", border_key=\"border\"),\n                GenerateWatershedMarkersd(mask_key=\"mask\", border_key=\"border\"),\n                Watershedd(**args),\n            ]\n        )\n\n        output = trans(data)\n        self.assertTupleEqual(output[\"dist_map\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/test_pathology_he_stain.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms import ExtractHEStains, NormalizeHEStains\n\n# None inputs\nEXTRACT_STAINS_TEST_CASE_0 = (None,)\nEXTRACT_STAINS_TEST_CASE_00 = (None, None)\nNORMALIZE_STAINS_TEST_CASE_0 = (None,)\nNORMALIZE_STAINS_TEST_CASE_00: tuple = ({}, None, None)\n\n# input pixels with negative values\nNEGATIVE_VALUE_TEST_CASE = [np.full((3, 2, 3), -1)]\n\n# input pixels with greater than 255 values\nINVALID_VALUE_TEST_CASE = [np.full((3, 2, 3), 256)]\n\n# input pixels all transparent and below the beta absorbance threshold\nEXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)]\n\n# input pixels uniformly filled, but above beta absorbance threshold\nEXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 3), 100)]\n\n# input pixels uniformly filled (different value), but above beta absorbance threshold\nEXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 3), 150)]\n\n# input pixels uniformly filled with zeros, leading to two identical stains extracted\nEXTRACT_STAINS_TEST_CASE_4 = [\n    np.zeros((3, 2, 3)),\n    np.array([[0.0, 0.0], [0.70710678, 0.70710678], [0.70710678, 0.70710678]]),\n]\n\n# input pixels not uniformly filled, leading to two different stains extracted\nEXTRACT_STAINS_TEST_CASE_5 = [\n    np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),\n    np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),\n]\n\n# input pixels all transparent and below the beta absorbance threshold\nNORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)]\n\n# input pixels uniformly filled with zeros, and target stain matrix provided\nNORMALIZE_STAINS_TEST_CASE_2 = [{\"target_he\": np.full((3, 2), 1)}, np.zeros((3, 2, 3)), np.full((3, 2, 3), 11)]\n\n# input pixels uniformly filled with zeros, and target stain matrix not provided\nNORMALIZE_STAINS_TEST_CASE_3 = [\n    {},\n    np.zeros((3, 2, 3)),\n    np.array([[[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]),\n]\n\n# input pixels not uniformly filled\nNORMALIZE_STAINS_TEST_CASE_4 = [\n    {\"target_he\": np.full((3, 2), 1)},\n    np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),\n    np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),\n]\n\n\nclass TestExtractHEStains(unittest.TestCase):\n\n    @parameterized.expand(\n        [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1]\n    )\n    def test_transparent_image(self, image):\n        \"\"\"\n        Test HE stain extraction on an image that comprises\n        only transparent pixels - pixels with absorbance below the\n        beta absorbance threshold. A ValueError should be raised,\n        since once the transparent pixels are removed, there are no\n        remaining pixels to compute eigenvectors.\n        \"\"\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                ExtractHEStains()(image)\n        else:\n            with self.assertRaises(ValueError):\n                ExtractHEStains()(image)\n\n    @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_2, EXTRACT_STAINS_TEST_CASE_3])\n    def test_identical_result_vectors(self, image):\n        \"\"\"\n        Test HE stain extraction on input images that are\n        uniformly filled with pixels that have absorbance above the\n        beta absorbance threshold. Since input image is uniformly filled,\n        the two extracted stains should have the same RGB values. So,\n        we assert that the first column is equal to the second column\n        of the returned stain matrix.\n        \"\"\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                ExtractHEStains()(image)\n        else:\n            result = ExtractHEStains()(image)\n            np.testing.assert_array_equal(result[:, 0], result[:, 1])\n\n    @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5])\n    def test_result_value(self, image, expected_data):\n        \"\"\"\n        Test that an input image returns an expected stain matrix.\n\n        For test case 4:\n        - a uniformly filled input image should result in\n          eigenvectors [[1,0,0],[0,1,0],[0,0,1]]\n        - phi should be an array containing only values of\n          arctan(1) since the ratio between the eigenvectors\n          corresponding to the two largest eigenvalues is 1\n        - maximum phi and minimum phi should thus be arctan(1)\n        - thus, maximum vector and minimum vector should be\n          [[0],[0.70710677],[0.70710677]]\n        - the resulting extracted stain should be\n          [[0,0],[0.70710678,0.70710678],[0.70710678,0.70710678]]\n\n        For test case 5:\n        - the non-uniformly filled input image should result in\n          eigenvectors [[0,0,1],[1,0,0],[0,1,0]]\n        - maximum phi and minimum phi should thus be 0.785 and\n          0.188 respectively\n        - thus, maximum vector and minimum vector should be\n          [[0.18696113],[0],[0.98236734]] and\n          [[0.70710677],[0],[0.70710677]] respectively\n        - the resulting extracted stain should be\n          [[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]\n        \"\"\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                ExtractHEStains()(image)\n        else:\n            result = ExtractHEStains()(image)\n            np.testing.assert_allclose(result, expected_data)\n\n\nclass TestNormalizeHEStains(unittest.TestCase):\n\n    @parameterized.expand(\n        [NEGATIVE_VALUE_TEST_CASE, INVALID_VALUE_TEST_CASE, NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1]\n    )\n    def test_transparent_image(self, image):\n        \"\"\"\n        Test HE stain normalization on an image that comprises\n        only transparent pixels - pixels with absorbance below the\n        beta absorbance threshold. A ValueError should be raised,\n        since once the transparent pixels are removed, there are no\n        remaining pixels to compute eigenvectors.\n        \"\"\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                NormalizeHEStains()(image)\n        else:\n            with self.assertRaises(ValueError):\n                NormalizeHEStains()(image)\n\n    @parameterized.expand(\n        [\n            NORMALIZE_STAINS_TEST_CASE_00,\n            NORMALIZE_STAINS_TEST_CASE_2,\n            NORMALIZE_STAINS_TEST_CASE_3,\n            NORMALIZE_STAINS_TEST_CASE_4,\n        ]\n    )\n    def test_result_value(self, arguments, image, expected_data):\n        \"\"\"\n        Test that an input image returns an expected normalized image.\n\n        For test case 2:\n        - This case tests calling the stain normalizer, after the\n          _deconvolution_extract_conc function. This is because the normalized\n          concentration returned for each pixel is the same as the reference\n          maximum stain concentrations in the case that the image is uniformly\n          filled, as in this test case. This is because the maximum concentration\n          for each stain is the same as each pixel's concentration.\n        - Thus, the normalized concentration matrix should be a (2, 6) matrix\n          with the first row having all values of 1.9705, second row all 1.0308.\n        - Taking the matrix product of the target stain matrix and the concentration\n          matrix, then using the inverse Beer-Lambert transform to obtain the RGB\n          image from the absorbance image, and finally converting to uint8,\n          we get that the stain normalized image should be a matrix of\n          dims (3, 2, 3), with all values 11.\n\n        For test case 3:\n        - This case also tests calling the stain normalizer, after the\n          _deconvolution_extract_conc function returns the image concentration\n          matrix.\n        - As in test case 2, the normalized concentration matrix should be a (2, 6) matrix\n          with the first row having all values of 1.9705, second row all 1.0308.\n        - Taking the matrix product of the target default stain matrix and the concentration\n          matrix, then using the inverse Beer-Lambert transform to obtain the RGB\n          image from the absorbance image, and finally converting to uint8,\n          we get that the stain normalized image should be [[[63, 25, 60], [63, 25, 60]],\n          [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]\n\n        For test case 4:\n        - For this non-uniformly filled image, the stain extracted should be\n          [[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the\n          ExtractHEStains class. Solving the linear least squares problem (since\n          absorbance matrix = stain matrix * concentration matrix), we obtain the concentration\n          matrix that should be [[5.8022, 0, 0, 0, 0, 0],\n          [-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]\n        - Normalizing the concentration matrix, taking the matrix product of the\n          target stain matrix and the concentration matrix, using the inverse\n          Beer-Lambert transform to obtain the RGB image from the absorbance\n          image, and finally converting to uint8, we get that the stain normalized\n          image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],\n          [[85, 85, 85], [85, 85, 85]]]\n        \"\"\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                NormalizeHEStains()(image)\n        else:\n            result = NormalizeHEStains(**arguments)(image)\n            np.testing.assert_allclose(result, expected_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/pathology/transforms/test_pathology_he_stain_dict.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms import ExtractHEStainsD, NormalizeHEStainsD\n\n# None inputs\nEXTRACT_STAINS_TEST_CASE_0 = (None,)\nEXTRACT_STAINS_TEST_CASE_00 = (None, None)\nNORMALIZE_STAINS_TEST_CASE_0 = (None,)\nNORMALIZE_STAINS_TEST_CASE_00: tuple = ({}, None, None)\n\n# input pixels all transparent and below the beta absorbance threshold\nEXTRACT_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)]\n\n# input pixels uniformly filled, but above beta absorbance threshold\nEXTRACT_STAINS_TEST_CASE_2 = [np.full((3, 2, 3), 100)]\n\n# input pixels uniformly filled (different value), but above beta absorbance threshold\nEXTRACT_STAINS_TEST_CASE_3 = [np.full((3, 2, 3), 150)]\n\n# input pixels uniformly filled with zeros, leading to two identical stains extracted\nEXTRACT_STAINS_TEST_CASE_4 = [\n    np.zeros((3, 2, 3)),\n    np.array([[0.0, 0.0], [0.70710678, 0.70710678], [0.70710678, 0.70710678]]),\n]\n\n# input pixels not uniformly filled, leading to two different stains extracted\nEXTRACT_STAINS_TEST_CASE_5 = [\n    np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),\n    np.array([[0.18696113, 0.70710677], [0.0, 0.0], [0.98236734, 0.70710677]]),\n]\n\n# input pixels all transparent and below the beta absorbance threshold\nNORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)]\n\n# input pixels uniformly filled with zeros, and target stain matrix provided\nNORMALIZE_STAINS_TEST_CASE_2 = [{\"target_he\": np.full((3, 2), 1)}, np.zeros((3, 2, 3)), np.full((3, 2, 3), 11)]\n\n# input pixels uniformly filled with zeros, and target stain matrix not provided\nNORMALIZE_STAINS_TEST_CASE_3 = [\n    {},\n    np.zeros((3, 2, 3)),\n    np.array([[[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]),\n]\n\n# input pixels not uniformly filled\nNORMALIZE_STAINS_TEST_CASE_4 = [\n    {\"target_he\": np.full((3, 2), 1)},\n    np.array([[[100, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]]),\n    np.array([[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]]]),\n]\n\n\nclass TestExtractHEStainsD(unittest.TestCase):\n\n    @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_1])\n    def test_transparent_image(self, image):\n        \"\"\"\n        Test HE stain extraction on an image that comprises\n        only transparent pixels - pixels with absorbance below the\n        beta absorbance threshold. A ValueError should be raised,\n        since once the transparent pixels are removed, there are no\n        remaining pixels to compute eigenvectors.\n        \"\"\"\n        key = \"image\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                ExtractHEStainsD([key])({key: image})\n        else:\n            with self.assertRaises(ValueError):\n                ExtractHEStainsD([key])({key: image})\n\n    @parameterized.expand([EXTRACT_STAINS_TEST_CASE_0, EXTRACT_STAINS_TEST_CASE_2, EXTRACT_STAINS_TEST_CASE_3])\n    def test_identical_result_vectors(self, image):\n        \"\"\"\n        Test HE stain extraction on input images that are\n        uniformly filled with pixels that have absorbance above the\n        beta absorbance threshold. Since input image is uniformly filled,\n        the two extracted stains should have the same RGB values. So,\n        we assert that the first column is equal to the second column\n        of the returned stain matrix.\n        \"\"\"\n        key = \"image\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                ExtractHEStainsD([key])({key: image})\n        else:\n            result = ExtractHEStainsD([key])({key: image})\n            np.testing.assert_array_equal(result[key][:, 0], result[key][:, 1])\n\n    @parameterized.expand([EXTRACT_STAINS_TEST_CASE_00, EXTRACT_STAINS_TEST_CASE_4, EXTRACT_STAINS_TEST_CASE_5])\n    def test_result_value(self, image, expected_data):\n        \"\"\"\n        Test that an input image returns an expected stain matrix.\n\n        For test case 4:\n        - a uniformly filled input image should result in\n          eigenvectors [[1,0,0],[0,1,0],[0,0,1]]\n        - phi should be an array containing only values of\n          arctan(1) since the ratio between the eigenvectors\n          corresponding to the two largest eigenvalues is 1\n        - maximum phi and minimum phi should thus be arctan(1)\n        - thus, maximum vector and minimum vector should be\n          [[0],[0.70710677],[0.70710677]]\n        - the resulting extracted stain should be\n          [[0,0],[0.70710678,0.70710678],[0.70710678,0.70710678]]\n\n        For test case 5:\n        - the non-uniformly filled input image should result in\n          eigenvectors [[0,0,1],[1,0,0],[0,1,0]]\n        - maximum phi and minimum phi should thus be 0.785 and\n          0.188 respectively\n        - thus, maximum vector and minimum vector should be\n          [[0.18696113],[0],[0.98236734]] and\n          [[0.70710677],[0],[0.70710677]] respectively\n        - the resulting extracted stain should be\n          [[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]]\n        \"\"\"\n        key = \"image\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                ExtractHEStainsD([key])({key: image})\n        else:\n            result = ExtractHEStainsD([key])({key: image})\n            np.testing.assert_allclose(result[key], expected_data)\n\n\nclass TestNormalizeHEStainsD(unittest.TestCase):\n\n    @parameterized.expand([NORMALIZE_STAINS_TEST_CASE_0, NORMALIZE_STAINS_TEST_CASE_1])\n    def test_transparent_image(self, image):\n        \"\"\"\n        Test HE stain normalization on an image that comprises\n        only transparent pixels - pixels with absorbance below the\n        beta absorbance threshold. A ValueError should be raised,\n        since once the transparent pixels are removed, there are no\n        remaining pixels to compute eigenvectors.\n        \"\"\"\n        key = \"image\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                NormalizeHEStainsD([key])({key: image})\n        else:\n            with self.assertRaises(ValueError):\n                NormalizeHEStainsD([key])({key: image})\n\n    @parameterized.expand(\n        [\n            NORMALIZE_STAINS_TEST_CASE_00,\n            NORMALIZE_STAINS_TEST_CASE_2,\n            NORMALIZE_STAINS_TEST_CASE_3,\n            NORMALIZE_STAINS_TEST_CASE_4,\n        ]\n    )\n    def test_result_value(self, arguments, image, expected_data):\n        \"\"\"\n        Test that an input image returns an expected normalized image.\n\n        For test case 2:\n        - This case tests calling the stain normalizer, after the\n          _deconvolution_extract_conc function. This is because the normalized\n          concentration returned for each pixel is the same as the reference\n          maximum stain concentrations in the case that the image is uniformly\n          filled, as in this test case. This is because the maximum concentration\n          for each stain is the same as each pixel's concentration.\n        - Thus, the normalized concentration matrix should be a (2, 6) matrix\n          with the first row having all values of 1.9705, second row all 1.0308.\n        - Taking the matrix product of the target stain matrix and the concentration\n          matrix, then using the inverse Beer-Lambert transform to obtain the RGB\n          image from the absorbance image, and finally converting to uint8,\n          we get that the stain normalized image should be a matrix of\n          dims (3, 2, 3), with all values 11.\n\n        For test case 3:\n        - This case also tests calling the stain normalizer, after the\n          _deconvolution_extract_conc function returns the image concentration\n          matrix.\n        - As in test case 2, the normalized concentration matrix should be a (2, 6) matrix\n          with the first row having all values of 1.9705, second row all 1.0308.\n        - Taking the matrix product of the target default stain matrix and the concentration\n          matrix, then using the inverse Beer-Lambert transform to obtain the RGB\n          image from the absorbance image, and finally converting to uint8,\n          we get that the stain normalized image should be [[[63, 25, 60], [63, 25, 60]],\n          [[63, 25, 60], [63, 25, 60]], [[63, 25, 60], [63, 25, 60]]]\n\n        For test case 4:\n        - For this non-uniformly filled image, the stain extracted should be\n          [[0.18696113,0.70710677],[0,0],[0.98236734,0.70710677]], as validated for the\n          ExtractHEStains class. Solving the linear least squares problem (since\n          absorbance matrix = stain matrix * concentration matrix), we obtain the concentration\n          matrix that should be  [[5.8022, 0, 0, 0, 0, 0],\n          [-0.3101, 7.7508, 7.7508, 7.7508, 7.7508, 7.7508]]\n        - Normalizing the concentration matrix, taking the matrix product of the\n          target stain matrix and the concentration matrix, using the inverse\n          Beer-Lambert transform to obtain the RGB image from the absorbance\n          image, and finally converting to uint8, we get that the stain normalized\n          image should be [[[31, 31, 31], [85, 85, 85]], [[85, 85, 85], [85, 85, 85]],\n          [[85, 85, 85], [85, 85, 85]]]\n        \"\"\"\n        key = \"image\"\n        if image is None:\n            with self.assertRaises(TypeError):\n                NormalizeHEStainsD([key])({key: image})\n        else:\n            result = NormalizeHEStainsD([key], **arguments)({key: image})\n            np.testing.assert_allclose(result[key], expected_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/reconstruction/nets/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/reconstruction/nets/test_recon_net_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.networks.nets.utils import (\n    complex_normalize,\n    divisible_pad_t,\n    inverse_divisible_pad_t,\n    reshape_batch_channel_to_channel_dim,\n    reshape_channel_complex_to_last_dim,\n    reshape_channel_to_batch_dim,\n    reshape_complex_to_channel_dim,\n    sensitivity_map_expand,\n    sensitivity_map_reduce,\n)\nfrom tests.test_utils import assert_allclose\n\n# no need for checking devices, these functions don't change device format\n# reshape test case\nim_2d, im_3d = torch.ones([3, 4, 50, 70, 2]), torch.ones([3, 4, 50, 70, 80, 2])\nTEST_RESHAPE = [(im_2d,), (im_3d,)]\n\n# normalize test case\nim_2d, im_3d = (torch.randint(0, 3, [3, 4, 50, 70]).float(), torch.randint(0, 3, [3, 4, 50, 70, 80]).float())\nTEST_NORMALIZE = [(im_2d,), (im_3d,)]\n\n# pad test case\nim_2d, im_3d = torch.ones([3, 4, 50, 70]), torch.ones([3, 4, 50, 70, 80])\nTEST_PAD = [(im_2d,), (im_3d,)]\n\n# test case for sensitivity map expansion/reduction\nksp_2d, ksp_3d = torch.ones([3, 4, 50, 70, 2]), torch.ones([3, 4, 50, 70, 80, 2])\nsens_2d, sens_3d = torch.ones([3, 4, 50, 70, 2]), torch.ones([3, 4, 50, 70, 80, 2])\nTEST_SENS = [(ksp_2d, sens_2d), (ksp_3d, sens_3d)]\n\n\nclass TestReconNetUtils(unittest.TestCase):\n    @parameterized.expand(TEST_RESHAPE)\n    def test_reshape_channel_complex(self, test_data):\n        result = reshape_complex_to_channel_dim(test_data)\n        result = reshape_channel_complex_to_last_dim(result)\n        self.assertEqual(result.shape, test_data.shape)\n\n        result, batch_size = reshape_channel_to_batch_dim(test_data)\n        result = reshape_batch_channel_to_channel_dim(result, batch_size)\n        self.assertEqual(result.shape, test_data.shape)\n\n    @parameterized.expand(TEST_NORMALIZE)\n    def test_complex_normalize(self, test_data):\n        result, mean, std = complex_normalize(test_data)\n        result = result * std + mean\n        self.assertLess((((result - test_data) ** 2).mean() ** 0.5).item(), 1e-5)\n\n    @parameterized.expand(TEST_PAD)\n    def test_pad(self, test_data):\n        result, padding_sizes = divisible_pad_t(test_data, k=16)\n        result = inverse_divisible_pad_t(result, padding_sizes)\n        assert_allclose(result, test_data)\n\n    @parameterized.expand(TEST_SENS)\n    def test_sens_expand_reduce(self, test_data, sens):\n        result = sensitivity_map_reduce(test_data, sens)\n        result = sensitivity_map_expand(result, sens)\n        self.assertEqual(result.shape, test_data.shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/test_complex_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.complex_utils import complex_abs, complex_conj, complex_mul, convert_to_tensor_complex\nfrom monai.utils.type_conversion import convert_data_type\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n# test case for convert_to_tensor_complex\nim_complex = [[1.0 + 1.0j, 1.0 + 1.0j], [1.0 + 1.0j, 1.0 + 1.0j]]\nexpected_shape = convert_data_type((2, 2, 2), torch.Tensor)[0]\nTESTS = [(im_complex, expected_shape)]\nfor p in TEST_NDARRAYS:\n    TESTS.append((p(im_complex), expected_shape))\n\n# test case for complex_abs\nim = [[3.0, 4.0], [3.0, 4.0]]\nres = [5.0, 5.0]\nTESTSC = []\nfor p in TEST_NDARRAYS:\n    TESTSC.append((p(im), p(res)))\n\n# test case for complex_mul\nx = [[1.0, 2.0], [3.0, 4.0]]\ny = [[1.0, 1.0], [1.0, 1.0]]\nres = [[-1.0, 3.0], [-1.0, 7.0]]  # type: ignore\nTESTSM = []\nfor p in TEST_NDARRAYS:\n    TESTSM.append((p(x), p(y), p(res)))\n\n# test case for complex_conj\nim = [[1.0, 2.0], [3.0, 4.0]]\nres = [[1.0, -2.0], [3.0, -4.0]]  # type: ignore\nTESTSJ = []\nfor p in TEST_NDARRAYS:\n    TESTSJ.append((p(im), p(res)))\n\n\nclass TestMRIUtils(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_to_tensor_complex(self, test_data, expected_shape):\n        result = convert_to_tensor_complex(test_data)\n        self.assertTrue(isinstance(result, torch.Tensor))\n        self.assertTupleEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TESTSC)\n    def test_complex_abs(self, test_data, res_data):\n        result = complex_abs(test_data)\n        assert_allclose(result, res_data, type_test=False)\n\n    @parameterized.expand(TESTSM)\n    def test_complex_mul(self, test_x, test_y, res_data):\n        result = complex_mul(test_x, test_y)\n        assert_allclose(result, res_data, type_test=False)\n\n    @parameterized.expand(TESTSJ)\n    def test_complex_conj(self, test_data, res_data):\n        result = complex_conj(test_data)\n        assert_allclose(result, res_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/test_fastmri_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.fastmri_reader import FastMRIReader\nfrom tests.test_utils import SkipIfNoModule, assert_allclose\n\nTEST_CASE1 = [\n    {\n        \"kspace\": np.array([[1.0, 2.0]]),\n        \"filename\": \"test1\",\n        \"reconstruction_rss\": np.array([[1.0, 2.0]]),\n        \"acquisition\": \"FS\",\n        \"max\": 2.0,\n        \"norm\": 2.2,\n        \"patient_id\": 12,\n    },\n    np.array([[1.0, 2.0]]),\n    {\n        \"filename\": \"test1\",\n        \"reconstruction_rss\": np.array([[1.0, 2.0]]),\n        \"acquisition\": \"FS\",\n        \"max\": 2.0,\n        \"norm\": 2.2,\n        \"patient_id\": 12,\n        \"mask\": np.zeros([1, 2]),\n    },\n]\n\nTEST_CASE2 = [\n    {\n        \"kspace\": np.array([[1.0, 2.0], [3.0, 4.0]]),\n        \"filename\": \"test2\",\n        \"reconstruction_rss\": np.array([[1.0, 2.0], [3.0, 4.0]]),\n        \"acquisition\": \"FS\",\n        \"max\": 4.0,\n        \"norm\": 5.5,\n        \"patient_id\": 1234,\n    },\n    np.array([[1.0, 2.0], [3.0, 4.0]]),\n    {\n        \"filename\": \"test2\",\n        \"reconstruction_rss\": np.array([[1.0, 2.0], [3.0, 4.0]]),\n        \"acquisition\": \"FS\",\n        \"max\": 4.0,\n        \"norm\": 5.5,\n        \"patient_id\": 1234,\n        \"mask\": np.zeros([2, 2]),\n    },\n]\n\n\n@SkipIfNoModule(\"h5py\")\nclass TestMRIUtils(unittest.TestCase):\n    @parameterized.expand([TEST_CASE1, TEST_CASE2])\n    def test_get_data(self, test_data, test_res, test_meta):\n        reader = FastMRIReader()\n        res, meta = reader.get_data(test_data)\n        assert_allclose(res, test_res)\n        for key in test_meta:\n            if isinstance(test_meta[key], np.ndarray):\n                assert_allclose(test_meta[key], meta[key])\n            else:\n                self.assertEqual(test_meta[key], meta[key])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/test_mri_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.mri_utils import root_sum_of_squares\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n# root_sum_of_squares\nim = [[3.0, 4.0], [3.0, 4.0]]\nres = [5.0, 5.0]\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append((p(im), p(res)))\n\n\nclass TestMRIUtils(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rss(self, test_data, res_data):\n        result = root_sum_of_squares(test_data, spatial_dim=1)\n        assert_allclose(result, res_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/reconstruction/transforms/test_kspace_mask.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask\nfrom monai.utils.type_conversion import convert_data_type\n\n# test case for apply_mask\nksp, *_ = convert_data_type(np.ones([50, 50, 2]), torch.Tensor)\nTESTSM = [(ksp,)]\n\n\nclass TestMRIUtils(unittest.TestCase):\n\n    @parameterized.expand(TESTSM)\n    def test_mask(self, test_data):\n        # random mask\n        masker = RandomKspaceMask(center_fractions=[0.08], accelerations=[4.0], spatial_dims=1, is_complex=True)\n        masker.set_random_state(seed=0)\n        result, _ = masker(test_data)\n        mask = masker.mask\n        result = result[..., mask.squeeze() == 0, :].sum()\n        self.assertEqual(result.item(), 0)\n\n        # equispaced mask\n        masker = EquispacedKspaceMask(center_fractions=[0.08], accelerations=[4.0], spatial_dims=1, is_complex=True)\n        masker.set_random_state(seed=0)\n        result, _ = masker(test_data)\n        mask = masker.mask\n        result = result[..., mask.squeeze() == 0, :].sum()\n        self.assertEqual(result.item(), 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/transforms/test_reference_based_normalize_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.transforms.dictionary import ReferenceBasedNormalizeIntensityd\nfrom monai.utils.type_conversion import convert_to_numpy\nfrom tests.test_utils import TEST_NDARRAYS_NO_META_TENSOR, assert_allclose\n\n# see test_normalize_intensityd for typical tests (like non-zero\n# normalization, device test, etc.)\n# here, we test DetailedNormalizeIntensityd's functionality\n# which focuses on (1) automatic target normalization and (2) mean-std\n# return values\n\nTESTS = []\nfor p in TEST_NDARRAYS_NO_META_TENSOR:\n    TESTS.append(\n        [\n            {\"keys\": [\"kspace_masked_ifft\", \"target\"], \"ref_key\": \"kspace_masked_ifft\", \"channel_wise\": True},\n            {\"kspace_masked_ifft\": p(np.array([[-2.0, 0.0, 2.0]])), \"target\": p(np.array([[1.0, 2.0, 3.0]]))},\n            p(np.array([[-1.225, 0.0, 1.225]])),  # normalized input\n            p(np.array([[0.612, 1.225, 1.837]])),  # normalized target\n            np.array([0.0]),  # mean\n            np.array([1.633]),  # std\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"kspace_masked_ifft\", \"target\"], \"ref_key\": \"kspace_masked_ifft\", \"channel_wise\": False},\n            {\"kspace_masked_ifft\": p(np.array([[-2.0, 0.0, 2.0]])), \"target\": p(np.array([[1.0, 2.0, 3.0]]))},\n            p(np.array([[-1.225, 0.0, 1.225]])),  # normalized input\n            p(np.array([[0.612, 1.225, 1.837]])),  # normalized target\n            0.0,  # mean\n            1.633,  # std\n        ]\n    )\n\n\nclass TestDetailedNormalizeIntensityd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_target_mean_std(self, args, data, normalized_data, normalized_target, mean, std):\n        dtype = data[args[\"keys\"][0]].dtype\n        normalizer = ReferenceBasedNormalizeIntensityd(\n            keys=args[\"keys\"], ref_key=args[\"ref_key\"], channel_wise=args[\"channel_wise\"], dtype=dtype\n        )\n        res_data = normalizer(data)\n\n        img = np.round(convert_to_numpy(res_data[args[\"keys\"][0]]), 3)\n        normalized_data = np.round(convert_to_numpy(normalized_data), 3)\n\n        target = np.round(convert_to_numpy(res_data[args[\"keys\"][1]]), 3)\n        normalized_target = np.round(convert_to_numpy(normalized_target), 3)\n\n        assert_allclose(img, normalized_data)\n        assert_allclose(target, normalized_target)\n\n        assert_allclose(np.round(res_data[\"mean\"], 3), mean)\n        assert_allclose(np.round(res_data[\"std\"], 3), std)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/reconstruction/transforms/test_reference_based_spatial_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.transforms.dictionary import ReferenceBasedSpatialCropd\nfrom tests.test_utils import TEST_NDARRAYS\n\n# see test_spatial_cropd for typical tests (like roi_start,\n# roi_slices, etc.)\n# here, we test TargetBasedSpatialCropd's functionality\n# which focuses on automatic input crop based on target image's shape.\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    # 2D\n    TESTS.append(\n        [\n            {\"keys\": [\"kspace_masked_ifft\"], \"ref_key\": \"target\"},\n            {\"kspace_masked_ifft\": p(np.ones([10, 20, 20])), \"target\": p(np.ones([5, 8, 8]))},\n            (8, 8),  # expected shape\n        ]\n    )\n\n    # 3D\n    TESTS.append(\n        [\n            {\"keys\": [\"kspace_masked_ifft\"], \"ref_key\": \"target\"},\n            {\"kspace_masked_ifft\": p(np.ones([10, 20, 20, 16])), \"target\": p(np.ones([5, 8, 8, 6]))},\n            (8, 8, 6),  # expected shape\n        ]\n    )\n\n\nclass TestTargetBasedSpatialCropd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, args, data, expected_shape):\n        cropper = ReferenceBasedSpatialCropd(keys=args[\"keys\"], ref_key=args[\"ref_key\"])\n        res_data = cropper(data)\n        self.assertTupleEqual(res_data[args[\"keys\"][0]].shape[1:], expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_auto3dseg.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom copy import deepcopy\nfrom numbers import Number\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.auto3dseg import DataAnalyzer\nfrom monai.auto3dseg import (\n    Analyzer,\n    FgImageStats,\n    FgImageStatsSumm,\n    FilenameStats,\n    ImageStats,\n    ImageStatsSumm,\n    LabelStats,\n    LabelStatsSumm,\n    Operations,\n    SampleOperations,\n    SegSummarizer,\n    SummaryOperations,\n    datafold_read,\n    verify_report_format,\n)\nfrom monai.bundle import ConfigParser\nfrom monai.data import DataLoader, Dataset, create_test_image_2d, create_test_image_3d\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import no_collation\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirstd,\n    EnsureTyped,\n    Lambdad,\n    LoadImaged,\n    Orientationd,\n    SqueezeDimd,\n    ToDeviced,\n)\nfrom monai.utils.enums import DataStatsKeys\nfrom tests.test_utils import skip_if_no_cuda\n\ndevice = \"cpu\"\nn_workers = 2\n\nsim_datalist = {\n    \"testing\": [{\"image\": \"val_001.fake.nii.gz\"}, {\"image\": \"val_002.fake.nii.gz\"}],\n    \"training\": [\n        {\"fold\": 0, \"image\": \"tr_image_001.fake.nii.gz\", \"label\": \"tr_label_001.fake.nii.gz\"},\n        {\"fold\": 0, \"image\": \"tr_image_002.fake.nii.gz\", \"label\": \"tr_label_002.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_001.fake.nii.gz\", \"label\": \"tr_label_001.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_004.fake.nii.gz\", \"label\": \"tr_label_004.fake.nii.gz\"},\n    ],\n}\n\nSIM_CPU_TEST_CASES = [\n    [{\"sim_dim\": (32, 32, 32), \"label_key\": \"label\"}],\n    [{\"sim_dim\": (32, 32, 32, 2), \"label_key\": \"label\"}],\n    [{\"sim_dim\": (32, 32, 32), \"label_key\": None}],\n    [{\"sim_dim\": (32, 32, 32), \"label_key\": \"None\"}],\n]\n\nSIM_GPU_TEST_CASES = [[{\"sim_dim\": (32, 32, 32), \"label_key\": \"label\"}], [{\"sim_dim\": (32, 32, 32), \"label_key\": None}]]\n\n\ndef create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:\n    \"\"\"\n    Create simulated data using create_test_image_3d.\n\n    Args:\n        dataroot: data directory path that hosts the \"nii.gz\" image files.\n        sim_datalist: a list of data to create.\n        sim_dim: the image sizes, for examples: a tuple of (64, 64, 64) for 3d, or (128, 128) for 2d\n    \"\"\"\n    if not os.path.isdir(dataroot):\n        os.makedirs(dataroot)\n\n    # Generate a fake dataset\n    for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n        if len(sim_dim) == 2:  # 2D image\n            im, seg = create_test_image_2d(sim_dim[0], sim_dim[1], **kwargs)\n        elif len(sim_dim) == 3:  # 3D image\n            im, seg = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs)\n        elif len(sim_dim) == 4:  # multi-modality 3D image\n            im_list = []\n            seg_list = []\n            for _ in range(sim_dim[3]):\n                im_3d, seg_3d = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs)\n                im_list.append(im_3d[..., np.newaxis])\n                seg_list.append(seg_3d[..., np.newaxis])\n            im = np.concatenate(im_list, axis=3)\n            seg = np.concatenate(seg_list, axis=3)\n        else:\n            raise ValueError(f\"Invalid argument input. sim_dim has f{len(sim_dim)} values. 2-4 values are expected.\")\n        nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n        image_fpath = os.path.join(dataroot, d[\"image\"])\n        nib.save(nib_image, image_fpath)\n\n        if not image_only and \"label\" in d:\n            nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n            label_fpath = os.path.join(dataroot, d[\"label\"])\n            nib.save(nib_image, label_fpath)\n\n\nclass TestOperations(Operations):\n    \"\"\"\n    Test example for user operation\n    \"\"\"\n\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __init__(self) -> None:\n        self.data = {\"max\": np.max, \"mean\": np.mean, \"min\": np.min}\n\n\nclass TestAnalyzer(Analyzer):\n    \"\"\"\n    Test example for a simple Analyzer\n    \"\"\"\n\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __init__(self, key, report_format, stats_name=\"test\"):\n        self.key = key\n        super().__init__(stats_name, report_format)\n\n    def __call__(self, data):\n        d = dict(data)\n        report = deepcopy(self.get_report_format())\n        report[\"stats\"] = self.ops[\"stats\"].evaluate(d[self.key])\n        d[self.stats_name] = report\n        return d\n\n\nclass TestImageAnalyzer(Analyzer):\n    \"\"\"\n    Test example for a simple Analyzer\n    \"\"\"\n\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __init__(self, image_key=\"image\", stats_name=\"test_image\"):\n        self.image_key = image_key\n        report_format = {\"test_stats\": None}\n\n        super().__init__(stats_name, report_format)\n        self.update_ops(\"test_stats\", TestOperations())\n\n    def __call__(self, data):\n        d = dict(data)\n        report = deepcopy(self.get_report_format())\n        report[\"test_stats\"] = self.ops[\"test_stats\"].evaluate(d[self.image_key])\n        d[self.stats_name] = report\n        return d\n\n\nclass TestDataAnalyzer(unittest.TestCase):\n    def setUp(self):\n        self.test_dir = tempfile.TemporaryDirectory()\n        work_dir = self.test_dir.name\n        self.dataroot_dir = os.path.join(work_dir, \"sim_dataroot\")\n        self.datalist_file = os.path.join(work_dir, \"sim_datalist.json\")\n        self.datastat_file = os.path.join(work_dir, \"datastats.yml\")\n        ConfigParser.export_config_file(sim_datalist, self.datalist_file)\n\n    @parameterized.expand(SIM_CPU_TEST_CASES)\n    def test_data_analyzer_cpu(self, input_params):\n        sim_dim = input_params[\"sim_dim\"]\n        label_key = input_params[\"label_key\"]\n        image_only = not bool(label_key)\n        rmax = max(int(sim_dim[0] / 4), 1)\n        create_sim_data(\n            self.dataroot_dir, sim_datalist, sim_dim, image_only=image_only, rad_max=rmax, rad_min=1, num_seg_classes=1\n        )\n\n        analyser = DataAnalyzer(\n            self.datalist_file, self.dataroot_dir, output_path=self.datastat_file, label_key=label_key, device=device\n        )\n        datastat = analyser.get_all_case_stats()\n\n        assert len(datastat[\"stats_by_cases\"]) == len(sim_datalist[\"training\"])\n\n    def test_data_analyzer_histogram(self):\n        create_sim_data(\n            self.dataroot_dir, sim_datalist, [32] * 3, image_only=True, rad_max=8, rad_min=1, num_seg_classes=1\n        )\n        analyser = DataAnalyzer(\n            self.datalist_file,\n            self.dataroot_dir,\n            output_path=self.datastat_file,\n            label_key=None,\n            device=device,\n            histogram_only=True,\n        )\n        datastat = analyser.get_all_case_stats()\n        assert len(datastat[\"stats_by_cases\"]) == len(sim_datalist[\"training\"])\n\n    @parameterized.expand(SIM_GPU_TEST_CASES)\n    @skip_if_no_cuda\n    def test_data_analyzer_gpu(self, input_params):\n        sim_dim = input_params[\"sim_dim\"]\n        label_key = input_params[\"label_key\"]\n        image_only = not bool(label_key)\n        rmax = max(int(sim_dim[0] / 4), 1)\n        create_sim_data(\n            self.dataroot_dir, sim_datalist, sim_dim, image_only=image_only, rad_max=rmax, rad_min=1, num_seg_classes=1\n        )\n        analyser = DataAnalyzer(\n            self.datalist_file, self.dataroot_dir, output_path=self.datastat_file, label_key=label_key, device=\"cuda\"\n        )\n        datastat = analyser.get_all_case_stats()\n\n        assert len(datastat[\"stats_by_cases\"]) == len(sim_datalist[\"training\"])\n\n    def test_basic_operation_class(self):\n        op = TestOperations()\n        test_data = np.random.rand(10, 10).astype(np.float64)\n        test_ret_1 = op.evaluate(test_data)\n        test_ret_2 = op.evaluate(test_data, axis=0)\n        assert isinstance(test_ret_1, dict) and isinstance(test_ret_2, dict)\n        assert (\"max\" in test_ret_1) and (\"max\" in test_ret_2)\n        assert (\"mean\" in test_ret_1) and (\"mean\" in test_ret_2)\n        assert (\"min\" in test_ret_1) and (\"min\" in test_ret_2)\n        assert isinstance(test_ret_1[\"max\"], np.float64)\n        assert isinstance(test_ret_2[\"max\"], np.ndarray)\n        assert test_ret_1[\"max\"].ndim == 0\n        assert test_ret_2[\"max\"].ndim == 1\n\n    def test_sample_operations(self):\n        op = SampleOperations()\n        test_data_np = np.random.rand(10, 10).astype(np.float64)\n        test_data_mt = MetaTensor(test_data_np, device=device)\n        test_ret_np = op.evaluate(test_data_np)\n        test_ret_mt = op.evaluate(test_data_mt)\n        assert isinstance(test_ret_np[\"max\"], Number)\n        assert isinstance(test_ret_np[\"percentile\"], list)\n        assert isinstance(test_ret_mt[\"max\"], Number)\n        assert isinstance(test_ret_mt[\"percentile\"], list)\n\n        op.update({\"sum\": np.sum})\n        test_ret_np = op.evaluate(test_data_np)\n        assert \"sum\" in test_ret_np\n\n    def test_summary_operations(self):\n        op = SummaryOperations()\n        test_dict = {\"min\": [0, 1, 2, 3], \"max\": [2, 3, 4, 5], \"mean\": [1, 2, 3, 4], \"sum\": [2, 4, 6, 8]}\n        test_ret = op.evaluate(test_dict)\n        assert isinstance(test_ret[\"max\"], Number)\n        assert isinstance(test_ret[\"min\"], Number)\n\n        op.update({\"sum\": np.sum})\n        test_ret = op.evaluate(test_dict)\n        assert \"sum\" in test_ret\n        assert isinstance(test_ret[\"sum\"], Number)\n\n    def test_basic_analyzer_class(self):\n        test_data = {}\n        test_data[\"image_test\"] = np.random.rand(10, 10)\n        report_format = {\"stats\": None}\n        user_analyzer = TestAnalyzer(\"image_test\", report_format)\n        user_analyzer.update_ops(\"stats\", TestOperations())\n        result = user_analyzer(test_data)\n        assert result[\"test\"][\"stats\"][\"max\"] == np.max(test_data[\"image_test\"])\n        assert result[\"test\"][\"stats\"][\"min\"] == np.min(test_data[\"image_test\"])\n        assert result[\"test\"][\"stats\"][\"mean\"] == np.mean(test_data[\"image_test\"])\n\n    def test_transform_analyzer_class(self):\n        transform = Compose([LoadImaged(keys=[\"image\"]), TestImageAnalyzer(image_key=\"image\")])\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=no_collation)\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            assert \"test_image\" in d\n            assert \"test_stats\" in d[\"test_image\"]\n            assert \"max\" in d[\"test_image\"][\"test_stats\"]\n            assert \"min\" in d[\"test_image\"][\"test_stats\"]\n            assert \"mean\" in d[\"test_image\"][\"test_stats\"]\n\n    def test_image_stats_case_analyzer(self):\n        analyzer = ImageStats(image_key=\"image\")\n        transform = Compose(\n            [\n                LoadImaged(keys=[\"image\"]),\n                EnsureChannelFirstd(keys=[\"image\"]),  # this creates label to be (1,H,W,D)\n                ToDeviced(keys=[\"image\"], device=device, non_blocking=True),\n                Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n                EnsureTyped(keys=[\"image\"], data_type=\"tensor\"),\n                analyzer,\n            ]\n        )\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            report_format = analyzer.get_report_format()\n            assert verify_report_format(d[\"image_stats\"], report_format)\n\n    def test_foreground_image_stats_cases_analyzer(self):\n        analyzer = FgImageStats(image_key=\"image\", label_key=\"label\")\n        transform_list = [\n            LoadImaged(keys=[\"image\", \"label\"]),\n            EnsureChannelFirstd(keys=[\"image\", \"label\"]),  # this creates label to be (1,H,W,D)\n            ToDeviced(keys=[\"image\", \"label\"], device=device, non_blocking=True),\n            Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n            EnsureTyped(keys=[\"image\", \"label\"], data_type=\"tensor\"),\n            Lambdad(keys=[\"label\"], func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),\n            SqueezeDimd(keys=[\"label\"], dim=0),\n            analyzer,\n        ]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            report_format = analyzer.get_report_format()\n            assert verify_report_format(d[\"image_foreground_stats\"], report_format)\n\n    def test_label_stats_case_analyzer(self):\n        analyzer = LabelStats(image_key=\"image\", label_key=\"label\")\n        transform = Compose(\n            [\n                LoadImaged(keys=[\"image\", \"label\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"label\"]),  # this creates label to be (1,H,W,D)\n                ToDeviced(keys=[\"image\", \"label\"], device=device, non_blocking=True),\n                Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n                EnsureTyped(keys=[\"image\", \"label\"], data_type=\"tensor\"),\n                Lambdad(keys=[\"label\"], func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),\n                SqueezeDimd(keys=[\"label\"], dim=0),\n                analyzer,\n            ]\n        )\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            report_format = analyzer.get_report_format()\n            assert verify_report_format(d[\"label_stats\"], report_format)\n\n    def test_filename_case_analyzer(self):\n        analyzer_image = FilenameStats(\"image\", DataStatsKeys.BY_CASE_IMAGE_PATH)\n        analyzer_label = FilenameStats(\"label\", DataStatsKeys.BY_CASE_IMAGE_PATH)\n        transform_list = [LoadImaged(keys=[\"image\", \"label\"]), analyzer_image, analyzer_label]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            assert DataStatsKeys.BY_CASE_IMAGE_PATH in d\n\n    def test_filename_case_analyzer_image_only(self):\n        analyzer_image = FilenameStats(\"image\", DataStatsKeys.BY_CASE_IMAGE_PATH)\n        analyzer_label = FilenameStats(None, DataStatsKeys.BY_CASE_IMAGE_PATH)\n        transform_list = [LoadImaged(keys=[\"image\"]), analyzer_image, analyzer_label]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            assert DataStatsKeys.BY_CASE_IMAGE_PATH in d\n            assert d[DataStatsKeys.BY_CASE_IMAGE_PATH] == \"None\"\n\n    def test_image_stats_summary_analyzer(self):\n        summary_analyzer = ImageStatsSumm(\"image_stats\")\n\n        transform_list = [\n            LoadImaged(keys=[\"image\"]),\n            EnsureChannelFirstd(keys=[\"image\"]),  # this creates label to be (1,H,W,D)\n            ToDeviced(keys=[\"image\"], device=device, non_blocking=True),\n            Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n            EnsureTyped(keys=[\"image\"], data_type=\"tensor\"),\n            ImageStats(image_key=\"image\"),\n        ]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        stats = []\n        for batch_data in self.dataset:\n            stats.append(transform(batch_data[0]))\n        summary_report = summary_analyzer(stats)\n        report_format = summary_analyzer.get_report_format()\n        assert verify_report_format(summary_report, report_format)\n\n    def test_fg_image_stats_summary_analyzer(self):\n        summary_analyzer = FgImageStatsSumm(\"image_foreground_stats\")\n\n        transform_list = [\n            LoadImaged(keys=[\"image\", \"label\"]),\n            EnsureChannelFirstd(keys=[\"image\", \"label\"]),  # this creates label to be (1,H,W,D)\n            ToDeviced(keys=[\"image\", \"label\"], device=device, non_blocking=True),\n            Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n            EnsureTyped(keys=[\"image\", \"label\"], data_type=\"tensor\"),\n            Lambdad(keys=\"label\", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),\n            SqueezeDimd(keys=[\"label\"], dim=0),\n            FgImageStats(image_key=\"image\", label_key=\"label\"),\n        ]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        stats = []\n        for batch_data in self.dataset:\n            stats.append(transform(batch_data[0]))\n        summary_report = summary_analyzer(stats)\n        report_format = summary_analyzer.get_report_format()\n        assert verify_report_format(summary_report, report_format)\n\n    def test_label_stats_summary_analyzer(self):\n        summary_analyzer = LabelStatsSumm(\"label_stats\")\n\n        transform_list = [\n            LoadImaged(keys=[\"image\", \"label\"]),\n            EnsureChannelFirstd(keys=[\"image\", \"label\"]),  # this creates label to be (1,H,W,D)\n            ToDeviced(keys=[\"image\", \"label\"], device=device, non_blocking=True),\n            Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n            EnsureTyped(keys=[\"image\", \"label\"], data_type=\"tensor\"),\n            Lambdad(keys=\"label\", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),\n            SqueezeDimd(keys=[\"label\"], dim=0),\n            LabelStats(image_key=\"image\", label_key=\"label\"),\n        ]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        stats = []\n        for batch_data in self.dataset:\n            stats.append(transform(batch_data[0]))\n        summary_report = summary_analyzer(stats)\n        report_format = summary_analyzer.get_report_format()\n        assert verify_report_format(summary_report, report_format)\n\n    def test_seg_summarizer(self):\n        summarizer = SegSummarizer(\"image\", \"label\")\n        keys = [\"image\", \"label\"]\n        transform_list = [\n            LoadImaged(keys=keys),\n            EnsureChannelFirstd(keys=keys),  # this creates label to be (1,H,W,D)\n            ToDeviced(keys=keys, device=device, non_blocking=True),\n            Orientationd(keys=keys, axcodes=\"RAS\"),\n            EnsureTyped(keys=keys, data_type=\"tensor\"),\n            Lambdad(keys=\"label\", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),\n            SqueezeDimd(keys=[\"label\"], dim=0),\n            summarizer,\n        ]\n        transform = Compose(transform_list)\n        create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)\n        files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)\n        ds = Dataset(data=files)\n        self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)\n        stats = []\n        for batch_data in self.dataset:\n            d = transform(batch_data[0])\n            stats.append(d)\n        report = summarizer.summarize(stats)\n        assert str(DataStatsKeys.IMAGE_STATS) in report\n        assert str(DataStatsKeys.FG_IMAGE_STATS) in report\n        assert str(DataStatsKeys.LABEL_STATS) in report\n\n    def tearDown(self) -> None:\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_auto3dseg_bundlegen.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport sys\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nfrom monai.apps.auto3dseg import BundleGen, DataAnalyzer\nfrom monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import create_test_image_3d\nfrom monai.utils import set_determinism\nfrom tests.test_utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick\n\nnum_images_perfold = max(torch.cuda.device_count(), 4)\nnum_images_per_batch = 2\n\nsim_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": \"val_001.fake.nii.gz\"}, {\"image\": \"val_002.fake.nii.gz\"}],\n    \"training\": [\n        {\n            \"fold\": f,\n            \"image\": f\"tr_image_{(f * num_images_perfold + idx):03d}.nii.gz\",\n            \"label\": f\"tr_label_{(f * num_images_perfold + idx):03d}.nii.gz\",\n        }\n        for f in range(num_images_per_batch + 1)\n        for idx in range(num_images_perfold)\n    ],\n}\n\n\ndef create_sim_data(dataroot, sim_datalist, sim_dim, **kwargs):\n    \"\"\"\n    Create simulated data using create_test_image_3d.\n\n    Args:\n        dataroot: data directory path that hosts the \"nii.gz\" image files.\n        sim_datalist: a list of data to create.\n        sim_dim: the image sizes, e.g. a tuple of (64, 64, 64).\n    \"\"\"\n    if not os.path.isdir(dataroot):\n        os.makedirs(dataroot)\n\n    # Generate a fake dataset\n    for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n        im, seg = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs)\n        nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n        image_fpath = os.path.join(dataroot, d[\"image\"])\n        nib.save(nib_image, image_fpath)\n\n        if \"label\" in d:\n            nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n            label_fpath = os.path.join(dataroot, d[\"label\"])\n            nib.save(nib_image, label_fpath)\n\n\ndef run_auto3dseg_before_bundlegen(test_path, work_dir):\n    \"\"\"\n    Run the Auto3DSeg modules before the BundleGen step.\n    Args:\n        test_path: a path to contain `sim_dataroot` which is for the simulated dataset file.\n        work_dir: working directory\n\n    Returns:\n        Paths of the outputs from the previous steps\n    \"\"\"\n\n    if not os.path.isdir(work_dir):\n        os.makedirs(work_dir)\n\n    # write to a json file\n    dataroot_dir = os.path.join(test_path, \"sim_dataroot\")\n    datalist_file = os.path.join(work_dir, \"sim_datalist.json\")\n    ConfigParser.export_config_file(sim_datalist, datalist_file)\n    create_sim_data(dataroot_dir, sim_datalist, (24, 24, 24), rad_max=10, rad_min=1, num_seg_classes=1)\n\n    datastats_file = os.path.join(work_dir, \"datastats.yaml\")\n    analyser = DataAnalyzer(datalist_file, dataroot_dir, output_path=os.path.join(work_dir, \"datastats.yaml\"))\n    analyser.get_all_case_stats()\n\n    return dataroot_dir, datalist_file, datastats_file\n\n\n@skip_if_no_cuda\n@skip_if_quick\nclass TestBundleGen(unittest.TestCase):\n    def setUp(self) -> None:\n        set_determinism(0)\n        self.test_dir = tempfile.TemporaryDirectory()\n\n    def test_move_bundle_gen_folder(self) -> None:\n        test_path = self.test_dir.name\n        work_dir = os.path.join(test_path, \"workdir\")\n        dataroot_dir, datalist_file, datastats_file = run_auto3dseg_before_bundlegen(test_path, work_dir)\n        data_src = {\n            \"name\": \"fake_data\",\n            \"task\": \"segmentation\",\n            \"modality\": \"MRI\",\n            \"datalist\": datalist_file,\n            \"dataroot\": dataroot_dir,\n            \"multigpu\": False,\n            \"class_names\": [\"label_class\"],\n        }\n        data_src_cfg = os.path.join(work_dir, \"data_src_cfg.yaml\")\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n\n        sys_path = sys.path.copy()\n        with skip_if_downloading_fails():\n            bundle_generator = BundleGen(\n                algo_path=work_dir,\n                data_stats_filename=datastats_file,\n                data_src_cfg_name=data_src_cfg,\n                templates_path_or_url=get_testing_algo_template_path(),\n            )\n\n        bundle_generator.generate(work_dir, num_fold=1)\n        history_before = bundle_generator.get_history()\n        export_bundle_algo_history(history_before)\n\n        sys.path = sys_path  # prevent the import_bundle_algo_history from using the path \"work_dir/algorithm_templates\"\n        tempfile.TemporaryDirectory()\n        work_dir_new = os.path.join(test_path, \"workdir_2\")\n        shutil.move(work_dir, work_dir_new)\n        history_after = import_bundle_algo_history(work_dir_new, only_trained=False)\n        self.assertEqual(len(history_before), len(history_after))\n\n    def tearDown(self) -> None:\n        set_determinism(None)\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_check_hash.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps import check_hash\n\nTEST_CASE_1 = [\"b94716452086a054208395e8c9d1ae2a\", \"md5\", True]\n\nTEST_CASE_2 = [\"abcdefg\", \"md5\", False]\n\nTEST_CASE_3 = [None, \"md5\", True]\n\nTEST_CASE_4 = [None, \"sha1\", True]\n\nTEST_CASE_5 = [\"b4dc3c246b298eae37cefdfdd2a50b091ffd5e69\", \"sha1\", True]\n\n\nclass TestCheckMD5(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_result(self, md5_value, t, expected_result):\n        test_image = np.ones((5, 5, 3))\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_file.png\")\n            test_image.tofile(filename)\n\n            result = check_hash(filename, md5_value, hash_type=t)\n            self.assertTrue(result == expected_result)\n\n    def test_hash_type_error(self):\n        with self.assertRaises(ValueError):\n            with tempfile.TemporaryDirectory() as tempdir:\n                check_hash(tempdir, \"test_hash\", \"test_type\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_cross_validation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom pathlib import Path\n\nfrom monai.apps import CrossValidation, DecathlonDataset\nfrom monai.data import MetaTensor\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick\n\n\nclass TestCrossValidation(unittest.TestCase):\n    @skip_if_quick\n    def test_values(self):\n        testing_dir = Path(__file__).parents[1] / \"testing_data\"\n        train_transform = Compose(\n            [\n                LoadImaged(keys=[\"image\", \"label\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=\"image\"),\n            ]\n        )\n        val_transform = LoadImaged(keys=[\"image\", \"label\"])\n\n        def _test_dataset(dataset):\n            self.assertEqual(len(dataset), 52)\n            self.assertTrue(\"image\" in dataset[0])\n            self.assertTrue(\"label\" in dataset[0])\n            self.assertTrue(isinstance(dataset[0][\"image\"], MetaTensor))\n            self.assertTupleEqual(dataset[0][\"image\"].shape, (1, 34, 49, 41))\n\n        cvdataset = CrossValidation(\n            dataset_cls=DecathlonDataset,\n            nfolds=5,\n            seed=12345,\n            root_dir=testing_dir,\n            task=\"Task04_Hippocampus\",\n            section=\"validation\",\n            transform=train_transform,\n            download=True,\n        )\n\n        with skip_if_downloading_fails():\n            data = cvdataset.get_dataset(folds=0)\n\n        _test_dataset(data)\n\n        # test training data for fold [1, 2, 3, 4] of 5 splits\n        data = cvdataset.get_dataset(folds=[1, 2, 3, 4])\n        self.assertTupleEqual(data[0][\"image\"].shape, (1, 35, 52, 33))\n        self.assertEqual(len(data), 208)\n        # test train / validation for fold 4 of 5 splits\n        data = cvdataset.get_dataset(folds=[4], transform=val_transform, download=False)\n        # val_transform doesn't add the channel dim to shape\n        self.assertTupleEqual(data[0][\"image\"].shape, (38, 53, 30))\n        self.assertEqual(len(data), 52)\n        data = cvdataset.get_dataset(folds=[0, 1, 2, 3])\n        self.assertTupleEqual(data[0][\"image\"].shape, (1, 34, 49, 41))\n        self.assertEqual(len(data), 208)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_decathlondataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport unittest\nfrom pathlib import Path\n\nfrom monai.apps import DecathlonDataset\nfrom monai.data import MetaTensor\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick\n\n\nclass TestDecathlonDataset(unittest.TestCase):\n    @skip_if_quick\n    def test_values(self):\n        testing_dir = Path(__file__).resolve().parents[1] / \"testing_data\"\n        transform = Compose(\n            [\n                LoadImaged(keys=[\"image\", \"label\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=\"image\"),\n            ]\n        )\n\n        def _test_dataset(dataset):\n            self.assertEqual(len(dataset), 52)\n            self.assertTrue(\"image\" in dataset[0])\n            self.assertTrue(\"label\" in dataset[0])\n            self.assertTrue(isinstance(dataset[0][\"image\"], MetaTensor))\n            self.assertTupleEqual(dataset[0][\"image\"].shape, (1, 36, 47, 44))\n\n        with skip_if_downloading_fails():\n            data = DecathlonDataset(\n                root_dir=testing_dir,\n                task=\"Task04_Hippocampus\",\n                transform=transform,\n                section=\"validation\",\n                download=True,\n                copy_cache=False,\n            )\n\n        _test_dataset(data)\n        data = DecathlonDataset(\n            root_dir=testing_dir,\n            task=\"Task04_Hippocampus\",\n            transform=transform,\n            section=\"validation\",\n            download=False,\n            runtime_cache=True,\n        )\n        _test_dataset(data)\n        self.assertTrue(data[0][\"image\"].meta[\"filename_or_obj\"].endswith(\"hippocampus_163.nii.gz\"))\n        self.assertTrue(data[0][\"label\"].meta[\"filename_or_obj\"].endswith(\"hippocampus_163.nii.gz\"))\n        # test validation without transforms\n        data = DecathlonDataset(root_dir=testing_dir, task=\"Task04_Hippocampus\", section=\"validation\", download=False)\n        self.assertTupleEqual(data[0][\"image\"].shape, (36, 47, 44))\n        self.assertEqual(len(data), 52)\n        data = DecathlonDataset(root_dir=testing_dir, task=\"Task04_Hippocampus\", section=\"training\", download=False)\n        self.assertTupleEqual(data[0][\"image\"].shape, (34, 56, 31))\n        self.assertEqual(len(data), 208)\n\n        # test dataset properties\n        data = DecathlonDataset(\n            root_dir=Path(testing_dir), task=\"Task04_Hippocampus\", section=\"validation\", download=False\n        )\n        properties = data.get_properties(keys=\"labels\")\n        self.assertDictEqual(properties[\"labels\"], {\"0\": \"background\", \"1\": \"Anterior\", \"2\": \"Posterior\"})\n\n        shutil.rmtree(os.path.join(testing_dir, \"Task04_Hippocampus\"))\n        with self.assertRaisesRegex(RuntimeError, \"^Cannot find dataset directory\"):\n            DecathlonDataset(\n                root_dir=testing_dir,\n                task=\"Task04_Hippocampus\",\n                transform=transform,\n                section=\"validation\",\n                download=False,\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_download_and_extract.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tarfile\nimport tempfile\nimport unittest\nimport zipfile\nfrom pathlib import Path\nfrom urllib.error import ContentTooShortError, HTTPError\n\nfrom parameterized import parameterized\n\nfrom monai.apps import download_and_extract, download_url, extractall\nfrom tests.test_utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, testing_data_config\n\n\n@SkipIfNoModule(\"requests\")\nclass TestDownloadAndExtract(unittest.TestCase):\n    @skip_if_quick\n    def test_actions(self):\n        testing_dir = Path(__file__).parents[1] / \"testing_data\"\n        config_dict = testing_data_config(\"images\", \"mednist\")\n        url = config_dict[\"url\"]\n        filepath = Path(testing_dir) / \"MedNIST.tar.gz\"\n        output_dir = Path(testing_dir)\n        hash_val, hash_type = config_dict[\"hash_val\"], config_dict[\"hash_type\"]\n        with skip_if_downloading_fails():\n            download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)\n            download_and_extract(url, filepath, output_dir, hash_val=hash_val, hash_type=hash_type)\n\n        wrong_md5 = \"0\"\n        with self.assertLogs(logger=\"monai.apps\", level=\"ERROR\"):\n            try:\n                download_url(url, filepath, wrong_md5)\n            except (ContentTooShortError, HTTPError, RuntimeError) as e:\n                if isinstance(e, RuntimeError):\n                    # FIXME: skip MD5 check as current downloading method may fail\n                    self.assertTrue(str(e).startswith(\"md5 check\"))\n                return  # skipping this test due the network connection errors\n\n        try:\n            extractall(filepath, output_dir, wrong_md5)\n        except RuntimeError as e:\n            self.assertTrue(str(e).startswith(\"md5 check\"))\n\n    @skip_if_quick\n    @parameterized.expand(((\"icon\", \"tar\"), (\"favicon\", \"zip\")))\n    def test_default(self, key, file_type):\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            with skip_if_downloading_fails():\n                img_spec = testing_data_config(\"images\", key)\n                download_and_extract(\n                    img_spec[\"url\"],\n                    output_dir=tmp_dir,\n                    hash_val=img_spec[\"hash_val\"],\n                    hash_type=img_spec[\"hash_type\"],\n                    file_type=file_type,\n                )\n\n\nclass TestPathTraversalProtection(unittest.TestCase):\n    \"\"\"Test cases for path traversal attack protection in extractall function.\"\"\"\n\n    def test_valid_zip_extraction(self):\n        \"\"\"Test that valid zip files extract successfully without raising exceptions.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create a valid zip file\n            zip_path = Path(tmp_dir) / \"valid_test.zip\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            # Create zip with normal file structure\n            with zipfile.ZipFile(zip_path, \"w\") as zf:\n                zf.writestr(\"normal_file.txt\", \"This is a normal file\")\n                zf.writestr(\"subdir/nested_file.txt\", \"This is a nested file\")\n                zf.writestr(\"another_file.json\", '{\"key\": \"value\"}')\n\n            # This should not raise any exception\n            try:\n                extractall(str(zip_path), str(extract_dir))\n\n                # Verify files were extracted correctly\n                self.assertTrue((extract_dir / \"normal_file.txt\").exists())\n                self.assertTrue((extract_dir / \"subdir\" / \"nested_file.txt\").exists())\n                self.assertTrue((extract_dir / \"another_file.json\").exists())\n\n                # Verify content\n                with open(extract_dir / \"normal_file.txt\") as f:\n                    self.assertEqual(f.read(), \"This is a normal file\")\n\n            except Exception as e:\n                self.fail(f\"Valid zip extraction should not raise exception: {e}\")\n\n    def test_malicious_zip_path_traversal(self):\n        \"\"\"Test that malicious zip files with path traversal attempts raise ValueError.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create malicious zip file with path traversal\n            zip_path = Path(tmp_dir) / \"malicious_test.zip\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            # Create zip with malicious paths\n            with zipfile.ZipFile(zip_path, \"w\") as zf:\n                # Try to write outside extraction directory\n                zf.writestr(\"../../../etc/malicious.txt\", \"malicious content\")\n                zf.writestr(\"normal_file.txt\", \"normal content\")\n\n            # This should raise ValueError due to path traversal detection\n            with self.assertRaises(ValueError) as context:\n                extractall(str(zip_path), str(extract_dir))\n\n            self.assertIn(\"unsafe path\", str(context.exception).lower())\n\n    def test_valid_tar_extraction(self):\n        \"\"\"Test that valid tar files extract successfully without raising exceptions.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create a valid tar file\n            tar_path = Path(tmp_dir) / \"valid_test.tar.gz\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            # Create tar with normal file structure\n            with tarfile.open(tar_path, \"w:gz\") as tf:\n                # Create temporary files to add to tar\n                temp_file1 = Path(tmp_dir) / \"temp1.txt\"\n                temp_file1.write_text(\"This is a normal file\")\n                tf.add(temp_file1, arcname=\"normal_file.txt\")\n\n                temp_file2 = Path(tmp_dir) / \"temp2.txt\"\n                temp_file2.write_text(\"This is a nested file\")\n                tf.add(temp_file2, arcname=\"subdir/nested_file.txt\")\n\n            # This should not raise any exception\n            try:\n                extractall(str(tar_path), str(extract_dir))\n\n                # Verify files were extracted correctly\n                self.assertTrue((extract_dir / \"normal_file.txt\").exists())\n                self.assertTrue((extract_dir / \"subdir\" / \"nested_file.txt\").exists())\n\n                # Verify content\n                with open(extract_dir / \"normal_file.txt\") as f:\n                    self.assertEqual(f.read(), \"This is a normal file\")\n\n            except Exception as e:\n                self.fail(f\"Valid tar extraction should not raise exception: {e}\")\n\n    def test_malicious_tar_path_traversal(self):\n        \"\"\"Test that malicious tar files with path traversal attempts raise ValueError.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create malicious tar file with path traversal\n            tar_path = Path(tmp_dir) / \"malicious_test.tar.gz\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            # Create tar with malicious paths\n            with tarfile.open(tar_path, \"w:gz\") as tf:\n                # Create a temporary file\n                temp_file = Path(tmp_dir) / \"temp.txt\"\n                temp_file.write_text(\"malicious content\")\n\n                # Add with malicious path (trying to write outside extraction directory)\n                tf.add(temp_file, arcname=\"../../../etc/malicious.txt\")\n\n            # This should raise ValueError due to path traversal detection\n            with self.assertRaises(ValueError) as context:\n                extractall(str(tar_path), str(extract_dir))\n\n            self.assertIn(\"unsafe path\", str(context.exception).lower())\n\n    def test_absolute_path_protection(self):\n        \"\"\"Test protection against absolute paths in archives.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create zip with absolute path\n            zip_path = Path(tmp_dir) / \"absolute_path_test.zip\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            with zipfile.ZipFile(zip_path, \"w\") as zf:\n                # Try to use absolute path\n                zf.writestr(\"/etc/passwd_bad\", \"malicious content\")\n\n            # This should raise ValueError due to absolute path detection\n            with self.assertRaises(ValueError) as context:\n                extractall(str(zip_path), str(extract_dir))\n\n            self.assertIn(\"unsafe path\", str(context.exception).lower())\n\n    def test_malicious_symlink_protection(self):\n        \"\"\"Test protection against malicious symlinks in tar archives.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create malicious tar file with symlink\n            tar_path = Path(tmp_dir) / \"malicious_symlink_test.tar.gz\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            # Create tar with malicious symlink\n            with tarfile.open(tar_path, \"w:gz\") as tf:\n                temp_file = Path(tmp_dir) / \"normal.txt\"\n                temp_file.write_text(\"normal content\")\n                tf.add(temp_file, arcname=\"normal.txt\")\n\n                symlink_info = tarfile.TarInfo(name=\"malicious_symlink.txt\")\n                symlink_info.type = tarfile.SYMTYPE\n                symlink_info.linkname = \"../../../etc/passwd_bad\"\n                symlink_info.size = 0\n                tf.addfile(symlink_info)\n\n            with self.assertRaises(ValueError) as context:\n                extractall(str(tar_path), str(extract_dir))\n\n            error_msg = str(context.exception).lower()\n            self.assertTrue(\"unsafe path\" in error_msg or \"symlink\" in error_msg)\n\n    def test_malicious_hardlink_protection(self):\n        \"\"\"Test protection against malicious hard links in tar archives.\"\"\"\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            # Create malicious tar file with hard link\n            tar_path = Path(tmp_dir) / \"malicious_hardlink_test.tar.gz\"\n            extract_dir = Path(tmp_dir) / \"extract\"\n            extract_dir.mkdir()\n\n            # Create tar with malicious hard link\n            with tarfile.open(tar_path, \"w:gz\") as tf:\n                temp_file = Path(tmp_dir) / \"normal.txt\"\n                temp_file.write_text(\"normal content\")\n                tf.add(temp_file, arcname=\"normal.txt\")\n\n                hardlink_info = tarfile.TarInfo(name=\"malicious_hardlink.txt\")\n                hardlink_info.type = tarfile.LNKTYPE\n                hardlink_info.linkname = \"/etc/passwd_bad\"\n                hardlink_info.size = 0\n                tf.addfile(hardlink_info)\n\n            with self.assertRaises(ValueError) as context:\n                extractall(str(tar_path), str(extract_dir))\n\n            error_msg = str(context.exception).lower()\n            self.assertTrue(\"unsafe path\" in error_msg or \"hardlink\" in error_msg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_download_url_yandex.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom urllib.error import HTTPError\n\nfrom monai.apps.utils import download_url\n\nYANDEX_MODEL_URL = (\n    \"https://cloud-api.yandex.net/v1/disk/public/resources/download?\"\n    \"public_key=https%3A%2F%2Fdisk.yandex.ru%2Fd%2Fxs0gzlj2_irgWA\"\n)\nYANDEX_MODEL_FLAWED_URL = (\n    \"https://cloud-api.yandex.net/v1/disk/public/resources/download?\"\n    \"public_key=https%3A%2F%2Fdisk.yandex.ru%2Fd%2Fxs0gzlj2_irgWA-url-with-error\"\n)\n\n\nclass TestDownloadUrlYandex(unittest.TestCase):\n\n    @unittest.skip(\"data source unstable\")\n    def test_verify(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            download_url(url=YANDEX_MODEL_URL, filepath=os.path.join(tempdir, \"model.pt\"))\n\n    def test_verify_error(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            with self.assertRaises(HTTPError):\n                download_url(url=YANDEX_MODEL_FLAWED_URL, filepath=os.path.join(tempdir, \"model.pt\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_mednistdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport unittest\nfrom pathlib import Path\n\nfrom monai.apps import MedNISTDataset\nfrom monai.data import MetaTensor\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick\n\nMEDNIST_FULL_DATASET_LENGTH = 58954\n\n\nclass TestMedNISTDataset(unittest.TestCase):\n    @skip_if_quick\n    def test_values(self):\n        testing_dir = Path(__file__).parents[1] / \"testing_data\"\n        transform = Compose(\n            [\n                LoadImaged(keys=\"image\"),\n                EnsureChannelFirstd(keys=\"image\", channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=\"image\"),\n            ]\n        )\n\n        def _test_dataset(dataset):\n            self.assertEqual(len(dataset), int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac))\n            self.assertTrue(\"image\" in dataset[0])\n            self.assertTrue(\"label\" in dataset[0])\n            self.assertIsInstance(dataset[0][\"image\"], MetaTensor)\n            self.assertTupleEqual(dataset[0][\"image\"].shape, (1, 64, 64))\n\n        with skip_if_downloading_fails():\n            data = MedNISTDataset(\n                root_dir=testing_dir, transform=transform, section=\"test\", download=True, copy_cache=False\n            )\n\n        _test_dataset(data)\n\n        # testing from\n        data = MedNISTDataset(\n            root_dir=Path(testing_dir), transform=transform, section=\"test\", download=False, runtime_cache=True\n        )\n        self.assertEqual(data.get_num_classes(), 6)\n        _test_dataset(data)\n        data = MedNISTDataset(root_dir=testing_dir, section=\"test\", download=False)\n        self.assertTupleEqual(data[0][\"image\"].shape, (64, 64))\n        # test same dataset length with different random seed\n        data = MedNISTDataset(root_dir=testing_dir, transform=transform, section=\"test\", download=False, seed=42)\n        _test_dataset(data)\n        self.assertEqual(data[0][\"class_name\"], \"AbdomenCT\")\n        self.assertEqual(data[0][\"label\"], 0)\n        shutil.rmtree(os.path.join(testing_dir, \"MedNIST\"))\n        with self.assertRaisesRegex(RuntimeError, \"^Cannot find dataset directory\"):\n            MedNISTDataset(root_dir=testing_dir, transform=transform, section=\"test\", download=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_mmar_download.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai import __version__\nfrom monai.apps import download_mmar, load_from_mmar\nfrom monai.apps.mmars import MODEL_DESC\nfrom monai.apps.mmars.mmars import _get_val\nfrom monai.utils import version_leq\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick\n\nTEST_CASES = [[\"clara_pt_prostate_mri_segmentation\"], [\"clara_pt_covid19_ct_lesion_segmentation\"]]\nTEST_EXTRACT_CASES = [\n    (\n        {\"item\": \"clara_pt_prostate_mri_segmentation\", \"map_location\": \"cuda\" if torch.cuda.is_available() else \"cpu\"},\n        \"UNet\",\n        np.array(\n            [\n                [[-0.0838, 0.0116, -0.0861], [-0.0792, 0.2216, -0.0301], [-0.0379, 0.0006, -0.0399]],\n                [[-0.0347, 0.0979, 0.0754], [0.1689, 0.3759, 0.2584], [-0.0698, 0.2740, 0.1414]],\n                [[-0.0772, 0.1046, -0.0103], [0.0917, 0.1942, 0.0284], [-0.0165, -0.0181, 0.0247]],\n            ]\n        ),\n    ),\n    (\n        {\n            \"item\": \"clara_pt_covid19_ct_lesion_segmentation\",\n            \"map_location\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n        },\n        \"SegResNet\",\n        np.array(\n            [\n                [\n                    [0.01671106, 0.08502351, -0.1766469],\n                    [-0.13039736, -0.06137804, 0.03924942],\n                    [0.02268324, 0.159056, -0.03485069],\n                ],\n                [\n                    [0.04788467, -0.09365353, -0.05802464],\n                    [-0.19500689, -0.13514304, -0.08191573],\n                    [0.0238207, 0.08029253, 0.10818923],\n                ],\n                [\n                    [-0.11541673, -0.10622888, 0.039689],\n                    [0.18462701, -0.0499289, 0.14309818],\n                    [0.00528282, 0.02152331, 0.1698219],\n                ],\n            ]\n        ),\n    ),\n    (\n        {\n            \"item\": \"clara_pt_fed_learning_brain_tumor_mri_segmentation\",\n            \"map_location\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n            \"model_file\": os.path.join(\"models\", \"server\", \"best_FL_global_model.pt\"),\n        },\n        \"SegResNet\",\n        np.array(\n            [\n                [\n                    [0.01874463, 0.12237817, 0.09269974],\n                    [0.07691482, 0.00621202, -0.06682577],\n                    [-0.07718472, 0.08637864, -0.03222707],\n                ],\n                [\n                    [0.05117761, 0.07428649, -0.03053505],\n                    [0.11045473, 0.07083791, 0.06547518],\n                    [0.09555705, -0.03950734, -0.00819483],\n                ],\n                [\n                    [0.03704128, 0.062543, 0.0380853],\n                    [-0.02814676, -0.03078287, -0.01383446],\n                    [-0.08137762, 0.01385882, 0.01229484],\n                ],\n            ]\n        ),\n    ),\n    (\n        {\n            \"item\": \"clara_pt_pathology_metastasis_detection\",\n            \"map_location\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n        },\n        \"TorchVisionFCModel\",\n        np.array(\n            [\n                [-0.00540746, -0.00274996, -0.00837622, 0.05415914, 0.03555066, -0.00071636, -0.02325751],\n                [0.00564625, 0.00674562, -0.1098334, -0.2936509, -0.28384757, -0.13580588, -0.00737865],\n                [-0.02159783, 0.04615543, 0.29717407, 0.6001161, 0.53496915, 0.2528417, 0.04530451],\n                [0.0225903, -0.07556137, -0.3070122, -0.43984795, -0.26286602, -0.00172576, 0.05003437],\n                [-0.0320133, 0.00855468, 0.06824744, -0.04786247, -0.30358723, -0.3960023, -0.24895012],\n                [0.02412516, 0.03411723, 0.06513759, 0.24332047, 0.41664436, 0.38999054, 0.15957521],\n                [-0.01303542, -0.00166874, -0.01965466, -0.06620175, -0.15635538, -0.10023144, -0.01698002],\n            ]\n        ),\n    ),\n]\n\n\n@unittest.skip(\"deprecating mmar tests\")\nclass TestMMMARDownload(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skip_if_quick\n    def test_download(self, idx):\n        with skip_if_downloading_fails():\n            with self.assertLogs(level=\"INFO\", logger=\"monai.apps\"):\n                download_mmar(idx)\n            download_mmar(idx, progress=False)  # repeated to check caching\n            with tempfile.TemporaryDirectory() as tmp_dir:\n                download_mmar(idx, mmar_dir=tmp_dir, progress=False)\n                download_mmar(idx, mmar_dir=Path(tmp_dir), progress=False, version=1)  # repeated to check caching\n                self.assertTrue(os.path.exists(os.path.join(tmp_dir, idx)))\n\n    @parameterized.expand(TEST_EXTRACT_CASES)\n    @skip_if_quick\n    @unittest.skipIf(version_leq(__version__, \"0.6\"), \"requires newer monai\")\n    def test_load_ckpt(self, input_args, expected_name, expected_val):\n        with skip_if_downloading_fails():\n            output = load_from_mmar(**input_args)\n        self.assertEqual(output.__class__.__name__, expected_name)\n        x = next(output.parameters())  # verify the first element\n        np.testing.assert_allclose(x[0][0].detach().cpu().numpy(), expected_val, rtol=1e-3, atol=1e-3)\n\n    def test_unique(self):\n        # model ids are unique\n        keys = sorted(m[\"id\"] for m in MODEL_DESC)\n        self.assertEqual(keys, sorted(set(keys)))\n\n    def test_search(self):\n        self.assertEqual(_get_val({\"a\": 1, \"b\": 2}, key=\"b\"), 2)\n        self.assertEqual(_get_val({\"a\": {\"c\": {\"c\": 4}}, \"b\": {\"c\": 2}}, key=\"b\"), {\"c\": 2})\n        self.assertEqual(_get_val({\"a\": {\"c\": 4}, \"b\": {\"c\": 2}}, key=\"c\"), 4)\n        self.assertEqual(_get_val({\"a\": {\"c\": None}, \"b\": {\"c\": 2}}, key=\"c\"), 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/test_tciadataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport unittest\nfrom pathlib import Path\n\nfrom monai.apps import TciaDataset\nfrom monai.apps.tcia import DCM_FILENAME_REGEX, TCIA_LABEL_DICT\nfrom monai.data import MetaTensor\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick\n\n\nclass TestTciaDataset(unittest.TestCase):\n    @skip_if_quick\n    def test_values(self):\n        testing_dir = Path(__file__).parents[1] / \"testing_data\"\n        download_len = 1\n        val_frac = 1.0\n        collection = \"QIN-PROSTATE-Repeatability\"\n\n        transform = Compose(\n            [\n                LoadImaged(\n                    keys=[\"image\", \"seg\"],\n                    reader=\"PydicomReader\",\n                    fname_regex=DCM_FILENAME_REGEX,\n                    label_dict=TCIA_LABEL_DICT[collection],\n                ),\n                EnsureChannelFirstd(keys=\"image\", channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=\"image\"),\n            ]\n        )\n\n        def _test_dataset(dataset):\n            self.assertEqual(len(dataset), int(download_len * val_frac))\n            self.assertTrue(\"image\" in dataset[0])\n            self.assertTrue(\"seg\" in dataset[0])\n            self.assertTrue(isinstance(dataset[0][\"image\"], MetaTensor))\n            self.assertTupleEqual(dataset[0][\"image\"].shape, (1, 256, 256, 24))\n            self.assertTupleEqual(dataset[0][\"seg\"].shape, (256, 256, 24, 4))\n\n        with skip_if_downloading_fails():\n            data = TciaDataset(\n                root_dir=testing_dir,\n                collection=collection,\n                transform=transform,\n                section=\"validation\",\n                download=True,\n                download_len=download_len,\n                copy_cache=False,\n                val_frac=val_frac,\n            )\n\n        _test_dataset(data)\n        data = TciaDataset(\n            root_dir=testing_dir,\n            collection=collection,\n            transform=transform,\n            section=\"validation\",\n            download=False,\n            val_frac=val_frac,\n            runtime_cache=True,\n        )\n        _test_dataset(data)\n        self.assertTrue(\n            data[0][\"image\"].meta[\"filename_or_obj\"].endswith(\"QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/image\")\n        )\n        self.assertTrue(\n            data[0][\"seg\"].meta[\"filename_or_obj\"].endswith(\"QIN-PROSTATE-Repeatability/PCAMPMRI-00015/1901/seg\")\n        )\n        # test validation without transforms\n        data = TciaDataset(\n            root_dir=testing_dir, collection=collection, section=\"validation\", download=False, val_frac=val_frac\n        )\n        self.assertTupleEqual(data[0][\"image\"].shape, (256, 256, 24))\n        self.assertEqual(len(data), int(download_len * val_frac))\n        data = TciaDataset(\n            root_dir=testing_dir,\n            collection=collection,\n            section=\"validation\",\n            download=False,\n            fname_regex=DCM_FILENAME_REGEX,\n            val_frac=val_frac,\n        )\n        self.assertTupleEqual(data[0][\"image\"].shape, (256, 256, 24))\n        self.assertEqual(len(data), download_len)\n        with self.assertWarns(UserWarning):\n            data = TciaDataset(\n                root_dir=testing_dir,\n                collection=collection,\n                section=\"validation\",\n                fname_regex=\".*\",  # all files including 'LICENSE' is not a valid input\n                download=False,\n                val_frac=val_frac,\n            )[0]\n\n        shutil.rmtree(os.path.join(testing_dir, collection))\n        with self.assertRaisesRegex(RuntimeError, \"^Cannot find dataset directory\"):\n            TciaDataset(\n                root_dir=testing_dir,\n                collection=collection,\n                transform=transform,\n                section=\"validation\",\n                download=False,\n                val_frac=val_frac,\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/vista3d/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/apps/vista3d/test_point_based_window_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.vista3d.inferer import point_based_window_inferer\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.vista3d import vista3d132\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_quick\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n_, has_tqdm = optional_import(\"tqdm\")\n\nTEST_CASES = [\n    [\n        {\"encoder_embed_dim\": 48, \"in_channels\": 1},\n        (1, 1, 64, 64, 64),\n        {\n            \"roi_size\": [32, 32, 32],\n            \"point_coords\": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),\n            \"point_labels\": torch.tensor([[1, 0]], device=device),\n        },\n    ],\n    [\n        {\"encoder_embed_dim\": 48, \"in_channels\": 1},\n        (1, 1, 64, 64, 64),\n        {\n            \"roi_size\": [32, 32, 32],\n            \"point_coords\": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),\n            \"point_labels\": torch.tensor([[1, 0]], device=device),\n            \"class_vector\": torch.tensor([1], device=device),\n        },\n    ],\n    [\n        {\"encoder_embed_dim\": 48, \"in_channels\": 1},\n        (1, 1, 64, 64, 64),\n        {\n            \"roi_size\": [32, 32, 32],\n            \"point_coords\": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),\n            \"point_labels\": torch.tensor([[1, 0]], device=device),\n            \"class_vector\": torch.tensor([1], device=device),\n            \"point_start\": 1,\n        },\n    ],\n]\n\n\n@skip_if_quick\nclass TestPointBasedWindowInferer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_vista3d(self, vista3d_params, inputs_shape, inferer_params):\n        vista3d = vista3d132(**vista3d_params).to(device)\n        with eval_mode(vista3d):\n            inferer_params[\"predictor\"] = vista3d\n            inferer_params[\"inputs\"] = torch.randn(*inputs_shape).to(device)\n            stitched_output = point_based_window_inferer(**inferer_params)\n            self.assertEqual(stitched_output.shape, inputs_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/vista3d/test_vista3d_sampler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.vista3d.sampler import sample_prompt_pairs\n\nlabel = torch.zeros([1, 1, 64, 64, 64])\nlabel[:, :, :10, :10, :10] = 1\nlabel[:, :, 20:30, 20:30, 20:30] = 2\nlabel[:, :, 30:40, 30:40, 30:40] = 3\nlabel1 = torch.zeros([1, 1, 64, 64, 64])\n\nTEST_VISTA_SAMPLE_PROMPT = [\n    [\n        {\n            \"labels\": label,\n            \"label_set\": [0, 1, 2, 3, 4],\n            \"max_prompt\": 5,\n            \"max_foreprompt\": 4,\n            \"max_backprompt\": 1,\n            \"drop_label_prob\": 0,\n            \"drop_point_prob\": 0,\n        },\n        [4, 4, 4, 4],\n    ],\n    [\n        {\n            \"labels\": label,\n            \"label_set\": [0, 1],\n            \"max_prompt\": 5,\n            \"max_foreprompt\": 4,\n            \"max_backprompt\": 1,\n            \"drop_label_prob\": 0,\n            \"drop_point_prob\": 1,\n        },\n        [2, None, None, 2],\n    ],\n    [\n        {\n            \"labels\": label,\n            \"label_set\": [0, 1, 2, 3, 4],\n            \"max_prompt\": 5,\n            \"max_foreprompt\": 4,\n            \"max_backprompt\": 1,\n            \"drop_label_prob\": 1,\n            \"drop_point_prob\": 0,\n        },\n        [None, 3, 3, 3],\n    ],\n    [\n        {\n            \"labels\": label1,\n            \"label_set\": [0, 1],\n            \"max_prompt\": 5,\n            \"max_foreprompt\": 4,\n            \"max_backprompt\": 1,\n            \"drop_label_prob\": 0,\n            \"drop_point_prob\": 1,\n        },\n        [1, None, None, 1],\n    ],\n    [\n        {\n            \"labels\": label1,\n            \"label_set\": [0, 1],\n            \"max_prompt\": 5,\n            \"max_foreprompt\": 4,\n            \"max_backprompt\": 0,\n            \"drop_label_prob\": 0,\n            \"drop_point_prob\": 1,\n        },\n        [None, None, None, None],\n    ],\n]\n\n\nclass TestGeneratePrompt(unittest.TestCase):\n    @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT)\n    def test_result(self, input_data, expected):\n        output = sample_prompt_pairs(**input_data)\n        result = [i.shape[0] if i is not None else None for i in output]\n        self.assertEqual(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/apps/vista3d/test_vista3d_transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest.case import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd\nfrom monai.utils import min_version\nfrom monai.utils.module import optional_import\n\nmeasure, has_measure = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n\nTEST_VISTA_PRETRANSFORM = [\n    [\n        {\"label_prompt\": [1], \"points\": [[0, 0, 0]], \"point_labels\": [1]},\n        {\"label_prompt\": [1], \"points\": [[0, 0, 0]], \"point_labels\": [3]},\n    ],\n    [\n        {\"label_prompt\": [2], \"points\": [[0, 0, 0]], \"point_labels\": [0]},\n        {\"label_prompt\": [2], \"points\": [[0, 0, 0]], \"point_labels\": [2]},\n    ],\n    [\n        {\"label_prompt\": [3], \"points\": [[0, 0, 0]], \"point_labels\": [0]},\n        {\"label_prompt\": [4, 5], \"points\": [[0, 0, 0]], \"point_labels\": [0]},\n    ],\n    [\n        {\"label_prompt\": [6], \"points\": [[0, 0, 0]], \"point_labels\": [0]},\n        {\"label_prompt\": [7, 8], \"points\": [[0, 0, 0]], \"point_labels\": [0]},\n    ],\n]\n\n\npred1 = torch.zeros([2, 64, 64, 64])\npred1[0, :10, :10, :10] = 1\npred1[1, 20:30, 20:30, 20:30] = 1\noutput1 = torch.zeros([1, 64, 64, 64])\noutput1[:, :10, :10, :10] = 2\noutput1[:, 20:30, 20:30, 20:30] = 3\n\n# -1 is needed since pred should be before sigmoid.\npred2 = torch.zeros([1, 64, 64, 64]) - 1\npred2[:, :10, :10, :10] = 1\npred2[:, 20:30, 20:30, 20:30] = 1\noutput2 = torch.zeros([1, 64, 64, 64])\noutput2[:, 20:30, 20:30, 20:30] = 1\n\nTEST_VISTA_POSTTRANSFORM = [\n    [{\"pred\": pred1.to(device), \"label_prompt\": torch.tensor([2, 3]).to(device)}, output1.to(device)],\n    [\n        {\n            \"pred\": pred2.to(device),\n            \"points\": torch.tensor([[25, 25, 25]]).to(device),\n            \"point_labels\": torch.tensor([1]).to(device),\n        },\n        output2.to(device),\n    ],\n]\n\n\nclass TestVistaPreTransformd(unittest.TestCase):\n    @parameterized.expand(TEST_VISTA_PRETRANSFORM)\n    def test_result(self, input_data, expected):\n        transform = VistaPreTransformd(keys=\"image\", subclass={\"3\": [4, 5], \"6\": [7, 8]}, special_index=[1, 2])\n        result = transform(input_data)\n        self.assertEqual(result, expected)\n\n\n@skipUnless(has_measure, \"skimage.measure required\")\nclass TestVistaPostTransformd(unittest.TestCase):\n    @parameterized.expand(TEST_VISTA_POSTTRANSFORM)\n    def test_result(self, input_data, expected):\n        transform = VistaPostTransformd(keys=\"pred\")\n        result = transform(input_data)\n        self.assertEqual((result[\"pred\"] == expected).all(), True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/bundle/test_bundle_ckpt_export.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser\nfrom monai.data import load_net_with_metadata\nfrom monai.networks import save_state\nfrom tests.test_utils import command_line_tests, skip_if_windows\n\nTESTS_PATH = Path(__file__).parents[1]\n\nTEST_CASE_1 = [\"\", \"\"]\n\nTEST_CASE_2 = [\"model\", \"\"]\n\nTEST_CASE_3 = [\"model\", \"True\"]\n\n\n@skip_if_windows\nclass TestCKPTExport(unittest.TestCase):\n\n    def setUp(self):\n        self.device = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n        if not self.device:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # default\n\n    def tearDown(self):\n        if self.device is not None:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = self.device\n        else:\n            del os.environ[\"CUDA_VISIBLE_DEVICES\"]  # previously unset\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_export(self, key_in_ckpt, use_trace):\n        meta_file = os.path.join(TESTS_PATH, \"testing_data\", \"metadata.json\")\n        config_file = os.path.join(TESTS_PATH, \"testing_data\", \"inference.json\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\"}\n            def_args_file = os.path.join(tempdir, \"def_args.yaml\")\n\n            ckpt_file = os.path.join(tempdir, \"model.pt\")\n            ts_file = os.path.join(tempdir, \"model.ts\")\n\n            parser = ConfigParser()\n            parser.export_config_file(config=def_args, filepath=def_args_file)\n            parser.read_config(config_file)\n            net = parser.get_parsed_content(\"network_def\")\n            save_state(src=net if key_in_ckpt == \"\" else {key_in_ckpt: net}, path=ckpt_file)\n\n            cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"ckpt_export\", \"network_def\", \"--filepath\", ts_file]\n            cmd += [\"--meta_file\", meta_file, \"--config_file\", f\"['{config_file}','{def_args_file}']\", \"--ckpt_file\"]\n            cmd += [ckpt_file, \"--key_in_ckpt\", key_in_ckpt, \"--args_file\", def_args_file]\n            if use_trace == \"True\":\n                cmd += [\"--use_trace\", use_trace, \"--input_shape\", \"[1, 1, 96, 96, 96]\"]\n            command_line_tests(cmd)\n            self.assertTrue(os.path.exists(ts_file))\n\n            _, metadata, extra_files = load_net_with_metadata(\n                ts_file, more_extra_files=[\"inference.json\", \"def_args.json\"]\n            )\n            self.assertIn(\"schema\", metadata)\n            self.assertIn(\"meta_file\", json.loads(extra_files[\"def_args.json\"]))\n            self.assertIn(\"network_def\", json.loads(extra_files[\"inference.json\"]))\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_default_value(self, key_in_ckpt, use_trace):\n        config_file = os.path.join(TESTS_PATH, \"testing_data\", \"inference.json\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\"}\n            def_args_file = os.path.join(tempdir, \"def_args.yaml\")\n            ckpt_file = os.path.join(tempdir, \"models/model.pt\")\n            ts_file = os.path.join(tempdir, \"models/model.ts\")\n\n            parser = ConfigParser()\n            parser.export_config_file(config=def_args, filepath=def_args_file)\n            parser.read_config(config_file)\n            net = parser.get_parsed_content(\"network_def\")\n            save_state(src=net if key_in_ckpt == \"\" else {key_in_ckpt: net}, path=ckpt_file)\n\n            # check with default value\n            cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"ckpt_export\", \"--key_in_ckpt\", key_in_ckpt]\n            cmd += [\"--config_file\", config_file, \"--bundle_root\", tempdir]\n            if use_trace == \"True\":\n                cmd += [\"--use_trace\", use_trace, \"--input_shape\", \"[1, 1, 96, 96, 96]\"]\n            command_line_tests(cmd)\n            self.assertTrue(os.path.exists(ts_file))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_download.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport tempfile\nimport unittest\nfrom unittest.case import skipIf, skipUnless\nfrom unittest.mock import patch\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai.networks.nets as nets\nfrom monai.apps import check_hash\nfrom monai.bundle import ConfigParser, create_workflow, load\nfrom monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download\nfrom monai.utils import optional_import\nfrom tests.test_utils import (\n    assert_allclose,\n    command_line_tests,\n    skip_if_downloading_fails,\n    skip_if_no_cuda,\n    skip_if_quick,\n    skip_if_windows,\n)\n\n_, has_huggingface_hub = optional_import(\"huggingface_hub\")\n\nTEST_CASE_1 = [\"test_bundle\", None]\n\nTEST_CASE_2 = [\"test_bundle\", \"0.1.1\"]\n\nTEST_CASE_3 = [\n    [\"model.pt\", \"model.ts\", \"network.json\", \"test_output.pt\", \"test_input.pt\"],\n    \"test_bundle\",\n    \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle.zip\",\n    \"a131d39a0af717af32d19e565b434928\",\n]\n\nTEST_CASE_4 = [\n    [\"model.pt\", \"model.ts\", \"network.json\", \"test_output.pt\", \"test_input.pt\"],\n    \"test_bundle\",\n    \"monai-test/test_bundle\",\n]\n\nTEST_CASE_5 = [\n    [\"models/model.pt\", \"models/model.ts\", \"configs/train.json\"],\n    \"brats_mri_segmentation\",\n    \"https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.4.0/files/brats_mri_segmentation_v0.4.0.zip\",\n]\n\nTEST_CASE_6 = [[\"models/model.pt\", \"configs/train.json\"], \"renalStructures_CECT_segmentation\", \"0.1.0\"]\n\nTEST_CASE_6_HF = [[\"models/model.pt\", \"configs/train.yaml\"], \"mednist_ddpm\", \"1.0.1\"]\n\nTEST_CASE_7 = [\n    [\"model.pt\", \"model.ts\", \"network.json\", \"test_output.pt\", \"test_input.pt\"],\n    \"test_bundle\",\n    \"Project-MONAI/MONAI-extra-test-data/0.8.1\",\n    \"cuda\" if torch.cuda.is_available() else \"cpu\",\n    \"model.pt\",\n]\n\nTEST_CASE_8 = [\n    \"spleen_ct_segmentation\",\n    \"cuda\" if torch.cuda.is_available() else \"cpu\",\n    {\"spatial_dims\": 3, \"out_channels\": 5},\n]\n\nTEST_CASE_9 = [\n    [\"test_output.pt\", \"test_input.pt\"],\n    \"test_bundle\",\n    \"0.1.1\",\n    \"Project-MONAI/MONAI-extra-test-data/0.8.1\",\n    \"cuda\" if torch.cuda.is_available() else \"cpu\",\n    \"model.ts\",\n]\n\nTEST_CASE_10 = [\n    [\"network.json\", \"test_output.pt\", \"test_input.pt\", \"large_files.yaml\"],\n    \"test_bundle\",\n    \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.3.zip\",\n    {\"model.pt\": \"27952767e2e154e3b0ee65defc5aed38\", \"model.ts\": \"97746870fe591f69ac09827175b00675\"},\n]\n\nTEST_CASE_NGC_1 = [\n    \"spleen_ct_segmentation\",\n    \"0.3.7\",\n    None,\n    \"monai_spleen_ct_segmentation\",\n    \"models/model.pt\",\n    \"b418a2dc8672ce2fd98dc255036e7a3d\",\n]\nTEST_CASE_NGC_2 = [\n    \"monai_spleen_ct_segmentation\",\n    \"0.3.7\",\n    \"monai_\",\n    \"spleen_ct_segmentation\",\n    \"models/model.pt\",\n    \"b418a2dc8672ce2fd98dc255036e7a3d\",\n]\n\nTESTCASE_NGC_WEIGHTS = {\n    \"key\": \"model.0.conv.unit0.adn.N.bias\",\n    \"value\": torch.tensor(\n        [\n            -0.0705,\n            -0.0937,\n            -0.0422,\n            -0.2068,\n            0.1023,\n            -0.2007,\n            -0.0883,\n            0.0018,\n            -0.1719,\n            0.0116,\n            0.0285,\n            -0.0044,\n            0.1223,\n            -0.1287,\n            -0.1858,\n            0.0460,\n        ]\n    ),\n}\n\n\nclass TestDownload(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    @skip_if_quick\n    def test_github_download_bundle(self, bundle_name, version):\n        bundle_files = [\"model.pt\", \"model.ts\", \"network.json\", \"test_output.pt\", \"test_input.pt\"]\n        repo = \"Project-MONAI/MONAI-extra-test-data/0.8.1\"\n        hash_val = \"a131d39a0af717af32d19e565b434928\"\n        with skip_if_downloading_fails():\n            # download a whole bundle from github releases\n            with tempfile.TemporaryDirectory() as tempdir:\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download\", \"--name\", bundle_name, \"--source\", \"github\"]\n                cmd += [\"--bundle_dir\", tempdir, \"--repo\", repo]\n                if version is not None:\n                    cmd += [\"--version\", version]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, \"test_bundle\", file)\n                    self.assertTrue(os.path.exists(file_path))\n                    if file == \"network.json\":\n                        self.assertTrue(check_hash(filepath=file_path, val=hash_val))\n\n    @parameterized.expand([TEST_CASE_3])\n    @skip_if_quick\n    def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):\n        with skip_if_downloading_fails():\n            # download a single file from url, also use `args_file`\n            with tempfile.TemporaryDirectory() as tempdir:\n                def_args = {\"name\": bundle_name, \"bundle_dir\": tempdir, \"url\": \"\"}\n                def_args_file = os.path.join(tempdir, \"def_args.json\")\n                parser = ConfigParser()\n                parser.export_config_file(config=def_args, filepath=def_args_file)\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download\", \"--args_file\", def_args_file]\n                cmd += [\"--url\", url, \"--source\", \"github\"]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, bundle_name, file)\n                    self.assertTrue(os.path.exists(file_path))\n                if file == \"network.json\":\n                    self.assertTrue(check_hash(filepath=file_path, val=hash_val))\n\n    @parameterized.expand([TEST_CASE_4])\n    @skip_if_quick\n    @skipUnless(has_huggingface_hub, \"Requires `huggingface_hub`.\")\n    def test_hf_hub_download_bundle(self, bundle_files, bundle_name, repo):\n        with skip_if_downloading_fails():\n            with tempfile.TemporaryDirectory() as tempdir:\n                cmd = [\n                    \"coverage\",\n                    \"run\",\n                    \"-m\",\n                    \"monai.bundle\",\n                    \"download\",\n                    \"--name\",\n                    bundle_name,\n                    \"--source\",\n                    \"huggingface_hub\",\n                ]\n                cmd += [\"--bundle_dir\", tempdir, \"--repo\", repo, \"--progress\", \"False\"]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, bundle_name, file)\n                    self.assertTrue(os.path.exists(file_path))\n\n    @parameterized.expand([TEST_CASE_5])\n    @skip_if_quick\n    def test_monaihosting_url_download_bundle(self, bundle_files, bundle_name, url):\n        with skip_if_downloading_fails():\n            # download a single file from url, also use `args_file`\n            with tempfile.TemporaryDirectory() as tempdir:\n                def_args = {\"name\": bundle_name, \"bundle_dir\": tempdir, \"url\": \"\"}\n                def_args_file = os.path.join(tempdir, \"def_args.json\")\n                parser = ConfigParser()\n                parser.export_config_file(config=def_args, filepath=def_args_file)\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download\", \"--args_file\", def_args_file]\n                cmd += [\"--url\", url, \"--progress\", \"False\"]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, bundle_name, file)\n                    self.assertTrue(os.path.exists(file_path))\n\n    @parameterized.expand([TEST_CASE_5])\n    @skip_if_quick\n    @skipIf(os.getenv(\"NGC_API_KEY\", None) is None, \"NGC API key required for this test\")\n    def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _url):\n        with skip_if_downloading_fails():\n            # download a single file from url, also use `args_file`\n            with tempfile.TemporaryDirectory() as tempdir:\n                def_args = {\"name\": bundle_name, \"bundle_dir\": tempdir}\n                def_args_file = os.path.join(tempdir, \"def_args.json\")\n                parser = ConfigParser()\n                parser.export_config_file(config=def_args, filepath=def_args_file)\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download\", \"--args_file\", def_args_file]\n                cmd += [\"--progress\", \"False\", \"--source\", \"ngc_private\"]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, bundle_name, file)\n                    self.assertTrue(os.path.exists(file_path))\n\n    @parameterized.expand([TEST_CASE_6])\n    @skip_if_quick\n    @skipUnless(has_huggingface_hub, \"Requires `huggingface_hub`.\")\n    def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version):\n        with skip_if_downloading_fails():\n            # download a single file from url, also use `args_file`\n            with tempfile.TemporaryDirectory() as tempdir:\n                def_args = {\"name\": bundle_name, \"bundle_dir\": tempdir, \"version\": version}\n                def_args_file = os.path.join(tempdir, \"def_args.json\")\n                parser = ConfigParser()\n                parser.export_config_file(config=def_args, filepath=def_args_file)\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download\", \"--args_file\", def_args_file]\n                cmd += [\"--progress\", \"False\", \"--source\", \"monaihosting\"]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, bundle_name, file)\n                    self.assertTrue(os.path.exists(file_path))\n\n    @patch(\"monai.bundle.scripts.get_versions\", return_value={\"version\": \"1.2\"})\n    def test_examine_monai_version(self, mock_get_versions):\n        self.assertTrue(_examine_monai_version(\"1.1\")[0])  # Should return True, compatible\n        self.assertTrue(_examine_monai_version(\"1.2rc1\")[0])  # Should return True, compatible\n        self.assertFalse(_examine_monai_version(\"1.3\")[0])  # Should return False, not compatible\n\n    @patch(\"monai.bundle.scripts.get_versions\", return_value={\"version\": \"1.2rc1\"})\n    def test_examine_monai_version_rc(self, mock_get_versions):\n        self.assertTrue(_examine_monai_version(\"1.2\")[0])  # Should return True, compatible\n        self.assertFalse(_examine_monai_version(\"1.3\")[0])  # Should return False, not compatible\n\n    def test_list_latest_versions(self):\n        \"\"\"Test listing of the latest versions.\"\"\"\n        data = {\n            \"modelVersions\": [\n                {\"createdDate\": \"2021-01-01\", \"versionId\": \"1.0\"},\n                {\"createdDate\": \"2021-01-02\", \"versionId\": \"1.1\"},\n                {\"createdDate\": \"2021-01-03\", \"versionId\": \"1.2\"},\n            ]\n        }\n        self.assertEqual(_list_latest_versions(data), [\"1.2\", \"1.1\", \"1.0\"])\n        self.assertEqual(_list_latest_versions(data, max_versions=2), [\"1.2\", \"1.1\"])\n        data = {\n            \"modelVersions\": [\n                {\"createdDate\": \"2021-01-01\", \"versionId\": \"1.0\"},\n                {\"createdDate\": \"2021-01-02\", \"versionId\": \"1.1\"},\n            ]\n        }\n        self.assertEqual(_list_latest_versions(data), [\"1.1\", \"1.0\"])\n\n    @skip_if_quick\n    @skipUnless(has_huggingface_hub, \"Requires `huggingface_hub`.\")\n    @patch(\"monai.bundle.scripts.get_versions\", return_value={\"version\": \"1.2\"})\n    def test_download_monaihosting(self, mock_get_versions):\n        \"\"\"Test checking MONAI version from a metadata file.\"\"\"\n        with patch(\"monai.bundle.scripts.logger\") as mock_logger:\n            with tempfile.TemporaryDirectory() as tempdir:\n                with skip_if_downloading_fails():\n                    download(name=\"spleen_ct_segmentation\", bundle_dir=tempdir, source=\"monaihosting\")\n                    # Should have a warning message because the latest version is using monai > 1.2\n                    mock_logger.warning.assert_called_once()\n\n    @skip_if_quick\n    @patch(\"monai.bundle.scripts.get_versions\", return_value={\"version\": \"1.3\"})\n    def test_download_ngc(self, mock_get_versions):\n        \"\"\"Test checking MONAI version from a metadata file.\"\"\"\n        with skip_if_downloading_fails():\n            with patch(\"monai.bundle.scripts.logger\") as mock_logger:\n                with tempfile.TemporaryDirectory() as tempdir:\n                    download(name=\"spleen_ct_segmentation\", bundle_dir=tempdir, source=\"ngc\")\n                    mock_logger.warning.assert_not_called()\n\n\n@skip_if_no_cuda\nclass TestLoad(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_7])\n    @skip_if_quick\n    def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):\n        with skip_if_downloading_fails():\n            with tempfile.TemporaryDirectory() as tempdir:\n                bundle_root = os.path.join(tempdir, bundle_name)\n                # load weights\n                model_1 = load(\n                    name=bundle_name,\n                    model_file=model_file,\n                    bundle_dir=tempdir,\n                    repo=repo,\n                    source=\"github\",\n                    progress=False,\n                    device=device,\n                )\n                # prepare network\n                with open(os.path.join(bundle_root, bundle_files[2])) as f:\n                    net_args = json.load(f)[\"network_def\"]\n                model_name = net_args[\"_target_\"]\n                del net_args[\"_target_\"]\n                model = getattr(nets, model_name)(**net_args)\n                model.to(device)\n                model.load_state_dict(model_1)\n                model.eval()\n\n                # prepare data and test\n                input_tensor = torch.load(\n                    os.path.join(bundle_root, bundle_files[4]), map_location=device, weights_only=True\n                )\n                output = model.forward(input_tensor)\n                expected_output = torch.load(\n                    os.path.join(bundle_root, bundle_files[3]), map_location=device, weights_only=True\n                )\n                assert_allclose(output, expected_output, atol=1e-3, rtol=1e-3, type_test=False)\n\n                # load instantiated model directly and test, since the bundle has been downloaded,\n                # there is no need to input `repo`\n                _model_2 = getattr(nets, model_name)(**net_args)\n                model_2 = load(\n                    name=bundle_name,\n                    model=_model_2,\n                    model_file=model_file,\n                    bundle_dir=tempdir,\n                    progress=False,\n                    device=device,\n                    source=\"github\",\n                )\n                model_2.eval()\n                output_2 = model_2.forward(input_tensor)\n                assert_allclose(output_2, expected_output, atol=1e-3, rtol=1e-3, type_test=False)\n\n    @parameterized.expand([TEST_CASE_8])\n    @skip_if_quick\n    @skipUnless(has_huggingface_hub, \"Requires `huggingface_hub`.\")\n    def test_load_weights_with_net_override(self, bundle_name, device, net_override):\n        with skip_if_downloading_fails():\n            # download bundle, and load weights from the downloaded path\n            with tempfile.TemporaryDirectory() as tempdir:\n                # load weights\n                model = load(name=bundle_name, bundle_dir=tempdir, source=\"monaihosting\", progress=False, device=device)\n\n                # prepare data and test\n                input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)\n                output = model(input_tensor)\n                model_path = f\"{tempdir}/spleen_ct_segmentation/models/model.pt\"\n                workflow = create_workflow(\n                    config_file=f\"{tempdir}/spleen_ct_segmentation/configs/train.json\", workflow_type=\"train\"\n                )\n                expected_model = workflow.network_def.to(device)\n                expected_model.load_state_dict(torch.load(model_path, weights_only=True))\n                expected_output = expected_model(input_tensor)\n                assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)\n\n                # using net_override to override kwargs in network directly\n                model_2 = load(\n                    name=bundle_name,\n                    bundle_dir=tempdir,\n                    source=\"monaihosting\",\n                    progress=False,\n                    device=device,\n                    net_override=net_override,\n                )\n\n                # prepare data and test\n                input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)\n                output = model_2(input_tensor)\n                expected_shape = (1, 5, 96, 96, 96)\n                np.testing.assert_equal(output.shape, expected_shape)\n\n    @parameterized.expand([TEST_CASE_9])\n    @skip_if_quick\n    def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, model_file):\n        with skip_if_downloading_fails():\n            # load ts module\n            with tempfile.TemporaryDirectory() as tempdir:\n                bundle_root = os.path.join(tempdir, bundle_name)\n                # load ts module\n                model_ts, metadata, extra_file_dict = load(\n                    name=bundle_name,\n                    version=version,\n                    model_file=model_file,\n                    load_ts_module=True,\n                    bundle_dir=tempdir,\n                    repo=repo,\n                    progress=False,\n                    device=device,\n                    source=\"github\",\n                    config_files=(\"network.json\",),\n                )\n\n                # prepare and test ts\n                input_tensor = torch.load(\n                    os.path.join(bundle_root, bundle_files[1]), map_location=device, weights_only=True\n                )\n                output = model_ts.forward(input_tensor)\n                expected_output = torch.load(\n                    os.path.join(bundle_root, bundle_files[0]), map_location=device, weights_only=True\n                )\n                assert_allclose(output, expected_output, atol=1e-3, rtol=1e-3, type_test=False)\n                # test metadata\n                self.assertTrue(metadata[\"pytorch_version\"] == \"1.7.1\")\n                # test extra_file_dict\n                self.assertTrue(\"network.json\" in extra_file_dict.keys())\n\n\nclass TestDownloadLargefiles(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_10])\n    @skip_if_quick\n    def test_url_download_large_files(self, bundle_files, bundle_name, url, hash_val):\n        with skip_if_downloading_fails():\n            # download a single file from url, also use `args_file`\n            with tempfile.TemporaryDirectory() as tempdir:\n                def_args = {\"name\": bundle_name, \"bundle_dir\": tempdir, \"url\": \"\"}\n                def_args_file = os.path.join(tempdir, \"def_args.json\")\n                parser = ConfigParser()\n                parser.export_config_file(config=def_args, filepath=def_args_file)\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download\", \"--args_file\", def_args_file]\n                cmd += [\"--url\", url, \"--source\", \"github\"]\n                command_line_tests(cmd)\n                for file in bundle_files:\n                    file_path = os.path.join(tempdir, bundle_name, file)\n                    print(file_path)\n                    self.assertTrue(os.path.exists(file_path))\n\n                # download large files\n                bundle_path = os.path.join(tempdir, bundle_name)\n                cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"download_large_files\", \"--bundle_path\", bundle_path]\n                command_line_tests(cmd)\n                for file in [\"model.pt\", \"model.ts\"]:\n                    file_path = os.path.join(tempdir, bundle_name, f\"models/{file}\")\n                    self.assertTrue(check_hash(filepath=file_path, val=hash_val[file]))\n\n\n@skip_if_windows\nclass TestNgcBundleDownload(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_NGC_1, TEST_CASE_NGC_2])\n    @skip_if_quick\n    def test_ngc_download_bundle(self, bundle_name, version, remove_prefix, download_name, file_path, hash_val):\n        with skip_if_downloading_fails():\n            with tempfile.TemporaryDirectory() as tempdir:\n                download(\n                    name=bundle_name, source=\"ngc\", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix\n                )\n                full_file_path = os.path.join(tempdir, download_name, file_path)\n                self.assertTrue(os.path.exists(full_file_path))\n                self.assertTrue(check_hash(filepath=full_file_path, val=hash_val))\n\n                model = load(\n                    name=bundle_name, source=\"ngc\", version=version, bundle_dir=tempdir, remove_prefix=remove_prefix\n                )\n                assert_allclose(\n                    model.state_dict()[TESTCASE_NGC_WEIGHTS[\"key\"]],\n                    TESTCASE_NGC_WEIGHTS[\"value\"],\n                    atol=1e-4,\n                    rtol=1e-4,\n                    type_test=False,\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_get_data.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import get_all_bundles_list, get_bundle_info, get_bundle_versions\nfrom monai.utils import optional_import\nfrom tests.test_utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, skip_if_windows\n\nrequests, _ = optional_import(\"requests\")\n\nTEST_CASE_1 = [{\"bundle_name\": \"brats_mri_segmentation\", \"tag\": \"hosting_storage_v1\"}]\n\nTEST_CASE_2 = [\n    {\"bundle_name\": \"spleen_ct_segmentation\", \"version\": \"0.1.0\", \"auth_token\": None, \"tag\": \"hosting_storage_v1\"}\n]\n\nTEST_CASE_3 = [{\"tag\": \"hosting_storage_v1\"}]\n\nTEST_CASE_4 = [{\"tag\": \"dev\"}]\n\nTEST_CASE_5 = [{\"bundle_name\": \"brats_mri_segmentation\", \"tag\": \"dev\"}]\n\nTEST_CASE_6 = [{\"bundle_name\": \"spleen_ct_segmentation\", \"version\": \"0.1.0\", \"auth_token\": None, \"tag\": \"dev\"}]\n\nTEST_CASE_FAKE_TOKEN_1 = [{\"bundle_name\": \"spleen_ct_segmentation\", \"version\": \"0.1.0\", \"auth_token\": \"ghp_errortoken\"}]\n\nTEST_CASE_FAKE_TOKEN_2 = [\n    {\"bundle_name\": \"spleen_ct_segmentation\", \"version\": \"0.1.0\", \"auth_token\": \"ghp_errortoken\", \"tag\": \"dev\"}\n]\n\n\n@skip_if_windows\n@SkipIfNoModule(\"requests\")\nclass TestGetBundleData(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_3, TEST_CASE_4])\n    @skip_if_quick\n    def test_get_all_bundles_list(self, params):\n        with skip_if_downloading_fails():\n            output = get_all_bundles_list(**params)\n            self.assertIsInstance(output, list)\n            self.assertIsInstance(output[0], tuple)\n            self.assertTrue(len(output[0]) == 2)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_5])\n    @skip_if_quick\n    def test_get_bundle_versions(self, params):\n        with skip_if_downloading_fails():\n            output = get_bundle_versions(**params)\n            self.assertIsInstance(output, dict)\n            self.assertIn(\"latest_version\", output)\n            self.assertIn(\"all_versions\", output)\n            self.assertIn(\"0.1.0\", output[\"all_versions\"])\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    @skip_if_quick\n    def test_get_bundle_info(self, params):\n        with skip_if_downloading_fails():\n            output = get_bundle_info(**params)\n            self.assertIsInstance(output, dict)\n            for key in [\"id\", \"name\", \"size\", \"download_count\", \"browser_download_url\"]:\n                self.assertTrue(key in output)\n\n    @parameterized.expand([TEST_CASE_5, TEST_CASE_6])\n    @skip_if_quick\n    def test_get_bundle_info_monaihosting(self, params):\n        with skip_if_downloading_fails():\n            output = get_bundle_info(**params)\n            self.assertIsInstance(output, dict)\n            for key in [\"name\", \"browser_download_url\"]:\n                self.assertTrue(key in output)\n\n    @parameterized.expand([TEST_CASE_FAKE_TOKEN_1, TEST_CASE_FAKE_TOKEN_2])\n    @skip_if_quick\n    def test_fake_token(self, params):\n        with skip_if_downloading_fails():\n            with self.assertRaises(requests.exceptions.HTTPError):\n                get_bundle_info(**params)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_push_to_hf_hub.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\nfrom unittest.case import skipUnless\nfrom unittest.mock import patch\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import push_to_hf_hub\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_quick\n\nhuggingface_hub, has_huggingface_hub = optional_import(\"huggingface_hub\")\n\nTEST_CASE_1 = [\"monai-test/test_bundle_push\", \"test_bundle\"]\n\n\nclass TestPushToHuggingFaceHub(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1])\n    @skip_if_quick\n    @skipUnless(has_huggingface_hub, \"Requires `huggingface_hub` package.\")\n    @patch.object(huggingface_hub.HfApi, \"create_repo\")\n    @patch.object(huggingface_hub.HfApi, \"upload_folder\")\n    @patch.object(huggingface_hub.HfApi, \"create_tag\")\n    def test_push_to_huggingface_hub(self, repo, bundle_name, test_createrepo, test_uploadfolder, test_createtag):\n        test_uploadfolder.return_value = \"https://hf.co/repo/test\"\n        with tempfile.TemporaryDirectory() as tempdir:\n            repo_url = push_to_hf_hub(repo, bundle_name, tempdir)\n            self.assertEqual(\"https://hf.co/repo/test\", repo_url)\n"
  },
  {
    "path": "tests/bundle/test_bundle_trt_export.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser\nfrom monai.data import load_net_with_metadata\nfrom monai.networks import save_state\nfrom monai.utils import optional_import\nfrom tests.test_utils import (\n    SkipIfBeforeComputeCapabilityVersion,\n    command_line_tests,\n    skip_if_no_cuda,\n    skip_if_quick,\n    skip_if_windows,\n)\n\n_, has_torchtrt = optional_import(\n    \"torch_tensorrt\",\n    version=\"1.4.0\",\n    descriptor=\"Torch-TRT is not installed. Are you sure you have a Torch-TensorRT compilation?\",\n)\n_, has_tensorrt = optional_import(\n    \"tensorrt\", descriptor=\"TensorRT is not installed. Are you sure you have a TensorRT compilation?\"\n)\n\n_, has_onnx = optional_import(\"onnx\", descriptor=\"Onnx is not installed. Will not test the onnx option.\")\n\nTEST_CASE_1 = [\"fp32\", [], []]\n\nTEST_CASE_2 = [\"fp16\", [], []]\n\nTEST_CASE_3 = [\"fp32\", [1, 1, 96, 96, 96], [1, 4, 8]]\n\nTEST_CASE_4 = [\"fp16\", [1, 1, 96, 96, 96], [1, 4, 8]]\n\n\n@skip_if_windows\n@skip_if_no_cuda\n@skip_if_quick\n@SkipIfBeforeComputeCapabilityVersion((7, 5))\nclass TestTRTExport(unittest.TestCase):\n    def setUp(self):\n        self.device = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n        if not self.device:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # default\n\n    def tearDown(self):\n        if self.device is not None:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = self.device\n        else:\n            del os.environ[\"CUDA_VISIBLE_DEVICES\"]  # previously unset\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    @unittest.skipUnless(has_torchtrt and has_tensorrt, \"Torch-TensorRT is required for conversion!\")\n    def test_trt_export(self, convert_precision, input_shape, dynamic_batch):\n        tests_dir = Path(__file__).resolve().parents[1]\n        meta_file = os.path.join(tests_dir, \"testing_data\", \"metadata.json\")\n        config_file = os.path.join(tests_dir, \"testing_data\", \"inference.json\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\"}\n            def_args_file = os.path.join(tempdir, \"def_args.yaml\")\n\n            ckpt_file = os.path.join(tempdir, \"model.pt\")\n            ts_file = os.path.join(tempdir, f\"model_trt_{convert_precision}.ts\")\n\n            parser = ConfigParser()\n            parser.export_config_file(config=def_args, filepath=def_args_file)\n            parser.read_config(config_file)\n            net = parser.get_parsed_content(\"network_def\")\n            save_state(src=net, path=ckpt_file)\n\n            cmd = [\"python\", \"-m\", \"monai.bundle\", \"trt_export\", \"network_def\", \"--filepath\", ts_file]\n            cmd += [\"--meta_file\", meta_file, \"--config_file\", f\"['{config_file}','{def_args_file}']\", \"--ckpt_file\"]\n            cmd += [ckpt_file, \"--args_file\", def_args_file, \"--precision\", convert_precision]\n            if input_shape:\n                cmd += [\"--input_shape\", str(input_shape)]\n            if dynamic_batch:\n                cmd += [\"--dynamic_batch\", str(dynamic_batch)]\n            command_line_tests(cmd)\n            self.assertTrue(os.path.exists(ts_file))\n\n            _, metadata, extra_files = load_net_with_metadata(\n                ts_file, more_extra_files=[\"inference.json\", \"def_args.json\"]\n            )\n            self.assertIn(\"schema\", metadata)\n            self.assertIn(\"meta_file\", json.loads(extra_files[\"def_args.json\"]))\n            self.assertIn(\"network_def\", json.loads(extra_files[\"inference.json\"]))\n\n    @parameterized.expand([TEST_CASE_3, TEST_CASE_4])\n    @unittest.skipUnless(\n        has_onnx and has_torchtrt and has_tensorrt, \"Onnx and TensorRT are required for onnx-trt conversion!\"\n    )\n    def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch):\n        tests_dir = Path(__file__).resolve().parents[1]\n        meta_file = os.path.join(tests_dir, \"testing_data\", \"metadata.json\")\n        config_file = os.path.join(tests_dir, \"testing_data\", \"inference.json\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\"}\n            def_args_file = os.path.join(tempdir, \"def_args.yaml\")\n\n            ckpt_file = os.path.join(tempdir, \"model.pt\")\n            ts_file = os.path.join(tempdir, f\"model_trt_{convert_precision}.ts\")\n\n            parser = ConfigParser()\n            parser.export_config_file(config=def_args, filepath=def_args_file)\n            parser.read_config(config_file)\n            net = parser.get_parsed_content(\"network_def\")\n            save_state(src=net, path=ckpt_file)\n\n            cmd = [\"python\", \"-m\", \"monai.bundle\", \"trt_export\", \"network_def\", \"--filepath\", ts_file]\n            cmd += [\"--meta_file\", meta_file, \"--config_file\", f\"['{config_file}','{def_args_file}']\", \"--ckpt_file\"]\n            cmd += [ckpt_file, \"--args_file\", def_args_file, \"--precision\", convert_precision]\n            cmd += [\"--use_onnx\", \"True\"]\n            if input_shape:\n                cmd += [\"--input_shape\", str(input_shape)]\n            if dynamic_batch:\n                cmd += [\"--dynamic_batch\", str(dynamic_batch)]\n            command_line_tests(cmd)\n            self.assertTrue(os.path.exists(ts_file))\n\n            _, metadata, extra_files = load_net_with_metadata(\n                ts_file, more_extra_files=[\"inference.json\", \"def_args.json\"]\n            )\n            self.assertIn(\"schema\", metadata)\n            self.assertIn(\"meta_file\", json.loads(extra_files[\"def_args.json\"]))\n            self.assertIn(\"network_def\", json.loads(extra_files[\"inference.json\"]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\n\nimport torch\n\nfrom monai.bundle import update_kwargs\nfrom monai.bundle.utils import load_bundle_config\nfrom monai.networks.nets import UNet\nfrom monai.utils import pprint_edges\nfrom tests.test_utils import command_line_tests, skip_if_windows\n\nmetadata = \"\"\"\n{\n    \"test_value\": 1,\n    \"test_list\": [2,3]\n}\n\"\"\"\n\ntest_json = \"\"\"\n{\n    \"test_dict\": {\n        \"a\": 3,\n        \"b\": \"c\"\n    },\n    \"network_def\": {\n        \"_target_\": \"UNet\",\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 1,\n        \"channels\": [4,8],\n        \"strides\": [2]\n    }\n}\n\"\"\"\n\n\n@skip_if_windows\nclass TestLoadBundleConfig(unittest.TestCase):\n    def setUp(self):\n        self.bundle_dir = tempfile.TemporaryDirectory()\n        self.dir_name = os.path.join(self.bundle_dir.name, \"TestBundle\")\n        self.configs_name = os.path.join(self.dir_name, \"configs\")\n        self.models_name = os.path.join(self.dir_name, \"models\")\n        self.metadata_name = os.path.join(self.configs_name, \"metadata.json\")\n        self.test_name = os.path.join(self.configs_name, \"test.json\")\n        self.modelpt_name = os.path.join(self.models_name, \"model.pt\")\n\n        self.zip_file = os.path.join(self.bundle_dir.name, \"TestBundle.zip\")\n        self.ts_file = os.path.join(self.bundle_dir.name, \"TestBundle.ts\")\n\n        # create the directories for the bundle\n        os.mkdir(self.dir_name)\n        os.mkdir(self.configs_name)\n        os.mkdir(self.models_name)\n\n        # fill bundle configs\n\n        with open(self.metadata_name, \"w\") as o:\n            o.write(metadata)\n\n        with open(self.test_name, \"w\") as o:\n            o.write(test_json)\n\n        # save network\n        net = UNet(2, 1, 1, [4, 8], [2])\n        torch.save(net.state_dict(), self.modelpt_name)\n\n    def tearDown(self):\n        self.bundle_dir.cleanup()\n\n    def test_load_config_dir(self):\n        p = load_bundle_config(self.dir_name, \"test.json\")\n\n        self.assertEqual(p[\"_meta_\"][\"test_value\"], 1)\n\n        self.assertEqual(p[\"test_dict\"][\"b\"], \"c\")\n\n    def test_load_config_zip(self):\n        # create a zip of the bundle\n        shutil.make_archive(self.zip_file[:-4], \"zip\", self.bundle_dir.name)\n\n        p = load_bundle_config(self.zip_file, \"test.json\")\n\n        self.assertEqual(p[\"_meta_\"][\"test_value\"], 1)\n\n        self.assertEqual(p[\"test_dict\"][\"b\"], \"c\")\n\n    def test_run(self):\n        command_line_tests(\n            [\n                \"python\",\n                \"-m\",\n                \"monai.bundle\",\n                \"run\",\n                \"test\",\n                \"--test\",\n                \"$print('hello world')\",\n                \"--config_file\",\n                self.test_name,\n                \"--meta_file\",\n                self.metadata_name,\n            ]\n        )\n\n    def test_load_config_ts(self):\n        # create a Torchscript zip of the bundle\n        cmd = [\"python\", \"-m\", \"monai.bundle\", \"ckpt_export\", \"network_def\", \"--filepath\", self.ts_file]\n        cmd += [\"--meta_file\", self.metadata_name]\n        cmd += [\"--config_file\", self.test_name]\n        cmd += [\"--ckpt_file\", self.modelpt_name]\n\n        command_line_tests(cmd)\n\n        p = load_bundle_config(self.ts_file, \"test.json\")\n\n        self.assertEqual(p[\"_meta_\"][\"test_value\"], 1)\n\n        self.assertEqual(p[\"test_dict\"][\"b\"], \"c\")\n\n\nclass TestPPrintEdges(unittest.TestCase):\n    def test_str(self):\n        self.assertEqual(pprint_edges(\"\", 0), \"''\")\n        self.assertEqual(pprint_edges({\"a\": 1, \"b\": 2}, 0), \"{'a': 1, 'b': 2}\")\n        self.assertEqual(\n            pprint_edges([{\"a\": 1, \"b\": 2}] * 20, 1),\n            \"[{'a': 1, 'b': 2},\\n\\n ... omitted 18 line(s)\\n\\n {'a': 1, 'b': 2}]\",\n        )\n        self.assertEqual(pprint_edges([{\"a\": 1, \"b\": 2}] * 8, 4), pprint_edges([{\"a\": 1, \"b\": 2}] * 8, 3))\n        self.assertEqual(update_kwargs({\"a\": 1}, a=2, b=3), {\"a\": 2, \"b\": 3})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_verify_metadata.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser, verify_metadata\nfrom tests.test_utils import command_line_tests, download_url_or_skip_test, skip_if_windows, testing_data_config\n\nTESTS_DIR = Path(__file__).parents[1]\nSCHEMA_FILE = os.path.join(TESTS_DIR, \"testing_data\", \"schema.json\")\n\nTEST_CASE_1 = [os.path.join(TESTS_DIR, \"testing_data\", \"metadata.json\"), SCHEMA_FILE]\n\n\n@skip_if_windows\nclass TestVerifyMetaData(unittest.TestCase):\n    def setUp(self):\n        self.config = testing_data_config(\"configs\", \"test_meta_file\")\n        download_url_or_skip_test(\n            url=self.config[\"url\"],\n            filepath=SCHEMA_FILE,\n            hash_val=self.config.get(\"hash_val\"),\n            hash_type=self.config.get(\"hash_type\", \"sha256\"),\n        )\n\n    @parameterized.expand([TEST_CASE_1])\n    def test_verify(self, meta_file, schema_file):\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\"}\n            def_args_file = os.path.join(tempdir, \"def_args.json\")\n            ConfigParser.export_config_file(config=def_args, filepath=def_args_file)\n\n            cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"verify_metadata\", \"--meta_file\", meta_file]\n            cmd += [\"--filepath\", schema_file, \"--hash_val\", self.config[\"hash_val\"], \"--args_file\", def_args_file]\n            command_line_tests(cmd)\n\n    def test_verify_error(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"schema.json\")\n            metafile = os.path.join(tempdir, \"metadata.json\")\n            meta_dict = {\"schema\": self.config[\"url\"], \"wrong_meta\": \"wrong content\"}\n            with open(metafile, \"w\") as f:\n                json.dump(meta_dict, f)\n\n            with self.assertRaises(ValueError):\n                verify_metadata(meta_file=metafile, filepath=filepath)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_verify_net.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser, verify_net_in_out\nfrom tests.test_utils import command_line_tests, skip_if_no_cuda, skip_if_windows\n\nTESTS_PATH = Path(__file__).parents[1].as_posix()\n\nTEST_CASE_1 = [\n    os.path.join(TESTS_PATH, \"testing_data\", \"metadata.json\"),\n    os.path.join(TESTS_PATH, \"testing_data\", \"inference.json\"),\n]\n\n\n@skip_if_windows\nclass TestVerifyNetwork(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1])\n    def test_verify(self, meta_file, config_file):\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\", \"p\": 2}\n            def_args_file = os.path.join(tempdir, \"def_args.json\")\n            ConfigParser.export_config_file(config=def_args, filepath=def_args_file)\n\n            cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"verify_net_in_out\", \"network_def\", \"--meta_file\"]\n            cmd += [meta_file, \"--config_file\", config_file, \"-n\", \"4\", \"--any\", \"16\", \"--args_file\", def_args_file]\n            cmd += [\"--device\", \"cpu\", \"--_meta_::network_data_format::inputs#image#spatial_shape\", \"[16,'*','2**p*n']\"]\n            command_line_tests(cmd)\n\n    @parameterized.expand([TEST_CASE_1])\n    @skip_if_no_cuda\n    def test_verify_fp16(self, meta_file, config_file):\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\", \"p\": 2}\n            def_args_file = os.path.join(tempdir, \"def_args.json\")\n            ConfigParser.export_config_file(config=def_args, filepath=def_args_file)\n\n            cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\", \"verify_net_in_out\", \"network_def\", \"--meta_file\"]\n            cmd += [meta_file, \"--config_file\", config_file, \"-n\", \"4\", \"--any\", \"16\", \"--args_file\", def_args_file]\n            cmd += [\"--device\", \"cuda\", \"--_meta_#network_data_format#inputs#image#spatial_shape\", \"[16,'*','2**p*n']\"]\n            cmd += [\"--_meta_#network_data_format#inputs#image#dtype\", \"float16\"]\n            cmd += [\"--_meta_::network_data_format::outputs::pred::dtype\", \"float16\"]\n            command_line_tests(cmd)\n\n    @parameterized.expand([TEST_CASE_1])\n    @skip_if_no_cuda\n    def test_verify_fp16_extra_forward_args(self, meta_file, config_file):\n        verify_net_in_out(\n            net_id=\"network_def\",\n            meta_file=meta_file,\n            config_file=config_file,\n            n=4,\n            any=16,\n            extra_forward_args={\"extra_arg1\": 1, \"extra_arg2\": 2},\n            **{\"network_def#_target_\": \"tests.testing_data.bundle_test_network.TestMultiInputUNet\"},\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_bundle_workflow.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport sys\nimport tempfile\nimport unittest\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigWorkflow, create_workflow\nfrom monai.data import Dataset\nfrom monai.inferers import SimpleInferer, SlidingWindowInferer\nfrom monai.networks.nets import UNet\nfrom monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged\nfrom tests.nonconfig_workflow import NonConfigWorkflow, PythonicWorkflowImpl\n\nMODULE_PATH = Path(__file__).resolve().parents[1]\n\nTEST_CASE_1 = [os.path.join(MODULE_PATH, \"testing_data\", \"inference.json\")]\n\nTEST_CASE_2 = [os.path.join(MODULE_PATH, \"testing_data\", \"inference.yaml\")]\n\nTEST_CASE_3 = [os.path.join(MODULE_PATH, \"testing_data\", \"config_fl_train.json\")]\n\nTEST_CASE_4 = [os.path.join(MODULE_PATH, \"testing_data\", \"responsive_inference.json\")]\n\nTEST_CASE_NON_CONFIG_WRONG_LOG = [None, \"logging.conf\", \"Cannot find the logging config file: logging.conf.\"]\n\n\nclass TestBundleWorkflow(unittest.TestCase):\n    def setUp(self):\n        self.data_dir = tempfile.mkdtemp()\n        self.expected_shape = (128, 128, 128)\n        test_image = np.random.rand(*self.expected_shape)\n        self.filename = os.path.join(self.data_dir, \"image.nii\")\n        self.filename1 = os.path.join(self.data_dir, \"image1.nii\")\n        nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename)\n        nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1)\n\n    def tearDown(self):\n        shutil.rmtree(self.data_dir)\n\n    def _test_inferer(self, inferer):\n        # should initialize before parsing any bundle content\n        inferer.initialize()\n        # test required and optional properties\n        self.assertListEqual(inferer.check_properties(), [])\n        # test read / write the properties, note that we don't assume it as JSON or YAML config here\n        self.assertEqual(inferer.bundle_root, \"will override\")\n        self.assertEqual(inferer.device, torch.device(\"cpu\"))\n        net = inferer.network_def\n        self.assertTrue(isinstance(net, UNet))\n        sliding_window = inferer.inferer\n        self.assertTrue(isinstance(sliding_window, SlidingWindowInferer))\n        preprocessing = inferer.preprocessing\n        self.assertTrue(isinstance(preprocessing, Compose))\n        postprocessing = inferer.postprocessing\n        self.assertTrue(isinstance(postprocessing, Compose))\n        # test optional properties get\n        self.assertTrue(inferer.key_metric is None)\n        inferer.bundle_root = \"/workspace/data/spleen_ct_segmentation\"\n        inferer.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n        inferer.network_def = deepcopy(net)\n        inferer.inferer = deepcopy(sliding_window)\n        inferer.preprocessing = deepcopy(preprocessing)\n        inferer.postprocessing = deepcopy(postprocessing)\n        # test optional properties set\n        inferer.key_metric = \"set optional properties\"\n\n        # should initialize and parse again as changed the bundle content\n        inferer.initialize()\n        inferer.run()\n        inferer.finalize()\n        # verify inference output\n        loader = LoadImage(image_only=True)\n        pred_file = os.path.join(self.data_dir, \"image\", \"image_seg.nii.gz\")\n        self.assertTupleEqual(loader(pred_file).shape, self.expected_shape)\n        os.remove(pred_file)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_inference_config(self, config_file):\n        override = {\n            \"network\": \"$@network_def.to(@device)\",\n            \"dataset#_target_\": \"Dataset\",\n            \"dataset#data\": [{\"image\": self.filename}],\n            \"postprocessing#transforms#2#output_postfix\": \"seg\",\n            \"output_dir\": self.data_dir,\n        }\n        # test standard MONAI model-zoo config workflow\n        inferer = ConfigWorkflow(\n            workflow_type=\"infer\",\n            config_file=config_file,\n            logging_file=os.path.join(MODULE_PATH, \"testing_data\", \"logging.conf\"),\n            **override,\n        )\n        self._test_inferer(inferer)\n\n        # test property path\n        inferer = ConfigWorkflow(\n            config_file=config_file,\n            workflow_type=\"infer\",\n            properties_path=os.path.join(MODULE_PATH, \"testing_data\", \"fl_infer_properties.json\"),\n            logging_file=os.path.join(MODULE_PATH, \"testing_data\", \"logging.conf\"),\n            **override,\n        )\n        self._test_inferer(inferer)\n        self.assertEqual(inferer.workflow_type, \"infer\")\n\n    @parameterized.expand([TEST_CASE_4])\n    def test_responsive_inference_config(self, config_file):\n        input_loader = LoadImaged(keys=\"image\")\n        output_saver = SaveImaged(keys=\"pred\", output_dir=self.data_dir, output_postfix=\"seg\")\n\n        # test standard MONAI model-zoo config workflow\n        inferer = ConfigWorkflow(\n            workflow_type=\"infer\",\n            config_file=config_file,\n            logging_file=os.path.join(MODULE_PATH, \"testing_data\", \"logging.conf\"),\n        )\n        # FIXME: temp add the property for test, we should add it to some formal realtime infer properties\n        inferer.add_property(name=\"dataflow\", required=True, config_id=\"dataflow\")\n\n        inferer.initialize()\n        inferer.dataflow.update(input_loader({\"image\": self.filename}))\n        inferer.run()\n        output_saver(inferer.dataflow)\n        self.assertTrue(os.path.exists(os.path.join(self.data_dir, \"image\", \"image_seg.nii.gz\")))\n\n        # bundle is instantiated and idle, just change the input for next inference\n        inferer.dataflow.clear()\n        inferer.dataflow.update(input_loader({\"image\": self.filename1}))\n        inferer.run()\n        output_saver(inferer.dataflow)\n        self.assertTrue(os.path.exists(os.path.join(self.data_dir, \"image1\", \"image1_seg.nii.gz\")))\n\n        inferer.finalize()\n\n    @parameterized.expand([TEST_CASE_3])\n    def test_train_config(self, config_file):\n        # test standard MONAI model-zoo config workflow\n        trainer = ConfigWorkflow(\n            workflow_type=\"train\",\n            config_file=config_file,\n            logging_file=os.path.join(MODULE_PATH, \"testing_data\", \"logging.conf\"),\n            init_id=\"initialize\",\n            run_id=\"run\",\n            final_id=\"finalize\",\n        )\n        # should initialize before parsing any bundle content\n        trainer.initialize()\n        # test required and optional properties\n        self.assertListEqual(trainer.check_properties(), [])\n        # test override optional properties\n        trainer.parser.update(\n            pairs={\"validate#evaluator#postprocessing\": \"$@validate#postprocessing if @val_interval > 0 else None\"}\n        )\n        trainer.initialize()\n        self.assertListEqual(trainer.check_properties(), [])\n        # test read / write the properties\n        dataset = trainer.train_dataset\n        self.assertIsInstance(dataset, Dataset)\n        inferer = trainer.train_inferer\n        self.assertIsInstance(inferer, SimpleInferer)\n        # test optional properties get\n        self.assertIsNone(trainer.train_key_metric)\n        trainer.train_dataset = deepcopy(dataset)\n        trainer.train_inferer = deepcopy(inferer)\n        # test optional properties set\n        trainer.train_key_metric = \"set optional properties\"\n\n        # should initialize and parse again as changed the bundle content\n        trainer.initialize()\n        trainer.run()\n        trainer.finalize()\n\n    def test_non_config(self):\n        # test user defined python style workflow\n        inferer = NonConfigWorkflow(self.filename, self.data_dir)\n        self.assertEqual(inferer.meta_file, None)\n        self._test_inferer(inferer)\n\n    @parameterized.expand([TEST_CASE_NON_CONFIG_WRONG_LOG])\n    def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_error):\n        with self.assertRaisesRegex(FileNotFoundError, expected_error):\n            NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file)\n\n    def test_pythonic_workflow(self):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        config_file = {\"roi_size\": (64, 64, 32)}\n        meta_file = os.path.join(MODULE_PATH, \"testing_data\", \"metadata.json\")\n        property_path = os.path.join(MODULE_PATH, \"testing_data\", \"python_workflow_properties.json\")\n        workflow = PythonicWorkflowImpl(\n            workflow_type=\"infer\", config_file=config_file, meta_file=meta_file, properties_path=property_path\n        )\n        workflow.initialize()\n        # Load input data\n        input_loader = LoadImaged(keys=\"image\")\n        workflow.dataflow.update(input_loader({\"image\": self.filename}))\n        self.assertEqual(workflow.bundle_root, \".\")\n        self.assertEqual(workflow.device, device)\n        self.assertEqual(workflow.version, \"0.1.0\")\n        # check config override correctly\n        self.assertEqual(workflow.inferer.roi_size, (64, 64, 32))\n        workflow.run()\n        # update input data and run again\n        workflow.dataflow.update(input_loader({\"image\": self.filename1}))\n        workflow.run()\n        pred = workflow.dataflow[\"pred\"]\n        self.assertEqual(pred.shape[2:], self.expected_shape)\n        self.assertEqual(pred.meta[\"filename_or_obj\"], self.filename1)\n        workflow.finalize()\n\n    def test_create_pythonic_workflow(self):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        config_file = {\"roi_size\": (64, 64, 32)}\n        meta_file = os.path.join(MODULE_PATH, \"testing_data\", \"metadata.json\")\n        property_path = os.path.join(MODULE_PATH, \"testing_data\", \"python_workflow_properties.json\")\n        sys.path.append(MODULE_PATH)\n        workflow = create_workflow(\n            \"tests.nonconfig_workflow.PythonicWorkflowImpl\",\n            workflow_type=\"infer\",\n            config_file=config_file,\n            meta_file=meta_file,\n            properties_path=property_path,\n        )\n        # Load input data\n        input_loader = LoadImaged(keys=\"image\")\n        workflow.dataflow.update(input_loader({\"image\": self.filename}))\n        self.assertEqual(workflow.bundle_root, \".\")\n        self.assertEqual(workflow.device, device)\n        self.assertEqual(workflow.version, \"0.1.0\")\n        # check config override correctly\n        self.assertEqual(workflow.inferer.roi_size, (64, 64, 32))\n\n        # check set property override correctly\n        workflow.inferer = SlidingWindowInferer(roi_size=config_file[\"roi_size\"], sw_batch_size=1, overlap=0.5)\n        workflow.initialize()\n        self.assertEqual(workflow.inferer.overlap, 0.5)\n\n        workflow.run()\n        # update input data and run again\n        workflow.dataflow.update(input_loader({\"image\": self.filename1}))\n        workflow.run()\n        pred = workflow.dataflow[\"pred\"]\n        self.assertEqual(pred.shape[2:], self.expected_shape)\n        self.assertEqual(pred.meta[\"filename_or_obj\"], self.filename1)\n\n        # test add properties\n        workflow.add_property(name=\"net\", required=True, desc=\"network for the training.\")\n        self.assertIn(\"net\", workflow.properties)\n        workflow.finalize()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_component_locator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom pydoc import locate\n\nfrom monai.bundle import ComponentLocator\nfrom monai.utils import optional_import\n\n_, has_ignite = optional_import(\"ignite\")\n\n\nclass TestComponentLocator(unittest.TestCase):\n\n    def test_locate(self):\n        locator = ComponentLocator(excludes=None if has_ignite else [\"monai.handlers\"])\n        # test init mapping table and get the module path of component\n        self.assertEqual(locator.get_component_module_name(\"LoadImage\"), \"monai.transforms.io.array\")\n        self.assertGreater(len(locator._components_table), 0)\n        for _, mods in locator._components_table.items():\n            for i in mods:\n                self.assertGreater(len(mods), 0)\n                # ensure we can locate all the items by `name`\n                self.assertIsNotNone(locate(i), msg=f\"can not locate target: {i}.\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_config_item.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom functools import partial\nfrom typing import Callable\n\nimport torch\nfrom parameterized import parameterized\n\nimport monai\nfrom monai.bundle import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem\nfrom monai.data import DataLoader, Dataset\nfrom monai.transforms import LoadImaged, RandTorchVisiond\nfrom monai.utils import min_version, optional_import\n\n_, has_tv = optional_import(\"torchvision\", \"0.8.0\", min_version)\n\nTEST_CASE_1 = [{\"lr\": 0.001}, 0.0001]\n\nTEST_CASE_2 = [{\"_target_\": \"LoadImaged\", \"keys\": [\"image\"], \"_desc_\": \"an image reader for 'image'\"}, LoadImaged]\n# test full module path\nTEST_CASE_3 = [{\"_target_\": \"monai.transforms.LoadImaged\", \"keys\": [\"image\"]}, LoadImaged]\n# test `_disabled_`\nTEST_CASE_4 = [{\"_target_\": \"LoadImaged\", \"_disabled_\": True, \"keys\": [\"image\"]}, dict]\n# test `_disabled_` with string\nTEST_CASE_5 = [{\"_target_\": \"LoadImaged\", \"_disabled_\": \"true\", \"keys\": [\"image\"]}, dict]\n# test non-monai modules and excludes\nTEST_CASE_6 = [{\"_target_\": \"torch.optim.Adam\", \"params\": torch.nn.PReLU().parameters(), \"lr\": 1e-4}, torch.optim.Adam]\nTEST_CASE_7 = [{\"_target_\": \"decollate_batch\", \"detach\": True, \"pad\": True, \"_mode_\": \"callable\"}, partial]\n# test args contains \"name\" field\nTEST_CASE_8 = [\n    {\"_target_\": \"RandTorchVisiond\", \"keys\": \"image\", \"name\": \"ColorJitter\", \"brightness\": 0.25},\n    RandTorchVisiond,\n]\n# test execute some function in args, test pre-imported global packages `monai`\nTEST_CASE_9 = [\"collate_fn\", \"$monai.data.list_data_collate\"]\n# test lambda function\nTEST_CASE_10 = [\"collate_fn\", \"$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)\"]\n# test regular expression with reference\nTEST_CASE_11 = [\"collate_fn\", \"$var + 100\"]\n\n\nclass TestConfigItem(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1])\n    def test_item(self, test_input, expected):\n        item = ConfigItem(config=test_input)\n        conf = item.get_config()\n        conf[\"lr\"] = 0.0001\n        item.update_config(config=conf)\n        self.assertEqual(item.get_config()[\"lr\"], expected)\n\n    @parameterized.expand(\n        [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]\n        + ([TEST_CASE_8] if has_tv else [])\n    )\n    def test_component(self, test_input, output_type):\n        locator = ComponentLocator(excludes=[\"metrics\"])\n        configer = ConfigComponent(id=\"test\", config=test_input, locator=locator)\n        ret = configer.instantiate()\n        if test_input.get(\"_disabled_\", False):\n            # test `_disabled_` works fine\n            self.assertEqual(ret, None)\n            return\n        self.assertTrue(isinstance(ret, output_type))\n        if isinstance(ret, LoadImaged):\n            self.assertEqual(ret.keys[0], \"image\")\n\n    @parameterized.expand([TEST_CASE_9, TEST_CASE_10, TEST_CASE_11])\n    def test_expression(self, id, test_input):\n        configer = ConfigExpression(id=id, config=test_input, globals={\"monai\": monai, \"torch\": torch})\n        var = 100\n        ret = configer.evaluate(globals={\"var\": var, \"monai\": monai})  # `{\"monai\": monai}` to verify the warning\n        if isinstance(ret, Callable):\n            self.assertTrue(isinstance(ret([torch.tensor(1), torch.tensor(2)]), torch.Tensor))\n        else:\n            # also test the `locals` for regular expressions\n            ret = configer.evaluate(locals={\"var\": var})\n            self.assertEqual(ret, 200)\n\n    def test_lazy_instantiation(self):\n        config = {\"_target_\": \"DataLoader\", \"dataset\": Dataset(data=[1, 2]), \"batch_size\": 2}\n        configer = ConfigComponent(config=config, locator=None)\n        init_config = configer.get_config()\n        # modify config content at runtime\n        init_config[\"batch_size\"] = 4\n        configer.update_config(config=init_config)\n\n        ret = configer.instantiate()\n        self.assertTrue(isinstance(ret, DataLoader))\n        self.assertEqual(ret.batch_size, 4)\n\n    @parameterized.expand([(\"$import json\", \"json\"), (\"$import json as j\", \"j\")])\n    def test_import(self, stmt, mod_name):\n        test_globals = {}\n        ConfigExpression(id=\"\", config=stmt, globals=test_globals).evaluate()\n        self.assertTrue(callable(test_globals[mod_name].dump))\n\n    @parameterized.expand(\n        [\n            (\"$from json import dump\", \"dump\"),\n            (\"$from json import dump, dumps\", \"dump\"),\n            (\"$from json import dump as jd\", \"jd\"),\n            (\"$from json import dump as jd, dumps as ds\", \"jd\"),\n        ]\n    )\n    def test_import_from(self, stmt, mod_name):\n        test_globals = {}\n        ConfigExpression(id=\"\", config=stmt, globals=test_globals).evaluate()\n        self.assertTrue(callable(test_globals[mod_name]))\n        self.assertTrue(ConfigExpression.is_import_statement(ConfigExpression(id=\"\", config=stmt).config))\n\n    @parameterized.expand(\n        [(\"$from json import dump\", True), (\"$print()\", False), (\"$import json\", True), (\"import json\", False)]\n    )\n    def test_is_import_stmt(self, stmt, expected):\n        expr = ConfigExpression(id=\"\", config=stmt)\n        flag = expr.is_import_statement(expr.config)\n        self.assertEqual(flag, expected)\n\n    def test_error_expr(self):\n        with self.assertRaisesRegex(RuntimeError, r\"1\\+\\[\\]\"):\n            ConfigExpression(id=\"\", config=\"$1+[]\").evaluate()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_config_parser.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nimport warnings\nfrom pathlib import Path\nfrom unittest import mock, skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser, ReferenceResolver\nfrom monai.bundle.config_item import ConfigItem\nfrom monai.data import DataLoader, Dataset\nfrom monai.transforms import Compose, LoadImaged, RandTorchVisiond\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TimedCall\n\n_, has_tv = optional_import(\"torchvision\", \"0.8.0\", min_version)\n_, has_yaml = optional_import(\"yaml\")\n\n\n@TimedCall(seconds=100, force_quit=True)\ndef case_pdb(sarg=None):\n    config = {\"transform\": {\"_target_\": \"Compose\", \"transforms\": [], \"_debug_\": True}}\n    parser = ConfigParser(config=config)\n    parser.get_parsed_content()\n\n\n@TimedCall(seconds=100, force_quit=True)\ndef case_pdb_inst(sarg=None):\n    config = {\"transform\": {\"_target_\": \"Compose\", \"transforms\": [], \"_mode_\": \"debug\"}}\n    parser = ConfigParser(config=config)\n    return parser.transform\n\n\n# test the resolved and parsed instances\nTEST_CASE_1 = [\n    {\n        \"transform\": {\n            \"_target_\": \"Compose\",\n            \"transforms\": [\n                {\"_target_\": \"LoadImaged\", \"keys\": \"image\"},\n                # test relative id in `keys`\n                {\"_target_\": \"RandTorchVisiond\", \"keys\": \"@##0#keys\", \"name\": \"ColorJitter\", \"brightness\": 0.25},\n            ],\n        },\n        \"dataset\": {\"_target_\": \"Dataset\", \"data\": [1, 2], \"transform\": \"@transform\"},\n        \"dataloader\": {\n            \"_target_\": \"DataLoader\",\n            # test relative id in `dataset`\n            \"dataset\": \"@##dataset\",\n            \"batch_size\": 2,\n            \"collate_fn\": \"$monai.data.list_data_collate\",\n        },\n    },\n    [\"transform\", \"transform#transforms#0\", \"transform#transforms#1\", \"dataset\", \"dataloader\"],\n    [Compose, LoadImaged, RandTorchVisiond, Dataset, DataLoader],\n]\n\n\nclass TestClass:\n    @staticmethod\n    def compute(a, b, func=lambda x, y: x + y):\n        return func(a, b)\n\n    @classmethod\n    def cls_compute(cls, a, b, func=lambda x, y: x + y):\n        return cls.compute(a, b, func)\n\n    def __call__(self, a, b):\n        return self.compute(a, b)\n\n\nTEST_CASE_2 = [\n    {\n        \"basic_func\": \"$lambda x, y: x + y\",\n        \"static_func\": \"$TestClass.compute\",\n        \"cls_func\": \"$TestClass.cls_compute\",\n        \"lambda_static_func\": \"$lambda x, y: TestClass.compute(x, y)\",\n        \"lambda_cls_func\": \"$lambda x, y: TestClass.cls_compute(x, y)\",\n        \"compute\": {\"_target_\": \"tests.bundle.test_config_parser.TestClass.compute\", \"func\": \"@basic_func\"},\n        \"cls_compute\": {\"_target_\": \"tests.bundle.test_config_parser.TestClass.cls_compute\", \"func\": \"@basic_func\"},\n        \"call_compute\": {\"_target_\": \"tests.bundle.test_config_parser.TestClass\"},\n        \"error_func\": \"$TestClass.__call__\",\n        \"<test>\": \"$lambda x, y: x + y\",\n    }\n]\n\nTEST_CASE_3 = [\n    {\n        \"A\": 1,\n        \"B\": \"@A\",\n        \"C\": \"@#A\",\n        \"D\": {\"key\": \"@##A\", \"value1\": 2, \"value2\": \"%#value1\", \"value3\": [3, 4, \"@#1\", \"$100 + @#0 + @##value1\"]},\n    }\n]\n\nTEST_CASE_4 = [{\"A\": 1, \"B\": \"@A\", \"C\": \"@D\", \"E\": \"$'test' + '@F'\"}]\n\nTEST_CASE_5 = [{\"training\": {\"A\": 1, \"A_B\": 2}, \"total\": \"$@training#A + @training#A_B + 1\"}, 4]\n\nTEST_CASE_DUPLICATED_KEY_JSON = [\"\"\"{\"key\": {\"unique\": 1, \"duplicate\": 0, \"duplicate\": 4 } }\"\"\", \"json\", 1, [0, 4]]\n\nTEST_CASE_DUPLICATED_KEY_YAML = [\n    \"\"\"key:\n    unique: 1\n    duplicate: 0\n    duplicate: 4\"\"\",\n    \"yaml\",\n    1,\n    [0, 4],\n]\n\nTEST_CASE_MERGE_JSON = [\"\"\"{\"key1\": [0], \"key2\": [0] }\"\"\", \"\"\"{\"key1\": [1], \"+key2\": [4] }\"\"\", \"json\", [1], [0, 4]]\n\nTEST_CASE_MERGE_YAML = [\n    \"\"\"\n    key1: 0\n    key2: [0]\n    \"\"\",\n    \"\"\"\n    key1: 1\n    +key2: [4]\n    \"\"\",\n    \"yaml\",\n    1,\n    [0, 4],\n]\n\n\nclass TestConfigParser(unittest.TestCase):\n    def test_config_content(self):\n        test_config = {\"preprocessing\": [{\"_target_\": \"LoadImage\"}], \"dataset\": {\"_target_\": \"Dataset\"}}\n        parser = ConfigParser(config=test_config)\n        # test `get`, `set`, `__getitem__`, `__setitem__`\n        self.assertEqual(str(parser.get()), str(test_config))\n        parser.set(config=test_config)\n        self.assertListEqual(parser[\"preprocessing\"], test_config[\"preprocessing\"])\n        parser[\"dataset\"] = {\"_target_\": \"CacheDataset\"}\n        self.assertEqual(parser[\"dataset\"][\"_target_\"], \"CacheDataset\")\n        # test nested ids\n        parser[\"dataset#_target_\"] = \"Dataset\"\n        self.assertEqual(parser[\"dataset#_target_\"], \"Dataset\")\n        parser.update({\"dataset#_target_1\": \"Dataset1\"})\n        self.assertEqual(parser[\"dataset#_target_1\"], \"Dataset1\")\n        # test int id\n        parser.set([\"test1\", \"test2\", \"test3\"])\n        parser[1] = \"test4\"\n        self.assertEqual(parser[1], \"test4\")\n\n    @parameterized.expand([TEST_CASE_1])\n    @skipUnless(has_tv, \"Requires torchvision >= 0.8.0.\")\n    def test_parse(self, config, expected_ids, output_types):\n        parser = ConfigParser(config=config, globals={\"monai\": \"monai\"})\n        # test lazy instantiation with original config content\n        parser[\"transform\"][\"transforms\"][0][\"keys\"] = \"label1\"\n        trans = parser.get_parsed_content(id=\"transform#transforms#0\")\n        self.assertEqual(trans.keys[0], \"label1\")\n        # test re-use the parsed content or not with the `lazy` option\n        self.assertEqual(trans, parser.get_parsed_content(id=\"transform#transforms#0\"))\n        self.assertEqual(trans, parser.get_parsed_content(id=\"transform#transforms#0\", lazy=True))\n        self.assertNotEqual(trans, parser.get_parsed_content(id=\"transform#transforms#0\", lazy=False))\n        # test new nested id\n        parser.set(\"fake_key\", \"transform#other_transforms#keys\", True)\n        self.assertEqual(parser.get(id=\"transform#other_transforms#keys\"), \"fake_key\")\n        # remove temp fake data\n        parser[\"transform\"].pop(\"other_transforms\")\n        # test update nested id\n        parser[\"transform#transforms#0#keys\"] = \"label2\"\n        self.assertEqual(parser.get_parsed_content(id=\"transform#transforms#0\").keys[0], \"label2\")\n\n        for id, cls in zip(expected_ids, output_types):\n            self.assertTrue(isinstance(parser.get_parsed_content(id), cls))\n        # test root content\n        root = parser.get_parsed_content(id=\"\")\n        for v, cls in zip(root.values(), [Compose, Dataset, DataLoader]):\n            self.assertTrue(isinstance(v, cls))\n        # test default value\n        self.assertEqual(parser.get_parsed_content(id=\"abc\", default=ConfigItem(12345, \"abc\")), 12345)\n        self.assertEqual(parser.get_parsed_content(id=\"abcd\", default=1), 1)\n\n    @parameterized.expand([TEST_CASE_2])\n    def test_function(self, config):\n        parser = ConfigParser(config=config, globals={\"TestClass\": TestClass})\n        for id in config:\n            if id in (\"compute\", \"cls_compute\"):\n                parser[f\"{id}#_mode_\"] = \"callable\"\n            func = parser.get_parsed_content(id=id)\n            self.assertIn(id, parser.ref_resolver.resolved_content)\n            if id == \"error_func\":\n                with self.assertRaises(TypeError):\n                    func(1, 2)\n                continue\n            self.assertEqual(func(1, 2), 3)\n\n    @parameterized.expand([TEST_CASE_3])\n    def test_relative_id(self, config):\n        parser = ConfigParser(config=config)\n        for id in config:\n            item = parser.get_parsed_content(id=id)\n            if isinstance(item, int):\n                self.assertEqual(item, 1)\n            if isinstance(item, dict):\n                self.assertEqual(str(item), str({\"key\": 1, \"value1\": 2, \"value2\": 2, \"value3\": [3, 4, 4, 105]}))\n\n    def test_macro_replace(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            another_file = os.path.join(tempdir, \"another.json\")\n            ConfigParser.export_config_file(config={\"E\": 4}, filepath=another_file)\n            # test macro with id, relative id, and macro in another file\n            config = {\"A\": {\"B\": 1, \"C\": 2}, \"D\": [3, \"%A#B\", \"%#0\", f\"%{another_file}#E\"]}\n            parser = ConfigParser(config=config)\n            parser.resolve_macro_and_relative_ids()\n            self.assertEqual(str(parser.get()), str({\"A\": {\"B\": 1, \"C\": 2}, \"D\": [3, 1, 3, 4]}))\n\n    @parameterized.expand([TEST_CASE_4])\n    def test_allow_missing_reference(self, config):\n        default = ReferenceResolver.allow_missing_reference\n        ReferenceResolver.allow_missing_reference = True\n        parser = ConfigParser(config=config)\n\n        for id in config:\n            item = parser.get_parsed_content(id=id)\n            if id in (\"A\", \"B\"):\n                self.assertEqual(item, 1)\n            elif id == \"C\":\n                self.assertEqual(item, \"@D\")\n            elif id == \"E\":\n                self.assertEqual(item, \"test@F\")\n\n        # restore the default value\n        ReferenceResolver.allow_missing_reference = default\n        with self.assertRaises(ValueError):\n            parser.parse()\n            parser.get_parsed_content(id=\"E\")\n\n    def test_list_expressions(self):\n        config = {\n            \"transform\": {\n                \"_target_\": \"Compose\",\n                \"transforms\": [{\"_target_\": \"RandScaleIntensity\", \"factors\": 0.5, \"prob\": 1.0}],\n            },\n            \"training\": [\"$monai.utils.set_determinism(seed=123)\", \"$@transform(np.asarray([1, 2]))\"],\n        }\n        parser = ConfigParser(config=config)\n        parser.get_parsed_content(\"training\", lazy=True, instantiate=True, eval_expr=True)\n        np.testing.assert_allclose(parser.get_parsed_content(\"training#1\", lazy=True), [0.7942, 1.5885], atol=1e-4)\n\n    def test_contains(self):\n        empty_parser = ConfigParser({})\n        empty_parser.parse()\n\n        parser = ConfigParser({\"value\": 1, \"entry\": \"string content\", \"array\": [1, 2]})\n        parser.parse()\n\n        with self.subTest(\"Testing empty parser\"):\n            self.assertFalse(\"something\" in empty_parser)\n        with self.assertRaises(KeyError):\n            empty_parser[\"something\"]\n        empty_parser[\"osmething\"] = \"test\"\n        with self.assertRaises(KeyError):\n            empty_parser[\"something\"]\n\n        with self.subTest(\"Testing with keys\"):\n            self.assertTrue(\"value\" in parser)\n            self.assertFalse(\"value1\" in parser)\n            self.assertTrue(\"entry\" in parser)\n            self.assertFalse(\"entr\" in parser)\n            self.assertFalse(\"array#2\" in parser)\n\n    def test_lambda_reference(self):\n        configs = {\n            \"patch_size\": [8, 8],\n            \"transform\": {\"_target_\": \"Lambda\", \"func\": \"$lambda x: x.reshape((1, *@patch_size))\"},\n        }\n        parser = ConfigParser(config=configs)\n        trans = parser.get_parsed_content(id=\"transform\")\n        result = trans(np.ones(64))\n        self.assertTupleEqual(result.shape, (1, 8, 8))\n\n    def test_non_str_target(self):\n        configs = {\n            \"fwd\": {\"_target_\": \"$@model.forward\", \"x\": \"$torch.rand(1, 3, 256, 256)\", \"_mode_\": \"callable\"},\n            \"model\": {\"_target_\": \"monai.networks.nets.resnet.resnet18\", \"pretrained\": False, \"spatial_dims\": 2},\n        }\n        self.assertTrue(callable(ConfigParser(config=configs).fwd))\n        self.assertTupleEqual(tuple(ConfigParser(config=configs).fwd().shape), (1, 400))\n\n    def test_error_instance(self):\n        config = {\"transform\": {\"_target_\": \"Compose\", \"transforms_wrong_key\": []}}\n        parser = ConfigParser(config=config)\n        with self.assertRaises(RuntimeError):\n            parser.get_parsed_content(\"transform\", instantiate=True, eval_expr=True)\n\n    def test_pdb(self):\n        with self.assertRaisesRegex(RuntimeError, \".*bdb.BdbQuit.*\"):\n            case_pdb()\n        self.assertEqual(case_pdb_inst(), None)  # pdb.runcall without input is None\n\n    def test_get_via_attributes(self):\n        config = {\n            \"A\": {\"B\": {\"C\": 1}},\n            \"my_dims\": 2,\n            \"dims_1\": \"$@my_dims + 1\",\n            \"patch_size\": [8, 8],\n            \"transform\": {\"_target_\": \"Lambda\", \"func\": \"$lambda x: x.reshape((1, *@patch_size))\"},\n        }\n        parser = ConfigParser(config=config)\n        self.assertEqual(parser.A, {\"B\": {\"C\": 1}})\n        self.assertEqual(parser.dims_1, 3)\n\n        trans = parser.transform\n        result = trans(np.ones(64))\n        self.assertTupleEqual(result.shape, (1, 8, 8))\n\n    def test_builtin(self):\n        config = {\"import statements\": \"$import math\", \"calc\": {\"_target_\": \"math.isclose\", \"a\": 0.001, \"b\": 0.001}}\n        self.assertEqual(ConfigParser(config).calc, True)\n\n    def test_slicing(self):\n        config = {\"test\": [1, 2, 3, 4], \"test1\": \"$@test[::]\", \"test2\": \"$@test[::-1]\", \"st\": \"aten::relu\"}\n        self.assertEqual(ConfigParser(config).test1, [1, 2, 3, 4])\n        self.assertEqual(ConfigParser(config).test2, [4, 3, 2, 1])\n        self.assertEqual(ConfigParser(config).st, \"aten::relu\")\n\n    @parameterized.expand([TEST_CASE_5])\n    def test_substring_reference(self, config, expected):\n        parser = ConfigParser(config=config)\n        self.assertEqual(parser.get_parsed_content(\"total\"), expected)\n\n    @parameterized.expand([TEST_CASE_DUPLICATED_KEY_JSON, TEST_CASE_DUPLICATED_KEY_YAML])\n    @mock.patch.dict(os.environ, {\"MONAI_FAIL_ON_DUPLICATE_CONFIG\": \"1\"})\n    @skipUnless(has_yaml, \"Requires pyyaml\")\n    def test_parse_json_raise(self, config_string, extension, _, __):\n        with tempfile.TemporaryDirectory() as tempdir:\n            config_path = Path(tempdir) / f\"config.{extension}\"\n            config_path.write_text(config_string)\n            parser = ConfigParser()\n\n            with self.assertRaises(ValueError) as context:\n                parser.read_config(config_path)\n\n            self.assertTrue(\"Duplicate key: `duplicate`\" in str(context.exception))\n\n    @parameterized.expand([TEST_CASE_DUPLICATED_KEY_JSON, TEST_CASE_DUPLICATED_KEY_YAML])\n    @skipUnless(has_yaml, \"Requires pyyaml\")\n    def test_parse_json_warn(self, config_string, extension, expected_unique_val, expected_duplicate_vals):\n        with tempfile.TemporaryDirectory() as tempdir:\n            config_path = Path(tempdir) / f\"config.{extension}\"\n            config_path.write_text(config_string)\n            parser = ConfigParser()\n\n            with warnings.catch_warnings(record=True) as w:\n                parser.read_config(config_path)\n            self.assertEqual(len(w), 1)\n            self.assertTrue(\"Duplicate key: `duplicate`\" in str(w[-1].message))\n\n            self.assertEqual(parser.get_parsed_content(\"key#unique\"), expected_unique_val)\n            self.assertIn(parser.get_parsed_content(\"key#duplicate\"), expected_duplicate_vals)\n\n    @parameterized.expand([TEST_CASE_MERGE_JSON, TEST_CASE_MERGE_YAML])\n    @skipUnless(has_yaml, \"Requires pyyaml\")\n    def test_load_configs(\n        self, config_string, config_string2, extension, expected_overridden_val, expected_merged_vals\n    ):\n        with tempfile.TemporaryDirectory() as tempdir:\n            config_path1 = Path(tempdir) / f\"config1.{extension}\"\n            config_path2 = Path(tempdir) / f\"config2.{extension}\"\n            config_path1.write_text(config_string)\n            config_path2.write_text(config_string2)\n\n            parser = ConfigParser.load_config_files([config_path1, config_path2])\n\n            self.assertEqual(parser[\"key1\"], expected_overridden_val)\n            self.assertEqual(parser[\"key2\"], expected_merged_vals)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/bundle/test_reference_resolver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nimport monai\nfrom monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem\nfrom monai.bundle.reference_resolver import ReferenceResolver\nfrom monai.data import DataLoader\nfrom monai.transforms import LoadImaged, RandTorchVisiond\nfrom monai.utils import min_version, optional_import\n\n_, has_tv = optional_import(\"torchvision\", \"0.8.0\", min_version)\n\n# test instance with no dependencies\nTEST_CASE_1 = [\n    {\n        # all the recursively parsed config items\n        \"transform#1\": {\"_target_\": \"LoadImaged\", \"keys\": [\"image\"]},\n        \"transform#1#_target_\": \"LoadImaged\",\n        \"transform#1#keys\": [\"image\"],\n        \"transform#1#keys#0\": \"image\",\n    },\n    \"transform#1\",\n    LoadImaged,\n]\n# test depends on other component and executable code\nTEST_CASE_2 = [\n    {\n        # some the recursively parsed config items\n        \"dataloader\": {\"_target_\": \"DataLoader\", \"dataset\": \"@dataset\", \"collate_fn\": \"$monai.data.list_data_collate\"},\n        \"dataset\": {\"_target_\": \"Dataset\", \"data\": [1, 2]},\n        \"dataloader#_target_\": \"DataLoader\",\n        \"dataloader#dataset\": \"@dataset\",\n        \"dataloader#collate_fn\": \"$monai.data.list_data_collate\",\n        \"dataset#_target_\": \"Dataset\",\n        \"dataset#data\": [1, 2],\n        \"dataset#data#0\": 1,\n        \"dataset#data#1\": 2,\n    },\n    \"dataloader\",\n    DataLoader,\n]\n# test config has key `name`\nTEST_CASE_3 = [\n    {\n        # all the recursively parsed config items\n        \"transform::1\": {\"_target_\": \"RandTorchVisiond\", \"keys\": \"image\", \"name\": \"ColorJitter\", \"brightness\": 0.25},\n        \"transform#1::_target_\": \"RandTorchVisiond\",\n        \"transform::1::keys\": \"image\",\n        \"transform::1#name\": \"ColorJitter\",\n        \"transform::1::brightness\": 0.25,\n    },\n    \"transform#1\",\n    RandTorchVisiond,\n]\n\n\nclass TestReferenceResolver(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2] + ([TEST_CASE_3] if has_tv else []))\n    def test_resolve(self, configs, expected_id, output_type):\n        locator = ComponentLocator()\n        resolver = ReferenceResolver()\n        # add items to resolver\n        for k, v in configs.items():\n            k = k.replace(\"#\", \"::\")\n            if ConfigComponent.is_instantiable(v):\n                resolver.add_item(ConfigComponent(config=v, id=k, locator=locator))\n            elif ConfigExpression.is_expression(v):\n                resolver.add_item(ConfigExpression(config=v, id=k, globals={\"monai\": monai, \"torch\": torch}))\n            else:\n                resolver.add_item(ConfigItem(config=v, id=k))\n\n        result = resolver.get_resolved_content(expected_id)  # the root id is `expected_id` here\n        self.assertTrue(isinstance(result, output_type))\n\n        # test lazy instantiation\n        item = resolver.get_item(expected_id, resolve=True)\n        config = item.get_config()\n        config[\"_disabled_\"] = False\n        item.update_config(config=config)\n        if isinstance(item, ConfigComponent):\n            result = item.instantiate()\n        else:\n            result = item.get_config()\n        self.assertTrue(isinstance(result, output_type))\n\n    def test_circular_references(self):\n        locator = ComponentLocator()\n        resolver = ReferenceResolver()\n        configs = {\"A\": \"@B\", \"B\": \"@C\", \"C\": \"@A\"}\n        for k, v in configs.items():\n            resolver.add_item(ConfigComponent(config=v, id=k, locator=locator))\n        for k in [\"A\", \"B\", \"C\"]:\n            with self.assertRaises(ValueError):\n                resolver.get_resolved_content(k)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/clang_format_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# this file is adapted from\n# github/pytorch/pytorch/blob/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_utils.py\n\nfrom __future__ import annotations\n\nimport os\nimport platform\nimport stat\nimport sys\nfrom pathlib import Path\n\nfrom monai.apps.utils import download_url\n\n# String representing the host platform (e.g. Linux, Darwin).\nHOST_PLATFORM = platform.system()\n\n# MONAI directory root, derived from the location of this file.\nMONAI_ROOT = Path(__file__).resolve().parent.parent\n\n# This dictionary maps each platform to the S3 object URL for its clang-format binary.\nPLATFORM_TO_CF_URL = {\n    \"Darwin\": \"https://oss-clang-format.s3.us-east-2.amazonaws.com/mac/clang-format-mojave\",\n    \"Linux\": \"https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64\",\n}\n\n# This dictionary maps each platform to a relative path to a file containing its reference hash.\n# github/pytorch/pytorch/tree/63d62d3e44a0a4ec09d94f30381d49b78cc5b095/tools/clang_format_hash\nPLATFORM_TO_HASH = {\n    \"Darwin\": \"1485a242a96c737ba7cdd9f259114f2201accdb46d87ac7a8650b1a814cd4d4d\",\n    \"Linux\": \"e1c8b97b919541a99e0a355df5c3f9e8abebc64259dbee6f8c68e1ef90582856\",\n}\n\n# Directory and file paths for the clang-format binary.\nCLANG_FORMAT_DIR = os.path.join(MONAI_ROOT, \".clang-format-bin\")\nCLANG_FORMAT_PATH = os.path.join(CLANG_FORMAT_DIR, \"clang-format\")\n\n\ndef get_and_check_clang_format():\n    \"\"\"\n    Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify\n    that it is the right binary by checking its SHA1 hash against the expected hash.\n    \"\"\"\n    # If the host platform is not in PLATFORM_TO_HASH, it is unsupported.\n    if HOST_PLATFORM not in PLATFORM_TO_HASH:\n        print(f\"Unsupported platform: {HOST_PLATFORM}\")\n        return False\n    if HOST_PLATFORM not in PLATFORM_TO_CF_URL:\n        print(f\"Unsupported platform: {HOST_PLATFORM}\")\n        return False\n\n    try:\n        download_url(\n            PLATFORM_TO_CF_URL[HOST_PLATFORM], CLANG_FORMAT_PATH, PLATFORM_TO_HASH[HOST_PLATFORM], hash_type=\"sha256\"\n        )\n    except Exception as e:\n        print(f\"Download {CLANG_FORMAT_PATH} failed: {e}\")\n        print(f\"Please remove {CLANG_FORMAT_PATH} and retry.\")\n        return False\n\n    # Make sure the binary is executable.\n    mode = os.stat(CLANG_FORMAT_PATH).st_mode\n    mode |= stat.S_IXUSR\n    os.chmod(CLANG_FORMAT_PATH, mode)\n    print(f\"Using clang-format located at {CLANG_FORMAT_PATH}\")\n\n    return True\n\n\nif __name__ == \"__main__\":\n    ok = get_and_check_clang_format()\n    sys.exit(int(not ok))\n"
  },
  {
    "path": "tests/config/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/config/test_cv2_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\n# FIXME: test for the workaround of https://github.com/Project-MONAI/MONAI/issues/5291\nfrom monai.config.deviceconfig import print_config\nfrom tests.test_utils import skip_if_no_cuda\n\n\ndef main_worker(rank, ngpus_per_node, port):\n    dist.init_process_group(backend=\"nccl\", init_method=f\"tcp://127.0.0.1:{port}\", world_size=ngpus_per_node, rank=rank)\n    # `benchmark = True` is not compatible with openCV in PyTorch 22.09 docker for multi-gpu training\n    torch.backends.cudnn.benchmark = True\n\n    model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)\n    model = torch.nn.parallel.DistributedDataParallel(\n        model, device_ids=[rank], output_device=rank, find_unused_parameters=False\n    )\n    x = torch.ones(1, 1, 12, 12, 12).to(rank)\n    with torch.autocast(\"cuda\"):\n        model(x)\n\n    if dist.is_initialized():\n        dist.destroy_process_group()\n\n\n@skip_if_no_cuda\nclass TestCV2Dist(unittest.TestCase):\n    def test_cv2_cuda_ops(self):\n        print_config()\n        ngpus_per_node = torch.cuda.device_count()\n        port = np.random.randint(10000, 20000)\n        torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, port))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/config/test_print_info.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.config import print_debug_info\n\n\nclass TestPrintInfo(unittest.TestCase):\n\n    def test_print_info(self):\n        print_debug_info()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/croppers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Randomizable\nfrom monai.transforms.lazy.functional import apply_pending\nfrom monai.transforms.transform import MapTransform\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\n\nclass CropTest(unittest.TestCase):\n    @staticmethod\n    def get_arr(shape):\n        return np.random.randint(100, size=shape).astype(float)\n\n    def crop_test(self, input_param, input_shape, expected_shape, same_area=None):\n        base_comparison = None\n        input_image = self.get_arr(input_shape)\n\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                # input parameters, such as roi_start can be numpy, torch, list etc.\n                for param_type in TEST_NDARRAYS_ALL + (None,):\n                    with self.subTest(param_type=param_type):\n                        input_param_mod = deepcopy(input_param)\n                        if param_type is not None:\n                            for k in (\"roi_start\", \"roi_end\", \"roi_center\", \"roi_size\", \"roi_scale\"):\n                                if k in input_param:\n                                    input_param_mod[k] = param_type(input_param[k])\n                        im = im_type(input_image)\n                        cropper = self.Cropper(**input_param_mod)\n                        is_map = isinstance(cropper, MapTransform)\n                        input_data = {\"img\": im} if is_map else im\n                        result = cropper(input_data)\n                        out_im = result[\"img\"] if is_map else result\n                        self.assertIsInstance(out_im, MetaTensor)\n                        self.assertTupleEqual(out_im.shape, expected_shape)\n                        if same_area is not None:\n                            assert_allclose(out_im, im[same_area], type_test=False)\n                        # check result is the same regardless of input type\n                        if base_comparison is None:\n                            base_comparison = out_im\n                        else:\n                            assert_allclose(out_im, base_comparison)\n\n                        # test inverse\n                        inv = cropper.inverse(result)\n                        inv_im = inv[\"img\"] if is_map else inv\n                        self.assertIsInstance(inv_im, MetaTensor)\n                        if same_area is not None:\n                            assert_allclose(inv_im[same_area], im[same_area], type_test=False)\n                        self.assertEqual(inv_im.applied_operations, [])\n\n    def crop_test_value(self, input_param, input_arr, expected_array):\n        cropper = self.Cropper(**input_param)\n        is_map = isinstance(cropper, MapTransform)\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                im = im_type(input_arr)\n                input_data = {\"img\": im} if is_map else im\n                result = self.Cropper(**input_param)(input_data)\n                out_im = result[\"img\"] if is_map else result\n                self.assertIsInstance(out_im, MetaTensor)\n                assert_allclose(out_im, expected_array, type_test=False)\n\n    def multi_inverse(self, input_shape, init_params):\n        input_data = np.arange(np.prod(input_shape)).reshape(*input_shape) + 1\n        xform = self.Cropper(**init_params)\n        xform.set_random_state(1234)\n        out = xform(input_data)\n        if \"num_samples\" in init_params:\n            self.assertEqual(len(out), init_params[\"num_samples\"])\n        inv = xform.inverse(out)\n        self.assertIsInstance(inv, MetaTensor)\n        self.assertEqual(inv.applied_operations, [])\n        self.assertTrue(\"patch_index\" not in inv.meta)\n        self.assertTupleEqual(inv.shape, input_shape)\n        inv_np = inv.numpy()\n\n        # get list of all numbers that exist inside the crops\n        uniques = set()\n        for o in out:\n            uniques.update(set(o.flatten().tolist()))\n\n        # make sure that\n        for i in uniques:\n            a = np.where(input_data == i)\n            b = np.where(inv_np == i)\n            self.assertTupleEqual(a, b)\n        # there should be as many zeros as elements missing from uniques\n        missing = input_data.size - len(uniques)\n        self.assertEqual((inv_np == 0).sum(), missing)\n\n    def crop_test_pending_ops(self, input_param, input_shape, align_corners=False):\n        crop_fn = self.Cropper(**input_param)\n        data = self.get_arr(input_shape)\n        is_map = isinstance(crop_fn, MapTransform)\n        im = MetaTensor(data, meta={\"a\": \"b\", \"affine\": np.eye(len(input_shape))})\n        input_data = {\"img\": im} if is_map else im\n        # non-lazy\n        result_non_lazy = crop_fn(input_data)\n        expected = result_non_lazy[\"img\"] if is_map else result_non_lazy\n        self.assertIsInstance(expected, MetaTensor)\n        # lazy\n        crop_fn.lazy = True\n        pending_result = crop_fn(input_data)\n        pending_result = pending_result[\"img\"] if is_map else pending_result\n        self.assertIsInstance(pending_result, MetaTensor)\n        assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n        assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n        # only support nearest\n        overrides = {\"mode\": \"nearest\", \"align_corners\": align_corners}\n        result = apply_pending(pending_result, overrides=overrides)[0]\n        # compare\n        assert_allclose(result, expected, rtol=1e-5)\n        if isinstance(result, MetaTensor) and not isinstance(crop_fn, MapTransform):\n            crop_fn.lazy = False\n            inverted = crop_fn.inverse(result)\n            self.assertTrue((not inverted.applied_operations) and (not inverted.pending_operations))\n            self.assertEqual(inverted.shape, im.shape)\n\n    def crop_test_combine_ops(self, funcs, input_shape):\n        _funcs = []\n        for func in funcs:\n            for _func, _params in func.items():\n                _funcs.append(_func(**_params))\n        is_map = isinstance(_funcs[0], MapTransform)\n        data = self.get_arr(input_shape)\n        im = MetaTensor(data, meta={\"a\": \"b\", \"affine\": np.eye(len(input_shape))})\n        input_data = {\"img\": im} if is_map else im\n\n        # non-lazy\n        non_lazy_result = input_data\n        for _func in _funcs:\n            if isinstance(_func, Randomizable):\n                _func.set_random_state(seed=123)\n            non_lazy_result = _func(non_lazy_result)\n        expected = non_lazy_result[\"img\"] if is_map else non_lazy_result\n        self.assertIsInstance(expected, MetaTensor)\n\n        # lazy\n        pending_result = input_data\n        for _func in _funcs:\n            _func.lazy = True\n            if isinstance(_func, Randomizable):\n                _func.set_random_state(seed=123)\n            pending_result = _func(pending_result)\n        pending_result = pending_result[\"img\"] if is_map else pending_result\n        self.assertIsInstance(pending_result, MetaTensor)\n        assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n        assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n        # TODO: mode=\"bilinear\" may report error\n        overrides = {\"mode\": \"nearest\", \"align_corners\": False}\n        result = apply_pending(pending_result, overrides=overrides)[0]\n\n        # compare\n        assert_allclose(result, expected, rtol=1e-5)\n"
  },
  {
    "path": "tests/data/meta_tensor/test_meta_tensor.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport io\nimport os\nimport random\nimport string\nimport tempfile\nimport unittest\nfrom copy import deepcopy\nfrom multiprocessing.reduction import ForkingPickler\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing\nfrom parameterized import parameterized\n\nfrom monai import config\nfrom monai.data import DataLoader, Dataset\nfrom monai.data.meta_obj import get_track_meta, set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import decollate_batch, list_data_collate\nfrom monai.transforms import BorderPadd, Compose, DivisiblePadd, FromMetaTensord, ToMetaTensord\nfrom monai.utils.enums import PostFix\nfrom tests.test_utils import TEST_DEVICES, assert_allclose, dict_product, skip_if_no_cuda\n\nDTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32], [None]]\n\n# Replace nested loops with dict_product\n\nTESTS = [(*params[\"device\"], *params[\"dtype\"]) for params in dict_product(device=TEST_DEVICES, dtype=DTYPES)]\n\n\ndef rand_string(min_len=5, max_len=10):\n    str_size = random.randint(min_len, max_len)\n    chars = string.ascii_letters + string.punctuation\n    return \"\".join(random.choice(chars) for _ in range(str_size))\n\n\nclass TestMetaTensor(unittest.TestCase):\n    @staticmethod\n    def get_im(shape=None, dtype=None, device=None):\n        if shape is None:\n            shape = (1, 10, 8)\n        affine = torch.randint(0, 10, (4, 4))\n        meta = {\"fname\": rand_string()}\n        t = torch.rand(shape)\n        if dtype is not None:\n            t = t.to(dtype)\n        if device is not None:\n            t = t.to(device)\n        m = MetaTensor(t.clone(), affine, meta)\n        return m, t\n\n    def check_ids(self, a, b, should_match):\n        comp = self.assertEqual if should_match else self.assertNotEqual\n        comp(id(a), id(b))\n\n    def check_meta(self, a: MetaTensor, b: MetaTensor) -> None:\n        self.assertEqual(a.is_batch, b.is_batch)\n        meta_a, meta_b = a.meta, b.meta\n        # need to split affine from rest of metadata\n        aff_a = meta_a.get(\"affine\", None)\n        aff_b = meta_b.get(\"affine\", None)\n        assert_allclose(aff_a, aff_b)\n        meta_a = {k: v for k, v in meta_a.items() if k != \"affine\"}\n        meta_b = {k: v for k, v in meta_b.items() if k != \"affine\"}\n        self.assertEqual(meta_a, meta_b)\n\n    def check(\n        self,\n        out: torch.Tensor,\n        orig: torch.Tensor,\n        *,\n        shape: bool = True,\n        vals: bool = True,\n        ids: bool = True,\n        device: str | torch.device | None = None,\n        meta: bool = True,\n        check_ids: bool = True,\n        **kwargs,\n    ):\n        if device is None:\n            device = orig.device\n\n        # check the image\n        self.assertIsInstance(out, type(orig))\n        if shape:\n            assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape))\n        if vals:\n            assert_allclose(out, orig, **kwargs)\n        if check_ids:\n            self.check_ids(out, orig, ids)\n        self.assertTrue(str(device) in str(out.device))\n\n        # check meta and affine are equal and affine is on correct device\n        if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta:\n            self.check_meta(orig, out)\n            if check_ids:\n                self.check_ids(out.meta, orig.meta, ids)\n\n    @parameterized.expand(TESTS)\n    def test_as_tensor(self, device, dtype):\n        m, t = self.get_im(device=device, dtype=dtype)\n        t2 = m.as_tensor()\n        self.assertIsInstance(t2, torch.Tensor)\n        self.assertNotIsInstance(t2, MetaTensor)\n        self.assertIsInstance(m, MetaTensor)\n        self.check(t, t2, ids=False)\n\n    def test_as_dict(self):\n        m, _ = self.get_im()\n        m_dict = m.as_dict(\"im\")\n        im, meta = m_dict[\"im\"], m_dict[PostFix.meta(\"im\")]\n        affine = meta.pop(\"affine\")\n        m2 = MetaTensor(im, affine, meta)\n        self.check(m2, m, check_ids=False)\n\n    @parameterized.expand(TESTS)\n    def test_constructor(self, device, dtype):\n        m, t = self.get_im(device=device, dtype=dtype)\n        # construct from pre-existing\n        m1 = MetaTensor(m.clone())\n        self.check(m, m1, ids=False, meta=False)\n        # meta already has affine\n        m2 = MetaTensor(t.clone(), meta=m.meta)\n        self.check(m, m2, ids=False, meta=False)\n        # meta dosen't have affine\n        affine = m.meta.pop(\"affine\")\n        m3 = MetaTensor(t.clone(), affine=affine, meta=m.meta)\n        self.check(m, m3, ids=False, meta=False)\n\n    @parameterized.expand(TESTS)\n    @skip_if_no_cuda\n    def test_to_cuda(self, device, dtype):\n        \"\"\"Test `to`, `cpu` and `cuda`. For `to`, check args and kwargs.\"\"\"\n        orig, _ = self.get_im(device=device, dtype=dtype)\n        m = orig.clone()\n        m = m.to(\"cuda\")\n        self.check(m, orig, ids=False, device=\"cuda\")\n        m = m.cpu()\n        self.check(m, orig, ids=False, device=\"cpu\")\n        m = m.cuda()\n        self.check(m, orig, ids=False, device=\"cuda\")\n        m = m.to(\"cpu\")\n        self.check(m, orig, ids=False, device=\"cpu\")\n        m = m.to(device=\"cuda\")\n        self.check(m, orig, ids=False, device=\"cuda\")\n        m = m.to(device=\"cpu\")\n        self.check(m, orig, ids=False, device=\"cpu\")\n\n    @skip_if_no_cuda\n    def test_affine_device(self):\n        m, _ = self.get_im()  # device=\"cuda\")\n        m.affine = torch.eye(4)\n        self.assertTrue(\"cpu\" in str(m.affine.device))\n\n    @parameterized.expand(TESTS)\n    def test_copy(self, device, dtype):\n        m, _ = self.get_im(device=device, dtype=dtype)\n        # shallow copy\n        a = m\n        self.check(a, m, ids=True)\n        # deepcopy\n        a = deepcopy(m)\n        self.check(a, m, ids=False)\n        # clone\n        a = m.clone(memory_format=torch.preserve_format)\n        a = m.clone()\n        self.check(a, m, ids=False)\n        a = MetaTensor([[]], device=device, dtype=dtype)\n        self.check(a, deepcopy(a), ids=False)\n\n    @parameterized.expand(TESTS)\n    def test_add(self, device, dtype):\n        m1, t1 = self.get_im(device=device, dtype=dtype)\n        m2, t2 = self.get_im(device=device, dtype=dtype)\n        self.check(m1 + m2, t1 + t2, ids=False)\n        self.check(torch.add(m1, m2), t1 + t2, ids=False)\n        self.check(torch.add(input=m1, other=m2), t1 + t2, ids=False)\n        self.check(torch.add(m1, other=m2), t1 + t2, ids=False)\n        m3 = deepcopy(m2)\n        t3 = deepcopy(t2)\n        m3 += 3\n        t3 += 3\n        self.check(m3, t3, ids=False)\n        # check torch.Tensor+MetaTensor and MetaTensor+torch.Tensor\n        self.check(torch.add(m1, t2), t1 + t2, ids=False)\n        self.check(torch.add(t2, m1), t1 + t2, ids=False)\n\n    @parameterized.expand(TEST_DEVICES)\n    def test_conv(self, device):\n        im, _ = self.get_im((1, 3, 10, 8, 12), device=device)\n        conv = torch.nn.Conv3d(im.shape[1], 5, 3)\n        conv.to(device)\n        out = conv(im)\n        self.check(out, im, shape=False, vals=False, ids=False)\n\n    @parameterized.expand(TESTS)\n    def test_stack(self, device, dtype):\n        numel = 3\n        ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)]\n        stacked = torch.stack(ims)\n        self.assertIsInstance(stacked, MetaTensor)\n        orig_affine = ims[0].meta.pop(\"affine\")\n        stacked_affine = stacked.meta.pop(\"affine\")\n        assert_allclose(orig_affine, stacked_affine)\n        self.assertEqual(stacked.meta, ims[0].meta)\n\n    def test_get_set_meta_fns(self):\n        set_track_meta(False)\n        self.assertFalse(get_track_meta())\n        set_track_meta(True)\n        self.assertTrue(get_track_meta())\n\n    @parameterized.expand(TEST_DEVICES)\n    def test_torchscript(self, device):\n        shape = (1, 3, 10, 8)\n        im, _ = self.get_im(shape, device=device)\n        conv = torch.nn.Conv2d(im.shape[1], 5, 3)\n        conv.to(device)\n        im_conv = conv(im)\n        traced_fn = torch.jit.trace(conv, im.as_tensor())\n        # save it, load it, use it\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"im.pt\")\n            torch.jit.save(traced_fn, f=fname)\n            traced_fn = torch.jit.load(fname)\n            out = traced_fn(im)\n            self.assertIsInstance(out, torch.Tensor)\n        self.check(out, im_conv, ids=False)\n\n    def test_pickling(self):\n        m, _ = self.get_im()\n        with tempfile.TemporaryDirectory() as tmp_dir:\n            fname = os.path.join(tmp_dir, \"im.pt\")\n            torch.save(m, fname)\n            m2 = torch.load(fname, weights_only=False)\n        self.check(m2, m, ids=False)\n\n    @skip_if_no_cuda\n    def test_amp(self):\n        shape = (1, 3, 10, 8)\n        device = \"cuda\"\n        im, _ = self.get_im(shape, device=device)\n        conv = torch.nn.Conv2d(im.shape[1], 5, 3)\n        conv.to(device)\n        im_conv = conv(im)\n        with torch.autocast(\"cuda\"):\n            im_conv2 = conv(im)\n        self.check(im_conv2, im_conv, ids=False, rtol=1e-2, atol=1e-2)\n\n    def test_out(self):\n        \"\"\"Test when `out` is given as an argument.\"\"\"\n        m1, _ = self.get_im()\n        m2, _ = self.get_im()\n        m3, _ = self.get_im()\n        torch.add(m2, m3, out=m1)\n        m1_add = m2 + m3\n\n        assert_allclose(m1, m1_add)\n        # self.check_meta(m1, m2)  # meta is from first input tensor\n\n    @parameterized.expand(TESTS)\n    def test_collate(self, device, dtype):\n        numel = 3\n        ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)]\n        ims = [MetaTensor(im, applied_operations=[f\"t{i}\"]) for i, im in enumerate(ims)]\n        collated = list_data_collate(ims)\n        # tensor\n        self.assertIsInstance(collated, MetaTensor)\n        expected_shape = (numel,) + tuple(ims[0].shape)\n        self.assertTupleEqual(tuple(collated.shape), expected_shape)\n        for i, im in enumerate(ims):\n            self.check(im, ims[i], ids=True)\n        # affine\n        self.assertIsInstance(collated.affine, torch.Tensor)\n        expected_shape = (numel,) + tuple(ims[0].affine.shape)\n        self.assertTupleEqual(tuple(collated.affine.shape), expected_shape)\n        self.assertEqual(len(collated.applied_operations), numel)\n\n    @parameterized.expand(TESTS)\n    def test_dataset(self, device, dtype):\n        ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)]\n        ds = Dataset(ims)\n        for i, im in enumerate(ds):\n            self.check(im, ims[i], ids=True)\n\n    @parameterized.expand(DTYPES)\n    def test_dataloader(self, dtype):\n        batch_size = 5\n        ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)]\n        ims = [MetaTensor(im, applied_operations=[f\"t{i}\"]) for i, im in enumerate(ims)]\n        ds = Dataset(ims)\n        im_shape = tuple(ims[0].shape)\n        affine_shape = tuple(ims[0].affine.shape)\n        expected_im_shape = (batch_size,) + im_shape\n        expected_affine_shape = (batch_size,) + affine_shape\n        dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size)\n        for batch in dl:\n            self.assertIsInstance(batch, MetaTensor)\n            self.assertTupleEqual(tuple(batch.shape), expected_im_shape)\n            self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape)\n            self.assertEqual(len(batch.applied_operations), batch_size)\n\n    def test_indexing(self):\n        \"\"\"\n        Check the metadata is returned in the expected format depending on whether\n        the input `MetaTensor` is a batch of data or not.\n        \"\"\"\n        ims = [self.get_im()[0] for _ in range(5)]\n        data = list_data_collate(ims)\n\n        # check that when using non-batch data, metadata is copied wholly when indexing\n        # or iterating across data.\n        im = ims[0]\n        self.check_meta(im[0], im)\n        self.check_meta(next(iter(im)), im)\n\n        self.assertEqual(im[None].shape, (1, 1, 10, 8))\n        self.assertEqual(data[None].shape, (1, 5, 1, 10, 8))\n\n        # index\n        d = data[0]\n        self.check(d, ims[0], ids=False)\n\n        # iter\n        d = next(iter(data))\n        self.check(d, ims[0], ids=False)\n\n        # complex indexing\n\n        # `is_batch==True`, should have subset of image and metadata.\n        d = data[1:3]\n        self.check(d, list_data_collate(ims[1:3]), ids=False)\n\n        # is_batch==True, should have subset of image and same metadata as `[1:3]`.\n        d = data[1:3, 0]\n        self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False)\n\n        # `is_batch==False`, should have first metadata and subset of first image.\n        d = data[0, 0]\n        self.check(d, ims[0][0], ids=False)\n        self.assertEqual(d.applied_operations, ims[0][0].applied_operations)\n\n        # `is_batch==True`, should have all metadata and subset of all images.\n        d = data[:, 0]\n        self.check(d, list_data_collate([i[0] for i in ims]), ids=False)\n\n        # `is_batch==True`, should have all metadata and subset of all images.\n        d = data[..., -1]\n        self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False)\n\n        # `is_batch==False`, tuple split along batch dim. Should have individual\n        # metadata.\n        d = data.unbind(0)\n        self.assertIsInstance(d, tuple)\n        self.assertEqual(len(d), len(ims))\n        for _d, _im in zip(d, ims):\n            self.check(_d, _im, ids=False)\n\n        # `is_batch==False`, tuple split along batch dim. Should have individual\n        # metadata.\n        d = data.unbind(dim=0)\n        self.assertIsInstance(d, tuple)\n        self.assertEqual(len(d), len(ims))\n        for _d, _im in zip(d, ims):\n            self.check(_d, _im, ids=False)\n            self.assertEqual(_d.applied_operations, _im.applied_operations)\n\n        # `is_batch==True`, tuple split along non-batch dim. Should have all metadata.\n        d = data.unbind(-1)\n        self.assertIsInstance(d, tuple)\n        self.assertEqual(len(d), ims[0].shape[-1])\n        for _d in d:\n            self.check_meta(_d, data)\n\n        # `is_batch==True`, tuple split along non-batch dim. Should have all metadata.\n        d = data.unbind(dim=-1)\n        self.assertIsInstance(d, tuple)\n        self.assertEqual(len(d), ims[0].shape[-1])\n        for _d in d:\n            self.check_meta(_d, data)\n\n    def test_slicing(self):\n        x = MetaTensor(np.zeros((10, 3, 4)))\n        self.assertEqual(x[slice(4, 1)].shape[0], 0)\n        x.is_batch = True\n        with self.assertRaises(ValueError):\n            x[slice(0, 8)]\n        x = MetaTensor(np.zeros((3, 3, 4)))\n        x.is_batch = True\n        self.assertEqual(x[torch.tensor([True, False, True])].shape, (2, 3, 4))\n        self.assertEqual(x[[True, False, True]].shape, (2, 3, 4))\n\n    @parameterized.expand(DTYPES)\n    def test_decollate(self, dtype):\n        batch_size = 3\n        ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)]\n        ds = Dataset(ims)\n        dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size)\n        batch = next(iter(dl))\n        decollated = decollate_batch(batch)\n        self.assertIsInstance(decollated, list)\n        self.assertEqual(len(decollated), batch_size)\n        for elem, im in zip(decollated, ims):\n            self.assertIsInstance(elem, MetaTensor)\n            self.check(elem, im, ids=False)\n\n    def test_str(self):\n        t = MetaTensor([1.0], affine=torch.tensor(1), meta={\"fname\": \"filename\"})\n        self.assertEqual(str(t), \"metatensor([1.])\")\n        self.assertEqual(t.__repr__(), \"metatensor([1.])\")\n        self.assertEqual(f\"{t[0]:.2f}\", \"1.00\")\n\n    def test_shape(self):\n        s = MetaTensor([1])\n        self.assertEqual(s.shape, torch.Size([1]))\n        self.assertEqual(s.size(), torch.Size([1]))\n        self.assertEqual(s.size(0), 1)\n\n    def test_astype(self):\n        t = MetaTensor([1.0], affine=torch.tensor(1), meta={\"fname\": \"filename\"})\n        for np_types in (\"float32\", \"np.float32\", \"numpy.float32\", np.float32, float, \"int\", np.uint16):\n            self.assertIsInstance(t.astype(np_types), np.ndarray)\n        for pt_types in (\"torch.float\", torch.float, \"torch.float64\"):\n            self.assertIsInstance(t.astype(pt_types), torch.Tensor)\n        self.assertIsInstance(t.astype(\"torch.float\", device=\"cpu\"), torch.Tensor)\n\n    def test_transforms(self):\n        key = \"im\"\n        _, im = self.get_im()\n        tr = Compose([ToMetaTensord(key), BorderPadd(key, 1), DivisiblePadd(key, 16), FromMetaTensord(key)])\n        num_tr = len(tr.transforms)\n        data = {key: im, PostFix.meta(key): {\"affine\": torch.eye(4)}}\n\n        # apply one at a time\n        for i, _tr in enumerate(tr.transforms):\n            data = _tr(data)\n            is_meta = isinstance(_tr, (ToMetaTensord, BorderPadd, DivisiblePadd))\n            if is_meta:\n                self.assertEqual(len(data), 1 if not config.USE_META_DICT else 2)  # im, im_transforms, compatibility\n                self.assertIsInstance(data[key], MetaTensor)\n                n_applied = len(data[key].applied_operations)\n            else:\n                self.assertEqual(len(data), 3)  # im, im_meta_dict, im_transforms\n                self.assertIsInstance(data[key], torch.Tensor)\n                self.assertNotIsInstance(data[key], MetaTensor)\n                n_applied = len(data[PostFix.transforms(key)])\n\n            self.assertEqual(n_applied, i + 1)\n\n        # inverse one at a time\n        for i, _tr in enumerate(tr.transforms[::-1]):\n            data = _tr.inverse(data)\n            is_meta = isinstance(_tr, (FromMetaTensord, BorderPadd, DivisiblePadd))\n            if is_meta:\n                self.assertEqual(len(data), 1)  # im\n                self.assertIsInstance(data[key], MetaTensor)\n                n_applied = len(data[key].applied_operations)\n            else:\n                self.assertEqual(len(data), 3)  # im, im_meta_dict, im_transforms\n                self.assertIsInstance(data[key], torch.Tensor)\n                self.assertNotIsInstance(data[key], MetaTensor)\n                n_applied = len(data[PostFix.transforms(key)])\n\n            self.assertEqual(n_applied, num_tr - i - 1)\n\n        # apply all in one go\n        data = tr({key: im, PostFix.meta(key): {\"affine\": torch.eye(4)}})\n        self.assertEqual(len(data), 3)  # im, im_meta_dict, im_transforms\n        self.assertIsInstance(data[key], torch.Tensor)\n        self.assertNotIsInstance(data[key], MetaTensor)\n        n_applied = len(data[PostFix.transforms(key)])\n        self.assertEqual(n_applied, num_tr)\n\n        # inverse all in one go\n        data = tr.inverse(data)\n        self.assertEqual(len(data), 3)  # im, im_meta_dict, im_transforms\n        self.assertIsInstance(data[key], torch.Tensor)\n        self.assertNotIsInstance(data[key], MetaTensor)\n        n_applied = len(data[PostFix.transforms(key)])\n        self.assertEqual(n_applied, 0)\n\n    def test_construct_with_pre_applied_transforms(self):\n        key = \"im\"\n        _, im = self.get_im()\n        tr = Compose([BorderPadd(key, 1), DivisiblePadd(key, 16)])\n        data = tr({key: im})\n        m = MetaTensor(im, applied_operations=data[\"im\"].applied_operations)\n        self.assertEqual(len(m.applied_operations), len(tr.transforms))\n\n    def test_pending_ops(self):\n        m, _ = self.get_im()\n        self.assertEqual(m.pending_operations, [])\n        self.assertEqual(m.peek_pending_shape(), (10, 8))\n        self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)\n        self.assertTrue(m.peek_pending_rank() >= 1)\n        m.push_pending_operation({})\n        self.assertEqual(m.peek_pending_shape(), (10, 8))\n        self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)\n        self.assertTrue(m.peek_pending_rank() >= 1)\n\n    @parameterized.expand(TESTS)\n    def test_multiprocessing(self, device=None, dtype=None):\n        \"\"\"multiprocessing sharing with 'device' and 'dtype'\"\"\"\n        buf = io.BytesIO()\n        t = MetaTensor([0, 0] if dtype in (torch.int32, torch.int64) else [0.0, 0.0], device=device, dtype=dtype)\n        t.is_batch = True\n        if t.is_cuda:\n            with self.assertRaises(NotImplementedError):\n                ForkingPickler(buf).dump(t)\n            return\n        ForkingPickler(buf).dump(t)\n        obj = ForkingPickler.loads(buf.getvalue())\n        self.assertIsInstance(obj, MetaTensor)\n        assert_allclose(obj.as_tensor(), t)\n        assert_allclose(obj.is_batch, True)\n\n    @parameterized.expand(TESTS)\n    def test_array_function(self, device=\"cpu\", dtype=float):\n        a = np.random.RandomState().randn(100, 100)\n        b = MetaTensor(a, device=device)\n        assert_allclose(np.sum(a), np.sum(b))\n        assert_allclose(np.sum(a, axis=1), np.sum(b, axis=1))\n        assert_allclose(np.linalg.qr(a), np.linalg.qr(b))\n        c = MetaTensor(\n            [1, 2, 3] if dtype in (torch.int32, torch.int64) else [1.0, 2.0, 3.0], device=device, dtype=dtype\n        )\n        assert_allclose(np.argwhere(c == 1.0).astype(int).tolist(), [[0]])\n        assert_allclose(np.concatenate([c, c]), np.asarray([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))\n        assert_allclose(c > np.asarray([1.0, 1.0, 1.0]), np.asarray([False, True, True]))\n        assert_allclose(\n            c > torch.as_tensor([1.0, 1.0, 1.0], device=device), torch.as_tensor([False, True, True], device=device)\n        )\n\n    @parameterized.expand(TESTS)\n    def test_numpy(self, device=None, dtype=None):\n        \"\"\"device, dtype\"\"\"\n        t = MetaTensor([0 if dtype in (torch.int32, torch.int64) else 0.0], device=device, dtype=dtype)\n        self.assertIsInstance(t, MetaTensor)\n        assert_allclose(t.array, np.asarray([0.0]))\n        t.array = np.asarray([1.0])\n        self.check_meta(t, MetaTensor([1.0]))\n        assert_allclose(t.as_tensor(), torch.as_tensor([1.0]))\n        t.array = [2.0]\n        self.check_meta(t, MetaTensor([2.0]))\n        assert_allclose(t.as_tensor(), torch.as_tensor([2.0]))\n        if not t.is_cuda:\n            t.array[0] = torch.as_tensor(3 if dtype in (torch.int32, torch.int64) else 3.0, device=device, dtype=dtype)\n            self.check_meta(t, MetaTensor([3.0]))\n            assert_allclose(t.as_tensor(), torch.as_tensor([3.0]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/meta_tensor/test_to_from_meta_tensord.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nimport string\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai import config\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import FromMetaTensord, ToMetaTensord\nfrom monai.utils.enums import PostFix\nfrom tests.test_utils import TEST_DEVICES, assert_allclose\n\nDTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]]\nTESTS = []\nfor _device in TEST_DEVICES:\n    for _dtype in DTYPES:\n        for _data_type in (\"tensor\", \"numpy\"):\n            TESTS.append((*_device, *_dtype, _data_type))\n\n\ndef rand_string(min_len=5, max_len=10):\n    str_size = random.randint(min_len, max_len)\n    chars = string.ascii_letters + string.punctuation\n    return \"\".join(random.choice(chars) for _ in range(str_size))\n\n\n@unittest.skipIf(config.USE_META_DICT, \"skipping not metatensor\")\nclass TestToFromMetaTensord(unittest.TestCase):\n    @staticmethod\n    def get_im(shape=None, dtype=None, device=None):\n        if shape is None:\n            shape = (1, 10, 8)\n        affine = torch.randint(0, 10, (4, 4))\n        meta = {\"fname\": rand_string()}\n        t = torch.rand(shape)\n        if dtype is not None:\n            t = t.to(dtype)\n        if device is not None:\n            t = t.to(device)\n        m = MetaTensor(t.clone(), affine, meta)\n        return m\n\n    def check_ids(self, a, b, should_match):\n        comp = self.assertEqual if should_match else self.assertNotEqual\n        comp(id(a), id(b))\n\n    def check(\n        self,\n        out: torch.Tensor,\n        orig: torch.Tensor,\n        *,\n        shape: bool = True,\n        vals: bool = True,\n        ids: bool = True,\n        device: str | torch.device | None = None,\n        meta: bool = True,\n        check_ids: bool = False,\n        **kwargs,\n    ):\n        if device is None:\n            device = orig.device\n\n        # check the image\n        self.assertIsInstance(out, type(orig))\n        if shape:\n            assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape))\n        if vals:\n            assert_allclose(out, orig, **kwargs)\n        if check_ids:\n            self.check_ids(out, orig, ids)\n        self.assertTrue(str(device) in str(out.device))\n\n        # check meta and affine are equal and affine is on correct device\n        if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta:\n            orig_meta_no_affine = deepcopy(orig.meta)\n            del orig_meta_no_affine[\"affine\"]\n            out_meta_no_affine = deepcopy(out.meta)\n            del out_meta_no_affine[\"affine\"]\n            self.assertEqual(orig_meta_no_affine, out_meta_no_affine)\n            assert_allclose(out.affine, orig.affine)\n            if check_ids:\n                self.check_ids(out.affine, orig.affine, ids)\n                self.check_ids(out.meta, orig.meta, ids)\n\n    @parameterized.expand(TESTS)\n    def test_from_to_meta_tensord(self, device, dtype, data_type=\"tensor\"):\n        m1 = self.get_im(device=device, dtype=dtype)\n        m2 = self.get_im(device=device, dtype=dtype)\n        m3 = self.get_im(device=device, dtype=dtype)\n        d_metas = {\"m1\": m1, \"m2\": m2, \"m3\": m3}\n        m1_meta = {k: v for k, v in m1.meta.items() if k != \"affine\"}\n        m1_aff = m1.affine\n\n        # FROM -> forward\n        t_from_meta = FromMetaTensord([\"m1\", \"m2\"], data_type=data_type)\n        d_dict = t_from_meta(d_metas)\n\n        self.assertEqual(\n            sorted(d_dict.keys()),\n            [\n                \"m1\",\n                PostFix.meta(\"m1\"),\n                PostFix.transforms(\"m1\"),\n                \"m2\",\n                PostFix.meta(\"m2\"),\n                PostFix.transforms(\"m2\"),\n                \"m3\",\n            ],\n        )\n        self.check(d_dict[\"m3\"], m3, ids=True)  # unchanged\n        if data_type == \"tensor\":\n            self.check(d_dict[\"m1\"], m1.as_tensor(), ids=False)\n        else:\n            self.assertIsInstance(d_dict[\"m1\"], np.ndarray)\n        meta_out = {k: v for k, v in d_dict[\"m1_meta_dict\"].items() if k != \"affine\"}\n        aff_out = d_dict[\"m1_meta_dict\"][\"affine\"]\n        self.check(aff_out, m1_aff, ids=False)\n        self.assertEqual(meta_out, m1_meta)\n\n        # FROM -> inverse\n        d_meta_dict_meta = t_from_meta.inverse(d_dict)\n        self.assertEqual(sorted(d_meta_dict_meta.keys()), [\"m1\", \"m2\", \"m3\"])\n        if data_type == \"numpy\":\n            m1, m1_aff = m1.cpu(), m1_aff.cpu()\n        self.check(d_meta_dict_meta[\"m1\"], m1, ids=False)\n        meta_out = {k: v for k, v in d_meta_dict_meta[\"m1\"].meta.items() if k != \"affine\"}\n        aff_out = d_meta_dict_meta[\"m1\"].affine\n        self.check(aff_out, m1_aff, ids=False)\n        self.assertEqual(meta_out, m1_meta)\n\n        # TO -> Forward\n        t_to_meta = ToMetaTensord([\"m1\", \"m2\"])\n        d_dict_meta = t_to_meta(d_dict)\n        self.assertEqual(sorted(d_dict_meta.keys()), [\"m1\", \"m2\", \"m3\"], f\"flag: {config.USE_META_DICT}\")\n        self.check(d_dict_meta[\"m3\"], m3, ids=True)  # unchanged (except deep copy in inverse)\n        self.check(d_dict_meta[\"m1\"], m1, ids=False)\n        meta_out = {k: v for k, v in d_dict_meta[\"m1\"].meta.items() if k != \"affine\"}\n        aff_out = d_dict_meta[\"m1\"].meta[\"affine\"]\n        self.check(aff_out, m1_aff, ids=False)\n        self.assertEqual(meta_out, m1_meta)\n\n        # TO -> Inverse\n        d_dict_meta_dict = t_to_meta.inverse(d_dict_meta)\n        self.assertEqual(\n            sorted(d_dict_meta_dict.keys()),\n            [\n                \"m1\",\n                PostFix.meta(\"m1\"),\n                PostFix.transforms(\"m1\"),\n                \"m2\",\n                PostFix.meta(\"m2\"),\n                PostFix.transforms(\"m2\"),\n                \"m3\",\n            ],\n        )\n        self.check(d_dict_meta_dict[\"m3\"], m3.as_tensor(), ids=False)  # unchanged (except deep copy in inverse)\n        self.check(d_dict_meta_dict[\"m1\"], m1.as_tensor(), ids=False)\n        meta_out = {k: v for k, v in d_dict_meta_dict[\"m1_meta_dict\"].items() if k != \"affine\"}\n        aff_out = d_dict_meta_dict[\"m1_meta_dict\"][\"affine\"]\n        self.check(aff_out, m1_aff, ids=False)\n        self.assertEqual(meta_out, m1_meta)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_arraydataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\nfrom torch.utils.data import DataLoader\n\nfrom monai.data import ArrayDataset\nfrom monai.transforms import Compose, EnsureChannelFirst, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing\n\nTEST_CASE_1 = [\n    Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim=\"no_channel\"), RandGaussianNoise(prob=1.0)]),\n    Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim=\"no_channel\"), RandGaussianNoise(prob=1.0)]),\n    (0, 1),\n    (1, 128, 128, 128),\n]\n\nTEST_CASE_2 = [\n    Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim=\"no_channel\"), RandAdjustContrast(prob=1.0)]),\n    Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim=\"no_channel\"), RandAdjustContrast(prob=1.0)]),\n    (0, 1),\n    (1, 128, 128, 128),\n]\n\n\nclass TestCompose(Compose):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __call__(self, input_, lazy=False):\n        img = self.transforms[0](input_)\n        metadata = img.meta\n        img = self.transforms[1](img)\n        img = self.transforms[2](img, lazy=lazy)\n        metadata = img.meta\n        return self.transforms[3](img), metadata\n\n\nTEST_CASE_3 = [\n    TestCompose(\n        [\n            LoadImage(image_only=True),\n            EnsureChannelFirst(channel_dim=\"no_channel\"),\n            Spacing(pixdim=(2, 2, 4)),\n            RandAdjustContrast(prob=1.0),\n        ]\n    ),\n    TestCompose(\n        [\n            LoadImage(image_only=True),\n            EnsureChannelFirst(channel_dim=\"no_channel\"),\n            Spacing(pixdim=(2, 2, 4)),\n            RandAdjustContrast(prob=1.0),\n        ]\n    ),\n    (0, 2),\n    (1, 64, 64, 33),\n]\n\nTEST_CASE_4 = [\n    Compose([LoadImage(image_only=True), EnsureChannelFirst(channel_dim=\"no_channel\"), RandGaussianNoise(prob=1.0)]),\n    (1, 128, 128, 128),\n]\n\n\nclass TestArrayDataset(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, img_transform, label_transform, indices, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_image1 = os.path.join(tempdir, \"test_image1.nii.gz\")\n            test_seg1 = os.path.join(tempdir, \"test_seg1.nii.gz\")\n            test_image2 = os.path.join(tempdir, \"test_image2.nii.gz\")\n            test_seg2 = os.path.join(tempdir, \"test_seg2.nii.gz\")\n            nib.save(test_image, test_image1)\n            nib.save(test_image, test_seg1)\n            nib.save(test_image, test_image2)\n            nib.save(test_image, test_seg2)\n            test_images = [test_image1, test_image2]\n            test_segs = [test_seg1, test_seg2]\n            test_labels = [1, 1]\n            dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None)\n            self.assertEqual(len(dataset), 2)\n            dataset.set_random_state(1234)\n            data1 = dataset[0]\n            data2 = dataset[1]\n\n            self.assertTupleEqual(data1[indices[0]].shape, expected_shape)\n            self.assertTupleEqual(data1[indices[1]].shape, expected_shape)\n            np.testing.assert_allclose(data1[indices[0]], data1[indices[1]])\n            self.assertTupleEqual(data2[indices[0]].shape, expected_shape)\n            self.assertTupleEqual(data2[indices[1]].shape, expected_shape)\n            np.testing.assert_allclose(data2[indices[0]], data2[indices[0]])\n\n            dataset = ArrayDataset(test_images, img_transform, test_segs, label_transform, test_labels, None)\n            dataset.set_random_state(1234)\n            _ = dataset[0]\n            data2_new = dataset[1]\n            np.testing.assert_allclose(data2[indices[0]], data2_new[indices[0]], atol=1e-3)\n\n    @parameterized.expand([TEST_CASE_4])\n    def test_default_none(self, img_transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_image1 = os.path.join(tempdir, \"test_image1.nii.gz\")\n            test_image2 = os.path.join(tempdir, \"test_image2.nii.gz\")\n            nib.save(test_image, test_image1)\n            nib.save(test_image, test_image2)\n            test_images = [test_image1, test_image2]\n            dataset = ArrayDataset(test_images, img_transform)\n            self.assertEqual(len(dataset), 2)\n            dataset.set_random_state(1234)\n            data1 = dataset[0]\n            data2 = dataset[1]\n            self.assertTupleEqual(data1.shape, expected_shape)\n            self.assertTupleEqual(data2.shape, expected_shape)\n\n            dataset = ArrayDataset(test_images, img_transform)\n            dataset.set_random_state(1234)\n            _ = dataset[0]\n            data2_new = dataset[1]\n            np.testing.assert_allclose(data2, data2_new, atol=1e-3)\n\n    @parameterized.expand([TEST_CASE_4])\n    def test_dataloading_img(self, img_transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_image1 = os.path.join(tempdir, \"test_image1.nii.gz\")\n            test_image2 = os.path.join(tempdir, \"test_image2.nii.gz\")\n            nib.save(test_image, test_image1)\n            nib.save(test_image, test_image2)\n            test_images = [test_image1, test_image2]\n            dataset = ArrayDataset(test_images, img_transform)\n            self.assertEqual(len(dataset), 2)\n            dataset.set_random_state(1234)\n            n_workers = 0 if sys.platform == \"win32\" else 2\n            loader = DataLoader(dataset, batch_size=10, num_workers=n_workers)\n            imgs = next(iter(loader))  # test batching\n            np.testing.assert_allclose(imgs.shape, [2] + list(expected_shape))\n\n            dataset.set_random_state(1234)\n            new_imgs = next(iter(loader))  # test batching\n            np.testing.assert_allclose(imgs, new_imgs, atol=1e-3)\n\n    @parameterized.expand([TEST_CASE_4])\n    def test_dataloading_img_label(self, img_transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=(128, 128, 128)).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_image1 = os.path.join(tempdir, \"test_image1.nii.gz\")\n            test_image2 = os.path.join(tempdir, \"test_image2.nii.gz\")\n            test_label1 = os.path.join(tempdir, \"test_label1.nii.gz\")\n            test_label2 = os.path.join(tempdir, \"test_label2.nii.gz\")\n            nib.save(test_image, test_image1)\n            nib.save(test_image, test_image2)\n            nib.save(test_image, test_label1)\n            nib.save(test_image, test_label2)\n            test_images = [test_image1, test_image2]\n            test_labels = [test_label1, test_label2]\n            dataset = ArrayDataset(test_images, img_transform, test_labels, img_transform)\n            self.assertEqual(len(dataset), 2)\n            dataset.set_random_state(1234)\n            n_workers = 0 if sys.platform == \"win32\" else 2\n            loader = DataLoader(dataset, batch_size=10, num_workers=n_workers)\n            data = next(iter(loader))  # test batching\n            np.testing.assert_allclose(data[0].shape, [2] + list(expected_shape))\n\n            dataset.set_random_state(1234)\n            new_data = next(iter(loader))  # test batching\n            np.testing.assert_allclose(data[0], new_data[0], atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_box_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.box_utils import (\n    CenterSizeMode,\n    CornerCornerModeTypeA,\n    CornerCornerModeTypeB,\n    CornerCornerModeTypeC,\n    CornerSizeMode,\n    box_area,\n    box_centers,\n    box_giou,\n    box_iou,\n    box_pair_giou,\n    boxes_center_distance,\n    centers_in_boxes,\n    clip_boxes_to_image,\n    convert_box_mode,\n    convert_box_to_standard_mode,\n    non_max_suppression,\n)\nfrom monai.utils.type_conversion import convert_data_type\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    boxes = [[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 2, 3], [0, 1, 1, 2, 2, 3]]\n    spatial_size = [4, 4, 4]\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"cccwhd\", \"half\": False},\n            CornerSizeMode,\n            p([[0, 0, 0, 0, 0, 0], [-1, 0, -1.5, 2, 2, 3], [-1, 0, -0.5, 2, 2, 3]]),\n            p([0, 12, 12]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xyzwhd\", \"half\": False},\n            CornerSizeMode,\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 2, 3], [0, 1, 1, 2, 2, 3]]),\n            p([0, 12, 12]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xyzwhd\", \"half\": True},\n            \"xyzxyz\",\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 3, 3], [0, 1, 1, 2, 3, 4]]),\n            p([0, 12, 12]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xyzwhd\", \"half\": False},\n            \"xxyyzz\",\n            p([[0, 0, 0, 0, 0, 0], [0, 2, 1, 3, 0, 3], [0, 2, 1, 3, 1, 4]]),\n            p([0, 12, 12]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xyzwhd\", \"half\": False},\n            CornerCornerModeTypeC,\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 0, 3], [0, 1, 2, 3, 1, 4]]),\n            p([0, 12, 12]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": CornerCornerModeTypeA(), \"half\": False},\n            \"xyzwhd\",\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 1, 3], [0, 1, 1, 2, 1, 2]]),\n            p([0, 6, 4]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": CornerCornerModeTypeA, \"half\": True},\n            CornerCornerModeTypeA,\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 2, 3], [0, 1, 1, 2, 2, 3]]),\n            p([0, 6, 4]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xyzxyz\", \"half\": False},\n            CornerCornerModeTypeB(),\n            p([[0, 0, 0, 0, 0, 0], [0, 2, 1, 2, 0, 3], [0, 2, 1, 2, 1, 3]]),\n            p([0, 6, 4]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xxyyzz\", \"half\": False},\n            \"xxyyzz\",\n            p([[0, 0, 0, 0, 0, 0], [0, 1, 0, 2, 2, 3], [0, 1, 1, 2, 2, 3]]),\n            p([0, 2, 1]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xxyyzz\", \"half\": True},\n            \"xyzxyz\",\n            p([[0, 0, 0, 0, 0, 0], [0, 0, 2, 1, 2, 3], [0, 1, 2, 1, 2, 3]]),\n            p([0, 2, 1]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xxyyzz\", \"half\": False},\n            \"xyzwhd\",\n            p([[0, 0, 0, 0, 0, 0], [0, 0, 2, 1, 2, 1], [0, 1, 2, 1, 1, 1]]),\n            p([0, 2, 1]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"boxes\": p(boxes), \"spatial_size\": spatial_size, \"mode\": \"xxyyzz\", \"half\": False},\n            CenterSizeMode(),\n            p([[0, 0, 0, 0, 0, 0], [0.5, 1, 2.5, 1, 2, 1], [0.5, 1.5, 2.5, 1, 1, 1]]),\n            p([0, 2, 1]),\n        ]\n    )\n\n\nclass TestCreateBoxList(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_data, mode2, expected_box, expected_area):\n        expected_box = convert_data_type(expected_box, dtype=np.float32)[0]\n        boxes1 = convert_data_type(input_data[\"boxes\"], dtype=np.float32)[0]\n        mode1 = input_data[\"mode\"]\n        half_bool = input_data[\"half\"]\n        spatial_size = input_data[\"spatial_size\"]\n\n        # test float16\n        if half_bool:\n            boxes1 = convert_data_type(boxes1, dtype=np.float16)[0]\n            expected_box = convert_data_type(expected_box, dtype=np.float16)[0]\n\n        # test convert_box_mode, convert_box_to_standard_mode\n        result2 = convert_box_mode(boxes=boxes1, src_mode=mode1, dst_mode=mode2)\n        assert_allclose(result2, expected_box, type_test=True, device_test=True, atol=0.0)\n\n        result1 = convert_box_mode(boxes=result2, src_mode=mode2, dst_mode=mode1)\n        assert_allclose(result1, boxes1, type_test=True, device_test=True, atol=0.0)\n\n        result_standard = convert_box_to_standard_mode(boxes=boxes1, mode=mode1)\n        expected_box_standard = convert_box_to_standard_mode(boxes=expected_box, mode=mode2)\n        assert_allclose(result_standard, expected_box_standard, type_test=True, device_test=True, atol=0.0)\n\n        # test box_area, box_iou, box_giou, box_pair_giou\n        assert_allclose(box_area(result_standard), expected_area, type_test=True, device_test=True, atol=0.0)\n        iou_metrics = (box_iou, box_giou)\n        for p in iou_metrics:\n            self_iou = p(boxes1=result_standard[1:2, :], boxes2=result_standard[1:1, :])\n            assert_allclose(self_iou, np.array([[]]), type_test=False)\n\n            self_iou = p(boxes1=result_standard[1:2, :], boxes2=result_standard[1:2, :])\n            assert_allclose(self_iou, np.array([[1.0]]), type_test=False)\n\n        self_iou = box_pair_giou(boxes1=result_standard[1:1, :], boxes2=result_standard[1:1, :])\n        assert_allclose(self_iou, np.array([]), type_test=False)\n\n        self_iou = box_pair_giou(boxes1=result_standard[1:2, :], boxes2=result_standard[1:2, :])\n        assert_allclose(self_iou, np.array([1.0]), type_test=False)\n\n        # test box_centers, centers_in_boxes, boxes_center_distance\n        result_standard_center = box_centers(result_standard)\n        expected_center = convert_box_mode(boxes=boxes1, src_mode=mode1, dst_mode=\"cccwhd\")[:, :3]\n        assert_allclose(result_standard_center, expected_center, type_test=True, device_test=True, atol=0.0)\n\n        center = expected_center\n        center[2, :] += 10\n        result_centers_in_boxes = centers_in_boxes(centers=center, boxes=result_standard)\n        assert_allclose(result_centers_in_boxes, np.array([False, True, False]), type_test=False)\n\n        center_dist, _, _ = boxes_center_distance(boxes1=result_standard[1:2, :], boxes2=result_standard[1:1, :])\n        assert_allclose(center_dist, np.array([[]]), type_test=False)\n        center_dist, _, _ = boxes_center_distance(boxes1=result_standard[1:2, :], boxes2=result_standard[1:2, :])\n        assert_allclose(center_dist, np.array([[0.0]]), type_test=False)\n        center_dist, _, _ = boxes_center_distance(boxes1=result_standard[0:1, :], boxes2=result_standard[0:1, :])\n        assert_allclose(center_dist, np.array([[0.0]]), type_test=False)\n\n        # test clip_boxes_to_image\n        clipped_boxes, keep = clip_boxes_to_image(expected_box_standard, spatial_size, remove_empty=True)\n        assert_allclose(\n            expected_box_standard[keep, :], expected_box_standard[1:, :], type_test=True, device_test=True, atol=0.0\n        )\n        assert_allclose(\n            id(clipped_boxes) != id(expected_box_standard), True, type_test=False, device_test=False, atol=0.0\n        )\n\n        # test non_max_suppression\n        nms_box = non_max_suppression(\n            boxes=result_standard, scores=boxes1[:, 1] / 2.0, nms_thresh=1.0, box_overlap_metric=box_giou\n        )\n        assert_allclose(nms_box, [1, 2, 0], type_test=False)\n\n        nms_box = non_max_suppression(\n            boxes=result_standard, scores=boxes1[:, 1] / 2.0, nms_thresh=-1.0, box_overlap_metric=box_iou\n        )\n        assert_allclose(nms_box, [1], type_test=False)\n\n\nclass TestBoxUtilsDtype(unittest.TestCase):\n    @parameterized.expand(\n        [\n            # numpy dtypes\n            (np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32)),\n            (np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32)),\n            # torch dtypes\n            (\n                torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64),\n                torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64),\n            ),\n            (\n                torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32),\n                torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32),\n            ),\n            # mixed numpy (int + float)\n            (np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32)),\n            # mixed torch (int + float)\n            (\n                torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64),\n                torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32),\n            ),\n        ]\n    )\n    def test_dtype_behavior(self, boxes1, boxes2):\n        funcs = [box_iou, box_giou, box_pair_giou]\n        for func in funcs:\n            result = func(boxes1, boxes2)\n\n            if isinstance(result, np.ndarray):\n                self.assertTrue(\n                    np.issubdtype(result.dtype, np.floating), f\"{func.__name__} expected float, got {result.dtype}\"\n                )\n            elif torch.is_tensor(result):\n                self.assertTrue(\n                    torch.is_floating_point(result), f\"{func.__name__} expected float tensor, got {result.dtype}\"\n                )\n            else:\n                self.fail(f\"Unexpected return type {type(result)}\")\n\n    def test_integer_truncation_bug(self):\n        # Verify fix for #8553: IoU < 1.0 with integer inputs should not truncate to 0\n        boxes1 = np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32)\n        boxes2 = np.array([[1, 1, 1, 3, 3, 3]], dtype=np.int32)\n\n        iou = box_iou(boxes1, boxes2)\n        self.assertTrue(np.issubdtype(iou.dtype, np.floating))\n        self.assertGreater(iou[0, 0], 0.0, \"IoU should not be truncated to 0\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_cachedataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset\nfrom monai.transforms import Compose, Lambda, LoadImaged, RandLambda, ThreadUnsafe, Transform\n\nTEST_CASE_1 = [Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])]), (128, 128, 128)]\n\nTEST_CASE_2 = [None, (128, 128, 128)]\n\nTEST_DS = []\nfor c in (0, 1, 2):\n    for l in (0, 1, 2):\n        TEST_DS.append([False, c, 0 if sys.platform in (\"darwin\", \"win32\") else l])\n    if sys.platform not in (\"darwin\", \"win32\"):\n        # persistent_workers need l > 0\n        for l in (1, 2):\n            TEST_DS.append([True, c, l])\n\n\nclass TestCacheDataset(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape(self, transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data = []\n            for i in [\"1\", \"2\"]:\n                for k in [\"image\", \"label\", \"extra\"]:\n                    nib.save(test_image, os.path.join(tempdir, f\"{k}{i}.nii.gz\"))\n                test_data.append({k: os.path.join(tempdir, f\"{k}{i}.nii.gz\") for k in [\"image\", \"label\", \"extra\"]})\n\n            dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True)\n            data1 = dataset[0]\n            data2 = dataset[1]\n            data3 = dataset[0:-1]\n            data4 = dataset[-1]\n            self.assertEqual(len(data3), 1)\n\n            if transform is None:\n                # Check without providing transform\n                dataset2 = CacheDataset(data=test_data, cache_rate=0.5, as_contiguous=True)\n                for k in [\"image\", \"label\", \"extra\"]:\n                    self.assertEqual(dataset[0][k], dataset2[0][k])\n\n        if transform is None:\n            self.assertEqual(data1[\"image\"], os.path.join(tempdir, \"image1.nii.gz\"))\n            self.assertEqual(data2[\"label\"], os.path.join(tempdir, \"label2.nii.gz\"))\n            self.assertEqual(data4[\"image\"], os.path.join(tempdir, \"image2.nii.gz\"))\n        else:\n            self.assertTupleEqual(data1[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data1[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data1[\"extra\"].shape, expected_shape)\n            self.assertTupleEqual(data2[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data2[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data2[\"extra\"].shape, expected_shape)\n            for d in data3:\n                self.assertTupleEqual(d[\"image\"].shape, expected_shape)\n\n    def test_set_data(self):\n        data_list1 = list(range(10))\n\n        transform = Compose([Lambda(func=lambda x: np.array([x * 10])), RandLambda(func=lambda x: x + 1)])\n\n        dataset = CacheDataset(\n            data=data_list1,\n            transform=transform,\n            cache_rate=1.0,\n            num_workers=4,\n            progress=True,\n            copy_cache=not sys.platform == \"linux\",\n        )\n\n        num_workers = 2 if sys.platform == \"linux\" else 0\n        dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1)\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d)\n        # simulate another epoch, the cache content should not be modified\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d)\n\n        # update the datalist and fill the cache content\n        data_list2 = list(range(-10, 0))\n        dataset.set_data(data=data_list2)\n        # rerun with updated cache content\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d)\n\n\nclass _StatefulTransform(Transform, ThreadUnsafe):\n    \"\"\"\n    A transform with an internal state.\n    The state is changing at each call.\n    \"\"\"\n\n    def __init__(self):\n        self.property = 1\n\n    def __call__(self, data):\n        self.property = self.property + 1\n        return data * 100 + self.property\n\n\nclass TestCacheThread(unittest.TestCase):\n    \"\"\"\n    cache dataset and persistent dataset should behave in the same way when used with different loader settings.\n    loader's are tested with two epochs.\n    \"\"\"\n\n    @parameterized.expand(TEST_DS)\n    def test_thread_safe(self, persistent_workers, cache_workers, loader_workers):\n        expected = [102, 202, 302, 402, 502, 602, 702, 802, 902, 1002]\n        _kwg = {\"persistent_workers\": persistent_workers}\n        data_list = list(range(1, 11))\n        dataset = CacheDataset(\n            data=data_list, transform=_StatefulTransform(), cache_rate=1.0, num_workers=cache_workers, progress=False\n        )\n        self.assertListEqual(expected, list(dataset))\n        loader = DataLoader(\n            CacheDataset(\n                data=data_list,\n                transform=_StatefulTransform(),\n                cache_rate=1.0,\n                num_workers=cache_workers,\n                progress=False,\n            ),\n            batch_size=1,\n            num_workers=loader_workers,\n            **_kwg,\n        )\n        self.assertListEqual(expected, [y.item() for y in loader])\n        self.assertListEqual(expected, [y.item() for y in loader])\n\n        dataset = SmartCacheDataset(\n            data=data_list,\n            transform=_StatefulTransform(),\n            cache_rate=0.7,\n            replace_rate=0.5,\n            num_replace_workers=cache_workers,\n            progress=False,\n            shuffle=False,\n        )\n        self.assertListEqual(expected[:7], list(dataset))\n        loader = DataLoader(\n            SmartCacheDataset(\n                data=data_list,\n                transform=_StatefulTransform(),\n                cache_rate=0.7,\n                replace_rate=0.5,\n                num_replace_workers=cache_workers,\n                progress=False,\n                shuffle=False,\n            ),\n            batch_size=1,\n            num_workers=loader_workers,\n            **_kwg,\n        )\n        self.assertListEqual(expected[:7], [y.item() for y in loader])\n        self.assertListEqual(expected[:7], [y.item() for y in loader])\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            pdata = PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir)\n            self.assertListEqual(expected, list(pdata))\n            loader = DataLoader(\n                PersistentDataset(data=data_list, transform=_StatefulTransform(), cache_dir=tempdir),\n                batch_size=1,\n                num_workers=loader_workers,\n                shuffle=False,\n                **_kwg,\n            )\n            self.assertListEqual(expected, [y.item() for y in loader])\n            self.assertListEqual(expected, [y.item() for y in loader])\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_hash_as_key(self, transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data = []\n            for i in [\"1\", \"2\", \"2\", \"3\", \"3\"]:\n                for k in [\"image\", \"label\", \"extra\"]:\n                    nib.save(test_image, os.path.join(tempdir, f\"{k}{i}.nii.gz\"))\n                test_data.append({k: os.path.join(tempdir, f\"{k}{i}.nii.gz\") for k in [\"image\", \"label\", \"extra\"]})\n\n            dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True)\n            self.assertEqual(len(dataset), 5)\n            # ensure no duplicated cache content\n            self.assertEqual(len(dataset._cache), 3)\n            self.assertEqual(len(dataset._hash_keys), 3)\n            self.assertEqual(dataset.cache_num, 3)\n            data1 = dataset[0]\n            data2 = dataset[1]\n            data3 = dataset[-1]\n            # test slice indices\n            data4 = dataset[0:-1]\n            self.assertEqual(len(data4), 4)\n\n            if transform is None:\n                self.assertEqual(data1[\"image\"], os.path.join(tempdir, \"image1.nii.gz\"))\n                self.assertEqual(data2[\"label\"], os.path.join(tempdir, \"label2.nii.gz\"))\n                self.assertEqual(data3[\"image\"], os.path.join(tempdir, \"image3.nii.gz\"))\n            else:\n                self.assertTupleEqual(data1[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data3[\"image\"].shape, expected_shape)\n                for d in data4:\n                    self.assertTupleEqual(d[\"image\"].shape, expected_shape)\n\n            test_data2 = test_data[:3]\n            dataset.set_data(data=test_data2)\n            self.assertEqual(len(dataset), 3)\n            # ensure no duplicated cache content\n            self.assertEqual(len(dataset._cache), 2)\n            self.assertEqual(dataset.cache_num, 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_cachedataset_parallel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import CacheDataset\nfrom monai.transforms import Compose, LoadImaged\n\nTEST_CASE_1 = [0, 5, Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])])]\n\nTEST_CASE_2 = [4, 5, Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])])]\n\nTEST_CASE_3 = [4, 5, None]\n\n\nclass TestCacheDatasetParallel(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, num_workers, dataset_size, transform):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            test_data = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1.nii.gz\"),\n                }\n            ] * dataset_size\n            dataset = CacheDataset(data=test_data, transform=transform, cache_rate=1, num_workers=num_workers)\n\n        self.assertEqual(len(dataset._cache), dataset.cache_num)\n        for i in range(dataset.cache_num):\n            self.assertIsNotNone(dataset._cache[i])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_cachedataset_persistent_workers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.data import CacheDataset, DataLoader, create_test_image_2d\nfrom monai.transforms import Compose, RandAffined, Spacingd\n\n\nclass TestTransformsWCacheDatasetAndPersistentWorkers(unittest.TestCase):\n\n    def test_duplicate_transforms(self):\n        data = [{\"img\": create_test_image_2d(128, 128, num_seg_classes=1, channel_dim=0)[0]} for _ in range(2)]\n\n        # at least 1 deterministic followed by at least 1 random\n        transform = Compose([Spacingd(\"img\", pixdim=(1, 1)), RandAffined(\"img\", prob=1.0)])\n\n        # cachedataset and data loader w persistent_workers\n        train_ds = CacheDataset(data, transform, cache_num=1)\n        # num_workers > 1 may fail randomly with 21.09 on A100 test node\n        # https://github.com/Project-MONAI/MONAI/issues/3283\n        train_loader = DataLoader(train_ds, num_workers=1, persistent_workers=True)\n\n        b1 = next(iter(train_loader))\n        b2 = next(iter(train_loader))\n\n        self.assertEqual(len(b1[\"img\"].applied_operations), len(b2[\"img\"].applied_operations))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_cachentransdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import CacheNTransDataset\nfrom monai.transforms import LoadImaged, ShiftIntensityd\n\nTEST_CASE_1 = [\n    [\n        LoadImaged(keys=\"image\"),\n        ShiftIntensityd(keys=\"image\", offset=1.0),\n        ShiftIntensityd(keys=\"image\", offset=2.0),\n        ShiftIntensityd(keys=\"image\", offset=3.0),\n    ],\n    (128, 128, 128),\n]\n\n\nclass TestCacheNTransDataset(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1])\n    def test_n_trans(self, transform, expected_shape):\n        data_array = np.random.randint(0, 2, size=[128, 128, 128]).astype(float)\n        test_image = nib.Nifti1Image(data_array, np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image.nii.gz\"))\n            test_data = [{\"image\": os.path.join(tempdir, \"test_image.nii.gz\")}]\n\n            cache_dir = os.path.join(os.path.join(tempdir, \"cache\"), \"data\")\n            dataset_precached = CacheNTransDataset(\n                data=test_data, transform=transform, cache_dir=cache_dir, cache_n_trans=2\n            )\n            data_precached = dataset_precached[0]\n            self.assertTupleEqual(data_precached[\"image\"].shape, expected_shape)\n\n            dataset_postcached = CacheNTransDataset(\n                data=test_data, transform=transform, cache_dir=cache_dir, cache_n_trans=2\n            )\n            data_postcached = dataset_postcached[0]\n            self.assertTupleEqual(data_postcached[\"image\"].shape, expected_shape)\n            np.testing.assert_allclose(data_array + 6.0, data_postcached[\"image\"])\n            np.testing.assert_allclose(data_precached[\"image\"], data_postcached[\"image\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_check_missing_files.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport nibabel as nib\nimport numpy as np\n\nfrom monai.data import check_missing_files\n\n\nclass TestCheckMissingFiles(unittest.TestCase):\n\n    def test_content(self):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_image2.nii.gz\"))\n\n            datalist = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": [os.path.join(tempdir, \"test_label1.nii.gz\"), os.path.join(tempdir, \"test_extra1.nii.gz\")],\n                },\n                {\n                    \"image\": Path(os.path.join(tempdir, \"test_image2.nii.gz\")),\n                    \"label\": Path(os.path.join(tempdir, \"test_label_missing.nii.gz\")),\n                },\n            ]\n\n            missings = check_missing_files(datalist=datalist, keys=[\"image\", \"label\"])\n            self.assertEqual(len(missings), 1)\n            self.assertEqual(str(missings[0]), os.path.join(tempdir, \"test_label_missing.nii.gz\"))\n\n            # test with missing key and relative path\n            datalist = [{\"image\": \"test_image1.nii.gz\", \"label\": \"test_label_missing.nii.gz\"}]\n            missings = check_missing_files(\n                datalist=datalist, keys=[\"image\", \"label\", \"test\"], root_dir=tempdir, allow_missing_keys=True\n            )\n            self.assertEqual(f\"{missings[0]}\", os.path.join(tempdir, \"test_label_missing.nii.gz\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_create_cross_validation_datalist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom monai.data import create_cross_validation_datalist, load_decathlon_datalist\n\n\nclass TestCreateCrossValidationDatalist(unittest.TestCase):\n\n    def test_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            datalist = []\n            for i in range(5):\n                image = os.path.join(tempdir, f\"test_image{i}.nii.gz\")\n                label = os.path.join(tempdir, f\"test_label{i}.nii.gz\")\n                Path(image).touch()\n                Path(label).touch()\n                datalist.append({\"image\": image, \"label\": label})\n\n            filename = os.path.join(tempdir, \"test_datalist.json\")\n            result = create_cross_validation_datalist(\n                datalist=datalist,\n                nfolds=5,\n                train_folds=[0, 1, 2, 3],\n                val_folds=4,\n                train_key=\"test_train\",\n                val_key=\"test_val\",\n                filename=Path(filename),\n                shuffle=True,\n                seed=123,\n                check_missing=True,\n                keys=[\"image\", \"label\"],\n                root_dir=None,\n                allow_missing_keys=False,\n                raise_error=True,\n            )\n\n            loaded = load_decathlon_datalist(filename, data_list_key=\"test_train\")\n            for r, l in zip(result[\"test_train\"], loaded):\n                self.assertEqual(r[\"image\"], l[\"image\"])\n                self.assertEqual(r[\"label\"], l[\"label\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_csv_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport pandas as pd\n\nfrom monai.data import CSVDataset\nfrom monai.transforms import ToNumpyd\n\n\nclass TestCSVDataset(unittest.TestCase):\n\n    def test_values(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data1 = [\n                [\"subject_id\", \"label\", \"image\", \"ehr_0\", \"ehr_1\", \"ehr_2\"],\n                [\"s000000\", 5, \"./imgs/s000000.png\", 2.007843256, 2.29019618, 2.054902077],\n                [\"s000001\", 0, \"./imgs/s000001.png\", 6.839215755, 6.474509716, 5.862744808],\n                [\"s000002\", 4, \"./imgs/s000002.png\", 3.772548914, 4.211764812, 4.635294437],\n                [\"s000003\", 1, \"./imgs/s000003.png\", 3.333333254, 3.235294342, 3.400000095],\n                [\"s000004\", 9, \"./imgs/s000004.png\", 6.427451134, 6.254901886, 5.976470947],\n            ]\n            test_data2 = [\n                [\"subject_id\", \"ehr_3\", \"ehr_4\", \"ehr_5\", \"ehr_6\", \"ehr_7\", \"ehr_8\"],\n                [\"s000000\", 3.019608021, 3.807843208, 3.584313869, 3.141176462, 3.1960783, 4.211764812],\n                [\"s000001\", 5.192157269, 5.274509907, 5.250980377, 4.647058964, 4.886274338, 4.392156601],\n                [\"s000002\", 5.298039436, 9.545097351, 12.57254887, 6.799999714, 2.1960783, 1.882352948],\n                [\"s000003\", 3.164705753, 3.086274624, 3.725490093, 3.698039293, 3.698039055, 3.701960802],\n                [\"s000004\", 6.26274538, 7.717647076, 9.584313393, 6.082352638, 2.662744999, 2.34117651],\n            ]\n            test_data3 = [\n                [\"subject_id\", \"ehr_9\", \"ehr_10\", \"meta_0\", \"meta_1\", \"meta_2\"],\n                [\"s000000\", 6.301961422, 6.470588684, \"TRUE\", \"TRUE\", \"TRUE\"],\n                [\"s000001\", 5.219608307, 7.827450752, \"FALSE\", \"TRUE\", \"FALSE\"],\n                [\"s000002\", 1.882352948, 2.031372547, \"TRUE\", \"FALSE\", \"TRUE\"],\n                [\"s000003\", 3.309803963, 3.729412079, \"FALSE\", \"FALSE\", \"TRUE\"],\n                [\"s000004\", 2.062745094, 2.34117651, \"FALSE\", \"TRUE\", \"TRUE\"],\n                # generate NaN values in the row\n                [\"s000005\", 3.353655643, 1.675674543, \"TRUE\", \"TRUE\", \"FALSE\"],\n            ]\n\n            def prepare_csv_file(data, filepath):\n                with open(filepath, \"a\") as f:\n                    for d in data:\n                        f.write((\",\".join([str(i) for i in d])) + \"\\n\")\n\n            filepath1 = os.path.join(tempdir, \"test_data1.csv\")\n            filepath2 = os.path.join(tempdir, \"test_data2.csv\")\n            filepath3 = os.path.join(tempdir, \"test_data3.csv\")\n            filepaths = [filepath1, filepath2, filepath3]\n            prepare_csv_file(test_data1, filepath1)\n            prepare_csv_file(test_data2, filepath2)\n            prepare_csv_file(test_data3, filepath3)\n\n            # test single CSV file\n            dataset = CSVDataset(filepath1)\n            self.assertDictEqual(\n                {k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()},\n                {\n                    \"subject_id\": \"s000002\",\n                    \"label\": 4,\n                    \"image\": \"./imgs/s000002.png\",\n                    \"ehr_0\": 3.7725,\n                    \"ehr_1\": 4.2118,\n                    \"ehr_2\": 4.6353,\n                },\n            )\n\n            # test multiple CSV files, join tables with kwargs\n            dataset = CSVDataset(filepaths, on=\"subject_id\")\n            self.assertDictEqual(\n                {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[3].items()},\n                {\n                    \"subject_id\": \"s000003\",\n                    \"label\": 1,\n                    \"image\": \"./imgs/s000003.png\",\n                    \"ehr_0\": 3.3333,\n                    \"ehr_1\": 3.2353,\n                    \"ehr_2\": 3.4000,\n                    \"ehr_3\": 3.1647,\n                    \"ehr_4\": 3.0863,\n                    \"ehr_5\": 3.7255,\n                    \"ehr_6\": 3.6980,\n                    \"ehr_7\": 3.6980,\n                    \"ehr_8\": 3.7020,\n                    \"ehr_9\": 3.3098,\n                    \"ehr_10\": 3.7294,\n                    \"meta_0\": False,\n                    \"meta_1\": False,\n                    \"meta_2\": True,\n                },\n            )\n\n            # test selected rows and columns\n            dataset = CSVDataset(\n                src=filepaths,\n                row_indices=[[0, 2], 3],  # load row: 0, 1, 3\n                col_names=[\"subject_id\", \"image\", \"ehr_1\", \"ehr_7\", \"meta_1\"],\n            )\n            self.assertEqual(len(dataset), 3)\n            self.assertDictEqual(\n                {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in dataset[-1].items()},\n                {\n                    \"subject_id\": \"s000003\",\n                    \"image\": \"./imgs/s000003.png\",\n                    \"ehr_1\": 3.2353,\n                    \"ehr_7\": 3.6980,\n                    \"meta_1\": False,\n                },\n            )\n\n            # test group columns\n            dataset = CSVDataset(\n                src=filepaths,\n                row_indices=[1, 3],  # load row: 1, 3\n                col_names=[\"subject_id\", \"image\", *[f\"ehr_{i}\" for i in range(11)], \"meta_0\", \"meta_1\", \"meta_2\"],\n                col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(11)], \"meta12\": [\"meta_1\", \"meta_2\"]},\n            )\n            np.testing.assert_allclose(\n                [round(i, 4) for i in dataset[-1][\"ehr\"]],\n                [3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294],\n            )\n            np.testing.assert_allclose(dataset[-1][\"meta12\"], [False, True])\n\n            # test transform\n            dataset = CSVDataset(\n                src=filepaths, col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(5)]}, transform=ToNumpyd(keys=\"ehr\")\n            )\n            self.assertEqual(len(dataset), 5)\n            expected = [\n                [2.0078, 2.2902, 2.0549, 3.0196, 3.8078],\n                [6.8392, 6.4745, 5.8627, 5.1922, 5.2745],\n                [3.7725, 4.2118, 4.6353, 5.2980, 9.5451],\n                [3.3333, 3.2353, 3.4000, 3.1647, 3.0863],\n                [6.4275, 6.2549, 5.9765, 6.2627, 7.7176],\n            ]\n            for item, exp in zip(dataset, expected):\n                self.assertTrue(isinstance(item[\"ehr\"], np.ndarray))\n                np.testing.assert_allclose(np.around(item[\"ehr\"], 4), exp)\n\n            # test default values and dtype\n            dataset = CSVDataset(\n                src=filepaths,\n                col_names=[\"subject_id\", \"image\", \"ehr_1\", \"ehr_9\", \"meta_1\"],\n                col_types={\"image\": {\"type\": str, \"default\": \"No image\"}, \"ehr_1\": {\"type\": int, \"default\": 0}},\n                how=\"outer\",  # generate NaN values in this merge mode\n            )\n            self.assertEqual(len(dataset), 6)\n            self.assertEqual(dataset[-1][\"image\"], \"No image\")\n            self.assertEqual(type(dataset[-1][\"ehr_1\"]), int)\n            np.testing.assert_allclose(dataset[-1][\"ehr_9\"], 3.3537, rtol=1e-2)\n\n            # test pre-loaded DataFrame\n            df = pd.read_csv(filepath1)\n            dataset = CSVDataset(src=df)\n            self.assertDictEqual(\n                {k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()},\n                {\n                    \"subject_id\": \"s000002\",\n                    \"label\": 4,\n                    \"image\": \"./imgs/s000002.png\",\n                    \"ehr_0\": 3.7725,\n                    \"ehr_1\": 4.2118,\n                    \"ehr_2\": 4.6353,\n                },\n            )\n\n            # test pre-loaded DataFrame subset\n            df = pd.read_csv(filepath1)\n            df_subset = df.iloc[[1, 3, 4]]\n            dataset = CSVDataset(src=df_subset, col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(3)]})\n            self.assertEqual(len(dataset), 3)\n            np.testing.assert_allclose([round(i, 4) for i in dataset[1][\"ehr\"]], [3.3333, 3.2353, 3.4000])\n\n            # test pre-loaded DataFrame subset with row_indices != None\n            df = pd.read_csv(filepath1)\n            df_subset = df.iloc[[1, 3, 4]]\n            dataset = CSVDataset(src=df_subset, row_indices=[1, 3], col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(3)]})\n            self.assertEqual(len(dataset), 2)\n            np.testing.assert_allclose([round(i, 4) for i in dataset[1][\"ehr\"]], [3.3333, 3.2353, 3.4000])\n\n            # test pre-loaded multiple DataFrames, join tables with kwargs\n            dfs = [pd.read_csv(i) for i in filepaths]\n            dataset = CSVDataset(src=dfs, on=\"subject_id\")\n            self.assertEqual(dataset[3][\"subject_id\"], \"s000003\")\n            self.assertEqual(dataset[3][\"label\"], 1)\n            self.assertEqual(round(dataset[3][\"ehr_0\"], 4), 3.3333)\n            self.assertEqual(dataset[3][\"meta_0\"], False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_csv_iterable_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport pandas as pd\n\nfrom monai.data import CSVIterableDataset, DataLoader\nfrom monai.transforms import ToNumpyd\nfrom tests.test_utils import skip_if_windows\n\n\n@skip_if_windows\nclass TestCSVIterableDataset(unittest.TestCase):\n    def test_values(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data1 = [\n                [\"subject_id\", \"label\", \"image\", \"ehr_0\", \"ehr_1\", \"ehr_2\"],\n                [\"s000000\", 5, \"./imgs/s000000.png\", 2.007843256, 2.29019618, 2.054902077],\n                [\"s000001\", 0, \"./imgs/s000001.png\", 6.839215755, 6.474509716, 5.862744808],\n                [\"s000002\", 4, \"./imgs/s000002.png\", 3.772548914, 4.211764812, 4.635294437],\n                [\"s000003\", 1, \"./imgs/s000003.png\", 3.333333254, 3.235294342, 3.400000095],\n                [\"s000004\", 9, \"./imgs/s000004.png\", 6.427451134, 6.254901886, 5.976470947],\n            ]\n            test_data2 = [\n                [\"subject_id\", \"ehr_3\", \"ehr_4\", \"ehr_5\", \"ehr_6\", \"ehr_7\", \"ehr_8\"],\n                [\"s000000\", 3.019608021, 3.807843208, 3.584313869, 3.141176462, 3.1960783, 4.211764812],\n                [\"s000001\", 5.192157269, 5.274509907, 5.250980377, 4.647058964, 4.886274338, 4.392156601],\n                [\"s000002\", 5.298039436, 9.545097351, 12.57254887, 6.799999714, 2.1960783, 1.882352948],\n                [\"s000003\", 3.164705753, 3.086274624, 3.725490093, 3.698039293, 3.698039055, 3.701960802],\n                [\"s000004\", 6.26274538, 7.717647076, 9.584313393, 6.082352638, 2.662744999, 2.34117651],\n            ]\n            test_data3 = [\n                [\"subject_id\", \"ehr_9\", \"ehr_10\", \"meta_0\", \"meta_1\", \"meta_2\"],\n                [\"s000000\", 6.301961422, 6.470588684, \"TRUE\", \"TRUE\", \"TRUE\"],\n                [\"s000001\", 5.219608307, 7.827450752, \"FALSE\", \"TRUE\", \"FALSE\"],\n                [\"s000002\", 1.882352948, 2.031372547, \"TRUE\", \"FALSE\", \"TRUE\"],\n                [\"s000003\", 3.309803963, 3.729412079, \"FALSE\", \"FALSE\", \"TRUE\"],\n                [\"s000004\", 2.062745094, 2.34117651, \"FALSE\", \"TRUE\", \"TRUE\"],\n            ]\n\n            def prepare_csv_file(data, filepath):\n                with open(filepath, \"a\") as f:\n                    for d in data:\n                        f.write((\",\".join([str(i) for i in d])) + \"\\n\")\n\n            filepath1 = os.path.join(tempdir, \"test_data1.csv\")\n            filepath2 = os.path.join(tempdir, \"test_data2.csv\")\n            filepath3 = os.path.join(tempdir, \"test_data3.csv\")\n            filepaths = [filepath1, filepath2, filepath3]\n            prepare_csv_file(test_data1, filepath1)\n            prepare_csv_file(test_data2, filepath2)\n            prepare_csv_file(test_data3, filepath3)\n\n            # test single CSV file\n            dataset = CSVIterableDataset(filepath1, shuffle=False)\n            count = 0\n            for item in dataset:\n                count += 1\n                if count == 3:\n                    self.assertDictEqual(\n                        {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()},\n                        {\n                            \"subject_id\": \"s000002\",\n                            \"label\": 4,\n                            \"image\": \"./imgs/s000002.png\",\n                            \"ehr_0\": 3.7725,\n                            \"ehr_1\": 4.2118,\n                            \"ehr_2\": 4.6353,\n                        },\n                    )\n                    break\n            self.assertEqual(count, 3)\n            dataset.close()\n\n            # test reset iterables\n            dataset.reset(src=filepath3)\n            count = 0\n            for i, item in enumerate(dataset):\n                count += 1\n                if i == 4:\n                    self.assertEqual(item[\"meta_0\"], False)\n            self.assertEqual(count, 5)\n            dataset.close()\n\n            # test multiple CSV files, join tables with kwargs\n            dataset = CSVIterableDataset(filepaths, on=\"subject_id\", shuffle=False)\n            count = 0\n            for item in dataset:\n                count += 1\n                if count == 4:\n                    self.assertDictEqual(\n                        {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()},\n                        {\n                            \"subject_id\": \"s000003\",\n                            \"label\": 1,\n                            \"image\": \"./imgs/s000003.png\",\n                            \"ehr_0\": 3.3333,\n                            \"ehr_1\": 3.2353,\n                            \"ehr_2\": 3.4000,\n                            \"ehr_3\": 3.1647,\n                            \"ehr_4\": 3.0863,\n                            \"ehr_5\": 3.7255,\n                            \"ehr_6\": 3.6980,\n                            \"ehr_7\": 3.6980,\n                            \"ehr_8\": 3.7020,\n                            \"ehr_9\": 3.3098,\n                            \"ehr_10\": 3.7294,\n                            \"meta_0\": False,\n                            \"meta_1\": False,\n                            \"meta_2\": True,\n                        },\n                    )\n            self.assertEqual(count, 5)\n            dataset.close()\n\n            # test selected columns and chunk size\n            dataset = CSVIterableDataset(\n                src=filepaths, chunksize=2, col_names=[\"subject_id\", \"image\", \"ehr_1\", \"ehr_7\", \"meta_1\"], shuffle=False\n            )\n            count = 0\n            for item in dataset:\n                count += 1\n                if count == 4:\n                    self.assertDictEqual(\n                        {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()},\n                        {\n                            \"subject_id\": \"s000003\",\n                            \"image\": \"./imgs/s000003.png\",\n                            \"ehr_1\": 3.2353,\n                            \"ehr_7\": 3.6980,\n                            \"meta_1\": False,\n                        },\n                    )\n            self.assertEqual(count, 5)\n            dataset.close()\n\n            # test group columns\n            dataset = CSVIterableDataset(\n                src=filepaths,\n                col_names=[\"subject_id\", \"image\", *[f\"ehr_{i}\" for i in range(11)], \"meta_0\", \"meta_1\", \"meta_2\"],\n                col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(11)], \"meta12\": [\"meta_1\", \"meta_2\"]},\n                shuffle=False,\n            )\n            count = 0\n            for item in dataset:\n                count += 1\n                if count == 4:\n                    np.testing.assert_allclose(\n                        [round(i, 4) for i in item[\"ehr\"]],\n                        [3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294],\n                    )\n                    np.testing.assert_allclose(item[\"meta12\"], [False, True])\n            self.assertEqual(count, 5)\n            dataset.close()\n\n            # test transform\n            dataset = CSVIterableDataset(\n                chunksize=2,\n                buffer_size=4,\n                src=filepaths,\n                col_groups={\"ehr\": [f\"ehr_{i}\" for i in range(5)]},\n                transform=ToNumpyd(keys=\"ehr\"),\n                shuffle=True,\n                seed=123,\n            )\n            expected = [\n                [6.8392, 6.4745, 5.8627, 5.1922, 5.2745],\n                [3.3333, 3.2353, 3.4000, 3.1647, 3.0863],\n                [3.7725, 4.2118, 4.6353, 5.298, 9.5451],\n                [6.4275, 6.2549, 5.9765, 6.2627, 7.7176],\n                [2.0078, 2.2902, 2.0549, 3.0196, 3.8078],\n            ]\n            count = 0\n            for item, exp in zip(dataset, expected):\n                count += 1\n                self.assertTrue(isinstance(item[\"ehr\"], np.ndarray))\n                np.testing.assert_allclose(np.around(item[\"ehr\"], 4), exp)\n            self.assertEqual(count, 5)\n            dataset.close()\n\n            # test multiple processes loading\n            dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys=\"label\"), shuffle=False)\n            # set num workers = 0 for mac / win\n            num_workers = 2 if sys.platform == \"linux\" else 0\n            dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2)\n            count = 0\n            for item in dataloader:\n                count += 1\n                # test the last item which only has 1 data\n                if len(item) == 1:\n                    self.assertListEqual(item[\"subject_id\"], [\"s000002\"])\n                    np.testing.assert_allclose(item[\"label\"], [4])\n                    self.assertListEqual(item[\"image\"], [\"./imgs/s000002.png\"])\n            self.assertEqual(count, 3)\n            dataset.close()\n\n            # test iterable stream\n            iters = pd.read_csv(filepath1, chunksize=1000)\n            dataset = CSVIterableDataset(src=iters, shuffle=False)\n            count = 0\n            for item in dataset:\n                count += 1\n                if count == 3:\n                    self.assertDictEqual(\n                        {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()},\n                        {\n                            \"subject_id\": \"s000002\",\n                            \"label\": 4,\n                            \"image\": \"./imgs/s000002.png\",\n                            \"ehr_0\": 3.7725,\n                            \"ehr_1\": 4.2118,\n                            \"ehr_2\": 4.6353,\n                        },\n                    )\n                    break\n            self.assertEqual(count, 3)\n            dataset.close()\n\n            # test multiple iterable streams, join tables with kwargs\n            iters = [pd.read_csv(i, chunksize=1000) for i in filepaths]\n            dataset = CSVIterableDataset(src=iters, on=\"subject_id\", shuffle=False)\n            count = 0\n            for item in dataset:\n                count += 1\n                if count == 4:\n                    self.assertEqual(item[\"subject_id\"], \"s000003\")\n                    self.assertEqual(item[\"label\"], 1)\n                    self.assertEqual(round(item[\"ehr_0\"], 4), 3.3333)\n                    self.assertEqual(item[\"meta_0\"], False)\n            self.assertEqual(count, 5)\n            # manually close the pre-loaded iterables instead of `dataset.close()`\n            for i in iters:\n                i.close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_csv_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import CSVSaver\n\n\nclass TestCSVSaver(unittest.TestCase):\n\n    def test_saved_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            saver = CSVSaver(output_dir=tempdir, filename=\"predictions.csv\", delimiter=\"\\t\")\n            meta_data = {\"filename_or_obj\": [\"testfile\" + str(i) for i in range(8)]}\n            saver.save_batch(torch.zeros(8), meta_data)\n            saver.finalize()\n            filepath = os.path.join(tempdir, \"predictions.csv\")\n            self.assertTrue(os.path.exists(filepath))\n            with open(filepath) as f:\n                reader = csv.reader(f, delimiter=\"\\t\")\n                i = 0\n                for row in reader:\n                    self.assertEqual(row[0], \"testfile\" + str(i))\n                    self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)\n                    i += 1\n                self.assertEqual(i, 8)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_dataloader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import CacheDataset, DataLoader, Dataset, ZipDataset\nfrom monai.transforms import Compose, DataStatsd, Randomizable, SimulateDelayd\nfrom monai.utils import convert_to_numpy, set_determinism\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [[{\"image\": np.asarray([1, 2, 3])}, {\"image\": np.asarray([4, 5])}]]\n\nTEST_CASE_2 = [[{\"label\": torch.as_tensor([[3], [2]])}, {\"label\": np.asarray([[1], [2]])}]]\n\n\nclass TestDataLoader(unittest.TestCase):\n    def test_values(self):\n        datalist = [\n            {\"image\": \"spleen_19.nii.gz\", \"label\": \"spleen_label_19.nii.gz\"},\n            {\"image\": \"spleen_31.nii.gz\", \"label\": \"spleen_label_31.nii.gz\"},\n        ]\n        transform = Compose(\n            [\n                DataStatsd(keys=[\"image\", \"label\"], data_shape=False, value_range=False, data_value=True),\n                SimulateDelayd(keys=[\"image\", \"label\"], delay_time=0.1),\n            ]\n        )\n        dataset = CacheDataset(data=datalist, transform=transform, cache_rate=0.5, cache_num=1)\n        n_workers = 0 if sys.platform == \"win32\" else 2\n        dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=n_workers)\n        for d in dataloader:\n            self.assertEqual(d[\"image\"][0], \"spleen_19.nii.gz\")\n            self.assertEqual(d[\"image\"][1], \"spleen_31.nii.gz\")\n            self.assertEqual(d[\"label\"][0], \"spleen_label_19.nii.gz\")\n            self.assertEqual(d[\"label\"][1], \"spleen_label_31.nii.gz\")\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_exception(self, datalist):\n        dataset = Dataset(data=datalist, transform=None)\n        dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0)\n        with self.assertRaisesRegex((TypeError, RuntimeError), \"Collate error on the key\"):\n            for _ in dataloader:\n                pass\n\n\nclass _RandomDataset(torch.utils.data.Dataset, Randomizable):\n    def __getitem__(self, index):\n        return self.R.randint(0, 1000, (1,))\n\n    def __len__(self):\n        return 8\n\n\nclass TestLoaderRandom(unittest.TestCase):\n    \"\"\"\n    Testing data loader working with the randomizable interface\n    \"\"\"\n\n    def setUp(self):\n        set_determinism(0)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand([[1], [0]])\n    def test_randomize(self, workers):\n        set_determinism(0)\n        dataset = _RandomDataset()\n        dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=workers)\n        output = []\n        for _ in range(1):  # need persistent workers for reproducibility of num_workers 0, 1\n            for batch in dataloader:\n                output.extend(batch.data.numpy().flatten().tolist())\n        set_determinism(None)\n        self.assertListEqual(output, [594, 170, 292, 589, 153, 811, 21, 550])\n\n    def test_zipdataset(self):\n        dataset = ZipDataset([_RandomDataset(), ZipDataset([_RandomDataset(), _RandomDataset()])])\n        dataloader = DataLoader(dataset, batch_size=2, num_workers=2)\n        output = []\n        for _ in range(2):\n            for batch in dataloader:\n                output.extend([convert_to_numpy(batch, wrap_sequence=False)])\n        assert_allclose(np.stack(output).flatten()[:7], np.array([594, 170, 594, 170, 594, 170, 524]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport tempfile\nimport unittest\nfrom copy import deepcopy\nfrom io import StringIO\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import Dataset\nfrom monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd\nfrom tests.transforms.compose.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys\n\nTEST_CASE_1 = [(128, 128, 128)]\n\n\nclass TestDataset(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1])\n    def test_shape(self, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_image2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            test_data = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2.nii.gz\"),\n                },\n            ]\n            test_transform = Compose(\n                [\n                    LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n                    SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n                ]\n            )\n            dataset = Dataset(data=test_data, transform=test_transform)\n            data1 = dataset[0]\n            data2 = dataset[1]\n\n            self.assertTupleEqual(data1[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data1[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data1[\"extra\"].shape, expected_shape)\n            self.assertTupleEqual(data2[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data2[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data2[\"extra\"].shape, expected_shape)\n\n            dataset = Dataset(data=test_data, transform=LoadImaged(keys=[\"image\", \"label\", \"extra\"]))\n            data1_simple = dataset[0]\n            data2_simple = dataset[1]\n            data3_simple = dataset[-1]\n            data4_simple = dataset[[0, 1]]\n\n            self.assertTupleEqual(data1_simple[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data1_simple[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data1_simple[\"extra\"].shape, expected_shape)\n            self.assertTupleEqual(data2_simple[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data2_simple[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data2_simple[\"extra\"].shape, expected_shape)\n            self.assertTupleEqual(data3_simple[\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data3_simple[\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data3_simple[\"extra\"].shape, expected_shape)\n            self.assertTupleEqual(data4_simple[0][\"image\"].shape, expected_shape)\n            self.assertTupleEqual(data4_simple[1][\"label\"].shape, expected_shape)\n            self.assertTupleEqual(data4_simple[-1][\"extra\"].shape, expected_shape)\n\n            data4_list = dataset[0:1]\n            self.assertEqual(len(data4_list), 1)\n            for d in data4_list:\n                self.assertTupleEqual(d[\"image\"].shape, expected_shape)\n\n    def test_dataset_lazy_on_call(self):\n        data = np.zeros((1, 5, 5))\n        data[0, 0:2, 0:2] = 1\n\n\nclass TestTupleDataset(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1])\n    def test_shape(self, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_image2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label2.nii.gz\"))\n            test_data = [\n                (os.path.join(tempdir, \"test_image1.nii.gz\"), os.path.join(tempdir, \"test_label1.nii.gz\")),\n                (os.path.join(tempdir, \"test_image2.nii.gz\"), os.path.join(tempdir, \"test_label2.nii.gz\")),\n            ]\n\n            test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)])\n\n            # Here test_transform is applied element by element for the tuple.\n            dataset = Dataset(data=test_data, transform=test_transform)\n            data1 = dataset[0]\n            data2 = dataset[1]\n\n            # Output is a list/tuple\n            self.assertTrue(isinstance(data1, (list, tuple)))\n            self.assertTrue(isinstance(data2, (list, tuple)))\n\n            # Number of elements are 2\n            self.assertEqual(len(data1), 2)\n            self.assertEqual(len(data2), 2)\n\n            # Output shapes are as expected\n            self.assertTupleEqual(data1[0].shape, expected_shape)\n            self.assertTupleEqual(data1[1].shape, expected_shape)\n            self.assertTupleEqual(data2[0].shape, expected_shape)\n            self.assertTupleEqual(data2[1].shape, expected_shape)\n\n            # Here test_transform is applied to the tuple as a whole.\n            test_transform = Compose(\n                [\n                    # LoadImage creates a channel-stacked image when applied to a tuple\n                    LoadImage(),\n                    # Get the channel-stacked image and the label\n                    Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])),\n                ],\n                map_items=False,\n            )\n\n            dataset = Dataset(data=test_data, transform=test_transform)\n            data1 = dataset[0]\n            data2 = dataset[1]\n\n            # Output is a list/tuple\n            self.assertTrue(isinstance(data1, (list, tuple)))\n            self.assertTrue(isinstance(data2, (list, tuple)))\n\n            # Number of elements are 2\n            self.assertEqual(len(data1), 2)\n            self.assertEqual(len(data2), 2)\n\n            # Output shapes are as expected\n            self.assertTupleEqual(data1[0].shape, expected_shape)\n            self.assertTupleEqual(data1[1].shape, expected_shape)\n            self.assertTupleEqual(data2[0].shape, expected_shape)\n            self.assertTupleEqual(data2[1].shape, expected_shape)\n\n\nclass TestDatsesetWithLazy(unittest.TestCase):\n    LOGGER_NAME = \"a_logger_name\"\n\n    def init_logger(self, name=LOGGER_NAME):\n        stream = StringIO()\n        handler = logging.StreamHandler(stream)\n        formatter = logging.Formatter(\"%(levelname)s - %(message)s\")\n        handler.setFormatter(formatter)\n        logger = logging.getLogger(name)\n        logger.setLevel(logging.INFO)\n        while len(logger.handlers) > 0:\n            logger.removeHandler(logger.handlers[-1])\n        logger.addHandler(handler)\n        return handler, stream\n\n    @parameterized.expand(TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES)\n    def test_dataset_lazy_with_logging(self, compose_type, pipeline, lazy, expected):\n        handler, stream = self.init_logger(name=self.LOGGER_NAME)\n\n        data = data_from_keys(None, 12, 16)\n        c = compose_type(deepcopy(pipeline), log_stats=self.LOGGER_NAME, lazy=lazy)\n        ds = Dataset([data], transform=c)\n        ds[0]\n\n        handler.flush()\n        actual = stream.getvalue()\n        self.assertEqual(actual, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_dataset_func.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport tempfile\nimport unittest\n\nfrom monai.data import Dataset, DatasetFunc, load_decathlon_datalist, partition_dataset\n\n\nclass TestDatasetFunc(unittest.TestCase):\n\n    def test_seg_values(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # prepare test datalist file\n            test_data = {\n                \"name\": \"Spleen\",\n                \"description\": \"Spleen Segmentation\",\n                \"labels\": {\"0\": \"background\", \"1\": \"spleen\"},\n                \"training\": [\n                    {\"image\": \"spleen_19.nii.gz\", \"label\": \"spleen_19.nii.gz\"},\n                    {\"image\": \"spleen_31.nii.gz\", \"label\": \"spleen_31.nii.gz\"},\n                ],\n                \"test\": [\"spleen_15.nii.gz\", \"spleen_23.nii.gz\"],\n            }\n            json_str = json.dumps(test_data)\n            file_path = os.path.join(tempdir, \"test_data.json\")\n            with open(file_path, \"w\") as json_file:\n                json_file.write(json_str)\n\n            data_list = DatasetFunc(\n                data=file_path, func=load_decathlon_datalist, data_list_key=\"training\", base_dir=tempdir\n            )\n            # partition dataset for train / validation\n            data_partition = DatasetFunc(\n                data=data_list, func=lambda x, **kwargs: partition_dataset(x, **kwargs)[0], num_partitions=2\n            )\n            dataset = Dataset(data=data_partition, transform=None)\n            self.assertEqual(dataset[0][\"image\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            self.assertEqual(dataset[0][\"label\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_dataset_summary.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\n\nfrom monai.data import Dataset, DatasetSummary, create_test_image_3d\nfrom monai.transforms import LoadImaged\nfrom monai.transforms.compose import Compose\nfrom monai.transforms.utility.dictionary import ToNumpyd\nfrom monai.utils import set_determinism\n\n\ndef test_collate(batch):\n    elem = batch[0]\n    elem_type = type(elem)\n    if isinstance(elem, np.ndarray):\n        return np.stack(batch, 0)\n    elif isinstance(elem, dict):\n        return elem_type({key: test_collate([d[key] for d in batch]) for key in elem})\n\n\nclass TestDatasetSummary(unittest.TestCase):\n\n    def test_spacing_intensity(self):\n        set_determinism(seed=0)\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i in range(5):\n                im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0)\n                n = nib.Nifti1Image(im, np.eye(4))\n                nib.save(n, os.path.join(tempdir, f\"img{i:d}.nii.gz\"))\n                n = nib.Nifti1Image(seg, np.eye(4))\n                nib.save(n, os.path.join(tempdir, f\"seg{i:d}.nii.gz\"))\n\n            train_images = sorted(glob.glob(os.path.join(tempdir, \"img*.nii.gz\")))\n            train_labels = sorted(glob.glob(os.path.join(tempdir, \"seg*.nii.gz\")))\n            data_dicts = [\n                {\"image\": image_name, \"label\": label_name} for image_name, label_name in zip(train_images, train_labels)\n            ]\n\n            t = Compose(\n                [\n                    LoadImaged(keys=[\"image\", \"label\"], image_only=False),\n                    ToNumpyd(keys=[\"image\", \"label\", \"image_meta_dict\", \"label_meta_dict\"]),\n                ]\n            )\n            dataset = Dataset(data=data_dicts, transform=t)\n\n            # test **kwargs of `DatasetSummary` for `DataLoader`\n            calculator = DatasetSummary(dataset, num_workers=4, meta_key=\"image_meta_dict\", collate_fn=test_collate)\n\n            target_spacing = calculator.get_target_spacing(spacing_key=\"pixdim\")\n            self.assertEqual(target_spacing, (1.0, 1.0, 1.0))\n            calculator.calculate_statistics()\n            np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5)\n            np.testing.assert_allclose(calculator.data_std, 0.131731, rtol=1e-5, atol=1e-5)\n            calculator.calculate_percentiles(sampling_flag=True, interval=2)\n            self.assertEqual(calculator.data_max_percentile, 1.0)\n            np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5)\n\n    def test_anisotropic_spacing(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            pixdims = [[1.0, 1.0, 5.0], [1.0, 1.0, 4.0], [1.0, 1.0, 4.5], [1.0, 1.0, 2.0], [1.0, 1.0, 1.0]]\n            for i in range(5):\n                im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0)\n                n = nib.Nifti1Image(im, np.eye(4))\n                n.header[\"pixdim\"][1:4] = pixdims[i]\n                nib.save(n, os.path.join(tempdir, f\"img{i:d}.nii.gz\"))\n                n = nib.Nifti1Image(seg, np.eye(4))\n                n.header[\"pixdim\"][1:4] = pixdims[i]\n                nib.save(n, os.path.join(tempdir, f\"seg{i:d}.nii.gz\"))\n\n            train_images = sorted(glob.glob(os.path.join(tempdir, \"img*.nii.gz\")))\n            train_labels = sorted(glob.glob(os.path.join(tempdir, \"seg*.nii.gz\")))\n            data_dicts = [\n                {\"image\": image_name, \"label\": label_name} for image_name, label_name in zip(train_images, train_labels)\n            ]\n\n            t = Compose([LoadImaged(keys=[\"image\", \"label\"])])\n            dataset = Dataset(data=data_dicts, transform=t)\n\n            calculator = DatasetSummary(dataset, num_workers=4)\n\n            target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0)\n            np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_fft_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.data.fft_utils import fftn_centered, ifftn_centered\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n#\nim = [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]\nres = [\n    [[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [3.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]]\n]\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append((p(im), p(res)))\n\n#\nTESTS_CONSISTENCY = []\nfor p in TEST_NDARRAYS:\n    TESTS_CONSISTENCY.append(p(im))\n\n#\nim_complex = [\n    [[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]]\n]\nTESTS_CONSISTENCY_COMPLEX = []\nfor p in TEST_NDARRAYS:\n    TESTS_CONSISTENCY_COMPLEX.append(p(im_complex))\n\n\nclass TestFFT(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test(self, test_data, res_data):\n        result = fftn_centered(test_data, spatial_dims=2, is_complex=False)\n        assert_allclose(result, res_data, type_test=True)\n\n    @parameterized.expand(TESTS_CONSISTENCY)\n    def test_consistency(self, test_data):\n        result = fftn_centered(test_data, spatial_dims=2, is_complex=False)\n        result = ifftn_centered(result, spatial_dims=2, is_complex=True)\n        result = (result[..., 0] ** 2 + result[..., 1] ** 2) ** 0.5\n        assert_allclose(result, test_data, type_test=False)\n\n    @parameterized.expand(TESTS_CONSISTENCY_COMPLEX)\n    def test_consistency_complex(self, test_data):\n        result = fftn_centered(test_data, spatial_dims=2)\n        result = ifftn_centered(result, spatial_dims=2)\n        assert_allclose(result, test_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_folder_layout.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.data.folder_layout import FolderLayout\n\nTEST_CASES = [\n    ({\"output_dir\": \"\"}, {}, \"subject\"),\n    ({\"output_dir\": Path(\".\")}, {}, \"subject\"),\n    ({\"output_dir\": Path(\".\")}, {\"idx\": 1}, \"subject_1\"),\n    (dict(output_dir=Path(\"/test_run_1\"), extension=\".seg\", makedirs=False), {}, \"/test_run_1/subject.seg\"),\n    (dict(output_dir=Path(\"/test_run_1\"), extension=None, makedirs=False), {}, \"/test_run_1/subject\"),\n    (\n        dict(output_dir=Path(\"/test_run_1\"), postfix=\"seg\", extension=\".test\", makedirs=False),\n        {},  # using the default subject name\n        \"/test_run_1/subject_seg.test\",\n    ),\n    (\n        dict(output_dir=Path(\"/test_run_1\"), postfix=\"seg\", extension=\".test\", makedirs=False),\n        {\"subject\": \"test.abc\"},\n        \"/test_run_1/test_seg.test\",  # subject's extension is ignored\n    ),\n    (\n        dict(output_dir=Path(\"/test_run_1/dest/test1/\"), data_root_dir=\"/test_run\", makedirs=False),\n        {\"subject\": \"/test_run/source/test.abc\"},\n        \"/test_run_1/dest/test1/source/test\",  # preserves the structure from `subject`\n    ),\n    (\n        dict(output_dir=Path(\"/test_run_1/dest/test1/\"), makedirs=False),\n        {\"subject\": \"/test_run/source/test.abc\"},\n        \"/test_run_1/dest/test1/test\",  # data_root_dir used\n    ),\n    (\n        dict(output_dir=Path(\"/test_run_1/dest/test1/\"), makedirs=False),\n        {\"subject\": \"/test_run/source/test.abc\", \"key\": \"value\"},\n        \"/test_run_1/dest/test1/test_key-value\",  # data_root_dir used\n    ),\n    (\n        dict(output_dir=Path(\"/test_run_1/\"), postfix=\"seg\", extension=\".nii\", makedirs=False),\n        dict(subject=Path(\"Sub-A\"), idx=\"00\", modality=\"T1\"),\n        \"/test_run_1/Sub-A_seg_00_modality-T1.nii\",  # test the code example\n    ),\n]\n\n\nclass TestFolderLayout(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_value(self, con_params, f_params, expected):\n        fname = FolderLayout(**con_params).filename(**f_params)\n        self.assertEqual(Path(fname), Path(expected))\n\n    def test_mkdir(self):\n        \"\"\"mkdir=True should create the directory if it does not exist.\"\"\"\n        with tempfile.TemporaryDirectory() as tempdir:\n            output_tmp = os.path.join(tempdir, \"output\")\n            FolderLayout(output_tmp, makedirs=True).filename(\"subject_test\", \"001\")\n            self.assertTrue(os.path.exists(os.path.join(output_tmp)))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_gdsdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import GDSDataset, json_hashing\nfrom monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform\nfrom monai.utils import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, skip_if_no_cuda\n\n_, has_cp = optional_import(\"cupy\")\nnib, has_nib = optional_import(\"nibabel\")\n_, has_kvikio_numpy = optional_import(\"kvikio.numpy\")\n\nTEST_CASE_1 = [\n    Compose(\n        [\n            LoadImaged(keys=[\"image\", \"label\", \"extra\"], image_only=True),\n            SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n        ]\n    ),\n    (128, 128, 128),\n]\n\nTEST_CASE_2 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"], image_only=True),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n]\n\nTEST_CASE_3 = [None, (128, 128, 128)]\n\nDTYPES = {\n    np.dtype(np.uint8): torch.uint8,\n    np.dtype(np.int8): torch.int8,\n    np.dtype(np.int16): torch.int16,\n    np.dtype(np.int32): torch.int32,\n    np.dtype(np.int64): torch.int64,\n    np.dtype(np.float16): torch.float16,\n    np.dtype(np.float32): torch.float32,\n    np.dtype(np.float64): torch.float64,\n    np.dtype(np.complex64): torch.complex64,\n    np.dtype(np.complex128): torch.complex128,\n}\n\n\nclass _InplaceXform(Transform):\n    def __call__(self, data):\n        data[0] = data[0] + 1\n        return data\n\n\n@skip_if_no_cuda\n@unittest.skipUnless(has_cp, \"Requires CuPy library.\")\n@unittest.skipUnless(has_cp and has_kvikio_numpy, \"Requires CuPy and kvikio library.\")\nclass TestDataset(unittest.TestCase):\n    def test_cache(self):\n        \"\"\"testing no inplace change to the hashed item\"\"\"\n        for p in TEST_NDARRAYS[:2]:\n            shape = (1, 10, 9, 8)\n            items = [p(np.arange(0, np.prod(shape)).reshape(shape))]\n\n            with tempfile.TemporaryDirectory() as tempdir:\n                ds = GDSDataset(\n                    data=items,\n                    transform=_InplaceXform(),\n                    cache_dir=tempdir,\n                    device=0,\n                    pickle_module=\"pickle\",\n                    # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility\n                    pickle_protocol=torch.serialization.DEFAULT_PROTOCOL,\n                )\n                assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape)))\n                ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0)\n                assert_allclose(ds[0], ds1[0], type_test=False)\n                assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape)))\n\n                ds = GDSDataset(\n                    items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0\n                )\n                assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape)))\n                ds1 = GDSDataset(\n                    items, transform=_InplaceXform(), cache_dir=tempdir, hash_transform=json_hashing, device=0\n                )\n                assert_allclose(ds[0], ds1[0], type_test=False)\n                assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape)))\n\n    def test_metatensor(self):\n        shape = (1, 10, 9, 8)\n        items = [TEST_NDARRAYS[-1](np.arange(0, np.prod(shape)).reshape(shape))]\n        with tempfile.TemporaryDirectory() as tempdir:\n            ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)\n            assert_allclose(ds[0], ds[0][0], type_test=False)\n\n    def test_dtype(self):\n        shape = (1, 10, 9, 8)\n        data = np.arange(0, np.prod(shape)).reshape(shape)\n        for _dtype in DTYPES.keys():\n            items = [np.array(data).astype(_dtype)]\n            with tempfile.TemporaryDirectory() as tempdir:\n                ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)\n                ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)\n                self.assertEqual(ds[0].dtype, _dtype)\n                self.assertEqual(ds1[0].dtype, DTYPES[_dtype])\n\n        for _dtype in DTYPES.keys():\n            items = [torch.tensor(data, dtype=DTYPES[_dtype])]\n            with tempfile.TemporaryDirectory() as tempdir:\n                ds = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)\n                ds1 = GDSDataset(data=items, transform=_InplaceXform(), cache_dir=tempdir, device=0)\n                self.assertEqual(ds[0].dtype, DTYPES[_dtype])\n                self.assertEqual(ds1[0].dtype, DTYPES[_dtype])\n\n    @unittest.skipUnless(has_nib, \"Requires nibabel package.\")\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_image2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            test_data = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2.nii.gz\"),\n                },\n            ]\n\n            cache_dir = os.path.join(os.path.join(tempdir, \"cache\"), \"data\")\n            dataset_precached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0)\n            data1_precached = dataset_precached[0]\n            data2_precached = dataset_precached[1]\n\n            dataset_postcached = GDSDataset(data=test_data, transform=transform, cache_dir=cache_dir, device=0)\n            data1_postcached = dataset_postcached[0]\n            data2_postcached = dataset_postcached[1]\n            data3_postcached = dataset_postcached[0:2]\n\n            if transform is None:\n                self.assertEqual(data1_precached[\"image\"], os.path.join(tempdir, \"test_image1.nii.gz\"))\n                self.assertEqual(data2_precached[\"label\"], os.path.join(tempdir, \"test_label2.nii.gz\"))\n                self.assertEqual(data1_postcached[\"image\"], os.path.join(tempdir, \"test_image1.nii.gz\"))\n                self.assertEqual(data2_postcached[\"extra\"], os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            else:\n                self.assertTupleEqual(data1_precached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data1_precached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data1_precached[\"extra\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"extra\"].shape, expected_shape)\n\n                self.assertTupleEqual(data1_postcached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data1_postcached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data1_postcached[\"extra\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"extra\"].shape, expected_shape)\n                for d in data3_postcached:\n                    self.assertTupleEqual(d[\"image\"].shape, expected_shape)\n\n            # update the data to cache\n            test_data_new = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1_new.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1_new.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1_new.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2_new.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2_new.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2_new.nii.gz\"),\n                },\n            ]\n            dataset_postcached.set_data(data=test_data_new)\n            # test new exchanged cache content\n            if transform is None:\n                self.assertEqual(dataset_postcached[0][\"image\"], os.path.join(tempdir, \"test_image1_new.nii.gz\"))\n                self.assertEqual(dataset_postcached[0][\"label\"], os.path.join(tempdir, \"test_label1_new.nii.gz\"))\n                self.assertEqual(dataset_postcached[1][\"extra\"], os.path.join(tempdir, \"test_extra2_new.nii.gz\"))\n\n    def test_different_transforms(self):\n        \"\"\"\n        Different instances of `GDSDataset` with the same cache_dir,\n        same input data, but different transforms should give different results.\n        \"\"\"\n        shape = (1, 10, 9, 8)\n        im = np.arange(0, np.prod(shape)).reshape(shape)\n        with tempfile.TemporaryDirectory() as path:\n            im1 = GDSDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing, device=0)[0]\n            im2 = GDSDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing, device=0)[0]\n            l2 = ((im1 - im2) ** 2).sum() ** 0.5\n            self.assertTrue(l2 > 1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_grid_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import DataLoader, GridPatchDataset, PatchIter, PatchIterd, iter_patch\nfrom monai.transforms import RandShiftIntensity, RandShiftIntensityd\nfrom monai.utils import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, get_arange_img\n\n\ndef identity_generator(x):\n    # simple transform that returns the input itself\n    for idx, item in enumerate(x):\n        yield item, idx\n\n\nTEST_CASES_ITER_PATCH = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES_ITER_PATCH.append([p, True])\n    TEST_CASES_ITER_PATCH.append([p, False])\n\nA = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)\nA11 = A[:, :2, :2]\nA12 = A[:, :2, 2:]\nA21 = A[:, 2:, :2]\nA22 = A[:, 2:, 2:]\nCOORD11 = [[0, 3], [0, 2], [0, 2]]\nCOORD12 = [[0, 3], [0, 2], [2, 4]]\nCOORD21 = [[0, 3], [2, 4], [0, 2]]\nCOORD22 = [[0, 3], [2, 4], [2, 4]]\n\nTEST_CASE_0 = [{\"patch_size\": (2, 2)}, A, [A11, A12, A21, A22], np.array([COORD11, COORD12, COORD21, COORD22])]\nTEST_CASE_1 = [{\"patch_size\": (2, 2), \"start_pos\": (0, 2, 2)}, A, [A22], np.array([COORD22])]\nTEST_CASE_2 = [{\"patch_size\": (2, 2), \"start_pos\": (0, 0, 2)}, A, [A12, A22], np.array([COORD12, COORD22])]\nTEST_CASE_3 = [{\"patch_size\": (2, 2), \"start_pos\": (0, 2, 0)}, A, [A21, A22], np.array([COORD21, COORD22])]\n\nTEST_CASES_PATCH_ITER = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_0])\n    TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_1])\n    TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_2])\n    TEST_CASES_PATCH_ITER.append([p, *TEST_CASE_3])\n\n\nclass TestGridPatchDataset(unittest.TestCase):\n    def setUp(self):\n        set_determinism(seed=1234)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES_ITER_PATCH)\n    def test_iter_patch(self, in_type, cb):\n        shape = (10, 30, 30)\n        input_img = in_type(get_arange_img(shape))\n        for p, _ in iter_patch(input_img, patch_size=(None, 10, 30, None), copy_back=cb):\n            p += 1.0\n            assert_allclose(p, in_type(get_arange_img(shape)) + 1.0, type_test=True, device_test=True)\n        assert_allclose(\n            input_img, in_type(get_arange_img(shape)) + (1.0 if cb else 0.0), type_test=True, device_test=True\n        )\n\n    @parameterized.expand(TEST_CASES_PATCH_ITER)\n    def test_patch_iter(self, in_type, input_parameters, image, expected, coords):\n        input_image = in_type(image)\n        patch_iterator = PatchIter(**input_parameters)(input_image)\n        for (result_image, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords):\n            assert_allclose(result_image, in_type(expected_patch), type_test=True, device_test=True)\n            assert_allclose(result_loc, coord, type_test=True, device_test=True)\n\n    @parameterized.expand(TEST_CASES_PATCH_ITER)\n    def test_patch_iterd(self, in_type, input_parameters, image, expected, coords):\n        image_key = \"image\"\n        input_dict = {image_key: in_type(image)}\n        patch_iterator = PatchIterd(keys=image_key, **input_parameters)(input_dict)\n        for (result_image_dict, result_loc), expected_patch, coord in zip(patch_iterator, expected, coords):\n            assert_allclose(result_image_dict[image_key], in_type(expected_patch), type_test=True, device_test=True)\n            assert_allclose(result_loc, coord, type_test=True, device_test=True)\n\n    def test_shape(self):\n        # test Iterable input data\n        test_dataset = iter([\"vwxyz\", \"helloworld\", \"worldfoobar\"])\n        result = GridPatchDataset(data=test_dataset, patch_iter=identity_generator, with_coordinates=False)\n        output = []\n        n_workers = 0 if sys.platform == \"win32\" else 2\n        for item in DataLoader(result, batch_size=3, num_workers=n_workers):\n            output.append(\"\".join(item))\n        if sys.platform == \"win32\":\n            expected = [\"ar\", \"ell\", \"ldf\", \"oob\", \"owo\", \"rld\", \"vwx\", \"wor\", \"yzh\"]\n        else:\n            expected = [\"d\", \"dfo\", \"hel\", \"low\", \"oba\", \"orl\", \"orl\", \"r\", \"vwx\", \"yzw\"]\n            self.assertEqual(len(\"\".join(expected)), len(\"\".join(list(test_dataset))))\n        self.assertEqual(sorted(output), sorted(expected))\n\n    def test_loading_array(self):\n        # test sequence input data with images\n        images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]\n        # image level\n        patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0).set_random_state(seed=1234)\n        patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))\n        ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity)\n        # use the grid patch dataset\n        for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):\n            np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))\n        np.testing.assert_allclose(\n            item[0],\n            np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),\n            rtol=1e-4,\n        )\n        np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)\n        if sys.platform != \"win32\":\n            for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2):\n                np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))\n            np.testing.assert_allclose(\n                item[0],\n                np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),\n                rtol=1e-3,\n            )\n            np.testing.assert_allclose(\n                item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5\n            )\n\n    def test_loading_dict(self):\n        set_determinism(seed=1234)\n        # test sequence input data with dict\n        data = [\n            {\n                \"image\": np.arange(16, dtype=float).reshape(1, 4, 4),\n                \"label\": np.arange(16, dtype=float).reshape(1, 4, 4),\n                \"metadata\": \"test string\",\n            },\n            {\n                \"image\": np.arange(16, dtype=float).reshape(1, 4, 4),\n                \"label\": np.arange(16, dtype=float).reshape(1, 4, 4),\n                \"metadata\": \"test string\",\n            },\n        ]\n        # image level\n        patch_intensity = RandShiftIntensityd(keys=\"image\", offsets=1.0, prob=1.0)\n        patch_iter = PatchIterd(keys=[\"image\", \"label\"], patch_size=(2, 2), start_pos=(0, 0))\n        ds = GridPatchDataset(data=data, patch_iter=patch_iter, transform=patch_intensity, with_coordinates=True)\n        # use the grid patch dataset\n        for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):\n            np.testing.assert_equal(item[0][\"image\"].shape, (2, 1, 2, 2))\n            np.testing.assert_equal(item[0][\"label\"].shape, (2, 1, 2, 2))\n            self.assertListEqual(item[0][\"metadata\"], [\"test string\", \"test string\"])\n        np.testing.assert_allclose(\n            item[0][\"image\"],\n            np.array([[[[8.708934, 9.708934], [12.708934, 13.708934]]], [[[10.8683, 11.8683], [14.8683, 15.8683]]]]),\n            rtol=1e-4,\n        )\n        np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)\n        if sys.platform != \"win32\":\n            for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2):\n                np.testing.assert_equal(item[0][\"image\"].shape, (2, 1, 2, 2))\n            np.testing.assert_allclose(\n                item[0][\"image\"],\n                np.array([[[[7.27427, 8.27427], [11.27427, 12.27427]]], [[[9.4353, 10.4353], [13.4353, 14.4353]]]]),\n                rtol=1e-3,\n            )\n            np.testing.assert_allclose(\n                item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5\n            )\n\n    def test_set_data(self):\n        from monai.transforms import Compose, Lambda, RandLambda\n\n        images = [np.arange(2, 18, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]\n\n        transform = Compose(\n            [Lambda(func=lambda x: np.array(x * 10)), RandLambda(func=lambda x: x + 1)], map_items=False\n        )\n        patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))\n        dataset = GridPatchDataset(\n            data=images,\n            patch_iter=patch_iter,\n            transform=transform,\n            cache=True,\n            cache_rate=1.0,\n            copy_cache=not sys.platform == \"linux\",\n        )\n\n        num_workers = 2 if sys.platform == \"linux\" else 0\n        for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):\n            np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))\n        np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)\n        np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)\n        # simulate another epoch, the cache content should not be modified\n        for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):\n            np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))\n        np.testing.assert_allclose(item[0], np.array([[[[81, 91], [121, 131]]], [[[101, 111], [141, 151]]]]), rtol=1e-4)\n        np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)\n\n        # update the datalist and fill the cache content\n        data_list2 = [np.arange(1, 17, dtype=float).reshape(1, 4, 4)]\n        dataset.set_data(data=data_list2)\n        # rerun with updated cache content\n        for item in DataLoader(dataset, batch_size=2, shuffle=False, num_workers=num_workers):\n            np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))\n        np.testing.assert_allclose(\n            item[0], np.array([[[[91, 101], [131, 141]]], [[[111, 121], [151, 161]]]]), rtol=1e-4\n        )\n        np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_handler_smartcache.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine\n\nfrom monai.data import SmartCacheDataset\nfrom monai.handlers import SmartCacheHandler\n\n\nclass TestHandlerSmartCache(unittest.TestCase):\n\n    def test_content(self):\n        data = [0, 1, 2, 3, 4, 5, 6, 7, 8]\n        expected = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7], [4, 5, 6, 7, 8]]\n\n        # set up engine\n        def _train_func(engine, batch):\n            self.assertListEqual(batch.tolist(), expected[engine.state.epoch - 1])\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        dataset = SmartCacheDataset(data, transform=None, replace_rate=0.2, cache_num=5, shuffle=False)\n        workers = 2 if sys.platform == \"linux\" else 0\n        data_loader = torch.utils.data.DataLoader(dataset, batch_size=5, num_workers=workers, persistent_workers=False)\n        SmartCacheHandler(dataset).attach(engine)\n\n        engine.run(data_loader, max_epochs=5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_hashing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.data import json_hashing, pickle_hashing\nfrom monai.utils import set_determinism\n\n\nclass TestPickleHashing(unittest.TestCase):\n\n    def test_pickle(self):\n        set_determinism(0)\n        data1 = np.random.rand(10)\n        data2 = np.random.rand(10)\n        set_determinism(0)\n        data3 = np.random.rand(10)\n        data4 = np.random.rand(10)\n        set_determinism(None)\n\n        h1 = pickle_hashing(data1)\n        h2 = pickle_hashing(data3)\n        self.assertEqual(h1, h2)\n\n        data_dict1 = {\"b\": data2, \"a\": data1}\n        data_dict2 = {\"a\": data3, \"b\": data4}\n\n        h1 = pickle_hashing(data_dict1)\n        h2 = pickle_hashing(data_dict2)\n        self.assertEqual(h1, h2)\n\n        with self.assertRaises(TypeError):\n            json_hashing(data_dict1)\n\n\nclass TestJSONHashing(unittest.TestCase):\n\n    def test_json(self):\n        data_dict1 = {\"b\": \"str2\", \"a\": \"str1\"}\n        data_dict2 = {\"a\": \"str1\", \"b\": \"str2\"}\n\n        h1 = json_hashing(data_dict1)\n        h2 = json_hashing(data_dict2)\n        self.assertEqual(h1, h2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_header_correct.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\n\nfrom monai.data import correct_nifti_header_if_necessary\n\n\nclass TestCorrection(unittest.TestCase):\n\n    def test_correct(self):\n        test_img = nib.Nifti1Image(np.zeros((1, 2, 3)), np.eye(4))\n        test_img.header.set_zooms((100, 100, 100))\n        test_img = correct_nifti_header_if_necessary(test_img)\n        np.testing.assert_allclose(\n            test_img.affine,\n            np.array([[100.0, 0.0, 0.0, 0.0], [0.0, 100.0, 0.0, 0.0], [0.0, 0.0, 100.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    def test_affine(self):\n        test_img = nib.Nifti1Image(np.zeros((1, 2, 3)), np.eye(4) * 20.0)\n        test_img = correct_nifti_header_if_necessary(test_img)\n        np.testing.assert_allclose(\n            test_img.affine,\n            np.array([[20.0, 0.0, 0.0, 0.0], [0.0, 20.0, 0.0, 0.0], [0.0, 0.0, 20.0, 0.0], [0.0, 0.0, 0.0, 20.0]]),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_image_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nfrom monai.data import ImageDataset\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirst,\n    MapLabelValue,\n    RandAdjustContrast,\n    RandomizableTransform,\n    Spacing,\n)\nfrom monai.transforms.utility.array import ToNumpy\n\nFILENAMES = [\"test1.nii.gz\", \"test2.nii\", \"test3.nii.gz\"]\n\n\nclass RandTest(RandomizableTransform):\n    \"\"\"\n    randomisable transform for testing.\n    \"\"\"\n\n    def randomize(self, data=None):\n        self._a = self.R.random()\n\n    def __call__(self, data):\n        self.randomize()\n        return data + self._a\n\n\nclass _TestCompose(Compose):\n\n    def __call__(self, data, meta, lazy):\n        data = self.transforms[0](data)  # ensure channel first\n        data = self.transforms[1](data, lazy=lazy)  # spacing\n        meta = data.meta\n        if len(self.transforms) == 3:\n            return self.transforms[2](data), meta  # image contrast\n        return data, meta\n\n\nclass TestImageDataset(unittest.TestCase):\n\n    def test_use_case(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            img_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)).astype(float), np.eye(4))\n            seg_ = nib.Nifti1Image(np.random.randint(0, 2, size=(20, 20, 20)).astype(float), np.eye(4))\n            img_name, seg_name = os.path.join(tempdir, \"img.nii.gz\"), os.path.join(tempdir, \"seg.nii.gz\")\n            nib.save(img_, img_name)\n            nib.save(seg_, seg_name)\n            img_list, seg_list = [img_name], [seg_name]\n\n            img_xform = _TestCompose([EnsureChannelFirst(), Spacing(pixdim=(1.5, 1.5, 3.0)), RandAdjustContrast()])\n            seg_xform = _TestCompose([EnsureChannelFirst(), Spacing(pixdim=(1.5, 1.5, 3.0), mode=\"nearest\")])\n            img_dataset = ImageDataset(\n                image_files=img_list,\n                seg_files=seg_list,\n                transform=img_xform,\n                seg_transform=seg_xform,\n                image_only=False,\n                transform_with_metadata=True,\n            )\n            self.assertTupleEqual(img_dataset[0][0].shape, (1, 14, 14, 7))\n            self.assertTupleEqual(img_dataset[0][1].shape, (1, 14, 14, 7))\n\n    def test_dataset(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            full_names, ref_data = [], []\n            for filename in FILENAMES:\n                test_image = np.random.randint(0, 2, size=(4, 4, 4)).astype(float)\n                ref_data.append(test_image)\n                save_path = os.path.join(tempdir, filename)\n                full_names.append(save_path)\n                nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path)\n\n            # default loading no meta\n            dataset = ImageDataset(full_names)\n            for d, ref in zip(dataset, ref_data):\n                np.testing.assert_allclose(d, ref, atol=1e-3)\n\n            # loading no meta, int\n            dataset = ImageDataset(full_names, dtype=np.float16)\n            for d, _ in zip(dataset, ref_data):\n                self.assertEqual(d.dtype, torch.float16)\n\n            # loading with meta, no transform\n            dataset = ImageDataset(full_names, image_only=False)\n            for d_tuple, ref in zip(dataset, ref_data):\n                d, meta = d_tuple\n                np.testing.assert_allclose(d, ref, atol=1e-3)\n                np.testing.assert_allclose(meta[\"original_affine\"], np.eye(4))\n\n            # loading image/label, no meta\n            dataset = ImageDataset(full_names, seg_files=full_names, image_only=True)\n            for d_tuple, ref in zip(dataset, ref_data):\n                img, seg = d_tuple\n                np.testing.assert_allclose(img, ref, atol=1e-3)\n                np.testing.assert_allclose(seg, ref, atol=1e-3)\n\n            # loading image/label, no meta\n            dataset = ImageDataset(full_names, transform=lambda x: x + 1, image_only=True)\n            for d, ref in zip(dataset, ref_data):\n                np.testing.assert_allclose(d, ref + 1, atol=1e-3)\n\n            # loading image/label, with meta\n            dataset = ImageDataset(\n                full_names,\n                transform=lambda x: x + 1,\n                seg_files=full_names,\n                seg_transform=lambda x: x + 2,\n                image_only=False,\n            )\n            for d_tuple, ref in zip(dataset, ref_data):\n                img, seg, meta, seg_meta = d_tuple\n                np.testing.assert_allclose(img, ref + 1, atol=1e-3)\n                np.testing.assert_allclose(seg, ref + 2, atol=1e-3)\n                np.testing.assert_allclose(meta[\"original_affine\"], np.eye(4), atol=1e-3)\n                np.testing.assert_allclose(seg_meta[\"original_affine\"], np.eye(4), atol=1e-3)\n\n            # loading image/label, with meta\n            dataset = ImageDataset(\n                image_files=full_names,\n                seg_files=full_names,\n                labels=[1, 2, 3],\n                transform=lambda x: x + 1,\n                label_transform=Compose(\n                    [\n                        ToNumpy(),\n                        MapLabelValue(orig_labels=[1, 2, 3], target_labels=[30.0, 20.0, 10.0], dtype=np.float32),\n                    ]\n                ),\n                image_only=False,\n            )\n            for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)):\n                img, seg, label, meta, seg_meta = d_tuple\n                np.testing.assert_allclose(img, ref + 1, atol=1e-3)\n                np.testing.assert_allclose(seg, ref, atol=1e-3)\n                # test label_transform\n\n                np.testing.assert_allclose((3 - idx) * 10.0, label)\n                self.assertTrue(isinstance(label, np.ndarray))\n                self.assertEqual(label.dtype, np.float32)\n                np.testing.assert_allclose(meta[\"original_affine\"], np.eye(4), atol=1e-3)\n                np.testing.assert_allclose(seg_meta[\"original_affine\"], np.eye(4), atol=1e-3)\n\n            # loading image/label, with sync. transform\n            dataset = ImageDataset(\n                full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False\n            )\n            for d_tuple, ref in zip(dataset, ref_data):\n                img, seg, meta, seg_meta = d_tuple\n                np.testing.assert_allclose(img, seg, atol=1e-3)\n                self.assertTrue(not np.allclose(img, ref))\n                np.testing.assert_allclose(meta[\"original_affine\"], np.eye(4), atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_image_rw.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport os\nimport shutil\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.image_reader import ITKReader, NibabelReader, NrrdReader, PILReader\nfrom monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import LoadImage, SaveImage, moveaxis\nfrom monai.utils import MetaKeys, OptionalImportError, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadSaveNifti(unittest.TestCase):\n    def setUp(self):\n        self.test_dir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.test_dir, ignore_errors=True)\n\n    def nifti_rw(self, test_data, reader, writer, dtype, resample=True):\n        test_data = test_data.astype(dtype)\n        ndim = len(test_data.shape) - 1\n        for p in TEST_NDARRAYS:\n            output_ext = \".nii.gz\"\n            filepath = f\"testfile_{ndim}d\"\n            saver = SaveImage(\n                output_dir=self.test_dir,\n                output_ext=output_ext,\n                output_dtype=None,\n                resample=resample,\n                separate_folder=False,\n                writer=writer,\n            )\n            meta_dict = {\n                \"filename_or_obj\": f\"{filepath}.png\",\n                \"affine\": np.eye(4),\n                \"original_affine\": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),\n            }\n            test_data = MetaTensor(p(test_data), meta=meta_dict)\n            self.assertEqual(test_data.meta[MetaKeys.SPACE], \"RAS\")\n            saver(test_data)\n            saved_path = os.path.join(self.test_dir, filepath + \"_trans\" + output_ext)\n            self.assertTrue(os.path.exists(saved_path))\n            loader = LoadImage(image_only=True, reader=reader, squeeze_non_spatial_dims=True, dtype=None)\n            data = loader(saved_path)\n            self.assertIn(dtype.__name__, str(data.dtype))\n            meta = data.meta\n            if meta[\"original_channel_dim\"] == -1:\n                _test_data = moveaxis(test_data, 0, -1)\n            else:\n                _test_data = test_data[0]\n            if resample:\n                _test_data = moveaxis(_test_data, 0, 1)\n            assert_allclose(meta[\"qform_code\"], 1, type_test=False)\n            assert_allclose(meta[\"sform_code\"], 1, type_test=False)\n            assert_allclose(data, torch.as_tensor(_test_data))\n\n    @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, \"ITKWriter\"]))\n    def test_2d(self, reader, writer):\n        test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8)\n        self.nifti_rw(test_data, reader, writer, np.uint8)\n        self.nifti_rw(test_data, reader, writer, np.float32)\n\n    @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter]))\n    def test_3d(self, reader, writer):\n        test_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8)\n        self.nifti_rw(test_data, reader, writer, np.int16)\n        self.nifti_rw(test_data, reader, writer, float, False)\n\n    @parameterized.expand(itertools.product([NibabelReader, ITKReader], [\"NibabelWriter\", ITKWriter]))\n    def test_4d(self, reader, writer):\n        test_data = np.arange(48, dtype=np.uint8).reshape(2, 1, 3, 8)\n        self.nifti_rw(test_data, reader, writer, np.float64)\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadSavePNG(unittest.TestCase):\n    def setUp(self):\n        self.test_dir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.test_dir, ignore_errors=True)\n\n    def png_rw(self, test_data, reader, writer, dtype, resample=True):\n        test_data = test_data.astype(dtype)\n        ndim = len(test_data.shape) - 1\n        for p in TEST_NDARRAYS:\n            output_ext = \".png\"\n            filepath = f\"testfile_{ndim}d\"\n            saver = SaveImage(\n                output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer\n            )\n            test_data = MetaTensor(p(test_data), meta={\"filename_or_obj\": f\"{filepath}.png\", \"spatial_shape\": (6, 8)})\n            saver(test_data)\n            saved_path = os.path.join(self.test_dir, filepath + \"_trans\" + output_ext)\n            self.assertTrue(os.path.exists(saved_path))\n            loader = LoadImage(image_only=True, reader=reader)\n            data = loader(saved_path)\n            meta = data.meta\n            if meta[\"original_channel_dim\"] == -1:\n                _test_data = moveaxis(test_data, 0, -1)\n            else:\n                _test_data = test_data[0]\n            assert_allclose(data, torch.as_tensor(_test_data))\n\n    @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter]))\n    def test_2d(self, reader, writer):\n        test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8)\n        self.png_rw(test_data, reader, writer, np.uint8)\n\n    @parameterized.expand(itertools.product([PILReader, ITKReader], [\"monai.data.PILWriter\", ITKWriter]))\n    def test_rgb(self, reader, writer):\n        test_data = np.arange(48, dtype=np.uint8).reshape(3, 2, 8)\n        self.png_rw(test_data, reader, writer, np.uint8, False)\n\n\nclass TestRegRes(unittest.TestCase):\n    def test_0_default(self):\n        self.assertTrue(len(resolve_writer(\".png\")) > 0, \"has png writer\")\n        self.assertTrue(len(resolve_writer(\".nrrd\")) > 0, \"has nrrd writer\")\n        self.assertTrue(len(resolve_writer(\"unknown\")) > 0, \"has writer\")\n        register_writer(\"unknown1\", lambda: (_ for _ in ()).throw(OptionalImportError))\n        with self.assertRaises(OptionalImportError):\n            resolve_writer(\"unknown1\")\n\n    def test_1_new(self):\n        register_writer(\"new\", lambda x: x + 1)\n        register_writer(\"new2\", lambda x: x + 1)\n        self.assertEqual(resolve_writer(\"new\")[0](0), 1)\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadSaveNrrd(unittest.TestCase):\n    def setUp(self):\n        self.test_dir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.test_dir, ignore_errors=True)\n\n    def nrrd_rw(self, test_data, reader, writer, dtype, resample=True):\n        test_data = test_data.astype(dtype)\n        ndim = len(test_data.shape)\n        for p in TEST_NDARRAYS:\n            output_ext = \".nrrd\"\n            filepath = f\"testfile_{ndim}d\"\n            saver = SaveImage(\n                output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer\n            ).set_options(init_kwargs={\"affine_lps_to_ras\": True})\n            test_data = MetaTensor(\n                p(test_data), meta={\"filename_or_obj\": f\"{filepath}{output_ext}\", \"spatial_shape\": test_data.shape}\n            )\n            saver(test_data)\n            saved_path = os.path.join(self.test_dir, filepath + \"_trans\" + output_ext)\n            loader = LoadImage(image_only=True, reader=reader)\n            data = loader(saved_path)\n            assert_allclose(data, torch.as_tensor(test_data))\n\n    @parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter]))\n    def test_2d(self, reader, writer):\n        test_data = np.random.randn(8, 8).astype(np.float32)\n        self.nrrd_rw(test_data, reader, writer, np.float32)\n\n    @parameterized.expand(itertools.product([NrrdReader, ITKReader], [ITKWriter, ITKWriter]))\n    def test_3d(self, reader, writer):\n        test_data = np.random.randn(8, 8, 8).astype(np.float32)\n        self.nrrd_rw(test_data, reader, writer, np.float32)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_init_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.data import ITKReader, NibabelReader, NrrdReader, NumpyReader, PILReader, PydicomReader\nfrom monai.transforms import LoadImage, LoadImaged\nfrom tests.test_utils import SkipIfNoModule\n\n\nclass TestInitLoadImage(unittest.TestCase):\n    def test_load_image(self):\n        instance1 = LoadImage(image_only=False, dtype=None)\n        instance2 = LoadImage(image_only=True, dtype=None)\n        self.assertIsInstance(instance1, LoadImage)\n        self.assertIsInstance(instance2, LoadImage)\n\n        for r in [\"NibabelReader\", \"PILReader\", \"ITKReader\", \"NumpyReader\", \"NrrdReader\", \"PydicomReader\", None]:\n            inst = LoadImaged(\"image\", reader=r)\n            self.assertIsInstance(inst, LoadImaged)\n\n    @SkipIfNoModule(\"nibabel\")\n    @SkipIfNoModule(\"cupy\")\n    @SkipIfNoModule(\"kvikio\")\n    def test_load_image_to_gpu(self):\n        for to_gpu in [True, False]:\n            instance1 = LoadImage(reader=\"NibabelReader\", to_gpu=to_gpu)\n            self.assertIsInstance(instance1, LoadImage)\n\n            instance2 = LoadImaged(\"image\", reader=\"NibabelReader\", to_gpu=to_gpu)\n            self.assertIsInstance(instance2, LoadImaged)\n\n    @SkipIfNoModule(\"itk\")\n    @SkipIfNoModule(\"nibabel\")\n    @SkipIfNoModule(\"PIL\")\n    @SkipIfNoModule(\"nrrd\")\n    @SkipIfNoModule(\"Pydicom\")\n    def test_readers(self):\n        inst = ITKReader()\n        self.assertIsInstance(inst, ITKReader)\n\n        inst = NibabelReader()\n        self.assertIsInstance(inst, NibabelReader)\n        inst = NibabelReader(as_closest_canonical=True)\n        self.assertIsInstance(inst, NibabelReader)\n\n        inst = PydicomReader()\n        self.assertIsInstance(inst, PydicomReader)\n\n        inst = NumpyReader()\n        self.assertIsInstance(inst, NumpyReader)\n        inst = NumpyReader(npz_keys=\"test\")\n        self.assertIsInstance(inst, NumpyReader)\n\n        inst = PILReader()\n        self.assertIsInstance(inst, PILReader)\n\n        inst = NrrdReader()\n        self.assertIsInstance(inst, NrrdReader)\n\n    @SkipIfNoModule(\"nibabel\")\n    @SkipIfNoModule(\"cupy\")\n    @SkipIfNoModule(\"kvikio\")\n    def test_readers_to_gpu(self):\n        for to_gpu in [True, False]:\n            inst = NibabelReader(to_gpu=to_gpu)\n            self.assertIsInstance(inst, NibabelReader)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_is_supported_format.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.data import is_supported_format\n\nTEST_CASE_1 = [{\"filename\": \"testfile.nii.gz\", \"suffixes\": [\"nii\", \"nii.gz\"]}, True]\n\nTEST_CASE_2 = [{\"filename\": \"./testfile.nii.gz\", \"suffixes\": [\"nii\", \"nii.gz\"]}, True]\n\nTEST_CASE_3 = [{\"filename\": \"./test.data/file.nii.gz\", \"suffixes\": [\"nii\", \"nii.gz\"]}, True]\n\nTEST_CASE_4 = [{\"filename\": \"./test.data/file.nii\", \"suffixes\": [\"nii\", \"nii.gz\"]}, True]\n\nTEST_CASE_5 = [{\"filename\": \"C:\\\\documents\\\\testfile.nii.gz\", \"suffixes\": [\"nii\", \"nii.gz\"]}, True]\n\nTEST_CASE_6 = [{\"filename\": \"1.3.12.2.1107.5.4.4.145.nii.gz\", \"suffixes\": [\"nii.gz\"]}, True]\n\nTEST_CASE_7 = [{\"filename\": \"test.PNG\", \"suffixes\": [\"bmp\", \"png\"]}, True]\n\n\nclass TestIsSupportedFormat(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])\n    def test_value(self, input_param, result):\n        self.assertEqual(is_supported_format(**input_param), result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_iterable_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch.nn as nn\n\nfrom monai.data import DataLoader, Dataset, IterableDataset\nfrom monai.engines import SupervisedEvaluator\nfrom monai.transforms import Compose, LoadImaged, SimulateDelayd\n\n\nclass _Stream:\n\n    def __init__(self, data):\n        self.data = data\n\n    def __iter__(self):\n        return iter(self.data)\n\n\nclass TestIterableDataset(unittest.TestCase):\n\n    def test_shape(self):\n        expected_shape = (128, 128, 128)\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        test_data = []\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i in range(6):\n                nib.save(test_image, os.path.join(tempdir, f\"test_image{str(i)}.nii.gz\"))\n                test_data.append({\"image\": os.path.join(tempdir, f\"test_image{i}.nii.gz\")})\n\n            test_transform = Compose([LoadImaged(keys=\"image\"), SimulateDelayd(keys=\"image\", delay_time=1e-7)])\n\n            data_iterator = _Stream(test_data)\n            with self.assertRaises(TypeError):  # Dataset doesn't work\n                dataset = Dataset(data=data_iterator, transform=test_transform)\n                for _ in dataset:\n                    pass\n            dataset = IterableDataset(data=data_iterator, transform=test_transform)\n            for d in dataset:\n                self.assertTupleEqual(d[\"image\"].shape, expected_shape)\n\n            num_workers = 2 if sys.platform == \"linux\" else 0\n            dataloader = DataLoader(dataset=dataset, batch_size=3, num_workers=num_workers)\n            for d in dataloader:\n                self.assertTupleEqual(d[\"image\"].shape[1:], expected_shape)\n\n    def test_supervisedevaluator(self):\n        \"\"\"\n        Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader.\n        \"\"\"\n        data = list(range(10))\n        dl = DataLoader(IterableDataset(data))\n        evaluator = SupervisedEvaluator(device=\"cpu\", val_data_loader=dl, network=nn.Identity())\n        evaluator.run()  # fails if the epoch length or other internal setup is not done correctly\n\n        self.assertEqual(evaluator.state.iteration, len(data))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_itk_torch_bridge.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai\nimport monai.transforms as mt\nfrom monai.apps import download_url\nfrom monai.data import ITKReader\nfrom monai.data.itk_torch_bridge import (\n    get_itk_image_center,\n    itk_image_to_metatensor,\n    itk_to_monai_affine,\n    metatensor_to_itk_image,\n    monai_to_itk_affine,\n    monai_to_itk_ddf,\n)\nfrom monai.networks.blocks import Warp\nfrom monai.transforms import Affine\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import (\n    assert_allclose,\n    skip_if_downloading_fails,\n    skip_if_quick,\n    test_is_quick,\n    testing_data_config,\n)\n\nitk, has_itk = optional_import(\"itk\")\n_, has_nib = optional_import(\"nibabel\")\n\nTESTS = [\"CT_2D_head_fixed.mha\", \"CT_2D_head_moving.mha\"]\nif not test_is_quick():\n    TESTS += [\"copd1_highres_INSP_STD_COPD_img.nii.gz\", \"copd1_highres_EXP_STD_COPD_img.nii.gz\"]\n\nRW_TESTS = TESTS + [\"nrrd_example.nrrd\"]\n\n\n@unittest.skipUnless(has_itk, \"Requires `itk` package.\")\nclass TestITKTorchAffineMatrixBridge(unittest.TestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n        self.data_dir = Path(__file__).parents[1] / \"testing_data\"\n        self.reader = ITKReader(pixel_type=itk.F)\n\n        for file_name in RW_TESTS:\n            path = os.path.join(self.data_dir, file_name)\n            if not os.path.exists(path):\n                with skip_if_downloading_fails():\n                    data_spec = testing_data_config(\"images\", f\"{file_name.split('.', 1)[0]}\")\n                    download_url(\n                        data_spec[\"url\"], path, hash_val=data_spec[\"hash_val\"], hash_type=data_spec[\"hash_type\"]\n                    )\n\n    def tearDown(self):\n        set_determinism(seed=None)\n\n    def create_itk_affine_from_parameters(\n        self, image, translation=None, rotation=None, scale=None, shear=None, center_of_rotation=None\n    ):\n        \"\"\"\n        Creates an affine transformation for an ITK image based on the provided parameters.\n\n        Args:\n            image: The ITK image.\n            translation: The translation (shift) to apply to the image.\n            rotation: The rotation to apply to the image, specified as angles in radians around the x, y, and z axes.\n            scale: The scaling factor to apply to the image.\n            shear: The shear to apply to the image.\n            center_of_rotation: The center of rotation for the image. If not specified,\n                                the center of the image is used.\n\n        Returns:\n            A tuple containing the affine transformation matrix and the translation vector.\n        \"\"\"\n        itk_transform = itk.AffineTransform[itk.D, image.ndim].New()\n\n        # Set center\n        if center_of_rotation:\n            itk_transform.SetCenter(center_of_rotation)\n        else:\n            itk_transform.SetCenter(get_itk_image_center(image))\n\n        # Set parameters\n        if rotation:\n            if image.ndim == 2:\n                itk_transform.Rotate2D(rotation[0])\n            else:\n                for i, angle_in_rads in enumerate(rotation):\n                    if angle_in_rads != 0:\n                        axis = [0, 0, 0]\n                        axis[i] = 1\n                        itk_transform.Rotate3D(axis, angle_in_rads)\n\n        if scale:\n            itk_transform.Scale(scale)\n\n        if shear:\n            itk_transform.Shear(*shear)\n\n        if translation:\n            itk_transform.Translate(translation)\n\n        matrix = np.asarray(itk_transform.GetMatrix(), dtype=np.float64)\n\n        return matrix, translation\n\n    def itk_affine_resample(self, image, matrix, translation, center_of_rotation=None, reference_image=None):\n        # Translation transform\n        itk_transform = itk.AffineTransform[itk.D, image.ndim].New()\n\n        # Set center\n        if center_of_rotation:\n            itk_transform.SetCenter(center_of_rotation)\n        else:\n            itk_transform.SetCenter(get_itk_image_center(image))\n\n        # Set matrix and translation\n        itk_transform.SetMatrix(itk.matrix_from_array(matrix))\n        itk_transform.Translate(translation)\n\n        # Interpolator\n        image = image.astype(itk.D)\n        interpolator = itk.LinearInterpolateImageFunction.New(image)\n\n        if not reference_image:\n            reference_image = image\n\n        # Resample with ITK\n        output_image = itk.resample_image_filter(\n            image, interpolator=interpolator, transform=itk_transform, output_parameters_from_image=reference_image\n        )\n\n        return np.asarray(output_image, dtype=np.float32)\n\n    def monai_affine_resample(self, metatensor, affine_matrix):\n        affine = Affine(\n            affine=affine_matrix, padding_mode=\"zeros\", mode=\"bilinear\", dtype=torch.float64, image_only=True\n        )\n        output_tensor = affine(metatensor)\n\n        return output_tensor.squeeze().permute(*torch.arange(output_tensor.ndim - 2, -1, -1)).array\n\n    def remove_border(self, image):\n        \"\"\"\n        MONAI seems to have different behavior in the borders of the image than ITK.\n        This helper function sets the border of the ITK image as 0 (padding but keeping\n        the same image size) in order to allow numerical comparison between the\n        result from resampling with ITK/Elastix and resampling with MONAI.\n        To use: image[:] = remove_border(image)\n        Args:\n            image: The ITK image to be padded.\n\n        Returns:\n            The padded array of data.\n        \"\"\"\n        return np.pad(image[1:-1, 1:-1, 1:-1] if image.ndim == 3 else image[1:-1, 1:-1], pad_width=1)\n\n    def itk_warp(self, image, ddf):\n        \"\"\"\n        Warping with python itk\n        Args:\n            image: itk image of array shape 2D: (H, W) or 3D: (D, H, W)\n            ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W)\n        Returns:\n            warped_image: numpy array of shape (H, W) or (D, H, W)\n        \"\"\"\n        # MONAI -> ITK ddf\n        displacement_field = monai_to_itk_ddf(image, ddf)\n\n        # Resample using the ddf\n        interpolator = itk.LinearInterpolateImageFunction.New(image)\n        warped_image = itk.warp_image_filter(\n            image, interpolator=interpolator, displacement_field=displacement_field, output_parameters_from_image=image\n        )\n\n        return np.asarray(warped_image)\n\n    def monai_warp(self, image_tensor, ddf_tensor):\n        \"\"\"\n        Warping with MONAI\n        Args:\n            image_tensor: torch tensor of shape 2D: (1, 1, H, W) and 3D: (1, 1, D, H, W)\n            ddf_tensor: torch tensor of shape 2D: (1, 2, H, W) and 3D: (1, 3, D, H, W)\n        Returns:\n            warped_image: numpy array of shape (H, W) or (D, H, W)\n        \"\"\"\n        warp = Warp(mode=\"bilinear\", padding_mode=\"zeros\")\n        warped_image = warp(image_tensor.to(torch.float64), ddf_tensor.to(torch.float64))\n\n        return warped_image.to(torch.float32).squeeze().numpy()\n\n    @parameterized.expand(TESTS)\n    def test_setting_affine_parameters(self, filepath):\n        # Read image\n        image = self.reader.read(os.path.join(self.data_dir, filepath))\n        image[:] = self.remove_border(image)\n        ndim = image.ndim\n\n        # Affine parameters\n        translation = [65.2, -50.2, 33.9][:ndim]\n        rotation = [0.78539816339, 1.0, -0.66][:ndim]\n        scale = [2.0, 1.5, 3.2][:ndim]\n        shear = [0, 1, 1.6]  # axis1, axis2, coeff\n\n        # Spacing\n        spacing = np.array([1.2, 1.5, 2.0])[:ndim]\n        image.SetSpacing(spacing)\n\n        # ITK\n        matrix, translation = self.create_itk_affine_from_parameters(image, translation, rotation, scale, shear)\n        output_array_itk = self.itk_affine_resample(image, matrix=matrix, translation=translation)\n\n        # MONAI\n        metatensor = itk_image_to_metatensor(image)\n        affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation)\n        output_array_monai = self.monai_affine_resample(metatensor, affine_matrix=affine_matrix_for_monai)\n\n        # Make sure that the array conversion of the inputs is the same\n        input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array\n        np.testing.assert_array_equal(input_array_monai, np.asarray(image))\n\n        # Compare outputs\n        percentage = (\n            100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size\n        )\n        self.assertGreaterEqual(percentage, 99.0)\n\n    @parameterized.expand(TESTS)\n    def test_arbitary_center_of_rotation(self, filepath):\n        # Read image\n        image = self.reader.read(os.path.join(self.data_dir, filepath))\n        image[:] = self.remove_border(image)\n        ndim = image.ndim\n\n        # ITK matrix (3x3 affine matrix)\n        matrix = np.array(\n            [\n                [0.55915995, 0.50344867, 0.43208387],\n                [0.01133669, 0.82088571, 0.86841365],\n                [0.30478496, 0.94998986, 0.32742505],\n            ]\n        )[:ndim, :ndim]\n        translation = [54.0, 2.7, -11.9][:ndim]\n\n        # Spatial properties\n        center_of_rotation = [-32.3, 125.1, 0.7][:ndim]\n        origin = [1.6, 0.5, 2.0][:ndim]\n        spacing = np.array([1.2, 1.5, 0.6])[:ndim]\n\n        image.SetSpacing(spacing)\n        image.SetOrigin(origin)\n\n        # ITK\n        output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation)\n\n        # MONAI\n        metatensor = itk_image_to_metatensor(image)\n        affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation, center_of_rotation)\n        output_array_monai = self.monai_affine_resample(metatensor, affine_matrix=affine_matrix_for_monai)\n\n        # Make sure that the array conversion of the inputs is the same\n        input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array\n        np.testing.assert_array_equal(input_array_monai, np.asarray(image))\n\n        # Compare outputs\n        percentage = (\n            100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size\n        )\n        self.assertGreaterEqual(percentage, 99.0)\n\n    @parameterized.expand(TESTS)\n    def test_monai_to_itk(self, filepath):\n        # Read image\n        image = self.reader.read(os.path.join(self.data_dir, filepath))\n        image[:] = self.remove_border(image)\n        ndim = image.ndim\n\n        # MONAI affine matrix\n        affine_matrix = torch.eye(ndim + 1, dtype=torch.float64)\n        affine_matrix[:ndim, :ndim] = torch.tensor(\n            [\n                [0.55915995, 0.50344867, 0.43208387],\n                [0.01133669, 0.82088571, 0.86841365],\n                [0.30478496, 0.94998986, 0.32742505],\n            ],\n            dtype=torch.float64,\n        )[:ndim, :ndim]\n\n        affine_matrix[:ndim, ndim] = torch.tensor([54.0, 2.7, -11.9], dtype=torch.float64)[:ndim]\n\n        # Spatial properties\n        center_of_rotation = [-32.3, 125.1, 0.7][:ndim]\n        origin = [1.6, 0.5, 2.0][:ndim]\n        spacing = np.array([1.2, 1.5, 0.6])[:ndim]\n\n        image.SetSpacing(spacing)\n        image.SetOrigin(origin)\n\n        # ITK\n        matrix, translation = monai_to_itk_affine(image, affine_matrix, center_of_rotation)\n        output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation)\n\n        # MONAI\n        metatensor = itk_image_to_metatensor(image)\n        output_array_monai = self.monai_affine_resample(metatensor, affine_matrix)\n\n        # Make sure that the array conversion of the inputs is the same\n        input_array_monai = metatensor.squeeze().permute(*torch.arange(metatensor.ndim - 2, -1, -1)).array\n        np.testing.assert_array_equal(input_array_monai, np.asarray(image))\n\n        # Compare outputs\n        percentage = (\n            100 * np.isclose(output_array_monai, output_array_itk).sum(dtype=np.float64) / output_array_itk.size\n        )\n        self.assertGreaterEqual(percentage, 99.0)\n\n    @parameterized.expand(TESTS)\n    def test_cyclic_conversion(self, filepath):\n        image = self.reader.read(os.path.join(self.data_dir, filepath))\n        image[:] = self.remove_border(image)\n        ndim = image.ndim\n\n        # ITK matrix (3x3 affine matrix)\n        matrix = np.array(\n            [\n                [2.90971094, 1.18297296, 2.60008784],\n                [0.29416137, 0.10294283, 2.82302616],\n                [1.70578374, 1.39706003, 2.54652029],\n            ]\n        )[:ndim, :ndim]\n\n        translation = [-29.05463245, 35.27116398, 48.58759597][:ndim]\n\n        # Spatial properties\n        center_of_rotation = [-27.84789587, -60.7871084, 42.73501932][:ndim]\n        origin = [8.10416794, 5.4831944, 0.49211025][:ndim]\n        spacing = np.array([0.7, 3.2, 1.3])[:ndim]\n\n        direction = np.array(\n            [\n                [1.02895588, 0.22791448, 0.02429561],\n                [0.21927512, 1.28632268, -0.14932226],\n                [0.47455613, 0.38534345, 0.98505633],\n            ],\n            dtype=np.float64,\n        )\n        image.SetDirection(direction[:ndim, :ndim])\n\n        image.SetSpacing(spacing)\n        image.SetOrigin(origin)\n\n        affine_matrix = itk_to_monai_affine(image, matrix, translation, center_of_rotation)\n        matrix_result, translation_result = monai_to_itk_affine(image, affine_matrix, center_of_rotation)\n\n        meta_tensor = itk_image_to_metatensor(image)\n        image_result = metatensor_to_itk_image(meta_tensor)\n\n        np.testing.assert_allclose(matrix, matrix_result)\n        np.testing.assert_allclose(translation, translation_result)\n        np.testing.assert_array_equal(image.shape, image_result.shape)\n        np.testing.assert_array_equal(image, image_result)\n\n    @parameterized.expand([(2,), (3,)])\n    def test_random_array(self, ndim):\n        # Create image/array with random size and pixel intensities\n        s = torch.randint(low=2, high=20, size=(ndim,))\n        img = 100 * torch.rand((1, 1, *s.tolist()), dtype=torch.float32)\n\n        # Pad at the edges because ITK and MONAI have different behavior there\n        # during resampling\n        img = torch.nn.functional.pad(img, pad=ndim * (1, 1))\n        ddf = 5 * torch.rand((1, ndim, *img.shape[-ndim:]), dtype=torch.float32) - 2.5\n\n        # Warp with MONAI\n        img_resampled = self.monai_warp(img, ddf)\n\n        # Create ITK image\n        itk_img = itk.GetImageFromArray(img.squeeze().numpy())\n\n        # Set random spacing\n        spacing = 3 * np.random.rand(ndim)\n        itk_img.SetSpacing(spacing)\n\n        # Set random direction\n        direction = 5 * np.random.rand(ndim, ndim) - 5\n        direction = itk.matrix_from_array(direction)\n        itk_img.SetDirection(direction)\n\n        # Set random origin\n        origin = 100 * np.random.rand(ndim) - 100\n        itk_img.SetOrigin(origin)\n\n        # Warp with ITK\n        itk_img_resampled = self.itk_warp(itk_img, ddf.squeeze().numpy())\n\n        # Compare\n        np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-2, atol=1e-2)\n\n    @parameterized.expand(TESTS)\n    @skip_if_quick\n    def test_real_data(self, filepath):\n        # Read image\n        image = self.reader.read(os.path.join(self.data_dir, filepath))\n        image[:] = self.remove_border(image)\n        ndim = image.ndim\n\n        # Random ddf\n        ddf = 10 * torch.rand((1, ndim, *image.shape), dtype=torch.float32) - 10\n\n        # Warp with MONAI\n        image_tensor = torch.tensor(itk.GetArrayFromImage(image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)\n        img_resampled = self.monai_warp(image_tensor, ddf)\n\n        # Warp with ITK\n        itk_img_resampled = self.itk_warp(image, ddf.squeeze().numpy())\n\n        # Compare\n        np.testing.assert_allclose(img_resampled, itk_img_resampled, rtol=1e-3, atol=1e-3)\n\n    @parameterized.expand(zip(TESTS[::2], TESTS[1::2]))\n    @skip_if_quick\n    def test_use_reference_space(self, ref_filepath, filepath):\n        # Read the images\n        image = self.reader.read(os.path.join(self.data_dir, filepath))\n        image[:] = self.remove_border(image)\n        ndim = image.ndim\n\n        ref_image = self.reader.read(os.path.join(self.data_dir, ref_filepath))\n\n        # Set arbitary origin, spacing, direction for both of the images\n        image.SetSpacing([1.2, 2.0, 1.7][:ndim])\n        ref_image.SetSpacing([1.9, 1.5, 1.3][:ndim])\n\n        direction = np.array(\n            [\n                [1.02895588, 0.22791448, 0.02429561],\n                [0.21927512, 1.28632268, -0.14932226],\n                [0.47455613, 0.38534345, 0.98505633],\n            ],\n            dtype=np.float64,\n        )\n        image.SetDirection(direction[:ndim, :ndim])\n\n        ref_direction = np.array(\n            [\n                [1.26032417, -0.19243174, 0.54877414],\n                [0.31958275, 0.9543068, 0.2720827],\n                [-0.24106769, -0.22344502, 0.9143302],\n            ],\n            dtype=np.float64,\n        )\n        ref_image.SetDirection(ref_direction[:ndim, :ndim])\n\n        image.SetOrigin([57.3, 102.0, -20.9][:ndim])\n        ref_image.SetOrigin([23.3, -0.5, 23.7][:ndim])\n\n        # Set affine parameters\n        matrix = np.array(\n            [\n                [0.55915995, 0.50344867, 0.43208387],\n                [0.01133669, 0.82088571, 0.86841365],\n                [0.30478496, 0.94998986, 0.32742505],\n            ]\n        )[:ndim, :ndim]\n        translation = [54.0, 2.7, -11.9][:ndim]\n        center_of_rotation = [-32.3, 125.1, 0.7][:ndim]\n\n        # Resample using ITK\n        output_array_itk = self.itk_affine_resample(image, matrix, translation, center_of_rotation, ref_image)\n\n        # MONAI\n        metatensor = itk_image_to_metatensor(image)\n        affine_matrix_for_monai = itk_to_monai_affine(image, matrix, translation, center_of_rotation, ref_image)\n        output_array_monai = self.monai_affine_resample(metatensor, affine_matrix_for_monai)\n\n        # Compare outputs\n        np.testing.assert_allclose(output_array_monai, output_array_itk, rtol=1e-3, atol=1e-3)\n\n\n@unittest.skipUnless(has_itk, \"Requires `itk` package.\")\n@unittest.skipUnless(has_nib, \"Requires `nibabel` package.\")\n@skip_if_quick\nclass TestITKTorchRW(unittest.TestCase):\n    def setUp(self):\n        TestITKTorchAffineMatrixBridge.setUp(self)\n\n    def tearDown(self):\n        TestITKTorchAffineMatrixBridge.setUp(self)\n\n    @parameterized.expand(list(itertools.product(RW_TESTS, [\"ITKReader\", \"NrrdReader\"], [True, False])))\n    def test_rw_itk(self, filepath, reader, flip):\n        \"\"\"reading and convert: filepath, reader, flip\"\"\"\n        print(filepath, reader, flip)\n        fname = os.path.join(self.data_dir, filepath)\n        xform = mt.LoadImageD(\"img\", image_only=True, ensure_channel_first=True, affine_lps_to_ras=flip, reader=reader)\n        out = xform({\"img\": fname})[\"img\"]\n        itk_image = metatensor_to_itk_image(out, channel_dim=0, dtype=float)\n        with tempfile.TemporaryDirectory() as tempdir:\n            tname = os.path.join(tempdir, filepath) + (\".nii.gz\" if not filepath.endswith(\".nii.gz\") else \"\")\n            itk.imwrite(itk_image, tname, True)\n            ref = mt.LoadImage(image_only=True, ensure_channel_first=True, reader=\"NibabelReader\")(tname)\n        if out.meta[\"space\"] != ref.meta[\"space\"]:\n            ref.affine = monai.data.utils.orientation_ras_lps(ref.affine)\n        assert_allclose(\n            out.affine, monai.data.utils.to_affine_nd(len(out.affine) - 1, ref.affine), rtol=1e-3, atol=1e-3\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_itk_writer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import ITKWriter\nfrom monai.utils import optional_import\n\nitk, has_itk = optional_import(\"itk\")\nnib, has_nibabel = optional_import(\"nibabel\")\n\n\n@unittest.skipUnless(has_itk, \"Requires `itk` package.\")\nclass TestITKWriter(unittest.TestCase):\n\n    def test_channel_shape(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            for c in (0, 1, 2, 3):\n                fname = os.path.join(tempdir, f\"testing{c}.nii\")\n                itk_writer = ITKWriter()\n                itk_writer.set_data_array(torch.zeros(1, 2, 3, 4), channel_dim=c, squeeze_end_dims=False)\n                itk_writer.set_metadata({})\n                itk_writer.write(fname)\n                itk_obj = itk.imread(fname)\n                s = [1, 2, 3, 4]\n                s.pop(c)\n                np.testing.assert_allclose(itk.size(itk_obj), s)\n\n    def test_rgb(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            fname = os.path.join(tempdir, \"testing.png\")\n            writer = ITKWriter(output_dtype=np.uint8)\n            writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=0)\n            writer.set_metadata({\"spatial_shape\": (5, 5)})\n            writer.write(fname)\n\n            output = np.asarray(itk.imread(fname))\n            np.testing.assert_allclose(output.shape, (5, 5, 3))\n            np.testing.assert_allclose(output[1, 1], (5, 5, 4))\n\n    def test_no_channel(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            fname = os.path.join(tempdir, \"testing.nii.gz\")\n            writer = ITKWriter(output_dtype=np.uint8)\n            writer.set_data_array(np.arange(48).reshape(3, 4, 4), channel_dim=None)\n            writer.write(fname)\n\n            output = np.asarray(itk.imread(fname))\n            np.testing.assert_allclose(output.shape, (4, 4, 3))\n            np.testing.assert_allclose(output[1, 1], (5, 21, 37))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_list_data_collate.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, list_data_collate\n\na = {\"image\": np.array([1, 2, 3]), \"label\": MetaTensor([4, 5, 6])}\nb = {\"image\": np.array([7, 8, 9]), \"label\": MetaTensor([10, 11, 12])}\nc = {\"image\": np.array([13, 14, 15]), \"label\": MetaTensor([16, 7, 18])}\nd = {\"image\": np.array([19, 20, 21]), \"label\": MetaTensor([22, 23, 24])}\nTEST_CASE_1 = [[[a, b], [c, d]], dict, torch.Size([4, 3])]  # dataset returns a list of dictionary data\n\ne = (np.array([1, 2, 3]), MetaTensor([4, 5, 6]))\nf = (np.array([7, 8, 9]), MetaTensor([10, 11, 12]))\ng = (np.array([13, 14, 15]), MetaTensor([16, 7, 18]))\nh = (np.array([19, 20, 21]), MetaTensor([22, 23, 24]))\nTEST_CASE_2 = [[[e, f], [g, h]], list, torch.Size([4, 3])]  # dataset returns a list of tuple data\n\ng_m = (np.array([13, 14, 15]), MetaTensor([16, 7, 18], meta={\"key1\": 0}))\nh_m = (np.array([19, 20, 21]), MetaTensor([22, 23, 24], meta={\"key2\": 1}))\nTEST_CASE_3 = [[[g_m], [h_m]], list, torch.Size([2, 3])]\n\n\nclass TestListDataCollate(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_type_shape(self, input_data, expected_type, expected_shape):\n        result = list_data_collate(input_data)\n        self.assertIsInstance(result, expected_type)\n        if isinstance(result, dict):\n            image = result[\"image\"]\n            label = result[\"label\"]\n        else:\n            image = result[0]\n            label = result[1]\n        self.assertEqual(image.shape, expected_shape)\n        self.assertEqual(label.shape, expected_shape)\n        self.assertTrue(isinstance(label, MetaTensor))\n        self.assertTrue(label.is_batch, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_lmdbdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import LMDBDataset, json_hashing\nfrom monai.transforms import Compose, LoadImaged, SimulateDelayd, Transform\nfrom tests.test_utils import skip_if_windows\n\nTEST_CASE_1 = [\n    Compose(\n        [\n            LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n            SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n        ]\n    ),\n    (128, 128, 128),\n]\n\nTEST_CASE_2 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n]\n\nTEST_CASE_3 = [None, (128, 128, 128), None]\n\nTEST_CASE_4 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n    {\"db_name\": \"test42\"},\n]\n\nTEST_CASE_5 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n    {\"pickle_protocol\": 2, \"lmdb_kwargs\": {\"map_size\": 100 * 1024**2}},\n]\n\nTEST_CASE_6 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n    {\"db_name\": \"testdb\", \"lmdb_kwargs\": {\"map_size\": 100 * 1024**2}},\n]\n\nTEST_CASE_7 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n    {\"db_name\": \"testdb\", \"lmdb_kwargs\": {\"map_size\": 2 * 1024**2}},\n]\n\n\nclass _InplaceXform(Transform):\n    def __call__(self, data):\n        if data:\n            data[0] = data[0] + np.pi\n        else:\n            data.append(1)\n        return data\n\n\n@skip_if_windows\nclass TestLMDBDataset(unittest.TestCase):\n    def test_cache(self):\n        \"\"\"testing no inplace change to the hashed item\"\"\"\n        items = [[list(range(i))] for i in range(5)]\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            ds = LMDBDataset(items, transform=_InplaceXform(), cache_dir=tempdir, lmdb_kwargs={\"map_size\": 10 * 1024})\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n            ds1 = LMDBDataset(items, transform=_InplaceXform(), cache_dir=tempdir, lmdb_kwargs={\"map_size\": 10 * 1024})\n            self.assertEqual(list(ds1), list(ds))\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n            ds = LMDBDataset(\n                items,\n                transform=_InplaceXform(),\n                cache_dir=tempdir,\n                lmdb_kwargs={\"map_size\": 10 * 1024},\n                hash_func=json_hashing,\n            )\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n            ds1 = LMDBDataset(\n                items,\n                transform=_InplaceXform(),\n                cache_dir=tempdir,\n                lmdb_kwargs={\"map_size\": 10 * 1024},\n                hash_func=json_hashing,\n            )\n            self.assertEqual(list(ds1), list(ds))\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n        self.assertTrue(isinstance(ds1.info(), dict))\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])\n    def test_shape(self, transform, expected_shape, kwargs=None):\n        kwargs = kwargs or {}\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_image2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            test_data = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2.nii.gz\"),\n                },\n            ]\n\n            cache_dir = os.path.join(os.path.join(tempdir, \"cache\"), \"data\")\n            dataset_precached = LMDBDataset(\n                data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs\n            )\n            data1_precached = dataset_precached[0]\n            data2_precached = dataset_precached[1]\n\n            dataset_postcached = LMDBDataset(\n                data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs\n            )\n            data1_postcached = dataset_postcached[0]\n            data2_postcached = dataset_postcached[1]\n\n            if transform is None:\n                self.assertEqual(data1_precached[\"image\"], os.path.join(tempdir, \"test_image1.nii.gz\"))\n                self.assertEqual(data2_precached[\"label\"], os.path.join(tempdir, \"test_label2.nii.gz\"))\n                self.assertEqual(data1_postcached[\"image\"], os.path.join(tempdir, \"test_image1.nii.gz\"))\n                self.assertEqual(data2_postcached[\"extra\"], os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            else:\n                self.assertTupleEqual(data1_precached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data1_precached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data1_precached[\"extra\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"extra\"].shape, expected_shape)\n\n                self.assertTupleEqual(data1_postcached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data1_postcached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data1_postcached[\"extra\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"extra\"].shape, expected_shape)\n\n            # update the data to cache\n            test_data_new = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1_new.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1_new.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1_new.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2_new.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2_new.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2_new.nii.gz\"),\n                },\n            ]\n            # test new exchanged cache content\n            if transform is None:\n                dataset_postcached.set_data(data=test_data_new)\n                self.assertEqual(dataset_postcached[0][\"image\"], os.path.join(tempdir, \"test_image1_new.nii.gz\"))\n                self.assertEqual(dataset_postcached[0][\"label\"], os.path.join(tempdir, \"test_label1_new.nii.gz\"))\n                self.assertEqual(dataset_postcached[1][\"extra\"], os.path.join(tempdir, \"test_extra2_new.nii.gz\"))\n            else:\n                with self.assertRaises(RuntimeError):\n                    dataset_postcached.set_data(data=test_data_new)  # filename list updated, files do not exist\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_lmdbdataset_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport shutil\nimport tempfile\nimport unittest\n\nimport numpy as np\n\nfrom monai.data import LMDBDataset, json_hashing\nfrom monai.transforms import Transform\nfrom tests.test_utils import DistCall, DistTestCase, skip_if_windows\n\n\nclass _InplaceXform(Transform):\n    def __call__(self, data):\n        if data:\n            data[0] = data[0] + np.pi\n        else:\n            data.append(1)\n        return data\n\n\n@skip_if_windows\nclass TestMPLMDBDataset(DistTestCase):\n    def setUp(self):\n        self.tempdir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.tempdir)\n\n    @DistCall(nnodes=1, nproc_per_node=1)\n    def test_mp_cache(self):\n        items = [[list(range(i))] for i in range(5)]\n\n        ds = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={\"map_size\": 10 * 1024})\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n        ds1 = LMDBDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, lmdb_kwargs={\"map_size\": 10 * 1024})\n        self.assertEqual(list(ds1), list(ds))\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n        ds = LMDBDataset(\n            items,\n            transform=_InplaceXform(),\n            cache_dir=self.tempdir,\n            lmdb_kwargs={\"map_size\": 10 * 1024},\n            hash_func=json_hashing,\n        )\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n        ds1 = LMDBDataset(\n            items,\n            transform=_InplaceXform(),\n            cache_dir=self.tempdir,\n            lmdb_kwargs={\"map_size\": 10 * 1024},\n            hash_func=json_hashing,\n        )\n        self.assertEqual(list(ds1), list(ds))\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n        self.assertTrue(isinstance(ds1.info(), dict))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_load_decathlon_datalist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom monai.data import load_decathlon_datalist\n\n\nclass TestLoadDecathlonDatalist(unittest.TestCase):\n\n    def test_seg_values(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data = {\n                \"name\": \"Spleen\",\n                \"description\": \"Spleen Segmentation\",\n                \"labels\": {\"0\": \"background\", \"1\": \"spleen\"},\n                \"training\": [\n                    {\"image\": \"spleen_19.nii.gz\", \"label\": \"spleen_19.nii.gz\"},\n                    {\"image\": \"spleen_31.nii.gz\", \"label\": \"spleen_31.nii.gz\"},\n                ],\n                \"test\": [{\"image\": \"spleen_15.nii.gz\"}, {\"image\": \"spleen_23.nii.gz\"}],\n            }\n            json_str = json.dumps(test_data)\n            file_path = os.path.join(tempdir, \"test_data.json\")\n            with open(file_path, \"w\") as json_file:\n                json_file.write(json_str)\n            result = load_decathlon_datalist(file_path, True, \"training\", tempdir)\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            self.assertEqual(result[0][\"label\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            result = load_decathlon_datalist(file_path, True, \"test\", None)\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"spleen_15.nii.gz\"))\n\n    def test_cls_values(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data = {\n                \"name\": \"ChestXRay\",\n                \"description\": \"Chest X-ray classification\",\n                \"labels\": {\"0\": \"background\", \"1\": \"chest\"},\n                \"training\": [{\"image\": \"chest_19.nii.gz\", \"label\": 0}, {\"image\": \"chest_31.nii.gz\", \"label\": 1}],\n                \"test\": [\"chest_15.nii.gz\", \"chest_23.nii.gz\"],\n            }\n            json_str = json.dumps(test_data)\n            file_path = os.path.join(tempdir, \"test_data.json\")\n            with open(file_path, \"w\") as json_file:\n                json_file.write(json_str)\n            result = load_decathlon_datalist(file_path, False, \"training\", tempdir)\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"chest_19.nii.gz\"))\n            self.assertEqual(result[0][\"label\"], 0)\n\n    def test_seg_no_basedir(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data = {\n                \"name\": \"Spleen\",\n                \"description\": \"Spleen Segmentation\",\n                \"labels\": {\"0\": \"background\", \"1\": \"spleen\"},\n                \"training\": [\n                    {\n                        \"image\": os.path.join(tempdir, \"spleen_19.nii.gz\"),\n                        \"label\": os.path.join(tempdir, \"spleen_19.nii.gz\"),\n                    },\n                    {\n                        \"image\": os.path.join(tempdir, \"spleen_31.nii.gz\"),\n                        \"label\": os.path.join(tempdir, \"spleen_31.nii.gz\"),\n                    },\n                ],\n                \"test\": [os.path.join(tempdir, \"spleen_15.nii.gz\"), os.path.join(tempdir, \"spleen_23.nii.gz\")],\n            }\n            json_str = json.dumps(test_data)\n            file_path = os.path.join(tempdir, \"test_data.json\")\n            with open(file_path, \"w\") as json_file:\n                json_file.write(json_str)\n            result = load_decathlon_datalist(file_path, True, \"training\", None)\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            self.assertEqual(result[0][\"label\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            result = load_decathlon_datalist(file_path, True, \"test\", None)\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"spleen_15.nii.gz\"))\n\n    def test_seg_no_labels(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            test_data = {\n                \"name\": \"Spleen\",\n                \"description\": \"Spleen Segmentation\",\n                \"labels\": {\"0\": \"background\", \"1\": \"spleen\"},\n                \"test\": [\"spleen_15.nii.gz\", \"spleen_23.nii.gz\"],\n            }\n            json_str = json.dumps(test_data)\n            file_path = os.path.join(tempdir, \"test_data.json\")\n            with open(file_path, \"w\") as json_file:\n                json_file.write(json_str)\n            result = load_decathlon_datalist(file_path, True, \"test\", tempdir)\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"spleen_15.nii.gz\"))\n\n    def test_additional_items(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            with open(os.path.join(tempdir, \"mask31.txt\"), \"w\") as f:\n                f.write(\"spleen31 mask\")\n\n            test_data = {\n                \"name\": \"Spleen\",\n                \"description\": \"Spleen Segmentation\",\n                \"labels\": {\"0\": \"background\", \"1\": \"spleen\"},\n                \"training\": [\n                    {\"image\": \"spleen_19.nii.gz\", \"label\": \"spleen_19.nii.gz\", \"mask\": \"spleen mask\"},\n                    {\"image\": \"spleen_31.nii.gz\", \"label\": \"spleen_31.nii.gz\", \"mask\": \"mask31.txt\"},\n                ],\n                \"test\": [\"spleen_15.nii.gz\", \"spleen_23.nii.gz\"],\n            }\n            json_str = json.dumps(test_data)\n            file_path = os.path.join(tempdir, \"test_data.json\")\n            with open(file_path, \"w\") as json_file:\n                json_file.write(json_str)\n            result = load_decathlon_datalist(file_path, True, \"training\", Path(tempdir))\n            self.assertEqual(result[0][\"image\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            self.assertEqual(result[0][\"label\"], os.path.join(tempdir, \"spleen_19.nii.gz\"))\n            self.assertEqual(result[1][\"mask\"], os.path.join(tempdir, \"mask31.txt\"))\n            self.assertEqual(result[0][\"mask\"], \"spleen mask\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_make_nifti.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d\nfrom monai.utils import optional_import\nfrom tests.test_utils import make_nifti_image\n\n_, has_nib = optional_import(\"nibabel\")\n\nTESTS = []\nfor affine in (None, np.eye(4), torch.eye(4)):\n    for dir in (None, tempfile.mkdtemp()):\n        for fname in (None, \"fname\"):\n            TESTS.append([{\"affine\": affine, \"dir\": dir, \"fname\": fname}])\n\n\n@unittest.skipUnless(has_nib, \"Requires nibabel\")\nclass TestMakeNifti(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_make_nifti(self, params):\n        im, _ = create_test_image_2d(100, 88)\n        created_file = make_nifti_image(im, verbose=True, **params)\n        self.assertTrue(os.path.isfile(created_file))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_mapping_file.py",
    "content": "# Copyright (c) MONAI Consortium\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nfrom __future__ import annotations\r\n\r\nimport json\r\nimport os\r\nimport shutil\r\nimport tempfile\r\nimport unittest\r\n\r\nimport numpy as np\r\nfrom parameterized import parameterized\r\n\r\nfrom monai.data import DataLoader, Dataset\r\nfrom monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping\r\nfrom monai.utils import optional_import\r\n\r\nnib, has_nib = optional_import(\"nibabel\")\r\n\r\n\r\ndef create_input_file(temp_dir, name):\r\n    test_image = np.random.rand(128, 128, 128)\r\n    output_ext = \".nii.gz\"\r\n    input_file = os.path.join(temp_dir, name + output_ext)\r\n    nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)\r\n    return input_file\r\n\r\n\r\ndef create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True):\r\n    return Compose(\r\n        [\r\n            LoadImage(image_only=True),\r\n            SaveImage(output_dir=temp_dir, output_ext=\".nii.gz\", savepath_in_metadict=savepath_in_metadict),\r\n            WriteFileMapping(mapping_file_path=mapping_file_path),\r\n        ]\r\n    )\r\n\r\n\r\n@unittest.skipUnless(has_nib, \"nibabel required\")\r\nclass TestWriteFileMapping(unittest.TestCase):\r\n    def setUp(self):\r\n        self.temp_dir = tempfile.mkdtemp()\r\n\r\n    def tearDown(self):\r\n        shutil.rmtree(self.temp_dir)\r\n\r\n    @parameterized.expand([(True,), (False,)])\r\n    def test_mapping_file(self, savepath_in_metadict):\r\n        mapping_file_path = os.path.join(self.temp_dir, \"mapping.json\")\r\n        name = \"test_image\"\r\n        input_file = create_input_file(self.temp_dir, name)\r\n        output_file = os.path.join(self.temp_dir, name, name + \"_trans.nii.gz\")\r\n\r\n        transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict)\r\n\r\n        if savepath_in_metadict:\r\n            transform(input_file)\r\n            self.assertTrue(os.path.exists(mapping_file_path))\r\n            with open(mapping_file_path) as f:\r\n                mapping_data = json.load(f)\r\n            self.assertEqual(len(mapping_data), 1)\r\n            self.assertEqual(mapping_data[0][\"input\"], input_file)\r\n            self.assertEqual(mapping_data[0][\"output\"], output_file)\r\n        else:\r\n            with self.assertRaises(RuntimeError) as cm:\r\n                transform(input_file)\r\n            cause_exception = cm.exception.__cause__\r\n            self.assertIsInstance(cause_exception, KeyError)\r\n            self.assertIn(\r\n                \"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.\",\r\n                str(cause_exception),\r\n            )\r\n\r\n    def test_multiprocess_mapping_file(self):\r\n        num_images = 50\r\n\r\n        single_mapping_file = os.path.join(self.temp_dir, \"single_mapping.json\")\r\n        multi_mapping_file = os.path.join(self.temp_dir, \"multi_mapping.json\")\r\n\r\n        data = [create_input_file(self.temp_dir, f\"test_image_{i}\") for i in range(num_images)]\r\n\r\n        # single process\r\n        single_transform = create_transform(self.temp_dir, single_mapping_file)\r\n        single_dataset = Dataset(data=data, transform=single_transform)\r\n        single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True)\r\n        for _ in single_loader:\r\n            pass\r\n\r\n        # multiple processes\r\n        multi_transform = create_transform(self.temp_dir, multi_mapping_file)\r\n        multi_dataset = Dataset(data=data, transform=multi_transform)\r\n        multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True)\r\n        for _ in multi_loader:\r\n            pass\r\n\r\n        with open(single_mapping_file) as f:\r\n            single_mapping_data = json.load(f)\r\n        with open(multi_mapping_file) as f:\r\n            multi_mapping_data = json.load(f)\r\n\r\n        single_set = {(entry[\"input\"], entry[\"output\"]) for entry in single_mapping_data}\r\n        multi_set = {(entry[\"input\"], entry[\"output\"]) for entry in multi_mapping_data}\r\n\r\n        self.assertEqual(single_set, multi_set)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    unittest.main()\r\n"
  },
  {
    "path": "tests/data/test_masked_patch_wsi_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nfrom numpy.testing import assert_array_equal\nfrom parameterized import parameterized\n\nfrom monai.data import Dataset, MaskedPatchWSIDataset\nfrom monai.transforms import Lambdad\nfrom monai.utils import ProbMapKeys, WSIPatchKeys, optional_import, set_determinism\nfrom tests.test_utils import download_url_or_skip_test, testing_data_config\n\nset_determinism(0)\n\ncucim, has_cucim = optional_import(\"cucim\")\nhas_cucim = has_cucim and hasattr(cucim, \"CuImage\")\n_, has_osl = optional_import(\"openslide\")\n_, has_tiff = optional_import(\"tifffile\", name=\"imwrite\")\n_, has_codec = optional_import(\"imagecodecs\")\nhas_tiff = has_tiff and has_codec\n\nFILE_KEY = \"wsi_generic_tiff\"\nFILE_URL = testing_data_config(\"images\", FILE_KEY, \"url\")\nTESTS_PATH = Path(__file__).parents[1]\nFILE_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"temp_{FILE_KEY}.tiff\")\n\nTEST_CASE_0 = [\n    {\"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LEVEL: 8, WSIPatchKeys.SIZE: (2, 2)}], \"mask_level\": 8},\n    {\n        \"num_patches\": 4256,\n        \"wsi_size\": [32914, 46000],\n        \"mask_level\": 8,\n        \"patch_level\": 8,\n        \"mask_size\": (128, 179),\n        \"patch_size\": (2, 2),\n    },\n]\n\nTEST_CASE_1 = [\n    {\n        \"data\": Dataset([{\"image\": FILE_PATH}], transform=Lambdad(keys=\"image\", func=lambda x: x[:])),\n        \"mask_level\": 8,\n        \"patch_level\": 8,\n        \"patch_size\": (2, 2),\n    },\n    {\n        \"num_patches\": 4256,\n        \"wsi_size\": [32914, 46000],\n        \"mask_level\": 8,\n        \"patch_level\": 8,\n        \"mask_size\": (128, 179),\n        \"patch_size\": (2, 2),\n    },\n]\n\n\n@skipUnless(has_cucim or has_osl or has_tiff, \"Requires cucim, openslide, or tifffile!\")\ndef setUpModule():\n    hash_type = testing_data_config(\"images\", FILE_KEY, \"hash_type\")\n    hash_val = testing_data_config(\"images\", FILE_KEY, \"hash_val\")\n    download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)\n\n\nclass MaskedPatchWSIDatasetTests:\n    class Tests(unittest.TestCase):\n        backend = None\n\n        @parameterized.expand([TEST_CASE_0, TEST_CASE_1])\n        def test_gen_patches(self, input_parameters, expected):\n            dataset = MaskedPatchWSIDataset(reader=self.backend, **input_parameters)\n            self.assertEqual(len(dataset), expected[\"num_patches\"])\n            self.assertTrue(isinstance(dataset.image_data, list))\n            for d1, d2 in zip(dataset.image_data, input_parameters[\"data\"]):\n                self.assertTrue(d1[\"image\"] == d2[\"image\"])\n                self.assertTrue(d1[ProbMapKeys.NAME] == os.path.basename(d2[\"image\"]))\n\n            for i, sample in enumerate(dataset):\n                self.assertEqual(sample[\"image\"].meta[WSIPatchKeys.LEVEL], expected[\"patch_level\"])\n                assert_array_equal(sample[\"image\"].meta[WSIPatchKeys.SIZE], expected[\"patch_size\"])\n                assert_array_equal(sample[\"image\"].shape[1:], expected[\"patch_size\"])\n                self.assertTrue(sample[\"image\"].meta[WSIPatchKeys.LOCATION][0] >= 0)\n                self.assertTrue(sample[\"image\"].meta[WSIPatchKeys.LOCATION][0] < expected[\"wsi_size\"][0])\n                self.assertTrue(sample[\"image\"].meta[WSIPatchKeys.LOCATION][1] >= 0)\n                self.assertTrue(sample[\"image\"].meta[WSIPatchKeys.LOCATION][1] < expected[\"wsi_size\"][1])\n                if i > 10:\n                    break\n\n\n@skipUnless(has_cucim, \"Requires cucim\")\nclass TestSlidingPatchWSIDatasetCuCIM(MaskedPatchWSIDatasetTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"cucim\"\n\n\n@skipUnless(has_osl, \"Requires openslide\")\nclass TestSlidingPatchWSIDatasetOpenSlide(MaskedPatchWSIDatasetTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"openslide\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_nifti_header_revise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\n\nfrom monai.data import rectify_header_sform_qform\n\n\nclass TestRectifyHeaderSformQform(unittest.TestCase):\n\n    def test_revise_q(self):\n        img = nib.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4))\n        img.header.set_zooms((0.1, 0.2, 0.3))\n        output = rectify_header_sform_qform(img)\n        expected = np.diag([0.1, 0.2, 0.3, 1.0])\n        np.testing.assert_allclose(output.affine, expected)\n\n    def test_revise_both(self):\n        img = nib.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4))\n        img.header.set_sform(np.diag([5, 3, 4, 1]))\n        img.header.set_qform(np.diag([2, 3, 4, 1]))\n        img.header.set_zooms((0.1, 0.2, 0.3))\n        output = rectify_header_sform_qform(img)\n        expected = np.diag([0.1, 0.2, 0.3, 1.0])\n        np.testing.assert_allclose(output.affine, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_nifti_rw.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import NibabelWriter\nfrom monai.transforms import LoadImage, Orientation, Spacing\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, make_nifti_image\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for q in TEST_NDARRAYS:\n        TEST_IMAGE = p(np.arange(24).reshape((2, 4, 3)))\n        TEST_AFFINE = q(\n            np.array(\n                [[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]]\n            )\n        )\n        TESTS.append(\n            [\n                TEST_IMAGE,\n                TEST_AFFINE,\n                dict(reader=\"NibabelReader\", image_only=True, as_closest_canonical=True),\n                np.array(\n                    [\n                        [[12.0, 15.0, 18.0, 21.0], [13.0, 16.0, 19.0, 22.0], [14.0, 17.0, 20.0, 23.0]],\n                        [[0.0, 3.0, 6.0, 9.0], [1.0, 4.0, 7.0, 10.0], [2.0, 5.0, 8.0, 11.0]],\n                    ]\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                TEST_IMAGE,\n                TEST_AFFINE,\n                dict(reader=\"NibabelReader\", image_only=True, as_closest_canonical=False),\n                np.arange(24).reshape((2, 4, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                TEST_IMAGE,\n                TEST_AFFINE,\n                dict(reader=\"NibabelReader\", image_only=True, as_closest_canonical=False),\n                np.arange(24).reshape((2, 4, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                TEST_IMAGE,\n                None,\n                dict(reader=\"NibabelReader\", image_only=True, as_closest_canonical=False),\n                np.arange(24).reshape((2, 4, 3)),\n            ]\n        )\n\n\nclass TestNiftiLoadRead(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_orientation(self, array, affine, reader_param, expected):\n        test_image = make_nifti_image(array, affine)\n\n        # read test cases\n        loader = LoadImage(**reader_param)\n        load_result = loader(test_image)\n        data_array = load_result.numpy()\n        if reader_param.get(\"image_only\", False):\n            header = None\n        else:\n            header = load_result.meta\n            header[\"affine\"] = header[\"affine\"].numpy()\n        if os.path.exists(test_image):\n            os.remove(test_image)\n\n        # write test cases\n        writer_obj = NibabelWriter()\n        writer_obj.set_data_array(data_array, channel_dim=None)\n        if header is not None:\n            writer_obj.set_metadata(header)\n        elif affine is not None:\n            writer_obj.set_metadata({\"affine\": affine})\n        writer_obj.write(test_image, verbose=True)\n        saved = nib.load(test_image)\n        saved_affine = saved.affine\n        saved_data = saved.get_fdata()\n        if os.path.exists(test_image):\n            os.remove(test_image)\n\n        if affine is not None:\n            assert_allclose(saved_affine, affine, type_test=False)\n        assert_allclose(saved_data, expected, type_test=False)\n\n    def test_consistency(self):\n        np.set_printoptions(suppress=True, precision=3)\n        test_image = make_nifti_image(np.arange(64).reshape(1, 8, 8), np.diag([1.5, 1.5, 1.5, 1]))\n        data = LoadImage(image_only=True, reader=\"NibabelReader\", as_closest_canonical=False)(test_image)\n        data = Spacing([0.8, 0.8, 0.8])(data[None], mode=\"nearest\")\n        original_affine = data.meta[\"original_affine\"]\n        data = Orientation(\"ILP\")(data)\n        new_affine = data.affine\n        if os.path.exists(test_image):\n            os.remove(test_image)\n        writer_obj = NibabelWriter()\n        writer_obj.set_data_array(data[0], channel_dim=None)\n        writer_obj.set_metadata(\n            meta_dict={\"affine\": new_affine, \"original_affine\": original_affine}, mode=\"nearest\", padding_mode=\"border\"\n        )\n        writer_obj.write(test_image, verbose=True)\n        saved = nib.load(test_image)\n        saved_data = saved.get_fdata()\n        np.testing.assert_allclose(saved_data, np.arange(64).reshape(1, 8, 8), atol=1e-7)\n        if os.path.exists(test_image):\n            os.remove(test_image)\n        writer_obj.set_data_array(data[0], channel_dim=None)\n        writer_obj.set_metadata(\n            meta_dict={\"affine\": new_affine, \"original_affine\": original_affine, \"spatial_shape\": (1, 8, 8)},\n            mode=\"nearest\",\n            padding_mode=\"border\",\n        )\n        writer_obj.write(test_image, verbose=True)\n        saved = nib.load(test_image)\n        saved_data = saved.get_fdata()\n        np.testing.assert_allclose(saved_data, np.arange(64).reshape(1, 8, 8), atol=1e-7)\n        if os.path.exists(test_image):\n            os.remove(test_image)\n        # test the case no resample\n        writer_obj.set_data_array(data[0], channel_dim=None)\n        writer_obj.set_metadata(meta_dict={\"affine\": new_affine, \"original_affine\": original_affine}, resample=False)\n        writer_obj.write(test_image, verbose=True)\n        saved = nib.load(test_image)\n        np.testing.assert_allclose(saved.affine, new_affine)\n        if os.path.exists(test_image):\n            os.remove(test_image)\n\n    def test_write_2d(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.nii.gz\")\n            for p in TEST_NDARRAYS:\n                img = p(np.arange(6).reshape((2, 3)))\n                writer_obj = NibabelWriter()\n                writer_obj.set_data_array(img, channel_dim=None)\n                writer_obj.set_metadata({\"affine\": np.diag([1, 1, 1]), \"original_affine\": np.diag([1.4, 1, 1])})\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]], atol=1e-4, rtol=1e-4)\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4)\n\n                image_name = os.path.join(out_dir, \"test1.nii.gz\")\n                img = np.arange(5).reshape((1, 5))\n                writer_obj.set_data_array(img, channel_dim=None)\n                writer_obj.set_metadata(\n                    {\"affine\": np.diag([1, 1, 1, 3, 3]), \"original_affine\": np.diag([1.4, 2.0, 1, 3, 5])}\n                )\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]], atol=1e-4, rtol=1e-4)\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1]), atol=1e-4, rtol=1e-4)\n\n    def test_write_3d(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.nii.gz\")\n            for p in TEST_NDARRAYS:\n                img = p(np.arange(6).reshape((1, 2, 3)))\n                writer_obj = NibabelWriter()\n                writer_obj.set_data_array(img, channel_dim=None)\n                writer_obj.set_metadata({\"affine\": np.diag([1, 1, 1, 1]), \"original_affine\": np.diag([1.4, 1, 1, 1])})\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]], atol=1e-4, rtol=1e-4)\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4)\n\n                image_name = os.path.join(out_dir, \"test1.nii.gz\")\n                img = p(np.arange(5).reshape((1, 1, 5)))\n                writer_obj.set_data_array(img, channel_dim=None)\n                writer_obj.set_metadata(\n                    {\"affine\": np.diag([1, 1, 1, 3, 3]), \"original_affine\": np.diag([1.4, 2.0, 2, 3, 5])}\n                )\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]], atol=1e-4, rtol=1e-4)\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4)\n\n    def test_write_4d(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.nii.gz\")\n            for p in TEST_NDARRAYS:\n                img = p(np.arange(6).reshape((1, 1, 3, 2)))\n                writer_obj = NibabelWriter()\n                writer_obj.set_data_array(img, channel_dim=-1)\n                writer_obj.set_metadata({\"affine\": np.diag([1.4, 1, 1, 1]), \"original_affine\": np.diag([1, 1.4, 1, 1])})\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]], atol=1e-4, rtol=1e-4)\n                np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1]), atol=1e-4, rtol=1e-4)\n\n                image_name = os.path.join(out_dir, \"test1.nii.gz\")\n                img = p(np.arange(5).reshape((1, 1, 5, 1)))\n                writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False)\n                writer_obj.set_metadata(\n                    {\"affine\": np.diag([1, 1, 1, 3, 3]), \"original_affine\": np.diag([1.4, 2.0, 2, 3, 5])}\n                )\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]], atol=1e-4, rtol=1e-4)\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4)\n\n    def test_write_5d(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.nii.gz\")\n            for p in TEST_NDARRAYS:\n                img = p(np.arange(12).reshape((1, 1, 3, 2, 2)))\n                writer_obj = NibabelWriter()\n                writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False, spatial_ndim=None)\n                writer_obj.set_metadata({\"affine\": np.diag([1, 1, 1, 1]), \"original_affine\": np.diag([1.4, 1, 1, 1])})\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(\n                    out.get_fdata(),\n                    np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]),\n                    atol=1e-4,\n                    rtol=1e-4,\n                )\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4)\n\n                image_name = os.path.join(out_dir, \"test1.nii.gz\")\n                img = p(np.arange(10).reshape((1, 1, 5, 1, 2)))\n                writer_obj.set_data_array(img, channel_dim=-1, squeeze_end_dims=False, spatial_ndim=None)\n                writer_obj.set_metadata({\"affine\": np.diag([1, 1, 1, 3]), \"original_affine\": np.diag([1.4, 2.0, 2, 3])})\n                writer_obj.write(image_name, verbose=True)\n                out = nib.load(image_name)\n                np.testing.assert_allclose(\n                    out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]]), atol=1e-4, rtol=1e-4\n                )\n                np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_npzdictitemdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\nfrom io import BytesIO\n\nimport numpy as np\n\nfrom monai.data import NPZDictItemDataset\n\n\nclass TestNPZDictItemDataset(unittest.TestCase):\n\n    def test_load_stream(self):\n        dat0 = np.random.rand(10, 1, 4, 4)\n        dat1 = np.random.rand(10, 1, 4, 4)\n\n        npzfile = BytesIO()\n        np.savez_compressed(npzfile, dat0=dat0, dat1=dat1)\n        npzfile.seek(0)\n\n        npzds = NPZDictItemDataset(npzfile, {\"dat0\": \"images\", \"dat1\": \"seg\"})\n\n        item = npzds[0]\n\n        np.testing.assert_allclose(item[\"images\"].shape, (1, 4, 4))\n        np.testing.assert_allclose(item[\"seg\"].shape, (1, 4, 4))\n\n    def test_load_file(self):\n        dat0 = np.random.rand(10, 1, 4, 4)\n        dat1 = np.random.rand(10, 1, 4, 4)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            npzfile = f\"{tempdir}/test.npz\"\n\n            np.savez_compressed(npzfile, dat0=dat0, dat1=dat1)\n\n            npzds = NPZDictItemDataset(npzfile, {\"dat0\": \"images\", \"dat1\": \"seg\"})\n\n            item = npzds[0]\n\n            np.testing.assert_allclose(item[\"images\"].shape, (1, 4, 4))\n            np.testing.assert_allclose(item[\"seg\"].shape, (1, 4, 4))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_nrrd_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom unittest.case import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import NrrdReader\nfrom monai.utils.module import optional_import\n\nnrrd, has_nrrd = optional_import(\"nrrd\", allow_namespace_pkg=True)\n\nTEST_CASE_1 = [(4, 4), \"test_image.nrrd\", (4, 4), np.uint8]\nTEST_CASE_2 = [(4, 4, 4), \"test_image.nrrd\", (4, 4, 4), np.uint16]\nTEST_CASE_3 = [(4, 4, 4, 4), \"test_image.nrrd\", (4, 4, 4, 4), np.uint32]\nTEST_CASE_4 = [(1, 2, 3, 4, 5), \"test_image.nrrd\", (1, 2, 3, 4, 5), np.uint64]\nTEST_CASE_5 = [(6, 5, 4, 3, 2, 1), \"test_image.nrrd\", (6, 5, 4, 3, 2, 1), np.float32]\nTEST_CASE_6 = [(4,), \"test_image.nrrd\", (4,), np.float64]\nTEST_CASE_7 = [(4, 4), [\"test_image.nrrd\", \"test_image2.nrrd\", \"test_image3.nrrd\"], (4, 4), np.float32]\nTEST_CASE_8 = [\n    (3, 4, 4, 1),\n    \"test_image.nrrd\",\n    (3, 4, 4, 1),\n    np.float32,\n    {\n        \"dimension\": 4,\n        \"space\": \"left-posterior-superior\",\n        \"sizes\": [3, 4, 4, 1],\n        \"space directions\": [[0.7, 0.0, 0.0], [0.0, 0.0, -0.8], [0.0, 0.9, 0.0]],\n        \"space origin\": [1.0, 5.0, 20.0],\n    },\n]\n\n\n@skipUnless(has_nrrd, \"nrrd required\")\nclass TestNrrdReader(unittest.TestCase):\n\n    def test_verify_suffix(self):\n        reader = NrrdReader()\n        self.assertFalse(reader.verify_suffix(\"test_image.nrd\"))\n        reader.verify_suffix(\"test_image.nrrd\")\n        reader.verify_suffix(\"test_image.seg.nrrd\")\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_read_int(self, data_shape, filename, expected_shape, dtype):\n        min_val, max_val = np.iinfo(dtype).min, np.iinfo(dtype).max\n        test_image = np.random.randint(min_val, max_val, size=data_shape, dtype=dtype)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, filename)\n            nrrd.write(filename, test_image.astype(dtype))\n            reader = NrrdReader()\n            result = reader.read(filename)\n        self.assertEqual(result.array.dtype, dtype)\n        self.assertTupleEqual(result.array.shape, expected_shape)\n        self.assertTupleEqual(tuple(result.header[\"sizes\"]), expected_shape)\n        np.testing.assert_allclose(result.array, test_image)\n\n    @parameterized.expand([TEST_CASE_5, TEST_CASE_6])\n    def test_read_float(self, data_shape, filename, expected_shape, dtype):\n        test_image = np.random.rand(*data_shape).astype(dtype)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, filename)\n            nrrd.write(filename, test_image.astype(dtype))\n            reader = NrrdReader()\n            result = reader.read(filename)\n        self.assertEqual(result.array.dtype, dtype)\n        self.assertTupleEqual(result.array.shape, expected_shape)\n        self.assertTupleEqual(tuple(result.header[\"sizes\"]), expected_shape)\n        np.testing.assert_allclose(result.array, test_image)\n\n    @parameterized.expand([TEST_CASE_7])\n    def test_read_list(self, data_shape, filenames, expected_shape, dtype):\n        test_image = np.random.rand(*data_shape).astype(dtype)\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, filename in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, filename)\n                nrrd.write(filenames[i], test_image.astype(dtype))\n            reader = NrrdReader()\n            results = reader.read(filenames)\n        for result in results:\n            self.assertTupleEqual(result.array.shape, expected_shape)\n            self.assertTupleEqual(tuple(result.header[\"sizes\"]), expected_shape)\n            np.testing.assert_allclose(result.array, test_image)\n\n    @parameterized.expand([TEST_CASE_8])\n    def test_read_with_header(self, data_shape, filename, expected_shape, dtype, reference_header):\n        test_image = np.random.rand(*data_shape).astype(dtype)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, filename)\n            nrrd.write(filename, test_image.astype(dtype), header=reference_header)\n            reader = NrrdReader()\n            image_array, image_header = reader.get_data(reader.read(filename))\n        self.assertIsInstance(image_array, np.ndarray)\n        self.assertEqual(image_array.dtype, dtype)\n        self.assertTupleEqual(image_array.shape, expected_shape)\n        np.testing.assert_allclose(image_array, test_image)\n        self.assertIsInstance(image_header, dict)\n        self.assertTupleEqual(tuple(image_header[\"spatial_shape\"]), expected_shape)\n        np.testing.assert_allclose(\n            image_header[\"affine\"],\n            np.array([[-0.7, 0.0, 0.0, -1.0], [0.0, 0.0, -0.9, -5.0], [0.0, -0.8, 0.0, 20.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    @parameterized.expand([TEST_CASE_8])\n    def test_read_with_header_index_order_c(self, data_shape, filename, expected_shape, dtype, reference_header):\n        test_image = np.random.rand(*data_shape).astype(dtype)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, filename)\n            nrrd.write(filename, test_image.astype(dtype), header=reference_header)\n            reader = NrrdReader(index_order=\"C\")\n            image_array, image_header = reader.get_data(reader.read(filename))\n        self.assertIsInstance(image_array, np.ndarray)\n        self.assertEqual(image_array.dtype, dtype)\n        self.assertTupleEqual(image_array.shape, expected_shape[::-1])\n        self.assertTupleEqual(image_array.shape, tuple(image_header[\"spatial_shape\"]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_numpy_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport numpy as np\n\nfrom monai.data import DataLoader, Dataset, NumpyReader\nfrom monai.transforms import LoadImage, LoadImaged\nfrom tests.test_utils import assert_allclose\n\n\nclass TestNumpyReader(unittest.TestCase):\n    def test_npy(self):\n        test_data = np.random.randint(0, 256, size=[3, 4, 4])\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npy\")\n            np.save(filepath, test_data)\n\n            reader = NumpyReader()\n            result = reader.get_data(reader.read(filepath))\n        np.testing.assert_allclose(result[1][\"spatial_shape\"], test_data.shape)\n        np.testing.assert_allclose(result[0].shape, test_data.shape)\n        np.testing.assert_allclose(result[0], test_data)\n\n    def test_npz1(self):\n        test_data1 = np.random.randint(0, 256, size=[3, 4, 4])\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npy\")\n            np.save(filepath, test_data1)\n\n            reader = NumpyReader()\n            result = reader.get_data(reader.read(filepath))\n        np.testing.assert_allclose(result[1][\"spatial_shape\"], test_data1.shape)\n        np.testing.assert_allclose(result[0].shape, test_data1.shape)\n        np.testing.assert_allclose(result[0], test_data1)\n\n    def test_npz2(self):\n        test_data1 = np.random.randint(0, 256, size=[3, 4, 4])\n        test_data2 = np.random.randint(0, 256, size=[3, 4, 4])\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npz\")\n            np.savez(filepath, test_data1, test_data2)\n\n            reader = NumpyReader()\n            result = reader.get_data(reader.read(filepath))\n        np.testing.assert_allclose(result[1][\"spatial_shape\"], test_data1.shape)\n        np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4))\n        np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2]))\n\n    def test_npz3(self):\n        test_data1 = np.random.randint(0, 256, size=[3, 4, 4])\n        test_data2 = np.random.randint(0, 256, size=[3, 4, 4])\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npz\")\n            np.savez(filepath, test1=test_data1, test2=test_data2)\n\n            reader = NumpyReader(npz_keys=[\"test1\", \"test2\"])\n            result = reader.get_data(reader.read(filepath))\n        np.testing.assert_allclose(result[1][\"spatial_shape\"], test_data1.shape)\n        np.testing.assert_allclose(result[0].shape, (2, 3, 4, 4))\n        np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2]))\n\n    def test_npy_pickle(self):\n        test_data = {\"test\": np.random.randint(0, 256, size=[3, 4, 4])}\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npy\")\n            np.save(filepath, test_data, allow_pickle=True)\n\n            reader = NumpyReader()\n            result = reader.get_data(reader.read(filepath))[0].item()\n        np.testing.assert_allclose(result[\"test\"].shape, test_data[\"test\"].shape)\n        np.testing.assert_allclose(result[\"test\"], test_data[\"test\"])\n\n    def test_kwargs(self):\n        test_data = {\"test\": np.random.randint(0, 256, size=[3, 4, 4])}\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npy\")\n            np.save(filepath, test_data, allow_pickle=True)\n\n            reader = NumpyReader(mmap_mode=\"r\")\n            result = reader.get_data(reader.read(filepath, mmap_mode=None))[0].item()\n        np.testing.assert_allclose(result[\"test\"].shape, test_data[\"test\"].shape)\n\n    def test_dataloader(self):\n        test_data = np.random.randint(0, 256, size=[3, 4, 5])\n        datalist_dict, datalist_array = [], []\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i in range(4):\n                filepath = os.path.join(tempdir, f\"test_data{i}.npz\")\n                np.savez(filepath, test_data)\n                datalist_dict.append({\"image\": filepath})\n                datalist_array.append(filepath)\n\n            num_workers = 2 if sys.platform == \"linux\" else 0\n            loader = DataLoader(\n                Dataset(data=datalist_dict, transform=LoadImaged(keys=\"image\", reader=NumpyReader())),\n                batch_size=2,\n                num_workers=num_workers,\n            )\n            for d in loader:\n                for c in d[\"image\"]:\n                    assert_allclose(c, test_data, type_test=False)\n\n            loader = DataLoader(\n                Dataset(data=datalist_array, transform=LoadImage(reader=NumpyReader())),\n                batch_size=2,\n                num_workers=num_workers,\n            )\n            for d in loader:\n                for c in d:\n                    assert_allclose(c, test_data, type_test=False)\n\n    def test_channel_dim(self):\n        test_data = np.random.randint(0, 256, size=[3, 4, 5, 2])\n        with tempfile.TemporaryDirectory() as tempdir:\n            filepath = os.path.join(tempdir, \"test_data.npy\")\n            np.save(filepath, test_data)\n\n            reader = NumpyReader(channel_dim=-1)\n            result = reader.get_data(reader.read(filepath))\n        np.testing.assert_allclose(result[1][\"spatial_shape\"], test_data.shape[:-1])\n        self.assertEqual(result[1][\"original_channel_dim\"], -1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_partition_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.data import partition_dataset\n\nTEST_CASE_1 = [\n    {\n        \"data\": [1, 2, 3, 4],\n        \"num_partitions\": 2,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[1, 3], [2, 4]],\n]\n\nTEST_CASE_2 = [\n    {\n        \"data\": [1, 2, 3, 4],\n        \"num_partitions\": 2,\n        \"shuffle\": True,\n        \"seed\": 123,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[4, 2], [1, 3]],\n]\n\nTEST_CASE_3 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"num_partitions\": 2,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[1, 3, 5], [2, 4]],\n]\n\nTEST_CASE_4 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"num_partitions\": 2,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": True,\n    },\n    [[1, 3, 5], [2, 4, 1]],\n]\n\nTEST_CASE_5 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"num_partitions\": 2,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": True,\n        \"even_divisible\": True,\n    },\n    [[1, 3], [2, 4]],\n]\n\nTEST_CASE_6 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"ratios\": [3, 2],\n        \"num_partitions\": None,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": True,\n        \"even_divisible\": True,\n    },\n    [[1, 2, 3], [4, 5]],\n]\n\nTEST_CASE_7 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"ratios\": [2, 1],\n        \"num_partitions\": None,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": True,\n        \"even_divisible\": True,\n    },\n    [[1, 2, 3], [4, 5]],\n]\n\nTEST_CASE_8 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"ratios\": [2, 1],\n        \"num_partitions\": None,\n        \"shuffle\": True,\n        \"seed\": 123,\n        \"drop_last\": True,\n        \"even_divisible\": True,\n    },\n    [[2, 4, 5], [1, 3]],\n]\n\n\nclass TestPartitionDataset(unittest.TestCase):\n\n    @parameterized.expand(\n        [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]\n    )\n    def test_value(self, input_param, result):\n        self.assertListEqual(partition_dataset(**input_param), result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_partition_dataset_classes.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import partition_dataset_classes\n\nTEST_CASE_1 = [\n    {\n        \"data\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],\n        \"classes\": [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3],\n        \"ratios\": [2, 1],\n        \"num_partitions\": None,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[2, 8, 4, 1, 3, 6, 5, 11, 12], [10, 13, 7, 9, 14]],\n]\n\nTEST_CASE_2 = [\n    {\n        \"data\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],\n        \"classes\": [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3],\n        \"ratios\": None,\n        \"num_partitions\": 2,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[2, 10, 4, 1, 6, 9, 5, 12], [8, 13, 3, 7, 11, 14]],\n]\n\nTEST_CASE_3 = [\n    {\n        \"data\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],\n        \"classes\": [2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3],\n        \"ratios\": None,\n        \"num_partitions\": 2,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": True,\n    },\n    [[2, 10, 4, 1, 6, 9, 5, 12], [8, 2, 13, 3, 7, 1, 11, 14]],\n]\n\nTEST_CASE_4 = [\n    {\n        \"data\": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]),\n        \"classes\": np.array([2, 0, 2, 1, 3, 2, 2, 0, 2, 0, 3, 3, 1, 3]),\n        \"ratios\": [1, 2],\n        \"num_partitions\": None,\n        \"shuffle\": True,\n        \"seed\": 123,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[13, 7, 14, 2, 3], [6, 8, 1, 5, 12, 11, 4, 9, 10]],\n]\n\n\nclass TestPartitionDatasetClasses(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_value(self, input_param, result):\n        self.assertListEqual(partition_dataset_classes(**input_param), result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_patch_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport numpy as np\n\nfrom monai.data import DataLoader, Dataset, PatchDataset\nfrom monai.transforms import RandShiftIntensity, RandSpatialCropSamples\nfrom monai.utils import set_determinism\n\n\ndef identity(x):\n    # simple transform that returns the input itself\n    return x\n\n\nclass TestPatchDataset(unittest.TestCase):\n\n    def test_shape(self):\n        test_dataset = [\"vwxyz\", \"hello\", \"world\"]\n        n_per_image = len(test_dataset[0])\n\n        result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image)\n\n        output = []\n        n_workers = 0 if sys.platform == \"win32\" else 2\n        for item in DataLoader(result, batch_size=3, num_workers=n_workers):\n            output.append(\"\".join(item))\n        if n_workers == 0:\n            expected = [\"vwx\", \"yzh\", \"ell\", \"owo\", \"rld\"]\n        else:\n            expected = [\"vwx\", \"hel\", \"yzw\", \"lo\", \"orl\", \"d\"]\n        self.assertEqual(output, expected)\n\n    def test_loading_array(self):\n        set_determinism(seed=1234)\n        # image dataset\n        images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)]\n        # image patch sampler\n        n_samples = 8\n        sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples, random_center=True, random_size=False)\n\n        # image level\n        patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)\n        image_ds = Dataset(images, transform=patch_intensity)\n        # patch level\n        ds = PatchDataset(data=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity)\n\n        np.testing.assert_equal(len(ds), n_samples * len(images))\n        # use the patch dataset, length: len(images) x samplers_per_image\n        for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0):\n            np.testing.assert_equal(tuple(item.shape), (2, 1, 3, 3))\n        np.testing.assert_allclose(\n            item[0],\n            np.array(\n                [[[4.970372, 5.970372, 6.970372], [8.970372, 9.970372, 10.970372], [12.970372, 13.970372, 14.970372]]]\n            ),\n            rtol=1e-5,\n        )\n        if sys.platform != \"win32\":\n            for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=2):\n                np.testing.assert_equal(tuple(item.shape), (2, 1, 3, 3))\n            np.testing.assert_allclose(\n                item[0],\n                np.array(\n                    [\n                        [\n                            [5.028125, 6.028125, 7.028125],\n                            [9.028125, 10.028125, 11.028125],\n                            [13.028125, 14.028125, 15.028125],\n                        ]\n                    ]\n                ),\n                rtol=1e-5,\n            )\n        set_determinism(seed=None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_patch_wsi_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom numpy.testing import assert_array_equal\nfrom parameterized import parameterized\n\nfrom monai.data import PatchWSIDataset\nfrom monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader\nfrom monai.utils import optional_import\nfrom monai.utils.enums import WSIPatchKeys\nfrom tests.test_utils import download_url_or_skip_test, testing_data_config\n\ncucim, has_cim = optional_import(\"cucim\")\nhas_cim = has_cim and hasattr(cucim, \"CuImage\")\nopenslide, has_osl = optional_import(\"openslide\")\nimwrite, has_tiff = optional_import(\"tifffile\", name=\"imwrite\")\n_, has_codec = optional_import(\"imagecodecs\")\nhas_tiff = has_tiff and has_codec\n\nFILE_KEY = \"wsi_generic_tiff\"\nFILE_URL = testing_data_config(\"images\", FILE_KEY, \"url\")\nTESTS_PATH = Path(__file__).parents[1].as_posix()\nFILE_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"temp_{FILE_KEY}.tiff\")\n\nTEST_CASE_0 = [\n    {\n        \"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [1], \"patch_level\": 0}],\n        \"patch_size\": (1, 1),\n    },\n    {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([1])},\n]\n\nTEST_CASE_0_L1 = [\n    {\n        \"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [1]}],\n        \"patch_size\": (1, 1),\n        \"patch_level\": 1,\n    },\n    {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([1])},\n]\n\nTEST_CASE_0_L2 = [\n    {\n        \"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [1]}],\n        \"patch_size\": (1, 1),\n        \"patch_level\": 1,\n    },\n    {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([1])},\n]\nTEST_CASE_1 = [\n    {\"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], WSIPatchKeys.SIZE.value: 1, \"label\": [1]}]},\n    {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([1])},\n]\n\nTEST_CASE_2 = [\n    {\n        \"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [1]}],\n        \"patch_size\": 1,\n        \"patch_level\": 0,\n    },\n    {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([1])},\n]\n\nTEST_CASE_3 = [\n    {\"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [[[0, 1], [1, 0]]]}], \"patch_size\": 1},\n    {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([[[0, 1], [1, 0]]])},\n]\n\nTEST_CASE_4 = [\n    {\n        \"data\": [\n            {\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [[[0, 1], [1, 0]]]},\n            {\"image\": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], \"label\": [[[1, 0], [0, 0]]]},\n        ],\n        \"patch_size\": 1,\n    },\n    [\n        {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([[[0, 1], [1, 0]]])},\n        {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([[[1, 0], [0, 0]]])},\n    ],\n]\n\nTEST_CASE_5 = [\n    {\n        \"data\": [\n            {\n                \"image\": FILE_PATH,\n                WSIPatchKeys.LOCATION.value: [0, 0],\n                \"label\": [[[0, 1], [1, 0]]],\n                WSIPatchKeys.SIZE.value: 1,\n                WSIPatchKeys.LEVEL.value: 1,\n            },\n            {\n                \"image\": FILE_PATH,\n                WSIPatchKeys.LOCATION.value: [100, 100],\n                \"label\": [[[1, 0], [0, 0]]],\n                WSIPatchKeys.SIZE.value: 1,\n                WSIPatchKeys.LEVEL.value: 1,\n            },\n        ]\n    },\n    [\n        {\"image\": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), \"label\": np.array([[[0, 1], [1, 0]]])},\n        {\"image\": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), \"label\": np.array([[[1, 0], [0, 0]]])},\n    ],\n]\n\n\n@skipUnless(has_cim or has_osl or has_tiff, \"Requires cucim, openslide, or tifffile!\")\ndef setUpModule():\n    hash_type = testing_data_config(\"images\", FILE_KEY, \"hash_type\")\n    hash_val = testing_data_config(\"images\", FILE_KEY, \"hash_val\")\n    download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)\n\n\nclass PatchWSIDatasetTests:\n    class Tests(unittest.TestCase):\n        backend = None\n\n        @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n        def test_read_patches_str(self, input_parameters, expected):\n            dataset = PatchWSIDataset(reader=self.backend, **input_parameters)\n            sample = dataset[0]\n            self.assertTupleEqual(sample[\"label\"].shape, expected[\"label\"].shape)\n            self.assertTupleEqual(sample[\"image\"].shape, expected[\"image\"].shape)\n            self.assertIsNone(assert_array_equal(sample[\"label\"], expected[\"label\"]))\n            self.assertIsNone(assert_array_equal(sample[\"image\"], expected[\"image\"]))\n\n        @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n        def test_read_patches_class(self, input_parameters, expected):\n            if self.backend == \"openslide\":\n                reader = OpenSlideWSIReader\n            elif self.backend == \"cucim\":\n                reader = CuCIMWSIReader\n            else:\n                raise ValueError(\"Unsupported backend: {self.backend}\")\n            dataset = PatchWSIDataset(reader=reader, **input_parameters)\n            sample = dataset[0]\n            self.assertTupleEqual(sample[\"label\"].shape, expected[\"label\"].shape)\n            self.assertTupleEqual(sample[\"image\"].shape, expected[\"image\"].shape)\n            self.assertIsNone(assert_array_equal(sample[\"label\"], expected[\"label\"]))\n            self.assertIsNone(assert_array_equal(sample[\"image\"], expected[\"image\"]))\n\n        @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n        def test_read_patches_object(self, input_parameters, expected):\n            if self.backend == \"openslide\":\n                reader = OpenSlideWSIReader(level=input_parameters.get(\"patch_level\", 0))\n            elif self.backend == \"cucim\":\n                reader = CuCIMWSIReader(level=input_parameters.get(\"patch_level\", 0))\n            else:\n                raise ValueError(\"Unsupported backend: {self.backend}\")\n            dataset = PatchWSIDataset(reader=reader, **input_parameters)\n            sample = dataset[0]\n            self.assertTupleEqual(sample[\"label\"].shape, expected[\"label\"].shape)\n            self.assertTupleEqual(sample[\"image\"].shape, expected[\"image\"].shape)\n            self.assertIsNone(assert_array_equal(sample[\"label\"], expected[\"label\"]))\n            self.assertIsNone(assert_array_equal(sample[\"image\"], expected[\"image\"]))\n\n        @parameterized.expand([TEST_CASE_4, TEST_CASE_5])\n        def test_read_patches_str_multi(self, input_parameters, expected):\n            dataset = PatchWSIDataset(reader=self.backend, **input_parameters)\n            for i, item in enumerate(dataset):\n                self.assertTupleEqual(item[\"label\"].shape, expected[i][\"label\"].shape)\n                self.assertTupleEqual(item[\"image\"].shape, expected[i][\"image\"].shape)\n                self.assertIsNone(assert_array_equal(item[\"label\"], expected[i][\"label\"]))\n                self.assertIsNone(assert_array_equal(item[\"image\"], expected[i][\"image\"]))\n\n\n@skipUnless(has_cim, \"Requires cucim\")\nclass TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"cucim\"\n\n\n@skipUnless(has_osl, \"Requires openslide\")\nclass TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"openslide\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_persistentdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport contextlib\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, PersistentDataset, json_hashing\nfrom monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform\n\nTEST_CASE_1 = [\n    Compose(\n        [\n            LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n            SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n        ]\n    ),\n    (128, 128, 128),\n]\n\nTEST_CASE_2 = [\n    [\n        LoadImaged(keys=[\"image\", \"label\", \"extra\"]),\n        SimulateDelayd(keys=[\"image\", \"label\", \"extra\"], delay_time=[1e-7, 1e-6, 1e-5]),\n    ],\n    (128, 128, 128),\n]\n\nTEST_CASE_3 = [None, (128, 128, 128)]\n\nTEST_CASE_4 = [True, False, False, MetaTensor]\n\nTEST_CASE_5 = [True, True, True, None]\n\nTEST_CASE_6 = [False, False, False, torch.Tensor]\n\nTEST_CASE_7 = [False, True, False, torch.Tensor]\n\n\nclass _InplaceXform(Transform):\n    def __call__(self, data):\n        if data:\n            data[0] = data[0] + np.pi\n        else:\n            data.append(1)\n        return data\n\n\nclass TestDataset(unittest.TestCase):\n    def test_cache(self):\n        \"\"\"testing no inplace change to the hashed item\"\"\"\n        items = [[list(range(i))] for i in range(5)]\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            ds = PersistentDataset(\n                data=items,\n                transform=_InplaceXform(),\n                cache_dir=tempdir,\n                pickle_module=\"pickle\",\n                # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility\n                pickle_protocol=torch.serialization.DEFAULT_PROTOCOL,\n            )\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n            ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir)\n            self.assertEqual(list(ds1), list(ds))\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n            ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_func=json_hashing)\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n            ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir, hash_func=json_hashing)\n            self.assertEqual(list(ds1), list(ds))\n            self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, transform, expected_shape):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_image2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label2.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            test_data = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2.nii.gz\"),\n                },\n            ]\n\n            cache_dir = os.path.join(os.path.join(tempdir, \"cache\"), \"data\")\n            dataset_precached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir)\n            data1_precached = dataset_precached[0]\n            data2_precached = dataset_precached[1]\n\n            dataset_postcached = PersistentDataset(data=test_data, transform=transform, cache_dir=cache_dir)\n            data1_postcached = dataset_postcached[0]\n            data2_postcached = dataset_postcached[1]\n            data3_postcached = dataset_postcached[0:2]\n\n            if transform is None:\n                self.assertEqual(data1_precached[\"image\"], os.path.join(tempdir, \"test_image1.nii.gz\"))\n                self.assertEqual(data2_precached[\"label\"], os.path.join(tempdir, \"test_label2.nii.gz\"))\n                self.assertEqual(data1_postcached[\"image\"], os.path.join(tempdir, \"test_image1.nii.gz\"))\n                self.assertEqual(data2_postcached[\"extra\"], os.path.join(tempdir, \"test_extra2.nii.gz\"))\n            else:\n                self.assertTupleEqual(data1_precached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data1_precached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data1_precached[\"extra\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data2_precached[\"extra\"].shape, expected_shape)\n\n                self.assertTupleEqual(data1_postcached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data1_postcached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data1_postcached[\"extra\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"image\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"label\"].shape, expected_shape)\n                self.assertTupleEqual(data2_postcached[\"extra\"].shape, expected_shape)\n                for d in data3_postcached:\n                    self.assertTupleEqual(d[\"image\"].shape, expected_shape)\n\n            # update the data to cache\n            test_data_new = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1_new.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1_new.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1_new.nii.gz\"),\n                },\n                {\n                    \"image\": os.path.join(tempdir, \"test_image2_new.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label2_new.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra2_new.nii.gz\"),\n                },\n            ]\n            dataset_postcached.set_data(data=test_data_new)\n            # test new exchanged cache content\n            if transform is None:\n                self.assertEqual(dataset_postcached[0][\"image\"], os.path.join(tempdir, \"test_image1_new.nii.gz\"))\n                self.assertEqual(dataset_postcached[0][\"label\"], os.path.join(tempdir, \"test_label1_new.nii.gz\"))\n                self.assertEqual(dataset_postcached[1][\"extra\"], os.path.join(tempdir, \"test_extra2_new.nii.gz\"))\n\n    def test_different_transforms(self):\n        \"\"\"\n        Different instances of `PersistentDataset` with the same cache_dir,\n        same input data, but different transforms should give different results.\n        \"\"\"\n        shape = (1, 10, 9, 8)\n        im = np.arange(0, np.prod(shape)).reshape(shape)\n        with tempfile.TemporaryDirectory() as path:\n            im1 = PersistentDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing)[0]\n            im2 = PersistentDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing)[0]\n            l2 = ((im1 - im2) ** 2).sum() ** 0.5\n            self.assertGreater(l2, 1)\n\n    @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])\n    def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type):\n        \"\"\"\n        Ensure expected behavior for all combinations of `track_meta` and `weights_only`.\n        \"\"\"\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image.nii.gz\"))\n            test_data = [{\"image\": os.path.join(tempdir, \"test_image.nii.gz\")}]\n            transform = Compose([LoadImaged(keys=[\"image\"])])\n            cache_dir = os.path.join(os.path.join(tempdir, \"cache\"), \"data\")\n\n            cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext()\n            with cm:\n                test_dataset = PersistentDataset(\n                    data=test_data,\n                    transform=transform,\n                    cache_dir=cache_dir,\n                    track_meta=track_meta,\n                    weights_only=weights_only,\n                )\n\n                im = test_dataset[0][\"image\"]\n                self.assertIsInstance(im, expected_type)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_persistentdataset_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch.distributed as dist\n\nfrom monai.data import PersistentDataset, json_hashing\nfrom monai.transforms import Transform\nfrom tests.test_utils import DistCall, DistTestCase, skip_if_windows\n\n\nclass _InplaceXform(Transform):\n    def __call__(self, data):\n        if data:\n            data[0] = data[0] + np.pi\n        else:\n            data.append(1)\n        return data\n\n\n@skip_if_windows\nclass TestDistDataset(DistTestCase):\n    def setUp(self):\n        self.tempdir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.tempdir)\n\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_mp_dataset(self):\n        print(\"persistent\", dist.get_rank())\n        items = [[list(range(i))] for i in range(5)]\n        ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir)\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n        ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir)\n        self.assertEqual(list(ds1), list(ds))\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n        ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, hash_func=json_hashing)\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n        ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=self.tempdir, hash_func=json_hashing)\n        self.assertEqual(list(ds1), list(ds))\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n\n@skip_if_windows\nclass TestDistCreateDataset(DistTestCase):\n    def setUp(self):\n        self.tempdir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.tempdir)\n\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_mp_dataset(self):\n        print(\"persistent\", dist.get_rank())\n        items = [[list(range(i))] for i in range(5)]\n        cache_dir = os.path.join(self.tempdir, \"test\")\n        ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=cache_dir)\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n        ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=cache_dir)\n        self.assertEqual(list(ds1), list(ds))\n        self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_pil_reader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\nfrom PIL import Image\n\nfrom monai.data import PILReader\n\nTEST_CASE_1 = [(128, 128), [\"test_image.png\"], (128, 128), (128, 128)]\n\nTEST_CASE_2 = [(128, 128, 3), [\"test_image.png\"], (128, 128, 3), (128, 128)]\n\nTEST_CASE_3 = [(128, 128, 4), [\"test_image.png\"], (128, 128, 4), (128, 128), False]\n\nTEST_CASE_4 = [(128, 128), [\"test_image1.png\", \"test_image2.png\", \"test_image3.png\"], (3, 128, 128), (128, 128)]\n\nTEST_CASE_5 = [(128, 128, 3), [\"test_image.jpg\"], (128, 128, 3), (128, 128)]\n\nTEST_CASE_6 = [(128, 128, 3), [\"test_image.bmp\"], (128, 128, 3), (128, 128)]\n\nTEST_CASE_7 = [(128, 128, 3), [\"test_image.png\"], (128, 128, 2), (128, 128)]\n\n\nclass TestPNGReader(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])\n    def test_shape_value(self, data_shape, filenames, expected_shape, meta_shape, reverse=True):\n        test_image = np.random.randint(0, 256, size=data_shape)\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                Image.fromarray(test_image.astype(\"uint8\")).save(filenames[i])\n            reader = PILReader(mode=\"r\", reverse_indexing=reverse)\n            result = reader.get_data(reader.read(filenames))\n            # load image by PIL and compare the result\n            test_image = np.asarray(Image.open(filenames[0]))\n\n        self.assertTupleEqual(tuple(result[1][\"spatial_shape\"]), meta_shape)\n        self.assertTupleEqual(result[0].shape, expected_shape)\n        if reverse:\n            test_image = np.moveaxis(test_image, 0, 1)\n        if result[0].shape == test_image.shape:\n            np.testing.assert_allclose(result[0], test_image)\n        else:\n            np.testing.assert_allclose(result[0], np.tile(test_image, [result[0].shape[0], 1, 1]))\n\n    @parameterized.expand([TEST_CASE_7])\n    def test_converter(self, data_shape, filenames, expected_shape, meta_shape):\n        test_image = np.random.randint(0, 256, size=data_shape)\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                Image.fromarray(test_image.astype(\"uint8\")).save(filenames[i])\n            reader = PILReader(converter=lambda image: image.convert(\"LA\"))\n            result = reader.get_data(reader.read(filenames, mode=\"r\"))\n            self.assertEqual(result[1][\"format\"], \"none\")  # project-monai/monai issue#5251\n            # load image by PIL and compare the result\n            test_image = np.asarray(Image.open(filenames[0]).convert(\"LA\"))\n\n        self.assertTupleEqual(tuple(result[1][\"spatial_shape\"]), meta_shape)\n        self.assertTupleEqual(result[0].shape, expected_shape)\n        test_image = np.moveaxis(test_image, 0, 1)\n        np.testing.assert_allclose(result[0], test_image)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_png_rw.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nfrom PIL import Image\n\nfrom monai.data.image_writer import PILWriter\n\n\nclass TestPngWrite(unittest.TestCase):\n\n    def test_write_gray(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.png\")\n            img = np.random.rand(2, 3)\n            img_save_val = (255 * img).astype(np.uint8)\n            writer_obj = PILWriter(output_dtype=np.uint8)\n            writer_obj.set_data_array(img, channel_dim=None)\n            writer_obj.write(image_name, format=\"PNG\")\n            out = np.asarray(Image.open(image_name))\n            out = np.moveaxis(out, 0, 1)\n            np.testing.assert_allclose(out, img_save_val)\n\n    def test_write_gray_1height(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.png\")\n            img = np.random.rand(1, 3)\n            img_save_val = (65535 * img).astype(np.uint16)\n            writer_obj = PILWriter(output_dtype=np.uint16, scale=65535)\n            writer_obj.set_data_array(img, channel_dim=None)\n            writer_obj.write(image_name, format=\"PNG\")\n            out = np.asarray(Image.open(image_name))\n            out = np.moveaxis(out, 0, 1)\n            np.testing.assert_allclose(out, img_save_val)\n\n    def test_write_gray_1channel(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.png\")\n            img = np.random.rand(2, 3, 1)\n            img_save_val = (255 * img).astype(np.uint8).squeeze(2)\n            writer_obj = PILWriter(output_dtype=np.uint8, scale=255)\n            writer_obj.set_data_array(img, channel_dim=None)\n            writer_obj.write(image_name, format=\"PNG\")\n            out = np.asarray(Image.open(image_name))\n            out = np.moveaxis(out, 0, 1)\n            np.testing.assert_allclose(out, img_save_val)\n\n    def test_write_rgb(self):\n        \"\"\"testing default kwargs and obj_kwargs\"\"\"\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.png\")\n            img = np.random.rand(2, 3, 3)\n            img_save_val = (255 * img).astype(np.uint8)\n            writer_obj = PILWriter(output_dtype=np.uint8)\n            writer_obj.set_data_array(img, channel_dim=-1)\n            writer_obj.write(image_name)\n            out = np.asarray(Image.open(image_name))\n            out = np.moveaxis(out, 0, 1)\n            np.testing.assert_allclose(out, img_save_val)\n\n    def test_write_2channels(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.png\")\n            img = np.random.rand(2, 3, 2)\n            img_save_val = (255 * img).astype(np.uint8)\n            writer_obj = PILWriter(output_dtype=np.uint8)\n            writer_obj.set_data_array(img, channel_dim=-1)\n            writer_obj.write(image_name, format=\"PNG\")\n            out = np.asarray(Image.open(image_name))\n            out = np.moveaxis(out, 0, 1)\n            np.testing.assert_allclose(out, img_save_val)\n\n    def test_write_output_shape(self):\n        with tempfile.TemporaryDirectory() as out_dir:\n            image_name = os.path.join(out_dir, \"test.png\")\n            img = np.random.rand(2, 2, 3)\n            writer_obj = PILWriter(output_dtype=np.uint8)\n            writer_obj.set_data_array(img, channel_dim=-1)\n            writer_obj.set_metadata({\"spatial_shape\": (4, 4)}, scale=255)\n            writer_obj.write(image_name, format=\"PNG\")\n            out = np.asarray(Image.open(image_name))\n            np.testing.assert_allclose(out.shape, (4, 4, 3))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_resample_datalist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import resample_datalist\n\nTEST_CASE_1 = [\n    {\"data\": [1, 2, 3, 4, 5], \"factor\": 2.5, \"random_pick\": True, \"seed\": 123},\n    [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 2, 4, 5],\n]\n\nTEST_CASE_2 = [\n    {\"data\": [1, 2, 3, 4, 5], \"factor\": 2.5, \"random_pick\": False, \"seed\": 0},\n    [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3],\n]\n\nTEST_CASE_3 = [{\"data\": [1, 2, 3, 4, 5], \"factor\": 0.6, \"random_pick\": True, \"seed\": 123}, [2, 4, 5]]\n\n\nclass TestResampleDatalist(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_value_shape(self, input_param, expected):\n        result = resample_datalist(**input_param)\n        np.testing.assert_allclose(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_sampler_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom torch.multiprocessing import Manager\n\nfrom monai.data import CacheDataset, DataLoader, DistributedSampler\nfrom monai.transforms import ToTensor\nfrom tests.test_utils import DistCall, DistTestCase, assert_allclose\n\n\nclass DistributedSamplerTest(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_even(self):\n        data = [1, 2, 3, 4, 5]\n        sampler = DistributedSampler(dataset=data, shuffle=False)\n        samples = np.array([data[i] for i in list(sampler)])\n        self.assertEqual(dist.get_rank(), sampler.rank)\n        if dist.get_rank() == 0:\n            np.testing.assert_allclose(samples, np.array([1, 3, 5]))\n\n        if dist.get_rank() == 1:\n            np.testing.assert_allclose(samples, np.array([2, 4, 1]))\n\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_uneven(self):\n        data = [1, 2, 3, 4, 5]\n        sampler = DistributedSampler(dataset=data, shuffle=False, even_divisible=False)\n        samples = np.array([data[i] for i in list(sampler)])\n        self.assertEqual(dist.get_rank(), sampler.rank)\n        if dist.get_rank() == 0:\n            np.testing.assert_allclose(samples, np.array([1, 3, 5]))\n\n        if dist.get_rank() == 1:\n            np.testing.assert_allclose(samples, np.array([2, 4]))\n\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_uneven_less_data(self):\n        data = [1]\n        with self.assertRaises(ValueError):\n            DistributedSampler(dataset=data, shuffle=False, even_divisible=False)\n\n    @DistCall(nnodes=1, nproc_per_node=2, timeout=120)\n    def test_cachedataset(self):\n        data = [1, 2, 3, 4, 5]\n        obj_list = [Manager().list([None] * len(data))]\n        dist.broadcast_object_list(obj_list, src=0)\n        dataset = CacheDataset(\n            data=data, transform=ToTensor(track_meta=False), cache_rate=1.0, runtime_cache=obj_list[0]\n        )\n        sampler = DistributedSampler(dataset=dataset, shuffle=False, even_divisible=False)\n        dataloader = DataLoader(dataset=dataset, sampler=sampler, batch_size=1, num_workers=2)\n        dist.barrier()\n        for i in range(3):\n            if i > 0:\n                # verify the runtime cache content is completed after first epoch\n                for j, c in enumerate(dataset._cache):\n                    self.assertIsInstance(c, torch.Tensor)\n                    assert_allclose(c, j + 1, type_test=False)\n            for k, d in enumerate(dataloader):\n                self.assertIsInstance(d, torch.Tensor)\n                if dist.get_rank() == 0:\n                    assert_allclose(d[0], k * 2 + 1, type_test=False)\n\n                if dist.get_rank() == 1:\n                    assert_allclose(d[0], (k + 1) * 2, type_test=False)\n            dist.barrier()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_select_cross_validation_folds.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.data import partition_dataset, select_cross_validation_folds\n\nTEST_CASE_1 = [\n    {\n        \"data\": [1, 2, 3, 4, 5],\n        \"num_partitions\": 5,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[1, 3, 4, 5], [2]],\n]\n\nTEST_CASE_2 = [\n    {\n        \"data\": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n        \"num_partitions\": 10,\n        \"shuffle\": False,\n        \"seed\": 0,\n        \"drop_last\": False,\n        \"even_divisible\": False,\n    },\n    [[1, 3, 4, 5, 6, 7, 8, 9, 10], [2]],\n]\n\n\nclass TestSelectCrossValidationFolds(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_value(self, input_param, result):\n        partitions = partition_dataset(**input_param)\n        train = select_cross_validation_folds(partitions=partitions, folds=[0] + list(range(2, len(partitions))))\n        self.assertListEqual(train, result[0])\n        val = select_cross_validation_folds(partitions=partitions, folds=1)\n        self.assertListEqual(val, result[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_shuffle_buffer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport numpy as np\n\nfrom monai.data import DataLoader, ShuffleBuffer\nfrom monai.utils import convert_data_type\n\n\nclass TestShuffleBuffer(unittest.TestCase):\n    def test_shape(self):\n        buffer = ShuffleBuffer([1, 2, 3, 4], seed=0)\n        num_workers = 2 if sys.platform == \"linux\" else 0\n        dataloader = DataLoader(\n            dataset=buffer, batch_size=2, num_workers=num_workers, persistent_workers=num_workers > 0\n        )\n        output = [convert_data_type(x, np.ndarray)[0] for x in dataloader]\n        buffer.seed += 1\n        output2 = [convert_data_type(x, np.ndarray)[0] for x in dataloader]  # test repeating\n        if num_workers == 0:\n            np.testing.assert_allclose(output, [[2, 1], [3, 4]])\n            np.testing.assert_allclose(output2, [[3, 1], [2, 4]])\n        else:  # multiprocess shuffle\n            np.testing.assert_allclose(output, [[2, 3], [1, 4]], err_msg=f\"seed {buffer.seed}\")\n            np.testing.assert_allclose(output2, [[1, 4], [2, 3]], err_msg=f\"seed {buffer.seed}\")\n\n    def test_epochs(self):\n        buffer = ShuffleBuffer([1, 2, 3, 4], seed=0, epochs=2)\n        output = [convert_data_type(x, np.ndarray)[0] for x in DataLoader(dataset=buffer, batch_size=2)]\n        np.testing.assert_allclose(output, [[2, 1], [3, 4], [4, 2], [3, 1]])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_sliding_patch_wsi_dataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom numpy.testing import assert_array_equal\nfrom parameterized import parameterized\n\nfrom monai.data import SlidingPatchWSIDataset\nfrom monai.utils import WSIPatchKeys, optional_import, set_determinism\nfrom tests.test_utils import download_url_or_skip_test, testing_data_config\n\nset_determinism(0)\n\ncucim, has_cucim = optional_import(\"cucim\")\nhas_cucim = has_cucim and hasattr(cucim, \"CuImage\")\nopenslide, has_osl = optional_import(\"openslide\")\nimwrite, has_tiff = optional_import(\"tifffile\", name=\"imwrite\")\n_, has_codec = optional_import(\"imagecodecs\")\nhas_tiff = has_tiff and has_codec\n\nFILE_KEY = \"wsi_generic_tiff\"\nFILE_URL = testing_data_config(\"images\", FILE_KEY, \"url\")\nTESTS_PATH = Path(__file__).parents[1].as_posix()\nFILE_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"temp_{FILE_KEY}.tiff\")\n\nFILE_PATH_SMALL_0 = os.path.join(TESTS_PATH, \"testing_data\", \"temp_wsi_inference_0.tiff\")\nFILE_PATH_SMALL_1 = os.path.join(TESTS_PATH, \"testing_data\", \"temp_wsi_inference_1.tiff\")\nARRAY_SMALL_0 = np.random.randint(low=0, high=255, size=(3, 4, 4), dtype=np.uint8)\nARRAY_SMALL_1 = np.random.randint(low=0, high=255, size=(3, 5, 5), dtype=np.uint8)\n\nTEST_CASE_SMALL_0 = [\n    {\"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0}], \"patch_size\": (2, 2)},\n    [\n        {\"image\": ARRAY_SMALL_0[:, :2, :2]},\n        {\"image\": ARRAY_SMALL_0[:, :2, 2:]},\n        {\"image\": ARRAY_SMALL_0[:, 2:, :2]},\n        {\"image\": ARRAY_SMALL_0[:, 2:, 2:]},\n    ],\n]\n\nTEST_CASE_SMALL_1 = [\n    {\"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (2, 2)}]},\n    [\n        {\"image\": ARRAY_SMALL_0[:, :2, :2]},\n        {\"image\": ARRAY_SMALL_0[:, :2, 2:]},\n        {\"image\": ARRAY_SMALL_0[:, 2:, :2]},\n        {\"image\": ARRAY_SMALL_0[:, 2:, 2:]},\n    ],\n]\n\nTEST_CASE_SMALL_2 = [\n    {\"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0}], \"patch_size\": (2, 2), \"overlap\": 0.5},\n    [\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 1:3]},\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 2:4]},\n        {\"image\": ARRAY_SMALL_0[:, 1:3, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 1:3, 1:3]},\n        {\"image\": ARRAY_SMALL_0[:, 1:3, 2:4]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 1:3]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 2:4]},\n    ],\n]\n\nTEST_CASE_SMALL_3 = [\n    {\"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0}], \"patch_size\": (3, 3), \"overlap\": 2.0 / 3.0},\n    [\n        {\"image\": ARRAY_SMALL_0[:, :3, :3]},\n        {\"image\": ARRAY_SMALL_0[:, :3, 1:]},\n        {\"image\": ARRAY_SMALL_0[:, 1:, :3]},\n        {\"image\": ARRAY_SMALL_0[:, 1:, 1:]},\n    ],\n]\n\nTEST_CASE_SMALL_4 = [\n    {\n        \"data\": [\n            {\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0},\n            {\"image\": FILE_PATH_SMALL_1, WSIPatchKeys.LEVEL: 0},\n        ],\n        \"patch_size\": (2, 2),\n    },\n    [\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 2:4]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 2:4]},\n        {\"image\": ARRAY_SMALL_1[:, 0:2, 0:2]},\n        {\"image\": ARRAY_SMALL_1[:, 0:2, 2:4]},\n        {\"image\": ARRAY_SMALL_1[:, 2:4, 0:2]},\n        {\"image\": ARRAY_SMALL_1[:, 2:4, 2:4]},\n    ],\n]\n\nTEST_CASE_SMALL_5 = [\n    {\n        \"data\": [\n            {\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (2, 2)},\n            {\"image\": FILE_PATH_SMALL_1, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (3, 3)},\n        ]\n    },\n    [\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 2:4]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 2:4]},\n        {\"image\": ARRAY_SMALL_1[:, 0:3, 0:3]},\n    ],\n]\n\nTEST_CASE_SMALL_6 = [\n    {\n        \"data\": [\n            {\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 1, WSIPatchKeys.SIZE: (1, 1)},\n            {\"image\": FILE_PATH_SMALL_1, WSIPatchKeys.LEVEL: 2, WSIPatchKeys.SIZE: (4, 4)},\n        ],\n        \"patch_size\": (2, 2),\n        \"patch_level\": 0,\n    },\n    [\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 0:2, 2:4]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 0:2]},\n        {\"image\": ARRAY_SMALL_0[:, 2:4, 2:4]},\n        {\"image\": ARRAY_SMALL_1[:, 0:2, 0:2]},\n        {\"image\": ARRAY_SMALL_1[:, 0:2, 2:4]},\n        {\"image\": ARRAY_SMALL_1[:, 2:4, 0:2]},\n        {\"image\": ARRAY_SMALL_1[:, 2:4, 2:4]},\n    ],\n]\n\nTEST_CASE_SMALL_7 = [\n    {\"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (2, 2)}], \"offset\": (1, 0)},\n    [{\"image\": ARRAY_SMALL_0[:, 1:3, :2]}, {\"image\": ARRAY_SMALL_0[:, 1:3, 2:]}],\n]\n\nTEST_CASE_SMALL_8 = [\n    {\n        \"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (2, 2)}],\n        \"offset\": \"random\",\n        \"offset_limits\": (0, 2),\n    },\n    [{\"image\": ARRAY_SMALL_0[:, 1:3, :2]}, {\"image\": ARRAY_SMALL_0[:, 1:3, 2:]}],\n]\n\nTEST_CASE_SMALL_9 = [\n    {\n        \"data\": [{\"image\": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (2, 2)}],\n        \"offset\": \"random\",\n        \"offset_limits\": ((0, 3), (0, 2)),\n    },\n    [{\"image\": ARRAY_SMALL_0[:, :2, 1:3]}, {\"image\": ARRAY_SMALL_0[:, 2:, 1:3]}],\n]\n\nTEST_CASE_LARGE_0 = [\n    {\"data\": [{\"image\": FILE_PATH, WSIPatchKeys.LEVEL: 8, WSIPatchKeys.SIZE: (64, 50)}]},\n    [\n        {\"step_loc\": (0, 0), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (0, 1), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (0, 2), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (1, 0), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (1, 1), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (1, 2), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n    ],\n]\n\nTEST_CASE_LARGE_1 = [\n    {\n        \"data\": [\n            {\"image\": FILE_PATH, WSIPatchKeys.LEVEL: 8, WSIPatchKeys.SIZE: (64, 50)},\n            {\"image\": FILE_PATH, WSIPatchKeys.LEVEL: 7, WSIPatchKeys.SIZE: (125, 110)},\n        ]\n    },\n    [\n        {\"step_loc\": (0, 0), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (0, 1), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (0, 2), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (1, 0), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (1, 1), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (1, 2), \"patch_size\": (64, 50), \"patch_level\": 8, \"ratio\": 257.06195068359375},\n        {\"step_loc\": (0, 0), \"patch_size\": (125, 110), \"patch_level\": 7, \"ratio\": 128.10186767578125},\n        {\"step_loc\": (0, 1), \"patch_size\": (125, 110), \"patch_level\": 7, \"ratio\": 128.10186767578125},\n        {\"step_loc\": (0, 2), \"patch_size\": (125, 110), \"patch_level\": 7, \"ratio\": 128.10186767578125},\n        {\"step_loc\": (1, 0), \"patch_size\": (125, 110), \"patch_level\": 7, \"ratio\": 128.10186767578125},\n        {\"step_loc\": (1, 1), \"patch_size\": (125, 110), \"patch_level\": 7, \"ratio\": 128.10186767578125},\n        {\"step_loc\": (1, 2), \"patch_size\": (125, 110), \"patch_level\": 7, \"ratio\": 128.10186767578125},\n    ],\n]\n\n\n@skipUnless(has_cucim or has_tiff, \"Requires cucim, openslide, or tifffile!\")\ndef setUpModule():\n    for info in [(ARRAY_SMALL_0, FILE_PATH_SMALL_0), (ARRAY_SMALL_1, FILE_PATH_SMALL_1)]:\n        array = info[0].transpose([1, 2, 0])\n        imwrite(info[1], array, shape=array.shape, photometric=\"rgb\")\n    hash_type = testing_data_config(\"images\", FILE_KEY, \"hash_type\")\n    hash_val = testing_data_config(\"images\", FILE_KEY, \"hash_val\")\n    download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)\n\n\nclass SlidingPatchWSIDatasetTests:\n    class Tests(unittest.TestCase):\n        backend = None\n\n        @parameterized.expand(\n            [\n                TEST_CASE_SMALL_0,\n                TEST_CASE_SMALL_1,\n                TEST_CASE_SMALL_2,\n                TEST_CASE_SMALL_3,\n                TEST_CASE_SMALL_4,\n                TEST_CASE_SMALL_5,\n                TEST_CASE_SMALL_6,\n                TEST_CASE_SMALL_7,\n                TEST_CASE_SMALL_8,\n                TEST_CASE_SMALL_9,\n            ]\n        )\n        def test_read_patches(self, input_parameters, expected):\n            if self.backend == \"openslide\":\n                return\n            dataset = SlidingPatchWSIDataset(reader=self.backend, **input_parameters)\n            self.assertEqual(len(dataset), len(expected))\n            for i, sample in enumerate(dataset):\n                self.assertTupleEqual(sample[\"image\"].shape, expected[i][\"image\"].shape)\n\n        @parameterized.expand([TEST_CASE_LARGE_0, TEST_CASE_LARGE_1])\n        def test_read_patches_large(self, input_parameters, expected):\n            dataset = SlidingPatchWSIDataset(reader=self.backend, **input_parameters)\n            self.assertEqual(len(dataset), len(expected))\n            for i, sample in enumerate(dataset):\n                self.assertEqual(sample[\"image\"].meta[WSIPatchKeys.LEVEL], expected[i][\"patch_level\"])\n                assert_array_equal(sample[\"image\"].meta[WSIPatchKeys.SIZE], expected[i][\"patch_size\"])\n                steps = [round(expected[i][\"ratio\"] * s) for s in expected[i][\"patch_size\"]]\n                expected_location = tuple(expected[i][\"step_loc\"][j] * steps[j] for j in range(len(steps)))\n                assert_array_equal(sample[\"image\"].meta[WSIPatchKeys.LOCATION], expected_location)\n\n\n@skipUnless(has_cucim, \"Requires cucim\")\nclass TestSlidingPatchWSIDatasetCuCIM(SlidingPatchWSIDatasetTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"cucim\"\n\n\n@skipUnless(has_osl, \"Requires openslide\")\nclass TestSlidingPatchWSIDatasetOpenSlide(SlidingPatchWSIDatasetTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"openslide\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_smartcachedataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import DataLoader, SmartCacheDataset\nfrom monai.transforms import Compose, Lambda, LoadImaged\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [0.1, 0, Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])])]\n\nTEST_CASE_2 = [0.1, 4, Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])])]\n\nTEST_CASE_3 = [0.1, None, Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])])]\n\nTEST_CASE_4 = [0.1, 4, None]\n\nTEST_CASE_5 = [0.5, 2, Compose([LoadImaged(keys=[\"image\", \"label\", \"extra\"])])]\n\n\nclass TestSmartCacheDataset(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_shape(self, replace_rate, num_replace_workers, transform):\n        test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]).astype(float), np.eye(4))\n        with tempfile.TemporaryDirectory() as tempdir:\n            nib.save(test_image, os.path.join(tempdir, \"test_image1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_label1.nii.gz\"))\n            nib.save(test_image, os.path.join(tempdir, \"test_extra1.nii.gz\"))\n            test_data = [\n                {\n                    \"image\": os.path.join(tempdir, \"test_image1.nii.gz\"),\n                    \"label\": os.path.join(tempdir, \"test_label1.nii.gz\"),\n                    \"extra\": os.path.join(tempdir, \"test_extra1.nii.gz\"),\n                }\n            ] * 20\n            dataset = SmartCacheDataset(\n                data=test_data,\n                transform=transform,\n                replace_rate=replace_rate,\n                cache_num=16,\n                num_init_workers=4,\n                num_replace_workers=num_replace_workers,\n            )\n            if transform is None:\n                # Check without providing transform\n                dataset2 = SmartCacheDataset(\n                    data=test_data,\n                    replace_rate=replace_rate,\n                    cache_num=16,\n                    num_init_workers=4,\n                    num_replace_workers=num_replace_workers,\n                )\n                for k in [\"image\", \"label\", \"extra\"]:\n                    self.assertEqual(dataset[0][k], dataset2[0][k])\n\n            self.assertEqual(len(dataset._cache), dataset.cache_num)\n            for i in range(dataset.cache_num):\n                self.assertIsNotNone(dataset._cache[i])\n\n            for _ in range(2):\n                dataset.start()\n                for _ in range(3):\n                    dataset.update_cache()\n                    self.assertIsNotNone(dataset[15])\n                    if isinstance(dataset[15][\"image\"], (np.ndarray, torch.Tensor)):\n                        assert_allclose(dataset[15][\"image\"], dataset[15][\"label\"])\n                    else:\n                        self.assertIsInstance(dataset[15][\"image\"], str)\n                dataset.shutdown()\n\n    def test_update_cache(self):\n        # Given\n        test_data = [{\"image\": f\"test_image{i}.nii.gz\", \"label\": f\"test_image{i}.nii.gz\"} for i in range(40)]\n        dataset = SmartCacheDataset(\n            data=test_data,\n            transform=None,\n            replace_rate=0.2,\n            cache_num=10,\n            num_init_workers=4,\n            num_replace_workers=4,\n            shuffle=False,\n        )\n        dataset.start()\n        start_num = int(0.2 * 10)\n        remain_num = int((1 - 0.2) * 10)\n\n        old_cache = copy.deepcopy(dataset._cache)\n        # When\n        with dataset._update_lock:\n            replacements = copy.deepcopy(dataset._replacements)\n        dataset.update_cache()\n        new_cache = dataset._cache\n        kept_cache = old_cache[start_num:]\n        # Then\n        for string1, string2 in zip(kept_cache, new_cache[0:remain_num]):\n            assert string1 == string2\n        for string_new, string_replacement in zip(replacements, new_cache[remain_num:]):\n            assert string_new == string_replacement\n\n    def test_shuffle(self):\n        test_data = [{\"image\": f\"test_image{i}.nii.gz\"} for i in range(20)]\n        dataset = SmartCacheDataset(\n            data=test_data,\n            transform=None,\n            replace_rate=0.1,\n            cache_num=16,\n            num_init_workers=4,\n            num_replace_workers=4,\n            shuffle=True,\n            seed=123,\n        )\n\n        dataset.start()\n        for i in range(3):\n            dataset.update_cache()\n\n            if i == 0:\n                self.assertEqual(dataset[15][\"image\"], \"test_image18.nii.gz\")\n            elif i == 1:\n                self.assertEqual(dataset[15][\"image\"], \"test_image13.nii.gz\")\n            else:\n                self.assertEqual(dataset[15][\"image\"], \"test_image5.nii.gz\")\n\n        dataset.shutdown()\n\n    @unittest.skip(\"https://github.com/Project-MONAI/MONAI/issues/5660 blocks the ci\")\n    def test_set_data(self):\n        data_list1 = list(range(10))\n\n        transform = Lambda(func=lambda x: np.array([x * 10]))\n\n        dataset = SmartCacheDataset(\n            data=data_list1,\n            transform=transform,\n            cache_rate=0.5,\n            replace_rate=0.4,\n            num_init_workers=4,\n            num_replace_workers=2,\n            shuffle=False,\n            progress=True,\n        )\n\n        num_workers = 2 if sys.platform == \"linux\" else 0\n        dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1)\n\n        dataset.start()\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list1[i] * 10]], d)\n        # replace cache content, move forward 2(5 * 0.4) items\n        dataset.update_cache()\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list1[i + 2] * 10]], d)\n        # shutdown to update data\n        dataset.shutdown()\n        # update the datalist and fill the cache content\n        data_list2 = list(range(-10, 0))\n        dataset.set_data(data=data_list2)\n        # restart the dataset\n        dataset.start()\n        # rerun with updated cache content\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list2[i] * 10]], d)\n        # replace cache content, move forward 2(5 * 0.4) items\n        dataset.update_cache()\n        for i, d in enumerate(dataloader):\n            np.testing.assert_allclose([[data_list2[i + 2] * 10]], d)\n        # finally shutdown the dataset\n        dataset.shutdown()\n\n    def test_datalist(self):\n        data_list = [np.array([i]) for i in range(5)]\n        data_list_backup = copy.copy(data_list)\n\n        SmartCacheDataset(data=data_list, transform=None, cache_rate=0.5, replace_rate=0.4, shuffle=True)\n        np.testing.assert_allclose(data_list, data_list_backup)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_synthetic.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import create_test_image_2d, create_test_image_3d\nfrom monai.utils import set_determinism\n\nTEST_CASES = [\n    [2, {\"width\": 64, \"height\": 64, \"rad_max\": 10, \"rad_min\": 4}, 0.1479004, 0.739502, (64, 64), 5],\n    [\n        2,\n        {\"width\": 28, \"height\": 32, \"num_objs\": 3, \"rad_max\": 5, \"rad_min\": 1, \"noise_max\": 0.2},\n        0.1709315,\n        0.4040179,\n        (32, 28),\n        5,\n    ],\n    [\n        3,\n        {\"width\": 64, \"height\": 64, \"depth\": 45, \"num_seg_classes\": 3, \"channel_dim\": -1, \"rad_max\": 10, \"rad_min\": 4},\n        0.025132,\n        0.0753961,\n        (64, 64, 45, 1),\n        3,\n    ],\n]\n\n\nclass TestDiceCELoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_create_test_image(self, dim, input_param, expected_img, expected_seg, expected_shape, expected_max_cls):\n        set_determinism(seed=0)\n        if dim == 2:\n            img, seg = create_test_image_2d(**input_param)\n        else:  # dim == 3\n            img, seg = create_test_image_3d(**input_param)\n        self.assertEqual(img.shape, expected_shape)\n        self.assertEqual(seg.max(), expected_max_cls)\n        np.testing.assert_allclose(img.mean(), expected_img, atol=1e-7, rtol=1e-7)\n        np.testing.assert_allclose(seg.mean(), expected_seg, atol=1e-7, rtol=1e-7)\n\n    def test_ill_radius(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            img, seg = create_test_image_2d(32, 32, rad_max=20)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            img, seg = create_test_image_3d(32, 32, 32, rad_max=10, rad_min=11)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            img, seg = create_test_image_2d(32, 32, rad_max=10, rad_min=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_thread_buffer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport time\nimport unittest\n\nimport torch\n\nfrom monai.data import DataLoader, Dataset, ThreadBuffer, ThreadDataLoader\nfrom monai.transforms import Compose, SimulateDelayd\nfrom monai.utils import PerfContext, set_determinism\nfrom tests.test_utils import assert_allclose\n\n\nclass TestDataLoader(unittest.TestCase):\n    def setUp(self):\n        super().setUp()\n\n        self.datalist = [\n            {\"image\": \"spleen_19.nii.gz\", \"label\": \"spleen_label_19.nii.gz\"},\n            {\"image\": \"spleen_31.nii.gz\", \"label\": \"spleen_label_31.nii.gz\"},\n        ]\n\n        self.transform = Compose([SimulateDelayd(keys=[\"image\", \"label\"], delay_time=0.1)])\n\n    def test_values(self):\n        dataset = Dataset(data=self.datalist, transform=self.transform)\n        dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0)\n\n        tbuffer = ThreadBuffer(dataloader)\n\n        for d in tbuffer:\n            self.assertEqual(d[\"image\"][0], \"spleen_19.nii.gz\")\n            self.assertEqual(d[\"image\"][1], \"spleen_31.nii.gz\")\n            self.assertEqual(d[\"label\"][0], \"spleen_label_19.nii.gz\")\n            self.assertEqual(d[\"label\"][1], \"spleen_label_31.nii.gz\")\n\n    def test_dataloader(self):\n        dataset = Dataset(data=self.datalist, transform=self.transform)\n        dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=0)\n\n        for d in dataloader:\n            self.assertEqual(d[\"image\"][0], \"spleen_19.nii.gz\")\n            self.assertEqual(d[\"image\"][1], \"spleen_31.nii.gz\")\n\n        for d in dataloader:\n            self.assertEqual(d[\"label\"][0], \"spleen_label_19.nii.gz\")\n            self.assertEqual(d[\"label\"][1], \"spleen_label_31.nii.gz\")\n\n    def test_deterministic(self):\n        set_determinism(0)\n        res_1 = list(ThreadDataLoader(torch.arange(5), batch_size=2, buffer_size=2, shuffle=True, num_workers=0))\n\n        set_determinism(0)\n        num_workers = 2 if sys.platform == \"linux\" else 1\n        res_2 = list(\n            ThreadDataLoader(torch.arange(5), batch_size=2, buffer_size=3, shuffle=True, num_workers=num_workers)\n        )\n\n        set_determinism(None)\n        assert_allclose(torch.cat(res_1), torch.cat(res_2), type_test=False)\n\n    def test_time(self):\n        dataset = Dataset(data=self.datalist * 2, transform=self.transform)  # contains data for 2 batches\n        dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0)\n\n        tbuffer = ThreadBuffer(dataloader)\n\n        with PerfContext() as pc:\n            for _ in dataloader:\n                time.sleep(0.5)  # each batch takes 0.8 s to generate on top of this time\n\n        unbuffered_time = pc.total_time\n\n        with PerfContext() as pc:\n            for _ in tbuffer:\n                time.sleep(0.5)  # while \"computation\" is happening the next batch is being generated, saving 0.4 s\n\n        buffered_time = pc.total_time\n        if sys.platform == \"darwin\":  # skip macOS measure\n            print(f\"darwin: Buffered time {buffered_time} vs unbuffered time {unbuffered_time}\")\n        else:\n            self.assertTrue(\n                buffered_time < unbuffered_time,\n                f\"Buffered time {buffered_time} should be less than unbuffered time {unbuffered_time}\",\n            )\n\n    def test_dataloader_repeats(self):\n        dataset = Dataset(data=self.datalist, transform=self.transform)\n        dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=0, repeats=2)\n\n        previous_batch = None\n\n        for d in dataloader:\n            self.assertEqual(d[\"image\"][0], \"spleen_19.nii.gz\")\n            self.assertEqual(d[\"image\"][1], \"spleen_31.nii.gz\")\n\n            if previous_batch is None:\n                previous_batch = d\n            else:\n                self.assertTrue(previous_batch is d, \"Batch object was not repeated\")\n                previous_batch = None\n\n    def test_thread_workers(self):\n        dataset = Dataset(data=self.datalist, transform=self.transform)\n        dataloader = ThreadDataLoader(dataset=dataset, batch_size=2, num_workers=2, use_thread_workers=True)\n\n        for d in dataloader:\n            self.assertEqual(d[\"image\"][0], \"spleen_19.nii.gz\")\n            self.assertEqual(d[\"image\"][1], \"spleen_31.nii.gz\")\n            self.assertEqual(d[\"label\"][0], \"spleen_label_19.nii.gz\")\n            self.assertEqual(d[\"label\"][1], \"spleen_label_31.nii.gz\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_threadcontainer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport time\nimport unittest\nfrom pathlib import Path\n\nimport torch\n\nfrom monai.data import DataLoader\nfrom monai.utils import optional_import, set_determinism\nfrom monai.utils.enums import CommonKeys\nfrom tests.test_utils import SkipIfNoModule\n\ntry:\n    _, has_ignite = optional_import(\"ignite\")\n\n    from monai.engines import SupervisedTrainer\n    from monai.handlers import MetricLogger\n    from monai.utils import ThreadContainer\nexcept ImportError:\n    has_ignite = False\n\ncompare_images, _ = optional_import(\"matplotlib.testing.compare\", name=\"compare_images\")\n\n\nclass TestThreadContainer(unittest.TestCase):\n    @SkipIfNoModule(\"ignite\")\n    def test_container(self):\n        net = torch.nn.Conv2d(1, 1, 3, padding=1)\n\n        opt = torch.optim.Adam(net.parameters())\n\n        img = torch.rand(1, 16, 16)\n        data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img}\n        loader = DataLoader([data for _ in range(10)])\n\n        trainer = SupervisedTrainer(\n            device=torch.device(\"cpu\"),\n            max_epochs=1,\n            train_data_loader=loader,\n            network=net,\n            optimizer=opt,\n            loss_function=torch.nn.L1Loss(),\n        )\n\n        con = ThreadContainer(trainer)\n        con.start()\n        time.sleep(1)  # wait for trainer to start\n\n        self.assertTrue(con.is_alive)\n        self.assertIsNotNone(con.status())\n        self.assertGreater(len(con.status_dict), 0)\n\n        con.join()\n\n    @SkipIfNoModule(\"ignite\")\n    @SkipIfNoModule(\"matplotlib\")\n    def test_plot(self):\n        set_determinism(0)\n        test_dir = Path(__file__).parents[1]\n        testing_dir = os.path.join(test_dir, \"testing_data\")\n\n        net = torch.nn.Conv2d(1, 1, 3, padding=1)\n\n        opt = torch.optim.Adam(net.parameters())\n\n        img = torch.rand(1, 16, 16)\n\n        # a third non-image key is added to test that this is correctly ignored when plotting\n        data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img, \"Not Image Data\": [\"This isn't an image\"]}\n\n        loader = DataLoader([data] * 20, batch_size=2)\n\n        trainer = SupervisedTrainer(\n            device=torch.device(\"cpu\"),\n            max_epochs=1,\n            train_data_loader=loader,\n            network=net,\n            optimizer=opt,\n            loss_function=torch.nn.L1Loss(),\n        )\n\n        logger = MetricLogger()\n        logger.attach(trainer)\n\n        con = ThreadContainer(trainer)\n        con.start()\n        con.join()\n\n        fig = con.plot_status(logger)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            tempimg = f\"{tempdir}/threadcontainer_plot_test.png\"\n            fig.savefig(tempimg)\n            comp = compare_images(f\"{testing_dir}/threadcontainer_plot_test.png\", tempimg, 5e-2)\n\n            self.assertIsNone(comp, comp)  # None indicates test passed\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_video_datasets.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport torch\n\nimport monai.transforms as mt\nfrom monai.data.dataloader import DataLoader\nfrom monai.data.video_dataset import CameraDataset, VideoDataset, VideoFileDataset\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config\n\ncv2, has_cv2 = optional_import(\"cv2\")\n\nNUM_CAPTURE_DEVICES = CameraDataset.get_num_devices()\nTRANSFORMS = mt.Compose(\n    [mt.EnsureChannelFirst(True, \"no_channel\"), mt.DivisiblePad(16), mt.ScaleIntensity(), mt.CastToType(torch.float32)]\n)\n\n\nclass Base:\n    class TestVideoDataset(unittest.TestCase):\n        video_source: int | str\n        ds: type[VideoDataset]\n\n        def get_video_source(self):\n            return self.video_source\n\n        def get_ds(self, *args, **kwargs) -> VideoDataset:\n            return self.ds(*args, video_source=self.get_video_source(), transform=TRANSFORMS, **kwargs)  # type: ignore\n\n        @unittest.skipIf(has_cv2, \"Only tested when OpenCV not installed.\")\n        def test_no_opencv_raises(self):\n            with self.assertRaises(RuntimeError):\n                _ = self.get_ds(max_num_frames=10)\n\n        @unittest.skipUnless(has_cv2, \"OpenCV required.\")\n        def test_multiprocessing(self):\n            for num_workers in (0, 2):\n                multiprocessing = num_workers > 0\n                ds = self.get_ds(max_num_frames=100, multiprocessing=multiprocessing)\n                dl = DataLoader(ds, num_workers=num_workers, batch_size=2)\n                _ = next(iter(dl))\n\n        @unittest.skipUnless(has_cv2, \"OpenCV required.\")\n        def test_multiple_sources(self, should_match: bool = True):\n            ds1 = self.get_ds()\n            ds2 = self.get_ds()\n            if should_match:\n                assert_allclose(ds1.get_frame(), ds2.get_frame())\n\n        @unittest.skipUnless(has_cv2, \"OpenCV required.\")\n        def test_dataset(self, known_num_frames=None, known_fps=None):\n            num_frames = (10,) if known_num_frames is None else (10, None)\n            for max_num_frames in num_frames:\n                ds = self.get_ds(max_num_frames=max_num_frames)\n                if known_fps is not None:\n                    self.assertEqual(ds.get_fps(), known_fps)\n                frames = list(ds)\n                if max_num_frames is not None:\n                    self.assertEqual(len(frames), max_num_frames)\n                elif known_num_frames is not None:\n                    self.assertEqual(len(frames), len(ds))\n                for f in frames:\n                    self.assertTupleEqual(f.shape, frames[0].shape)\n\n\n@unittest.skipIf(NUM_CAPTURE_DEVICES == 0, \"At least one capture device required.\")\nclass TestCameraDataset(Base.TestVideoDataset):\n    video_source = 0\n    ds = CameraDataset\n\n    @unittest.skipUnless(has_cv2, \"OpenCV required.\")\n    def test_multiple_sources(self):\n        super().test_multiple_sources(should_match=False)\n\n    @unittest.skipUnless(has_cv2, \"OpenCV required.\")\n    def test_device_out_of_range(self):\n        capture_device = NUM_CAPTURE_DEVICES + 1\n        with self.assertRaises(RuntimeError):\n            _ = CameraDataset(capture_device, TRANSFORMS, 0)\n\n\nclass TestVideoFileDataset(Base.TestVideoDataset):\n    ds = VideoFileDataset\n\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        codecs = VideoFileDataset.get_available_codecs()\n        if \".mp4\" in codecs.values():\n            fname = \"endo.mp4\"\n            config = testing_data_config(\"videos\", \"endovis\")\n            cls.known_fps = 2.0\n            cls.known_num_frames = 23\n        elif \".avi\" in codecs.values():\n            fname = \"ultrasound.avi\"\n            config = testing_data_config(\"videos\", \"ultrasound\")\n            cls.known_fps = 2.0\n            cls.known_num_frames = 523\n        else:\n            cls.known_fps = None\n            cls.known_num_frames = None\n            cls.video_source = None\n            return\n        tests_path = Path(__file__).parents[1].as_posix()\n        cls.video_source = os.path.join(tests_path, \"testing_data\", fname)\n        download_url_or_skip_test(\n            url=config[\"url\"],\n            filepath=cls.video_source,\n            hash_val=config.get(\"hash_val\"),\n            hash_type=config.get(\"hash_type\", \"sha256\"),\n        )\n\n    @unittest.skipUnless(has_cv2, \"OpenCV required.\")\n    def test_dataset(self):\n        super().test_dataset(self.known_num_frames, self.known_fps)\n        self.assertEqual(self.get_ds().get_num_frames(), self.known_num_frames)\n\n    def test_available_codecs(self):\n        codecs = VideoFileDataset.get_available_codecs()\n        if not has_cv2:\n            self.assertEqual(codecs, {})\n        else:\n            self.assertGreaterEqual(len(codecs), 0)\n\n    def get_video_source(self):\n        if self.video_source is None:\n            raise unittest.SkipTest(\"missing required codecs\")\n        return super().get_video_source()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_weighted_random_sampler_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom monai.data import DistributedWeightedRandomSampler\nfrom tests.test_utils import DistCall, DistTestCase, skip_if_darwin, skip_if_windows\n\n\n@skip_if_windows\n@skip_if_darwin\nclass DistributedWeightedRandomSamplerTest(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_sampling(self):\n        data = [1, 2, 3, 4, 5]\n        weights = [1, 2, 3, 4, 5]\n        sampler = DistributedWeightedRandomSampler(\n            weights=weights, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(0)\n        )\n        samples = np.array([data[i] for i in list(sampler)])\n\n        if dist.get_rank() == 0:\n            np.testing.assert_allclose(samples, np.array([5, 5, 5]))\n\n        if dist.get_rank() == 1:\n            np.testing.assert_allclose(samples, np.array([1, 4, 4]))\n\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_num_samples(self):\n        data = [1, 2, 3, 4, 5]\n        weights = [1, 2, 3, 4, 5]\n        sampler = DistributedWeightedRandomSampler(\n            weights=weights,\n            num_samples_per_rank=5,\n            dataset=data,\n            shuffle=False,\n            generator=torch.Generator().manual_seed(123),\n        )\n        samples = np.array([data[i] for i in list(sampler)])\n\n        if dist.get_rank() == 0:\n            np.testing.assert_allclose(samples, np.array([3, 1, 5, 1, 5]))\n\n        if dist.get_rank() == 1:\n            np.testing.assert_allclose(samples, np.array([4, 2, 4, 2, 4]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/test_zipdataset.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import ZipDataset\n\n\nclass Dataset_(torch.utils.data.Dataset):\n\n    def __init__(self, length, index_only=True):\n        self.len = length\n        self.index_only = index_only\n\n    def __len__(self):\n        return self.len\n\n    def __getitem__(self, index):\n        if self.index_only:\n            return index\n        return 1, 2, index\n\n\nTEST_CASE_1 = [[Dataset_(5), Dataset_(5), Dataset_(5)], None, (0, 0, 0), 5]\n\nTEST_CASE_2 = [[Dataset_(3), Dataset_(4), Dataset_(5)], None, (0, 0, 0), 3]\n\nTEST_CASE_3 = [[Dataset_(3), Dataset_(4, index_only=False), Dataset_(5)], None, (0, 1, 2, 0, 0), 3]\n\nTEST_CASE_4 = [\n    [Dataset_(3), Dataset_(4, index_only=False), Dataset_(5)],\n    lambda x: [i + 1 for i in x],\n    (1, 2, 3, 1, 1),\n    3,\n]\n\n\nclass TestZipDataset(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_value(self, datasets, transform, expected_output, expected_length):\n        test_dataset = ZipDataset(datasets=datasets, transform=transform)\n        self.assertEqual(test_dataset[0], expected_output)\n        self.assertEqual(len(test_dataset), expected_length)\n\n    def test_slicing(self):\n        test_dataset = ZipDataset(datasets=[Dataset_(5), Dataset_(5), Dataset_(5)], transform=None)\n        subset = test_dataset[0:2]\n        self.assertEqual(subset[-1], (1, 1, 1))\n        self.assertEqual(len(subset), 2)\n\n    def test_sequence(self):\n        test_dataset = ZipDataset(datasets=[Dataset_(5), Dataset_(5), Dataset_(5)], transform=None)\n        subset = test_dataset[[1, 3, 4]]\n        self.assertEqual(subset[-1], (4, 4, 4))\n        self.assertEqual(len(subset), 3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/utils/test_decollate.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\nfrom enum import Enum\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import CacheDataset, DataLoader, Dataset, create_test_image_2d\nfrom monai.data.utils import decollate_batch\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirst,\n    EnsureChannelFirstd,\n    LoadImage,\n    LoadImaged,\n    RandAffine,\n    RandFlip,\n    RandFlipd,\n    RandRotate90,\n    SpatialPad,\n    SpatialPadd,\n    ToTensor,\n    ToTensord,\n)\nfrom monai.transforms.inverse_batch_transform import Decollated\nfrom monai.transforms.spatial.dictionary import RandAffined, RandRotate90d\nfrom monai.utils import optional_import, set_determinism\nfrom monai.utils.enums import PostFix, TraceKeys\nfrom tests.test_utils import make_nifti_image\n\n_, has_nib = optional_import(\"nibabel\")\n\nKEYS = [\"image\"]\n\nTESTS_DICT: list[tuple] = []\nTESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))\nTESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))\nTESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),))\n\nTESTS_LIST: list[tuple] = []\nTESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1)))\nTESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),))\nTESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),))\n\nTEST_BASIC = [\n    [(\"channel\", \"channel\"), [\"channel\", \"channel\"]],\n    [torch.Tensor([1, 2, 3]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]],\n    [\n        [[torch.Tensor((1.0, 2.0, 3.0)), torch.Tensor((2.0, 3.0, 1.0))]],\n        [\n            [[torch.tensor(1.0), torch.tensor(2.0)]],\n            [[torch.tensor(2.0), torch.tensor(3.0)]],\n            [[torch.tensor(3.0), torch.tensor(1.0)]],\n        ],\n    ],\n    [torch.Tensor((True, True, False, False)), [1.0, 1.0, 0.0, 0.0]],\n    [\n        [torch.Tensor([1]), torch.Tensor([2]), torch.Tensor([3])],\n        [[torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]],\n    ],\n    [[None, None], [None, None]],\n    [[\"test\"], [\"test\"]],\n    [np.array([64, 64]), [64, 64]],\n    [[], []],\n    [[(\"ch1\", \"ch2\"), (\"ch3\",)], [[\"ch1\", \"ch3\"], [\"ch2\", None]]],  # default pad None\n]\n\n\nclass TestDeCollate(unittest.TestCase):\n    def setUp(self) -> None:\n        set_determinism(seed=0)\n\n        im = create_test_image_2d(100, 101)[0]\n        self.data_dict = [{\"image\": make_nifti_image(im) if has_nib else im} for _ in range(6)]\n        self.data_list = [make_nifti_image(im) if has_nib else im for _ in range(6)]\n\n    def tearDown(self) -> None:\n        set_determinism(None)\n\n    def check_match(self, in1, in2):\n        if isinstance(in1, dict):\n            self.assertTrue(isinstance(in2, dict))\n            for (k1, v1), (k2, v2) in zip(in1.items(), in2.items()):\n                if isinstance(k1, Enum) and isinstance(k2, Enum):\n                    k1, k2 = k1.value, k2.value\n                self.check_match(k1, k2)\n                # Transform ids won't match for windows with multiprocessing, so don't check values\n                if k1 == TraceKeys.ID and sys.platform in [\"darwin\", \"win32\"]:\n                    continue\n                if not (isinstance(k1, str) and k1.endswith(\"_transforms\")):\n                    self.check_match(v1, v2)  # transform stack not necessarily match\n        elif isinstance(in1, (list, tuple)):\n            for l1, l2 in zip(in1, in2):\n                self.check_match(l1, l2)\n        elif isinstance(in1, (str, int)):\n            self.assertEqual(in1, in2)\n        elif isinstance(in1, (torch.Tensor, np.ndarray)):\n            np.testing.assert_array_equal(in1, in2)\n        else:\n            raise RuntimeError(f\"Not sure how to compare types. type(in1): {type(in1)}, type(in2): {type(in2)}\")\n\n    def check_decollate(self, dataset):\n        batch_size = 2\n        num_workers = 2 if sys.platform == \"linux\" else 0\n\n        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n\n        for b, batch_data in enumerate(loader):\n            decollated_1 = decollate_batch(batch_data)\n            decollated_2 = Decollated(detach=True)(batch_data)\n\n            for decollated in [decollated_1, decollated_2]:\n                for i, d in enumerate(decollated):\n                    self.check_match(dataset[b * batch_size + i], d)\n\n    @parameterized.expand(TESTS_DICT)\n    def test_decollation_dict(self, *transforms):\n        t_compose = Compose([EnsureChannelFirstd(KEYS, channel_dim=\"no_channel\"), Compose(transforms), ToTensord(KEYS)])\n        # If nibabel present, read from disk\n        if has_nib:\n            t_compose = Compose([LoadImaged(\"image\", image_only=True), t_compose])\n\n        dataset = CacheDataset(self.data_dict, t_compose, progress=False)\n        self.check_decollate(dataset=dataset)\n\n    @parameterized.expand(TESTS_LIST)\n    def test_decollation_tensor(self, *transforms):\n        t_compose = Compose([EnsureChannelFirst(channel_dim=\"no_channel\"), Compose(transforms), ToTensor()])\n        # If nibabel present, read from disk\n        if has_nib:\n            t_compose = Compose([LoadImage(image_only=True), t_compose])\n\n        dataset = Dataset(self.data_list, t_compose)\n        self.check_decollate(dataset=dataset)\n\n    @parameterized.expand(TESTS_LIST)\n    def test_decollation_list(self, *transforms):\n        t_compose = Compose([EnsureChannelFirst(channel_dim=\"no_channel\"), Compose(transforms), ToTensor()])\n        # If nibabel present, read from disk\n        if has_nib:\n            t_compose = Compose([LoadImage(image_only=True), t_compose])\n\n        dataset = Dataset(self.data_list, t_compose)\n        self.check_decollate(dataset=dataset)\n\n\nclass TestBasicDeCollate(unittest.TestCase):\n    @parameterized.expand(TEST_BASIC)\n    def test_decollation_examples(self, input_val, expected_out):\n        out = decollate_batch(input_val)\n        self.assertListEqual(expected_out, out)\n\n    def test_dict_examples(self):\n        test_case = {\"meta\": {\"out\": [\"test\", \"test\"]}, PostFix.meta(\"image\"): {\"scl_slope\": torch.Tensor((0.0, 0.0))}}\n        out = decollate_batch(test_case)\n        self.assertEqual(out[0][\"meta\"][\"out\"], \"test\")\n        self.assertEqual(out[0][PostFix.meta(\"image\")][\"scl_slope\"], 0.0)\n\n        test_case = [torch.ones((2, 1, 10, 10)), torch.ones((2, 3, 5, 5))]\n        out = decollate_batch(test_case)\n        self.assertTupleEqual(out[0][0].shape, (1, 10, 10))\n        self.assertTupleEqual(out[0][1].shape, (3, 5, 5))\n\n        test_case = torch.rand((2, 1, 10, 10))\n        out = decollate_batch(test_case)\n        self.assertTupleEqual(out[0].shape, (1, 10, 10))\n\n        test_case = [torch.tensor(0), torch.tensor(0)]\n        out = decollate_batch(test_case, detach=True)\n        self.assertListEqual([0, 0], out)\n        self.assertFalse(isinstance(out[0], torch.Tensor))\n\n        test_case = {\"a\": [torch.tensor(0), torch.tensor(0)]}\n        out = decollate_batch(test_case, detach=False)\n        self.assertListEqual([{\"a\": torch.tensor(0)}, {\"a\": torch.tensor(0)}], out)\n        self.assertTrue(isinstance(out[0][\"a\"], torch.Tensor))\n\n        test_case = [torch.tensor(0), torch.tensor(0)]\n        out = decollate_batch(test_case, detach=False)\n        self.assertListEqual(test_case, out)\n\n        test_case = {\n            \"image\": torch.tensor([[[1, 2]], [[3, 4]]]),\n            \"label\": torch.tensor([[[5, 6]], [[7, 8]]]),\n            \"pred\": torch.tensor([[[9, 10]], [[11, 12]]]),\n            \"out\": [\"test\"],\n        }\n        out = decollate_batch(test_case, detach=False)\n        self.assertEqual(out[0][\"out\"], \"test\")\n\n        test_case = {\n            \"image\": torch.tensor([[[1, 2, 3]], [[3, 4, 5]]]),\n            \"label\": torch.tensor([[[5]], [[7]]]),\n            \"out\": [\"test\"],\n        }\n        out = decollate_batch(test_case, detach=False, pad=False)\n        self.assertEqual(len(out), 1)  # no padding\n        out = decollate_batch(test_case, detach=False, pad=True, fill_value=0)\n        self.assertEqual(out[1][\"out\"], 0)  # verify padding fill_value\n\n    def test_decollated(self):\n        test_case = {\n            \"image\": torch.tensor([[[1, 2]], [[3, 4]]]),\n            \"meta\": {\"out\": [\"test\", \"test\"]},\n            PostFix.meta(\"image\"): {\"scl_slope\": torch.Tensor((0.0, 0.0))},\n            \"loss\": 0.85,\n        }\n        transform = Decollated(keys=[\"meta\", PostFix.meta(\"image\")], detach=False)\n        out = transform(test_case)\n        self.assertFalse(\"loss\" in out)\n        self.assertEqual(out[0][\"meta\"][\"out\"], \"test\")\n        self.assertEqual(out[0][PostFix.meta(\"image\")][\"scl_slope\"], 0.0)\n        self.assertTrue(isinstance(out[0][PostFix.meta(\"image\")][\"scl_slope\"], torch.Tensor))\n        # decollate all data with keys=None\n        transform = Decollated(keys=None, detach=True)\n        out = transform(test_case)\n        self.assertEqual(out[1][\"loss\"], 0.85)\n        self.assertEqual(out[0][\"meta\"][\"out\"], \"test\")\n        self.assertEqual(out[0][PostFix.meta(\"image\")][\"scl_slope\"], 0.0)\n        self.assertTrue(isinstance(out[0][PostFix.meta(\"image\")][\"scl_slope\"], float))\n\n        # test list input\n        test_case = [\n            torch.tensor([[[1, 2]], [[3, 4]]]),\n            {\"out\": [\"test\", \"test\"]},\n            {\"scl_slope\": torch.Tensor((0.0, 0.0))},\n            {\"out2\": [\"test1\"]},\n            0.85,\n            [],\n        ]\n        transform = Decollated(keys=None, detach=False, fill_value=-1)\n        out = transform(test_case)\n\n        self.assertEqual(out[0][-2], 0.85)  # scalar replicates\n        self.assertEqual(out[1][-2], 0.85)  # scalar replicates\n        self.assertEqual(out[1][-3], -1)  # fill value for the dictionary item\n        self.assertEqual(out[0][1][\"out\"], \"test\")\n        self.assertEqual(out[0][2][\"scl_slope\"], 0.0)\n        self.assertTrue(isinstance(out[0][2][\"scl_slope\"], torch.Tensor))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/utils/test_dev_collate.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.utils import dev_collate\n\nTEST_CASES = [\n    [\n        [\n            {\"img\": 2, \"meta\": {\"shape\": [torch.tensor(1.0)]}},\n            {\"img\": 3, \"meta\": {\"shape\": [np.asarray(1.0)]}},\n            {\"img\": 4, \"meta\": {\"shape\": [torch.tensor(1.0)]}},\n        ],\n        \"got numpy.ndarray\",\n    ],\n    [[[\"img\", np.array([2])], [\"img\", np.array([3, 4])], [\"img\", np.array([4])]], \"size\"],\n    [[[\"img\", [2]], [\"img\", [3, 4]], [\"img\", 4]], \"type\"],\n    [[[\"img\", [2, 2]], [\"img\", [3, 4]], [\"img\", 4]], \"type\"],\n]\n\n\nclass DevCollateTest(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_dev_collate(self, inputs, msg):\n        with self.assertLogs(level=logging.CRITICAL) as log:\n            dev_collate(inputs)\n            self.assertRegex(\" \".join(log.output), f\"{msg}\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/utils/test_file_basename.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom monai.data.utils import create_file_basename\n\n\nclass TestFilename(unittest.TestCase):\n\n    def test_value(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            output_tmp = os.path.join(tempdir, \"output\")\n            result = create_file_basename(\"\", \"test.txt\", output_tmp, \"\")\n            expected = os.path.join(output_tmp, \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", os.path.join(\"foo\", \"test.txt\"), output_tmp, \"\")\n            expected = os.path.join(output_tmp, \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", os.path.join(\"foo\", \"test.txt\"), output_tmp, \"foo\")\n            expected = os.path.join(output_tmp, \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", os.path.join(\"foo\", \"bar\", \"test.txt\"), output_tmp, \"foo\")\n            expected = os.path.join(output_tmp, \"bar\", \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\n                postfix=\"\",\n                input_file_name=os.path.join(\"foo\", \"bar\", \"data\", \"test.txt\"),\n                folder_path=output_tmp,\n                data_root_dir=os.path.join(\"foo\", \"bar\"),\n            )\n            expected = os.path.join(output_tmp, \"data\", \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", os.path.join(\"foo\", \"bar\", \"test.txt\"), output_tmp, \"bar\")\n            expected = os.path.join(tempdir, \"foo\", \"bar\", \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", os.path.join(\"rest\", \"test.txt\"), output_tmp, \"rest\")\n            expected = os.path.join(tempdir, \"output\", \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", \"test.txt\", output_tmp, \"foo\")\n            expected = os.path.join(output_tmp, \"test\", \"test\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"\", \"test.txt\", output_tmp, \"foo\", False, 5)\n            expected = os.path.join(output_tmp, \"test_5\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"post\", \"test.tar.gz\", output_tmp, \"foo\")\n            expected = os.path.join(output_tmp, \"test\", \"test_post\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"post\", \"test.tar.gz\", output_tmp, \"foo\", True, 8)\n            expected = os.path.join(output_tmp, \"test\", \"test_post_8\")\n            self.assertEqual(result, expected)\n\n            result = create_file_basename(\"post\", Path(\"test.tar.gz\"), Path(output_tmp), Path(\"foo\"), True, 8)\n            expected = os.path.join(output_tmp, \"test\", \"test_post_8\")\n            self.assertEqual(result, expected)\n\n    def test_relative_path(self):\n        output = create_file_basename(\"\", \"test.txt\", \"output\", \"\", makedirs=False)\n        expected = os.path.join(\"output\", \"test\", \"test\")\n        self.assertEqual(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/utils/test_ori_ras_lps.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.utils import orientation_ras_lps\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASES_AFFINE = []\nfor p in TEST_NDARRAYS:\n    case_1d = p([[1.0, 0.0], [1.0, 1.0]]), p([[-1.0, 0.0], [1.0, 1.0]])\n    TEST_CASES_AFFINE.append(case_1d)\n    case_2d_1 = (p([[1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]), p([[-1.0, 0.0, -1.0], [1.0, 1.0, 1.0]]))\n    TEST_CASES_AFFINE.append(case_2d_1)\n    case_2d_2 = (\n        p([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),\n        p([[-1.0, 0.0, -1.0], [0.0, -1.0, -1.0], [1.0, 1.0, 1.0]]),\n    )\n    TEST_CASES_AFFINE.append(case_2d_2)\n    case_3d = (\n        p([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 3.0]]),\n        p([[-1.0, 0.0, -1.0, -1.0], [0.0, -1.0, -1.0, -2.0], [1.0, 1.0, 1.0, 3.0]]),\n    )\n    TEST_CASES_AFFINE.append(case_3d)\n    case_4d = p(np.ones((5, 5))), p([[-1] * 5, [-1] * 5, [1] * 5, [1] * 5, [1] * 5])\n    TEST_CASES_AFFINE.append(case_4d)\n\n\nclass TestITKWriter(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_AFFINE)\n    def test_ras_to_lps(self, param, expected):\n        assert_allclose(orientation_ras_lps(param), expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data/utils/test_zoom_affine.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.utils import zoom_affine\n\nVALID_CASES = [\n    (\n        np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]),\n        (10, 20, 30),\n        np.array([[8.94427191, -8.94427191, 0], [-4.47213595, -17.88854382, 0], [0.0, 0.0, 1.0]]),\n    ),\n    (\n        np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]),\n        (10, 20, 30),\n        np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]),\n    ),\n    (\n        np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]),\n        (10, 20),\n        np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]),\n    ),\n    (\n        np.array([[1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]),\n        (10,),\n        np.array([[10, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 1]]),\n    ),\n    (\n        [[1, 0, 10], [0, 1, 20], [0, 0, 1]]\n        @ ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[2, 0.3, 0], [0, 3, 0], [0, 0, 1]])),\n        (4, 5, 6),\n        ([[0, -1, 0], [1, 0, 0], [0, 0, 1]] @ np.array([[4, 0, 0], [0, 5, 0], [0, 0, 1]])),\n    ),\n]\n\nDIAGONAL_CASES = [\n    (\n        np.array([[-1, 0, 0, 4], [0, 2, 0, 5], [0, 0, 3, 6], [0, 0, 0, 1]]),\n        (10, 20, 30),\n        np.array([[10, 0, 0, 0], [0, 20, 0, 0], [0, 0, 30, 0], [0, 0, 0, 1]]),\n    ),\n    (np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]), (10, 20, 30), np.array([[10, 0, 0], [0, 20, 0], [0.0, 0.0, 1.0]])),\n    (  # test default scale from affine\n        np.array([[2, 1, 4], [-1, -3, 5], [0, 0, 1]]),\n        (10,),\n        np.array([[10, 0, 0], [0, 3.162278, 0], [0.0, 0.0, 1.0]]),\n    ),\n]\n\n\nclass TestZoomAffine(unittest.TestCase):\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct(self, affine, scale, expected):\n        output = zoom_affine(affine, scale, diagonal=False)\n        ornt_affine = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(output))\n        ornt_output = nib.orientations.ornt2axcodes(nib.orientations.io_orientation(affine))\n        np.testing.assert_array_equal(ornt_affine, ornt_output)\n        np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6)\n\n    @parameterized.expand(DIAGONAL_CASES)\n    def test_diagonal(self, affine, scale, expected):\n        output = zoom_affine(affine, scale, diagonal=True)\n        np.testing.assert_allclose(output, expected, rtol=1e-6, atol=1e-6)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/engines/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/engines/test_ensemble_evaluator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import EventEnum, Events\nfrom parameterized import parameterized\n\nfrom monai.engines import EnsembleEvaluator\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [[\"pred_0\", \"pred_1\", \"pred_2\", \"pred_3\", \"pred_4\"]]\n\nTEST_CASE_2 = [None]\n\n\nclass TestEnsembleEvaluator(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_content(self, pred_keys):\n        device = torch.device(\"cpu:0\")\n\n        class TestDataset(torch.utils.data.Dataset):\n            def __len__(self):\n                return 8\n\n            def __getitem__(self, index):\n                return {\"image\": torch.tensor([index]), \"label\": torch.zeros(1)}\n\n        val_loader = torch.utils.data.DataLoader(TestDataset())\n\n        class TestNet(torch.nn.Module):\n            def __init__(self, func):\n                super().__init__()\n                self.func = func\n\n            def forward(self, x):\n                return self.func(x)\n\n        net0 = TestNet(lambda x: x + 1)\n        net1 = TestNet(lambda x: x + 2)\n        net2 = TestNet(lambda x: x + 3)\n        net3 = TestNet(lambda x: x + 4)\n        net4 = TestNet(lambda x: x + 5)\n\n        class CustomEvents(EventEnum):\n            FOO_EVENT = \"foo_event\"\n            BAR_EVENT = \"bar_event\"\n\n        val_engine = EnsembleEvaluator(\n            device=device,\n            val_data_loader=val_loader,\n            networks=[net0, net1, net2, net3, net4],\n            pred_keys=pred_keys,\n            event_names=[\"bwd_event\", \"opt_event\", CustomEvents],\n            event_to_attr={CustomEvents.FOO_EVENT: \"foo\", \"opt_event\": \"opt\"},\n        )\n\n        @val_engine.on(Events.ITERATION_COMPLETED)\n        def run_transform(engine):\n            for i in range(5):\n                expected_value = engine.state.iteration + i\n                assert_allclose(engine.state.output[0][f\"pred_{i}\"].item(), expected_value)\n\n        @val_engine.on(Events.EPOCH_COMPLETED)\n        def trigger_custom_event():\n            val_engine.fire_event(CustomEvents.FOO_EVENT)\n            val_engine.fire_event(CustomEvents.BAR_EVENT)\n            val_engine.fire_event(\"bwd_event\")\n            val_engine.fire_event(\"opt_event\")\n\n        @val_engine.on(CustomEvents.FOO_EVENT)\n        def do_foo_op():\n            self.assertEqual(val_engine.state.foo, 0)\n\n        @val_engine.on(\"opt_event\")\n        def do_bar_op():\n            self.assertEqual(val_engine.state.opt, 0)\n\n        val_engine.run()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/engines/test_prepare_batch_default.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.engines import PrepareBatchDefault, SupervisedEvaluator\nfrom tests.test_utils import assert_allclose\n\n\nclass TestNet(torch.nn.Module):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def forward(self, x: torch.Tensor):\n        return x\n\n\nclass TestPrepareBatchDefault(unittest.TestCase):\n    @parameterized.expand(\n        [\n            (\n                [\n                    {\n                        \"image\": torch.tensor([1, 2]),\n                        \"label\": torch.tensor([3, 4]),\n                        \"extra1\": torch.tensor([5, 6]),\n                        \"extra2\": 16,\n                        \"extra3\": \"test\",\n                    }\n                ],\n                TestNet(),\n                True,\n            ),  # dict_content\n            ([torch.tensor([1, 2])], torch.nn.Identity(), True),  # tensor_content\n            ([(torch.tensor([1, 2]), torch.tensor([3, 4]))], torch.nn.Identity(), True),  # pair_content\n            ([], TestNet(), False),  # empty_data\n        ]\n    )\n    def test_prepare_batch(self, dataloader, network, should_run):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        evaluator = SupervisedEvaluator(\n            device=device,\n            val_data_loader=dataloader,\n            epoch_length=len(dataloader) if should_run else 0,\n            network=network,\n            non_blocking=False,\n            prepare_batch=PrepareBatchDefault(),\n            decollate=False,\n            mode=\"eval\" if should_run else \"train\",\n        )\n        evaluator.run()\n\n        if should_run:\n            output = evaluator.state.output\n            if isinstance(dataloader[0], dict) or isinstance(dataloader[0], tuple):\n                assert_allclose(output[\"image\"], torch.tensor([1, 2], device=device))\n                assert_allclose(output[\"label\"], torch.tensor([3, 4], device=device))\n            else:\n                assert_allclose(output[\"image\"], torch.tensor([1, 2], device=device))\n                self.assertTrue(output[\"label\"] is None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/engines/test_prepare_batch_default_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nimport torch.distributed as dist\nfrom parameterized import parameterized\n\nfrom monai.engines import PrepareBatchDefault, SupervisedEvaluator\nfrom tests.test_utils import DistCall, DistTestCase, assert_allclose\n\nTEST_CASE_1 = [\n    [\n        # data for rank 0, has 1 iteration\n        [{\"image\": torch.tensor([1, 1]), \"label\": torch.tensor([1, 0])}],\n        # data for rank 1, has 2 iterations\n        [\n            {\"image\": torch.tensor([1, 0]), \"label\": torch.tensor([1, 0])},\n            {\"image\": torch.tensor([1]), \"label\": torch.tensor([0])},\n        ],\n    ]\n]\n\nTEST_CASE_2 = [\n    [\n        # data for rank 0\n        [{\"image\": torch.tensor([0, 1, 1, 0, 1]), \"label\": torch.tensor([1, 1, 0, 0, 1])}],\n        # data for rank 1\n        [],\n    ]\n]\n\n\nclass TestNet(torch.nn.Module):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def forward(self, x: torch.Tensor):\n        return x\n\n\nclass DistributedPrepareBatchDefault(DistTestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    @DistCall(nnodes=1, nproc_per_node=2, node_rank=0)\n    def test_compute(self, dataloaders):\n        device = torch.device(f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\")\n        dataloader = dataloaders[dist.get_rank()]\n        # set up engine\n        evaluator = SupervisedEvaluator(\n            device=device,\n            val_data_loader=dataloader,\n            epoch_length=len(dataloader),\n            network=TestNet(),\n            non_blocking=False,\n            prepare_batch=PrepareBatchDefault(),\n            decollate=False,\n        )\n        evaluator.run()\n        output = evaluator.state.output\n        if len(dataloader) > 0:\n            assert_allclose(output[\"image\"], dataloader[-1][\"image\"].to(device=device))\n            assert_allclose(output[\"label\"], dataloader[-1][\"label\"].to(device=device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/engines/test_prepare_batch_diffusion.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.engines import SupervisedEvaluator\nfrom monai.engines.utils import DiffusionPrepareBatch\nfrom monai.inferers import DiffusionInferer\nfrom monai.networks.nets import DiffusionModelUNet\nfrom monai.networks.schedulers import DDPMScheduler\n\nTEST_CASES = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (2, 1, 8, 8),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (2, 1, 8, 8, 8),\n    ],\n]\n\n\nclass TestPrepareBatchDiffusion(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_output_sizes(self, input_args, image_size):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        dataloader = [{\"image\": torch.randn(image_size).to(device)}]\n        scheduler = DDPMScheduler(num_train_timesteps=20)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        network = DiffusionModelUNet(**input_args).to(device)\n        evaluator = SupervisedEvaluator(\n            device=device,\n            val_data_loader=dataloader,\n            epoch_length=1,\n            network=network,\n            inferer=inferer,\n            non_blocking=True,\n            prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20),\n            decollate=False,\n        )\n        evaluator.run()\n        output = evaluator.state.output\n        # check shapes are the same\n        self.assertEqual(output[\"pred\"].shape, image_size)\n        self.assertEqual(output[\"label\"].shape, output[\"image\"].shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_conditioning(self, input_args, image_size):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        dataloader = [{\"image\": torch.randn(image_size).to(device), \"context\": torch.randn((2, 4, 3)).to(device)}]\n        scheduler = DDPMScheduler(num_train_timesteps=20)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        network = DiffusionModelUNet(**input_args, with_conditioning=True, cross_attention_dim=3).to(device)\n        evaluator = SupervisedEvaluator(\n            device=device,\n            val_data_loader=dataloader,\n            epoch_length=1,\n            network=network,\n            inferer=inferer,\n            non_blocking=True,\n            prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20, condition_name=\"context\"),\n            decollate=False,\n        )\n        evaluator.run()\n        output = evaluator.state.output\n        # check shapes are the same\n        self.assertEqual(output[\"pred\"].shape, image_size)\n        self.assertEqual(output[\"label\"].shape, output[\"image\"].shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/engines/test_prepare_batch_extra_input.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.engines import PrepareBatchExtraInput, SupervisedEvaluator\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_0 = [\n    {\"extra_keys\": \"extra1\"},\n    {\"x\": torch.tensor([1, 2]), \"t1\": torch.tensor([5, 6]), \"t2\": None, \"t3\": None},\n]\n\nTEST_CASE_1 = [\n    {\"extra_keys\": [\"extra1\", \"extra3\"]},\n    {\"x\": torch.tensor([1, 2]), \"t1\": torch.tensor([5, 6]), \"t2\": \"test\", \"t3\": None},\n]\n\nTEST_CASE_2 = [\n    {\"extra_keys\": {\"t1\": \"extra2\", \"t2\": \"extra3\", \"t3\": \"extra1\"}},\n    {\"x\": torch.tensor([1, 2]), \"t1\": 16, \"t2\": \"test\", \"t3\": torch.tensor([5, 6])},\n]\n\n\nclass TestNet(torch.nn.Module):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def forward(self, x: torch.Tensor, t1=None, t2=None, t3=None):\n        return {\"x\": x, \"t1\": t1, \"t2\": t2, \"t3\": t3}\n\n\nclass TestPrepareBatchExtraInput(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])\n    def test_content(self, input_args, expected_value):\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        dataloader = [\n            {\n                \"image\": torch.tensor([1, 2]),\n                \"label\": torch.tensor([3, 4]),\n                \"extra1\": torch.tensor([5, 6]),\n                \"extra2\": 16,\n                \"extra3\": \"test\",\n            }\n        ]\n        # set up engine\n        evaluator = SupervisedEvaluator(\n            device=device,\n            val_data_loader=dataloader,\n            epoch_length=1,\n            network=TestNet(),\n            non_blocking=True,\n            prepare_batch=PrepareBatchExtraInput(**input_args),\n            decollate=False,\n        )\n        evaluator.run()\n        output = evaluator.state.output\n        assert_allclose(output[\"image\"], torch.tensor([1, 2], device=device))\n        assert_allclose(output[\"label\"], torch.tensor([3, 4], device=device))\n        for k, v in output[\"pred\"].items():\n            if isinstance(v, torch.Tensor):\n                assert_allclose(v, expected_value[k].to(device))\n            else:\n                self.assertEqual(v, expected_value[k])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/fl/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/fl/monai_algo/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/fl/monai_algo/test_fl_monai_algo.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport os\nimport shutil\nimport unittest\nfrom copy import deepcopy\nfrom os.path import join as pathjoin\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser, ConfigWorkflow\nfrom monai.bundle.utils import DEFAULT_HANDLERS_ID\nfrom monai.fl.client.monai_algo import MonaiAlgo\nfrom monai.fl.utils.constants import ExtraItems\nfrom monai.fl.utils.exchange_object import ExchangeObject\nfrom monai.utils import path_to_uri\nfrom tests.test_utils import SkipIfNoModule\n\n_root_dir = Path(__file__).resolve().parents[2]\n_data_dir = os.path.join(_root_dir, \"testing_data\")\n_logging_file = pathjoin(_data_dir, \"logging.conf\")\n\nTEST_TRAIN_1 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"train_workflow\": ConfigWorkflow(\n            config_file=os.path.join(_data_dir, \"config_fl_train.json\"),\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n        ),\n        \"config_evaluate_filename\": None,\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\nTEST_TRAIN_2 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"config_train_filename\": os.path.join(_data_dir, \"config_fl_train.json\"),\n        \"config_evaluate_filename\": None,\n        \"config_filters_filename\": None,\n    }\n]\nTEST_TRAIN_3 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"train_workflow\": ConfigWorkflow(\n            config_file=os.path.join(_data_dir, \"config_fl_train.json\"),\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n        ),\n        \"config_evaluate_filename\": None,\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\n\nTEST_TRAIN_4 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"train_workflow\": ConfigWorkflow(\n            config_file=os.path.join(_data_dir, \"config_fl_train.json\"),\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n            tracking={\n                \"handlers_id\": DEFAULT_HANDLERS_ID,\n                \"configs\": {\n                    \"save_execute_config\": f\"{_data_dir}/config_executed.json\",\n                    \"trainer\": {\n                        \"_target_\": \"MLFlowHandler\",\n                        \"tracking_uri\": path_to_uri(_data_dir) + \"/mlflow_override\",\n                        \"output_transform\": \"$monai.handlers.from_engine(['loss'], first=True)\",\n                        \"close_on_complete\": True,\n                    },\n                },\n            },\n        ),\n        \"config_evaluate_filename\": None,\n        \"config_filters_filename\": None,\n    }\n]\n\nTEST_EVALUATE_1 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"config_train_filename\": None,\n        \"eval_workflow\": ConfigWorkflow(\n            config_file=[\n                os.path.join(_data_dir, \"config_fl_train.json\"),\n                os.path.join(_data_dir, \"config_fl_evaluate.json\"),\n            ],\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n            tracking=\"mlflow\",\n            tracking_uri=path_to_uri(_data_dir) + \"/mlflow_1\",\n            experiment_name=\"monai_eval1\",\n        ),\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\nTEST_EVALUATE_2 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"config_train_filename\": None,\n        \"config_evaluate_filename\": [\n            os.path.join(_data_dir, \"config_fl_train.json\"),\n            os.path.join(_data_dir, \"config_fl_evaluate.json\"),\n        ],\n        \"eval_kwargs\": {\n            \"tracking\": \"mlflow\",\n            \"tracking_uri\": path_to_uri(_data_dir) + \"/mlflow_2\",\n            \"experiment_name\": \"monai_eval2\",\n        },\n        \"eval_workflow_name\": \"training\",\n        \"config_filters_filename\": None,\n    }\n]\nTEST_EVALUATE_3 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"config_train_filename\": None,\n        \"eval_workflow\": ConfigWorkflow(\n            config_file=[\n                os.path.join(_data_dir, \"config_fl_train.json\"),\n                os.path.join(_data_dir, \"config_fl_evaluate.json\"),\n            ],\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n        ),\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\n\nTEST_GET_WEIGHTS_1 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"train_workflow\": ConfigWorkflow(\n            config_file=os.path.join(_data_dir, \"config_fl_train.json\"),\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n        ),\n        \"config_evaluate_filename\": None,\n        \"send_weight_diff\": False,\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\nTEST_GET_WEIGHTS_2 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"config_train_filename\": os.path.join(_data_dir, \"config_fl_train.json\"),\n        \"config_evaluate_filename\": None,\n        \"send_weight_diff\": True,\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\nTEST_GET_WEIGHTS_3 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"train_workflow\": ConfigWorkflow(\n            config_file=os.path.join(_data_dir, \"config_fl_train.json\"),\n            workflow_type=\"train\",\n            logging_file=_logging_file,\n        ),\n        \"config_evaluate_filename\": None,\n        \"send_weight_diff\": True,\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\n\n\n@SkipIfNoModule(\"ignite\")\n@SkipIfNoModule(\"mlflow\")\nclass TestFLMonaiAlgo(unittest.TestCase):\n    @parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4])\n    def test_train(self, input_params):\n        # initialize algo\n        algo = MonaiAlgo(**input_params)\n        algo.initialize(extra={ExtraItems.CLIENT_NAME: \"test_fl\"})\n        algo.abort()\n\n        # initialize model\n        parser = ConfigParser(config=deepcopy(algo.train_workflow.parser.get()))\n        parser.parse()\n        network = parser.get_parsed_content(\"network\")\n\n        data = ExchangeObject(weights=network.state_dict())\n\n        # test train\n        algo.train(data=data, extra={})\n        algo.finalize()\n\n        # test experiment management\n        if \"save_execute_config\" in algo.train_workflow.parser:\n            self.assertTrue(os.path.exists(f\"{_data_dir}/mlflow_override\"))\n            shutil.rmtree(f\"{_data_dir}/mlflow_override\")\n            self.assertTrue(os.path.exists(f\"{_data_dir}/config_executed.json\"))\n            os.remove(f\"{_data_dir}/config_executed.json\")\n\n    @parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])\n    def test_evaluate(self, input_params):\n        # initialize algo\n        algo = MonaiAlgo(**input_params)\n        algo.initialize(extra={ExtraItems.CLIENT_NAME: \"test_fl\"})\n\n        # initialize model\n        parser = ConfigParser(config=deepcopy(algo.eval_workflow.parser.get()))\n        parser.parse()\n        network = parser.get_parsed_content(\"network\")\n\n        data = ExchangeObject(weights=network.state_dict())\n\n        # test evaluate\n        algo.evaluate(data=data, extra={})\n\n        # test experiment management\n        if \"save_execute_config\" in algo.eval_workflow.parser:\n            self.assertGreater(len(list(glob.glob(f\"{_data_dir}/mlflow_*\"))), 0)\n            for f in list(glob.glob(f\"{_data_dir}/mlflow_*\")):\n                shutil.rmtree(f)\n            self.assertGreater(len(list(glob.glob(f\"{_data_dir}/eval/config_*\"))), 0)\n            for f in list(glob.glob(f\"{_data_dir}/eval/config_*\")):\n                os.remove(f)\n\n    @parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3])\n    def test_get_weights(self, input_params):\n        # initialize algo\n        algo = MonaiAlgo(**input_params)\n        algo.initialize(extra={ExtraItems.CLIENT_NAME: \"test_fl\"})\n\n        # test train\n        if input_params[\"send_weight_diff\"]:  # should not work as test doesn't receive a global model\n            with self.assertRaises(ValueError):\n                weights = algo.get_weights(extra={})\n        else:\n            weights = algo.get_weights(extra={})\n            self.assertIsInstance(weights, ExchangeObject)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/fl/monai_algo/test_fl_monai_algo_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom os.path import join as pathjoin\nfrom pathlib import Path\n\nimport torch.distributed as dist\n\nfrom monai.bundle import ConfigParser, ConfigWorkflow\nfrom monai.fl.client.monai_algo import MonaiAlgo\nfrom monai.fl.utils.constants import ExtraItems\nfrom monai.fl.utils.exchange_object import ExchangeObject\nfrom monai.networks import get_state_dict\nfrom tests.test_utils import DistCall, DistTestCase, SkipIfNoModule, skip_if_no_cuda\n\nTESTS_PATH = TESTS_PATH = Path(__file__).parents[2].as_posix()\n_root_dir = os.path.abspath(pathjoin(TESTS_PATH))\n_data_dir = pathjoin(_root_dir, \"testing_data\")\n_logging_file = pathjoin(_data_dir, \"logging.conf\")\n\n\n@SkipIfNoModule(\"ignite\")\nclass TestFLMonaiAlgo(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2, init_method=\"no_init\")\n    @skip_if_no_cuda\n    def test_train(self):\n        train_configs = [pathjoin(_data_dir, \"config_fl_train.json\"), pathjoin(_data_dir, \"multi_gpu_train.json\")]\n        eval_configs = [\n            pathjoin(_data_dir, \"config_fl_train.json\"),\n            pathjoin(_data_dir, \"config_fl_evaluate.json\"),\n            pathjoin(_data_dir, \"multi_gpu_evaluate.json\"),\n        ]\n        train_workflow = ConfigWorkflow(config_file=train_configs, workflow_type=\"train\", logging_file=_logging_file)\n        # simulate the case that this application has specific requirements for a bundle workflow\n        train_workflow.add_property(name=\"loader\", required=True, config_id=\"train#training_transforms#0\", desc=\"NA\")\n\n        # initialize algo\n        algo = MonaiAlgo(\n            bundle_root=_data_dir,\n            train_workflow=ConfigWorkflow(config_file=train_configs, workflow_type=\"train\", logging_file=_logging_file),\n            eval_workflow=ConfigWorkflow(config_file=eval_configs, workflow_type=\"train\", logging_file=_logging_file),\n            config_filters_filename=pathjoin(_root_dir, \"testing_data\", \"config_fl_filters.json\"),\n        )\n        algo.initialize(extra={ExtraItems.CLIENT_NAME: \"test_fl\"})\n        self.assertTrue(dist.get_rank() in (0, 1))\n\n        # initialize model\n        parser = ConfigParser()\n        parser.read_config(train_configs)\n        parser.parse()\n        network = parser.get_parsed_content(\"network\")\n        data = ExchangeObject(weights=get_state_dict(network))\n        # test train\n        for i in range(2):\n            print(f\"Testing round {i + 1} of {2}...\")\n            # test evaluate\n            metric_eo = algo.evaluate(data=data, extra={})\n            self.assertIsInstance(metric_eo, ExchangeObject)\n            metric = metric_eo.metrics\n            self.assertIsInstance(metric[\"accuracy\"], float)\n\n            # test train\n            algo.train(data=data, extra={})\n            weights_eo = algo.get_weights()\n            self.assertIsInstance(weights_eo, ExchangeObject)\n            self.assertTrue(weights_eo.is_valid_weights())\n            self.assertIsInstance(weights_eo.weights, dict)\n            self.assertTrue(len(weights_eo.weights) > 0)\n\n    @DistCall(nnodes=1, nproc_per_node=2, init_method=\"no_init\")\n    @skip_if_no_cuda\n    def test_evaluate(self):\n        config_file = [\n            pathjoin(_data_dir, \"config_fl_train.json\"),\n            pathjoin(_data_dir, \"config_fl_evaluate.json\"),\n            pathjoin(_data_dir, \"multi_gpu_evaluate.json\"),\n        ]\n        # initialize algo\n        algo = MonaiAlgo(\n            bundle_root=_data_dir,\n            config_train_filename=None,\n            eval_workflow=ConfigWorkflow(config_file=config_file, workflow_type=\"train\", logging_file=_logging_file),\n            config_filters_filename=pathjoin(_data_dir, \"config_fl_filters.json\"),\n        )\n        algo.initialize(extra={ExtraItems.CLIENT_NAME: \"test_fl\"})\n        self.assertTrue(dist.get_rank() in (0, 1))\n\n        # initialize model\n        parser = ConfigParser()\n        parser.read_config(\n            [pathjoin(_data_dir, \"config_fl_train.json\"), pathjoin(_data_dir, \"config_fl_evaluate.json\")]\n        )\n        parser.parse()\n        network = parser.get_parsed_content(\"network\")\n        data = ExchangeObject(weights=get_state_dict(network))\n        # test evaluate\n        metric_eo = algo.evaluate(data=data, extra={})\n        self.assertIsInstance(metric_eo, ExchangeObject)\n        metric = metric_eo.metrics\n        self.assertIsInstance(metric[\"accuracy\"], float)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/fl/test_fl_monai_algo_stats.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigWorkflow\nfrom monai.fl.client import MonaiAlgoStats\nfrom monai.fl.utils.constants import ExtraItems, FlStatistics\nfrom monai.fl.utils.exchange_object import ExchangeObject\nfrom tests.test_utils import SkipIfNoModule\n\n_root_dir = Path(__file__).resolve().parents[1]\n_data_dir = os.path.join(_root_dir, \"testing_data\")\n_logging_file = os.path.join(_data_dir, \"logging.conf\")\n\nTEST_GET_DATA_STATS_1 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"workflow\": ConfigWorkflow(\n            workflow_type=\"train\",\n            config_file=os.path.join(_data_dir, \"config_fl_stats_1.json\"),\n            logging_file=_logging_file,\n            meta_file=None,\n        ),\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\nTEST_GET_DATA_STATS_2 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"config_train_filename\": os.path.join(_data_dir, \"config_fl_stats_2.json\"),\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\nTEST_GET_DATA_STATS_3 = [\n    {\n        \"bundle_root\": _data_dir,\n        \"workflow\": ConfigWorkflow(\n            workflow_type=\"train\",\n            config_file=[\n                os.path.join(_data_dir, \"config_fl_stats_1.json\"),\n                os.path.join(_data_dir, \"config_fl_stats_2.json\"),\n            ],\n            logging_file=_logging_file,\n            meta_file=None,\n        ),\n        \"config_filters_filename\": os.path.join(_data_dir, \"config_fl_filters.json\"),\n    }\n]\n\n\n@SkipIfNoModule(\"ignite\")\nclass TestFLMonaiAlgo(unittest.TestCase):\n    @parameterized.expand([TEST_GET_DATA_STATS_1, TEST_GET_DATA_STATS_2, TEST_GET_DATA_STATS_3])\n    def test_get_data_stats(self, input_params):\n        # initialize algo\n        algo = MonaiAlgoStats(**input_params)\n        algo.initialize(extra={ExtraItems.CLIENT_NAME: \"test_fl\", ExtraItems.APP_ROOT: _data_dir})\n\n        requested_stats = {FlStatistics.HIST_BINS: 100, FlStatistics.HIST_RANGE: [-500, 500]}\n        # test train\n        stats = algo.get_data_stats(extra=requested_stats)\n        self.assertIsInstance(stats, ExchangeObject)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/fl/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/fl/utils/test_fl_exchange_object.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.fl.utils.constants import WeightType\nfrom monai.fl.utils.exchange_object import ExchangeObject\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import SkipIfNoModule\n\nmodels, has_torchvision = optional_import(\"torchvision.models\")\n\nTEST_INIT_1 = [{\"weights\": None, \"optim\": None, \"metrics\": None, \"weight_type\": None, \"statistics\": None}, \"{}\"]\nTEST_INIT_2: list = []\nif has_torchvision:\n    network = models.resnet18()\n    TEST_INIT_2.append(\n        {\n            \"weights\": network.state_dict(),\n            \"optim\": torch.optim.Adam(lr=1, params=network.parameters()).state_dict(),\n            \"metrics\": {\"accuracy\": 1},\n            \"weight_type\": WeightType.WEIGHT_DIFF,\n            \"statistics\": {\"some_stat\": 1},\n        }\n    )\n    TEST_INIT_2.append(\"{'weights': 122, 'optim': 2, 'metrics': 1, 'weight_type': fl_weight_diff, 'statistics': 1}\")\n\nTEST_FAILURE_METRICS = [{\"weights\": None, \"optim\": None, \"metrics\": 1, \"weight_type\": None, \"statistics\": None}]\nTEST_FAILURE_STATISTICS = [{\"weights\": None, \"optim\": None, \"metrics\": None, \"weight_type\": None, \"statistics\": 1}]\nTEST_FAILURE_WEIGHT_TYPE = [{\"weights\": None, \"optim\": None, \"metrics\": None, \"weight_type\": 1, \"statistics\": None}]\n\n\n@SkipIfNoModule(\"torchvision\")\n@SkipIfNoModule(\"ignite\")\nclass TestFLExchangeObject(unittest.TestCase):\n    @parameterized.expand([TEST_INIT_1, TEST_INIT_2])\n    def test_init(self, input_params, expected_str):\n        eo = ExchangeObject(**input_params)\n        self.assertIsInstance(eo, ExchangeObject)\n        eo.summary()\n        self.assertEqual(repr(eo), expected_str)\n\n    @parameterized.expand([TEST_FAILURE_METRICS, TEST_FAILURE_STATISTICS, TEST_FAILURE_WEIGHT_TYPE])\n    def test_failures(self, input_params):\n        with self.assertRaises(ValueError):\n            ExchangeObject(**input_params)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/handlers/test_handler_average_precision.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom monai.handlers import AveragePrecision\nfrom monai.transforms import Activations, AsDiscrete\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass TestHandlerAveragePrecision(unittest.TestCase):\n\n    def test_compute(self):\n        ap_metric = AveragePrecision()\n        act = Activations(softmax=True)\n        to_onehot = AsDiscrete(to_onehot=2)\n\n        y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]\n        y = [torch.Tensor([0]), torch.Tensor([1])]\n        y_pred = [act(p) for p in y_pred]\n        y = [to_onehot(y_) for y_ in y]\n        ap_metric.update([y_pred, y])\n\n        y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]\n        y = [torch.Tensor([0]), torch.Tensor([1])]\n        y_pred = [act(p) for p in y_pred]\n        y = [to_onehot(y_) for y_ in y]\n\n        ap_metric.update([y_pred, y])\n\n        ap = ap_metric.compute()\n        np.testing.assert_allclose(0.8333333, ap)\n\n\nclass DistributedAveragePrecision(DistTestCase):\n\n    @DistCall(nnodes=1, nproc_per_node=2, node_rank=0)\n    def test_compute(self):\n        ap_metric = AveragePrecision()\n        act = Activations(softmax=True)\n        to_onehot = AsDiscrete(to_onehot=2)\n\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        if dist.get_rank() == 0:\n            y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)]\n            y = [torch.tensor([0], device=device), torch.tensor([1], device=device)]\n\n        if dist.get_rank() == 1:\n            y_pred = [\n                torch.tensor([0.2, 0.1], device=device),\n                torch.tensor([0.1, 0.5], device=device),\n                torch.tensor([0.3, 0.4], device=device),\n            ]\n            y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)]\n\n        y_pred = [act(p) for p in y_pred]\n        y = [to_onehot(y_) for y_ in y]\n        ap_metric.update([y_pred, y])\n\n        result = ap_metric.compute()\n        np.testing.assert_allclose(0.7778, result, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_calibration_error.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.handlers import CalibrationError, from_engine\nfrom monai.utils import IgniteInfo, min_version, optional_import\nfrom tests.test_utils import assert_allclose\n\nEngine, has_ignite = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Engine\")\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n# Test cases for handler\n# Format: [input_params, expected_value, expected_rows, expected_channels]\nTEST_CASE_1 = [\n    {\n        \"num_bins\": 5,\n        \"include_background\": True,\n        \"calibration_reduction\": \"expected\",\n        \"metric_reduction\": \"mean\",\n        \"output_transform\": from_engine([\"pred\", \"label\"]),\n    },\n    0.2250,\n    4,  # 2 batches * 2 iterations\n    2,  # 2 channels\n]\n\nTEST_CASE_2 = [\n    {\n        \"num_bins\": 5,\n        \"include_background\": False,\n        \"calibration_reduction\": \"expected\",\n        \"metric_reduction\": \"mean\",\n        \"output_transform\": from_engine([\"pred\", \"label\"]),\n    },\n    0.2500,\n    4,  # 2 batches * 2 iterations\n    1,  # 1 channel (background excluded)\n]\n\nTEST_CASE_3 = [\n    {\n        \"num_bins\": 5,\n        \"include_background\": True,\n        \"calibration_reduction\": \"average\",\n        \"metric_reduction\": \"mean\",\n        \"output_transform\": from_engine([\"pred\", \"label\"]),\n    },\n    0.2584,  # Mean of [[0.2000, 0.4667], [0.2000, 0.1667]]\n    4,\n    2,\n]\n\nTEST_CASE_4 = [\n    {\n        \"num_bins\": 5,\n        \"include_background\": True,\n        \"calibration_reduction\": \"maximum\",\n        \"metric_reduction\": \"mean\",\n        \"output_transform\": from_engine([\"pred\", \"label\"]),\n    },\n    0.4000,  # Mean of [[0.3000, 0.7000], [0.3000, 0.3000]]\n    4,\n    2,\n]\n\n\n@unittest.skipUnless(has_ignite, \"Requires pytorch-ignite\")\nclass TestHandlerCalibrationError(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_compute(self, input_params, expected_value, expected_rows, expected_channels):\n        calibration_metric = CalibrationError(**input_params)\n\n        # Test data: 2 batches with 2 channels each\n        y_pred = torch.tensor(\n            [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]]\n        ).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device)\n\n        # Create data as list of batches (2 iterations)\n        data = [{\"pred\": y_pred, \"label\": y}, {\"pred\": y_pred, \"label\": y}]\n\n        def _val_func(engine, batch):\n            return batch\n\n        engine = Engine(_val_func)\n        calibration_metric.attach(engine=engine, name=\"calibration_error\")\n\n        engine.run(data, max_epochs=1)\n\n        assert_allclose(\n            engine.state.metrics[\"calibration_error\"], expected_value, atol=1e-4, rtol=1e-4, type_test=False\n        )\n\n        # Check details shape using invariants rather than exact tuple\n        details = engine.state.metric_details[\"calibration_error\"]\n        self.assertEqual(details.shape[0], expected_rows)\n        self.assertEqual(details.shape[-1], expected_channels)\n\n\n@unittest.skipUnless(has_ignite, \"Requires pytorch-ignite\")\nclass TestHandlerCalibrationErrorEdgeCases(unittest.TestCase):\n\n    def test_single_iteration(self):\n        \"\"\"Test handler with single iteration.\"\"\"\n        calibration_metric = CalibrationError(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=\"expected\",\n            metric_reduction=\"mean\",\n            output_transform=from_engine([\"pred\", \"label\"]),\n        )\n\n        y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device)\n\n        data = [{\"pred\": y_pred, \"label\": y}]\n\n        def _val_func(engine, batch):\n            return batch\n\n        engine = Engine(_val_func)\n        calibration_metric.attach(engine=engine, name=\"calibration_error\")\n\n        engine.run(data, max_epochs=1)\n\n        assert_allclose(engine.state.metrics[\"calibration_error\"], 0.2, atol=1e-4, rtol=1e-4, type_test=False)\n\n    def test_save_details_false(self):\n        \"\"\"Test handler with save_details=False.\"\"\"\n        calibration_metric = CalibrationError(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=\"expected\",\n            metric_reduction=\"mean\",\n            output_transform=from_engine([\"pred\", \"label\"]),\n            save_details=False,\n        )\n\n        y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device)\n\n        data = [{\"pred\": y_pred, \"label\": y}]\n\n        def _val_func(engine, batch):\n            return batch\n\n        engine = Engine(_val_func)\n        calibration_metric.attach(engine=engine, name=\"calibration_error\")\n\n        engine.run(data, max_epochs=1)\n\n        assert_allclose(engine.state.metrics[\"calibration_error\"], 0.2, atol=1e-4, rtol=1e-4, type_test=False)\n\n        # When save_details=False, metric_details should not exist or should not have the metric key\n        if hasattr(engine.state, \"metric_details\"):\n            self.assertNotIn(\"calibration_error\", engine.state.metric_details or {})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_checkpoint_loader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\n\nimport torch\nimport torch.optim as optim\nfrom ignite.engine import Engine, Events\n\nfrom monai.handlers import CheckpointLoader, CheckpointSaver\nfrom tests.test_utils import assert_allclose\n\n\nclass TestHandlerCheckpointLoader(unittest.TestCase):\n    def test_one_save_one_load(self):\n        net1 = torch.nn.PReLU()\n        data1 = net1.state_dict()\n        data1[\"weight\"] = torch.tensor([0.1])\n        net1.load_state_dict(data1)\n        net2 = torch.nn.PReLU()\n        data2 = net2.state_dict()\n        data2[\"weight\"] = torch.tensor([0.2])\n        net2.load_state_dict(data2)\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine1 = Engine(lambda e, b: None)\n            CheckpointSaver(save_dir=tempdir, save_dict={\"net\": net1, \"eng\": engine1}, save_final=True).attach(engine1)\n            engine1.run([0] * 8, max_epochs=5)\n            path = tempdir + \"/checkpoint_final_iteration=40.pt\"\n            engine2 = Engine(lambda e, b: None)\n            CheckpointLoader(load_path=path, load_dict={\"net\": net2, \"eng\": engine2}, strict=True).attach(engine2)\n\n            @engine2.on(Events.STARTED)\n            def check_epoch(engine: Engine):\n                self.assertEqual(engine.state.epoch, 5)\n\n            engine2.run([0] * 8, max_epochs=8)\n            assert_allclose(net2.state_dict()[\"weight\"], torch.tensor([0.1]))\n\n            # test bad case with max_epochs smaller than current epoch\n            engine3 = Engine(lambda e, b: None)\n            CheckpointLoader(load_path=path, load_dict={\"net\": net2, \"eng\": engine3}, strict=True).attach(engine3)\n\n            try:\n                engine3.run([0] * 8, max_epochs=3)\n            except ValueError:\n                self.assertEqual(engine3.state.epoch, 5)\n                self.assertEqual(engine3.state.max_epochs, 5)\n\n    def test_two_save_one_load(self):\n        net1 = torch.nn.PReLU()\n        optimizer = optim.SGD(net1.parameters(), lr=0.02)\n        data1 = net1.state_dict()\n        data1[\"weight\"] = torch.tensor([0.1])\n        net1.load_state_dict(data1)\n        net2 = torch.nn.PReLU()\n        data2 = net2.state_dict()\n        data2[\"weight\"] = torch.tensor([0.2])\n        net2.load_state_dict(data2)\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(lambda e, b: None)\n            save_dict = {\"net\": net1, \"opt\": optimizer}\n            CheckpointSaver(save_dir=tempdir, save_dict=save_dict, save_final=True).attach(engine)\n            engine.run([0] * 8, max_epochs=5)\n            path = tempdir + \"/checkpoint_final_iteration=40.pt\"\n            engine = Engine(lambda e, b: None)\n            CheckpointLoader(load_path=path, load_dict={\"net\": net2}, strict=True).attach(engine)\n            engine.run([0] * 8, max_epochs=1)\n            assert_allclose(net2.state_dict()[\"weight\"], torch.tensor([0.1]))\n\n    def test_save_single_device_load_multi_devices(self):\n        net1 = torch.nn.PReLU()\n        data1 = net1.state_dict()\n        data1[\"weight\"] = torch.tensor([0.1])\n        net1.load_state_dict(data1)\n        net2 = torch.nn.PReLU()\n        data2 = net2.state_dict()\n        data2[\"weight\"] = torch.tensor([0.2])\n        net2.load_state_dict(data2)\n        net2 = torch.nn.DataParallel(net2)\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(lambda e, b: None)\n            CheckpointSaver(save_dir=tempdir, save_dict={\"net\": net1}, save_final=True).attach(engine)\n            engine.run([0] * 8, max_epochs=5)\n            path = tempdir + \"/net_final_iteration=40.pt\"\n            engine = Engine(lambda e, b: None)\n            CheckpointLoader(load_path=path, load_dict={\"net\": net2}, strict=True).attach(engine)\n            engine.run([0] * 8, max_epochs=1)\n            assert_allclose(net2.state_dict()[\"module.weight\"].cpu(), torch.tensor([0.1]))\n\n    def test_partial_under_load(self):\n        net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])\n        data1 = net1.state_dict()\n        data1[\"0.weight\"] = torch.tensor([0.1])\n        data1[\"1.weight\"] = torch.tensor([0.2])\n        net1.load_state_dict(data1)\n\n        net2 = torch.nn.Sequential(*[torch.nn.PReLU()])\n        data2 = net2.state_dict()\n        data2[\"0.weight\"] = torch.tensor([0.3])\n        net2.load_state_dict(data2)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(lambda e, b: None)\n            CheckpointSaver(save_dir=tempdir, save_dict={\"net\": net1}, save_final=True).attach(engine)\n            engine.run([0] * 8, max_epochs=5)\n            path = tempdir + \"/net_final_iteration=40.pt\"\n            engine = Engine(lambda e, b: None)\n            CheckpointLoader(load_path=path, load_dict={\"net\": net2}, strict=False).attach(engine)\n            engine.run([0] * 8, max_epochs=1)\n            assert_allclose(net2.state_dict()[\"0.weight\"].cpu(), torch.tensor([0.1]))\n\n    def test_partial_over_load(self):\n        net1 = torch.nn.Sequential(*[torch.nn.PReLU()])\n        data1 = net1.state_dict()\n        data1[\"0.weight\"] = torch.tensor([0.1])\n        net1.load_state_dict(data1)\n\n        net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])\n        data2 = net2.state_dict()\n        data2[\"0.weight\"] = torch.tensor([0.2])\n        data2[\"1.weight\"] = torch.tensor([0.3])\n        net2.load_state_dict(data2)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(lambda e, b: None)\n            CheckpointSaver(save_dir=tempdir, save_dict={\"net\": net1}, save_final=True).attach(engine)\n            engine.run([0] * 8, max_epochs=5)\n            path = tempdir + \"/net_final_iteration=40.pt\"\n            engine = Engine(lambda e, b: None)\n            CheckpointLoader(load_path=path, load_dict={\"net\": net2}, strict=False).attach(engine)\n            engine.run([0] * 8, max_epochs=1)\n            assert_allclose(net2.state_dict()[\"0.weight\"].cpu(), torch.tensor([0.1]))\n\n    def test_strict_shape(self):\n        net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)])\n        data1 = net1.state_dict()\n        data1[\"0.weight\"] = torch.tensor([1, 2, 3, 4, 5])\n        data1[\"new\"] = torch.tensor(0.1)\n        net1.load_state_dict(data1, strict=False)\n        opt1 = optim.SGD(net1.parameters(), lr=0.02)\n\n        net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])\n        data2 = net2.state_dict()\n        data2[\"0.weight\"] = torch.tensor([0.2])\n        data2[\"1.weight\"] = torch.tensor([0.3])\n        net2.load_state_dict(data2)\n        opt2 = optim.SGD(net2.parameters(), lr=0.02)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(lambda e, b: None)\n            CheckpointSaver(save_dir=tempdir, save_dict={\"net\": net1, \"opt\": opt1}, save_final=True).attach(engine)\n            engine.run([0] * 8, max_epochs=5)\n            path = tempdir + \"/checkpoint_final_iteration=40.pt\"\n            engine = Engine(lambda e, b: None)\n            CheckpointLoader(\n                load_path=path,\n                # expect to print a warning because it loads not only `net` but also `opt` with `strict_shape=False`\n                load_dict={\"net\": net2, \"opt\": opt2},\n                strict=False,\n                strict_shape=False,\n            ).attach(engine)\n            engine.run([0] * 8, max_epochs=1)\n            assert_allclose(net2.state_dict()[\"0.weight\"].cpu(), torch.tensor([0.2]))\n            # test whether `opt2` had been skipped when loading with `strict_shape=False`,\n            # it should have 2 items in `params`(0.weight and 1.weight) while the checkpoint has 1 item(0.weight)\n            self.assertEqual(len(opt1.state_dict()[\"param_groups\"][0][\"params\"]), 1)\n            self.assertEqual(len(opt2.state_dict()[\"param_groups\"][0][\"params\"]), 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_checkpoint_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\nimport torch.optim as optim\nfrom ignite.engine import Engine\nfrom parameterized import parameterized\n\nfrom monai.handlers import CheckpointLoader, CheckpointSaver\n\nTEST_CASE_1 = [\n    True,\n    None,\n    False,\n    None,\n    1,\n    None,\n    False,\n    False,\n    False,\n    True,\n    0,\n    None,\n    [\"test_checkpoint_final_iteration=40.pt\"],\n]\n\nTEST_CASE_2 = [\n    False,\n    None,\n    True,\n    \"val_loss\",\n    2,\n    None,\n    False,\n    True,\n    False,\n    False,\n    0,\n    None,\n    [\"test_checkpoint_key_metric=32.pt\", \"test_checkpoint_key_metric=40.pt\"],\n]\n\nTEST_CASE_3 = [\n    False,\n    None,\n    False,\n    None,\n    1,\n    None,\n    False,\n    True,\n    False,\n    True,\n    2,\n    2,\n    [\"test_checkpoint_epoch=2.pt\", \"test_checkpoint_epoch=4.pt\"],\n]\n\nTEST_CASE_4 = [\n    False,\n    None,\n    False,\n    None,\n    1,\n    None,\n    False,\n    False,\n    False,\n    False,\n    10,\n    2,\n    [\"test_checkpoint_iteration=30.pt\", \"test_checkpoint_iteration=40.pt\"],\n]\n\nTEST_CASE_5 = [\n    True,\n    None,\n    False,\n    None,\n    1,\n    None,\n    False,\n    False,\n    False,\n    True,\n    0,\n    None,\n    [\"test_checkpoint_final_iteration=40.pt\"],\n    True,\n]\n\nTEST_CASE_6 = [True, \"final_model.pt\", False, None, 1, None, False, False, False, True, 0, None, [\"final_model.pt\"]]\n\nTEST_CASE_7 = [False, None, True, \"val_loss\", 1, \"model.pt\", False, False, False, True, 0, None, [\"model.pt\"]]\n\nTEST_CASE_8 = [False, None, True, \"val_loss\", 1, \"model.pt\", False, True, False, True, 0, None, [\"model.pt\"]]\n\n\nclass TestHandlerCheckpointSaver(unittest.TestCase):\n\n    @parameterized.expand(\n        [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]\n    )\n    def test_file(\n        self,\n        save_final,\n        final_filename,\n        save_key_metric,\n        key_metric_name,\n        key_metric_n_saved,\n        key_metric_filename,\n        key_metric_save_state,\n        key_metric_greater_or_equal,\n        key_metric_negative_sign,\n        epoch_level,\n        save_interval,\n        n_saved,\n        filenames,\n        multi_devices=False,\n    ):\n        data = [0] * 8\n\n        # set up engine\n        def _train_func(engine, batch):\n            engine.state.metrics[\"val_loss\"] = engine.state.iteration\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        net = torch.nn.PReLU()\n        if multi_devices:\n            net = torch.nn.DataParallel(net)\n        optimizer = optim.SGD(net.parameters(), lr=0.02)\n        with tempfile.TemporaryDirectory() as tempdir:\n            handler = CheckpointSaver(\n                tempdir,\n                {\"net\": net, \"opt\": optimizer},\n                \"CheckpointSaver\",\n                \"test\",\n                save_final,\n                final_filename,\n                save_key_metric,\n                key_metric_name,\n                key_metric_n_saved,\n                key_metric_filename,\n                key_metric_save_state,\n                key_metric_greater_or_equal,\n                key_metric_negative_sign,\n                epoch_level,\n                save_interval,\n                n_saved,\n            )\n            handler.attach(engine)\n            engine.run(data, max_epochs=2)\n            engine.run(data, max_epochs=5)\n            for filename in filenames:\n                self.assertTrue(os.path.exists(os.path.join(tempdir, filename)))\n\n    def test_exception(self):\n        net = torch.nn.PReLU()\n\n        # set up engine\n        def _train_func(engine, batch):\n            raise RuntimeError(\"test exception.\")\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        with tempfile.TemporaryDirectory() as tempdir:\n            stats_handler = CheckpointSaver(tempdir, {\"net\": net}, save_final=True)\n            stats_handler.attach(engine)\n\n            with self.assertRaises(RuntimeError):\n                engine.run(range(3), max_epochs=2)\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"net_final_iteration=1.pt\")))\n\n    def test_load_state_dict(self):\n        net = torch.nn.PReLU()\n\n        # set up engine\n        def _train_func(engine, batch):\n            engine.state.metrics[\"val_loss\"] = engine.state.iteration\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(_train_func)\n            CheckpointSaver(\n                save_dir=tempdir,\n                save_dict={\"net\": net},\n                save_key_metric=True,\n                key_metric_name=\"val_loss\",\n                key_metric_n_saved=2,\n                key_metric_save_state=True,\n                key_metric_negative_sign=True,\n            ).attach(engine)\n            engine.run(range(3), max_epochs=3)\n\n            saver = CheckpointSaver(\n                save_dir=tempdir,\n                save_dict={\"net\": net},\n                save_key_metric=True,\n                key_metric_name=\"val_loss\",\n                key_metric_n_saved=2,\n                key_metric_negative_sign=True,\n            )\n            engine = Engine(_train_func)\n            CheckpointLoader(os.path.join(tempdir, \"net_key_metric=-6.pt\"), {\"checkpointer\": saver}).attach(engine)\n            engine.run(range(1), max_epochs=1)\n\n            resumed = saver._key_metric_checkpoint._saved\n            for i in range(2):\n                self.assertEqual(resumed[1 - i].priority, -3 * (i + 1))\n                self.assertEqual(resumed[1 - i].filename, f\"net_key_metric=-{3 * (i + 1)}.pt\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_classification_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine\n\nfrom monai.data import decollate_batch\nfrom monai.data.csv_saver import CSVSaver\nfrom monai.handlers import ClassificationSaver\n\n\nclass TestHandlerClassificationSaver(unittest.TestCase):\n\n    def test_saved_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                engine.state.batch = decollate_batch(batch)\n                return [torch.zeros(1) for _ in range(8)]\n\n            engine = Engine(_train_func)\n\n            # set up testing handler\n            saver = CSVSaver(output_dir=tempdir, filename=\"predictions2.csv\", delimiter=\"\\t\")\n            ClassificationSaver(output_dir=tempdir, filename=\"predictions1.csv\", delimiter=\"\\t\").attach(engine)\n            ClassificationSaver(saver=saver).attach(engine)\n\n            data = [{\"filename_or_obj\": [\"testfile\" + str(i) for i in range(8)]}]\n            engine.run(data, max_epochs=1)\n\n            def _test_file(filename):\n                filepath = os.path.join(tempdir, filename)\n                self.assertTrue(os.path.exists(filepath))\n                with open(filepath) as f:\n                    reader = csv.reader(f, delimiter=\"\\t\")\n                    i = 0\n                    for row in reader:\n                        self.assertEqual(row[0], \"testfile\" + str(i))\n                        self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)\n                        i += 1\n                    self.assertEqual(i, 8)\n\n            _test_file(\"predictions1.csv\")\n            _test_file(\"predictions2.csv\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_classification_saver_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom ignite.engine import Engine\n\nfrom monai.data import decollate_batch\nfrom monai.handlers import ClassificationSaver\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedHandlerClassificationSaver(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_saved_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            rank = dist.get_rank()\n\n            # set up engine\n            def _train_func(engine, batch):\n                engine.state.batch = decollate_batch(batch)\n                return [torch.zeros(1) for _ in range(8 + rank * 2)]\n\n            engine = Engine(_train_func)\n\n            # set up testing handler\n            saver = ClassificationSaver(output_dir=tempdir, filename=\"predictions.csv\", save_rank=1)\n            saver.attach(engine)\n\n            # rank 0 has 8 images, rank 1 has 10 images\n            data = [\n                {\n                    \"filename_or_obj\": [\"testfile\" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))],\n                    \"data_shape\": torch.ones((8 + rank * 2, 1, 1)),\n                }\n            ]\n            # rank 1 has more iterations\n            if rank == 1:\n                data.append(\n                    {\n                        \"filename_or_obj\": [\"testfile\" + str(i) for i in range(18, 28)],\n                        \"data_shape\": torch.ones((10, 1, 1)),\n                    }\n                )\n\n            engine.run(data, max_epochs=1)\n            filepath = os.path.join(tempdir, \"predictions.csv\")\n            if rank == 1:\n                self.assertTrue(os.path.exists(filepath))\n                with open(filepath) as f:\n                    reader = csv.reader(f)\n                    i = 0\n                    for row in reader:\n                        self.assertEqual(row[0], \"testfile\" + str(i))\n                        self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)\n                        i += 1\n                    self.assertEqual(i, 28)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_clearml_image.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\nfrom os import environ\n\nfrom monai.handlers import ClearMLImageHandler\nfrom monai.utils import optional_import\n\nTask, has_clearml = optional_import(\"clearml\", name=\"Task\")\nget_active_config_file, has_get_active_config_file = optional_import(\n    \"clearml.backend_config.defs\", name=\"get_active_config_file\"\n)\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\n\n@unittest.skipUnless(has_clearml, \"Requires 'clearml' installation\")\n@unittest.skipUnless(has_tb, \"Requires SummaryWriter installation\")\n@unittest.skipIf(not has_get_active_config_file, \"ClearML 'get_active_config_file' not found\")\nclass TestHandlerClearMLImageHandler(unittest.TestCase):\n\n    def test_task_init(self):\n        handle, path = tempfile.mkstemp()\n        with open(handle, \"w\") as new_config:\n            if get_active_config_file():\n                with open(get_active_config_file()) as old_config:\n                    new_config.write(old_config.read())\n            new_config.write(\n                \"\\nsdk.development.vcs_repo_detect_async: false\\nsdk.development.report_use_subprocess: false\\n\"\n            )\n        environ[\"CLEARML_CONFIG_FILE\"] = path\n        try:\n            Task.force_store_standalone_script(True)\n            Task.set_offline(offline_mode=True)\n            ClearMLImageHandler(\n                project_name=\"MONAI\",\n                task_name=\"monai_experiment\",\n                output_uri=True,\n                tags=None,\n                reuse_last_task_id=True,\n                continue_last_task=False,\n                auto_connect_frameworks=True,\n                auto_connect_arg_parser=False,\n            )\n        except Exception as exc:\n            self.fail(exc)\n        self.assertEqual(Task.current_task().name, \"monai_experiment\")\n        self.assertEqual(Task.current_task()._project_name[1], \"MONAI\")\n        # Close ClearML Task\n        Task.current_task().close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_clearml_stats.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\nfrom os import environ\n\nfrom monai.handlers import ClearMLStatsHandler\nfrom monai.utils import optional_import\n\nTask, has_clearml = optional_import(\"clearml\", name=\"Task\")\nget_active_config_file, has_get_active_config_file = optional_import(\n    \"clearml.backend_config.defs\", name=\"get_active_config_file\"\n)\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\n\n@unittest.skipUnless(has_clearml, \"Requires 'clearml' installation\")\n@unittest.skipUnless(has_tb, \"Requires SummaryWriter installation\")\n@unittest.skipIf(not has_get_active_config_file, \"ClearML 'get_active_config_file' not found\")\nclass TestHandlerClearMLStatsHandler(unittest.TestCase):\n\n    def test_task_init(self):\n        handle, path = tempfile.mkstemp()\n        with open(handle, \"w\") as new_config:\n            if get_active_config_file():\n                with open(get_active_config_file()) as old_config:\n                    new_config.write(old_config.read())\n            new_config.write(\n                \"\\nsdk.development.vcs_repo_detect_async: false\\nsdk.development.report_use_subprocess: false\\n\"\n            )\n        environ[\"CLEARML_CONFIG_FILE\"] = path\n        try:\n            Task.force_store_standalone_script(True)\n            Task.set_offline(offline_mode=True)\n            ClearMLStatsHandler(\n                project_name=\"MONAI\",\n                task_name=\"monai_experiment\",\n                output_uri=True,\n                tags=None,\n                reuse_last_task_id=True,\n                continue_last_task=False,\n                auto_connect_frameworks=True,\n                auto_connect_arg_parser=False,\n            )\n        except Exception as exc:\n            self.fail(exc)\n        self.assertEqual(Task.current_task().name, \"monai_experiment\")\n        self.assertEqual(Task.current_task()._project_name[1], \"MONAI\")\n        # Close ClearML Task\n        Task.current_task().close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_confusion_matrix.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import Any\n\nimport torch\nfrom ignite.engine import Engine\nfrom parameterized import parameterized\n\nfrom monai.handlers import ConfusionMatrix\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [{\"include_background\": True, \"save_details\": False, \"metric_name\": \"f1\"}, 0.75]\nTEST_CASE_2 = [{\"include_background\": False, \"save_details\": False, \"metric_name\": \"ppv\"}, 1.0]\nTEST_CASE_3 = [{\"save_details\": False, \"metric_name\": \"f1\", \"reduction\": \"mean_batch\"}, torch.tensor([0.6667, 0.8000])]\nTEST_CASE_SEG_1 = [{\"include_background\": True, \"metric_name\": \"tpr\"}, 0.7]\n\ndata_1: dict[Any, Any] = {\n    \"y_pred\": torch.tensor(\n        [\n            [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n            [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n        ]\n    ),\n    \"y\": torch.tensor(\n        [\n            [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n            [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n        ]\n    ),\n}\n\ndata_2: dict[Any, Any] = {\n    \"y_pred\": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]]),\n    \"y\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n}\n\n\nclass TestHandlerConfusionMatrix(unittest.TestCase):\n    # TODO test multi node averaged confusion matrix\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_compute(self, input_params, expected_avg):\n        metric = ConfusionMatrix(**input_params)\n        # test input a list of channel-first tensor\n        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]\n        y = [torch.Tensor([[0], [1]]), torch.Tensor([[0], [1]])]\n        metric.update([y_pred, y])\n\n        y_pred = torch.Tensor([[[0], [1]], [[1], [0]]])\n        y = torch.Tensor([[[0], [1]], [[1], [0]]])\n        metric.update([y_pred, y])\n\n        avg_metric = metric.compute()\n        assert_allclose(avg_metric, expected_avg, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @parameterized.expand([TEST_CASE_SEG_1])\n    def test_compute_seg(self, input_params, expected_avg):\n        metric = ConfusionMatrix(**input_params)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine, \"confusion_matrix\")\n\n        y_pred = data_1[\"y_pred\"]\n        y = data_1[\"y\"]\n        metric.update([y_pred, y])\n\n        y_pred = data_2[\"y_pred\"]\n        y = data_2[\"y\"]\n        metric.update([y_pred, y])\n\n        avg_metric = metric.compute()\n        self.assertAlmostEqual(avg_metric, expected_avg, places=4)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape_mismatch(self, input_params, _expected):\n        metric = ConfusionMatrix(**input_params)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((2, 3))\n            metric.update([y_pred, y])\n\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((3, 2))\n            metric.update([y_pred, y])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_confusion_matrix_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom ignite.engine import Engine\n\nfrom monai.handlers import ConfusionMatrix\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedConfusionMatrix(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_compute(self):\n        self._compute()\n\n    def _compute(self):\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        metric = ConfusionMatrix(include_background=True, metric_name=\"tpr\")\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine, \"confusion_matrix\")\n        if dist.get_rank() == 0:\n            y_pred = torch.tensor(\n                [\n                    [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n                    [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n                ],\n                device=device,\n            )\n            y = torch.tensor(\n                [\n                    [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n                    [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n                ],\n                device=device,\n            )\n            metric.update([y_pred, y])\n\n        if dist.get_rank() == 1:\n            y_pred = torch.tensor(\n                [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]], device=device\n            )\n            y = torch.tensor(\n                [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=device\n            )\n            metric.update([y_pred, y])\n\n        avg_metric = metric.compute()\n        np.testing.assert_allclose(avg_metric, 0.7, rtol=1e-04, atol=1e-04)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_decollate_batch.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.engines import SupervisedEvaluator\nfrom monai.handlers import DecollateBatch, PostProcessing\nfrom monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd\nfrom tests.test_utils import assert_allclose\n\n\nclass TestHandlerDecollateBatch(unittest.TestCase):\n    def test_compute(self):\n        data = [\n            {\"image\": torch.tensor([[[[2.0], [3.0]]]]), \"filename\": [\"test1\"]},\n            {\"image\": torch.tensor([[[[6.0], [8.0]]]]), \"filename\": [\"test2\"]},\n        ]\n\n        handlers = [\n            DecollateBatch(event=\"MODEL_COMPLETED\"),\n            PostProcessing(\n                transform=Compose(\n                    [\n                        Activationsd(keys=\"pred\", sigmoid=True),\n                        CopyItemsd(keys=\"filename\", times=1, names=\"filename_bak\"),\n                        AsDiscreted(keys=\"pred\", threshold=0.5, to_onehot=2),\n                    ]\n                )\n            ),\n        ]\n        # set up engine, PostProcessing handler works together with postprocessing transforms of engine\n        engine = SupervisedEvaluator(\n            device=torch.device(\"cpu:0\"),\n            val_data_loader=data,\n            epoch_length=2,\n            network=torch.nn.PReLU(),\n            # set decollate=False and execute some postprocessing first, then decollate in handlers\n            postprocessing=lambda x: dict(pred=x[\"pred\"] + 1.0),\n            decollate=False,\n            val_handlers=handlers,\n        )\n        engine.run()\n\n        expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]])\n\n        for o, e in zip(engine.state.output, expected):\n            assert_allclose(o[\"pred\"], e)\n            filename = o.get(\"filename_bak\")\n            if filename is not None:\n                self.assertEqual(filename, \"test2\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_early_stop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom ignite.engine import Engine, Events\n\nfrom monai.handlers import EarlyStopHandler\n\n\nclass TestHandlerEarlyStop(unittest.TestCase):\n\n    def test_early_stop_train_loss(self):\n\n        def _train_func(engine, batch):\n            return {\"loss\": 1.5}\n\n        trainer = Engine(_train_func)\n        EarlyStopHandler(\n            patience=5, score_function=lambda x: x.state.output[\"loss\"], trainer=trainer, epoch_level=False\n        ).attach(trainer)\n\n        trainer.run(range(4), max_epochs=2)\n        self.assertEqual(trainer.state.iteration, 6)\n        self.assertEqual(trainer.state.epoch, 2)\n\n    def test_early_stop_val_metric(self):\n\n        def _train_func(engine, batch):\n            pass\n\n        trainer = Engine(_train_func)\n        validator = Engine(_train_func)\n        validator.state.metrics[\"val_acc\"] = 0.90\n\n        @trainer.on(Events.EPOCH_COMPLETED)\n        def run_validation(engine):\n            validator.state.metrics[\"val_acc\"] += 0.01\n            validator.run(range(3))\n\n        handler = EarlyStopHandler(\n            patience=3,\n            score_function=lambda x: x.state.metrics[\"val_acc\"],\n            trainer=None,\n            min_delta=0.1,\n            cumulative_delta=True,\n            epoch_level=True,\n        )\n        handler.attach(validator)\n        handler.set_trainer(trainer=trainer)\n\n        trainer.run(range(3), max_epochs=5)\n        self.assertEqual(trainer.state.iteration, 12)\n        self.assertEqual(trainer.state.epoch, 4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_garbage_collector.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport gc\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom ignite.engine import Engine\nfrom parameterized import parameterized\n\nfrom monai.data import Dataset\nfrom monai.handlers import GarbageCollector\nfrom monai.utils import IgniteInfo, min_version, optional_import\n\nEvents, has_ignite = optional_import(\"ignite.engine\", IgniteInfo.OPT_IMPORT_VERSION, min_version, \"Events\")\n\nTEST_CASE_0 = [[0, 1, 2], \"epoch\"]\n\nTEST_CASE_1 = [[0, 1, 2], \"iteration\"]\n\nTEST_CASE_2 = [[0, 1, 2], Events.EPOCH_COMPLETED]\n\n\nclass TestHandlerGarbageCollector(unittest.TestCase):\n\n    @skipUnless(has_ignite, \"Requires ignite\")\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])\n    def test_content(self, data, trigger_event):\n        # set up engine\n        gb_count_dict = {}\n\n        def _train_func(engine, batch):\n            # store garbage collection counts\n            if trigger_event == Events.EPOCH_COMPLETED or trigger_event.lower() == \"epoch\":\n                if engine.state.iteration % engine.state.epoch_length == 1:\n                    gb_count_dict[engine.state.epoch] = gc.get_count()\n            elif trigger_event.lower() == \"iteration\":\n                gb_count_dict[engine.state.iteration] = gc.get_count()\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        dataset = Dataset(data, transform=None)\n        data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)\n        GarbageCollector(trigger_event=trigger_event, log_level=30).attach(engine)\n\n        engine.run(data_loader, max_epochs=5)\n\n        first_count = 0\n        for iter, gb_count in gb_count_dict.items():\n            # At least one zero-generation object is collected\n            # self.assertGreaterEqual(gb_count[0], 0)\n            if iter > 1:\n                # Since we are collecting all objects from all generations manually at each call,\n                # starting from the second call, there shouldn't be any 1st and 2nd\n                # generation objects available to collect.\n                self.assertEqual(gb_count[1], first_count)\n                self.assertEqual(gb_count[2], first_count)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_hausdorff_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine\n\nfrom monai.handlers import HausdorffDistance\nfrom tests.test_utils import assert_allclose\n\n\ndef create_spherical_seg_3d(\n    radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99)\n) -> np.ndarray:\n    \"\"\"\n    Return a 3D image with a sphere inside. Voxel values will be\n    1 inside the sphere, and 0 elsewhere.\n\n    Args:\n        radius: radius of sphere (in terms of number of voxels, can be partial)\n        centre: location of sphere centre.\n        im_shape: shape of image to create\n\n    See also:\n        :py:meth:`~create_test_image_3d`\n    \"\"\"\n    # Create image\n    image = np.zeros(im_shape, dtype=np.int32)\n    spy, spx, spz = np.ogrid[\n        -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2]\n    ]\n    circle = (spx * spx + spy * spy + spz * spz) <= radius * radius\n\n    image[circle] = 1\n    image[~circle] = 0\n    return image\n\n\nsampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0)\n# test input a list of channel-first tensor\nsampler_sphere_gt = [torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0)]\nsampler_sphere_zeros = torch.zeros_like(sampler_sphere)\n\nTEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt]\nTEST_SAMPLE_2 = [sampler_sphere_gt, sampler_sphere_gt]\nTEST_SAMPLE_3 = [sampler_sphere_zeros, sampler_sphere_gt]\nTEST_SAMPLE_4 = [sampler_sphere_zeros, sampler_sphere_zeros]\n\n\nclass TestHandlerHausdorffDistance(unittest.TestCase):\n    # TODO test multi node Hausdorff Distance\n\n    def test_compute(self):\n        hd_metric = HausdorffDistance(include_background=True)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        hd_metric.attach(engine, \"hausdorff_distance\")\n\n        y_pred, y = TEST_SAMPLE_1\n        hd_metric.update([y_pred, y])\n        self.assertEqual(hd_metric.compute(), 10)\n        y_pred, y = TEST_SAMPLE_2\n        hd_metric.update([y_pred, y])\n        self.assertEqual(hd_metric.compute(), 5)\n        y_pred, y = TEST_SAMPLE_3\n        hd_metric.update([y_pred, y])\n        self.assertEqual(hd_metric.compute(), float(\"inf\"))\n        y_pred, y = TEST_SAMPLE_4\n        hd_metric.update([y_pred, y])\n        self.assertEqual(hd_metric.compute(), float(\"inf\"))\n\n    def test_shape_mismatch(self):\n        hd_metric = HausdorffDistance(include_background=True)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = TEST_SAMPLE_1[0]\n            y = torch.ones((1, 1, 10, 10, 10))\n            hd_metric.update([y_pred, y])\n\n    def test_reduction(self):\n        hd_metric = HausdorffDistance(include_background=True, reduction=\"mean_channel\")\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        hd_metric.attach(engine, \"hausdorff_distance\")\n\n        y_pred, y = TEST_SAMPLE_1\n        hd_metric.update([y_pred, y])\n        y_pred, y = TEST_SAMPLE_2\n        hd_metric.update([y_pred, y])\n        assert_allclose(hd_metric.compute().float(), torch.tensor([10.0, 0.0]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_ignite_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.handlers import IgniteMetricHandler, from_engine\nfrom monai.losses import DiceLoss\nfrom monai.metrics import LossMetric\nfrom tests.test_utils import SkipIfNoModule, assert_allclose, optional_import\n\ntry:\n    _, has_ignite = optional_import(\"ignite\")\n    from ignite.engine import Engine, Events\nexcept ImportError:\n    has_ignite = False\n\nTEST_CASE_1 = [\n    {\"reduction\": \"none\", \"include_background\": True},\n    {},\n    {\"output_transform\": from_engine([\"pred\", \"label\"])},\n    0.25,\n]\nTEST_CASE_2 = [\n    {\"reduction\": \"mean\", \"include_background\": False},\n    {},\n    {\"output_transform\": from_engine([\"pred\", \"label\"])},\n    0.5,\n]\nTEST_CASE_3 = [\n    {\"reduction\": \"none\"},\n    {\"reduction\": \"mean_channel\"},\n    {\"output_transform\": from_engine([\"pred\", \"label\"])},\n    torch.Tensor([0.5, 0]),\n]\n\nTEST_CASES = [\n    [\n        {\"include_background\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n        },\n        0,\n    ],\n    [\n        {\"include_background\": False, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n        },\n        0,\n    ],\n    [\n        {\"include_background\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]),\n        },\n        1,\n    ],\n    [\n        {\"include_background\": False, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]),\n            \"target\": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n        },\n        1,\n    ],\n    [\n        {\"include_background\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[0.0, 1.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n        },\n        0.333333,\n    ],\n    [\n        {\"include_background\": False, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[0.0, 1.0], [0.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]]]),\n        },\n        0,\n    ],\n]\n\n\nclass TestHandlerIgniteMetricHandler(unittest.TestCase):\n    @SkipIfNoModule(\"ignite\")\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_metric_fn(self, loss_params, metric_params, handler_params, expected_avg):\n        loss_fn = DiceLoss(**loss_params)\n        metric_fn = LossMetric(loss_fn=loss_fn, **metric_params)\n        ignite_metric = IgniteMetricHandler(metric_fn=metric_fn, **handler_params)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        ignite_metric.attach(engine=engine, name=\"ignite_dice_loss\")\n        y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]])\n        y = torch.tensor([[[[0.0, 1.0]], [[0.0, 1.0]]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]])\n        y = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"ignite_dice_loss\"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @SkipIfNoModule(\"ignite\")\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_loss_fn(self, loss_params, metric_params, handler_params, expected_avg):\n        loss_fn = DiceLoss(**loss_params)\n        ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, **handler_params, **metric_params)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        ignite_metric.attach(engine=engine, name=\"ignite_dice_loss\")\n        y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]])\n        y = torch.tensor([[[[0.0, 1.0]], [[0.0, 1.0]]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        y_pred = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]])\n        y = torch.tensor([[[[0.0, 1.0]], [[1.0, 0.0]]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"ignite_dice_loss\"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @SkipIfNoModule(\"ignite\")\n    @parameterized.expand(TEST_CASES)\n    def test_dice_loss(self, input_param, input_data, expected_val):\n        loss_fn = DiceLoss(**input_param)\n        ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine([\"pred\", \"label\"]))\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        ignite_metric.attach(engine=engine, name=\"ignite_dice_loss\")\n        y_pred = input_data[\"input\"]\n        y = input_data[\"target\"]\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"ignite_dice_loss\"], expected_val, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @SkipIfNoModule(\"ignite\")\n    @parameterized.expand(TEST_CASES[0:2])\n    def test_old_ignite_metric(self, input_param, input_data, expected_val):\n        loss_fn = DiceLoss(**input_param)\n        ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine([\"pred\", \"label\"]))\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        ignite_metric.attach(engine=engine, name=\"ignite_dice_loss\")\n        y_pred = input_data[\"input\"]\n        y = input_data[\"target\"]\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"ignite_dice_loss\"], expected_val, atol=1e-4, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_lr_scheduler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport re\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine, Events\n\nfrom monai.handlers import LrScheduleHandler\n\n\nclass TestHandlerLrSchedule(unittest.TestCase):\n\n    def test_content(self):\n        data = [0] * 8\n        test_lr = 0.1\n        gamma = 0.1\n\n        # set up engine\n        def _train_func(engine, batch):\n            pass\n\n        val_engine = Engine(_train_func)\n        train_engine = Engine(_train_func)\n\n        @train_engine.on(Events.EPOCH_COMPLETED)\n        def run_validation(engine):\n            val_engine.run(data)\n            val_engine.state.metrics[\"val_loss\"] = 1\n\n        # set up testing handler\n        net = torch.nn.PReLU()\n\n        def _reduce_lr_on_plateau():\n            optimizer = torch.optim.SGD(net.parameters(), test_lr)\n            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1)\n            handler = LrScheduleHandler(lr_scheduler, step_transform=lambda x: val_engine.state.metrics[\"val_loss\"])\n            handler.attach(train_engine)\n            return handler\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            key_to_handler = \"test_log_lr\"\n            key_to_print = \"Current learning rate\"\n            filename = os.path.join(tempdir, \"test_lr.log\")\n            # test with additional logging handler\n            file_saver = logging.FileHandler(filename, mode=\"w\")\n            file_saver.setLevel(logging.INFO)\n            logger = logging.getLogger(key_to_handler)\n            logger.setLevel(logging.INFO)\n            logger.addHandler(file_saver)\n\n            def _reduce_on_step():\n                optimizer = torch.optim.SGD(net.parameters(), test_lr)\n                lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=gamma)\n                handler = LrScheduleHandler(lr_scheduler, name=key_to_handler)\n                handler.attach(train_engine)\n                return handler\n\n            schedulers = _reduce_lr_on_plateau(), _reduce_on_step()\n\n            train_engine.run(data, max_epochs=5)\n            file_saver.close()\n            logger.removeHandler(file_saver)\n\n            with open(filename) as f:\n                output_str = f.read()\n                has_key_word = re.compile(f\".*{key_to_print}.*\")\n                content_count = 0\n                for line in output_str.split(\"\\n\"):\n                    if has_key_word.match(line):\n                        content_count += 1\n                self.assertTrue(content_count > 0)\n\n        for scheduler in schedulers:\n            np.testing.assert_allclose(scheduler.lr_scheduler._last_lr[0], 0.001)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_mean_dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.handlers import MeanDice, from_engine\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [{\"include_background\": True, \"output_transform\": from_engine([\"pred\", \"label\"])}, 0.75, (4, 2)]\nTEST_CASE_2 = [{\"include_background\": False, \"output_transform\": from_engine([\"pred\", \"label\"])}, 0.66666, (4, 1)]\nTEST_CASE_3 = [\n    {\"reduction\": \"mean_channel\", \"output_transform\": from_engine([\"pred\", \"label\"])},\n    torch.Tensor([1.0, 0.0, 1.0, 1.0]),\n    (4, 2),\n]\n\n\nclass TestHandlerMeanDice(unittest.TestCase):\n    # TODO test multi node averaged dice\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_compute(self, input_params, expected_avg, details_shape):\n        dice_metric = MeanDice(**input_params)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        dice_metric.attach(engine=engine, name=\"mean_dice\")\n        # test input a list of channel-first tensor\n        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]\n        y = torch.Tensor([[[0], [1]], [[0], [1]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]\n        y = torch.Tensor([[[0], [1]], [[1], [0]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"mean_dice\"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)\n        self.assertTupleEqual(tuple(engine.state.metric_details[\"mean_dice\"].shape), details_shape)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape_mismatch(self, input_params, _expected_avg, _details_shape):\n        dice_metric = MeanDice(**input_params)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((2, 3))\n            dice_metric.update([y_pred, y])\n\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((3, 2))\n            dice_metric.update([y_pred, y])\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_compute_n_class(self, input_params, expected_avg, details_shape):\n        dice_metric = MeanDice(num_classes=2, **input_params)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        dice_metric.attach(engine=engine, name=\"mean_dice\")\n        # test input a list of channel-first tensor\n        y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])]\n        y = torch.Tensor([[[0], [1]], [[0], [1]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        y_pred = [torch.Tensor([[1]]), torch.Tensor([[0]])]  # class indices y_pred\n        y = torch.Tensor([[[1]], [[0]]])  # class indices y\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"mean_dice\"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)\n        self.assertTupleEqual(tuple(engine.state.metric_details[\"mean_dice\"].shape), details_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_mean_iou.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.handlers import MeanIoUHandler, from_engine\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [{\"include_background\": True, \"output_transform\": from_engine([\"pred\", \"label\"])}, 0.75, (4, 2)]\nTEST_CASE_2 = [{\"include_background\": False, \"output_transform\": from_engine([\"pred\", \"label\"])}, 2 / 3, (4, 1)]\nTEST_CASE_3 = [\n    {\"reduction\": \"mean_channel\", \"output_transform\": from_engine([\"pred\", \"label\"])},\n    torch.Tensor([1.0, 0.0, 1.0, 1.0]),\n    (4, 2),\n]\n\n\nclass TestHandlerMeanIoU(unittest.TestCase):\n    # TODO test multi node averaged iou\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_compute(self, input_params, expected_avg, details_shape):\n        iou_metric = MeanIoUHandler(**input_params)\n\n        # set up engine\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        iou_metric.attach(engine=engine, name=\"mean_iou\")\n        # test input a list of channel-first tensor\n        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]\n        y = torch.Tensor([[[0], [1]], [[0], [1]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]\n        y = torch.Tensor([[[0], [1]], [[1], [0]]])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"mean_iou\"], expected_avg)\n        self.assertTupleEqual(tuple(engine.state.metric_details[\"mean_iou\"].shape), details_shape)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape_mismatch(self, input_params, _expected_avg, _details_shape):\n        iou_metric = MeanIoUHandler(**input_params)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((3, 30))\n            iou_metric.update([y_pred, y])\n\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((8, 30))\n            iou_metric.update([y_pred, y])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_metrics_reloaded.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.handlers import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler, from_engine\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose\n\n_, has_metrics = optional_import(\"MetricsReloaded\")\n\nTEST_CASE_BIN_1 = [\n    {\"metric_name\": \"Volume Difference\"},\n    [torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])],\n    [torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])],\n    0.3333,\n]\n\nTEST_CASE_BIN_2 = [\n    {\"metric_name\": \"Boundary IoU\"},\n    [torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])],\n    [torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])],\n    0.6667,\n]\n\nTEST_CASE_BIN_3 = [\n    {\"metric_name\": \"xTh Percentile Hausdorff Distance\"},\n    [torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])],\n    [torch.tensor([[[1.0, 0.0], [1.0, 1.0]]]), torch.tensor([[[1.0, 0.0], [1.0, 1.0]]])],\n    0.9,\n]\n\nTEST_CASE_CAT_1 = [\n    {\"metric_name\": \"Weighted Cohens Kappa\"},\n    [\n        torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),\n        torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),\n    ],\n    [\n        torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),\n        torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),\n    ],\n    0.272727,\n]\n\nTEST_CASE_CAT_2 = [\n    {\"metric_name\": \"Matthews Correlation Coefficient\"},\n    [\n        torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),\n        torch.tensor([[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]),\n    ],\n    [\n        torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),\n        torch.tensor([[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]),\n    ],\n    0.387298,\n]\n\n\n@unittest.skipIf(not has_metrics, \"MetricsReloaded not available.\")\nclass TestHandlerMetricsReloadedBinary(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3])\n    def test_compute(self, input_params, y_pred, y, expected_value):\n        input_params[\"output_transform\"] = from_engine([\"pred\", \"label\"])\n        metric = MetricsReloadedBinaryHandler(**input_params)\n\n        # set up engine\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine=engine, name=input_params[\"metric_name\"])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(\n            engine.state.metrics[input_params[\"metric_name\"]], expected_value, atol=1e-4, rtol=1e-4, type_test=False\n        )\n\n    @parameterized.expand([TEST_CASE_BIN_1, TEST_CASE_BIN_2, TEST_CASE_BIN_3])\n    def test_shape_mismatch(self, input_params, _y_pred, _y, _expected_value):\n        input_params[\"output_transform\"] = from_engine([\"pred\", \"label\"])\n        metric = MetricsReloadedBinaryHandler(**input_params)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = torch.Tensor([[0, 1], [1, 0]])\n            y = torch.ones((2, 3))\n            metric.update([y_pred, y])\n\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = [torch.ones((2, 1, 1)), torch.ones((1, 1, 1))]\n            y = [torch.ones((2, 1, 1)), torch.ones((1, 1, 1))]\n            metric.update([y_pred, y])\n\n\n@unittest.skipIf(not has_metrics, \"MetricsReloaded not available.\")\nclass TestMetricsReloadedCategorical(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2])\n    def test_compute(self, input_params, y_pred, y, expected_value):\n        input_params[\"output_transform\"] = from_engine([\"pred\", \"label\"])\n        metric = MetricsReloadedCategoricalHandler(**input_params)\n\n        # set up engine\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine=engine, name=input_params[\"metric_name\"])\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(\n            engine.state.metrics[input_params[\"metric_name\"]], expected_value, atol=1e-4, rtol=1e-4, type_test=False\n        )\n\n    @parameterized.expand([TEST_CASE_CAT_1, TEST_CASE_CAT_2])\n    def test_shape_mismatch(self, input_params, y_pred, y, _expected_value):\n        input_params[\"output_transform\"] = from_engine([\"pred\", \"label\"])\n        metric = MetricsReloadedCategoricalHandler(**input_params)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred[0] = torch.zeros([3, 2, 1])\n            metric.update([y_pred, y])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_metrics_saver.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine, Events\n\nfrom monai.handlers import MetricsSaver\nfrom monai.utils.enums import PostFix\n\n\nclass TestHandlerMetricsSaver(unittest.TestCase):\n\n    def test_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            metrics_saver = MetricsSaver(\n                save_dir=tempdir,\n                metrics=[\"metric1\", \"metric2\"],\n                metric_details=[\"metric3\", \"metric4\"],\n                batch_transform=lambda x: x[PostFix.meta(\"image\")],\n                summary_ops=[\"mean\", \"median\", \"max\", \"5percentile\", \"95percentile\", \"notnans\"],\n                delimiter=\"\\t\",\n            )\n            # set up engine\n            data = [\n                {PostFix.meta(\"image\"): {\"filename_or_obj\": [\"filepath1\"]}},\n                {PostFix.meta(\"image\"): {\"filename_or_obj\": [\"filepath2\"]}},\n            ]\n\n            def _val_func(engine, batch):\n                pass\n\n            engine = Engine(_val_func)\n\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _save_metrics(engine):\n                engine.state.metrics = {\"metric1\": 1, \"metric2\": 2}\n                engine.state.metric_details = {\n                    \"metric3\": torch.tensor([[1, 2], [2, 3]]),\n                    \"metric4\": torch.tensor([[5, 6], [7, torch.tensor(float(\"nan\"))]]),\n                }\n\n            metrics_saver.attach(engine)\n            engine.run(data, max_epochs=1)\n\n            # check the metrics.csv and content\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metrics.csv\")))\n            with open(os.path.join(tempdir, \"metrics.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    self.assertEqual(row, [f\"metric{i + 1}\\t{i + 1}\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_raw.csv\")))\n            # check the metric_raw.csv and content\n            with open(os.path.join(tempdir, \"metric3_raw.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    if i > 0:\n                        self.assertEqual(row, [f\"filepath{i}\\t{float(i):.4f}\\t{float(i + 1):.4f}\\t{i + 0.5:.4f}\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_summary.csv\")))\n            # check the metric_summary.csv and content\n            with open(os.path.join(tempdir, \"metric4_summary.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    if i == 1:\n                        self.assertEqual(row, [\"class0\\t6.0000\\t6.0000\\t7.0000\\t5.1000\\t6.9000\\t2.0000\"])\n                    elif i == 2:\n                        self.assertEqual(row, [\"class1\\t6.0000\\t6.0000\\t6.0000\\t6.0000\\t6.0000\\t1.0000\"])\n                    elif i == 3:\n                        self.assertEqual(row, [\"mean\\t6.2500\\t6.2500\\t7.0000\\t5.5750\\t6.9250\\t2.0000\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric4_raw.csv\")))\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_summary.csv\")))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_metrics_saver_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\n\nimport torch\nimport torch.distributed as dist\nfrom ignite.engine import Engine, Events\n\nfrom monai.handlers import MetricsSaver\nfrom monai.utils import evenly_divisible_all_gather\nfrom monai.utils.enums import PostFix\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedMetricsSaver(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            self._run(tempdir)\n\n    def _run(self, tempdir):\n        my_rank = dist.get_rank()\n        fnames = [\"aaa\" * 300, \"bbb\" * 301, \"ccc\" * 302]\n\n        metrics_saver = MetricsSaver(\n            save_dir=tempdir,\n            metrics=[\"metric1\", \"metric2\"],\n            metric_details=[\"metric3\", \"metric4\"],\n            batch_transform=lambda x: x[PostFix.meta(\"image\")],\n            summary_ops=\"*\",\n            delimiter=\"\\t\",\n        )\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n\n        # define here to ensure symbol always exists regardless of the following if conditions\n        data = [{PostFix.meta(\"image\"): {\"filename_or_obj\": [fnames[0]]}}]\n\n        if my_rank == 0:\n\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _save_metrics0(engine):\n                engine.state.metrics = {\"metric1\": 1, \"metric2\": 2}\n                engine.state.metric_details = {\"metric3\": torch.tensor([[1, 2]]), \"metric4\": torch.tensor([[5, 6]])}\n\n        if my_rank == 1:\n            # different ranks have different data length\n            data = [\n                {PostFix.meta(\"image\"): {\"filename_or_obj\": [fnames[1]]}},\n                {PostFix.meta(\"image\"): {\"filename_or_obj\": [fnames[2]]}},\n            ]\n\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _save_metrics1(engine):\n                engine.state.metrics = {\"metric1\": 1, \"metric2\": 2}\n                engine.state.metric_details = {\n                    \"metric3\": torch.tensor([[2, 3], [3, 4]]),\n                    \"metric4\": torch.tensor([[6, 7], [7, 8]]),\n                }\n\n        @engine.on(Events.EPOCH_COMPLETED)\n        def _all_gather(engine):\n            scores = engine.state.metric_details[\"metric3\"]\n            engine.state.metric_details[\"metric3\"] = evenly_divisible_all_gather(data=scores, concat=True)\n            scores = engine.state.metric_details[\"metric4\"]\n            engine.state.metric_details[\"metric4\"] = evenly_divisible_all_gather(data=scores, concat=True)\n\n        metrics_saver.attach(engine)\n        engine.run(data, max_epochs=1)\n\n        if my_rank == 0:\n            # check the metrics.csv and content\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metrics.csv\")))\n            with open(os.path.join(tempdir, \"metrics.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    self.assertEqual(row, [f\"metric{i + 1}\\t{i + 1}\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_raw.csv\")))\n            # check the metric_raw.csv and content\n            with open(os.path.join(tempdir, \"metric3_raw.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    if i > 0:\n                        expected = [f\"{fnames[i - 1]}\\t{float(i):.4f}\\t{float(i + 1):.4f}\\t{i + 0.5:.4f}\"]\n                        self.assertEqual(row, expected)\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_summary.csv\")))\n            # check the metric_summary.csv and content\n            with open(os.path.join(tempdir, \"metric3_summary.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    if i == 1:\n                        self.assertEqual(row, [\"class0\\t2.0000\\t2.0000\\t3.0000\\t1.0000\\t2.8000\\t0.8165\\t3.0000\"])\n                    elif i == 2:\n                        self.assertEqual(row, [\"class1\\t3.0000\\t3.0000\\t4.0000\\t2.0000\\t3.8000\\t0.8165\\t3.0000\"])\n                    elif i == 3:\n                        self.assertEqual(row, [\"mean\\t2.5000\\t2.5000\\t3.5000\\t1.5000\\t3.3000\\t0.8165\\t3.0000\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric4_raw.csv\")))\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric4_summary.csv\")))\n        dist.barrier()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_mlflow.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom concurrent.futures import ThreadPoolExecutor\nfrom unittest.mock import MagicMock\n\nimport numpy as np\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.apps import download_and_extract\nfrom monai.bundle import ConfigWorkflow, download\nfrom monai.handlers import MLFlowHandler\nfrom monai.utils import optional_import, path_to_uri\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick\n\n_, has_dataset_tracking = optional_import(\"mlflow\", \"2.4.0\")\n\n\ndef get_event_filter(e):\n    def event_filter(_, event):\n        if event in e:\n            return True\n        return False\n\n    return event_filter\n\n\ndef dummy_train(tracking_folder):\n    tempdir = tempfile.mkdtemp()\n\n    # set up engine\n    def _train_func(engine, batch):\n        return [batch + 1.0]\n\n    engine = Engine(_train_func)\n\n    # set up testing handler\n    test_path = os.path.join(tempdir, tracking_folder)\n    handler = MLFlowHandler(\n        iteration_log=False,\n        epoch_log=True,\n        tracking_uri=path_to_uri(test_path),\n        state_attributes=[\"test\"],\n        close_on_complete=True,\n    )\n    handler.attach(engine)\n    engine.run(range(3), max_epochs=2)\n    return test_path\n\n\nclass TestHandlerMLFlow(unittest.TestCase):\n    def setUp(self):\n        self.tmpdir_list = []\n\n    def tearDown(self):\n        for tmpdir in self.tmpdir_list:\n            if tmpdir and os.path.exists(tmpdir):\n                shutil.rmtree(tmpdir)\n\n    def test_multi_run(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up the train function for engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            # create and run an engine several times to get several runs\n            create_engine_times = 3\n            for _ in range(create_engine_times):\n                engine = Engine(_train_func)\n\n                @engine.on(Events.EPOCH_COMPLETED)\n                def _update_metric(engine):\n                    current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                    engine.state.metrics[\"acc\"] = current_metric + 0.1\n                    engine.state.test = current_metric\n\n                # set up testing handler\n                test_path = os.path.join(tempdir, \"mlflow_test\")\n                handler = MLFlowHandler(\n                    iteration_log=False,\n                    epoch_log=True,\n                    tracking_uri=path_to_uri(test_path),\n                    state_attributes=[\"test\"],\n                    close_on_complete=True,\n                )\n                handler.attach(engine)\n                engine.run(range(3), max_epochs=2)\n                run_cnt = len(handler.client.search_runs(handler.experiment.experiment_id))\n                handler.close()\n            # the run count should equal to the times of creating engine\n            self.assertEqual(create_engine_times, run_cnt)\n\n    def test_metrics_track(self):\n        experiment_param = {\"backbone\": \"efficientnet_b0\"}\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n                # log nested metrics\n                engine.state.metrics[\"acc_per_label\"] = {\n                    \"label_0\": current_metric + 0.1,\n                    \"label_1\": current_metric + 0.2,\n                }\n                engine.state.test = current_metric\n\n            # set up testing handler\n            test_path = os.path.join(tempdir, \"mlflow_test\")\n            artifact_path = os.path.join(tempdir, \"artifacts\")\n            os.makedirs(artifact_path, exist_ok=True)\n            dummy_numpy = np.zeros((64, 64, 3))\n            dummy_path = os.path.join(artifact_path, \"tmp.npy\")\n            np.save(dummy_path, dummy_numpy)\n            handler = MLFlowHandler(\n                iteration_log=False,\n                epoch_log=True,\n                tracking_uri=path_to_uri(test_path),\n                state_attributes=[\"test\"],\n                experiment_param=experiment_param,\n                artifacts=[artifact_path],\n                close_on_complete=False,\n            )\n            handler.attach(engine)\n            engine.run(range(3), max_epochs=2)\n            cur_run = handler.client.get_run(handler.cur_run.info.run_id)\n            self.assertTrue(\"label_0\" in cur_run.data.metrics.keys())\n            handler.close()\n            # check logging output\n            self.assertTrue(len(glob.glob(test_path)) > 0)\n\n    @parameterized.expand([[True], [get_event_filter([1, 2])]])\n    def test_metrics_track_mock(self, epoch_log):\n        experiment_param = {\"backbone\": \"efficientnet_b0\"}\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n                engine.state.test = current_metric\n\n            # set up testing handler\n            test_path = os.path.join(tempdir, \"mlflow_test\")\n            handler = MLFlowHandler(\n                iteration_log=False,\n                epoch_log=epoch_log,\n                tracking_uri=path_to_uri(test_path),\n                state_attributes=[\"test\"],\n                experiment_param=experiment_param,\n                close_on_complete=True,\n            )\n            handler._default_epoch_log = MagicMock()\n            handler.attach(engine)\n\n            max_epochs = 4\n            engine.run(range(3), max_epochs=max_epochs)\n            handler.close()\n            # check logging output\n            if epoch_log is True:\n                self.assertEqual(handler._default_epoch_log.call_count, max_epochs)\n            else:\n                self.assertEqual(handler._default_epoch_log.call_count, 2)  # 2 = len([1, 2]) from event_filter\n\n    @parameterized.expand([[True], [get_event_filter([1, 3])]])\n    def test_metrics_track_iters_mock(self, iteration_log):\n        experiment_param = {\"backbone\": \"efficientnet_b0\"}\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n                engine.state.test = current_metric\n\n            # set up testing handler\n            test_path = os.path.join(tempdir, \"mlflow_test\")\n            handler = MLFlowHandler(\n                iteration_log=iteration_log,\n                epoch_log=False,\n                tracking_uri=path_to_uri(test_path),\n                state_attributes=[\"test\"],\n                experiment_param=experiment_param,\n                close_on_complete=True,\n            )\n            handler._default_iteration_log = MagicMock()\n            handler.attach(engine)\n\n            num_iters = 3\n            max_epochs = 2\n            engine.run(range(num_iters), max_epochs=max_epochs)\n            handler.close()\n            # check logging output\n            if iteration_log is True:\n                self.assertEqual(handler._default_iteration_log.call_count, num_iters * max_epochs)\n            else:\n                self.assertEqual(handler._default_iteration_log.call_count, 2)  # 2 = len([1, 3]) from event_filter\n\n    def test_multi_thread(self):\n        test_uri_list = [\"monai_mlflow_test1\", \"monai_mlflow_test2\"]\n        with ThreadPoolExecutor(2, \"Training\") as executor:\n            futures = {}\n            for t in test_uri_list:\n                futures[t] = executor.submit(dummy_train, t)\n\n            for _, future in futures.items():\n                res = future.result()\n                self.tmpdir_list.append(res)\n                self.assertTrue(len(glob.glob(res)) > 0)\n\n    @skip_if_quick\n    @unittest.skipUnless(has_dataset_tracking, reason=\"Requires mlflow version >= 2.4.0.\")\n    def test_dataset_tracking(self):\n        test_bundle_name = \"endoscopic_tool_segmentation\"\n        with tempfile.TemporaryDirectory() as tempdir:\n            resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/endoscopic_tool_dataset.zip\"\n            md5 = \"f82da47259c0a617202fb54624798a55\"\n            compressed_file = os.path.join(tempdir, \"endoscopic_tool_segmentation.zip\")\n            data_dir = os.path.join(tempdir, \"endoscopic_tool_dataset\")\n            with skip_if_downloading_fails():\n                if not os.path.exists(data_dir):\n                    download_and_extract(resource, compressed_file, tempdir, md5)\n\n                download(test_bundle_name, bundle_dir=tempdir)\n\n                bundle_root = os.path.join(tempdir, test_bundle_name)\n                config_file = os.path.join(bundle_root, \"configs/inference.json\")\n                meta_file = os.path.join(bundle_root, \"configs/metadata.json\")\n                logging_file = os.path.join(bundle_root, \"configs/logging.conf\")\n                workflow = ConfigWorkflow(\n                    workflow_type=\"infer\",\n                    config_file=config_file,\n                    meta_file=meta_file,\n                    logging_file=logging_file,\n                    init_id=\"initialize\",\n                    run_id=\"run\",\n                    final_id=\"finalize\",\n                )\n\n                tracking_path = os.path.join(bundle_root, \"eval\")\n                workflow.bundle_root = bundle_root\n                workflow.dataset_dir = data_dir\n                workflow.initialize()\n                infer_dataset = workflow.dataset\n                mlflow_handler = MLFlowHandler(\n                    iteration_log=False,\n                    epoch_log=False,\n                    dataset_dict={\"test\": infer_dataset},\n                    tracking_uri=path_to_uri(tracking_path),\n                )\n                mlflow_handler.attach(workflow.evaluator)\n                workflow.run()\n                workflow.finalize()\n\n                cur_run = mlflow_handler.client.get_run(mlflow_handler.cur_run.info.run_id)\n                logged_nontrain_set = [x for x in cur_run.inputs.dataset_inputs if x.dataset.name.startswith(\"test\")]\n                self.assertEqual(len(logged_nontrain_set), 1)\n                mlflow_handler.close()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_nvtx.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import Events\nfrom parameterized import parameterized\n\nfrom monai.engines import SupervisedEvaluator\nfrom monai.handlers import StatsHandler, from_engine\nfrom monai.handlers.nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose\n\n_, has_nvtx = optional_import(\"torch._C._nvtx\", descriptor=\"NVTX is not installed. Are you sure you have a CUDA build?\")\n\nTENSOR_0 = torch.tensor([[[[1.0], [2.0]], [[3.0], [4.0]]]])\n\nTENSOR_1 = torch.tensor([[[[0.0], [-2.0]], [[-3.0], [4.0]]]])\n\nTENSOR_1_EXPECTED = torch.tensor([[[1.0], [0.5]], [[0.25], [5.0]]])\n\nTEST_CASE_0 = [[{\"image\": TENSOR_0}], TENSOR_0[0] + 1.0]\nTEST_CASE_1 = [[{\"image\": TENSOR_1}], TENSOR_1_EXPECTED]\n\n\nclass TestHandlerDecollateBatch(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1])\n    @unittest.skipUnless(has_nvtx, \"CUDA is required for NVTX!\")\n    def test_compute(self, data, expected):\n        # Set up handlers\n        handlers = [\n            # Mark with Ignite Event\n            MarkHandler(Events.STARTED),\n            # Mark with literal\n            MarkHandler(\"EPOCH_STARTED\"),\n            # Mark with literal and providing the message\n            MarkHandler(\"EPOCH_STARTED\", \"Start of the epoch\"),\n            # Define a range using one prefix (between BATCH_STARTED and BATCH_COMPLETED)\n            RangeHandler(\"Batch\"),\n            # Define a range using a pair of events\n            RangeHandler((Events.STARTED, Events.COMPLETED)),\n            # Define a range using a pair of literals\n            RangeHandler((\"GET_BATCH_STARTED\", \"GET_BATCH_COMPLETED\"), msg=\"Batching!\"),\n            # Define a range using a pair of literal and events\n            RangeHandler((\"GET_BATCH_STARTED\", Events.COMPLETED)),\n            # Define the start of range using literal\n            RangePushHandler(\"ITERATION_STARTED\"),\n            # Define the start of range using event\n            RangePushHandler(Events.ITERATION_STARTED, \"Iteration 2\"),\n            # Define the start of range using literals and providing message\n            RangePushHandler(\"EPOCH_STARTED\", \"Epoch 2\"),\n            # Define the end of range using Ignite Event\n            RangePopHandler(Events.ITERATION_COMPLETED),\n            RangePopHandler(Events.EPOCH_COMPLETED),\n            # Define the end of range using literal\n            RangePopHandler(\"ITERATION_COMPLETED\"),\n            # Other handlers\n            StatsHandler(tag_name=\"train\", output_transform=from_engine([\"label\"], first=True)),\n        ]\n\n        # Set up an engine\n        engine = SupervisedEvaluator(\n            device=torch.device(\"cpu:0\"),\n            val_data_loader=data,\n            epoch_length=1,\n            network=torch.nn.PReLU(),\n            postprocessing=lambda x: dict(pred=x[\"pred\"] + 1.0),\n            decollate=True,\n            val_handlers=handlers,\n        )\n        # Run the engine\n        engine.run()\n\n        # Get the output from the engine\n        output = engine.state.output[0]\n\n        assert_allclose(output[\"pred\"], expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_panoptic_quality.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.handlers import PanopticQuality, from_engine\nfrom tests.test_utils import SkipIfNoModule, assert_allclose\n\nsample_1_pred = torch.as_tensor(\n    [[[0, 1, 1, 1], [0, 0, 5, 5], [2, 0, 3, 3], [2, 2, 2, 0]], [[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]]]\n)\n\nsample_1_gt = torch.as_tensor(\n    [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [0, 0, 1, 1], [2, 0, 3, 3], [4, 4, 4, 3]]]\n)\n\nsample_2_pred = torch.as_tensor(\n    [[[3, 1, 1, 1], [3, 1, 1, 4], [3, 1, 4, 4], [3, 2, 2, 4]], [[0, 1, 1, 1], [2, 2, 2, 2], [2, 0, 0, 3], [4, 2, 2, 3]]]\n)\n\nsample_2_gt = torch.as_tensor(\n    [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [2, 1, 1, 3], [2, 0, 0, 3], [4, 2, 2, 3]]]\n)\n\nTEST_CASE_1 = [{\"num_classes\": 4, \"output_transform\": from_engine([\"pred\", \"label\"])}, [0.6667, 0.1538, 0.6667, 0.5714]]\nTEST_CASE_2 = [\n    {\n        \"num_classes\": 5,\n        \"output_transform\": from_engine([\"pred\", \"label\"]),\n        \"metric_name\": \"rq\",\n        \"match_iou_threshold\": 0.3,\n    },\n    [0.6667, 0.7692, 0.8889, 0.5714, 0.0000],\n]\nTEST_CASE_3 = [\n    {\n        \"num_classes\": 5,\n        \"reduction\": \"mean\",\n        \"output_transform\": from_engine([\"pred\", \"label\"]),\n        \"metric_name\": \"SQ\",\n        \"match_iou_threshold\": 0.2,\n    },\n    0.8235,\n]\n\n\n@SkipIfNoModule(\"scipy.optimize\")\nclass TestHandlerPanopticQuality(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_compute(self, input_params, expected_avg):\n        metric = PanopticQuality(**input_params)\n\n        # set up engine\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine=engine, name=\"panoptic_quality\")\n        # test input a list of channel-first tensor\n        y_pred = [sample_1_pred, sample_2_pred]\n        y = [sample_1_gt, sample_2_gt]\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n        y_pred = [sample_1_pred, sample_1_pred]\n        y = [sample_1_gt, sample_1_gt]\n        engine.state.output = {\"pred\": y_pred, \"label\": y}\n        engine.fire_event(Events.ITERATION_COMPLETED)\n\n        engine.fire_event(Events.EPOCH_COMPLETED)\n        assert_allclose(engine.state.metrics[\"panoptic_quality\"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_parameter_scheduler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom ignite.engine import Engine, Events\nfrom torch.nn import Module\n\nfrom monai.handlers.parameter_scheduler import ParamSchedulerHandler\nfrom tests.test_utils import assert_allclose\n\n\nclass ToyNet(Module):\n    def __init__(self, value):\n        super().__init__()\n        self.value = value\n\n    def forward(self, input):\n        return input\n\n    def get_value(self):\n        return self.value\n\n    def set_value(self, value):\n        self.value = value\n\n\nclass TestHandlerParameterScheduler(unittest.TestCase):\n    def test_linear_scheduler(self):\n        # Testing step_constant\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=\"linear\",\n            vc_kwargs={\"initial_value\": 0, \"step_constant\": 2, \"step_max_value\": 5, \"max_value\": 10},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=2)\n        assert_allclose(net.get_value(), 0)\n\n        # Testing linear increase\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=\"linear\",\n            vc_kwargs={\"initial_value\": 0, \"step_constant\": 2, \"step_max_value\": 5, \"max_value\": 10},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=3)\n        assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0)\n\n        # Testing max_value\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=\"linear\",\n            vc_kwargs={\"initial_value\": 0, \"step_constant\": 2, \"step_max_value\": 5, \"max_value\": 10},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=10)\n        assert_allclose(net.get_value(), 10)\n\n    def test_exponential_scheduler(self):\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=\"exponential\",\n            vc_kwargs={\"initial_value\": 10, \"gamma\": 0.99},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=2)\n        assert_allclose(net.get_value(), 10 * 0.99 * 0.99)\n\n    def test_step_scheduler(self):\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=\"step\",\n            vc_kwargs={\"initial_value\": 10, \"gamma\": 0.99, \"step_size\": 5},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=10)\n        assert_allclose(net.get_value(), 10 * 0.99 * 0.99)\n\n    def test_multistep_scheduler(self):\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=\"multistep\",\n            vc_kwargs={\"initial_value\": 10, \"gamma\": 0.99, \"milestones\": [3, 6]},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=10)\n        assert_allclose(net.get_value(), 10 * 0.99 * 0.99)\n\n    def test_custom_scheduler(self):\n        def custom_logic(initial_value, gamma, current_step):\n            return initial_value * gamma ** (current_step % 9)\n\n        net = ToyNet(value=-1)\n        engine = Engine(lambda e, b: None)\n        ParamSchedulerHandler(\n            parameter_setter=net.set_value,\n            value_calculator=custom_logic,\n            vc_kwargs={\"initial_value\": 10, \"gamma\": 0.99},\n            epoch_level=True,\n            event=Events.EPOCH_COMPLETED,\n        ).attach(engine)\n        engine.run([0] * 8, max_epochs=2)\n        assert_allclose(net.get_value(), 10 * 0.99 * 0.99)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_post_processing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.engines import SupervisedEvaluator\nfrom monai.handlers import PostProcessing\nfrom monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd\nfrom tests.test_utils import assert_allclose\n\n# test lambda function as `transform`\nTEST_CASE_1 = [{\"transform\": lambda x: dict(pred=x[\"pred\"] + 1.0)}, False, torch.tensor([[[[1.9975], [1.9997]]]])]\n# test composed postprocessing transforms as `transform`\nTEST_CASE_2 = [\n    {\n        \"transform\": Compose(\n            [\n                CopyItemsd(keys=\"filename\", times=1, names=\"filename_bak\"),\n                AsDiscreted(keys=\"pred\", threshold=0.5, to_onehot=2),\n            ]\n        ),\n        \"event\": \"iteration_completed\",\n    },\n    True,\n    torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]]),\n]\n\n\nclass TestHandlerPostProcessing(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_compute(self, input_params, decollate, expected):\n        data = [\n            {\"image\": torch.tensor([[[[2.0], [3.0]]]]), \"filename\": [\"test1\"]},\n            {\"image\": torch.tensor([[[[6.0], [8.0]]]]), \"filename\": [\"test2\"]},\n        ]\n        # set up engine, PostProcessing handler works together with postprocessing transforms of engine\n        engine = SupervisedEvaluator(\n            device=torch.device(\"cpu:0\"),\n            val_data_loader=data,\n            epoch_length=2,\n            network=torch.nn.PReLU(),\n            postprocessing=Compose([Activationsd(keys=\"pred\", sigmoid=True)]),\n            val_handlers=[PostProcessing(**input_params)],\n            decollate=decollate,\n        )\n        engine.run()\n\n        if isinstance(engine.state.output, list):\n            # test decollated list items\n            for o, e in zip(engine.state.output, expected):\n                assert_allclose(o[\"pred\"], e, atol=1e-4, rtol=1e-4, type_test=False)\n                filename = o.get(\"filename_bak\")\n                if filename is not None:\n                    self.assertEqual(filename, \"test2\")\n        else:\n            # test batch data\n            assert_allclose(engine.state.output[\"pred\"], expected, atol=1e-4, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_prob_map_producer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine\nfrom parameterized import parameterized\n\nfrom monai.data import DataLoader, Dataset, MetaTensor\nfrom monai.engines import Evaluator\nfrom monai.handlers import ProbMapProducer, ValidationHandler\nfrom monai.utils.enums import ProbMapKeys\n\nTEST_CASE_0 = [\"temp_image_inference_output_1\", 1]\nTEST_CASE_1 = [\"temp_image_inference_output_2\", 9]\nTEST_CASE_2 = [\"temp_image_inference_output_3\", 100]\n\n\nclass TestDataset(Dataset):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __init__(self, name, size):\n        super().__init__(\n            data=[\n                {\n                    \"image\": name,\n                    ProbMapKeys.COUNT.value: size,\n                    ProbMapKeys.SIZE.value: np.array([size + 1, size + 1]),\n                    ProbMapKeys.LOCATION.value: np.array([i, i + 1]),\n                }\n                for i in range(size)\n            ]\n        )\n        self.image_data = [\n            {\n                ProbMapKeys.NAME.value: name,\n                ProbMapKeys.COUNT.value: size,\n                ProbMapKeys.SIZE.value: np.array([size + 1, size + 1]),\n            }\n        ]\n\n    def __getitem__(self, index):\n        image = np.ones((3, 2, 2)) * index\n        metadata = {\n            ProbMapKeys.COUNT.value: self.data[index][ProbMapKeys.COUNT.value],\n            ProbMapKeys.NAME.value: self.data[index][\"image\"],\n            ProbMapKeys.SIZE.value: self.data[index][ProbMapKeys.SIZE.value],\n            ProbMapKeys.LOCATION.value: self.data[index][ProbMapKeys.LOCATION.value],\n        }\n\n        return {\"image\": MetaTensor(x=image, meta=metadata), \"pred\": index + 1}\n\n\nclass TestEvaluator(Evaluator):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def _iteration(self, engine, batchdata):\n        return batchdata\n\n\nclass TestHandlerProbMapGenerator(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])\n    def test_prob_map_generator(self, name, size):\n        # set up dataset\n        dataset = TestDataset(name, size)\n        batch_size = 2\n        data_loader = DataLoader(dataset, batch_size=batch_size)\n\n        # set up engine\n        def inference(engine, batch):\n            pass\n\n        engine = Engine(inference)\n\n        tests_path = Path(__file__).parents[1].as_posix()\n        # add ProbMapGenerator() to evaluator\n        output_dir = os.path.join(tests_path, \"testing_data\")\n        prob_map_gen = ProbMapProducer(output_dir=output_dir)\n\n        evaluator = TestEvaluator(\n            torch.device(\"cpu:0\"), data_loader, np.ceil(size / batch_size), val_handlers=[prob_map_gen]\n        )\n\n        # set up validation handler\n        validation = ValidationHandler(interval=1, validator=None)\n        validation.attach(engine)\n        validation.set_validator(validator=evaluator)\n\n        engine.run(data_loader)\n\n        prob_map = np.load(os.path.join(output_dir, name + \".npy\"))\n        self.assertListEqual(np.vstack(prob_map.nonzero()).T.tolist(), [[i, i + 1] for i in range(size)])\n        self.assertListEqual(prob_map[prob_map.nonzero()].tolist(), [i + 1 for i in range(size)])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_regression_metrics.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom functools import partial\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine\n\nfrom monai.handlers import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError\nfrom monai.utils import set_determinism\n\n\n# define a numpy flatten function that only preserves batch dimension\ndef flatten(data):\n    return np.reshape(data, [data.shape[0], -1])\n\n\n# define metrics computation truth functions to check our monai metrics against\ndef msemetric_np(y_pred, y):\n    return np.mean((flatten(y_pred) - flatten(y)) ** 2)\n\n\ndef maemetric_np(y_pred, y):\n    return np.mean(np.abs(flatten(y_pred) - flatten(y)))\n\n\ndef rmsemetric_np(y_pred, y):\n    return np.mean(np.sqrt(np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)))\n\n\ndef psnrmetric_np(max_val, y_pred, y):\n    mse = np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)\n    return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse))\n\n\nclass TestHandlerRegressionMetrics(unittest.TestCase):\n\n    def test_compute(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # regression metrics to check + truth metric function in numpy\n        metrics = [\n            MeanSquaredError,\n            MeanAbsoluteError,\n            RootMeanSquaredError,\n            partial(PeakSignalToNoiseRatio, max_val=1.0),\n        ]\n        metrics_np = [msemetric_np, maemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]\n\n        # define variations in batch/base_dims/spatial_dims\n        batch_dims = [1, 2, 4, 16]\n        base_dims = [16, 32, 64]\n        spatial_dims = [2, 3, 4]\n\n        # iterate over all variations and check shapes for different reduction functions\n        for mt_fn, mt_fn_np in zip(metrics, metrics_np):\n            for batch in batch_dims:\n                for spatial in spatial_dims:\n                    for base in base_dims:\n                        mt_fn_obj = mt_fn(save_details=False)\n\n                        # create random tensor\n                        in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        mt_fn_obj.update([in_tensor_a1, in_tensor_b1])\n                        out_tensor_np1 = mt_fn_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())\n\n                        in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        mt_fn_obj.update([in_tensor_a2, in_tensor_b2])\n                        out_tensor_np2 = mt_fn_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())\n\n                        out_tensor = mt_fn_obj.compute()\n                        out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0\n\n                        np.testing.assert_allclose(out_tensor, out_tensor_np, atol=1e-4)\n\n    def test_compute_engine(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # regression metrics to check + truth metric function in numpy\n        metrics_names = [\"MSE\", \"MAE\", \"RMSE\", \"PSNR\"]\n        metrics = [\n            MeanSquaredError,\n            MeanAbsoluteError,\n            RootMeanSquaredError,\n            partial(PeakSignalToNoiseRatio, max_val=1.0),\n        ]\n        metrics_np = [msemetric_np, maemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]\n\n        def _val_func(engine, batch):\n            pass\n\n        # define variations in batch/base_dims/spatial_dims\n        batch_dims = [1, 2, 4, 16]\n        base_dims = [16, 32, 64]\n        spatial_dims = [2, 3, 4]\n\n        # iterate over all variations and check shapes for different reduction functions\n        for mt_fn_name, mt_fn, mt_fn_np in zip(metrics_names, metrics, metrics_np):\n            for batch in batch_dims:\n                for spatial in spatial_dims:\n                    for base in base_dims:\n                        mt_fn_obj = mt_fn()  # 'save_details' == True\n                        engine = Engine(_val_func)\n                        mt_fn_obj.attach(engine, mt_fn_name)\n\n                        # create random tensor\n                        in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        mt_fn_obj.update([in_tensor_a1, in_tensor_b1])\n                        out_tensor_np1 = mt_fn_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())\n\n                        in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                        mt_fn_obj.update([in_tensor_a2, in_tensor_b2])\n                        out_tensor_np2 = mt_fn_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())\n\n                        out_tensor = mt_fn_obj.compute()\n                        out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0\n\n                        np.testing.assert_allclose(out_tensor, out_tensor_np, atol=1e-4)\n\n    def test_ill_shape(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # regression metrics to check + truth metric function in numpy\n        metrics = [\n            MeanSquaredError,\n            MeanAbsoluteError,\n            RootMeanSquaredError,\n            partial(PeakSignalToNoiseRatio, max_val=1.0),\n        ]\n        basedim = 10\n\n        # different shape for pred/target\n        with self.assertRaises((AssertionError, ValueError)):\n            in_tensor_a = torch.rand((basedim,)).to(device)\n            in_tensor_b = torch.rand((basedim, basedim)).to(device)\n            for mt_fn in metrics:\n                mt_fn().update([in_tensor_a, in_tensor_b])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_regression_metrics_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom ignite.engine import Engine\n\nfrom monai.handlers import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistCall, DistTestCase\n\n\n# define a numpy flatten function that only preserves batch dimension\ndef flatten(data):\n    return np.reshape(data, [data.shape[0], -1])\n\n\n# define metrics computation truth functions to check our monai metrics against\ndef msemetric_np(y_pred, y):\n    return np.mean((flatten(y_pred) - flatten(y)) ** 2)\n\n\ndef maemetric_np(y_pred, y):\n    return np.mean(np.abs(flatten(y_pred) - flatten(y)))\n\n\ndef rmsemetric_np(y_pred, y):\n    return np.mean(np.sqrt(np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)))\n\n\ndef psnrmetric_np(max_val, y_pred, y):\n    mse = np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)\n    return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse))\n\n\n# define tensor size as (BATCH_SIZE, (BASE_DIM_SIZE,) * SPATIAL_DIM)\n# One tensor with following shape takes 4*32*32*32*32/(8*1000) = 512 MB on a single GPU\n# We have total of 2 tensors each on one GPU for following tests, so required GPU memory is 1024 MB on each GPU\n# The required GPU memory can be lowered by changing BASE_DIM_SIZE to another value e.g. BASE_DIM_SIZE=16 will\n# require 128 MB on each GPU\nBATCH_SIZE = 4\nBASE_DIM_SIZE = 32\nSPATIAL_DIM = 3\n\n\nclass DistributedMeanSquaredError(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_compute(self):\n        set_determinism(123)\n        self._compute()\n\n    def _compute(self):\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        metric = MeanSquaredError()\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine, \"MSE\")\n\n        # get testing data\n        batch = BATCH_SIZE\n        base = BASE_DIM_SIZE\n        spatial = SPATIAL_DIM\n        in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        if dist.get_rank() == 0:\n            y_pred = in_tensor_a1.to(device)\n            y = in_tensor_b1.to(device)\n            metric.update([y_pred, y])\n\n        if dist.get_rank() == 1:\n            y_pred = in_tensor_a2.to(device)\n            y = in_tensor_b2.to(device)\n            metric.update([y_pred, y])\n\n        out_tensor = metric.compute()\n\n        # do numpy functions to get ground truth referece\n        out_tensor_np1 = msemetric_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())\n        out_tensor_np2 = msemetric_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())\n        out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0\n\n        np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)\n\n\nclass DistributedMeanAbsoluteError(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_compute(self):\n        set_determinism(123)\n        self._compute()\n\n    def _compute(self):\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        metric = MeanAbsoluteError()\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine, \"MAE\")\n\n        # get testing data\n        batch = BATCH_SIZE\n        base = BASE_DIM_SIZE\n        spatial = SPATIAL_DIM\n        in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        if dist.get_rank() == 0:\n            y_pred = in_tensor_a1.to(device)\n            y = in_tensor_b1.to(device)\n            metric.update([y_pred, y])\n\n        if dist.get_rank() == 1:\n            y_pred = in_tensor_a2.to(device)\n            y = in_tensor_b2.to(device)\n            metric.update([y_pred, y])\n\n        out_tensor = metric.compute()\n\n        # do numpy functions to get ground truth referece\n        out_tensor_np1 = maemetric_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())\n        out_tensor_np2 = maemetric_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())\n        out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0\n\n        np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)\n\n\nclass DistributedRootMeanSquaredError(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_compute(self):\n        set_determinism(123)\n        self._compute()\n\n    def _compute(self):\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        metric = RootMeanSquaredError()\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine, \"RMSE\")\n\n        # get testing data\n        batch = BATCH_SIZE\n        base = BASE_DIM_SIZE\n        spatial = SPATIAL_DIM\n        in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        if dist.get_rank() == 0:\n            y_pred = in_tensor_a1.to(device)\n            y = in_tensor_b1.to(device)\n            metric.update([y_pred, y])\n\n        if dist.get_rank() == 1:\n            y_pred = in_tensor_a2.to(device)\n            y = in_tensor_b2.to(device)\n            metric.update([y_pred, y])\n\n        out_tensor = metric.compute()\n\n        # do numpy functions to get ground truth referece\n        out_tensor_np1 = rmsemetric_np(y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())\n        out_tensor_np2 = rmsemetric_np(y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())\n        out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0\n\n        np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)\n\n\nclass DistributedPeakSignalToNoiseRatio(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_compute(self):\n        set_determinism(123)\n        self._compute()\n\n    def _compute(self):\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        max_val = 1.0\n        metric = PeakSignalToNoiseRatio(max_val=max_val)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        metric.attach(engine, \"PSNR\")\n\n        # get testing data\n        batch = BATCH_SIZE\n        base = BASE_DIM_SIZE\n        spatial = SPATIAL_DIM\n        in_tensor_a1 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b1 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        in_tensor_a2 = torch.rand((batch,) + (base,) * (spatial - 1))\n        in_tensor_b2 = torch.rand((batch,) + (base,) * (spatial - 1))\n\n        if dist.get_rank() == 0:\n            y_pred = in_tensor_a1.to(device)\n            y = in_tensor_b1.to(device)\n            metric.update([y_pred, y])\n\n        if dist.get_rank() == 1:\n            y_pred = in_tensor_a2.to(device)\n            y = in_tensor_b2.to(device)\n            metric.update([y_pred, y])\n\n        out_tensor = metric.compute()\n\n        # do numpy functions to get ground truth referece\n        out_tensor_np1 = psnrmetric_np(max_val=max_val, y_pred=in_tensor_a1.cpu().numpy(), y=in_tensor_b1.cpu().numpy())\n        out_tensor_np2 = psnrmetric_np(max_val=max_val, y_pred=in_tensor_a2.cpu().numpy(), y=in_tensor_b2.cpu().numpy())\n        out_tensor_np = (out_tensor_np1 + out_tensor_np2) / 2.0\n\n        np.testing.assert_allclose(out_tensor, out_tensor_np, rtol=1e-04, atol=1e-04)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_rocauc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.handlers import ROCAUC\nfrom monai.transforms import Activations, AsDiscrete\n\n\nclass TestHandlerROCAUC(unittest.TestCase):\n\n    def test_compute(self):\n        auc_metric = ROCAUC()\n        act = Activations(softmax=True)\n        to_onehot = AsDiscrete(to_onehot=2)\n\n        y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]\n        y = [torch.Tensor([0]), torch.Tensor([1])]\n        y_pred = [act(p) for p in y_pred]\n        y = [to_onehot(y_) for y_ in y]\n        auc_metric.update([y_pred, y])\n\n        y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]\n        y = [torch.Tensor([0]), torch.Tensor([1])]\n        y_pred = [act(p) for p in y_pred]\n        y = [to_onehot(y_) for y_ in y]\n\n        auc_metric.update([y_pred, y])\n\n        auc = auc_metric.compute()\n        np.testing.assert_allclose(0.75, auc)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_rocauc_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom monai.handlers import ROCAUC\nfrom monai.transforms import Activations, AsDiscrete\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedROCAUC(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2, node_rank=0)\n    def test_compute(self):\n        auc_metric = ROCAUC()\n        act = Activations(softmax=True)\n        to_onehot = AsDiscrete(to_onehot=2)\n\n        device = f\"cuda:{dist.get_rank()}\" if torch.cuda.is_available() else \"cpu\"\n        if dist.get_rank() == 0:\n            y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)]\n            y = [torch.tensor([0], device=device), torch.tensor([1], device=device)]\n\n        if dist.get_rank() == 1:\n            y_pred = [\n                torch.tensor([0.2, 0.1], device=device),\n                torch.tensor([0.1, 0.5], device=device),\n                torch.tensor([0.3, 0.4], device=device),\n            ]\n            y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)]\n\n        y_pred = [act(p) for p in y_pred]\n        y = [to_onehot(y_) for y_ in y]\n        auc_metric.update([y_pred, y])\n\n        result = auc_metric.compute()\n        np.testing.assert_allclose(0.66667, result, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_stats.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport re\nimport tempfile\nimport unittest\nfrom io import StringIO\n\nimport torch\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.handlers import StatsHandler\n\n\ndef get_event_filter(e):\n\n    def event_filter(_, event):\n        if event in e:\n            return True\n        return False\n\n    return event_filter\n\n\nclass TestHandlerStats(unittest.TestCase):\n\n    @parameterized.expand([[True], [get_event_filter([1, 2])]])\n    def test_metrics_print(self, epoch_log):\n        log_stream = StringIO()\n        log_handler = logging.StreamHandler(log_stream)\n        log_handler.setLevel(logging.INFO)\n        key_to_handler = \"test_logging\"\n        key_to_print = \"testing_metric\"\n\n        # set up engine\n        def _train_func(engine, batch):\n            return [torch.tensor(0.0)]\n\n        engine = Engine(_train_func)\n\n        # set up dummy metric\n        @engine.on(Events.EPOCH_COMPLETED)\n        def _update_metric(engine):\n            current_metric = engine.state.metrics.get(key_to_print, 0.1)\n            engine.state.metrics[key_to_print] = current_metric + 0.1\n\n        # set up testing handler\n        logger = logging.getLogger(key_to_handler)\n        logger.setLevel(logging.INFO)\n        logger.addHandler(log_handler)\n        stats_handler = StatsHandler(iteration_log=False, epoch_log=epoch_log, name=key_to_handler)\n        stats_handler.attach(engine)\n\n        max_epochs = 4\n        engine.run(range(3), max_epochs=max_epochs)\n\n        # check logging output\n        output_str = log_stream.getvalue()\n        log_handler.close()\n        has_key_word = re.compile(f\".*{key_to_print}.*\")\n        content_count = 0\n        for line in output_str.split(\"\\n\"):\n            if has_key_word.match(line):\n                content_count += 1\n        if epoch_log is True:\n            self.assertEqual(content_count, max_epochs)\n        else:\n            self.assertEqual(content_count, 2)  # 2 = len([1, 2]) from event_filter\n\n    @parameterized.expand([[True], [get_event_filter([1, 3])]])\n    def test_loss_print(self, iteration_log):\n        log_stream = StringIO()\n        log_handler = logging.StreamHandler(log_stream)\n        log_handler.setLevel(logging.INFO)\n        key_to_handler = \"test_logging\"\n        key_to_print = \"myLoss\"\n\n        # set up engine\n        def _train_func(engine, batch):\n            return [torch.tensor(0.0)]\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        logger = logging.getLogger(key_to_handler)\n        logger.setLevel(logging.INFO)\n        logger.addHandler(log_handler)\n        stats_handler = StatsHandler(\n            iteration_log=iteration_log, epoch_log=False, name=key_to_handler, tag_name=key_to_print\n        )\n        stats_handler.attach(engine)\n\n        num_iters = 3\n        max_epochs = 2\n        engine.run(range(num_iters), max_epochs=max_epochs)\n\n        # check logging output\n        output_str = log_stream.getvalue()\n        log_handler.close()\n        has_key_word = re.compile(f\".*{key_to_print}.*\")\n        content_count = 0\n        for line in output_str.split(\"\\n\"):\n            if has_key_word.match(line):\n                content_count += 1\n        if iteration_log is True:\n            self.assertEqual(content_count, num_iters * max_epochs)\n        else:\n            self.assertEqual(content_count, 2)  # 2 = len([1, 3]) from event_filter\n\n    def test_loss_dict(self):\n        log_stream = StringIO()\n        log_handler = logging.StreamHandler(log_stream)\n        log_handler.setLevel(logging.INFO)\n        key_to_handler = \"test_logging\"\n        key_to_print = \"myLoss1\"\n\n        # set up engine\n        def _train_func(engine, batch):\n            return [torch.tensor(0.0)]\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        logger = logging.getLogger(key_to_handler)\n        logger.setLevel(logging.INFO)\n        logger.addHandler(log_handler)\n        stats_handler = StatsHandler(name=key_to_handler, output_transform=lambda x: {key_to_print: x[0]})\n        stats_handler.attach(engine)\n\n        engine.run(range(3), max_epochs=2)\n\n        # check logging output\n        output_str = log_stream.getvalue()\n        log_handler.close()\n        has_key_word = re.compile(f\".*{key_to_print}.*\")\n        content_count = 0\n        for line in output_str.split(\"\\n\"):\n            if has_key_word.match(line):\n                content_count += 1\n        self.assertGreater(content_count, 0)\n\n    def test_loss_file(self):\n        key_to_handler = \"test_logging\"\n        key_to_print = \"myLoss\"\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_loss_stats.log\")\n            handler = logging.FileHandler(filename, mode=\"w\")\n            handler.setLevel(logging.INFO)\n\n            # set up engine\n            def _train_func(engine, batch):\n                return [torch.tensor(0.0)]\n\n            engine = Engine(_train_func)\n\n            # set up testing handler\n            logger = logging.getLogger(key_to_handler)\n            logger.setLevel(logging.INFO)\n            logger.addHandler(handler)\n            stats_handler = StatsHandler(name=key_to_handler, tag_name=key_to_print)\n            stats_handler.attach(engine)\n\n            engine.run(range(3), max_epochs=2)\n            handler.close()\n            stats_handler.logger.removeHandler(handler)\n            with open(filename) as f:\n                output_str = f.read()\n                has_key_word = re.compile(f\".*{key_to_print}.*\")\n                content_count = 0\n                for line in output_str.split(\"\\n\"):\n                    if has_key_word.match(line):\n                        content_count += 1\n                self.assertGreater(content_count, 0)\n\n    def test_exception(self):\n        # set up engine\n        def _train_func(engine, batch):\n            raise RuntimeError(\"test exception.\")\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        stats_handler = StatsHandler()\n        stats_handler.attach(engine)\n\n        with self.assertRaises(RuntimeError):\n            engine.run(range(3), max_epochs=2)\n\n    def test_attributes_print(self):\n        log_stream = StringIO()\n        log_handler = logging.StreamHandler(log_stream)\n        log_handler.setLevel(logging.INFO)\n        key_to_handler = \"test_logging\"\n\n        # set up engine\n        def _train_func(engine, batch):\n            return [torch.tensor(0.0)]\n\n        engine = Engine(_train_func)\n\n        # set up dummy metric\n        @engine.on(Events.EPOCH_COMPLETED)\n        def _update_metric(engine):\n            if not hasattr(engine.state, \"test1\"):\n                engine.state.test1 = 0.1\n                engine.state.test2 = 0.2\n            else:\n                engine.state.test1 += 0.1\n                engine.state.test2 += 0.2\n\n        # set up testing handler\n        logger = logging.getLogger(key_to_handler)\n        logger.setLevel(logging.INFO)\n        logger.addHandler(log_handler)\n        stats_handler = StatsHandler(name=key_to_handler, state_attributes=[\"test1\", \"test2\", \"test3\"])\n        stats_handler.attach(engine)\n\n        engine.run(range(3), max_epochs=2)\n\n        # check logging output\n        output_str = log_stream.getvalue()\n        log_handler.close()\n        has_key_word = re.compile(\".*State values.*\")\n        content_count = 0\n        for line in output_str.split(\"\\n\"):\n            if has_key_word.match(line):\n                content_count += 1\n        self.assertGreater(content_count, 0)\n\n    def test_default_logger(self):\n        log_stream = StringIO()\n        log_handler = logging.StreamHandler(log_stream)\n        log_handler.setLevel(logging.INFO)\n        key_to_print = \"myLoss\"\n\n        # set up engine\n        def _train_func(engine, batch):\n            return [torch.tensor(0.0)]\n\n        engine = Engine(_train_func)\n        engine.logger.addHandler(log_handler)\n\n        # set up testing handler\n        stats_handler = StatsHandler(name=None, tag_name=key_to_print)\n        engine.logger.setLevel(logging.WARNING)\n        with self.assertWarns(Warning):  # engine logging level warn\n            stats_handler.attach(engine)\n        # leverage `engine.logger` to print info\n        engine.logger.setLevel(logging.INFO)\n        level = logging.root.getEffectiveLevel()\n        logging.basicConfig(level=logging.INFO)\n        engine.run(range(3), max_epochs=2)\n        logging.basicConfig(level=level)\n\n        # check logging output\n        output_str = log_stream.getvalue()\n        log_handler.close()\n        has_key_word = re.compile(f\".*{key_to_print}.*\")\n        content_count = 0\n        for line in output_str.split(\"\\n\"):\n            if has_key_word.match(line):\n                content_count += 1\n        self.assertGreater(content_count, 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_surface_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine\n\nfrom monai.handlers import SurfaceDistance\nfrom tests.test_utils import assert_allclose\n\n\ndef create_spherical_seg_3d(\n    radius: float = 20.0, centre: tuple[int, int, int] = (49, 49, 49), im_shape: tuple[int, int, int] = (99, 99, 99)\n) -> np.ndarray:\n    \"\"\"\n    Return a 3D image with a sphere inside. Voxel values will be\n    1 inside the sphere, and 0 elsewhere.\n\n    Args:\n        radius: radius of sphere (in terms of number of voxels, can be partial)\n        centre: location of sphere centre.\n        im_shape: shape of image to create\n\n    See also:\n        :py:meth:`~create_test_image_3d`\n    \"\"\"\n    # Create image\n    image = np.zeros(im_shape, dtype=np.int32)\n    spy, spx, spz = np.ogrid[\n        -centre[0] : im_shape[0] - centre[0], -centre[1] : im_shape[1] - centre[1], -centre[2] : im_shape[2] - centre[2]\n    ]\n    circle = (spx * spx + spy * spy + spz * spz) <= radius * radius\n\n    image[circle] = 1\n    image[~circle] = 0\n    return image\n\n\nsampler_sphere = torch.Tensor(create_spherical_seg_3d(radius=20, centre=(20, 20, 20))).unsqueeze(0).unsqueeze(0)\n# test input a list of channel-first tensor\nsampler_sphere_gt = [torch.Tensor(create_spherical_seg_3d(radius=20, centre=(10, 20, 20))).unsqueeze(0)]\nsampler_sphere_zeros = torch.zeros_like(sampler_sphere)\n\nTEST_SAMPLE_1 = [sampler_sphere, sampler_sphere_gt]\nTEST_SAMPLE_2 = [sampler_sphere_gt, sampler_sphere_gt]\nTEST_SAMPLE_3 = [sampler_sphere_zeros, sampler_sphere_gt]\nTEST_SAMPLE_4 = [sampler_sphere_zeros, sampler_sphere_zeros]\n\n\nclass TestHandlerSurfaceDistance(unittest.TestCase):\n    # TODO test multi node Surface Distance\n\n    def test_compute(self):\n        sur_metric = SurfaceDistance(include_background=True)\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        sur_metric.attach(engine, \"surface_distance\")\n\n        y_pred, y = TEST_SAMPLE_1\n        sur_metric.update([y_pred, y])\n        self.assertAlmostEqual(sur_metric.compute(), 4.17133, places=4)\n        y_pred, y = TEST_SAMPLE_2\n        sur_metric.update([y_pred, y])\n        self.assertAlmostEqual(sur_metric.compute(), 2.08566, places=4)\n        y_pred, y = TEST_SAMPLE_3\n        sur_metric.update([y_pred, y])\n        self.assertAlmostEqual(sur_metric.compute(), float(\"inf\"))\n        y_pred, y = TEST_SAMPLE_4\n        sur_metric.update([y_pred, y])\n        self.assertAlmostEqual(sur_metric.compute(), float(\"inf\"))\n\n    def test_shape_mismatch(self):\n        sur_metric = SurfaceDistance(include_background=True)\n        with self.assertRaises((AssertionError, ValueError)):\n            y_pred = TEST_SAMPLE_1[0]\n            y = torch.ones((1, 1, 10, 10, 10))\n            sur_metric.update([y_pred, y])\n\n    def test_reduction(self):\n        sur_metric = SurfaceDistance(include_background=True, reduction=\"mean_channel\")\n\n        def _val_func(engine, batch):\n            pass\n\n        engine = Engine(_val_func)\n        sur_metric.attach(engine, \"surface_distance\")\n\n        y_pred, y = TEST_SAMPLE_1\n        sur_metric.update([y_pred, y])\n        y_pred, y = TEST_SAMPLE_2\n        sur_metric.update([y_pred, y])\n        assert_allclose(\n            sur_metric.compute().float(), torch.tensor([4.1713, 0.0000]), atol=1e-4, rtol=1e-4, type_test=False\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_tb_image.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.data import decollate_batch\nfrom monai.handlers import TensorBoardImageHandler\nfrom monai.utils import optional_import\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\nTEST_CASES = [[[20, 20]], [[2, 20, 20]], [[3, 20, 20]], [[20, 20, 20]], [[2, 20, 20, 20]], [[2, 2, 20, 20, 20]]]\n\n\n@unittest.skipUnless(has_tb, \"Requires SummaryWriter installation\")\nclass TestHandlerTBImage(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_tb_image_shape(self, shape):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                engine.state.batch = decollate_batch(list(batch))\n                return [torch.zeros((1, 10, 10))]\n\n            engine = Engine(_train_func)\n\n            # set up testing handler\n            stats_handler = TensorBoardImageHandler(log_dir=tempdir)\n            engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler)\n\n            data = zip(\n                torch.as_tensor(np.random.normal(size=(10, 4, *shape))),\n                torch.as_tensor(np.random.normal(size=(10, 4, *shape))),\n            )\n            engine.run(data, epoch_length=10, max_epochs=1)\n            stats_handler.close()\n\n            self.assertTrue(len(glob.glob(tempdir)) > 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_tb_stats.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport tempfile\nimport unittest\nfrom unittest.mock import MagicMock\n\nfrom ignite.engine import Engine, Events\nfrom parameterized import parameterized\n\nfrom monai.handlers import TensorBoardStatsHandler\nfrom monai.utils import optional_import\n\nSummaryWriter, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\n\ndef get_event_filter(e):\n\n    def event_filter(_, event):\n        if event in e:\n            return True\n        return False\n\n    return event_filter\n\n\n@unittest.skipUnless(has_tb, \"Requires SummaryWriter installation\")\nclass TestHandlerTBStats(unittest.TestCase):\n\n    def test_metrics_print(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n\n            # set up testing handler\n            stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=True)\n            stats_handler.attach(engine)\n            engine.run(range(3), max_epochs=2)\n            stats_handler.close()\n            # check logging output\n            self.assertTrue(len(glob.glob(tempdir)) > 0)\n\n    @parameterized.expand([[True], [get_event_filter([1, 2])]])\n    def test_metrics_print_mock(self, epoch_log):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n\n            # set up testing handler\n            stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=epoch_log)\n            stats_handler._default_epoch_writer = MagicMock()\n            stats_handler.attach(engine)\n\n            max_epochs = 4\n            engine.run(range(3), max_epochs=max_epochs)\n            stats_handler.close()\n            # check logging output\n            if epoch_log is True:\n                self.assertEqual(stats_handler._default_epoch_writer.call_count, max_epochs)\n            else:\n                self.assertEqual(stats_handler._default_epoch_writer.call_count, 2)  # 2 = len([1, 2]) from event_filter\n\n    def test_metrics_writer(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n                engine.state.test = current_metric\n\n            # set up testing handler\n            writer = SummaryWriter(log_dir=tempdir)\n            stats_handler = TensorBoardStatsHandler(\n                summary_writer=writer,\n                iteration_log=True,\n                epoch_log=False,\n                output_transform=lambda x: {\"loss\": x[0] * 2.0},\n                global_epoch_transform=lambda x: x * 3.0,\n                state_attributes=[\"test\"],\n            )\n            stats_handler.attach(engine)\n            engine.run(range(3), max_epochs=2)\n            writer.close()\n            # check logging output\n            self.assertTrue(len(glob.glob(tempdir)) > 0)\n\n    @parameterized.expand([[True], [get_event_filter([1, 3])]])\n    def test_metrics_writer_mock(self, iteration_log):\n        with tempfile.TemporaryDirectory() as tempdir:\n            # set up engine\n            def _train_func(engine, batch):\n                return [batch + 1.0]\n\n            engine = Engine(_train_func)\n\n            # set up dummy metric\n            @engine.on(Events.EPOCH_COMPLETED)\n            def _update_metric(engine):\n                current_metric = engine.state.metrics.get(\"acc\", 0.1)\n                engine.state.metrics[\"acc\"] = current_metric + 0.1\n                engine.state.test = current_metric\n\n            # set up testing handler\n            writer = SummaryWriter(log_dir=tempdir)\n            stats_handler = TensorBoardStatsHandler(\n                summary_writer=writer,\n                iteration_log=iteration_log,\n                epoch_log=False,\n                output_transform=lambda x: {\"loss\": x[0] * 2.0},\n                global_epoch_transform=lambda x: x * 3.0,\n                state_attributes=[\"test\"],\n            )\n            stats_handler._default_iteration_writer = MagicMock()\n            stats_handler.attach(engine)\n\n            num_iters = 3\n            max_epochs = 2\n            engine.run(range(num_iters), max_epochs=max_epochs)\n            writer.close()\n\n            if iteration_log is True:\n                self.assertEqual(stats_handler._default_iteration_writer.call_count, num_iters * max_epochs)\n            else:\n                self.assertEqual(\n                    stats_handler._default_iteration_writer.call_count, 2\n                )  # 2 = len([1, 3]) from event_filter\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_handler_validation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom ignite.engine import Engine\n\nfrom monai.data import Dataset\nfrom monai.engines import Evaluator\nfrom monai.handlers import ValidationHandler\n\n\nclass TestEvaluator(Evaluator):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def _iteration(self, engine, batchdata):\n        engine.state.output = \"called\"\n        return engine.state.output\n\n\nclass TestHandlerValidation(unittest.TestCase):\n\n    def test_content(self):\n        data = [0] * 8\n\n        # set up engine\n        def _train_func(engine, batch):\n            pass\n\n        engine = Engine(_train_func)\n\n        # set up testing handler\n        val_data_loader = torch.utils.data.DataLoader(Dataset(data))\n        evaluator = TestEvaluator(torch.device(\"cpu:0\"), val_data_loader)\n        ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine)\n        # test execution at start\n        engine.run(data, max_epochs=1)\n        self.assertEqual(evaluator.state.max_epochs, 1)\n        self.assertEqual(evaluator.state.epoch_length, 8)\n        self.assertEqual(evaluator.state.output, \"called\")\n\n        engine.run(data, max_epochs=5)\n        self.assertEqual(evaluator.state.max_epochs, 4)\n        self.assertEqual(evaluator.state.epoch_length, 8)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_trt_compile.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.handlers import TrtHandler\nfrom monai.networks import trt_compile\nfrom monai.networks.nets import cell_sam_wrapper, vista3d132\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows\n\ntrt, trt_imported = optional_import(\"tensorrt\", \"10.1.0\", min_version)\ntorch_tensorrt, torch_trt_imported = optional_import(\"torch_tensorrt\")\npolygraphy, polygraphy_imported = optional_import(\"polygraphy\")\nbuild_sam_vit_b, has_sam = optional_import(\"segment_anything.build_sam\", name=\"build_sam_vit_b\")\n_, has_cudart = optional_import(\"cuda.bindings.runtime\")\nif not has_cudart:\n    _, has_cudart = optional_import(\"cuda.cudart\")\n\nTEST_CASE_1 = [\"fp32\"]\nTEST_CASE_2 = [\"fp16\"]\n\n\nclass ListAdd(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1):\n        y1 = y.clone()\n        x1 = x.copy()\n        z1 = z + y\n        for xi in x:\n            y1 = y1 + xi + bs\n        return x1, [y1, z1], y1 + z1\n\n\n@skip_if_windows\n@skip_if_no_cuda\n@skip_if_quick\n@unittest.skipUnless(trt_imported, \"tensorrt is required\")\n@unittest.skipUnless(polygraphy_imported, \"polygraphy is required\")\n@unittest.skipUnless(has_cudart, \"cuda-python or cuda-bindings is required\")\n@SkipIfBeforeComputeCapabilityVersion((7, 5))\nclass TestTRTCompile(unittest.TestCase):\n    def setUp(self):\n        self.gpu_device = torch.cuda.current_device()\n\n    def tearDown(self):\n        current_device = torch.cuda.current_device()\n        if current_device != self.gpu_device:\n            torch.cuda.set_device(self.gpu_device)\n\n    # @unittest.skipUnless(torch_trt_imported, \"torch_tensorrt is required\")\n    def test_handler(self):\n        from ignite.engine import Engine\n\n        net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])\n        data1 = net1.state_dict()\n        data1[\"0.weight\"] = torch.tensor([0.1])\n        data1[\"1.weight\"] = torch.tensor([0.2])\n        net1.load_state_dict(data1)\n        net1.cuda()\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            engine = Engine(lambda e, b: None)\n            args = {\"method\": \"onnx\", \"dynamic_batchsize\": [1, 4, 8]}\n            TrtHandler(net1, tempdir + \"/trt_handler\", args=args).attach(engine)\n            engine.run([0] * 8, max_epochs=1)\n            self.assertIsNotNone(net1._trt_compiler)\n            self.assertIsNone(net1._trt_compiler.engine)\n            net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=\"cuda\"))\n            self.assertIsNotNone(net1._trt_compiler.engine)\n\n    def test_lists(self):\n        model = ListAdd().cuda()\n\n        with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:\n            args = {\n                \"output_lists\": [[-1], [2], []],\n                \"export_args\": {\"dynamo\": False, \"verbose\": True},\n                \"dynamic_batchsize\": [1, 4, 8],\n            }\n            x = torch.randn(1, 16).to(\"cuda\")\n            y = torch.randn(1, 16).to(\"cuda\")\n            z = torch.randn(1, 16).to(\"cuda\")\n            input_example = ([x, y, z], y.clone(), z.clone())\n            output_example = model(*input_example)\n            trt_compile(model, f\"{tmpdir}/test_lists\", args=args)\n            self.assertIsNone(model._trt_compiler.engine)\n            trt_output = model(*input_example)\n            # Check that lazy TRT build succeeded\n            self.assertIsNotNone(model._trt_compiler.engine)\n            torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    @unittest.skipUnless(has_sam, \"Requires SAM installation\")\n    def test_cell_sam_wrapper_value(self, precision):\n        model = cell_sam_wrapper.CellSamWrapper(checkpoint=None).to(\"cuda\")\n        with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:\n            model.eval()\n            input_example = torch.randn(1, 3, 128, 128).to(\"cuda\")\n            output_example = model(input_example)\n            trt_compile(model, f\"{tmpdir}/test_cell_sam_wrapper_trt_compile\", args={\"precision\": precision})\n            self.assertIsNone(model._trt_compiler.engine)\n            trt_output = model(input_example)\n            # Check that lazy TRT build succeeded\n            self.assertIsNotNone(model._trt_compiler.engine)\n            torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_vista3d(self, precision):\n        model = vista3d132(in_channels=1).to(\"cuda\")\n        with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:\n            model.eval()\n            input_example = torch.randn(1, 1, 64, 64, 64).to(\"cuda\")\n            output_example = model(input_example)\n            model = trt_compile(\n                model,\n                f\"{tmpdir}/test_vista3d_trt_compile\",\n                args={\"precision\": precision, \"dynamic_batchsize\": [1, 2, 4]},\n                submodule=[\"image_encoder.encoder\", \"class_head\"],\n            )\n            self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)\n            self.assertIsNotNone(model.class_head._trt_compiler)\n            trt_output = model.forward(input_example)\n            # Check that lazy TRT build succeeded\n            # TODO: set up input_example in such a way that image_encoder.encoder and class_head are called\n            # and uncomment the asserts below\n            # self.assertIsNotNone(model.image_encoder.encoder._trt_compiler.engine)\n            # self.assertIsNotNone(model.class_head._trt_compiler.engine)\n            torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/handlers/test_write_metrics_reports.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport torch\n\nfrom monai.handlers.utils import write_metrics_reports\n\n\nclass TestWriteMetricsReports(unittest.TestCase):\n\n    def test_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            write_metrics_reports(\n                save_dir=Path(tempdir),\n                images=[\"filepath1\", \"filepath2\"],\n                metrics={\"metric1\": 1, \"metric2\": 2},\n                metric_details={\"metric3\": torch.tensor([[1, 2], [2, 3]]), \"metric4\": torch.tensor([[5, 6], [7, 8]])},\n                summary_ops=[\"mean\", \"median\", \"max\", \"90percentile\"],\n                deli=\"\\t\",\n                output_type=\"csv\",\n            )\n\n            # check the metrics.csv and content\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metrics.csv\")))\n            with open(os.path.join(tempdir, \"metrics.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    self.assertEqual(row, [f\"metric{i + 1}\\t{i + 1}\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_raw.csv\")))\n            # check the metric_raw.csv and content\n            with open(os.path.join(tempdir, \"metric3_raw.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    if i > 0:\n                        self.assertEqual(row, [f\"filepath{i}\\t{float(i):.4f}\\t{float(i + 1):.4f}\\t{i + 0.5:.4f}\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric3_summary.csv\")))\n            # check the metric_summary.csv and content\n            with open(os.path.join(tempdir, \"metric3_summary.csv\")) as f:\n                f_csv = csv.reader(f)\n                for i, row in enumerate(f_csv):\n                    if i == 1:\n                        self.assertEqual(row, [\"class0\\t1.5000\\t1.5000\\t2.0000\\t1.9000\"])\n                    elif i == 2:\n                        self.assertEqual(row, [\"class1\\t2.5000\\t2.5000\\t3.0000\\t2.9000\"])\n                    elif i == 3:\n                        self.assertEqual(row, [\"mean\\t2.0000\\t2.0000\\t2.5000\\t2.4000\"])\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric4_raw.csv\")))\n            self.assertTrue(os.path.exists(os.path.join(tempdir, \"metric4_summary.csv\")))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/hvd_evenly_divisible_all_gather.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.utils import evenly_divisible_all_gather\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import assert_allclose\n\nhvd, has_hvd = optional_import(\"horovod\", name=\"torch\")\n\n\nclass HvdEvenlyDivisibleAllGather:\n    def test_data(self):\n        # initialize Horovod\n        hvd.init()\n        if torch.cuda.is_available():\n            torch.cuda.set_device(hvd.local_rank())\n        self._run()\n\n    def _run(self):\n        # if hvd.rank() == 0:\n        data1 = torch.tensor([[1, 2], [3, 4]])\n        data2 = torch.tensor([[1.0, 2.0]])\n        data3 = torch.tensor(7)\n\n        if hvd.rank() == 1:\n            data1 = torch.tensor([[5, 6]])\n            data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]])\n            data3 = torch.tensor(8)\n\n        result1 = evenly_divisible_all_gather(data=data1, concat=True)\n        assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))\n        result2 = evenly_divisible_all_gather(data=data2, concat=False)\n        for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]):\n            assert_allclose(r, e)\n        result3 = evenly_divisible_all_gather(data=data3, concat=False)\n        for r in result3:\n            assert_allclose(r.ndimension(), 0)\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    1. Install Horovod:\n    `HOROVOD_NCCL_INCLUDE=/usr/include HOROVOD_NCCL_LIB=/usr/lib/x86_64-linux-gnu HOROVOD_GPU_OPERATIONS=NCCL \\\n    HOROVOD_NCCL_LINK=SHARED pip install --no-cache-dir horovod`\n\n    2. Execute on 2 GPUs in a single machine:\n    `horovodrun -np 2 python test_evenly_divisible_all_gather_hvd.py`\n\n    \"\"\"\n    HvdEvenlyDivisibleAllGather().test_data()\n"
  },
  {
    "path": "tests/inferers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/inferers/test_avg_merger.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\nfrom torch.nn.functional import pad\n\nfrom monai.inferers import AvgMerger\nfrom tests.test_utils import assert_allclose\n\nTENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)\nTENSOR_4x4_WITH_NAN = TENSOR_4x4.clone()\nTENSOR_4x4_WITH_NAN[..., 2:, 2:] = float(\"nan\")\n\n# no-overlapping 2x2\nTEST_CASE_0_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# overlapping 2x2\nTEST_CASE_1_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [\n        (TENSOR_4x4[..., 0:2, 0:2], (0, 0)),\n        (TENSOR_4x4[..., 0:2, 1:3], (0, 1)),\n        (TENSOR_4x4[..., 0:2, 2:4], (0, 2)),\n        (TENSOR_4x4[..., 1:3, 0:2], (1, 0)),\n        (TENSOR_4x4[..., 1:3, 1:3], (1, 1)),\n        (TENSOR_4x4[..., 1:3, 2:4], (1, 2)),\n        (TENSOR_4x4[..., 2:4, 0:2], (2, 0)),\n        (TENSOR_4x4[..., 2:4, 1:3], (2, 1)),\n        (TENSOR_4x4[..., 2:4, 2:4], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# overlapping 3x3 (non-divisible)\nTEST_CASE_2_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [\n        (TENSOR_4x4[..., :3, :3], (0, 0)),\n        (TENSOR_4x4[..., :3, 1:], (0, 1)),\n        (TENSOR_4x4[..., 1:, :3], (1, 0)),\n        (TENSOR_4x4[..., 1:, 1:], (1, 1)),\n    ],\n    TENSOR_4x4,\n]\n\n#  overlapping 2x2 with NaN values\nTEST_CASE_3_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4_WITH_NAN.shape),\n    [\n        (TENSOR_4x4_WITH_NAN[..., 0:2, 0:2], (0, 0)),\n        (TENSOR_4x4_WITH_NAN[..., 0:2, 1:3], (0, 1)),\n        (TENSOR_4x4_WITH_NAN[..., 0:2, 2:4], (0, 2)),\n        (TENSOR_4x4_WITH_NAN[..., 1:3, 0:2], (1, 0)),\n        (TENSOR_4x4_WITH_NAN[..., 1:3, 1:3], (1, 1)),\n        (TENSOR_4x4_WITH_NAN[..., 1:3, 2:4], (1, 2)),\n        (TENSOR_4x4_WITH_NAN[..., 2:4, 0:2], (2, 0)),\n        (TENSOR_4x4_WITH_NAN[..., 2:4, 1:3], (2, 1)),\n        (TENSOR_4x4_WITH_NAN[..., 2:4, 2:4], (2, 2)),\n    ],\n    TENSOR_4x4_WITH_NAN,\n]\n\n# non-overlapping 2x2 with missing patch\nTEST_CASE_4_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [(TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), (TENSOR_4x4[..., 2:, :2], (2, 0))],\n    TENSOR_4x4_WITH_NAN,\n]\n\n# with value_dtype set to half precision\nTEST_CASE_5_VALUE_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, value_dtype=torch.float16),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n# with count_dtype set to int32\nTEST_CASE_6_COUNT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, count_dtype=torch.int32),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n# with both value_dtype, count_dtype set to double precision\nTEST_CASE_7_COUNT_VALUE_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, value_dtype=torch.float64, count_dtype=torch.float64),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# shape larger than what is covered by patches\nTEST_CASE_8_LARGER_SHAPE = [\n    dict(merged_shape=(2, 3, 4, 6)),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    pad(TENSOR_4x4, (0, 2), value=float(\"nan\")),\n]\n\n\nclass AvgMergerTests(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_0_DEFAULT_DTYPE,\n            TEST_CASE_1_DEFAULT_DTYPE,\n            TEST_CASE_2_DEFAULT_DTYPE,\n            TEST_CASE_3_DEFAULT_DTYPE,\n            TEST_CASE_4_DEFAULT_DTYPE,\n            TEST_CASE_5_VALUE_DTYPE,\n            TEST_CASE_6_COUNT_DTYPE,\n            TEST_CASE_7_COUNT_VALUE_DTYPE,\n            TEST_CASE_8_LARGER_SHAPE,\n        ]\n    )\n    def test_avg_merger_patches(self, arguments, patch_locations, expected):\n        merger = AvgMerger(**arguments)\n        for pl in patch_locations:\n            merger.aggregate(pl[0], pl[1])\n        output = merger.finalize()\n        if \"value_dtype\" in arguments:\n            self.assertTrue(merger.get_values().dtype, arguments[\"value_dtype\"])\n        if \"count_dtype\" in arguments:\n            self.assertTrue(merger.get_counts().dtype, arguments[\"count_dtype\"])\n        # check for multiple call of finalize\n        self.assertIs(output, merger.finalize())\n        # check if the result is matching the expectation\n        assert_allclose(output, expected)\n\n    def test_avg_merger_finalized_error(self):\n        with self.assertRaises(ValueError):\n            merger = AvgMerger(merged_shape=(1, 3, 2, 3))\n            merger.finalize()\n            merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3))\n\n    def test_avg_merge_none_merged_shape_error(self):\n        with self.assertRaises(ValueError):\n            AvgMerger(merged_shape=None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_controlnet_inferers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer\nfrom monai.networks.nets import (\n    VQVAE,\n    AutoencoderKL,\n    ControlNet,\n    DiffusionModelUNet,\n    SPADEAutoencoderKL,\n    SPADEDiffusionModelUNet,\n)\nfrom monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler\nfrom monai.utils import optional_import\n\n_, has_scipy = optional_import(\"scipy\")\n_, has_einops = optional_import(\"einops\")\n\n\nCNDM_TEST_CASES = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"channels\": [8],\n            \"attention_levels\": [True],\n            \"norm_num_groups\": 8,\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (2, 1, 8, 8),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"channels\": [8],\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 8,\n            \"num_head_channels\": 8,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (2, 1, 8, 8, 8),\n    ],\n]\nLATENT_CNDM_TEST_CASES = [\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [4, 4],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"num_head_channels\": 4,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [8, 8],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 8,\n            \"num_head_channels\": 8,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"channels\": [8, 8],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 8,\n            \"num_head_channels\": 8,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 16, 16, 16),\n        (1, 3, 4, 4, 4),\n    ],\n    [\n        \"SPADEAutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"label_nc\": 5,\n        },\n        \"SPADEDiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n            \"label_nc\": 5,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [4, 4],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"num_head_channels\": 4,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n]\nLATENT_CNDM_TEST_CASES_DIFF_SHAPES = [\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [4, 4],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"num_head_channels\": 4,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 12, 12),\n        (1, 3, 8, 8),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [8, 8],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 8,\n            \"num_head_channels\": 8,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 12, 12),\n        (1, 3, 8, 8),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"channels\": [8, 8],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 8,\n            \"num_head_channels\": 8,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 12, 12, 12),\n        (1, 3, 8, 8, 8),\n    ],\n    [\n        \"SPADEAutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [4, 4],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"num_head_channels\": 4,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"SPADEDiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [4, 4],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"num_head_channels\": 4,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"SPADEAutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"SPADEDiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"channels\": [4, 4],\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"num_head_channels\": 4,\n            \"conditioning_embedding_num_channels\": [16],\n            \"conditioning_embedding_in_channels\": 1,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n]\n\n\nclass ControlNetTestDiffusionSamplingInferer(unittest.TestCase):\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_call(self, model_params, controlnet_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        controlnet = ControlNet(**controlnet_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet.to(device)\n        controlnet.eval()\n        input = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n        sample = inferer(\n            inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask\n        )\n        self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_intermediates(self, model_params, controlnet_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        controlnet = ControlNet(**controlnet_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet.to(device)\n        controlnet.eval()\n        noise = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        for cfg in [5, None]:\n            sample, intermediates = inferer.sample(\n                input_noise=noise,\n                diffusion_model=model,\n                scheduler=scheduler,\n                controlnet=controlnet,\n                cn_cond=mask,\n                save_intermediates=True,\n                intermediate_steps=1,\n                cfg=cfg,\n            )\n\n            self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_ddpm_sampler(self, model_params, controlnet_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        controlnet = ControlNet(**controlnet_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet.to(device)\n        controlnet.eval()\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            controlnet=controlnet,\n            cn_cond=mask,\n            save_intermediates=True,\n            intermediate_steps=1,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_ddim_sampler(self, model_params, controlnet_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        controlnet = ControlNet(**controlnet_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet.to(device)\n        controlnet.eval()\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            controlnet=controlnet,\n            cn_cond=mask,\n            save_intermediates=True,\n            intermediate_steps=1,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_rflow_sampler(self, model_params, controlnet_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        controlnet = ControlNet(**controlnet_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet.to(device)\n        controlnet.eval()\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            controlnet=controlnet,\n            cn_cond=mask,\n            save_intermediates=True,\n            intermediate_steps=1,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):\n        model_params[\"with_conditioning\"] = True\n        model_params[\"cross_attention_dim\"] = 3\n        controlnet_params[\"with_conditioning\"] = True\n        controlnet_params[\"cross_attention_dim\"] = 3\n        model = DiffusionModelUNet(**model_params)\n        controlnet = ControlNet(**controlnet_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet.to(device)\n        controlnet.eval()\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n\n        # DDIM\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        conditioning = torch.randn([input_shape[0], 1, 3]).to(device)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            controlnet=controlnet,\n            cn_cond=mask,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n        # RFlow\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        conditioning = torch.randn([input_shape[0], 1, 3]).to(device)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            controlnet=controlnet,\n            cn_cond=mask,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihood(self, model_params, controlnet_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet = ControlNet(**controlnet_params)\n        controlnet.to(device)\n        controlnet.eval()\n        input = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        likelihood, intermediates = inferer.get_likelihood(\n            inputs=input,\n            diffusion_model=model,\n            scheduler=scheduler,\n            controlnet=controlnet,\n            cn_cond=mask,\n            save_intermediates=True,\n        )\n        self.assertEqual(intermediates[0].shape, input.shape)\n        self.assertEqual(likelihood.shape[0], input.shape[0])\n\n    @unittest.skipUnless(has_scipy, \"Requires scipy library.\")\n    def test_normal_cdf(self):\n        from scipy.stats import norm\n\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        x = torch.linspace(-10, 10, 20)\n        cdf_approx = inferer._approx_standard_normal_cdf(x)\n        cdf_true = norm.cdf(x)\n        torch.testing.assert_close(cdf_approx, torch.as_tensor(cdf_true, dtype=cdf_approx.dtype), atol=1e-3, rtol=1e-5)\n\n    @parameterized.expand(CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape):\n        # copy the model_params dict to prevent from modifying test cases\n        model_params = model_params.copy()\n        n_concat_channel = 2\n        model_params[\"in_channels\"] = model_params[\"in_channels\"] + n_concat_channel\n        controlnet_params[\"in_channels\"] = controlnet_params[\"in_channels\"] + n_concat_channel\n        model_params[\"cross_attention_dim\"] = None\n        controlnet_params[\"cross_attention_dim\"] = None\n        model_params[\"with_conditioning\"] = False\n        controlnet_params[\"with_conditioning\"] = False\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        controlnet = ControlNet(**controlnet_params)\n        controlnet.to(device)\n        controlnet.eval()\n        noise = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        conditioning_shape = list(input_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n\n        # DDIM\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            controlnet=controlnet,\n            cn_cond=mask,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n            mode=\"concat\",\n        )\n        self.assertEqual(len(intermediates), 10)\n\n        # RFlow\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = ControlNetDiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            controlnet=controlnet,\n            cn_cond=mask,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n            mode=\"concat\",\n        )\n        self.assertEqual(len(intermediates), 10)\n\n\nclass LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_prediction_shape(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        input = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(latent_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n            timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n            if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                prediction = inferer(\n                    inputs=input,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    controlnet=controlnet,\n                    cn_cond=mask,\n                    seg=input_seg,\n                    noise=noise,\n                    timesteps=timesteps,\n                )\n            else:\n                prediction = inferer(\n                    inputs=input,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    noise=noise,\n                    timesteps=timesteps,\n                    controlnet=controlnet,\n                    cn_cond=mask,\n                )\n            self.assertEqual(prediction.shape, latent_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_pred_shape(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                controlnet=controlnet,\n                cn_cond=mask,\n                scheduler=scheduler,\n                seg=input_seg,\n            )\n        else:\n            sample = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                controlnet=controlnet,\n                cn_cond=mask,\n            )\n        self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_intermediates(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample, intermediates = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                seg=input_seg,\n                controlnet=controlnet,\n                cn_cond=mask,\n                save_intermediates=True,\n                intermediate_steps=1,\n            )\n        else:\n            sample, intermediates = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                save_intermediates=True,\n                intermediate_steps=1,\n                controlnet=controlnet,\n                cn_cond=mask,\n            )\n\n        self.assertEqual(len(intermediates), 10)\n        self.assertEqual(intermediates[0].shape, input_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihoods(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        input = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                controlnet=controlnet,\n                cn_cond=mask,\n                scheduler=scheduler,\n                save_intermediates=True,\n                seg=input_seg,\n            )\n        else:\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                controlnet=controlnet,\n                cn_cond=mask,\n                save_intermediates=True,\n            )\n        self.assertEqual(len(intermediates), 10)\n        self.assertEqual(intermediates[0].shape, latent_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_resample_likelihoods(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        input = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                controlnet=controlnet,\n                cn_cond=mask,\n                save_intermediates=True,\n                resample_latent_likelihoods=True,\n                seg=input_seg,\n            )\n        else:\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                controlnet=controlnet,\n                cn_cond=mask,\n                save_intermediates=True,\n                resample_latent_likelihoods=True,\n            )\n        self.assertEqual(len(intermediates), 10)\n        self.assertEqual(intermediates[0].shape[2:], input_shape[2:])\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_prediction_shape_conditioned_concat(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        stage_2_params = stage_2_params.copy()\n        controlnet_params = controlnet_params.copy()\n        n_concat_channel = 3\n        stage_2_params[\"in_channels\"] = stage_2_params[\"in_channels\"] + n_concat_channel\n        controlnet_params[\"in_channels\"] = controlnet_params[\"in_channels\"] + n_concat_channel\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        input = torch.randn(input_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        noise = torch.randn(latent_shape).to(device)\n        conditioning_shape = list(latent_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n\n        if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            prediction = inferer(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                noise=noise,\n                controlnet=controlnet,\n                cn_cond=mask,\n                timesteps=timesteps,\n                condition=conditioning,\n                mode=\"concat\",\n                seg=input_seg,\n            )\n        else:\n            prediction = inferer(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                noise=noise,\n                controlnet=controlnet,\n                cn_cond=mask,\n                timesteps=timesteps,\n                condition=conditioning,\n                mode=\"concat\",\n            )\n        self.assertEqual(prediction.shape, latent_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shape_conditioned_concat(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        stage_2_params = stage_2_params.copy()\n        controlnet_params = controlnet_params.copy()\n        n_concat_channel = 3\n        stage_2_params[\"in_channels\"] = stage_2_params[\"in_channels\"] + n_concat_channel\n        controlnet_params[\"in_channels\"] = controlnet_params[\"in_channels\"] + n_concat_channel\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        conditioning_shape = list(latent_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                controlnet=controlnet,\n                cn_cond=mask,\n                scheduler=scheduler,\n                conditioning=conditioning,\n                mode=\"concat\",\n                seg=input_seg,\n            )\n        else:\n            sample = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                controlnet=controlnet,\n                cn_cond=mask,\n                scheduler=scheduler,\n                conditioning=conditioning,\n                mode=\"concat\",\n            )\n        self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_different_latents(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(latent_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        # We infer the VAE shape\n        autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params[\"channels\"]) - 1)) for i in input_shape[2:]]\n        inferer = ControlNetLatentDiffusionInferer(\n            scheduler=scheduler,\n            scale_factor=1.0,\n            ldm_latent_shape=list(latent_shape[2:]),\n            autoencoder_latent_shape=autoencoder_latent_shape,\n        )\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n\n        if dm_model_type == \"SPADEDiffusionModelUNet\" or ae_model_type == \"SPADEAutoencoderKL\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            prediction = inferer(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                controlnet=controlnet,\n                cn_cond=mask,\n                noise=noise,\n                timesteps=timesteps,\n                seg=input_seg,\n            )\n        else:\n            prediction = inferer(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                noise=noise,\n                controlnet=controlnet,\n                cn_cond=mask,\n                timesteps=timesteps,\n            )\n        self.assertEqual(prediction.shape, latent_shape)\n\n    @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shape_different_latents(\n        self,\n        ae_model_type,\n        autoencoder_params,\n        dm_model_type,\n        stage_2_params,\n        controlnet_params,\n        input_shape,\n        latent_shape,\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n        controlnet = ControlNet(**controlnet_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n        mask = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        # We infer the VAE shape\n        if ae_model_type == \"VQVAE\":\n            autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params[\"channels\"]))) for i in input_shape[2:]]\n        else:\n            autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params[\"channels\"]) - 1)) for i in input_shape[2:]]\n\n        inferer = ControlNetLatentDiffusionInferer(\n            scheduler=scheduler,\n            scale_factor=1.0,\n            ldm_latent_shape=list(latent_shape[2:]),\n            autoencoder_latent_shape=autoencoder_latent_shape,\n        )\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if dm_model_type == \"SPADEDiffusionModelUNet\" or ae_model_type == \"SPADEAutoencoderKL\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            prediction, _ = inferer.sample(\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                controlnet=controlnet,\n                cn_cond=mask,\n                input_noise=noise,\n                seg=input_seg,\n                save_intermediates=True,\n            )\n        else:\n            prediction = inferer.sample(\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                input_noise=noise,\n                controlnet=controlnet,\n                cn_cond=mask,\n                save_intermediates=False,\n            )\n        self.assertEqual(prediction.shape, input_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_incompatible_spade_setup(self):\n        stage_1 = SPADEAutoencoderKL(\n            spatial_dims=2,\n            label_nc=6,\n            in_channels=1,\n            out_channels=1,\n            channels=(4, 4),\n            latent_channels=3,\n            attention_levels=[False, False],\n            num_res_blocks=1,\n            with_encoder_nonlocal_attn=False,\n            with_decoder_nonlocal_attn=False,\n            norm_num_groups=4,\n        )\n        stage_2 = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=3,\n            out_channels=3,\n            channels=[4, 4],\n            norm_num_groups=4,\n            attention_levels=[False, False],\n            num_res_blocks=1,\n            num_head_channels=4,\n        )\n        controlnet = ControlNet(\n            spatial_dims=2,\n            in_channels=1,\n            channels=[4, 4],\n            norm_num_groups=4,\n            attention_levels=[False, False],\n            num_res_blocks=1,\n            num_head_channels=4,\n            conditioning_embedding_num_channels=[16],\n        )\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        controlnet.to(device)\n        controlnet.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        controlnet.eval()\n        noise = torch.randn((1, 3, 4, 4)).to(device)\n        mask = torch.randn((1, 1, 4, 4)).to(device)\n        input_seg = torch.randn((1, 3, 8, 8)).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        with self.assertRaises(ValueError):\n            _ = inferer.sample(\n                input_noise=noise,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                controlnet=controlnet,\n                cn_cond=mask,\n                seg=input_seg,\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_diffusion_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.inferers import DiffusionInferer\nfrom monai.networks.nets import DiffusionModelUNet\nfrom monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler\nfrom monai.utils import optional_import\n\n_, has_scipy = optional_import(\"scipy\")\n_, has_einops = optional_import(\"einops\")\n\nTEST_CASES = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (2, 1, 8, 8),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [True],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (2, 1, 8, 8, 8),\n    ],\n]\n\n\nclass TestDiffusionSamplingInferer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_call(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n        sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps)\n        self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_intermediates(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_cfg(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            cfg=5,\n            cfg_fill_value=-1,\n        )\n        self.assertEqual(sample.shape, noise.shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_ddpm_sampler(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_ddim_sampler(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_rflow_sampler(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned(self, model_params, input_shape):\n        model_params[\"with_conditioning\"] = True\n        model_params[\"cross_attention_dim\"] = 3\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        conditioning = torch.randn([input_shape[0], 1, 3]).to(device)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned_rflow(self, model_params, input_shape):\n        model_params[\"with_conditioning\"] = True\n        model_params[\"cross_attention_dim\"] = 3\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        conditioning = torch.randn([input_shape[0], 1, 3]).to(device)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihood(self, model_params, input_shape):\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        input = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        likelihood, intermediates = inferer.get_likelihood(\n            inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True\n        )\n        self.assertEqual(intermediates[0].shape, input.shape)\n        self.assertEqual(likelihood.shape[0], input.shape[0])\n\n    @unittest.skipUnless(has_scipy, \"Requires scipy library.\")\n    def test_normal_cdf(self):\n        from scipy.stats import norm\n\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = DiffusionInferer(scheduler=scheduler)\n\n        x = torch.linspace(-10, 10, 20)\n        cdf_approx = inferer._approx_standard_normal_cdf(x)\n        cdf_true = norm.cdf(x)\n        torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned_concat(self, model_params, input_shape):\n        # copy the model_params dict to prevent from modifying test cases\n        model_params = model_params.copy()\n        n_concat_channel = 2\n        model_params[\"in_channels\"] = model_params[\"in_channels\"] + n_concat_channel\n        model_params[\"cross_attention_dim\"] = None\n        model_params[\"with_conditioning\"] = False\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        conditioning_shape = list(input_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n            mode=\"concat\",\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned_concat_cfg(self, model_params, input_shape):\n        # copy the model_params dict to prevent from modifying test cases\n        model_params = model_params.copy()\n        n_concat_channel = 2\n        model_params[\"in_channels\"] = model_params[\"in_channels\"] + n_concat_channel\n        model_params[\"cross_attention_dim\"] = None\n        model_params[\"with_conditioning\"] = False\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        conditioning_shape = list(input_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n            mode=\"concat\",\n            cfg=5,\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):\n        # copy the model_params dict to prevent from modifying test cases\n        model_params = model_params.copy()\n        n_concat_channel = 2\n        model_params[\"in_channels\"] = model_params[\"in_channels\"] + n_concat_channel\n        model_params[\"cross_attention_dim\"] = None\n        model_params[\"with_conditioning\"] = False\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        noise = torch.randn(input_shape).to(device)\n        conditioning_shape = list(input_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        sample, intermediates = inferer.sample(\n            input_noise=noise,\n            diffusion_model=model,\n            scheduler=scheduler,\n            save_intermediates=True,\n            intermediate_steps=1,\n            conditioning=conditioning,\n            mode=\"concat\",\n        )\n        self.assertEqual(len(intermediates), 10)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_call_conditioned_concat(self, model_params, input_shape):\n        # copy the model_params dict to prevent from modifying test cases\n        model_params = model_params.copy()\n        n_concat_channel = 2\n        model_params[\"in_channels\"] = model_params[\"in_channels\"] + n_concat_channel\n        model_params[\"cross_attention_dim\"] = None\n        model_params[\"with_conditioning\"] = False\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        conditioning_shape = list(input_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n        sample = inferer(\n            inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode=\"concat\"\n        )\n        self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_call_conditioned_concat_rflow(self, model_params, input_shape):\n        # copy the model_params dict to prevent from modifying test cases\n        model_params = model_params.copy()\n        n_concat_channel = 2\n        model_params[\"in_channels\"] = model_params[\"in_channels\"] + n_concat_channel\n        model_params[\"cross_attention_dim\"] = None\n        model_params[\"with_conditioning\"] = False\n        model = DiffusionModelUNet(**model_params)\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(input_shape).to(device)\n        conditioning_shape = list(input_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        inferer = DiffusionInferer(scheduler=scheduler)\n        scheduler.set_timesteps(num_inference_steps=10)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n        sample = inferer(\n            inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode=\"concat\"\n        )\n        self.assertEqual(sample.shape, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_latent_diffusion_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.inferers import LatentDiffusionInferer\nfrom monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet\nfrom monai.networks.schedulers import DDPMScheduler, RFlowScheduler\nfrom monai.utils import optional_import\n\n_, has_einops = optional_import(\"einops\")\nTEST_CASES = [\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (1, 1, 16, 16, 16),\n        (1, 3, 4, 4, 4),\n    ],\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"SPADEDiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n]\nTEST_CASES_DIFF_SHAPES = [\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        (1, 1, 12, 12),\n        (1, 3, 8, 8),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (1, 1, 12, 12),\n        (1, 3, 8, 8),\n    ],\n    [\n        \"VQVAE\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": [4, 4],\n            \"num_res_layers\": 1,\n            \"num_res_channels\": [4, 4],\n            \"downsample_parameters\": ((2, 4, 1, 1), (2, 4, 1, 1)),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 3,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [8, 8],\n            \"norm_num_groups\": 8,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 8,\n        },\n        (1, 1, 12, 12, 12),\n        (1, 3, 8, 8, 8),\n    ],\n    [\n        \"SPADEAutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"DiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"AutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"SPADEDiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n    [\n        \"SPADEAutoencoderKL\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"latent_channels\": 3,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n            \"norm_num_groups\": 4,\n        },\n        \"SPADEDiffusionModelUNet\",\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"channels\": [4, 4],\n            \"norm_num_groups\": 4,\n            \"attention_levels\": [False, False],\n            \"num_res_blocks\": 1,\n            \"num_head_channels\": 4,\n        },\n        (1, 1, 8, 8),\n        (1, 3, 4, 4),\n    ],\n]\n\n\nclass TestDiffusionSamplingInferer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_prediction_shape(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(latent_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n            timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n\n            if dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                prediction = inferer(\n                    inputs=input,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    seg=input_seg,\n                    noise=noise,\n                    timesteps=timesteps,\n                )\n            else:\n                prediction = inferer(\n                    inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps\n                )\n            self.assertEqual(prediction.shape, latent_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shape(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                sample = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    seg=input_seg,\n                )\n            else:\n                sample = inferer.sample(\n                    input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler\n                )\n            self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shape_with_cfg(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                sample = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    seg=input_seg,\n                    cfg=5,\n                    cfg_fill_value=-1,\n                )\n            else:\n                sample = inferer.sample(\n                    input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5\n                )\n            self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_intermediates(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            if ae_model_type == \"SPADEAutoencoderKL\" or dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                sample, intermediates = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    seg=input_seg,\n                    save_intermediates=True,\n                    intermediate_steps=1,\n                )\n            else:\n                sample, intermediates = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    save_intermediates=True,\n                    intermediate_steps=1,\n                )\n            self.assertEqual(len(intermediates), 10)\n            self.assertEqual(intermediates[0].shape, input_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihoods(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                save_intermediates=True,\n                seg=input_seg,\n            )\n        else:\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                save_intermediates=True,\n            )\n        self.assertEqual(len(intermediates), 10)\n        self.assertEqual(intermediates[0].shape, latent_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_resample_likelihoods(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n        scheduler = DDPMScheduler(num_train_timesteps=10)\n        inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n        scheduler.set_timesteps(num_inference_steps=10)\n\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            input_shape_seg = list(input_shape)\n            if \"label_nc\" in stage_2_params.keys():\n                input_shape_seg[1] = stage_2_params[\"label_nc\"]\n            else:\n                input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n            input_seg = torch.randn(input_shape_seg).to(device)\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                save_intermediates=True,\n                resample_latent_likelihoods=True,\n                seg=input_seg,\n            )\n        else:\n            sample, intermediates = inferer.get_likelihood(\n                inputs=input,\n                autoencoder_model=stage_1,\n                diffusion_model=stage_2,\n                scheduler=scheduler,\n                save_intermediates=True,\n                resample_latent_likelihoods=True,\n            )\n        self.assertEqual(len(intermediates), 10)\n        self.assertEqual(intermediates[0].shape[2:], input_shape[2:])\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_prediction_shape_conditioned_concat(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        stage_2_params = stage_2_params.copy()\n        n_concat_channel = 3\n        stage_2_params[\"in_channels\"] = stage_2_params[\"in_channels\"] + n_concat_channel\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(latent_shape).to(device)\n        conditioning_shape = list(latent_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n\n            if dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                prediction = inferer(\n                    inputs=input,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    noise=noise,\n                    timesteps=timesteps,\n                    condition=conditioning,\n                    mode=\"concat\",\n                    seg=input_seg,\n                )\n            else:\n                prediction = inferer(\n                    inputs=input,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    noise=noise,\n                    timesteps=timesteps,\n                    condition=conditioning,\n                    mode=\"concat\",\n                )\n            self.assertEqual(prediction.shape, latent_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shape_conditioned_concat(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        stage_2_params = stage_2_params.copy()\n        n_concat_channel = 3\n        stage_2_params[\"in_channels\"] = stage_2_params[\"in_channels\"] + n_concat_channel\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n        conditioning_shape = list(latent_shape)\n        conditioning_shape[1] = n_concat_channel\n        conditioning = torch.randn(conditioning_shape).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            if dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                sample = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    conditioning=conditioning,\n                    mode=\"concat\",\n                    seg=input_seg,\n                )\n            else:\n                sample = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    conditioning=conditioning,\n                    mode=\"concat\",\n                )\n            self.assertEqual(sample.shape, input_shape)\n\n    @parameterized.expand(TEST_CASES_DIFF_SHAPES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_different_latents(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n        noise = torch.randn(latent_shape).to(device)\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            # We infer the VAE shape\n            autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params[\"channels\"]) - 1)) for i in input_shape[2:]]\n            inferer = LatentDiffusionInferer(\n                scheduler=scheduler,\n                scale_factor=1.0,\n                ldm_latent_shape=list(latent_shape[2:]),\n                autoencoder_latent_shape=autoencoder_latent_shape,\n            )\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()\n\n            if dm_model_type == \"SPADEDiffusionModelUNet\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                prediction = inferer(\n                    inputs=input,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    noise=noise,\n                    timesteps=timesteps,\n                    seg=input_seg,\n                )\n            else:\n                prediction = inferer(\n                    inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps\n                )\n            self.assertEqual(prediction.shape, latent_shape)\n\n    @parameterized.expand(TEST_CASES_DIFF_SHAPES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shape_different_latents(\n        self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape\n    ):\n        stage_1 = None\n\n        if ae_model_type == \"AutoencoderKL\":\n            stage_1 = AutoencoderKL(**autoencoder_params)\n        if ae_model_type == \"VQVAE\":\n            stage_1 = VQVAE(**autoencoder_params)\n        if ae_model_type == \"SPADEAutoencoderKL\":\n            stage_1 = SPADEAutoencoderKL(**autoencoder_params)\n        if dm_model_type == \"SPADEDiffusionModelUNet\":\n            stage_2 = SPADEDiffusionModelUNet(**stage_2_params)\n        else:\n            stage_2 = DiffusionModelUNet(**stage_2_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        noise = torch.randn(latent_shape).to(device)\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            # We infer the VAE shape\n            if ae_model_type == \"VQVAE\":\n                autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params[\"channels\"]))) for i in input_shape[2:]]\n            else:\n                autoencoder_latent_shape = [\n                    i // (2 ** (len(autoencoder_params[\"channels\"]) - 1)) for i in input_shape[2:]\n                ]\n\n            inferer = LatentDiffusionInferer(\n                scheduler=scheduler,\n                scale_factor=1.0,\n                ldm_latent_shape=list(latent_shape[2:]),\n                autoencoder_latent_shape=autoencoder_latent_shape,\n            )\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            if dm_model_type == \"SPADEDiffusionModelUNet\" or ae_model_type == \"SPADEAutoencoderKL\":\n                input_shape_seg = list(input_shape)\n                if \"label_nc\" in stage_2_params.keys():\n                    input_shape_seg[1] = stage_2_params[\"label_nc\"]\n                else:\n                    input_shape_seg[1] = autoencoder_params[\"label_nc\"]\n                input_seg = torch.randn(input_shape_seg).to(device)\n                prediction, _ = inferer.sample(\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    input_noise=noise,\n                    save_intermediates=True,\n                    seg=input_seg,\n                )\n            else:\n                prediction = inferer.sample(\n                    autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False\n                )\n            self.assertEqual(prediction.shape, input_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_incompatible_spade_setup(self):\n        stage_1 = SPADEAutoencoderKL(\n            spatial_dims=2,\n            label_nc=6,\n            in_channels=1,\n            out_channels=1,\n            channels=(4, 4),\n            latent_channels=3,\n            attention_levels=[False, False],\n            num_res_blocks=1,\n            with_encoder_nonlocal_attn=False,\n            with_decoder_nonlocal_attn=False,\n            norm_num_groups=4,\n        )\n        stage_2 = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=3,\n            out_channels=3,\n            channels=[4, 4],\n            norm_num_groups=4,\n            attention_levels=[False, False],\n            num_res_blocks=1,\n            num_head_channels=4,\n        )\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n        noise = torch.randn((1, 3, 4, 4)).to(device)\n        input_seg = torch.randn((1, 3, 8, 8)).to(device)\n\n        for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:\n            inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)\n            scheduler.set_timesteps(num_inference_steps=10)\n\n            with self.assertRaises(ValueError):\n                _ = inferer.sample(\n                    input_noise=noise,\n                    autoencoder_model=stage_1,\n                    diffusion_model=stage_2,\n                    scheduler=scheduler,\n                    seg=input_seg,\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_patch_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\nfrom torch.nn.functional import avg_pool2d, pad\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.inferers import AvgMerger, PatchInferer, SlidingWindowSplitter\nfrom tests.test_utils import assert_allclose\n\nTENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)\nTENSOR_2x2 = avg_pool2d(TENSOR_4x4, 2, 2)\n\n# no-overlapping 2x2 patches\nTEST_CASE_0_TENSOR = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# no-overlapping 2x2 patches using all default parameters (except for splitter)\nTEST_CASE_1_TENSOR = [TENSOR_4x4, dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))), lambda x: x, TENSOR_4x4]\n\n# divisible batch_size\nTEST_CASE_2_TENSOR = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=2),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# non-divisible batch_size\nTEST_CASE_3_TENSOR = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=3),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# patches that are already split (Splitter should be None)\nTEST_CASE_4_SPLIT_LIST = [\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    dict(splitter=None, merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# using all default parameters (patches are already split)\nTEST_CASE_5_SPLIT_LIST = [\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    dict(merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# output smaller than input patches\nTEST_CASE_6_SMALLER = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x: torch.mean(x, dim=(-1, -2), keepdim=True),\n    TENSOR_2x2,\n]\n\n# preprocess patches\nTEST_CASE_7_PREPROCESS = [\n    TENSOR_4x4,\n    dict(\n        splitter=SlidingWindowSplitter(patch_size=(2, 2)),\n        merger_cls=AvgMerger,\n        preprocessing=lambda x: 2 * x,\n        postprocessing=None,\n    ),\n    lambda x: x,\n    2 * TENSOR_4x4,\n]\n\n# preprocess patches\nTEST_CASE_8_POSTPROCESS = [\n    TENSOR_4x4,\n    dict(\n        splitter=SlidingWindowSplitter(patch_size=(2, 2)),\n        merger_cls=AvgMerger,\n        preprocessing=None,\n        postprocessing=lambda x: 4 * x,\n    ),\n    lambda x: x,\n    4 * TENSOR_4x4,\n]\n\n# str merger as the class name\nTEST_CASE_9_STR_MERGER = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=\"AvgMerger\"),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# str merger as dotted patch\nTEST_CASE_10_STR_MERGER = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=\"monai.inferers.merger.AvgMerger\"),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# non-divisible patch_size leading to larger image (without matching spatial shape)\nTEST_CASE_11_PADDING = [\n    TENSOR_4x4,\n    dict(\n        splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=\"constant\", pad_value=0.0),\n        merger_cls=AvgMerger,\n        match_spatial_shape=False,\n    ),\n    lambda x: x,\n    pad(TENSOR_4x4, (0, 2), value=0.0),\n]\n\n# non-divisible patch_size with matching spatial shapes\nTEST_CASE_12_MATCHING = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=None), merger_cls=AvgMerger),\n    lambda x: x,\n    pad(TENSOR_4x4[..., :3], (0, 1), value=float(\"nan\")),\n]\n\n# non-divisible patch_size with matching spatial shapes\nTEST_CASE_13_PADDING_MATCHING = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 3)), merger_cls=AvgMerger),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# multi-threading\nTEST_CASE_14_MULTITHREAD_BUFFER = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# multi-threading with batch\nTEST_CASE_15_MULTITHREADD_BUFFER = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4),\n    lambda x: x,\n    TENSOR_4x4,\n]\n\n# list of tensor output\nTEST_CASE_0_LIST_TENSOR = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x: (x, x),\n    (TENSOR_4x4, TENSOR_4x4),\n]\n\n# list of tensor output\nTEST_CASE_0_DICT = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x: {\"model_output\": x},\n    {\"model_output\": TENSOR_4x4},\n]\n\n# ----------------------------------------------------------------------------\n# Error test cases\n# ----------------------------------------------------------------------------\n# invalid splitter: not callable\nTEST_CASE_ERROR_0 = [None, dict(splitter=1), TypeError]\n# invalid merger: non-existent merger\nTEST_CASE_ERROR_1 = [\n    None,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=\"NonExistent\"),\n    ValueError,\n]\n# invalid merger: callable\nTEST_CASE_ERROR_2 = [None, dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=lambda x: x), TypeError]\n# invalid merger: Merger object\nTEST_CASE_ERROR_3 = [\n    None,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger(merged_shape=(1, 1))),\n    TypeError,\n]\n# invalid merger: list of Merger class\nTEST_CASE_ERROR_4 = [\n    None,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=[AvgMerger, AvgMerger]),\n    TypeError,\n]\n# invalid preprocessing\nTEST_CASE_ERROR_5 = [None, dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), preprocessing=1), TypeError]\n# invalid postprocessing\nTEST_CASE_ERROR_6 = [None, dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), postprocessing=1), TypeError]\n# provide splitter when data is already split (splitter is not None)\nTEST_CASE_ERROR_7 = [\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))),\n    ValueError,\n]\n# invalid inputs: split patches tensor without location\nTEST_CASE_ERROR_8 = [torch.zeros(2, 2), dict(splitter=None), ValueError]\n# invalid inputs: split patches MetaTensor without location metadata\nTEST_CASE_ERROR_9 = [MetaTensor(torch.zeros(2, 2)), dict(splitter=None), ValueError]\n# merged_shape is not provided for the merger\nTEST_CASE_ERROR_10 = [\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))),\n    ValueError,\n]\n\n\nclass PatchInfererTests(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_0_TENSOR,\n            TEST_CASE_1_TENSOR,\n            TEST_CASE_2_TENSOR,\n            TEST_CASE_3_TENSOR,\n            TEST_CASE_4_SPLIT_LIST,\n            TEST_CASE_5_SPLIT_LIST,\n            TEST_CASE_6_SMALLER,\n            TEST_CASE_7_PREPROCESS,\n            TEST_CASE_8_POSTPROCESS,\n            TEST_CASE_9_STR_MERGER,\n            TEST_CASE_10_STR_MERGER,\n            TEST_CASE_11_PADDING,\n            TEST_CASE_12_MATCHING,\n            TEST_CASE_13_PADDING_MATCHING,\n            TEST_CASE_14_MULTITHREAD_BUFFER,\n            TEST_CASE_15_MULTITHREADD_BUFFER,\n        ]\n    )\n    def test_patch_inferer_tensor(self, inputs, arguments, network, expected):\n        inferer = PatchInferer(**arguments)\n        output = inferer(inputs=inputs, network=network)\n        assert_allclose(output, expected)\n\n    @parameterized.expand([TEST_CASE_0_LIST_TENSOR])\n    def test_patch_inferer_list_tensor(self, inputs, arguments, network, expected):\n        inferer = PatchInferer(**arguments)\n        output = inferer(inputs=inputs, network=network)\n        for out, exp in zip(output, expected):\n            assert_allclose(out, exp)\n\n    @parameterized.expand([TEST_CASE_0_DICT])\n    def test_patch_inferer_dict(self, inputs, arguments, network, expected):\n        inferer = PatchInferer(**arguments)\n        output = inferer(inputs=inputs, network=network)\n        for k in expected:\n            assert_allclose(output[k], expected[k])\n\n    @parameterized.expand(\n        [\n            TEST_CASE_ERROR_0,\n            TEST_CASE_ERROR_1,\n            TEST_CASE_ERROR_2,\n            TEST_CASE_ERROR_3,\n            TEST_CASE_ERROR_4,\n            TEST_CASE_ERROR_5,\n            TEST_CASE_ERROR_6,\n            TEST_CASE_ERROR_7,\n            TEST_CASE_ERROR_8,\n            TEST_CASE_ERROR_9,\n        ]\n    )\n    def test_patch_inferer_errors(self, inputs, arguments, expected_error):\n        with self.assertRaises(expected_error):\n            PatchInferer(**arguments)\n            inferer = PatchInferer(**arguments)\n            inferer(inputs=inputs, network=lambda x: x)\n\n\n# ----------------------------------------------------------------------------\n# Error test cases with conditionign\n# ----------------------------------------------------------------------------\n\n# no-overlapping 2x2 patches\nTEST_CASE_0_TENSOR_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# no-overlapping 2x2 patches using all default parameters (except for splitter)\nTEST_CASE_1_TENSOR_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2))),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# divisible batch_size\nTEST_CASE_2_TENSOR_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=2),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# non-divisible batch_size\nTEST_CASE_3_TENSOR_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, batch_size=3),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# patches that are already split (Splitter should be None)\nTEST_CASE_4_SPLIT_LIST_c = [\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    dict(splitter=None, merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# using all default parameters (patches are already split)\nTEST_CASE_5_SPLIT_LIST_c = [\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    dict(merger_cls=AvgMerger, merged_shape=(2, 3, 4, 4)),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# output smaller than input patches\nTEST_CASE_6_SMALLER_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x, condition: torch.mean(x, dim=(-1, -2), keepdim=True) + torch.mean(condition, dim=(-1, -2), keepdim=True),\n    TENSOR_2x2 * 2,\n]\n\n# preprocess patches\nTEST_CASE_7_PREPROCESS_c = [\n    TENSOR_4x4,\n    dict(\n        splitter=SlidingWindowSplitter(patch_size=(2, 2)),\n        merger_cls=AvgMerger,\n        preprocessing=lambda x: 2 * x,\n        postprocessing=None,\n    ),\n    lambda x, condition: x + condition,\n    2 * TENSOR_4x4 + TENSOR_4x4,\n]\n\n# preprocess patches\nTEST_CASE_8_POSTPROCESS_c = [\n    TENSOR_4x4,\n    dict(\n        splitter=SlidingWindowSplitter(patch_size=(2, 2)),\n        merger_cls=AvgMerger,\n        preprocessing=None,\n        postprocessing=lambda x: 4 * x,\n    ),\n    lambda x, condition: x + condition,\n    4 * TENSOR_4x4 * 2,\n]\n\n# str merger as the class name\nTEST_CASE_9_STR_MERGER_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=\"AvgMerger\"),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# str merger as dotted patch\nTEST_CASE_10_STR_MERGER_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=\"monai.inferers.merger.AvgMerger\"),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# non-divisible patch_size leading to larger image (without matching spatial shape)\nTEST_CASE_11_PADDING_c = [\n    TENSOR_4x4,\n    dict(\n        splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=\"constant\", pad_value=0.0),\n        merger_cls=AvgMerger,\n        match_spatial_shape=False,\n    ),\n    lambda x, condition: x + condition,\n    pad(TENSOR_4x4, (0, 2), value=0.0) * 2,\n]\n\n# non-divisible patch_size with matching spatial shapes\nTEST_CASE_12_MATCHING_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 3), pad_mode=None), merger_cls=AvgMerger),\n    lambda x, condition: x + condition,\n    pad(TENSOR_4x4[..., :3], (0, 1), value=float(\"nan\")) * 2,\n]\n\n# non-divisible patch_size with matching spatial shapes\nTEST_CASE_13_PADDING_MATCHING_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 3)), merger_cls=AvgMerger),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# multi-threading\nTEST_CASE_14_MULTITHREAD_BUFFER_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# multi-threading with batch\nTEST_CASE_15_MULTITHREADD_BUFFER_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4),\n    lambda x, condition: x + condition,\n    TENSOR_4x4 * 2,\n]\n\n# list of tensor output\nTEST_CASE_0_LIST_TENSOR_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x, condition: (x + condition, x + condition),\n    (TENSOR_4x4 * 2, TENSOR_4x4 * 2),\n]\n\n# list of tensor output\nTEST_CASE_0_DICT_c = [\n    TENSOR_4x4,\n    dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger),\n    lambda x, condition: {\"model_output\": x + condition},\n    {\"model_output\": TENSOR_4x4 * 2},\n]\n\n\nclass PatchInfererTestsCond(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_0_TENSOR_c,\n            TEST_CASE_1_TENSOR_c,\n            TEST_CASE_2_TENSOR_c,\n            TEST_CASE_3_TENSOR_c,\n            TEST_CASE_4_SPLIT_LIST_c,\n            TEST_CASE_5_SPLIT_LIST_c,\n            TEST_CASE_6_SMALLER_c,\n            TEST_CASE_7_PREPROCESS_c,\n            TEST_CASE_8_POSTPROCESS_c,\n            TEST_CASE_9_STR_MERGER_c,\n            TEST_CASE_10_STR_MERGER_c,\n            TEST_CASE_11_PADDING_c,\n            TEST_CASE_12_MATCHING_c,\n            TEST_CASE_13_PADDING_MATCHING_c,\n            TEST_CASE_14_MULTITHREAD_BUFFER_c,\n            TEST_CASE_15_MULTITHREADD_BUFFER_c,\n        ]\n    )\n    def test_patch_inferer_tensor(self, inputs, arguments, network, expected):\n        if isinstance(inputs, list):  # case 4 and 5\n            condition = [(x[0].clone(), x[1]) for x in inputs]\n        else:\n            condition = inputs.clone()\n        inferer = PatchInferer(**arguments)\n        output = inferer(inputs=inputs, network=network, condition=condition)\n        assert_allclose(output, expected)\n\n    @parameterized.expand([TEST_CASE_0_LIST_TENSOR_c])\n    def test_patch_inferer_list_tensor(self, inputs, arguments, network, expected):\n        if isinstance(inputs, list):  # case 4 and 5\n            condition = [(x[0].clone(), x[1]) for x in inputs]\n        else:\n            condition = inputs.clone()\n        inferer = PatchInferer(**arguments)\n        output = inferer(inputs=inputs, network=network, condition=condition)\n        for out, exp in zip(output, expected):\n            assert_allclose(out, exp)\n\n    @parameterized.expand([TEST_CASE_0_DICT_c])\n    def test_patch_inferer_dict(self, inputs, arguments, network, expected):\n        if isinstance(inputs, list):  # case 4 and 5\n            condition = [(x[0].clone(), x[1]) for x in inputs]\n        else:\n            condition = inputs.clone()\n        inferer = PatchInferer(**arguments)\n        output = inferer(inputs=inputs, network=network, condition=condition)\n        for k in expected:\n            assert_allclose(output[k], expected[k])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_saliency_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.inferers import SaliencyInferer\nfrom monai.networks.nets import DenseNet\nfrom monai.visualize.visualizer import default_upsampler\n\nTEST_CASE_1 = [\"CAM\"]\n\nTEST_CASE_2 = [\"GradCAM\"]\n\nTEST_CASE_3 = [\"GradCAMpp\"]\n\n\nclass TestSaliencyInferer(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, cam_name):\n        model = DenseNet(\n            spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)\n        )\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n\n        image = torch.rand((2, 1, 6, 6, 6), device=device)\n        target_layer = \"class_layers.relu\"\n        fc_layer = \"class_layers.out\"\n        if cam_name == \"CAM\":\n            inferer = SaliencyInferer(cam_name, target_layer, None, fc_layer, upsampler=default_upsampler)\n            result = inferer(inputs=image, network=model, layer_idx=-1)\n        else:\n            inferer = SaliencyInferer(cam_name, target_layer, None, upsampler=default_upsampler)\n            result = inferer(image, model, -1, retain_graph=False)\n\n        self.assertTupleEqual(result.shape, (2, 1, 6, 6, 6))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_slice_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.inferers import SliceInferer\nfrom monai.networks.nets import UNet\n\nTEST_CASES = [\"0\", \"1\", \"2\"]\n\n\nclass TestSliceInferer(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, spatial_dim):\n        spatial_dim = int(spatial_dim)\n\n        model = UNet(\n            spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2\n        )\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n\n        # Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n        input_volume = torch.ones(1, 1, 64, 256, 256, device=device)\n\n        # Remove spatial dim to slide across from the roi_size\n        roi_size = list(input_volume.shape[2:])\n        roi_size.pop(spatial_dim)\n\n        # Initialize and run inferer\n        inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1)\n        result = inferer(input_volume, model)\n\n        self.assertTupleEqual(result.shape, input_volume.shape)\n\n        # test that the inferer can be run multiple times\n        result = inferer(input_volume, model)\n\n\nclass TestSliceInfererCond(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, spatial_dim):\n        spatial_dim = int(spatial_dim)\n\n        model = UNet(\n            spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16), strides=(2, 2), num_res_units=2\n        )\n\n        # overwrite the forward method to test the inferer with a model that takes a condition\n        model.forward = lambda x, condition: x + condition if condition is not None else x\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n\n        # Initialize a dummy 3D tensor volume with shape (N,C,D,H,W)\n        input_volume = torch.ones(1, 1, 64, 256, 256, device=device)\n        condition_volume = torch.ones(1, 1, 64, 256, 256, device=device)\n        # Remove spatial dim to slide across from the roi_size\n        roi_size = list(input_volume.shape[2:])\n        roi_size.pop(spatial_dim)\n\n        # Initialize and run inferer\n        inferer = SliceInferer(roi_size=roi_size, spatial_dim=spatial_dim, sw_batch_size=1, cval=-1)\n        result = inferer(input_volume, model, condition=condition_volume)\n\n        self.assertTupleEqual(result.shape, input_volume.shape)\n        self.assertEqual(result.sum(), (input_volume + condition_volume).sum())\n        # test that the inferer can be run multiple times\n        result = inferer(input_volume, model, condition=condition_volume)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_sliding_window_inference.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.utils import list_data_collate\nfrom monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference\nfrom monai.inferers.utils import _compute_coords\nfrom monai.utils import optional_import\nfrom tests.test_utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick\n\n_, has_tqdm = optional_import(\"tqdm\")\n\nTEST_CASES = [\n    [(2, 3, 16), (4,), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 1D small roi\n    [(2, 3, 16, 15, 7, 9), 4, 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 4D small roi\n    [(1, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n    [(2, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n    [(3, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n    [(2, 3, 16, 15, 7), (4, -1, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n    [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n    [(1, 3, 16, 15, 7), (20, 22, 23), 10, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D large roi\n    [(2, 3, 15, 7), (2, 6), 1000, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 2D small roi, large batch\n    [(1, 3, 16, 7), (80, 50), 7, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 2D large roi\n    [(1, 3, 16, 15, 7), (20, 22, 23), 10, 0.5, \"constant\", torch.device(\"cpu:0\")],  # 3D large overlap\n    [(1, 3, 16, 7), (80, 50), 7, 0.5, \"gaussian\", torch.device(\"cpu:0\")],  # 2D large overlap, gaussian\n    [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, \"gaussian\", torch.device(\"cpu:0\")],  # 3D small roi, gaussian\n    [(3, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, \"gaussian\", torch.device(\"cpu:0\")],  # 3D small roi, gaussian\n    [(1, 3, 16, 15, 7), (4, 10, 7), 3, 0.25, \"gaussian\", torch.device(\"cuda:0\")],  # test inference on gpu if availabe\n    [(1, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n    [(5, 3, 16, 15, 7), (4, 1, 7), 3, 0.25, \"constant\", torch.device(\"cpu:0\")],  # 3D small roi\n]\n\n_devices = [[\"cpu\", \"cuda:0\"]] if torch.cuda.is_available() else [[\"cpu\"]]\n_windows = [\n    [(2, 3, 10, 11), (7, 10), 0.8, 5],\n    [(2, 3, 10, 11), (15, 12), 0, 2],\n    [(2, 3, 10, 11), (10, 11), 0, 3],\n    [(2, 3, 511, 237), (96, 80), 0.4, 5],\n    [(2, 3, 512, 245), (96, 80), 0, 5],\n    [(2, 3, 512, 245), (512, 80), 0.125, 5],\n    [(2, 3, 10, 11, 12), (7, 8, 10), 0.2, 2],\n]\nif not test_is_quick():\n    _windows += [\n        [(2, 1, 125, 512, 200), (96, 97, 98), (0.4, 0.32, 0), 20],\n        [(2, 1, 10, 512, 200), (96, 97, 98), (0.4, 0.12, 0), 21],\n        [(2, 3, 100, 100, 200), (50, 50, 100), 0, 8],\n    ]\n\nBUFFER_CASES: list = []\nfor x in _windows:\n    for s in (1, 3, 4):\n        for d in (-1, 0, 1):\n            BUFFER_CASES.extend([x, s, d, dev] for dev in itertools.product(*_devices * 3))\n\n\nclass TestSlidingWindowInference(unittest.TestCase):\n    @parameterized.expand(BUFFER_CASES)\n    def test_buffers(self, size_params, buffer_steps, buffer_dim, device_params):\n        def mult_two(patch, *args, **kwargs):\n            return 2.0 * patch\n\n        img_size, roi_size, overlap, sw_batch_size = size_params\n        img_device, device, sw_device = device_params\n        dtype = [torch.float, torch.double][roi_size[0] % 2]  # test different input dtype\n        mode = [\"constant\", \"gaussian\"][img_size[1] % 2]\n        image = torch.randint(0, 255, size=img_size, dtype=dtype, device=img_device)\n        sw = sliding_window_inference(\n            image,\n            roi_size,\n            sw_batch_size,\n            mult_two,\n            overlap,\n            mode=mode,\n            sw_device=sw_device,\n            device=device,\n            buffer_steps=buffer_steps,\n            buffer_dim=buffer_dim,\n        )\n        max_diff = torch.max(torch.abs(image.to(sw) - 0.5 * sw)).item()\n        self.assertGreater(0.001, max_diff)\n\n    @parameterized.expand(TEST_CASES)\n    def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device):\n        n_total = np.prod(image_shape)\n        if mode == \"constant\":\n            inputs = torch.arange(n_total, dtype=torch.float).reshape(*image_shape)\n        else:\n            inputs = torch.ones(*image_shape, dtype=torch.float)\n        if device.type == \"cuda\" and not torch.cuda.is_available():\n            device = torch.device(\"cpu:0\")\n\n        def compute(data):\n            return data + 1\n\n        if mode == \"constant\":\n            expected_val = np.arange(n_total, dtype=np.float32).reshape(*image_shape) + 1.0\n        else:\n            expected_val = np.ones(image_shape, dtype=np.float32) + 1.0\n        result = sliding_window_inference(inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode)\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap, mode)(inputs.to(device), compute)\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n    @parameterized.expand([[x] for x in TEST_TORCH_AND_META_TENSORS])\n    def test_default_device(self, data_type):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device)\n        inputs = list_data_collate([inputs])  # make a proper batch\n        roi_shape = (4, 10, 7)\n        sw_batch_size = 10\n\n        def compute(data):\n            return data + 1\n\n        inputs.requires_grad = True\n        result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute)\n        self.assertTrue(result.requires_grad)\n        np.testing.assert_string_equal(inputs.device.type, result.device.type)\n        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val)\n\n    @parameterized.expand(list(itertools.product(TEST_TORCH_AND_META_TENSORS, (\"cpu\", \"cuda\"), (\"cpu\", \"cuda\", None))))\n    @skip_if_no_cuda\n    def test_sw_device(self, data_type, device, sw_device):\n        inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device)\n        inputs = list_data_collate([inputs])  # make a proper batch\n        roi_shape = (4, 10, 7)\n        sw_batch_size = 10\n\n        def compute(data):\n            self.assertEqual(data.device.type, sw_device or device)\n            return data + torch.tensor(1, device=sw_device or device)\n\n        result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, sw_device=sw_device, device=\"cpu\")\n        np.testing.assert_string_equal(\"cpu\", result.device.type)\n        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n    def test_sigma(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 7, 7)).to(device=device)\n        roi_shape = (3, 3)\n        sw_batch_size = 10\n\n        class _Pred:\n            add = 1\n\n            def compute(self, data):\n                self.add += 1\n                return data + self.add\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            _Pred().compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"constant\",\n            sigma_scale=1.0,\n        )\n\n        expected = np.array(\n            [\n                [\n                    [\n                        [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],\n                        [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],\n                        [3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333],\n                        [3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667],\n                        [4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333],\n                        [4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000],\n                        [5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000],\n                    ]\n                ]\n            ]\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            _Pred().compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"gaussian\",\n            sigma_scale=1.0,\n            progress=has_tqdm,\n        )\n        expected = np.array(\n            [\n                [\n                    [\n                        [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],\n                        [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],\n                        [3.3271625, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271625],\n                        [3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377],\n                        [4.3271623, 4.3271623, 4.3271627, 4.3271627, 4.3271627, 4.3271623, 4.3271623],\n                        [4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757],\n                        [4.9999995, 5.0, 5.0, 5.0, 5.0, 5.0, 4.9999995],\n                    ]\n                ]\n            ]\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=1.0)(\n            inputs, _Pred().compute\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=[1.0, 1.0])(\n            inputs, _Pred().compute\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=[1.0, 1.0], cache_roi_weight_map=True\n        )(inputs, _Pred().compute)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_cval(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 3, 3)).to(device=device)\n        roi_shape = (5, 5)\n        sw_batch_size = 10\n\n        def compute(data):\n            return data + data.sum()\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"constant\",\n            sigma_scale=1.0,\n        )\n        expected = np.ones((1, 1, 3, 3)) * -6.0\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1)(inputs, compute)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_args_kwargs(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 3, 3)).to(device=device)\n        t1 = torch.ones(1).to(device=device)\n        t2 = torch.ones(1).to(device=device)\n        roi_shape = (5, 5)\n        sw_batch_size = 10\n\n        def compute(data, test1, test2):\n            return data + test1 + test2\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n            None,\n            None,\n            0,\n            False,\n            t1,\n            test2=t2,\n        )\n        expected = np.ones((1, 1, 3, 3)) + 2.0\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute, t1, test2=t2)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInfererAdapt(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute, t1, test2=t2)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_multioutput(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 6, 20, 20)).to(device=device)\n        roi_shape = (8, 8)\n        sw_batch_size = 10\n\n        def compute(data):\n            return data + 1, data[:, ::3, ::2, ::2] + 2, data[:, ::2, ::4, ::4] + 3\n\n        def compute_dict(data):\n            return {1: data + 1, 2: data[:, ::3, ::2, ::2] + 2, 3: data[:, ::2, ::4, ::4] + 3}\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n        )\n        result_dict = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute_dict,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n        )\n        expected = (np.ones((1, 6, 20, 20)) + 1, np.ones((1, 2, 10, 10)) + 2, np.ones((1, 3, 5, 5)) + 3)\n        expected_dict = {1: np.ones((1, 6, 20, 20)) + 1, 2: np.ones((1, 2, 10, 10)) + 2, 3: np.ones((1, 3, 5, 5)) + 3}\n        for rr, ee in zip(result, expected):\n            np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4)\n        for rr, _ in zip(result_dict, expected_dict):\n            np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)\n\n        result = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute)\n        for rr, ee in zip(result, expected):\n            np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4)\n\n        result_dict = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute_dict)\n        for rr, _ in zip(result_dict, expected_dict):\n            np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)\n\n    def test_strict_shape_validation(self):\n        \"\"\"Test strict shape validation to ensure inputs match roi_size dimensions.\"\"\"\n        device = \"cpu\"\n        roi_size = (16, 16, 16)\n        sw_batch_size = 4\n\n        def predictor(data):\n            return data\n\n        # Case 1: Input has fewer dimensions than expected (e.g., missing Batch or Channel)\n        # 3D roi_size requires 5D input (B, C, D, H, W), giving 4D here.\n        inputs_4d = torch.randn((1, 16, 16, 16), device=device)\n        with self.assertRaisesRegex(ValueError, \"Inputs must have 5 dimensions\"):\n            sliding_window_inference(inputs_4d, roi_size, sw_batch_size, predictor)\n\n        # Case 2: Input is 3D (missing Batch AND Channel)\n        inputs_3d = torch.randn((16, 16, 16), device=device)\n        with self.assertRaisesRegex(ValueError, \"Inputs must have 5 dimensions\"):\n            sliding_window_inference(inputs_3d, roi_size, sw_batch_size, predictor)\n\n\nclass TestSlidingWindowInferenceCond(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_sliding_window_default(self, image_shape, roi_shape, sw_batch_size, overlap, mode, device):\n        n_total = np.prod(image_shape)\n        if mode == \"constant\":\n            inputs = torch.arange(n_total, dtype=torch.float).reshape(*image_shape)\n        else:\n            inputs = torch.ones(*image_shape, dtype=torch.float)\n        if device.type == \"cuda\" and not torch.cuda.is_available():\n            device = torch.device(\"cpu:0\")\n\n        # condition\n        condition = torch.ones(*image_shape, dtype=torch.float)\n\n        def compute(data, condition):\n            return data + condition\n\n        if mode == \"constant\":\n            expected_val = np.arange(n_total, dtype=np.float32).reshape(*image_shape) + 1.0\n        else:\n            expected_val = np.ones(image_shape, dtype=np.float32) + 1.0\n\n        result = sliding_window_inference(\n            inputs.to(device), roi_shape, sw_batch_size, compute, overlap, mode=mode, condition=condition.to(device)\n        )\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap, mode)(\n            inputs.to(device), compute, condition=condition.to(device)\n        )\n        np.testing.assert_string_equal(device.type, result.device.type)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n    @parameterized.expand([[x] for x in TEST_TORCH_AND_META_TENSORS])\n    def test_default_device(self, data_type):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device)\n        condition = torch.ones((3, 16, 15, 7)).to(device=device)\n        inputs = list_data_collate([inputs])  # make a proper batch\n        condition = list_data_collate([condition])  # make a proper batch\n        roi_shape = (4, 10, 7)\n        sw_batch_size = 10\n\n        def compute(data, condition):\n            return data + condition\n\n        inputs.requires_grad = True\n        result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, condition=condition)\n        self.assertTrue(result.requires_grad)\n        np.testing.assert_string_equal(inputs.device.type, result.device.type)\n        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val)\n\n    @parameterized.expand(list(itertools.product(TEST_TORCH_AND_META_TENSORS, (\"cpu\", \"cuda\"), (\"cpu\", \"cuda\", None))))\n    @skip_if_no_cuda\n    def test_sw_device(self, data_type, device, sw_device):\n        inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device)\n        condition = torch.ones((3, 16, 15, 7)).to(device=device)\n        inputs = list_data_collate([inputs])  # make a proper batch\n        condition = list_data_collate([condition])  # make a proper batch\n        roi_shape = (4, 10, 7)\n        sw_batch_size = 10\n\n        def compute(data, condition):\n            self.assertEqual(data.device.type, sw_device or device)\n            self.assertEqual(condition.device.type, sw_device or device)\n\n            return data + condition\n\n        result = sliding_window_inference(\n            inputs, roi_shape, sw_batch_size, compute, sw_device=sw_device, device=\"cpu\", condition=condition\n        )\n        np.testing.assert_string_equal(\"cpu\", result.device.type)\n        expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val)\n\n    def test_sigma(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 7, 7)).to(device=device)\n        roi_shape = (3, 3)\n        sw_batch_size = 10\n\n        class _Pred:\n            add = 1\n\n            def compute(self, data):\n                self.add += 1\n                return data + self.add\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            _Pred().compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"constant\",\n            sigma_scale=1.0,\n        )\n\n        expected = np.array(\n            [\n                [\n                    [\n                        [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],\n                        [3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000, 3.0000],\n                        [3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333, 3.3333],\n                        [3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667, 3.6667],\n                        [4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333, 4.3333],\n                        [4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000, 4.5000],\n                        [5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000, 5.0000],\n                    ]\n                ]\n            ]\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            _Pred().compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"gaussian\",\n            sigma_scale=1.0,\n            progress=has_tqdm,\n        )\n        expected = np.array(\n            [\n                [\n                    [\n                        [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],\n                        [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],\n                        [3.3271625, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271623, 3.3271625],\n                        [3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377, 3.6728377],\n                        [4.3271623, 4.3271623, 4.3271627, 4.3271627, 4.3271627, 4.3271623, 4.3271623],\n                        [4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757, 4.513757],\n                        [4.9999995, 5.0, 5.0, 5.0, 5.0, 5.0, 4.9999995],\n                    ]\n                ]\n            ]\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=1.0)(\n            inputs, _Pred().compute\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=[1.0, 1.0])(\n            inputs, _Pred().compute\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"gaussian\", sigma_scale=[1.0, 1.0], cache_roi_weight_map=True\n        )(inputs, _Pred().compute)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_cval(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 3, 3)).to(device=device)\n        condition = torch.ones((1, 1, 3, 3)).to(device=device)\n        roi_shape = (5, 5)\n        sw_batch_size = 10\n\n        def compute(data, condition):\n            return data + data.sum() + condition\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            overlap=0.5,\n            padding_mode=\"constant\",\n            cval=-1,\n            mode=\"constant\",\n            sigma_scale=1.0,\n            condition=condition,\n        )\n        expected = np.ones((1, 1, 3, 3)) * -6.0 + 1.0\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1)(\n            inputs, compute, condition=condition\n        )\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_args_kwargs(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 1, 3, 3)).to(device=device)\n        condition = torch.ones((1, 1, 3, 3)).to(device=device)\n        t1 = torch.ones(1).to(device=device)\n        t2 = torch.ones(1).to(device=device)\n        roi_shape = (5, 5)\n        sw_batch_size = 10\n\n        def compute(data, test1, test2, condition):\n            return data + test1 + test2 + condition\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n            None,\n            None,\n            0,\n            False,\n            t1,\n            condition=condition,\n            test2=t2,\n        )\n        expected = np.ones((1, 1, 3, 3)) + 3.0\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute, t1, condition=condition, test2=t2)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n        result = SlidingWindowInfererAdapt(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute, t1, condition=condition, test2=t2)\n        np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_multioutput(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu:0\"\n        inputs = torch.ones((1, 6, 20, 20)).to(device=device)\n        condition = torch.ones((1, 6, 20, 20)).to(device=device)\n        roi_shape = (8, 8)\n        sw_batch_size = 10\n\n        def compute(data, condition):\n            return (\n                data + 1 + condition,\n                data[:, ::3, ::2, ::2] + 2 + condition[:, ::3, ::2, ::2],\n                data[:, ::2, ::4, ::4] + 3 + condition[:, ::2, ::4, ::4],\n            )\n\n        def compute_dict(data, condition):\n            return {\n                1: data + 1 + condition,\n                2: data[:, ::3, ::2, ::2] + 2 + condition[:, ::3, ::2, ::2],\n                3: data[:, ::2, ::4, ::4] + 3 + condition[:, ::2, ::4, ::4],\n            }\n\n        result = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n            condition=condition,\n        )\n        result_dict = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute_dict,\n            0.5,\n            \"constant\",\n            1.0,\n            \"constant\",\n            0.0,\n            device,\n            device,\n            has_tqdm,\n            None,\n            condition=condition,\n        )\n        expected = (np.ones((1, 6, 20, 20)) + 2, np.ones((1, 2, 10, 10)) + 3, np.ones((1, 3, 5, 5)) + 4)\n        expected_dict = {1: np.ones((1, 6, 20, 20)) + 2, 2: np.ones((1, 2, 10, 10)) + 3, 3: np.ones((1, 3, 5, 5)) + 4}\n        for rr, ee in zip(result, expected):\n            np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4)\n        for rr, _ in zip(result_dict, expected_dict):\n            np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)\n\n        result = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute, condition=condition)\n        for rr, ee in zip(result, expected):\n            np.testing.assert_allclose(rr.cpu().numpy(), ee, rtol=1e-4)\n\n        result_dict = SlidingWindowInferer(\n            roi_shape, sw_batch_size, overlap=0.5, mode=\"constant\", cval=-1, progress=has_tqdm\n        )(inputs, compute_dict, condition=condition)\n        for rr, _ in zip(result_dict, expected_dict):\n            np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)\n\n    @parameterized.expand([(1,), (4,)])\n    def test_conditioned_branches_and_buffered_parity(self, sw_batch_size):\n        \"\"\"Validate conditioned parity between buffered and non-buffered flows.\n\n        Args:\n            sw_batch_size (int): Sliding-window batch size.\n\n        Returns:\n            None.\n\n        Raises:\n            AssertionError: If device, conditioning alignment, or output parity checks fail.\n        \"\"\"\n        inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8)\n        condition = inputs + 100.0\n        roi_shape = (4, 4)\n\n        def compute(data, condition):\n            \"\"\"Compute output for a conditioned patch.\n\n            Args:\n                data (torch.Tensor): Input patch tensor.\n                condition (torch.Tensor): Conditioning patch tensor aligned to ``data``.\n\n            Returns:\n                torch.Tensor: Element-wise ``data + condition``.\n\n            Raises:\n                AssertionError: If device placement or conditioning alignment checks fail.\n            \"\"\"\n            self.assertEqual(data.device.type, \"cpu\")\n            self.assertEqual(condition.device.type, \"cpu\")\n            torch.testing.assert_close(condition - data, torch.full_like(data, 100.0))\n            return data + condition\n\n        # Non-buffered flow.\n        result_non_buffered = sliding_window_inference(\n            inputs, roi_shape, sw_batch_size, compute, overlap=0.5, mode=\"constant\", condition=condition\n        )\n        # Buffered flow; should match the non-buffered output.\n        result_buffered = sliding_window_inference(\n            inputs,\n            roi_shape,\n            sw_batch_size,\n            compute,\n            overlap=0.5,\n            mode=\"constant\",\n            condition=condition,\n            buffer_steps=2,\n            buffer_dim=0,\n        )\n\n        expected = inputs + condition\n        torch.testing.assert_close(result_non_buffered, expected)\n        torch.testing.assert_close(result_buffered, expected)\n        torch.testing.assert_close(result_buffered, result_non_buffered)\n\n\nclass TestSlidingWindowUtils(unittest.TestCase):\n    \"\"\"Tests for low-level sliding-window utility helpers.\n\n    Args:\n        None.\n\n    Returns:\n        None.\n\n    Raises:\n        None.\n    \"\"\"\n\n    def test_compute_coords_accepts_list_indices(self):\n        \"\"\"Ensure ``_compute_coords`` handles list-based index containers.\n\n        Args:\n            None.\n\n        Returns:\n            None.\n\n        Raises:\n            AssertionError: If computed output placement differs from expected placement.\n        \"\"\"\n        out = torch.zeros((1, 1, 12, 12), dtype=torch.float)\n        patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4)\n        coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]]\n\n        _compute_coords(coords=coords, z_scale=[2.0, 2.0], out=out, patch=patch)\n\n        expected = torch.zeros_like(out)\n        expected[0, 0, 2:6, 4:8] = patch[0, 0]\n        torch.testing.assert_close(out, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_sliding_window_splitter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\nfrom torch.nn.functional import pad\n\nfrom monai.inferers import SlidingWindowSplitter\nfrom tests.test_utils import assert_allclose\n\n# ----------------------------------------------------------------------------\n# Tensor test cases\n# ----------------------------------------------------------------------------\n# random int tensor (0, 255)\nTENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)\n\n# random int tensor (0, 255) with artifacts at [..., :2, 2:]\nTENSOR_4x4_artifact = TENSOR_4x4.clone()\nTENSOR_4x4_artifact[..., :2, 2:] = 512.0\n\n# no-overlapping 2x2\nTEST_CASE_TENSOR_0 = [\n    TENSOR_4x4,\n    {\"patch_size\": (2, 2), \"overlap\": 0.0},\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n]\n\n# no-overlapping 3x3 with pad\nTEST_CASE_TENSOR_1 = [\n    TENSOR_4x4,\n    {\"patch_size\": (3, 3), \"overlap\": 0.0, \"pad_mode\": \"constant\"},\n    [\n        (TENSOR_4x4[..., :3, :3], (0, 0)),\n        (pad(TENSOR_4x4[..., :3, 3:], (0, 2)), (0, 3)),\n        (pad(TENSOR_4x4[..., 3:, :3], (0, 0, 0, 2)), (3, 0)),\n        (pad(TENSOR_4x4[..., 3:, 3:], (0, 2, 0, 2)), (3, 3)),\n    ],\n]\n\n# overlapping 2x2 with fraction\nTEST_CASE_TENSOR_2 = [\n    TENSOR_4x4,\n    {\"patch_size\": (2, 2), \"overlap\": (0.5, 0.5)},\n    [\n        (TENSOR_4x4[..., 0:2, 0:2], (0, 0)),\n        (TENSOR_4x4[..., 0:2, 1:3], (0, 1)),\n        (TENSOR_4x4[..., 0:2, 2:4], (0, 2)),\n        (TENSOR_4x4[..., 1:3, 0:2], (1, 0)),\n        (TENSOR_4x4[..., 1:3, 1:3], (1, 1)),\n        (TENSOR_4x4[..., 1:3, 2:4], (1, 2)),\n        (TENSOR_4x4[..., 2:4, 0:2], (2, 0)),\n        (TENSOR_4x4[..., 2:4, 1:3], (2, 1)),\n        (TENSOR_4x4[..., 2:4, 2:4], (2, 2)),\n    ],\n]\n\n# overlapping 3x3 with fraction (non-divisible)\nTEST_CASE_TENSOR_3 = [\n    TENSOR_4x4,\n    {\"patch_size\": (3, 3), \"overlap\": 2.0 / 3.0},\n    [\n        (TENSOR_4x4[..., :3, :3], (0, 0)),\n        (TENSOR_4x4[..., :3, 1:], (0, 1)),\n        (TENSOR_4x4[..., 1:, :3], (1, 0)),\n        (TENSOR_4x4[..., 1:, 1:], (1, 1)),\n    ],\n]\n\n# overlapping 2x2 with number of pixels\nTEST_CASE_TENSOR_4 = [\n    TENSOR_4x4,\n    {\"patch_size\": (2, 2), \"overlap\": (1, 1)},\n    [\n        (TENSOR_4x4[..., 0:2, 0:2], (0, 0)),\n        (TENSOR_4x4[..., 0:2, 1:3], (0, 1)),\n        (TENSOR_4x4[..., 0:2, 2:4], (0, 2)),\n        (TENSOR_4x4[..., 1:3, 0:2], (1, 0)),\n        (TENSOR_4x4[..., 1:3, 1:3], (1, 1)),\n        (TENSOR_4x4[..., 1:3, 2:4], (1, 2)),\n        (TENSOR_4x4[..., 2:4, 0:2], (2, 0)),\n        (TENSOR_4x4[..., 2:4, 1:3], (2, 1)),\n        (TENSOR_4x4[..., 2:4, 2:4], (2, 2)),\n    ],\n]\n\n# overlapping 3x3 with number of pixels (non-divisible)\nTEST_CASE_TENSOR_5 = [\n    TENSOR_4x4,\n    {\"patch_size\": (3, 3), \"overlap\": 2},\n    [\n        (TENSOR_4x4[..., :3, :3], (0, 0)),\n        (TENSOR_4x4[..., :3, 1:], (0, 1)),\n        (TENSOR_4x4[..., 1:, :3], (1, 0)),\n        (TENSOR_4x4[..., 1:, 1:], (1, 1)),\n    ],\n]\n# non-overlapping 2x2 with positive offset\nTEST_CASE_TENSOR_6 = [\n    TENSOR_4x4,\n    {\"patch_size\": (2, 2), \"offset\": 1},\n    [\n        (TENSOR_4x4[..., 1:3, 1:3], (1, 1)),\n        (pad(TENSOR_4x4[..., 1:3, 3:], (0, 1)), (1, 3)),\n        (pad(TENSOR_4x4[..., 3:, 1:3], (0, 0, 0, 1)), (3, 1)),\n        (pad(TENSOR_4x4[..., 3:, 3:], (0, 1, 0, 1)), (3, 3)),\n    ],\n]\n\n# non-overlapping 2x2 with negative offset\nTEST_CASE_TENSOR_7 = [\n    TENSOR_4x4,\n    {\"patch_size\": (2, 2), \"offset\": -1},\n    [\n        (pad(TENSOR_4x4[..., :1, :1], (1, 0, 1, 0)), (-1, -1)),\n        (pad(TENSOR_4x4[..., :1, 1:3], (0, 0, 1, 0)), (-1, 1)),\n        (pad(TENSOR_4x4[..., :1, 3:], (0, 1, 1, 0)), (-1, 3)),\n        (pad(TENSOR_4x4[..., 1:3, :1], (1, 0)), (1, -1)),\n        (TENSOR_4x4[..., 1:3, 1:3], (1, 1)),\n        (pad(TENSOR_4x4[..., 1:3, 3:], (0, 1)), (1, 3)),\n        (pad(TENSOR_4x4[..., 3:, :1], (1, 0, 0, 1)), (3, -1)),\n        (pad(TENSOR_4x4[..., 3:, 1:3], (0, 0, 0, 1)), (3, 1)),\n        (pad(TENSOR_4x4[..., 3:, 3:], (0, 1, 0, 1)), (3, 3)),\n    ],\n]\n\n# non-overlapping 2x2 with positive offset and no padding\nTEST_CASE_TENSOR_8 = [\n    TENSOR_4x4,\n    {\"patch_size\": (2, 2), \"offset\": 1, \"pad_mode\": None},\n    [(TENSOR_4x4[..., 1:3, 1:3], (1, 1))],\n]\n\n\n# ----------------------------------------------------------------------------\n# Filtering function test cases\n# ----------------------------------------------------------------------------\ndef gen_filter(filter_type, value=None):\n    \"\"\" \"Generate patch filtering function for testing\"\"\"\n    if filter_type.lower() == \"high\":\n\n        def my_filter(patch, location):\n            if torch.any(patch > value):\n                return True\n            return False\n\n    elif filter_type.lower() == \"low\":\n\n        def my_filter(patch, location):\n            if torch.any(patch < value):\n                return True\n            return False\n\n    elif filter_type.lower() == \"location\":\n\n        def my_filter(patch, location):\n            if location in value:\n                return True\n            return False\n\n    return my_filter\n\n\nTEST_CASE_FILTER_FN_0 = [\n    TENSOR_4x4_artifact,\n    {\"patch_size\": (2, 2), \"filter_fn\": gen_filter(\"low\", 256)},\n    [\n        (TENSOR_4x4_artifact[..., :2, :2], (0, 0)),\n        (TENSOR_4x4_artifact[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4_artifact[..., 2:, 2:], (2, 2)),\n    ],\n]\n\nTEST_CASE_FILTER_FN_1 = [\n    TENSOR_4x4_artifact,\n    {\"patch_size\": (2, 2), \"filter_fn\": gen_filter(\"high\", 256)},\n    [(TENSOR_4x4_artifact[..., :2, 2:], (0, 2))],\n]\n\nTEST_CASE_FILTER_FN_2 = [\n    TENSOR_4x4_artifact,\n    {\"patch_size\": (2, 2), \"filter_fn\": gen_filter(\"location\", [(2, 2), (2, 0)])},\n    [(TENSOR_4x4_artifact[..., 2:, :2], (2, 0)), (TENSOR_4x4_artifact[..., 2:, 2:], (2, 2))],\n]\n\n\n# ----------------------------------------------------------------------------\n# Error test cases\n# ----------------------------------------------------------------------------\ndef extra_parameter_filter(patch, location, extra):\n    return\n\n\ndef missing_parameter_filter(patch):\n    return\n\n\n# invalid overlap: float 1.0\nTEST_CASE_ERROR_0 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"overlap\": 1.0}, ValueError]\n# invalid overlap: negative float\nTEST_CASE_ERROR_1 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"overlap\": -0.1}, ValueError]\n# invalid overlap: negative integer\nTEST_CASE_ERROR_2 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"overlap\": -1}, ValueError]\n# invalid overlap: integer larger than patch size\nTEST_CASE_ERROR_3 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"overlap\": 3}, ValueError]\n\n# invalid offset: positive and larger than image size\nTEST_CASE_ERROR_4 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"offset\": 4}, ValueError]\n# invalid offset: negative and larger than patch size (in magnitude)\nTEST_CASE_ERROR_5 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"offset\": -3, \"pad_mode\": \"constant\"}, ValueError]\n# invalid offset: negative and no padding\nTEST_CASE_ERROR_6 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"offset\": -1, \"pad_mode\": None}, ValueError]\n\n# invalid filter function: with more than two positional parameters\nTEST_CASE_ERROR_7 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"filter_fn\": extra_parameter_filter}, ValueError]\n# invalid filter function: with less than two positional parameters\nTEST_CASE_ERROR_8 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"filter_fn\": missing_parameter_filter}, ValueError]\n# invalid filter function: non-callable\nTEST_CASE_ERROR_9 = [TENSOR_4x4, {\"patch_size\": (2, 2), \"filter_fn\": 1}, ValueError]\n\n\nclass SlidingWindowSplitterTests(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_TENSOR_0,\n            TEST_CASE_TENSOR_1,\n            TEST_CASE_TENSOR_2,\n            TEST_CASE_TENSOR_3,\n            TEST_CASE_TENSOR_4,\n            TEST_CASE_TENSOR_5,\n            TEST_CASE_TENSOR_6,\n            TEST_CASE_TENSOR_7,\n            TEST_CASE_TENSOR_8,\n            TEST_CASE_FILTER_FN_0,\n            TEST_CASE_FILTER_FN_1,\n            TEST_CASE_FILTER_FN_2,\n        ]\n    )\n    def test_split_patches_tensor(self, image, arguments, expected):\n        patches = SlidingWindowSplitter(**arguments)(image)\n        patches = list(patches)\n        self.assertEqual(len(patches), len(expected))\n        for p, e in zip(patches, expected):\n            assert_allclose(p[0], e[0])\n            self.assertTupleEqual(p[1], e[1])\n\n    @parameterized.expand(\n        [\n            TEST_CASE_ERROR_0,\n            TEST_CASE_ERROR_1,\n            TEST_CASE_ERROR_2,\n            TEST_CASE_ERROR_3,\n            TEST_CASE_ERROR_4,\n            TEST_CASE_ERROR_5,\n            TEST_CASE_ERROR_6,\n            TEST_CASE_ERROR_7,\n            TEST_CASE_ERROR_8,\n            TEST_CASE_ERROR_9,\n        ]\n    )\n    def test_split_patches_errors(self, image, arguments, expected_error):\n        with self.assertRaises(expected_error):\n            patches = SlidingWindowSplitter(**arguments)(image)\n            patches = list(patches)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_wsi_sliding_window_splitter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import CuCIMWSIReader, ImageReader, OpenSlideWSIReader, WSIReader\nfrom monai.inferers import WSISlidingWindowSplitter\nfrom tests.test_utils import download_url_or_skip_test, optional_import, testing_data_config\n\ncucim, has_cucim = optional_import(\"cucim\")\nhas_cucim = has_cucim and hasattr(cucim, \"CuImage\")\n_, has_osl = optional_import(\"openslide\")\n\nWSI_READER_STR = None\nWSI_READER_CLASS: type[CuCIMWSIReader] | type[OpenSlideWSIReader] | None = None\nif has_cucim:\n    WSI_READER_STR = \"cuCIM\"\n    WSI_READER_CLASS = CuCIMWSIReader\nelif has_osl:\n    WSI_READER_STR = \"OpenSlide\"\n    WSI_READER_CLASS = OpenSlideWSIReader\n\nWSI_GENERIC_TIFF_KEY = \"wsi_generic_tiff\"\nTESTS_PATH = Path(__file__).parents[1]\nWSI_GENERIC_TIFF_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"temp_{WSI_GENERIC_TIFF_KEY}.tiff\")\n\nHEIGHT = 32914\nWIDTH = 46000\n\n# ----------------------------------------------------------------------------\n# WSI test cases\n# ----------------------------------------------------------------------------\n\nTEST_CASE_WSI_0_BASE = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"reader\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n]\n\nTEST_CASE_WSI_1_BASE = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"level\": 1, \"reader\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n]\n\n# Check readers\nif WSI_READER_STR is not None:\n    TEST_CASE_WSI_2_READER = [\n        WSI_GENERIC_TIFF_PATH,\n        {\"patch_size\": (1000, 1000), \"reader\": WSIReader(backend=WSI_READER_STR)},\n        {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n    ]\nelse:\n    TEST_CASE_WSI_2_READER = []\nTEST_CASE_WSI_3_READER = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"overlap\": 0.0, \"reader\": WSIReader, \"backend\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n]\nTEST_CASE_WSI_4_READER = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"overlap\": 0.0, \"reader\": WSI_READER_CLASS},\n    {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n]\nTEST_CASE_WSI_5_READER = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"overlap\": 0.0, \"level\": 1, \"reader\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n]\n\n# Check overlaps\nTEST_CASE_WSI_6_OVERLAP = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"overlap\": 0.0, \"reader\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 1000), (0, 2000), (0, 3000)]},\n]\nTEST_CASE_WSI_7_OVERLAP = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"overlap\": 0.5, \"reader\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 500), (0, 1000), (0, 1500)]},\n]\nTEST_CASE_WSI_8_OVERLAP = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"overlap\": 0.999, \"reader\": WSI_READER_STR},\n    {\"locations\": [(0, 0), (0, 1), (0, 2), (0, 3)]},\n]\n\n\n# Filtering functions test cases\ndef gen_location_filter(locations):\n    def my_filter(patch, loc):\n        if loc in locations:\n            return False\n        return True\n\n    return my_filter\n\n\nTEST_CASE_WSI_9_FILTER = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (1000, 1000), \"reader\": WSI_READER_STR, \"filter_fn\": gen_location_filter([(0, 0), (0, 2000)])},\n    {\"locations\": [(0, 1000), (0, 3000)]},\n]\n\n\n# ----------------------------------------------------------------------------\n# Error test cases\n# ----------------------------------------------------------------------------\ndef extra_parameter_filter(patch, location, extra):\n    return\n\n\ndef missing_parameter_filter(patch):\n    return\n\n\n# invalid overlap: float 1.0\nTEST_CASE_ERROR_0 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"overlap\": 1.0, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n# invalid overlap: negative float\nTEST_CASE_ERROR_1 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"overlap\": -0.1, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n# invalid overlap: negative integer\nTEST_CASE_ERROR_2 = [WSI_GENERIC_TIFF_PATH, {\"patch_size\": (2, 2), \"overlap\": -1, \"reader\": WSI_READER_STR}, ValueError]\n# invalid overlap: integer larger than patch size\nTEST_CASE_ERROR_3 = [WSI_GENERIC_TIFF_PATH, {\"patch_size\": (2, 2), \"overlap\": 3, \"reader\": WSI_READER_STR}, ValueError]\n\n# invalid offset: positive and larger than image size\nTEST_CASE_ERROR_4 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"offset\": WIDTH, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n# invalid offset: negative and larger than patch size (in magnitude)\nTEST_CASE_ERROR_5 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"offset\": -3, \"pad_mode\": \"constant\", \"reader\": WSI_READER_STR},\n    ValueError,\n]\n# invalid offset: negative and no padding\nTEST_CASE_ERROR_6 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"pad_mode\": None, \"offset\": -1, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n\n# invalid filter function: with more than two positional parameters\nTEST_CASE_ERROR_7 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"filter_fn\": extra_parameter_filter, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n# invalid filter function: with less than two positional parameters\nTEST_CASE_ERROR_8 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"filter_fn\": missing_parameter_filter, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n# invalid filter function: non-callable\nTEST_CASE_ERROR_9 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"patch_size\": (2, 2), \"filter_fn\": 1, \"reader\": WSI_READER_STR},\n    ValueError,\n]\n\n# invalid reader\nTEST_CASE_ERROR_10 = [WSI_GENERIC_TIFF_PATH, {\"patch_size\": (2, 2), \"reader\": ImageReader}, ValueError]\n\n\n@skipUnless(WSI_READER_STR, \"Requires cucim or openslide!\")\ndef setUpModule():\n    download_url_or_skip_test(\n        testing_data_config(\"images\", WSI_GENERIC_TIFF_KEY, \"url\"),\n        WSI_GENERIC_TIFF_PATH,\n        hash_type=testing_data_config(\"images\", WSI_GENERIC_TIFF_KEY, \"hash_type\"),\n        hash_val=testing_data_config(\"images\", WSI_GENERIC_TIFF_KEY, \"hash_val\"),\n    )\n\n\nclass WSISlidingWindowSplitterTests(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_WSI_0_BASE,\n            TEST_CASE_WSI_1_BASE,\n            TEST_CASE_WSI_2_READER,\n            TEST_CASE_WSI_3_READER,\n            TEST_CASE_WSI_4_READER,\n            TEST_CASE_WSI_5_READER,\n            TEST_CASE_WSI_6_OVERLAP,\n            TEST_CASE_WSI_7_OVERLAP,\n            TEST_CASE_WSI_8_OVERLAP,\n            TEST_CASE_WSI_9_FILTER,\n        ]\n    )\n    def test_split_patches_wsi(self, filepath, arguments, expected):\n        patches = WSISlidingWindowSplitter(**arguments)(filepath)\n        for sample, expected_loc in zip(patches, expected[\"locations\"]):\n            patch = sample[0]\n            loc = sample[1]\n            self.assertTrue(isinstance(patch, torch.Tensor))\n            self.assertTupleEqual(patch.shape[2:], arguments[\"patch_size\"])\n            self.assertTrue(isinstance(loc, tuple))\n            self.assertTupleEqual(loc, expected_loc)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_ERROR_0,\n            TEST_CASE_ERROR_1,\n            TEST_CASE_ERROR_2,\n            TEST_CASE_ERROR_3,\n            TEST_CASE_ERROR_4,\n            TEST_CASE_ERROR_5,\n            TEST_CASE_ERROR_6,\n            TEST_CASE_ERROR_7,\n            TEST_CASE_ERROR_8,\n            TEST_CASE_ERROR_9,\n            TEST_CASE_ERROR_10,\n        ]\n    )\n    def test_split_patches_errors(self, image, arguments, expected_error):\n        with self.assertRaises(expected_error):\n            patches = WSISlidingWindowSplitter(**arguments)(image)\n            patches = list(patches)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/inferers/test_zarr_avg_merger.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom tempfile import TemporaryDirectory\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.nn.functional import pad\n\nfrom monai.inferers import ZarrAvgMerger\nfrom monai.utils import get_package_version, optional_import, version_geq\nfrom tests.test_utils import assert_allclose\n\nzarr, has_zarr = optional_import(\"zarr\")\nnumcodecs, has_numcodecs = optional_import(\"numcodecs\")\n\nTENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32)\nTENSOR_4x4_WITH_NAN = TENSOR_4x4.clone()\nTENSOR_4x4_WITH_NAN[..., 2:, 2:] = float(\"nan\")\n\n# no-overlapping 2x2\nTEST_CASE_0_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# overlapping 2x2\nTEST_CASE_1_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [\n        (TENSOR_4x4[..., 0:2, 0:2], (0, 0)),\n        (TENSOR_4x4[..., 0:2, 1:3], (0, 1)),\n        (TENSOR_4x4[..., 0:2, 2:4], (0, 2)),\n        (TENSOR_4x4[..., 1:3, 0:2], (1, 0)),\n        (TENSOR_4x4[..., 1:3, 1:3], (1, 1)),\n        (TENSOR_4x4[..., 1:3, 2:4], (1, 2)),\n        (TENSOR_4x4[..., 2:4, 0:2], (2, 0)),\n        (TENSOR_4x4[..., 2:4, 1:3], (2, 1)),\n        (TENSOR_4x4[..., 2:4, 2:4], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# overlapping 3x3 (non-divisible)\nTEST_CASE_2_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [\n        (TENSOR_4x4[..., :3, :3], (0, 0)),\n        (TENSOR_4x4[..., :3, 1:], (0, 1)),\n        (TENSOR_4x4[..., 1:, :3], (1, 0)),\n        (TENSOR_4x4[..., 1:, 1:], (1, 1)),\n    ],\n    TENSOR_4x4,\n]\n\n#  overlapping 2x2 with NaN values\nTEST_CASE_3_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4_WITH_NAN.shape),\n    [\n        (TENSOR_4x4_WITH_NAN[..., 0:2, 0:2], (0, 0)),\n        (TENSOR_4x4_WITH_NAN[..., 0:2, 1:3], (0, 1)),\n        (TENSOR_4x4_WITH_NAN[..., 0:2, 2:4], (0, 2)),\n        (TENSOR_4x4_WITH_NAN[..., 1:3, 0:2], (1, 0)),\n        (TENSOR_4x4_WITH_NAN[..., 1:3, 1:3], (1, 1)),\n        (TENSOR_4x4_WITH_NAN[..., 1:3, 2:4], (1, 2)),\n        (TENSOR_4x4_WITH_NAN[..., 2:4, 0:2], (2, 0)),\n        (TENSOR_4x4_WITH_NAN[..., 2:4, 1:3], (2, 1)),\n        (TENSOR_4x4_WITH_NAN[..., 2:4, 2:4], (2, 2)),\n    ],\n    TENSOR_4x4_WITH_NAN,\n]\n\n# non-overlapping 2x2 with missing patch\nTEST_CASE_4_DEFAULT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape),\n    [(TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), (TENSOR_4x4[..., 2:, :2], (2, 0))],\n    TENSOR_4x4_WITH_NAN,\n]\n\n# with value_dtype set to half precision\nTEST_CASE_5_VALUE_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, value_dtype=np.float16),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n# with count_dtype set to int32\nTEST_CASE_6_COUNT_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, count_dtype=np.int32),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n# with both value_dtype, count_dtype set to double precision\nTEST_CASE_7_COUNT_VALUE_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, value_dtype=np.float64, count_dtype=np.float64),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n# with both value_dtype, count_dtype set to double precision\nTEST_CASE_8_DTYPE = [\n    dict(merged_shape=TENSOR_4x4.shape, dtype=np.float64),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# shape larger than what is covered by patches\nTEST_CASE_9_LARGER_SHAPE = [\n    dict(merged_shape=(2, 3, 4, 6)),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    pad(TENSOR_4x4, (0, 2), value=float(\"nan\")),\n]\n\n# explicit directory store, defer creating the store until test time by using the placeholder value \"directory_store\"\nTEST_CASE_10_DIRECTORY_STORE = [\n    dict(merged_shape=TENSOR_4x4.shape, store=\"directory_store\"),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# memory store for all arrays\nTEST_CASE_11_MEMORY_STORE = [\n    dict(\n        merged_shape=TENSOR_4x4.shape,\n        store=zarr.storage.MemoryStore(),\n        value_store=zarr.storage.MemoryStore(),\n        count_store=zarr.storage.MemoryStore(),\n    ),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# explicit chunk size\nTEST_CASE_12_CHUNKS = [\n    dict(merged_shape=TENSOR_4x4.shape, chunks=(1, 1, 2, 2)),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# Define zarr v3 codec configurations with proper bytes codec\nZARR_V3_LZ4_CODECS = [{\"name\": \"bytes\", \"configuration\": {}}, {\"name\": \"blosc\", \"configuration\": {\"cname\": \"lz4\"}}]\n\nZARR_V3_PICKLE_CODECS = [{\"name\": \"bytes\", \"configuration\": {}}, {\"name\": \"blosc\", \"configuration\": {\"cname\": \"zstd\"}}]\n\nZARR_V3_LZMA_CODECS = [{\"name\": \"bytes\", \"configuration\": {}}, {\"name\": \"blosc\", \"configuration\": {\"cname\": \"zlib\"}}]\n\n# test for LZ4 compressor (zarr v2) or codecs (zarr v3)\nTEST_CASE_13_COMPRESSOR_LZ4 = [\n    (\n        dict(merged_shape=TENSOR_4x4.shape, compressor=\"LZ4\")\n        if not version_geq(get_package_version(\"zarr\"), \"3.0.0\")\n        else dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZ4_CODECS)\n    ),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test for pickle compressor (zarr v2) or codecs (zarr v3)\nTEST_CASE_14_COMPRESSOR_PICKLE = [\n    (\n        dict(merged_shape=TENSOR_4x4.shape, compressor=\"Pickle\")\n        if not version_geq(get_package_version(\"zarr\"), \"3.0.0\")\n        else dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_PICKLE_CODECS)\n    ),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test for LZMA compressor (zarr v2) or codecs (zarr v3)\nTEST_CASE_15_COMPRESSOR_LZMA = [\n    (\n        dict(merged_shape=TENSOR_4x4.shape, compressor=\"LZMA\")\n        if not version_geq(get_package_version(\"zarr\"), \"3.0.0\")\n        else dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZMA_CODECS)\n    ),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test with thread locking\nTEST_CASE_16_WITH_LOCK = [\n    dict(merged_shape=TENSOR_4x4.shape, thread_locking=True),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test without thread locking\nTEST_CASE_17_WITHOUT_LOCK = [\n    dict(merged_shape=TENSOR_4x4.shape, thread_locking=False),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test with codecs for zarr v3\nTEST_CASE_18_CODECS = [\n    dict(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZ4_CODECS),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test with value_codecs for zarr v3\nTEST_CASE_19_VALUE_CODECS = [\n    dict(\n        merged_shape=TENSOR_4x4.shape,\n        value_codecs=[{\"name\": \"bytes\", \"configuration\": {}}, {\"name\": \"blosc\", \"configuration\": {\"cname\": \"zstd\"}}],\n    ),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\n# test with count_codecs for zarr v3\nTEST_CASE_20_COUNT_CODECS = [\n    dict(\n        merged_shape=TENSOR_4x4.shape,\n        count_codecs=[{\"name\": \"bytes\", \"configuration\": {}}, {\"name\": \"blosc\", \"configuration\": {\"cname\": \"zlib\"}}],\n    ),\n    [\n        (TENSOR_4x4[..., :2, :2], (0, 0)),\n        (TENSOR_4x4[..., :2, 2:], (0, 2)),\n        (TENSOR_4x4[..., 2:, :2], (2, 0)),\n        (TENSOR_4x4[..., 2:, 2:], (2, 2)),\n    ],\n    TENSOR_4x4,\n]\n\nALL_TESTS = [\n    TEST_CASE_0_DEFAULT_DTYPE,\n    TEST_CASE_1_DEFAULT_DTYPE,\n    TEST_CASE_2_DEFAULT_DTYPE,\n    TEST_CASE_3_DEFAULT_DTYPE,\n    TEST_CASE_4_DEFAULT_DTYPE,\n    TEST_CASE_5_VALUE_DTYPE,\n    TEST_CASE_6_COUNT_DTYPE,\n    TEST_CASE_7_COUNT_VALUE_DTYPE,\n    TEST_CASE_8_DTYPE,\n    TEST_CASE_9_LARGER_SHAPE,\n    TEST_CASE_10_DIRECTORY_STORE,\n    TEST_CASE_11_MEMORY_STORE,\n    TEST_CASE_12_CHUNKS,\n    TEST_CASE_16_WITH_LOCK,\n    TEST_CASE_17_WITHOUT_LOCK,\n    # Add compression/codec tests regardless of zarr version - they're now version-aware\n    TEST_CASE_13_COMPRESSOR_LZ4,\n    TEST_CASE_14_COMPRESSOR_PICKLE,\n    TEST_CASE_15_COMPRESSOR_LZMA,\n]\n\n# Add zarr v3 specific codec tests only when using Zarr version 3.0 or later\nif version_geq(get_package_version(\"zarr\"), \"3.0.0\"):\n    ALL_TESTS += [TEST_CASE_18_CODECS, TEST_CASE_19_VALUE_CODECS, TEST_CASE_20_COUNT_CODECS]\n\n\n@unittest.skipUnless(has_zarr and has_numcodecs, \"Requires zarr (and numcodecs) packages.)\")\nclass ZarrAvgMergerTests(unittest.TestCase):\n\n    def setUp(self):\n        self.orig_settings = np.seterr(divide=\"ignore\", invalid=\"ignore\")\n        self.temp_dir = TemporaryDirectory()\n        self.merged_name = os.path.join(self.temp_dir.name, \"merged.zarr\")\n\n    def tearDown(self):\n        np.seterr(**self.orig_settings)\n        self.temp_dir.cleanup()\n\n    def _get_directory_store(self, base_dir):\n        zarr_path = os.path.join(base_dir, \"test.zarr\")\n\n        if version_geq(get_package_version(\"zarr\"), \"3.0.0\"):\n            directory_store = zarr.storage.LocalStore(zarr_path)\n        else:\n            directory_store = zarr.storage.DirectoryStore(zarr_path)\n\n        return directory_store\n\n    @parameterized.expand(ALL_TESTS)\n    def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):\n        is_zarr_v3 = version_geq(get_package_version(\"zarr\"), \"3.0.0\")\n        codec_reg = numcodecs.registry.codec_registry\n        arguments = dict(arguments)\n\n        # Handle compressor/codecs based on zarr version\n        if \"compressor\" in arguments and is_zarr_v3:\n            # For zarr v3, convert compressor to codecs\n            if arguments[\"compressor\"] != \"default\" and arguments[\"compressor\"] is not None:\n                compressor_name = arguments[\"compressor\"].lower()\n                if compressor_name == \"lz4\":\n                    arguments[\"codecs\"] = ZARR_V3_LZ4_CODECS\n                elif compressor_name == \"pickle\":\n                    arguments[\"codecs\"] = ZARR_V3_PICKLE_CODECS\n                elif compressor_name == \"lzma\":\n                    arguments[\"codecs\"] = ZARR_V3_LZMA_CODECS\n                # Remove compressor as it's not supported in zarr v3\n                del arguments[\"compressor\"]\n        elif \"compressor\" in arguments and not is_zarr_v3:\n            # For zarr v2, use the compressor registry\n            if arguments[\"compressor\"] != \"default\" and arguments[\"compressor\"] is not None:\n                arguments[\"compressor\"] = codec_reg[arguments[\"compressor\"].lower()]()\n\n        # Same for value_compressor\n        if \"value_compressor\" in arguments and is_zarr_v3:\n            if arguments[\"value_compressor\"] != \"default\" and arguments[\"value_compressor\"] is not None:\n                compressor_name = arguments[\"value_compressor\"].lower()\n                if compressor_name == \"lz4\":\n                    arguments[\"value_codecs\"] = ZARR_V3_LZ4_CODECS\n                elif compressor_name == \"pickle\":\n                    arguments[\"value_codecs\"] = ZARR_V3_PICKLE_CODECS\n                elif compressor_name == \"lzma\":\n                    arguments[\"value_codecs\"] = ZARR_V3_LZMA_CODECS\n                del arguments[\"value_compressor\"]\n        elif \"value_compressor\" in arguments and not is_zarr_v3:\n            if arguments[\"value_compressor\"] != \"default\" and arguments[\"value_compressor\"] is not None:\n                arguments[\"value_compressor\"] = codec_reg[arguments[\"value_compressor\"].lower()]()\n\n        # Same for count_compressor\n        if \"count_compressor\" in arguments and is_zarr_v3:\n            if arguments[\"count_compressor\"] != \"default\" and arguments[\"count_compressor\"] is not None:\n                compressor_name = arguments[\"count_compressor\"].lower()\n                if compressor_name == \"lz4\":\n                    arguments[\"count_codecs\"] = ZARR_V3_LZ4_CODECS\n                elif compressor_name == \"pickle\":\n                    arguments[\"count_codecs\"] = ZARR_V3_PICKLE_CODECS\n                elif compressor_name == \"lzma\":\n                    arguments[\"count_codecs\"] = ZARR_V3_LZMA_CODECS\n                del arguments[\"count_compressor\"]\n        elif \"count_compressor\" in arguments and not is_zarr_v3:\n            if arguments[\"count_compressor\"] != \"default\" and arguments[\"count_compressor\"] is not None:\n                arguments[\"count_compressor\"] = codec_reg[arguments[\"count_compressor\"].lower()]()\n\n        # ensure the merged directory is in the temporary directory and not the current directory\n\n        if \"store\" not in arguments:\n            arguments[\"store\"] = self.merged_name\n        elif arguments[\"store\"] == \"directory_store\":\n            arguments[\"store\"] = self._get_directory_store(self.temp_dir.name)  # get store object now\n\n        merger = ZarrAvgMerger(**arguments)\n\n        for pl in patch_locations:\n            merger.aggregate(pl[0], pl[1])\n        output = merger.finalize()\n        if \"value_dtype\" in arguments:\n            self.assertTrue(merger.get_values().dtype, arguments[\"value_dtype\"])\n        if \"count_dtype\" in arguments:\n            self.assertTrue(merger.get_counts().dtype, arguments[\"count_dtype\"])\n        # check for multiple call of finalize\n        self.assertIs(output, merger.finalize())\n        # check if the result is matching the expectation\n        assert_allclose(output[:], expected.numpy())\n\n    def test_zarr_avg_merger_finalized_error(self):\n        with self.assertRaises(ValueError):\n            merger = ZarrAvgMerger(merged_shape=(1, 3, 2, 3), store=self.merged_name)\n            merger.finalize()\n            merger.aggregate(torch.zeros(1, 3, 2, 2), (3, 3))\n\n    def test_zarr_avg_merge_none_merged_shape_error(self):\n        with self.assertRaises(ValueError):\n            ZarrAvgMerger(merged_shape=None, store=self.merged_name)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/integration/test_auto3dseg_ensemble.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nfrom monai.apps.auto3dseg import (\n    AlgoEnsembleBestByFold,\n    AlgoEnsembleBestN,\n    AlgoEnsembleBuilder,\n    BundleGen,\n    DataAnalyzer,\n    EnsembleRunner,\n)\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import create_test_image_3d\nfrom monai.transforms import SaveImage\nfrom monai.utils import check_parent_dir, optional_import, set_determinism\nfrom monai.utils.enums import AlgoKeys\nfrom tests.test_utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\nnum_images_perfold = max(torch.cuda.device_count(), 4)\nnum_images_per_batch = 2\n\nfake_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": f\"imagesTs/ts_image_{idx:03d}.nii.gz\"} for idx in range(num_images_perfold)],\n    \"training\": [\n        {\n            \"fold\": f,\n            \"image\": f\"imagesTr/tr_image_{(f * num_images_perfold + idx):03d}.nii.gz\",\n            \"label\": f\"labelsTr/tr_label_{(f * num_images_perfold + idx):03d}.nii.gz\",\n        }\n        for f in range(num_images_per_batch + 1)\n        for idx in range(num_images_perfold)\n    ],\n}\n\ntrain_param = (\n    {\n        \"num_images_per_batch\": num_images_per_batch,\n        \"num_epochs\": 2,\n        \"num_epochs_per_validation\": 1,\n        \"num_warmup_epochs\": 1,\n        \"use_pretrain\": False,\n        \"pretrained_path\": \"\",\n        \"determ\": True,\n    }\n    if torch.cuda.is_available()\n    else {}\n)\n\npred_param = {\n    \"files_slices\": slice(0, 1),\n    \"mode\": \"mean\",\n    \"sigmoid\": True,\n    \"algo_spec_params\": {\"segresnet\": {\"network#init_filters\": 8}, \"swinunetr\": {\"network#feature_size\": 12}},\n}\n\n\ndef create_sim_data(dataroot, sim_datalist, sim_dim, **kwargs):\n    \"\"\"\n    Create simulated data using create_test_image_3d.\n\n    Args:\n        dataroot: data directory path that hosts the \"nii.gz\" image files.\n        sim_datalist: a list of data to create.\n        sim_dim: the image sizes, e.g. a tuple of (64, 64, 64).\n    \"\"\"\n    if not os.path.isdir(dataroot):\n        os.makedirs(dataroot)\n\n    # Generate a fake dataset\n    for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n        im, seg = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs)\n        nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n        image_fpath = os.path.join(dataroot, d[\"image\"])\n        check_parent_dir(image_fpath)\n        nib.save(nib_image, image_fpath)\n\n        if \"label\" in d:\n            nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n            label_fpath = os.path.join(dataroot, d[\"label\"])\n            check_parent_dir(label_fpath)\n            nib.save(nib_image, label_fpath)\n\n\n@skip_if_quick\n@skip_if_no_cuda\n@unittest.skipIf(not has_tb, \"no tensorboard summary writer\")\nclass TestEnsembleBuilder(unittest.TestCase):\n    def setUp(self) -> None:\n        set_determinism(0)\n        self.test_dir = tempfile.TemporaryDirectory()\n        test_path = self.test_dir.name\n\n        dataroot = os.path.join(test_path, \"dataroot\")\n        work_dir = os.path.join(test_path, \"workdir\")\n\n        da_output_yaml = os.path.join(work_dir, \"datastats.yaml\")\n        data_src_cfg = os.path.join(work_dir, \"data_src_cfg.yaml\")\n\n        if not os.path.isdir(work_dir):\n            os.makedirs(work_dir)\n\n        create_sim_data(dataroot, fake_datalist, (24, 24, 24), rad_max=10, rad_min=1, num_seg_classes=1)\n\n        # write to a json file\n        fake_json_datalist = os.path.join(dataroot, \"fake_input.json\")\n        ConfigParser.export_config_file(fake_datalist, fake_json_datalist)\n\n        da = DataAnalyzer(fake_json_datalist, dataroot, output_path=da_output_yaml)\n        da.get_all_case_stats()\n\n        data_src = {\n            \"name\": \"fake_data\",\n            \"task\": \"segmentation\",\n            \"modality\": \"MRI\",\n            \"datalist\": fake_json_datalist,\n            \"dataroot\": dataroot,\n            \"multigpu\": False,\n            \"class_names\": [\"label_class\"],\n        }\n\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n\n        self.da_output_yaml = da_output_yaml\n        self.work_dir = work_dir\n        self.data_src_cfg_name = data_src_cfg\n\n    def test_ensemble(self) -> None:\n        with skip_if_downloading_fails():\n            bundle_generator = BundleGen(\n                algo_path=self.work_dir,\n                data_stats_filename=self.da_output_yaml,\n                data_src_cfg_name=self.data_src_cfg_name,\n                templates_path_or_url=get_testing_algo_template_path(),\n            )\n        bundle_generator.generate(self.work_dir, num_fold=1)\n        history = bundle_generator.get_history()\n\n        for algo_dict in history:\n            name = algo_dict[AlgoKeys.ID]\n            algo = algo_dict[AlgoKeys.ALGO]\n            _train_param = train_param.copy()\n            if name.startswith(\"segresnet\"):\n                _train_param[\"network#init_filters\"] = 8\n                _train_param[\"pretrained_ckpt_name\"] = \"\"\n            elif name.startswith(\"swinunetr\"):\n                _train_param[\"network#feature_size\"] = 12\n            algo.train(_train_param)\n\n        builder = AlgoEnsembleBuilder(history, data_src_cfg_name=self.data_src_cfg_name)\n        builder.set_ensemble_method(AlgoEnsembleBestN(n_best=1))\n        ensemble = builder.get_ensemble()\n        preds = ensemble(pred_param)\n        self.assertTupleEqual(preds[0].shape, (2, 24, 24, 24))\n\n        builder.set_ensemble_method(AlgoEnsembleBestByFold(1))\n        ensemble = builder.get_ensemble()\n        for algo in ensemble.get_algo_ensemble():\n            print(algo[AlgoKeys.ID])\n\n    def test_ensemble_runner(self) -> None:\n        runner = EnsembleRunner(data_src_cfg_name=self.data_src_cfg_name, mgpu=False)\n        runner.set_num_fold(3)\n        self.assertTrue(runner.num_fold == 3)\n        runner.set_ensemble_method(ensemble_method_name=\"AlgoEnsembleBestByFold\")\n        self.assertIsInstance(runner.ensemble_method, AlgoEnsembleBestByFold)\n        self.assertTrue(runner.ensemble_method.n_fold == 3)  # type: ignore\n\n        runner.set_ensemble_method(ensemble_method_name=\"AlgoEnsembleBestN\", n_best=3)\n        self.assertIsInstance(runner.ensemble_method, AlgoEnsembleBestN)\n        self.assertTrue(runner.ensemble_method.n_best == 3)\n\n        save_output = os.path.join(self.test_dir.name, \"workdir\")\n        save_image = runner._pop_kwargs_to_get_image_save_transform(\n            output_dir=save_output, output_postfix=\"test_ensemble\", output_dtype=float, resample=True\n        )\n        self.assertIsInstance(ConfigParser(save_image).get_parsed_content(), SaveImage)\n\n    def tearDown(self) -> None:\n        set_determinism(None)\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_auto3dseg_hpo.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom functools import partial\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nfrom monai.apps.auto3dseg import BundleGen, DataAnalyzer, NNIGen, OptunaGen, import_bundle_algo_history\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import create_test_image_3d\nfrom monai.utils import optional_import\nfrom monai.utils.enums import AlgoKeys\nfrom tests.test_utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\noptuna, has_optuna = optional_import(\"optuna\")\n\noverride_param = (\n    {\n        \"num_images_per_batch\": 2,\n        \"num_epochs\": 2,\n        \"num_epochs_per_validation\": 1,\n        \"num_warmup_epochs\": 1,\n        \"use_pretrain\": False,\n        \"pretrained_path\": \"\",\n        \"auto_scale_allowed\": False,\n    }\n    if torch.cuda.is_available()\n    else {}\n)\n\n\ndef skip_if_no_optuna(obj):\n    \"\"\"\n    Skip the unit tests if torch.cuda.is_available is False.\n    \"\"\"\n    return unittest.skipUnless(has_optuna, \"Skipping optuna tests\")(obj)\n\n\nfake_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": \"val_001.fake.nii.gz\"}, {\"image\": \"val_002.fake.nii.gz\"}],\n    \"training\": [\n        {\"fold\": 0, \"image\": \"tr_image_001.fake.nii.gz\", \"label\": \"tr_label_001.fake.nii.gz\"},\n        {\"fold\": 0, \"image\": \"tr_image_002.fake.nii.gz\", \"label\": \"tr_label_002.fake.nii.gz\"},\n        {\"fold\": 0, \"image\": \"tr_image_003.fake.nii.gz\", \"label\": \"tr_label_003.fake.nii.gz\"},\n        {\"fold\": 0, \"image\": \"tr_image_004.fake.nii.gz\", \"label\": \"tr_label_004.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_005.fake.nii.gz\", \"label\": \"tr_label_005.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_006.fake.nii.gz\", \"label\": \"tr_label_006.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_007.fake.nii.gz\", \"label\": \"tr_label_007.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_008.fake.nii.gz\", \"label\": \"tr_label_008.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_009.fake.nii.gz\", \"label\": \"tr_label_009.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_010.fake.nii.gz\", \"label\": \"tr_label_010.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_011.fake.nii.gz\", \"label\": \"tr_label_011.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_012.fake.nii.gz\", \"label\": \"tr_label_012.fake.nii.gz\"},\n    ],\n}\n\n\n@unittest.skipIf(not has_tb, \"no tensorboard summary writer\")\nclass TestHPO(unittest.TestCase):\n    def setUp(self) -> None:\n        self.test_dir = tempfile.TemporaryDirectory()\n        test_path = self.test_dir.name\n\n        work_dir = os.path.abspath(os.path.join(test_path, \"workdir\"))\n        dataroot = os.path.join(work_dir, \"dataroot\")\n\n        da_output_yaml = os.path.join(work_dir, \"datastats.yaml\")\n        data_src_cfg = os.path.join(work_dir, \"data_src_cfg.yaml\")\n\n        if not os.path.isdir(dataroot):\n            os.makedirs(dataroot)\n\n        if not os.path.isdir(work_dir):\n            os.makedirs(work_dir)\n\n        # Generate a fake dataset\n        for d in fake_datalist[\"testing\"] + fake_datalist[\"training\"]:\n            im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=1)\n            nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n            image_fpath = os.path.join(dataroot, d[\"image\"])\n            nib.save(nib_image, image_fpath)\n\n            if \"label\" in d:\n                nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n                label_fpath = os.path.join(dataroot, d[\"label\"])\n                nib.save(nib_image, label_fpath)\n\n        # write to a json file\n        fake_json_datalist = os.path.join(dataroot, \"fake_input.json\")\n        ConfigParser.export_config_file(fake_datalist, fake_json_datalist)\n\n        da = DataAnalyzer(fake_json_datalist, dataroot, output_path=da_output_yaml)\n        da.get_all_case_stats()\n\n        data_src = {\n            \"name\": \"fake_data\",\n            \"task\": \"segmentation\",\n            \"modality\": \"MRI\",\n            \"datalist\": fake_json_datalist,\n            \"dataroot\": dataroot,\n            \"multigpu\": False,\n            \"class_names\": [\"label_class\"],\n        }\n\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n        with skip_if_downloading_fails():\n            bundle_generator = BundleGen(\n                algo_path=work_dir,\n                data_stats_filename=da_output_yaml,\n                data_src_cfg_name=data_src_cfg,\n                templates_path_or_url=get_testing_algo_template_path(),\n            )\n        bundle_generator.generate(work_dir, num_fold=1)\n\n        self.history = bundle_generator.get_history()\n        self.work_dir = work_dir\n        self.test_path = test_path\n\n    @skip_if_no_cuda\n    def test_run_algo(self) -> None:\n        algo_dict = self.history[-1]\n        algo = algo_dict[AlgoKeys.ALGO]\n        nni_gen = NNIGen(algo=algo, params=override_param)\n        obj_filename = nni_gen.get_obj_filename()\n        # this function will be used in HPO via Python Fire\n        NNIGen().run_algo(obj_filename, self.work_dir)\n\n    @skip_if_no_cuda\n    @skip_if_no_optuna\n    def test_run_optuna(self) -> None:\n        algo_dict = self.history[-1]\n        algo = algo_dict[AlgoKeys.ALGO]\n\n        class OptunaGenLearningRate(OptunaGen):\n            def get_hyperparameters(self):\n                return {\"learning_rate\": self.trial.suggest_float(\"learning_rate\", 0.00001, 0.1)}\n\n        optuna_gen = OptunaGenLearningRate(algo=algo, params=override_param)\n        search_space = {\"learning_rate\": [0.0001, 0.001, 0.01, 0.1]}\n        study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space), direction=\"maximize\")\n        study.optimize(\n            partial(\n                optuna_gen,\n                obj_filename=optuna_gen.get_obj_filename(),\n                output_folder=os.path.join(self.test_path, \"optuna_test\"),\n            ),\n            n_trials=2,\n        )\n        print(f\"Best value: {study.best_value} (params: {study.best_params})\\n\")\n\n    @skip_if_no_cuda\n    def test_get_history(self) -> None:\n        algo_dict = self.history[-1]\n        algo = algo_dict[AlgoKeys.ALGO]\n        nni_gen = NNIGen(algo=algo, params=override_param)\n        obj_filename = nni_gen.get_obj_filename()\n\n        NNIGen().run_algo(obj_filename, self.work_dir)\n        history = import_bundle_algo_history(self.work_dir, only_trained=True)\n        assert len(history) == 1\n\n    def tearDown(self) -> None:\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_deepedit_interaction.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.apps.deepedit.interaction import Interaction\nfrom monai.apps.deepedit.transforms import (\n    AddGuidanceSignalDeepEditd,\n    AddInitialSeedPointMissingLabelsd,\n    AddRandomGuidanceDeepEditd,\n    FindAllValidSlicesMissingLabelsd,\n    FindDiscrepancyRegionsDeepEditd,\n    SplitPredsLabeld,\n)\nfrom monai.data import DataLoader, Dataset\nfrom monai.engines import SupervisedTrainer\nfrom monai.engines.utils import IterationEvents\nfrom monai.losses import DiceCELoss\nfrom monai.transforms import Activationsd, AsDiscreted, Compose, ToTensord\n\n\ndef add_one(engine):\n    if engine.state.best_metric == -1:\n        engine.state.best_metric = 0\n    else:\n        engine.state.best_metric = engine.state.best_metric + 1\n\n\nclass TestInteractions(unittest.TestCase):\n\n    def run_interaction(self, train):\n        label_names = {\"spleen\": 1, \"background\": 0}\n        np.random.seed(0)\n        data = [\n            {\n                \"image\": np.random.randint(0, 256, size=(1, 15, 15, 15)).astype(np.float32),\n                \"label\": np.random.randint(0, 2, size=(1, 15, 15, 15)),\n                \"label_names\": label_names,\n            }\n            for _ in range(5)\n        ]\n        network = torch.nn.Conv3d(3, len(label_names), 1)\n        lr = 1e-3\n        opt = torch.optim.Adam(network.parameters(), lr)\n        loss = DiceCELoss(to_onehot_y=True, softmax=True)\n        pre_transforms = Compose(\n            [\n                FindAllValidSlicesMissingLabelsd(keys=\"label\", sids=\"sids\"),\n                AddInitialSeedPointMissingLabelsd(keys=\"label\", guidance=\"guidance\", sids=\"sids\"),\n                AddGuidanceSignalDeepEditd(keys=\"image\", guidance=\"guidance\", number_intensity_ch=1),\n                ToTensord(keys=(\"image\", \"label\")),\n            ]\n        )\n        dataset = Dataset(data, transform=pre_transforms)\n        data_loader = DataLoader(dataset, batch_size=5)\n\n        iteration_transforms = [\n            FindDiscrepancyRegionsDeepEditd(keys=\"label\", pred=\"pred\", discrepancy=\"discrepancy\"),\n            AddRandomGuidanceDeepEditd(\n                keys=\"NA\", guidance=\"guidance\", discrepancy=\"discrepancy\", probability=\"probability\"\n            ),\n            AddGuidanceSignalDeepEditd(keys=\"image\", guidance=\"guidance\", number_intensity_ch=1),\n            ToTensord(keys=(\"image\", \"label\")),\n        ]\n        post_transforms = [\n            Activationsd(keys=\"pred\", softmax=True),\n            AsDiscreted(keys=(\"pred\", \"label\"), argmax=(True, False), to_onehot=len(label_names)),\n            SplitPredsLabeld(keys=\"pred\"),\n            ToTensord(keys=(\"image\", \"label\")),\n        ]\n        iteration_transforms = Compose(iteration_transforms)\n        post_transforms = Compose(post_transforms)\n\n        i = Interaction(\n            deepgrow_probability=1.0,\n            transforms=iteration_transforms,\n            click_probability_key=\"probability\",\n            train=train,\n            label_names=label_names,\n        )\n        self.assertEqual(len(i.transforms.transforms), 4, \"Mismatch in expected transforms\")\n\n        # set up engine\n        engine = SupervisedTrainer(\n            device=\"cpu\",\n            max_epochs=1,\n            train_data_loader=data_loader,\n            network=network,\n            optimizer=opt,\n            loss_function=loss,\n            postprocessing=post_transforms,\n            iteration_update=i,\n        )\n        engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one)\n        engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one)\n\n        engine.run()\n        self.assertIsNotNone(engine.state.batch[0].get(\"guidance\"), \"guidance is missing\")\n        self.assertEqual(engine.state.best_metric, 1)\n\n    def test_train_interaction(self):\n        self.run_interaction(train=True)\n\n    def test_val_interaction(self):\n        self.run_interaction(train=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_hovernet_nuclear_type_post_processingd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.apps.pathology.transforms.post.dictionary import (\n    HoVerNetInstanceMapPostProcessingd,\n    HoVerNetNuclearTypePostProcessingd,\n)\nfrom monai.transforms import ComputeHoVerMaps\nfrom monai.utils import min_version, optional_import\nfrom monai.utils.enums import HoVerNetBranch\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_scipy = optional_import(\"scipy\", \"1.8.1\", min_version)\n_, has_skimage = optional_import(\"skimage\", \"0.19.3\", min_version)\n\ny, x = np.ogrid[0:30, 0:30]\nimage = (x - 10) ** 2 + (y - 10) ** 2 <= 5**2\nimage = image[None, ...].astype(\"uint8\")\n\nTEST_CASE_1 = [{}, [{\"1\": [10, 10]}, np.zeros_like(image), np.zeros_like(image)]]\n\nTEST_CASE = []\nfor p in TEST_NDARRAYS:\n    TEST_CASE.append([p, image] + TEST_CASE_1)\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy library.\")\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass TestHoVerNetNuclearTypePostProcessingd(unittest.TestCase):\n    @parameterized.expand(TEST_CASE)\n    def test_value(self, in_type, test_data, kwargs, expected):\n        input = {\n            HoVerNetBranch.NP.value: in_type(test_data.astype(float)),\n            HoVerNetBranch.HV.value: in_type(ComputeHoVerMaps()(test_data.astype(int))),\n            HoVerNetBranch.NC.value: in_type(test_data),\n        }\n\n        outputs = HoVerNetInstanceMapPostProcessingd()(input)\n        outputs = HoVerNetNuclearTypePostProcessingd(**kwargs)(outputs)\n\n        # instance prediction info\n        for key in outputs[\"instance_info\"]:\n            assert_allclose(outputs[\"instance_info\"][key][\"centroid\"], expected[0][key], type_test=False)\n\n        # instance map\n        assert_allclose(outputs[\"instance_map\"], expected[1], type_test=False)\n\n        # type map\n        if expected[2] is None:\n            self.assertIsNone(outputs[\"type_map\"])\n        else:\n            assert_allclose(outputs[\"type_map\"], expected[2], type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_autorunner.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nfrom monai.apps.auto3dseg import AutoRunner\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import create_test_image_3d\nfrom monai.utils import optional_import\nfrom tests.test_utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n_, has_nni = optional_import(\"nni\")\n\nnum_images_perfold = max(torch.cuda.device_count(), 4)\nnum_images_per_batch = 2\n\nsim_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": f\"ts_image__{idx:03d}.nii.gz\"} for idx in range(num_images_perfold)],\n    \"training\": [\n        {\n            \"fold\": f,\n            \"image\": f\"tr_image_{(f * num_images_perfold + idx):03d}.nii.gz\",\n            \"label\": f\"tr_label_{(f * num_images_perfold + idx):03d}.nii.gz\",\n        }\n        for f in range(num_images_per_batch + 1)\n        for idx in range(num_images_perfold)\n    ],\n}\n\ntrain_param = (\n    {\n        \"num_images_per_batch\": num_images_per_batch,\n        \"num_epochs\": 2,\n        \"num_epochs_per_validation\": 1,\n        \"num_warmup_epochs\": 1,\n        \"use_pretrain\": False,\n        \"pretrained_path\": \"\",\n        \"num_steps_per_image\": 1,\n    }\n    if torch.cuda.is_available()\n    else {}\n)\n\npred_param = {\"files_slices\": slice(0, 1), \"mode\": \"mean\", \"sigmoid\": True}\n\n\n@skip_if_quick\n@unittest.skipIf(not has_tb, \"no tensorboard summary writer\")\nclass TestAutoRunner(unittest.TestCase):\n    def setUp(self) -> None:\n        self.test_dir = tempfile.TemporaryDirectory()\n        test_path = self.test_dir.name\n\n        sim_dataroot = os.path.join(test_path, \"dataroot\")\n        if not os.path.isdir(sim_dataroot):\n            os.makedirs(sim_dataroot)\n\n        # Generate a fake dataset\n        for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n            im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=1)\n            nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n            image_fpath = os.path.join(sim_dataroot, d[\"image\"])\n            nib.save(nib_image, image_fpath)\n\n            if \"label\" in d:\n                nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n                label_fpath = os.path.join(sim_dataroot, d[\"label\"])\n                nib.save(nib_image, label_fpath)\n\n        sim_json_datalist = os.path.join(sim_dataroot, \"sim_input.json\")\n        ConfigParser.export_config_file(sim_datalist, sim_json_datalist)\n\n        data_src_cfg = os.path.join(test_path, \"data_src_cfg.yaml\")\n        data_src = {\n            \"name\": \"sim_data\",\n            \"task\": \"segmentation\",\n            \"modality\": \"MRI\",\n            \"datalist\": sim_json_datalist,\n            \"dataroot\": sim_dataroot,\n            \"multigpu\": False,\n            \"class_names\": [\"label_class\"],\n        }\n\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n        self.data_src_cfg = data_src_cfg\n        self.test_path = test_path\n\n    @skip_if_no_cuda\n    def test_autorunner(self) -> None:\n        work_dir = os.path.join(self.test_path, \"work_dir\")\n        runner = AutoRunner(\n            work_dir=work_dir,\n            input=self.data_src_cfg,\n            templates_path_or_url=get_testing_algo_template_path(),\n            allow_skip=False,\n        )\n        runner.set_training_params(train_param)  # 2 epochs\n        runner.set_num_fold(1)\n        with skip_if_downloading_fails():\n            runner.run()\n\n    @skip_if_no_cuda\n    def test_autorunner_ensemble(self) -> None:\n        work_dir = os.path.join(self.test_path, \"work_dir\")\n        runner = AutoRunner(\n            work_dir=work_dir,\n            input=self.data_src_cfg,\n            templates_path_or_url=get_testing_algo_template_path(),\n            allow_skip=False,\n        )\n        runner.set_training_params(train_param)  # 2 epochs\n        runner.set_ensemble_method(\"AlgoEnsembleBestByFold\")\n        runner.set_num_fold(1)\n        with skip_if_downloading_fails():\n            runner.run()\n\n    @skip_if_no_cuda\n    def test_autorunner_gpu_customization(self) -> None:\n        work_dir = os.path.join(self.test_path, \"work_dir\")\n        runner = AutoRunner(\n            work_dir=work_dir,\n            input=self.data_src_cfg,\n            templates_path_or_url=get_testing_algo_template_path(),\n            allow_skip=False,\n        )\n        gpu_customization_specs = {\n            \"universal\": {\"num_trials\": 1, \"range_num_images_per_batch\": [1, 2], \"range_num_sw_batch_size\": [1, 2]}\n        }\n        runner.set_gpu_customization(gpu_customization=True, gpu_customization_specs=gpu_customization_specs)\n        runner.set_training_params(train_param)  # 2 epochs\n        runner.set_num_fold(1)\n        with skip_if_downloading_fails():\n            runner.run()\n\n    @skip_if_no_cuda\n    @unittest.skipIf(not has_nni, \"nni required\")\n    def test_autorunner_hpo(self) -> None:\n        work_dir = os.path.join(self.test_path, \"work_dir\")\n        runner = AutoRunner(\n            work_dir=work_dir,\n            input=self.data_src_cfg,\n            hpo=True,\n            ensemble=False,\n            templates_path_or_url=get_testing_algo_template_path(),\n            allow_skip=False,\n        )\n        hpo_param = {\n            \"num_epochs_per_validation\": train_param[\"num_epochs_per_validation\"],\n            \"num_images_per_batch\": train_param[\"num_images_per_batch\"],\n            \"num_epochs\": train_param[\"num_epochs\"],\n            \"num_warmup_epochs\": train_param[\"num_warmup_epochs\"],\n            \"use_pretrain\": train_param[\"use_pretrain\"],\n            \"pretrained_path\": train_param[\"pretrained_path\"],\n            # below are to shorten the time for dints\n            \"training#num_epochs_per_validation\": train_param[\"num_epochs_per_validation\"],\n            \"training#num_images_per_batch\": train_param[\"num_images_per_batch\"],\n            \"training#num_epochs\": train_param[\"num_epochs\"],\n            \"training#num_warmup_epochs\": train_param[\"num_warmup_epochs\"],\n            \"searching#num_epochs_per_validation\": train_param[\"num_epochs_per_validation\"],\n            \"searching#num_images_per_batch\": train_param[\"num_images_per_batch\"],\n            \"searching#num_epochs\": train_param[\"num_epochs\"],\n            \"searching#num_warmup_epochs\": train_param[\"num_warmup_epochs\"],\n            \"nni_dry_run\": True,\n        }\n        search_space = {\"learning_rate\": {\"_type\": \"choice\", \"_value\": [0.0001, 0.001, 0.01, 0.1]}}\n        runner.set_num_fold(1)\n        runner.set_nni_search_space(search_space)\n        runner.set_hpo_params(params=hpo_param)\n        with skip_if_downloading_fails():\n            runner.run()\n\n    def tearDown(self) -> None:\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_bundle_run.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\nimport unittest\nfrom glob import glob\nfrom pathlib import Path\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser\nfrom monai.bundle.utils import DEFAULT_HANDLERS_ID\nfrom monai.transforms import LoadImage\nfrom monai.utils import path_to_uri\nfrom tests.test_utils import command_line_tests\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_CASE_1 = [os.path.join(TESTS_PATH, \"testing_data\", \"inference.json\"), (128, 128, 128)]\n\nTEST_CASE_2 = [os.path.join(TESTS_PATH, \"testing_data\", \"inference.yaml\"), (128, 128, 128)]\n\n\nclass _Runnable42:\n    def __init__(self, val=1):\n        self.val = val\n\n    def run(self):\n        assert self.val == 42  # defined in `TestBundleRun.test_tiny``\n        return self.val\n\n\nclass _Runnable43:\n    def __init__(self, func):\n        self.func = func\n\n    def run(self):\n        self.func()\n\n\nclass TestBundleRun(unittest.TestCase):\n    def setUp(self):\n        self.data_dir = tempfile.mkdtemp()\n\n    def tearDown(self):\n        shutil.rmtree(self.data_dir)\n\n    def test_tiny(self):\n        config_file = os.path.join(self.data_dir, \"tiny_config.json\")\n        meta_file = os.path.join(self.data_dir, \"tiny_meta.json\")\n        with open(config_file, \"w\") as f:\n            json.dump(\n                {\n                    \"trainer\": {\"_target_\": \"tests.integration.test_integration_bundle_run._Runnable42\", \"val\": 42},\n                    # keep this test case to cover the \"run_id\" arg\n                    \"training\": \"$@trainer.run()\",\n                },\n                f,\n            )\n        with open(meta_file, \"w\") as f:\n            json.dump(\n                {\"version\": \"0.1.0\", \"monai_version\": \"1.1.0\", \"pytorch_version\": \"2.3.0\", \"numpy_version\": \"1.22.2\"}, f\n            )\n        cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\"]\n        # test both CLI entry \"run\" and \"run_workflow\"\n        command_line_tests(cmd + [\"run\", \"training\", \"--config_file\", config_file, \"--meta_file\", meta_file])\n        command_line_tests(\n            cmd + [\"run_workflow\", \"--run_id\", \"training\", \"--config_file\", config_file, \"--meta_file\", meta_file]\n        )\n        with self.assertRaises(RuntimeError):\n            # test wrong run_id=\"run\"\n            command_line_tests(cmd + [\"run\", \"run\", \"--config_file\", config_file])\n        # test missing meta file\n        self.assertIn(\"ERROR\", command_line_tests(cmd + [\"run\", \"training\", \"--config_file\", config_file]))\n\n    def test_scripts_fold(self):\n        # test scripts directory has been added to Python search directories automatically\n        config_file = os.path.join(self.data_dir, \"tiny_config.json\")\n        meta_file = os.path.join(self.data_dir, \"tiny_meta.json\")\n        scripts_dir = os.path.join(self.data_dir, \"scripts\")\n        script_file = os.path.join(scripts_dir, \"test_scripts_fold.py\")\n        init_file = os.path.join(scripts_dir, \"__init__.py\")\n\n        with open(config_file, \"w\") as f:\n            json.dump(\n                {\n                    \"imports\": [\"$import scripts\"],\n                    \"trainer\": {\n                        \"_target_\": \"tests.integration.test_integration_bundle_run._Runnable43\",\n                        \"func\": \"$scripts.tiny_test\",\n                    },\n                    # keep this test case to cover the \"run_id\" arg\n                    \"training\": \"$@trainer.run()\",\n                },\n                f,\n            )\n        with open(meta_file, \"w\") as f:\n            json.dump(\n                {\"version\": \"0.1.0\", \"monai_version\": \"1.1.0\", \"pytorch_version\": \"2.3.0\", \"numpy_version\": \"1.22.2\"}, f\n            )\n\n        os.mkdir(scripts_dir)\n        script_file_lines = [\"def tiny_test():\\n\", \"    print('successfully added scripts fold!') \\n\"]\n        init_file_line = \"from .test_scripts_fold import tiny_test\\n\"\n        with open(script_file, \"w\") as f:\n            f.writelines(script_file_lines)\n            f.close()\n        with open(init_file, \"w\") as f:\n            f.write(init_file_line)\n            f.close()\n\n        cmd = [\"coverage\", \"run\", \"-m\", \"monai.bundle\"]\n        # test both CLI entry \"run\" and \"run_workflow\"\n        expected_condition = \"successfully added scripts fold!\"\n        command_run = cmd + [\"run\", \"training\", \"--config_file\", config_file, \"--meta_file\", meta_file]\n        completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True)\n        output = repr(completed_process.stdout).replace(\"\\\\n\", \"\\n\").replace(\"\\\\t\", \"\\t\")  # Get the captured output\n\n        self.assertIn(expected_condition, output)\n        command_run_workflow = cmd + [\n            \"run_workflow\",\n            \"--run_id\",\n            \"training\",\n            \"--config_file\",\n            config_file,\n            \"--meta_file\",\n            meta_file,\n        ]\n        completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True)\n        output = repr(completed_process.stdout).replace(\"\\\\n\", \"\\n\").replace(\"\\\\t\", \"\\t\")  # Get the captured output\n        self.assertIn(expected_condition, output)\n\n        # test missing meta file\n        self.assertIn(\"ERROR\", command_line_tests(cmd + [\"run\", \"training\", \"--config_file\", config_file]))\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape(self, config_file, expected_shape):\n        test_image = np.random.rand(*expected_shape)\n        tempdir = self.data_dir\n        filename = os.path.join(tempdir, \"image.nii\")\n        nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)\n\n        # generate default args in a JSON file\n        logging_conf = os.path.join(TESTS_PATH, \"testing_data\", \"logging.conf\")\n        def_args = {\"config_file\": \"will be replaced by `config_file` arg\", \"logging_file\": logging_conf}\n        def_args_file = os.path.join(tempdir, \"def_args.json\")\n        ConfigParser.export_config_file(config=def_args, filepath=def_args_file)\n\n        meta = {\"datalist\": [{\"image\": filename}], \"window\": (96, 96, 96)}\n        # test YAML file\n        meta_file = os.path.join(tempdir, \"meta.yaml\")\n        ConfigParser.export_config_file(config=meta, filepath=meta_file, fmt=\"yaml\")\n\n        # test MLFlow settings\n        settings = {\n            \"handlers_id\": DEFAULT_HANDLERS_ID,\n            \"configs\": {\n                \"no_epoch\": True,  # test override config in the settings file\n                \"evaluator\": {\n                    \"_target_\": \"MLFlowHandler\",\n                    \"tracking_uri\": \"$monai.utils.path_to_uri(@output_dir) + '/mlflow_override1'\",\n                    \"iteration_log\": \"@no_epoch\",\n                },\n            },\n        }\n        settings_file = os.path.join(tempdir, \"mlflow.json\")\n        ConfigParser.export_config_file(config=settings, filepath=settings_file, fmt=\"json\")\n\n        # test override with file, up case postfix\n        overridefile1 = os.path.join(tempdir, \"override1.JSON\")\n        with open(overridefile1, \"w\") as f:\n            # test override with part of the overriding file\n            json.dump({\"move_net\": \"$@network_def.to(@device)\"}, f)\n        os.makedirs(os.path.join(tempdir, \"jsons\"), exist_ok=True)\n        overridefile2 = os.path.join(tempdir, \"jsons/override2.JSON\")\n        with open(overridefile2, \"w\") as f:\n            # test override with the whole overriding file\n            json.dump(\"Dataset\", f)\n\n        if sys.platform == \"win32\":\n            override = \"--network $@network_def.to(@device) --dataset#_target_ Dataset\"\n        else:\n            override = f\"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}\"\n        device = \"$torch.device('cuda:0')\" if torch.cuda.is_available() else \"$torch.device('cpu')\"\n        # test with `monai.bundle` as CLI entry directly\n        cmd = \"-m monai.bundle run --postprocessing#transforms#2#output_postfix seg\"\n        cmd += f\" {override} --no_epoch False --output_dir {tempdir} --device {device}\"\n        la = [\"coverage\", \"run\"] + cmd.split(\" \") + [\"--meta_file\", meta_file] + [\"--config_file\", config_file]\n        test_env = os.environ.copy()\n        print(f\"CUDA_VISIBLE_DEVICES in {__file__}\", test_env.get(\"CUDA_VISIBLE_DEVICES\"))\n        command_line_tests(la + [\"--args_file\", def_args_file] + [\"--tracking\", settings_file])\n        loader = LoadImage(image_only=True)\n        self.assertTupleEqual(loader(os.path.join(tempdir, \"image\", \"image_seg.nii.gz\")).shape, expected_shape)\n        self.assertTrue(os.path.exists(f\"{tempdir}/mlflow_override1\"))\n\n        tracking_uri = path_to_uri(tempdir) + \"/mlflow_override2\"  # test override experiment management configs\n        # here test the script with `google fire` tool as CLI\n        cmd = \"-m fire monai.bundle.scripts run --tracking mlflow --evaluator#amp False\"\n        cmd += f\" --tracking_uri {tracking_uri} {override} --output_dir {tempdir} --device {device}\"\n        la = [\"coverage\", \"run\"] + cmd.split(\" \") + [\"--meta_file\", meta_file] + [\"--config_file\", config_file]\n        command_line_tests(la)\n        self.assertTupleEqual(loader(os.path.join(tempdir, \"image\", \"image_trans.nii.gz\")).shape, expected_shape)\n        self.assertTrue(os.path.exists(f\"{tempdir}/mlflow_override2\"))\n        # test the saved execution configs\n        self.assertTrue(len(glob(f\"{tempdir}/config_*.json\")), 2)\n\n    def test_customized_workflow(self):\n        expected_shape = (64, 64, 64)\n        test_image = np.random.rand(*expected_shape)\n        filename = os.path.join(self.data_dir, \"image.nii\")\n        nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)\n\n        cmd = \"-m fire monai.bundle.scripts run_workflow --workflow_name tests.nonconfig_workflow.NonConfigWorkflow\"\n        cmd += f\" --filename {filename} --output_dir {self.data_dir}\"\n        command_line_tests([\"coverage\", \"run\"] + cmd.split(\" \"))\n        loader = LoadImage(image_only=True)\n        self.assertTupleEqual(loader(os.path.join(self.data_dir, \"image\", \"image_seg.nii.gz\")).shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_classification_2d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nimport warnings\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\n\nimport monai\nfrom monai.apps import download_and_extract\nfrom monai.data import decollate_batch\nfrom monai.metrics import ROCAUCMetric\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import DenseNet121\nfrom monai.transforms import (\n    Activations,\n    AsDiscrete,\n    Compose,\n    EnsureChannelFirst,\n    LoadImage,\n    RandFlip,\n    RandRotate,\n    RandZoom,\n    ScaleIntensity,\n    Transpose,\n)\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_downloading_fails, skip_if_quick, testing_data_config\nfrom tests.testing_data.integration_answers import test_integration_value\n\nTASK = \"integration_classification_2d\"\n\n\nclass MedNISTDataset(torch.utils.data.Dataset):\n    def __init__(self, image_files, labels, transforms):\n        self.image_files = image_files\n        self.labels = labels\n        self.transforms = transforms\n\n    def __len__(self):\n        return len(self.image_files)\n\n    def __getitem__(self, index):\n        return self.transforms(self.image_files[index]), self.labels[index]\n\n\ndef run_training_test(root_dir, train_x, train_y, val_x, val_y, device=\"cuda:0\", num_workers=10):\n    monai.config.print_config()\n    # define transforms for image and classification\n    train_transforms = Compose(\n        [\n            LoadImage(image_only=True, simple_keys=True),\n            EnsureChannelFirst(channel_dim=\"no_channel\"),\n            Transpose(indices=[0, 2, 1]),\n            ScaleIntensity(),\n            RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True, dtype=np.float64),\n            RandFlip(spatial_axis=0, prob=0.5),\n            RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),\n        ]\n    )\n    train_transforms.set_random_state(1234)\n    val_transforms = Compose(\n        [\n            LoadImage(image_only=True, simple_keys=True),\n            EnsureChannelFirst(channel_dim=\"no_channel\"),\n            Transpose(indices=[0, 2, 1]),\n            ScaleIntensity(),\n        ]\n    )\n    y_pred_trans = Compose([Activations(softmax=True)])\n    y_trans = AsDiscrete(to_onehot=len(np.unique(train_y)))\n    auc_metric = ROCAUCMetric()\n\n    # create train, val data loaders\n    train_ds = MedNISTDataset(train_x, train_y, train_transforms)\n    train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers)\n\n    val_ds = MedNISTDataset(val_x, val_y, val_transforms)\n    val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers)\n\n    model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(train_y))).to(device)\n    loss_function = torch.nn.CrossEntropyLoss()\n    optimizer = torch.optim.Adam(model.parameters(), 1e-5)\n    epoch_num = 1\n    val_interval = 1\n\n    # start training validation\n    best_metric = -1\n    best_metric_epoch = -1\n    epoch_loss_values = []\n    metric_values = []\n    model_filename = os.path.join(root_dir, \"best_metric_model.pth\")\n    for epoch in range(epoch_num):\n        print(\"-\" * 10)\n        print(f\"Epoch {epoch + 1}/{epoch_num}\")\n        model.train()\n        epoch_loss = 0\n        step = 0\n        for batch_data in train_loader:\n            step += 1\n            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)\n            optimizer.zero_grad()\n            outputs = model(inputs)\n            loss = loss_function(outputs, labels)\n            loss.backward()\n            optimizer.step()\n            epoch_loss += loss.item()\n        epoch_loss /= step\n        epoch_loss_values.append(epoch_loss)\n        print(f\"epoch {epoch + 1} average loss:{epoch_loss:0.4f}\")\n\n        if (epoch + 1) % val_interval == 0:\n            with eval_mode(model):\n                y_pred = torch.tensor([], dtype=torch.float32, device=device)\n                y = torch.tensor([], dtype=torch.long, device=device)\n                for val_data in val_loader:\n                    val_images, val_labels = (val_data[0].to(device), val_data[1].to(device))\n                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)\n                    y = torch.cat([y, val_labels], dim=0)\n\n                # compute accuracy\n                acc_value = torch.eq(y_pred.argmax(dim=1), y)\n                acc_metric = acc_value.sum().item() / len(acc_value)\n                # decollate prediction and label and execute post processing\n                y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)]\n                y = [y_trans(i) for i in decollate_batch(y, detach=False)]\n                # compute AUC\n                auc_metric(y_pred, y)\n                auc_value = auc_metric.aggregate()\n                auc_metric.reset()\n                metric_values.append(auc_value)\n                if auc_value > best_metric:\n                    best_metric = auc_value\n                    best_metric_epoch = epoch + 1\n                    torch.save(model.state_dict(), model_filename)\n                    print(\"saved new best metric model\")\n                print(\n                    f\"current epoch {epoch + 1} current AUC: {auc_value:0.4f} \"\n                    f\"current accuracy: {acc_metric:0.4f} best AUC: {best_metric:0.4f} at epoch {best_metric_epoch}\"\n                )\n    print(f\"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}\")\n    return epoch_loss_values, best_metric, best_metric_epoch\n\n\ndef run_inference_test(root_dir, test_x, test_y, device=\"cuda:0\", num_workers=10):\n    # define transforms for image and classification\n    val_transforms = Compose(\n        [LoadImage(image_only=True), EnsureChannelFirst(channel_dim=\"no_channel\"), ScaleIntensity()]\n    )\n    val_ds = MedNISTDataset(test_x, test_y, val_transforms)\n    val_loader = DataLoader(val_ds, batch_size=300, num_workers=num_workers)\n\n    model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(test_y))).to(device)\n\n    model_filename = os.path.join(root_dir, \"best_metric_model.pth\")\n    model.load_state_dict(torch.load(model_filename, weights_only=True))\n    y_true = []\n    y_pred = []\n    with eval_mode(model):\n        for test_data in val_loader:\n            test_images, test_labels = test_data[0].to(device), test_data[1].to(device)\n            pred = model(test_images).argmax(dim=1)\n            for i in range(len(pred)):\n                y_true.append(test_labels[i].item())\n                y_pred.append(pred[i].item())\n    tps = [np.sum((np.asarray(y_true) == idx) & (np.asarray(y_pred) == idx)) for idx in np.unique(test_y)]\n    return tps\n\n\n@skip_if_quick\nclass IntegrationClassification2D(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n        self.data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), \"../testing_data\")\n        data_dir = os.path.join(self.data_dir, \"MedNIST\")\n        dataset_file = os.path.join(self.data_dir, \"MedNIST.tar.gz\")\n\n        if not os.path.exists(data_dir):\n            with skip_if_downloading_fails():\n                data_spec = testing_data_config(\"images\", \"mednist\")\n                download_and_extract(\n                    data_spec[\"url\"],\n                    dataset_file,\n                    self.data_dir,\n                    hash_val=data_spec[\"hash_val\"],\n                    hash_type=data_spec[\"hash_type\"],\n                )\n\n        assert os.path.exists(data_dir)\n\n        class_names = sorted(x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))\n        image_files = [\n            [os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))]\n            for class_name in class_names\n        ]\n        image_file_list, image_classes = [], []\n        for i, _ in enumerate(class_names):\n            image_file_list.extend(image_files[i])\n            image_classes.extend([i] * len(image_files[i]))\n\n        # split train, val, test\n        valid_frac, test_frac = 0.1, 0.1\n        self.train_x, self.train_y = [], []\n        self.val_x, self.val_y = [], []\n        self.test_x, self.test_y = [], []\n        for i in range(len(image_classes)):\n            rann = np.random.random()\n            if rann < valid_frac:\n                self.val_x.append(image_file_list[i])\n                self.val_y.append(image_classes[i])\n            elif rann < test_frac + valid_frac:\n                self.test_x.append(image_file_list[i])\n                self.test_y.append(image_classes[i])\n            else:\n                self.train_x.append(image_file_list[i])\n                self.train_y.append(image_classes[i])\n\n        self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu:0\"\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        try:\n            os.remove(os.path.join(self.data_dir, \"best_metric_model.pth\"))\n        except FileNotFoundError:\n            warnings.warn(\"not found best_metric_model.pth, training skipped?\")\n\n    def train_and_infer(self, idx=0):\n        results = []\n        if not os.path.exists(os.path.join(self.data_dir, \"MedNIST\")):\n            # skip test if no MedNIST dataset\n            return results\n\n        set_determinism(seed=0)\n        losses, best_metric, best_metric_epoch = run_training_test(\n            self.data_dir, self.train_x, self.train_y, self.val_x, self.val_y, device=self.device\n        )\n        infer_metric = run_inference_test(self.data_dir, self.test_x, self.test_y, device=self.device)\n\n        print(f\"integration_classification_2d {losses}\")\n        print(\"best metric\", best_metric)\n        print(\"infer metric\", infer_metric)\n        # check training properties\n        self.assertTrue(test_integration_value(TASK, key=\"losses\", data=losses, rtol=1e-2))\n        self.assertTrue(test_integration_value(TASK, key=\"best_metric\", data=best_metric, rtol=1e-4))\n        np.testing.assert_allclose(best_metric_epoch, 1)\n        model_file = os.path.join(self.data_dir, \"best_metric_model.pth\")\n        self.assertTrue(os.path.exists(model_file))\n        # check inference properties\n        self.assertTrue(test_integration_value(TASK, key=\"infer_prop\", data=np.asarray(infer_metric), rtol=1))\n        results.extend(losses)\n        results.append(best_metric)\n        results.extend(infer_metric)\n        return results\n\n    def test_training(self):\n        repeated = []\n        for i in range(2):\n            results = self.train_and_infer(i)\n            repeated.append(results)\n        np.testing.assert_allclose(repeated[0], repeated[1])\n\n    @TimedCall(seconds=2000, skip_timing=not torch.cuda.is_available(), force_quit=False, daemon=False)\n    def test_timing(self):\n        self.train_and_infer()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_determinism.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom monai.data import create_test_image_2d\nfrom monai.losses import DiceLoss\nfrom monai.networks.nets import UNet\nfrom monai.transforms import Compose, EnsureChannelFirst, RandRotate90, RandSpatialCrop, ScaleIntensity\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall\n\n\ndef run_test(batch_size=64, train_steps=200, device=\"cuda:0\"):\n    class _TestBatch(Dataset):\n        def __init__(self, transforms):\n            self.transforms = transforms\n\n        def __getitem__(self, _unused_id):\n            im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)\n            seed = np.random.randint(2147483647)\n            self.transforms.set_random_state(seed=seed)\n            im = self.transforms(im)\n            self.transforms.set_random_state(seed=seed)\n            seg = self.transforms(seg)\n            return im, seg\n\n        def __len__(self):\n            return train_steps\n\n    net = UNet(\n        spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2\n    ).to(device)\n\n    loss = DiceLoss(sigmoid=True)\n    opt = torch.optim.Adam(net.parameters(), 1e-2)\n    train_transforms = Compose(\n        [\n            EnsureChannelFirst(channel_dim=\"no_channel\"),\n            ScaleIntensity(),\n            RandSpatialCrop((96, 96), random_size=False),\n            RandRotate90(),\n        ]\n    )\n\n    src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size, shuffle=True)\n\n    net.train()\n    epoch_loss = 0\n    step = 0\n    for img, seg in src:\n        step += 1\n        opt.zero_grad()\n        output = net(img.to(device))\n        step_loss = loss(output, seg.to(device))\n        step_loss.backward()\n        opt.step()\n        epoch_loss += step_loss.item()\n    epoch_loss /= step\n\n    return epoch_loss, step\n\n\nclass TestDeterminism(DistTestCase):\n    def setUp(self):\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n\n    def tearDown(self):\n        set_determinism(seed=None)\n\n    @TimedCall(seconds=150, skip_timing=not torch.cuda.is_available())\n    def test_training(self):\n        set_determinism(seed=0)\n        loss, step = run_test(device=self.device)\n        print(f\"Deterministic loss {loss} at training step {step}\")\n        np.testing.assert_allclose(step, 4)\n        np.testing.assert_allclose(loss, 0.536134, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_fast_train.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport math\nimport os\nimport shutil\nimport tempfile\nimport time\nimport unittest\nfrom glob import glob\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.data import CacheDataset, ThreadDataLoader, create_test_image_3d, decollate_batch\nfrom monai.inferers import sliding_window_inference\nfrom monai.losses import DiceCELoss\nfrom monai.metrics import DiceMetric\nfrom monai.networks.layers import Norm\nfrom monai.networks.nets import UNet\nfrom monai.optimizers import Novograd\nfrom monai.transforms import (\n    AsDiscrete,\n    Compose,\n    CropForegroundd,\n    EnsureChannelFirstd,\n    FgBgToIndicesd,\n    LoadImaged,\n    RandAffined,\n    RandAxisFlipd,\n    RandCropByPosNegLabeld,\n    RandFlipd,\n    RandGaussianNoised,\n    RandRotate90d,\n    RandRotated,\n    RandStdShiftIntensityd,\n    RandZoomd,\n    ScaleIntensityd,\n    Spacingd,\n    ToDeviced,\n)\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_no_cuda, skip_if_quick\n\n\n@skip_if_no_cuda\n@skip_if_quick\nclass IntegrationFastTrain(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n        monai.config.print_config()\n\n        self.data_dir = tempfile.mkdtemp()\n        for i in range(41):\n            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)\n            n = nib.Nifti1Image(im, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"img{i:d}.nii.gz\"))\n            n = nib.Nifti1Image(seg, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"seg{i:d}.nii.gz\"))\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        shutil.rmtree(self.data_dir)\n\n    # test the fast training speed is as expected\n    @TimedCall(seconds=100, daemon=False, force_quit=False)\n    def test_train_timing(self):\n        images = sorted(glob(os.path.join(self.data_dir, \"img*.nii.gz\")))\n        segs = sorted(glob(os.path.join(self.data_dir, \"seg*.nii.gz\")))\n        train_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images[:32], segs[:32])]\n        val_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images[-9:], segs[-9:])]\n\n        device = torch.device(\"cuda:0\")\n        # define transforms for train and validation\n        train_transforms = Compose(\n            [\n                LoadImaged(keys=[\"image\", \"label\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n                Spacingd(keys=[\"image\", \"label\"], pixdim=(1.0, 1.0, 1.0), mode=(\"bilinear\", \"nearest\")),\n                ScaleIntensityd(keys=\"image\"),\n                CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n                # pre-compute foreground and background indexes\n                # and cache them to accelerate training\n                FgBgToIndicesd(keys=\"label\", fg_postfix=\"_fg\", bg_postfix=\"_bg\"),\n                # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch\n                ToDeviced(keys=[\"image\", \"label\"], device=device),\n                # randomly crop out patch samples from big\n                # image based on pos / neg ratio\n                # the image centers of negative samples\n                # must be in valid image area\n                RandCropByPosNegLabeld(\n                    keys=[\"image\", \"label\"],\n                    label_key=\"label\",\n                    spatial_size=(64, 64, 64),\n                    pos=1,\n                    neg=1,\n                    num_samples=4,\n                    fg_indices_key=\"label_fg\",\n                    bg_indices_key=\"label_bg\",\n                ),\n                RandFlipd(keys=[\"image\", \"label\"], prob=0.5, spatial_axis=[1, 2]),\n                RandAxisFlipd(keys=[\"image\", \"label\"], prob=0.5),\n                RandRotate90d(keys=[\"image\", \"label\"], prob=0.5, spatial_axes=(1, 2)),\n                RandZoomd(keys=[\"image\", \"label\"], prob=0.5, min_zoom=0.8, max_zoom=1.2, keep_size=True),\n                RandRotated(\n                    keys=[\"image\", \"label\"],\n                    prob=0.5,\n                    range_x=np.pi / 4,\n                    mode=(\"bilinear\", \"nearest\"),\n                    align_corners=True,\n                    dtype=np.float64,\n                ),\n                RandAffined(keys=[\"image\", \"label\"], prob=0.5, rotate_range=np.pi / 2, mode=(\"bilinear\", \"nearest\")),\n                RandGaussianNoised(keys=\"image\", prob=0.5),\n                RandStdShiftIntensityd(keys=\"image\", prob=0.5, factors=0.05, nonzero=True),\n            ]\n        )\n\n        val_transforms = Compose(\n            [\n                LoadImaged(keys=[\"image\", \"label\"]),\n                EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n                Spacingd(keys=[\"image\", \"label\"], pixdim=(1.0, 1.0, 1.0), mode=(\"bilinear\", \"nearest\")),\n                ScaleIntensityd(keys=\"image\"),\n                CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n                # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch\n                ToDeviced(keys=[\"image\", \"label\"], device=device),\n            ]\n        )\n\n        max_epochs = 5\n        learning_rate = 2e-4\n        val_interval = 1  # do validation for every epoch\n\n        # set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training\n        train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8)\n        val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, runtime_cache=True)\n        # disable multi-workers because `ThreadDataLoader` works with multi-threads\n        train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True)\n        val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)\n\n        loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True)\n        model = UNet(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=2,\n            channels=(16, 32, 64, 128, 256),\n            strides=(2, 2, 2, 2),\n            num_res_units=2,\n            norm=Norm.BATCH,\n        ).to(device)\n\n        # Novograd paper suggests to use a bigger LR than Adam,\n        # because Adam does normalization by element-wise second moments\n        optimizer = Novograd(model.parameters(), learning_rate * 10)\n        scaler = torch.cuda.amp.GradScaler()\n\n        post_pred = AsDiscrete(argmax=True, to_onehot=2)\n        post_label = AsDiscrete(to_onehot=2)\n\n        dice_metric = DiceMetric(include_background=True, reduction=\"mean\", get_not_nans=False)\n\n        best_metric = -1\n        total_start = time.time()\n        for epoch in range(max_epochs):\n            epoch_start = time.time()\n            print(\"-\" * 10)\n            print(f\"epoch {epoch + 1}/{max_epochs}\")\n            model.train()\n            epoch_loss = 0\n            step = 0\n            for batch_data in train_loader:\n                step_start = time.time()\n                step += 1\n                optimizer.zero_grad()\n                # set AMP for training\n                with torch.autocast(\"cuda\"):\n                    outputs = model(batch_data[\"image\"])\n                    loss = loss_function(outputs, batch_data[\"label\"])\n                scaler.scale(loss).backward()\n                scaler.step(optimizer)\n                scaler.update()\n                epoch_loss += loss.item()\n                epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)\n                print(\n                    f\"{step}/{epoch_len}, train_loss: {loss.item():.4f}\" f\" step time: {(time.time() - step_start):.4f}\"\n                )\n            epoch_loss /= step\n            print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n\n            if (epoch + 1) % val_interval == 0:\n                model.eval()\n                with torch.no_grad():\n                    for val_data in val_loader:\n                        roi_size = (96, 96, 96)\n                        sw_batch_size = 4\n                        # set AMP for validation\n                        with torch.autocast(\"cuda\"):\n                            val_outputs = sliding_window_inference(val_data[\"image\"], roi_size, sw_batch_size, model)\n\n                        val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n                        val_labels = [post_label(i) for i in decollate_batch(val_data[\"label\"])]\n                        dice_metric(y_pred=val_outputs, y=val_labels)\n\n                    metric = dice_metric.aggregate().item()\n                    dice_metric.reset()\n                    if metric > best_metric:\n                        best_metric = metric\n                    print(f\"epoch: {epoch + 1} current mean dice: {metric:.4f}, best mean dice: {best_metric:.4f}\")\n            print(f\"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}\")\n\n        total_time = time.time() - total_start\n        print(f\"train completed, best_metric: {best_metric:.4f} total time: {total_time:.4f}\")\n        # test expected metrics\n        self.assertGreater(best_metric, 0.95)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_gpu_customization.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nfrom monai.apps.auto3dseg import AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder, BundleGen, DataAnalyzer\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import create_test_image_3d\nfrom monai.utils import optional_import\nfrom monai.utils.enums import AlgoKeys\nfrom tests.test_utils import get_testing_algo_template_path, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\nnum_images_perfold = max(torch.cuda.device_count(), 4)\nnum_images_per_batch = 2\n\nfake_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": \"val_001.fake.nii.gz\"}, {\"image\": \"val_002.fake.nii.gz\"}],\n    \"training\": [\n        {\n            \"fold\": f,\n            \"image\": f\"tr_image_{(f * num_images_perfold + idx):03d}.nii.gz\",\n            \"label\": f\"tr_label_{(f * num_images_perfold + idx):03d}.nii.gz\",\n        }\n        for f in range(num_images_per_batch + 1)\n        for idx in range(num_images_perfold)\n    ],\n}\n\ntrain_param = (\n    {\n        \"num_images_per_batch\": num_images_per_batch,\n        \"num_epochs\": 2,\n        \"num_epochs_per_validation\": 1,\n        \"num_warmup_epochs\": 1,\n        \"use_pretrain\": False,\n        \"pretrained_path\": \"\",\n    }\n    if torch.cuda.is_available()\n    else {}\n)\n\npred_param = {\"files_slices\": slice(0, 1), \"mode\": \"mean\", \"sigmoid\": True}\n\n\n@skip_if_quick\n@unittest.skipIf(not has_tb, \"no tensorboard summary writer\")\nclass TestEnsembleGpuCustomization(unittest.TestCase):\n    def setUp(self) -> None:\n        self.test_dir = tempfile.TemporaryDirectory()\n\n    @skip_if_no_cuda\n    def test_ensemble_gpu_customization(self) -> None:\n        test_path = self.test_dir.name\n\n        dataroot = os.path.join(test_path, \"dataroot\")\n        work_dir = os.path.join(test_path, \"workdir\")\n\n        da_output_yaml = os.path.join(work_dir, \"datastats.yaml\")\n        data_src_cfg = os.path.join(work_dir, \"data_src_cfg.yaml\")\n\n        if not os.path.isdir(dataroot):\n            os.makedirs(dataroot)\n\n        if not os.path.isdir(work_dir):\n            os.makedirs(work_dir)\n\n        # Generate a fake dataset\n        for d in fake_datalist[\"testing\"] + fake_datalist[\"training\"]:\n            im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=1)\n            nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n            image_fpath = os.path.join(dataroot, d[\"image\"])\n            nib.save(nib_image, image_fpath)\n\n            if \"label\" in d:\n                nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n                label_fpath = os.path.join(dataroot, d[\"label\"])\n                nib.save(nib_image, label_fpath)\n\n        # write to a json file\n        fake_json_datalist = os.path.join(dataroot, \"fake_input.json\")\n        ConfigParser.export_config_file(fake_datalist, fake_json_datalist)\n\n        da = DataAnalyzer(fake_json_datalist, dataroot, output_path=da_output_yaml)\n        da.get_all_case_stats()\n\n        data_src = {\n            \"name\": \"fake_data\",\n            \"task\": \"segmentation\",\n            \"modality\": \"MRI\",\n            \"datalist\": fake_json_datalist,\n            \"dataroot\": dataroot,\n            \"multigpu\": False,\n            \"class_names\": [\"label_class\"],\n        }\n\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n\n        with skip_if_downloading_fails():\n            bundle_generator = BundleGen(\n                algo_path=work_dir,\n                data_stats_filename=da_output_yaml,\n                data_src_cfg_name=data_src_cfg,\n                templates_path_or_url=get_testing_algo_template_path(),\n            )\n\n        gpu_customization_specs = {\n            \"universal\": {\"num_trials\": 1, \"range_num_images_per_batch\": [1, 2], \"range_num_sw_batch_size\": [1, 2]}\n        }\n        bundle_generator.generate(\n            work_dir, num_fold=1, gpu_customization=True, gpu_customization_specs=gpu_customization_specs\n        )\n        history = bundle_generator.get_history()\n\n        for algo_dict in history:\n            algo = algo_dict[AlgoKeys.ALGO]\n            algo.train(train_param)\n\n        builder = AlgoEnsembleBuilder(history, data_src_cfg)\n        builder.set_ensemble_method(AlgoEnsembleBestN(n_best=2))\n        ensemble = builder.get_ensemble()\n        preds = ensemble(pred_param)\n        self.assertTupleEqual(preds[0].shape, (2, 24, 24, 24))\n\n        builder.set_ensemble_method(AlgoEnsembleBestByFold(1))\n        ensemble = builder.get_ensemble()\n        for algo in ensemble.get_algo_ensemble():\n            print(algo[AlgoKeys.ID])\n\n    def tearDown(self) -> None:\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_lazy_samples.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom glob import glob\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nimport monai\nimport monai.transforms as mt\nfrom monai.data import create_test_image_3d, decollate_batch\nfrom monai.transforms.utils import has_status_keys\nfrom monai.utils import TraceStatusKeys, set_determinism\nfrom tests.test_utils import HAS_CUPY, DistTestCase, skip_if_quick\n\n\ndef _no_op(x):\n    return x\n\n\ndef run_training_test(root_dir, device=\"cuda:0\", cachedataset=0, readers=(None, None), num_workers=4, lazy=True):\n    print(f\"test case: {locals()}\")\n    images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n    segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n    train_files = [{\"img\": img, \"seg\": seg} for img, seg in zip(images[:20], segs[:20])]\n    device = \"cuda:0\" if HAS_CUPY and torch.cuda.is_available() else \"cpu\"  # mode 0 and cuda requires CUPY\n    num_workers = 0 if torch.cuda.is_available() else num_workers\n\n    # define transforms for image and segmentation\n    lazy_kwargs = {\n        \"img\": {\"mode\": \"bilinear\", \"device\": device, \"padding_mode\": \"border\", \"dtype\": torch.float32},\n        \"seg\": {\"mode\": 0, \"device\": device, \"padding_mode\": \"nearest\", \"dtype\": torch.uint8},\n    }\n    train_transforms = mt.Compose(\n        [\n            mt.LoadImaged(keys=[\"img\", \"seg\"], reader=readers[0], image_only=True),\n            mt.EnsureChannelFirstd(keys=[\"img\", \"seg\"]),\n            mt.Spacingd(\n                keys=[\"img\", \"seg\"],\n                pixdim=[1.2, 0.8, 0.7],\n                mode=[\"bilinear\", 0],\n                padding_mode=(\"border\", \"nearest\"),\n                dtype=np.float32,\n            ),\n            mt.Orientationd(keys=[\"img\", \"seg\"], axcodes=\"ARS\"),\n            mt.RandRotate90d(keys=[\"img\", \"seg\"], prob=1.0, spatial_axes=(1, 2)),\n            mt.ScaleIntensityd(keys=\"img\"),\n            mt.ApplyPendingd(keys=[\"seg\"]),\n            mt.RandCropByPosNegLabeld(\n                keys=[\"img\", \"seg\"], label_key=\"seg\", spatial_size=[76, 82, 80], pos=1, neg=1, num_samples=4\n            ),\n            mt.RandRotate90d(keys=[\"img\", \"seg\"], prob=0.8, spatial_axes=(0, 2)),\n            mt.RandZoomd(\n                keys=[\"img\", \"seg\"], prob=1.0, min_zoom=1.0, max_zoom=1.0, mode=(\"trilinear\", 0), keep_size=True\n            ),\n            mt.ResizeWithPadOrCropD(keys=[\"img\", \"seg\"], spatial_size=[80, 72, 80]),\n            mt.Rotated(keys=[\"img\", \"seg\"], angle=[np.pi / 2, np.pi / 2, 0], mode=\"nearest\", keep_size=False),\n            mt.Lambdad(keys=[\"img\"], func=_no_op),\n        ],\n        lazy=lazy,\n        overrides=lazy_kwargs,\n        log_stats=num_workers > 0,\n    )\n\n    # create a training data loader\n    if cachedataset == 2:\n        train_ds = monai.data.CacheDataset(\n            data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache=False, num_workers=0\n        )\n    elif cachedataset == 3:\n        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)\n    else:\n        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)\n\n    # create UNet, DiceLoss and Adam optimizer\n    model = monai.networks.nets.UNet(\n        spatial_dims=3, in_channels=1, out_channels=1, channels=(2, 2, 2, 2), strides=(2, 2, 2), num_res_units=2\n    ).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), 5e-4)\n    loss_function = monai.losses.DiceLoss(sigmoid=True)\n\n    saver = mt.SaveImage(\n        output_dir=os.path.join(root_dir, \"output\"),\n        dtype=np.float32,\n        output_ext=\".nii.gz\",\n        output_postfix=f\"seg_{lazy}_{num_workers}\",\n        mode=\"bilinear\",\n        resample=False,\n        separate_folder=False,\n        print_log=False,\n    )\n    inverter = mt.Invertd(\n        keys=\"seg\", orig_keys=\"img\", transform=mt.Compose(train_transforms.transforms[-5:]), to_tensor=True\n    )\n\n    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training\n    _g = torch.Generator()\n    _g.manual_seed(0)\n    set_determinism(0)\n    train_loader = monai.data.DataLoader(\n        train_ds, batch_size=1, shuffle=True, num_workers=num_workers, generator=_g, persistent_workers=num_workers > 0\n    )\n    all_coords = set()\n    batch_data = None\n    for epoch in range(5):\n        print(\"-\" * 10)\n        print(f\"Epoch {epoch + 1}/5\")\n        for step, batch_data in enumerate(train_loader, start=1):\n            inputs, labels = batch_data[\"img\"].to(device), batch_data[\"seg\"].to(device)\n            optimizer.zero_grad()\n            outputs = model(inputs)\n            loss = loss_function(outputs, labels)\n            loss.backward()\n            optimizer.step()\n            epoch_len = len(train_ds) // train_loader.batch_size\n            print(f\"{step}/{epoch_len}, train_loss:{loss.item():0.4f}\")\n\n            for item, in_img, in_seg in zip(outputs, inputs, labels):  # this decollates the batch, pt 1.9+\n                item.copy_meta_from(in_img)\n                np.testing.assert_array_equal(item.pending_operations, [])\n                np.testing.assert_array_equal(in_seg.pending_operations, [])\n                ops = [0]\n                if len(item.applied_operations) > 1:\n                    for idx, n in enumerate(item.applied_operations):\n                        if n[\"class\"] == \"RandCropByPosNegLabel\":\n                            ops = item.applied_operations[idx][\"extra_info\"][\"extra_info\"][\"cropped\"]\n                            break\n                img_name = os.path.basename(item.meta[\"filename_or_obj\"])\n                coords = f\"{img_name} - {ops}\"\n                print(coords)\n                # np.testing.assert_allclose(coords in all_coords, False)\n                all_coords.add(coords)\n                saver(item)  # just testing the saving\n                saver(in_img)\n                saver(in_seg)\n    invertible, reasons = has_status_keys(batch_data, TraceStatusKeys.PENDING_DURING_APPLY)\n    inverted = [inverter(b_data) for b_data in decollate_batch(batch_data)]  # expecting no error\n\n    return ops\n\n\n@skip_if_quick\nclass IntegrationLazyResampling(DistTestCase):\n    def setUp(self):\n        monai.config.print_config()\n        set_determinism(seed=0)\n\n        self.data_dir = tempfile.mkdtemp()\n        for i in range(3):\n            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)\n            n = nib.Nifti1Image(im, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"img{i:d}.nii.gz\"))\n            n = nib.Nifti1Image(seg, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"seg{i:d}.nii.gz\"))\n\n        self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu:0\"\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        shutil.rmtree(self.data_dir)\n\n    def train_and_infer(self, idx=0):\n        results = []\n        _readers = (None, None)\n        _w = 2\n        if idx == 1:\n            _readers = (\"itkreader\", \"itkreader\")\n            _w = 1\n        elif idx == 2:\n            _readers = (\"itkreader\", \"nibabelreader\")\n            _w = 0\n\n        results_expected = run_training_test(\n            self.data_dir, device=self.device, cachedataset=0, readers=_readers, num_workers=_w, lazy=False\n        )\n        results = run_training_test(\n            self.data_dir, device=self.device, cachedataset=idx, readers=_readers, num_workers=_w, lazy=True\n        )\n        self.assertFalse(np.allclose(results, [0]))\n        self.assertFalse(np.allclose(results_expected, [0]))\n        np.testing.assert_allclose(results, results_expected)\n        lazy_files = glob(os.path.join(self.data_dir, \"output\", \"*_True_*.nii.gz\"))\n        regular_files = glob(os.path.join(self.data_dir, \"output\", \"*_False_*.nii.gz\"))\n        diffs = []\n        for a, b in zip(sorted(lazy_files), sorted(regular_files)):\n            img_lazy = mt.LoadImage(image_only=True)(a)\n            img_regular = mt.LoadImage(image_only=True)(b)\n            diff = np.size(img_lazy) - np.sum(np.isclose(img_lazy, img_regular, atol=1e-4))\n            diff_rate = diff / np.size(img_lazy)\n            diffs.append(diff_rate)\n            np.testing.assert_allclose(diff_rate, 0.0, atol=0.03)\n        print(\"volume diff:\", diffs)\n\n    def test_training(self):\n        for i in range(4):\n            self.train_and_infer(i)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_nnunet_bundle.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\n\nfrom monai.apps.nnunet import nnUNetV2Runner\nfrom monai.apps.nnunet.nnunet_bundle import (\n    convert_nnunet_to_monai_bundle,\n    get_nnunet_monai_predictor,\n    get_nnunet_trainer,\n)\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import DataLoader, Dataset, create_test_image_3d\nfrom monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n_, has_nnunet = optional_import(\"nnunetv2\")\n\nsim_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": \"val_001.fake.nii.gz\"}, {\"image\": \"val_002.fake.nii.gz\"}],\n    \"training\": [\n        {\"fold\": 0, \"image\": \"tr_image_001.fake.nii.gz\", \"label\": \"tr_label_001.fake.nii.gz\"},\n        {\"fold\": 0, \"image\": \"tr_image_002.fake.nii.gz\", \"label\": \"tr_label_002.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_003.fake.nii.gz\", \"label\": \"tr_label_003.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_004.fake.nii.gz\", \"label\": \"tr_label_004.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_005.fake.nii.gz\", \"label\": \"tr_label_005.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_006.fake.nii.gz\", \"label\": \"tr_label_006.fake.nii.gz\"},\n        {\"fold\": 3, \"image\": \"tr_image_007.fake.nii.gz\", \"label\": \"tr_label_007.fake.nii.gz\"},\n        {\"fold\": 3, \"image\": \"tr_image_008.fake.nii.gz\", \"label\": \"tr_label_008.fake.nii.gz\"},\n        {\"fold\": 4, \"image\": \"tr_image_009.fake.nii.gz\", \"label\": \"tr_label_009.fake.nii.gz\"},\n        {\"fold\": 4, \"image\": \"tr_image_010.fake.nii.gz\", \"label\": \"tr_label_010.fake.nii.gz\"},\n    ],\n}\n\n\n@skip_if_quick\n@unittest.skipIf(not has_tb, \"no tensorboard summary writer\")\n@unittest.skipIf(not has_nnunet, \"no nnunetv2\")\nclass TestnnUNetBundle(unittest.TestCase):\n\n    def setUp(self) -> None:\n\n        import nibabel as nib\n\n        self.test_dir = tempfile.TemporaryDirectory()\n        test_path = self.test_dir.name\n\n        sim_dataroot = os.path.join(test_path, \"dataroot\")\n        if not os.path.isdir(sim_dataroot):\n            os.makedirs(sim_dataroot)\n\n        self.sim_dataroot = sim_dataroot\n        # Generate a fake dataset\n        for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n            im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=2)\n            nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n            image_fpath = os.path.join(sim_dataroot, d[\"image\"])\n            nib.save(nib_image, image_fpath)\n\n            if \"label\" in d:\n                nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n                label_fpath = os.path.join(sim_dataroot, d[\"label\"])\n                nib.save(nib_image, label_fpath)\n\n        sim_json_datalist = os.path.join(sim_dataroot, \"sim_input.json\")\n        ConfigParser.export_config_file(sim_datalist, sim_json_datalist)\n\n        data_src_cfg = os.path.join(test_path, \"data_src_cfg.yaml\")\n        data_src = {\"modality\": \"CT\", \"datalist\": sim_json_datalist, \"dataroot\": sim_dataroot}\n\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n        self.data_src_cfg = data_src_cfg\n        self.test_path = test_path\n\n    @skip_if_no_cuda\n    def test_nnunet_bundle(self) -> None:\n        runner = nnUNetV2Runner(\n            input_config=self.data_src_cfg, trainer_class_name=\"nnUNetTrainer_1epoch\", work_dir=self.test_path\n        )\n        with skip_if_downloading_fails():\n            runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)\n\n            nnunet_trainer = get_nnunet_trainer(\n                dataset_name_or_id=runner.dataset_name, fold=0, configuration=\"3d_fullres\"\n            )\n\n            print(\"Max Epochs: \", nnunet_trainer.num_epochs)\n            print(\"Num Iterations: \", nnunet_trainer.num_iterations_per_epoch)\n            print(\"Train Batch dims: \", next(nnunet_trainer.dataloader_train.generator)[\"data\"].shape)\n            print(\"Val Batch dims: \", next(nnunet_trainer.dataloader_val.generator)[\"data\"].shape)\n            print(\"Network: \", nnunet_trainer.network)\n            print(\"Optimizer: \", nnunet_trainer.optimizer)\n            print(\"Loss Function: \", nnunet_trainer.loss)\n            print(\"LR Scheduler: \", nnunet_trainer.lr_scheduler)\n            print(\"Device: \", nnunet_trainer.device)\n            runner.train_single_model(\"3d_fullres\", fold=0)\n\n        nnunet_config = {\"dataset_name_or_id\": \"001\", \"nnunet_trainer\": \"nnUNetTrainer_1epoch\"}\n        self.bundle_root = os.path.join(\"bundle_root\")\n\n        Path(self.bundle_root).joinpath(\"models\").mkdir(parents=True, exist_ok=True)\n        convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0)\n\n        data_transforms = Compose([LoadImaged(keys=\"image\"), EnsureChannelFirstd(keys=\"image\")])\n        dataset = Dataset(\n            data=[{\"image\": os.path.join(self.test_path, \"dataroot\", \"val_001.fake.nii.gz\")}], transform=data_transforms\n        )\n        data_loader = DataLoader(dataset, batch_size=1)\n        input = next(iter(data_loader))\n\n        predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath(\"models\", \"fold_0\"))\n        pred_batch = predictor(input[\"image\"])\n        Path(self.sim_dataroot).joinpath(\"predictions\").mkdir(parents=True, exist_ok=True)\n\n        post_processing_transforms = Compose(\n            [\n                Decollated(keys=None, detach=True),\n                # Not needed after reading the data directly from the MONAI LoadImaged Transform\n                # Transposed(keys=\"pred\", indices=[0, 3, 2, 1]),\n                SaveImaged(\n                    keys=\"pred\", output_dir=Path(self.sim_dataroot).joinpath(\"predictions\"), output_postfix=\"pred\"\n                ),\n            ]\n        )\n        post_processing_transforms({\"pred\": pred_batch})\n\n    def tearDown(self) -> None:\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_nnunetv2_runner.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\n\nfrom monai.apps.nnunet import nnUNetV2Runner\nfrom monai.bundle.config_parser import ConfigParser\nfrom monai.data import create_test_image_3d\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick\n\n_, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n_, has_nnunet = optional_import(\"nnunetv2\")\n\nsim_datalist: dict[str, list[dict]] = {\n    \"testing\": [{\"image\": \"val_001.fake.nii.gz\"}, {\"image\": \"val_002.fake.nii.gz\"}],\n    \"training\": [\n        {\"fold\": 0, \"image\": \"tr_image_001.fake.nii.gz\", \"label\": \"tr_label_001.fake.nii.gz\"},\n        {\"fold\": 0, \"image\": \"tr_image_002.fake.nii.gz\", \"label\": \"tr_label_002.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_003.fake.nii.gz\", \"label\": \"tr_label_003.fake.nii.gz\"},\n        {\"fold\": 1, \"image\": \"tr_image_004.fake.nii.gz\", \"label\": \"tr_label_004.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_005.fake.nii.gz\", \"label\": \"tr_label_005.fake.nii.gz\"},\n        {\"fold\": 2, \"image\": \"tr_image_006.fake.nii.gz\", \"label\": \"tr_label_006.fake.nii.gz\"},\n        {\"fold\": 3, \"image\": \"tr_image_007.fake.nii.gz\", \"label\": \"tr_label_007.fake.nii.gz\"},\n        {\"fold\": 3, \"image\": \"tr_image_008.fake.nii.gz\", \"label\": \"tr_label_008.fake.nii.gz\"},\n        {\"fold\": 4, \"image\": \"tr_image_009.fake.nii.gz\", \"label\": \"tr_label_009.fake.nii.gz\"},\n        {\"fold\": 4, \"image\": \"tr_image_010.fake.nii.gz\", \"label\": \"tr_label_010.fake.nii.gz\"},\n    ],\n}\n\n\n@skip_if_quick\n@unittest.skipIf(not has_tb, \"no tensorboard summary writer\")\n@unittest.skipIf(not has_nnunet, \"no nnunetv2\")\nclass TestnnUNetV2Runner(unittest.TestCase):\n    def setUp(self) -> None:\n        self.test_dir = tempfile.TemporaryDirectory()\n        test_path = self.test_dir.name\n\n        sim_dataroot = os.path.join(test_path, \"dataroot\")\n        if not os.path.isdir(sim_dataroot):\n            os.makedirs(sim_dataroot)\n\n        # Generate a fake dataset\n        for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n            im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=2)\n            nib_image = nib.Nifti1Image(im, affine=np.eye(4))\n            image_fpath = os.path.join(sim_dataroot, d[\"image\"])\n            nib.save(nib_image, image_fpath)\n\n            if \"label\" in d:\n                nib_image = nib.Nifti1Image(seg, affine=np.eye(4))\n                label_fpath = os.path.join(sim_dataroot, d[\"label\"])\n                nib.save(nib_image, label_fpath)\n\n        sim_json_datalist = os.path.join(sim_dataroot, \"sim_input.json\")\n        ConfigParser.export_config_file(sim_datalist, sim_json_datalist)\n\n        data_src_cfg = os.path.join(test_path, \"data_src_cfg.yaml\")\n        data_src = {\"modality\": \"CT\", \"datalist\": sim_json_datalist, \"dataroot\": sim_dataroot}\n\n        ConfigParser.export_config_file(data_src, data_src_cfg)\n        self.data_src_cfg = data_src_cfg\n        self.test_path = test_path\n\n    @skip_if_no_cuda\n    def test_nnunetv2runner(self) -> None:\n        runner = nnUNetV2Runner(input_config=self.data_src_cfg, trainer_class_name=\"nnUNetTrainer_1epoch\")\n        with skip_if_downloading_fails():\n            runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False)\n            runner.train(configs=\"3d_fullres\")\n            runner.find_best_configuration(configs=\"3d_fullres\")\n            runner.predict_ensemble_postprocessing()\n\n    def tearDown(self) -> None:\n        self.test_dir.cleanup()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_segmentation_3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom glob import glob\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.data import create_test_image_3d, decollate_batch\nfrom monai.inferers import sliding_window_inference\nfrom monai.metrics import DiceMetric\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import UNet\nfrom monai.transforms import (\n    Activations,\n    AsDiscrete,\n    Compose,\n    EnsureChannelFirstd,\n    LoadImaged,\n    RandCropByPosNegLabeld,\n    RandRotate90d,\n    SaveImage,\n    ScaleIntensityd,\n    Spacingd,\n)\nfrom monai.utils import optional_import, set_determinism\nfrom monai.visualize import plot_2d_or_3d_image\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_quick\nfrom tests.testing_data.integration_answers import test_integration_value\n\nSummaryWriter, _ = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\nTASK = \"integration_segmentation_3d\"\n\n\ndef run_training_test(root_dir, device=\"cuda:0\", cachedataset=0, readers=(None, None)):\n    monai.config.print_config()\n    images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n    segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n    train_files = [{\"img\": img, \"seg\": seg} for img, seg in zip(images[:20], segs[:20])]\n    val_files = [{\"img\": img, \"seg\": seg} for img, seg in zip(images[-20:], segs[-20:])]\n\n    # define transforms for image and segmentation\n    train_transforms = Compose(\n        [\n            LoadImaged(keys=[\"img\", \"seg\"], reader=readers[0]),\n            EnsureChannelFirstd(keys=[\"img\", \"seg\"]),\n            # resampling with align_corners=True or dtype=float64 will generate\n            # slight different results between PyTorch 1.5 an 1.6\n            Spacingd(keys=[\"img\", \"seg\"], pixdim=[1.2, 0.8, 0.7], mode=[\"bilinear\", \"nearest\"], dtype=np.float32),\n            ScaleIntensityd(keys=\"img\"),\n            RandCropByPosNegLabeld(\n                keys=[\"img\", \"seg\"], label_key=\"seg\", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4\n            ),\n            RandRotate90d(keys=[\"img\", \"seg\"], prob=0.8, spatial_axes=[0, 2]),\n        ]\n    )\n    train_transforms.set_random_state(1234)\n    val_transforms = Compose(\n        [\n            LoadImaged(keys=[\"img\", \"seg\"], reader=readers[1]),\n            EnsureChannelFirstd(keys=[\"img\", \"seg\"]),\n            # resampling with align_corners=True or dtype=float64 will generate\n            # slight different results between PyTorch 1.5 an 1.6\n            Spacingd(keys=[\"img\", \"seg\"], pixdim=[1.2, 0.8, 0.7], mode=[\"bilinear\", \"nearest\"], dtype=np.float32),\n            ScaleIntensityd(keys=\"img\"),\n        ]\n    )\n\n    # create a training data loader\n    if cachedataset == 2:\n        train_ds = monai.data.CacheDataset(\n            data=train_files, transform=train_transforms, cache_rate=0.8, runtime_cache=\"process\"\n        )\n    elif cachedataset == 3:\n        train_ds = monai.data.LMDBDataset(data=train_files, transform=train_transforms, cache_dir=root_dir)\n    else:\n        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)\n    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training\n    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)\n    # create a validation data loader\n    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)\n    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)\n    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n    dice_metric = DiceMetric(include_background=True, reduction=\"mean\", get_not_nans=False)\n\n    # create UNet, DiceLoss and Adam optimizer\n    model = monai.networks.nets.UNet(\n        spatial_dims=3,\n        in_channels=1,\n        out_channels=1,\n        channels=(16, 32, 64, 128, 256),\n        strides=(2, 2, 2, 2),\n        num_res_units=2,\n    ).to(device)\n    loss_function = monai.losses.DiceLoss(sigmoid=True)\n    optimizer = torch.optim.Adam(model.parameters(), 5e-4)\n\n    # start a typical PyTorch training\n    val_interval = 2\n    best_metric, best_metric_epoch = -1, -1\n    epoch_loss_values = []\n    metric_values = []\n    writer = SummaryWriter(log_dir=os.path.join(root_dir, \"runs\"))\n    model_filename = os.path.join(root_dir, \"best_metric_model.pth\")\n    for epoch in range(6):\n        print(\"-\" * 10)\n        print(f\"Epoch {epoch + 1}/{6}\")\n        model.train()\n        epoch_loss = 0\n        step = 0\n        for batch_data in train_loader:\n            step += 1\n            inputs, labels = batch_data[\"img\"].to(device), batch_data[\"seg\"].to(device)\n            optimizer.zero_grad()\n            outputs = model(inputs)\n            loss = loss_function(outputs, labels)\n            loss.backward()\n            optimizer.step()\n            epoch_loss += loss.item()\n            epoch_len = len(train_ds) // train_loader.batch_size\n            print(f\"{step}/{epoch_len}, train_loss:{loss.item():0.4f}\")\n            writer.add_scalar(\"train_loss\", loss.item(), epoch_len * epoch + step)\n        epoch_loss /= step\n        epoch_loss_values.append(epoch_loss)\n        print(f\"epoch {epoch + 1} average loss:{epoch_loss:0.4f}\")\n\n        if (epoch + 1) % val_interval == 0:\n            with eval_mode(model):\n                val_images = None\n                val_labels = None\n                val_outputs = None\n                for val_data in val_loader:\n                    val_images, val_labels = (val_data[\"img\"].to(device), val_data[\"seg\"].to(device))\n                    sw_batch_size, roi_size = 4, (96, 96, 96)\n                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)\n                    # decollate prediction into a list and execute post processing for every item\n                    val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]\n                    # compute metrics\n                    dice_metric(y_pred=val_outputs, y=val_labels)\n\n                metric = dice_metric.aggregate().item()\n                dice_metric.reset()\n                metric_values.append(metric)\n                if metric > best_metric:\n                    best_metric = metric\n                    best_metric_epoch = epoch + 1\n                    torch.save(model.state_dict(), model_filename)\n                    print(\"saved new best metric model\")\n                print(\n                    f\"current epoch {epoch + 1} current mean dice: {metric:0.4f} \"\n                    f\"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}\"\n                )\n                writer.add_scalar(\"val_mean_dice\", metric, epoch + 1)\n                # plot the last model output as GIF image in TensorBoard with the corresponding image and label\n                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag=\"image\")\n                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag=\"label\")\n                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag=\"output\")\n    print(f\"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}\")\n    writer.close()\n    return epoch_loss_values, best_metric\n\n\ndef run_inference_test(root_dir, device=\"cuda:0\"):\n    images = sorted(glob(os.path.join(root_dir, \"im*.nii.gz\")))\n    segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n    val_files = [{\"img\": img, \"seg\": seg} for img, seg in zip(images, segs)]\n\n    saver = SaveImage(\n        output_dir=os.path.join(root_dir, \"output\"),\n        dtype=np.float32,\n        output_ext=\".nii.gz\",\n        output_postfix=\"seg\",\n        mode=\"bilinear\",\n    )\n    # define transforms for image and segmentation\n    val_transforms = Compose(\n        [\n            LoadImaged(keys=[\"img\", \"seg\"]),\n            EnsureChannelFirstd(keys=[\"img\", \"seg\"]),\n            # resampling with align_corners=True or dtype=float64 will generate\n            # slight different results between PyTorch 1.5 an 1.6\n            Spacingd(keys=[\"img\", \"seg\"], pixdim=[1.2, 0.8, 0.7], mode=[\"bilinear\", \"nearest\"], dtype=np.float32),\n            ScaleIntensityd(keys=\"img\"),\n        ]\n    )\n    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)\n    # sliding window inference need to input 1 image in every iteration\n    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4)\n    val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5), saver])\n    dice_metric = DiceMetric(include_background=True, reduction=\"mean\", get_not_nans=False)\n\n    model = UNet(\n        spatial_dims=3,\n        in_channels=1,\n        out_channels=1,\n        channels=(16, 32, 64, 128, 256),\n        strides=(2, 2, 2, 2),\n        num_res_units=2,\n    ).to(device)\n\n    model_filename = os.path.join(root_dir, \"best_metric_model.pth\")\n    model.load_state_dict(torch.load(model_filename, weights_only=True))\n    with eval_mode(model):\n        # resampling with align_corners=True or dtype=float64 will generate\n        # slight different results between PyTorch 1.5 an 1.6\n        for val_data in val_loader:\n            val_images, val_labels = (val_data[\"img\"].to(device), val_data[\"seg\"].to(device))\n            # define sliding window size and batch size for windows inference\n            sw_batch_size, roi_size = 4, (96, 96, 96)\n            val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)\n            # decollate prediction into a list\n            val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)]\n            # compute metrics\n            dice_metric(y_pred=val_outputs, y=val_labels)\n\n    return dice_metric.aggregate().item()\n\n\n@skip_if_quick\nclass IntegrationSegmentation3D(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n\n        self.data_dir = tempfile.mkdtemp()\n        for i in range(40):\n            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)\n            n = nib.Nifti1Image(im, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"img{i:d}.nii.gz\"))\n            n = nib.Nifti1Image(seg, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"seg{i:d}.nii.gz\"))\n\n        self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu:0\"\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        shutil.rmtree(self.data_dir)\n\n    def train_and_infer(self, idx=0):\n        results = []\n        set_determinism(0)\n        _readers = (None, None)\n        if idx == 1:\n            _readers = (\"itkreader\", \"itkreader\")\n        elif idx == 2:\n            _readers = (\"itkreader\", \"nibabelreader\")\n        losses, best_metric = run_training_test(self.data_dir, device=self.device, cachedataset=idx, readers=_readers)\n        infer_metric = run_inference_test(self.data_dir, device=self.device)\n\n        # check training properties\n        print(\"losses\", losses)\n        print(\"best metric\", best_metric)\n        print(\"infer metric\", infer_metric)\n        self.assertTrue(len(glob(os.path.join(self.data_dir, \"runs\"))) > 0)\n        model_file = os.path.join(self.data_dir, \"best_metric_model.pth\")\n        self.assertTrue(os.path.exists(model_file))\n\n        # check inference properties\n        output_files = sorted(glob(os.path.join(self.data_dir, \"output\", \"img*\", \"*.nii.gz\")))\n        print([np.mean(nib.load(output).get_fdata()) for output in output_files])\n        results.extend(losses)\n        results.append(best_metric)\n        results.append(infer_metric)\n        for output in output_files:\n            ave = np.mean(nib.load(output).get_fdata())\n            results.append(ave)\n        self.assertTrue(test_integration_value(TASK, key=\"losses\", data=results[:6], rtol=1e-3))\n        self.assertTrue(test_integration_value(TASK, key=\"best_metric\", data=results[6], rtol=1e-2))\n        self.assertTrue(test_integration_value(TASK, key=\"infer_metric\", data=results[7], rtol=1e-2))\n        self.assertTrue(test_integration_value(TASK, key=\"output_sums\", data=results[8:], rtol=5e-2))\n        return results\n\n    def test_training(self):\n        repeated = []\n        for i in range(4):\n            results = self.train_and_infer(i)\n            repeated.append(results)\n        np.testing.assert_allclose(repeated[0], repeated[1])\n        np.testing.assert_allclose(repeated[0], repeated[2])\n        np.testing.assert_allclose(repeated[0], repeated[3])\n\n    @TimedCall(seconds=360, daemon=False)\n    def test_timing(self):\n        self.train_and_infer(idx=3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_sliding_window.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom ignite.engine import Engine, Events\nfrom torch.utils.data import DataLoader\n\nfrom monai.data import ImageDataset, create_test_image_3d\nfrom monai.inferers import sliding_window_inference\nfrom monai.networks import eval_mode, predict_segmentation\nfrom monai.networks.nets import UNet\nfrom monai.transforms import EnsureChannelFirst, SaveImage\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, make_nifti_image, skip_if_quick\n\n\ndef run_test(batch_size, img_name, seg_name, output_dir, device=\"cuda:0\"):\n    ds = ImageDataset(\n        [img_name],\n        [seg_name],\n        transform=EnsureChannelFirst(channel_dim=\"no_channel\"),\n        seg_transform=EnsureChannelFirst(channel_dim=\"no_channel\"),\n        image_only=True,\n    )\n    loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())\n\n    net = UNet(\n        spatial_dims=3, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2\n    ).to(device)\n    roi_size = (16, 32, 48)\n    sw_batch_size = batch_size\n\n    saver = SaveImage(output_dir=output_dir, output_ext=\".nii.gz\", output_postfix=\"seg\")\n\n    def _sliding_window_processor(_engine, batch):\n        img = batch[0]  # first item from ImageDataset is the input image\n        with eval_mode(net):\n            seg_probs = sliding_window_inference(img.to(device), roi_size, sw_batch_size, net, device=device)\n            return predict_segmentation(seg_probs)\n\n    def save_func(engine):\n        for m in engine.state.output:\n            saver(m)\n\n    infer_engine = Engine(_sliding_window_processor)\n    infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)\n    infer_engine.run(loader)\n\n    basename = os.path.basename(img_name)[: -len(\".nii.gz\")]\n    saved_name = os.path.join(output_dir, basename, f\"{basename}_seg.nii.gz\")\n    return saved_name\n\n\n@skip_if_quick\nclass TestIntegrationSlidingWindow(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n\n        im, seg = create_test_image_3d(28, 25, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)\n        self.img_name = make_nifti_image(im)\n        self.seg_name = make_nifti_image(seg)\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        if os.path.exists(self.img_name):\n            os.remove(self.img_name)\n        if os.path.exists(self.seg_name):\n            os.remove(self.seg_name)\n\n    @TimedCall(seconds=20)\n    def test_training(self):\n        set_determinism(seed=0)\n        with tempfile.TemporaryDirectory() as tempdir:\n            output_file = run_test(\n                batch_size=2, img_name=self.img_name, seg_name=self.seg_name, output_dir=tempdir, device=self.device\n            )\n            output_image = nib.load(output_file).get_fdata()\n            np.testing.assert_allclose(np.sum(output_image), 33621)\n            np.testing.assert_allclose(output_image.shape, (28, 25, 63))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_stn.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nfrom monai.data import create_test_image_2d\nfrom monai.networks.layers import AffineTransform\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall\n\n\nclass STNBenchmark(nn.Module):\n    \"\"\"\n    adapted from https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html\n    \"\"\"\n\n    def __init__(self, is_ref=True, reverse_indexing=False):\n        super().__init__()\n        self.is_ref = is_ref\n        self.localization = nn.Sequential(\n            nn.Conv2d(1, 8, kernel_size=7),\n            nn.MaxPool2d(2, stride=2),\n            nn.ReLU(True),\n            nn.Conv2d(8, 10, kernel_size=5),\n            nn.MaxPool2d(2, stride=2),\n            nn.ReLU(True),\n        )\n        # Regressor for the 3 * 2 affine matrix\n        self.fc_loc = nn.Sequential(nn.Linear(10 * 3 * 3, 32), nn.ReLU(True), nn.Linear(32, 3 * 2))\n        # Initialize the weights/bias with identity transformation\n        self.fc_loc[2].weight.data.zero_()\n        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))\n        if not self.is_ref:\n            self.xform = AffineTransform(align_corners=False, normalized=True, reverse_indexing=reverse_indexing)\n\n    # Spatial transformer network forward function\n    def stn_ref(self, x):\n        xs = self.localization(x)\n        xs = xs.view(-1, 10 * 3 * 3)\n        theta = self.fc_loc(xs)\n        theta = theta.view(-1, 2, 3)\n\n        grid = F.affine_grid(theta, x.size(), align_corners=False)\n        x = F.grid_sample(x, grid, align_corners=False)\n        return x\n\n    def stn(self, x):\n        xs = self.localization(x)\n        xs = xs.view(-1, 10 * 3 * 3)\n        theta = self.fc_loc(xs)\n        theta = theta.view(-1, 2, 3)\n        x = self.xform(x, theta, spatial_size=x.size()[2:])\n        return x\n\n    def forward(self, x):\n        if self.is_ref:\n            return self.stn_ref(x)\n        return self.stn(x)\n\n\ndef compare_2d(is_ref=True, device=None, reverse_indexing=False):\n    batch_size = 32\n    img_a = [create_test_image_2d(28, 28, 5, rad_max=6, noise_max=1)[0][None] for _ in range(batch_size)]\n    img_b = [create_test_image_2d(28, 28, 5, rad_max=6, noise_max=1)[0][None] for _ in range(batch_size)]\n    img_a = np.stack(img_a, axis=0)\n    img_b = np.stack(img_b, axis=0)\n    img_a = torch.as_tensor(img_a, device=device)\n    img_b = torch.as_tensor(img_b, device=device)\n    model = STNBenchmark(is_ref=is_ref, reverse_indexing=reverse_indexing).to(device)\n    optimizer = optim.SGD(model.parameters(), lr=0.001)\n    model.train()\n    init_loss = None\n    for _ in range(20):\n        optimizer.zero_grad()\n        output_a = model(img_a)\n        loss = torch.mean((output_a - img_b) ** 2)\n        if init_loss is None:\n            init_loss = loss.item()\n        loss.backward()\n        optimizer.step()\n    return model(img_a).detach().cpu().numpy(), loss.item(), init_loss\n\n\nclass TestSpatialTransformerCore(DistTestCase):\n    def setUp(self):\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n\n    def tearDown(self):\n        set_determinism(seed=None)\n\n    @TimedCall(seconds=100, skip_timing=not torch.cuda.is_available())\n    def test_training(self):\n        \"\"\"\n        check that the quality AffineTransform backpropagation\n        \"\"\"\n        atol = 1e-5\n        set_determinism(seed=0)\n        out_ref, loss_ref, init_loss_ref = compare_2d(True, self.device)\n        print(out_ref.shape, loss_ref, init_loss_ref)\n\n        set_determinism(seed=0)\n        out, loss, init_loss = compare_2d(False, self.device)\n        print(out.shape, loss, init_loss)\n        np.testing.assert_allclose(out_ref, out, atol=atol)\n        np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)\n        np.testing.assert_allclose(loss_ref, loss, atol=atol)\n\n        set_determinism(seed=0)\n        out, loss, init_loss = compare_2d(False, self.device, True)\n        print(out.shape, loss, init_loss)\n        np.testing.assert_allclose(out_ref, out, atol=atol)\n        np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)\n        np.testing.assert_allclose(loss_ref, loss, atol=atol)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_unet_2d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom ignite.engine import create_supervised_trainer\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom monai.data import create_test_image_2d\nfrom monai.losses import DiceLoss\nfrom monai.networks.nets import BasicUNet, UNet\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_quick\n\n\ndef run_test(net_name=\"basicunet\", batch_size=64, train_steps=100, device=\"cuda:0\"):\n    class _TestBatch(Dataset):\n        def __getitem__(self, _unused_id):\n            im, seg = create_test_image_2d(128, 128, noise_max=1, num_objs=4, num_seg_classes=1)\n            return im[None], seg[None].astype(np.float32)\n\n        def __len__(self):\n            return train_steps\n\n    net = None\n    if net_name == \"basicunet\":\n        net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32))\n    elif net_name == \"unet\":\n        net = UNet(\n            spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32), strides=(2, 2, 2), num_res_units=2\n        )\n    net.to(device)\n\n    loss = DiceLoss(sigmoid=True)\n    opt = torch.optim.Adam(net.parameters(), 1e-4)\n    src = DataLoader(_TestBatch(), batch_size=batch_size)\n\n    trainer = create_supervised_trainer(net, opt, loss, device, False)\n\n    trainer.run(src, 1)\n    loss = trainer.state.output\n    return loss\n\n\n@skip_if_quick\nclass TestIntegrationUnet2D(DistTestCase):\n    @TimedCall(seconds=20, daemon=False)\n    def test_unet_training(self):\n        for n in [\"basicunet\", \"unet\"]:\n            loss = run_test(net_name=n, device=torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\"))\n            print(loss)\n            self.assertGreaterEqual(0.85, loss)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_workers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import DataLoader\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_no_cuda, skip_if_quick\n\n\ndef run_loading_test(num_workers=50, device=None, pw=False):\n    \"\"\"multi workers stress tests\"\"\"\n    set_determinism(seed=0)\n    if device is None:\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n    train_ds = list(range(10000))\n    train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=num_workers, persistent_workers=pw)\n    answer = []\n    for _ in range(2):\n        np.testing.assert_equal(torch.cuda.memory_allocated(), 0)\n        for batch_data in train_loader:\n            x = batch_data.to(device)\n            mem = torch.cuda.memory_allocated()\n            np.testing.assert_equal(mem > 0 and mem < 5000, True)\n        answer.append(x[-1].item())\n        del x\n    return answer\n\n\n@skip_if_quick\n@skip_if_no_cuda\nclass IntegrationLoading(DistTestCase):\n    def tearDown(self):\n        set_determinism(seed=None)\n\n    @TimedCall(seconds=5000, skip_timing=not torch.cuda.is_available(), daemon=False)\n    def test_timing(self):\n        expected = None\n        for pw in (False, True):\n            result = run_loading_test(pw=pw)\n            if expected is None:\n                expected = result[0]\n        np.testing.assert_allclose(result[0], expected)  # test for deterministic first epoch in two settings\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_workflows.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nimport warnings\nfrom glob import glob\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom ignite.engine import Events\nfrom ignite.metrics import Accuracy\n\nimport monai\nfrom monai.data import create_test_image_3d\nfrom monai.engines import IterationEvents, SupervisedEvaluator, SupervisedTrainer\nfrom monai.handlers import (\n    CheckpointLoader,\n    CheckpointSaver,\n    LrScheduleHandler,\n    MeanDice,\n    StatsHandler,\n    TensorBoardImageHandler,\n    TensorBoardStatsHandler,\n    ValidationHandler,\n    from_engine,\n)\nfrom monai.inferers import SimpleInferer, SlidingWindowInferer\nfrom monai.transforms import (\n    Activationsd,\n    AsDiscreted,\n    Compose,\n    EnsureChannelFirstd,\n    KeepLargestConnectedComponentd,\n    LoadImaged,\n    RandCropByPosNegLabeld,\n    RandRotate90d,\n    SaveImage,\n    SaveImaged,\n    ScaleIntensityd,\n)\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, assert_allclose, skip_if_quick\nfrom tests.testing_data.integration_answers import test_integration_value\n\nSummaryWriter, _ = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\nTASK = \"integration_workflows\"\n\n\ndef run_training_test(root_dir, device=\"cuda:0\", amp=False, num_workers=4):\n    images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n    segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n    train_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images[:20], segs[:20])]\n    val_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images[-20:], segs[-20:])]\n\n    # define transforms for image and segmentation\n    train_transforms = Compose(\n        [\n            LoadImaged(keys=[\"image\", \"label\"]),\n            EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n            ScaleIntensityd(keys=[\"image\", \"label\"]),\n            RandCropByPosNegLabeld(\n                keys=[\"image\", \"label\"], label_key=\"label\", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4\n            ),\n            RandRotate90d(keys=[\"image\", \"label\"], prob=0.5, spatial_axes=[0, 2]),\n        ]\n    )\n    val_transforms = Compose(\n        [\n            LoadImaged(keys=[\"image\", \"label\"]),\n            EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n            ScaleIntensityd(keys=[\"image\", \"label\"]),\n        ]\n    )\n\n    # create a training data loader\n    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)\n    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training\n    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=num_workers)\n    # create a validation data loader\n    val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)\n    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)\n\n    # create UNet, DiceLoss and Adam optimizer\n    net = monai.networks.nets.UNet(\n        spatial_dims=3,\n        in_channels=1,\n        out_channels=1,\n        channels=(16, 32, 64, 128, 256),\n        strides=(2, 2, 2, 2),\n        num_res_units=2,\n    ).to(device)\n    loss = monai.losses.DiceLoss(sigmoid=True)\n    opt = torch.optim.Adam(net.parameters(), 1e-3)\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1)\n    summary_writer = SummaryWriter(log_dir=root_dir)\n\n    val_postprocessing = Compose(\n        [\n            Activationsd(keys=\"pred\", sigmoid=True),\n            AsDiscreted(keys=\"pred\", threshold=0.5),\n            KeepLargestConnectedComponentd(keys=\"pred\", applied_labels=[1]),\n        ]\n    )\n\n    class _TestEvalIterEvents:\n        def attach(self, engine):\n            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)\n\n        def _forward_completed(self, engine):\n            pass\n\n    val_handlers = [\n        StatsHandler(iteration_log=False),\n        TensorBoardStatsHandler(summary_writer=summary_writer, iteration_log=False),\n        TensorBoardImageHandler(\n            log_dir=root_dir, batch_transform=from_engine([\"image\", \"label\"]), output_transform=from_engine(\"pred\")\n        ),\n        CheckpointSaver(save_dir=root_dir, save_dict={\"net\": net}, save_key_metric=True),\n        _TestEvalIterEvents(),\n    ]\n\n    evaluator = SupervisedEvaluator(\n        device=device,\n        val_data_loader=val_loader,\n        network=net,\n        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),\n        postprocessing=val_postprocessing,\n        key_val_metric={\n            \"val_mean_dice\": MeanDice(include_background=True, output_transform=from_engine([\"pred\", \"label\"]))\n        },\n        additional_metrics={\"val_acc\": Accuracy(output_transform=from_engine([\"pred\", \"label\"]))},\n        metric_cmp_fn=lambda cur, prev: cur >= prev,  # if greater or equal, treat as new best metric\n        val_handlers=val_handlers,\n        amp=bool(amp),\n        to_kwargs={\"memory_format\": torch.preserve_format},\n        amp_kwargs={\"dtype\": torch.float16 if bool(amp) else torch.float32},\n    )\n\n    train_postprocessing = Compose(\n        [\n            Activationsd(keys=\"pred\", sigmoid=True),\n            AsDiscreted(keys=\"pred\", threshold=0.5),\n            KeepLargestConnectedComponentd(keys=\"pred\", applied_labels=[1]),\n        ]\n    )\n\n    class _TestTrainIterEvents:\n        def attach(self, engine):\n            engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self._forward_completed)\n            engine.add_event_handler(IterationEvents.LOSS_COMPLETED, self._loss_completed)\n            engine.add_event_handler(IterationEvents.BACKWARD_COMPLETED, self._backward_completed)\n            engine.add_event_handler(IterationEvents.MODEL_COMPLETED, self._model_completed)\n\n        def _forward_completed(self, engine):\n            pass\n\n        def _loss_completed(self, engine):\n            pass\n\n        def _backward_completed(self, engine):\n            pass\n\n        def _model_completed(self, engine):\n            pass\n\n    train_handlers = [\n        LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),\n        ValidationHandler(validator=evaluator, interval=2, epoch_level=True),\n        StatsHandler(tag_name=\"train_loss\", output_transform=from_engine(\"loss\", first=True)),\n        TensorBoardStatsHandler(\n            summary_writer=summary_writer, tag_name=\"train_loss\", output_transform=from_engine(\"loss\", first=True)\n        ),\n        CheckpointSaver(save_dir=root_dir, save_dict={\"net\": net, \"opt\": opt}, save_interval=2, epoch_level=True),\n        _TestTrainIterEvents(),\n    ]\n\n    trainer = SupervisedTrainer(\n        device=device,\n        max_epochs=5,\n        train_data_loader=train_loader,\n        network=net,\n        optimizer=opt,\n        loss_function=loss,\n        inferer=SimpleInferer(),\n        postprocessing=train_postprocessing,\n        key_train_metric={\"train_acc\": Accuracy(output_transform=from_engine([\"pred\", \"label\"]))},\n        train_handlers=train_handlers,\n        amp=bool(amp),\n        optim_set_to_none=True,\n        to_kwargs={\"memory_format\": torch.preserve_format},\n        amp_kwargs={\"dtype\": torch.float16 if bool(amp) else torch.float32},\n    )\n    trainer.run()\n\n    # test train and validation stats\n    train_stats = trainer.get_stats(\"output\")\n    assert_allclose(train_stats[\"output\"][0][\"loss\"], trainer.state.output[0][\"loss\"])\n    val_stats = evaluator.get_stats(\"metrics\")\n\n    return val_stats[\"best_validation_metric\"]\n\n\ndef run_inference_test(root_dir, model_file, device=\"cuda:0\", amp=False, num_workers=4):\n    images = sorted(glob(os.path.join(root_dir, \"im*.nii.gz\")))\n    segs = sorted(glob(os.path.join(root_dir, \"seg*.nii.gz\")))\n    val_files = [{\"image\": img, \"label\": seg} for img, seg in zip(images, segs)]\n\n    # define transforms for image and segmentation\n    val_transforms = Compose(\n        [\n            LoadImaged(keys=[\"image\", \"label\"]),\n            EnsureChannelFirstd(keys=[\"image\", \"label\"], channel_dim=-1),\n            ScaleIntensityd(keys=[\"image\", \"label\"]),\n        ]\n    )\n\n    # create a validation data loader\n    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)\n    val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=num_workers)\n\n    # create UNet, DiceLoss and Adam optimizer\n    net = monai.networks.nets.UNet(\n        spatial_dims=3,\n        in_channels=1,\n        out_channels=1,\n        channels=(16, 32, 64, 128, 256),\n        strides=(2, 2, 2, 2),\n        num_res_units=2,\n    ).to(device)\n\n    val_postprocessing = Compose(\n        [\n            Activationsd(keys=\"pred\", sigmoid=True),\n            AsDiscreted(keys=\"pred\", threshold=0.5),\n            KeepLargestConnectedComponentd(keys=\"pred\", applied_labels=[1]),\n            # test the case that `pred` in `engine.state.output`, while `image_meta_dict` in `engine.state.batch`\n            SaveImaged(keys=\"pred\", output_dir=root_dir, output_postfix=\"seg_transform\"),\n        ]\n    )\n    val_handlers = [\n        StatsHandler(iteration_log=False),\n        CheckpointLoader(load_path=f\"{model_file}\", load_dict={\"net\": net}),\n    ]\n\n    saver = SaveImage(output_dir=root_dir, output_postfix=\"seg_handler\")\n\n    def save_func(engine):\n        for o in from_engine(\"pred\")(engine.state.output):\n            saver(o)\n\n    evaluator = SupervisedEvaluator(\n        device=device,\n        val_data_loader=val_loader,\n        network=net,\n        inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),\n        postprocessing=val_postprocessing,\n        key_val_metric={\n            \"val_mean_dice\": MeanDice(include_background=True, output_transform=from_engine([\"pred\", \"label\"]))\n        },\n        additional_metrics={\"val_acc\": Accuracy(output_transform=from_engine([\"pred\", \"label\"]))},\n        val_handlers=val_handlers,\n        amp=bool(amp),\n    )\n    evaluator.add_event_handler(Events.ITERATION_COMPLETED, save_func)\n    evaluator.run()\n\n    return evaluator.state.best_metric\n\n\n@skip_if_quick\nclass IntegrationWorkflows(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n\n        self.data_dir = tempfile.mkdtemp()\n        for i in range(40):\n            im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)\n            n = nib.Nifti1Image(im, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"img{i:d}.nii.gz\"))\n            n = nib.Nifti1Image(seg, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"seg{i:d}.nii.gz\"))\n\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n        monai.config.print_config()\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        shutil.rmtree(self.data_dir)\n\n    def train_and_infer(self, idx=0):\n        results = []\n        set_determinism(seed=0)\n        best_metric = run_training_test(self.data_dir, device=self.device, amp=(idx == 2))\n        model_file = sorted(glob(os.path.join(self.data_dir, \"net_key_metric*.pt\")))[-1]\n        infer_metric = run_inference_test(self.data_dir, model_file, device=self.device, amp=(idx == 2))\n\n        print(\"best metric\", best_metric)\n        print(\"infer metric\", infer_metric)\n        if idx == 2:\n            self.assertTrue(test_integration_value(TASK, key=\"best_metric_2\", data=best_metric, rtol=1e-2))\n        else:\n            self.assertTrue(test_integration_value(TASK, key=\"best_metric\", data=best_metric, rtol=1e-2))\n        # check inference properties\n        if idx == 2:\n            self.assertTrue(test_integration_value(TASK, key=\"infer_metric_2\", data=infer_metric, rtol=1e-2))\n        else:\n            self.assertTrue(test_integration_value(TASK, key=\"infer_metric\", data=infer_metric, rtol=1e-2))\n        results.append(best_metric)\n        results.append(infer_metric)\n\n        def _test_saved_files(postfix):\n            output_files = sorted(glob(os.path.join(self.data_dir, \"img*\", f\"*{postfix}.nii.gz\")))\n            values = []\n            for output in output_files:\n                ave = np.mean(nib.load(output).get_fdata())\n                values.append(ave)\n            if idx == 2:\n                self.assertTrue(test_integration_value(TASK, key=\"output_sums_2\", data=values, rtol=1e-2))\n            else:\n                self.assertTrue(test_integration_value(TASK, key=\"output_sums\", data=values, rtol=1e-2))\n\n        _test_saved_files(postfix=\"seg_handler\")\n        _test_saved_files(postfix=\"seg_transform\")\n        try:\n            os.remove(model_file)\n        except Exception as e:\n            warnings.warn(f\"Fail to remove {model_file}: {e}.\")\n        if torch.cuda.is_available():\n            try:\n                torch.cuda.empty_cache()\n            except Exception:\n                pass\n\n        return results\n\n    def test_training(self):\n        repeated = []\n        test_rounds = 3\n        for i in range(test_rounds):\n            results = self.train_and_infer(idx=i)\n            repeated.append(results)\n        np.testing.assert_allclose(repeated[0], repeated[1])\n\n    @TimedCall(seconds=300, skip_timing=not torch.cuda.is_available(), daemon=False)\n    def test_timing(self):\n        self.train_and_infer(idx=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_workflows_adversarial.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom glob import glob\n\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.data import create_test_image_2d\nfrom monai.engines import AdversarialTrainer\nfrom monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler\nfrom monai.networks.nets import AutoEncoder, Discriminator\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd\nfrom monai.utils import AdversarialKeys as Keys\nfrom monai.utils import CommonKeys, optional_import, set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_quick\n\nnib, has_nibabel = optional_import(\"nibabel\")\n\n\ndef run_training_test(root_dir, device=\"cuda:0\"):\n    learning_rate = 2e-4\n    real_label = 1\n    fake_label = 0\n\n    real_images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n    train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)]\n\n    # prepare real data\n    train_transforms = Compose(\n        [\n            LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]),\n            EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2),\n            ScaleIntensityd(keys=[CommonKeys.IMAGE]),\n            RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5),\n        ]\n    )\n    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)\n    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)\n\n    # Create Discriminator\n    discriminator_net = Discriminator(\n        in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5\n    ).to(device)\n    discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate)\n    discriminator_loss_criterion = torch.nn.BCELoss()\n\n    def discriminator_loss(real_logits, fake_logits):\n        real_target = real_logits.new_full((real_logits.shape[0], 1), real_label)\n        fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label)\n        real_loss = discriminator_loss_criterion(real_logits, real_target)\n        fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target)\n        return torch.div(torch.add(real_loss, fake_loss), 2)\n\n    # Create Generator\n    generator_network = AutoEncoder(\n        spatial_dims=2,\n        in_channels=1,\n        out_channels=1,\n        channels=(8, 16, 32, 64),\n        strides=(2, 2, 2, 2),\n        num_res_units=1,\n        num_inter_units=1,\n    )\n    generator_network = generator_network.to(device)\n    generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate)\n    generator_loss_criterion = torch.nn.MSELoss()\n\n    def reconstruction_loss(recon_images, real_images):\n        return generator_loss_criterion(recon_images, real_images)\n\n    def generator_loss(fake_logits):\n        fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label)\n        recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target)\n        return recon_loss\n\n    key_train_metric = None\n\n    train_handlers = [\n        StatsHandler(\n            name=\"training_loss\",\n            output_transform=lambda x: {\n                Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS],\n                Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS],\n                Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS],\n            },\n        ),\n        TensorBoardStatsHandler(\n            log_dir=root_dir,\n            tag_name=\"training_loss\",\n            output_transform=lambda x: {\n                Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS],\n                Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS],\n                Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS],\n            },\n        ),\n        CheckpointSaver(\n            save_dir=root_dir,\n            save_dict={\"g_net\": generator_network, \"d_net\": discriminator_net},\n            save_interval=2,\n            epoch_level=True,\n        ),\n    ]\n\n    num_epochs = 5\n\n    trainer = AdversarialTrainer(\n        device=device,\n        max_epochs=num_epochs,\n        train_data_loader=train_loader,\n        g_network=generator_network,\n        g_optimizer=generator_optimiser,\n        g_loss_function=generator_loss,\n        recon_loss_function=reconstruction_loss,\n        d_network=discriminator_net,\n        d_optimizer=discriminator_opt,\n        d_loss_function=discriminator_loss,\n        non_blocking=True,\n        key_train_metric=key_train_metric,\n        train_handlers=train_handlers,\n    )\n    trainer.run()\n\n    return trainer.state\n\n\n@skip_if_quick\n@unittest.skipUnless(has_nibabel, \"Requires nibabel library.\")\nclass IntegrationWorkflowsAdversarialTrainer(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n\n        self.data_dir = tempfile.mkdtemp()\n        for i in range(40):\n            im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1)\n            n = nib.Nifti1Image(im, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"img{i:d}.nii.gz\"))\n\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n        monai.config.print_config()\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        shutil.rmtree(self.data_dir)\n\n    @TimedCall(seconds=300, daemon=False)\n    def test_training(self):\n        torch.manual_seed(0)\n\n        finish_state = run_training_test(self.data_dir, device=self.device)\n\n        # Assert AdversarialTrainer training finished\n        self.assertEqual(finish_state.iteration, 100)\n        self.assertEqual(finish_state.epoch, 5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_integration_workflows_gan.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom glob import glob\n\nimport nibabel as nib\nimport numpy as np\nimport torch\n\nimport monai\nfrom monai.data import create_test_image_2d\nfrom monai.engines import GanTrainer\nfrom monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler\nfrom monai.networks import normal_init\nfrom monai.networks.nets import Discriminator, Generator\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd\nfrom monai.utils import GanKeys as Keys\nfrom monai.utils import set_determinism\nfrom tests.test_utils import DistTestCase, TimedCall, skip_if_quick\n\n\ndef run_training_test(root_dir, device=\"cuda:0\"):\n    real_images = sorted(glob(os.path.join(root_dir, \"img*.nii.gz\")))\n    train_files = [{\"reals\": img} for img in zip(real_images)]\n\n    # prepare real data\n    train_transforms = Compose(\n        [\n            LoadImaged(keys=[\"reals\"]),\n            EnsureChannelFirstd(keys=[\"reals\"], channel_dim=-1),\n            ScaleIntensityd(keys=[\"reals\"]),\n            RandFlipd(keys=[\"reals\"], prob=0.5),\n        ]\n    )\n    train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)\n    train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)\n\n    learning_rate = 2e-4\n    betas = (0.5, 0.999)\n    real_label = 1\n    fake_label = 0\n\n    # create discriminator\n    disc_net = Discriminator(\n        in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5\n    ).to(device)\n    disc_net.apply(normal_init)\n    disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas)\n    disc_loss_criterion = torch.nn.BCELoss()\n\n    def discriminator_loss(gen_images, real_images):\n        real = real_images.new_full((real_images.shape[0], 1), real_label)\n        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)\n        realloss = disc_loss_criterion(disc_net(real_images), real)\n        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)\n        return torch.div(torch.add(realloss, genloss), 2)\n\n    # create generator\n    latent_size = 64\n    gen_net = Generator(\n        latent_shape=latent_size, start_shape=(latent_size, 8, 8), channels=[32, 16, 8, 1], strides=[2, 2, 2, 1]\n    )\n    gen_net.apply(normal_init)\n    gen_net.conv.add_module(\"activation\", torch.nn.Sigmoid())\n    gen_net = gen_net.to(device)\n    gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas)\n    gen_loss_criterion = torch.nn.BCELoss()\n\n    def generator_loss(gen_images):\n        output = disc_net(gen_images)\n        cats = output.new_full(output.shape, real_label)\n        return gen_loss_criterion(output, cats)\n\n    key_train_metric = None\n\n    train_handlers = [\n        StatsHandler(\n            name=\"training_loss\", output_transform=lambda x: {Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS]}\n        ),\n        TensorBoardStatsHandler(\n            log_dir=root_dir,\n            tag_name=\"training_loss\",\n            output_transform=lambda x: {Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS]},\n        ),\n        CheckpointSaver(\n            save_dir=root_dir, save_dict={\"g_net\": gen_net, \"d_net\": disc_net}, save_interval=2, epoch_level=True\n        ),\n    ]\n\n    disc_train_steps = 2\n    num_epochs = 5\n\n    trainer = GanTrainer(\n        device,\n        num_epochs,\n        train_loader,\n        gen_net,\n        gen_opt,\n        generator_loss,\n        disc_net,\n        disc_opt,\n        discriminator_loss,\n        d_train_steps=disc_train_steps,\n        latent_shape=latent_size,\n        key_train_metric=key_train_metric,\n        train_handlers=train_handlers,\n        to_kwargs={\"memory_format\": torch.preserve_format, \"dtype\": torch.float32},\n    )\n    trainer.run()\n\n    return trainer.state\n\n\n@skip_if_quick\nclass IntegrationWorkflowsGAN(DistTestCase):\n    def setUp(self):\n        set_determinism(seed=0)\n\n        self.data_dir = tempfile.mkdtemp()\n        for i in range(40):\n            im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1)\n            n = nib.Nifti1Image(im, np.eye(4))\n            nib.save(n, os.path.join(self.data_dir, f\"img{i:d}.nii.gz\"))\n\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n        monai.config.print_config()\n\n    def tearDown(self):\n        set_determinism(seed=None)\n        shutil.rmtree(self.data_dir)\n\n    @TimedCall(seconds=200, daemon=False)\n    def test_training(self):\n        torch.manual_seed(0)\n\n        finish_state = run_training_test(self.data_dir, device=self.device)\n\n        # assert GAN training finished\n        self.assertEqual(finish_state.iteration, 100)\n        self.assertEqual(finish_state.epoch, 5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_loader_semaphore.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"this test should not generate errors or\nUserWarning: semaphore_tracker: There appear to be 1 leaked semaphores\"\"\"\nfrom __future__ import annotations\n\nimport multiprocessing as mp\nimport unittest\n\n\ndef w():\n    pass\n\n\ndef _main():\n    ps = mp.Process(target=w)\n    ps.start()\n    ps.join()\n\n\ndef _run_test():\n    try:\n        tmp = mp.get_context(\"spawn\")\n    except RuntimeError:\n        tmp = mp\n    p = tmp.Process(target=_main)\n    p.start()\n    p.join()\n\n\nclass TestImportLock(unittest.TestCase):\n\n    def test_start(self):\n        _run_test()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_mapping_filed.py",
    "content": "# Copyright (c) MONAI Consortium\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nfrom __future__ import annotations\r\n\r\nimport json\r\nimport os\r\nimport shutil\r\nimport tempfile\r\nimport unittest\r\n\r\nimport numpy as np\r\nimport torch\r\nfrom parameterized import parameterized\r\n\r\nfrom monai.data import DataLoader, Dataset, decollate_batch\r\nfrom monai.inferers import sliding_window_inference\r\nfrom monai.networks.nets import UNet\r\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, SaveImaged, WriteFileMappingd\r\nfrom monai.utils import optional_import\r\n\r\nnib, has_nib = optional_import(\"nibabel\")\r\n\r\n\r\ndef create_input_file(temp_dir, name):\r\n    test_image = np.random.rand(128, 128, 128)\r\n    input_file = os.path.join(temp_dir, name + \".nii.gz\")\r\n    nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)\r\n    return input_file\r\n\r\n\r\n# Test cases that should succeed\r\nSUCCESS_CASES = [([\"seg\"], [\"seg\"]), ([\"image\", \"seg\"], [\"seg\"])]\r\n\r\n# Test cases that should fail\r\nFAILURE_CASES = [([\"seg\"], [\"image\"]), ([\"image\"], [\"seg\"]), ([\"seg\"], [\"image\", \"seg\"])]\r\n\r\n\r\n@unittest.skipUnless(has_nib, \"nibabel required\")\r\nclass TestWriteFileMappingd(unittest.TestCase):\r\n    def setUp(self):\r\n        self.temp_dir = tempfile.mkdtemp()\r\n        self.output_dir = os.path.join(self.temp_dir, \"output\")\r\n        os.makedirs(self.output_dir)\r\n        self.mapping_file_path = os.path.join(self.temp_dir, \"mapping.json\")\r\n\r\n    def tearDown(self):\r\n        shutil.rmtree(self.temp_dir)\r\n        if os.path.exists(self.mapping_file_path):\r\n            os.remove(self.mapping_file_path)\r\n\r\n    def run_test(self, save_keys, write_keys):\r\n        name = \"test_image\"\r\n        input_file = create_input_file(self.temp_dir, name)\r\n        output_file = os.path.join(self.output_dir, name, name + \"_seg.nii.gz\")\r\n        data = [{\"image\": input_file}]\r\n\r\n        test_transforms = Compose([LoadImaged(keys=[\"image\"]), EnsureChannelFirstd(keys=[\"image\"])])\r\n\r\n        post_transforms = Compose(\r\n            [\r\n                SaveImaged(\r\n                    keys=save_keys,\r\n                    meta_keys=\"image_meta_dict\",\r\n                    output_dir=self.output_dir,\r\n                    output_postfix=\"seg\",\r\n                    savepath_in_metadict=True,\r\n                ),\r\n                WriteFileMappingd(keys=write_keys, mapping_file_path=self.mapping_file_path),\r\n            ]\r\n        )\r\n\r\n        dataset = Dataset(data=data, transform=test_transforms)\r\n        dataloader = DataLoader(dataset, batch_size=1)\r\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n        model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device)\r\n        model.eval()\r\n\r\n        with torch.no_grad():\r\n            for batch_data in dataloader:\r\n                test_inputs = batch_data[\"image\"].to(device)\r\n                roi_size = (64, 64, 64)\r\n                sw_batch_size = 2\r\n                batch_data[\"seg\"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)\r\n                batch_data = [post_transforms(i) for i in decollate_batch(batch_data)]\r\n\r\n        return input_file, output_file\r\n\r\n    @parameterized.expand(SUCCESS_CASES)\r\n    def test_successful_mapping_filed(self, save_keys, write_keys):\r\n        input_file, output_file = self.run_test(save_keys, write_keys)\r\n        self.assertTrue(os.path.exists(self.mapping_file_path))\r\n        with open(self.mapping_file_path) as f:\r\n            mapping_data = json.load(f)\r\n        self.assertEqual(len(mapping_data), len(write_keys))\r\n        for entry in mapping_data:\r\n            self.assertEqual(entry[\"input\"], input_file)\r\n            self.assertEqual(entry[\"output\"], output_file)\r\n\r\n    @parameterized.expand(FAILURE_CASES)\r\n    def test_failure_mapping_filed(self, save_keys, write_keys):\r\n        with self.assertRaises(RuntimeError) as cm:\r\n            self.run_test(save_keys, write_keys)\r\n\r\n        cause_exception = cm.exception.__cause__\r\n        self.assertIsInstance(cause_exception, KeyError)\r\n        self.assertIn(\r\n            \"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.\",\r\n            str(cause_exception),\r\n        )\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    unittest.main()\r\n"
  },
  {
    "path": "tests/integration/test_meta_affine.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.image_writer import ITKWriter\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirst,\n    EnsureChannelFirstd,\n    LoadImage,\n    LoadImaged,\n    MapTransform,\n    Orientation,\n    Orientationd,\n    Randomizable,\n    Spacing,\n    Spacingd,\n    Transform,\n)\nfrom monai.utils import convert_data_type, optional_import\nfrom tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config\n\nitk, has_itk = optional_import(\"itk\")\nTINY_DIFF = 1e-4\n\nkeys = (\"img1\", \"img2\")\nkey, key_1 = \"ref_avg152T1_LR\", \"ref_avg152T1_RL\"\nTESTS_PATH = Path(__file__).parents[1]\nFILE_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"{key}.nii.gz\")\nFILE_PATH_1 = os.path.join(TESTS_PATH, \"testing_data\", f\"{key_1}.nii.gz\")\n\nTEST_CASES_ARRAY = [\n    [Compose([Spacing(pixdim=(1.0, 1.1, 1.2)), Orientation(axcodes=\"RAS\")]), {}, TINY_DIFF],\n    [Compose([Orientation(axcodes=\"RAS\"), Spacing(pixdim=(1.0, 1.1, 1.2))]), {}, TINY_DIFF],\n    [\"CropForeground\", {\"k_divisible\": 3}, TINY_DIFF],\n    [\"BorderPad\", {\"spatial_border\": (2, 3, 4)}, TINY_DIFF],\n    [\"CenterScaleCrop\", {\"roi_scale\": (0.6, 0.7, 0.8)}, TINY_DIFF],\n    [\"CenterSpatialCrop\", {\"roi_size\": (30, 200, 52)}, TINY_DIFF],\n    [\"DivisiblePad\", {\"k\": 16}, TINY_DIFF],\n    [\"RandScaleCrop\", {\"roi_scale\": (0.3, 0.4, 0.5)}, TINY_DIFF],\n    [\"RandSpatialCrop\", {\"roi_size\": (31, 32, 33)}, TINY_DIFF],\n    [\"ResizeWithPadOrCrop\", {\"spatial_size\": (50, 80, 200)}, TINY_DIFF],\n    [\"Spacing\", {\"pixdim\": (1.0, 1.1, 1.2)}, TINY_DIFF],\n    [\"Orientation\", {\"axcodes\": \"RAS\"}, TINY_DIFF],\n    [\"Flip\", {\"spatial_axis\": (0, 1)}, TINY_DIFF],\n    [\"Resize\", {\"spatial_size\": (100, 201, 1)}, 30.0],\n    [\"Rotate\", {\"angle\": (np.pi / 3, np.pi / 2, np.pi / 4), \"mode\": \"bilinear\"}, 20.0],\n    [\"Zoom\", {\"zoom\": (0.8, 0.91, 1.2)}, 20.0],\n    [\"Rotate90\", {\"k\": 3}, TINY_DIFF],\n    [\"RandRotate90\", {\"prob\": 1.0, \"max_k\": 3}, TINY_DIFF],\n    [\"RandRotate\", {\"prob\": 1.0, \"range_x\": np.pi / 3}, 20.0],\n    [\"RandFlip\", {\"prob\": 1.0}, TINY_DIFF],\n    [\"RandAxisFlip\", {\"prob\": 1.0}, TINY_DIFF],\n    [\"RandZoom\", {\"prob\": 1.0, \"mode\": \"trilinear\"}, TINY_DIFF],\n    [\n        \"RandAffine\",\n        {\n            \"prob\": 1.0,\n            \"rotate_range\": (np.pi / 4, np.pi / 3, np.pi / 2),\n            \"translate_range\": (3, 4, 5),\n            \"scale_range\": (-0.1, 0.2),\n            \"spatial_size\": (30, 40, 50),\n            \"cache_grid\": True,\n            \"mode\": \"bilinear\",\n        },\n        20.0,\n    ],\n    [\n        \"Affine\",\n        {\n            \"rotate_params\": (np.pi / 4, 0.0, 0.0),\n            \"translate_params\": (3, 4, 5),\n            \"spatial_size\": (30, 40, 50),\n            \"mode\": \"bilinear\",\n            \"image_only\": True,\n        },\n        20.0,\n    ],\n]\n\nTEST_CASES_DICT = [\n    [Compose([Spacingd(keys, pixdim=(1.0, 1.1, 1.2)), Orientationd(keys, axcodes=\"LAS\")]), {}, TINY_DIFF],\n    [Compose([Orientationd(keys, axcodes=\"LAS\"), Spacingd(keys, pixdim=(1.0, 1.1, 1.2))]), {}, TINY_DIFF],\n    [\"CropForegroundd\", {\"k_divisible\": 3, \"source_key\": \"img1\"}, TINY_DIFF],\n]\nfor c in TEST_CASES_ARRAY[3:-1]:  # exclude CropForegroundd and Affined\n    TEST_CASES_DICT.append(deepcopy(c))\n    TEST_CASES_DICT[-1][0] = TEST_CASES_DICT[-1][0] + \"d\"  # type: ignore\n\n\ndef _create_itk_obj(array, affine):\n    itk_img = deepcopy(array)\n    itk_img = convert_data_type(itk_img, np.ndarray)[0]\n    itk_obj = ITKWriter.create_backend_obj(itk_img, channel_dim=None, affine=affine, affine_lps=True)\n    return itk_obj\n\n\ndef _resample_to_affine(itk_obj, ref_obj):\n    \"\"\"linear resample\"\"\"\n    dim = itk_obj.GetImageDimension()\n    transform = itk.IdentityTransform[itk.D, dim].New()\n    interpolator = itk.LinearInterpolateImageFunction[type(itk_obj), itk.D].New()\n    resampled = itk.resample_image_filter(\n        Input=itk_obj, interpolator=interpolator, transform=transform, UseReferenceImage=True, ReferenceImage=ref_obj\n    )\n    return resampled\n\n\n@unittest.skipUnless(has_itk, \"Requires itk package.\")\nclass TestAffineConsistencyITK(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        for k, n in ((key, FILE_PATH), (key_1, FILE_PATH_1)):\n            config = testing_data_config(\"images\", f\"{k}\")\n            download_url_or_skip_test(filepath=n, **config)\n\n    def run_transform(self, img, xform_cls, args_dict):\n        if isinstance(xform_cls, Transform):\n            xform = xform_cls\n            output = xform(img, **args_dict)\n        else:\n            if isinstance(xform_cls, str):\n                xform_cls, _ = optional_import(\"monai.transforms\", name=xform_cls)\n            if issubclass(xform_cls, MapTransform):\n                args_dict.update({\"keys\": keys})\n            xform = xform_cls(**args_dict)\n            if isinstance(xform, Randomizable):\n                xform.set_random_state(5)\n            output = xform(img)\n        return output\n\n    @parameterized.expand(TEST_CASES_ARRAY)\n    def test_linear_consistent(self, xform_cls, input_dict, atol):\n        \"\"\"xform cls testing itk consistency\"\"\"\n        img = LoadImage(image_only=True, simple_keys=True)(FILE_PATH)\n        img = EnsureChannelFirst()(img)\n        ref_1 = _create_itk_obj(img[0], img.affine)\n        output = self.run_transform(img, xform_cls, input_dict)\n        ref_2 = _create_itk_obj(output[0], output.affine)\n        assert_allclose(output.pixdim, np.asarray(ref_2.GetSpacing()), type_test=False)\n        expected = _resample_to_affine(ref_1, ref_2)\n        # compare ref_2 and expected results from itk\n        diff = np.abs(itk.GetArrayFromImage(ref_2) - itk.GetArrayFromImage(expected))\n        avg_diff = np.mean(diff)\n\n        self.assertLess(avg_diff, atol, f\"{xform_cls} avg_diff: {avg_diff}, tol: {atol}\")\n\n    @parameterized.expand(TEST_CASES_DICT)\n    def test_linear_consistent_dict(self, xform_cls, input_dict, atol):\n        \"\"\"xform cls testing itk consistency\"\"\"\n        img = LoadImaged(keys, image_only=True, simple_keys=True)({keys[0]: FILE_PATH, keys[1]: FILE_PATH_1})\n        img = EnsureChannelFirstd(keys)(img)\n        ref_1 = {k: _create_itk_obj(img[k][0], img[k].affine) for k in keys}\n        output = self.run_transform(img, xform_cls, input_dict)\n        ref_2 = {k: _create_itk_obj(output[k][0], output[k].affine) for k in keys}\n        expected = {k: _resample_to_affine(ref_1[k], ref_2[k]) for k in keys}\n        # compare ref_2 and expected results from itk\n        diff = {k: np.abs(itk.GetArrayFromImage(ref_2[k]) - itk.GetArrayFromImage(expected[k])) for k in keys}\n        avg_diff = {k: np.mean(diff[k]) for k in keys}\n        for k in keys:\n            self.assertLess(avg_diff[k], atol, f\"{xform_cls} avg_diff: {avg_diff}, tol: {atol}\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_metatensor_integration.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai import config as monai_config\nfrom monai.bundle import ConfigParser\nfrom monai.data import CacheDataset, DataLoader, MetaTensor, decollate_batch\nfrom monai.data.utils import TraceKeys\nfrom monai.transforms import InvertD, SaveImageD, reset_ops_id\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config\n\nnib, has_nib = optional_import(\"nibabel\")\nTINY_DIFF = 0.1\n\nkeys = (\"img\", \"seg\")\nkey, key_1 = \"MNI152_T1_2mm\", \"MNI152_T1_2mm_strucseg\"\nTESTS_PATH = Path(__file__).parents[1]\nFILE_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"{key}.nii.gz\")\nFILE_PATH_1 = os.path.join(TESTS_PATH, \"testing_data\", f\"{key_1}.nii.gz\")\nTEST_CASES = os.path.join(TESTS_PATH, \"testing_data\", \"transform_metatensor_cases.yaml\")\n\n\n@unittest.skipUnless(has_nib, \"Requires nibabel package.\")\nclass TestMetaTensorIntegration(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        for k, n in ((key, FILE_PATH), (key_1, FILE_PATH_1)):\n            config = testing_data_config(\"images\", f\"{k}\")\n            download_url_or_skip_test(filepath=n, **config)\n        cls.files = [{keys[0]: x, keys[1]: y} for (x, y) in [[FILE_PATH, FILE_PATH_1]] * 4]\n\n    @classmethod\n    def tearDownClass(cls):\n        super().tearDownClass()\n        set_determinism(None)\n\n    @parameterized.expand([\"TEST_CASE_1\", \"TEST_CASE_2\", \"TEST_CASE_3\"])\n    def test_transforms(self, case_id):\n        set_determinism(2022)\n        config = ConfigParser()\n        config.read_config(TEST_CASES)\n        config[\"input_keys\"] = keys\n        test_case = config.get_parsed_content(id=case_id, instantiate=True, lazy=False)  # transform instance\n\n        dataset = CacheDataset(self.files, transform=test_case)\n        loader = DataLoader(dataset, batch_size=3, shuffle=True)\n        for x in loader:\n            self.assertIsInstance(x[keys[0]], MetaTensor)\n            self.assertIsInstance(x[keys[1]], MetaTensor)\n            out = decollate_batch(x)  # decollate every batch should work\n\n        # test forward patches\n        loaded = out[0]\n        if not monai_config.USE_META_DICT:\n            self.assertEqual(len(loaded), len(keys))\n        else:\n            self.assertNotEqual(len(loaded), len(keys))\n        img, seg = loaded[keys[0]], loaded[keys[1]]\n        expected = config.get_parsed_content(id=f\"{case_id}_answer\", instantiate=True)  # expected results\n        self.assertEqual(expected[\"load_shape\"], list(x[keys[0]].shape))\n        assert_allclose(expected[\"affine\"], img.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)\n        assert_allclose(expected[\"affine\"], seg.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)\n        test_cls = [type(x).__name__ for x in test_case.transforms]\n        tracked_cls = [x[TraceKeys.CLASS_NAME] for x in img.applied_operations]\n        self.assertTrue(len(tracked_cls) <= len(test_cls))  # tracked items should  be no more than the compose items.\n        with tempfile.TemporaryDirectory() as tempdir:  # test writer\n            SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(loaded)\n        test_data = reset_ops_id(deepcopy(loaded))\n        for val in test_data.values():\n            if isinstance(val, MetaTensor) and val.applied_operations:\n                self.assertEqual(val.applied_operations[-1][TraceKeys.ID], TraceKeys.NONE)\n\n        # test inverse\n        inv = InvertD(keys, orig_keys=keys, transform=test_case, nearest_interp=True)\n        out = inv(loaded)\n        img, seg = out[keys[0]], out[keys[1]]\n        assert_allclose(expected[\"inv_affine\"], img.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)\n        assert_allclose(expected[\"inv_affine\"], seg.affine, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)\n        self.assertFalse(img.applied_operations)\n        self.assertFalse(seg.applied_operations)\n        assert_allclose(expected[\"inv_shape\"], img.shape, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)\n        assert_allclose(expected[\"inv_shape\"], seg.shape, type_test=False, atol=TINY_DIFF, rtol=TINY_DIFF)\n        with tempfile.TemporaryDirectory() as tempdir:  # test writer\n            SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(out)\n            seg_file = os.path.join(tempdir, key_1, f\"{key_1}_{case_id}.nii.gz\")\n            segout = nib.load(seg_file).get_fdata()\n            segin = nib.load(FILE_PATH_1).get_fdata()\n            ndiff = np.sum(np.abs(segout - segin) > 0)\n            total = np.prod(segout.shape)\n        self.assertTrue(ndiff / total < 0.4, f\"{ndiff / total}\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_module_list.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport inspect\nimport os\nimport pathlib\nimport unittest\n\nimport monai\n\n\nclass TestAllImport(unittest.TestCase):\n\n    def test_public_api(self):\n        \"\"\"\n        This is to check \"monai.__all__\" should be consistent with\n        the top-level folders except for \"__pycache__\", \"_extensions\" and \"csrc\" (cpp/cuda src)\n        \"\"\"\n        base_folder = os.path.dirname(monai.__file__)\n        to_search = os.path.join(base_folder, \"*\", \"\")\n        subfolders = [os.path.basename(x[:-1]) for x in glob.glob(to_search)]\n        to_exclude = (\"__pycache__\", \"_extensions\", \"csrc\")\n        mod = []\n        for code_folder in subfolders:\n            if code_folder in to_exclude:\n                continue\n            mod.append(code_folder)\n        self.assertEqual(sorted(monai.__all__), sorted(mod))\n\n    def test_transform_api(self):\n        \"\"\"monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'\"\"\"\n        to_exclude = {\"MapTransform\"}  # except for these transforms\n        to_exclude_docs = {\"Decollate\", \"Ensemble\", \"Invert\", \"SaveClassification\", \"RandTorchVision\", \"RandCrop\"}\n        to_exclude_docs.update({\"DeleteItems\", \"SelectItems\", \"FlattenSubKeys\", \"CopyItems\", \"ConcatItems\"})\n        to_exclude_docs.update({\"ToMetaTensor\", \"FromMetaTensor\"})\n        xforms = {\n            name: obj\n            for name, obj in monai.transforms.__dict__.items()\n            if inspect.isclass(obj) and issubclass(obj, monai.transforms.MapTransform)\n        }\n        names = sorted(x for x in xforms if x not in to_exclude)\n        remained = set(names)\n        doc_file = os.path.join(pathlib.Path(__file__).parent.parent, \"docs\", \"source\", \"transforms.rst\")\n        contents = pathlib.Path(doc_file).read_text() if os.path.exists(doc_file) else None\n        for n in names:\n            if not n.endswith(\"d\"):\n                continue\n            with self.subTest(n=n):\n                basename = n[:-1]  # Transformd basename is Transform\n\n                # remove aliases to check, do this before the assert below so that a failed assert does skip this\n                for postfix in (\"D\", \"d\", \"Dict\"):\n                    remained.remove(f\"{basename}{postfix}\")\n\n                for docname in (f\"{basename}\", f\"{basename}d\"):\n                    if docname in to_exclude_docs:\n                        continue\n                    if (contents is not None) and f\"`{docname}`\" not in f\"{contents}\":\n                        self.assertTrue(False, f\"please add `{docname}` to docs/source/transforms.rst\")\n\n        self.assertFalse(remained)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_one_of.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai.transforms.intensity.array as ia\nimport monai.transforms.spatial.array as sa\nimport monai.transforms.spatial.dictionary as sd\nfrom monai.data import MetaTensor\nfrom monai.transforms import (\n    InvertibleTransform,\n    OneOf,\n    RandScaleIntensity,\n    RandScaleIntensityd,\n    RandShiftIntensity,\n    RandShiftIntensityd,\n    Resize,\n    Resized,\n    Transform,\n)\nfrom monai.transforms.compose import Compose\nfrom monai.transforms.transform import MapTransform\nfrom monai.utils.enums import TraceKeys\n\n\nclass X(Transform):\n\n    def __call__(self, x):\n        return x\n\n\nclass Y(Transform):\n\n    def __call__(self, x):\n        return x\n\n\nclass A(Transform):\n\n    def __call__(self, x):\n        return x + 1\n\n\nclass B(Transform):\n\n    def __call__(self, x):\n        return x + 2\n\n\nclass C(Transform):\n\n    def __call__(self, x):\n        return x + 3\n\n\nclass MapBase(MapTransform):\n\n    def __init__(self, keys):\n        super().__init__(keys)\n        self.fwd_fn, self.inv_fn = None, None\n\n    def __call__(self, data):\n        d = deepcopy(dict(data))\n        for key in self.key_iterator(d):\n            d[key] = self.fwd_fn(d[key])\n        return d\n\n\nclass NonInv(MapBase):\n\n    def __init__(self, keys):\n        super().__init__(keys)\n        self.fwd_fn = lambda x: x * 2\n\n\nclass Inv(MapBase, InvertibleTransform):\n\n    def __call__(self, data):\n        d = deepcopy(dict(data))\n        for key in self.key_iterator(d):\n            d[key] = self.fwd_fn(d[key])\n            self.push_transform(d, key)\n        return d\n\n    def inverse(self, data):\n        d = deepcopy(dict(data))\n        for key in self.key_iterator(d):\n            d[key] = self.inv_fn(d[key])\n            self.pop_transform(d, key)\n        return d\n\n\nclass InvA(Inv):\n\n    def __init__(self, keys):\n        super().__init__(keys)\n        self.fwd_fn = lambda x: x + 1\n        self.inv_fn = lambda x: x - 1\n\n\nclass InvB(Inv):\n\n    def __init__(self, keys):\n        super().__init__(keys)\n        self.fwd_fn = lambda x: x + 100\n        self.inv_fn = lambda x: x - 100\n\n\nTESTS = [((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25))]\n\nKEYS = [\"x\", \"y\"]\nTEST_INVERSES = [\n    (OneOf((InvA(KEYS), InvB(KEYS))), True, True),\n    (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, False),\n    (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, False),\n    (OneOf((NonInv(KEYS), NonInv(KEYS))), False, False),\n]\n\n\nclass TestOneOf(unittest.TestCase):\n\n    @parameterized.expand(TESTS)\n    def test_normalize_weights(self, transforms, input_weights, expected_weights):\n        tr = OneOf(transforms, input_weights)\n        self.assertTupleEqual(tr.weights, expected_weights)\n\n    def test_no_weights_arg(self):\n        p = OneOf((X(), Y(), X(), Y()))\n        expected_weights = (0.25,) * 4\n        self.assertTupleEqual(p.weights, expected_weights)\n\n    def test_len_and_flatten(self):\n        p1 = OneOf((X(), Y()), (1, 3))  # 0.25, 0.75\n        p2 = OneOf((Y(), Y()), (2, 2))  # 0.5. 0.5\n        p = OneOf((p1, p2, X()), (1, 2, 1))  # 0.25, 0.5, 0.25\n        expected_order = (X, Y, Y, Y, X)\n        expected_weights = (0.25 * 0.25, 0.25 * 0.75, 0.5 * 0.5, 0.5 * 0.5, 0.25)\n        self.assertEqual(len(p), len(expected_order))\n        self.assertTupleEqual(p.flatten().weights, expected_weights)\n\n    def test_compose_flatten_does_not_affect_one_of(self):\n        p = Compose([A(), B(), OneOf([C(), Inv(KEYS), Compose([X(), Y()])])])\n        f = p.flatten()\n\n        # in this case the flattened transform should be the same.\n\n        def _match(a, b):\n            self.assertEqual(type(a), type(b))\n            for a_, b_ in zip(a.transforms, b.transforms):\n                self.assertEqual(type(a_), type(b_))\n                if isinstance(a_, (Compose, OneOf)):\n                    _match(a_, b_)\n\n        _match(p, f)\n\n    @parameterized.expand(TEST_INVERSES)\n    def test_inverse(self, transform, invertible, use_metatensor):\n        data = {k: MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}\n        fwd_data = transform(data)\n\n        if invertible:\n            for k in KEYS:\n                t = fwd_data[k].applied_operations[-1]\n                # make sure the OneOf index was stored\n                self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)\n                # make sure index exists and is in bounds\n                self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO][\"index\"] < len(transform))\n\n        # call the inverse\n        fwd_inv_data = transform.inverse(fwd_data)\n\n        if invertible:\n            for k in KEYS:\n                # check data is same as original (and different from forward)\n                self.assertEqual(fwd_inv_data[k], data[k])\n                self.assertNotEqual(fwd_inv_data[k], fwd_data[k])\n        else:\n            # if not invertible, should not change the data\n            self.assertDictEqual(fwd_data, fwd_inv_data)\n\n    def test_inverse_compose(self):\n        transform = Compose(\n            [\n                Resized(keys=\"img\", spatial_size=[100, 100, 100]),\n                OneOf(\n                    [\n                        RandScaleIntensityd(keys=\"img\", factors=0.5, prob=1.0),\n                        RandShiftIntensityd(keys=\"img\", offsets=0.5, prob=1.0),\n                    ]\n                ),\n                OneOf(\n                    [\n                        RandScaleIntensityd(keys=\"img\", factors=0.5, prob=1.0),\n                        RandShiftIntensityd(keys=\"img\", offsets=0.5, prob=1.0),\n                    ]\n                ),\n            ]\n        )\n        transform.set_random_state(seed=0)\n        result = transform({\"img\": np.ones((1, 101, 102, 103))})\n        result = transform.inverse(result)\n        # invert to the original spatial shape\n        self.assertTupleEqual(result[\"img\"].shape, (1, 101, 102, 103))\n\n    def test_inverse_metatensor(self):\n        transform = Compose(\n            [\n                Resize(spatial_size=[100, 100, 100]),\n                OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),\n                OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),\n            ]\n        )\n        transform.set_random_state(seed=0)\n        result = transform(np.ones((1, 101, 102, 103)))\n        self.assertTupleEqual(result.shape, (1, 100, 100, 100))\n        result = transform.inverse(result)\n        self.assertTupleEqual(result.shape, (1, 101, 102, 103))\n\n    def test_one_of(self):\n        p = OneOf((A(), B(), C()), (1, 2, 1))\n        counts = [0] * 3\n        for _i in range(10000):\n            out = p(1.0)\n            counts[int(out - 2)] += 1\n        self.assertAlmostEqual(counts[0] / 10000, 0.25, delta=1.0)\n        self.assertAlmostEqual(counts[1] / 10000, 0.50, delta=1.0)\n        self.assertAlmostEqual(counts[2] / 10000, 0.25, delta=1.0)\n\n\nTEST_ONEOF_EXTENDED_TEST_CASES = [\n    [None, tuple()],\n    [None, (sa.Rotate(np.pi / 8),)],\n    [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())],\n    [(\"a\",), (sd.Rotated((\"a\",), np.pi / 8),)],\n]\n\n\nclass TestOneOfAPITests(unittest.TestCase):\n\n    @staticmethod\n    def data_from_keys(keys):\n        if keys is None:\n            data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0)\n        else:\n            data = {}\n            for i_k, k in enumerate(keys):\n                data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0)\n        return data\n\n    @parameterized.expand(TEST_ONEOF_EXTENDED_TEST_CASES)\n    def test_execute_change_start_end(self, keys, pipeline):\n        data = self.data_from_keys(keys)\n\n        c = OneOf(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, start=1)\n        with self.assertRaises(ValueError):\n            c(data, start=1)\n\n        c = OneOf(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, end=1)\n        with self.assertRaises(ValueError):\n            c(data, end=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_pad_collation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport random\nimport unittest\nfrom contextlib import redirect_stderr\nfrom functools import wraps\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import CacheDataset, DataLoader\nfrom monai.data.utils import decollate_batch, pad_list_data_collate\nfrom monai.transforms import (\n    BatchInverseTransform,\n    Compose,\n    PadListDataCollate,\n    RandRotate,\n    RandRotate90,\n    RandRotate90d,\n    RandRotated,\n    RandSpatialCrop,\n    RandSpatialCropd,\n    RandZoom,\n    RandZoomd,\n    ToTensor,\n)\nfrom monai.utils import first, set_determinism\n\n\n@wraps(pad_list_data_collate)\ndef _testing_collate(x):\n    return pad_list_data_collate(batch=x, method=\"end\", mode=\"constant\")\n\n\nTESTS: list[tuple] = []\n\nfor pad_collate in [_testing_collate, PadListDataCollate(method=\"end\", mode=\"constant\")]:\n    TESTS.append((dict, pad_collate, RandSpatialCropd(\"image\", roi_size=[8, 7], random_size=True)))\n    TESTS.append((dict, pad_collate, RandRotated(\"image\", prob=1, range_x=np.pi, keep_size=False, dtype=np.float64)))\n    TESTS.append((dict, pad_collate, RandZoomd(\"image\", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))\n    TESTS.append(\n        (dict, pad_collate, Compose([RandRotate90d(\"image\", prob=1, max_k=3), RandRotate90d(\"image\", prob=1, max_k=4)]))\n    )\n\n    TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True)))\n    TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False, dtype=np.float64)))\n    TESTS.append((list, pad_collate, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))\n    TESTS.append((list, pad_collate, Compose([RandRotate90(prob=1, max_k=2), ToTensor()])))\n\n\nclass _Dataset(torch.utils.data.Dataset):\n\n    def __init__(self, images, labels, transforms):\n        self.images = images\n        self.labels = labels\n        self.transform = transforms\n\n    def __len__(self):\n        return len(self.images)\n\n    def __getitem__(self, index):\n        return self.transform(self.images[index]), self.labels[index]\n\n\nclass TestPadCollation(unittest.TestCase):\n\n    def setUp(self) -> None:\n        set_determinism(seed=0)\n        # image is non square to throw rotation errors\n        im = np.arange(0, 10 * 9).reshape(1, 10, 9)\n        num_elements = 20\n        self.dict_data = [{\"image\": im} for _ in range(num_elements)]\n        self.list_data = [im for _ in range(num_elements)]\n        self.list_labels = [random.randint(0, 1) for _ in range(num_elements)]\n\n    def tearDown(self) -> None:\n        set_determinism(None)\n\n    @parameterized.expand(TESTS)\n    def test_pad_collation(self, t_type, collate_method, transform):\n        if t_type is dict:\n            dataset = CacheDataset(self.dict_data, transform, progress=False)\n        else:\n            dataset = _Dataset(self.list_data, self.list_labels, transform)\n\n        # Default collation should raise an error\n        loader_fail = DataLoader(dataset, batch_size=10)\n        with self.assertRaises(RuntimeError):\n            # stifle PyTorch error reporting, we expect failure so don't need to look at it\n            with open(os.devnull) as f, redirect_stderr(f):\n                _ = first(loader_fail)\n\n        # Padded collation shouldn't\n        loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)\n        # check collation in forward direction\n        for data in loader:\n            if t_type is dict:\n                shapes = []\n                decollated_data = decollate_batch(data)\n                for d in decollated_data:\n                    output = PadListDataCollate.inverse(d)\n                    shapes.append(output[\"image\"].shape)\n                    self.assertTrue(len(output[\"image\"].applied_operations), len(dataset.transform.transforms))\n                self.assertTrue(len(set(shapes)) > 1)  # inverted shapes must be different because of random xforms\n\n        if t_type is dict:\n            batch_inverse = BatchInverseTransform(dataset.transform, loader)\n            for data in loader:\n                output = batch_inverse(data)\n                self.assertEqual(output[0][\"image\"].shape, (1, 10, 9))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_reg_loss_integration.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom parameterized import parameterized\n\nfrom monai.losses import BendingEnergyLoss, GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss\nfrom monai.utils import set_determinism\n\nTEST_CASES = [\n    [BendingEnergyLoss, {}, [\"pred\"], 3],\n    [LocalNormalizedCrossCorrelationLoss, {\"kernel_size\": 7, \"kernel_type\": \"rectangular\"}, [\"pred\", \"target\"]],\n    [LocalNormalizedCrossCorrelationLoss, {\"kernel_size\": 5, \"kernel_type\": \"triangular\"}, [\"pred\", \"target\"]],\n    [LocalNormalizedCrossCorrelationLoss, {\"kernel_size\": 3, \"kernel_type\": \"gaussian\"}, [\"pred\", \"target\"]],\n    [GlobalMutualInformationLoss, {\"num_bins\": 10}, [\"pred\", \"target\"]],\n    [GlobalMutualInformationLoss, {\"kernel_type\": \"b-spline\", \"num_bins\": 10}, [\"pred\", \"target\"]],\n]\n\n\nclass TestRegLossIntegration(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES)\n    def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1):\n        \"\"\"\n        The goal of this test is to assess if the gradient of the loss function\n        is correct by testing if we can train a one layer neural network\n        to segment one image.\n        We verify that the loss is decreasing in almost all SGD steps.\n        \"\"\"\n        learning_rate = 0.001\n        max_iter = 100\n\n        # define a simple 3d example\n        target = torch.rand((1, 1, 5, 5, 5), device=self.device)\n        image = 12 * target + 27\n        image = image.to(device=self.device)\n\n        # define a one layer model\n        class OnelayerNet(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.layer = nn.Sequential(\n                    nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1),\n                    nn.ReLU(),\n                    nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1),\n                )\n\n            def forward(self, x):\n                return self.layer(x)\n\n        # initialise the network\n        net = OnelayerNet().to(self.device)\n\n        # initialize the loss\n        loss = loss_type(**loss_args).to(self.device)\n\n        # initialize a SGD optimizer\n        optimizer = optim.Adam(net.parameters(), lr=learning_rate)\n\n        # declare first for pylint\n        init_loss = None\n\n        # train the network\n        for it in range(max_iter):\n            # set the gradient to zero\n            optimizer.zero_grad()\n\n            # forward pass\n            output = net(image)\n            loss_input = {\"pred\": output, \"target\": target}\n\n            loss_val = loss(**{k: loss_input[k] for k in forward_args})\n            if it == 0:\n                init_loss = loss_val\n\n            # backward pass\n            loss_val.backward()\n            optimizer.step()\n        self.assertGreater(init_loss, loss_val, \"loss did not decrease\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_retinanet_predict_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.detection.utils.predict_utils import ensure_dict_value_to_list_, predict_with_inferer\nfrom monai.inferers import SlidingWindowInferer\n\nTEST_CASE_1 = [  # 3D, batch 3, 2 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 3,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": 7,\n        \"conv1_t_stride\": (2, 2, 2),\n    },\n    (3, 2, 32, 64, 48),\n]\n\nTEST_CASE_2 = [  # 2D, batch 2, 1 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [7, 7],\n        \"conv1_t_stride\": [2, 2],\n    },\n    (2, 1, 32, 64),\n]\n\nTEST_CASE_2_A = [  # 2D, batch 2, 1 input channel, shortcut type A\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"shortcut_type\": \"A\",\n        \"conv1_t_size\": (7, 7),\n        \"conv1_t_stride\": 2,\n    },\n    (2, 1, 32, 64),\n]\n\nTEST_CASE_3 = [  # 1D, batch 1, 2 input channels\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n]\n\nTEST_CASE_3_A = [  # 1D, batch 1, 2 input channels\n    {\"pretrained\": False, \"spatial_dims\": 1, \"n_input_channels\": 2, \"num_classes\": 3, \"shortcut_type\": \"A\"},\n    (1, 2, 32),\n]\n\nTEST_CASE_4 = [  # 2D, batch 2, 1 input channel\n    {\"pretrained\": False, \"spatial_dims\": 2, \"n_input_channels\": 1, \"num_classes\": 3, \"feed_forward\": False},\n    (2, 1, 32, 64),\n]\n\nTEST_CASES = []\nTEST_CASES = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_A]\n\nTEST_CASES_TS = [TEST_CASE_1]\n\n\nclass NaiveNetwork(torch.nn.Module):\n\n    def __init__(self, spatial_dims, num_classes, **kwargs):\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.num_classes = num_classes\n        self.num_anchors = 1\n        self.cls_key = \"cls\"\n        self.box_reg_key = \"box_reg\"\n        self.size_divisible = 1\n\n    def forward(self, images):\n        out_cls_shape = (images.shape[0], self.num_classes * self.num_anchors) + images.shape[-self.spatial_dims :]\n        out_box_reg_shape = (images.shape[0], 2 * self.spatial_dims * self.num_anchors) + images.shape[\n            -self.spatial_dims :\n        ]\n        return {self.cls_key: torch.randn(out_cls_shape), self.box_reg_key: [torch.randn(out_box_reg_shape)]}\n\n\nclass NaiveNetwork2(torch.nn.Module):\n\n    def __init__(self, spatial_dims, num_classes, **kwargs):\n        super().__init__()\n        self.spatial_dims = spatial_dims\n        self.num_classes = num_classes\n        self.num_anchors = 1\n        self.cls_key = \"cls\"\n        self.box_reg_key = \"box_reg\"\n        self.size_divisible = 1\n\n    def forward(self, images):\n        out_cls_shape = (images.shape[0], self.num_classes * self.num_anchors) + images.shape[-self.spatial_dims :]\n        out_box_reg_shape = (images.shape[0], 2 * self.spatial_dims * self.num_anchors) + images.shape[\n            -self.spatial_dims :\n        ]\n        return {self.cls_key: [torch.randn(out_cls_shape)] * 2, self.box_reg_key: [torch.randn(out_box_reg_shape)] * 2}\n\n\nclass TestPredictor(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_naive_predictor(self, input_param, input_shape):\n        net = NaiveNetwork(**input_param)\n        net2 = NaiveNetwork2(**input_param)\n        inferer = SlidingWindowInferer(roi_size=16, overlap=0.25, cache_roi_weight_map=True)\n        network_output_keys = [\"cls\", \"box_reg\"]\n\n        input_data = torch.randn(input_shape)\n\n        result = predict_with_inferer(input_data, net, network_output_keys, inferer=inferer)\n        self.assertTrue(len(result[\"cls\"]) == 1)\n\n        result = net(input_data)\n        self.assertTrue(len(result[\"cls\"]) == input_data.shape[0])\n        ensure_dict_value_to_list_(result)\n        self.assertTrue(len(result[\"cls\"]) == 1)\n\n        result = predict_with_inferer(input_data, net2, network_output_keys, inferer=inferer)\n        self.assertTrue(len(result[\"cls\"]) == 2)\n\n        result = net2(input_data)\n        self.assertTrue(len(result[\"cls\"]) == 2)\n        ensure_dict_value_to_list_(result)\n        self.assertTrue(len(result[\"cls\"]) == 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_seg_loss_integration.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom parameterized import parameterized\n\nfrom monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, TverskyLoss\nfrom monai.networks import one_hot\nfrom monai.utils import set_determinism\n\nTEST_CASES = [\n    [DiceLoss, {\"to_onehot_y\": True, \"squared_pred\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4}, {}],\n    [DiceLoss, {\"to_onehot_y\": True, \"squared_pred\": True, \"smooth_nr\": 0, \"smooth_dr\": 1e-3}, {}],\n    [DiceLoss, {\"to_onehot_y\": False, \"squared_pred\": True, \"smooth_nr\": 0, \"smooth_dr\": 1e-3}, {}],\n    [DiceLoss, {\"to_onehot_y\": True, \"squared_pred\": True, \"batch\": True}, {}],\n    [DiceLoss, {\"to_onehot_y\": True, \"sigmoid\": True}, {}],\n    [DiceLoss, {\"to_onehot_y\": True, \"softmax\": True}, {}],\n    [FocalLoss, {\"to_onehot_y\": True, \"gamma\": 1.5, \"weight\": torch.tensor([1, 2])}, {}],\n    [FocalLoss, {\"to_onehot_y\": False, \"gamma\": 1.5, \"weight\": [1, 2]}, {}],\n    [FocalLoss, {\"to_onehot_y\": False, \"gamma\": 1.5, \"weight\": 1.0}, {}],\n    [FocalLoss, {\"to_onehot_y\": True, \"gamma\": 1.5}, {}],\n    [GeneralizedDiceLoss, {\"to_onehot_y\": True, \"softmax\": True}, {}],\n    [GeneralizedDiceLoss, {\"to_onehot_y\": True, \"sigmoid\": True}, {}],\n    [GeneralizedDiceLoss, {\"to_onehot_y\": True, \"sigmoid\": True, \"w_type\": \"simple\"}, {}],\n    [GeneralizedDiceLoss, {\"to_onehot_y\": True, \"sigmoid\": True, \"w_type\": \"uniform\"}, {}],\n    [GeneralizedDiceLoss, {\"to_onehot_y\": True, \"sigmoid\": True, \"w_type\": \"uniform\", \"batch\": True}, {}],\n    [GeneralizedDiceLoss, {\"to_onehot_y\": False, \"sigmoid\": True, \"w_type\": \"uniform\", \"batch\": True}, {}],\n    [TverskyLoss, {\"to_onehot_y\": True, \"softmax\": True, \"alpha\": 0.8, \"beta\": 0.2}, {}],\n    [TverskyLoss, {\"to_onehot_y\": True, \"softmax\": True, \"alpha\": 0.8, \"beta\": 0.2, \"batch\": True}, {}],\n    [TverskyLoss, {\"to_onehot_y\": True, \"softmax\": True, \"alpha\": 1.0, \"beta\": 0.0}, {}],\n    [TverskyLoss, {\"to_onehot_y\": False, \"softmax\": True, \"alpha\": 1.0, \"beta\": 0.0}, {}],\n]\n\n\nclass TestSegLossIntegration(unittest.TestCase):\n\n    def setUp(self):\n        set_determinism(0)\n        self.device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu:0\")\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES)\n    def test_convergence(self, loss_type, loss_args, forward_args):\n        \"\"\"\n        The goal of this test is to assess if the gradient of the loss function\n        is correct by testing if we can train a one layer neural network\n        to segment one image.\n        We verify that the loss is decreasing in almost all SGD steps.\n        \"\"\"\n        learning_rate = 0.001\n        max_iter = 40\n\n        # define a simple 3d example\n        target_seg = torch.tensor(\n            [\n                [\n                    # raw 0\n                    [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                    # raw 1\n                    [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                    # raw 2\n                    [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                ]\n            ],\n            device=self.device,\n        )\n        target_seg = torch.unsqueeze(target_seg, dim=0)\n        image = 12 * target_seg + 27\n        image = image.float().to(self.device)\n        num_classes = 2\n        num_voxels = 3 * 4 * 4\n\n        target_onehot = one_hot(target_seg, num_classes=num_classes)\n\n        # define a one layer model\n        class OnelayerNet(nn.Module):\n\n            def __init__(self):\n                super().__init__()\n                self.layer_1 = nn.Linear(num_voxels, 200)\n                self.acti = nn.ReLU()\n                self.layer_2 = nn.Linear(200, num_voxels * num_classes)\n\n            def forward(self, x):\n                x = x.view(-1, num_voxels)\n                x = self.layer_1(x)\n                x = self.acti(x)\n                x = self.layer_2(x)\n                x = x.view(-1, num_classes, 3, 4, 4)\n                return x\n\n        # initialise the network\n        net = OnelayerNet().to(self.device)\n\n        # initialize the loss\n        loss = loss_type(**loss_args)\n\n        # initialize a SGD optimizer\n        optimizer = optim.Adam(net.parameters(), lr=learning_rate)\n\n        loss_history = []\n        init_output = None\n\n        # train the network\n        for iter_i in range(max_iter):\n            # set the gradient to zero\n            optimizer.zero_grad()\n\n            # forward pass\n            output = net(image)\n            if init_output is None:\n                init_output = torch.argmax(output, 1).detach().cpu().numpy()\n\n            if loss_args[\"to_onehot_y\"] is False:\n                loss_val = loss(output, target_onehot, **forward_args)\n            else:\n                loss_val = loss(output, target_seg, **forward_args)\n\n            if iter_i % 10 == 0:\n                pred = torch.argmax(output, 1).detach().cpu().numpy()\n                gt = target_seg.detach().cpu().numpy()[:, 0]\n                print(f\"{loss_type.__name__} iter: {iter_i}, acc: {np.sum(pred == gt) / np.prod(pred.shape)}\")\n\n            # backward pass\n            loss_val.backward()\n            optimizer.step()\n\n            # stats\n            loss_history.append(loss_val.item())\n\n        pred = torch.argmax(output, 1).detach().cpu().numpy()\n        target = target_seg.detach().cpu().numpy()[:, 0]\n        # initial predictions are bad\n        self.assertTrue(not np.allclose(init_output, target))\n        # final predictions are good\n        np.testing.assert_allclose(pred, target)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_spatial_combine_transforms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai.transforms as mt\nfrom monai.data import create_test_image_2d, create_test_image_3d\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.lazy.functional import apply_pending\nfrom monai.transforms.transform import MapTransform\nfrom monai.utils import set_determinism\nfrom tests.lazy_transforms_utils import get_apply_param\nfrom tests.test_utils import assert_allclose\n\nTEST_2D = [\n    [\n        (2, 62, 61),\n        [\n            (mt.Spacing, {\"pixdim\": (1.2, 1.5), \"padding_mode\": \"zeros\", \"dtype\": torch.float32}),\n            (mt.Orientation, {\"axcodes\": \"RA\"}),\n            (mt.Resize, {\"spatial_size\": (64, 48), \"mode\": \"bilinear\"}),\n            (mt.RandSpatialCrop, {\"roi_size\": (32, 32)}),\n            (\n                mt.RandAffine,\n                {\n                    \"prob\": 0.9,\n                    \"rotate_range\": (np.pi / 2,),\n                    \"shear_range\": [1, 2],\n                    \"translate_range\": [2, 1],\n                    \"mode\": \"bilinear\",\n                    \"padding_mode\": \"reflection\",\n                },\n            ),\n            (mt.RandFlip, {\"prob\": 0.9}),\n            (mt.RandRotate, {\"prob\": 0.9, \"range_x\": np.pi / 4, \"mode\": \"bilinear\", \"padding_mode\": \"reflection\"}),\n            (mt.CenterScaleCrop, {\"roi_scale\": (0.96, 0.8)}),\n            (mt.RandZoom, {\"prob\": 0.9, \"mode\": \"bilinear\", \"keep_size\": False, \"align_corners\": False}),\n        ],\n    ],\n    [\n        (2, 63, 64),\n        [\n            (mt.CenterScaleCropd, {\"roi_scale\": (0.96, 0.8), \"keys\": \"img\"}),\n            (mt.RandRotated, {\"prob\": 0.9, \"range_x\": np.pi / 4, \"keys\": \"img\"}),\n            (\n                mt.RandZoomd,\n                {\"prob\": 0.9, \"mode\": \"bilinear\", \"keep_size\": False, \"keys\": \"img\", \"align_corners\": False},\n            ),\n            (mt.Spacingd, {\"pixdim\": (1.2, 1.5), \"padding_mode\": \"zeros\", \"dtype\": torch.float32, \"keys\": \"img\"}),\n            (mt.RandFlipd, {\"prob\": 0.9, \"keys\": \"img\"}),\n            (\n                mt.RandAffined,\n                {\n                    \"prob\": 0.9,\n                    \"rotate_range\": (np.pi / 2,),\n                    \"shear_range\": [1, 2],\n                    \"translate_range\": [2, 1],\n                    \"mode\": \"bilinear\",\n                    \"keys\": \"img\",\n                },\n            ),\n            (mt.Orientationd, {\"axcodes\": \"RA\", \"keys\": \"img\"}),\n            (mt.Resized, {\"spatial_size\": (48, 48), \"mode\": \"bilinear\", \"keys\": \"img\"}),\n            (mt.RandScaleCropd, {\"roi_scale\": (0.4, 1.5), \"random_size\": False, \"keys\": \"img\"}),\n        ],\n    ],\n]\n\nTEST_3D = [\n    [\n        (2, 83, 100, 67),\n        [\n            (mt.Orientation, {\"axcodes\": \"RAS\"}),\n            (mt.CenterScaleCrop, {\"roi_scale\": (1.2, 0.8, 1.0)}),\n            (\n                mt.RandAffine,\n                {\n                    \"prob\": 0.9,\n                    \"rotate_range\": (np.pi / 2,),\n                    \"shear_range\": [1, 2],\n                    \"translate_range\": [2, 1],\n                    \"mode\": \"bilinear\",\n                },\n            ),\n            (mt.Spacing, {\"pixdim\": (0.9, 1.2, 1.0), \"padding_mode\": \"zeros\", \"dtype\": torch.float32}),\n            (mt.RandSpatialCrop, {\"roi_size\": (36, 36, 38), \"random_size\": False}),\n            (mt.RandZoom, {\"prob\": 0.9, \"mode\": \"nearest\", \"keep_size\": False}),\n            (mt.Resize, {\"spatial_size\": (32, 32, 32), \"mode\": \"nearest\"}),\n            (mt.RandFlip, {\"prob\": 0.9}),\n            (mt.RandRotate, {\"prob\": 0.9, \"range_x\": np.pi / 4}),\n        ],\n    ],\n    [\n        (2, 62, 64, 72),\n        [\n            (mt.RandScaleCropd, {\"roi_scale\": (0.9, 0.7, 1.1), \"random_size\": False, \"keys\": \"img\"}),\n            (mt.Spacingd, {\"pixdim\": (1.2, 1.5, 0.9), \"padding_mode\": \"zeros\", \"dtype\": torch.float32, \"keys\": \"img\"}),\n            (mt.Orientationd, {\"axcodes\": \"RAS\", \"keys\": \"img\"}),\n            (mt.Resized, {\"spatial_size\": (32, 32, 32), \"mode\": \"nearest\", \"keys\": \"img\"}),\n            (mt.RandFlipd, {\"prob\": 0.9, \"keys\": \"img\"}),\n            (mt.CenterScaleCropd, {\"roi_scale\": (0.96, 0.8, 1.25), \"keys\": \"img\"}),\n            (mt.RandZoomd, {\"prob\": 0.9, \"mode\": \"nearest\", \"keep_size\": False, \"keys\": \"img\"}),\n            (\n                mt.RandAffined,\n                {\n                    \"prob\": 0.9,\n                    \"rotate_range\": (np.pi / 2,),\n                    \"shear_range\": [1, 2],\n                    \"translate_range\": [2, 1],\n                    \"mode\": \"bilinear\",\n                    \"keys\": \"img\",\n                },\n            ),\n            (mt.RandRotated, {\"prob\": 0.9, \"range_x\": np.pi / 4, \"keys\": \"img\"}),\n        ],\n    ],\n]\n\n\nclass CombineLazyTest(unittest.TestCase):\n    @parameterized.expand(TEST_2D + TEST_3D)\n    def test_combine_transforms(self, input_shape, funcs):\n        for device in [\"cpu\", \"cuda\"] if torch.cuda.is_available() else [\"cpu\"]:\n            for seed in [10, 100, 1000, 10000]:\n                set_determinism(seed=seed)\n                _funcs = []\n                for _func, _params in funcs:\n                    _funcs.append(_func(**_params))\n                is_map = isinstance(_funcs[0], MapTransform)\n                chns, sp_size = input_shape[0], input_shape[1:]\n                imgs = []\n                for _ in range(chns):\n                    if len(sp_size) == 2:\n                        imgs.append(create_test_image_2d(sp_size[0], sp_size[1])[0])\n                    else:\n                        imgs.append(create_test_image_3d(sp_size[0], sp_size[1], sp_size[2])[0])\n                data = np.stack(imgs).astype(float)\n                im = MetaTensor(data, meta={\"a\": \"b\", \"affine\": np.eye(len(input_shape))}).to(device)\n                input_data = {\"img\": im} if is_map else im\n                # non lazy\n                non_lazy_result = input_data\n                for _func in _funcs:\n                    if isinstance(_func, mt.Randomizable):\n                        _func.set_random_state(seed=seed)\n                    non_lazy_result = _func(non_lazy_result)\n                expected = non_lazy_result[\"img\"] if is_map else non_lazy_result\n\n                # lazy\n                pending_result = input_data\n                for _func in _funcs:\n                    _func.lazy = True\n                    if isinstance(_func, mt.Randomizable):\n                        _func.set_random_state(seed=seed)\n                    pending_result = _func(pending_result)\n                pending_result = pending_result[\"img\"] if is_map else pending_result\n\n                assert_allclose(pending_result.peek_pending_affine(), expected.affine, atol=1e-7)\n                assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:4])\n\n                # test final result\n                init_param = funcs[-1][1]\n                call_param = {}\n                apply_param = get_apply_param(init_param, call_param)\n                result = apply_pending(pending_result, overrides=apply_param)[0]\n\n                match_ratio = np.sum(np.isclose(result.array, expected.array, atol=5e-1)) / np.prod(result.shape)\n                self.assertGreater(match_ratio, 0.5)  # at least half of the images are very close\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_testtimeaugmentation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom functools import partial\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\nimport torch\n\nfrom monai.data import CacheDataset, DataLoader, create_test_image_2d\nfrom monai.data.test_time_augmentation import TestTimeAugmentation\nfrom monai.data.utils import pad_list_data_collate\nfrom monai.losses import DiceLoss\nfrom monai.networks.nets import UNet\nfrom monai.transforms import (\n    Activations,\n    AsDiscrete,\n    Compose,\n    CropForegroundd,\n    DivisiblePadd,\n    EnsureChannelFirstd,\n    RandAffined,\n    RandScaleIntensityd,\n)\nfrom monai.transforms.croppad.dictionary import SpatialPadd\nfrom monai.transforms.spatial.dictionary import RandFlipd\nfrom monai.utils import optional_import, set_determinism\nfrom monai.utils.enums import PostFix\nfrom tests.test_utils import TEST_NDARRAYS\n\nif TYPE_CHECKING:\n    import tqdm\n\n    has_tqdm = True\n    has_nib = True\nelse:\n    tqdm, has_tqdm = optional_import(\"tqdm\")\n    _, has_nib = optional_import(\"nibabel\")\n\ntrange = partial(tqdm.trange, desc=\"training\") if has_tqdm else range\n\n\nclass TestTestTimeAugmentation(unittest.TestCase):\n    @staticmethod\n    def get_data(num_examples, input_size, data_type=np.asarray, include_label=True):\n        custom_create_test_image_2d = partial(\n            create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1\n        )\n        data = []\n        for i in range(num_examples):\n            im, label = custom_create_test_image_2d()\n            d = {\"image\": data_type(im[:, i:])}\n            if include_label:\n                d[\"label\"] = data_type(label[:, i:])\n                d[PostFix.meta(\"label\")] = {\"affine\": np.eye(4)}\n            data.append(d)\n        return data[0] if num_examples == 1 else data\n\n    def setUp(self) -> None:\n        set_determinism(seed=0)\n\n    def tearDown(self) -> None:\n        set_determinism(None)\n\n    def test_test_time_augmentation(self):\n        input_size = (20, 40)  # test different input data shape to pad list collate\n        keys = [\"image\", \"label\"]\n        num_training_ims = 10\n\n        train_data = self.get_data(num_training_ims, input_size)\n        test_data = self.get_data(1, input_size)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        transforms = Compose(\n            [\n                EnsureChannelFirstd(keys, channel_dim=\"no_channel\"),\n                RandAffined(\n                    keys,\n                    prob=1.0,\n                    spatial_size=(30, 30),\n                    rotate_range=(np.pi / 3, np.pi / 3),\n                    translate_range=(3, 3),\n                    scale_range=((0.8, 1), (0.8, 1)),\n                    padding_mode=\"zeros\",\n                    mode=(\"bilinear\", \"nearest\"),\n                ),\n                CropForegroundd(keys, source_key=\"image\"),\n                DivisiblePadd(keys, 4),\n            ]\n        )\n\n        train_ds = CacheDataset(train_data, transforms)\n        # output might be different size, so pad so that they match\n        train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate)\n\n        model = UNet(2, 1, 1, channels=(6, 6), strides=(2,)).to(device)\n        loss_function = DiceLoss(sigmoid=True)\n        optimizer = torch.optim.Adam(model.parameters(), 1e-3)\n\n        num_epochs = 10\n        for _ in trange(num_epochs):\n            epoch_loss = 0\n\n            for batch_data in train_loader:\n                inputs, labels = (batch_data[\"image\"].to(device), batch_data[\"label\"].to(device))\n                optimizer.zero_grad()\n                outputs = model(inputs)\n                loss = loss_function(outputs, labels)\n                loss.backward()\n                optimizer.step()\n                epoch_loss += loss.item()\n\n            epoch_loss /= len(train_loader)\n\n        post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n\n        tt_aug = TestTimeAugmentation(\n            transform=transforms,\n            batch_size=5,\n            num_workers=0,\n            inferrer_fn=model,\n            device=device,\n            to_tensor=True,\n            output_device=\"cpu\",\n            post_func=post_trans,\n        )\n        mode, mean, std, vvc = tt_aug(test_data)\n        self.assertEqual(mode.shape, (1,) + input_size)\n        self.assertEqual(mean.shape, (1,) + input_size)\n        self.assertTrue(all(np.unique(mode) == (0, 1)))\n        self.assertGreaterEqual(mean.min(), 0.0)\n        self.assertLessEqual(mean.max(), 1.0)\n        self.assertEqual(std.shape, (1,) + input_size)\n        self.assertIsInstance(vvc, float)\n\n    def test_warn_non_random(self):\n        transforms = Compose([EnsureChannelFirstd(\"im\", channel_dim=\"no_channel\"), SpatialPadd(\"im\", 1)])\n        with self.assertWarns(UserWarning):\n            TestTimeAugmentation(transforms, None, None, None)\n\n    def test_warn_random_but_has_no_invertible(self):\n        transforms = Compose(\n            [\n                EnsureChannelFirstd(\"image\", channel_dim=\"no_channel\"),\n                RandFlipd(\"image\", prob=1.0),\n                RandScaleIntensityd(\"image\", 0.1, prob=1.0),\n            ]\n        )\n        with self.assertWarns(UserWarning):\n            tta = TestTimeAugmentation(transforms, 5, 0, orig_key=\"image\")\n            tta(self.get_data(1, (20, 20), data_type=np.float32))\n\n    def test_warn_random_but_all_not_invertible(self):\n        \"\"\"test with no invertible stack\"\"\"\n        transforms = Compose(\n            [EnsureChannelFirstd(\"image\", channel_dim=\"no_channel\"), RandScaleIntensityd(\"image\", 0.1, prob=1.0)]\n        )\n        with self.assertWarns(UserWarning):\n            tta = TestTimeAugmentation(transforms, 1, 0, orig_key=\"image\")\n            tta(self.get_data(1, (20, 20), data_type=np.float32))\n\n    def test_single_transform(self):\n        for p in TEST_NDARRAYS:\n            transforms = RandFlipd([\"image\", \"label\"], prob=1.0)\n            tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x)\n            tta(self.get_data(1, (20, 20), data_type=p))\n\n    def test_image_no_label(self):\n        transforms = RandFlipd([\"image\"], prob=1.0)\n        tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key=\"image\")\n        tta(self.get_data(1, (20, 20), include_label=False))\n\n    def test_non_spatial_output(self):\n        \"\"\"\n        Test TTA for non-spatial output (e.g., classification scores).\n        Verifies that setting `apply_inverse_to_pred=False` correctly aggregates\n        predictions without attempting spatial inversion.\n        \"\"\"\n        input_size = (20, 20)\n        data = {\"image\": np.random.rand(1, *input_size).astype(np.float32)}\n\n        transforms = Compose(\n            [EnsureChannelFirstd(\"image\", channel_dim=\"no_channel\"), RandFlipd(\"image\", prob=1.0, spatial_axis=0)]\n        )\n\n        def mock_classifier(x):\n            batch_size = x.shape[0]\n            return torch.tensor([[0.2, 0.8]] * batch_size, dtype=torch.float32, device=x.device)\n\n        tt_aug = TestTimeAugmentation(\n            transform=transforms,\n            batch_size=2,\n            num_workers=0,\n            inferrer_fn=mock_classifier,\n            device=\"cpu\",\n            orig_key=\"image\",\n            apply_inverse_to_pred=False,\n            return_full_data=False,\n        )\n        mode, mean, std, vvc = tt_aug(data, num_examples=4)\n\n        self.assertEqual(mean.shape, (2,))\n        np.testing.assert_allclose(mean, [0.2, 0.8], atol=1e-6)\n        np.testing.assert_allclose(std, [0.0, 0.0], atol=1e-6)\n\n        tt_aug.return_full_data = True\n        full_output = tt_aug(data, num_examples=4)\n        self.assertEqual(full_output.shape, (4, 2))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_vis_gradbased.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import DenseNet, DenseNet121, SEResNet50\nfrom monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad\n\n\nclass DenseNetAdjoint(DenseNet121):\n\n    def __call__(self, x, adjoint_info):\n        if adjoint_info != 42:\n            raise ValueError\n        return super().__call__(x)\n\n\nDENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\nDENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))\nSENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)\nSENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)\nDENSENET2DADJOINT = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3)\n\nTESTS = []\nfor type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad):\n    # 2D densenet\n    TESTS.append([type, DENSENET2D, (1, 1, 48, 64)])\n    # 3D densenet\n    TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6)])\n    # 2D senet\n    TESTS.append([type, SENET2D, (1, 3, 64, 64)])\n    # 3D senet\n    TESTS.append([type, SENET3D, (1, 3, 8, 8, 48)])\n    # 2D densenet - adjoint\n    TESTS.append([type, DENSENET2DADJOINT, (1, 1, 48, 64)])\n\n\nclass TestGradientClassActivationMap(unittest.TestCase):\n\n    @parameterized.expand(TESTS)\n    def test_shape(self, vis_type, model, shape):\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n        # optionally test for adjoint info\n        kwargs = {\"adjoint_info\": 42} if isinstance(model, DenseNetAdjoint) else {}\n\n        model.to(device)\n        model.eval()\n        vis = vis_type(model)\n        x = torch.rand(shape, device=device)\n        result = vis(x, **kwargs)\n        self.assertTupleEqual(result.shape, x.shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/integration/test_vista3d_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest.case import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label\nfrom monai.utils import min_version\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import skip_if_no_cuda, skip_if_quick\n\ncp, has_cp = optional_import(\"cupy\")\ncucim_skimage, has_cucim = optional_import(\"cucim.skimage\")\nmeasure, has_measure = optional_import(\"skimage.measure\", \"0.14.2\", min_version)\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n\nTESTS_SAMPLE_POINTS_FROM_LABEL = []\nfor use_center in [True, False]:\n    labels = torch.zeros(1, 1, 32, 32, 32)\n    labels[0, 0, 5:10, 5:10, 5:10] = 1\n    labels[0, 0, 10:15, 10:15, 10:15] = 3\n    labels[0, 0, 20:25, 20:25, 20:25] = 5\n    TESTS_SAMPLE_POINTS_FROM_LABEL.append(\n        [{\"labels\": labels, \"label_set\": (1, 3, 5), \"use_center\": use_center}, (3, 1, 3), (3, 1)]\n    )\n\nTEST_CONVERT_POINTS_TO_DISC = []\nfor radius in [1, 2]:\n    for disc in [True, False]:\n        image_size = (32, 32, 32)\n        point = torch.randn(3, 1, 3)\n        point_label = torch.randint(0, 4, (3, 1))\n        expected_shape = (point.shape[0], 2, *image_size)\n        TEST_CONVERT_POINTS_TO_DISC.append(\n            [\n                {\"image_size\": image_size, \"point\": point, \"point_label\": point_label, \"radius\": radius, \"disc\": disc},\n                expected_shape,\n            ]\n        )\n        image_size = (16, 32, 64)\n        point = torch.tensor([[[8, 16, 42], [2, 8, 21]]])\n        point_label = torch.tensor([[1, 0]])\n        expected_shape = (point.shape[0], 2, *image_size)\n        TEST_CONVERT_POINTS_TO_DISC.append(\n            [\n                {\"image_size\": image_size, \"point\": point, \"point_label\": point_label, \"radius\": radius, \"disc\": disc},\n                expected_shape,\n            ]\n        )\n\nTEST_CONVERT_POINTS_TO_DISC_VALUE = []\nimage_size = (16, 32, 64)\npoint = torch.tensor([[[8, 16, 42], [2, 8, 21]]])\npoint_label = torch.tensor([[1, 0]])\nexpected_shape = (point.shape[0], 2, *image_size)\nfor radius in [5, 10]:\n    for disc in [True, False]:\n        TEST_CONVERT_POINTS_TO_DISC_VALUE.append(\n            [\n                {\"image_size\": image_size, \"point\": point, \"point_label\": point_label, \"radius\": radius, \"disc\": disc},\n                [point, point_label],\n            ]\n        )\n\n\nTEST_LCC_MASK_POINT_TORCH = []\nfor bs in [1, 2]:\n    for num_points in [1, 3]:\n        shape = (bs, 1, 128, 32, 32)\n        TEST_LCC_MASK_POINT_TORCH.append(\n            [\n                {\n                    \"img_pos\": torch.randint(0, 2, shape, dtype=torch.bool),\n                    \"img_neg\": torch.randint(0, 2, shape, dtype=torch.bool),\n                    \"point_coords\": torch.randint(0, 10, (bs, num_points, 3)),\n                    \"point_labels\": torch.randint(0, 4, (bs, num_points)),\n                },\n                shape,\n            ]\n        )\n\nTEST_LCC_MASK_POINT_NP = []\nfor bs in [1, 2]:\n    for num_points in [1, 3]:\n        shape = (bs, 1, 32, 32, 64)\n        TEST_LCC_MASK_POINT_NP.append(\n            [\n                {\n                    \"img_pos\": np.random.randint(0, 2, shape, dtype=bool),\n                    \"img_neg\": np.random.randint(0, 2, shape, dtype=bool),\n                    \"point_coords\": np.random.randint(0, 5, (bs, num_points, 3)),\n                    \"point_labels\": np.random.randint(0, 4, (bs, num_points)),\n                },\n                shape,\n            ]\n        )\n\n\n@skipUnless(has_measure or cucim_skimage, \"skimage or cucim.skimage required\")\nclass TestSamplePointsFromLabel(unittest.TestCase):\n    @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL)\n    def test_shape(self, input_data, expected_point_shape, expected_point_label_shape):\n        point, point_label = sample_points_from_label(**input_data)\n        self.assertEqual(point.shape, expected_point_shape)\n        self.assertEqual(point_label.shape, expected_point_label_shape)\n\n\nclass TestConvertPointsToDisc(unittest.TestCase):\n    @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC)\n    def test_shape(self, input_data, expected_shape):\n        result = convert_points_to_disc(**input_data)\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE)\n    def test_value(self, input_data, points):\n        result = convert_points_to_disc(**input_data)\n        point, point_label = points\n        for i in range(point.shape[0]):\n            for j in range(point.shape[1]):\n                self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True)\n\n\n@skipUnless(has_measure or cucim_skimage, \"skimage or cucim.skimage required\")\nclass TestKeepMergeComponentsWithPoints(unittest.TestCase):\n    @skip_if_quick\n    @skip_if_no_cuda\n    @skipUnless(has_cp and cucim_skimage, \"cupy and cucim.skimage required\")\n    @parameterized.expand(TEST_LCC_MASK_POINT_TORCH)\n    def test_cp_shape(self, input_data, shape):\n        for key in input_data:\n            input_data[key] = input_data[key].to(device)\n        mask = keep_merge_components_with_points(**input_data)\n        self.assertEqual(mask.shape, shape)\n\n    @skipUnless(has_measure, \"skimage required\")\n    @parameterized.expand(TEST_LCC_MASK_POINT_NP)\n    def test_np_shape(self, input_data, shape):\n        mask = keep_merge_components_with_points(**input_data)\n        self.assertEqual(mask.shape, shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/lazy_transforms_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nfrom copy import deepcopy\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import InvertibleTransform, MapTransform, Randomizable\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import assert_allclose\n\napply_transforms_kwargs = (\"pending\", \"mode\", \"padding_mode\", \"dtype\", \"align_corners\")\n\n\ndef get_apply_param(init_param=None, call_param=None, params=apply_transforms_kwargs):\n    apply_param = {}\n    for key in apply_transforms_kwargs:\n        if init_param and key in init_param.keys():\n            apply_param[key] = init_param[key]\n        if call_param and key in call_param.keys():\n            apply_param[key] = call_param[key]\n    return apply_param\n\n\ndef test_resampler_lazy(\n    resampler,\n    expected_output,\n    init_param=None,\n    call_param=None,\n    output_key=None,\n    output_idx=None,\n    rtol=1e-5,\n    atol=1e-7,\n    skip_shape_check=False,\n    seed=None,\n):\n    \"\"\"\n    This test function is used to test the consistency between non-lazy and lazy transforms.\n    Args:\n        resampler: instance of a resampling transform.\n        expected_output: output of non-lazy transform.\n        init_param: parameters that are used to initialize the transform.\n        call_param: parameters that are used when calling the transform.\n        output_key: key to get the output of the transform. This argument is used for dictionary based transforms.\n        output_idx: index to get the expected output from multiple outputs of the transform.\n        rtol: relative tolerance. This argument is only used to compare the output.\n        atol: absolute tolerance. This argument is only used to compare the output.\n        skip_shape_check: skip the check of shapes.\n        seed: set the random state with an integer seed. This argument is used for randomizable transforms.\n\n    \"\"\"\n    if isinstance(resampler, Randomizable):\n        resampler.set_random_state(seed=seed)\n    set_track_meta(True)\n    resampler.lazy = True\n    pending_output = resampler(**deepcopy(call_param))\n    if output_idx is not None:\n        expected_output, pending_output = (expected_output[output_idx], pending_output[output_idx])\n    if output_key is not None:\n        non_lazy_out, lazy_out = expected_output[output_key], pending_output[output_key]\n    else:\n        non_lazy_out, lazy_out = expected_output, pending_output\n    assert_allclose(lazy_out.peek_pending_affine(), non_lazy_out.affine)\n    if not skip_shape_check:\n        assert_allclose(lazy_out.peek_pending_shape(), non_lazy_out.shape[1:4])\n    apply_param = get_apply_param(init_param, call_param)\n    lazy_out = apply_pending(lazy_out, overrides=apply_param)[0]\n    assert_allclose(lazy_out, non_lazy_out, rtol=rtol, atol=atol)\n    if (\n        isinstance(resampler, InvertibleTransform)\n        and (not isinstance(resampler, MapTransform))\n        and isinstance(lazy_out, MetaTensor)\n        and isinstance(non_lazy_out, MetaTensor)\n        and non_lazy_out.applied_operations\n    ):\n        resampler.lazy = False\n        out = resampler.inverse(lazy_out.clone())\n        ref = resampler.inverse(non_lazy_out.clone())\n        assert_allclose(out.applied_operations, [])\n        assert_allclose(out.pending_operations, [])\n        assert_allclose(ref, out, type_test=False, rtol=1e-3, atol=1e-3)\n        resampler.lazy = True\n"
  },
  {
    "path": "tests/losses/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/losses/deform/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/losses/deform/test_bending_energy.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses.deform import BendingEnergyLoss\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    [{}, {\"pred\": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],\n    [{}, {\"pred\": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0],\n    [\n        {\"normalize\": False},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},\n        4.0,\n    ],\n    [\n        {\"normalize\": False},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},\n        4.0,\n    ],\n    [{\"normalize\": False}, {\"pred\": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0],\n    [\n        {\"normalize\": True},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},\n        100.0,\n    ],\n    [\n        {\"normalize\": True},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},\n        100.0,\n    ],\n    [{\"normalize\": True}, {\"pred\": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0],\n]\n\n\nclass TestBendingEnergy(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = BendingEnergyLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = BendingEnergyLoss()\n        # not in 3-d, 4-d, 5-d\n        with self.assertRaisesRegex(ValueError, \"Expecting 3-d, 4-d or 5-d\"):\n            loss.forward(torch.ones((1, 3), device=device))\n        with self.assertRaisesRegex(ValueError, \"Expecting 3-d, 4-d or 5-d\"):\n            loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))\n        with self.assertRaisesRegex(ValueError, \"All spatial dimensions\"):\n            loss.forward(torch.ones((1, 3, 4, 5, 5), device=device))\n        with self.assertRaisesRegex(ValueError, \"All spatial dimensions\"):\n            loss.forward(torch.ones((1, 3, 5, 4, 5)))\n        with self.assertRaisesRegex(ValueError, \"All spatial dimensions\"):\n            loss.forward(torch.ones((1, 3, 5, 5, 4)))\n\n        # number of vector components unequal to number of spatial dims\n        with self.assertRaisesRegex(ValueError, \"Number of vector components\"):\n            loss.forward(torch.ones((1, 2, 5, 5, 5)))\n        with self.assertRaisesRegex(ValueError, \"Number of vector components\"):\n            loss.forward(torch.ones((1, 2, 5, 5, 5)))\n\n    def test_ill_opts(self):\n        pred = torch.rand(1, 3, 5, 5, 5).to(device=device)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            BendingEnergyLoss(reduction=\"unknown\")(pred)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            BendingEnergyLoss(reduction=None)(pred)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/deform/test_diffusion_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses.deform import DiffusionLoss\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    # all first partials are zero, so the diffusion loss is also zero\n    [{}, {\"pred\": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],\n    # all first partials are one, so the diffusion loss is also one\n    [{}, {\"pred\": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],\n    # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67\n    [\n        {\"normalize\": False},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},\n        56.0 / 3.0,\n    ],\n    # same as the previous case\n    [\n        {\"normalize\": False},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},\n        56.0 / 3.0,\n    ],\n    # same as the previous case\n    [{\"normalize\": False}, {\"pred\": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],\n    # we have shown in the demo notebook that\n    # diffusion loss is scale-invariant when the all axes have the same resolution\n    [\n        {\"normalize\": True},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},\n        56.0 / 3.0,\n    ],\n    [\n        {\"normalize\": True},\n        {\"pred\": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},\n        56.0 / 3.0,\n    ],\n    [{\"normalize\": True}, {\"pred\": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],\n    # for the following case, consider the following 2D matrix:\n    # tensor([[[[0, 1, 2],\n    #           [1, 2, 3],\n    #           [2, 3, 4],\n    #           [3, 4, 5],\n    #           [4, 5, 6]],\n    #          [[0, 1, 2],\n    #           [1, 2, 3],\n    #           [2, 3, 4],\n    #           [3, 4, 5],\n    #           [4, 5, 6]]]])\n    # the first partials wrt x are all ones, and so are the first partials wrt y\n    # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2\n    [{\"normalize\": False}, {\"pred\": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],\n    # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,\n    # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y\n    # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689\n    [\n        {\"normalize\": True},\n        {\"pred\": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},\n        (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,\n    ],\n]\n\n\nclass TestDiffusionLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = DiffusionLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = DiffusionLoss()\n        # not in 3-d, 4-d, 5-d\n        with self.assertRaisesRegex(ValueError, \"Expecting 3-d, 4-d or 5-d\"):\n            loss.forward(torch.ones((1, 3), device=device))\n        with self.assertRaisesRegex(ValueError, \"Expecting 3-d, 4-d or 5-d\"):\n            loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))\n        with self.assertRaisesRegex(ValueError, \"All spatial dimensions\"):\n            loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))\n        with self.assertRaisesRegex(ValueError, \"All spatial dimensions\"):\n            loss.forward(torch.ones((1, 3, 5, 2, 5)))\n        with self.assertRaisesRegex(ValueError, \"All spatial dimensions\"):\n            loss.forward(torch.ones((1, 3, 5, 5, 2)))\n\n        # number of vector components unequal to number of spatial dims\n        with self.assertRaisesRegex(ValueError, \"Number of vector components\"):\n            loss.forward(torch.ones((1, 2, 5, 5, 5)))\n        with self.assertRaisesRegex(ValueError, \"Number of vector components\"):\n            loss.forward(torch.ones((1, 2, 5, 5, 5)))\n\n    def test_ill_opts(self):\n        pred = torch.rand(1, 3, 5, 5, 5).to(device=device)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            DiffusionLoss(reduction=\"unknown\")(pred)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            DiffusionLoss(reduction=None)(pred)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/image_dissimilarity/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/losses/image_dissimilarity/test_global_mutual_information_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai import transforms\nfrom monai.losses.image_dissimilarity import GlobalMutualInformationLoss\nfrom tests.test_utils import download_url_or_skip_test, skip_if_quick, testing_data_config\n\ndevice = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\nTESTS_PATH = Path(__file__).parents[2]\nFILE_PATH = os.path.join(TESTS_PATH, \"testing_data\", \"temp_\" + \"mri.nii\")\n\nEXPECTED_VALUE = {\n    \"xyz_translation\": [\n        -1.5860257,\n        -0.62433463,\n        -0.38217825,\n        -0.2905613,\n        -0.23233329,\n        -0.1961407,\n        -0.16905619,\n        -0.15100679,\n        -0.13666219,\n        -0.12635908,\n    ],\n    \"xyz_rotation\": [\n        -1.5860257,\n        -0.30265224,\n        -0.18666176,\n        -0.15887907,\n        -0.1625064,\n        -0.16603896,\n        -0.19222091,\n        -0.18158069,\n        -0.167644,\n        -0.16698098,\n    ],\n}\n\n\n@skip_if_quick\nclass TestGlobalMutualInformationLoss(unittest.TestCase):\n    def setUp(self):\n        config = testing_data_config(\"images\", \"Prostate_T2W_AX_1\")\n        download_url_or_skip_test(\n            url=config[\"url\"],\n            filepath=FILE_PATH,\n            hash_val=config.get(\"hash_val\"),\n            hash_type=config.get(\"hash_type\", \"sha256\"),\n        )\n\n    def test_bspline(self):\n        loss_fn = GlobalMutualInformationLoss(kernel_type=\"b-spline\", num_bins=32, sigma_ratio=0.015)\n\n        transform_params_dict = {\n            \"xyz_translation\": [(i, i, i) for i in range(10)],\n            \"xyz_rotation\": [(np.pi / 100 * i, np.pi / 100 * i, np.pi / 100 * i) for i in range(10)],\n        }\n\n        def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.0)):\n            \"\"\"\n            Read and transform Prostate_T2W_AX_1.nii\n            Args:\n                translate_params: a tuple of 3 floats, translation is in pixel/voxel relative to the center of the input\n                        image. Defaults to no translation.\n                rotate_params: a rotation angle in radians, a tuple of 3 floats for 3D.\n                        Defaults to no rotation.\n            Returns:\n                numpy array of shape HWD\n            \"\"\"\n            transform_list = [\n                transforms.LoadImaged(keys=\"img\", image_only=True),\n                transforms.Affined(\n                    keys=\"img\",\n                    translate_params=translate_params,\n                    rotate_params=rotate_params,\n                    device=None,\n                    padding_mode=\"border\",\n                ),\n                transforms.NormalizeIntensityd(keys=[\"img\"]),\n            ]\n            transformation = transforms.Compose(transform_list)\n            return transformation({\"img\": FILE_PATH})[\"img\"]\n\n        a1 = transformation()\n        a1 = a1.clone().unsqueeze(0).unsqueeze(0).to(device)\n\n        for mode in transform_params_dict:\n            transform_params_list = transform_params_dict[mode]\n            expected_value_list = EXPECTED_VALUE[mode]\n            for transform_params, expected_value in zip(transform_params_list, expected_value_list):\n                a2 = transformation(\n                    translate_params=transform_params if \"translation\" in mode else (0.0, 0.0, 0.0),\n                    rotate_params=transform_params if \"rotation\" in mode else (0.0, 0.0, 0.0),\n                )\n                a2 = a2.clone().unsqueeze(0).unsqueeze(0).to(device)\n                result = loss_fn(a2, a1).detach().cpu().numpy()\n                np.testing.assert_allclose(result, expected_value, rtol=0.08, atol=0.05)\n\n\nclass TestGlobalMutualInformationLossIll(unittest.TestCase):\n    @parameterized.expand(\n        [\n            (torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)),  # mismatched_simple_dims\n            (\n                torch.ones((1, 3, 3), dtype=torch.float),\n                torch.ones((1, 3), dtype=torch.float),\n            ),  # mismatched_advanced_dims\n        ]\n    )\n    def test_ill_shape(self, input1, input2):\n        loss = GlobalMutualInformationLoss()\n        with self.assertRaises(ValueError):\n            loss.forward(input1, input2)\n\n    @parameterized.expand(\n        [\n            (0, \"mean\", ValueError, \"\"),  # num_bins_zero\n            (-1, \"mean\", ValueError, \"\"),  # num_bins_negative\n            (64, \"unknown\", ValueError, \"\"),  # reduction_unknown\n            (64, None, ValueError, \"\"),  # reduction_none\n        ]\n    )\n    def test_ill_opts(self, num_bins, reduction, expected_exception, expected_message):\n        pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)\n        target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)\n        with self.assertRaisesRegex(expected_exception, expected_message):\n            GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses.image_dissimilarity import LocalNormalizedCrossCorrelationLoss\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    [\n        {\"spatial_dims\": 1, \"kernel_type\": \"rectangular\", \"reduction\": \"sum\"},\n        {\n            \"pred\": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),\n        },\n        -1.0 * 3,\n    ],\n    [\n        {\"spatial_dims\": 1, \"kernel_type\": \"rectangular\"},\n        {\n            \"pred\": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),\n        },\n        -1.0,\n    ],\n    [\n        {\"spatial_dims\": 1, \"kernel_type\": \"triangular\", \"smooth_dr\": 0.1},\n        {\n            \"pred\": torch.zeros(1, 2, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),\n            \"target\": torch.zeros(1, 2, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),\n        },\n        0.0,\n    ],\n    [\n        {\"spatial_dims\": 2, \"kernel_type\": \"rectangular\"},\n        {\n            \"pred\": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(dtype=torch.float, device=device),\n        },\n        -1.0,\n    ],\n    [\n        {\"spatial_dims\": 3, \"kernel_type\": \"rectangular\"},\n        {\n            \"pred\": torch.arange(0, 3)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 1, 3, 3, 3)\n            .to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 3)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 1, 3, 3, 3)\n            .to(dtype=torch.float, device=device),\n        },\n        -1.0,\n    ],\n    [\n        {\"spatial_dims\": 3, \"kernel_type\": \"rectangular\"},\n        {\n            \"pred\": torch.arange(0, 3)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 3, 3, 3, 3)\n            .to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 3)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 3, 3, 3, 3)\n            .to(dtype=torch.float, device=device)\n            ** 2,\n        },\n        -0.95801723,\n    ],\n    [\n        {\"spatial_dims\": 3, \"kernel_type\": \"triangular\", \"kernel_size\": 5},\n        {\n            \"pred\": torch.arange(0, 5)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 3, 5, 5, 5)\n            .to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 5)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 3, 5, 5, 5)\n            .to(dtype=torch.float, device=device)\n            ** 2,\n        },\n        -0.918672,\n    ],\n    [\n        {\"spatial_dims\": 3, \"kernel_type\": \"gaussian\"},\n        {\n            \"pred\": torch.arange(0, 3)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 3, 3, 3, 3)\n            .to(dtype=torch.float, device=device),\n            \"target\": torch.arange(0, 3)\n            .reshape(1, 1, -1, 1, 1)\n            .expand(1, 3, 3, 3, 3)\n            .to(dtype=torch.float, device=device)\n            ** 2,\n        },\n        -0.95406944,\n    ],\n]\n\n\nclass TestLocalNormalizedCrossCorrelationLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = LocalNormalizedCrossCorrelationLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = LocalNormalizedCrossCorrelationLoss(spatial_dims=3)\n        # spatial_dims unmatch\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss.forward(\n                torch.ones((1, 3, 3, 3), dtype=torch.float, device=device),\n                torch.ones((1, 3, 3, 3), dtype=torch.float, device=device),\n            )\n        # pred, target shape unmatch\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss.forward(\n                torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device),\n                torch.ones((1, 3, 4, 4, 4), dtype=torch.float, device=device),\n            )\n\n    def test_ill_opts(self):\n        pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)\n        target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LocalNormalizedCrossCorrelationLoss(kernel_type=\"unknown\")(pred, target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LocalNormalizedCrossCorrelationLoss(kernel_type=None)(pred, target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LocalNormalizedCrossCorrelationLoss(kernel_size=4)(pred, target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LocalNormalizedCrossCorrelationLoss(reduction=\"unknown\")(pred, target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LocalNormalizedCrossCorrelationLoss(reduction=None)(pred, target)\n\n\n#     def test_script(self):\n#         input_param, input_data, _ = TEST_CASES[0]\n#         loss = LocalNormalizedCrossCorrelationLoss(**input_param)\n#         test_script_save(loss, input_data[\"pred\"], input_data[\"target\"])\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_adversarial_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import PatchAdversarialLoss\n\nshapes_tensors = {\"2d\": [4, 1, 64, 64], \"3d\": [4, 1, 64, 64, 64]}\nreductions = [\"sum\", \"mean\"]\ncriterion = [\"bce\", \"least_squares\", \"hinge\"]\n\nTEST_CASE_CREATION_FAIL = [{\"reduction\": \"sum\", \"criterion\": \"invalid\"}]\n\nTEST_CASES_LOSS_LOGIC_2D = []\nTEST_CASES_LOSS_LOGIC_3D = []\n\nfor c in criterion:\n    for r in reductions:\n        TEST_CASES_LOSS_LOGIC_2D.append([{\"reduction\": r, \"criterion\": c}, shapes_tensors[\"2d\"]])\n        TEST_CASES_LOSS_LOGIC_3D.append([{\"reduction\": r, \"criterion\": c}, shapes_tensors[\"3d\"]])\n\nTEST_CASES_LOSS_LOGIC_LIST = []\nfor c in criterion:\n    TEST_CASES_LOSS_LOGIC_LIST.append([{\"reduction\": \"none\", \"criterion\": c}, shapes_tensors[\"2d\"]])\n    TEST_CASES_LOSS_LOGIC_LIST.append([{\"reduction\": \"none\", \"criterion\": c}, shapes_tensors[\"3d\"]])\n\n\nclass TestPatchAdversarialLoss(unittest.TestCase):\n\n    def get_input(self, shape, is_positive):\n        \"\"\"\n        Get tensor for the tests. The tensor is around (-1) or (+1), depending on\n        is_positive.\n        \"\"\"\n        if is_positive:\n            offset = 1\n        else:\n            offset = -1\n        return torch.ones(shape) * (offset) + 0.01 * torch.randn(shape)\n\n    def test_criterion(self):\n        \"\"\"\n        Make sure that unknown criterion fail.\n        \"\"\"\n        with self.assertRaises(ValueError):\n            PatchAdversarialLoss(**TEST_CASE_CREATION_FAIL[0])\n\n    @parameterized.expand(TEST_CASES_LOSS_LOGIC_2D + TEST_CASES_LOSS_LOGIC_3D)\n    def test_loss_logic(self, input_param: dict, shape_input: list):\n        \"\"\"\n        We want to make sure that the adversarial losses do what they should.\n        If the discriminator takes in a tensor that looks positive, yet the label is fake,\n        the loss should be bigger than that obtained with a tensor that looks negative.\n        Same for the real label, and for the generator.\n        \"\"\"\n        loss = PatchAdversarialLoss(**input_param)\n        fakes = self.get_input(shape_input, is_positive=False)\n        reals = self.get_input(shape_input, is_positive=True)\n        # Discriminator: fake label\n        loss_disc_f_f = loss(fakes, target_is_real=False, for_discriminator=True)\n        loss_disc_f_r = loss(reals, target_is_real=False, for_discriminator=True)\n        assert loss_disc_f_f < loss_disc_f_r\n        # Discriminator: real label\n        loss_disc_r_f = loss(fakes, target_is_real=True, for_discriminator=True)\n        loss_disc_r_r = loss(reals, target_is_real=True, for_discriminator=True)\n        assert loss_disc_r_f > loss_disc_r_r\n        # Generator:\n        loss_gen_f = loss(fakes, target_is_real=True, for_discriminator=False)  # target_is_real is overridden\n        loss_gen_r = loss(reals, target_is_real=True, for_discriminator=False)  # target_is_real is overridden\n        assert loss_gen_f > loss_gen_r\n\n    @parameterized.expand(TEST_CASES_LOSS_LOGIC_LIST)\n    def test_multiple_discs(self, input_param: dict, shape_input):\n        shapes = [shape_input] + [shape_input[0:2] + [int(i / j) for i in shape_input[2:]] for j in range(1, 3)]\n        inputs = [self.get_input(shapes[i], is_positive=True) for i in range(len(shapes))]\n        loss = PatchAdversarialLoss(**input_param)\n        assert len(loss(inputs, for_discriminator=True, target_is_real=True)) == 3\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_barlow_twins_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import BarlowTwinsLoss\n\nTEST_CASES = [\n    [  # shape: (2, 4), (2, 4)\n        {\"lambd\": 5e-3},\n        {\n            \"input\": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),\n            \"target\": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),\n        },\n        4.0,\n    ],\n    [  # shape: (2, 4), (2, 4)\n        {\"lambd\": 5e-3},\n        {\n            \"input\": torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]),\n            \"target\": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),\n        },\n        4.0,\n    ],\n    [  # shape: (2, 4), (2, 4)\n        {\"lambd\": 5e-3},\n        {\n            \"input\": torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 0.0]]),\n            \"target\": torch.tensor([[1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 1.0]]),\n        },\n        5.2562,\n    ],\n    [  # shape: (2, 4), (2, 4)\n        {\"lambd\": 5e-4},\n        {\n            \"input\": torch.tensor([[2.0, 3.0, 1.0, 2.0], [0.0, 1.0, 2.0, 5.0]]),\n            \"target\": torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]),\n        },\n        5.0015,\n    ],\n    [  # shape: (4, 4), (4, 4)\n        {\"lambd\": 5e-3},\n        {\n            \"input\": torch.tensor(\n                [[1.0, 2.0, 1.0, 1.0], [3.0, 1.0, 1.0, 2.0], [1.0, 1.0, 1.0, 1.0], [2.0, 1.0, 1.0, 0.0]]\n            ),\n            \"target\": torch.tensor(\n                [\n                    [0.0, 1.0, -1.0, 0.0],\n                    [1 / 3, 0.0, -2 / 3, 1 / 3],\n                    [-2 / 3, -1.0, 7 / 3, 1 / 3],\n                    [1 / 3, 0.0, 1 / 3, -2 / 3],\n                ]\n            ),\n        },\n        1.4736,\n    ],\n]\n\n\nclass TestBarlowTwinsLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_param, input_data, expected_val):\n        barlowtwinsloss = BarlowTwinsLoss(**input_param)\n        result = barlowtwinsloss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = BarlowTwinsLoss(lambd=5e-3)\n        with self.assertRaises(ValueError):\n            loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_ill_batch_size(self):\n        loss = BarlowTwinsLoss(lambd=5e-3)\n        with self.assertRaises(ValueError):\n            loss(torch.ones((1, 2)), torch.ones((1, 2)))\n\n    def test_with_cuda(self):\n        loss = BarlowTwinsLoss(lambd=5e-3)\n        i = torch.ones((2, 10))\n        j = torch.ones((2, 10))\n        if torch.cuda.is_available():\n            i = i.cuda()\n            j = j.cuda()\n        output = loss(i, j)\n        np.testing.assert_allclose(output.detach().cpu().numpy(), 10.0, atol=1e-4, rtol=1e-4)\n\n    def check_warning_raised(self):\n        with self.assertWarns(Warning):\n            BarlowTwinsLoss(lambd=5e-3, batch_size=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_cldice_loss.py",
    "content": "# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import SoftclDiceLoss, SoftDiceclDiceLoss\n\nTEST_CASES = [\n    [{\"y_pred\": torch.ones((7, 3, 11, 10)), \"y_true\": torch.ones((7, 3, 11, 10))}, 0.0],\n    [{\"y_pred\": torch.ones((2, 3, 13, 14, 5)), \"y_true\": torch.ones((2, 3, 13, 14, 5))}, 0.0],\n]\n\n\nclass TestclDiceLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, y_pred_data, expected_val):\n        loss = SoftclDiceLoss()\n        loss_dice = SoftDiceclDiceLoss()\n        result = loss(**y_pred_data)\n        result_dice = loss_dice(**y_pred_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n        np.testing.assert_allclose(result_dice.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_with_cuda(self):\n        loss = SoftclDiceLoss()\n        loss_dice = SoftDiceclDiceLoss()\n        i = torch.ones((100, 3, 256, 256))\n        j = torch.ones((100, 3, 256, 256))\n        if torch.cuda.is_available():\n            i = i.cuda()\n            j = j.cuda()\n        output = loss(i, j)\n        output_dice = loss_dice(i, j)\n        np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)\n        np.testing.assert_allclose(output_dice.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_contrastive_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import ContrastiveLoss\n\nTEST_CASES = [\n    [  # shape: (1, 4), (1, 4)\n        {\"temperature\": 0.5},\n        {\"input\": torch.tensor([[1.0, 1.0, 0.0, 0.0]]), \"target\": torch.tensor([[1.0, 1.0, 0.0, 0.0]])},\n        0.0,\n    ],\n    [  # shape: (2, 4), (2, 4)\n        {\"temperature\": 0.5},\n        {\n            \"input\": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),\n            \"target\": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),\n        },\n        1.0986,\n    ],\n    [  # shape: (1, 4), (1, 4)\n        {\"temperature\": 0.5},\n        {\n            \"input\": torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 1.0, 0.0, 0.0]]),\n            \"target\": torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),\n        },\n        0.8719,\n    ],\n    [  # shape: (1, 4), (1, 4)\n        {\"temperature\": 0.5},\n        {\"input\": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), \"target\": torch.tensor([[1.0, 1.0, 0.0, 0.0]])},\n        0.0,\n    ],\n    [  # shape: (1, 4), (1, 4)\n        {\"temperature\": 0.05},\n        {\"input\": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), \"target\": torch.tensor([[1.0, 1.0, 0.0, 0.0]])},\n        0.0,\n    ],\n]\n\n\nclass TestContrastiveLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_param, input_data, expected_val):\n        contrastiveloss = ContrastiveLoss(**input_param)\n        result = contrastiveloss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = ContrastiveLoss(temperature=0.5)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_with_cuda(self):\n        loss = ContrastiveLoss(temperature=0.5)\n        i = torch.ones((1, 10))\n        j = torch.ones((1, 10))\n        if torch.cuda.is_available():\n            i = i.cuda()\n            j = j.cuda()\n        output = loss(i, j)\n        np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)\n\n    def check_warning_rasied(self):\n        with self.assertWarns(Warning):\n            ContrastiveLoss(temperature=0.5, batch_size=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_dice_ce_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import DiceCELoss\n\nTEST_CASES = [\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"to_onehot_y\": True},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.3133,  # the result equals to -1 + np.log(1 + np.exp(1))\n    ],\n    [  # shape: (2, 2, 3), (2, 2, 3), one-hot target\n        {\"to_onehot_y\": False},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n        },\n        0.3133,\n    ],\n    [  # shape: (2, 2, 3), (2, 2, 3), one-hot target\n        {\"to_onehot_y\": False},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[1, 1, 0], [0, 0, 1]], [[1, 0, 1], [0, 1, 0]]], dtype=torch.uint8),\n        },\n        0.3133,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"weight\": torch.tensor([1.0, 1.0])},\n        {\n            \"input\": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.2088,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3) lambda_dice: 1.0, lambda_ce: 2.0\n        {\n            \"include_background\": False,\n            \"to_onehot_y\": True,\n            \"weight\": torch.tensor([1.0, 1.0]),\n            \"lambda_dice\": 1.0,\n            \"lambda_ce\": 2.0,\n        },\n        {\n            \"input\": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.4176,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3), do not include class 0\n        {\"include_background\": False, \"to_onehot_y\": True, \"weight\": torch.tensor([0.0, 1.0])},\n        {\n            \"input\": torch.tensor([[[100.0, 100.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.3133,\n    ],\n    [  # shape: (2, 1, 3), (2, 1, 3), bceloss\n        {\"weight\": torch.tensor([0.5]), \"sigmoid\": True},\n        {\n            \"input\": torch.tensor([[[0.8, 0.6, 0.0]], [[0.0, 0.0, 0.9]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        1.445239,\n    ],\n]\n\n\nclass TestDiceCELoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_param, input_data, expected_val):\n        diceceloss = DiceCELoss(**input_param)\n        result = diceceloss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = DiceCELoss()\n        with self.assertRaises(AssertionError):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))\n\n    def test_ill_shape2(self):\n        loss = DiceCELoss()\n        with self.assertRaises(ValueError):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_ill_shape3(self):\n        loss = DiceCELoss()\n        with self.assertRaises(ValueError):\n            loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))\n\n    # def test_ill_reduction(self):\n    #     with self.assertRaisesRegex(ValueError, \"\"):\n    #         loss = DiceCELoss(reduction=\"none\")\n    #         loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    # def test_script(self):\n    #     loss = DiceCELoss()\n    #     test_input = torch.ones(2, 2, 8, 8)\n    #     test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_dice_focal_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import DiceFocalLoss, DiceLoss, FocalLoss\nfrom tests.test_utils import test_script_save\n\n\nclass TestDiceFocalLoss(unittest.TestCase):\n    def test_result_onehot_target_include_bg(self):\n        size = [3, 3, 5, 5]\n        label = torch.randint(low=0, high=2, size=size)\n        pred = torch.randn(size)\n        for reduction in [\"sum\", \"mean\", \"none\"]:\n            for weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]:\n                common_params = {\n                    \"include_background\": True,\n                    \"to_onehot_y\": False,\n                    \"reduction\": reduction,\n                    \"weight\": weight,\n                }\n                for lambda_focal in [0.5, 1.0, 1.5]:\n                    dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, **common_params)\n                    dice = DiceLoss(**common_params)\n                    focal = FocalLoss(gamma=1.0, **common_params)\n                    result = dice_focal(pred, label)\n                    expected_val = dice(pred, label) + lambda_focal * focal(pred, label)\n                    np.testing.assert_allclose(result, expected_val)\n\n    @parameterized.expand([[[3, 3, 5, 5], True], [[3, 2, 5, 5], False]])\n    def test_result_no_onehot_no_bg(self, size, onehot):\n        label = torch.randint(low=0, high=size[1] - 1, size=size)\n        if onehot:\n            label = torch.argmax(label, dim=1, keepdim=True)\n        pred = torch.randn(size)\n        for reduction in [\"sum\", \"mean\", \"none\"]:\n            for weight in [2.0] + [] if size[1] != 3 else [torch.tensor([1.0, 2.0]), (2.0, 1)]:\n                for lambda_focal in [0.5, 1.0, 1.5]:\n                    common_params = {\n                        \"include_background\": False,\n                        \"softmax\": True,\n                        \"to_onehot_y\": onehot,\n                        \"reduction\": reduction,\n                        \"weight\": weight,\n                    }\n                    dice_focal = DiceFocalLoss(lambda_focal=lambda_focal, **common_params)\n                    dice = DiceLoss(**common_params)\n                    common_params.pop(\"softmax\", None)\n                    focal = FocalLoss(**common_params)\n                    result = dice_focal(pred, label)\n                    expected_val = dice(pred, label) + lambda_focal * focal(pred, label)\n                    np.testing.assert_allclose(result, expected_val)\n\n    def test_ill_shape(self):\n        loss = DiceFocalLoss()\n        with self.assertRaises(AssertionError):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))\n\n    def test_ill_shape2(self):\n        loss = DiceFocalLoss()\n        with self.assertRaises(ValueError):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_ill_shape3(self):\n        loss = DiceFocalLoss()\n        with self.assertRaises(ValueError):\n            loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))\n\n    def test_ill_lambda(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            DiceFocalLoss(lambda_dice=-1.0)\n\n    def test_script(self):\n        loss = DiceFocalLoss()\n        test_input = torch.ones(2, 1, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n    @parameterized.expand(\n        [\n            (\"sum_None_0.5_0.25\", \"sum\", None, 0.5, 0.25),\n            (\"sum_weight_0.5_0.25\", \"sum\", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),\n            (\"sum_weight_tuple_0.5_0.25\", \"sum\", (3, 2.0, 1), 0.5, 0.25),\n            (\"mean_None_0.5_0.25\", \"mean\", None, 0.5, 0.25),\n            (\"mean_weight_0.5_0.25\", \"mean\", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),\n            (\"mean_weight_tuple_0.5_0.25\", \"mean\", (3, 2.0, 1), 0.5, 0.25),\n            (\"none_None_0.5_0.25\", \"none\", None, 0.5, 0.25),\n            (\"none_weight_0.5_0.25\", \"none\", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),\n            (\"none_weight_tuple_0.5_0.25\", \"none\", (3, 2.0, 1), 0.5, 0.25),\n        ]\n    )\n    def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):\n        size = [3, 3, 5, 5]\n        label = torch.randint(low=0, high=2, size=size)\n        pred = torch.randn(size)\n\n        common_params = {\"include_background\": True, \"to_onehot_y\": False, \"reduction\": reduction, \"weight\": weight}\n\n        dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)\n        dice = DiceLoss(**common_params)\n        focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)\n\n        result = dice_focal(pred, label)\n        expected_val = dice(pred, label) + lambda_focal * focal(pred, label)\n\n        np.testing.assert_allclose(result, expected_val, err_msg=f\"Failed on case: {name}\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_dice_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import DiceLoss\nfrom tests.test_utils import test_script_save\n\nTEST_CASES = [\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.307576,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.416657,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"soft_label\": True},\n        {\n            \"input\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n            \"target\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"soft_label\": False},\n        {\n            \"input\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n            \"target\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n        },\n        0.307773,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"smooth_nr\": 0, \"smooth_dr\": 0},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.435050,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"sigmoid\": True,\n            \"reduction\": \"none\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        [[[0.296529], [0.415136]], [[0.599976], [0.428559]]],\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"softmax\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.383713,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"softmax\": True,\n            \"reduction\": \"sum\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        1.534853,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.307576,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"squared_pred\": True},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.178337,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"jaccard\": True},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.470451,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.999963,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"other_act\": lambda x: torch.log_softmax(x, dim=1),\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        -8.522593,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"batch\": True},\n        {\n            \"input\": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.774718,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 0, \"smooth_dr\": 1e-4, \"batch\": True},\n        {\n            \"input\": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.774733,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 0, \"smooth_dr\": 1e-4, \"batch\": False},\n        {\n            \"input\": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.840058,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3) weight\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"other_act\": lambda x: torch.log_softmax(x, dim=1),\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n            \"weight\": (0, 1),\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        -8.268515,\n    ],\n]\n\n\nclass TestDiceLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = DiceLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = DiceLoss()\n        with self.assertRaisesRegex(AssertionError, \"\"):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))\n\n    def test_ill_opts(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            DiceLoss(sigmoid=True, softmax=True)\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            DiceLoss(reduction=\"unknown\")(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            DiceLoss(reduction=None)(chn_input, chn_target)\n\n    def test_input_warnings(self):\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertWarns(Warning):\n            loss = DiceLoss(include_background=False)\n            loss.forward(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            loss = DiceLoss(softmax=True)\n            loss.forward(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            loss = DiceLoss(to_onehot_y=True)\n            loss.forward(chn_input, chn_target)\n\n    def test_script(self):\n        loss = DiceLoss()\n        test_input = torch.ones(2, 1, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_ds_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import DeepSupervisionLoss, DiceCELoss, DiceFocalLoss, DiceLoss\nfrom tests.test_utils import test_script_save\n\nTEST_CASES_DICECE = [\n    [\n        {\"to_onehot_y\": True},\n        {},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.606557,\n    ]\n]\n\nTEST_CASES_DICECE2 = [\n    [\n        {\"to_onehot_y\": True},\n        {},\n        {\n            \"input\": [\n                torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n                torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),\n                torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),\n            ],\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        1.78144,\n    ],\n    [\n        {\"to_onehot_y\": True},\n        {\"weight_mode\": \"same\"},\n        {\n            \"input\": [\n                torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n                torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),\n                torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),\n            ],\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        3.5529,\n    ],\n    [\n        {\"to_onehot_y\": True},\n        {\"weight_mode\": \"two\"},\n        {\n            \"input\": [\n                torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n                torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),\n                torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),\n            ],\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        2.07973,\n    ],\n    [\n        {\"to_onehot_y\": True},\n        {\"weights\": [0.1, 0.2, 0.3]},\n        {\n            \"input\": [\n                torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n                torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),\n                torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),\n            ],\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.76924,\n    ],\n]\n\nTEST_CASES_DICE = [\n    [\n        {\"to_onehot_y\": True},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.166666,  # the result equals to -1 + np.log(1 + np.exp(1))\n    ],\n    [\n        {\"to_onehot_y\": True},\n        {\n            \"input\": [\n                torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n                torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),\n                torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),\n            ],\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.666665,\n    ],\n]\n\nTEST_CASES_DICEFOCAL = [\n    [\n        {\"to_onehot_y\": True},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.32124,  # the result equals to -1 + np.log(1 + np.exp(1))\n    ],\n    [\n        {\"to_onehot_y\": True},\n        {\n            \"input\": [\n                torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n                torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),\n                torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),\n            ],\n            \"target\": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        1.06452,\n    ],\n]\n\n\nclass TestDSLossDiceCE(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_DICECE)\n    def test_result(self, input_param, input_param2, input_data, expected_val):\n        diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)\n        result = diceceloss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = DeepSupervisionLoss(DiceCELoss())\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_ill_reduction(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss = DeepSupervisionLoss(DiceCELoss(reduction=\"none\"))\n            loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_script(self):\n        loss = DeepSupervisionLoss(DiceCELoss())\n        test_input = torch.ones(2, 2, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n\nclass TestDSLossDiceCE2(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_DICECE2)\n    def test_result(self, input_param, input_param2, input_data, expected_val):\n        diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)\n        result = diceceloss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n\nclass TestDSLossDice(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_DICE)\n    def test_result(self, input_param, input_data, expected_val):\n        loss = DeepSupervisionLoss(DiceLoss(**input_param))\n        result = loss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n\nclass TestDSLossDiceFocal(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_DICEFOCAL)\n    def test_result(self, input_param, input_data, expected_val):\n        loss = DeepSupervisionLoss(DiceFocalLoss(**input_param))\n        result = loss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_focal_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom parameterized import parameterized\n\nfrom monai.losses import FocalLoss\nfrom monai.networks import one_hot\nfrom tests.test_utils import TEST_DEVICES, test_script_save\n\nTEST_CASES = []\nfor case in TEST_DEVICES:\n    device = case[0]\n    input_data = {\n        \"input\": torch.tensor(\n            [[[[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]], [[1.0, 1.0], [0.5, 0.0]]]], device=device\n        ),  # (1, 3, 2, 2)\n        \"target\": torch.tensor([[[[0, 1], [2, 0]]]], device=device),  # (1, 1, 2, 2)\n    }\n    TEST_CASES.append([{\"to_onehot_y\": True}, input_data, 0.34959])\n    TEST_CASES.append(\n        [\n            {\"to_onehot_y\": False},\n            {\n                \"input\": input_data[\"input\"],  # (1, 3, 2, 2)\n                \"target\": F.one_hot(input_data[\"target\"].squeeze(1)).permute(0, 3, 1, 2),  # (1, 3, 2, 2)\n            },\n            0.34959,\n        ]\n    )\n    TEST_CASES.append([{\"to_onehot_y\": True, \"include_background\": False}, input_data, 0.36498])\n    TEST_CASES.append([{\"to_onehot_y\": True, \"alpha\": 0.8}, input_data, 0.08423])\n    TEST_CASES.append(\n        [\n            {\"to_onehot_y\": True, \"reduction\": \"none\"},\n            input_data,\n            np.array(\n                [\n                    [\n                        [[0.02266, 0.70187], [0.37741, 0.17329]],\n                        [[0.70187, 0.02266], [0.37741, 0.17329]],\n                        [[0.70187, 0.70187], [0.06757, 0.17329]],\n                    ]\n                ]\n            ),\n        ]\n    )\n    TEST_CASES.append(\n        [\n            {\"to_onehot_y\": True, \"weight\": torch.tensor([0.5, 0.1, 0.2]), \"reduction\": \"none\"},\n            input_data,\n            np.array(\n                [\n                    [\n                        [[0.01133, 0.35093], [0.18871, 0.08664]],\n                        [[0.07019, 0.00227], [0.03774, 0.01733]],\n                        [[0.14037, 0.14037], [0.01352, 0.03466]],\n                    ]\n                ]\n            ),\n        ]\n    )\n    TEST_CASES.append([{\"to_onehot_y\": True, \"use_softmax\": True}, input_data, 0.16276])\n    TEST_CASES.append([{\"to_onehot_y\": True, \"alpha\": 0.8, \"use_softmax\": True}, input_data, 0.08138])\n\nTEST_ALPHA_BROADCASTING = []\nfor case in TEST_DEVICES:\n    device = case[0]\n    for include_background in [True, False]:\n        for use_softmax in [True, False]:\n            TEST_ALPHA_BROADCASTING.append([device, include_background, use_softmax])\n\n\nclass TestFocalLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_param, input_data, expected_val):\n        focal_loss = FocalLoss(**input_param)\n        result = focal_loss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_consistency_with_cross_entropy_2d(self):\n        \"\"\"For gamma=0 the focal loss reduces to the cross entropy loss\"\"\"\n        focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction=\"mean\", weight=1.0)\n        ce = nn.BCEWithLogitsLoss(reduction=\"mean\")\n        max_error = 0\n        class_num = 10\n        batch_size = 128\n        for _ in range(100):\n            # Create a random tensor of shape (batch_size, class_num, 8, 4)\n            x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)\n            # Create a random batch of classes\n            l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float()\n            if torch.cuda.is_available():\n                x = x.cuda()\n                l = l.cuda()\n            output0 = focal_loss(x, l)\n            output1 = ce(x, l)\n            a = float(output0.cpu().detach())\n            b = float(output1.cpu().detach())\n            if abs(a - b) > max_error:\n                max_error = abs(a - b)\n        self.assertAlmostEqual(max_error, 0.0, places=3)\n\n    def test_consistency_with_cross_entropy_2d_no_reduction(self):\n        \"\"\"For gamma=0 the focal loss reduces to the cross entropy loss\"\"\"\n\n        focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction=\"none\", weight=1.0)\n        ce = nn.BCEWithLogitsLoss(reduction=\"none\")\n        max_error = 0\n        class_num = 10\n        batch_size = 128\n        for _ in range(100):\n            # Create a random tensor of shape (batch_size, class_num, 8, 4)\n            x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)\n            # Create a random batch of classes\n            l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float()\n            if torch.cuda.is_available():\n                x = x.cuda()\n                l = l.cuda()\n            output0 = focal_loss(x, l)\n            output1 = ce(x, l)\n            a = output0.cpu().detach().numpy()\n            b = output1.cpu().detach().numpy()\n            error = np.abs(a - b)\n            max_error = np.maximum(error, max_error)\n\n        assert np.allclose(max_error, 0, atol=1e-6)\n\n    def test_consistency_with_cross_entropy_2d_onehot_label(self):\n        \"\"\"For gamma=0 the focal loss reduces to the cross entropy loss\"\"\"\n        focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction=\"mean\")\n        ce = nn.BCEWithLogitsLoss(reduction=\"mean\")\n        max_error = 0\n        class_num = 10\n        batch_size = 128\n        for _ in range(100):\n            # Create a random tensor of shape (batch_size, class_num, 8, 4)\n            x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)\n            # Create a random batch of classes\n            l = torch.randint(low=0, high=class_num, size=(batch_size, 1, 8, 4))\n            if torch.cuda.is_available():\n                x = x.cuda()\n                l = l.cuda()\n            output0 = focal_loss(x, l)\n            output1 = ce(x, one_hot(l, num_classes=class_num))\n            a = float(output0.cpu().detach())\n            b = float(output1.cpu().detach())\n            if abs(a - b) > max_error:\n                max_error = abs(a - b)\n        self.assertAlmostEqual(max_error, 0.0, places=3)\n\n    def test_consistency_with_cross_entropy_classification(self):\n        \"\"\"for gamma=0 the focal loss reduces to the cross entropy loss\"\"\"\n        focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction=\"mean\")\n        ce = nn.BCEWithLogitsLoss(reduction=\"mean\")\n        max_error = 0\n        class_num = 10\n        batch_size = 128\n        for _ in range(100):\n            # Create a random scores tensor of shape (batch_size, class_num)\n            x = torch.rand(batch_size, class_num, requires_grad=True)\n            # Create a random batch of classes\n            l = torch.randint(low=0, high=class_num, size=(batch_size, 1))\n            l = l.long()\n            if torch.cuda.is_available():\n                x = x.cuda()\n                l = l.cuda()\n            output0 = focal_loss(x, l)\n            output1 = ce(x, one_hot(l, num_classes=class_num))\n            a = float(output0.cpu().detach())\n            b = float(output1.cpu().detach())\n            if abs(a - b) > max_error:\n                max_error = abs(a - b)\n        self.assertAlmostEqual(max_error, 0.0, places=3)\n\n    def test_consistency_with_cross_entropy_classification_01(self):\n        # for gamma=0.1 the focal loss differs from the cross entropy loss\n        focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction=\"mean\")\n        ce = nn.BCEWithLogitsLoss(reduction=\"mean\")\n        max_error = 0\n        class_num = 10\n        batch_size = 128\n        for _ in range(100):\n            # Create a random scores tensor of shape (batch_size, class_num)\n            x = torch.rand(batch_size, class_num, requires_grad=True)\n            # Create a random batch of classes\n            l = torch.randint(low=0, high=class_num, size=(batch_size, 1))\n            l = l.long()\n            if torch.cuda.is_available():\n                x = x.cuda()\n                l = l.cuda()\n            output0 = focal_loss(x, l)\n            output1 = ce(x, one_hot(l, num_classes=class_num))\n            a = float(output0.cpu().detach())\n            b = float(output1.cpu().detach())\n            if abs(a - b) > max_error:\n                max_error = abs(a - b)\n        self.assertNotAlmostEqual(max_error, 0.0, places=3)\n\n    def test_bin_seg_2d(self):\n        for use_softmax in [True, False]:\n            # define 2d examples\n            target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])\n            # add another dimension corresponding to the batch (batch size = 1 here)\n            target = target.unsqueeze(0)  # shape (1, H, W)\n            pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0\n\n            # initialize the mean dice loss\n            loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)\n\n            # focal loss for pred_very_good should be close to 0\n            target = target.unsqueeze(1)  # shape (1, 1, H, W)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n            # with alpha\n            loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n    def test_empty_class_2d(self):\n        for use_softmax in [True, False]:\n            num_classes = 2\n            # define 2d examples\n            target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])\n            # add another dimension corresponding to the batch (batch size = 1 here)\n            target = target.unsqueeze(0)  # shape (1, H, W)\n            pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0\n\n            # initialize the mean dice loss\n            loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)\n\n            # focal loss for pred_very_good should be close to 0\n            target = target.unsqueeze(1)  # shape (1, 1, H, W)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n            # with alpha\n            loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n    def test_multi_class_seg_2d(self):\n        for use_softmax in [True, False]:\n            num_classes = 6  # labels 0 to 5\n            # define 2d examples\n            target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])\n            # add another dimension corresponding to the batch (batch size = 1 here)\n            target = target.unsqueeze(0)  # shape (1, H, W)\n            pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0\n            # initialize the mean dice loss\n            loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)\n            loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax)\n\n            # focal loss for pred_very_good should be close to 0\n            target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2)  # test one hot\n            target = target.unsqueeze(1)  # shape (1, 1, H, W)\n\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n            focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n            # with alpha\n            loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n            loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax)\n            focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n    def test_bin_seg_3d(self):\n        for use_softmax in [True, False]:\n            num_classes = 2  # labels 0, 1\n            # define 3d examples\n            target = torch.tensor(\n                [\n                    # raw 0\n                    [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                    # raw 1\n                    [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                    # raw 2\n                    [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                ]\n            )\n            # add another dimension corresponding to the batch (batch size = 1 here)\n            target = target.unsqueeze(0)  # shape (1, H, W, D)\n            target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3)  # test one hot\n            pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0\n\n            # initialize the mean dice loss\n            loss = FocalLoss(to_onehot_y=True, use_softmax=use_softmax)\n            loss_onehot = FocalLoss(to_onehot_y=False, use_softmax=use_softmax)\n\n            # focal loss for pred_very_good should be close to 0\n            target = target.unsqueeze(1)  # shape (1, 1, H, W)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n            focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n            # with alpha\n            loss = FocalLoss(to_onehot_y=True, alpha=0.5, use_softmax=use_softmax)\n            focal_loss_good = float(loss(pred_very_good, target).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n            loss_onehot = FocalLoss(to_onehot_y=False, alpha=0.5, use_softmax=use_softmax)\n            focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())\n            self.assertAlmostEqual(focal_loss_good, 0.0, places=3)\n\n    def test_foreground(self):\n        background = torch.ones(1, 1, 5, 5)\n        foreground = torch.zeros(1, 1, 5, 5)\n        target = torch.cat((background, foreground), dim=1)\n        input = torch.cat((background, foreground), dim=1)\n        target[:, 0, 2, 2] = 0\n        target[:, 1, 2, 2] = 1\n\n        fgbg = FocalLoss(to_onehot_y=False, include_background=True)(input, target)\n        fg = FocalLoss(to_onehot_y=False, include_background=False)(input, target)\n        self.assertAlmostEqual(float(fgbg.cpu()), 0.1116, places=3)\n        self.assertAlmostEqual(float(fg.cpu()), 0.1733, places=3)\n\n    def test_ill_opts(self):\n        chn_input = torch.ones((1, 2, 3))\n        chn_target = torch.ones((1, 2, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            FocalLoss(reduction=\"unknown\")(chn_input, chn_target)\n\n    def test_ill_shape(self):\n        chn_input = torch.ones((1, 2, 3))\n        chn_target = torch.ones((1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            FocalLoss(reduction=\"mean\")(chn_input, chn_target)\n\n    def test_ill_class_weight(self):\n        chn_input = torch.ones((1, 4, 3, 3))\n        chn_target = torch.ones((1, 4, 3, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            FocalLoss(include_background=True, weight=(1.0, 1.0, 2.0))(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            FocalLoss(include_background=False, weight=(1.0, 1.0, 1.0, 1.0))(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            FocalLoss(include_background=False, weight=(1.0, 1.0, -1.0))(chn_input, chn_target)\n\n    def test_warnings(self):\n        with self.assertWarns(Warning):\n            chn_input = torch.ones((1, 1, 3))\n            chn_target = torch.ones((1, 1, 3))\n            loss = FocalLoss(to_onehot_y=True)\n            loss(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            chn_input = torch.ones((1, 1, 3))\n            chn_target = torch.ones((1, 1, 3))\n            loss = FocalLoss(include_background=False)\n            loss(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            chn_input = torch.ones((1, 3, 3))\n            chn_target = torch.ones((1, 3, 3))\n            loss = FocalLoss(include_background=False, use_softmax=True, alpha=0.5)\n            loss(chn_input, chn_target)\n\n    def test_script(self):\n        for use_softmax in [True, False]:\n            loss = FocalLoss(use_softmax=use_softmax)\n            test_input = torch.ones(2, 2, 8, 8)\n            test_script_save(loss, test_input, test_input)\n\n    @parameterized.expand(TEST_ALPHA_BROADCASTING)\n    def test_alpha_sequence_broadcasting(self, device, include_background, use_softmax):\n        \"\"\"\n        Test FocalLoss with alpha as a sequence for proper broadcasting.\n        \"\"\"\n        num_classes = 3\n        batch_size = 2\n        spatial_dims = (4, 4)\n\n        logits = torch.randn(batch_size, num_classes, *spatial_dims, device=device)\n        target = torch.randint(0, num_classes, (batch_size, 1, *spatial_dims), device=device)\n\n        if include_background:\n            alpha_seq = [0.1, 0.5, 2.0]\n        else:\n            alpha_seq = [0.5, 2.0]\n\n        loss_func = FocalLoss(\n            to_onehot_y=True,\n            gamma=2.0,\n            alpha=alpha_seq,\n            include_background=include_background,\n            use_softmax=use_softmax,\n            reduction=\"mean\",\n        )\n\n        result = loss_func(logits, target)\n\n        self.assertTrue(torch.is_tensor(result))\n        self.assertEqual(result.ndim, 0)\n        self.assertTrue(\n            result > 0, f\"Loss should be positive. params: dev={device}, bg={include_background}, softmax={use_softmax}\"\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_generalized_dice_focal_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.losses import FocalLoss, GeneralizedDiceFocalLoss, GeneralizedDiceLoss\nfrom tests.test_utils import test_script_save\n\n\nclass TestGeneralizedDiceFocalLoss(unittest.TestCase):\n    def test_result_onehot_target_include_bg(self):\n        size = [3, 3, 5, 5]\n        label = torch.randint(low=0, high=2, size=size)\n        pred = torch.randn(size)\n        for reduction in [\"sum\", \"mean\", \"none\"]:\n            common_params = {\"include_background\": True, \"to_onehot_y\": False, \"reduction\": reduction}\n            for weight in [None, torch.tensor([1.0, 1.0, 2.0]), (3, 2.0, 1)]:\n                for lambda_focal in [0.5, 1.0, 1.5]:\n                    generalized_dice_focal = GeneralizedDiceFocalLoss(\n                        weight=weight, gamma=1.0, lambda_focal=lambda_focal, **common_params\n                    )\n                    generalized_dice = GeneralizedDiceLoss(**common_params)\n                    focal = FocalLoss(weight=weight, gamma=1.0, **common_params)\n                    result = generalized_dice_focal(pred, label)\n                    expected_val = generalized_dice(pred, label) + lambda_focal * focal(pred, label)\n                    np.testing.assert_allclose(result, expected_val)\n\n    def test_result_no_onehot_no_bg(self):\n        size = [3, 3, 5, 5]\n        label = torch.randint(low=0, high=2, size=size)\n        label = torch.argmax(label, dim=1, keepdim=True)\n        pred = torch.randn(size)\n        for reduction in [\"sum\", \"mean\", \"none\"]:\n            common_params = {\"include_background\": False, \"to_onehot_y\": True, \"reduction\": reduction}\n            for weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]:\n                for lambda_focal in [0.5, 1.0, 1.5]:\n                    generalized_dice_focal = GeneralizedDiceFocalLoss(\n                        weight=weight, lambda_focal=lambda_focal, **common_params\n                    )\n                    generalized_dice = GeneralizedDiceLoss(**common_params)\n                    focal = FocalLoss(weight=weight, **common_params)\n                    result = generalized_dice_focal(pred, label)\n                    expected_val = generalized_dice(pred, label) + lambda_focal * focal(pred, label)\n                    np.testing.assert_allclose(result, expected_val)\n\n    def test_ill_shape(self):\n        loss = GeneralizedDiceFocalLoss()\n        with self.assertRaises(AssertionError):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))\n\n    def test_ill_shape2(self):\n        loss = GeneralizedDiceFocalLoss()\n        with self.assertRaises(ValueError):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_ill_shape3(self):\n        loss = GeneralizedDiceFocalLoss()\n        with self.assertRaises(ValueError):\n            loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))\n\n    def test_ill_lambda(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            GeneralizedDiceFocalLoss(lambda_gdl=-1.0)\n\n    def test_script(self):\n        loss = GeneralizedDiceFocalLoss()\n        test_input = torch.ones(2, 1, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_generalized_dice_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import GeneralizedDiceLoss\nfrom tests.test_utils import test_script_save\n\nTEST_CASES = [\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.307576,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.416597,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"soft_label\": True},\n        {\n            \"input\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n            \"target\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"soft_label\": False},\n        {\n            \"input\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n            \"target\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n        },\n        0.307748,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"smooth_nr\": 0.0, \"smooth_dr\": 0.0},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.469964,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"softmax\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.414507,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"softmax\": True,\n            \"reduction\": \"sum\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.829015,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"softmax\": True,\n            \"reduction\": \"none\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        [[[0.273476]], [[0.555539]]],\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"smooth_nr\": 1e-8, \"smooth_dr\": 1e-8},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.307576,\n    ],\n    [  # shape: (1, 2, 4), (1, 1, 4)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"softmax\": True,\n            \"w_type\": \"simple\",\n            \"smooth_nr\": 0,\n            \"smooth_dr\": 0,\n        },\n        {\n            \"input\": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1, 1, 0, 0]]]),\n        },\n        0.250023,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.99970,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"other_act\": lambda x: torch.log_softmax(x, dim=1),\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        -0.097833,\n    ],\n]\n\n\nclass TestGeneralizedDiceLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = GeneralizedDiceLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = GeneralizedDiceLoss()\n        with self.assertRaisesRegex(AssertionError, \"\"):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))\n\n    def test_ill_opts(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            GeneralizedDiceLoss(sigmoid=True, softmax=True)\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            GeneralizedDiceLoss(reduction=\"unknown\")(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            GeneralizedDiceLoss(reduction=None)(chn_input, chn_target)\n\n    def test_input_warnings(self):\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertWarns(Warning):\n            loss = GeneralizedDiceLoss(include_background=False)\n            loss.forward(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            loss = GeneralizedDiceLoss(softmax=True)\n            loss.forward(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            loss = GeneralizedDiceLoss(to_onehot_y=True)\n            loss.forward(chn_input, chn_target)\n\n    def test_differentiability(self):\n        prediction = torch.ones((1, 1, 1, 3))\n        target = torch.ones((1, 1, 1, 3))\n        prediction.requires_grad = True\n        target.requires_grad = True\n\n        generalized_dice_loss = GeneralizedDiceLoss()\n        loss = generalized_dice_loss(prediction, target)\n        self.assertIsNotNone(loss.grad_fn)\n\n    def test_batch(self):\n        prediction = torch.zeros(2, 3, 3, 3)\n        target = torch.zeros(2, 3, 3, 3)\n        prediction.requires_grad = True\n        target.requires_grad = True\n\n        generalized_dice_loss = GeneralizedDiceLoss(batch=True)\n        loss = generalized_dice_loss(prediction, target)\n        self.assertIsNotNone(loss.grad_fn)\n\n    def test_script(self):\n        loss = GeneralizedDiceLoss()\n        test_input = torch.ones(2, 1, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_generalized_wasserstein_dice_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nfrom monai.losses import GeneralizedWassersteinDiceLoss\nfrom tests.test_utils import test_script_save\n\n\nclass TestGeneralizedWassersteinDiceLoss(unittest.TestCase):\n    def test_bin_seg_2d(self):\n        target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])\n\n        # add another dimension corresponding to the batch (batch size = 1 here)\n        target = target.unsqueeze(0)\n        pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()\n        pred_very_poor = 1000 * F.one_hot(1 - target, num_classes=2).permute(0, 3, 1, 2).float()\n\n        for weight_mode in [\"default\", \"GDL\"]:\n            # initialize the loss\n            loss = GeneralizedWassersteinDiceLoss(\n                dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=weight_mode\n            )\n\n            # the loss for pred_very_good should be close to 0\n            loss_good = float(loss.forward(pred_very_good, target))\n            self.assertAlmostEqual(loss_good, 0.0, places=3)\n\n            # same test, but with target with a class dimension\n            target_4dim = target.unsqueeze(1)\n            loss_good = float(loss.forward(pred_very_good, target_4dim))\n            self.assertAlmostEqual(loss_good, 0.0, places=3)\n\n            # the loss for pred_very_poor should be close to 1\n            loss_poor = float(loss.forward(pred_very_poor, target))\n            self.assertAlmostEqual(loss_poor, 1.0, places=3)\n\n    def test_different_target_data_type(self):\n        \"\"\"\n        Test if the loss is compatible with all the integer types\n        for the target segmentation.\n        \"\"\"\n        # define 2d examples\n        target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])\n\n        # add another dimension corresponding to the batch (batch size = 1 here)\n        target = target.unsqueeze(0)  # shape (1, H, W)\n        pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()\n\n        target_uint8 = target.to(torch.uint8)\n        target_int8 = target.to(torch.int8)\n        target_short = target.short()\n        target_int = target.int()\n        target_long = target.long()\n        target_list = [target_uint8, target_int8, target_short, target_int, target_long]\n\n        for w_mode in [\"default\", \"GDL\"]:\n            # initialize the loss\n            loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode)\n\n            # The test should pass irrespectively of the integer type used\n            for t in target_list:\n                # the loss for pred_very_good should be close to 0\n                loss_good = float(loss.forward(pred_very_good, t))\n                self.assertAlmostEqual(loss_good, 0.0, places=3)\n\n    def test_empty_class_2d(self):\n        num_classes = 2\n        target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])\n\n        # add another dimension corresponding to the batch (batch size = 1 here)\n        target = target.unsqueeze(0)\n        pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()\n        pred_very_poor = 1000 * F.one_hot(1 - target, num_classes=num_classes).permute(0, 3, 1, 2).float()\n\n        for w_mode in [\"default\", \"GDL\"]:\n            # initialize the loss\n            loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode)\n\n            # loss for pred_very_good should be close to 0\n            loss_good = float(loss.forward(pred_very_good, target))\n            self.assertAlmostEqual(loss_good, 0.0, places=3)\n\n            # loss for pred_very_poor should be close to 1\n            loss_poor = float(loss.forward(pred_very_poor, target))\n            self.assertAlmostEqual(loss_poor, 1.0, places=3)\n\n    def test_bin_seg_3d(self):\n        # define 3d examples\n        target = torch.tensor(\n            [\n                # raw 0\n                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                # raw 1\n                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                # raw 2\n                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n            ]\n        )\n\n        # add another dimension corresponding to the batch (batch size = 1 here)\n        target = target.unsqueeze(0)  # shape (1, H, W, D)\n        pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 4, 1, 2, 3).float()\n        pred_very_poor = 1000 * F.one_hot(1 - target, num_classes=2).permute(0, 4, 1, 2, 3).float()\n\n        for w_mode in [\"default\", \"GDL\"]:\n            # initialize the loss\n            loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode)\n\n            # mean dice loss for pred_very_good should be close to 0\n            loss_good = float(loss.forward(pred_very_good, target))\n            self.assertAlmostEqual(loss_good, 0.0, places=3)\n\n            # mean dice loss for pred_very_poor should be close to 1\n            loss_poor = float(loss.forward(pred_very_poor, target))\n            self.assertAlmostEqual(loss_poor, 1.0, places=3)\n\n    def test_convergence(self):\n        \"\"\"\n        The goal of this test is to assess if the gradient of the loss function\n        is correct by testing if we can train a one layer neural network\n        to segment one image.\n        We verify that the loss is decreasing in almost all SGD steps.\n        \"\"\"\n        learning_rate = 0.001\n        max_iter = 50\n\n        # define a simple 3d example\n        target_seg = torch.tensor(\n            [\n                # raw 0\n                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                # raw 1\n                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n                # raw 2\n                [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],\n            ]\n        )\n        target_seg = torch.unsqueeze(target_seg, dim=0)\n        image = 12 * target_seg + 27  # dummy image to segment\n        image = image.float()\n        num_classes = 2\n        num_voxels = 3 * 4 * 4\n\n        # define a model with one layer\n        class OnelayerNet(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.layer = nn.Linear(num_voxels, num_voxels * num_classes)\n\n            def forward(self, x):\n                x = x.view(-1, num_voxels)\n                x = self.layer(x)\n                x = x.view(-1, num_classes, 3, 4, 4)\n                return x\n\n        for w_mode in [\"default\", \"GDL\"]:\n            # initialise the network\n            net = OnelayerNet()\n\n            # initialize the loss\n            loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode)\n\n            # initialize an optimizer\n            optimizer = optim.Adam(net.parameters(), lr=learning_rate)\n\n            # initial difference between pred and target\n            pred_start = torch.argmax(net(image), dim=1)\n            diff_start = torch.norm(pred_start.float() - target_seg.float())\n\n            loss_history = []\n            # train the network\n            for _ in range(max_iter):\n                # set the gradient to zero\n                optimizer.zero_grad()\n\n                # forward pass\n                output = net(image)\n                loss_val = loss(output, target_seg)\n\n                # backward pass\n                loss_val.backward()\n                optimizer.step()\n\n                # stats\n                loss_history.append(loss_val.item())\n\n            # difference between pred and target after training\n            pred_end = torch.argmax(net(image), dim=1)\n            diff_end = torch.norm(pred_end.float() - target_seg.float())\n\n            # count the number of SGD steps in which the loss decreases\n            num_decreasing_steps = 0\n            for i in range(len(loss_history) - 1):\n                if loss_history[i] > loss_history[i + 1]:\n                    num_decreasing_steps += 1\n            decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1)\n\n            # verify that the loss is decreasing for sufficiently many SGD steps\n            self.assertTrue(decreasing_steps_ratio > 0.75)\n\n            # check that the predicted segmentation has improved\n            self.assertGreater(diff_start, diff_end)\n\n    def test_batch_size_greater_than_one(self):\n        \"\"\"\n        Regression test for https://github.com/Project-MONAI/MONAI/issues/4650\n        With M=identity and batch_size > 1, the GWDL should produce the same\n        per-sample loss values as with batch_size=1.\n        \"\"\"\n        target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])\n        target_single = target_single.unsqueeze(0)  # shape (1, H, W)\n        pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float()\n\n        # Create a batch of size 2 by repeating the same sample\n        target_batch = target_single.repeat(2, 1, 1)  # shape (2, H, W)\n        pred_batch = pred_single.repeat(2, 1, 1, 1)  # shape (2, C, H, W)\n\n        for w_mode in [\"default\", \"GDL\"]:\n            loss_fn = GeneralizedWassersteinDiceLoss(\n                dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction=\"none\"\n            )\n\n            loss_single = loss_fn(pred_single, target_single)\n            loss_batch = loss_fn(pred_batch, target_batch)\n\n            # Each sample in the batch should produce the same loss as the single sample\n            for i in range(2):\n                self.assertAlmostEqual(\n                    float(loss_batch[i]),\n                    float(loss_single[0]),\n                    places=5,\n                    msg=f\"Batch loss[{i}] != single loss for weighting_mode={w_mode}\",\n                )\n\n        # Also test with mean reduction using a non-trivial (poor) prediction\n        # so the expected loss is not near zero\n        pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float()\n        pred_poor_batch = pred_poor.repeat(2, 1, 1, 1)\n\n        for w_mode in [\"default\", \"GDL\"]:\n            loss_fn = GeneralizedWassersteinDiceLoss(\n                dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction=\"mean\"\n            )\n\n            loss_single = float(loss_fn(pred_poor, target_single))\n            loss_batch = float(loss_fn(pred_poor_batch, target_batch))\n\n            # Verify the loss is non-trivial (close to 1 for poor predictions)\n            self.assertGreater(loss_single, 0.5, msg=f\"Expected non-trivial loss for weighting_mode={w_mode}\")\n            self.assertAlmostEqual(\n                loss_batch,\n                loss_single,\n                places=5,\n                msg=f\"Batch mean loss != single mean loss for weighting_mode={w_mode}\",\n            )\n\n    def test_batch_size_different_samples(self):\n        \"\"\"\n        Regression test for https://github.com/Project-MONAI/MONAI/issues/4650\n        Verify loss is computed correctly when batch contains different samples.\n        \"\"\"\n        target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0)\n        target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0)\n\n        pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float()\n        # Use a poor prediction for sample b so its loss is non-trivial (~1.0)\n        pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float()\n\n        # Combine into a batch\n        target_batch = torch.cat([target_a, target_b], dim=0)\n        pred_batch = torch.cat([pred_a, pred_b], dim=0)\n\n        for w_mode in [\"default\", \"GDL\"]:\n            loss_fn = GeneralizedWassersteinDiceLoss(\n                dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction=\"none\"\n            )\n\n            loss_a = float(loss_fn(pred_a, target_a))\n            loss_b = float(loss_fn(pred_b, target_b))\n            loss_batch = loss_fn(pred_batch, target_batch)\n\n            self.assertAlmostEqual(\n                float(loss_batch[0]), loss_a, places=5, msg=f\"Batch loss[0] != loss_a for weighting_mode={w_mode}\"\n            )\n            self.assertAlmostEqual(\n                float(loss_batch[1]), loss_b, places=5, msg=f\"Batch loss[1] != loss_b for weighting_mode={w_mode}\"\n            )\n\n    def test_script(self):\n        target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])\n\n        # add another dimension corresponding to the batch (batch size = 1 here)\n        target = target.unsqueeze(0)\n        pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()\n\n        loss = GeneralizedWassersteinDiceLoss(dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=\"default\")\n\n        test_script_save(loss, pred_very_good, target)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_giou_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import BoxGIoULoss\n\nTEST_CASES = [\n    [  # shape: (1, 4), (1, 4)\n        {\"input\": torch.tensor([[1.0, 1.0, 2.0, 2.0]]), \"target\": torch.tensor([[1.0, 1.0, 2.0, 2.0]])},\n        0.0,\n    ],\n    [  # shape: (1, 6), (1, 6)\n        {\n            \"input\": torch.tensor([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0]]),\n            \"target\": torch.tensor([[1.0, 1.0, 1.0, 2.0, 2.0, 2.0]]),\n        },\n        0.0,\n    ],\n]\n\n\nclass TestGIoULoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_data, expected_val):\n        loss = BoxGIoULoss()\n        result = loss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = BoxGIoULoss()\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))\n\n    def test_with_cuda(self):\n        loss = BoxGIoULoss()\n        i = torch.tensor([[1.0, 1.0, 2.0, 2.0]])\n        j = torch.tensor([[1.0, 1.0, 2.0, 2.0]])\n        if torch.cuda.is_available():\n            i = i.cuda()\n            j = j.cuda()\n        output = loss(i, j)\n        np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_hausdorff_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest.case import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import HausdorffDTLoss, LogHausdorffDTLoss\nfrom monai.utils import optional_import\n\n_, has_scipy = optional_import(\"scipy\")\n\nTEST_CASES = []\nfor device in [\"cpu\", \"cuda\"] if torch.cuda.is_available() else [\"cpu\"]:\n    TEST_CASES.append(\n        [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n            {\"include_background\": True, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device),\n            },\n            0.509329,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (1, 1, 1, 2, 2), (1, 1, 1, 2, 2)\n            {\"include_background\": True, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device),\n                \"target\": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device),\n            },\n            0.509329,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (1, 1, 2, 2, 2), (1, 1, 2, 2, 2)\n            {\"include_background\": True, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[[1.0, -1.0], [1.0, -1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]], device=device),\n                \"target\": torch.tensor([[[[[1.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]], device=device),\n            },\n            0.375718,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (1, 2, 2, 2), (1, 2, 2, 2)\n            {\"include_background\": True, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device),\n            },\n            0.326994,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n            {\"include_background\": True, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device),\n            },\n            0.455082,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n            {\"include_background\": False, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]], [[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]]]], device=device),\n            },\n            0.144659,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 2, 3, 1), (2, 1, 3, 1)\n            {\"include_background\": True, \"to_onehot_y\": True, \"sigmoid\": True, \"reduction\": \"none\"},\n            {\n                \"input\": torch.tensor(\n                    [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]],\n                    device=device,\n                ),\n                \"target\": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device),\n            },\n            [[[[0.407765]], [[0.407765]]], [[[0.5000]], [[0.5000]]]],\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 2, 3, 1), (2, 1, 3, 1)\n            {\"include_background\": True, \"to_onehot_y\": True, \"softmax\": True},\n            {\n                \"input\": torch.tensor(\n                    [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]],\n                    device=device,\n                ),\n                \"target\": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device),\n            },\n            0.357016,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 2, 3, 1), (2, 1, 3, 1)\n            {\"include_background\": True, \"to_onehot_y\": True, \"softmax\": True, \"reduction\": \"sum\"},\n            {\n                \"input\": torch.tensor(\n                    [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]],\n                    device=device,\n                ),\n                \"target\": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device),\n            },\n            1.428062,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n            {\"include_background\": True, \"sigmoid\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device),\n            },\n            0.509329,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n            {\"include_background\": True, \"other_act\": torch.tanh},\n            {\n                \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device),\n            },\n            1.870039,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 2, 3), (2, 1, 3)\n            {\"include_background\": True, \"to_onehot_y\": True, \"other_act\": lambda x: torch.log_softmax(x, dim=1)},\n            {\n                \"input\": torch.tensor(\n                    [[[[-1.0], [0.0], [1.0]], [[1.0], [0.0], [-1.0]]], [[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]],\n                    device=device,\n                ),\n                \"target\": torch.tensor([[[[1.0], [0.0], [0.0]]], [[[1.0], [1.0], [0.0]]]], device=device),\n            },\n            4.366613,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n            {\"include_background\": True, \"other_act\": torch.tanh, \"batch\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device),\n            },\n            1.607137,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n            {\"include_background\": True, \"other_act\": torch.tanh, \"batch\": True},\n            {\n                \"input\": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device),\n            },\n            1.607137,\n        ]\n    )\n    TEST_CASES.append(\n        [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n            {\"include_background\": True, \"other_act\": torch.tanh, \"batch\": False},\n            {\n                \"input\": torch.tensor([[[[1.0, -0.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n                \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]], device=device),\n            },\n            1.607137,\n        ]\n    )\n\nTEST_CASES_LOG = [[*inputs, np.log(np.array(output) + 1)] for *inputs, output in TEST_CASES]\n\n\ndef _describe_test_case(test_func, test_number, params):\n    input_param, input_data, _ = params.args\n    return f\"params:{input_param}, shape:{input_data['input'].shape}, device:{input_data['input'].device}\"\n\n\n@skipUnless(has_scipy, \"Scipy required\")\nclass TestHausdorffDTLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES, doc_func=_describe_test_case)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = HausdorffDTLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = HausdorffDTLoss()\n        with self.assertRaisesRegex(AssertionError, \"\"):\n            loss.forward(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6)))\n\n    def test_ill_opts(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            HausdorffDTLoss(sigmoid=True, softmax=True)\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            HausdorffDTLoss(reduction=\"unknown\")(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            HausdorffDTLoss(reduction=None)(chn_input, chn_target)\n\n    @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])\n    def test_input_warnings(self, include_background, softmax, to_onehot_y):\n        chn_input = torch.ones((1, 1, 1, 3))\n        chn_target = torch.ones((1, 1, 1, 3))\n        with self.assertWarns(Warning):\n            loss = HausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)\n            loss.forward(chn_input, chn_target)\n\n\n@skipUnless(has_scipy, \"Scipy required\")\nclass TesLogtHausdorffDTLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_LOG, doc_func=_describe_test_case)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = LogHausdorffDTLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = LogHausdorffDTLoss()\n        with self.assertRaisesRegex(AssertionError, \"\"):\n            loss.forward(torch.ones((1, 1, 2, 3)), torch.ones((1, 4, 5, 6)))\n\n    def test_ill_opts(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LogHausdorffDTLoss(sigmoid=True, softmax=True)\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LogHausdorffDTLoss(reduction=\"unknown\")(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            LogHausdorffDTLoss(reduction=None)(chn_input, chn_target)\n\n    @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])\n    def test_input_warnings(self, include_background, softmax, to_onehot_y):\n        chn_input = torch.ones((1, 1, 1, 3))\n        chn_target = torch.ones((1, 1, 1, 3))\n        with self.assertWarns(Warning):\n            loss = LogHausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)\n            loss.forward(chn_input, chn_target)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_masked_dice_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import MaskedDiceLoss\n\nTEST_CASES = [\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),\n            \"mask\": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]),\n        },\n        0.333333,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n            \"mask\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]),\n        },\n        0.301128,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"smooth_nr\": 0, \"smooth_dr\": 0},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n            \"mask\": torch.tensor([[[1.0, 1.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n            \"mask\": torch.tensor([[[1.0, 1.0, 0.0]]]),\n        },\n        0.579184,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"sigmoid\": True,\n            \"reduction\": \"none\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        [[[0.296529], [0.415136]], [[0.599976], [0.428559]]],\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"softmax\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.383713,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"softmax\": True,\n            \"reduction\": \"sum\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        1.534853,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.307576,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"squared_pred\": True, \"smooth_nr\": 1e-5, \"smooth_dr\": 1e-5},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.178337,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"jaccard\": True, \"smooth_nr\": 1e-5, \"smooth_dr\": 1e-5},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.470451,\n    ],\n]\n\n\nclass TestDiceLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = MaskedDiceLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n    def test_ill_shape(self):\n        loss = MaskedDiceLoss()\n        with self.assertRaisesRegex(AssertionError, \"\"):\n            loss.forward(torch.ones((1, 2, 3)), torch.ones((4, 5, 6)))\n\n    def test_ill_opts(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            MaskedDiceLoss(sigmoid=True, softmax=True)\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            MaskedDiceLoss(reduction=\"unknown\")(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            MaskedDiceLoss(reduction=None)(chn_input, chn_target)\n\n    def test_input_warnings(self):\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertWarns(Warning):\n            loss = MaskedDiceLoss(include_background=False)\n            loss.forward(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            loss = MaskedDiceLoss(softmax=True)\n            loss.forward(chn_input, chn_target)\n        with self.assertWarns(Warning):\n            loss = MaskedDiceLoss(to_onehot_y=True)\n            loss.forward(chn_input, chn_target)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_masked_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses.dice import DiceFocalLoss, DiceLoss\nfrom monai.losses.spatial_mask import MaskedLoss\nfrom monai.utils import set_determinism\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    [\n        {\n            \"loss\": DiceFocalLoss,\n            \"weight\": torch.tensor([1.0, 1.0, 2.0]),\n            \"gamma\": 0.1,\n            \"lambda_focal\": 0.5,\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"reduction\": \"sum\",\n        },\n        [17.1679, 15.5623],\n    ]\n]\n\n\nclass TestMaskedLoss(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, expected_val):\n        size = [3, 3, 5, 5]\n        label = torch.randint(low=0, high=2, size=size)\n        label = torch.argmax(label, dim=1, keepdim=True)\n        pred = torch.randn(size)\n        result = MaskedLoss(**input_param)(pred, label, None)\n        out = result.detach().cpu().numpy()\n        self.assertTrue(np.allclose(out, expected_val[0]))\n\n        mask = torch.randint(low=0, high=2, size=label.shape)\n        result = MaskedLoss(**input_param)(pred, label, mask)\n        out = result.detach().cpu().numpy()\n        self.assertTrue(np.allclose(out, expected_val[1]))\n\n    def test_ill_opts(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            MaskedLoss(loss=[])\n\n        dice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            masked = MaskedLoss(loss=dice_loss)\n            masked(input=torch.zeros((3, 1, 2, 2)), target=torch.zeros((3, 1, 2, 2)), mask=torch.zeros((3, 3, 2, 2)))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            masked = MaskedLoss(loss=dice_loss)\n            masked(input=torch.zeros((3, 3, 2, 2)), target=torch.zeros((3, 2, 2, 2)), mask=torch.zeros((3, 3, 2, 2)))\n\n    def test_script(self):\n        input_param, expected_val = TEST_CASES[0]\n        size = [3, 3, 5, 5]\n        label = torch.randint(low=0, high=2, size=size)\n        label = torch.argmax(label, dim=1, keepdim=True)\n        pred = torch.randn(size)\n        loss = MaskedLoss(**input_param)\n        test_script_save(loss, pred, label)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_multi_scale.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import DiceLoss\nfrom monai.losses.multi_scale import MultiScaleLoss\nfrom tests.test_utils import test_script_save\n\ndice_loss = DiceLoss(include_background=True, sigmoid=True, smooth_nr=1e-5, smooth_dr=1e-5)\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    [\n        {\"loss\": dice_loss, \"scales\": None, \"kernel\": \"gaussian\"},\n        {\n            \"y_pred\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n            \"y_true\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device),\n        },\n        0.307576,\n    ],\n    [\n        {\"loss\": dice_loss, \"scales\": [0, 1], \"kernel\": \"gaussian\"},\n        {\n            \"y_pred\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]], device=device),\n            \"y_true\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=device),\n        },\n        0.463116,\n    ],\n    [\n        {\"loss\": dice_loss, \"scales\": [0, 1, 2], \"kernel\": \"cauchy\"},\n        {\n            \"y_pred\": torch.tensor([[[[[1.0, -1.0], [-1.0, 1.0]]]]], device=device),\n            \"y_true\": torch.tensor([[[[[1.0, 0.0], [1.0, 1.0]]]]], device=device),\n        },\n        0.715228,\n    ],\n]\n\n\nclass TestMultiScale(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = MultiScaleLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n    @parameterized.expand(\n        [\n            ({\"loss\": dice_loss, \"kernel\": \"none\"}, None, None),  # kernel_none\n            ({\"loss\": dice_loss, \"scales\": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))),  # scales_negative\n            (\n                {\"loss\": dice_loss, \"scales\": [-1], \"reduction\": \"none\"},\n                torch.ones((1, 1, 3)),\n                torch.ones((1, 1, 3)),\n            ),  # scales_negative_reduction_none\n        ]\n    )\n    def test_ill_opts(self, kwargs, input, target):\n        if input is None and target is None:\n            with self.assertRaisesRegex(ValueError, \"\"):\n                MultiScaleLoss(**kwargs)\n        else:\n            with self.assertRaisesRegex(ValueError, \"\"):\n                MultiScaleLoss(**kwargs)(input, target)\n\n    def test_script(self):\n        input_param, input_data, expected_val = TEST_CASES[0]\n        loss = MultiScaleLoss(**input_param)\n        test_script_save(loss, input_data[\"y_pred\"], input_data[\"y_true\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_nacl_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import NACLLoss\n\ninputs = torch.tensor(\n    [\n        [\n            [\n                [0.1498, 0.1158, 0.3996, 0.3730],\n                [0.2155, 0.1585, 0.8541, 0.8579],\n                [0.6640, 0.2424, 0.0774, 0.0324],\n                [0.0580, 0.2180, 0.3447, 0.8722],\n            ],\n            [\n                [0.3908, 0.9366, 0.1779, 0.1003],\n                [0.9630, 0.6118, 0.4405, 0.7916],\n                [0.5782, 0.9515, 0.4088, 0.3946],\n                [0.7860, 0.3910, 0.0324, 0.9568],\n            ],\n            [\n                [0.0759, 0.0238, 0.5570, 0.1691],\n                [0.2703, 0.7722, 0.1611, 0.6431],\n                [0.8051, 0.6596, 0.4121, 0.1125],\n                [0.5283, 0.6746, 0.5528, 0.7913],\n            ],\n        ]\n    ]\n)\ntargets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]])\n\nTEST_CASES = [\n    [{\"classes\": 3, \"dim\": 2}, {\"inputs\": inputs, \"targets\": targets}, 1.1442],\n    [{\"classes\": 3, \"dim\": 2}, {\"inputs\": inputs.repeat(4, 1, 1, 1), \"targets\": targets.repeat(4, 1, 1)}, 1.1442],\n    [{\"classes\": 3, \"dim\": 2, \"kernel_ops\": \"gaussian\"}, {\"inputs\": inputs, \"targets\": targets}, 1.1433],\n    [{\"classes\": 3, \"dim\": 2, \"kernel_ops\": \"gaussian\", \"sigma\": 0.5}, {\"inputs\": inputs, \"targets\": targets}, 1.1469],\n    [{\"classes\": 3, \"dim\": 2, \"distance_type\": \"l2\"}, {\"inputs\": inputs, \"targets\": targets}, 1.1269],\n    [{\"classes\": 3, \"dim\": 2, \"alpha\": 0.2}, {\"inputs\": inputs, \"targets\": targets}, 1.1790],\n    [\n        {\"classes\": 3, \"dim\": 3, \"kernel_ops\": \"gaussian\"},\n        {\n            \"inputs\": torch.tensor(\n                [\n                    [\n                        [\n                            [\n                                [0.5977, 0.2767, 0.0591, 0.1675],\n                                [0.4835, 0.3778, 0.8406, 0.3065],\n                                [0.6047, 0.2860, 0.9742, 0.2013],\n                                [0.9128, 0.8368, 0.6711, 0.4384],\n                            ],\n                            [\n                                [0.9797, 0.1863, 0.5584, 0.6652],\n                                [0.2272, 0.2004, 0.7914, 0.4224],\n                                [0.5097, 0.8818, 0.2581, 0.3495],\n                                [0.1054, 0.5483, 0.3732, 0.3587],\n                            ],\n                            [\n                                [0.3060, 0.7066, 0.7922, 0.4689],\n                                [0.1733, 0.8902, 0.6704, 0.2037],\n                                [0.8656, 0.5561, 0.2701, 0.0092],\n                                [0.1866, 0.7714, 0.6424, 0.9791],\n                            ],\n                            [\n                                [0.5067, 0.3829, 0.6156, 0.8985],\n                                [0.5192, 0.8347, 0.2098, 0.2260],\n                                [0.8887, 0.3944, 0.6400, 0.5345],\n                                [0.1207, 0.3763, 0.5282, 0.7741],\n                            ],\n                        ],\n                        [\n                            [\n                                [0.8499, 0.4759, 0.1964, 0.5701],\n                                [0.3190, 0.1238, 0.2368, 0.9517],\n                                [0.0797, 0.6185, 0.0135, 0.8672],\n                                [0.4116, 0.1683, 0.1355, 0.0545],\n                            ],\n                            [\n                                [0.7533, 0.2658, 0.5955, 0.4498],\n                                [0.9500, 0.2317, 0.2825, 0.9763],\n                                [0.1493, 0.1558, 0.3743, 0.8723],\n                                [0.1723, 0.7980, 0.8816, 0.0133],\n                            ],\n                            [\n                                [0.8426, 0.2666, 0.2077, 0.3161],\n                                [0.1725, 0.8414, 0.1515, 0.2825],\n                                [0.4882, 0.5159, 0.4120, 0.1585],\n                                [0.2551, 0.9073, 0.7691, 0.9898],\n                            ],\n                            [\n                                [0.4633, 0.8717, 0.8537, 0.2899],\n                                [0.3693, 0.7953, 0.1183, 0.4596],\n                                [0.0087, 0.7925, 0.0989, 0.8385],\n                                [0.8261, 0.6920, 0.7069, 0.4464],\n                            ],\n                        ],\n                        [\n                            [\n                                [0.0110, 0.1608, 0.4814, 0.6317],\n                                [0.0194, 0.9669, 0.3259, 0.0028],\n                                [0.5674, 0.8286, 0.0306, 0.5309],\n                                [0.3973, 0.8183, 0.0238, 0.1934],\n                            ],\n                            [\n                                [0.8947, 0.6629, 0.9439, 0.8905],\n                                [0.0072, 0.1697, 0.4634, 0.0201],\n                                [0.7184, 0.2424, 0.0820, 0.7504],\n                                [0.3937, 0.1424, 0.4463, 0.5779],\n                            ],\n                            [\n                                [0.4123, 0.6227, 0.0523, 0.8826],\n                                [0.0051, 0.0353, 0.3662, 0.7697],\n                                [0.4867, 0.8986, 0.2510, 0.5316],\n                                [0.1856, 0.2634, 0.9140, 0.9725],\n                            ],\n                            [\n                                [0.2041, 0.4248, 0.2371, 0.7256],\n                                [0.2168, 0.5380, 0.4538, 0.7007],\n                                [0.9013, 0.2623, 0.0739, 0.2998],\n                                [0.1366, 0.5590, 0.2952, 0.4592],\n                            ],\n                        ],\n                    ]\n                ]\n            ),\n            \"targets\": torch.tensor(\n                [\n                    [\n                        [[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]],\n                        [[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]],\n                        [[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]],\n                        [[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]],\n                    ]\n                ]\n            ),\n        },\n        1.15035,\n    ],\n]\n\n\nclass TestNACLLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_param, input_data, expected_val):\n        loss = NACLLoss(**input_param)\n        result = loss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_perceptual_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import PerceptualLoss\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose, skip_if_downloading_fails, skip_if_quick\n\n_, has_torchvision = optional_import(\"torchvision\")\nTEST_CASES = [\n    [{\"spatial_dims\": 2, \"network_type\": \"squeeze\"}, (2, 1, 64, 64), (2, 1, 64, 64)],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"squeeze\", \"is_fake_3d\": True, \"fake_3d_ratio\": 0.1},\n        (2, 1, 64, 64, 64),\n        (2, 1, 64, 64, 64),\n    ],\n    [{\"spatial_dims\": 2, \"network_type\": \"radimagenet_resnet50\"}, (2, 1, 64, 64), (2, 1, 64, 64)],\n    [{\"spatial_dims\": 2, \"network_type\": \"radimagenet_resnet50\"}, (2, 3, 64, 64), (2, 3, 64, 64)],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"radimagenet_resnet50\", \"is_fake_3d\": True, \"fake_3d_ratio\": 0.1},\n        (2, 1, 64, 64, 64),\n        (2, 1, 64, 64, 64),\n    ],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"medicalnet_resnet10_23datasets\", \"is_fake_3d\": False},\n        (2, 1, 64, 64, 64),\n        (2, 1, 64, 64, 64),\n    ],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"medicalnet_resnet10_23datasets\", \"is_fake_3d\": False},\n        (2, 6, 64, 64, 64),\n        (2, 6, 64, 64, 64),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"network_type\": \"medicalnet_resnet10_23datasets\",\n            \"is_fake_3d\": False,\n            \"channel_wise\": True,\n        },\n        (2, 6, 64, 64, 64),\n        (2, 6, 64, 64, 64),\n    ],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"medicalnet_resnet50_23datasets\", \"is_fake_3d\": False},\n        (2, 1, 64, 64, 64),\n        (2, 1, 64, 64, 64),\n    ],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"medicalnet_resnet50_23datasets\", \"is_fake_3d\": False},\n        (2, 6, 64, 64, 64),\n        (2, 6, 64, 64, 64),\n    ],\n    [\n        {\"spatial_dims\": 3, \"network_type\": \"resnet50\", \"is_fake_3d\": True, \"pretrained\": True, \"fake_3d_ratio\": 0.2},\n        (2, 1, 64, 64, 64),\n        (2, 1, 64, 64, 64),\n    ],\n]\n\n\n@unittest.skipUnless(has_torchvision, \"Requires torchvision\")\n@skip_if_quick\nclass TestPerceptualLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape, target_shape):\n        with skip_if_downloading_fails():\n            loss = PerceptualLoss(**input_param)\n        result = loss(torch.randn(input_shape), torch.randn(target_shape))\n\n        if \"channel_wise\" in input_param.keys() and input_param[\"channel_wise\"]:\n            self.assertEqual(result.shape, torch.Size([input_shape[1]]))\n        else:\n            self.assertEqual(result.shape, torch.Size([]))\n\n    @parameterized.expand(TEST_CASES)\n    def test_identical_input(self, input_param, input_shape, target_shape):\n        with skip_if_downloading_fails():\n            loss = PerceptualLoss(**input_param)\n        tensor = torch.randn(input_shape)\n        result = loss(tensor, tensor)\n\n        if \"channel_wise\" in input_param.keys() and input_param[\"channel_wise\"]:\n            assert_allclose(result, torch.Tensor([0.0] * input_shape[1]))\n        else:\n            self.assertEqual(result, torch.Tensor([0.0]))\n\n    def test_different_shape(self):\n        with skip_if_downloading_fails():\n            loss = PerceptualLoss(spatial_dims=2, network_type=\"squeeze\")\n        tensor = torch.randn(2, 1, 64, 64)\n        target = torch.randn(2, 1, 32, 32)\n        with self.assertRaises(ValueError):\n            loss(tensor, target)\n\n    def test_1d(self):\n        with self.assertRaises(NotImplementedError):\n            PerceptualLoss(spatial_dims=1)\n\n    @parameterized.expand([\"medicalnet_resnet10_23datasets\", \"medicalnet_resnet50_23datasets\"])\n    def test_medicalnet_on_2d_data(self, network_type):\n        with self.assertRaises(ValueError):\n            PerceptualLoss(spatial_dims=2, network_type=network_type)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_spectral_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import JukeboxLoss\nfrom tests.test_utils import test_script_save\n\nTEST_CASES = [\n    [\n        {\"spatial_dims\": 2},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.070648,\n    ],\n    [\n        {\"spatial_dims\": 2, \"reduction\": \"sum\"},\n        {\n            \"input\": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),\n        },\n        0.8478,\n    ],\n    [\n        {\"spatial_dims\": 3},\n        {\n            \"input\": torch.tensor(\n                [\n                    [\n                        [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],\n                        [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],\n                    ]\n                ]\n            ),\n            \"target\": torch.tensor(\n                [\n                    [\n                        [[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],\n                        [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],\n                    ]\n                ]\n            ),\n        },\n        0.03838,\n    ],\n]\n\n\nclass TestJukeboxLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_results(self, input_param, input_data, expected_val):\n        results = JukeboxLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n    def test_2d_shape(self):\n        results = JukeboxLoss(spatial_dims=2, reduction=\"none\").forward(**TEST_CASES[0][1])\n        self.assertEqual(results.shape, (1, 2, 2, 3))\n\n    def test_3d_shape(self):\n        results = JukeboxLoss(spatial_dims=3, reduction=\"none\").forward(**TEST_CASES[2][1])\n        self.assertEqual(results.shape, (1, 2, 2, 2, 3))\n\n    def test_script(self):\n        loss = JukeboxLoss(spatial_dims=2)\n        test_input = torch.ones(2, 1, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_ssim_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.losses.ssim_loss import SSIMLoss\nfrom monai.utils import set_determinism\n\n# from tests.utils import test_script_save\n\n\nclass TestSSIMLoss(unittest.TestCase):\n\n    def test_shape(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(2, 3, 16, 16))\n        target = torch.abs(torch.randn(2, 3, 16, 16))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type=\"gaussian\", reduction=\"mean\").forward(\n            preds, target\n        )\n        expected_val = 0.9546\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n        result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type=\"gaussian\", reduction=\"sum\").forward(\n            preds, target\n        )\n        expected_val = 1.9092\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n        result = SSIMLoss(spatial_dims=2, data_range=1.0, kernel_type=\"gaussian\", reduction=\"none\").forward(\n            preds, target\n        )\n        expected_val = [[0.9121], [0.9971]]\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n    # def test_script(self):\n    #     loss = SSIMLoss(spatial_dims=2)\n    #     test_input = torch.ones(2, 2, 16, 16)\n    #     test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_sure_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.losses import SURELoss\n\n\nclass TestSURELoss(unittest.TestCase):\n\n    def test_real_value(self):\n        \"\"\"Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.\"\"\"\n        sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1)\n\n        def operator(x):\n            return x\n\n        y_pseudo_gt = torch.randn(2, 1, 128, 128)\n        x = torch.randn(2, 1, 128, 128)\n        loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False)\n        self.assertAlmostEqual(loss.item(), 0.0)\n\n    def test_complex_value(self):\n        \"\"\"Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.\"\"\"\n\n        def operator(x):\n            return x\n\n        sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1)\n        y_pseudo_gt = torch.randn(2, 2, 128, 128)\n        x = torch.randn(2, 2, 128, 128)\n        loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True)\n        self.assertAlmostEqual(loss.item(), 0.0)\n\n    def test_complex_general_input(self):\n        \"\"\"Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.\"\"\"\n\n        def operator(x):\n            return x\n\n        perturb_noise_real = torch.randn(2, 1, 128, 128)\n        perturb_noise_complex = torch.zeros(2, 2, 128, 128)\n        perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze()\n        y_pseudo_gt_real = torch.randn(2, 1, 128, 128)\n        y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128)\n        y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze()\n        x_real = torch.randn(2, 1, 128, 128)\n        x_complex = torch.zeros(2, 2, 128, 128)\n        x_complex[:, 0, :, :] = x_real.squeeze()\n\n        sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1)\n        sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1)\n\n        loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False)\n        loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True)\n        self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_tversky_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import TverskyLoss\nfrom tests.test_utils import test_script_save\n\nTEST_CASES = [\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.307576,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.416657,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"soft_label\": True},\n        {\n            \"input\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n            \"target\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4, \"soft_label\": False},\n        {\n            \"input\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n            \"target\": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),\n        },\n        0.307773,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"smooth_nr\": 0, \"smooth_dr\": 0},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": False, \"to_onehot_y\": True, \"smooth_nr\": 0, \"smooth_dr\": 1e-3},\n        {\n            \"input\": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),\n            \"target\": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),\n        },\n        0.000999,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"sigmoid\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.435050,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"sigmoid\": True, \"batch\": True},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.422979,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"sigmoid\": True,\n            \"reduction\": \"sum\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        1.74013,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\"include_background\": True, \"to_onehot_y\": True, \"softmax\": True, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        0.383713,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"softmax\": True,\n            \"reduction\": \"none\",\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        [[0.210961, 0.295339], [0.599952, 0.428547]],\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"alpha\": 0.3, \"beta\": 0.7, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.3589,\n    ],\n    [  # shape: (1, 1, 2, 2), (1, 1, 2, 2)\n        {\"include_background\": True, \"sigmoid\": True, \"alpha\": 0.7, \"beta\": 0.3, \"smooth_nr\": 1e-6, \"smooth_dr\": 1e-6},\n        {\"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), \"target\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},\n        0.247366,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 1e-4, \"smooth_dr\": 1e-4},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.999963,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\"include_background\": True, \"other_act\": torch.tanh, \"smooth_nr\": 0, \"smooth_dr\": 1e-3, \"batch\": True},\n        {\n            \"input\": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),\n            \"target\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),\n        },\n        0.999963,\n    ],\n    [  # shape: (2, 2, 3), (2, 1, 3)\n        {\n            \"include_background\": True,\n            \"to_onehot_y\": True,\n            \"other_act\": lambda x: torch.log_softmax(x, dim=1),\n            \"smooth_nr\": 1e-4,\n            \"smooth_dr\": 1e-4,\n        },\n        {\n            \"input\": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),\n            \"target\": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),\n        },\n        -8.533317,\n    ],\n]\n\n\nclass TestTverskyLoss(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_val):\n        result = TverskyLoss(**input_param).forward(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = TverskyLoss()\n        with self.assertRaisesRegex(AssertionError, \"\"):\n            loss.forward(torch.ones((2, 2, 3)), torch.ones((4, 5, 6)))\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertRaisesRegex(ValueError, \"\"):\n            TverskyLoss(reduction=\"unknown\")(chn_input, chn_target)\n        with self.assertRaisesRegex(ValueError, \"\"):\n            TverskyLoss(reduction=None)(chn_input, chn_target)\n\n    @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])\n    def test_input_warnings(self, include_background, softmax, to_onehot_y):\n        chn_input = torch.ones((1, 1, 3))\n        chn_target = torch.ones((1, 1, 3))\n        with self.assertWarns(Warning):\n            loss = TverskyLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)\n            loss.forward(chn_input, chn_target)\n\n    def test_script(self):\n        loss = TverskyLoss()\n        test_input = torch.ones(2, 1, 8, 8)\n        test_script_save(loss, test_input, test_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/losses/test_unified_focal_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import AsymmetricUnifiedFocalLoss\n\nTEST_CASES = [\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\n            \"y_pred\": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),\n            \"y_true\": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),\n        },\n        0.0,\n    ],\n    [  # shape: (2, 1, 2, 2), (2, 1, 2, 2)\n        {\n            \"y_pred\": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),\n            \"y_true\": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]),\n        },\n        0.0,\n    ],\n]\n\n\nclass TestAsymmetricUnifiedFocalLoss(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_data, expected_val):\n        loss = AsymmetricUnifiedFocalLoss()\n        result = loss(**input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)\n\n    def test_ill_shape(self):\n        loss = AsymmetricUnifiedFocalLoss()\n        with self.assertRaisesRegex(ValueError, \"\"):\n            loss(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2)))\n\n    def test_with_cuda(self):\n        loss = AsymmetricUnifiedFocalLoss()\n        i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])\n        j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])\n        if torch.cuda.is_available():\n            i = i.cuda()\n            j = j.cuda()\n        output = loss(i, j)\n        print(output)\n        np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/metrics/test_calibration_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import mock\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import CalibrationErrorMetric, CalibrationReduction, calibration_binning\nfrom monai.utils import MetricReduction\nfrom tests.test_utils import assert_allclose\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n# Test cases for calibration binning\n# Format: [name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts]\nTEST_BINNING_SMALL_MID = [\n    \"small_mid\",\n    torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]),\n    torch.tensor([[[[1, 0], [0, 1]]]]),\n    5,\n    False,\n    torch.tensor([[[0.1, 0.3, float(\"nan\"), 0.7, 0.9]]]),\n    torch.tensor([[[0.0, 0.0, float(\"nan\"), 1.0, 1.0]]]),\n    torch.tensor([[[1.0, 1.0, 0.0, 1.0, 1.0]]]),\n]\n\nTEST_BINNING_LARGE_MID = [\n    \"large_mid\",\n    torch.tensor(\n        [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]]\n    ),\n    torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]),\n    5,\n    False,\n    torch.tensor(\n        [\n            [[0.1, 0.3, float(\"nan\"), 0.7, 0.9], [float(\"nan\"), 0.3, 0.5, 0.7, float(\"nan\")]],\n            [[float(\"nan\"), 0.3, float(\"nan\"), float(\"nan\"), 0.9], [0.1, float(\"nan\"), float(\"nan\"), 0.7, 0.9]],\n        ]\n    ),\n    torch.tensor(\n        [\n            [[0.0, 0.0, float(\"nan\"), 1.0, 1.0], [float(\"nan\"), 1.0, 0.5, 0.0, float(\"nan\")]],\n            [[float(\"nan\"), 0.0, float(\"nan\"), float(\"nan\"), 1.0], [0.0, float(\"nan\"), float(\"nan\"), 1.0, 1.0]],\n        ]\n    ),\n    torch.tensor(\n        [[[1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0, 0.0]], [[0.0, 2.0, 0.0, 0.0, 2.0], [2.0, 0.0, 0.0, 1.0, 1.0]]]\n    ),\n]\n\nTEST_BINNING_SMALL_LEFT_EDGE = [\n    \"small_left_edge\",\n    torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]),\n    torch.tensor([[[[1, 0], [0, 1]]]]),\n    5,\n    False,\n    torch.tensor([[[0.2, 0.4, 0.6, 0.8, float(\"nan\")]]]),\n    torch.tensor([[[0.0, 0.0, 1.0, 1.0, float(\"nan\")]]]),\n    torch.tensor([[[1.0, 1.0, 1.0, 1.0, 0.0]]]),\n]\n\nTEST_BINNING_SMALL_RIGHT_EDGE = [\n    \"small_right_edge\",\n    torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]),\n    torch.tensor([[[[1, 0], [0, 1]]]]),\n    5,\n    True,\n    torch.tensor([[[float(\"nan\"), 0.2, 0.4, 0.6, 0.8]]]),\n    torch.tensor([[[float(\"nan\"), 0.0, 0.0, 1.0, 1.0]]]),\n    torch.tensor([[[0.0, 1.0, 1.0, 1.0, 1.0]]]),\n]\n\nBINNING_TEST_CASES = [\n    TEST_BINNING_SMALL_MID,\n    TEST_BINNING_LARGE_MID,\n    TEST_BINNING_SMALL_LEFT_EDGE,\n    TEST_BINNING_SMALL_RIGHT_EDGE,\n]\n\n# Test cases for calibration error metric values\n# Format: [name, y_pred, y, num_bins, expected_expected, expected_average, expected_maximum]\nTEST_VALUE_1B1C = [\n    \"1b1c\",\n    torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]),\n    torch.tensor([[[[1, 0], [0, 1]]]]),\n    5,\n    torch.tensor([[0.2]]),\n    torch.tensor([[0.2]]),\n    torch.tensor([[0.3]]),\n]\n\nTEST_VALUE_2B2C = [\n    \"2b2c\",\n    torch.tensor(\n        [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]]\n    ),\n    torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]),\n    5,\n    torch.tensor([[0.2000, 0.3500], [0.2000, 0.1500]]),\n    torch.tensor([[0.2000, 0.4667], [0.2000, 0.1667]]),\n    torch.tensor([[0.3000, 0.7000], [0.3000, 0.3000]]),\n]\n\nVALUE_TEST_CASES = [TEST_VALUE_1B1C, TEST_VALUE_2B2C]\n\n\nclass TestCalibrationBinning(unittest.TestCase):\n\n    @parameterized.expand(BINNING_TEST_CASES)\n    def test_binning(self, _name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts):\n        y_pred = y_pred.to(_device)\n        y = y.to(_device)\n        expected_mean_p = expected_mean_p.to(_device)\n        expected_mean_gt = expected_mean_gt.to(_device)\n        expected_counts = expected_counts.to(_device)\n\n        # Use mock.patch to replace torch.linspace\n        # This is to avoid floating point precision issues when looking at edge conditions\n        mock_boundaries = torch.tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], device=_device)\n        with mock.patch(\"monai.metrics.calibration.torch.linspace\", return_value=mock_boundaries):\n            mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning(y_pred, y, num_bins=num_bins, right=right)\n\n        # Handle NaN comparisons: compare NaN masks separately, then compare non-NaN values\n        # mean_p_per_bin\n        self.assertTrue(torch.equal(torch.isnan(mean_p_per_bin), torch.isnan(expected_mean_p)))\n        mask_p = ~torch.isnan(expected_mean_p)\n        if mask_p.any():\n            assert_allclose(mean_p_per_bin[mask_p], expected_mean_p[mask_p], atol=1e-4, rtol=1e-4)\n\n        # mean_gt_per_bin\n        self.assertTrue(torch.equal(torch.isnan(mean_gt_per_bin), torch.isnan(expected_mean_gt)))\n        mask_gt = ~torch.isnan(expected_mean_gt)\n        if mask_gt.any():\n            assert_allclose(mean_gt_per_bin[mask_gt], expected_mean_gt[mask_gt], atol=1e-4, rtol=1e-4)\n\n        # bin_counts (no NaNs)\n        assert_allclose(bin_counts, expected_counts, atol=1e-4, rtol=1e-4)\n\n    def test_shape_mismatch_raises(self):\n        \"\"\"Test that mismatched shapes raise ValueError.\"\"\"\n        y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1], [0, 0]]]]).to(_device)  # Different shape\n        with self.assertRaises(ValueError) as context:\n            calibration_binning(y_pred, y, num_bins=5)\n        self.assertIn(\"same shape\", str(context.exception))\n\n    def test_insufficient_ndim_raises(self):\n        \"\"\"Test that tensors with ndim < 3 raise ValueError.\"\"\"\n        y_pred = torch.tensor([[0.7, 0.3]]).to(_device)  # Only 2D\n        y = torch.tensor([[1, 0]]).to(_device)\n        with self.assertRaises(ValueError) as context:\n            calibration_binning(y_pred, y, num_bins=5)\n        self.assertIn(\"ndim\", str(context.exception))\n\n    def test_invalid_num_bins_raises(self):\n        \"\"\"Test that num_bins < 1 raises ValueError.\"\"\"\n        y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device)\n        with self.assertRaises(ValueError) as context:\n            calibration_binning(y_pred, y, num_bins=0)\n        self.assertIn(\"num_bins\", str(context.exception))\n\n\nclass TestCalibrationErrorMetricValue(unittest.TestCase):\n\n    @parameterized.expand(VALUE_TEST_CASES)\n    def test_expected_reduction(self, _name, y_pred, y, num_bins, expected_expected, _expected_average, _expected_max):\n        y_pred = y_pred.to(_device)\n        y = y.to(_device)\n        expected_expected = expected_expected.to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=num_bins,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n\n        assert_allclose(result, expected_expected, atol=1e-4, rtol=1e-4)\n\n    @parameterized.expand(VALUE_TEST_CASES)\n    def test_average_reduction(self, _name, y_pred, y, num_bins, _expected_expected, expected_average, _expected_max):\n        y_pred = y_pred.to(_device)\n        y = y.to(_device)\n        expected_average = expected_average.to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=num_bins,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.AVERAGE,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n\n        assert_allclose(result, expected_average, atol=1e-4, rtol=1e-4)\n\n    @parameterized.expand(VALUE_TEST_CASES)\n    def test_maximum_reduction(self, _name, y_pred, y, num_bins, _expected_expected, _expected_average, expected_max):\n        y_pred = y_pred.to(_device)\n        y = y.to(_device)\n        expected_max = expected_max.to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=num_bins,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.MAXIMUM,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n\n        assert_allclose(result, expected_max, atol=1e-4, rtol=1e-4)\n\n\nclass TestCalibrationErrorEmptyBins(unittest.TestCase):\n    \"\"\"Test edge cases when all bins are empty (division by zero scenarios).\"\"\"\n\n    def test_expected_reduction_all_empty_bins_returns_nan(self):\n        \"\"\"Test that EXPECTED reduction returns NaN when all bins are empty (division by zero case).\"\"\"\n        from unittest import mock\n\n        y_pred = torch.tensor([[[[0.5, 0.5], [0.5, 0.5]]]]).to(_device)\n        y = torch.tensor([[[[1, 1], [1, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        # Mock calibration_binning to return zero bin_counts (all empty bins)\n        def mock_binning(y_pred, y, num_bins, right):\n            batch_size, num_channels = y_pred.shape[:2]\n            device = y_pred.device\n            mean_p = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            mean_gt = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            counts = torch.zeros((batch_size, num_channels, num_bins), device=device)\n            return mean_p, mean_gt, counts\n\n        with mock.patch(\"monai.metrics.calibration.calibration_binning\", side_effect=mock_binning):\n            metric(y_pred=y_pred, y=y)\n            result = metric.aggregate()\n\n        # All bins empty should result in NaN\n        self.assertTrue(torch.isnan(result).all(), \"Result should be NaN when all bins are empty\")\n\n    def test_average_reduction_all_empty_bins_returns_nan(self):\n        \"\"\"Test that AVERAGE reduction returns NaN when all bins are empty.\"\"\"\n        from unittest import mock\n\n        y_pred = torch.tensor([[[[0.5, 0.5], [0.5, 0.5]]]]).to(_device)\n        y = torch.tensor([[[[1, 1], [1, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.AVERAGE,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        def mock_binning(y_pred, y, num_bins, right):\n            batch_size, num_channels = y_pred.shape[:2]\n            device = y_pred.device\n            mean_p = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            mean_gt = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            counts = torch.zeros((batch_size, num_channels, num_bins), device=device)\n            return mean_p, mean_gt, counts\n\n        with mock.patch(\"monai.metrics.calibration.calibration_binning\", side_effect=mock_binning):\n            metric(y_pred=y_pred, y=y)\n            result = metric.aggregate()\n\n        self.assertTrue(torch.isnan(result).all(), \"Result should be NaN when all bins are empty\")\n\n    def test_maximum_reduction_all_empty_bins_returns_nan(self):\n        \"\"\"Test that MAXIMUM reduction returns NaN when all bins are empty.\"\"\"\n        from unittest import mock\n\n        y_pred = torch.tensor([[[[0.5, 0.5], [0.5, 0.5]]]]).to(_device)\n        y = torch.tensor([[[[1, 1], [1, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.MAXIMUM,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        def mock_binning(y_pred, y, num_bins, right):\n            batch_size, num_channels = y_pred.shape[:2]\n            device = y_pred.device\n            mean_p = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            mean_gt = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            counts = torch.zeros((batch_size, num_channels, num_bins), device=device)\n            return mean_p, mean_gt, counts\n\n        with mock.patch(\"monai.metrics.calibration.calibration_binning\", side_effect=mock_binning):\n            metric(y_pred=y_pred, y=y)\n            result = metric.aggregate()\n\n        self.assertTrue(torch.isnan(result).all(), \"Result should be NaN when all bins are empty\")\n\n    def test_expected_reduction_with_zeros_only_returns_nan(self):\n        \"\"\"Test EXPECTED reduction returns NaN for channels where all bin_counts are zero.\n\n        This tests the actual division-by-zero fix: if we have values that all fall\n        outside the valid probability range [0, 1], all bins would be empty.\n        \"\"\"\n        # Create a 2-channel tensor where one channel has valid data and one is out of range\n        # Note: calibration_binning clamps values, but we can test with very extreme distributions\n        # that result in some channels having all NaN abs_diff\n        # A simpler test: create data where bin_counts sum to zero for a channel\n\n        # Use mock to simulate the scenario where bin_counts are zero for one channel\n        from unittest import mock\n\n        y_pred = torch.tensor([[[[0.5, 0.5], [0.5, 0.5]], [[0.5, 0.5], [0.5, 0.5]]]]).to(_device)\n        y = torch.tensor([[[[1, 1], [1, 1]], [[1, 1], [1, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.NONE,\n        )\n\n        # Mock calibration_binning to return zero bin_counts for first channel\n        def mock_binning(y_pred, y, num_bins, right):\n            batch_size, num_channels = y_pred.shape[:2]\n            device = y_pred.device\n\n            # Create normal results for channel 1, all zeros for channel 0\n            mean_p = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            mean_gt = torch.full((batch_size, num_channels, num_bins), float(\"nan\"), device=device)\n            counts = torch.zeros((batch_size, num_channels, num_bins), device=device)\n\n            # Channel 1 has some data\n            mean_p[0, 1, 2] = 0.5\n            mean_gt[0, 1, 2] = 0.6\n            counts[0, 1, 2] = 4.0\n\n            return mean_p, mean_gt, counts\n\n        with mock.patch(\"monai.metrics.calibration.calibration_binning\", side_effect=mock_binning):\n            metric(y_pred=y_pred, y=y)\n            result = metric.aggregate()\n\n        # Channel 0 should be NaN (all bins empty), Channel 1 should have a value\n        self.assertTrue(torch.isnan(result[0, 0]).item(), \"Channel 0 should be NaN when all bins are empty\")\n        self.assertFalse(torch.isnan(result[0, 1]).item(), \"Channel 1 should have a valid value\")\n        assert_allclose(result[0, 1], torch.tensor(0.1, device=_device), atol=1e-4, rtol=1e-4)\n\n\nclass TestCalibrationErrorMetricOptions(unittest.TestCase):\n\n    def test_include_background_false(self):\n        y_pred = torch.tensor(\n            [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]]\n        ).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=False,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.MEAN,\n        )\n\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n\n        assert_allclose(result, torch.tensor(0.2500, device=_device), atol=1e-4, rtol=1e-4)\n\n    def test_metric_reduction_mean(self):\n        y_pred = torch.tensor(\n            [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]]\n        ).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.MEAN,\n        )\n\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n\n        # Mean of [[0.2000, 0.3500], [0.2000, 0.1500]] = 0.225\n        assert_allclose(result, torch.tensor(0.2250, device=_device), atol=1e-4, rtol=1e-4)\n\n    def test_get_not_nans(self):\n        y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.MEAN,\n            get_not_nans=True,\n        )\n\n        metric(y_pred=y_pred, y=y)\n        result, not_nans = metric.aggregate()\n\n        assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4)\n        self.assertEqual(not_nans.item(), 1)\n\n    def test_cumulative_iterations(self):\n        \"\"\"Test that the metric correctly accumulates over multiple iterations.\"\"\"\n        y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device)\n        y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device)\n\n        metric = CalibrationErrorMetric(\n            num_bins=5,\n            include_background=True,\n            calibration_reduction=CalibrationReduction.EXPECTED,\n            metric_reduction=MetricReduction.MEAN,\n        )\n\n        # First iteration\n        metric(y_pred=y_pred, y=y)\n        # Second iteration\n        metric(y_pred=y_pred, y=y)\n\n        result = metric.aggregate()\n        # Should still be 0.2 since both iterations have the same data\n        assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4)\n\n        # Test reset\n        metric.reset()\n        data = metric.get_buffer()\n        self.assertIsNone(data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_average_precision.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import decollate_batch\nfrom monai.metrics import AveragePrecisionMetric, compute_average_precision\nfrom monai.transforms import Activations, AsDiscrete, Compose, ToTensor\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\nTEST_CASE_1 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),\n    torch.tensor([[0], [0], [1], [1]], device=_device),\n    True,\n    2,\n    \"macro\",\n    0.41667,\n]\n\nTEST_CASE_2 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),\n    torch.tensor([[1], [1], [0], [0]], device=_device),\n    True,\n    2,\n    \"micro\",\n    0.85417,\n]\n\nTEST_CASE_3 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),\n    torch.tensor([[0], [1], [0], [1]], device=_device),\n    True,\n    2,\n    \"macro\",\n    0.83333,\n]\n\nTEST_CASE_4 = [\n    torch.tensor([[0.5], [0.5], [0.2], [8.3]]),\n    torch.tensor([[0], [1], [0], [1]]),\n    False,\n    None,\n    \"macro\",\n    0.83333,\n]\n\nTEST_CASE_5 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, \"macro\", 0.83333]\n\nTEST_CASE_6 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, \"macro\", 0.83333]\n\nTEST_CASE_7 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[0], [1], [0], [1]]),\n    True,\n    2,\n    \"none\",\n    [0.83333, 0.83333],\n]\n\nTEST_CASE_8 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),\n    torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),\n    True,\n    None,\n    \"weighted\",\n    0.66667,\n]\n\nTEST_CASE_9 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),\n    torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),\n    True,\n    None,\n    \"micro\",\n    0.71111,\n]\n\nTEST_CASE_10 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[0], [0], [0], [0]]),\n    True,\n    2,\n    \"macro\",\n    float(\"nan\"),\n]\n\nTEST_CASE_11 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[1], [1], [1], [1]]),\n    True,\n    2,\n    \"macro\",\n    float(\"nan\"),\n]\n\nTEST_CASE_12 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]),\n    True,\n    None,\n    \"macro\",\n    float(\"nan\"),\n]\n\nALL_TESTS = [\n    TEST_CASE_1,\n    TEST_CASE_2,\n    TEST_CASE_3,\n    TEST_CASE_4,\n    TEST_CASE_5,\n    TEST_CASE_6,\n    TEST_CASE_7,\n    TEST_CASE_8,\n    TEST_CASE_9,\n    TEST_CASE_10,\n    TEST_CASE_11,\n    TEST_CASE_12,\n]\n\n\nclass TestComputeAveragePrecision(unittest.TestCase):\n\n    @parameterized.expand(ALL_TESTS)\n    def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value):\n        y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)])\n        y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)])\n        y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0)\n        y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0)\n        result = compute_average_precision(y_pred=y_pred, y=y, average=average)\n        np.testing.assert_allclose(expected_value, result, rtol=1e-5)\n\n    @parameterized.expand(ALL_TESTS)\n    def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value):\n        y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)])\n        y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)])\n        y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)]\n        y = [y_trans(i) for i in decollate_batch(y)]\n        metric = AveragePrecisionMetric(average=average)\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n        np.testing.assert_allclose(expected_value, result, rtol=1e-5)\n        result = metric.aggregate(average=average)  # test optional argument\n        metric.reset()\n        np.testing.assert_allclose(expected_value, result, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_confusion_matrix.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import (\n    ConfusionMatrixMetric,\n    compute_confusion_matrix_metric,\n    do_metric_reduction,\n    get_confusion_matrix,\n)\nfrom tests.test_utils import assert_allclose\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n# input data\ndata: dict[Any, Any] = {\n    \"y_pred\": torch.tensor(\n        [\n            [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n            [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n        ],\n        device=_device,\n    ),\n    \"y\": torch.tensor(\n        [\n            [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]]],\n            [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n        ],\n        device=_device,\n    ),\n}\n\ndata_nan: dict[Any, Any] = {\n    # confusion matrix:[[[0,1,2,1],[1,1,1,1],[0,1,2,1]],\n    #                   [[0,0,0,4],[0,0,4,0],[0,4,0,0]],\n    #                   [[0,0,2,2],[0,0,2,2],[0,4,0,0]]]\n    \"y_pred\": torch.tensor(\n        [\n            [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n            [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],\n            [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],\n        ]\n    ),\n    \"y\": torch.tensor(\n        [\n            [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]]],\n            [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],\n        ]\n    ),\n}\n\ndata_clf: dict[Any, Any] = {\n    \"y_pred\": torch.tensor([[1, 0, 0], [0, 0, 1]]),\n    \"y\": torch.tensor([[1, 0, 0], [0, 1, 0]]),\n    \"compute_sample\": False,\n    \"include_background\": True,\n    \"metric_name\": \"tpr\",\n    \"reduction\": \"mean_channel\",\n    \"get_not_nans\": True,\n}\n\n# 1. test confusion matrix\nTEST_CASE_CONFUSION_MATRIX = [\n    data.copy(),\n    torch.tensor(\n        [\n            [[0.0, 1.0, 2.0, 1.0], [1.0, 1.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0]],\n            [[1.0, 0.0, 3.0, 0.0], [1.0, 0.0, 2.0, 1.0], [1.0, 1.0, 2.0, 0.0]],\n        ]\n    ),\n]\n\n# 2. test metric with compute_sample\nTEST_CASES_COMPUTE_SAMPLE = []\nTEST_CASES_COMPUTE_SAMPLE_MULTI_METRICS = []\nresult_mean_batch = [\n    torch.tensor([0.5000, 0.5000, 0.5000]),\n    torch.tensor([0.1667, 0.2500, 0.3333]),\n    torch.tensor([0.8333, 0.7500, 0.6667]),\n    torch.tensor([0.5000, 0.7500, 0.2500]),\n    torch.tensor([0.8333, 0.5833, 0.8333]),\n    torch.tensor([0.5000, 0.5000, 0.5000]),\n    torch.tensor([0.5000, 0.2500, 0.7500]),\n    torch.tensor([0.1667, 0.4167, 0.1667]),\n    torch.tensor([0.5000, 0.0000, 0.6830]),\n    torch.tensor([0.5000, 0.4167, 0.2500]),\n    torch.tensor([0.7500, 0.6250, 0.6250]),\n    torch.tensor([0.6667, 0.6250, 0.5833]),\n    torch.tensor([0.5000, 0.5833, 0.3333]),\n    torch.tensor([0.3333, 0.2887, 0.1220]),\n    torch.tensor([0.5000, 0.6036, 0.3536]),\n    torch.tensor([0.3333, 0.2500, 0.1667]),\n    torch.tensor([0.3333, 0.3333, 0.0833]),\n]\nresult_mean = [\n    torch.tensor([0.5000]),\n    torch.tensor([0.2500]),\n    torch.tensor([0.7500]),\n    torch.tensor([0.5000]),\n    torch.tensor([0.7500]),\n    torch.tensor([0.5000]),\n    torch.tensor([0.5000]),\n    torch.tensor([0.2500]),\n    torch.tensor([0.5610]),\n    torch.tensor([0.3889]),\n    torch.tensor([0.6667]),\n    torch.tensor([0.6250]),\n    torch.tensor([0.4722]),\n    torch.tensor([0.2480]),\n    torch.tensor([0.4857]),\n    torch.tensor([0.2500]),\n    torch.tensor([0.2500]),\n]\nmetric_names = [\n    \"tpr\",\n    \"fpr\",\n    \"tnr\",\n    \"ppv\",\n    \"npv\",\n    \"fnr\",\n    \"fdr\",\n    \"for\",\n    \"pt\",\n    \"ts\",\n    \"acc\",\n    \"ba\",\n    \"f1\",\n    \"mcc\",\n    \"fm\",\n    \"bm\",\n    \"mk\",\n]\nresult: Any = None\nfor idx, item in enumerate(metric_names):\n    for reduction in [\"mean\", \"mean_batch\"]:\n        TEST_CASE: list[Any] = [data.copy()]\n        TEST_CASE[0][\"compute_sample\"] = True\n        TEST_CASE[0][\"include_background\"] = True\n        TEST_CASE[0][\"metric_name\"] = item\n        TEST_CASE[0][\"reduction\"] = reduction\n        TEST_CASE[0][\"get_not_nans\"] = True\n        if reduction == \"mean_batch\":\n            result = result_mean_batch[idx]\n        elif reduction == \"mean\":\n            result = result_mean[idx]\n        TEST_CASE.append(result)\n        TEST_CASES_COMPUTE_SAMPLE.append(TEST_CASE)\n\n# one input to compute multiple metrics\nfor reduction in [\"mean\", \"mean_batch\"]:\n    TEST_CASE_MULTIPLE: list[Any] = [data.copy()]\n    TEST_CASE_MULTIPLE[0][\"compute_sample\"] = True\n    TEST_CASE_MULTIPLE[0][\"include_background\"] = True\n    TEST_CASE_MULTIPLE[0][\"metric_name\"] = metric_names\n    TEST_CASE_MULTIPLE[0][\"reduction\"] = reduction\n    TEST_CASE_MULTIPLE[0][\"get_not_nans\"] = True\n    if reduction == \"mean_batch\":\n        result = result_mean_batch\n    elif reduction == \"mean\":\n        result = result_mean\n    TEST_CASE_MULTIPLE.append(result)\n    TEST_CASES_COMPUTE_SAMPLE_MULTI_METRICS.append(TEST_CASE_MULTIPLE)\n\n# 3. test metric with compute_sample, denominator may have zeros\nTEST_CASES_COMPUTE_SAMPLE_NAN = []\nmetric_names = [\"tpr\", \"tnr\"]\nresult_sum = [torch.tensor([0.5000]), torch.tensor([4.8333])]\nnot_nans_sum = [torch.tensor([6]), torch.tensor([8])]\nresult_sum_batch = [torch.tensor([0.0000, 0.5000, 0.0000]), torch.tensor([1.6667, 2.5000, 0.6667])]\nnot_nans_sum_batch = [torch.tensor([3.0, 2.0, 1.0]), torch.tensor([2.0, 3.0, 3.0])]\nfor idx in range(2):\n    for reduction in [\"sum\", \"sum_batch\"]:\n        TEST_CASE = [data_nan.copy()]\n        TEST_CASE[0][\"compute_sample\"] = True\n        TEST_CASE[0][\"include_background\"] = True\n        TEST_CASE[0][\"reduction\"] = reduction\n        TEST_CASE[0][\"metric_name\"] = metric_names[idx]\n        TEST_CASE[0][\"get_not_nans\"] = True\n        if reduction == \"sum\":\n            TEST_CASE.append(result_sum[idx])\n            TEST_CASE.append(not_nans_sum[idx])\n        elif reduction == \"sum_batch\":\n            TEST_CASE.append(result_sum_batch[idx])\n            TEST_CASE.append(not_nans_sum_batch[idx])\n        TEST_CASES_COMPUTE_SAMPLE_NAN.append(TEST_CASE)\n\n# 4. test classification task\nresult_clf = torch.tensor(\n    [\n        [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0]],\n        [[0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0]],\n    ]\n)\n\nTEST_CASES_CLF = [data_clf.copy(), result_clf]\n\nTEST_CASE_PRECISION = [\n    {\n        \"y_pred\": torch.zeros([1, 1, 1024, 1024, 44], device=_device),\n        \"y\": torch.zeros([1, 1, 1024, 1024, 44], device=_device),\n    },\n    torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]),\n]\n\n\nclass TestConfusionMatrix(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_CONFUSION_MATRIX])\n    def test_value(self, input_data, expected_value):\n        # include or ignore background\n        input_data[\"include_background\"] = True\n        result = get_confusion_matrix(**input_data)\n        assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n        input_data[\"include_background\"] = False\n        result = get_confusion_matrix(**input_data)\n        assert_allclose(result, expected_value[:, 1:, :], atol=1e-4, rtol=1e-4)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n    @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE)\n    def test_compute_sample(self, input_data, expected_value):\n        params = input_data.copy()\n        vals = {}\n        vals[\"y_pred\"] = params.pop(\"y_pred\")\n        vals[\"y\"] = params.pop(\"y\")\n        metric = ConfusionMatrixMetric(**params)\n        metric(**vals)\n        result, _ = metric.aggregate()[0]\n        assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n\n    @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_MULTI_METRICS)\n    def test_compute_sample_multiple_metrics(self, input_data, expected_values):\n        params = input_data.copy()\n        vals = {}\n        vals[\"y_pred\"] = params.pop(\"y_pred\")\n        vals[\"y\"] = params.pop(\"y\")\n        metric = ConfusionMatrixMetric(**params)\n        metric(**vals)\n        results = metric.aggregate()\n        for idx, item in enumerate(results):\n            result = item[0]\n            expected_value = expected_values[idx]\n            assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n\n    @parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_NAN)\n    def test_compute_sample_with_nan(self, input_data, expected_value, expected_not_nans):\n        params = input_data.copy()\n        vals = {}\n        vals[\"y_pred\"] = params.pop(\"y_pred\")\n        vals[\"y\"] = params.pop(\"y\")\n        metric = ConfusionMatrixMetric(**params)\n        metric(**vals)\n        result, not_nans = metric.aggregate()[0]\n        assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n        assert_allclose(not_nans, expected_not_nans, atol=1e-4, rtol=1e-4)\n\n    @parameterized.expand([TEST_CASES_CLF])\n    def test_clf_with_nan(self, input_data, expected_value):\n        params = input_data.copy()\n        vals = {}\n        vals[\"y_pred\"] = params.pop(\"y_pred\")\n        vals[\"y\"] = params.pop(\"y\")\n        metric = ConfusionMatrixMetric(**params)\n        result = metric(**vals)\n        assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n        result, _ = metric.aggregate(reduction=\"mean_channel\")[0]\n        expected_value, _ = do_metric_reduction(expected_value, \"mean_channel\")\n        expected_value = compute_confusion_matrix_metric(\"tpr\", expected_value)\n        assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n\n    @parameterized.expand([TEST_CASE_PRECISION])\n    def test_precision(self, input_data, expected_value):\n        # include or ignore background\n        result = get_confusion_matrix(**input_data)\n        assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_f_beta.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import FBetaScore\nfrom tests.test_utils import assert_allclose\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n\nclass TestFBetaScore(unittest.TestCase):\n    def test_expecting_success_and_device(self):\n        metric = FBetaScore()\n        y_pred = torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], device=_device)\n        y = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], device=_device)\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()[0]\n        assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)\n        np.testing.assert_equal(result.device, y_pred.device)\n\n    @parameterized.expand(\n        [\n            (0.5, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.609756])),  # success_beta_0_5\n            (2, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.862069])),  # success_beta_2\n            (\n                2,  # success_beta_2, denominator_zero\n                torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),\n                torch.Tensor([0.0]),\n            ),\n        ]\n    )\n    def test_success_and_zero(self, beta, y, expected_score):\n        metric = FBetaScore(beta=beta)\n        metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=y)\n        assert_allclose(metric.aggregate()[0], expected_score, atol=1e-6, rtol=1e-6)\n\n    def test_number_of_dimensions_less_than_2_should_raise_error(self):\n        metric = FBetaScore()\n        with self.assertRaises(ValueError):\n            metric(y_pred=torch.Tensor([1, 1, 1]), y=torch.Tensor([0, 0, 0]))\n\n    def test_with_nan_values(self):\n        metric = FBetaScore(get_not_nans=True)\n        metric(\n            y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]),\n            y=torch.Tensor([[1, 0, 1], [np.nan, np.nan, np.nan], [1, 0, 1]]),\n        )\n        assert_allclose(metric.aggregate()[0][0], torch.Tensor([0.727273]), atol=1e-6, rtol=1e-6)\n\n    def test_do_not_include_background(self):\n        metric = FBetaScore(include_background=False)\n        metric(\n            y_pred=torch.Tensor([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]),\n            y=torch.Tensor([[0, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 1]]),\n        )\n        assert_allclose(metric.aggregate()[0], torch.Tensor([1.0]), atol=1e-7, rtol=1e-7)\n\n    def test_prediction_and_result_have_different_shape(self):\n        metric = FBetaScore()\n        with self.assertRaises(ValueError):\n            metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1]]), y=torch.Tensor([1, 1, 1]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_fid_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.metrics import FIDMetric\nfrom monai.utils import optional_import\n\n_, has_scipy = optional_import(\"scipy\")\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy\")\nclass TestFIDMetric(unittest.TestCase):\n\n    def test_results(self):\n        x = torch.Tensor([[1, 2], [1, 2], [1, 2]])\n        y = torch.Tensor([[2, 2], [1, 2], [1, 2]])\n        results = FIDMetric()(x, y)\n        np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4)\n\n    def test_input_dimensions(self):\n        with self.assertRaises(ValueError):\n            FIDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_froc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import compute_fp_tp_probs, compute_fp_tp_probs_nd, compute_froc_curve_data, compute_froc_score\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\nTEST_CASE_1 = [\n    {\n        \"probs\": torch.tensor([1, 0.6, 0.8], device=_device),\n        \"y_coord\": torch.tensor([0, 2, 3], device=_device),\n        \"x_coord\": torch.tensor([3, 0, 1], device=_device),\n        \"evaluation_mask\": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]),\n        \"labels_to_exclude\": [2],\n        \"resolution_level\": 0,\n    },\n    np.array([0.6]),\n    np.array([1, 0, 0.8]),\n    2,\n]\n\nTEST_CASE_2 = [\n    {\n        \"probs\": torch.tensor([1, 0.6, 0.8]),\n        \"y_coord\": torch.tensor([0, 2, 3]),\n        \"x_coord\": torch.tensor([3, 0, 1]),\n        \"evaluation_mask\": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]),\n        \"resolution_level\": 0,\n    },\n    np.array([0.6]),\n    np.array([1, 0, 0.8]),\n    3,\n]\n\nTEST_CASE_3 = [\n    {\n        \"probs\": torch.tensor([1, 0.6, 0.8]),\n        \"y_coord\": torch.tensor([0, 4, 6]),\n        \"x_coord\": torch.tensor([6, 0, 2]),\n        \"evaluation_mask\": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]),\n        \"resolution_level\": 1,\n    },\n    np.array([0.6]),\n    np.array([1, 0, 0.8]),\n    3,\n]\n\nTEST_CASE_4 = [\n    {\n        \"fp_probs\": np.array([0.8, 0.6]),\n        \"tp_probs\": np.array([1, 1, 0, 0, 0.8, 0.8, 0]),\n        \"num_targets\": 4,\n        \"num_images\": 2,\n    },\n    (0.25, 0.5, 1, 2, 4, 8),\n    0.95833333,\n]\n\nTEST_CASE_5 = [\n    {\n        \"fp_probs\": torch.tensor([0.8, 0.6]),\n        \"tp_probs\": torch.tensor([1, 1, 0, 0, 0.8, 0.8, 0]),\n        \"num_targets\": 4,\n        \"num_images\": 2,\n    },\n    (0.25),\n    0.75,\n]\n\nTEST_CASE_ND_1 = [\n    {\n        \"probs\": torch.tensor([1, 0.6, 0.8]),\n        \"coords\": torch.tensor([[0, 3], [2, 0], [3, 1]]),\n        \"evaluation_mask\": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]),\n    },\n    np.array([0.6]),\n    np.array([1, 0, 0.8]),\n    3,\n]\n\nTEST_CASE_ND_2 = [\n    {\n        \"probs\": torch.tensor([1, 0.6, 0.8]),\n        \"coords\": torch.tensor([[0, 0, 3], [1, 2, 0], [0, 3, 1]]),\n        \"evaluation_mask\": np.array(\n            [\n                [[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]],\n                [[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]],\n            ]\n        ),\n    },\n    np.array([0.6]),\n    np.array([1, 0, 0.8]),\n    3,\n]\n\n\nclass TestComputeFpTp(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_value(self, input_data, expected_fp, expected_tp, expected_num):\n        fp_probs, tp_probs, num_tumors = compute_fp_tp_probs(**input_data)\n        np.testing.assert_allclose(fp_probs, expected_fp, rtol=1e-5)\n        np.testing.assert_allclose(tp_probs, expected_tp, rtol=1e-5)\n        np.testing.assert_equal(num_tumors, expected_num)\n\n\nclass TestComputeFpTpNd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_ND_1, TEST_CASE_ND_2])\n    def test_value(self, input_data, expected_fp, expected_tp, expected_num):\n        fp_probs, tp_probs, num_tumors = compute_fp_tp_probs_nd(**input_data)\n        np.testing.assert_allclose(fp_probs, expected_fp, rtol=1e-5)\n        np.testing.assert_allclose(tp_probs, expected_tp, rtol=1e-5)\n        np.testing.assert_equal(num_tumors, expected_num)\n\n\nclass TestComputeFrocScore(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_4, TEST_CASE_5])\n    def test_value(self, input_data, thresholds, expected_score):\n        fps_per_image, total_sensitivity = compute_froc_curve_data(**input_data)\n        score = compute_froc_score(fps_per_image, total_sensitivity, thresholds)\n        np.testing.assert_allclose(score, expected_score, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_generalized_dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import GeneralizedDiceScore, compute_generalized_dice\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n# keep background\nTEST_CASE_1 = [  # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) with compute_generalized_dice\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n    },\n    [[0.8]],\n]\n\n# remove background\nTEST_CASE_2 = [  # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": False,\n        \"reduction\": \"mean_batch\",\n    },\n    [0.583333, 0.333333],\n]\n\nTEST_CASE_3 = [  # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n        \"reduction\": \"mean\",\n    },\n    [0.5454],\n]\n\nTEST_CASE_4 = [  # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n        \"reduction\": \"sum\",\n    },\n    [1.045455],\n]\n\nTEST_CASE_5 = [  # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice\n    {\"y\": torch.ones((2, 2, 3, 3)), \"y_pred\": torch.ones((2, 2, 3, 3))},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_6 = [  # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice\n    {\"y\": torch.zeros((2, 2, 3, 3)), \"y_pred\": torch.ones((2, 2, 3, 3))},\n    [[0.0000, 0.0000], [0.0000, 0.0000]],\n]\n\nTEST_CASE_7 = [  # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice\n    {\"y\": torch.ones((2, 2, 3, 3)), \"y_pred\": torch.zeros((2, 2, 3, 3))},\n    [[0.0000, 0.0000], [0.0000, 0.0000]],\n]\n\nTEST_CASE_8 = [  # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice\n    {\"y\": torch.zeros((2, 2, 3, 3)), \"y_pred\": torch.zeros((2, 2, 3, 3))},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_9 = [  # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) with GeneralizedDiceScore\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n        \"reduction\": \"mean_channel\",\n    },\n    [0.545455, 0.545455],\n]\n\n\nTEST_CASE_10 = [  # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice\n    # and (3) with GeneralizedDiceScore \"mean_batch\"\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n    },\n    [[0.857143, 0.0, 0.0], [0.5, 0.4, 0.666667]],\n]\n\nTEST_CASE_11 = [  # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes)\n    # and (2) with GeneralizedDiceScore \"mean_channel\"\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n        \"sum_over_classes\": True,\n    },\n    [[0.545455], [0.545455]],\n]\n\n\nclass TestComputeGeneralizedDiceScore(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1])\n    def test_device(self, input_data, _expected_value):\n        \"\"\"\n        Test if the result tensor is on the same device as the input tensor.\n        \"\"\"\n        result = compute_generalized_dice(**input_data)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8])\n    def test_value(self, input_data, expected_value):\n        \"\"\"\n        Test if the computed generalized dice score matches the expected value.\n        \"\"\"\n        result = compute_generalized_dice(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_9])\n    def test_value_class(self, input_data, expected_value):\n        \"\"\"\n        Test if the GeneralizedDiceScore class computes the correct values.\n        \"\"\"\n        y_pred = input_data.pop(\"y_pred\")\n        y = input_data.pop(\"y\")\n        generalized_dice_score = GeneralizedDiceScore(**input_data)\n        generalized_dice_score(y_pred=y_pred, y=y)\n        result = generalized_dice_score.aggregate()\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_10])\n    def test_values_compare(self, input_data, expected_value):\n        \"\"\"\n        Compare the results of compute_generalized_dice function and GeneralizedDiceScore class.\n        \"\"\"\n        result = compute_generalized_dice(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n        y_pred = input_data.pop(\"y_pred\")\n        y = input_data.pop(\"y\")\n        generalized_dice_score = GeneralizedDiceScore(**input_data, reduction=\"mean_batch\")\n        generalized_dice_score(y_pred=y_pred, y=y)\n        result_class_mean = generalized_dice_score.aggregate()\n        np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=0), atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_11])\n    def test_values_compare_sum_over_classes(self, input_data, expected_value):\n        \"\"\"\n        Compare the results when summing over classes between compute_generalized_dice function and GeneralizedDiceScore class.\n        \"\"\"\n        result = compute_generalized_dice(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n        y_pred = input_data.pop(\"y_pred\")\n        y = input_data.pop(\"y\")\n        input_data.pop(\"sum_over_classes\")\n        generalized_dice_score = GeneralizedDiceScore(**input_data, reduction=\"mean_channel\")\n        generalized_dice_score(y_pred=y_pred, y=y)\n        result_class_mean = generalized_dice_score.aggregate()\n        np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=1), atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_meandice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import DiceHelper, DiceMetric, compute_dice\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n# keep background\nTEST_CASE_1 = [  # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n    },\n    [[0.8]],\n]\n\n# remove background and not One-Hot target\nTEST_CASE_2 = [  # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background)\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": False,\n    },\n    [[0.5000, 0.0000], [0.6667, 0.6667]],\n]\n\n# should return Nan for all labels=0 case and skip for MeanDice\nTEST_CASE_3 = [\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],\n                [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n    },\n    [[False, True, True], [False, False, True]],\n]\n\nTEST_CASE_4 = [\n    {\"include_background\": True, \"reduction\": \"mean_batch\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    [0.6786, 0.4000, 0.6667],\n]\n\nTEST_CASE_5 = [\n    {\"include_background\": True, \"reduction\": \"mean\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    0.689683,\n]\n\nTEST_CASE_6 = [\n    {\"include_background\": True, \"reduction\": \"sum_batch\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n    },\n    [1.7143, 0.0000, 0.0000],\n]\n\nTEST_CASE_7 = [\n    {\"include_background\": True, \"reduction\": \"mean\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n    },\n    0.857143,\n]\n\nTEST_CASE_8 = [\n    {\"include_background\": False, \"reduction\": \"sum_batch\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n    },\n    [0.0000, 0.0000],\n]\n\nTEST_CASE_9 = [\n    {\"y\": torch.ones((2, 2, 3, 3)), \"y_pred\": torch.ones((2, 2, 3, 3))},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_10 = [\n    {\"y\": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], \"y_pred\": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))]},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_11 = [\n    {\"y\": torch.zeros((2, 2, 3, 3)), \"y_pred\": torch.zeros((2, 2, 3, 3)), \"ignore_empty\": False},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_12 = [\n    {\"y\": torch.zeros((2, 2, 3, 3)), \"y_pred\": torch.ones((2, 2, 3, 3)), \"ignore_empty\": False},\n    [[0.0000, 0.0000], [0.0000, 0.0000]],\n]\n\n# test return_with_label\nTEST_CASE_13 = [\n    {\n        \"include_background\": True,\n        \"reduction\": \"mean_batch\",\n        \"get_not_nans\": True,\n        \"return_with_label\": [\"bg\", \"fg0\", \"fg1\"],\n    },\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    {\"bg\": 0.6786, \"fg0\": 0.4000, \"fg1\": 0.6667},\n]\n\n# test return_with_label, include_background\nTEST_CASE_14 = [\n    {\"include_background\": True, \"reduction\": \"mean_batch\", \"get_not_nans\": True, \"return_with_label\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    {\"label_0\": 0.6786, \"label_1\": 0.4000, \"label_2\": 0.6667},\n]\n\n# test return_with_label, not include_background\nTEST_CASE_15 = [\n    {\"include_background\": False, \"reduction\": \"mean_batch\", \"get_not_nans\": True, \"return_with_label\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    {\"label_1\": 0.4000, \"label_2\": 0.6667},\n]\n\n\nclass TestComputeMeanDice(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])\n    def test_value(self, input_data, expected_value):\n        result = compute_dice(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n    @parameterized.expand([TEST_CASE_3])\n    def test_nans(self, input_data, expected_value):\n        result = compute_dice(**input_data)\n        self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))\n\n    @parameterized.expand([TEST_CASE_3])\n    def test_helper(self, input_data, _unused):\n        vals = {\"y_pred\": dict(input_data).pop(\"y_pred\"), \"y\": dict(input_data).pop(\"y\")}\n        result = DiceHelper(threshold=True)(**vals)\n        np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)\n        np.testing.assert_allclose(sorted(result[1].cpu().numpy()), [0.0, 1.0, 2.0], atol=1e-4)\n        result = DiceHelper(apply_argmax=True, get_not_nans=False)(**vals)\n        np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0], atol=1e-4)\n\n        num_classes = vals[\"y_pred\"].shape[1]\n        vals[\"y_pred\"] = torch.argmax(vals[\"y_pred\"], dim=1, keepdim=True)\n        result = DiceHelper(threshold=True, num_classes=num_classes)(**vals)\n        np.testing.assert_allclose(result[0].cpu().numpy(), [0.0, 0.0, 0.0], atol=1e-4)\n\n    # DiceMetric class tests\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10])\n    def test_value_class(self, input_data, expected_value):\n        # same test as for compute_dice\n        vals = {\"y_pred\": input_data.pop(\"y_pred\")}\n        vals[\"y\"] = input_data.pop(\"y\")\n        dice_metric = DiceMetric(**input_data)\n        dice_metric(**vals)\n        result = dice_metric.aggregate(reduction=\"none\")\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand(\n        [TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_13, TEST_CASE_14, TEST_CASE_15]\n    )\n    def test_nans_class(self, params, input_data, expected_value):\n        dice_metric = DiceMetric(**params)\n        dice_metric(**input_data)\n        result, _ = dice_metric.aggregate()\n        if isinstance(result, dict):\n            self.assertEqual(result, expected_value)\n        else:\n            np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_meaniou.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import MeanIoU, compute_iou\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n# keep background\nTEST_CASE_1 = [  # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n    },\n    [[0.6667]],\n]\n\n# remove background and not One-Hot target\nTEST_CASE_2 = [  # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background)\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": False,\n    },\n    [[0.3333, 0.0000], [0.5000, 0.5000]],\n]\n\n# should return Nan for all labels=0 case and skip for MeanIoU\nTEST_CASE_3 = [\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],\n                [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n        \"include_background\": True,\n    },\n    [[False, True, True], [False, False, True]],\n]\n\nTEST_CASE_4 = [\n    {\"include_background\": True, \"reduction\": \"mean_batch\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    [0.5416, 0.2500, 0.5000],\n]\n\nTEST_CASE_5 = [\n    {\"include_background\": True, \"reduction\": \"mean\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],\n            ]\n        ),\n    },\n    0.5555,\n]\n\nTEST_CASE_6 = [\n    {\"include_background\": True, \"reduction\": \"sum_batch\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n    },\n    [1.5000, 0.0000, 0.0000],\n]\n\nTEST_CASE_7 = [\n    {\"include_background\": True, \"reduction\": \"mean\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n    },\n    0.7500,\n]\n\nTEST_CASE_8 = [\n    {\"include_background\": False, \"reduction\": \"sum_batch\", \"get_not_nans\": True},\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],\n                [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],\n            ]\n        ),\n        \"y\": torch.tensor(\n            [\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n            ]\n        ),\n    },\n    [0.0000, 0.0000],\n]\n\nTEST_CASE_9 = [\n    {\"y\": torch.ones((2, 2, 3, 3)), \"y_pred\": torch.ones((2, 2, 3, 3))},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_10 = [\n    {\"y\": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))], \"y_pred\": [torch.ones((2, 3, 3)), torch.ones((2, 3, 3))]},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_11 = [\n    {\"y\": torch.zeros((2, 2, 3, 3)), \"y_pred\": torch.zeros((2, 2, 3, 3)), \"ignore_empty\": False},\n    [[1.0000, 1.0000], [1.0000, 1.0000]],\n]\n\nTEST_CASE_12 = [\n    {\"y\": torch.zeros((2, 2, 3, 3)), \"y_pred\": torch.ones((2, 2, 3, 3)), \"ignore_empty\": False},\n    [[0.0000, 0.0000], [0.0000, 0.0000]],\n]\n\n\nclass TestComputeMeanIoU(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])\n    def test_value(self, input_data, expected_value):\n        result = compute_iou(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n    @parameterized.expand([TEST_CASE_3])\n    def test_nans(self, input_data, expected_value):\n        result = compute_iou(**input_data)\n        self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))\n\n    # MeanIoU class tests\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_10])\n    def test_value_class(self, input_data, expected_value):\n        # same test as for compute_iou\n        vals = {}\n        vals[\"y_pred\"] = input_data.pop(\"y_pred\")\n        vals[\"y\"] = input_data.pop(\"y\")\n        iou_metric = MeanIoU(**input_data)\n        iou_metric(**vals)\n        result = iou_metric.aggregate(reduction=\"none\")\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8])\n    def test_nans_class(self, params, input_data, expected_value):\n        iou_metric = MeanIoU(**params)\n        iou_metric(**input_data)\n        result, _ = iou_metric.aggregate()\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_mmd_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import MMDMetric\n\nTEST_CASES = [\n    [{\"y_mapping\": None}, {\"y\": torch.ones([3, 3, 144, 144]), \"y_pred\": torch.ones([3, 3, 144, 144])}, 0.0],\n    [{\"y_mapping\": None}, {\"y\": torch.ones([3, 3, 144, 144, 144]), \"y_pred\": torch.ones([3, 3, 144, 144, 144])}, 0.0],\n    [\n        {\"y_mapping\": lambda x: x.square()},\n        {\"y\": torch.ones([3, 3, 144, 144]), \"y_pred\": torch.ones([3, 3, 144, 144])},\n        0.0,\n    ],\n    [\n        {\"y_mapping\": lambda x: x.square()},\n        {\"y\": torch.ones([3, 3, 144, 144, 144]), \"y_pred\": torch.ones([3, 3, 144, 144, 144])},\n        0.0,\n    ],\n]\n\n\nclass TestMMDMetric(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_results(self, input_param, input_data, expected_val):\n        metric = MMDMetric(**input_param)\n        results = metric(**input_data)\n        np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4)\n\n    def test_if_inputs_different_shapes(self):\n        with self.assertRaises(ValueError):\n            MMDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))\n\n    def test_if_inputs_have_one_sample(self):\n        with self.assertRaises(ValueError):\n            MMDMetric()(torch.ones([1, 3, 144, 144]), torch.ones([1, 3, 144, 144]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_multiscalessim_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.metrics import MultiScaleSSIMMetric\nfrom monai.utils import set_determinism\n\n\nclass TestMultiScaleSSIMMetric(unittest.TestCase):\n\n    def test2d_gaussian(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(1, 1, 64, 64))\n        target = torch.abs(torch.randn(1, 1, 64, 64))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type=\"gaussian\", weights=[0.5, 0.5])\n        metric(preds, target)\n        result = metric.aggregate()\n        expected_value = 0.023176\n        self.assertAlmostEqual(expected_value, result.item(), 4)\n\n    def test2d_uniform(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(1, 1, 64, 64))\n        target = torch.abs(torch.randn(1, 1, 64, 64))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type=\"uniform\", weights=[0.5, 0.5])\n        metric(preds, target)\n        result = metric.aggregate()\n        expected_value = 0.022655\n        self.assertAlmostEqual(expected_value, result.item(), 4)\n\n    def test3d_gaussian(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(1, 1, 64, 64, 64))\n        target = torch.abs(torch.randn(1, 1, 64, 64, 64))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        metric = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_type=\"gaussian\", weights=[0.5, 0.5])\n        metric(preds, target)\n        result = metric.aggregate()\n        expected_value = 0.061796\n        self.assertAlmostEqual(expected_value, result.item(), 4)\n\n    def input_ill_input_shape2d(self):\n        metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5])\n\n        with self.assertRaises(ValueError):\n            metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64))\n\n    def input_ill_input_shape3d(self):\n        metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5])\n\n        with self.assertRaises(ValueError):\n            metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64))\n\n    def small_inputs(self):\n        metric = MultiScaleSSIMMetric(spatial_dims=2)\n\n        with self.assertRaises(ValueError):\n            metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_panoptic_quality.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import PanopticQualityMetric, compute_panoptic_quality\nfrom monai.metrics.panoptic_quality import compute_mean_iou\nfrom tests.test_utils import SkipIfNoModule\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n# TEST_FUNC_CASE related cases are used to test for single image with HW input shape\n\nsample_1 = torch.randint(low=0, high=5, size=(64, 64), device=_device)\nsample_2_pred = torch.as_tensor([[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]], device=_device)\nsample_2_pred_need_remap = torch.as_tensor([[0, 7, 7, 7], [0, 0, 0, 0], [1, 0, 8, 8], [9, 1, 1, 0]], device=_device)\nsample_2_gt = torch.as_tensor([[1, 1, 2, 1], [0, 0, 0, 0], [1, 3, 0, 0], [4, 3, 3, 3]], device=_device)\n# if pred == gt, result should be 1\nTEST_FUNC_CASE_1 = [{\"pred\": sample_1, \"gt\": sample_1, \"match_iou_threshold\": 0.99}, 1.0]\n\n# test sample_2 when match_iou_threshold = 0.5\nTEST_FUNC_CASE_2 = [{\"pred\": sample_2_pred, \"gt\": sample_2_gt, \"match_iou_threshold\": 0.5}, 0.25]\n# test sample_2 when match_iou_threshold = 0.3, metric_name = \"sq\"\nTEST_FUNC_CASE_3 = [{\"pred\": sample_2_pred, \"gt\": sample_2_gt, \"metric_name\": \"sq\", \"match_iou_threshold\": 0.3}, 0.6]\n# test sample_2 when match_iou_threshold = 0.3, pred has different order, metric_name = \"RQ\"\nTEST_FUNC_CASE_4 = [\n    {\"pred\": sample_2_pred_need_remap, \"gt\": sample_2_gt, \"metric_name\": \"RQ\", \"match_iou_threshold\": 0.3},\n    0.75,\n]\n\n# TEST_CLS_CASE related cases are used to test the PanopticQualityMetric with B2HW input\nsample_3_pred = torch.as_tensor(\n    [\n        [[[2, 0, 1], [2, 1, 1], [0, 1, 1]], [[0, 1, 3], [0, 0, 0], [1, 2, 1]]],\n        [[[1, 1, 1], [3, 2, 0], [3, 2, 1]], [[1, 1, 3], [3, 1, 1], [0, 3, 0]]],\n    ],\n    device=_device,\n)\n\nsample_3_gt = torch.as_tensor(\n    [\n        [[[2, 0, 0], [2, 0, 0], [2, 2, 3]], [[3, 3, 3], [3, 2, 1], [2, 2, 3]]],\n        [[[1, 1, 1], [0, 0, 3], [0, 0, 3]], [[0, 1, 3], [2, 1, 0], [3, 0, 3]]],\n    ],\n    device=_device,\n)\n\n# test sample_3, num_classes = 3, match_iou_threshold = 0.5\nTEST_CLS_CASE_1 = [{\"num_classes\": 3, \"match_iou_threshold\": 0.5}, sample_3_pred, sample_3_gt, (0.0, 0.0, 0.25)]\n\n# test sample_3, num_classes = 3, match_iou_threshold = 0.3\nTEST_CLS_CASE_2 = [{\"num_classes\": 3, \"match_iou_threshold\": 0.3}, sample_3_pred, sample_3_gt, (0.25, 0.5, 0.25)]\n\n# test sample_3, num_classes = 4, match_iou_threshold = 0.3, metric_name = \"segmentation_quality\"\nTEST_CLS_CASE_3 = [\n    {\"num_classes\": 4, \"match_iou_threshold\": 0.3, \"metric_name\": \"segmentation_quality\"},\n    sample_3_pred,\n    sample_3_gt,\n    (0.5, 0.5, 1.0, 0.0),\n]\n\n# test sample_3, num_classes = 3, match_iou_threshold = 0.4, reduction = \"none\", metric_name = \"Recognition Quality\"\nTEST_CLS_CASE_4 = [\n    {\"num_classes\": 3, \"reduction\": \"none\", \"match_iou_threshold\": 0.4, \"metric_name\": \"Recognition Quality\"},\n    sample_3_pred,\n    sample_3_gt,\n    [[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]],\n]\n\n# test sample_3, num_classes = 3, match_iou_threshold = 0.4, reduction = \"none\", multiple metrics\nTEST_CLS_CASE_5 = [\n    {\"num_classes\": 3, \"reduction\": \"none\", \"match_iou_threshold\": 0.4, \"metric_name\": [\"Recognition Quality\", \"pq\"]},\n    sample_3_pred,\n    sample_3_gt,\n    [torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])],\n]\n\n# 3D test cases\nsample_3d_pred = torch.as_tensor(\n    [[[[[2, 0], [1, 1]], [[0, 1], [2, 1]]], [[[0, 1], [3, 0]], [[1, 0], [1, 1]]]]],  # instance channel  # class channel\n    device=_device,\n)\n\nsample_3d_gt = torch.as_tensor(\n    [[[[[2, 0], [0, 0]], [[2, 2], [2, 3]]], [[[3, 3], [3, 2]], [[2, 2], [3, 3]]]]],  # instance channel  # class channel\n    device=_device,\n)\n\n# test 3D sample, num_classes = 3, match_iou_threshold = 0.5\nTEST_3D_CASE_1 = [{\"num_classes\": 3, \"match_iou_threshold\": 0.5}, sample_3d_pred, sample_3d_gt]\n\n# test confusion matrix return\nTEST_CM_CASE_1 = [\n    {\"num_classes\": 3, \"match_iou_threshold\": 0.5, \"return_confusion_matrix\": True},\n    sample_3_pred,\n    sample_3_gt,\n]\n\n\n@SkipIfNoModule(\"scipy.optimize\")\nclass TestPanopticQualityMetric(unittest.TestCase):\n    @parameterized.expand([TEST_FUNC_CASE_1, TEST_FUNC_CASE_2, TEST_FUNC_CASE_3, TEST_FUNC_CASE_4])\n    def test_value(self, input_params, expected_value):\n        result = compute_panoptic_quality(**input_params)\n        np.testing.assert_allclose(result.cpu().detach().item(), expected_value, atol=1e-4)\n        np.testing.assert_equal(result.device, input_params[\"pred\"].device)\n\n    @parameterized.expand([TEST_CLS_CASE_1, TEST_CLS_CASE_2, TEST_CLS_CASE_3, TEST_CLS_CASE_4, TEST_CLS_CASE_5])\n    def test_value_class(self, input_params, y_pred, y_gt, expected_value):\n        metric = PanopticQualityMetric(**input_params)\n        metric(y_pred, y_gt)\n        outputs = metric.aggregate()\n        if isinstance(outputs, list):\n            for output, value in zip(outputs, expected_value):\n                np.testing.assert_allclose(output.cpu().numpy(), np.asarray(value), atol=1e-4)\n        else:\n            np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4)\n\n    def test_3d_support(self):\n        \"\"\"Test that 3D input is properly supported.\"\"\"\n        input_params, y_pred, y_gt = TEST_3D_CASE_1\n        metric = PanopticQualityMetric(**input_params)\n        # Should not raise an error for 3D input\n        metric(y_pred, y_gt)\n        outputs = metric.aggregate()\n        # Check that output is a tensor\n        self.assertIsInstance(outputs, torch.Tensor)\n        # Check that output shape is correct (num_classes,)\n        self.assertEqual(outputs.shape, torch.Size([3]))\n\n    def test_confusion_matrix_return(self):\n        \"\"\"Test that confusion matrix can be returned instead of computed metrics.\"\"\"\n        input_params, y_pred, y_gt = TEST_CM_CASE_1\n        metric = PanopticQualityMetric(**input_params)\n        metric(y_pred, y_gt)\n        outputs = metric.aggregate()\n        # Check that output is a tensor with shape (batch_size, num_classes, 4)\n        self.assertIsInstance(outputs, torch.Tensor)\n        self.assertEqual(outputs.shape[-1], 4)\n        # Verify that values correspond to [tp, fp, fn, iou_sum]\n        tp, fp, fn, iou_sum = outputs[..., 0], outputs[..., 1], outputs[..., 2], outputs[..., 3]\n        # tp, fp, fn should be non-negative integers\n        self.assertTrue(torch.all(tp >= 0))\n        self.assertTrue(torch.all(fp >= 0))\n        self.assertTrue(torch.all(fn >= 0))\n        # iou_sum should be non-negative float\n        self.assertTrue(torch.all(iou_sum >= 0))\n\n    def test_compute_mean_iou(self):\n        \"\"\"Test mean IoU computation from confusion matrix.\"\"\"\n        input_params, y_pred, y_gt = TEST_CM_CASE_1\n        metric = PanopticQualityMetric(**input_params)\n        metric(y_pred, y_gt)\n        confusion_matrix = metric.aggregate()\n        mean_iou = compute_mean_iou(confusion_matrix)\n\n        # Check shape is correct\n        self.assertEqual(mean_iou.shape, confusion_matrix.shape[:-1])\n\n        # Check values are non-negative\n        self.assertTrue(torch.all(mean_iou >= 0))\n\n        # Validate against expected values\n        # mean_iou = iou_sum / (tp + smooth_numerator)\n        tp = confusion_matrix[..., 0]\n        iou_sum = confusion_matrix[..., 3]\n        expected_mean_iou = iou_sum / (tp + 1e-6)  # smooth_numerator=1e-6 is default\n        np.testing.assert_allclose(mean_iou.cpu().numpy(), expected_mean_iou.cpu().numpy(), atol=1e-4)\n\n    def test_metric_name_filtering(self):\n        \"\"\"Test that metric_name parameter properly filters output.\"\"\"\n        # Test single metric \"sq\"\n        metric_sq = PanopticQualityMetric(num_classes=3, metric_name=\"sq\", match_iou_threshold=0.5)\n        metric_sq(sample_3_pred, sample_3_gt)\n        result_sq = metric_sq.aggregate()\n        self.assertIsInstance(result_sq, torch.Tensor)\n        self.assertEqual(result_sq.shape, torch.Size([3]))\n\n        # Test single metric \"rq\"\n        metric_rq = PanopticQualityMetric(num_classes=3, metric_name=\"rq\", match_iou_threshold=0.5)\n        metric_rq(sample_3_pred, sample_3_gt)\n        result_rq = metric_rq.aggregate()\n        self.assertIsInstance(result_rq, torch.Tensor)\n        self.assertEqual(result_rq.shape, torch.Size([3]))\n\n        # Results should be different for different metrics\n        self.assertFalse(torch.allclose(result_sq, result_rq, atol=1e-4))\n\n    def test_invalid_3d_shape(self):\n        \"\"\"Test that invalid 3D shapes are rejected.\"\"\"\n        # Shape with 3 dimensions should fail\n        invalid_pred = torch.randint(0, 5, (2, 2, 10))\n        invalid_gt = torch.randint(0, 5, (2, 2, 10))\n        metric = PanopticQualityMetric(num_classes=3)\n        with self.assertRaises(ValueError):\n            metric(invalid_pred, invalid_gt)\n\n        # Shape with 6 dimensions should fail\n        invalid_pred = torch.randint(0, 5, (1, 2, 8, 8, 8, 8))\n        invalid_gt = torch.randint(0, 5, (1, 2, 8, 8, 8, 8))\n        with self.assertRaises(ValueError):\n            metric(invalid_pred, invalid_gt)\n\n    def test_compute_mean_iou_invalid_shape(self):\n        \"\"\"Test that compute_mean_iou raises ValueError for invalid shapes.\"\"\"\n        from monai.metrics.panoptic_quality import compute_mean_iou\n\n        # Shape (..., 3) instead of (..., 4) should fail\n        invalid_confusion_matrix = torch.zeros(3, 3)\n        with self.assertRaises(ValueError):\n            compute_mean_iou(invalid_confusion_matrix)\n\n        # Shape (..., 5) should also fail\n        invalid_confusion_matrix = torch.zeros(2, 5)\n        with self.assertRaises(ValueError):\n            compute_mean_iou(invalid_confusion_matrix)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_regression_metrics.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom functools import partial\n\nimport numpy as np\nimport torch\n\nfrom monai.metrics import MAEMetric, MAPEMetric, MSEMetric, PSNRMetric, RMSEMetric\nfrom monai.utils import set_determinism\n\n\n# define a numpy flatten function that only preserves batch dimension\ndef flatten(data):\n    return np.reshape(data, [data.shape[0], -1])\n\n\n# define metrics computation truth functions to check our monai metrics against\ndef msemetric_np(y_pred, y):\n    return np.mean((flatten(y_pred) - flatten(y)) ** 2)\n\n\ndef maemetric_np(y_pred, y):\n    return np.mean(np.abs(flatten(y_pred) - flatten(y)))\n\n\ndef rmsemetric_np(y_pred, y):\n    return np.mean(np.sqrt(np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)))\n\n\ndef psnrmetric_np(max_val, y_pred, y):\n    mse = np.mean((flatten(y_pred) - flatten(y)) ** 2, axis=1)\n    return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse))\n\n\ndef mapemetric_np(y_pred, y, epsilon=1e-7):\n    percentage_error = np.abs(y - y_pred) / np.clip(np.abs(y), a_min=epsilon, a_max=None) * 100.0\n    return np.mean(flatten(percentage_error))\n\n\nclass TestRegressionMetrics(unittest.TestCase):\n\n    def test_shape_reduction(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # regression metrics to check\n        metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]\n\n        # define variations in batch/base_dims/spatial_dims\n        batch_dims = [1, 2, 4, 16]\n        base_dims = [16, 32, 64]\n        spatial_dims = [2, 3, 4]\n\n        # iterate over all variations and check shapes for different reduction functions\n        for batch in batch_dims:\n            for spatial in spatial_dims:\n                for base in base_dims:\n                    # create random tensors\n                    in_tensor = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n\n                    # iterate over regression metrics, check shape for diff. reduction func\n                    for mt_fn in metrics:\n                        mt = mt_fn(reduction=\"mean\")\n                        mt(in_tensor, in_tensor)\n                        out_tensor = mt.aggregate()\n                        self.assertEqual(len(out_tensor.shape), 1)\n\n                        mt = mt_fn(reduction=\"sum\")\n                        mt(in_tensor, in_tensor)\n                        out_tensor = mt.aggregate()\n                        self.assertEqual(len(out_tensor.shape), 0)\n\n                        mt = mt_fn(reduction=\"sum\")  # test reduction arg overriding\n                        mt(in_tensor, in_tensor)\n                        out_tensor = mt.aggregate(reduction=\"mean_channel\")\n                        self.assertEqual(len(out_tensor.shape), 1)\n                        self.assertEqual(out_tensor.shape[0], batch)\n\n                        mt = mt_fn(reduction=\"sum_channel\")\n                        mt(in_tensor, in_tensor)\n                        out_tensor = mt.aggregate()\n                        self.assertEqual(len(out_tensor.shape), 1)\n                        self.assertEqual(out_tensor.shape[0], batch)\n\n    def test_compare_numpy(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # regression metrics to check + truth metric function in numpy\n        metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]\n        metrics_np = [msemetric_np, maemetric_np, mapemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)]\n\n        # define variations in batch/base_dims/spatial_dims\n        batch_dims = [1, 2, 4, 16]\n        base_dims = [16, 32, 64]\n        spatial_dims = [2, 3, 4]\n\n        # iterate over all variations and check shapes for different reduction functions\n        for batch in batch_dims:\n            for spatial in spatial_dims:\n                for base in base_dims:\n                    # create random tensors\n                    in_tensor_a = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n                    in_tensor_b = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n\n                    # check metrics\n                    for mt_fn, mt_fn_np in zip(metrics, metrics_np):\n                        mt = mt_fn()\n                        mt(y_pred=in_tensor_a, y=in_tensor_b)\n                        out_tensor = mt.aggregate(reduction=\"mean\")\n                        out_np = mt_fn_np(y_pred=in_tensor_a.cpu().numpy(), y=in_tensor_b.cpu().numpy())\n\n                        np.testing.assert_allclose(out_tensor.cpu().numpy(), out_np, atol=1e-3, rtol=1e-4)\n\n    def test_ill_shape(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # regression metrics to check + truth metric function in numpy\n        metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]\n        basedim = 10\n\n        # too small shape\n        with self.assertRaises(ValueError):\n            in_tensor = torch.rand((basedim,)).to(device)\n            for mt_fn in metrics:\n                mt_fn()(in_tensor, in_tensor)\n\n        # different shape for pred/target\n        with self.assertRaises(ValueError):\n            in_tensor_a = torch.rand((basedim,)).to(device)\n            in_tensor_b = torch.rand((basedim, basedim)).to(device)\n            for mt_fn in metrics:\n                mt_fn()(y_pred=in_tensor_a, y=in_tensor_b)\n\n    def test_same_input(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]\n        results = [0.0, 0.0, 0.0, 0.0, float(\"inf\")]\n\n        # define variations in batch/base_dims/spatial_dims\n        batch_dims = [1, 2, 4, 16]\n        base_dims = [16, 32, 64]\n        spatial_dims = [2, 3, 4]\n\n        # iterate over all variations and check shapes for different reduction functions\n        for batch in batch_dims:\n            for spatial in spatial_dims:\n                for base in base_dims:\n                    # create random tensors\n                    in_tensor = torch.rand((batch,) + (base,) * (spatial - 1)).to(device)\n\n                    # check metrics\n                    for mt_fn, rs in zip(metrics, results):\n                        mt = mt_fn(reduction=\"mean\")\n                        mt(in_tensor, in_tensor)\n                        out_tensor = mt.aggregate()\n                        np.testing.assert_allclose(out_tensor.cpu(), rs, atol=1e-4)\n\n    def test_diff_input(self):\n        set_determinism(seed=123)\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)]\n        results = [1.0, 1.0, 100.0, 1.0, 0.0]\n\n        # define variations in batch/base_dims/spatial_dims\n        batch_dims = [1, 2, 4, 16]\n        base_dims = [16, 32, 64]\n        spatial_dims = [2, 3, 4]\n\n        # iterate over all variations and check shapes for different reduction functions\n        for batch in batch_dims:\n            for spatial in spatial_dims:\n                for base in base_dims:\n                    # create random tensors\n                    in_tensor_a = torch.zeros((batch,) + (base,) * (spatial - 1)).to(device)\n                    in_tensor_b = torch.ones((batch,) + (base,) * (spatial - 1)).to(device)\n\n                    # check metrics\n                    for mt_fn, rs in zip(metrics, results):\n                        mt = mt_fn(reduction=\"mean\")\n                        mt(in_tensor_a, in_tensor_b)\n                        out_tensor = mt.aggregate()\n                        np.testing.assert_allclose(out_tensor.cpu(), rs, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_roc_auc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import decollate_batch\nfrom monai.metrics import ROCAUCMetric, compute_roc_auc\nfrom monai.transforms import Activations, AsDiscrete, Compose, ToTensor\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\nTEST_CASE_1 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),\n    torch.tensor([[0], [1], [0], [1]], device=_device),\n    True,\n    2,\n    \"macro\",\n    0.75,\n]\n\nTEST_CASE_2 = [\n    torch.tensor([[0.5], [0.5], [0.2], [8.3]]),\n    torch.tensor([[0], [1], [0], [1]]),\n    False,\n    None,\n    \"macro\",\n    0.875,\n]\n\nTEST_CASE_3 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, \"macro\", 0.875]\n\nTEST_CASE_4 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, \"macro\", 0.875]\n\nTEST_CASE_5 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[0], [1], [0], [1]]),\n    True,\n    2,\n    \"none\",\n    [0.75, 0.75],\n]\n\nTEST_CASE_6 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),\n    torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),\n    True,\n    None,\n    \"weighted\",\n    0.56667,\n]\n\nTEST_CASE_7 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),\n    torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),\n    True,\n    None,\n    \"micro\",\n    0.62,\n]\n\nTEST_CASE_8 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[0], [0], [0], [0]]),\n    True,\n    2,\n    \"macro\",\n    float(\"nan\"),\n]\n\nTEST_CASE_9 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[1], [1], [1], [1]]),\n    True,\n    2,\n    \"macro\",\n    float(\"nan\"),\n]\n\nTEST_CASE_10 = [\n    torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),\n    torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]),\n    True,\n    None,\n    \"macro\",\n    float(\"nan\"),\n]\n\n\nclass TestComputeROCAUC(unittest.TestCase):\n\n    @parameterized.expand(\n        [\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n            TEST_CASE_10,\n        ]\n    )\n    def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value):\n        y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)])\n        y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)])\n        y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0)\n        y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0)\n        result = compute_roc_auc(y_pred=y_pred, y=y, average=average)\n        np.testing.assert_allclose(expected_value, result, rtol=1e-5)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n            TEST_CASE_10,\n        ]\n    )\n    def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value):\n        y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)])\n        y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)])\n        y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)]\n        y = [y_trans(i) for i in decollate_batch(y)]\n        metric = ROCAUCMetric(average=average)\n        metric(y_pred=y_pred, y=y)\n        result = metric.aggregate()\n        np.testing.assert_allclose(expected_value, result, rtol=1e-5)\n        result = metric.aggregate(average=average)  # test optional argument\n        metric.reset()\n        np.testing.assert_allclose(expected_value, result, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_compute_variance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import VarianceMetric, compute_variance\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n# keep background, 1D Case\nTEST_CASE_1 = [  # y_pred (3, 1, 3), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),\n        \"include_background\": True,\n        \"spatial_map\": False,\n    },\n    [[0.0]],\n]\n\n# keep background, 2D Case\nTEST_CASE_2 = [  # y_pred (1, 1, 2, 2), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n        \"spatial_map\": False,\n    },\n    [[0.0]],\n]\n\n# keep background, 3D Case\nTEST_CASE_3 = [  # y_pred (1, 1, 1, 2, 2), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),\n        \"include_background\": True,\n        \"spatial_map\": False,\n    },\n    [[0.0]],\n]\n\n# remove background, 1D Case\nTEST_CASE_4 = [  # y_pred (3, 1, 3), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[1.0, 2.0, 3.0], [1.0, 1.0, 1.0]],\n                [[4.0, 5.0, 6.0], [1.0, 1.0, 1.0]],\n                [[7.0, 8.0, 9.0], [1.0, 1.0, 1.0]],\n            ],\n            device=_device,\n        ),\n        \"include_background\": False,\n        \"spatial_map\": False,\n    },\n    [[0.0]],\n]\n\n# Spatial Map Test Case for 2D Case\nTEST_CASE_5 = [  # y_pred (1, 1, 2, 2), expected out all (0.0) map of 2x2\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n        \"spatial_map\": True,\n    },\n    [[0.0, 0.0], [0.0, 0.0]],\n]\n\n# Spatial Map Test Case for 3D Case\nTEST_CASE_6 = [  # y_pred (1, 1, 2, 2, 2), expected out all (0.0) map of 2x2x2\n    {\n        \"y_pred\": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),\n        \"include_background\": True,\n        \"spatial_map\": True,\n    },\n    [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n]\n\n# Threshold test for a 1D Case\nTEST_CASE_7 = [  # y_pred (3, 1, 3), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor(\n            [\n                [[1.0, 2.0, 3.0], [1.0, 1.0, 0.0]],\n                [[4.0, 5.0, 6.0], [1.0, 1.0, 1.0]],\n                [[7.0, 8.0, 9.0], [1.0, 1.0, 0.0]],\n                [[1.0, 2.0, 3.0], [1.0, 1.0, 1.0]],\n            ],\n            device=_device,\n        ),\n        \"include_background\": False,\n        \"spatial_map\": False,\n        \"threshold\": 0.001,\n    },\n    [[0.083167]],\n]\n\n\nclass TestComputeVariance(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_value(self, input_data, expected_value):\n        result = compute_variance(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n    @parameterized.expand([TEST_CASE_5, TEST_CASE_6])\n    def test_spatial_case(self, input_data, expected_value):\n        result = compute_variance(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_7])\n    def test_threshold_case(self, input_data, expected_value):\n        result = compute_variance(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_value_class(self, input_data, expected_value):\n        vals = {}\n        vals[\"y_pred\"] = input_data.pop(\"y_pred\")\n        comp_var = VarianceMetric(**input_data)\n        result = comp_var(**vals)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_5, TEST_CASE_6])\n    def test_spatial_case_class(self, input_data, expected_value):\n        vals = {}\n        vals[\"y_pred\"] = input_data.pop(\"y_pred\")\n        comp_var = VarianceMetric(**input_data)\n        result = comp_var(**vals)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_cumulative.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.metrics import Cumulative\nfrom tests.test_utils import assert_allclose\n\n\nclass TestCumulative(unittest.TestCase):\n    def test_single(self):\n        c = Cumulative()\n        c.extend([2, 3])\n        c.append(1)\n        assert_allclose(c.get_buffer(), torch.tensor([2, 3, 1]))\n\n    def test_multi(self):\n        c = Cumulative()\n        c.extend([2, 3], [4, 6])\n        c.append(1)\n        assert_allclose(c.get_buffer()[0], torch.tensor([2, 3, 1]))\n        assert_allclose(c.get_buffer()[1], torch.tensor([4, 6]))\n\n        c.reset()\n        c.append()\n        c.extend()\n        self.assertEqual(c.get_buffer(), [])\n        c.get_buffer().append(1)\n        self.assertEqual(c.get_buffer(), [])  # no in-place update for the buffer\n\n        c.reset()\n\n    def test_ill(self):\n        c = Cumulative()\n        with self.assertRaises(TypeError):\n            c.extend(None)\n        with self.assertRaises(TypeError):\n            c.extend([])\n        with self.assertRaises(TypeError):\n            c.extend(1)\n        with self.assertRaises(TypeError):\n            c.append([])\n            c.append([1, 2])\n            c.get_buffer()\n        with self.assertRaises(TypeError):\n            c.append(None)\n            c.get_buffer()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_cumulative_average.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import CumulativeAverage\n\nTEST_CASE_1 = []\nTEST_CASE_1.append([{\"vals\": [1, 2, 3], \"avg\": 2}])\nTEST_CASE_1.append([{\"vals\": [[1, 1, 1], [2, 2, 2], [3, 6, 9]], \"avg\": [2, 3, 4]}])\n\nTEST_CASE_1.append([{\"vals\": [2, 4, 6], \"counts\": [2, 1, 2], \"avg\": 4}])\nTEST_CASE_1.append(\n    [{\"vals\": [[3, 2, 1], [2, 3, 2], [0, 0, 9]], \"counts\": [[4, 4, 4], [4, 4, 4], [2, 2, 2]], \"avg\": [2, 2, 3]}]\n)\n\nTEST_CASE_1.append([{\"vals\": [1, 2, float(\"nan\")], \"avg\": 1.5}])\n\n\nclass TestAverageMeter(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_1)\n    def test_value_all(self, data):\n        # test orig\n        self.run_test(data)\n\n        # test in numpy\n        data[\"vals\"] = np.array(data[\"vals\"])\n        data[\"avg\"] = np.array(data[\"avg\"])\n        self.run_test(data)\n\n        # test as Tensors\n        data[\"vals\"] = torch.tensor(data[\"vals\"])\n        data[\"avg\"] = torch.tensor(data[\"avg\"], dtype=torch.float)\n        self.run_test(data)\n\n        if torch.cuda.is_available():\n            data[\"vals\"] = data[\"vals\"].cuda()\n            self.run_test(data)\n\n    def run_test(self, data):\n        vals = data[\"vals\"]\n        avg = data[\"avg\"]\n\n        counts = data.get(\"counts\", None)\n        if counts is not None and not isinstance(counts, list) and isinstance(vals, list):\n            counts = [counts] * len(vals)\n\n        avg_meter = CumulativeAverage()\n        for i in range(len(vals)):\n            if counts is not None:\n                avg_meter.append(vals[i], counts[i])\n            else:\n                avg_meter.append(vals[i])\n\n        np.testing.assert_equal(avg_meter.aggregate(), avg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_cumulative_average_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom monai.metrics import CumulativeAverage\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedCumulativeAverage(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_value(self):\n        rank = dist.get_rank()\n        nprocs = dist.get_world_size()\n        is_cuda = dist.get_backend() == dist.Backend.NCCL\n        if is_cuda:\n            torch.cuda.set_device(rank)\n\n        device = torch.device(rank) if is_cuda else torch.device(\"cpu\")\n\n        avg_meter = CumulativeAverage()  # each process rank has it's own AverageMeter\n        n_iter = 10\n        for i in range(n_iter):\n            val = torch.as_tensor(rank + i, device=device)\n            avg_meter.append(val=val)\n\n        avg_val = avg_meter.aggregate()  # average across all processes\n        expected_val = sum(sum(list(range(rank_i, rank_i + n_iter))) for rank_i in range(nprocs)) / (n_iter * nprocs)\n        np.testing.assert_equal(avg_val, expected_val)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_hausdorff_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom itertools import product\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import HausdorffDistanceMetric\n\n_devices = [\"cpu\"]\nif torch.cuda.is_available():\n    _devices.append(\"cuda\")\n\n\ndef create_spherical_seg_3d(\n    radius: float = 20.0,\n    centre: tuple[int, int, int] = (49, 49, 49),\n    im_shape: tuple[int, int, int] = (99, 99, 99),\n    im_spacing: tuple[float, float, float] = (1.0, 1.0, 1.0),\n) -> np.ndarray:\n    \"\"\"\n    Return a 3D image with a sphere inside. Voxel values will be\n    1 inside the sphere, and 0 elsewhere.\n\n    Args:\n        radius: radius of sphere (in terms of number of voxels, can be partial)\n        centre: location of sphere centre.\n        im_shape: shape of image to create.\n        im_spacing: spacing of image to create.\n\n    See also:\n        :py:meth:`~create_test_image_3d`\n    \"\"\"\n    # Create image\n    image = np.zeros(im_shape, dtype=np.int32)\n    spy, spx, spz = np.ogrid[: im_shape[0], : im_shape[1], : im_shape[2]]\n    spy = spy.astype(float) * im_spacing[0]\n    spx = spx.astype(float) * im_spacing[1]\n    spz = spz.astype(float) * im_spacing[2]\n\n    spy -= centre[0]\n    spx -= centre[1]\n    spz -= centre[2]\n\n    circle = (spx * spx + spy * spy + spz * spz) <= radius * radius\n\n    image[circle] = 1\n    image[~circle] = 0\n    return image\n\n\ntest_spacing = (0.85, 1.2, 0.9)\nTEST_CASES = [\n    [[create_spherical_seg_3d(), create_spherical_seg_3d(), None, 1], [0, 0, 0, 0, 0, 0]],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 20, 20)),\n            create_spherical_seg_3d(radius=20, centre=(19, 19, 19)),\n            None,\n        ],\n        [1.7320508075688772, 1.7320508075688772, 1, 1, 3, 3],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=33, centre=(19, 33, 22)),\n            create_spherical_seg_3d(radius=33, centre=(20, 33, 22)),\n            None,\n        ],\n        [1, 1, 1, 1, 1, 1],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 33, 22)),\n            create_spherical_seg_3d(radius=40, centre=(20, 33, 22)),\n            None,\n        ],\n        [20.09975124224178, 20.223748416156685, 15, 20, 24, 35],\n    ],\n    [\n        [\n            # pred does not have foreground (but gt has), the metric should be inf\n            np.zeros([99, 99, 99]),\n            create_spherical_seg_3d(radius=40, centre=(20, 33, 22)),\n            None,\n        ],\n        [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf],\n    ],\n    [\n        [\n            # gt does not have foreground (but pred has), the metric should be inf\n            create_spherical_seg_3d(),\n            np.zeros([99, 99, 99]),\n            None,\n        ],\n        [np.inf, np.inf, np.inf, np.inf, np.inf, np.inf],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 33, 22)),\n            create_spherical_seg_3d(radius=40, centre=(20, 33, 22)),\n            None,\n            95,\n        ],\n        [19.924858845171276, 20.09975124224178, 14, 18, 22, 33],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 20, 20), im_spacing=test_spacing),\n            create_spherical_seg_3d(radius=20, centre=(19, 19, 19), im_spacing=test_spacing),\n            test_spacing,\n        ],\n        [2.0808651447296143, 2.2671568, 2, 2, 3, 4],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=15, centre=(20, 33, 22), im_spacing=test_spacing),\n            create_spherical_seg_3d(radius=30, centre=(20, 33, 22), im_spacing=test_spacing),\n            test_spacing,\n        ],\n        [15.439640998840332, 15.62594, 11, 17, 20, 28],\n    ],\n]\n\nTEST_CASES_NANS = [\n    [\n        [\n            # both pred and gt do not have foreground, spacing is None, metric and not_nans should be 0\n            np.zeros([99, 99, 99]),\n            np.zeros([99, 99, 99]),\n            None,\n        ]\n    ],\n    [\n        [\n            # both pred and gt do not have foreground, metric and not_nans should be 0\n            np.zeros([99, 99, 99]),\n            np.zeros([99, 99, 99]),\n            test_spacing,\n        ]\n    ],\n]\n\nTEST_CASES_EXPANDED = []\nfor test_case in TEST_CASES:\n    test_output: list[float | int]\n    test_input, test_output = test_case  # type: ignore\n    for _device in _devices:\n        for i, (metric, directed) in enumerate(product([\"euclidean\", \"chessboard\", \"taxicab\"], [True, False])):\n            TEST_CASES_EXPANDED.append((_device, metric, directed, test_input, test_output[i]))\n\n\ndef _describe_test_case(test_func, test_number, params):\n    _device, metric, directed, test_input, test_output = params.args\n    return f\"device: {_device} metric: {metric} directed:{directed} expected: {test_output}\"\n\n\nclass TestHausdorffDistance(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_EXPANDED, doc_func=_describe_test_case)\n    def test_value(self, device, metric, directed, input_data, expected_value):\n        percentile = None\n        if len(input_data) == 4:\n            [seg_1, seg_2, spacing, percentile] = input_data\n        else:\n            [seg_1, seg_2, spacing] = input_data\n\n        seg_1 = torch.tensor(seg_1, device=device)\n        seg_2 = torch.tensor(seg_2, device=device)\n        hd_metric = HausdorffDistanceMetric(\n            include_background=False, distance_metric=metric, percentile=percentile, directed=directed\n        )\n        # shape of seg_1, seg_2 are: HWD, converts to BNHWD\n        batch, n_class = 2, 3\n        batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1])\n        batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1])\n        hd_metric(batch_seg_1, batch_seg_2, spacing=spacing)\n        result: torch.Tensor = hd_metric.aggregate(reduction=\"mean\")  # type: ignore\n        np.testing.assert_allclose(expected_value, result.cpu(), rtol=1e-6)\n        np.testing.assert_equal(result.device, seg_1.device)\n\n    @parameterized.expand(TEST_CASES_NANS)\n    def test_nans(self, input_data):\n        [seg_1, seg_2, spacing] = input_data\n        seg_1 = torch.tensor(seg_1)\n        seg_2 = torch.tensor(seg_2)\n        hd_metric = HausdorffDistanceMetric(include_background=False, get_not_nans=True)\n        batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0)\n        batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0)\n        hd_metric(batch_seg_1, batch_seg_2, spacing=spacing)\n        result, not_nans = hd_metric.aggregate()\n        np.testing.assert_allclose(0, result, rtol=1e-7)\n        np.testing.assert_allclose(0, not_nans, rtol=1e-7)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_label_quality_score.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import LabelQualityScore, label_quality_score\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n# keep background, 1D Case\nTEST_CASE_1 = [  # y_pred (3, 1, 3), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),\n        \"y\": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"sum\",\n    },\n    [0.0, 0.0, 0.0],\n]\n\n# keep background, 2D Case\nTEST_CASE_2 = [  # y_pred (1, 1, 2, 2), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"sum\",\n    },\n    [0.0],\n]\n\n# keep background, 3D Case\nTEST_CASE_3 = [  # y_pred (1, 1, 1, 2, 2), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),\n        \"y\": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"sum\",\n    },\n    [0.0],\n]\n\n# keep background, 2D Case\nTEST_CASE_4 = [  # y_pred (1, 1, 2, 2), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"sum\",\n    },\n    [4.0],\n]\n\nTEST_CASE_5 = [  # y_pred (1, 1, 2, 2), expected out (0.0)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"mean\",\n    },\n    [1.0],\n]\n\n# Spatial Map Test Case for 3D Case\nTEST_CASE_6 = [  # y_pred (1, 1, 2, 2, 2), expected out all (0.0) map of 2x2x2\n    {\n        \"y_pred\": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),\n        \"y\": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"none\",\n    },\n    [[[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]],\n]\n\n# Spatial Map Test Case for 2D Case\nTEST_CASE_7 = [  # y_pred (1, 1, 2, 2)\n    {\n        \"y_pred\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n        \"scalar_reduction\": \"none\",\n    },\n    [[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]],\n]\n\n\nclass TestLabelQualityScore(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_value(self, input_data, expected_value):\n        result = label_quality_score(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n        np.testing.assert_equal(result.device, input_data[\"y_pred\"].device)\n\n    @parameterized.expand([TEST_CASE_6, TEST_CASE_7])\n    def test_spatial_case(self, input_data, expected_value):\n        result = label_quality_score(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_value_class(self, input_data, expected_value):\n        vals = {}\n        vals[\"y_pred\"] = input_data.pop(\"y_pred\")\n        vals[\"y\"] = input_data.pop(\"y\")\n        comp_var = LabelQualityScore(**input_data)\n        result = comp_var(**vals)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n    @parameterized.expand([TEST_CASE_6, TEST_CASE_7])\n    def test_spatial_case_class(self, input_data, expected_value):\n        vals = {}\n        vals[\"y_pred\"] = input_data.pop(\"y_pred\")\n        vals[\"y\"] = input_data.pop(\"y\")\n        comp_var = LabelQualityScore(**input_data)\n        result = comp_var(**vals)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_loss_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.losses import DiceLoss\nfrom monai.metrics import LossMetric\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\nTEST_CASE_1 = [  # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)\n    {\n        \"loss_class\": DiceLoss,\n        \"loss_kwargs\": {\"include_background\": True},\n        \"reduction\": \"mean\",\n        \"get_not_nans\": False,\n        \"y_pred\": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),\n        \"y\": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),\n        \"include_background\": True,\n    },\n    [0.2],\n]\n\n\nclass TestComputeLossMetric(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1])\n    def test_value_class(self, input_data, expected_value):\n        loss_fn = input_data[\"loss_class\"](**input_data[\"loss_kwargs\"])\n        loss_metric = LossMetric(\n            loss_fn=loss_fn, reduction=input_data[\"reduction\"], get_not_nans=input_data[\"get_not_nans\"]\n        )\n\n        loss_metric(y_pred=input_data.get(\"y_pred\"), y=input_data.get(\"y\"))\n        loss_metric(y_pred=input_data.get(\"y_pred\"), y=input_data.get(\"y\"))\n        result = loss_metric.aggregate()\n        np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)\n        loss_metric.reset()\n        result = loss_metric.aggregate()\n        np.testing.assert_allclose(result.cpu().numpy(), 0.0, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_metrics_reloaded.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical\nfrom monai.utils import optional_import\n\n_, has_metrics = optional_import(\"MetricsReloaded\")\n\n# shape: (1, 1, 2, 2)\ny_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])\ny = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])\nTEST_CASES_BINARY = [\n    [{\"metric_name\": \"False Positives\"}, [y_pred, y], 0.0],\n    [{\"metric_name\": \"False Negatives\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"True Positives\"}, [y_pred, y], 2.0],\n    [{\"metric_name\": \"True Negatives\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"Youden Index\"}, [y_pred, y], 0.666654],\n    [{\"metric_name\": \"Sensitivity\"}, [y_pred, y], 0.666664],\n    [{\"metric_name\": \"Specificity\"}, [y_pred, y], 0.99999],\n    [{\"metric_name\": \"Balanced Accuracy\"}, [y_pred, y], 0.833327],\n    [{\"metric_name\": \"Accuracy\"}, [y_pred, y], 0.75],\n    [{\"metric_name\": \"False Positive Rate\"}, [y_pred, y], 0.0],\n    [{\"metric_name\": \"Normalised Expected Cost\"}, [y_pred, y], 0.333333],\n    [{\"metric_name\": \"Matthews Correlation Coefficient\"}, [y_pred, y], 0.57735],\n    [{\"metric_name\": \"Cohens Kappa\"}, [y_pred, y], 0.5],\n    [{\"metric_name\": \"Positive Likelihood Ratio\"}, [y_pred, y], 66576.03],\n    [{\"metric_name\": \"Prediction Overlaps Reference\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"Positive Predictive Value\"}, [y_pred, y], 0.999995],\n    [{\"metric_name\": \"Recall\"}, [y_pred, y], 0.666664],\n    [{\"metric_name\": \"FBeta\"}, [y_pred, y], 0.799992],\n    [{\"metric_name\": \"Net Benefit Treated\"}, [y_pred, y], 0.5],\n    [{\"metric_name\": \"Negative Predictive Values\"}, [y_pred, y], 0.5],\n    [{\"metric_name\": \"Dice Score\"}, [y_pred, y], 0.799992],\n    [{\"metric_name\": \"False Positives Per Image\"}, [y_pred, y], 0.0],\n    [{\"metric_name\": \"Intersection Over Reference\"}, [y_pred, y], 0.666664],\n    [{\"metric_name\": \"Intersection Over Union\"}, [y_pred, y], 0.666664],\n    [{\"metric_name\": \"Volume Difference\"}, [y_pred, y], 0.333333],\n    [{\"metric_name\": \"Topology Precision\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"Topology Sensitivity\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"Centreline Dice Score\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"Boundary IoU\"}, [y_pred, y], 0.666667],\n    [{\"metric_name\": \"Normalised Surface Distance\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"Average Symmetric Surface Distance\"}, [y_pred, y], 0.2],\n    [{\"metric_name\": \"Mean Average Surfance Distance\"}, [y_pred, y], 0.166666],\n    [{\"metric_name\": \"Hausdorff Distance\"}, [y_pred, y], 1.0],\n    [{\"metric_name\": \"xTh Percentile Hausdorff Distance\"}, [y_pred, y], 0.9],\n]\n\n# shape: (1, 3, 2, 2)\ny_pred = torch.tensor([[[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]])\ny = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])\nTEST_CASES_CATEGORICAL = [\n    [{\"metric_name\": \"Balanced Accuracy\"}, [y_pred, y], 0.5],\n    [{\"metric_name\": \"Weighted Cohens Kappa\"}, [y_pred, y], 0.272727],\n    [{\"metric_name\": \"Matthews Correlation Coefficient\"}, [y_pred, y], 0.387298],\n    [{\"metric_name\": \"Expected Cost\"}, [y_pred, y], 0.5],\n    [{\"metric_name\": \"Normalised Expected Cost\"}, [y_pred, y], 0.75],\n]\n\n\n@unittest.skipIf(not has_metrics, \"MetricsReloaded not available.\")\nclass TestMetricsReloaded(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_BINARY)\n    def test_binary(self, input_param, input_data, expected_val):\n        metric = MetricsReloadedBinary(**input_param)\n        result = metric(*input_data)\n        np.testing.assert_allclose(\n            result.detach().cpu().numpy(), expected_val, rtol=1e-5, err_msg=input_param[\"metric_name\"]\n        )\n\n    @parameterized.expand(TEST_CASES_CATEGORICAL)\n    def test_categorical(self, input_param, input_data, expected_val):\n        metric = MetricsReloadedCategorical(**input_param)\n        result = metric(*input_data)\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_ssim_metric.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.metrics.regression import SSIMMetric, compute_ssim_and_cs\nfrom monai.utils import set_determinism\n\n\nclass TestSSIMMetric(unittest.TestCase):\n\n    def test2d_gaussian(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(2, 3, 16, 16))\n        target = torch.abs(torch.randn(2, 3, 16, 16))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        metric = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_type=\"gaussian\")\n        metric(preds, target)\n        result = metric.aggregate()\n        expected_value = 0.045415\n        self.assertTrue(expected_value - result.item() < 0.000001)\n\n    def test2d_uniform(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(2, 3, 16, 16))\n        target = torch.abs(torch.randn(2, 3, 16, 16))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        metric = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_type=\"uniform\")\n        metric(preds, target)\n        result = metric.aggregate()\n        expected_value = 0.050103\n        self.assertTrue(expected_value - result.item() < 0.000001)\n\n    def test3d_gaussian(self):\n        set_determinism(0)\n        preds = torch.abs(torch.randn(2, 3, 16, 16, 16))\n        target = torch.abs(torch.randn(2, 3, 16, 16, 16))\n        preds = preds / preds.max()\n        target = target / target.max()\n\n        metric = SSIMMetric(spatial_dims=3, data_range=1.0, kernel_type=\"gaussian\")\n        metric(preds, target)\n        result = metric.aggregate()\n        expected_value = 0.017644\n        self.assertTrue(expected_value - result.item() < 0.000001)\n\n    def input_ill_input_shape(self):\n        with self.assertRaises(ValueError):\n            metric = SSIMMetric(spatial_dims=3)\n            metric(torch.randn(1, 1, 16, 16), torch.randn(1, 1, 16, 16))\n\n        with self.assertRaises(ValueError):\n            metric = SSIMMetric(spatial_dims=2)\n            metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16))\n\n    def mismatch_y_pred_and_y(self):\n        with self.assertRaises(ValueError):\n            compute_ssim_and_cs(y_pred=torch.randn(1, 1, 16, 8), y=torch.randn(1, 1, 16, 16), spatial_dims=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_surface_dice.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom monai.metrics.surface_dice import SurfaceDiceMetric, compute_surface_dice\nfrom tests.test_utils import assert_allclose\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n\nclass TestAllSurfaceDiceMetrics(unittest.TestCase):\n    def test_tolerance_euclidean_distance_with_spacing(self):\n        batch_size = 2\n        n_class = 2\n        test_spacing = (0.85, 1.2)\n        predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device)\n        labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device)\n        predictions[0, :, 50:] = 1\n        labels[0, :, 60:] = 1  # 10 px shift\n        predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2)\n        labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2)\n\n        sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)\n        res0 = sd0(predictions_hot, labels_hot, spacing=test_spacing)\n        agg0 = sd0.aggregate()  # aggregation: nanmean across image then nanmean across batch\n        sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True)\n        res0_nans = sd0_nans(predictions_hot, labels_hot)\n        agg0_nans, not_nans = sd0_nans.aggregate()\n\n        np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu())\n        np.testing.assert_equal(res0.device, predictions.device)\n        np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu())\n        np.testing.assert_equal(agg0.device, predictions.device)\n\n        res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(\n            predictions_hot, labels_hot, spacing=test_spacing\n        )\n        res9 = SurfaceDiceMetric(class_thresholds=[9, 9], include_background=True)(\n            predictions_hot, labels_hot, spacing=test_spacing\n        )\n        res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(\n            predictions_hot, labels_hot, spacing=test_spacing\n        )\n        res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(\n            predictions_hot, labels_hot, spacing=test_spacing\n        )\n        # because spacing is (0.85, 1.2) and we moved 10 pixels in the columns direction,\n        # everything with tolerance 12 or more should be the same as tolerance 12 (surface dice is 1.0)\n        res12 = SurfaceDiceMetric(class_thresholds=[12, 12], include_background=True)(\n            predictions_hot, labels_hot, spacing=test_spacing\n        )\n        res13 = SurfaceDiceMetric(class_thresholds=[13, 13], include_background=True)(\n            predictions_hot, labels_hot, spacing=test_spacing\n        )\n\n        for res in [res0, res9, res10, res11, res12, res13]:\n            assert res.shape == torch.Size([2, 2])\n\n        assert res0[0, 0] < res1[0, 0] < res9[0, 0] < res10[0, 0] < res11[0, 0]\n        assert res0[0, 1] < res1[0, 1] < res9[0, 1] < res10[0, 1] < res11[0, 1]\n        np.testing.assert_array_equal(res12.cpu(), res13.cpu())\n\n        expected_res0 = np.zeros((batch_size, n_class))\n        expected_res0[0, 1] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 588 * 2 + 578 * 2)\n        expected_res0[0, 0] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 48 * 2 + 58 * 2)\n        expected_res0[1, 0] = 1\n        expected_res0[1, 1] = np.nan\n        for b, c in np.ndindex(batch_size, n_class):\n            np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu())\n        np.testing.assert_allclose(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))\n        np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))\n\n    def test_tolerance_euclidean_distance(self):\n        batch_size = 2\n        n_class = 2\n        predictions = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device)\n        labels = torch.zeros((batch_size, 480, 640), dtype=torch.int64, device=_device)\n        predictions[0, :, 50:] = 1\n        labels[0, :, 60:] = 1  # 10 px shift\n        predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2)\n        labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2)\n\n        sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)\n        res0 = sd0(predictions_hot, labels_hot)\n        agg0 = sd0.aggregate()  # aggregation: nanmean across image then nanmean across batch\n        sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True)\n        res0_nans = sd0_nans(predictions_hot, labels_hot)\n        agg0_nans, not_nans = sd0_nans.aggregate()\n\n        np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu())\n        np.testing.assert_equal(res0.device, predictions.device)\n        np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu())\n        np.testing.assert_equal(agg0.device, predictions.device)\n\n        res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot)\n        res9 = SurfaceDiceMetric(class_thresholds=[9, 9], include_background=True)(predictions_hot, labels_hot)\n        res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(predictions_hot, labels_hot)\n        res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(predictions_hot, labels_hot)\n\n        for res in [res0, res9, res10, res11]:\n            assert res.shape == torch.Size([2, 2])\n\n        assert res0[0, 0] < res1[0, 0] < res9[0, 0] < res10[0, 0]\n        assert res0[0, 1] < res1[0, 1] < res9[0, 1] < res10[0, 1]\n        np.testing.assert_array_equal(res10.cpu(), res11.cpu())\n\n        expected_res0 = np.zeros((batch_size, n_class))\n        expected_res0[0, 1] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 588 * 2 + 578 * 2)\n        expected_res0[0, 0] = 1 - (478 + 480 + 9 * 2) / (480 * 4 + 48 * 2 + 58 * 2)\n        expected_res0[1, 0] = 1\n        expected_res0[1, 1] = np.nan\n        for b, c in np.ndindex(batch_size, n_class):\n            np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu())\n        np.testing.assert_allclose(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))\n        np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))\n\n    def test_tolerance_euclidean_distance_3d(self):\n        batch_size = 2\n        n_class = 2\n        predictions = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device)\n        labels = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device)\n        predictions[0, :, :, 20:] = 1\n        labels[0, :, :, 30:] = 1  # offset by 10\n        predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 4, 1, 2, 3)\n        labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 4, 1, 2, 3)\n\n        sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)\n        res0 = sd0(predictions_hot, labels_hot)\n        agg0 = sd0.aggregate()  # aggregation: nanmean across image then nanmean across batch\n        sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True)\n        res0_nans = sd0_nans(predictions_hot, labels_hot)\n        agg0_nans, not_nans = sd0_nans.aggregate()\n\n        np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu())\n        np.testing.assert_equal(res0.device, predictions.device)\n        np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu())\n        np.testing.assert_equal(agg0.device, predictions.device)\n\n        res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot)\n        res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(predictions_hot, labels_hot)\n        res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(predictions_hot, labels_hot)\n\n        for res in [res0, res1, res10, res11]:\n            assert res.shape == torch.Size([2, 2])\n\n        assert res0[0, 0] < res1[0, 0] < res10[0, 0]\n        assert res0[0, 1] < res1[0, 1] < res10[0, 1]\n        np.testing.assert_array_equal(res10.cpu(), res11.cpu())\n\n        expected_res0 = np.zeros((batch_size, n_class))\n        expected_res0[0, 1] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / (\n            200 * 110 * 4 + (58 + 48) * 200 * 2 + (58 + 48) * 108 * 2\n        )\n        expected_res0[0, 0] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / (\n            200 * 110 * 4 + (28 + 18) * 200 * 2 + (28 + 18) * 108 * 2\n        )\n        expected_res0[1, 0] = 1\n        expected_res0[1, 1] = np.nan\n        for b, c in np.ndindex(batch_size, n_class):\n            np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu())\n        np.testing.assert_allclose(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))\n        np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))\n\n    def test_tolerance_all_distances(self):\n        batch_size = 1\n        n_class = 2\n        predictions = torch.zeros((batch_size, 10, 10), dtype=torch.int64)\n        labels = torch.zeros((batch_size, 10, 10), dtype=torch.int64)\n        predictions[0, 1:4, 1] = 1\n        \"\"\"\n        [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]\n        \"\"\"\n        labels[0, 5:8, 6] = 1\n        \"\"\"\n        [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]\n        \"\"\"\n        predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2)\n        labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2)\n\n        # Euclidean distance:\n        # background:\n        # 36 boundary pixels have 0 distances; non-zero distances:\n        # distances gt_pred: [3, np.sqrt(9+4), 2, 3, 2, 2, 2, 1]\n        # distances pred_gt: [1, 2, 2, 1]\n        # class 1:\n        # distances gt_pred: [sqrt(25+4), sqrt(25+9), sqrt(25+16)] = [5.38516481, 5.83095189, 6.40312424]\n        # distances pred_gt: [sqrt(25+16), sqrt(25+9), sqrt(25+4)] = [6.40312424, 5.83095189, 5.38516481]\n\n        res = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)(predictions_hot, labels_hot)\n        expected_res = [[1 - (8 + 4) / (36 * 2 + 8 + 4), 0]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        res = SurfaceDiceMetric(class_thresholds=[2.8, 5.5], include_background=True)(predictions_hot, labels_hot)\n        expected_res = [[1 - 3 / (36 * 2 + 8 + 4), 1 - (2 + 2) / (3 + 3)]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        res = SurfaceDiceMetric(class_thresholds=[3, 6], include_background=True)(predictions_hot, labels_hot)\n        expected_res = [[1 - 1 / (36 * 2 + 8 + 4), 1 - 2 / (3 + 3)]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        # Chessboard distance:\n        # background:\n        # 36 boundary pixels have 0 distances; non-zero distances:\n        # distances gt_pred: [max(3,0), max(3,2), max(2,0), max(3,3), max(2,0), max(0,2), max(2,0), max(0,1)] =\n        # [3, 3, 2, 3, 2, 2, 2, 1]\n        # distances pred_gt: [max(0,1), max(2,0), max(2,0), max(1,0)] = [1, 2, 2, 1]\n        # class 1:\n        # distances gt_pred: [max(5,2), max(5,3), max(5,4)] = [5, 5, 5]\n        # distances pred_gt: [max(5,4), max(5,3), max(5,2)] = [5, 5, 5]\n\n        res = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, distance_metric=\"chessboard\")(\n            predictions_hot, labels_hot\n        )\n        expected_res = [[1 - (8 + 4) / (36 * 2 + 8 + 4), 0]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        res = SurfaceDiceMetric(class_thresholds=[1, 4.999], include_background=True, distance_metric=\"chessboard\")(\n            predictions_hot, labels_hot\n        )\n        expected_res = [[1 - (7 + 2) / (36 * 2 + 8 + 4), 0]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        res = SurfaceDiceMetric(class_thresholds=[2, 5], include_background=True, distance_metric=\"chessboard\")(\n            predictions_hot, labels_hot\n        )\n        expected_res = [[1 - 3 / (36 * 2 + 8 + 4), 1]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        # Taxicab distance (= Manhattan distance):\n        # background:\n        # 36 boundary pixels have 0 distances; non-zero distances:\n        # distances gt_pred: [3+0, 4+0, 2+0, 0+3, 2+0, 0+2, 2+0, 0+1] = [3, 4, 2, 3, 2, 2, 2, 1]\n        # distances pred_gt: [0+1, 2+0, 2+0, 1+0] = [1, 2, 2, 1]\n        # class 1:\n        # distances gt_pred: [5+2, 5+3, 5+4] = [7, 8, 9]\n        # distances pred_gt: [5+4, 5+3, 5+2] = [9, 8, 7]\n\n        res = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, distance_metric=\"taxicab\")(\n            predictions_hot, labels_hot\n        )\n        expected_res = [[1 - (8 + 4) / (36 * 2 + 8 + 4), 0]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        res = SurfaceDiceMetric(class_thresholds=[1, 7], include_background=True, distance_metric=\"taxicab\")(\n            predictions_hot, labels_hot\n        )\n        expected_res = [[1 - (7 + 2) / (36 * 2 + 8 + 4), 1 - (2 + 2) / (3 + 3)]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n        res = SurfaceDiceMetric(class_thresholds=[3, 9], include_background=True, distance_metric=\"taxicab\")(\n            predictions_hot, labels_hot\n        )\n        expected_res = [[1 - 1 / (36 * 2 + 8 + 4), 1]]\n        np.testing.assert_array_almost_equal(res, expected_res)\n\n    def test_asserts(self):\n        batch_size = 1\n        n_class = 2\n        predictions = torch.zeros((batch_size, 80, 80), dtype=torch.int64)\n        labels = torch.zeros((batch_size, 80, 80), dtype=torch.int64)\n        predictions[0, 10:20, 10:20] = 1\n        labels[0, 20:30, 20:30] = 1\n        predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2)\n        labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2)\n\n        # no torch tensor\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot.numpy(), labels_hot)\n        self.assertEqual(\n            \"y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.\",\n            str(context.exception),\n        )\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot.numpy())\n        self.assertEqual(\"y_pred and y must be PyTorch Tensor.\", str(context.exception))\n\n        # wrong dimensions\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions, labels_hot)\n        self.assertEqual(\"y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].\", str(context.exception))\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels)\n        self.assertEqual(\"y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].\", str(context.exception))\n\n        # mismatch of shape of input tensors\n        input_bad_shape = torch.clone(predictions_hot)\n        input_bad_shape = input_bad_shape[:, :, :, :50]\n\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, input_bad_shape)\n        self.assertEqual(\n            \"y_pred and y should have same shape, but instead, shapes are torch.Size([1, 2, 80, 80]) (y_pred) and \"\n            \"torch.Size([1, 2, 80, 50]) (y).\",\n            str(context.exception),\n        )\n\n        # wrong number of class thresholds\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True)(predictions_hot, labels_hot)\n        self.assertEqual(\"number of classes (2) does not match number of class thresholds (3).\", str(context.exception))\n\n        # inf and nan values in class thresholds\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[np.inf, 1], include_background=True)(predictions_hot, labels_hot)\n        self.assertEqual(\"All class thresholds need to be finite.\", str(context.exception))\n\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[np.nan, 1], include_background=True)(predictions_hot, labels_hot)\n            self.assertEqual(\"All class thresholds need to be finite.\", str(context.exception))\n\n        # negative values in class thresholds:\n        with self.assertRaises(ValueError) as context:\n            SurfaceDiceMetric(class_thresholds=[-0.22, 1], include_background=True)(predictions_hot, labels_hot)\n        self.assertEqual(\"All class thresholds need to be >= 0.\", str(context.exception))\n\n    def test_not_predicted_not_present(self):\n        # class is present in labels, but not in prediction -> nsd of 0 should be yielded for that class; class is\n        # neither present on labels, nor prediction -> nan should be yielded\n        batch_size = 1\n        n_class = 4\n        predictions = torch.zeros((batch_size, 80, 80), dtype=torch.int64)\n        labels = torch.zeros((batch_size, 80, 80), dtype=torch.int64)\n        predictions[0, 10:20, 10:20] = 1\n        labels[0, 10:20, 10:20] = 2\n        predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 3, 1, 2)\n        labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 3, 1, 2)\n\n        # with and without background class\n        sur_metric_bgr = SurfaceDiceMetric(class_thresholds=[1, 1, 1, 1], include_background=True)\n        sur_metric = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=False)\n\n        # test per-class results\n        res_bgr_classes = sur_metric_bgr(predictions_hot, labels_hot)\n        np.testing.assert_array_equal(res_bgr_classes, [[1, 0, 0, np.nan]])\n        res_classes = sur_metric(predictions_hot, labels_hot)\n        np.testing.assert_array_equal(res_classes, [[0, 0, np.nan]])\n\n        # test aggregation\n        res_bgr = sur_metric_bgr.aggregate(reduction=\"mean\")\n        np.testing.assert_equal(res_bgr, torch.tensor([1 / 3], dtype=torch.float))\n        res = sur_metric.aggregate()\n        np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float))\n\n        predictions_empty = torch.zeros((2, 3, 1, 1))\n        sur_metric_nans = SurfaceDiceMetric(class_thresholds=[1, 1, 1], include_background=True, get_not_nans=True)\n        res_classes = sur_metric_nans(predictions_empty, predictions_empty)\n        res, not_nans = sur_metric_nans.aggregate()\n        np.testing.assert_array_equal(res_classes, [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]])\n        np.testing.assert_equal(res, torch.tensor([0], dtype=torch.float))\n        np.testing.assert_equal(not_nans, torch.tensor([0], dtype=torch.float))\n\n    def test_compute_surface_dice_subvoxel(self):\n        mask_gt, mask_pred = (torch.zeros(1, 1, 128, 128, 128), torch.zeros(1, 1, 128, 128, 128))\n        mask_gt[0, 0, 50, 60, 70] = 1\n        res = compute_surface_dice(\n            mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True\n        )\n        assert_allclose(res, 0.0, type_test=False)\n        mask_gt[0, 0, 50, 60, 70] = 0\n        mask_pred[0, 0, 50, 60, 72] = 1\n        res = compute_surface_dice(\n            mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True\n        )\n        assert_allclose(res, 0.0, type_test=False)\n        mask_gt[0, 0, 50, 60, 70] = 1\n        mask_pred[0, 0, 50, 60, 72] = 1\n        res = compute_surface_dice(\n            mask_pred, mask_gt, [1.0], include_background=True, spacing=(3, 2, 1), use_subvoxels=True\n        )\n        assert_allclose(res, 0.5, type_test=False)\n\n        d = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n        mask_gt, mask_pred = (torch.zeros(1, 1, 100, 100, 100, device=d), torch.zeros(1, 1, 100, 100, 100, device=d))\n        mask_gt[0, 0, 0:50, :, :] = 1\n        mask_pred[0, 0, 0:51, :, :] = 1\n        res = compute_surface_dice(\n            mask_pred, mask_gt, [1.0], include_background=True, spacing=(2, 1, 1), use_subvoxels=True\n        )\n        assert_allclose(res, 0.836145, type_test=False, atol=1e-3, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/metrics/test_surface_distance.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.metrics import SurfaceDistanceMetric\n\n_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\n\ndef create_spherical_seg_3d(\n    radius: float = 20.0,\n    centre: tuple[int, int, int] = (49, 49, 49),\n    im_shape: tuple[int, int, int] = (99, 99, 99),\n    im_spacing: tuple[float, float, float] = (1.0, 1.0, 1.0),\n) -> np.ndarray:\n    \"\"\"\n    Return a 3D image with a sphere inside. Voxel values will be\n    1 inside the sphere, and 0 elsewhere.\n\n    Args:\n        radius: radius of sphere (in terms of number of voxels, can be partial)\n        centre: location of sphere centre.\n        im_shape: shape of image to create.\n        im_spacing: spacing of image to create.\n\n    See also:\n        :py:meth:`~create_test_image_3d`\n    \"\"\"\n    # Create image\n    image = np.zeros(im_shape, dtype=np.int32)\n    spy, spx, spz = np.ogrid[: im_shape[0], : im_shape[1], : im_shape[2]]\n\n    spy = spy.astype(float) * im_spacing[0]\n    spx = spx.astype(float) * im_spacing[1]\n    spz = spz.astype(float) * im_spacing[2]\n\n    spy -= centre[0]\n    spx -= centre[1]\n    spz -= centre[2]\n\n    circle = (spx * spx + spy * spy + spz * spz) <= radius * radius\n\n    image[circle] = 1\n    image[~circle] = 0\n    return image\n\n\ntest_spacing = (0.85, 1.2, 0.9)\nTEST_CASES = [\n    [[create_spherical_seg_3d(), create_spherical_seg_3d(), \"euclidean\", None], [0, 0]],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 20, 20)),\n            create_spherical_seg_3d(radius=20, centre=(19, 19, 19)),\n            \"taxicab\",\n        ],\n        [1.0380029806259314, 1.0380029806259314],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=33, centre=(19, 33, 22)),\n            create_spherical_seg_3d(radius=33, centre=(20, 33, 22)),\n            \"euclidean\",\n            None,\n        ],\n        [0.350217, 0.3483278807706289],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 33, 22)),\n            create_spherical_seg_3d(radius=40, centre=(20, 33, 22)),\n            \"euclidean\",\n            None,\n        ],\n        [15.117741, 12.040033513150455],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 33, 22)),\n            create_spherical_seg_3d(radius=40, centre=(20, 33, 22)),\n            \"chessboard\",\n        ],\n        [11.492719, 9.605067064083457],\n    ],\n    [\n        [\n            create_spherical_seg_3d(radius=20, centre=(20, 33, 22)),\n            create_spherical_seg_3d(radius=40, centre=(20, 33, 22)),\n            \"taxicab\",\n        ],\n        [20.214613, 12.432687531048186],\n    ],\n    [\n        [np.zeros([99, 99, 99]), create_spherical_seg_3d(radius=40, centre=(20, 33, 22)), \"euclidean\", None],\n        [np.inf, np.inf],\n    ],\n    [[create_spherical_seg_3d(), np.zeros([99, 99, 99]), \"taxicab\"], [np.inf, np.inf]],\n    [\n        [\n            create_spherical_seg_3d(radius=33, centre=(42, 45, 52), im_spacing=test_spacing),\n            create_spherical_seg_3d(radius=33, centre=(43, 45, 52), im_spacing=test_spacing),\n            \"euclidean\",\n            test_spacing,\n        ],\n        [0.4951, 0.4951],\n    ],\n]\n\nTEST_CASES_NANS = [\n    [\n        [\n            # both pred and gt do not have foreground, spacing is None, metric and not_nans should be 0\n            np.zeros([99, 99, 99]),\n            np.zeros([99, 99, 99]),\n            None,\n        ]\n    ],\n    [\n        [\n            # both pred and gt do not have foreground, spacing is not None, metric and not_nans should be 0\n            np.zeros([99, 99, 99]),\n            np.zeros([99, 99, 99]),\n            test_spacing,\n        ]\n    ],\n]\n\n\nclass TestAllSurfaceMetrics(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_value(self, input_data, expected_value):\n        if len(input_data) == 3:\n            [seg_1, seg_2, metric] = input_data\n            spacing = None\n        else:\n            [seg_1, seg_2, metric, spacing] = input_data\n\n        ct = 0\n        seg_1 = torch.tensor(seg_1, device=_device)\n        seg_2 = torch.tensor(seg_2, device=_device)\n        for symmetric in [True, False]:\n            sur_metric = SurfaceDistanceMetric(include_background=False, symmetric=symmetric, distance_metric=metric)\n            # shape of seg_1, seg_2 are: HWD, converts to BNHWD\n            batch, n_class = 2, 3\n            batch_seg_1 = seg_1.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1])\n            batch_seg_2 = seg_2.unsqueeze(0).unsqueeze(0).repeat([batch, n_class, 1, 1, 1])\n            sur_metric(batch_seg_1, batch_seg_2, spacing=spacing)\n            result = sur_metric.aggregate()\n            expected_value_curr = expected_value[ct]\n            np.testing.assert_allclose(expected_value_curr, result.cpu(), rtol=1e-5)\n            np.testing.assert_equal(result.device, seg_1.device)\n            ct += 1\n\n    @parameterized.expand(TEST_CASES_NANS)\n    def test_nans(self, input_data):\n        [seg_1, seg_2, spacing] = input_data\n        seg_1 = torch.tensor(seg_1)\n        seg_2 = torch.tensor(seg_2)\n        sur_metric = SurfaceDistanceMetric(include_background=False, get_not_nans=True)\n        # test list of channel-first Tensor\n        batch_seg_1 = [seg_1.unsqueeze(0)]\n        batch_seg_2 = [seg_2.unsqueeze(0)]\n        sur_metric(batch_seg_1, batch_seg_2, spacing=spacing)\n        result, not_nans = sur_metric.aggregate(reduction=\"mean\")\n        np.testing.assert_allclose(0, result, rtol=1e-5)\n        np.testing.assert_allclose(0, not_nans, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/min_tests.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport unittest\nfrom pathlib import Path\n\n\ndef run_testsuit():\n    \"\"\"\n    Load test cases by excluding those need external dependencies.\n    The loaded cases should work with \"requirements-min.txt\"::\n\n        # in the monai repo folder:\n        pip install -r requirements-min.txt\n        QUICKTEST=true python -m tests.min_tests\n\n    :return: a test suite\n    \"\"\"\n    exclude_cases = [  # these cases use external dependencies\n        \"test_ahnet\",\n        \"test_arraydataset\",\n        \"test_auto3dseg_bundlegen\",\n        \"test_auto3dseg_ensemble\",\n        \"test_auto3dseg_hpo\",\n        \"test_auto3dseg\",\n        \"test_bundle_onnx_export\",\n        \"test_bundle_trt_export\",\n        \"test_bundle_push_to_hf_hub\",\n        \"test_cachedataset\",\n        \"test_cachedataset_parallel\",\n        \"test_cachedataset_persistent_workers\",\n        \"test_cachentransdataset\",\n        \"test_check_missing_files\",\n        \"test_compute_f_beta\",\n        \"test_compute_ho_ver_maps\",\n        \"test_compute_ho_ver_maps_d\",\n        \"test_compute_panoptic_quality\",\n        \"test_contrastive_loss\",\n        \"test_convert_to_onnx\",\n        \"test_convert_to_trt\",\n        \"test_csv_dataset\",\n        \"test_csv_iterable_dataset\",\n        \"test_cumulative_average_dist\",\n        \"test_sampler_dist\",\n        \"test_dataset\",\n        \"test_dataset_summary\",\n        \"test_deepedit_transforms\",\n        \"test_deepedit_interaction\",\n        \"test_deepgrow_dataset\",\n        \"test_deepgrow_interaction\",\n        \"test_deepgrow_transforms\",\n        \"test_detect_envelope\",\n        \"test_dints_network\",\n        \"test_distance_transform_edt\",\n        \"test_efficientnet\",\n        \"test_ensemble_evaluator\",\n        \"test_ensure_channel_first\",\n        \"test_ensure_channel_firstd\",\n        \"test_fill_holes\",\n        \"test_fill_holesd\",\n        \"test_foreground_mask\",\n        \"test_foreground_maskd\",\n        \"test_global_mutual_information_loss\",\n        \"test_grid_patch\",\n        \"test_gmm\",\n        \"test_handler_metrics_reloaded\",\n        \"test_handler_average_precision\",\n        \"test_handler_checkpoint_loader\",\n        \"test_handler_checkpoint_saver\",\n        \"test_handler_classification_saver\",\n        \"test_handler_classification_saver_dist\",\n        \"test_handler_confusion_matrix\",\n        \"test_handler_confusion_matrix_dist\",\n        \"test_handler_decollate_batch\",\n        \"test_handler_early_stop\",\n        \"test_handler_garbage_collector\",\n        \"test_handler_hausdorff_distance\",\n        \"test_handler_ignite_metric\",\n        \"test_handler_lr_scheduler\",\n        \"test_handler_mean_dice\",\n        \"test_handler_panoptic_quality\",\n        \"test_handler_mean_iou\",\n        \"test_handler_metrics_saver\",\n        \"test_handler_metrics_saver_dist\",\n        \"test_handler_mlflow\",\n        \"test_handler_nvtx\",\n        \"test_handler_parameter_scheduler\",\n        \"test_handler_post_processing\",\n        \"test_handler_prob_map_producer\",\n        \"test_handler_regression_metrics\",\n        \"test_handler_regression_metrics_dist\",\n        \"test_handler_rocauc\",\n        \"test_handler_rocauc_dist\",\n        \"test_handler_smartcache\",\n        \"test_handler_stats\",\n        \"test_handler_surface_distance\",\n        \"test_handler_tb_image\",\n        \"test_handler_tb_stats\",\n        \"test_handler_validation\",\n        \"test_hausdorff_distance\",\n        \"test_header_correct\",\n        \"test_hilbert_transform\",\n        \"test_hovernet_loss\",\n        \"test_image_dataset\",\n        \"test_image_rw\",\n        \"test_img2tensorboard\",\n        \"test_integration_fast_train\",\n        \"test_integration_gpu_customization\",\n        \"test_integration_segmentation_3d\",\n        \"test_integration_lazy_samples\",\n        \"test_integration_sliding_window\",\n        \"test_integration_unet_2d\",\n        \"test_integration_workflows\",\n        \"test_integration_workflows_gan\",\n        \"test_integration_bundle_run\",\n        \"test_integration_autorunner\",\n        \"test_integration_nnunetv2_runner\",\n        \"test_integration_nnunet_bundle\",\n        \"test_invert\",\n        \"test_invertd\",\n        \"test_iterable_dataset\",\n        \"test_keep_largest_connected_component\",\n        \"test_keep_largest_connected_componentd\",\n        \"test_label_filter\",\n        \"test_lltm\",\n        \"test_lmdbdataset\",\n        \"test_lmdbdataset_dist\",\n        \"test_load_image\",\n        \"test_load_imaged\",\n        \"test_load_spacing_orientation\",\n        \"test_mednistdataset\",\n        \"test_milmodel\",\n        \"test_mlp\",\n        \"test_nifti_header_revise\",\n        \"test_nifti_rw\",\n        \"test_nuclick_transforms\",\n        \"test_nrrd_reader\",\n        \"test_occlusion_sensitivity\",\n        \"test_orientation\",\n        \"test_orientationd\",\n        \"test_patchembedding\",\n        \"test_persistentdataset\",\n        \"test_pil_reader\",\n        \"test_plot_2d_or_3d_image\",\n        \"test_png_rw\",\n        \"test_prepare_batch_default\",\n        \"test_prepare_batch_diffusion\",\n        \"test_prepare_batch_extra_input\",\n        \"test_prepare_batch_hovernet\",\n        \"test_rand_grid_patch\",\n        \"test_rand_rotate\",\n        \"test_rand_rotated\",\n        \"test_rand_zoom\",\n        \"test_rand_zoomd\",\n        \"test_randtorchvisiond\",\n        \"test_rankfilter_dist\",\n        \"test_resample_backends\",\n        \"test_resize\",\n        \"test_resized\",\n        \"test_resample_to_match\",\n        \"test_resample_to_matchd\",\n        \"test_rotate\",\n        \"test_rotated\",\n        \"test_save_image\",\n        \"test_save_imaged\",\n        \"test_selfattention\",\n        \"test_senet\",\n        \"test_smartcachedataset\",\n        \"test_spacing\",\n        \"test_spacingd\",\n        \"test_splitdimd\",\n        \"test_surface_distance\",\n        \"test_surface_dice\",\n        \"test_testtimeaugmentation\",\n        \"test_torchvision\",\n        \"test_torchvisiond\",\n        \"test_transchex\",\n        \"test_transformerblock\",\n        \"test_trt_compile\",\n        \"test_unetr\",\n        \"test_unetr_block\",\n        \"test_vit\",\n        \"test_vitautoenc\",\n        \"test_vnet\",\n        \"test_write_metrics_reports\",\n        \"test_wsireader\",\n        \"test_zoom\",\n        \"test_zoom_affine\",\n        \"test_zoomd\",\n        \"test_prepare_batch_default_dist\",\n        \"test_bundle_verify_metadata\",\n        \"test_bundle_verify_net\",\n        \"test_bundle_ckpt_export\",\n        \"test_bundle_utils\",\n        \"test_bundle_init_bundle\",\n        \"test_fastmri_reader\",\n        \"test_metrics_reloaded\",\n        \"test_spatial_combine_transforms\",\n        \"test_bundle_workflow\",\n        \"test_zarr_avg_merger\",\n        \"test_perceptual_loss\",\n        \"test_ultrasound_confidence_map_transform\",\n        \"test_vista3d_utils\",\n        \"test_vista3d_transforms\",\n        \"test_matshow3d\",\n    ]\n    assert sorted(exclude_cases) == sorted(set(exclude_cases)), f\"Duplicated items in {exclude_cases}\"\n\n    files = [f.relative_to(Path(__file__).parent.parent) for f in Path(__file__).parent.rglob(\"test_*.py\")]\n    files = [str(f).replace(os.sep, \".\").replace(\".py\", \"\") for f in files]\n\n    cases = []\n    for test_module in files:\n        test_case = test_module.split(\".\")[-1]\n        if test_case in exclude_cases:\n            exclude_cases.remove(test_case)\n            print(f\"skipping {test_module}\")\n        else:\n            print(f\"adding {test_module}\")\n            cases.append(test_module)\n    exclude_cases = [str(list(Path(__file__).parent.rglob(f\"*{et}*\"))[0]) for et in exclude_cases]\n    assert not exclude_cases, f\"items in exclude_cases not used: {' '.join(exclude_cases)}\"\n    test_suite = unittest.TestLoader().loadTestsFromNames(cases)\n    return test_suite\n\n\nif __name__ == \"__main__\":\n    # testing import submodules\n    from monai.utils.module import load_submodules\n\n    _, err_mod = load_submodules(sys.modules[\"monai\"], True)\n    assert not err_mod, f\"err_mod={err_mod} not empty\"\n\n    # testing all modules\n    test_runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2)\n    result = test_runner.run(run_testsuit())\n    sys.exit(int(not result.wasSuccessful()))\n"
  },
  {
    "path": "tests/networks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/blocks/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/blocks/dints_block/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/blocks/dints_block/test_acn_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks.dints_block import ActiConvNormBlock\n\nTEST_CASES = [\n    [{\"in_channel\": 32, \"out_channel\": 16, \"kernel_size\": 3, \"padding\": 1}, (7, 32, 16, 31, 7), (7, 16, 16, 31, 7)],\n    [\n        {\"in_channel\": 32, \"out_channel\": 16, \"kernel_size\": 3, \"padding\": 1, \"spatial_dims\": 2},\n        (7, 32, 13, 32),\n        (7, 16, 13, 32),\n    ],\n]\n\n\nclass TestACNBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_acn_block(self, input_param, input_shape, expected_shape):\n        net = ActiConvNormBlock(**input_param)\n        result = net(torch.randn(input_shape))\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/dints_block/test_factorized_increase.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks.dints_block import FactorizedIncreaseBlock\n\nTEST_CASES_3D = [\n    [{\"in_channel\": 32, \"out_channel\": 16}, (7, 32, 24, 16, 8), (7, 16, 48, 32, 16)],\n    [{\"in_channel\": 1, \"out_channel\": 2}, (1, 1, 1, 1, 1), (1, 2, 2, 2, 2)],\n]\n\n\nclass TestFactInc(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_3D)\n    def test_factorized_increase_3d(self, input_param, input_shape, expected_shape):\n        net = FactorizedIncreaseBlock(**input_param)\n        result = net(torch.randn(input_shape))\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/dints_block/test_factorized_reduce.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks.dints_block import FactorizedReduceBlock\n\nTEST_CASES_3D = [\n    [{\"in_channel\": 32, \"out_channel\": 16}, (7, 32, 24, 16, 8), (7, 16, 12, 8, 4)],\n    [{\"in_channel\": 16, \"out_channel\": 32}, (7, 16, 22, 14, 6), (7, 32, 11, 7, 3)],\n]\n\n\nclass TestFactRed(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_3D)\n    def test_factorized_reduce_3d(self, input_param, input_shape, expected_shape):\n        net = FactorizedReduceBlock(**input_param)\n        result = net(torch.randn(input_shape))\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/dints_block/test_p3d_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks.dints_block import P3DActiConvNormBlock\n\nTEST_CASES_3D = [\n    [\n        {\"in_channel\": 32, \"out_channel\": 16, \"kernel_size\": 3, \"padding\": 0, \"mode\": 0},\n        (7, 32, 16, 32, 8),\n        (7, 16, 14, 30, 6),\n    ],\n    [\n        {\"in_channel\": 32, \"out_channel\": 16, \"kernel_size\": 3, \"padding\": 1, \"mode\": 0},  # check padding\n        (7, 32, 16, 32, 8),\n        (7, 16, 16, 32, 8),\n    ],\n    [\n        {\"in_channel\": 32, \"out_channel\": 16, \"kernel_size\": 3, \"padding\": 0, \"mode\": 1},\n        (7, 32, 16, 32, 8),\n        (7, 16, 14, 30, 6),\n    ],\n    [\n        {\n            \"in_channel\": 32,\n            \"out_channel\": 16,\n            \"kernel_size\": 3,\n            \"padding\": 0,\n            \"mode\": 2,\n            \"act_name\": (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.2}),\n        },\n        (7, 32, 16, 32, 8),\n        (7, 16, 14, 30, 6),\n    ],\n    [\n        {\n            \"in_channel\": 32,\n            \"out_channel\": 16,\n            \"kernel_size\": 4,\n            \"padding\": 0,\n            \"mode\": 0,\n            \"norm_name\": (\"INSTANCE\", {\"affine\": True}),\n        },\n        (7, 32, 16, 32, 8),\n        (7, 16, 13, 29, 5),\n    ],\n]\n\n\nclass TestP3D(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_3D)\n    def test_3d(self, input_param, input_shape, expected_shape):\n        net = P3DActiConvNormBlock(**input_param)\n        result = net(torch.randn(input_shape))\n        self.assertEqual(result.shape, expected_shape)\n\n    def test_ill(self):\n        with self.assertRaises(ValueError):\n            P3DActiConvNormBlock(in_channel=32, out_channel=16, kernel_size=3, padding=0, mode=3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_CABlock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.cablock import CABlock, FeedForward\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose, dict_product\n\neinops, has_einops = optional_import(\"einops\")\n\nTEST_CASES_CAB = [\n    [\n        {**params, \"flash_attention\": False},\n        (2, params[\"dim\"], *([16] * params[\"spatial_dims\"])),\n        (2, params[\"dim\"], *([16] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(spatial_dims=[2, 3], dim=[32, 64, 128], num_heads=[2, 4, 8], bias=[True, False])\n]\n\n\nTEST_CASES_FEEDFORWARD = [\n    # Test different spatial dims, dimensions and expansion factors\n    [{\"spatial_dims\": 2, \"dim\": 64, \"ffn_expansion_factor\": 2.0, \"bias\": True}, (2, 64, 32, 32)],\n    [{\"spatial_dims\": 3, \"dim\": 128, \"ffn_expansion_factor\": 1.5, \"bias\": False}, (2, 128, 16, 16, 16)],\n    [{\"spatial_dims\": 2, \"dim\": 256, \"ffn_expansion_factor\": 1.0, \"bias\": True}, (1, 256, 64, 64)],\n]\n\n\nclass TestFeedForward(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_FEEDFORWARD)\n    def test_shape(self, input_param, input_shape):\n        net = FeedForward(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, input_shape)\n\n    def test_gating_mechanism(self):\n        net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True)\n        x = torch.ones(1, 32, 16, 16)\n        out = net(x)\n        self.assertNotEqual(torch.sum(out), torch.sum(x))\n\n\nclass TestCABlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_CAB)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = CABlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_invalid_spatial_dims(self):\n        with self.assertRaises(ValueError):\n            CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_flash_attention(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)\n        x = torch.randn(2, 64, 32, 32).to(device)\n        output = block(x)\n        self.assertEqual(output.shape, x.shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_temperature_parameter(self):\n        block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)\n        self.assertTrue(isinstance(block.temperature, torch.nn.Parameter))\n        self.assertEqual(block.temperature.shape, (4, 1, 1))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_qkv_transformation_2d(self):\n        block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True)\n        x = torch.randn(2, 64, 32, 32)\n        qkv = block.qkv(x)\n        self.assertEqual(qkv.shape, (2, 192, 32, 32))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_qkv_transformation_3d(self):\n        block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True)\n        x = torch.randn(2, 64, 16, 16, 16)\n        qkv = block.qkv(x)\n        self.assertEqual(qkv.shape, (2, 192, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_flash_vs_normal_attention(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device)\n        block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device)\n\n        block_normal.load_state_dict(block_flash.state_dict())\n\n        x = torch.randn(2, 64, 32, 32).to(device)\n        with torch.no_grad():\n            out_flash = block_flash(x)\n            out_normal = block_normal(x)\n\n        assert_allclose(out_flash, out_normal, atol=1e-4)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_deterministic_small_input(self):\n        block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False)\n        with torch.no_grad():\n            block.qkv.conv.weight.data.fill_(1.0)\n            block.qkv_dwconv.conv.weight.data.fill_(1.0)\n            block.temperature.data.fill_(1.0)\n            block.project_out.conv.weight.data.fill_(1.0)\n\n        x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32)\n\n        output = block(x)\n        # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72\n        expected = torch.full_like(x, 72.0)\n\n        assert_allclose(output, expected, atol=1e-6)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_adn.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks import ADN\nfrom tests.test_utils import TorchImageTestCase2D, TorchImageTestCase3D\n\nTEST_CASES_2D = [\n    [{\"act\": None}],\n    [{\"norm_dim\": 2}],\n    [{\"norm_dim\": 2, \"act\": \"relu\", \"dropout\": 0.8, \"ordering\": \"DA\"}],\n    [{\"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"DA\"}],\n    [{\"norm\": \"BATCH\", \"norm_dim\": 2, \"in_channels\": 1, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"NDA\"}],\n    [{\"norm\": \"BATCH\", \"norm_dim\": 2, \"in_channels\": 1, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"AND\"}],\n    [{\"norm\": \"INSTANCE\", \"norm_dim\": 2, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"AND\"}],\n    [\n        {\n            \"norm\": (\"GROUP\", {\"num_groups\": 1, \"affine\": False}),\n            \"in_channels\": 1,\n            \"norm_dim\": 2,\n            \"dropout_dim\": 1,\n            \"dropout\": 0.8,\n            \"ordering\": \"AND\",\n        }\n    ],\n    [{\"norm\": (\"localresponse\", {\"size\": 4}), \"norm_dim\": 2, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"AND\"}],\n]\n\nTEST_CASES_3D = [\n    [{\"norm_dim\": 3}],\n    [{\"act\": \"prelu\", \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"DA\"}],\n    [{\"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"DA\"}],\n    [{\"norm\": \"BATCH\", \"norm_dim\": 3, \"in_channels\": 1, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"NDA\"}],\n    [{\"norm\": \"BATCH\", \"norm_dim\": 3, \"in_channels\": 1, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"AND\"}],\n    [{\"norm\": \"INSTANCE\", \"norm_dim\": 3, \"dropout_dim\": 1, \"dropout\": 0.8, \"ordering\": \"AND\"}],\n    [\n        {\n            \"norm\": (\"layer\", {\"normalized_shape\": (48, 80)}),\n            \"norm_dim\": 3,\n            \"dropout_dim\": 1,\n            \"dropout\": 0.8,\n            \"ordering\": \"AND\",\n        }\n    ],\n]\n\n\nclass TestADN2D(TorchImageTestCase2D):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_adn_2d(self, args):\n        adn = ADN(**args)\n        print(adn)\n        out = adn(self.imt)\n        expected_shape = (1, self.input_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_no_input(self):\n        with self.assertRaises(ValueError):\n            ADN(norm=\"instance\")\n\n\nclass TestADN3D(TorchImageTestCase3D):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_adn_3d(self, args):\n        adn = ADN(**args)\n        print(adn)\n        out = adn(self.imt)\n        expected_shape = (1, self.input_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_convolutions.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.networks.blocks import Convolution, ResidualUnit\nfrom tests.test_utils import TorchImageTestCase2D, TorchImageTestCase3D\n\n\nclass TestConvolution2D(TorchImageTestCase2D):\n    def test_conv1(self):\n        conv = Convolution(2, self.input_channels, self.output_channels)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_conv1_no_acti(self):\n        conv = Convolution(2, self.input_channels, self.output_channels, act=None)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_conv_only1(self):\n        conv = Convolution(2, self.input_channels, self.output_channels, conv_only=True)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_stride1(self):\n        for strides in [2, [2, 2], (2, 2)]:\n            conv = Convolution(2, self.input_channels, self.output_channels, strides=strides)\n            out = conv(self.imt)\n            expected_shape = (1, self.output_channels, self.im_shape[0] // 2, self.im_shape[1] // 2)\n            self.assertEqual(out.shape, expected_shape)\n\n    def test_dilation1(self):\n        conv = Convolution(2, self.input_channels, self.output_channels, dilation=3)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_dropout1(self):\n        conv = Convolution(2, self.input_channels, self.output_channels, dropout=0.15)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_transpose1(self):\n        conv = Convolution(2, self.input_channels, self.output_channels, is_transposed=True)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_transpose2(self):\n        conv = Convolution(2, self.input_channels, self.output_channels, strides=2, is_transposed=True)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0] * 2, self.im_shape[1] * 2)\n        self.assertEqual(out.shape, expected_shape)\n\n\nclass TestConvolution3D(TorchImageTestCase3D):\n    def test_conv1(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.1, adn_ordering=\"DAN\")\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_conv1_no_acti(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, act=None, adn_ordering=\"AND\")\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_conv_only1(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, conv_only=True)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_stride1(self):\n        for strides in [2, (2, 2, 2), [2, 2, 2]]:\n            conv = Convolution(3, self.input_channels, self.output_channels, strides=strides)\n            out = conv(self.imt)\n            expected_shape = (\n                1,\n                self.output_channels,\n                self.im_shape[0] // 2,\n                self.im_shape[1] // 2,\n                self.im_shape[2] // 2,\n            )\n            self.assertEqual(out.shape, expected_shape)\n\n    def test_dilation1(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, dilation=3)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_dropout1(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, dropout=0.15)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_transpose1(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, is_transposed=True)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_transpose2(self):\n        conv = Convolution(3, self.input_channels, self.output_channels, strides=2, is_transposed=True)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0] * 2, self.im_shape[1] * 2, self.im_shape[2] * 2)\n        self.assertEqual(out.shape, expected_shape)\n\n\nclass TestResidualUnit2D(TorchImageTestCase2D):\n    def test_conv_only1(self):\n        conv = ResidualUnit(2, 1, self.output_channels)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_stride1(self):\n        for strides in [2, [2, 2], (2, 2)]:\n            conv = ResidualUnit(2, 1, self.output_channels, strides=strides)\n            out = conv(self.imt)\n            expected_shape = (1, self.output_channels, self.im_shape[0] // 2, self.im_shape[1] // 2)\n            self.assertEqual(out.shape, expected_shape)\n\n    def test_dilation1(self):\n        conv = ResidualUnit(2, 1, self.output_channels, dilation=3)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_dropout1(self):\n        conv = ResidualUnit(2, 1, self.output_channels, dropout=0.15)\n        out = conv(self.imt)\n        expected_shape = (1, self.output_channels, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_crf_cpu.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks import CRF\nfrom tests.test_utils import skip_if_no_cpp_extension\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s)\",\n        # Parameters\n        [\n            5,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.5,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            None,  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 1, 0.5, 0]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [1, 1, 0.5, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.726896, 0.704883, 0.589467, 0.376669, 0.380321],\n                # Class 1\n                [0.273104, 0.295117, 0.410533, 0.623331, 0.619679],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.741916, 0.720671, 0.551116, 0.328360, 0.376258],\n                # Class 1\n                [0.258084, 0.279329, 0.448885, 0.671640, 0.623742],\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s), with_matrix\",\n        # Parameters\n        [\n            5,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.5,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            2 * torch.eye(2),  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 1, 0.5, 0]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [1, 1, 0.5, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.870921, 0.857105, 0.781170, 0.544729, 0.476710],\n                # Class 1\n                [0.129078, 0.142894, 0.218830, 0.455271, 0.523290],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.867234, 0.852610, 0.648074, 0.334584, 0.386766],\n                # Class 1\n                [0.132766, 0.147390, 0.351926, 0.665416, 0.613234],\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batche(s), 2 dimension(s), 3 classe(s), 2 channel(s)\",\n        # Parameters\n        [\n            5,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.5,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            None,  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                ],\n                # Class 1\n                [\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                # Class 2\n                [\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                    [0.0, 0.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 0.0],\n                    [1.0, 1.0, 1.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                ],\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                # Channel 1\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                ],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    [0.159525, 0.161449, 0.270907, 0.152424, 0.152515],\n                    [0.161763, 0.163849, 0.154026, 0.154187, 0.154360],\n                    [0.273231, 0.154715, 0.155208, 0.155677, 0.275885],\n                    [0.155076, 0.155748, 0.156349, 0.598796, 0.600179],\n                    [0.156186, 0.156858, 0.277928, 0.598459, 0.600289],\n                ],\n                # Class 1\n                [\n                    [0.647632, 0.639540, 0.276122, 0.155184, 0.155117],\n                    [0.638555, 0.629703, 0.155613, 0.155552, 0.155509],\n                    [0.276475, 0.156138, 0.156061, 0.155919, 0.275726],\n                    [0.156109, 0.156397, 0.156575, 0.172626, 0.172270],\n                    [0.156380, 0.156690, 0.277053, 0.172495, 0.172123],\n                ],\n                # Class 2\n                [\n                    [0.192843, 0.199011, 0.452971, 0.692392, 0.692368],\n                    [0.199682, 0.206448, 0.690361, 0.690261, 0.690130],\n                    [0.450294, 0.689147, 0.688731, 0.688403, 0.448389],\n                    [0.688815, 0.687855, 0.687076, 0.228579, 0.227552],\n                    [0.687434, 0.686453, 0.445019, 0.229047, 0.227588],\n                ],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batche(s), 3 dimension(s), 2 classe(s), 1 channel(s)\",\n        # Parameters\n        [\n            2,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.1,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            None,  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    # Slice 0\n                    [\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 1\n                    [\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 2\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 3\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 4\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                ],\n                # Class 1\n                [\n                    # Slice 0\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 1\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 2\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 3\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                    ],\n                    # Slice 4\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                    ],\n                ],\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Slice 0\n                    [\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 1\n                    [\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 2\n                    [\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.8, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                    ],\n                    # Slice 3\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                    ],\n                    # Slice 4\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                    ],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    # Slice 0\n                    [\n                        [0.775729, 0.774871, 0.557369, 0.501589, 0.501239],\n                        [0.774804, 0.774011, 0.556061, 0.501171, 0.500821],\n                        [0.557136, 0.556079, 0.554716, 0.500764, 0.500415],\n                        [0.501416, 0.501049, 0.500709, 0.500370, 0.500021],\n                        [0.500989, 0.500631, 0.500300, 0.499986, 0.499665],\n                    ],\n                    # Slice 1\n                    [\n                        [0.774559, 0.773821, 0.555753, 0.501108, 0.500757],\n                        [0.773701, 0.772905, 0.554399, 0.500680, 0.500342],\n                        [0.555462, 0.554443, 0.553025, 0.500300, 0.499967],\n                        [0.500892, 0.500562, 0.500256, 0.499931, 0.499666],\n                        [0.500477, 0.500156, 0.499859, 0.499572, 0.499355],\n                    ],\n                    # Slice 2\n                    [\n                        [0.556395, 0.555530, 0.554037, 0.500641, 0.500290],\n                        [0.555370, 0.554400, 0.552711, 0.500238, 0.499967],\n                        [0.553709, 0.552798, 0.459696, 0.449011, 0.448406],\n                        [0.500418, 0.500123, 0.448768, 0.448438, 0.447680],\n                        [0.500064, 0.499770, 0.448217, 0.447788, 0.446945],\n                    ],\n                    # Slice 3\n                    [\n                        [0.500963, 0.500754, 0.500531, 0.500187, 0.499956],\n                        [0.500662, 0.500394, 0.500144, 0.499822, 0.499657],\n                        [0.500353, 0.500090, 0.448429, 0.448021, 0.447234],\n                        [0.499966, 0.499724, 0.447893, 0.229453, 0.228867],\n                        [0.499779, 0.499514, 0.447548, 0.229087, 0.228434],\n                    ],\n                    # Slice 4\n                    [\n                        [0.500406, 0.500208, 0.500018, 0.499775, 0.499615],\n                        [0.500126, 0.499892, 0.499725, 0.499501, 0.499322],\n                        [0.499869, 0.499645, 0.447670, 0.446978, 0.446165],\n                        [0.499609, 0.499403, 0.447168, 0.228777, 0.228153],\n                        [0.499467, 0.499255, 0.446656, 0.228424, 0.227778],\n                    ],\n                ],\n                # Class 1\n                [\n                    # Slice 0\n                    [\n                        [0.224271, 0.225129, 0.442631, 0.498411, 0.498761],\n                        [0.225196, 0.225989, 0.443939, 0.498829, 0.499179],\n                        [0.442864, 0.443921, 0.445284, 0.499236, 0.499585],\n                        [0.498584, 0.498951, 0.499291, 0.499630, 0.499979],\n                        [0.499011, 0.499369, 0.499700, 0.500014, 0.500335],\n                    ],\n                    # Slice 1\n                    [\n                        [0.225441, 0.226179, 0.444247, 0.498892, 0.499243],\n                        [0.226299, 0.227095, 0.445601, 0.499320, 0.499658],\n                        [0.444538, 0.445557, 0.446975, 0.499700, 0.500033],\n                        [0.499108, 0.499438, 0.499744, 0.500069, 0.500334],\n                        [0.499523, 0.499844, 0.500141, 0.500428, 0.500645],\n                    ],\n                    # Slice 2\n                    [\n                        [0.443605, 0.444470, 0.445963, 0.499359, 0.499710],\n                        [0.444630, 0.445600, 0.447289, 0.499762, 0.500033],\n                        [0.446291, 0.447202, 0.540304, 0.550989, 0.551594],\n                        [0.499582, 0.499877, 0.551232, 0.551562, 0.552320],\n                        [0.499936, 0.500230, 0.551783, 0.552212, 0.553055],\n                    ],\n                    # Slice 3\n                    [\n                        [0.499037, 0.499246, 0.499469, 0.499813, 0.500044],\n                        [0.499338, 0.499606, 0.499856, 0.500178, 0.500343],\n                        [0.499647, 0.499910, 0.551571, 0.551979, 0.552766],\n                        [0.500034, 0.500276, 0.552106, 0.770547, 0.771133],\n                        [0.500221, 0.500486, 0.552452, 0.770913, 0.771566],\n                    ],\n                    # Slice 4\n                    [\n                        [0.499594, 0.499792, 0.499982, 0.500225, 0.500385],\n                        [0.499874, 0.500108, 0.500275, 0.500499, 0.500678],\n                        [0.500131, 0.500355, 0.552330, 0.553022, 0.553835],\n                        [0.500391, 0.500597, 0.552832, 0.771223, 0.771847],\n                        [0.500533, 0.500745, 0.553344, 0.771576, 0.772222],\n                    ],\n                ],\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\nclass CRFTestCaseCpu(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test(self, test_case_description, params, input, features, expected):\n        # Create input tensors\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device(\"cpu\"))\n        feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device(\"cpu\"))\n\n        # apply filter\n        crf = CRF(*params)\n        output = crf(input_tensor, feature_tensor).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_crf_cuda.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks import CRF\nfrom tests.test_utils import skip_if_no_cpp_extension, skip_if_no_cuda\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s)\",\n        # Parameters\n        [\n            5,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.5,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            None,  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 1, 0.5, 0]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [1, 1, 0.5, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.724431, 0.702247, 0.586338, 0.364053, 0.362328],\n                # Class 1\n                [0.275569, 0.297753, 0.413662, 0.635947, 0.637672],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.735150, 0.713455, 0.522234, 0.301106, 0.345620],\n                # Class 1\n                [0.264850, 0.286545, 0.477766, 0.698894, 0.654381],\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 batche(s), 1 dimension(s), 2 classe(s), 1 channel(s), with_matrix\",\n        # Parameters\n        [\n            5,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.5,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            2 * torch.eye(2),  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.8, 0.9, 0.6, 0.2, 0.3],\n                # Class 1\n                [0.1, 0.3, 0.5, 0.8, 0.7],\n            ],\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 1, 0.5, 0]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [1, 1, 0.5, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [0.854686, 0.839089, 0.755770, 0.463087, 0.357129],\n                # Class 1\n                [0.145314, 0.160911, 0.244230, 0.536913, 0.642871],\n            ],\n            # Batch 1\n            [\n                # Class 0\n                [0.825893, 0.807061, 0.492641, 0.196325, 0.231688],\n                # Class 1\n                [0.174107, 0.192939, 0.507359, 0.803675, 0.768312],\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batche(s), 2 dimension(s), 3 classe(s), 2 channel(s)\",\n        # Parameters\n        [\n            5,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.5,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            None,  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                ],\n                # Class 1\n                [\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                # Class 2\n                [\n                    [0.0, 0.0, 0.0, 0.5, 1.0],\n                    [0.0, 0.0, 0.5, 1.0, 0.5],\n                    [0.0, 0.5, 1.0, 0.5, 0.0],\n                    [0.5, 1.0, 0.5, 0.0, 0.0],\n                    [1.0, 0.5, 0.0, 0.0, 0.0],\n                ],\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                # Channel 1\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                ],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    [0.154633, 0.164076, 0.300110, 0.239729, 0.179437],\n                    [0.156664, 0.161426, 0.254582, 0.191402, 0.253060],\n                    [0.316391, 0.259811, 0.201576, 0.271977, 0.333670],\n                    [0.263658, 0.204998, 0.276233, 0.686272, 0.687161],\n                    [0.208480, 0.281425, 0.355033, 0.690412, 0.692331],\n                ],\n                # Class 1\n                [\n                    [0.681083, 0.652977, 0.312156, 0.245985, 0.181768],\n                    [0.675692, 0.662155, 0.247827, 0.183893, 0.240174],\n                    [0.309154, 0.249075, 0.186364, 0.240742, 0.293918],\n                    [0.243739, 0.185445, 0.242820, 0.151819, 0.151363],\n                    [0.180842, 0.238059, 0.292488, 0.150209, 0.149395],\n                ],\n                # Class 2\n                [\n                    [0.164284, 0.182947, 0.387733, 0.514285, 0.638795],\n                    [0.167644, 0.176419, 0.497592, 0.624705, 0.506766],\n                    [0.374455, 0.491115, 0.612060, 0.487281, 0.372412],\n                    [0.492602, 0.609557, 0.480947, 0.161909, 0.161476],\n                    [0.610678, 0.480516, 0.352479, 0.159380, 0.158274],\n                ],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batche(s), 3 dimension(s), 2 classe(s), 1 channel(s)\",\n        # Parameters\n        [\n            2,  # iterations\n            1.0,  # bilateral_weight\n            0.3,  # gaussian_weight\n            5.0,  # bilateral_spatial_sigma\n            0.1,  # bilateral_color_sigma\n            5.0,  # gaussian_spatial_sigma\n            1.0,  # update_factor\n            None,  # compatibility_matrix\n        ],\n        # Input\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    # Slice 0\n                    [\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 1\n                    [\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 2\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 3\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 4\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                ],\n                # Class 1\n                [\n                    # Slice 0\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 1\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 2\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 3\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                    ],\n                    # Slice 4\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                    ],\n                ],\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Slice 0\n                    [\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 1\n                    [\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    # Slice 2\n                    [\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.5, 0.0, 0.0],\n                        [0.5, 0.5, 0.8, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                    ],\n                    # Slice 3\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                    ],\n                    # Slice 4\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                    ],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Class 0\n                [\n                    # Slice 0\n                    [\n                        [0.778237, 0.777561, 0.561416, 0.501611, 0.501294],\n                        [0.777517, 0.776882, 0.560301, 0.501103, 0.500791],\n                        [0.561231, 0.560339, 0.559060, 0.500619, 0.500311],\n                        [0.501322, 0.500872, 0.500468, 0.500156, 0.499851],\n                        [0.500883, 0.500449, 0.500059, 0.499713, 0.499420],\n                    ],\n                    # Slice 1\n                    [\n                        [0.777409, 0.776861, 0.560182, 0.501111, 0.500808],\n                        [0.776784, 0.776102, 0.558887, 0.500618, 0.500329],\n                        [0.559943, 0.558969, 0.557350, 0.500183, 0.499897],\n                        [0.500789, 0.500403, 0.500052, 0.499765, 0.499487],\n                        [0.500378, 0.500005, 0.499668, 0.499363, 0.499072],\n                    ],\n                    # Slice 2\n                    [\n                        [0.560846, 0.560185, 0.558597, 0.500660, 0.500369],\n                        [0.560078, 0.559146, 0.556974, 0.500209, 0.499950],\n                        [0.558225, 0.557130, 0.486025, 0.448784, 0.445606],\n                        [0.500340, 0.500005, 0.448945, 0.448551, 0.444201],\n                        [0.499972, 0.499644, 0.448537, 0.447195, 0.443425],\n                    ],\n                    # Slice 3\n                    [\n                        [0.500887, 0.500713, 0.500529, 0.500251, 0.499999],\n                        [0.500596, 0.500312, 0.500109, 0.499848, 0.499605],\n                        [0.500301, 0.500002, 0.447391, 0.445814, 0.442289],\n                        [0.499940, 0.499662, 0.447338, 0.227284, 0.225224],\n                        [0.499650, 0.499367, 0.445866, 0.227800, 0.225564],\n                    ],\n                    # Slice 4\n                    [\n                        [0.500399, 0.500241, 0.500090, 0.499883, 0.499637],\n                        [0.500134, 0.499888, 0.499756, 0.499526, 0.499261],\n                        [0.499888, 0.499631, 0.446166, 0.442215, 0.440038],\n                        [0.499603, 0.499369, 0.445307, 0.225463, 0.223935],\n                        [0.499337, 0.499113, 0.443668, 0.226403, 0.224790],\n                    ],\n                ],\n                # Class 1\n                [\n                    # Slice 0\n                    [\n                        [0.221763, 0.222439, 0.438584, 0.498389, 0.498706],\n                        [0.222483, 0.223118, 0.439699, 0.498897, 0.499209],\n                        [0.438769, 0.439661, 0.440940, 0.499381, 0.499689],\n                        [0.498678, 0.499128, 0.499532, 0.499844, 0.500149],\n                        [0.499117, 0.499551, 0.499941, 0.500287, 0.500580],\n                    ],\n                    # Slice 1\n                    [\n                        [0.222591, 0.223139, 0.439818, 0.498889, 0.499192],\n                        [0.223216, 0.223898, 0.441113, 0.499382, 0.499671],\n                        [0.440057, 0.441031, 0.442650, 0.499817, 0.500103],\n                        [0.499211, 0.499597, 0.499948, 0.500235, 0.500513],\n                        [0.499622, 0.499995, 0.500332, 0.500637, 0.500928],\n                    ],\n                    # Slice 2\n                    [\n                        [0.439154, 0.439815, 0.441403, 0.499340, 0.499631],\n                        [0.439922, 0.440854, 0.443026, 0.499791, 0.500050],\n                        [0.441775, 0.442870, 0.513975, 0.551216, 0.554394],\n                        [0.499660, 0.499995, 0.551055, 0.551449, 0.555799],\n                        [0.500028, 0.500356, 0.551463, 0.552805, 0.556575],\n                    ],\n                    # Slice 3\n                    [\n                        [0.499113, 0.499287, 0.499471, 0.499749, 0.500001],\n                        [0.499404, 0.499688, 0.499891, 0.500152, 0.500395],\n                        [0.499699, 0.499998, 0.552609, 0.554186, 0.557711],\n                        [0.500060, 0.500338, 0.552662, 0.772716, 0.774776],\n                        [0.500350, 0.500633, 0.554134, 0.772200, 0.774436],\n                    ],\n                    # Slice 4\n                    [\n                        [0.499601, 0.499759, 0.499910, 0.500117, 0.500363],\n                        [0.499866, 0.500112, 0.500244, 0.500474, 0.500739],\n                        [0.500112, 0.500369, 0.553834, 0.557785, 0.559962],\n                        [0.500397, 0.500631, 0.554693, 0.774537, 0.776065],\n                        [0.500663, 0.500887, 0.556332, 0.773597, 0.775210],\n                    ],\n                ],\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\n@skip_if_no_cuda\nclass CRFTestCaseCuda(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test(self, test_case_description, params, input, features, expected):\n        # Create input tensors\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device(\"cuda\"))\n        feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device(\"cuda\"))\n\n        params[-1] = None if params[-1] is None else params[-1].cuda()\n\n        # apply filter\n        crf = CRF(*params)\n        output = crf(input_tensor, feature_tensor).cpu().numpy()\n\n        # Ensure result are as expected\n        # np.testing.assert_allclose(output, expected, atol=1e-4)\n\n        # Temporarily allowing some (10%) mismatched elements due to non determinism.\n        absolute_diff_tolerance = 5e-2\n        mismatch_ratio_tolerance = 0.1\n\n        output = np.array(output).flatten()\n        expected = np.array(expected).flatten()\n\n        abs_diff = abs(output - expected)\n        mismatch_count = sum(np.where(abs_diff > absolute_diff_tolerance, 1, 0))\n\n        self.assertLessEqual(mismatch_count / len(output), mismatch_ratio_tolerance)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_crossattention.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.crossattention import CrossAttentionBlock\nfrom monai.networks.layers.factories import RelPosEmbedding\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose, dict_product\n\neinops, has_einops = optional_import(\"einops\")\n\nTEST_CASE_CABLOCK = [\n    [\n        {\n            **{k: v for k, v in params.items() if k not in [\"rel_pos_embedding_val\"]},\n            \"rel_pos_embedding\": params[\"rel_pos_embedding_val\"] if not params[\"use_flash_attention\"] else None,\n        },\n        (2, 512, params[\"hidden_size\"]),\n        (2, 512, params[\"hidden_size\"]),\n    ]\n    for params in dict_product(\n        dropout_rate=np.linspace(0, 1, 4),\n        hidden_size=[360, 480, 600, 768],\n        num_heads=[4, 6, 8, 12],\n        rel_pos_embedding_val=[None, RelPosEmbedding.DECOMPOSED],\n        input_size=[(16, 32), (8, 8, 8)],\n        use_flash_attention=[True, False],\n    )\n]\n\n\nclass TestResBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_CABLOCK)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        # Without flash attention\n        net = CrossAttentionBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param[\"hidden_size\"]))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0)\n\n        with self.assertRaises(ValueError):\n            CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4)\n\n    def test_save_attn_with_flash_attention(self):\n        with self.assertRaises(ValueError):\n            CrossAttentionBlock(\n                hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True\n            )\n\n    def test_rel_pos_embedding_with_flash_attention(self):\n        with self.assertRaises(ValueError):\n            CrossAttentionBlock(\n                hidden_size=128,\n                num_heads=3,\n                dropout_rate=0.1,\n                use_flash_attention=True,\n                save_attn=False,\n                rel_pos_embedding=RelPosEmbedding.DECOMPOSED,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_attention_dim_not_multiple_of_heads(self):\n        with self.assertRaises(ValueError):\n            CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_inner_dim_different(self):\n        CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30)\n\n    def test_causal_no_sequence_length(self):\n        with self.assertRaises(ValueError):\n            CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_causal_flash_attention(self):\n        block = CrossAttentionBlock(\n            hidden_size=128,\n            num_heads=1,\n            dropout_rate=0.1,\n            causal=True,\n            sequence_length=16,\n            save_attn=False,\n            use_flash_attention=True,\n        )\n        input_shape = (1, 16, 128)\n        # Check it runs correctly\n        block(torch.randn(input_shape))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_causal(self):\n        block = CrossAttentionBlock(\n            hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True\n        )\n        input_shape = (1, 16, 128)\n        block(torch.randn(input_shape))\n        # check upper triangular part of the attention matrix is zero\n        assert torch.triu(block.att_mat, diagonal=1).sum() == 0\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_context_input(self):\n        block = CrossAttentionBlock(\n            hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12\n        )\n        input_shape = (1, 16, 128)\n        block(torch.randn(input_shape), context=torch.randn(1, 3, 12))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_context_wrong_input_size(self):\n        block = CrossAttentionBlock(\n            hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12\n        )\n        input_shape = (1, 16, 128)\n        with self.assertRaises(RuntimeError):\n            block(torch.randn(input_shape), context=torch.randn(1, 3, 24))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_access_attn_matrix(self):\n        # input format\n        hidden_size = 128\n        num_heads = 2\n        dropout_rate = 0\n        input_shape = (2, 256, hidden_size)\n\n        # be  not able to access the matrix\n        no_matrix_acess_blk = CrossAttentionBlock(\n            hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate\n        )\n        no_matrix_acess_blk(torch.randn(input_shape))\n        assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor)\n        # no of elements is zero\n        assert no_matrix_acess_blk.att_mat.nelement() == 0\n\n        # be able to acess the attention matrix.\n        matrix_acess_blk = CrossAttentionBlock(\n            hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True\n        )\n        matrix_acess_blk(torch.randn(input_shape))\n        assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])\n\n    @parameterized.expand([[True], [False]])\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_flash_attention(self, causal):\n        input_param = {\"hidden_size\": 128, \"num_heads\": 1, \"causal\": causal, \"sequence_length\": 16 if causal else None}\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)\n        block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)\n        block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())\n        test_data = torch.randn(1, 16, 128).to(device)\n\n        out_1 = block_w_flash_attention(test_data)\n        out_2 = block_wo_flash_attention(test_data)\n        assert_allclose(out_1, out_2, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_denseblock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch.nn as nn\n\nfrom monai.networks.blocks import ConvDenseBlock, DenseBlock\nfrom tests.test_utils import TorchImageTestCase2D, TorchImageTestCase3D\n\n\nclass TestDenseBlock2D(TorchImageTestCase2D):\n    def test_block_empty(self):\n        block = DenseBlock([])\n        out = block(self.imt)\n        expected_shape = self.imt.shape\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_block_conv(self):\n        conv1 = nn.Conv2d(self.input_channels, self.output_channels, 3, padding=1)\n        conv2 = nn.Conv2d(self.input_channels + self.output_channels, self.input_channels, 3, padding=1)\n        block = DenseBlock([conv1, conv2])\n        out = block(self.imt)\n        expected_shape = (1, self.output_channels + self.input_channels * 2, self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n\nclass TestDenseBlock3D(TorchImageTestCase3D):\n    def test_block_conv(self):\n        conv1 = nn.Conv3d(self.input_channels, self.output_channels, 3, padding=1)\n        conv2 = nn.Conv3d(self.input_channels + self.output_channels, self.input_channels, 3, padding=1)\n        block = DenseBlock([conv1, conv2])\n        out = block(self.imt)\n        expected_shape = (\n            1,\n            self.output_channels + self.input_channels * 2,\n            self.im_shape[0],\n            self.im_shape[1],\n            self.im_shape[2],\n        )\n        self.assertEqual(out.shape, expected_shape)\n\n\nclass TestConvDenseBlock2D(TorchImageTestCase2D):\n    def test_block_empty(self):\n        conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=[])\n        out = conv(self.imt)\n        expected_shape = self.imt.shape\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_except(self):\n        with self.assertRaises(ValueError):\n            _ = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=[1, 2], dilations=[1, 2, 3])\n\n    def test_block1(self):\n        channels = [2, 4]\n        conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=channels)\n        out = conv(self.imt)\n        expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_block2(self):\n        channels = [2, 4]\n        dilations = [1, 2]\n        conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=channels, dilations=dilations)\n        out = conv(self.imt)\n        expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1])\n        self.assertEqual(out.shape, expected_shape)\n\n\nclass TestConvDenseBlock3D(TorchImageTestCase3D):\n    def test_block_empty(self):\n        conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=[])\n        out = conv(self.imt)\n        expected_shape = self.imt.shape\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_block1(self):\n        channels = [2, 4]\n        conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=channels)\n        out = conv(self.imt)\n        expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n    def test_block2(self):\n        channels = [2, 4]\n        dilations = [1, 2]\n        conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=channels, dilations=dilations)\n        out = conv(self.imt)\n        expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1], self.im_shape[2])\n        self.assertEqual(out.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_downsample_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample\nfrom monai.utils import optional_import\n\neinops, has_einops = optional_import(\"einops\")\n\nTEST_CASES = [\n    [{\"spatial_dims\": 2, \"kernel_size\": 2}, (7, 4, 64, 48), (7, 8, 32, 24)],  # 4-channel 2D, batch 7\n    [{\"spatial_dims\": 1, \"kernel_size\": 4}, (16, 4, 63), (16, 8, 15)],  # 4-channel 1D, batch 16\n    [{\"spatial_dims\": 1, \"kernel_size\": 4, \"padding\": 1}, (16, 4, 63), (16, 8, 16)],  # 4-channel 1D, batch 16\n    [  # 4-channel 3D, batch 16\n        {\"spatial_dims\": 3, \"kernel_size\": 3, \"ceil_mode\": True},\n        (16, 4, 32, 24, 48),\n        (16, 8, 11, 8, 16),\n    ],\n    [  # 1-channel 3D, batch 16\n        {\"spatial_dims\": 3, \"kernel_size\": 3, \"ceil_mode\": False},\n        (16, 1, 32, 24, 48),\n        (16, 2, 10, 8, 16),\n    ],\n]\n\nTEST_CASES_SUBPIXEL = [\n    [{\"spatial_dims\": 2, \"in_channels\": 1, \"scale_factor\": 2}, (1, 1, 8, 8), (1, 4, 4, 4)],\n    [{\"spatial_dims\": 3, \"in_channels\": 2, \"scale_factor\": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)],\n    [{\"spatial_dims\": 1, \"in_channels\": 3, \"scale_factor\": 2}, (1, 3, 8), (1, 6, 4)],\n]\n\nTEST_CASES_DOWNSAMPLE = [\n    [{\"spatial_dims\": 2, \"in_channels\": 4, \"mode\": \"conv\"}, (1, 4, 16, 16), (1, 4, 8, 8)],\n    [{\"spatial_dims\": 2, \"in_channels\": 4, \"out_channels\": 8, \"mode\": \"convgroup\"}, (1, 4, 16, 16), (1, 8, 8, 8)],\n    [{\"spatial_dims\": 3, \"in_channels\": 2, \"mode\": \"maxpool\"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)],\n    [{\"spatial_dims\": 2, \"in_channels\": 4, \"mode\": \"avgpool\"}, (1, 4, 16, 16), (1, 4, 8, 8)],\n    [{\"spatial_dims\": 2, \"in_channels\": 1, \"mode\": \"pixelunshuffle\"}, (1, 1, 16, 16), (1, 4, 8, 8)],\n]\n\n\nclass TestMaxAvgPool(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = MaxAvgPool(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestSubpixelDownsample(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_SUBPIXEL)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        downsampler = SubpixelDownsample(**input_param)\n        with eval_mode(downsampler):\n            result = downsampler(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_predefined_tensor(self):\n        test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4)\n        test_tensor = test_tensor.unsqueeze(0)\n\n        downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)\n        with eval_mode(downsampler):\n            result = downsampler(test_tensor)\n            self.assertEqual(result.shape, (1, 16, 2, 2))\n            self.assertTrue(torch.all(result[0, 0:3] == 0))\n            self.assertTrue(torch.all(result[0, 4:7] == 1))\n            self.assertTrue(torch.all(result[0, 8:11] == 2))\n            self.assertTrue(torch.all(result[0, 12:15] == 3))\n\n    def test_reconstruction_2d(self):\n        input_tensor = torch.randn(1, 1, 4, 4)\n        down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None)\n        up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)\n        with eval_mode(down), eval_mode(up):\n            downsampled = down(input_tensor)\n            reconstructed = up(downsampled)\n            self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))\n\n    def test_reconstruction_3d(self):\n        input_tensor = torch.randn(1, 1, 4, 4, 4)\n        down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None)\n        up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False)\n        with eval_mode(down), eval_mode(up):\n            downsampled = down(input_tensor)\n            reconstructed = up(downsampled)\n            self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5))\n\n    def test_invalid_spatial_size(self):\n        downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2)\n        with self.assertRaises(ValueError):\n            downsampler(torch.randn(1, 1, 3, 4))\n\n    def test_custom_conv_block(self):\n        custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1)\n        downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv)\n        with eval_mode(downsampler):\n            result = downsampler(torch.randn(1, 1, 4, 4))\n            self.assertEqual(result.shape, (1, 8, 2, 2))\n\n\nclass TestDownSample(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_DOWNSAMPLE)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = DownSample(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_pre_post_conv(self):\n        net = DownSample(\n            spatial_dims=2,\n            in_channels=4,\n            out_channels=8,\n            mode=\"maxpool\",\n            pre_conv=\"default\",\n            post_conv=torch.nn.Conv2d(8, 16, 1),\n        )\n        with eval_mode(net):\n            result = net(torch.randn(1, 4, 16, 16))\n            self.assertEqual(result.shape, (1, 16, 8, 8))\n\n    def test_pixelunshuffle_equivalence(self):\n        class DownSampleLocal(torch.nn.Module):\n            def __init__(self, n_feat: int):\n                super().__init__()\n                self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)\n                self.pixelunshuffle = torch.nn.PixelUnshuffle(2)\n\n            def forward(self, x: torch.Tensor) -> torch.Tensor:\n                x = self.conv(x)\n                x = self.pixelunshuffle(x)\n                return x\n\n        n_feat = 2\n        x = torch.randn(1, n_feat, 64, 64)\n\n        fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)\n\n        monai_down = DownSample(\n            spatial_dims=2,\n            in_channels=n_feat,\n            out_channels=n_feat // 2,\n            mode=\"pixelunshuffle\",\n            pre_conv=fix_weight_conv,\n        )\n\n        local_down = DownSampleLocal(n_feat)\n        local_down.conv.weight.data = fix_weight_conv.weight.data.clone()\n\n        with eval_mode(monai_down), eval_mode(local_down):\n            out_monai = monai_down(x)\n            out_local = local_down(x)\n\n        self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5))\n\n    def test_invalid_mode(self):\n        with self.assertRaises(ValueError):\n            DownSample(spatial_dims=2, in_channels=4, mode=\"invalid\")\n\n    def test_missing_channels(self):\n        with self.assertRaises(ValueError):\n            DownSample(spatial_dims=2, mode=\"conv\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_dynunet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding\nfrom tests.test_utils import dict_product, test_script_save\n\nTEST_CASE_RES_BASIC_BLOCK = []\nfor params in dict_product(\n    spatial_dims=range(2, 4),\n    kernel_size=[1, 3],\n    stride=[1, 2],\n    norm_name=[(\"GROUP\", {\"num_groups\": 16}), (\"batch\", {\"track_running_stats\": False}), \"instance\"],\n    in_size=[15, 16],\n):\n    padding = get_padding(params[\"kernel_size\"], params[\"stride\"])\n    if not isinstance(padding, int):\n        padding = padding[0]\n    out_size = int((params[\"in_size\"] + 2 * padding - params[\"kernel_size\"]) / params[\"stride\"]) + 1\n    test_case = [\n        {\n            **{k: v for k, v in params.items() if k != \"in_size\"},\n            \"in_channels\": 16,\n            \"out_channels\": 16,\n            \"act_name\": (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.1}),\n        },\n        (1, 16, *([params[\"in_size\"]] * params[\"spatial_dims\"])),\n        (1, 16, *([out_size] * params[\"spatial_dims\"])),\n    ]\n    TEST_CASE_RES_BASIC_BLOCK.append(test_case)\n\nTEST_UP_BLOCK = []\nin_channels, out_channels = 4, 2\nfor params in dict_product(\n    spatial_dims=range(2, 4),\n    kernel_size=[1, 3],\n    stride=[1, 2],\n    norm_name=[\"batch\", \"instance\"],\n    in_size=[15, 16],\n    trans_bias=[True, False],\n):\n    out_size = params[\"in_size\"] * params[\"stride\"]\n    test_case = [\n        {\n            **{k: v for k, v in params.items() if k != \"in_size\"},\n            \"in_channels\": in_channels,\n            \"out_channels\": out_channels,\n            \"upsample_kernel_size\": params[\"stride\"],\n        },\n        (1, in_channels, *([params[\"in_size\"]] * params[\"spatial_dims\"])),\n        (1, out_channels, *([out_size] * params[\"spatial_dims\"])),\n        (1, out_channels, *([params[\"in_size\"] * params[\"stride\"]] * params[\"spatial_dims\"])),\n    ]\n    TEST_UP_BLOCK.append(test_case)\n\n\nclass TestResBasicBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_RES_BASIC_BLOCK)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        for net in [UnetResBlock(**input_param), UnetBasicBlock(**input_param)]:\n            with eval_mode(net):\n                result = net(torch.randn(input_shape))\n                self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            UnetBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name=\"norm\")\n        with self.assertRaises(AssertionError):\n            UnetResBlock(3, 4, 2, kernel_size=1, stride=4, norm_name=\"batch\")\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_RES_BASIC_BLOCK[0]\n\n        for net_type in (UnetResBlock, UnetBasicBlock):\n            net = net_type(**input_param)\n            test_data = torch.randn(input_shape)\n            test_script_save(net, test_data)\n\n\nclass TestUpBlock(unittest.TestCase):\n    @parameterized.expand(TEST_UP_BLOCK)\n    def test_shape(self, input_param, input_shape, expected_shape, skip_shape):\n        net = UnetUpBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape), torch.randn(skip_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _, skip_shape = TEST_UP_BLOCK[0]\n\n        net = UnetUpBlock(**input_param)\n        test_data = torch.randn(input_shape)\n        skip_data = torch.randn(skip_shape)\n        test_script_save(net, test_data, skip_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_fpn_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom collections import OrderedDict\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.blocks.backbone_fpn_utils import _resnet_fpn_extractor\nfrom monai.networks.blocks.feature_pyramid_network import FeaturePyramidNetwork\nfrom monai.networks.nets.resnet import resnet50\nfrom monai.utils import optional_import\nfrom tests.test_utils import test_script_save\n\n_, has_torchvision = optional_import(\"torchvision\")\n\nTEST_CASES = [\n    [\n        {\"spatial_dims\": 3, \"in_channels_list\": [32, 64], \"out_channels\": 6},\n        ((7, 32, 16, 32, 64), (7, 64, 8, 16, 32)),\n        ((7, 6, 16, 32, 64), (7, 6, 8, 16, 32)),\n    ],\n    [\n        {\"spatial_dims\": 2, \"in_channels_list\": [32, 64], \"out_channels\": 6},\n        ((7, 32, 16, 32), (7, 64, 8, 16)),\n        ((7, 6, 16, 32), (7, 6, 8, 16)),\n    ],\n]\n\nTEST_CASES2 = [\n    [{\"spatial_dims\": 3, \"returned_layers\": [1]}, (7, 3, 32, 64, 32), ((7, 256, 16, 32, 16), (7, 256, 8, 16, 8))]\n]\n\n\nclass TestFPNBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_fpn_block(self, input_param, input_shape, expected_shape):\n        net = FeaturePyramidNetwork(**input_param)\n        data = OrderedDict()\n        data[\"feat0\"] = torch.rand(input_shape[0])\n        data[\"feat1\"] = torch.rand(input_shape[1])\n        result = net(data)\n        self.assertEqual(result[\"feat0\"].shape, expected_shape[0])\n        self.assertEqual(result[\"feat1\"].shape, expected_shape[1])\n\n    @parameterized.expand(TEST_CASES)\n    def test_script(self, input_param, input_shape, expected_shape):\n        # test whether support torchscript\n        net = FeaturePyramidNetwork(**input_param)\n        data = OrderedDict()\n        data[\"feat0\"] = torch.rand(input_shape[0])\n        data[\"feat1\"] = torch.rand(input_shape[1])\n        test_script_save(net, data)\n\n\n@unittest.skipUnless(has_torchvision, \"Requires torchvision\")\nclass TestFPN(unittest.TestCase):\n    @parameterized.expand(TEST_CASES2)\n    def test_fpn(self, input_param, input_shape, expected_shape):\n        net = _resnet_fpn_extractor(backbone=resnet50(), spatial_dims=input_param[\"spatial_dims\"], returned_layers=[1])\n        data = torch.rand(input_shape)\n        result = net(data)\n        self.assertEqual(result[\"0\"].shape, expected_shape[0])\n        self.assertEqual(result[\"pool\"].shape, expected_shape[1])\n\n    @parameterized.expand(TEST_CASES2)\n    def test_script(self, input_param, input_shape, expected_shape):\n        # test whether support torchscript\n        net = _resnet_fpn_extractor(backbone=resnet50(), spatial_dims=input_param[\"spatial_dims\"], returned_layers=[1])\n        data = torch.rand(input_shape)\n        test_script_save(net, data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_localnet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.localnet_block import (\n    LocalNetDownSampleBlock,\n    LocalNetFeatureExtractorBlock,\n    LocalNetUpSampleBlock,\n)\n\nTEST_CASE_DOWN_SAMPLE = [\n    [{\"spatial_dims\": spatial_dims, \"in_channels\": 2, \"out_channels\": 4, \"kernel_size\": 3}] for spatial_dims in [2, 3]\n]\n\nTEST_CASE_UP_SAMPLE = [\n    [\n        {\n            \"spatial_dims\": spatial_dims,\n            \"in_channels\": 4,\n            \"out_channels\": 2,\n            \"mode\": \"bilinear\" if spatial_dims == 2 else \"trilinear\",\n        }\n    ]\n    for spatial_dims in [2, 3]\n]\n\nTEST_CASE_EXTRACT = [\n    [{\"spatial_dims\": spatial_dims, \"in_channels\": 2, \"out_channels\": 3, \"act\": act, \"initializer\": initializer}]\n    for spatial_dims, act, initializer in zip([2, 3], [\"sigmoid\", None], [\"kaiming_uniform\", \"zeros\"])\n]\n\nin_size = 4\n\n\nclass TestLocalNetDownSampleBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_DOWN_SAMPLE)\n    def test_shape(self, input_param):\n        net = LocalNetDownSampleBlock(**input_param)\n        input_shape = (1, input_param[\"in_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        expect_mid_shape = (1, input_param[\"out_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        expect_x_shape = (1, input_param[\"out_channels\"], *([in_size / 2] * input_param[\"spatial_dims\"]))\n        with eval_mode(net):\n            x, mid = net(torch.randn(input_shape))\n            self.assertEqual(x.shape, expect_x_shape)\n            self.assertEqual(mid.shape, expect_mid_shape)\n\n    def test_ill_arg(self):\n        # even kernel_size\n        with self.assertRaises(NotImplementedError):\n            LocalNetDownSampleBlock(spatial_dims=2, in_channels=2, out_channels=4, kernel_size=4)\n\n    @parameterized.expand(TEST_CASE_DOWN_SAMPLE)\n    def test_ill_shape(self, input_param):\n        net = LocalNetDownSampleBlock(**input_param)\n        input_shape = (1, input_param[\"in_channels\"], *([5] * input_param[\"spatial_dims\"]))\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net(torch.randn(input_shape))\n\n\nclass TestLocalNetUpSampleBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_UP_SAMPLE)\n    def test_shape(self, input_param):\n        net = LocalNetUpSampleBlock(**input_param)\n        input_shape = (1, input_param[\"in_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        mid_shape = (1, input_param[\"out_channels\"], *([in_size * 2] * input_param[\"spatial_dims\"]))\n        expected_shape = mid_shape\n        with eval_mode(net):\n            result = net(torch.randn(input_shape), torch.randn(mid_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        # channel unmatch\n        with self.assertRaises(ValueError):\n            LocalNetUpSampleBlock(spatial_dims=2, in_channels=2, out_channels=2)\n\n    @parameterized.expand(TEST_CASE_UP_SAMPLE)\n    def test_ill_shape(self, input_param):\n        net = LocalNetUpSampleBlock(**input_param)\n        input_shape = (1, input_param[\"in_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        mid_shape = (1, input_param[\"out_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net(torch.randn(input_shape), torch.randn(mid_shape))\n\n\nclass TestExtractBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_EXTRACT)\n    def test_shape(self, input_param):\n        net = LocalNetFeatureExtractorBlock(**input_param)\n        input_shape = (1, input_param[\"in_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        expected_shape = (1, input_param[\"out_channels\"], *([in_size] * input_param[\"spatial_dims\"]))\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            LocalNetFeatureExtractorBlock(spatial_dims=2, in_channels=2, out_channels=2, initializer=\"none\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_mlp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.mlp import MLPBlock\nfrom monai.networks.layers.factories import split_args\n\nTEST_CASE_MLP = []\nfor dropout_rate in np.linspace(0, 1, 4):\n    for hidden_size in [128, 256, 512, 768]:\n        for mlp_dim in [0, 1028, 2048, 3072]:\n            test_case = [\n                {\"hidden_size\": hidden_size, \"mlp_dim\": mlp_dim, \"dropout_rate\": dropout_rate},\n                (2, 512, hidden_size),\n                (2, 512, hidden_size),\n            ]\n            TEST_CASE_MLP.append(test_case)\n\n# test different activation layers\nTEST_CASE_ACT = []\nfor act in [\"GELU\", \"GEGLU\", (\"GEGLU\", {})]:  # type: ignore\n    TEST_CASE_ACT.append([{\"hidden_size\": 128, \"mlp_dim\": 0, \"act\": act}, (2, 512, 128), (2, 512, 128)])\n\n# test different dropout modes\nTEST_CASE_DROP = [[\"vit\", nn.Dropout], [\"swin\", nn.Dropout], [\"vista3d\", nn.Identity]]\n\n\nclass TestMLPBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_MLP)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = MLPBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0)\n\n    @parameterized.expand(TEST_CASE_ACT)\n    def test_act(self, input_param, input_shape, expected_shape):\n        net = MLPBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n        act_name, _ = split_args(input_param[\"act\"])\n        if act_name == \"GEGLU\":\n            self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2)\n        else:\n            self.assertEqual(net.linear1.in_features, net.linear1.out_features)\n\n    @parameterized.expand(TEST_CASE_DROP)\n    def test_dropout_mode(self, dropout_mode, dropout_layer):\n        net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode)\n        self.assertTrue(isinstance(net.drop1, dropout_layer))\n        self.assertTrue(isinstance(net.drop2, dropout_layer))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_patchembedding.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nimport torch.nn as nn\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock\nfrom monai.utils import optional_import\nfrom tests.test_utils import dict_product\n\neinops, has_einops = optional_import(\"einops\")\n\n\nTEST_CASE_PATCHEMBEDDINGBLOCK = [\n    [\n        params,\n        (2, params[\"in_channels\"], *([params[\"img_size\"]] * params[\"spatial_dims\"])),\n        (2, (params[\"img_size\"] // params[\"patch_size\"]) ** params[\"spatial_dims\"], params[\"hidden_size\"]),\n    ]\n    for params in dict_product(\n        dropout_rate=[0.5],\n        in_channels=[1, 4],\n        hidden_size=[96, 288],\n        img_size=[32, 64],\n        patch_size=[8, 16],\n        num_heads=[8, 12],\n        proj_type=[\"conv\", \"perceptron\"],\n        pos_embed_type=[\"none\", \"learnable\", \"sincos\"],\n        spatial_dims=[2, 3],\n    )\n]\n\nimg_size = 96\nTEST_CASE_PATCHEMBED = [\n    [\n        params,\n        (2, params[\"in_chans\"], *([img_size] * params[\"spatial_dims\"])),\n        (2, params[\"embed_dim\"], *([img_size // params[\"patch_size\"]]) * params[\"spatial_dims\"]),\n    ]\n    for params in dict_product(\n        patch_size=[2], in_chans=[1, 4], embed_dim=[6, 12], norm_layer=[nn.LayerNorm], spatial_dims=[2, 3]\n    )\n]\n\n\nclass TestPatchEmbeddingBlock(unittest.TestCase):\n    def setUp(self):\n        self.threads = torch.get_num_threads()\n        torch.set_num_threads(4)\n\n    def tearDown(self):\n        torch.set_num_threads(self.threads)\n\n    @parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = PatchEmbeddingBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_sincos_pos_embed(self):\n        net = PatchEmbeddingBlock(\n            in_channels=1,\n            img_size=(32, 32, 32),\n            patch_size=(8, 8, 8),\n            hidden_size=96,\n            num_heads=8,\n            pos_embed_type=\"sincos\",\n            dropout_rate=0.5,\n        )\n\n        self.assertEqual(net.position_embeddings.requires_grad, False)\n\n    def test_fourier_pos_embed(self):\n        net = PatchEmbeddingBlock(\n            in_channels=1,\n            img_size=(32, 32, 32),\n            patch_size=(8, 8, 8),\n            hidden_size=96,\n            num_heads=8,\n            pos_embed_type=\"fourier\",\n            dropout_rate=0.5,\n        )\n\n        self.assertEqual(net.position_embeddings.requires_grad, False)\n\n    def test_learnable_pos_embed(self):\n        net = PatchEmbeddingBlock(\n            in_channels=1,\n            img_size=(32, 32, 32),\n            patch_size=(8, 8, 8),\n            hidden_size=96,\n            num_heads=8,\n            pos_embed_type=\"learnable\",\n            dropout_rate=0.5,\n        )\n\n        self.assertEqual(net.position_embeddings.requires_grad, True)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(128, 128, 128),\n                patch_size=(16, 16, 16),\n                hidden_size=128,\n                num_heads=12,\n                proj_type=\"conv\",\n                dropout_rate=0.1,\n                pos_embed_type=\"fourier\",\n                pos_embed_kwargs=dict(scales=[1.0, 1.0]),\n            )\n\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(128, 128),\n                patch_size=(16, 16),\n                hidden_size=128,\n                num_heads=12,\n                proj_type=\"conv\",\n                dropout_rate=0.1,\n                pos_embed_type=\"fourier\",\n                pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]),\n            )\n\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(128, 128, 128),\n                patch_size=(16, 16, 16),\n                hidden_size=128,\n                num_heads=12,\n                proj_type=\"conv\",\n                pos_embed_type=\"sincos\",\n                dropout_rate=5.0,\n            )\n\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(32, 32, 32),\n                patch_size=(64, 64, 64),\n                hidden_size=512,\n                num_heads=8,\n                proj_type=\"perceptron\",\n                pos_embed_type=\"sincos\",\n                dropout_rate=0.3,\n            )\n\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(96, 96, 96),\n                patch_size=(8, 8, 8),\n                hidden_size=512,\n                num_heads=14,\n                proj_type=\"conv\",\n                dropout_rate=0.3,\n            )\n\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(97, 97, 97),\n                patch_size=(4, 4, 4),\n                hidden_size=768,\n                num_heads=8,\n                proj_type=\"perceptron\",\n                dropout_rate=0.3,\n            )\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=1,\n                img_size=(97, 97, 97),\n                patch_size=(4, 4, 4),\n                hidden_size=768,\n                num_heads=8,\n                proj_type=\"perceptron\",\n                dropout_rate=0.3,\n            )\n\n        with self.assertRaises(ValueError):\n            PatchEmbeddingBlock(\n                in_channels=4,\n                img_size=(96, 96, 96),\n                patch_size=(16, 16, 16),\n                hidden_size=768,\n                num_heads=12,\n                proj_type=\"perc\",\n                dropout_rate=0.3,\n            )\n\n\nclass TestPatchEmbed(unittest.TestCase):\n    def setUp(self):\n        self.threads = torch.get_num_threads()\n        torch.set_num_threads(4)\n\n    def tearDown(self):\n        torch.set_num_threads(self.threads)\n\n    @parameterized.expand(TEST_CASE_PATCHEMBED)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = PatchEmbed(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            PatchEmbed(patch_size=(2, 2, 2), in_chans=1, embed_dim=24, norm_layer=nn.LayerNorm, spatial_dims=5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_regunet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.regunet_block import (\n    RegistrationDownSampleBlock,\n    RegistrationExtractionBlock,\n    RegistrationResidualConvBlock,\n)\n\nTEST_CASE_RESIDUAL = [\n    [{\"spatial_dims\": 2, \"in_channels\": 1, \"out_channels\": 2, \"num_layers\": 1}, (1, 1, 5, 5), (1, 2, 5, 5)],\n    [{\"spatial_dims\": 3, \"in_channels\": 2, \"out_channels\": 2, \"num_layers\": 2}, (1, 2, 5, 5, 5), (1, 2, 5, 5, 5)],\n]\n\nTEST_CASE_DOWN_SAMPLE = [\n    [{\"spatial_dims\": 2, \"channels\": 1, \"pooling\": False}, (1, 1, 4, 4), (1, 1, 2, 2)],\n    [{\"spatial_dims\": 3, \"channels\": 2, \"pooling\": True}, (1, 2, 4, 4, 4), (1, 2, 2, 2, 2)],\n]\n\nTEST_CASE_EXTRACTION = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"extract_levels\": (0,),\n            \"num_channels\": [1],\n            \"out_channels\": 1,\n            \"kernel_initializer\": \"kaiming_uniform\",\n            \"activation\": None,\n        },\n        [(1, 1, 2, 2)],\n        (3, 3),\n        (1, 1, 3, 3),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"extract_levels\": (1, 2),\n            \"num_channels\": [1, 2, 3],\n            \"out_channels\": 1,\n            \"kernel_initializer\": \"zeros\",\n            \"activation\": \"sigmoid\",\n            \"mode\": \"trilinear\",\n        },\n        [(1, 3, 2, 2, 2), (1, 2, 4, 4, 4), (1, 1, 8, 8, 8)],\n        (3, 3, 3),\n        (1, 1, 3, 3, 3),\n    ],\n]\n\n\nclass TestRegistrationResidualConvBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_RESIDUAL)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = RegistrationResidualConvBlock(**input_param)\n        with eval_mode(net):\n            x = net(torch.randn(input_shape))\n            self.assertEqual(x.shape, expected_shape)\n\n\nclass TestRegistrationDownSampleBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_DOWN_SAMPLE)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = RegistrationDownSampleBlock(**input_param)\n        with eval_mode(net):\n            x = net(torch.rand(input_shape))\n            self.assertEqual(x.shape, expected_shape)\n\n    def test_ill_shape(self):\n        net = RegistrationDownSampleBlock(spatial_dims=2, channels=2, pooling=True)\n        with self.assertRaises(ValueError):\n            net(torch.rand((1, 2, 3, 3)))\n\n\nclass TestRegistrationExtractionBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_EXTRACTION)\n    def test_shape(self, input_param, input_shapes, image_size, expected_shape):\n        net = RegistrationExtractionBlock(**input_param)\n        with eval_mode(net):\n            x = net([torch.rand(input_shape) for input_shape in input_shapes], image_size)\n            self.assertEqual(x.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_se_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import SEBlock\nfrom monai.networks.layers.factories import Act, Norm\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    [\n        {\"spatial_dims\": 2, \"in_channels\": 4, \"n_chns_1\": 20, \"n_chns_2\": 30, \"n_chns_3\": 4, \"r\": 2},\n        (7, 4, 64, 48),  # 4-channel 2D, batch 7\n        (7, 4, 64, 48),\n    ],\n    [\n        {\"spatial_dims\": 1, \"in_channels\": 3, \"n_chns_1\": 20, \"n_chns_2\": 30, \"n_chns_3\": 40, \"r\": 5},\n        (16, 3, 63),  # 3-channel 1D, batch 16\n        (16, 40, 63),\n    ],\n]\n\nTEST_CASES_3D = []\nfor type_1 in (\n    {\"kernel_size\": 3, \"act\": Act.PRELU, \"norm\": Norm.INSTANCE},\n    {\"kernel_size\": 1, \"act\": None, \"norm\": Norm.INSTANCE},\n):\n    for type_2 in (\n        {\"kernel_size\": 3, \"act\": Act.PRELU, \"norm\": Norm.INSTANCE},\n        {\"kernel_size\": 1, \"act\": None, \"norm\": Norm.INSTANCE},\n    ):\n        test_case = [\n            {\n                \"spatial_dims\": 3,\n                \"in_channels\": 10,\n                \"r\": 3,\n                \"n_chns_1\": 3,\n                \"n_chns_2\": 5,\n                \"n_chns_3\": 11,\n                \"conv_param_1\": type_1,\n                \"conv_param_3\": type_2,\n            },\n            (16, 10, 32, 24, 48),  # 10-channel 3D, batch 16\n            (16, 11, 32, 24, 48),\n        ]\n        TEST_CASES_3D.append(test_case)\n\n\nclass TestSEBlockLayer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES + TEST_CASES_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SEBlock(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASES[0]\n        net = SEBlock(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            SEBlock(spatial_dims=1, in_channels=4, n_chns_1=2, n_chns_2=3, n_chns_3=4, r=100)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_se_blocks.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import ChannelSELayer, ResidualSELayer\nfrom tests.test_utils import test_script_save\n\nTEST_CASES = [  # single channel 3D, batch 16\n    [{\"spatial_dims\": 2, \"in_channels\": 4, \"r\": 3}, (7, 4, 64, 48), (7, 4, 64, 48)],  # 4-channel 2D, batch 7\n    [  # 4-channel 1D, batch 16\n        {\"spatial_dims\": 1, \"in_channels\": 4, \"r\": 3, \"acti_type_1\": \"relu\"},\n        (16, 4, 63),\n        (16, 4, 63),\n    ],\n]\n\nTEST_CASES_3D = []\nfor type_1 in {\"relu\", \"relu6\", \"leakyrelu\"}:\n    for type_2 in {\"prelu\", \"sigmoid\", \"relu\"}:\n        test_case = [\n            {\"spatial_dims\": 3, \"in_channels\": 10, \"r\": 3, \"acti_type_1\": type_1, \"acti_type_2\": type_2},\n            (16, 10, 32, 24, 48),\n            (16, 10, 32, 24, 48),\n        ]\n        TEST_CASES_3D.append(test_case)\n\n\nclass TestChannelSELayer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES + TEST_CASES_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = ChannelSELayer(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASES[0]\n        net = ChannelSELayer(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            ChannelSELayer(spatial_dims=1, in_channels=4, r=100)\n\n\nclass TestResidualSELayer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES[:1])\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = ResidualSELayer(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASES[0]\n        net = ResidualSELayer(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_segresnet_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.segresnet_block import ResBlock\nfrom tests.test_utils import dict_product\n\nTEST_CASE_RESBLOCK = [\n    [\n        params,\n        (2, params[\"in_channels\"], *([16] * params[\"spatial_dims\"])),\n        (2, params[\"in_channels\"], *([16] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(\n        spatial_dims=range(2, 4),\n        in_channels=range(1, 4),\n        kernel_size=[1, 3],\n        norm=[(\"group\", {\"num_groups\": 1}), \"batch\", \"instance\"],\n    )\n]\n\n\nclass TestResBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_RESBLOCK)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = ResBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(AssertionError):\n            ResBlock(spatial_dims=3, in_channels=8, norm=\"group\", kernel_size=2)\n        with self.assertRaises(ValueError):\n            ResBlock(spatial_dims=3, in_channels=8, norm=\"norm\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_selfattention.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.selfattention import SABlock\nfrom monai.networks.layers.factories import RelPosEmbedding\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose, test_script_save\n\neinops, has_einops = optional_import(\"einops\")\n\nTEST_CASE_SABLOCK = []\nfor dropout_rate in np.linspace(0, 1, 4):\n    for hidden_size in [360, 480, 600, 768]:\n        for num_heads in [4, 6, 8, 12]:\n            for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:\n                for input_size in [(16, 32), (8, 8, 8)]:\n                    for include_fc in [True, False]:\n                        for use_combined_linear in [True, False]:\n                            test_case = [\n                                {\n                                    \"hidden_size\": hidden_size,\n                                    \"num_heads\": num_heads,\n                                    \"dropout_rate\": dropout_rate,\n                                    \"rel_pos_embedding\": rel_pos_embedding,\n                                    \"input_size\": input_size,\n                                    \"include_fc\": include_fc,\n                                    \"use_combined_linear\": use_combined_linear,\n                                    \"use_flash_attention\": True if rel_pos_embedding is None else False,\n                                },\n                                (2, 512, hidden_size),\n                                (2, 512, hidden_size),\n                            ]\n                            TEST_CASE_SABLOCK.append(test_case)\n\n\nclass TestResBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_SABLOCK)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SABlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            SABlock(hidden_size=128, num_heads=12, dropout_rate=6.0)\n\n        with self.assertRaises(ValueError):\n            SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)\n\n    def test_rel_pos_embedding_with_flash_attention(self):\n        with self.assertRaises(ValueError):\n            SABlock(\n                hidden_size=128,\n                num_heads=3,\n                dropout_rate=0.1,\n                use_flash_attention=True,\n                save_attn=False,\n                rel_pos_embedding=RelPosEmbedding.DECOMPOSED,\n            )\n\n    def test_save_attn_with_flash_attention(self):\n        with self.assertRaises(ValueError):\n            SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)\n\n    def test_attention_dim_not_multiple_of_heads(self):\n        with self.assertRaises(ValueError):\n            SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_inner_dim_different(self):\n        SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30)\n\n    def test_causal_no_sequence_length(self):\n        with self.assertRaises(ValueError):\n            SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_causal_flash_attention(self):\n        block = SABlock(\n            hidden_size=128,\n            num_heads=1,\n            dropout_rate=0.1,\n            causal=True,\n            sequence_length=16,\n            save_attn=False,\n            use_flash_attention=True,\n        )\n        input_shape = (1, 16, 128)\n        # Check it runs correctly\n        block(torch.randn(input_shape))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_causal(self):\n        block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True)\n        input_shape = (1, 16, 128)\n        block(torch.randn(input_shape))\n        # check upper triangular part of the attention matrix is zero\n        assert torch.triu(block.att_mat, diagonal=1).sum() == 0\n\n    def test_masked_selfattention(self):\n        n = 64\n        block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)\n        input_shape = (1, n, 128)\n        # generate a mask randomly with zeros and ones of shape (1, n)\n        mask = torch.randint(0, 2, (1, n)).bool()\n        block(torch.randn(input_shape), attn_mask=mask)\n        att_mat = block.att_mat.squeeze()\n        # ensure all masked columns are zeros\n        assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))\n\n    def test_causal_and_mask(self):\n        with self.assertRaises(ValueError):\n            block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)\n            inputs = torch.randn(2, 64, 128)\n            mask = torch.randint(0, 2, (2, 64)).bool()\n            block(inputs, attn_mask=mask)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_access_attn_matrix(self):\n        # input format\n        hidden_size = 128\n        num_heads = 2\n        dropout_rate = 0\n        input_shape = (2, 256, hidden_size)\n\n        # be  not able to access the matrix\n        no_matrix_acess_blk = SABlock(hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate)\n        no_matrix_acess_blk(torch.randn(input_shape))\n        assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor)\n        # no of elements is zero\n        assert no_matrix_acess_blk.att_mat.nelement() == 0\n\n        # be able to acess the attention matrix\n        matrix_acess_blk = SABlock(\n            hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True\n        )\n        matrix_acess_blk(torch.randn(input_shape))\n        assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])\n\n    def test_number_of_parameters(self):\n        def count_sablock_params(*args, **kwargs):\n            \"\"\"Count the number of parameters in a SABlock.\"\"\"\n            sablock = SABlock(*args, **kwargs)\n            return sum([x.numel() for x in sablock.parameters() if x.requires_grad])\n\n        hidden_size = 128\n        num_heads = 8\n        default_dim_head = hidden_size // num_heads\n\n        # Default dim_head is hidden_size // num_heads\n        nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads)\n        nparams_like_default = count_sablock_params(\n            hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head\n        )\n        self.assertEqual(nparams_default, nparams_like_default)\n\n        # Increasing dim_head should increase the number of parameters\n        nparams_custom_large = count_sablock_params(\n            hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2\n        )\n        self.assertGreater(nparams_custom_large, nparams_default)\n\n        # Decreasing dim_head should decrease the number of parameters\n        nparams_custom_small = count_sablock_params(\n            hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2\n        )\n        self.assertGreater(nparams_default, nparams_custom_small)\n\n        # Increasing the number of heads with the default behaviour should not change the number of params.\n        nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2)\n        self.assertEqual(nparams_default, nparams_default_more_heads)\n\n    @parameterized.expand([[True, False], [True, True], [False, True], [False, False]])\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_script(self, include_fc, use_combined_linear):\n        input_param = {\n            \"hidden_size\": 360,\n            \"num_heads\": 4,\n            \"dropout_rate\": 0.0,\n            \"rel_pos_embedding\": None,\n            \"input_size\": (16, 32),\n            \"include_fc\": include_fc,\n            \"use_combined_linear\": use_combined_linear,\n        }\n        net = SABlock(**input_param)\n        input_shape = (2, 512, 360)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_flash_attention(self):\n        for causal in [True, False]:\n            input_param = {\"hidden_size\": 360, \"num_heads\": 4, \"input_size\": (16, 32), \"causal\": causal}\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device)\n        block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device)\n        block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())\n        test_data = torch.randn(2, 512, 360).to(device)\n\n        out_1 = block_w_flash_attention(test_data)\n        out_2 = block_wo_flash_attention(test_data)\n        assert_allclose(out_1, out_2, atol=1e-4)\n\n    @parameterized.expand([[True], [False]])\n    def test_no_extra_weights_if_no_fc(self, include_fc):\n        input_param = {\n            \"hidden_size\": 360,\n            \"num_heads\": 4,\n            \"dropout_rate\": 0.0,\n            \"rel_pos_embedding\": None,\n            \"input_size\": (16, 32),\n            \"include_fc\": include_fc,\n            \"use_combined_linear\": use_combined_linear,\n        }\n        net = SABlock(**input_param)\n        if not include_fc:\n            self.assertNotIn(\"out_proj.weight\", net.state_dict())\n            self.assertNotIn(\"out_proj.bias\", net.state_dict())\n            self.assertIsInstance(net.out_proj, torch.nn.Identity)\n        else:\n            self.assertIn(\"out_proj.weight\", net.state_dict())\n            self.assertIn(\"out_proj.bias\", net.state_dict())\n            self.assertIsInstance(net.out_proj, torch.nn.Linear)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_simple_aspp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import SimpleASPP\n\nTEST_CASES = [\n    [  # 32-channel 2D, batch 7\n        {\"spatial_dims\": 2, \"in_channels\": 32, \"conv_out_channels\": 3, \"norm_type\": (\"batch\", {\"affine\": False})},\n        (7, 32, 18, 20),\n        (7, 12, 18, 20),\n    ],\n    [  # 4-channel 1D, batch 16\n        {\"spatial_dims\": 1, \"in_channels\": 4, \"conv_out_channels\": 8, \"acti_type\": (\"PRELU\", {\"num_parameters\": 32})},\n        (16, 4, 17),\n        (16, 32, 17),\n    ],\n    [  # 3-channel 3D, batch 16\n        {\"spatial_dims\": 3, \"in_channels\": 3, \"conv_out_channels\": 2},\n        (16, 3, 17, 18, 19),\n        (16, 8, 17, 18, 19),\n    ],\n    [  # 3-channel 3D, batch 16\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"conv_out_channels\": 2,\n            \"kernel_sizes\": (1, 3, 3),\n            \"dilations\": (1, 2, 4),\n        },\n        (16, 3, 17, 18, 19),\n        (16, 6, 17, 18, 19),\n    ],\n]\n\nTEST_ILL_CASES = [\n    [  # 3-channel 3D, batch 16, wrong k and d sizes.\n        {\"spatial_dims\": 3, \"in_channels\": 3, \"conv_out_channels\": 2, \"kernel_sizes\": (1, 3, 3), \"dilations\": (1, 2)},\n        (16, 3, 17, 18, 19),\n        ValueError,\n    ],\n    [  # 3-channel 3D, batch 16, wrong k and d sizes.\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 3,\n            \"conv_out_channels\": 2,\n            \"kernel_sizes\": (1, 3, 4),\n            \"dilations\": (1, 2, 3),\n        },\n        (16, 3, 17, 18, 19),\n        NotImplementedError,  # unknown padding k=4, d=3\n    ],\n]\n\n\nclass TestChannelSELayer(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SimpleASPP(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_ILL_CASES)\n    def test_ill_args(self, input_param, input_shape, error_type):\n        with self.assertRaises(error_type):\n            SimpleASPP(**input_param)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_spatialattention.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.spatialattention import SpatialAttentionBlock\nfrom monai.utils import optional_import\n\neinops, has_einops = optional_import(\"einops\")\n\nTEST_CASES = [\n    [\n        {\"spatial_dims\": 2, \"num_channels\": 128, \"num_head_channels\": 32, \"norm_num_groups\": 32, \"norm_eps\": 1e-6},\n        (1, 128, 32, 32),\n        (1, 128, 32, 32),\n    ],\n    [\n        {\"spatial_dims\": 3, \"num_channels\": 16, \"num_head_channels\": 8, \"norm_num_groups\": 8, \"norm_eps\": 1e-6},\n        (1, 16, 8, 8, 8),\n        (1, 16, 8, 8, 8),\n    ],\n]\n\n\nclass TestBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SpatialAttentionBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_attention_dim_not_multiple_of_heads(self):\n        with self.assertRaises(ValueError):\n            SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_subpixel_upsample.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nimport torch.nn as nn\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import SubpixelUpsample\nfrom monai.networks.layers.factories import Conv\nfrom tests.test_utils import test_script_save\n\nTEST_CASE_SUBPIXEL = []\nfor inch in range(1, 5):\n    for dim in range(1, 4):\n        for factor in range(1, 3):\n            test_case = [\n                {\"spatial_dims\": dim, \"in_channels\": inch, \"scale_factor\": factor},\n                (2, inch, *([8] * dim)),\n                (2, inch, *([8 * factor] * dim)),\n            ]\n            TEST_CASE_SUBPIXEL.append(test_case)\n\nTEST_CASE_SUBPIXEL_2D_EXTRA = [\n    {\"spatial_dims\": 2, \"in_channels\": 2, \"scale_factor\": 3},\n    (2, 2, 8, 4),  # different size for H and W\n    (2, 2, 24, 12),\n]\n\nTEST_CASE_SUBPIXEL_3D_EXTRA = [\n    {\"spatial_dims\": 3, \"in_channels\": 1, \"scale_factor\": 2},\n    (2, 1, 16, 8, 4),  # different size for H, W and D\n    (2, 1, 32, 16, 8),\n]\n\nconv_block = nn.Sequential(\n    Conv[Conv.CONV, 3](1, 4, kernel_size=1), Conv[Conv.CONV, 3](4, 8, kernel_size=3, stride=1, padding=1)\n)\n\nTEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA = [\n    {\"spatial_dims\": 3, \"in_channels\": 1, \"scale_factor\": 2, \"conv_block\": conv_block},\n    (2, 1, 16, 8, 4),  # different size for H, W and D\n    (2, 1, 32, 16, 8),\n]\n\nTEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_2D_EXTRA)  # type: ignore\nTEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_3D_EXTRA)  # type: ignore\nTEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA)  # type: ignore\n\n# add every test back with the pad/pool sequential component omitted\nfor tests in list(TEST_CASE_SUBPIXEL):\n    args: dict = tests[0]  # type: ignore\n    args = dict(args)\n    args[\"apply_pad_pool\"] = False\n    TEST_CASE_SUBPIXEL.append([args, tests[1], tests[2]])\n\n\nclass TestSUBPIXEL(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_SUBPIXEL)\n    def test_subpixel_shape(self, input_param, input_shape, expected_shape):\n        net = SubpixelUpsample(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_SUBPIXEL[0]\n        net = SubpixelUpsample(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_text_encoding.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.networks.blocks.text_embedding import TextEncoder\nfrom tests.test_utils import skip_if_downloading_fails\n\n\nclass TestTextEncoder(unittest.TestCase):\n    def test_test_encoding_shape(self):\n        with skip_if_downloading_fails():\n            # test 2D encoder\n            text_encoder = TextEncoder(\n                spatial_dims=2, out_channels=32, encoding=\"clip_encoding_universal_model_32\", pretrained=True\n            )\n            text_encoding = text_encoder()\n            self.assertEqual(text_encoding.shape, (32, 256, 1, 1))\n\n            # test 3D encoder\n            text_encoder = TextEncoder(\n                spatial_dims=3, out_channels=32, encoding=\"clip_encoding_universal_model_32\", pretrained=True\n            )\n            text_encoding = text_encoder()\n            self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))\n\n        # test random enbedding 3D\n        text_encoder = TextEncoder(spatial_dims=3, out_channels=32, encoding=\"rand_embedding\", pretrained=True)\n        text_encoding = text_encoder()\n        self.assertEqual(text_encoding.shape, (32, 256, 1, 1, 1))\n\n        # test random enbedding 2D\n        text_encoder = TextEncoder(spatial_dims=2, out_channels=32, encoding=\"rand_embedding\", pretrained=True)\n        text_encoding = text_encoder()\n        self.assertEqual(text_encoding.shape, (32, 256, 1, 1))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_transformerblock.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.transformerblock import TransformerBlock\nfrom monai.utils import optional_import\nfrom tests.test_utils import dict_product\n\neinops, has_einops = optional_import(\"einops\")\nTEST_CASE_TRANSFORMERBLOCK = [\n    [params, (2, 512, params[\"hidden_size\"]), (2, 512, params[\"hidden_size\"])]\n    for params in dict_product(\n        dropout_rate=np.linspace(0, 1, 4),\n        hidden_size=[360, 480, 600, 768],\n        num_heads=[4, 8, 12],\n        mlp_dim=[1024, 3072],\n        with_cross_attention=[False, True],\n    )\n]\n\n\nclass TestTransformerBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = TransformerBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            TransformerBlock(hidden_size=128, num_heads=12, mlp_dim=2048, dropout_rate=4.0)\n\n        with self.assertRaises(ValueError):\n            TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_access_attn_matrix(self):\n        # input format\n        hidden_size = 128\n        mlp_dim = 12\n        num_heads = 2\n        dropout_rate = 0\n        input_shape = (2, 256, hidden_size)\n\n        # returns an empty attention matrix\n        no_matrix_acess_blk = TransformerBlock(\n            hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate\n        )\n        no_matrix_acess_blk(torch.randn(input_shape))\n        assert isinstance(no_matrix_acess_blk.attn.att_mat, torch.Tensor)\n        # no of elements is zero\n        assert no_matrix_acess_blk.attn.att_mat.nelement() == 0\n\n        # be able to acess the attention matrix\n        matrix_acess_blk = TransformerBlock(\n            hidden_size=hidden_size, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True\n        )\n        matrix_acess_blk(torch.randn(input_shape))\n        assert matrix_acess_blk.attn.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_unetr_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.dynunet_block import get_padding\nfrom monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock\nfrom tests.test_utils import dict_product, test_script_save\n\n\ndef _get_out_size(params):\n    in_size = params[\"in_size\"]\n    kernel_size = params[\"kernel_size\"]\n    stride = params[\"stride\"]\n    padding = get_padding(kernel_size, stride)\n    if not isinstance(padding, int):\n        padding = padding[0]\n    return int((in_size + 2 * padding - kernel_size) / stride) + 1\n\n\nnorm_names = [(\"GROUP\", {\"num_groups\": 16}), (\"batch\", {\"track_running_stats\": False}), \"instance\"]\nparam_dicts = dict_product(\n    spatial_dims=range(1, 4), kernel_size=[1, 3], stride=[2], norm_name=norm_names, in_size=[15, 16]\n)\nTEST_CASE_UNETR_BASIC_BLOCK = []\nfor params in param_dicts:\n    input_param = {**{k: v for k, v in params.items() if k != \"in_size\"}, \"in_channels\": 16, \"out_channels\": 16}\n    input_shape = (1, 16, *([params[\"in_size\"]] * params[\"spatial_dims\"]))\n    expected_shape = (1, 16, *([_get_out_size(params)] * params[\"spatial_dims\"]))\n    TEST_CASE_UNETR_BASIC_BLOCK.append([input_param, input_shape, expected_shape])\n\n\nTEST_UP_BLOCK = [\n    [\n        {\n            **{k: v for k, v in params.items() if k not in [\"in_size\", \"stride\", \"upsample_kernel_size\"]},\n            \"upsample_kernel_size\": params[\"stride\"],\n        },\n        (1, params[\"in_channels\"], *([params[\"in_size\"]] * params[\"spatial_dims\"])),\n        (1, params[\"out_channels\"], *([params[\"in_size\"] * params[\"stride\"]] * params[\"spatial_dims\"])),\n        (1, params[\"out_channels\"], *([params[\"in_size\"] * params[\"stride\"]] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(\n        spatial_dims=range(1, 4),\n        in_channels=[4],\n        out_channels=[2],\n        kernel_size=[1, 3],\n        norm_name=[\"instance\"],\n        res_block=[False, True],\n        upsample_kernel_size=[2, 3],\n        stride=[1, 2],\n        in_size=[15, 16],\n    )\n]\n\nTEST_PRUP_BLOCK = []\nin_channels, out_channels = 4, 2\nfor params in dict_product(\n    spatial_dims=range(1, 4),\n    kernel_size=[1, 3],\n    upsample_kernel_size=[2, 3],\n    stride=[1, 2],\n    res_block=[False, True],\n    norm_name=[\"instance\"],\n    in_size_scalar=[15, 16],\n    num_layer=[0, 2],\n):\n    in_size_tmp = params[\"in_size_scalar\"]\n    out_size = 0  # Initialize out_size\n    for _ in range(params[\"num_layer\"] + 1):\n        out_size = in_size_tmp * params[\"upsample_kernel_size\"]\n        in_size_tmp = out_size\n\n    test_case = [\n        {\n            **{k: v for k, v in params.items() if k != \"in_size_scalar\"},\n            \"in_channels\": in_channels,\n            \"out_channels\": out_channels,\n        },\n        (1, in_channels, *([params[\"in_size_scalar\"]] * params[\"spatial_dims\"])),\n        (1, out_channels, *([out_size] * params[\"spatial_dims\"])),\n    ]\n    TEST_PRUP_BLOCK.append(test_case)\n\n\nclass TestResBasicBlock(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_UNETR_BASIC_BLOCK)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        for net in [UnetrBasicBlock(**input_param)]:\n            with eval_mode(net):\n                result = net(torch.randn(input_shape))\n                self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            UnetrBasicBlock(3, 4, 2, kernel_size=3, stride=1, norm_name=\"norm\")\n        with self.assertRaises(AssertionError):\n            UnetrBasicBlock(3, 4, 2, kernel_size=1, stride=4, norm_name=\"batch\")\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_UNETR_BASIC_BLOCK[0]\n        net = UnetrBasicBlock(**input_param)\n        with eval_mode(net):\n            test_data = torch.randn(input_shape)\n            test_script_save(net, test_data)\n\n\nclass TestUpBlock(unittest.TestCase):\n    @parameterized.expand(TEST_UP_BLOCK)\n    def test_shape(self, input_param, input_shape, expected_shape, skip_shape):\n        net = UnetrUpBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape), torch.randn(skip_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _, skip_shape = TEST_UP_BLOCK[0]\n        net = UnetrUpBlock(**input_param)\n        test_data = torch.randn(input_shape)\n        skip_data = torch.randn(skip_shape)\n        test_script_save(net, test_data, skip_data)\n\n\nclass TestPrUpBlock(unittest.TestCase):\n    @parameterized.expand(TEST_PRUP_BLOCK)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = UnetrPrUpBlock(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_PRUP_BLOCK[0]\n        net = UnetrPrUpBlock(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/test_upsample_block.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import UpSample\nfrom monai.utils import UpsampleMode\n\nTEST_CASES = [\n    [{\"spatial_dims\": 2, \"in_channels\": 4}, (7, 4, 32, 48), (7, 4, 64, 96)],  # 4-channel 2D, batch 7\n    [{\"spatial_dims\": 1, \"in_channels\": 4, \"out_channels\": 3}, (16, 4, 63), (16, 3, 126)],  # 4-channel 1D, batch 16\n    [\n        {\"spatial_dims\": 1, \"in_channels\": 4, \"out_channels\": 8, \"mode\": \"deconv\", \"align_corners\": False},\n        (16, 4, 20),\n        (16, 8, 40),\n    ],  # 4-channel 1D, batch 16\n    [\n        {\"spatial_dims\": 3, \"in_channels\": 4, \"mode\": \"nontrainable\"},\n        (16, 4, 32, 24, 48),\n        (16, 4, 64, 48, 96),\n    ],  # 4-channel 3D, batch 16\n    [\n        {\"spatial_dims\": 3, \"in_channels\": 4, \"mode\": \"nontrainable\", \"size\": 64},\n        (16, 4, 32, 24, 48),\n        (16, 4, 64, 64, 64),\n    ],  # 4-channel 3D, batch 16\n    [\n        {\"spatial_dims\": 3, \"in_channels\": 4, \"mode\": \"nontrainable\", \"size\": (64, 24, 48)},\n        (16, 4, 32, 24, 48),\n        (16, 4, 64, 24, 48),\n    ],  # 4-channel 3D, batch 16\n    [\n        {\"spatial_dims\": 3, \"in_channels\": 1, \"mode\": \"deconv\", \"scale_factor\": 3, \"align_corners\": False},\n        (16, 1, 10, 15, 20),\n        (16, 1, 30, 45, 60),\n    ],  # 1-channel 3D, batch 16\n    [\n        {\"spatial_dims\": 3, \"in_channels\": 1, \"mode\": \"pixelshuffle\", \"scale_factor\": 2, \"align_corners\": False},\n        (16, 1, 10, 15, 20),\n        (16, 1, 20, 30, 40),\n    ],  # 1-channel 3D, batch 16\n    [\n        {\"spatial_dims\": 2, \"in_channels\": 4, \"mode\": \"pixelshuffle\", \"scale_factor\": 2},\n        (16, 4, 10, 15),\n        (16, 4, 20, 30),\n    ],  # 4-channel 2D, batch 16\n    [\n        {\n            \"spatial_dims\": 3,\n            \"mode\": \"pixelshuffle\",\n            \"scale_factor\": 2,\n            \"align_corners\": False,\n            \"pre_conv\": torch.nn.Conv3d(in_channels=1, out_channels=24, kernel_size=3, stride=1, padding=1),\n        },\n        (16, 1, 10, 15, 20),\n        (16, 3, 20, 30, 40),\n    ],  # 1-channel 3D, batch 16, pre_conv\n    [\n        {\"spatial_dims\": 3, \"in_channels\": 8, \"out_channels\": 4, \"mode\": \"deconvgroup\"},\n        (16, 8, 16, 16, 16),\n        (16, 4, 32, 32, 32),\n    ],  # 8-channel 3D, batch 16\n    [\n        {\"spatial_dims\": 2, \"in_channels\": 32, \"out_channels\": 16, \"mode\": \"deconvgroup\", \"scale_factor\": 2},\n        (8, 32, 16, 16),\n        (8, 16, 32, 32),\n    ],  # 32-channel 2D, batch 8\n]\n\nTEST_CASES_EQ = []\nfor s in range(1, 5):\n    expected_shape = (16, 5, 4 * s, 5 * s, 6 * s)\n    for t in UpsampleMode:\n        test_case = [\n            {\n                \"spatial_dims\": 3,\n                \"in_channels\": 3,\n                \"out_channels\": 5,\n                \"mode\": t,\n                \"scale_factor\": s,\n                \"align_corners\": True,\n            },\n            (16, 3, 4, 5, 6),\n            expected_shape,\n        ]\n        TEST_CASES_EQ.append(test_case)\n\nTEST_CASES_EQ2 = []  # type: ignore\nfor s in range(2, 5):\n    for k in range(1, 7):\n        expected_shape = (16, 5, 4 * s, 5 * s, 6 * s)\n        for t in UpsampleMode:\n            test_case = [\n                {\n                    \"spatial_dims\": 3,\n                    \"in_channels\": 3,\n                    \"out_channels\": 5,\n                    \"mode\": t,\n                    \"scale_factor\": s,\n                    \"kernel_size\": k,\n                    \"align_corners\": False,\n                },\n                (16, 3, 4, 5, 6),\n                expected_shape,\n            ]\n            TEST_CASES_EQ.append(test_case)\n\n\nclass TestUpsample(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES + TEST_CASES_EQ + TEST_CASES_EQ2)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = UpSample(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/warp/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/blocks/warp/test_dvf2ddf.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch import nn\nfrom torch.optim import SGD\n\nfrom monai.networks.blocks.warp import DVF2DDF\nfrom monai.utils import set_determinism\n\nTEST_CASES = [\n    [{\"num_steps\": 1}, {\"dvf\": torch.zeros(1, 2, 2, 2)}, torch.zeros(1, 2, 2, 2)],\n    [\n        {\"num_steps\": 1},\n        {\"dvf\": torch.ones(1, 3, 2, 2, 2)},\n        torch.tensor([[[1.0000, 0.7500], [0.7500, 0.6250]], [[0.7500, 0.6250], [0.6250, 0.5625]]])\n        .reshape(1, 1, 2, 2, 2)\n        .expand(-1, 3, -1, -1, -1),\n    ],\n    [\n        {\"num_steps\": 2},\n        {\"dvf\": torch.ones(1, 3, 2, 2, 2)},\n        torch.tensor([[[0.9175, 0.6618], [0.6618, 0.5306]], [[0.6618, 0.5306], [0.5306, 0.4506]]])\n        .reshape(1, 1, 2, 2, 2)\n        .expand(-1, 3, -1, -1, -1),\n    ],\n]\n\n\nclass TestDVF2DDF(unittest.TestCase):\n\n    def setUp(self):\n        set_determinism(0)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES)\n    def test_value(self, input_param, input_data, expected_val):\n        layer = DVF2DDF(**input_param)\n        result = layer(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)\n\n    def test_gradient(self):\n        network = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=1)\n        dvf2ddf = DVF2DDF(num_steps=1)\n        optimizer = SGD(network.parameters(), lr=0.01)\n        x = torch.ones((1, 1, 5, 5))\n        x = network(x)\n        x = dvf2ddf(x)\n        loss = torch.sum(x)\n        loss.backward()\n        optimizer.step()\n        np.testing.assert_allclose(network.weight.grad.cpu().numpy(), np.array([[[[22.471329]]], [[[22.552576]]]]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/blocks/warp/test_warp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import gradcheck\n\nfrom monai.config.deviceconfig import USE_COMPILED\nfrom monai.networks.blocks.warp import Warp\nfrom monai.transforms import LoadImaged\nfrom monai.utils import GridSampleMode, GridSamplePadMode\nfrom tests.test_utils import SkipIfNoModule, download_url_or_skip_test, skip_if_quick, testing_data_config\n\nLOW_POWER_TEST_CASES = [  # run with BUILD_MONAI=1 to test csrc/resample, BUILD_MONAI=0 to test native grid_sample\n    [\n        {\"mode\": \"nearest\", \"padding_mode\": \"zeros\"},\n        {\"image\": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), \"ddf\": torch.zeros(1, 2, 2, 2)},\n        torch.arange(4).reshape((1, 1, 2, 2)),\n    ],\n    [\n        {\"mode\": \"bilinear\", \"padding_mode\": \"zeros\"},\n        {\"image\": torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), \"ddf\": torch.ones(1, 2, 2, 2)},\n        torch.tensor([[[[3, 0], [0, 0]]]]),\n    ],\n    [\n        {\"mode\": \"bilinear\", \"padding_mode\": \"border\"},\n        {\n            \"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),\n            \"ddf\": torch.ones(1, 3, 2, 2, 2) * -1,\n        },\n        torch.tensor([[[[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]]),\n    ],\n    [\n        {\"mode\": \"bilinear\", \"padding_mode\": \"reflection\"},\n        {\n            \"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),\n            \"ddf\": torch.ones(1, 3, 2, 2, 2) * -1,\n        },\n        torch.tensor([[[[[7.0, 6.0], [5.0, 4.0]], [[3.0, 2.0], [1.0, 0.0]]]]]),\n    ],\n]\n\nCPP_TEST_CASES = [  # high order, BUILD_MONAI=1 to test csrc/resample\n    [\n        {\"mode\": 2, \"padding_mode\": \"border\"},\n        {\n            \"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),\n            \"ddf\": torch.ones(1, 3, 2, 2, 2) * -1,\n        },\n        torch.tensor([[[[[0.0000, 0.1250], [0.2500, 0.3750]], [[0.5000, 0.6250], [0.7500, 0.8750]]]]]),\n    ],\n    [\n        {\"mode\": 2, \"padding_mode\": \"reflection\"},\n        {\n            \"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),\n            \"ddf\": torch.ones(1, 3, 2, 2, 2) * -1,\n        },\n        torch.tensor([[[[[5.2500, 4.7500], [4.2500, 3.7500]], [[3.2500, 2.7500], [2.2500, 1.7500]]]]]),\n    ],\n    [\n        {\"mode\": 2, \"padding_mode\": \"zeros\"},\n        {\n            \"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),\n            \"ddf\": torch.ones(1, 3, 2, 2, 2) * -1,\n        },\n        torch.tensor([[[[[0.0000, 0.0020], [0.0039, 0.0410]], [[0.0078, 0.0684], [0.0820, 0.6699]]]]]),\n    ],\n    [\n        {\"mode\": 2, \"padding_mode\": 7},\n        {\n            \"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float),\n            \"ddf\": torch.ones(1, 3, 2, 2, 2) * -1,\n        },\n        torch.tensor([[[[[0.0000, 0.0020], [0.0039, 0.0410]], [[0.0078, 0.0684], [0.0820, 0.6699]]]]]),\n    ],\n    [\n        {\"mode\": 3, \"padding_mode\": \"reflection\"},\n        {\"image\": torch.arange(8).reshape((1, 1, 2, 2, 2)).to(dtype=torch.float), \"ddf\": torch.ones(1, 3, 2, 2, 2)},\n        torch.tensor([[[[[4.6667, 4.3333], [4.0000, 3.6667]], [[3.3333, 3.0000], [2.6667, 2.3333]]]]]),\n    ],\n]\n\nTEST_CASES = LOW_POWER_TEST_CASES\nif USE_COMPILED:\n    TEST_CASES += CPP_TEST_CASES\n\n\n@skip_if_quick\nclass TestWarp(unittest.TestCase):\n    def setUp(self):\n        config = testing_data_config(\"images\", \"Prostate_T2W_AX_1\")\n        download_url_or_skip_test(\n            url=config[\"url\"],\n            filepath=FILE_PATH,\n            hash_val=config.get(\"hash_val\"),\n            hash_type=config.get(\"hash_type\", \"sha256\"),\n        )\n\n    @SkipIfNoModule(\"itk\")\n    def test_itk_benchmark(self):\n        img, ddf = load_img_and_sample_ddf()\n        monai_result = monai_warp(img, ddf)\n        itk_result = itk_warp(img, ddf)\n        relative_diff = np.mean(\n            np.divide(monai_result - itk_result, itk_result, out=np.zeros_like(itk_result), where=(itk_result != 0))\n        )\n        self.assertLess(relative_diff, 0.01)\n\n    @parameterized.expand(TEST_CASES, skip_on_empty=True)\n    def test_resample(self, input_param, input_data, expected_val):\n        warp_layer = Warp(**input_param)\n        result = warp_layer(**input_data)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)\n\n    def test_ill_shape(self):\n        warp_layer = Warp()\n        with self.assertRaisesRegex(ValueError, \"\"):\n            warp_layer(\n                image=torch.arange(4).reshape((1, 1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 2, 2)\n            )\n        with self.assertRaisesRegex(ValueError, \"\"):\n            warp_layer(\n                image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 1, 2, 2)\n            )\n        with self.assertRaisesRegex(ValueError, \"\"):\n            warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3))\n\n    def test_grad(self):\n        for b in GridSampleMode:\n            for p in GridSamplePadMode:\n                warp_layer = Warp(mode=b.value, padding_mode=p.value)\n                input_image = torch.rand((2, 3, 20, 20), dtype=torch.float64) * 10.0\n                ddf = torch.rand((2, 2, 20, 20), dtype=torch.float64) * 2.0\n                input_image.requires_grad = True\n                ddf.requires_grad = False  # Jacobian mismatch for output 0 with respect to input 1\n                gradcheck(warp_layer, (input_image, ddf), atol=1e-2, eps=1e-2)\n\n\nTESTS_PATH = Path(__file__).parents[3]\nFILE_PATH = TESTS_PATH / \"testing_data\" / \"temp_\" / \"mri.nii\"\n\n\ndef load_img_and_sample_ddf():\n    # load image\n    img = LoadImaged(keys=\"img\")({\"img\": FILE_PATH})[\"img\"]\n    img = img.detach().numpy()\n    # W, H, D -> D, H, W\n    img = img.transpose((2, 1, 0)).copy()\n\n    # randomly sample ddf such that maximum displacement in each direction equals to one-tenth of the image dimension in\n    # that direction.\n    ddf = np.random.random((3, *img.shape)).astype(np.float32)  # (3, D, H, W)\n    ddf[0] = ddf[0] * img.shape[0] * 0.1\n    ddf[1] = ddf[1] * img.shape[1] * 0.1\n    ddf[2] = ddf[2] * img.shape[2] * 0.1\n    return img, ddf\n\n\ndef itk_warp(img, ddf):\n    \"\"\"\n    warping with python itk\n    Args:\n        img: numpy array of shape (D, H, W)\n        ddf: numpy array of shape (3, D, H, W)\n\n    Returns:\n        warped_img: numpy arrap of shape (D, H, W)\n    \"\"\"\n    import itk\n\n    # 3, D, H, W -> D, H, W, 3\n    ddf = ddf.transpose((1, 2, 3, 0))\n    # x, y, z -> z, x, y\n    ddf = ddf[..., ::-1]\n\n    dimension = 3\n\n    # initialise image\n    pixel_type = itk.F  # float32\n    image_type = itk.Image[pixel_type, dimension]\n    itk_img = itk.PyBuffer[image_type].GetImageFromArray(img.astype(np.float32), is_vector=None)\n\n    # initialise displacement field\n    vector_component_type = itk.F\n    vector_pixel_type = itk.Vector[vector_component_type, dimension]\n    displacement_field_type = itk.Image[vector_pixel_type, dimension]\n    displacement_field = itk.PyBuffer[displacement_field_type].GetImageFromArray(ddf.astype(np.float32), is_vector=True)\n\n    # initialise warp_filter\n    warp_filter = itk.WarpImageFilter[image_type, image_type, displacement_field_type].New()\n    interpolator = itk.LinearInterpolateImageFunction[image_type, itk.D].New()\n    warp_filter.SetInterpolator(interpolator)\n    warp_filter.SetOutputSpacing(itk_img.GetSpacing())\n    warp_filter.SetOutputOrigin(itk_img.GetOrigin())\n    warp_filter.SetOutputDirection(itk_img.GetDirection())\n\n    # warp\n    warp_filter.SetDisplacementField(displacement_field)\n    warp_filter.SetInput(itk_img)\n    warp_filter.Update()\n    warped_img = warp_filter.GetOutput()\n    warped_img = np.asarray(warped_img)\n\n    return warped_img\n\n\ndef monai_warp(img, ddf):\n    \"\"\"\n    warp with MONAI\n    Args:\n        img: numpy array of shape (D, H, W)\n        ddf: numpy array of shape (3, D, H, W)\n\n    Returns:\n        warped_img: numpy arrap of shape (D, H, W)\n    \"\"\"\n    warp_layer = Warp(padding_mode=\"zeros\")\n    # turn to tensor and add channel dim\n    monai_img = torch.tensor(img).unsqueeze(0)\n    ddf = torch.tensor(ddf)\n    # img -> batch -> img\n    warped_img = warp_layer(monai_img.unsqueeze(0), ddf.unsqueeze(0)).squeeze(0)\n    # remove channel dim\n    warped_img = np.asarray(warped_img.squeeze(0))\n\n    return warped_img\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/layers/filtering/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_bilateral_approx_cpu.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import gradcheck\n\nfrom monai.networks.layers.filtering import BilateralFilter\nfrom tests.test_utils import skip_if_no_cpp_extension\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigma, low color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1.000000, 0.000000, 0.000000, 0.000000, 1.000000]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000000, 1.000000, 0.000000, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.631360, 0.099349, 0.070177, 0.164534, 0.649869]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.052271, 0.173599, 0.481337, 0.183721, 0.045619]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigma, low color sigma\",\n        # Spatial and Color Sigmas\n        (4, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1.000000, 0.000000, 0.000000, 0.000000, 1.000000]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000000, 1.000000, 0.000000, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.497667, 0.268683, 0.265026, 0.261467, 0.495981]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.145959, 0.142282, 0.315710, 0.135609, 0.132572]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 4 channel, low spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 0],\n                # Channel 1\n                [1, 0, 1, 0, 0],\n                # Channel 2\n                [0, 0, 1, 0, 1],\n                # Channel 3\n                [0, 0, 0, 0, 1],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.960843, 0.073540, 0.027689, 0.002676, 0.000000],\n                # Channel 1\n                [0.960843, 0.073540, 0.951248, 0.003033, 0.000750],\n                # Channel 2\n                [0.000000, 0.000000, 0.923559, 0.000357, 0.981324],\n                # Channel 3\n                [0.000000, 0.000000, 0.000000, 0.000000, 0.980574],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.213684, 0.094356, 0.092973, 0.091650, 0.216281],\n                    [0.094085, 0.092654, 0.091395, 0.090186, 0.089302],\n                    [0.092436, 0.091150, 0.090008, 0.088896, 0.088897],\n                    [0.090849, 0.089717, 0.088759, 0.087751, 0.088501],\n                    [0.211458, 0.088334, 0.087495, 0.087049, 0.212173],\n                ]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [\n                    [0.033341, 0.031314, 0.029367, 0.027494, 0.025692],\n                    [0.031869, 0.030632, 0.028820, 0.027074, 0.025454],\n                    [0.030455, 0.029628, 0.084257, 0.026704, 0.025372],\n                    [0.029095, 0.028391, 0.027790, 0.026375, 0.025292],\n                    [0.027786, 0.027197, 0.026692, 0.026181, 0.025213],\n                ]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimension, 4 channel, high spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]],\n                # Channel 1\n                [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]],\n                # Channel 2\n                [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]],\n                # Channel 3\n                [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.244373, 0.014488, 0.036589, 0.014226, 0.024329],\n                    [0.014108, 0.014228, 0.014096, 0.013961, 0.013823],\n                    [0.013574, 0.013757, 0.013836, 0.013699, 0.013558],\n                    [0.013008, 0.013211, 0.013404, 0.013438, 0.013295],\n                    [0.025179, 0.012634, 0.034555, 0.013050, 0.237582],\n                ],\n                # Channel 1\n                [\n                    [0.271496, 0.015547, 0.439432, 0.015700, 0.089579],\n                    [0.015252, 0.015702, 0.015779, 0.015859, 0.015940],\n                    [0.015020, 0.015556, 0.015935, 0.016015, 0.016098],\n                    [0.014774, 0.015331, 0.015860, 0.016171, 0.016255],\n                    [0.107384, 0.015094, 0.462471, 0.016166, 0.263480],\n                ],\n                # Channel 2\n                [\n                    [0.027123, 0.003527, 0.467273, 0.004912, 0.645776],\n                    [0.003810, 0.004908, 0.005605, 0.006319, 0.007050],\n                    [0.004816, 0.005991, 0.006989, 0.007716, 0.008459],\n                    [0.005880, 0.007060, 0.008179, 0.009101, 0.009858],\n                    [0.633398, 0.008191, 0.496893, 0.010376, 0.025898],\n                ],\n                # Channel 3\n                [\n                    [0.000000, 0.002468, 0.064430, 0.003437, 0.580526],\n                    [0.002666, 0.003434, 0.003922, 0.004422, 0.004933],\n                    [0.003370, 0.004192, 0.004890, 0.005399, 0.005919],\n                    [0.004115, 0.004940, 0.005723, 0.006368, 0.006898],\n                    [0.551194, 0.005731, 0.068977, 0.007260, 0.000000],\n                ],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"3 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [\n                        [0.086801, 0.036670, 0.035971, 0.035304, 0.088456],\n                        [0.036639, 0.035652, 0.035009, 0.034394, 0.033803],\n                        [0.035899, 0.034897, 0.034136, 0.033566, 0.033129],\n                        [0.035180, 0.034238, 0.033413, 0.032811, 0.032577],\n                        [0.088290, 0.033597, 0.032821, 0.032134, 0.088786],\n                    ],\n                    # Frame 1\n                    [\n                        [0.036286, 0.035269, 0.034632, 0.034021, 0.033435],\n                        [0.035398, 0.034485, 0.033922, 0.033381, 0.033177],\n                        [0.034688, 0.033822, 0.033169, 0.032664, 0.032780],\n                        [0.034024, 0.033234, 0.032533, 0.032005, 0.032388],\n                        [0.033564, 0.032797, 0.032118, 0.031525, 0.032105],\n                    ],\n                    # Frame 2\n                    [\n                        [0.035225, 0.034169, 0.033404, 0.032843, 0.032766],\n                        [0.034383, 0.033487, 0.032908, 0.032415, 0.032650],\n                        [0.033691, 0.032921, 0.032353, 0.031900, 0.032384],\n                        [0.033080, 0.032390, 0.031786, 0.031432, 0.032008],\n                        [0.033099, 0.032373, 0.031737, 0.031479, 0.032054],\n                    ],\n                    # Frame 3\n                    [\n                        [0.034216, 0.033231, 0.032337, 0.031758, 0.032101],\n                        [0.033456, 0.032669, 0.031913, 0.031455, 0.032034],\n                        [0.032788, 0.032140, 0.031618, 0.031413, 0.031977],\n                        [0.032221, 0.031650, 0.031145, 0.031130, 0.031652],\n                        [0.032642, 0.031968, 0.031378, 0.031433, 0.032003],\n                    ],\n                    # Frame 4\n                    [\n                        [0.086207, 0.032335, 0.031499, 0.030832, 0.087498],\n                        [0.032570, 0.031884, 0.031155, 0.030858, 0.031401],\n                        [0.031967, 0.031417, 0.030876, 0.030881, 0.031388],\n                        [0.031602, 0.031103, 0.030696, 0.030960, 0.031455],\n                        [0.090599, 0.031546, 0.031127, 0.031386, 0.083483],\n                    ],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\nclass BilateralFilterTestCaseCpuApprox(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_approx(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n        fast_approx = True\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n        fast_approx = True\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        input_tensor.requires_grad = True\n\n        # Prepare args\n        args = (input_tensor, *sigmas, fast_approx)\n\n        # Run grad check\n        gradcheck(BilateralFilter.apply, args, raise_exception=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_bilateral_approx_cuda.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import gradcheck\n\nfrom monai.networks.layers.filtering import BilateralFilter\nfrom tests.test_utils import skip_if_no_cpp_extension, skip_if_no_cuda\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigma, low color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1.000000, 0.000000, 0.000000, 0.000000, 1.000000]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000000, 1.000000, 0.000000, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.880626, 0.306148, 0.158734, 0.164534, 0.754386]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.019010, 0.104507, 0.605634, 0.183721, 0.045619]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigma, low color sigma\",\n        # Spatial and Color Sigmas\n        (4, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1.000000, 0.000000, 0.000000, 0.000000, 1.000000]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000000, 1.000000, 0.000000, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.497667, 0.268683, 0.265026, 0.261467, 0.495981]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.149889, 0.148226, 0.367978, 0.144023, 0.141317]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 4 channel, low spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 0],\n                # Channel 1\n                [1, 0, 1, 0, 0],\n                # Channel 2\n                [0, 0, 1, 0, 1],\n                # Channel 3\n                [0, 0, 0, 0, 1],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.988107, 0.061340, 0.001565, 0.000011, 0.000000],\n                # Channel 1\n                [0.988107, 0.061340, 0.998000, 0.000016, 0.000123],\n                # Channel 2\n                [0.000000, 0.000000, 0.996435, 0.000006, 0.999236],\n                # Channel 3\n                [0.000000, 0.000000, 0.000000, 0.000000, 0.999113],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.211469, 0.094356, 0.092973, 0.091650, 0.211894],\n                    [0.093755, 0.091753, 0.090524, 0.089343, 0.088384],\n                    [0.091803, 0.089783, 0.088409, 0.087346, 0.086927],\n                    [0.089938, 0.088126, 0.086613, 0.085601, 0.085535],\n                    [0.208359, 0.086535, 0.085179, 0.084210, 0.205858],\n                ]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [\n                    [0.032760, 0.030146, 0.027442, 0.024643, 0.021744],\n                    [0.030955, 0.029416, 0.026574, 0.023629, 0.020841],\n                    [0.028915, 0.027834, 0.115442, 0.022515, 0.020442],\n                    [0.026589, 0.025447, 0.024319, 0.021286, 0.019964],\n                    [0.023913, 0.022704, 0.021510, 0.020388, 0.019379],\n                ]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimension, 4 channel, high spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]],\n                # Channel 1\n                [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]],\n                # Channel 2\n                [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]],\n                # Channel 3\n                [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.557349, 0.011031, 0.001800, 0.011265, 0.000631],\n                    [0.009824, 0.010361, 0.010429, 0.010506, 0.010595],\n                    [0.008709, 0.009252, 0.009688, 0.009714, 0.009744],\n                    [0.007589, 0.008042, 0.008576, 0.008887, 0.008852],\n                    [0.000420, 0.006827, 0.001048, 0.007763, 0.190722],\n                ],\n                # Channel 1\n                [\n                    [0.614072, 0.011045, 0.925766, 0.011287, 0.007548],\n                    [0.009838, 0.010382, 0.010454, 0.010536, 0.010630],\n                    [0.008727, 0.009277, 0.009720, 0.009751, 0.009787],\n                    [0.007611, 0.008071, 0.008613, 0.008932, 0.008904],\n                    [0.027088, 0.006859, 0.950749, 0.007815, 0.230270],\n                ],\n                # Channel 2\n                [\n                    [0.056723, 0.000150, 0.973790, 0.000233, 0.990814],\n                    [0.000151, 0.000214, 0.000257, 0.000307, 0.000364],\n                    [0.000186, 0.000257, 0.000328, 0.000384, 0.000449],\n                    [0.000221, 0.000295, 0.000382, 0.000465, 0.000538],\n                    [0.993884, 0.000333, 0.984743, 0.000532, 0.039548],\n                ],\n                # Channel 3\n                [\n                    [0.000000, 0.000136, 0.049824, 0.000210, 0.983897],\n                    [0.000136, 0.000193, 0.000232, 0.000277, 0.000329],\n                    [0.000168, 0.000232, 0.000297, 0.000347, 0.000405],\n                    [0.000200, 0.000266, 0.000345, 0.000420, 0.000485],\n                    [0.967217, 0.000301, 0.035041, 0.000481, 0.000000],\n                ],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"3 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [\n                        [0.085451, 0.037820, 0.036880, 0.035978, 0.084296],\n                        [0.037939, 0.036953, 0.036155, 0.035385, 0.034640],\n                        [0.037167, 0.036302, 0.035603, 0.034931, 0.034465],\n                        [0.036469, 0.035724, 0.035137, 0.034572, 0.034480],\n                        [0.088942, 0.035193, 0.034682, 0.034266, 0.090568],\n                    ],\n                    # Frame 1\n                    [\n                        [0.037125, 0.035944, 0.035103, 0.033429, 0.033498],\n                        [0.033380, 0.032653, 0.033748, 0.033073, 0.032549],\n                        [0.034834, 0.034001, 0.033500, 0.032902, 0.032560],\n                        [0.033972, 0.033554, 0.033220, 0.032765, 0.032570],\n                        [0.033590, 0.033222, 0.032927, 0.032689, 0.032629],\n                    ],\n                    # Frame 2\n                    [\n                        [0.035635, 0.034468, 0.033551, 0.032818, 0.032302],\n                        [0.034523, 0.032830, 0.032146, 0.031536, 0.031149],\n                        [0.033612, 0.032011, 0.031664, 0.031128, 0.030839],\n                        [0.032801, 0.031668, 0.031529, 0.031198, 0.030978],\n                        [0.032337, 0.031550, 0.031419, 0.031383, 0.031211],\n                    ],\n                    # Frame 3\n                    [\n                        [0.034300, 0.033236, 0.032239, 0.031517, 0.031133],\n                        [0.033357, 0.031842, 0.031035, 0.030471, 0.030126],\n                        [0.032563, 0.031094, 0.030156, 0.029703, 0.029324],\n                        [0.031850, 0.030505, 0.030027, 0.029802, 0.029461],\n                        [0.031555, 0.030121, 0.029943, 0.030000, 0.029700],\n                    ],\n                    # Frame 4\n                    [\n                        [0.083156, 0.032122, 0.031204, 0.030380, 0.080582],\n                        [0.032296, 0.030936, 0.030170, 0.029557, 0.029124],\n                        [0.031617, 0.030293, 0.029377, 0.028886, 0.028431],\n                        [0.031084, 0.029859, 0.028839, 0.028439, 0.027973],\n                        [0.164616, 0.029457, 0.028484, 0.028532, 0.211082],\n                    ],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cuda\n@skip_if_no_cpp_extension\nclass BilateralFilterTestCaseCudaApprox(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_approx(self, test_case_description, sigmas, input, expected):\n        # Skip this test\n        if not torch.cuda.is_available():\n            return\n\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n        fast_approx = True\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=2e-1)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_approx_backwards(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n        fast_approx = True\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        input_tensor.requires_grad = True\n\n        # Prepare args\n        args = (input_tensor, *sigmas, fast_approx)\n\n        # Run grad check\n        gradcheck(BilateralFilter.apply, args, raise_exception=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_bilateral_precise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import gradcheck\n\nfrom monai.networks.layers.filtering import BilateralFilter\nfrom tests.test_utils import skip_if_no_cpp_extension, skip_if_no_cuda, skip_if_quick\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigma, low color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.999998, 0.000002, 0.000000, 0.000002, 0.999998]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000001, 0.999995, 0.000001, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.813183, 0.186817, 0.061890, 0.186817, 0.813183]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.030148, 0.148418, 0.555452, 0.148418, 0.030148]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigma, low color sigma\",\n        # Spatial and Color Sigmas\n        (4, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.999999, 0.000009, 0.000009, 0.000009, 0.999999]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000000, 0.999967, 0.000000, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.839145, 0.572834, 0.562460, 0.572834, 0.839145]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.049925, 0.055062, 0.171732, 0.055062, 0.049925]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 4 channel, low spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 0],\n                # Channel 1\n                [1, 0, 1, 0, 0],\n                # Channel 2\n                [0, 0, 1, 0, 1],\n                # Channel 3\n                [0, 0, 0, 0, 1],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.889742, 0.141296, 0.027504, 0.000000, 0.000000],\n                # Channel 1\n                [0.909856, 0.256817, 0.725970, 0.115520, 0.020114],\n                # Channel 2\n                [0.020114, 0.115520, 0.725970, 0.256817, 0.909856],\n                # Channel 3\n                [0.000000, 0.000000, 0.027504, 0.141296, 0.889742],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.688943, 0.374599, 0.368574, 0.374599, 0.688943],\n                    [0.374599, 0.358248, 0.352546, 0.358248, 0.374599],\n                    [0.368574, 0.352546, 0.346955, 0.352546, 0.368574],\n                    [0.374599, 0.358248, 0.352546, 0.358248, 0.374599],\n                    [0.688943, 0.374599, 0.368574, 0.374599, 0.688943],\n                ]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [\n                    [0.004266, 0.004687, 0.004836, 0.004687, 0.004266],\n                    [0.004687, 0.005150, 0.005314, 0.005150, 0.004687],\n                    [0.004836, 0.005314, 0.018598, 0.005314, 0.004836],\n                    [0.004687, 0.005150, 0.005314, 0.005150, 0.004687],\n                    [0.004266, 0.004687, 0.004836, 0.004687, 0.004266],\n                ]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimension, 4 channel, high spatial sigma, high color sigma\",\n        # Spatial and Color Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]],\n                # Channel 1\n                [[1, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 1]],\n                # Channel 2\n                [[0, 0, 1, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 0]],\n                # Channel 3\n                [[0, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.692549, 0.149979, 0.220063, 0.115840, 0.035799],\n                    [0.148403, 0.133935, 0.123253, 0.116828, 0.114623],\n                    [0.128773, 0.122804, 0.120731, 0.122804, 0.128773],\n                    [0.114623, 0.116828, 0.123253, 0.133935, 0.148403],\n                    [0.035799, 0.115840, 0.220063, 0.149979, 0.692549],\n                ],\n                # Channel 1\n                [\n                    [0.731597, 0.186319, 0.436069, 0.152181, 0.074847],\n                    [0.180049, 0.168217, 0.158453, 0.151110, 0.146269],\n                    [0.159760, 0.156381, 0.155211, 0.156381, 0.159760],\n                    [0.146269, 0.151110, 0.158453, 0.168217, 0.180049],\n                    [0.074847, 0.152181, 0.436068, 0.186319, 0.731597],\n                ],\n                # Channel 2\n                [\n                    [0.074847, 0.152181, 0.436068, 0.186319, 0.731597],\n                    [0.146269, 0.151110, 0.158453, 0.168217, 0.180049],\n                    [0.159760, 0.156381, 0.155211, 0.156381, 0.159760],\n                    [0.180049, 0.168217, 0.158453, 0.151110, 0.146269],\n                    [0.731597, 0.186319, 0.436069, 0.152181, 0.074847],\n                ],\n                # Channel 3\n                [\n                    [0.035799, 0.115840, 0.220063, 0.149979, 0.692549],\n                    [0.114623, 0.116828, 0.123253, 0.133935, 0.148403],\n                    [0.128773, 0.122804, 0.120731, 0.122804, 0.128773],\n                    [0.148403, 0.133935, 0.123253, 0.116828, 0.114623],\n                    [0.692549, 0.149979, 0.220063, 0.115840, 0.035799],\n                ],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"3 dimension, 1 channel, high spatial sigma, high color sigma\",\n        # Sigmas\n        (4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [\n                        [0.554430, 0.254995, 0.251207, 0.254996, 0.554430],\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                        [0.251207, 0.241082, 0.237534, 0.241082, 0.251207],\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                        [0.554430, 0.254995, 0.251207, 0.254996, 0.554430],\n                    ],\n                    # Frame 1\n                    [\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                        [0.244692, 0.234873, 0.231432, 0.234873, 0.244692],\n                        [0.241082, 0.231431, 0.228049, 0.231432, 0.241082],\n                        [0.244692, 0.234873, 0.231432, 0.234873, 0.244692],\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                    ],\n                    # Frame 2\n                    [\n                        [0.251207, 0.241081, 0.237534, 0.241082, 0.251207],\n                        [0.241082, 0.231431, 0.228049, 0.231432, 0.241082],\n                        [0.237534, 0.228048, 0.224724, 0.228049, 0.237534],\n                        [0.241082, 0.231431, 0.228049, 0.231432, 0.241082],\n                        [0.251207, 0.241081, 0.237534, 0.241082, 0.251207],\n                    ],\n                    # Frame 3\n                    [\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                        [0.244692, 0.234873, 0.231432, 0.234873, 0.244692],\n                        [0.241082, 0.231431, 0.228049, 0.231432, 0.241082],\n                        [0.244692, 0.234873, 0.231432, 0.234873, 0.244692],\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                    ],\n                    # Frame 4\n                    [\n                        [0.554430, 0.254995, 0.251207, 0.254996, 0.554430],\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                        [0.251207, 0.241082, 0.237534, 0.241082, 0.251207],\n                        [0.254996, 0.244691, 0.241082, 0.244692, 0.254996],\n                        [0.554430, 0.254995, 0.251207, 0.254996, 0.554430],\n                    ],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\n@skip_if_quick\nclass BilateralFilterTestCaseCpuPrecise(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_precise(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n        fast_approx = False\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n        fast_approx = False\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        input_tensor.requires_grad = True\n\n        # Prepare args\n        args = (input_tensor, *sigmas, fast_approx)\n\n        # Run grad check\n        gradcheck(BilateralFilter.apply, args, raise_exception=False)\n\n\n@skip_if_no_cuda\n@skip_if_no_cpp_extension\nclass BilateralFilterTestCaseCudaPrecise(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_precise(self, test_case_description, sigmas, input, expected):\n        # Skip this test\n        if not torch.cuda.is_available():\n            return\n\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n        fast_approx = False\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        output = BilateralFilter.apply(input_tensor, *sigmas, fast_approx).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n        fast_approx = False\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=device)\n        input_tensor.requires_grad = True\n\n        # Prepare args\n        args = (input_tensor, *sigmas, fast_approx)\n\n        # Run grad check\n        gradcheck(BilateralFilter.apply, args, raise_exception=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_phl_cpu.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers.filtering import PHLFilter\nfrom tests.test_utils import skip_if_no_cpp_extension\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"2 batches, 1 dimensions, 1 channels, 1 features\",\n        # Sigmas\n        [1, 0.2],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0.2, 0.5, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.5, 0, 1, 1, 1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.468968, 0.364596, 0.4082, 0.332579, 0.468968]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.202473, 0.176527, 0.220995, 0.220995, 0.220995]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 1 dimensions, 3 channels, 1 features\",\n        # Sigmas\n        [1],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 0],\n                # Channel 1\n                [0, 0, 0, 0, 1],\n                # Channel 2\n                [0, 0, 1, 0, 0],\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0.2, 0.5, 0.2, 1]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.229572, 0.182884, 0.202637, 0.182884, 0.229572],\n                # Channel 1\n                [0.229572, 0.182884, 0.202637, 0.182884, 0.229572],\n                # Channel 2\n                [0.201235, 0.208194, 0.205409, 0.208194, 0.201235],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 2 dimensions, 1 channels, 3 features\",\n        # Sigmas\n        [5, 3, 3],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]]\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]],\n                # Channel 1\n                [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]],\n                # Channel 2\n                [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [7.696051, 7.427121, 1.191990, 1.156004, 1.157489],\n                    [7.670297, 7.371155, 1.340232, 1.287871, 1.304018],\n                    [7.639579, 7.365163, 1.473319, 1.397826, 1.416861],\n                    [7.613517, 7.359183, 5.846500, 5.638952, 5.350098],\n                    [7.598255, 7.458446, 5.912375, 5.583625, 5.233126],\n                ]\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 3 dimensions, 1 channels, 1 features\",\n        # Sigmas\n        [5, 3, 3],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                ]\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0], [9, 9, 9, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 5, 5, 5], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [3.578490, 3.578490, 3.578490, 0.284234, 0.284234],\n                        [3.578490, 3.578490, 3.578490, 0.284234, 0.284234],\n                        [3.578490, 3.578490, 3.578490, 0.284234, 0.284234],\n                    ],\n                    # Frame 1\n                    [\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [3.578490, 3.578490, 3.578490, 0.284234, 0.284234],\n                        [3.578490, 3.578490, 3.578490, 0.284234, 0.284234],\n                        [3.578490, 3.578490, 3.578490, 0.284234, 0.284234],\n                    ],\n                    # Frame 2\n                    [\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                    ],\n                    # Frame 3\n                    [\n                        [0.284234, 0.284234, 1.359728, 1.359728, 1.359728],\n                        [0.284234, 0.284234, 1.359728, 1.359728, 1.359728],\n                        [0.284234, 0.284234, 1.359728, 1.359728, 1.359728],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                    ],\n                    # Frame 4\n                    [\n                        [0.284234, 0.284234, 1.359728, 1.359728, 1.359728],\n                        [0.284234, 0.284234, 1.359728, 1.359728, 1.359728],\n                        [0.284234, 0.284234, 1.359728, 1.359728, 1.359728],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                        [0.284234, 0.284234, 0.284234, 0.284234, 0.284234],\n                    ],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\nclass PHLFilterTestCaseCpu(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cpu(self, test_case_description, sigmas, input, features, expected):\n        # Create input tensors\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device(\"cpu\"))\n        feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device(\"cpu\"))\n\n        # apply filter\n        output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_phl_cuda.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers.filtering import PHLFilter\nfrom tests.test_utils import skip_if_no_cpp_extension, skip_if_no_cuda\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"2 batches, 1 dimensions, 1 channels, 1 features\",\n        # Sigmas\n        [1, 0.2],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0.2, 0.5, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.5, 0, 1, 1, 1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.468968, 0.364596, 0.408200, 0.332579, 0.468968]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.202473, 0.176527, 0.220995, 0.220995, 0.220995]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 1 dimensions, 3 channels, 1 features\",\n        # Sigmas\n        [1],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 0],\n                # Channel 1\n                [0, 0, 0, 0, 1],\n                # Channel 2\n                [0, 0, 1, 0, 0],\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0.2, 0.5, 0.2, 1]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.229572, 0.182884, 0.202637, 0.182884, 0.229572],\n                # Channel 1\n                [0.229572, 0.182884, 0.202637, 0.182884, 0.229572],\n                # Channel 2\n                [0.201235, 0.208194, 0.205409, 0.208194, 0.201235],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 2 dimensions, 1 channels, 3 features\",\n        # Sigmas\n        [5, 3, 3],\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]]\n            ]\n        ],\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 0, 0, 0], [9, 9, 6, 6, 6], [9, 9, 6, 6, 6]],\n                # Channel 1\n                [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]],\n                # Channel 2\n                [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4]],\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [7.792655, 7.511395, 0.953769, 0.860538, 0.912978],\n                    [7.758870, 7.426762, 1.164386, 1.050956, 1.121830],\n                    [7.733974, 7.429964, 1.405752, 1.244949, 1.320862],\n                    [7.712976, 7.429060, 5.789552, 5.594258, 5.371737],\n                    [7.701185, 7.492719, 5.860026, 5.538241, 5.281656],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cuda\n@skip_if_no_cpp_extension\nclass PHLFilterTestCaseCuda(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cuda(self, test_case_description, sigmas, input, features, expected):\n        # Create input tensors\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.float, device=torch.device(\"cuda\"))\n        feature_tensor = torch.from_numpy(np.array(features)).to(dtype=torch.float, device=torch.device(\"cuda\"))\n\n        # apply filter\n        output = PHLFilter.apply(input_tensor, feature_tensor, sigmas).cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(output, expected, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_trainable_bilateral.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import gradcheck\n\nfrom monai.networks.layers.filtering import TrainableBilateralFilterFunction\nfrom tests.test_utils import skip_if_no_cpp_extension, skip_if_no_cuda\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigmas, low color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (1.0, 1.0, 1.0, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.999997, 0.000001, 0.000000, 0.000001, 0.999997]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000001, 0.999995, 0.000001, 0.000000]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (1, 1, 1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.714200, 0.158126, 0.061890, 0.158126, 0.714200]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.043465, 0.158126, 0.555452, 0.158126, 0.043465]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigmas, low color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.999994, 0.000002, 0.000002, 0.000002, 0.999994]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000001, 0.000001, 0.999986, 0.000001, 0.000001]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.533282, 0.245915, 0.244711, 0.245915, 0.533282]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.125052, 0.126608, 0.333592, 0.126608, 0.125052]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimensions, 1 channel, high spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.239789, 0.082990, 0.082630, 0.082990, 0.239789],\n                    [0.082990, 0.081934, 0.081579, 0.081934, 0.082990],\n                    [0.082630, 0.081579, 0.081225, 0.081579, 0.082630],\n                    [0.082990, 0.081934, 0.081579, 0.081934, 0.082990],\n                    [0.239789, 0.082990, 0.082630, 0.082990, 0.239789],\n                ]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [\n                    [0.024155, 0.024432, 0.024525, 0.024432, 0.024155],\n                    [0.024432, 0.024712, 0.024806, 0.024712, 0.024432],\n                    [0.024525, 0.024806, 0.080686, 0.024806, 0.024525],\n                    [0.024432, 0.024712, 0.024806, 0.024712, 0.024432],\n                    [0.024155, 0.024432, 0.024525, 0.024432, 0.024155],\n                ]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"3 dimensions, 1 channel, high spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [\n                        [0.098142, 0.030317, 0.030191, 0.030316, 0.098142],\n                        [0.030316, 0.029947, 0.029822, 0.029947, 0.030316],\n                        [0.030191, 0.029822, 0.029698, 0.029822, 0.030191],\n                        [0.030316, 0.029947, 0.029822, 0.029947, 0.030316],\n                        [0.098142, 0.030317, 0.030191, 0.030317, 0.098142],\n                    ],\n                    # Frame 1\n                    [\n                        [0.030316, 0.029947, 0.029822, 0.029947, 0.030316],\n                        [0.029947, 0.029581, 0.029458, 0.029581, 0.029947],\n                        [0.029822, 0.029458, 0.029336, 0.029458, 0.029822],\n                        [0.029947, 0.029581, 0.029458, 0.029581, 0.029947],\n                        [0.030316, 0.029947, 0.029822, 0.029947, 0.030316],\n                    ],\n                    # Frame 2\n                    [\n                        [0.030191, 0.029822, 0.029698, 0.029822, 0.030191],\n                        [0.029822, 0.029458, 0.029336, 0.029458, 0.029822],\n                        [0.029698, 0.029336, 0.029214, 0.029336, 0.029698],\n                        [0.029822, 0.029458, 0.029336, 0.029458, 0.029822],\n                        [0.030191, 0.029822, 0.029698, 0.029822, 0.030191],\n                    ],\n                    # Frame 3\n                    [\n                        [0.030316, 0.029947, 0.029822, 0.029947, 0.030317],\n                        [0.029947, 0.029581, 0.029458, 0.029581, 0.029947],\n                        [0.029822, 0.029458, 0.029336, 0.029458, 0.029822],\n                        [0.029947, 0.029581, 0.029458, 0.029581, 0.029947],\n                        [0.030316, 0.029947, 0.029822, 0.029947, 0.030316],\n                    ],\n                    # Frame 4\n                    [\n                        [0.098142, 0.030317, 0.030191, 0.030317, 0.098142],\n                        [0.030317, 0.029947, 0.029822, 0.029947, 0.030316],\n                        [0.030191, 0.029822, 0.029698, 0.029822, 0.030191],\n                        [0.030317, 0.029947, 0.029822, 0.029947, 0.030316],\n                        [0.098142, 0.030317, 0.030191, 0.030316, 0.098142],\n                    ],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\nclass BilateralFilterTestCaseCpuPrecise(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_precise(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n\n        len_input = len(input_tensor.shape)\n        # C++ extension so far only supports 5-dim inputs.\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n\n        output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy()\n\n        # Make sure to return tensor of the same shape as the input.\n        if len_input == 3:\n            output = output.squeeze(4).squeeze(3)\n        elif len_input == 4:\n            output = output.squeeze(4)\n\n        # Ensure result are as expected.\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_precise_backwards(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n        input_tensor.requires_grad = True\n\n        # C++ extension so far only supports 5-dim inputs.\n        len_input = len(input_tensor.shape)\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n\n        # Check gradient toward input.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False)\n        input_tensor = input_tensor.detach()\n        input_tensor.requires_grad = False\n\n        # Check gradient toward sigma_x.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_y.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_z.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_color.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False)\n\n\n@skip_if_no_cuda\n@skip_if_no_cpp_extension\nclass BilateralFilterTestCaseCudaPrecise(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_precise(self, test_case_description, sigmas, input, expected):\n        # Skip this test\n        if not torch.cuda.is_available():\n            return\n\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n\n        len_input = len(input_tensor.shape)\n        # C++ extension so far only supports 5-dim inputs.\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n\n        output = TrainableBilateralFilterFunction.apply(input_tensor, *sigmas).cpu().numpy()\n\n        # Make sure to return tensor of the same shape as the input.\n        if len_input == 3:\n            output = output.squeeze(4).squeeze(3)\n        elif len_input == 4:\n            output = output.squeeze(4)\n\n        # Ensure result are as expected.\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_precise_backwards(self, test_case_description, sigmas, input, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n        input_tensor.requires_grad = True\n\n        # C++ extension so far only supports 5-dim inputs.\n        len_input = len(input_tensor.shape)\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n\n        # Check gradient toward input.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False)\n        input_tensor = input_tensor.detach()\n        input_tensor.requires_grad = False\n\n        # Check gradient toward sigma_x.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_y.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_z.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_color.\n        args = (\n            input_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True),\n        )\n        gradcheck(TrainableBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/filtering/test_trainable_joint_bilateral.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import gradcheck\n\nfrom monai.networks.layers.filtering import TrainableJointBilateralFilterFunction\nfrom tests.test_utils import skip_if_no_cpp_extension, skip_if_no_cuda, skip_if_quick\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigmas, low color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (1.0, 1.0, 1.0, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Guide\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.622459, 0.377540, 0.000001, 0.000001, 0.999997]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000000, 0.000001, 0.880793, 0.000002, 0.119203]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, low spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (1, 1, 1, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Guide\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.595404, 0.302253, 0.070203, 0.163038, 0.714200]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.043465, 0.158126, 0.536864, 0.182809, 0.092537]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigmas, low color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.2),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Guide\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.623709, 0.632901, 0.000003, 0.000003, 0.680336]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.000001, 0.000001, 0.531206, 0.000001, 0.468788]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 dimension, 1 channel, high spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 0, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 0]\n            ],\n        ],\n        # Guide\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 0, 1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0.464455, 0.463098, 0.276430, 0.275530, 0.478105]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0.134956, 0.138247, 0.293759, 0.141954, 0.281082]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"2 dimensions, 1 channel, high spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]\n            ],\n        ],\n        # Guide\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[1, 1, 0, 0, 1], [0, 0, 0, 1, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 1]]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [[0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.186535, 0.187357, 0.105377, 0.103652, 0.198665],\n                    [0.112617, 0.108847, 0.105970, 0.189602, 0.102954],\n                    [0.178338, 0.179829, 0.107473, 0.105256, 0.103963],\n                    [0.117651, 0.113304, 0.109876, 0.107392, 0.105853],\n                    [0.121557, 0.177689, 0.113150, 0.110388, 0.192877],\n                ]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [\n                    [0.047156, 0.047865, 0.048233, 0.038611, 0.047911],\n                    [0.047607, 0.048292, 0.048633, 0.039251, 0.038611],\n                    [0.047715, 0.048369, 0.048678, 0.048633, 0.048233],\n                    [0.047477, 0.048094, 0.048369, 0.048292, 0.047865],\n                    [0.039190, 0.047477, 0.047715, 0.047607, 0.047156],\n                ]\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"3 dimensions, 1 channel, high spatial sigmas, high color sigma\",\n        # (sigma_x, sigma_y, sigma_z, color_sigma)\n        (4, 4, 4, 0.9),\n        # Input\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[1, 0, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 1]],\n                ]\n            ]\n        ],\n        # Guide\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [[1, 1, 0, 0, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 1]],\n                    # Frame 1\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 2\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 3\n                    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],\n                    # Frame 4\n                    [[1, 1, 0, 1, 1], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 1]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Frame 0\n                    [\n                        [0.089316, 0.088903, 0.033707, 0.033461, 0.091881],\n                        [0.035173, 0.034324, 0.033747, 0.033448, 0.033427],\n                        [0.035619, 0.034710, 0.034074, 0.033720, 0.033646],\n                        [0.036364, 0.035387, 0.034688, 0.034275, 0.034148],\n                        [0.037401, 0.085687, 0.035583, 0.035109, 0.089741],\n                    ],\n                    # Frame 1\n                    [\n                        [0.034248, 0.033502, 0.033023, 0.032816, 0.032881],\n                        [0.034339, 0.033546, 0.033020, 0.032767, 0.032785],\n                        [0.034721, 0.033876, 0.033298, 0.032994, 0.032965],\n                        [0.035397, 0.034490, 0.033856, 0.033501, 0.033424],\n                        [0.036357, 0.035383, 0.034688, 0.034279, 0.034155],\n                    ],\n                    # Frame 2\n                    [\n                        [0.033748, 0.033047, 0.032609, 0.032441, 0.032541],\n                        [0.033782, 0.033041, 0.032562, 0.032353, 0.032410],\n                        [0.034104, 0.033316, 0.032792, 0.032538, 0.032554],\n                        [0.034714, 0.033872, 0.033298, 0.032998, 0.032972],\n                        [0.035604, 0.034702, 0.034074, 0.033727, 0.033660],\n                    ],\n                    # Frame 3\n                    [\n                        [0.033533, 0.032871, 0.032471, 0.032340, 0.032476],\n                        [0.033511, 0.032815, 0.032380, 0.032212, 0.032310],\n                        [0.033775, 0.033037, 0.032562, 0.032356, 0.032417],\n                        [0.034324, 0.033539, 0.033020, 0.032774, 0.032799],\n                        [0.035151, 0.034313, 0.033747, 0.033459, 0.033449],\n                    ],\n                    # Frame 4\n                    [\n                        [0.091383, 0.090681, 0.032608, 0.091418, 0.092851],\n                        [0.033525, 0.032867, 0.032471, 0.032344, 0.032483],\n                        [0.033733, 0.033039, 0.032609, 0.032448, 0.032555],\n                        [0.034226, 0.033491, 0.033023, 0.032827, 0.032903],\n                        [0.089445, 0.034216, 0.033707, 0.090126, 0.091748],\n                    ],\n                ]\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_no_cpp_extension\n@skip_if_quick\nclass JointBilateralFilterTestCaseCpuPrecise(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_precise(self, test_case_description, sigmas, input, guide, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n        guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device)\n\n        len_input = len(input_tensor.shape)\n        # C++ extension so far only supports 5-dim inputs.\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(4)\n\n        output = TrainableJointBilateralFilterFunction.apply(input_tensor, guide_tensor, *sigmas).cpu().numpy()\n\n        # Make sure to return tensor of the same shape as the input.\n        if len_input == 3:\n            output = output.squeeze(4).squeeze(3)\n        elif len_input == 4:\n            output = output.squeeze(4)\n\n        # Ensure result are as expected.\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cpu_precise_backwards(self, test_case_description, sigmas, input, guide, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cpu\")\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n        input_tensor.requires_grad = True\n        guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device)\n\n        # C++ extension so far only supports 5-dim inputs.\n        len_input = len(input_tensor.shape)\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(4)\n\n        # Check gradient toward input.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False)\n        input_tensor = input_tensor.detach()\n        input_tensor.requires_grad = False\n\n        # Check gradient toward guide.\n        guide_tensor.requires_grad = True\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False)\n        guide_tensor = guide_tensor.detach()\n        guide_tensor.guide_tensor = False\n\n        # Check gradient toward sigma_x.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_y.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_z.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_color.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False)\n\n\n@skip_if_no_cuda\n@skip_if_no_cpp_extension\nclass JointBilateralFilterTestCaseCudaPrecise(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_precise(self, test_case_description, sigmas, input, guide, expected):\n        # Skip this test\n        if not torch.cuda.is_available():\n            return\n\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n\n        # Create input tensor and apply filter\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n        guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device)\n\n        len_input = len(input_tensor.shape)\n        # C++ extension so far only supports 5-dim inputs.\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(4)\n\n        output = TrainableJointBilateralFilterFunction.apply(input_tensor, guide_tensor, *sigmas).cpu().numpy()\n\n        # Make sure to return tensor of the same shape as the input.\n        if len_input == 3:\n            output = output.squeeze(4).squeeze(3)\n        elif len_input == 4:\n            output = output.squeeze(4)\n\n        # Ensure result are as expected.\n        np.testing.assert_allclose(output, expected, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES)\n    def test_cuda_precise_backwards(self, test_case_description, sigmas, input, guide, expected):\n        # Params to determine the implementation to test\n        device = torch.device(\"cuda\")\n\n        # Prepare input tensor\n        input_tensor = torch.from_numpy(np.array(input)).to(dtype=torch.double, device=device)\n        input_tensor.requires_grad = True\n        guide_tensor = torch.from_numpy(np.array(guide)).to(dtype=torch.double, device=device)\n\n        # C++ extension so far only supports 5-dim inputs.\n        len_input = len(input_tensor.shape)\n        if len_input == 3:\n            input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(3).unsqueeze(4)\n        elif len_input == 4:\n            input_tensor = input_tensor.unsqueeze(4)\n            guide_tensor = guide_tensor.unsqueeze(4)\n\n        # Check gradient toward input.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False)\n        input_tensor = input_tensor.detach()\n        input_tensor.requires_grad = False\n\n        # Check gradient toward guide.\n        guide_tensor.requires_grad = True\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-6, atol=1e-5, raise_exception=False)\n        guide_tensor = guide_tensor.detach()\n        guide_tensor.guide_tensor = False\n\n        # Check gradient toward sigma_x.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_y.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_z.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2], dtype=torch.double, requires_grad=True),\n            torch.tensor(sigmas[3]),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-2, atol=1e-3, raise_exception=False)\n\n        # Check gradient toward sigma_color.\n        args = (\n            input_tensor,\n            guide_tensor,\n            torch.tensor(sigmas[0]),\n            torch.tensor(sigmas[1]),\n            torch.tensor(sigmas[2]),\n            torch.tensor(sigmas[3], dtype=torch.double, requires_grad=True),\n        )\n        gradcheck(TrainableJointBilateralFilterFunction.apply, args, eps=1e-3, atol=1e-3, raise_exception=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_affine_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import normalize_transform, to_norm_affine\nfrom monai.networks.layers import AffineTransform\nfrom tests.test_utils import is_tf32_env\n\n_rtol = 1e-4 if not is_tf32_env() else 5e-3\n\nTEST_NORM_CASES = [\n    [(4, 5), True, [[[0.666667, 0, -1], [0, 0.5, -1], [0, 0, 1]]]],\n    [(4, 5), True, [[[0.5, 0, 0], [0, 0.4, 0], [0, 0, 1]]], True],\n    [\n        (2, 4, 5),\n        True,\n        [[[2.0, 0.0, 0.0, -1.0], [0.0, 0.6666667, 0.0, -1.0], [0.0, 0.0, 0.5, -1.0], [0.0, 0.0, 0.0, 1.0]]],\n    ],\n    [(4, 5), False, [[[0.5, 0.0, -0.75], [0.0, 0.4, -0.8], [0.0, 0.0, 1.0]]]],\n    [(4, 5), False, [[[0.6666667, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 1.0]]], True],\n    [(2, 4, 5), False, [[[1.0, 0.0, 0.0, -0.5], [0.0, 0.5, 0.0, -0.75], [0.0, 0.0, 0.4, -0.8], [0.0, 0.0, 0.0, 1.0]]]],\n]\n\nTEST_TO_NORM_AFFINE_CASES = [\n    [\n        [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]],\n        (4, 6),\n        (5, 3),\n        True,\n        [[[1.3333334, 0.0, 0.33333337], [0.0, 0.4, -0.6], [0.0, 0.0, 1.0]]],\n    ],\n    [\n        [[[1, 0, 0], [0, 1, 0], [0, 0, 1]]],\n        (4, 6),\n        (5, 3),\n        False,\n        [[[1.25, 0.0, 0.25], [0.0, 0.5, -0.5], [0.0, 0.0, 1.0]]],\n    ],\n    [\n        [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]],\n        (2, 4, 6),\n        (3, 5, 3),\n        True,\n        [[[2.0, 0.0, 0.0, 1.0], [0.0, 1.3333334, 0.0, 0.33333337], [0.0, 0.0, 0.4, -0.6], [0.0, 0.0, 0.0, 1.0]]],\n    ],\n    [\n        [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]],\n        (2, 4, 6),\n        (3, 5, 3),\n        False,\n        [[[1.5, 0.0, 0.0, 0.5], [0.0, 1.25, 0.0, 0.25], [0.0, 0.0, 0.5, -0.5], [0.0, 0.0, 0.0, 1.0]]],\n    ],\n    [\n        [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]],\n        (2, 4, 6),\n        (3, 5, 3),\n        False,\n        [[[2.0, 0.0, 0.0, 0.0], [0.0, 1.3333334, 0.0, 0.0], [0.0, 0.0, 0.4, 0.0], [0.0, 0.0, 0.0, 1.0]]],\n        True,\n    ],\n]\n\nTEST_ILL_TO_NORM_AFFINE_CASES = [\n    [[[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], (3, 4, 6), (3, 5, 3), False],\n    [[[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]], (4, 6), (3, 5, 3), True],\n    [[[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]], (4, 6), (3, 5, 3), True],\n]\n\n\nclass TestNormTransform(unittest.TestCase):\n    @parameterized.expand(TEST_NORM_CASES)\n    def test_norm_xform(self, input_shape, align_corners, expected, zero_centered=False):\n        norm = normalize_transform(\n            input_shape,\n            device=torch.device(\"cpu:0\"),\n            dtype=torch.float32,\n            align_corners=align_corners,\n            zero_centered=zero_centered,\n        )\n        norm = norm.detach().cpu().numpy()\n        np.testing.assert_allclose(norm, expected, atol=1e-6)\n        if torch.cuda.is_available():\n            norm = normalize_transform(\n                input_shape,\n                device=torch.device(\"cuda:0\"),\n                dtype=torch.float32,\n                align_corners=align_corners,\n                zero_centered=zero_centered,\n            )\n            norm = norm.detach().cpu().numpy()\n            np.testing.assert_allclose(norm, expected, atol=1e-4)\n\n\nclass TestToNormAffine(unittest.TestCase):\n    @parameterized.expand(TEST_TO_NORM_AFFINE_CASES)\n    def test_to_norm_affine(self, affine, src_size, dst_size, align_corners, expected, zero_centered=False):\n        affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n        new_affine = to_norm_affine(affine, src_size, dst_size, align_corners, zero_centered)\n        new_affine = new_affine.detach().cpu().numpy()\n        np.testing.assert_allclose(new_affine, expected, atol=1e-6)\n\n        if torch.cuda.is_available():\n            affine = torch.as_tensor(affine, device=torch.device(\"cuda:0\"), dtype=torch.float32)\n            new_affine = to_norm_affine(affine, src_size, dst_size, align_corners, zero_centered)\n            new_affine = new_affine.detach().cpu().numpy()\n            np.testing.assert_allclose(new_affine, expected, atol=1e-5, rtol=_rtol)\n\n    @parameterized.expand(TEST_ILL_TO_NORM_AFFINE_CASES)\n    def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):\n        with self.assertRaises(TypeError):\n            to_norm_affine(affine, src_size, dst_size, align_corners)\n        with self.assertRaises(ValueError):\n            affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n            to_norm_affine(affine, src_size, dst_size, align_corners)\n\n\nclass TestAffineTransform(unittest.TestCase):\n    @parameterized.expand(\n        [\n            (torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]),\n            (torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]), [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]),\n            (torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]), [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]),\n        ]\n    )\n    def test_affine_transforms(self, affine, expected):\n        image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])\n        out = AffineTransform(align_corners=False)(image, affine)\n        out = out.detach().cpu().numpy()\n        np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)\n\n    def test_zoom(self):\n        affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]])\n        image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device(\"cpu:0\"))\n        out = AffineTransform((3, 2), align_corners=False)(image, affine)\n        expected = [[[[1, 3], [5, 7], [9, 11]]]]\n        np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)\n\n    def test_zoom_1(self):\n        affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0]])\n        image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device(\"cpu:0\"))\n        out = AffineTransform()(image, affine, (1, 4))\n        expected = [[[[2.333333, 3.333333, 4.333333, 5.333333]]]]\n        np.testing.assert_allclose(out, expected, atol=_rtol)\n\n    def test_zoom_2(self):\n        affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32)\n        image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device(\"cpu:0\"))\n        out = AffineTransform((1, 2))(image, affine)\n        expected = [[[[1.458333, 4.958333]]]]\n        np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)\n\n    def test_zoom_zero_center(self):\n        affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32)\n        image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device(\"cpu:0\"))\n        out = AffineTransform((1, 2), zero_centered=True)(image, affine)\n        expected = [[[[5.5, 7.5]]]]\n        np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)\n\n    def test_affine_transform_minimum(self):\n        t = np.pi / 3\n        affine = [[np.cos(t), -np.sin(t), 0], [np.sin(t), np.cos(t), 0], [0, 0, 1]]\n        affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n        image = torch.arange(24.0).view(1, 1, 4, 6).to(device=torch.device(\"cpu:0\"))\n        out = AffineTransform(align_corners=False)(image, affine)\n        out = out.detach().cpu().numpy()\n        expected = [\n            [\n                [\n                    [0.0, 0.06698727, 0.0, 0.0, 0.0, 0.0],\n                    [3.8660254, 0.86602557, 0.0, 0.0, 0.0, 0.0],\n                    [7.732051, 3.035899, 0.73205125, 0.0, 0.0, 0.0],\n                    [11.598076, 6.901923, 2.7631402, 0.0, 0.0, 0.0],\n                ]\n            ]\n        ]\n        np.testing.assert_allclose(out, expected, atol=1e-3, rtol=_rtol)\n\n    def test_affine_transform_2d(self):\n        t = np.pi / 3\n        affine = [[np.cos(t), -np.sin(t), 0], [np.sin(t), np.cos(t), 0], [0, 0, 1]]\n        affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n        image = torch.arange(24.0).view(1, 1, 4, 6).to(device=torch.device(\"cpu:0\"))\n        xform = AffineTransform((3, 4), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n        out = xform(image, affine)\n        out = out.detach().cpu().numpy()\n        expected = [\n            [\n                [\n                    [7.1525574e-07, 4.9999994e-01, 1.0000000e00, 1.4999999e00],\n                    [3.8660259e00, 1.3660253e00, 1.8660252e00, 2.3660252e00],\n                    [7.7320518e00, 3.0358994e00, 2.7320509e00, 3.2320507e00],\n                ]\n            ]\n        ]\n        np.testing.assert_allclose(out, expected, atol=1e-3, rtol=_rtol)\n\n        if torch.cuda.is_available():\n            affine = torch.as_tensor(affine, device=torch.device(\"cuda:0\"), dtype=torch.float32)\n            image = torch.arange(24.0).view(1, 1, 4, 6).to(device=torch.device(\"cuda:0\"))\n            xform = AffineTransform(padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            out = xform(image, affine, (3, 4))\n            out = out.detach().cpu().numpy()\n            expected = [\n                [\n                    [\n                        [7.1525574e-07, 4.9999994e-01, 1.0000000e00, 1.4999999e00],\n                        [3.8660259e00, 1.3660253e00, 1.8660252e00, 2.3660252e00],\n                        [7.7320518e00, 3.0358994e00, 2.7320509e00, 3.2320507e00],\n                    ]\n                ]\n            ]\n            np.testing.assert_allclose(out, expected, atol=5e-3)\n\n    def test_affine_transform_3d(self):\n        t = np.pi / 3\n        affine = [[1, 0, 0, 0], [0.0, np.cos(t), -np.sin(t), 0], [0, np.sin(t), np.cos(t), 0], [0, 0, 0, 1]]\n        affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n        image = torch.arange(48.0).view(2, 1, 4, 2, 3).to(device=torch.device(\"cpu:0\"))\n        xform = AffineTransform((3, 4, 2), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n        out = xform(image, affine)\n        out = out.detach().cpu().numpy()\n        expected = [\n            [\n                [\n                    [[0.00000006, 0.5000001], [2.3660254, 1.3660254], [4.732051, 2.4019241], [5.0, 3.9019237]],\n                    [[6.0, 6.5], [8.366026, 7.3660254], [10.732051, 8.401924], [11.0, 9.901924]],\n                    [[12.0, 12.5], [14.366026, 13.366025], [16.732052, 14.401924], [17.0, 15.901923]],\n                ]\n            ],\n            [\n                [\n                    [[24.0, 24.5], [26.366024, 25.366024], [28.732052, 26.401924], [29.0, 27.901924]],\n                    [[30.0, 30.5], [32.366028, 31.366026], [34.732048, 32.401924], [35.0, 33.901924]],\n                    [[36.0, 36.5], [38.366024, 37.366024], [40.73205, 38.401924], [41.0, 39.901924]],\n                ]\n            ],\n        ]\n        np.testing.assert_allclose(out, expected, atol=1e-4, rtol=_rtol)\n\n        if torch.cuda.is_available():\n            affine = torch.as_tensor(affine, device=torch.device(\"cuda:0\"), dtype=torch.float32)\n            image = torch.arange(48.0).view(2, 1, 4, 2, 3).to(device=torch.device(\"cuda:0\"))\n            xform = AffineTransform(padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            out = xform(image, affine, (3, 4, 2))\n            out = out.detach().cpu().numpy()\n            expected = [\n                [\n                    [\n                        [[0.00000006, 0.5000001], [2.3660254, 1.3660254], [4.732051, 2.4019241], [5.0, 3.9019237]],\n                        [[6.0, 6.5], [8.366026, 7.3660254], [10.732051, 8.401924], [11.0, 9.901924]],\n                        [[12.0, 12.5], [14.366026, 13.366025], [16.732052, 14.401924], [17.0, 15.901923]],\n                    ]\n                ],\n                [\n                    [\n                        [[24.0, 24.5], [26.366024, 25.366024], [28.732052, 26.401924], [29.0, 27.901924]],\n                        [[30.0, 30.5], [32.366028, 31.366026], [34.732048, 32.401924], [35.0, 33.901924]],\n                        [[36.0, 36.5], [38.366024, 37.366024], [40.73205, 38.401924], [41.0, 39.901924]],\n                    ]\n                ],\n            ]\n            np.testing.assert_allclose(out, expected, atol=5e-3)\n\n    def test_ill_affine_transform(self):\n        with self.assertRaises(ValueError):  # image too small\n            t = np.pi / 3\n            affine = [[1, 0, 0, 0], [0.0, np.cos(t), -np.sin(t), 0], [0, np.sin(t), np.cos(t), 0], [0, 0, 0, 1]]\n            affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n            xform = AffineTransform((3, 4, 2), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            xform(torch.as_tensor([1.0, 2.0, 3.0]), affine)\n\n        with self.assertRaises(ValueError):  # output shape too small\n            t = np.pi / 3\n            affine = [[1, 0, 0, 0], [0.0, np.cos(t), -np.sin(t), 0], [0, np.sin(t), np.cos(t), 0], [0, 0, 0, 1]]\n            affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n            image = torch.arange(48).view(2, 1, 4, 2, 3).to(device=torch.device(\"cpu:0\"))\n            xform = AffineTransform((3, 4), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            xform(image, affine)\n\n        with self.assertRaises(ValueError):  # incorrect affine\n            t = np.pi / 3\n            affine = [[1, 0, 0, 0], [0.0, np.cos(t), -np.sin(t), 0], [0, np.sin(t), np.cos(t), 0], [0, 0, 0, 1]]\n            affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n            affine = affine.unsqueeze(0).unsqueeze(0)\n            image = torch.arange(48).view(2, 1, 4, 2, 3).to(device=torch.device(\"cpu:0\"))\n            xform = AffineTransform((2, 3, 4), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            xform(image, affine)\n\n        with self.assertRaises(ValueError):  # batch doesn't match\n            t = np.pi / 3\n            affine = [[1, 0, 0, 0], [0.0, np.cos(t), -np.sin(t), 0], [0, np.sin(t), np.cos(t), 0], [0, 0, 0, 1]]\n            affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n            affine = affine.unsqueeze(0)\n            affine = affine.repeat(3, 1, 1)\n            image = torch.arange(48).view(2, 1, 4, 2, 3).to(device=torch.device(\"cpu:0\"))\n            xform = AffineTransform((2, 3, 4), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            xform(image, affine)\n\n        with self.assertRaises(RuntimeError):  # input grid dtypes different\n            t = np.pi / 3\n            affine = [[1, 0, 0, 0], [0.0, np.cos(t), -np.sin(t), 0], [0, np.sin(t), np.cos(t), 0], [0, 0, 0, 1]]\n            affine = torch.as_tensor(affine, device=torch.device(\"cpu:0\"), dtype=torch.float32)\n            affine = affine.unsqueeze(0)\n            affine = affine.repeat(2, 1, 1)\n            image = torch.arange(48).view(2, 1, 4, 2, 3).to(device=torch.device(\"cpu:0\"), dtype=torch.int32)\n            xform = AffineTransform((2, 3, 4), padding_mode=\"border\", mode=\"bilinear\", normalized=True)\n            xform(image, affine)\n\n        with self.assertRaises(ValueError):  # wrong affine\n            affine = torch.as_tensor([[1, 0, 0, 0], [0, 0, 0, 1]])\n            image = torch.arange(48).view(2, 1, 4, 2, 3).to(device=torch.device(\"cpu:0\"))\n            xform = AffineTransform((2, 3, 4), padding_mode=\"border\", align_corners=False, mode=\"bilinear\")\n            xform(image, affine)\n\n        with self.assertRaises(RuntimeError):  # dtype doesn't match\n            affine = torch.as_tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float64)\n            image = torch.arange(1.0, 13.0).view(1, 1, 3, 4).to(device=torch.device(\"cpu:0\"))\n            AffineTransform((1, 2))(image, affine)\n\n    def test_forward_2d(self):\n        x = torch.rand(2, 1, 4, 4)\n        theta = torch.Tensor([[[0, -1, 0], [1, 0, 0]]]).repeat(2, 1, 1)\n        grid = torch.nn.functional.affine_grid(theta, x.size(), align_corners=False)\n        expected = torch.nn.functional.grid_sample(x, grid, align_corners=False)\n        expected = expected.detach().cpu().numpy()\n\n        actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta)\n        actual = actual.detach().cpu().numpy()\n        np.testing.assert_allclose(actual, expected)\n        np.testing.assert_allclose(list(theta.shape), [2, 2, 3])\n\n        theta = torch.Tensor([[0, -1, 0], [1, 0, 0]])\n        actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta)\n        actual = actual.detach().cpu().numpy()\n        np.testing.assert_allclose(actual, expected)\n        np.testing.assert_allclose(list(theta.shape), [2, 3])\n\n        theta = torch.Tensor([[[0, -1, 0], [1, 0, 0]]])\n        actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta)\n        actual = actual.detach().cpu().numpy()\n        np.testing.assert_allclose(actual, expected)\n        np.testing.assert_allclose(list(theta.shape), [1, 2, 3])\n\n    def test_forward_3d(self):\n        x = torch.rand(2, 1, 4, 4, 4)\n        theta = torch.Tensor([[[0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 1, 0]]]).repeat(2, 1, 1)\n        grid = torch.nn.functional.affine_grid(theta, x.size(), align_corners=False)\n        expected = torch.nn.functional.grid_sample(x, grid, align_corners=False)\n        expected = expected.detach().cpu().numpy()\n\n        actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta)\n        actual = actual.detach().cpu().numpy()\n        np.testing.assert_allclose(actual, expected)\n        np.testing.assert_allclose(list(theta.shape), [2, 3, 4])\n\n        theta = torch.Tensor([[0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 1, 0]])\n        actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta)\n        actual = actual.detach().cpu().numpy()\n        np.testing.assert_allclose(actual, expected)\n        np.testing.assert_allclose(list(theta.shape), [3, 4])\n\n        theta = torch.Tensor([[[0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 1, 0]]])\n        actual = AffineTransform(normalized=True, reverse_indexing=False, align_corners=False)(x, theta)\n        actual = actual.detach().cpu().numpy()\n        np.testing.assert_allclose(actual, expected)\n        np.testing.assert_allclose(list(theta.shape), [1, 3, 4])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_apply_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.networks.layers import apply_filter\n\n\nclass ApplyFilterTestCase(unittest.TestCase):\n\n    def test_1d(self):\n        a = torch.tensor([[list(range(10))]], dtype=torch.float)\n        out = apply_filter(a, torch.tensor([-1, 0, 1]), stride=1)\n        expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]])\n        np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n        if torch.cuda.is_available():\n            out = apply_filter(a.cuda(), torch.tensor([-1, 0, 1]).cuda())\n            np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_2d(self):\n        a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float)\n        expected = np.array(\n            [\n                [14.0, 21.0, 21.0, 21.0, 21.0, 21.0, 14.0],\n                [15.0, 24.0, 27.0, 30.0, 33.0, 36.0, 25.0],\n                [14.0, 21.0, 21.0, 21.0, 21.0, 21.0, 14.0],\n            ]\n        )\n        expected = expected[None][None]\n        out = apply_filter(a, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]))\n        np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n        if torch.cuda.is_available():\n            out = apply_filter(a.cuda(), torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).cuda())\n            np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_3d(self):\n        a = torch.tensor(\n            [[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]],\n            dtype=torch.float,\n        )\n        a = a[None][None]\n        a = a.expand(2, 3, -1, -1, -1)\n        expected = np.array(\n            [\n                [\n                    [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0],\n                    [3.0, 9.0, 18.0, 27.0, 36.0, 45.0, 33.0],\n                    [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0],\n                ],\n                [\n                    [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0],\n                    [3.0, 9.0, 18.0, 27.0, 36.0, 45.0, 33.0],\n                    [2.0, 6.0, 12.0, 18.0, 24.0, 30.0, 22.0],\n                ],\n            ]\n        )\n        # testing shapes\n        k = torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]])\n        for kernel in (k, k[None], k[None][None]):\n            out = apply_filter(a, kernel)\n            np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4)\n            if torch.cuda.is_available():\n                out = apply_filter(a.cuda(), kernel.cuda())\n                np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4)\n\n    def test_wrong_args(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            apply_filter(torch.ones((1, 2, 3, 2)), torch.ones((2,)))\n        with self.assertRaisesRegex(NotImplementedError, \"\"):\n            apply_filter(torch.ones((1, 1, 1, 2, 3, 2)), torch.ones((2,)))\n        with self.assertRaisesRegex(TypeError, \"\"):\n            apply_filter(((1, 1, 1, 2, 3, 2)), torch.ones((2,)))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_channel_pad.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.layers import ChannelPad\n\nTEST_CASES_3D = []\nfor type_1 in (\"pad\", \"project\"):\n    input_shape = (16, 10, 32, 24, 48)\n    out_chns = 13\n    result_shape = list(input_shape)\n    result_shape[1] = out_chns\n    test_case = [\n        {\"spatial_dims\": 3, \"in_channels\": 10, \"out_channels\": out_chns, \"mode\": type_1},\n        input_shape,\n        result_shape,\n    ]\n    TEST_CASES_3D.append(test_case)\n\n\nclass TestChannelPad(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = ChannelPad(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(list(result.shape), list(expected_shape))\n\n    def test_wrong_mode(self):\n        with self.assertRaises(ValueError):\n            ChannelPad(3, 10, 20, mode=\"test\")(torch.randn(10, 10))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_conjugate_gradient.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.networks.layers import ConjugateGradient\n\n\nclass TestConjugateGradient(unittest.TestCase):\n\n    def test_real_valued_inverse(self):\n        \"\"\"Test ConjugateGradient with real-valued input: when the input is real\n        value, the output should be the inverse of the matrix.\"\"\"\n        a_dim = 3\n        a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float)\n\n        def a_op(x):\n            return a_mat @ x\n\n        cg_solver = ConjugateGradient(a_op, num_iter=100)\n        # define the measurement\n        y = torch.tensor([1, 2, 3], dtype=torch.float)\n        # solve for x\n        x = cg_solver(torch.zeros(a_dim), y)\n        x_ref = torch.linalg.solve(a_mat, y)\n        # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution'\n        self.assertTrue(torch.allclose(x, x_ref, atol=1e-6))\n\n    def test_complex_valued_inverse(self):\n        a_dim = 3\n        a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64)\n\n        def a_op(x):\n            return a_mat @ x\n\n        cg_solver = ConjugateGradient(a_op, num_iter=100)\n        y = torch.tensor([1, 2, 3], dtype=torch.complex64)\n        x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y)\n        x_ref = torch.linalg.solve(a_mat, y)\n        self.assertTrue(torch.allclose(x, x_ref, atol=1e-6))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_drop_path.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import DropPath\n\nTEST_CASES = [\n    [{\"drop_prob\": 0.0, \"scale_by_keep\": True}, (1, 8, 8)],\n    [{\"drop_prob\": 0.7, \"scale_by_keep\": False}, (2, 16, 16, 16)],\n    [{\"drop_prob\": 0.3, \"scale_by_keep\": True}, (6, 16, 12)],\n]\n\nTEST_ERRORS = [[{\"drop_prob\": 2, \"scale_by_keep\": False}, (1, 24, 6)]]\n\n\nclass TestDropPath(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape):\n        im = torch.rand(input_shape)\n        dr_path = DropPath(**input_param)\n        out = dr_path(im)\n        self.assertEqual(out.shape, input_shape)\n\n    @parameterized.expand(TEST_ERRORS)\n    def test_ill_arg(self, input_param, input_shape):\n        with self.assertRaises(ValueError):\n            DropPath(**input_param)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_gaussian.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers.convutils import gaussian_1d\n\nTEST_CASES_NORM_F = [\n    [\n        0.5,\n        [\n            [\n                0.0000000e00,\n                0.0000000e00,\n                3.5762787e-07,\n                2.0313263e-04,\n                1.6743928e-02,\n                2.2280261e-01,\n                5.2049994e-01,\n                2.2280261e-01,\n                1.6743928e-02,\n                2.0313263e-04,\n                3.5762787e-07,\n                0.0000000e00,\n                0.0000000e00,\n            ],\n            [\n                1.3086457e-16,\n                7.8354033e-12,\n                6.3491058e-08,\n                6.9626461e-05,\n                1.0333488e-02,\n                2.0755373e-01,\n                5.6418961e-01,\n                2.0755373e-01,\n                1.0333488e-02,\n                6.9626461e-05,\n                6.3491058e-08,\n                7.8354033e-12,\n                1.3086457e-16,\n            ],\n            [\n                2.0750829e-07,\n                4.9876030e-06,\n                9.9959565e-05,\n                1.6043411e-03,\n                1.9352052e-02,\n                1.5642078e-01,\n                6.4503527e-01,\n                1.5642078e-01,\n                1.9352052e-02,\n                1.6043411e-03,\n                9.9959565e-05,\n                4.9876030e-06,\n                2.0750829e-07,\n            ],\n        ],\n    ],\n    [\n        1.0,\n        [\n            [\n                2.9802322e-08,\n                3.3676624e-06,\n                2.2923946e-04,\n                5.9770346e-03,\n                6.0597539e-02,\n                2.4173033e-01,\n                3.8292491e-01,\n                2.4173033e-01,\n                6.0597539e-02,\n                5.9770346e-03,\n                2.2923946e-04,\n                3.3676624e-06,\n                2.9802322e-08,\n            ],\n            [\n                6.0758829e-09,\n                1.4867196e-06,\n                1.3383022e-04,\n                4.4318484e-03,\n                5.3990968e-02,\n                2.4197073e-01,\n                3.9894229e-01,\n                2.4197073e-01,\n                5.3990968e-02,\n                4.4318484e-03,\n                1.3383022e-04,\n                1.4867196e-06,\n                6.0758829e-09,\n            ],\n            [\n                8.2731149e-06,\n                9.9865720e-05,\n                1.0069301e-03,\n                8.1553087e-03,\n                4.9938772e-02,\n                2.0791042e-01,\n                4.6575961e-01,\n                2.0791042e-01,\n                4.9938772e-02,\n                8.1553087e-03,\n                1.0069301e-03,\n                9.9865720e-05,\n                8.2731149e-06,\n            ],\n        ],\n    ],\n    [\n        2.0,\n        [\n            [\n                4.81605530e-05,\n                6.81042671e-04,\n                5.93280792e-03,\n                3.18857729e-02,\n                1.05872214e-01,\n                2.17414647e-01,\n                2.76326418e-01,\n                2.17414647e-01,\n                1.05872214e-01,\n                3.18857729e-02,\n                5.93280792e-03,\n                6.81042671e-04,\n                4.81605530e-05,\n            ],\n            [\n                3.48132307e-05,\n                5.44570561e-04,\n                5.16674388e-03,\n                2.97325663e-02,\n                1.03776865e-01,\n                2.19695643e-01,\n                2.82094806e-01,\n                2.19695643e-01,\n                1.03776865e-01,\n                2.97325663e-02,\n                5.16674388e-03,\n                5.44570561e-04,\n                3.48132307e-05,\n            ],\n            [\n                2.1655980e-04,\n                1.3297606e-03,\n                6.8653636e-03,\n                2.8791221e-02,\n                9.3239017e-02,\n                2.1526930e-01,\n                3.0850834e-01,\n                2.1526930e-01,\n                9.3239017e-02,\n                2.8791221e-02,\n                6.8653636e-03,\n                1.3297606e-03,\n                2.1655980e-04,\n            ],\n        ],\n    ],\n    [\n        4.0,\n        [\n            [\n                0.00240272,\n                0.00924471,\n                0.02783468,\n                0.06559062,\n                0.12097758,\n                0.17466632,\n                0.19741265,\n                0.17466632,\n                0.12097758,\n                0.06559062,\n                0.02783468,\n                0.00924471,\n                0.00240272,\n            ],\n            [\n                0.00221592,\n                0.00876415,\n                0.02699548,\n                0.0647588,\n                0.12098537,\n                0.17603266,\n                0.19947115,\n                0.17603266,\n                0.12098537,\n                0.0647588,\n                0.02699548,\n                0.00876415,\n                0.00221592,\n            ],\n            [\n                0.002829,\n                0.009244,\n                0.02594,\n                0.061124,\n                0.117627,\n                0.178751,\n                0.207002,\n                0.178751,\n                0.117627,\n                0.061124,\n                0.02594,\n                0.009244,\n                0.002829,\n            ],\n        ],\n    ],\n]\n\n\nclass TestGaussian1d(unittest.TestCase):\n\n    def test_gaussian(self):\n        np.testing.assert_allclose(\n            gaussian_1d(0.5, 8),\n            torch.tensor(\n                [\n                    0.0000e00,\n                    2.9802e-07,\n                    1.3496e-03,\n                    1.5731e-01,\n                    6.8269e-01,\n                    1.5731e-01,\n                    1.3496e-03,\n                    2.9802e-07,\n                    0.0000e00,\n                ]\n            ),\n            rtol=1e-4,\n        )\n\n        np.testing.assert_allclose(gaussian_1d(1, 1), torch.tensor([0.24173, 0.382925, 0.24173]), rtol=1e-4)\n        np.testing.assert_allclose(gaussian_1d(1, 1, normalize=True), torch.tensor([0.2790, 0.4420, 0.2790]), rtol=1e-4)\n\n    def test_scalespace_gaussian(self):\n        np.testing.assert_allclose(\n            gaussian_1d(0.5, 8, \"scalespace\"),\n            torch.tensor(\n                [\n                    7.9472e-06,\n                    2.5451e-04,\n                    6.1161e-03,\n                    9.8113e-02,\n                    7.9102e-01,\n                    9.8113e-02,\n                    6.1161e-03,\n                    2.5451e-04,\n                    7.9472e-06,\n                ]\n            ),\n            rtol=1e-4,\n        )\n\n        np.testing.assert_allclose(\n            gaussian_1d(1, 1, \"scalespace\"), torch.tensor([0.20791, 0.46576, 0.20791]), rtol=1e-3\n        )\n\n        np.testing.assert_allclose(\n            gaussian_1d(1, 1, \"scalespace\", normalize=True), torch.tensor([0.2358, 0.5283, 0.2358]), rtol=1e-3\n        )\n\n        np.testing.assert_allclose(\n            gaussian_1d(5, 1, \"scalespace\"),\n            torch.tensor(\n                [\n                    0.048225,\n                    0.057891,\n                    0.06675,\n                    0.073911,\n                    0.078576,\n                    0.080197,\n                    0.078576,\n                    0.073911,\n                    0.06675,\n                    0.057891,\n                    0.048225,\n                ]\n            ),\n            rtol=1e-3,\n        )\n\n    @parameterized.expand(TEST_CASES_NORM_F)\n    def test_norm_false(self, variance, expected):\n        extent = 6\n        atol = 1e-4\n        sigma = np.sqrt(variance)\n        k_erf = gaussian_1d(sigma, truncated=extent / sigma, approx=\"erf\", normalize=False).numpy()\n        k_sampled = gaussian_1d(sigma, truncated=extent / sigma, approx=\"sampled\").numpy()\n        k_scalespace = gaussian_1d(sigma, truncated=extent / sigma, approx=\"scalespace\").numpy()\n        np.testing.assert_allclose(k_erf, expected[0], atol=atol)\n        np.testing.assert_allclose(k_sampled, expected[1], atol=atol)\n        np.testing.assert_allclose(k_scalespace, expected[2], atol=atol)\n\n    def test_wrong_sigma(self):\n        with self.assertRaises(ValueError):\n            gaussian_1d(1, -10)\n        with self.assertRaises(NotImplementedError):\n            gaussian_1d(1, 10, \"wrong_arg\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_gaussian_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import GaussianFilter\nfrom monai.utils import unsqueeze_left\nfrom tests.test_utils import TEST_DEVICES, test_is_quick\n\n# trainable test cases\nTEST_CASES = [[{\"type\": \"erf\", \"gt\": 2.0}], [{\"type\": \"scalespace\", \"gt\": 3.0}], [{\"type\": \"sampled\", \"gt\": 5.0}]]\nTEST_CASES_GPU = [[{\"type\": \"erf\", \"gt\": 0.8, \"device\": \"cuda\"}], [{\"type\": \"sampled\", \"gt\": 5.0, \"device\": \"cuda\"}]]\nTEST_CASES_3D = [\n    [{\"type\": \"scalespace\", \"gt\": 0.5, \"dims\": (2, 3, 8, 9, 10), \"lr\": 0.01, \"device\": \"cuda\"}],\n    [{\"type\": \"erf\", \"gt\": 3.8, \"dims\": (2, 3, 8, 9, 10), \"lr\": 0.1, \"device\": \"cuda\"}],\n]\nTEST_CASES_SLOW = [\n    [{\"type\": \"erf\", \"gt\": 2.0, \"dims\": (2, 3, 8, 9, 10)}],\n    [{\"type\": \"scalespace\", \"gt\": 3.0, \"dims\": (2, 3, 8, 9, 10), \"device\": \"cuda\"}],\n    [{\"type\": \"sampled\", \"gt\": (0.5, 0.8, 3.0), \"dims\": (2, 3, 8, 9, 10), \"lr\": 0.1}],\n    [{\"type\": \"scalespace\", \"gt\": 3.0, \"device\": \"cuda\"}],\n]\n\nTEST_CASES_BP = TEST_CASES + TEST_CASES_GPU + TEST_CASES_3D\n\nif not test_is_quick():\n    TEST_CASES_BP += TEST_CASES_SLOW\n\n\nEXPECTED_1D = np.array(\n    [0.5654129, 0.68915915, 0.79146194, 0.8631974, 0.8998163, 0.8998163, 0.8631973, 0.79146194, 0.6891592, 0.5654129]\n)\n\nEXPECTED_2D = np.array(\n    [[0.13239081, 0.13932934, 0.13239081], [0.13932936, 0.14663152, 0.13932936], [0.13239081, 0.13932934, 0.13239081]]\n)\n\nEXPECTED_3D = np.array(\n    [\n        [\n            [0.07189433, 0.07911152, 0.07911152, 0.07189433],\n            [0.07566228, 0.08325771, 0.08325771, 0.07566228],\n            [0.07189433, 0.07911152, 0.07911152, 0.07189433],\n        ],\n        [\n            [0.07911152, 0.08705322, 0.08705322, 0.07911152],\n            [0.08325771, 0.09161563, 0.09161563, 0.08325771],\n            [0.07911152, 0.08705322, 0.08705322, 0.07911152],\n        ],\n        [\n            [0.07911152, 0.08705322, 0.08705322, 0.07911152],\n            [0.08325771, 0.09161563, 0.09161563, 0.08325771],\n            [0.07911152, 0.08705322, 0.08705322, 0.07911152],\n        ],\n        [\n            [0.07189433, 0.07911152, 0.07911152, 0.07189433],\n            [0.07566228, 0.08325771, 0.08325771, 0.07566228],\n            [0.07189433, 0.07911152, 0.07911152, 0.07189433],\n        ],\n    ]\n)\n\nEXPECTED_3D_SIGMAS = np.array(\n    [\n        [[0.13690521, 0.13690521], [0.15181276, 0.15181276], [0.13690521, 0.13690521]],\n        [[0.1506486, 0.15064861], [0.16705267, 0.16705267], [0.1506486, 0.15064861]],\n        [[0.1506486, 0.15064861], [0.16705267, 0.16705267], [0.1506486, 0.15064861]],\n        [[0.13690521, 0.13690521], [0.15181276, 0.15181276], [0.13690521, 0.13690521]],\n    ]\n)\n\nDEVICE_RTOL = [d + [rtol] for d, rtol in zip(TEST_DEVICES, (1e-5, 1e-2))]  # device/tolerance pairs\nTEST_CASES_2D_3D = [[(1, 1, 4, 3, 4), (3, 3, 3), EXPECTED_3D] + d for d in DEVICE_RTOL]\nTEST_CASES_2D_3D += [[(1, 1, 4, 3, 2), (3, [3, 2, 1], 3), EXPECTED_3D_SIGMAS] + d for d in DEVICE_RTOL]\n\nif not test_is_quick():\n    TEST_CASES_2D_3D += [[(1, 1, 3, 3), (2, 3, 3), EXPECTED_2D] + d for d in DEVICE_RTOL]\n\n\nclass GaussianFilterTestCase(unittest.TestCase):\n    def test_wrong_args(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            GaussianFilter(3, [3, 2], 3)\n        GaussianFilter(3, [3, 2, 1], 3)  # test init\n\n    def test_1d(self):\n        a = torch.ones(1, 8, 10)\n        g = GaussianFilter(1, 3, 3)\n\n        expected = np.tile(unsqueeze_left(EXPECTED_1D, 3), (1, 8, 1))\n        np.testing.assert_allclose(g(a).cpu().numpy(), expected, rtol=1e-5)\n\n    @parameterized.expand(TEST_CASES_2D_3D)\n    def test_2d_3d(self, oargs, gargs, expected, device, rtol):\n        a = torch.ones(*oargs).to(device)\n        g = GaussianFilter(*gargs).to(device)\n\n        np.testing.assert_allclose(g(a).cpu().numpy(), unsqueeze_left(expected, len(oargs)), rtol=rtol)\n\n\nclass TestGaussianFilterBackprop(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_BP)\n    def test_training(self, input_args):\n        input_dims = input_args.get(\"dims\", (2, 3, 8))\n\n        device = torch.device(\"cpu\")\n        if input_args.get(\"device\") == \"cuda\" and torch.cuda.is_available():\n            device = torch.device(\"cuda:0\")\n\n        base = torch.ones(*input_dims).to(device)\n        gt = torch.tensor(input_args[\"gt\"], requires_grad=False)\n        g_type = input_args[\"type\"]\n        lr = input_args.get(\"lr\", 0.1)\n        init_sigma = input_args.get(\"init\", 1.0)\n\n        # static filter to generate a target\n        spatial_dims = len(base.shape) - 2\n        filtering = GaussianFilter(spatial_dims=spatial_dims, sigma=gt.to(device), approx=g_type, requires_grad=False)\n        filtering.to(device)\n        target = filtering(base)\n        self.assertFalse(filtering.sigma[0].requires_grad)\n\n        # build trainable\n        init_sigma = torch.tensor(init_sigma).to(device)\n        trainable = GaussianFilter(spatial_dims=spatial_dims, sigma=init_sigma, approx=g_type, requires_grad=True)\n        trainable.to(device)\n        self.assertTrue(trainable.sigma[0].requires_grad)\n\n        # train\n        optimizer = torch.optim.Adam(trainable.parameters(), lr=lr)\n        for _ in range(1000):\n            optimizer.zero_grad()\n            pred = trainable(base)\n            loss = torch.pow(pred - target, 2).mean()\n            loss.backward()\n            if loss.item() < 1e-7:\n                break\n            optimizer.step()\n\n        for idx, s in enumerate(trainable.sigma):\n            np.testing.assert_allclose(s.cpu().item(), gt if len(gt.shape) == 0 else gt[idx].item(), rtol=1e-2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_get_layers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import get_act_layer, get_dropout_layer, get_norm_layer\n\nTEST_CASE_NORM = [\n    [{\"name\": (\"group\", {\"num_groups\": 1})}, \"GroupNorm(1, 1, eps=1e-05, affine=True)\"],\n    [\n        {\"name\": \"instance\", \"spatial_dims\": 2},\n        \"InstanceNorm2d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)\",\n    ],\n]\n\nTEST_CASE_ACT = [\n    [{\"name\": \"swish\"}, \"Swish()\"],\n    [{\"name\": (\"prelu\", {\"num_parameters\": 1, \"init\": 0.25})}, \"PReLU(num_parameters=1)\"],\n]\n\nTEST_CASE_DROPOUT = [\n    [{\"name\": \"dropout\"}, \"Dropout(p=0.5, inplace=False)\"],\n    [{\"name\": (\"alphadropout\", {\"p\": 0.25})}, \"AlphaDropout(p=0.25, inplace=False)\"],\n]\n\n\nclass TestGetLayers(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_NORM)\n    def test_norm_layer(self, input_param, expected):\n        layer = get_norm_layer(**input_param)\n        self.assertEqual(f\"{layer}\", expected)\n\n    @parameterized.expand(TEST_CASE_ACT)\n    def test_act_layer(self, input_param, expected):\n        layer = get_act_layer(**input_param)\n        self.assertEqual(f\"{layer}\", expected)\n\n    @parameterized.expand(TEST_CASE_DROPOUT)\n    def test_dropout_layer(self, input_param, expected):\n        layer = get_dropout_layer(**input_param)\n        self.assertEqual(f\"{layer}\", expected)\n\n\nclass TestSuggestion(unittest.TestCase):\n\n    def test_suggested(self):\n        with self.assertRaisesRegex(ValueError, \"did you mean 'GROUP'?\"):\n            get_norm_layer(name=\"grop\", spatial_dims=2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_gmm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai._extensions import load_module\nfrom monai.networks.layers import GaussianMixtureModel\nfrom tests.test_utils import skip_if_darwin, skip_if_no_cuda, skip_if_quick, skip_if_windows\n\nTEST_CASES = [\n    [\n        # Case Description\n        \"2 batches, 1 dimensions, 1 channels, 2 classes, 2 mixtures\",\n        # Class Count\n        2,\n        # Mixture Count\n        1,\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 0, 0, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0.2, 1, 0.8, 0.5]\n            ],\n        ],\n        # Labels\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, -1, 0, -1, 1]\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [1, 1, 0, 0, -1]\n            ],\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0, 0, 1, 1, 0],\n                # Channel 1\n                [1, 1, 0, 0, 1],\n            ],\n            # Batch 1\n            [\n                # Channel 0\n                [0, 0, 1, 1, 0.5],\n                # Channel 1\n                [1, 1, 0, 0, 0.5],\n            ],\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 1 dimensions, 5 channels, 2 classes, 1 mixtures\",\n        # Class Count\n        2,\n        # Mixture Count\n        1,\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1.0, 0.9, 0.0, 0.0, 0.0],\n                # Channel 1\n                [0.0, 0.0, 0.3, 0.3, 0.4],\n                # Channel 2\n                [0.9, 0.8, 0.0, 0.0, 0.0],\n                # Channel 3\n                [0.7, 0.9, 0.0, 0.0, 0.0],\n                # Channel 4\n                [0.2, 0.1, 0.2, 0.2, 0.1],\n            ]\n        ],\n        # Labels\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [0, 0, -1, 1, 1]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [1, 1, 0, 0, 0],\n                # Channel 1\n                [0, 0, 1, 1, 1],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 2 dimensions, 2 channels, 4 classes, 4 mixtures\",\n        # Class Count\n        4,\n        # Mixture Count\n        1,\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.8, 0.8, 0.0, 0.0, 0.0],\n                    [1.0, 0.9, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.8, 0.9],\n                    [0.0, 0.0, 0.0, 0.9, 1.0],\n                ],\n                # Channel 1\n                [\n                    [0.8, 0.8, 0.0, 0.0, 0.0],\n                    [0.7, 0.7, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.4, 0.5, 0.0, 0.0, 0.0],\n                    [0.7, 0.6, 0.0, 0.0, 0.0],\n                ],\n            ]\n        ],\n        # Labels\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [[-1, 1, -1, 0, -1], [1, -1, -1, -1, -1], [-1, -1, 0, -1, -1], [2, 2, -1, 3, -1], [-1, -1, -1, -1, 3]]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    [0.0, 0.0, 1.0, 1.0, 1.0],\n                    [0.0, 0.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 0.0, 1.0, 0.0, 0.0],\n                    [0.0, 0.0, 1.0, 0.0, 0.0],\n                ],\n                # Channel 1\n                [\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                ],\n                # Channel 2\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                    [1.0, 1.0, 0.0, 0.0, 0.0],\n                ],\n                # Channel 3\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                    [0.0, 0.0, 0.0, 1.0, 1.0],\n                ],\n            ]\n        ],\n    ],\n    [\n        # Case Description\n        \"1 batches, 3 dimensions, 1 channels, 2 classes, 1 mixtures\",\n        # Class Count\n        2,\n        # Mixture Count\n        1,\n        # Features\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Slice 0\n                    [[0.7, 0.6, 0.0], [0.5, 0.4, 0.0], [0.0, 0.0, 0.0]],\n                    # Slice 1\n                    [[0.5, 0.6, 0.0], [0.4, 0.3, 0.0], [0.0, 0.0, 0.0]],\n                    # Slice 2\n                    [[0.3, 0.3, 0.0], [0.2, 0.1, 0.0], [0.0, 0.0, 0.0]],\n                ]\n            ]\n        ],\n        # Labels\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Slice 0\n                    [[0, -1, -1], [0, -1, -1], [-1, -1, 1]],\n                    # Slice 1\n                    [[0, 0, -1], [-1, -1, 1], [-1, 1, 1]],\n                    # Slice 2\n                    [[0, -1, -1], [-1, -1, -1], [-1, -1, -1]],\n                ]\n            ]\n        ],\n        # Expected\n        [\n            # Batch 0\n            [\n                # Channel 0\n                [\n                    # Slice 0\n                    [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                    # Slice 1\n                    [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                    # Slice 2\n                    [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                ],\n                # Channel 1\n                [\n                    # Slice 0\n                    [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]],\n                    # Slice 1\n                    [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]],\n                    # Slice 2\n                    [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0], [1.0, 1.0, 1.0]],\n                ],\n            ]\n        ],\n    ],\n]\n\n\n@skip_if_quick\nclass GMMTestCase(unittest.TestCase):\n    def setUp(self):\n        self._var = os.environ.get(\"TORCH_EXTENSIONS_DIR\")\n        self.tempdir = tempfile.mkdtemp()\n        os.environ[\"TORCH_EXTENSIONS_DIR\"] = self.tempdir\n\n    def tearDown(self) -> None:\n        if self._var is None:\n            os.environ.pop(\"TORCH_EXTENSIONS_DIR\", None)\n        else:\n            os.environ[\"TORCH_EXTENSIONS_DIR\"] = f\"{self._var}\"\n        shutil.rmtree(self.tempdir)\n\n    @parameterized.expand(TEST_CASES)\n    @skip_if_no_cuda\n    def test_cuda(self, test_case_description, mixture_count, class_count, features, labels, expected):\n        # Device to run on\n        device = torch.device(\"cuda\")\n\n        # Create tensors\n        features_tensor = torch.tensor(features, dtype=torch.float32, device=device)\n        labels_tensor = torch.tensor(labels, dtype=torch.int32, device=device)\n\n        # Create GMM\n        try:\n            gmm = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count, verbose_build=True)\n        except RuntimeError as e:\n            if \"Error building extension\" in str(e):\n                self.skipTest(f\"GMM CUDA extension failed to compile: {e}\")\n            raise\n        # reload GMM to confirm the build\n        _ = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count, verbose_build=False)\n        # reload quietly\n        _ = GaussianMixtureModel(features_tensor.size(1), mixture_count, class_count, verbose_build=True)\n\n        # Apply GMM\n        gmm.learn(features_tensor, labels_tensor)\n        results_tensor = gmm.apply(features_tensor)\n\n        # Read back results\n        results = results_tensor.cpu().numpy()\n\n        # Ensure result are as expected\n        np.testing.assert_allclose(results, expected, atol=1e-3)\n\n    @skip_if_darwin\n    @skip_if_windows\n    def test_load(self):\n        if not torch.cuda.is_available():\n            with self.assertRaisesRegex(ImportError, \".*symbol.*\"):  # expecting import error if no cuda\n                load_module(\"gmm\", {\"CHANNEL_COUNT\": 2, \"MIXTURE_COUNT\": 2, \"MIXTURE_SIZE\": 3}, verbose_build=True)\n        else:\n            try:\n                load_module(\"gmm\", {\"CHANNEL_COUNT\": 2, \"MIXTURE_COUNT\": 2, \"MIXTURE_SIZE\": 3}, verbose_build=True)\n            except RuntimeError as e:\n                if \"Error building extension\" in str(e):\n                    self.skipTest(f\"GMM CUDA extension failed to compile: {e}\")\n                raise\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_grid_pull.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import grid_pull\nfrom monai.networks.utils import meshgrid_ij\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_no_cpp_extension\nfrom tests.testing_data.cpp_resample_answers import Expected_1D_GP_bwd, Expected_1D_GP_fwd\n\nBType, has_b_type = optional_import(\"monai._C\", name=\"BoundType\")\nPType, has_p_type = optional_import(\"monai._C\", name=\"InterpolationType\")\n\n\ndef make_grid(shape, dtype=None, device=None, requires_grad=True):\n    ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=requires_grad) for s in shape]\n    grid = torch.stack(meshgrid_ij(*ranges), dim=-1)\n    return grid[None]\n\n\n# 1D combinations of bounds/interpolations\nbounds = set(BType.__members__.values()) if has_b_type else []\ninterps = set(PType.__members__.values()) if has_p_type else []\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nTEST_1D_GP = []\nfor bound in bounds:\n    for interp in interps:\n        if not Expected_1D_GP_fwd or not Expected_1D_GP_bwd:\n            break  # skip if the testing data are unavailable\n        expected_val = Expected_1D_GP_fwd.pop(0)\n\n        for input_g in (True, False):\n            for grid_g in (True, False):\n                expected_grad = Expected_1D_GP_bwd.pop(0)\n                test_case = [\n                    {\n                        \"input\": torch.arange(10, dtype=torch.float, requires_grad=input_g, device=device).reshape(\n                            (1, 1, 10)\n                        ),\n                        \"grid\": make_grid((20,), dtype=torch.float, device=device, requires_grad=grid_g) + 0.5,\n                        \"interpolation\": interp,\n                        \"bound\": bound,\n                    },\n                    {\"val\": torch.tensor([[expected_val]]), \"device\": device, \"grad\": torch.tensor(expected_grad)},\n                ]\n                TEST_1D_GP.append(test_case)\n\n\n@skip_if_no_cpp_extension\nclass TestGridPull(unittest.TestCase):\n    @parameterized.expand(TEST_1D_GP, skip_on_empty=True)\n    def test_grid_pull(self, input_param, expected):\n        result = grid_pull(**input_param)\n        if input_param[\"input\"].requires_grad:\n            input_param[\"input\"].retain_grad()\n        if input_param[\"grid\"].requires_grad:\n            input_param[\"grid\"].retain_grad()\n        if input_param[\"input\"].requires_grad or input_param[\"grid\"].requires_grad:\n            result.sum().backward()\n\n        grads = []\n        if input_param[\"input\"].requires_grad:\n            grads.append(input_param[\"input\"].grad.view(-1))\n        if input_param[\"grid\"].requires_grad:\n            grads.append(input_param[\"grid\"].grad.view(-1))\n        if not grads:\n            grads = torch.tensor(0.0, device=result.device)\n        elif len(grads) == 1:\n            grads = grads[0]\n        else:\n            grads = torch.cat(grads, dim=0)\n        self.assertTrue(f\"{result.device}\".startswith(expected[\"device\"]))\n        np.testing.assert_allclose(result.detach().cpu().numpy(), expected[\"val\"].cpu().numpy(), rtol=1e-4, atol=1e-4)\n        np.testing.assert_allclose(grads.detach().cpu().numpy(), expected[\"grad\"].cpu().numpy(), rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_hilbert_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import HilbertTransform\n\n\ndef create_expected_numpy_output(input_datum, **kwargs):\n    x = np.fft.fft(input_datum.cpu().numpy(), **kwargs)\n    f = np.fft.fftfreq(x.shape[kwargs[\"axis\"]])\n    u = np.heaviside(f, 0.5)\n    new_dims_before = kwargs[\"axis\"]\n    new_dims_after = len(x.shape) - kwargs[\"axis\"] - 1\n    for _ in range(new_dims_before):\n        u = np.expand_dims(u, 0)\n    for _ in range(new_dims_after):\n        u = np.expand_dims(u, -1)\n    ht = np.fft.ifft(x * 2 * u, axis=kwargs[\"axis\"])\n\n    return ht\n\n\ncpu = torch.device(\"cpu\")\nn_samples = 500\nhann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples)\n\n# CPU TEST DATA\n\ncpu_input_data = {}\ncpu_input_data[\"1D\"] = torch.as_tensor(hann_windowed_sine, device=cpu)[None, None]\ncpu_input_data[\"2D\"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None, None]\ncpu_input_data[\"3D\"] = torch.as_tensor(\n    np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu\n)[None, None]\ncpu_input_data[\"1D 2CH\"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None]\ncpu_input_data[\"2D 2CH\"] = torch.as_tensor(\n    np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu\n)[None]\n\n# SINGLE-CHANNEL CPU VALUE TESTS\n\nTEST_CASE_1D_SINE_CPU = [\n    {},  # args (empty, so use default)\n    cpu_input_data[\"1D\"],  # Input data: Random 1D signal\n    create_expected_numpy_output(cpu_input_data[\"1D\"], axis=2),  # Expected output: FFT of signal\n    1e-5,  # absolute tolerance\n]\n\nTEST_CASE_2D_SINE_CPU = [\n    {},  # args (empty, so use default)\n    cpu_input_data[\"2D\"],  # Input data: Random 1D signal\n    create_expected_numpy_output(cpu_input_data[\"2D\"], axis=2),  # Expected output: FFT of signal\n    1e-5,  # absolute tolerance\n]\n\nTEST_CASE_3D_SINE_CPU = [\n    {},  # args (empty, so use default)\n    cpu_input_data[\"3D\"],  # Input data: Random 1D signal\n    create_expected_numpy_output(cpu_input_data[\"3D\"], axis=2),\n    1e-5,  # absolute tolerance\n]\n\n# MULTICHANNEL CPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS\n\nTEST_CASE_1D_2CH_SINE_CPU = [\n    {},  # args (empty, so use default)\n    cpu_input_data[\"1D 2CH\"],  # Input data: Random 1D signal\n    create_expected_numpy_output(cpu_input_data[\"1D 2CH\"], axis=2),\n    1e-5,  # absolute tolerance\n]\n\nTEST_CASE_2D_2CH_SINE_CPU = [\n    {},  # args (empty, so use default)\n    cpu_input_data[\"2D 2CH\"],  # Input data: Random 1D signal\n    create_expected_numpy_output(cpu_input_data[\"2D 2CH\"], axis=2),\n    1e-5,  # absolute tolerance\n]\n\nTEST_CASES_CPU = [\n    TEST_CASE_1D_SINE_CPU,\n    TEST_CASE_2D_SINE_CPU,\n    TEST_CASE_3D_SINE_CPU,\n    TEST_CASE_1D_2CH_SINE_CPU,\n    TEST_CASE_2D_2CH_SINE_CPU,\n]\n\n# GPU TEST DATA\n\nif torch.cuda.is_available():\n    gpu = torch.device(\"cuda\")\n    TEST_CASES_GPU = [[args, image.to(gpu), exp_data, atol] for args, image, exp_data, atol in TEST_CASES_CPU]\nelse:\n    TEST_CASES_GPU = []\n\n# TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py\n\n\nclass TestHilbertTransformCPU(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_CPU + TEST_CASES_GPU)\n    def test_value(self, arguments, image, expected_data, atol):\n        result = HilbertTransform(**arguments)(image)\n        result = np.squeeze(result.cpu().numpy())\n        np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_lltm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import LLTM\nfrom tests.test_utils import SkipIfNoModule, assert_allclose, is_tf32_env\n\n_rtol = 0.001 if is_tf32_env() else 0.0001\n\nTEST_CASE_1 = [\n    {\"input_features\": 32, \"state_size\": 2},\n    torch.tensor([[-0.1622, 0.1663], [0.5465, 0.0459], [-0.1436, 0.6171], [0.3632, -0.0111]]),\n    torch.tensor([[-1.3773, 0.3348], [0.8353, 1.3064], [-0.2179, 4.1739], [1.3045, -0.1444]]),\n]\n\n\nclass TestLLTM(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1])\n    @SkipIfNoModule(\"monai._C\")\n    def test_value(self, input_param, expected_h, expected_c):\n        torch.manual_seed(0)\n        x = torch.randn(4, 32)\n        h = torch.randn(4, 2)\n        c = torch.randn(4, 2)\n        new_h, new_c = LLTM(**input_param)(x, (h, c))\n        (new_h.sum() + new_c.sum()).backward()\n\n        assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04)\n        assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04)\n\n    @parameterized.expand([TEST_CASE_1])\n    @SkipIfNoModule(\"monai._C\")\n    def test_value_cuda(self, input_param, expected_h, expected_c):\n        device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu:0\")\n        torch.manual_seed(0)\n        x = torch.randn(4, 32).to(device)\n        h = torch.randn(4, 2).to(device)\n        c = torch.randn(4, 2).to(device)\n        lltm = LLTM(**input_param).to(device)\n        new_h, new_c = lltm(x, (h, c))\n        (new_h.sum() + new_c.sum()).backward()\n\n        assert_allclose(new_h, expected_h.to(device), rtol=_rtol, atol=0.001)\n        assert_allclose(new_c, expected_c.to(device), rtol=_rtol, atol=0.001)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_median_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import MedianFilter\n\n\nclass MedianFilterTestCase(unittest.TestCase):\n\n    @parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)])  # 3d_big  # 3d\n    def test_3d(self, input_tensor, radius):\n        filter = MedianFilter(radius).to(torch.device(\"cpu:0\"))\n\n        expected = input_tensor.numpy()\n        output = filter(input_tensor).cpu().numpy()\n\n        np.testing.assert_allclose(output, expected, rtol=1e-5)\n\n    def test_3d_radii(self):\n        a = torch.ones(1, 1, 4, 3, 2)\n        g = MedianFilter([3, 2, 1]).to(torch.device(\"cpu:0\"))\n\n        expected = a.numpy()\n        out = g(a).cpu().numpy()\n        np.testing.assert_allclose(out, expected, rtol=1e-5)\n        if torch.cuda.is_available():\n            g = MedianFilter([3, 2, 1]).to(torch.device(\"cuda:0\"))\n            np.testing.assert_allclose(g(a.cuda()).cpu().numpy(), expected, rtol=1e-2)\n\n    def test_wrong_args(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            MedianFilter([3, 2]).to(torch.device(\"cpu:0\"))\n        MedianFilter([3, 2, 1]).to(torch.device(\"cpu:0\"))  # test init\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_polyval.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import polyval\n\nTEST_CASES = [\n    [[1.0, 2.5, -4.2], 5.0, 33.3],\n    [[2, 1, 0], 3.0, 21],\n    [[2, 1, 0], [3.0, 3.0], [21, 21]],\n    [torch.as_tensor([2, 1, 0]), [3.0, 3.0], [21, 21]],\n    [torch.as_tensor([2, 1, 0]), torch.as_tensor([3.0, 3.0]), [21, 21]],\n    [torch.as_tensor([2, 1, 0]), np.array([3.0, 3.0]), [21, 21]],\n    [[], np.array([3.0, 3.0]), [0, 0]],\n]\n\n\nclass TestPolyval(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_floats(self, coef, x, expected):\n        result = polyval(coef, x)\n        np.testing.assert_allclose(result.cpu().numpy(), expected)\n\n    @parameterized.expand(TEST_CASES)\n    def test_gpu(self, coef, x, expected):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        x = torch.as_tensor(x, dtype=torch.float, device=device)\n        x.requires_grad = True\n        coef = torch.as_tensor(coef, dtype=torch.float, device=device)\n        coef.requires_grad = True\n        result = polyval(coef, x)\n        if coef.shape[0] > 0:  # empty coef doesn't have grad\n            result.mean().backward()\n            np.testing.assert_allclose(coef.grad.shape, coef.shape)\n        np.testing.assert_allclose(result.cpu().detach().numpy(), expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_preset_filters.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import ApplyFilter, EllipticalFilter, LaplaceFilter, MeanFilter, SharpenFilter\n\nTEST_CASES_MEAN = [(3, 3, torch.ones(3, 3, 3)), (2, 5, torch.ones(5, 5))]\n\nTEST_CASES_LAPLACE = [\n    (\n        3,\n        3,\n        torch.Tensor(\n            [\n                [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]],\n                [[-1, -1, -1], [-1, 26, -1], [-1, -1, -1]],\n                [[-1, -1, -1], [-1, -1, -1], [-1, -1, -1]],\n            ]\n        ),\n    ),\n    (2, 3, torch.Tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])),\n]\n\nTEST_CASES_ELLIPTICAL = [\n    (\n        3,\n        3,\n        torch.Tensor(\n            [[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [1, 1, 1], [0, 1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]]\n        ),\n    ),\n    (2, 3, torch.Tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]])),\n]\n\nTEST_CASES_SHARPEN = [\n    (\n        3,\n        3,\n        torch.Tensor(\n            [\n                [[0, 0, 0], [0, -1, 0], [0, 0, 0]],\n                [[0, -1, 0], [-1, 7, -1], [0, -1, 0]],\n                [[0, 0, 0], [0, -1, 0], [0, 0, 0]],\n            ]\n        ),\n    ),\n    (2, 3, torch.Tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])),\n]\n\n\nclass _TestFilter:\n\n    def test_init(self, spatial_dims, size, expected):\n        test_filter = self.filter_class(spatial_dims=spatial_dims, size=size)\n        torch.testing.assert_allclose(expected, test_filter.filter)\n        self.assertIsInstance(test_filter, torch.nn.Module)\n\n    def test_forward(self):\n        test_filter = self.filter_class(spatial_dims=2, size=3)\n        input = torch.ones(1, 1, 5, 5)\n        _ = test_filter(input)\n\n\nclass TestApplyFilter(unittest.TestCase):\n\n    def test_init_and_forward_2d(self):\n        filter_2d = torch.ones(3, 3)\n        image_2d = torch.ones(1, 3, 3)\n        apply_filter_2d = ApplyFilter(filter_2d)\n        out = apply_filter_2d(image_2d)\n        self.assertEqual(out.shape, image_2d.shape)\n\n    def test_init_and_forward_3d(self):\n        filter_2d = torch.ones(3, 3, 3)\n        image_2d = torch.ones(1, 3, 3, 3)\n        apply_filter_2d = ApplyFilter(filter_2d)\n        out = apply_filter_2d(image_2d)\n        self.assertEqual(out.shape, image_2d.shape)\n\n\nclass MeanFilterTestCase(_TestFilter, unittest.TestCase):\n\n    def setUp(self) -> None:\n        self.filter_class = MeanFilter\n\n    @parameterized.expand(TEST_CASES_MEAN)\n    def test_init(self, spatial_dims, size, expected):\n        super().test_init(spatial_dims, size, expected)\n\n\nclass LaplaceFilterTestCase(_TestFilter, unittest.TestCase):\n\n    def setUp(self) -> None:\n        self.filter_class = LaplaceFilter\n\n    @parameterized.expand(TEST_CASES_LAPLACE)\n    def test_init(self, spatial_dims, size, expected):\n        super().test_init(spatial_dims, size, expected)\n\n\nclass EllipticalTestCase(_TestFilter, unittest.TestCase):\n\n    def setUp(self) -> None:\n        self.filter_class = EllipticalFilter\n\n    @parameterized.expand(TEST_CASES_ELLIPTICAL)\n    def test_init(self, spatial_dims, size, expected):\n        super().test_init(spatial_dims, size, expected)\n\n\nclass SharpenTestCase(_TestFilter, unittest.TestCase):\n\n    def setUp(self) -> None:\n        self.filter_class = SharpenFilter\n\n    @parameterized.expand(TEST_CASES_SHARPEN)\n    def test_init(self, spatial_dims, size, expected):\n        super().test_init(spatial_dims, size, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_savitzky_golay_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import SavitzkyGolayFilter\nfrom tests.test_utils import skip_if_no_cuda\n\n# Zero-padding trivial tests\n\nTEST_CASE_SINGLE_VALUE = [\n    {\"window_length\": 3, \"order\": 1},\n    torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0),  # Input data: Single value\n    torch.Tensor([1 / 3]).unsqueeze(0).unsqueeze(0),  # Expected output: With a window length of 3 and polyorder 1\n    # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)\n    1e-6,  # absolute tolerance\n]\n\nTEST_CASE_1D = [\n    {\"window_length\": 3, \"order\": 1},\n    torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0),  # Input data\n    torch.Tensor([2 / 3, 1.0, 2 / 3])\n    .unsqueeze(0)\n    .unsqueeze(0),  # Expected output: zero padded, so linear interpolation\n    # over length-3 windows will result in output of [2/3, 1, 2/3].\n    1e-6,  # absolute tolerance\n]\n\nTEST_CASE_2D_AXIS_2 = [\n    {\"window_length\": 3, \"order\": 1},  # along default axis (2, first spatial dim)\n    torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),\n    torch.Tensor([[2 / 3, 2 / 3], [1.0, 1.0], [2 / 3, 2 / 3]]).unsqueeze(0).unsqueeze(0),\n    1e-6,  # absolute tolerance\n]\n\nTEST_CASE_2D_AXIS_3 = [\n    {\"window_length\": 3, \"order\": 1, \"axis\": 3},  # along axis 3 (second spatial dim)\n    torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),\n    torch.Tensor([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]).unsqueeze(0).unsqueeze(0),\n    1e-6,  # absolute tolerance\n]\n\n# Replicated-padding trivial tests\n\nTEST_CASE_SINGLE_VALUE_REP = [\n    {\"window_length\": 3, \"order\": 1, \"mode\": \"replicate\"},\n    torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0),  # Input data: Single value\n    torch.Tensor([1.0]).unsqueeze(0).unsqueeze(0),  # Expected output: With a window length of 3 and polyorder 1\n    # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)\n    1e-6,  # absolute tolerance\n]\n\nTEST_CASE_1D_REP = [\n    {\"window_length\": 3, \"order\": 1, \"mode\": \"replicate\"},\n    torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0),  # Input data\n    torch.Tensor([1.0, 1.0, 1.0]).unsqueeze(0).unsqueeze(0),  # Expected output: zero padded, so linear interpolation\n    # over length-3 windows will result in output of [2/3, 1, 2/3].\n    1e-6,  # absolute tolerance\n]\n\nTEST_CASE_2D_AXIS_2_REP = [\n    {\"window_length\": 3, \"order\": 1, \"mode\": \"replicate\"},  # along default axis (2, first spatial dim)\n    torch.ones((3, 2)).unsqueeze(0).unsqueeze(0),\n    torch.Tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]).unsqueeze(0).unsqueeze(0),\n    1e-6,  # absolute tolerance\n]\n\nTEST_CASE_2D_AXIS_3_REP = [\n    {\"window_length\": 3, \"order\": 1, \"axis\": 3, \"mode\": \"replicate\"},  # along axis 3 (second spatial dim)\n    torch.ones((2, 3)).unsqueeze(0).unsqueeze(0),\n    torch.Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).unsqueeze(0).unsqueeze(0),\n    1e-6,  # absolute tolerance\n]\n\n# Sine smoothing\n\nTEST_CASE_SINE_SMOOTH = [\n    {\"window_length\": 3, \"order\": 1},\n    # Sine wave with period equal to savgol window length (windowed to reduce edge effects).\n    torch.as_tensor(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100)).unsqueeze(0).unsqueeze(0),\n    # Should be smoothed out to zeros\n    torch.zeros(100).unsqueeze(0).unsqueeze(0),\n    # tolerance chosen by examining output of SciPy.signal.savgol_filter when provided the above input\n    2e-2,  # absolute tolerance\n]\n\n\nclass TestSavitzkyGolayCPU(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH]\n    )\n    def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):\n        result = SavitzkyGolayFilter(**arguments)(image)\n        np.testing.assert_allclose(result, expected_data, atol=atol, rtol=rtol)\n\n\nclass TestSavitzkyGolayCPUREP(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]\n    )\n    def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):\n        result = SavitzkyGolayFilter(**arguments)(image)\n        np.testing.assert_allclose(result, expected_data, atol=atol, rtol=rtol)\n\n\n@skip_if_no_cuda\nclass TestSavitzkyGolayGPU(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_SINGLE_VALUE, TEST_CASE_1D, TEST_CASE_2D_AXIS_2, TEST_CASE_2D_AXIS_3, TEST_CASE_SINE_SMOOTH]\n    )\n    def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):\n        result = SavitzkyGolayFilter(**arguments)(image.to(device=\"cuda\"))\n        np.testing.assert_allclose(result.cpu(), expected_data, atol=atol, rtol=rtol)\n\n\n@skip_if_no_cuda\nclass TestSavitzkyGolayGPUREP(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_SINGLE_VALUE_REP, TEST_CASE_1D_REP, TEST_CASE_2D_AXIS_2_REP, TEST_CASE_2D_AXIS_3_REP]\n    )\n    def test_value(self, arguments, image, expected_data, atol, rtol=1e-5):\n        result = SavitzkyGolayFilter(**arguments)(image.to(device=\"cuda\"))\n        np.testing.assert_allclose(result.cpu(), expected_data, atol=atol, rtol=rtol)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_separable_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.networks.layers import separable_filtering\n\n\nclass SeparableFilterTestCase(unittest.TestCase):\n\n    def test_1d(self):\n        a = torch.tensor([[list(range(10))]], dtype=torch.float)\n        out = separable_filtering(a, torch.tensor([-1, 0, 1]))\n        expected = np.array([[[1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -8.0]]])\n        np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n        if torch.cuda.is_available():\n            out = separable_filtering(a.cuda(), torch.tensor([-1, 0, 1]).cuda())\n            np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_2d(self):\n        a = torch.tensor([[[list(range(7)), list(range(7, 0, -1)), list(range(7))]]], dtype=torch.float)\n        expected = np.array(\n            [\n                [28.0, 28.0, 28.0, 28.0, 28.0, 28.0],\n                [30.0, 34.0, 38.0, 42.0, 46.0, 50.0],\n                [28.0, 28.0, 28.0, 28.0, 28.0, 28.0],\n            ]\n        )\n        expected = expected[None][None]\n        out = separable_filtering(a, [torch.tensor([1, 1, 1]), torch.tensor([2, 2])])\n        np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n        if torch.cuda.is_available():\n            out = separable_filtering(a.cuda(), [torch.tensor([1, 1, 1]).cuda(), torch.tensor([2, 2]).cuda()])\n            np.testing.assert_allclose(out.cpu().numpy(), expected, rtol=1e-4)\n\n    def test_3d(self):\n        a = torch.tensor(\n            [[list(range(7)), list(range(7)), list(range(7))], [list(range(7)), list(range(7)), list(range(7))]],\n            dtype=torch.float,\n        )\n        a = a[None][None]\n        a = a.expand(2, 3, -1, -1, -1)\n        expected = np.array(\n            [\n                [\n                    [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],\n                    [6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0],\n                    [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],\n                ],\n                [\n                    [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],\n                    [6.0, 18.0, 36.0, 54.0, 72.0, 90.0, 66.0],\n                    [4.0, 12.0, 24.0, 36.0, 48.0, 60.0, 44.0],\n                ],\n            ]\n        )\n        # testing shapes\n        k = torch.tensor([1, 1, 1])\n        for kernel in (k, [k] * 3):\n            out = separable_filtering(a, kernel)\n            np.testing.assert_allclose(out.cpu().numpy()[1][2], expected, rtol=1e-4)\n            if torch.cuda.is_available():\n                out = separable_filtering(\n                    a.cuda(), kernel.cuda() if isinstance(kernel, torch.Tensor) else [k.cuda() for k in kernel]\n                )\n                np.testing.assert_allclose(out.cpu().numpy()[0][1], expected, rtol=1e-4)\n\n    def test_wrong_args(self):\n        with self.assertRaisesRegex(TypeError, \"\"):\n            separable_filtering(((1, 1, 1, 2, 3, 2)), torch.ones((2,)))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_skip_connection.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.layers import SkipConnection\n\nTEST_CASES_3D = []\nfor type_1 in (\"cat\", \"add\", \"mul\"):\n    input_shape = (16, 10, 32, 24, 48)\n    if type_1 == \"cat\":\n        result_shape = (input_shape[0] * 2, *input_shape[1:])\n    else:\n        result_shape = input_shape\n    test_case = [{\"dim\": 0, \"mode\": type_1}, input_shape, result_shape]\n    TEST_CASES_3D.append(test_case)\n\n\nclass TestSkipConnection(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SkipConnection(submodule=torch.nn.Softmax(dim=1), **input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_wrong_mode(self):\n        with self.assertRaises(ValueError):\n            SkipConnection(torch.nn.Softmax(dim=1), mode=\"test\")(torch.randn(10, 10))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_vector_quantizer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom math import prod\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import EMAQuantizer, VectorQuantizer\n\nTEST_CASES = [\n    [{\"spatial_dims\": 2, \"num_embeddings\": 16, \"embedding_dim\": 8}, (1, 8, 4, 4), (1, 4, 4)],\n    [{\"spatial_dims\": 3, \"num_embeddings\": 16, \"embedding_dim\": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)],\n]\n\n\nclass TestEMA(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_ema_shape(self, input_param, input_shape, output_shape):\n        layer = EMAQuantizer(**input_param)\n        x = torch.randn(input_shape)\n        layer = layer.train()\n        outputs = layer(x)\n        self.assertEqual(outputs[0].shape, input_shape)\n        self.assertEqual(outputs[2].shape, output_shape)\n\n        layer = layer.eval()\n        outputs = layer(x)\n        self.assertEqual(outputs[0].shape, input_shape)\n        self.assertEqual(outputs[2].shape, output_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_ema_quantize(self, input_param, input_shape, output_shape):\n        layer = EMAQuantizer(**input_param)\n        x = torch.randn(input_shape)\n        outputs = layer.quantize(x)\n        self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1]))  # (HxW[xD], C)\n        self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param[\"num_embeddings\"]))  # (HxW[xD], E)\n        self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:])  # (1, H, W, [D])\n\n    def test_ema(self):\n        layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0)\n        original_weight_0 = layer.embedding.weight[0].clone()\n        original_weight_1 = layer.embedding.weight[1].clone()\n        x_0 = original_weight_0\n        x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\n        x_0 = x_0.repeat(1, 1, 1, 2) + 0.001\n\n        x_1 = original_weight_1\n        x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\n        x_1 = x_1.repeat(1, 1, 1, 2)\n\n        x = torch.cat([x_0, x_1], dim=0)\n        layer = layer.train()\n        _ = layer(x)\n\n        self.assertTrue(all(layer.embedding.weight[0] != original_weight_0))\n        self.assertTrue(all(layer.embedding.weight[1] == original_weight_1))\n\n\nclass TestVectorQuantizer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_vector_quantizer_shape(self, input_param, input_shape, output_shape):\n        layer = VectorQuantizer(EMAQuantizer(**input_param))\n        x = torch.randn(input_shape)\n        outputs = layer(x)\n        self.assertEqual(outputs[1].shape, input_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape):\n        layer = VectorQuantizer(EMAQuantizer(**input_param))\n        x = torch.randn(input_shape)\n        outputs = layer.quantize(x)\n        self.assertEqual(outputs.shape, output_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/layers/test_weight_init.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers import trunc_normal_\n\nTEST_CASES = [\n    [{\"mean\": 0.0, \"std\": 1.0, \"a\": 2, \"b\": 4}, (6, 12, 3, 1, 7)],\n    [{\"mean\": 0.3, \"std\": 0.6, \"a\": -1.0, \"b\": 1.3}, (1, 4, 4, 4)],\n    [{\"mean\": 0.1, \"std\": 0.4, \"a\": 1.3, \"b\": 1.8}, (5, 7, 7, 8, 9)],\n]\n\nTEST_ERRORS = [\n    [{\"mean\": 0.0, \"std\": 1.0, \"a\": 5, \"b\": 1.1}, (1, 1, 8, 8, 8)],\n    [{\"mean\": 0.3, \"std\": -0.1, \"a\": 1.0, \"b\": 2.0}, (8, 5, 2, 6, 9)],\n    [{\"mean\": 0.7, \"std\": 0.0, \"a\": 0.1, \"b\": 2.0}, (4, 12, 23, 17)],\n]\n\n\nclass TestWeightInit(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape):\n        im = torch.rand(input_shape)\n        trunc_normal_(im, **input_param)\n        self.assertEqual(im.shape, input_shape)\n\n    @parameterized.expand(TEST_ERRORS)\n    def test_ill_arg(self, input_param, input_shape):\n        with self.assertRaises(ValueError):\n            im = torch.rand(input_shape)\n            trunc_normal_(im, **input_param)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/nets/dints/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/nets/dints/test_dints_cell.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets.dints import Cell\n\nTEST_CASES_3D = [\n    [\n        {\"c_prev\": 8, \"c\": 8, \"rate\": 1, \"arch_code_c\": None},\n        torch.tensor([1, 1, 1, 1, 1]),\n        torch.tensor([1, 1, 1, 1, 1]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 64, 32, 16),\n    ],\n    [\n        {\"c_prev\": 8, \"c\": 4, \"rate\": 1, \"arch_code_c\": [1, 1, 0, 0, 1]},\n        torch.tensor([1, 1, 0, 0, 1]),\n        torch.tensor([1, 0.2, 1.3, 0, 1]),\n        (2, 8, 32, 16, 8),\n        (2, 4, 64, 32, 16),\n    ],\n    [\n        {\"c_prev\": 8, \"c\": 8, \"rate\": 0, \"arch_code_c\": None, \"act_name\": \"SELU\", \"norm_name\": \"BATCH\"},\n        torch.tensor([1, 1, 1, 1, 1]),\n        torch.tensor([0, 0, 0, 1, 0]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 32, 16, 8),\n    ],\n    [\n        {\n            \"c_prev\": 8,\n            \"c\": 8,\n            \"rate\": -1,\n            \"arch_code_c\": None,\n            \"act_name\": \"PRELU\",\n            \"norm_name\": (\"BATCH\", {\"affine\": False}),\n        },\n        torch.tensor([1, 1, 1, 1, 1]),\n        torch.tensor([1, 1, 1, 1, 1]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 16, 8, 4),\n    ],\n    [\n        {\"c_prev\": 8, \"c\": 8, \"rate\": -1, \"arch_code_c\": [1, 0, 0, 0, 1], \"act_name\": \"RELU\", \"norm_name\": \"INSTANCE\"},\n        torch.tensor([1, 0, 0, 0, 1]),\n        torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 16, 8, 4),\n    ],\n]\n\nTEST_CASES_2D = [\n    [\n        {\n            \"c_prev\": 8,\n            \"c\": 7,\n            \"rate\": -1,\n            \"arch_code_c\": [1, 0, 0, 0, 1],\n            \"spatial_dims\": 2,\n            \"act_name\": \"PRELU\",\n            \"norm_name\": (\"BATCH\", {\"affine\": False}),\n        },\n        torch.tensor([1, 0]),\n        torch.tensor([0.2, 0.2]),\n        (2, 8, 16, 8),\n        (2, 7, 8, 4),\n    ],\n    [\n        {\n            \"c_prev\": 8,\n            \"c\": 8,\n            \"rate\": -1,\n            \"arch_code_c\": None,\n            \"spatial_dims\": 2,\n            \"act_name\": \"SELU\",\n            \"norm_name\": \"INSTANCE\",\n        },\n        torch.tensor([1, 0]),\n        torch.tensor([0.2, 0.2]),\n        (2, 8, 16, 8),\n        (2, 8, 8, 4),\n    ],\n]\n\n\nclass TestCell(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_2D + TEST_CASES_3D)\n    def test_cell_3d(self, input_param, ops, weight, input_shape, expected_shape):\n        net = Cell(**input_param)\n        result = net(torch.randn(input_shape), weight=weight)\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/dints/test_dints_mixop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets.dints import Cell, MixedOp\nfrom tests.test_utils import test_script_save\n\nTEST_CASES_3D = [\n    [\n        {\"c\": 8, \"arch_code_c\": None},\n        torch.tensor([1, 1, 1, 1, 1]),\n        torch.tensor([1, 1, 1, 1, 1]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 32, 16, 8),\n    ],\n    [\n        {\"c\": 8, \"arch_code_c\": [1, 1, 0, 0, 1]},\n        torch.tensor([1, 1, 0, 0, 1]),\n        torch.tensor([1, 0.2, 1.3, 0, 1]),\n        (2, 8, 64, 32, 16),\n        (2, 8, 64, 32, 16),\n    ],\n    [\n        {\"c\": 8, \"arch_code_c\": None},\n        torch.tensor([1, 1, 1, 1, 1]),\n        torch.tensor([0, 0, 0, 1, 0]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 32, 16, 8),\n    ],\n    [\n        {\"c\": 8, \"arch_code_c\": [1, 1, 1, 0, 1]},\n        torch.tensor([1, 1, 1, 1, 1]),\n        torch.tensor([0, 0, 0, 1, 0]),\n        (2, 8, 32, 16, 8),\n        (2, 8, 32, 16, 8),\n    ],\n]\nTEST_CASES_2D = [\n    [\n        {\"c\": 32, \"arch_code_c\": [1, 1, 1, 0, 1]},\n        torch.tensor([1, 1]),\n        torch.tensor([0, 0]),\n        (2, 32, 16, 8),\n        (2, 32, 16, 8),\n    ]\n]\n\n\nclass TestMixOP(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_mixop_3d(self, input_param, ops, weight, input_shape, expected_shape):\n        net = MixedOp(ops=Cell.OPS3D, **input_param)\n        result = net(torch.randn(input_shape), weight=weight)\n        self.assertEqual(result.shape, expected_shape)\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES_2D)\n    def test_mixop_2d(self, input_param, ops, weight, input_shape, expected_shape):\n        net = MixedOp(ops=Cell.OPS2D, **input_param)\n        result = net(torch.randn(input_shape), weight=weight)\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES_3D)\n    def test_script(self, input_param, ops, weight, input_shape, expected_shape):\n        net = MixedOp(ops=Cell.OPS3D, **input_param)\n        test_script_save(net, torch.randn(input_shape), weight)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/regunet/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/nets/regunet/test_localnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.regunet import LocalNet\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_LOCALNET_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 2,\n            \"num_channel_initial\": 16,\n            \"out_kernel_initializer\": \"kaiming_uniform\",\n            \"out_activation\": None,\n            \"out_channels\": 2,\n            \"extract_levels\": (0, 1),\n            \"pooling\": False,\n            \"concat_skip\": True,\n            \"mode\": \"bilinear\",\n            \"align_corners\": True,\n        },\n        (1, 2, 16, 16),\n        (1, 2, 16, 16),\n    ]\n]\n\nTEST_CASE_LOCALNET_3D = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 2,\n            \"num_channel_initial\": 16,\n            \"out_kernel_initializer\": \"zeros\",\n            \"out_activation\": \"sigmoid\",\n            \"out_channels\": 2,\n            \"extract_levels\": (0, 1, 2, 3),\n            \"pooling\": True,\n            \"concat_skip\": False,\n        },\n        (1, 2, 16, 16, 16),\n        (1, 2, 16, 16, 16),\n    ]\n]\n\n\nclass TestLocalNet(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = LocalNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASE_LOCALNET_2D + TEST_CASE_LOCALNET_3D)\n    def test_extract_levels(self, input_param, input_shape, expected_shape):\n        net = LocalNet(**input_param).to(device)\n        self.assertEqual(len(net.decode_deconvs), len(input_param[\"extract_levels\"]) - 1)\n        self.assertEqual(len(net.decode_convs), len(input_param[\"extract_levels\"]) - 1)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_LOCALNET_2D[0]\n        net = LocalNet(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/regunet/test_regunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.regunet import RegUNet\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_REGUNET_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 2,\n            \"num_channel_initial\": 16,\n            \"depth\": 3,\n            \"out_kernel_initializer\": \"kaiming_uniform\",\n            \"out_activation\": None,\n            \"out_channels\": 2,\n            \"pooling\": False,\n            \"concat_skip\": True,\n            \"encode_kernel_sizes\": 3,\n        },\n        (1, 2, 16, 16),\n        (1, 2, 16, 16),\n    ]\n]\n\nTEST_CASE_REGUNET_3D = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 2,\n            \"num_channel_initial\": 16,\n            \"depth\": 3,\n            \"out_kernel_initializer\": \"kaiming_uniform\",\n            \"out_activation\": \"sigmoid\",\n            \"out_channels\": 2,\n            \"extract_levels\": (0, 1, 2, 3),\n            \"pooling\": True,\n            \"concat_skip\": False,\n            \"encode_kernel_sizes\": (3, 3, 3, 7),\n        },\n        (1, 2, 16, 16, 16),\n        (1, 2, 16, 16, 16),\n    ]\n]\n\n\nclass TestREGUNET(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_REGUNET_2D + TEST_CASE_REGUNET_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = RegUNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_shape(self):\n        with self.assertRaisesRegex(ValueError, \"\"):\n            input_param, _, _ = TEST_CASE_REGUNET_2D[0]\n            input_shape = (1, input_param[\"in_channels\"], 17, 17)\n            net = RegUNet(**input_param).to(device)\n            net.forward(torch.randn(input_shape).to(device))\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_REGUNET_2D[0]\n        net = RegUNet(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_ahnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import FCN, MCFCN\nfrom monai.networks.nets import AHNet\nfrom tests.test_utils import skip_if_quick, test_pretrained_networks, test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_FCN_1 = [  # batch 2\n    {\"out_channels\": 3, \"upsample_mode\": \"transpose\", \"pretrained\": False},\n    (2, 3, 32, 32),\n    (2, 3, 32, 32),\n]\nTEST_CASE_FCN_2 = [\n    {\"out_channels\": 2, \"upsample_mode\": \"transpose\", \"pretrained\": False, \"progress\": False},\n    (1, 3, 32, 32),\n    (1, 2, 32, 32),\n]\nTEST_CASE_FCN_3 = [\n    {\"out_channels\": 1, \"upsample_mode\": \"bilinear\", \"pretrained\": False},\n    (1, 3, 32, 32),\n    (1, 1, 32, 32),\n]\nTEST_CASE_FCN_WITH_PRETRAIN_1 = [  # batch 2\n    {\"out_channels\": 3, \"upsample_mode\": \"transpose\", \"pretrained\": True},\n    (2, 3, 32, 32),\n    (2, 3, 32, 32),\n]\nTEST_CASE_FCN_WITH_PRETRAIN_2 = [\n    {\"out_channels\": 2, \"upsample_mode\": \"transpose\", \"pretrained\": True, \"progress\": False},\n    (1, 3, 32, 32),\n    (1, 2, 32, 32),\n]\n\nTEST_CASE_MCFCN_1 = [  # batch 5\n    {\"out_channels\": 3, \"in_channels\": 8, \"upsample_mode\": \"transpose\", \"pretrained\": False, \"progress\": False},\n    (5, 8, 32, 32),\n    (5, 3, 32, 32),\n]\nTEST_CASE_MCFCN_2 = [\n    {\"out_channels\": 2, \"in_channels\": 1, \"upsample_mode\": \"transpose\", \"pretrained\": False, \"progress\": True},\n    (1, 1, 32, 32),\n    (1, 2, 32, 32),\n]\nTEST_CASE_MCFCN_3 = [\n    {\"out_channels\": 1, \"in_channels\": 2, \"upsample_mode\": \"bilinear\", \"pretrained\": False},\n    (1, 2, 32, 32),\n    (1, 1, 32, 32),\n]\nTEST_CASE_MCFCN_WITH_PRETRAIN_1 = [  # batch 5\n    {\"out_channels\": 3, \"in_channels\": 8, \"upsample_mode\": \"transpose\", \"pretrained\": True, \"progress\": False},\n    (5, 8, 32, 32),\n    (5, 3, 32, 32),\n]\nTEST_CASE_MCFCN_WITH_PRETRAIN_2 = [\n    {\"out_channels\": 2, \"in_channels\": 1, \"upsample_mode\": \"transpose\", \"pretrained\": True, \"progress\": True},\n    (1, 1, 32, 32),\n    (1, 2, 32, 32),\n]\n\nTEST_CASE_AHNET_2D_1 = [\n    {\"spatial_dims\": 2, \"upsample_mode\": \"bilinear\", \"psp_block_num\": 0},\n    (1, 1, 32, 64),\n    (1, 1, 32, 64),\n]\nTEST_CASE_AHNET_2D_2 = [\n    {\"spatial_dims\": 2, \"upsample_mode\": \"transpose\", \"out_channels\": 2, \"psp_block_num\": 1},\n    (1, 1, 64, 32),\n    (1, 2, 64, 32),\n]\nTEST_CASE_AHNET_2D_3 = [\n    {\"spatial_dims\": 2, \"upsample_mode\": \"bilinear\", \"out_channels\": 2, \"psp_block_num\": 2},\n    (1, 1, 64, 32),\n    (1, 2, 64, 32),\n]\nTEST_CASE_AHNET_3D_1 = [\n    {\"spatial_dims\": 3, \"upsample_mode\": \"trilinear\", \"psp_block_num\": 0},\n    (2, 1, 32, 32, 64),\n    (2, 1, 32, 32, 64),\n]\nTEST_CASE_AHNET_3D_2 = [\n    {\"spatial_dims\": 3, \"upsample_mode\": \"transpose\", \"out_channels\": 2, \"psp_block_num\": 1},\n    (1, 1, 32, 32, 64),\n    (1, 2, 32, 32, 64),\n]\nTEST_CASE_AHNET_3D_3 = [\n    {\"spatial_dims\": 3, \"upsample_mode\": \"nearest\", \"out_channels\": 2, \"psp_block_num\": 3},\n    (1, 1, 96, 128, 32),\n    (1, 2, 96, 128, 32),\n]\nTEST_CASE_AHNET_3D_WITH_PRETRAIN_1 = [\n    {\"spatial_dims\": 3, \"upsample_mode\": \"trilinear\", \"psp_block_num\": 0},\n    (2, 1, 32, 32, 64),\n    (2, 1, 32, 32, 64),\n    {\"out_channels\": 1, \"upsample_mode\": \"transpose\"},\n]\nTEST_CASE_AHNET_3D_WITH_PRETRAIN_2 = [\n    {\"spatial_dims\": 3, \"upsample_mode\": \"transpose\", \"out_channels\": 2, \"psp_block_num\": 3},\n    (1, 1, 64, 64, 64),\n    (1, 2, 64, 64, 64),\n    {\"out_channels\": 1, \"upsample_mode\": \"bilinear\"},\n]\nTEST_CASE_AHNET_3D_WITH_PRETRAIN_3 = [\n    {\"spatial_dims\": 3, \"upsample_mode\": \"transpose\", \"in_channels\": 2, \"out_channels\": 3},\n    (1, 2, 128, 128, 32),\n    (1, 3, 128, 128, 32),\n    {\"out_channels\": 1, \"upsample_mode\": \"bilinear\"},\n]\n\n\nclass TestFCN(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_FCN_1, TEST_CASE_FCN_2, TEST_CASE_FCN_3])\n    @skip_if_quick\n    def test_fcn_shape(self, input_param, input_shape, expected_shape):\n        net = FCN(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestFCNWithPretrain(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_FCN_WITH_PRETRAIN_1, TEST_CASE_FCN_WITH_PRETRAIN_2])\n    @skip_if_quick\n    def test_fcn_shape(self, input_param, input_shape, expected_shape):\n        net = test_pretrained_networks(FCN, input_param, device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestMCFCN(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_MCFCN_1, TEST_CASE_MCFCN_2, TEST_CASE_MCFCN_3])\n    def test_mcfcn_shape(self, input_param, input_shape, expected_shape):\n        net = MCFCN(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestMCFCNWithPretrain(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_MCFCN_WITH_PRETRAIN_1, TEST_CASE_MCFCN_WITH_PRETRAIN_2])\n    def test_mcfcn_shape(self, input_param, input_shape, expected_shape):\n        net = test_pretrained_networks(MCFCN, input_param, device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestAHNET(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_AHNET_2D_1, TEST_CASE_AHNET_2D_2, TEST_CASE_AHNET_2D_3])\n    def test_ahnet_shape_2d(self, input_param, input_shape, expected_shape):\n        net = AHNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand([TEST_CASE_AHNET_3D_1, TEST_CASE_AHNET_3D_2, TEST_CASE_AHNET_3D_3])\n    @skip_if_quick\n    def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape):\n        net = AHNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @skip_if_quick\n    def test_script(self):\n        # test 2D network\n        net = AHNet(spatial_dims=2, out_channels=2)\n        test_data = torch.randn(1, 1, 128, 64)\n        test_script_save(net, test_data)\n        # test 3D network\n        net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode=\"nearest\")\n        test_data = torch.randn(1, 1, 32, 32, 64)\n        test_script_save(net, test_data)\n\n\nclass TestAHNETWithPretrain(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_AHNET_3D_WITH_PRETRAIN_1, TEST_CASE_AHNET_3D_WITH_PRETRAIN_2, TEST_CASE_AHNET_3D_WITH_PRETRAIN_3]\n    )\n    def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_param):\n        net = AHNet(**input_param).to(device)\n        net2d = FCN(**fcn_input_param).to(device)\n        net.copy_from(net2d)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @skip_if_quick\n    def test_initialize_pretrained(self):\n        net = AHNet(\n            spatial_dims=3,\n            upsample_mode=\"transpose\",\n            in_channels=2,\n            out_channels=3,\n            psp_block_num=2,\n            pretrained=True,\n            progress=True,\n        ).to(device)\n        input_data = torch.randn(2, 2, 32, 32, 64).to(device)\n        with eval_mode(net):\n            result = net.forward(input_data)\n            self.assertEqual(result.shape, (2, 3, 32, 32, 64))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_attentionunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nimport torch.nn as nn\n\nimport monai.networks.nets.attentionunet as att\nfrom tests.test_utils import skip_if_no_cuda, skip_if_quick\n\n\ndef get_net_parameters(net: nn.Module) -> int:\n    \"\"\"Returns the total number of parameters in a Module.\"\"\"\n    return sum(param.numel() for param in net.parameters())\n\n\nclass TestAttentionUnet(unittest.TestCase):\n    def test_attention_block(self):\n        for dims in [2, 3]:\n            block = att.AttentionBlock(dims, f_int=2, f_g=6, f_l=6)\n            shape = (4, 6) + (30,) * dims\n            x = torch.rand(*shape, dtype=torch.float32)\n            output = block(x, x)\n            self.assertEqual(output.shape, x.shape)\n\n            block = att.AttentionBlock(dims, f_int=2, f_g=3, f_l=6)\n            xshape = (4, 6) + (30,) * dims\n            x = torch.rand(*xshape, dtype=torch.float32)\n            gshape = (4, 3) + (30,) * dims\n            g = torch.rand(*gshape, dtype=torch.float32)\n            output = block(g, x)\n            self.assertEqual(output.shape, x.shape)\n\n    @skip_if_quick\n    def test_attentionunet(self):\n        for dims in [2, 3]:\n            shape = (3, 1) + (92,) * dims\n            input = torch.rand(*shape)\n            model = att.AttentionUnet(\n                spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2)\n            )\n            output = model(input)\n            self.assertEqual(output.shape[2:], input.shape[2:])\n            self.assertEqual(output.shape[0], input.shape[0])\n            self.assertEqual(output.shape[1], 2)\n\n    def test_attentionunet_kernel_size(self):\n        args_dict = {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 2,\n            \"channels\": (3, 4, 5),\n            \"up_kernel_size\": 5,\n            \"strides\": (1, 2),\n        }\n        model_a = att.AttentionUnet(**args_dict, kernel_size=5)\n        model_b = att.AttentionUnet(**args_dict, kernel_size=7)\n        self.assertEqual(get_net_parameters(model_a), 3534)\n        self.assertEqual(get_net_parameters(model_b), 5574)\n\n    @skip_if_no_cuda\n    def test_attentionunet_gpu(self):\n        for dims in [2, 3]:\n            shape = (3, 1) + (92,) * dims\n            input = torch.rand(*shape).to(\"cuda:0\")\n            model = att.AttentionUnet(\n                spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)\n            ).to(\"cuda:0\")\n            with torch.no_grad():\n                output = model(input)\n                self.assertEqual(output.shape[2:], input.shape[2:])\n                self.assertEqual(output.shape[0], input.shape[0])\n                self.assertEqual(output.shape[1], 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_autoencoder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.layers import Act\nfrom monai.networks.nets import AutoEncoder\nfrom tests.test_utils import test_script_save\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nTEST_CASE_0 = [  # single channel 2D, batch 4, no residual\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 1,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n        \"num_res_units\": 0,\n    },\n    (1, 1, 128, 128),\n    (1, 1, 128, 128),\n]\n\nTEST_CASE_1 = [  # single channel 2D, batch 4\n    {\"spatial_dims\": 2, \"in_channels\": 1, \"out_channels\": 1, \"channels\": (4, 8, 16), \"strides\": (2, 2, 2)},\n    (1, 1, 128, 128),\n    (1, 1, 128, 128),\n]\n\nTEST_CASE_2 = [  # 3-channel 2D, batch 4, LeakyReLU activation\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 3,\n        \"out_channels\": 3,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n        \"act\": (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n    },\n    (1, 3, 128, 128),\n    (1, 3, 128, 128),\n]\n\nTEST_CASE_3 = [  # 4-channel 3D, batch 4\n    {\"spatial_dims\": 3, \"in_channels\": 4, \"out_channels\": 3, \"channels\": (4, 8, 16), \"strides\": (2, 2, 2)},\n    (1, 4, 128, 128, 128),\n    (1, 3, 128, 128, 128),\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]\n\n\nTEST_CASE_FAIL = {  # 2-channel 2D, should fail because of stride/channel mismatch.\n    \"spatial_dims\": 2,\n    \"in_channels\": 2,\n    \"out_channels\": 2,\n    \"channels\": (4, 8, 16),\n    \"strides\": (2, 2),\n}\n\n\nclass TestAutoEncoder(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = AutoEncoder(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = AutoEncoder(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8), strides=(2, 2))\n        test_data = torch.randn(2, 1, 32, 32)\n        test_script_save(net, test_data)\n\n    def test_channel_stride_difference(self):\n        with self.assertRaises(ValueError):\n            AutoEncoder(**TEST_CASE_FAIL)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_autoencoderkl.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps import download_url\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import AutoencoderKL\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, testing_data_config\n\ntqdm, has_tqdm = optional_import(\"tqdm\", name=\"tqdm\")\n_, has_einops = optional_import(\"einops\")\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\n\nCASES_NO_ATTENTION = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n        },\n        (1, 1, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n        },\n        (1, 1, 16, 16, 16),\n        (1, 1, 16, 16, 16),\n        (1, 4, 4, 4, 4),\n    ],\n]\n\nCASES_ATTENTION = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": (1, 1, 2),\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n        },\n        (1, 1, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, True),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, True),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16, 16),\n        (1, 1, 16, 16, 16),\n        (1, 4, 4, 4, 4),\n    ],\n]\n\nif has_einops:\n    CASES = CASES_NO_ATTENTION + CASES_ATTENTION\nelse:\n    CASES = CASES_NO_ATTENTION\n\n\nclass TestAutoEncoderKL(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n            self.assertEqual(result[2].shape, expected_latent_shape)\n\n    @parameterized.expand(CASES)\n    def test_shape_with_convtranspose_and_checkpointing(\n        self, input_param, input_shape, expected_shape, expected_latent_shape\n    ):\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpoint\": True, \"use_convtranspose\": True})\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n            self.assertEqual(result[2].shape, expected_latent_shape)\n\n    def test_model_channels_not_multiple_of_norm_num_group(self):\n        with self.assertRaises(ValueError):\n            AutoencoderKL(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(24, 24, 24),\n                attention_levels=(False, False, False),\n                latent_channels=8,\n                num_res_blocks=1,\n                norm_num_groups=16,\n            )\n\n    def test_model_num_channels_not_same_size_of_attention_levels(self):\n        with self.assertRaises(ValueError):\n            AutoencoderKL(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(24, 24, 24),\n                attention_levels=(False, False),\n                latent_channels=8,\n                num_res_blocks=1,\n                norm_num_groups=16,\n            )\n\n    def test_model_num_channels_not_same_size_of_num_res_blocks(self):\n        with self.assertRaises(ValueError):\n            AutoencoderKL(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(24, 24, 24),\n                attention_levels=(False, False, False),\n                latent_channels=8,\n                num_res_blocks=(8, 8),\n                norm_num_groups=16,\n            )\n\n    def test_shape_reconstruction(self):\n        input_param, input_shape, expected_shape, _ = CASES[0]\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.reconstruct(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):\n        input_param, input_shape, expected_shape, _ = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpoint\": True, \"use_convtranspose\": True})\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.reconstruct(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_shape_encode(self):\n        input_param, input_shape, _, expected_latent_shape = CASES[0]\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.encode(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_latent_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n\n    def test_shape_encode_with_convtranspose_and_checkpointing(self):\n        input_param, input_shape, _, expected_latent_shape = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpoint\": True, \"use_convtranspose\": True})\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.encode(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_latent_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n\n    def test_shape_sampling(self):\n        input_param, _, _, expected_latent_shape = CASES[0]\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.sampling(\n                torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)\n            )\n            self.assertEqual(result.shape, expected_latent_shape)\n\n    def test_shape_sampling_convtranspose_and_checkpointing(self):\n        input_param, _, _, expected_latent_shape = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpoint\": True, \"use_convtranspose\": True})\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.sampling(\n                torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)\n            )\n            self.assertEqual(result.shape, expected_latent_shape)\n\n    def test_shape_decode(self):\n        input_param, expected_input_shape, _, latent_shape = CASES[0]\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.decode(torch.randn(latent_shape).to(device))\n            self.assertEqual(result.shape, expected_input_shape)\n\n    def test_shape_decode_convtranspose_and_checkpointing(self):\n        input_param, expected_input_shape, _, latent_shape = CASES[0]\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpoint\": True, \"use_convtranspose\": True})\n        net = AutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.decode(torch.randn(latent_shape).to(device))\n            self.assertEqual(result.shape, expected_input_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_compatibility_with_monai_generative(self):\n        # test loading weights from a model saved in MONAI Generative, version 0.2.3\n        with skip_if_downloading_fails():\n            net = AutoencoderKL(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(4, 4, 4),\n                latent_channels=4,\n                attention_levels=(False, False, True),\n                num_res_blocks=1,\n                norm_num_groups=4,\n            ).to(device)\n\n            tmpdir = tempfile.mkdtemp()\n            key = \"autoencoderkl_monai_generative_weights\"\n            url = testing_data_config(\"models\", key, \"url\")\n            hash_type = testing_data_config(\"models\", key, \"hash_type\")\n            hash_val = testing_data_config(\"models\", key, \"hash_val\")\n            filename = \"autoencoderkl_monai_generative_weights.pt\"\n\n            weight_path = os.path.join(tmpdir, filename)\n            download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)\n\n            net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_basic_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import BasicUNet\nfrom tests.test_utils import test_script_save\n\nCASES_1D = []\nfor mode in [\"pixelshuffle\", \"nontrainable\", \"deconv\", None]:\n    kwargs = {\"spatial_dims\": 1, \"in_channels\": 5, \"out_channels\": 8}\n    if mode is not None:\n        kwargs[\"upsample\"] = mode  # type: ignore\n    CASES_1D.append([kwargs, (10, 5, 33), (10, 8, 33)])\n\nCASES_2D = []\nfor mode in [\"pixelshuffle\", \"nontrainable\", \"deconv\"]:\n    for d1 in range(33, 64, 14):\n        for d2 in range(63, 33, -21):\n            in_channels, out_channels = 2, 3\n            CASES_2D.append(\n                [\n                    {\n                        \"spatial_dims\": 2,\n                        \"in_channels\": in_channels,\n                        \"out_channels\": out_channels,\n                        \"features\": (12, 12, 13, 14, 15, 16),\n                        \"upsample\": mode,\n                    },\n                    (2, in_channels, d1, d2),\n                    (2, out_channels, d1, d2),\n                ]\n            )\nCASES_3D = [\n    [  # single channel 3D, batch 2\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 2,\n            \"features\": (16, 20, 21, 22, 23, 11),\n            \"upsample\": \"pixelshuffle\",\n        },\n        (2, 1, 33, 34, 35),\n        (2, 2, 33, 34, 35),\n    ],\n    [  # 2-channel 3D, batch 3\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 2,\n            \"out_channels\": 7,\n            \"features\": (14, 15, 16, 17, 18, 11),\n            \"upsample\": \"deconv\",\n        },\n        (3, 2, 33, 37, 34),\n        (3, 7, 33, 37, 34),\n    ],\n    [  # 4-channel 3D, batch 5\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 4,\n            \"out_channels\": 2,\n            \"features\": (14, 15, 16, 17, 18, 10),\n            \"upsample\": \"nontrainable\",\n        },\n        (5, 4, 34, 35, 37),\n        (5, 2, 34, 35, 37),\n    ],\n]\n\n\nclass TestBasicUNET(unittest.TestCase):\n    @parameterized.expand(CASES_1D + CASES_2D + CASES_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        print(input_param)\n        net = BasicUNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n        self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=3)\n        test_data = torch.randn(16, 1, 32, 32)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_basic_unetplusplus.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import BasicUNetPlusPlus\nfrom tests.test_utils import test_script_save\n\nCASES_1D = []\nfor mode in [\"pixelshuffle\", \"nontrainable\", \"deconv\", None]:\n    kwargs = {\"spatial_dims\": 1, \"in_channels\": 5, \"out_channels\": 8}\n    if mode is not None:\n        kwargs[\"upsample\"] = mode  # type: ignore\n    CASES_1D.append([kwargs, (10, 5, 33), (10, 8, 33)])\n\nCASES_2D = []\nfor mode in [\"pixelshuffle\", \"nontrainable\", \"deconv\"]:\n    for d1 in range(33, 64, 14):\n        for d2 in range(63, 33, -21):\n            in_channels, out_channels = 2, 3\n            CASES_2D.append(\n                [\n                    {\n                        \"spatial_dims\": 2,\n                        \"in_channels\": in_channels,\n                        \"out_channels\": out_channels,\n                        \"features\": (12, 12, 13, 14, 15, 16),\n                        \"upsample\": mode,\n                    },\n                    (2, in_channels, d1, d2),\n                    (2, out_channels, d1, d2),\n                ]\n            )\nCASES_3D = [\n    [  # single channel 3D, batch 2\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 2,\n            \"features\": (16, 20, 21, 22, 23, 11),\n            \"upsample\": \"pixelshuffle\",\n        },\n        (2, 1, 33, 34, 35),\n        (2, 2, 33, 34, 35),\n    ],\n    [  # 2-channel 3D, batch 3\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 2,\n            \"out_channels\": 7,\n            \"features\": (14, 15, 16, 17, 18, 11),\n            \"upsample\": \"deconv\",\n        },\n        (3, 2, 33, 37, 34),\n        (3, 7, 33, 37, 34),\n    ],\n    [  # 4-channel 3D, batch 5\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 4,\n            \"out_channels\": 2,\n            \"features\": (14, 15, 16, 17, 18, 10),\n            \"upsample\": \"nontrainable\",\n        },\n        (5, 4, 34, 35, 37),\n        (5, 2, 34, 35, 37),\n    ],\n]\n\n\nclass TestBasicUNETPlusPlus(unittest.TestCase):\n    @parameterized.expand(CASES_1D + CASES_2D + CASES_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        print(input_param)\n        net = BasicUNetPlusPlus(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n        self.assertEqual(result[0].shape, expected_shape)\n\n    def test_deep_supervision_shape(self):\n        net = BasicUNetPlusPlus(spatial_dims=2, deep_supervision=True, in_channels=3, out_channels=3)\n        test_data = torch.randn(16, 3, 32, 32)\n        with eval_mode(net):\n            result = net(test_data)\n        self.assertEqual(result[0].shape, test_data.shape)\n\n    def test_script(self):\n        net = BasicUNetPlusPlus(spatial_dims=2, deep_supervision=True, in_channels=1, out_channels=3)\n        test_data = torch.randn(16, 1, 32, 32)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_bundle_init_bundle.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\n\nfrom monai.networks.nets import UNet\nfrom tests.test_utils import command_line_tests, skip_if_windows\n\n\n@skip_if_windows\nclass TestBundleInit(unittest.TestCase):\n    def test_bundle(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            net = UNet(2, 1, 1, [4, 8], [2])\n            torch.save(net.state_dict(), tempdir + \"/test.pt\")\n\n            bundle_root = tempdir + \"/test_bundle\"\n\n            cmd = [\n                \"coverage\",\n                \"run\",\n                \"-m\",\n                \"monai.bundle\",\n                \"init_bundle\",\n                bundle_root,\n                tempdir + \"/test.pt\",\n                \"--dataset_license\",\n                \"True\",\n            ]\n            command_line_tests(cmd)\n\n            self.assertTrue(os.path.exists(bundle_root + \"/configs/metadata.json\"))\n            self.assertTrue(os.path.exists(bundle_root + \"/configs/inference.json\"))\n            self.assertTrue(os.path.exists(bundle_root + \"/models/model.pt\"))\n            self.assertTrue(os.path.exists(bundle_root + \"/LICENSE\"))\n            self.assertTrue(os.path.exists(bundle_root + \"/docs/README.md\"))\n            self.assertTrue(os.path.exists(bundle_root + \"/docs/data_license.txt\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_cell_sam_wrapper.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.cell_sam_wrapper import CellSamWrapper\nfrom monai.utils import optional_import\n\nbuild_sam_vit_b, has_sam = optional_import(\"segment_anything.build_sam\", name=\"build_sam_vit_b\")\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nTEST_CASE_CELLSEGWRAPPER = []\nfor dims in [128, 256, 512, 1024]:\n    test_case = [\n        {\"auto_resize_inputs\": True, \"network_resize_roi\": [1024, 1024], \"checkpoint\": None},\n        (1, 3, *([dims] * 2)),\n        (1, 3, *([dims] * 2)),\n    ]\n    TEST_CASE_CELLSEGWRAPPER.append(test_case)\n\n\n@unittest.skipUnless(has_sam, \"Requires SAM installation\")\nclass TestResNetDS(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_CELLSEGWRAPPER)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = CellSamWrapper(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n    def test_ill_arg0(self):\n        with self.assertRaises(RuntimeError):\n            net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device)\n            net(torch.randn([1, 3, 256, 256]).to(device))\n\n    def test_ill_arg1(self):\n        with self.assertRaises(RuntimeError):\n            net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device)\n            net(torch.randn([1, 3, 1024, 1024]).to(device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_checkpointunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.layers import Act, Norm\nfrom monai.networks.nets.unet import CheckpointUNet, UNet\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [  # single channel 2D, batch 16, no residual\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 0,\n    },\n    (16, 1, 32, 32),\n    (16, 3, 32, 32),\n]\n\nTEST_CASE_1 = [  # single channel 2D, batch 16\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n    },\n    (16, 1, 32, 32),\n    (16, 3, 32, 32),\n]\n\nTEST_CASE_2 = [  # single channel 3D, batch 16\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 1,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n    },\n    (16, 1, 32, 24, 48),\n    (16, 3, 32, 24, 48),\n]\n\nTEST_CASE_3 = [  # 4-channel 3D, batch 16\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nTEST_CASE_4 = [  # 4-channel 3D, batch 16, batch normalization\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n        \"norm\": Norm.BATCH,\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nTEST_CASE_5 = [  # 4-channel 3D, batch 16, LeakyReLU activation\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n        \"act\": (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        \"adn_ordering\": \"NA\",\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nTEST_CASE_6 = [  # 4-channel 3D, batch 16, LeakyReLU activation explicit\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n        \"act\": (torch.nn.LeakyReLU, {\"negative_slope\": 0.2}),\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]\n\n\nclass TestCheckpointUNet(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        \"\"\"Validate CheckpointUNet output shapes across configurations.\n\n        Args:\n            input_param: Dictionary of UNet constructor arguments.\n            input_shape: Tuple specifying input tensor dimensions.\n            expected_shape: Tuple specifying expected output tensor dimensions.\n        \"\"\"\n        net = CheckpointUNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_checkpointing_equivalence_eval(self):\n        \"\"\"Confirm eval parity when checkpointing is inactive.\"\"\"\n        params = dict(\n            spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1\n        )\n\n        x = torch.randn(2, 1, 32, 32, device=device)\n\n        torch.manual_seed(42)\n        net_plain = UNet(**params).to(device)\n\n        torch.manual_seed(42)\n        net_ckpt = CheckpointUNet(**params).to(device)\n\n        # Both in eval mode disables checkpointing logic\n        with eval_mode(net_ckpt), eval_mode(net_plain):\n            y_ckpt = net_ckpt(x)\n            y_plain = net_plain(x)\n\n        # Check shape equality\n        self.assertEqual(y_ckpt.shape, y_plain.shape)\n\n        # Check numerical equivalence\n        self.assertTrue(\n            torch.allclose(y_ckpt, y_plain, atol=1e-6, rtol=1e-5),\n            f\"Eval-mode outputs differ: max abs diff={torch.max(torch.abs(y_ckpt - y_plain)).item():.2e}\",\n        )\n\n    def test_checkpointing_activates_training(self):\n        \"\"\"Verify checkpointing recomputes activations during training.\"\"\"\n        params = dict(\n            spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1\n        )\n\n        net = CheckpointUNet(**params).to(device)\n        net.train()\n\n        x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True)\n        y = net(x)\n        loss = y.mean()\n        loss.backward()\n\n        # gradient flow check\n        grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None)\n        self.assertGreater(grad_norm.item(), 0.0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_controlnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps import download_url\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.controlnet import ControlNet\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, testing_data_config\n\n_, has_einops = optional_import(\"einops\")\nUNCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        },\n        (1, 8, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        },\n        (1, 8, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (4, 4, 4),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 4,\n        },\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        },\n        (1, 8, 4, 4),\n    ],\n]\n\nUNCOND_CASES_3D = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        },\n        (1, 8, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (4, 4, 4),\n            \"num_head_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 4,\n            \"resblock_updown\": True,\n        },\n        (1, 4, 4, 4, 4),\n    ],\n]\n\nCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n        },\n        (1, 8, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"resblock_updown\": True,\n        },\n        (1, 8, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"upcast_attention\": True,\n        },\n        (1, 8, 4, 4),\n    ],\n]\n\n\nclass TestControlNet(unittest.TestCase):\n    @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param, expected_output_shape):\n        input_param[\"conditioning_embedding_in_channels\"] = input_param[\"in_channels\"]\n        input_param[\"conditioning_embedding_num_channels\"] = (input_param[\"channels\"][0],)\n        net = ControlNet(**input_param)\n        with eval_mode(net):\n            x = torch.rand((1, 1) + (16,) * input_param[\"spatial_dims\"])\n            timesteps = torch.randint(0, 1000, (1,)).long()\n            controlnet_cond = torch.rand((1, 1) + (16,) * input_param[\"spatial_dims\"])\n            result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond)\n            self.assertEqual(len(result[0]), 2 * len(input_param[\"channels\"]))\n            self.assertEqual(result[1].shape, expected_output_shape)\n\n    @parameterized.expand(COND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self, input_param, expected_output_shape):\n        input_param[\"conditioning_embedding_in_channels\"] = input_param[\"in_channels\"]\n        input_param[\"conditioning_embedding_num_channels\"] = (input_param[\"channels\"][0],)\n        net = ControlNet(**input_param)\n        with eval_mode(net):\n            x = torch.rand((1, 1) + (16,) * input_param[\"spatial_dims\"])\n            timesteps = torch.randint(0, 1000, (1,)).long()\n            controlnet_cond = torch.rand((1, 1) + (16,) * input_param[\"spatial_dims\"])\n            result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3)))\n            self.assertEqual(len(result[0]), 2 * len(input_param[\"channels\"]))\n            self.assertEqual(result[1].shape, expected_output_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_compatibility_with_monai_generative(self):\n        # test loading weights from a model saved in MONAI Generative, version 0.2.3\n        with skip_if_downloading_fails():\n            net = ControlNet(\n                spatial_dims=2,\n                in_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False, True),\n                norm_num_groups=8,\n                with_conditioning=True,\n                transformer_num_layers=1,\n                cross_attention_dim=3,\n                resblock_updown=True,\n            )\n\n            tmpdir = tempfile.mkdtemp()\n            key = \"controlnet_monai_generative_weights\"\n            url = testing_data_config(\"models\", key, \"url\")\n            hash_type = testing_data_config(\"models\", key, \"hash_type\")\n            hash_val = testing_data_config(\"models\", key, \"hash_val\")\n            filename = \"controlnet_monai_generative_weights.pt\"\n\n            weight_path = os.path.join(tmpdir, filename)\n            download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)\n\n            net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_daf3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import DAF3D\nfrom monai.utils import optional_import\nfrom tests.test_utils import test_script_save\n\n_, has_tv = optional_import(\"torchvision\")\n\nTEST_CASES = [\n    [{\"in_channels\": 1, \"out_channels\": 1}, (1, 1, 32, 32, 64), (1, 1, 32, 32, 64)],  # single channel 3D, batch 1\n    [{\"in_channels\": 2, \"out_channels\": 1}, (3, 2, 32, 64, 128), (3, 1, 32, 64, 128)],  # two channel 3D, batch 3\n    [\n        {\"in_channels\": 2, \"out_channels\": 2},\n        (3, 2, 32, 64, 128),\n        (3, 2, 32, 64, 128),\n    ],  # two channel 3D, same in & out channels\n    [{\"in_channels\": 4, \"out_channels\": 1}, (5, 4, 35, 35, 35), (5, 1, 35, 35, 35)],  # four channel 3D, batch 5\n    [\n        {\"in_channels\": 4, \"out_channels\": 4},\n        (5, 4, 35, 35, 35),\n        (5, 4, 35, 35, 35),\n    ],  # four channel 3D, same in & out channels\n]\n\n\n@unittest.skipUnless(has_tv, \"torchvision not installed\")\nclass TestDAF3D(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        print(input_param)\n        net = DAF3D(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n        self.assertEqual(result.shape, expected_shape)\n\n    @unittest.skip(\"daf3d: torchscript not currently supported\")\n    def test_script(self):\n        net = DAF3D(in_channels=1, out_channels=1)\n        test_data = torch.randn(16, 1, 32, 32)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_densenet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import TYPE_CHECKING\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import DenseNet121, Densenet169, DenseNet264, densenet201\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick, test_script_save\n\nif TYPE_CHECKING:\n    import torchvision\n\n    has_torchvision = True\nelse:\n    torchvision, has_torchvision = optional_import(\"torchvision\")\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_1 = [  # 4-channel 3D, batch 2\n    {\"pretrained\": False, \"spatial_dims\": 3, \"in_channels\": 2, \"out_channels\": 3, \"norm\": (\"instance\", {\"eps\": 1e-5})},\n    (2, 2, 32, 64, 48),\n    (2, 3),\n]\n\nTEST_CASE_2 = [  # 4-channel 2D, batch 2\n    {\"pretrained\": False, \"spatial_dims\": 2, \"in_channels\": 2, \"out_channels\": 3, \"act\": \"PRELU\"},\n    (2, 2, 32, 64),\n    (2, 3),\n]\n\nTEST_CASE_3 = [  # 4-channel 1D, batch 1\n    {\"pretrained\": False, \"spatial_dims\": 1, \"in_channels\": 2, \"out_channels\": 3},\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASES = []\nfor case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]:\n    for model in [DenseNet121, Densenet169, densenet201, DenseNet264]:\n        TEST_CASES.append([model, *case])\n\nTEST_SCRIPT_CASES = [[model, *TEST_CASE_1] for model in [DenseNet121, Densenet169, densenet201, DenseNet264]]\n\nTEST_PRETRAINED_2D_CASE_1 = [  # 4-channel 2D, batch 2\n    DenseNet121,\n    {\"pretrained\": True, \"progress\": True, \"spatial_dims\": 2, \"in_channels\": 2, \"out_channels\": 3},\n    (1, 2, 32, 64),\n    (1, 3),\n]\n\nTEST_PRETRAINED_2D_CASE_2 = [  # 4-channel 2D, batch 2\n    DenseNet121,\n    {\"pretrained\": True, \"progress\": False, \"spatial_dims\": 2, \"in_channels\": 2, \"out_channels\": 1},\n    (1, 2, 32, 64),\n    (1, 1),\n]\n\nTEST_PRETRAINED_2D_CASE_3 = [\n    DenseNet121,\n    {\"pretrained\": True, \"progress\": False, \"spatial_dims\": 2, \"in_channels\": 3, \"out_channels\": 1},\n    (1, 3, 32, 32),\n]\n\n\nclass TestPretrainedDENSENET(unittest.TestCase):\n    @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2])\n    @skip_if_quick\n    def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape):\n        with skip_if_downloading_fails():\n            net = model(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand([TEST_PRETRAINED_2D_CASE_3])\n    @skipUnless(has_torchvision, \"Requires `torchvision` package.\")\n    def test_pretrain_consistency(self, model, input_param, input_shape):\n        example = torch.randn(input_shape).to(device)\n        with skip_if_downloading_fails():\n            net = model(**input_param).to(device)\n        with eval_mode(net):\n            result = net.features.forward(example)\n        torchvision_net = torchvision.models.densenet121(weights=\"DEFAULT\").to(device)\n        with eval_mode(torchvision_net):\n            expected_result = torchvision_net.features.forward(example)\n        self.assertTrue(torch.all(result == expected_result))\n\n\nclass TestDENSENET(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_densenet_shape(self, model, input_param, input_shape, expected_shape):\n        net = model(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_SCRIPT_CASES)\n    def test_script(self, model, input_param, input_shape, expected_shape):\n        net = model(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_diffusion_model_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps import download_url\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import DiffusionModelUNet\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, testing_data_config\n\n_, has_einops = optional_import(\"einops\")\n\nUNCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": (1, 1, 2),\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, True, True),\n            \"num_head_channels\": (0, 2, 4),\n            \"norm_num_groups\": 8,\n        }\n    ],\n]\n\nUNCOND_CASES_3D = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": (0, 0, 4),\n            \"norm_num_groups\": 8,\n        }\n    ],\n]\n\nCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"resblock_updown\": True,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"upcast_attention\": True,\n        }\n    ],\n]\n\nDROPOUT_OK = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"dropout_cattn\": 0.25,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n        }\n    ],\n]\n\nDROPOUT_WRONG = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"dropout_cattn\": 3.0,\n        }\n    ]\n]\n\n\nclass TestDiffusionModelUNet2D(unittest.TestCase):\n    @parameterized.expand(UNCOND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param):\n        net = DiffusionModelUNet(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_timestep_with_wrong_shape(self):\n        net = DiffusionModelUNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_different_in_channel_out_channel(self):\n        in_channels = 6\n        out_channels = 3\n        net = DiffusionModelUNet(\n            spatial_dims=2,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, out_channels, 16, 16))\n\n    def test_model_channels_not_multiple_of_norm_num_group(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNet(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 12),\n                attention_levels=(False, False, False),\n                norm_num_groups=8,\n            )\n\n    def test_attention_levels_with_different_length_num_head_channels(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNet(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False, False),\n                num_head_channels=(0, 2),\n                norm_num_groups=8,\n            )\n\n    def test_num_res_blocks_with_different_length_channels(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNet(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=(1, 1),\n                channels=(8, 8, 8),\n                attention_levels=(False, False, False),\n                norm_num_groups=8,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self):\n        net = DiffusionModelUNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            with_conditioning=True,\n            transformer_num_layers=1,\n            cross_attention_dim=3,\n            norm_num_groups=8,\n            num_head_channels=8,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                context=torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 32))\n\n    def test_with_conditioning_cross_attention_dim_none(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNet(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False, True),\n                with_conditioning=True,\n                transformer_num_layers=1,\n                cross_attention_dim=None,\n                norm_num_groups=8,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_context_with_conditioning_none(self):\n        net = DiffusionModelUNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            with_conditioning=False,\n            transformer_num_layers=1,\n            norm_num_groups=8,\n        )\n\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net.forward(\n                    x=torch.rand((1, 1, 16, 32)),\n                    timesteps=torch.randint(0, 1000, (1,)).long(),\n                    context=torch.rand((1, 1, 3)),\n                )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models_class_conditioning(self):\n        net = DiffusionModelUNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=8,\n            num_head_channels=8,\n            num_class_embeds=2,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                class_labels=torch.randint(0, 2, (1,)).long(),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 32))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_models_no_class_labels(self):\n        net = DiffusionModelUNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=8,\n            num_head_channels=8,\n            num_class_embeds=2,\n        )\n\n        with self.assertRaises(ValueError):\n            net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long())\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_model_channels_not_same_size_of_attention_levels(self):\n        with self.assertRaises(ValueError):\n            DiffusionModelUNet(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False),\n                norm_num_groups=8,\n                num_head_channels=8,\n                num_class_embeds=2,\n            )\n\n    @parameterized.expand(COND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_2d_models_shape(self, input_param):\n        net = DiffusionModelUNet(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)))\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n\nclass TestDiffusionModelUNet3D(unittest.TestCase):\n    @parameterized.expand(UNCOND_CASES_3D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param):\n        net = DiffusionModelUNet(**input_param)\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_different_in_channel_out_channel(self):\n        in_channels = 6\n        out_channels = 3\n        net = DiffusionModelUNet(\n            spatial_dims=3,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=4,\n        )\n        with eval_mode(net):\n            result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())\n            self.assertEqual(result.shape, (1, out_channels, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self):\n        net = DiffusionModelUNet(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(16, 16, 16),\n            attention_levels=(False, False, True),\n            norm_num_groups=16,\n            with_conditioning=True,\n            transformer_num_layers=1,\n            cross_attention_dim=3,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 16, 16)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                context=torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n    # Test dropout specification for cross-attention blocks\n    @parameterized.expand(DROPOUT_WRONG)\n    def test_wrong_dropout(self, input_param):\n        with self.assertRaises(ValueError):\n            _ = DiffusionModelUNet(**input_param)\n\n    @parameterized.expand(DROPOUT_OK)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_right_dropout(self, input_param):\n        _ = DiffusionModelUNet(**input_param)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_compatibility_with_monai_generative(self):\n        # test loading weights from a model saved in MONAI Generative, version 0.2.3\n        with skip_if_downloading_fails():\n            net = DiffusionModelUNet(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False, True),\n                with_conditioning=True,\n                cross_attention_dim=3,\n                transformer_num_layers=1,\n                norm_num_groups=8,\n            )\n\n            tmpdir = tempfile.mkdtemp()\n            key = \"diffusion_model_unet_monai_generative_weights\"\n            url = testing_data_config(\"models\", key, \"url\")\n            hash_type = testing_data_config(\"models\", key, \"hash_type\")\n            hash_val = testing_data_config(\"models\", key, \"hash_val\")\n            filename = \"diffusion_model_unet_monai_generative_weights.pt\"\n\n            weight_path = os.path.join(tmpdir, filename)\n            download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)\n\n            net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_dints_network.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import DiNTS, TopologyInstance, TopologySearch\nfrom monai.networks.nets.dints import Cell\nfrom tests.test_utils import skip_if_quick, test_script_save\n\nTEST_CASES_3D = [\n    [\n        {\n            \"channel_mul\": 0.2,\n            \"num_blocks\": 6,\n            \"num_depths\": 3,\n            \"device\": \"cpu\",\n            \"use_downsample\": False,\n            \"spatial_dims\": 3,\n        },\n        {\n            \"in_channels\": 1,\n            \"num_classes\": 3,\n            \"act_name\": \"RELU\",\n            \"norm_name\": (\"INSTANCE\", {\"affine\": True}),\n            \"use_downsample\": False,\n            \"spatial_dims\": 3,\n        },\n        (3, 1, 32, 32, 16),\n        (3, 3, 32, 32, 16),\n    ]\n]\nif torch.cuda.is_available():\n    TEST_CASES_3D += [\n        [\n            {\n                \"channel_mul\": 0.5,\n                \"num_blocks\": 7,\n                \"num_depths\": 4,\n                \"device\": \"cuda\",\n                \"use_downsample\": True,\n                \"spatial_dims\": 3,\n            },\n            {\n                \"in_channels\": 2,\n                \"num_classes\": 2,\n                \"act_name\": \"PRELU\",\n                \"norm_name\": \"BATCH\",\n                \"use_downsample\": True,\n                \"spatial_dims\": 3,\n            },\n            (3, 2, 32, 32, 16),\n            (3, 2, 32, 32, 16),\n        ]\n    ]\nTEST_CASES_2D = [\n    [\n        {\n            \"channel_mul\": 1,\n            \"num_blocks\": 7,\n            \"num_depths\": 4,\n            \"device\": \"cpu\",\n            \"use_downsample\": True,\n            \"spatial_dims\": 2,\n        },\n        {\n            \"in_channels\": 2,\n            \"num_classes\": 2,\n            \"act_name\": \"PRELU\",\n            \"norm_name\": \"BATCH\",\n            \"use_downsample\": True,\n            \"spatial_dims\": 2,\n        },\n        (2, 2, 32, 16),\n        (2, 2, 32, 16),\n    ]\n]\nif torch.cuda.is_available():\n    TEST_CASES_2D += [\n        [\n            {\n                \"channel_mul\": 0.5,\n                \"num_blocks\": 8,\n                \"num_depths\": 4,\n                \"device\": \"cuda\",\n                \"use_downsample\": False,\n                \"spatial_dims\": 2,\n            },\n            {\n                \"in_channels\": 1,\n                \"num_classes\": 4,\n                \"act_name\": \"RELU\",\n                \"norm_name\": (\"INSTANCE\", {\"affine\": True}),\n                \"use_downsample\": False,\n                \"spatial_dims\": 2,\n            },\n            (2, 1, 32, 16),\n            (2, 4, 32, 16),\n        ]\n    ]\n\n\n@skip_if_quick\nclass TestDints(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)\n    def test_dints_inference(self, dints_grid_params, dints_params, input_shape, expected_shape):\n        grid = TopologySearch(**dints_grid_params)\n        dints_params[\"dints_space\"] = grid\n        net = DiNTS(**dints_params).to(dints_grid_params[\"device\"])\n        result = net(torch.randn(input_shape).to(dints_grid_params[\"device\"]))\n        self.assertEqual(result.shape, expected_shape)\n        # test functions\n        grid.get_ram_cost_usage(in_size=input_shape, full=True)\n        grid.get_ram_cost_usage(in_size=input_shape, full=False)\n        probs_a, _ = grid.get_prob_a(child=True)\n        grid.get_topology_entropy(probs_a)\n        grid.decode()\n        grid.gen_mtx(depth=4)\n\n    @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)\n    def test_dints_search(self, dints_grid_params, dints_params, input_shape, expected_shape):\n        num_blocks = dints_grid_params[\"num_blocks\"]\n        num_depths = dints_grid_params[\"num_depths\"]\n        # init a Cell to obtain cell operation number\n        _cell = Cell(1, 1, 0, spatial_dims=dints_grid_params[\"spatial_dims\"])\n        num_cell_ops = len(_cell.OPS)\n        # define archtecture codes\n        node_a = torch.ones((num_blocks + 1, num_depths))\n        arch_code_a = np.ones((num_blocks, 3 * num_depths - 2))\n        arch_code_c = np.random.randint(num_cell_ops, size=(num_blocks, 3 * num_depths - 2))\n        # initialize with codes\n        dints_grid_params[\"arch_code\"] = [arch_code_a, arch_code_c]\n        grid = TopologyInstance(**dints_grid_params)\n        # set as deploy stage\n        dints_params[\"dints_space\"] = grid\n        dints_params[\"node_a\"] = node_a\n        net = DiNTS(**dints_params).to(dints_grid_params[\"device\"])\n        result = net(torch.randn(input_shape).to(dints_grid_params[\"device\"]))\n        self.assertEqual(result.shape, expected_shape)\n        self.assertTrue(isinstance(net.weight_parameters(), list))\n\n\nclass TestDintsTS(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_3D + TEST_CASES_2D)\n    def test_script(self, dints_grid_params, dints_params, input_shape, _):\n        grid = TopologyInstance(**dints_grid_params)\n        dints_grid_params[\"device\"] = \"cpu\"\n        dints_params[\"dints_space\"] = grid\n        net = DiNTS(**dints_params).to(dints_grid_params[\"device\"])\n        test_script_save(net, torch.randn(input_shape).to(dints_grid_params[\"device\"]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_discriminator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import Discriminator\nfrom tests.test_utils import test_script_save\n\nTEST_CASE_0 = [\n    {\"in_shape\": (1, 64, 64), \"channels\": (2, 4, 8), \"strides\": (2, 2, 2), \"num_res_units\": 0},\n    torch.rand(16, 1, 64, 64),\n    (16, 1),\n]\n\nTEST_CASE_1 = [\n    {\"in_shape\": (1, 64, 64), \"channels\": (2, 4, 8), \"strides\": (2, 2, 2), \"num_res_units\": 2},\n    torch.rand(16, 1, 64, 64),\n    (16, 1),\n]\n\nTEST_CASE_2 = [\n    {\"in_shape\": (1, 64, 64), \"channels\": (2, 4), \"strides\": (2, 2), \"num_res_units\": 0},\n    torch.rand(16, 1, 64, 64),\n    (16, 1),\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]\n\n\nclass TestDiscriminator(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_data, expected_shape):\n        net = Discriminator(**input_param)\n        with eval_mode(net):\n            result = net.forward(input_data)\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = Discriminator(in_shape=(1, 64, 64), channels=(2, 4), strides=(2, 2), num_res_units=0)\n        test_data = torch.rand(16, 1, 64, 64)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_dynunet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport platform\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import DynUNet\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose, dict_product, skip_if_no_cuda, skip_if_windows, test_script_save\n\nInstanceNorm3dNVFuser, _ = optional_import(\"apex.normalization\", name=\"InstanceNorm3dNVFuser\")\n\nON_AARCH64 = platform.machine() == \"aarch64\"\nif ON_AARCH64:\n    rtol, atol = 1e-2, 1e-2\nelse:\n    rtol, atol = 1e-4, 1e-4\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_DYNUNET_2D = []\nout_channels_2d = 2\nin_size_2d = 64\nspatial_dims_2d = 2\nfor params in dict_product(\n    kernel_size=[(3, 3, 3, 1), ((3, 1), 1, (3, 3), (1, 1))],\n    strides=[(1, 1, 1, 1), (2, 2, 2, 1)],\n    in_channels=[2, 3],\n    res_block=[True, False],\n):\n    kernel_size = params[\"kernel_size\"]\n    strides = params[\"strides\"]\n    in_channels = params[\"in_channels\"]\n    res_block = params[\"res_block\"]\n    expected_shape = (1, out_channels_2d, *[in_size_2d // strides[0]] * spatial_dims_2d)\n    test_case = [\n        {\n            \"spatial_dims\": spatial_dims_2d,\n            \"in_channels\": in_channels,\n            \"out_channels\": out_channels_2d,\n            \"kernel_size\": kernel_size,\n            \"strides\": strides,\n            \"upsample_kernel_size\": strides[1:],\n            \"norm_name\": \"batch\",\n            \"act_name\": (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.2}),\n            \"deep_supervision\": False,\n            \"res_block\": res_block,\n            \"dropout\": None,\n        },\n        (1, in_channels, in_size_2d, in_size_2d),\n        expected_shape,\n    ]\n    TEST_CASE_DYNUNET_2D.append(test_case)\n\nTEST_CASE_DYNUNET_3D = []  # in 3d cases, also test anisotropic kernel/strides\nin_channels_3d = 1\nin_size_3d = 64\nfor params in dict_product(out_channels=[2, 3], res_block=[True, False]):\n    out_channels = params[\"out_channels\"]\n    res_block = params[\"res_block\"]\n    expected_shape = (1, out_channels, 64, 32, 64)\n    test_case = [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": in_channels_3d,\n            \"out_channels\": out_channels,\n            \"kernel_size\": (3, (1, 1, 3), 3, 3),\n            \"strides\": ((1, 2, 1), 2, 2, 1),\n            \"upsample_kernel_size\": (2, 2, 1),\n            \"filters\": (64, 96, 128, 192),\n            \"norm_name\": (\"INSTANCE\", {\"affine\": True}),\n            \"deep_supervision\": True,\n            \"res_block\": res_block,\n            \"dropout\": (\"alphadropout\", {\"p\": 0.25}),\n        },\n        (1, in_channels_3d, in_size_3d, in_size_3d, in_size_3d),\n        expected_shape,\n    ]\n    TEST_CASE_DYNUNET_3D.append(test_case)\n\nTEST_CASE_DEEP_SUPERVISION = []\nin_size_ds = 64\nfor params in dict_product(\n    spatial_dims=[2, 3],\n    res_block=[True, False],\n    deep_supr_num=[1, 2],\n    strides=[(1, 2, 1, 2, 1), (2, 2, 2, 1), (2, 1, 1, 2, 2)],\n):\n    spatial_dims = params[\"spatial_dims\"]\n    res_block = params[\"res_block\"]\n    deep_supr_num = params[\"deep_supr_num\"]\n    strides = params[\"strides\"]\n    scale = strides[0]\n    test_case = [\n        {\n            \"spatial_dims\": spatial_dims,\n            \"in_channels\": 1,\n            \"out_channels\": 2,\n            \"kernel_size\": [3] * len(strides),\n            \"strides\": strides,\n            \"upsample_kernel_size\": strides[1:],\n            \"norm_name\": (\"group\", {\"num_groups\": 16}),\n            \"deep_supervision\": True,\n            \"deep_supr_num\": deep_supr_num,\n            \"res_block\": res_block,\n        },\n        (1, 1, *[in_size_ds] * spatial_dims),\n        (1, 1 + deep_supr_num, 2, *[in_size_ds // scale] * spatial_dims),\n    ]\n    TEST_CASE_DEEP_SUPERVISION.append(test_case)\n\n\nclass TestDynUNet(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_DYNUNET_3D)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = DynUNet(**input_param).to(device)\n        if \"alphadropout\" in input_param.get(\"dropout\"):\n            self.assertTrue(any(isinstance(x, torch.nn.AlphaDropout) for x in net.modules()))\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_DYNUNET_2D[0]\n        net = DynUNet(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\n@skip_if_no_cuda\n@skip_if_windows\nclass TestDynUNetWithInstanceNorm3dNVFuser(unittest.TestCase):\n    def setUp(self):\n        try:\n            layer = InstanceNorm3dNVFuser(num_features=1, affine=False).to(\"cuda:0\")\n            inp = torch.randn([1, 1, 1, 1, 1]).to(\"cuda:0\")\n            out = layer(inp)\n            del inp, out, layer\n        except Exception:\n            self.skipTest(\"NVFuser not available\")\n\n    @parameterized.expand([TEST_CASE_DYNUNET_3D[0]])\n    def test_consistency(self, input_param, input_shape, _):\n        for eps in [1e-4, 1e-5]:\n            for momentum in [0.1, 0.01]:\n                for affine in [True, False]:\n                    norm_param = {\"eps\": eps, \"momentum\": momentum, \"affine\": affine}\n                    input_param[\"norm_name\"] = (\"instance\", norm_param)\n                    input_param_fuser = input_param.copy()\n                    input_param_fuser[\"norm_name\"] = (\"instance_nvfuser\", norm_param)\n                    for memory_format in [torch.contiguous_format, torch.channels_last_3d]:\n                        net = DynUNet(**input_param).to(\"cuda:0\", memory_format=memory_format)\n                        net_fuser = DynUNet(**input_param_fuser).to(\"cuda:0\", memory_format=memory_format)\n                        net_fuser.load_state_dict(net.state_dict())\n\n                        input_tensor = torch.randn(input_shape).to(\"cuda:0\", memory_format=memory_format)\n                        with eval_mode(net):\n                            result = net(input_tensor)\n                        with eval_mode(net_fuser):\n                            result_fuser = net_fuser(input_tensor)\n\n                        assert_allclose(result, result_fuser, rtol=rtol, atol=atol)\n\n\nclass TestDynUNetDeepSupervision(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_DEEP_SUPERVISION)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = DynUNet(**input_param).to(device)\n        with torch.no_grad():\n            results = net(torch.randn(input_shape).to(device))\n            self.assertEqual(results.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_efficientnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import (\n    BlockArgs,\n    EfficientNetBN,\n    EfficientNetBNFeatures,\n    drop_connect,\n    get_efficientnet_image_size,\n)\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, skip_if_quick, test_pretrained_networks, test_script_save\n\nTESTS_PATH = Path(__file__).parents[2]\n\nif TYPE_CHECKING:\n    import torchvision\n\n    has_torchvision = True\nelse:\n    torchvision, has_torchvision = optional_import(\"torchvision\")\n\nif TYPE_CHECKING:\n    import PIL\n\n    has_pil = True\nelse:\n    PIL, has_pil = optional_import(\"PIL\")\n\n\ndef get_model_names():\n    return [f\"efficientnet-b{d}\" for d in range(8)]\n\n\ndef get_expected_model_shape(model_name):\n    model_input_shapes = {\n        \"efficientnet-b0\": 224,\n        \"efficientnet-b1\": 240,\n        \"efficientnet-b2\": 260,\n        \"efficientnet-b3\": 300,\n        \"efficientnet-b4\": 380,\n        \"efficientnet-b5\": 456,\n        \"efficientnet-b6\": 528,\n        \"efficientnet-b7\": 600,\n    }\n    return model_input_shapes[model_name]\n\n\ndef get_block_args():\n    # test string list\n    return [\n        \"r1_k3_s11_e1_i32_o16_se0.25\",\n        \"r2_k3_s22_e6_i16_o24_se0.25\",\n        \"r2_k5_s22_e6_i24_o40_se0.25\",\n        \"r3_k3_s22_e6_i40_o80_se0.25\",\n        \"r3_k5_s11_e6_i80_o112_se0.25\",\n        \"r4_k5_s22_e6_i112_o192_se0.25\",\n        \"r1_k3_s11_e6_i192_o320_se0.25\",\n        \"r1_k3_s11_e1_i32_o16_se0.25_noskip\",\n        \"r2_k3_s22_e6_i16_o24_se0.25_noskip\",\n        \"r2_k5_s22_e6_i24_o40_se0.25_noskip\",\n        \"r3_k3_s22_e6_i40_o80_se0.25_noskip\",\n        \"r3_k5_s11_e6_i80_o112_se0.25_noskip\",\n        \"r4_k5_s22_e6_i112_o192_se0.25_noskip\",\n        \"r1_k3_s11_e6_i192_o320_se0.25_noskip\",\n    ]\n\n\ndef make_shape_cases(\n    models,\n    spatial_dims,\n    batches,\n    pretrained,\n    in_channels=3,\n    num_classes=1000,\n    norm=(\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n):\n    ret_tests = []\n    for spatial_dim in spatial_dims:  # selected spatial_dims\n        for batch in batches:  # check single batch as well as multiple batch input\n            for model in models:  # selected models\n                for is_pretrained in pretrained:  # pretrained or not pretrained\n                    kwargs = {\n                        \"model_name\": model,\n                        \"pretrained\": is_pretrained,\n                        \"progress\": False,\n                        \"spatial_dims\": spatial_dim,\n                        \"in_channels\": in_channels,\n                        \"num_classes\": num_classes,\n                        \"norm\": norm,\n                    }\n                    ret_tests.append(\n                        [\n                            kwargs,\n                            (batch, in_channels) + (get_expected_model_shape(model),) * spatial_dim,\n                            (batch, num_classes),\n                        ]\n                    )\n    return ret_tests\n\n\n# create list of selected models to speed up redundant tests\n# only test the models B0, B3, B7\nSEL_MODELS = [get_model_names()[i] for i in [0, 3, 7]]\n\n# pretrained=False cases\n# 1D models are cheap so do test for all models in 1D\nCASES_1D = make_shape_cases(\n    models=get_model_names(), spatial_dims=[1], batches=[1, 4], pretrained=[False], in_channels=3, num_classes=1000\n)\n\n# 2D and 3D models are expensive so use selected models\nCASES_2D = make_shape_cases(\n    models=SEL_MODELS,\n    spatial_dims=[2],\n    batches=[1, 4],\n    pretrained=[False],\n    in_channels=3,\n    num_classes=1000,\n    norm=\"instance\",\n)\nCASES_3D = make_shape_cases(\n    models=[SEL_MODELS[0]],\n    spatial_dims=[3],\n    batches=[1],\n    pretrained=[False],\n    in_channels=3,\n    num_classes=1000,\n    norm=\"batch\",\n)\n\n# pretrained=True cases\n# tabby kitty test with pretrained model\n# needs 'testing_data/kitty_test.jpg'\n# image from: https://commons.wikimedia.org/wiki/File:Tabby_cat_with_blue_eyes-3336579.jpg\nCASES_KITTY_TRAINED = [\n    (\n        {\n            \"model_name\": \"efficientnet-b0\",\n            \"pretrained\": True,\n            \"progress\": False,\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"num_classes\": 1000,\n            \"norm\": (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n            \"adv_prop\": False,\n        },\n        os.path.join(TESTS_PATH, \"testing_data\", \"kitty_test.jpg\"),\n        282,  # ~ tiger cat\n    ),\n    (\n        {\n            \"model_name\": \"efficientnet-b3\",\n            \"pretrained\": True,\n            \"progress\": False,\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"num_classes\": 1000,\n        },\n        os.path.join(TESTS_PATH, \"testing_data\", \"kitty_test.jpg\"),\n        282,  # ~ tiger cat\n    ),\n    (\n        {\n            \"model_name\": \"efficientnet-b7\",\n            \"pretrained\": True,\n            \"progress\": False,\n            \"spatial_dims\": 2,\n            \"in_channels\": 3,\n            \"num_classes\": 1000,\n        },\n        os.path.join(TESTS_PATH, \"testing_data\", \"kitty_test.jpg\"),\n        282,  # ~ tiger cat\n    ),\n]\n\n# varying num_classes and in_channels\nCASES_VARIATIONS = []\n\n# change num_classes test\n# 10 classes\n# 2D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=3, num_classes=10\n    )\n)\n# 3D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=3, num_classes=10\n    )\n)\n\n# change in_channels test\n# 1 channel\n# 2D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=1, num_classes=1000\n    )\n)\n# 8 channel\n# 2D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=8, num_classes=1000\n    )\n)\n# 3D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=1, num_classes=1000\n    )\n)\n\nCASE_EXTRACT_FEATURES = [\n    (\n        {\n            \"model_name\": \"efficientnet-b8\",\n            \"pretrained\": True,\n            \"progress\": False,\n            \"spatial_dims\": 2,\n            \"in_channels\": 2,\n            \"adv_prop\": True,\n        },\n        [1, 2, 224, 224],\n        ([1, 32, 112, 112], [1, 56, 56, 56], [1, 88, 28, 28], [1, 248, 14, 14], [1, 704, 7, 7]),\n    )\n]\n\n\nclass TestEFFICIENTNET(unittest.TestCase):\n    @parameterized.expand(CASES_1D + CASES_2D + CASES_3D + CASES_VARIATIONS)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        with skip_if_downloading_fails():\n            net = EfficientNetBN(**input_param).to(device)\n\n        # run inference with random tensor\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n\n        # check output shape\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(CASES_1D + CASES_2D)\n    def test_non_default_shapes(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        with skip_if_downloading_fails():\n            net = EfficientNetBN(**input_param).to(device)\n\n        # override input shape with different variations\n        num_dims = len(input_shape) - 2\n        non_default_sizes = [128, 256, 512]\n        for candidate_size in non_default_sizes:\n            input_shape = input_shape[0:2] + (candidate_size,) * num_dims\n            # run inference with random tensor\n            with eval_mode(net):\n                result = net(torch.randn(input_shape).to(device))\n\n            # check output shape\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(CASES_KITTY_TRAINED)\n    @skip_if_quick\n    @skipUnless(has_torchvision, \"Requires `torchvision` package.\")\n    @skipUnless(has_pil, \"Requires `pillow` package.\")\n    def test_kitty_pretrained(self, input_param, image_path, expected_label):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        # open image\n        image_size = get_efficientnet_image_size(input_param[\"model_name\"])\n        img = PIL.Image.open(image_path)\n\n        # define ImageNet transforms\n        tfms = torchvision.transforms.Compose(\n            [\n                torchvision.transforms.Resize(image_size),\n                torchvision.transforms.CenterCrop(image_size),\n                torchvision.transforms.ToTensor(),\n                torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n            ]\n        )\n\n        # preprocess and prepare image tensor\n        img = tfms(img).unsqueeze(0).to(device)\n\n        # initialize a pretrained model\n        net = test_pretrained_networks(EfficientNetBN, input_param, device)\n\n        # run inference\n        with eval_mode(net):\n            result = net(img)\n        pred_label = torch.argmax(result, dim=-1)\n\n        # check output label\n        self.assertEqual(pred_label, expected_label)\n\n    def test_drop_connect_layer(self):\n        p_list = [float(d + 1) / 10.0 for d in range(9)]\n\n        # testing 1D, 2D and 3D shape\n        for rand_tensor_shape in [(512, 16, 4), (384, 16, 4, 4), (256, 16, 4, 4, 4)]:\n            # test validation mode, out tensor == in tensor\n            training = False\n            for p in p_list:\n                in_tensor = torch.rand(rand_tensor_shape) + 0.1\n                out_tensor = drop_connect(in_tensor, p, training=training)\n                self.assertTrue(torch.equal(out_tensor, in_tensor))\n\n            # test training mode, sum((out tensor * (1.0 - p)) != in tensor)/out_tensor.size() == p\n            # use tolerance of 0.175 to account for rounding errors due to finite set in/out\n            tol = 0.175\n            training = True\n            for p in p_list:\n                in_tensor = torch.rand(rand_tensor_shape) + 0.1\n                out_tensor = drop_connect(in_tensor, p, training=training)\n\n                p_calculated = 1.0 - torch.sum(torch.isclose(in_tensor, out_tensor * (1.0 - p))) / float(\n                    in_tensor.numel()\n                )\n                p_calculated = p_calculated.cpu().numpy()\n\n                self.assertTrue(abs(p_calculated - p) < tol)\n\n    def test_block_args_decode(self):\n        blocks_args_str = get_block_args()\n\n        # convert strings to BlockArgs\n        blocks_args = [BlockArgs.from_string(s) for s in blocks_args_str]\n        # convert BlockArgs back to string\n        blocks_args_str_convert = [s.to_string() for s in blocks_args]\n\n        # check if converted strings match original\n        [self.assertEqual(original, converted) for original, converted in zip(blocks_args_str, blocks_args_str_convert)]\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            # wrong spatial_dims\n            EfficientNetBN(model_name=\"efficientnet-b0\", spatial_dims=4)\n            # wrong model_name\n            EfficientNetBN(model_name=\"efficientnet-b10\", spatial_dims=3)\n\n    def test_func_get_efficientnet_input_shape(self):\n        for model in get_model_names():\n            result_shape = get_efficientnet_image_size(model_name=model)\n            expected_shape = get_expected_model_shape(model)\n            self.assertEqual(result_shape, expected_shape)\n\n    def test_script(self):\n        with skip_if_downloading_fails():\n            net = EfficientNetBN(model_name=\"efficientnet-b0\", spatial_dims=2, in_channels=3, num_classes=1000)\n        net.set_swish(memory_efficient=False)  # at the moment custom memory efficient swish is not exportable with jit\n        test_data = torch.randn(1, 3, 224, 224)\n        test_script_save(net, test_data)\n\n\nclass TestExtractFeatures(unittest.TestCase):\n    @parameterized.expand(CASE_EXTRACT_FEATURES)\n    def test_shape(self, input_param, input_shape, expected_shapes):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        with skip_if_downloading_fails():\n            net = EfficientNetBNFeatures(**input_param).to(device)\n\n        # run inference with random tensor\n        with eval_mode(net):\n            features = net(torch.randn(input_shape).to(device))\n\n        # check output shape\n        self.assertEqual(len(features), len(expected_shapes))\n        for feature, expected_shape in zip(features, expected_shapes):\n            self.assertEqual(feature.shape, torch.Size(expected_shape))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_flexible_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks.encoder import BaseEncoder\nfrom monai.networks.nets import (\n    FLEXUNET_BACKBONE,\n    EfficientNetBNFeatures,\n    FlexibleUNet,\n    FlexUNetEncoderRegister,\n    ResNetEncoder,\n    ResNetFeatures,\n)\nfrom monai.utils import optional_import\nfrom tests.test_utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick\n\ntorchvision, has_torchvision = optional_import(\"torchvision\")\nPIL, has_pil = optional_import(\"PIL\")\n\n\nclass DummyEncoder(BaseEncoder):\n    @classmethod\n    def get_encoder_parameters(cls):\n        basic_dict = {\"spatial_dims\": 2, \"in_channels\": 3, \"pretrained\": False}\n        param_dict_list = [basic_dict]\n        for key in basic_dict.keys():\n            cur_dict = basic_dict.copy()\n            del cur_dict[key]\n            param_dict_list.append(cur_dict)\n        return param_dict_list\n\n    @classmethod\n    def num_channels_per_output(cls):\n        return [(32, 64, 128, 256, 512, 1024), (32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)]\n\n    @classmethod\n    def num_outputs(cls):\n        return [6, 4, 4, 4]\n\n    @classmethod\n    def get_encoder_names(cls):\n        return [\"encoder_wrong_channels\", \"encoder_no_param1\", \"encoder_no_param2\", \"encoder_no_param3\"]\n\n\nFLEXUNET_BACKBONE.register_class(DummyEncoder)\n\n\ndef get_model_names():\n    return [f\"efficientnet-b{d}\" for d in range(8)]\n\n\ndef get_resnet_names():\n    return ResNetEncoder.get_encoder_names()\n\n\ndef make_shape_cases(\n    models,\n    spatial_dims,\n    batches,\n    pretrained,\n    in_channels=3,\n    num_classes=10,\n    input_shape=64,\n    norm=(\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n    upsample=(\"nontrainable\", \"deconv\", \"pixelshuffle\"),\n):\n    ret_tests = []\n    for spatial_dim in spatial_dims:  # selected spatial_dims\n        for batch in batches:  # check single batch as well as multiple batch input\n            for model in models:  # selected models\n                for is_pretrained in pretrained:  # pretrained or not pretrained\n                    for upsample_method in upsample:\n                        if (\"resnet\" in model) and is_pretrained:\n                            continue\n                        kwargs = {\n                            \"in_channels\": in_channels,\n                            \"out_channels\": num_classes,\n                            \"backbone\": model,\n                            \"pretrained\": is_pretrained,\n                            \"spatial_dims\": spatial_dim,\n                            \"norm\": norm,\n                            \"upsample\": upsample_method,\n                        }\n                        ret_tests.append(\n                            [\n                                kwargs,\n                                (batch, in_channels) + (input_shape,) * spatial_dim,\n                                (batch, num_classes) + (input_shape,) * spatial_dim,\n                            ]\n                        )\n    return ret_tests\n\n\ndef make_error_case():\n    error_backbones = DummyEncoder.get_encoder_names()\n    error_param_list = []\n    for backbone in error_backbones:\n        error_param_list.append(\n            [{\"in_channels\": 3, \"out_channels\": 2, \"backbone\": backbone, \"pretrained\": True, \"spatial_dims\": 3}]\n        )\n    return error_param_list\n\n\n# create list of selected models to speed up redundant tests\n# only test efficient net B0, B3 and resnet 10, 18, 34\nSEL_MODELS = [get_model_names()[i] for i in [0, 3]]\nSEL_MODELS += [get_resnet_names()[i] for i in [0, 1, 2]]\n\n# pretrained=False cases\n# 2D and 3D models are expensive so use selected models\nCASES_2D = make_shape_cases(\n    models=SEL_MODELS,\n    spatial_dims=[2],\n    batches=[1, 4],\n    pretrained=[False],\n    in_channels=3,\n    num_classes=10,\n    norm=\"instance\",\n)\nCASES_3D = make_shape_cases(\n    models=[SEL_MODELS[0], SEL_MODELS[2]],\n    spatial_dims=[3],\n    batches=[1],\n    pretrained=[False],\n    in_channels=3,\n    num_classes=10,\n    norm=\"batch\",\n)\n\n# varying num_classes and in_channels\nCASES_VARIATIONS = []\n\n# change num_classes test\n# 20 classes\n# 2D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=3, num_classes=20\n    )\n)\n# 3D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=3, num_classes=20\n    )\n)\n\n# change in_channels test\n# 1 channel\n# 2D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=1, num_classes=10\n    )\n)\n# 8 channel\n# 2D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=8, num_classes=10\n    )\n)\n# 3D\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=1, num_classes=10\n    )\n)\n\n# change input shape test\n# 96\n# 2D 96x96 input\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS,\n        spatial_dims=[2],\n        batches=[1],\n        pretrained=[False, True],\n        in_channels=3,\n        num_classes=10,\n        input_shape=96,\n    )\n)\n# 2D 64x64 input\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS,\n        spatial_dims=[2],\n        batches=[1],\n        pretrained=[False, True],\n        in_channels=3,\n        num_classes=10,\n        input_shape=64,\n    )\n)\n\n# 3D 32x32x32 input\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS,\n        spatial_dims=[2],\n        batches=[1],\n        pretrained=[False],\n        in_channels=3,\n        num_classes=10,\n        input_shape=32,\n    )\n)\n\n# 3D 64x64x64 input\nCASES_VARIATIONS.extend(\n    make_shape_cases(\n        models=SEL_MODELS,\n        spatial_dims=[2],\n        batches=[1],\n        pretrained=[False],\n        in_channels=3,\n        num_classes=10,\n        input_shape=64,\n    )\n)\n\n# pretrain weight verified\nCASES_PRETRAIN = [\n    (\n        {\n            \"in_channels\": 3,\n            \"out_channels\": 10,\n            \"backbone\": SEL_MODELS[0],\n            \"pretrained\": True,\n            \"spatial_dims\": 2,\n            \"norm\": (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        },\n        EfficientNetBNFeatures,\n        {\n            \"in_channels\": 3,\n            \"num_classes\": 10,\n            \"model_name\": SEL_MODELS[0],\n            \"pretrained\": True,\n            \"spatial_dims\": 2,\n            \"norm\": (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        },\n        [\"_conv_stem.weight\"],\n    ),\n    (\n        {\n            \"in_channels\": 1,\n            \"out_channels\": 10,\n            \"backbone\": SEL_MODELS[2],\n            \"pretrained\": True,\n            \"spatial_dims\": 3,\n            \"norm\": (\"batch\", {\"eps\": 1e-3, \"momentum\": 0.01}),\n        },\n        ResNetFeatures,\n        {\"model_name\": SEL_MODELS[2], \"pretrained\": True, \"spatial_dims\": 3, \"in_channels\": 1},\n        [\"conv1.weight\"],\n    ),\n]\n\nCASE_ERRORS = make_error_case()\n\n# Verify Register class with string type\nCASE_REGISTER_ENCODER = [\"EfficientNetEncoder\", \"monai.networks.nets.EfficientNetEncoder\"]\n\n\n@SkipIfNoModule(\"hf_hub_download\")\n@skip_if_quick\nclass TestFLEXIBLEUNET(unittest.TestCase):\n    @parameterized.expand(CASES_2D + CASES_3D + CASES_VARIATIONS)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        with skip_if_downloading_fails():\n            net = FlexibleUNet(**input_param).to(device)\n\n        # run inference with random tensor\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n\n        # check output shape\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(CASES_PRETRAIN)\n    def test_pretrain(self, flexunet_input_param, feature_extractor_class, feature_extractor_input_param, weight_list):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        with skip_if_downloading_fails():\n            net = FlexibleUNet(**flexunet_input_param).to(device)\n\n        with skip_if_downloading_fails():\n            feature_extractor_net = feature_extractor_class(**feature_extractor_input_param).to(device)\n\n        for weight_name in weight_list:\n            if weight_name in net.encoder.state_dict() and weight_name in feature_extractor_net.state_dict():\n                net_weight = net.encoder.state_dict()[weight_name]\n                download_weight = feature_extractor_net.state_dict()[weight_name]\n                weight_diff = torch.abs(net_weight - download_weight)\n                diff_sum = torch.sum(weight_diff)\n                # check if a weight in weight_list equals to the downloaded weight.\n                self.assertLess(abs(diff_sum.item() - 0), 1e-8)\n\n    @parameterized.expand(CASE_ERRORS)\n    def test_error_raise(self, input_param):\n        with self.assertRaises((ValueError, NotImplementedError)):\n            FlexibleUNet(**input_param)\n\n\nclass TestFlexUNetEncoderRegister(unittest.TestCase):\n    @parameterized.expand(CASE_REGISTER_ENCODER)\n    def test_regist(self, encoder):\n        tmp_backbone = FlexUNetEncoderRegister()\n        tmp_backbone.register_class(encoder)\n        for backbone in tmp_backbone.register_dict:\n            backbone_type = tmp_backbone.register_dict[backbone][\"type\"]\n            feature_number = backbone_type.num_outputs()\n            feature_channel = backbone_type.num_channels_per_output()\n            param_dict_list = backbone_type.get_encoder_parameters()\n            encoder_name_list = backbone_type.get_encoder_names()\n            encoder_cnt = encoder_name_list.index(backbone)\n            self.assertEqual(feature_number[encoder_cnt], tmp_backbone.register_dict[backbone][\"feature_number\"])\n            self.assertEqual(feature_channel[encoder_cnt], tmp_backbone.register_dict[backbone][\"feature_channel\"])\n            self.assertEqual(param_dict_list[encoder_cnt], tmp_backbone.register_dict[backbone][\"parameter\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_fullyconnectednet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import FullyConnectedNet, VarFullyConnectedNet\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nFC_TEST_CASE_0 = [0]\nFC_TEST_CASE_1 = [0.15]\n\nFC_CASES = [FC_TEST_CASE_0, FC_TEST_CASE_1]\n\nVFC_TEST_CASE_0 = [\n    {\n        \"in_channels\": 10,\n        \"out_channels\": 10,\n        \"latent_size\": 30,\n        \"encode_channels\": (15, 20, 25),\n        \"decode_channels\": (15, 20, 25),\n    },\n    (3, 10),\n    (3, 10),\n]\n\nVFC_CASES = [VFC_TEST_CASE_0]\n\n\nclass TestFullyConnectedNet(unittest.TestCase):\n\n    def setUp(self):\n        self.batch_size = 10\n        self.inSize = 10\n        self.arrShape = (self.batch_size, self.inSize)\n        self.outSize = 3\n        self.channels = [8, 16]\n        self.arr = torch.randn(self.arrShape, dtype=torch.float32).to(device)\n\n    @parameterized.expand(FC_CASES)\n    def test_fc_shape(self, dropout):\n        net = FullyConnectedNet(self.inSize, self.outSize, self.channels, dropout).to(device)\n        out = net(self.arr)\n        self.assertEqual(out.shape, (self.batch_size, self.outSize))\n\n    @parameterized.expand(VFC_CASES)\n    def test_vfc_shape(self, input_param, input_shape, expected_shape):\n        net = VarFullyConnectedNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))[0]\n            self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_generator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import Generator\nfrom tests.test_utils import test_script_save\n\nTEST_CASE_0 = [\n    {\"latent_shape\": (64,), \"start_shape\": (8, 8, 8), \"channels\": (8, 4, 1), \"strides\": (2, 2, 2), \"num_res_units\": 0},\n    torch.rand(16, 64),\n    (16, 1, 64, 64),\n]\n\nTEST_CASE_1 = [\n    {\"latent_shape\": (64,), \"start_shape\": (8, 8, 8), \"channels\": (8, 4, 1), \"strides\": (2, 2, 2), \"num_res_units\": 2},\n    torch.rand(16, 64),\n    (16, 1, 64, 64),\n]\n\nTEST_CASE_2 = [\n    {\"latent_shape\": (64,), \"start_shape\": (8, 8, 8), \"channels\": (8, 1), \"strides\": (2, 2), \"num_res_units\": 2},\n    torch.rand(16, 64),\n    (16, 1, 32, 32),\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2]\n\n\nclass TestGenerator(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_data, expected_shape):\n        net = Generator(**input_param)\n        with eval_mode(net):\n            result = net.forward(input_data)\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = Generator(latent_shape=(64,), start_shape=(8, 8, 8), channels=(8, 1), strides=(2, 2), num_res_units=2)\n        test_data = torch.rand(16, 64)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_globalnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.blocks import Warp\nfrom monai.networks.nets import GlobalNet\nfrom monai.networks.nets.regunet import AffineHead\nfrom tests.test_utils import assert_allclose, test_script_save\n\nTEST_CASES_AFFINE_TRANSFORM = [\n    [\n        {\"spatial_dims\": 3, \"image_size\": (2, 2, 2), \"decode_size\": (2, 2, 2), \"in_channels\": 1, \"save_theta\": True},\n        torch.ones(2, 12),\n        torch.tensor([[[1, 2], [2, 3]], [[2, 3], [3, 4]]]).unsqueeze(0).unsqueeze(0).expand(2, 3, 2, 2, 2),\n    ],\n    [\n        {\"spatial_dims\": 3, \"image_size\": (2, 2, 2), \"decode_size\": (2, 2, 2), \"in_channels\": 1},\n        torch.arange(1, 13).reshape(1, 12).to(torch.float),\n        torch.tensor(\n            [\n                [[[4.0, 7.0], [6.0, 9.0]], [[5.0, 8.0], [7.0, 10.0]]],\n                [[[8.0, 15.0], [14.0, 21.0]], [[13.0, 20.0], [19.0, 26.0]]],\n                [[[12.0, 23.0], [22.0, 33.0]], [[21.0, 32.0], [31.0, 42.0]]],\n            ]\n        ).unsqueeze(0),\n    ],\n]\n\nTEST_CASES_GLOBAL_NET = [\n    [\n        {\n            \"image_size\": (16, 16),\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"num_channel_initial\": 16,\n            \"depth\": 1,\n            \"out_kernel_initializer\": \"kaiming_uniform\",\n            \"out_activation\": None,\n            \"pooling\": True,\n            \"concat_skip\": True,\n            \"encode_kernel_sizes\": 3,\n            \"save_theta\": theta,\n        },\n        (1, 1, 16, 16),\n        (1, 2, 16, 16),\n    ]\n    for theta in (False, True)\n]\n\n\nclass TestAffineHead(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_AFFINE_TRANSFORM)\n    def test_shape(self, input_param, theta, expected_val):\n        layer = AffineHead(**input_param)\n        if input_param.get(\"save_theta\"):\n            assert_allclose(layer.theta, torch.Tensor())\n        result = layer.affine_transform(theta)\n        np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)\n\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\nclass TestGlobalNet(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_GLOBAL_NET)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = GlobalNet(**input_param).to(device)\n        warp_layer = Warp()\n        with eval_mode(net):\n            img = torch.randn(input_shape)\n            result = net(img.to(device))\n            if input_param.get(\"save_theta\"):\n                assert_allclose(net.output_block.theta, torch.Tensor([[1.0, 0.0, 0.0, 0.0, 1.0, 0.0]]))\n            warped = warp_layer(img.to(device), result)\n            self.assertEqual(result.shape, expected_shape)\n            # testing initial pred identity\n            np.testing.assert_allclose(warped.detach().cpu().numpy(), img.detach().cpu().numpy(), rtol=1e-4, atol=1e-4)\n\n    @parameterized.expand(TEST_CASES_GLOBAL_NET)\n    def test_script(self, input_param, input_shape, _):\n        net = GlobalNet(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_highresnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import HighResNet\nfrom tests.test_utils import DistTestCase, TimedCall, test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_1 = [  # single channel 3D, batch 16\n    {\"spatial_dims\": 3, \"in_channels\": 1, \"out_channels\": 3, \"norm_type\": \"instance\"},\n    (16, 1, 32, 24, 48),\n    (16, 3, 32, 24, 48),\n]\n\nTEST_CASE_2 = [  # 4-channel 3D, batch 1\n    {\"spatial_dims\": 3, \"in_channels\": 4, \"out_channels\": 3, \"acti_type\": \"relu6\"},\n    (1, 4, 17, 64, 48),\n    (1, 3, 17, 64, 48),\n]\n\nTEST_CASE_3 = [  # 4-channel 2D, batch 7\n    {\"spatial_dims\": 2, \"in_channels\": 4, \"out_channels\": 3},\n    (7, 4, 64, 48),\n    (7, 3, 64, 48),\n]\n\nTEST_CASE_4 = [  # 4-channel 1D, batch 16\n    {\"spatial_dims\": 1, \"in_channels\": 4, \"out_channels\": 3, \"dropout_prob\": 0.1},\n    (16, 4, 63),\n    (16, 3, 63),\n]\n\n\nclass TestHighResNet(DistTestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = HighResNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @TimedCall(seconds=800, force_quit=True)\n    def test_script(self):\n        input_param, input_shape, expected_shape = TEST_CASE_1\n        net = HighResNet(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_hovernet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode, train_mode\nfrom monai.networks.nets import HoVerNet\nfrom monai.networks.nets.hovernet import _DenseLayerDecoder\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [  # fast mode\n    {\"out_classes\": 5, \"mode\": HoVerNet.Mode.FAST},\n    (1, 3, 256, 256),\n    {HoVerNet.Branch.NP: (1, 2, 164, 164), HoVerNet.Branch.NC: (1, 5, 164, 164), HoVerNet.Branch.HV: (1, 2, 164, 164)},\n]\n\nTEST_CASE_1 = [  # original mode\n    {\"out_classes\": 6, \"mode\": HoVerNet.Mode.ORIGINAL},\n    (1, 3, 270, 270),\n    {HoVerNet.Branch.NP: (1, 2, 80, 80), HoVerNet.Branch.NC: (1, 6, 80, 80), HoVerNet.Branch.HV: (1, 2, 80, 80)},\n]\n\nTEST_CASE_2 = [  # dropout\n    {\"mode\": HoVerNet.Mode.FAST, \"dropout_prob\": 0.5, \"out_classes\": 3},\n    (1, 3, 256, 256),\n    {HoVerNet.Branch.NP: (1, 2, 164, 164), HoVerNet.Branch.NC: (1, 3, 164, 164), HoVerNet.Branch.HV: (1, 2, 164, 164)},\n]\n\nTEST_CASE_3 = [  # np_out_channels\n    {\"mode\": HoVerNet.Mode.FAST, \"np_out_channels\": 3, \"out_classes\": 2},\n    (1, 3, 256, 256),\n    {HoVerNet.Branch.NP: (1, 3, 164, 164), HoVerNet.Branch.NC: (1, 2, 164, 164), HoVerNet.Branch.HV: (1, 2, 164, 164)},\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]\n\nILL_CASES = [\n    [{\"out_classes\": 6, \"mode\": 3}],\n    [{\"out_classes\": 6, \"mode\": \"Wrong\"}],\n    [{\"out_classes\": 1000, \"mode\": HoVerNet.Mode.ORIGINAL}],\n    [{\"out_classes\": 1, \"mode\": HoVerNet.Mode.ORIGINAL}],\n    [{\"out_classes\": 6, \"mode\": HoVerNet.Mode.ORIGINAL, \"dropout_prob\": 100}],\n]\n\n\ndef check_branch(branch, mode):\n    if mode == HoVerNet.Mode.ORIGINAL:\n        ksize = 5\n    else:\n        ksize = 3\n\n    if branch.decoderblock1.conva.kernel_size != (ksize, ksize):\n        return True\n    if branch.decoderblock1.convf.kernel_size != (1, 1):\n        return True\n    for block in branch.decoderblock1:\n        if isinstance(block, _DenseLayerDecoder):\n            if block.layers.conv1.kernel_size != (1, 1) or block.layers.conv2.kernel_size != (ksize, ksize):\n                return True\n\n    if branch.decoderblock2.conva.kernel_size != (ksize, ksize):\n        return True\n    if branch.decoderblock2.convf.kernel_size != (1, 1):\n        return True\n\n    for block in branch.decoderblock2:\n        if isinstance(block, _DenseLayerDecoder):\n            if block.layers.conv1.kernel_size != (1, 1) or block.layers.conv2.kernel_size != (ksize, ksize):\n                return True\n\n    return False\n\n\ndef check_output(out_block, mode):\n    if mode == HoVerNet.Mode.ORIGINAL:\n        ksize = 5\n    else:\n        ksize = 3\n\n    if out_block.decoderblock3.conva.kernel_size != (ksize, ksize) or out_block.decoderblock3.conva.stride != (1, 1):\n        return True\n    if out_block.decoderblock4.conv.kernel_size != (1, 1) or out_block.decoderblock4.conv.stride != (1, 1):\n        return True\n\n    return False\n\n\ndef check_kernels(net, mode):\n    # Check the Encoder blocks\n    for layer_num, res_block in enumerate(net.res_blocks):\n        for inner_num, layer in enumerate(res_block.layers):\n            if layer_num > 0 and inner_num == 0:\n                sz = 2\n            else:\n                sz = 1\n\n            if (\n                layer.layers.conv1.kernel_size != (1, 1)\n                or layer.layers.conv2.kernel_size != (3, 3)\n                or layer.layers.conv3.kernel_size != (1, 1)\n            ):\n                return True\n\n            if (\n                layer.layers.conv1.stride != (1, 1)\n                or layer.layers.conv2.stride != (sz, sz)\n                or layer.layers.conv3.stride != (1, 1)\n            ):\n                return True\n\n        sz2 = 1\n        if layer_num > 0:\n            sz2 = 2\n        if res_block.shortcut.kernel_size != (1, 1) or res_block.shortcut.stride != (sz2, sz2):\n            return True\n\n    if net.bottleneck.conv_bottleneck.kernel_size != (1, 1) or net.bottleneck.conv_bottleneck.stride != (1, 1):\n        return True\n\n    # Check HV Branch\n    if check_branch(net.horizontal_vertical.decoder_blocks, mode):\n        return True\n    if check_output(net.horizontal_vertical.output_features, mode):\n        return True\n\n    # Check NP Branch\n    if check_branch(net.nucleus_prediction.decoder_blocks, mode):\n        return True\n    if check_output(net.nucleus_prediction.output_features, mode):\n        return True\n\n    # Check NC Branch\n    if check_branch(net.type_prediction.decoder_blocks, mode):\n        return True\n    if check_output(net.type_prediction.output_features, mode):\n        return True\n\n    return False\n\n\nclass TestHoverNet(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shapes):\n        input_param[\"decoder_padding\"] = False\n        net = HoVerNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            for item in result:\n                self.assertEqual(result[item].shape, expected_shapes[item])\n\n    @parameterized.expand(CASES)\n    def test_decoder_padding_shape(self, input_param, input_shape, expected_shapes):\n        if input_param[\"mode\"] == HoVerNet.Mode.FAST:\n            input_param[\"decoder_padding\"] = True\n            net = HoVerNet(**input_param).to(device)\n            with eval_mode(net):\n                result = net.forward(torch.randn(input_shape).to(device))\n                for item in result:\n                    expected_shape = expected_shapes[item]\n                    padding_expected_shape = list(expected_shape)\n                    padding_expected_shape[2:] = input_shape[2:]\n                    self.assertEqual(result[item].shape, tuple(padding_expected_shape))\n        else:\n            pass\n\n    def test_script(self):\n        for padding_flag in [True, False]:\n            net = HoVerNet(mode=HoVerNet.Mode.FAST, decoder_padding=padding_flag)\n        test_data = torch.randn(1, 3, 256, 256)\n        test_script_save(net, test_data)\n\n    def test_ill_input_shape(self):\n        net = HoVerNet(mode=HoVerNet.Mode.FAST)\n        with eval_mode(net):\n            with self.assertRaises(ValueError):\n                net.forward(torch.randn(1, 3, 270, 260))\n\n    def test_kernels_strides(self):\n        net = HoVerNet(mode=HoVerNet.Mode.FAST, out_classes=2)\n        with eval_mode(net):\n            self.assertEqual(check_kernels(net, HoVerNet.Mode.FAST), False)\n\n        net = HoVerNet(mode=HoVerNet.Mode.ORIGINAL, out_classes=2)\n        with eval_mode(net):\n            self.assertEqual(check_kernels(net, HoVerNet.Mode.ORIGINAL), False)\n\n    def test_freeze_encoder(self):\n        net = HoVerNet(mode=HoVerNet.Mode.FAST, freeze_encoder=True)\n        with train_mode(net):\n            for _, param in net.res_blocks[1:].named_parameters():\n                self.assertFalse(param.requires_grad)\n            for name, param in net.res_blocks[0].named_parameters():\n                if param.requires_grad is True:\n                    self.assertTrue(\"bna_block\" or \"shortcut\" in name)\n\n    @parameterized.expand(ILL_CASES)\n    def test_ill_input_hyper_params(self, input_param):\n        with self.assertRaises(ValueError):\n            _ = HoVerNet(**input_param)\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False)\n"
  },
  {
    "path": "tests/networks/nets/test_masked_autoencoder_vit.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT\nfrom tests.test_utils import dict_product, skip_if_quick\n\nTEST_CASE_MaskedAutoEncoderViT = []\n\nfor base_params in dict_product(\n    masking_ratio=[0.5],\n    dropout_rate=[0.6],\n    in_channels=[4],\n    hidden_size=[768],\n    img_size_scalar=[96, 128],\n    patch_size_scalar=[16],\n    num_heads=[12],\n    mlp_dim=[3072],\n    num_layers=[4],\n    decoder_hidden_size=[384],\n    decoder_mlp_dim=[512],\n    decoder_num_layers=[4],\n    decoder_num_heads=[16],\n    pos_embed_type=[\"sincos\", \"learnable\"],\n    proj_type=[\"conv\", \"perceptron\"],\n):\n    img_size_scalar = base_params.pop(\"img_size_scalar\")\n    patch_size_scalar = base_params.pop(\"patch_size_scalar\")\n    for nd in (2, 3):\n        # Parameters for the MaskedAutoEncoderViT model\n        model_params = base_params.copy()\n        model_params[\"img_size\"] = (img_size_scalar,) * nd\n        model_params[\"patch_size\"] = (patch_size_scalar,) * nd\n        model_params[\"decoder_pos_embed_type\"] = model_params[\"pos_embed_type\"]\n\n        # Expected input and output shapes\n        input_shape = (2, model_params[\"in_channels\"], *([img_size_scalar] * nd))\n        # N, num_patches, patch_dim_product\n        # num_patches = (img_size // patch_size) ** nd\n        # patch_dim_product = in_channels * (patch_size**nd)\n        expected_shape = (\n            2,\n            (img_size_scalar // patch_size_scalar) ** nd,\n            model_params[\"in_channels\"] * (patch_size_scalar**nd),\n        )\n\n        if nd == 2:\n            model_params[\"spatial_dims\"] = 2\n\n        test_case = [model_params, input_shape, expected_shape]\n        TEST_CASE_MaskedAutoEncoderViT.append(test_case)\n\nTEST_CASE_ill_args = [\n    [{\"in_channels\": 1, \"img_size\": (128, 128, 128), \"patch_size\": (16, 16, 16), \"dropout_rate\": 5.0}],\n    [{\"in_channels\": 1, \"img_size\": (128, 128, 128), \"patch_size\": (64, 64, 64), \"pos_embed_type\": \"sin\"}],\n    [{\"in_channels\": 1, \"img_size\": (128, 128, 128), \"patch_size\": (64, 64, 64), \"decoder_pos_embed_type\": \"sin\"}],\n    [{\"in_channels\": 1, \"img_size\": (32, 32, 32), \"patch_size\": (64, 64, 64)}],\n    [{\"in_channels\": 1, \"img_size\": (128, 128, 128), \"patch_size\": (64, 64, 64), \"num_layers\": 12, \"num_heads\": 14}],\n    [{\"in_channels\": 1, \"img_size\": (97, 97, 97), \"patch_size\": (16, 16, 16)}],\n    [{\"in_channels\": 1, \"img_size\": (128, 128, 128), \"patch_size\": (64, 64, 64), \"masking_ratio\": 1.1}],\n    [{\"in_channels\": 1, \"img_size\": (128, 128, 128), \"patch_size\": (64, 64, 64), \"masking_ratio\": -0.1}],\n]\n\n\n@skip_if_quick\nclass TestMaskedAutoencoderViT(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_MaskedAutoEncoderViT)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = MaskedAutoEncoderViT(**input_param)\n        with eval_mode(net):\n            result, _ = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_frozen_pos_embedding(self):\n        net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16))\n\n        self.assertEqual(net.decoder_pos_embedding.requires_grad, False)\n\n    @parameterized.expand(TEST_CASE_ill_args)\n    def test_ill_arg(self, input_param):\n        with self.assertRaises(ValueError):\n            MaskedAutoEncoderViT(**input_param)\n\n    def test_access_attn_matrix(self):\n        # input format\n        in_channels = 1\n        img_size = (96, 96, 96)\n        patch_size = (16, 16, 16)\n        in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])\n\n        # no data in the matrix\n        no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)\n        no_matrix_acess_blk(torch.randn(in_shape))\n        assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)\n        # no of elements is zero\n        assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0\n\n        # be able to acess the attention matrix\n        matrix_acess_blk = MaskedAutoEncoderViT(\n            in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True\n        )\n        matrix_acess_blk(torch.randn(in_shape))\n\n        assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55)\n\n    def test_masking_ratio(self):\n        # input format\n        in_channels = 1\n        img_size = (96, 96, 96)\n        patch_size = (16, 16, 16)\n        in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])\n\n        # masking ratio 0.25\n        masking_ratio_blk = MaskedAutoEncoderViT(\n            in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True\n        )\n        masking_ratio_blk(torch.randn(in_shape))\n        desired_num_tokens = int(\n            (img_size[0] // patch_size[0])\n            * (img_size[1] // patch_size[1])\n            * (img_size[2] // patch_size[2])\n            * (1 - 0.25)\n        )\n        assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens\n\n        # masking ratio 0.33\n        masking_ratio_blk = MaskedAutoEncoderViT(\n            in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True\n        )\n        masking_ratio_blk(torch.randn(in_shape))\n        desired_num_tokens = int(\n            (img_size[0] // patch_size[0])\n            * (img_size[1] // patch_size[1])\n            * (img_size[2] // patch_size[2])\n            * (1 - 0.33)\n        )\n\n        assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_mednext.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS\nfrom tests.test_utils import dict_product  # Import dict_product\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_MEDNEXT = [\n    [params, (2, 1, *([16] * params[\"spatial_dims\"])), (2, 2, *([16] * params[\"spatial_dims\"]))]\n    for params in dict_product(\n        spatial_dims=range(2, 4),\n        init_filters=[8, 16],\n        deep_supervision=[False, True],\n        use_residual_connection=[False, True],\n    )\n]\nTEST_CASE_MEDNEXT_2 = [\n    [params, (2, 1, *([16] * params[\"spatial_dims\"])), (2, params[\"out_channels\"], *([16] * params[\"spatial_dims\"]))]\n    for params in dict_product(\n        spatial_dims=range(2, 4), out_channels=[1, 2], deep_supervision=[False, True], init_filters=[8]\n    )\n]\n\n\nTEST_CASE_MEDNEXT_VARIANTS = [\n    [\n        params[\"model\"],\n        {\"spatial_dims\": params[\"spatial_dims\"], \"in_channels\": 1, \"out_channels\": params[\"out_channels\"]},\n        (2, 1, *([16] * params[\"spatial_dims\"])),\n        (2, params[\"out_channels\"], *([16] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(\n        model=[MedNeXtS, MedNeXtM, MedNeXtL], spatial_dims=range(2, 4), out_channels=[1, 2], in_channels=[1]\n    )\n]\n\n\nclass TestMedNeXt(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_MEDNEXT)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = MedNeXt(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            if input_param[\"deep_supervision\"] and net.training:\n                assert isinstance(result, tuple)\n                self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))\n            else:\n                self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n    @parameterized.expand(TEST_CASE_MEDNEXT_2)\n    def test_shape2(self, input_param, input_shape, expected_shape):\n        net = MedNeXt(**input_param).to(device)\n\n        net.train()\n        result = net(torch.randn(input_shape).to(device))\n        if input_param[\"deep_supervision\"]:\n            assert isinstance(result, tuple)\n            self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))\n        else:\n            assert isinstance(result, torch.Tensor)\n            self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n        net.eval()\n        result = net(torch.randn(input_shape).to(device))\n        assert isinstance(result, torch.Tensor)\n        self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n    def test_ill_arg(self):\n        with self.assertRaises(AssertionError):\n            MedNeXt(spatial_dims=4)\n\n    @parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS)\n    def test_mednext_variants(self, model, input_param, input_shape, expected_shape):\n        net = model(**input_param).to(device)\n\n        net.train()\n        result = net(torch.randn(input_shape).to(device))\n        assert isinstance(result, torch.Tensor)\n        self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n        net.eval()\n        with torch.no_grad():\n            result = net(torch.randn(input_shape).to(device))\n        assert isinstance(result, torch.Tensor)\n        self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_milmodel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import MILModel\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, test_script_save\n\nmodels, _ = optional_import(\"torchvision.models\")\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_MILMODEL = []\nfor num_classes in [1, 5]:\n    for mil_mode in [\"mean\", \"max\", \"att\", \"att_trans\", \"att_trans_pyramid\"]:\n        test_case = [\n            {\"num_classes\": num_classes, \"mil_mode\": mil_mode, \"pretrained\": False},\n            (1, 2, 3, 512, 512),\n            (1, num_classes),\n        ]\n        TEST_CASE_MILMODEL.append(test_case)\n\nfor trans_blocks in [1, 3]:\n    test_case = [\n        {\"num_classes\": 5, \"pretrained\": False, \"trans_blocks\": trans_blocks, \"trans_dropout\": 0.5},\n        (1, 2, 3, 512, 512),\n        (1, 5),\n    ]\n    TEST_CASE_MILMODEL.append(test_case)\n\n# torchvision backbone\nfor pretrained in [True, False]:\n    TEST_CASE_MILMODEL.append(\n        [{\"num_classes\": 5, \"backbone\": \"resnet18\", \"pretrained\": pretrained}, (2, 2, 3, 512, 512), (2, 5)]\n    )\n\n# custom backbone\nbackbone = models.densenet121()\nbackbone_nfeatures = backbone.classifier.in_features\nbackbone.classifier = torch.nn.Identity()\nTEST_CASE_MILMODEL.append(\n    [\n        {\"num_classes\": 5, \"backbone\": backbone, \"backbone_num_features\": backbone_nfeatures, \"pretrained\": False},\n        (2, 2, 3, 512, 512),\n        (2, 5),\n    ]\n)\n\n\nclass TestMilModel(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_MILMODEL)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        with skip_if_downloading_fails():\n            net = MILModel(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape, dtype=torch.float).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_args(self):\n        with self.assertRaises(ValueError):\n            MILModel(\n                num_classes=5,\n                pretrained=False,\n                backbone=\"resnet50\",\n                backbone_num_features=2048,\n                mil_mode=\"att_trans_pyramid\",\n            )\n\n    def test_script(self):\n        input_param, input_shape, expected_shape = TEST_CASE_MILMODEL[0]\n        net = MILModel(**input_param)\n        test_data = torch.randn(input_shape, dtype=torch.float)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_net_adapter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import NetAdapter, resnet18\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [{\"num_classes\": 1, \"use_conv\": True, \"dim\": 2}, (2, 3, 224, 224), (2, 1, 8, 1)]\n\nTEST_CASE_1 = [{\"num_classes\": 1, \"use_conv\": True, \"dim\": 3, \"pool\": None}, (2, 3, 32, 32, 32), (2, 1, 1, 1, 1)]\n\nTEST_CASE_2 = [{\"num_classes\": 5, \"use_conv\": True, \"dim\": 3, \"pool\": None}, (2, 3, 32, 32, 32), (2, 5, 1, 1, 1)]\n\nTEST_CASE_3 = [\n    {\"num_classes\": 5, \"use_conv\": True, \"pool\": (\"avg\", {\"kernel_size\": 4, \"stride\": 1}), \"dim\": 3},\n    (2, 3, 128, 128, 128),\n    (2, 5, 5, 1, 1),\n]\n\nTEST_CASE_4 = [\n    {\"num_classes\": 5, \"use_conv\": False, \"pool\": (\"adaptiveavg\", {\"output_size\": (1, 1, 1)}), \"dim\": 3},\n    (2, 3, 32, 32, 32),\n    (2, 5),\n]\n\n\nclass TestNetAdapter(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_shape(self, input_param, input_shape, expected_shape):\n        spatial_dims = input_param[\"dim\"]\n        stride = (1, 2, 2)[:spatial_dims]\n        model = resnet18(spatial_dims=spatial_dims, conv1_t_stride=stride)\n        input_param[\"model\"] = model\n        net = NetAdapter(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand([TEST_CASE_0])\n    def test_script(self, input_param, input_shape, expected_shape):\n        spatial_dims = input_param[\"dim\"]\n        stride = (1, 2, 2)[:spatial_dims]\n        model = resnet18(spatial_dims=spatial_dims, conv1_t_stride=stride)\n        input_param[\"model\"] = model\n        net = NetAdapter(**input_param).to(\"cpu\")\n        test_data = torch.randn(input_shape).to(\"cpu\")\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_network_consistency.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport json\nimport os\nimport unittest\nfrom collections.abc import Sequence\nfrom glob import glob\nfrom unittest.case import skipIf\n\nimport torch\nfrom parameterized.parameterized import parameterized\n\nimport monai.networks.nets as nets\nfrom monai.utils import set_determinism\nfrom monai.utils.misc import MONAIEnvVars\nfrom tests.test_utils import assert_allclose\n\nextra_test_data_dir = MONAIEnvVars.extra_test_data()\n\nTESTS = []\nif extra_test_data_dir is not None:\n    for data_path in glob(os.path.join(extra_test_data_dir, \"**\", \"*.pt\")):\n        json_path = data_path[:-3] + \".json\"\n        # net_name is filename until first underscore (e.g., unet_0.pt is unet)\n        net_name = os.path.basename(data_path).split(\"_\")[0]\n        TESTS.append((net_name, data_path, json_path))\n\n\nclass TestNetworkConsistency(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @skipIf(\n        len(TESTS) == 0,\n        \"To run these tests, clone https://github.com/Project-MONAI/MONAI-extra-test-data and set MONAI_EXTRA_TEST_DATA\",\n    )\n    @parameterized.expand(TESTS, skip_on_empty=True)\n    def test_network_consistency(self, net_name, data_path, json_path):\n        print(\"Net name: \" + net_name)\n        print(\"Data path: \" + data_path)\n        print(\"JSON path: \" + json_path)\n\n        # Load data\n        loaded_data = torch.load(data_path, weights_only=True)\n\n        # Load json from file\n        json_file = open(json_path)\n        model_params = json.load(json_file)\n        json_file.close()\n\n        # Create model\n        model = getattr(nets, net_name)(**model_params)\n        model.load_state_dict(loaded_data[\"model\"], strict=False)\n        model.eval()\n\n        in_data = loaded_data[\"in_data\"]\n        expected_out_data = loaded_data[\"out_data\"]\n\n        actual_out_data = model(in_data)\n\n        self.check_output_consistency(actual_out_data, expected_out_data)\n\n    def check_output_consistency(self, actual, expected):\n        if isinstance(actual, Sequence):\n            for a, e in zip(actual, expected):\n                self.check_output_consistency(a, e)\n        else:\n            assert_allclose(actual, expected, rtol=5e-2, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_patch_gan_dicriminator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator\nfrom tests.test_utils import test_script_save\n\nTEST_PATCHGAN = [\n    [\n        {\n            \"num_layers_d\": 3,\n            \"spatial_dims\": 2,\n            \"channels\": 8,\n            \"in_channels\": 3,\n            \"out_channels\": 1,\n            \"kernel_size\": 3,\n            \"activation\": \"LEAKYRELU\",\n            \"norm\": \"instance\",\n            \"bias\": False,\n            \"dropout\": 0.1,\n        },\n        torch.rand([1, 3, 256, 512]),\n        (1, 8, 128, 256),\n        (1, 1, 32, 64),\n    ],\n    [\n        {\n            \"num_layers_d\": 3,\n            \"spatial_dims\": 3,\n            \"channels\": 8,\n            \"in_channels\": 3,\n            \"out_channels\": 1,\n            \"kernel_size\": 3,\n            \"activation\": \"LEAKYRELU\",\n            \"norm\": \"instance\",\n            \"bias\": False,\n            \"dropout\": 0.1,\n        },\n        torch.rand([1, 3, 256, 512, 256]),\n        (1, 8, 128, 256, 128),\n        (1, 1, 32, 64, 32),\n    ],\n]\n\nTEST_MULTISCALE_PATCHGAN = [\n    [\n        {\n            \"num_d\": 2,\n            \"num_layers_d\": 3,\n            \"spatial_dims\": 2,\n            \"channels\": 8,\n            \"in_channels\": 3,\n            \"out_channels\": 1,\n            \"kernel_size\": 3,\n            \"activation\": \"LEAKYRELU\",\n            \"norm\": \"instance\",\n            \"bias\": False,\n            \"dropout\": 0.1,\n            \"minimum_size_im\": 256,\n        },\n        torch.rand([1, 3, 256, 512]),\n        [(1, 1, 32, 64), (1, 1, 4, 8)],\n        [4, 7],\n    ],\n    [\n        {\n            \"num_d\": 2,\n            \"num_layers_d\": 3,\n            \"spatial_dims\": 3,\n            \"channels\": 8,\n            \"in_channels\": 3,\n            \"out_channels\": 1,\n            \"kernel_size\": 3,\n            \"activation\": \"LEAKYRELU\",\n            \"norm\": \"instance\",\n            \"bias\": False,\n            \"dropout\": 0.1,\n            \"minimum_size_im\": 256,\n        },\n        torch.rand([1, 3, 256, 512, 256]),\n        [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)],\n        [4, 7],\n    ],\n]\nTEST_TOO_SMALL_SIZE = [\n    {\n        \"num_d\": 2,\n        \"num_layers_d\": 6,\n        \"spatial_dims\": 2,\n        \"channels\": 8,\n        \"in_channels\": 3,\n        \"out_channels\": 1,\n        \"kernel_size\": 3,\n        \"activation\": \"LEAKYRELU\",\n        \"norm\": \"instance\",\n        \"bias\": False,\n        \"dropout\": 0.1,\n        \"minimum_size_im\": 256,\n    }\n]\n\n\nclass TestPatchGAN(unittest.TestCase):\n    @parameterized.expand(TEST_PATCHGAN)\n    def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output):\n        net = PatchDiscriminator(**input_param)\n        with eval_mode(net):\n            result = net.forward(input_data)\n            self.assertEqual(tuple(result[0].shape), expected_shape_feature)\n            self.assertEqual(tuple(result[-1].shape), expected_shape_output)\n\n    def test_script(self):\n        net = PatchDiscriminator(\n            num_layers_d=3,\n            spatial_dims=2,\n            channels=8,\n            in_channels=3,\n            out_channels=1,\n            kernel_size=3,\n            activation=\"LEAKYRELU\",\n            norm=\"instance\",\n            bias=False,\n            dropout=0.1,\n        )\n        i = torch.rand([1, 3, 256, 512])\n        test_script_save(net, i)\n\n\nclass TestMultiscalePatchGAN(unittest.TestCase):\n    @parameterized.expand(TEST_MULTISCALE_PATCHGAN)\n    def test_shape(self, input_param, input_data, expected_shape, features_lengths=None):\n        net = MultiScalePatchDiscriminator(**input_param)\n        with eval_mode(net):\n            result, features = net.forward(input_data)\n            for r_ind, r in enumerate(result):\n                self.assertEqual(tuple(r.shape), expected_shape[r_ind])\n            for o_d_ind, o_d in enumerate(features):\n                self.assertEqual(len(o_d), features_lengths[o_d_ind])\n\n    def test_too_small_shape(self):\n        with self.assertRaises(AssertionError):\n            MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])\n\n    def test_script(self):\n        net = MultiScalePatchDiscriminator(\n            num_d=2,\n            num_layers_d=3,\n            spatial_dims=2,\n            channels=8,\n            in_channels=3,\n            out_channels=1,\n            kernel_size=3,\n            activation=\"LEAKYRELU\",\n            norm=\"instance\",\n            bias=False,\n            dropout=0.1,\n            minimum_size_im=256,\n        )\n        i = torch.rand([1, 3, 256, 512])\n        test_script_save(net, i)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_quicknat.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import Quicknat\nfrom monai.utils import optional_import\nfrom tests.test_utils import test_script_save\n\n_, has_se = optional_import(\"squeeze_and_excitation\")\n\nTEST_CASES = [\n    # params, input_shape, expected_shape\n    [{\"num_classes\": 1, \"num_channels\": 1, \"num_filters\": 1, \"se_block\": None}, (1, 1, 32, 32), (1, 1, 32, 32)],\n    [{\"num_classes\": 1, \"num_channels\": 1, \"num_filters\": 4, \"se_block\": None}, (1, 1, 64, 64), (1, 1, 64, 64)],\n    [{\"num_classes\": 1, \"num_channels\": 1, \"num_filters\": 64, \"se_block\": None}, (1, 1, 128, 128), (1, 1, 128, 128)],\n    [{\"num_classes\": 4, \"num_channels\": 1, \"num_filters\": 64, \"se_block\": None}, (1, 1, 32, 32), (1, 4, 32, 32)],\n    [{\"num_classes\": 33, \"num_channels\": 1, \"num_filters\": 64, \"se_block\": None}, (1, 1, 32, 32), (1, 33, 32, 32)],\n    [{\"num_classes\": 1, \"num_channels\": 1, \"num_filters\": 64, \"se_block\": \"CSE\"}, (1, 1, 32, 32), (1, 1, 32, 32)],\n    [{\"num_classes\": 1, \"num_channels\": 1, \"num_filters\": 64, \"se_block\": \"SSE\"}, (1, 1, 32, 32), (1, 1, 32, 32)],\n    [{\"num_classes\": 1, \"num_channels\": 1, \"num_filters\": 64, \"se_block\": \"CSSE\"}, (1, 1, 32, 32), (1, 1, 32, 32)],\n]\n\n\n@unittest.skipUnless(has_se, \"squeeze_and_excitation not installed\")\nclass TestQuicknat(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        print(input_param)\n        net = Quicknat(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n        self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = Quicknat(num_classes=1, num_channels=1)\n        test_data = torch.randn(16, 1, 32, 32)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_resnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport copy\nimport os\nimport re\nimport sys\nimport unittest\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import (\n    ResNet,\n    ResNetFeatures,\n    get_medicalnet_pretrained_resnet_args,\n    get_pretrained_resnet_medicalnet,\n    resnet10,\n    resnet18,\n    resnet34,\n    resnet50,\n    resnet101,\n    resnet152,\n    resnet200,\n)\nfrom monai.networks.nets.resnet import ResNetBlock\nfrom monai.utils import optional_import\nfrom tests.test_utils import (\n    SkipIfNoModule,\n    equal_state_dict,\n    skip_if_downloading_fails,\n    skip_if_no_cuda,\n    skip_if_quick,\n    test_script_save,\n)\n\nif TYPE_CHECKING:\n    import torchvision\n\n    has_torchvision = True\nelse:\n    torchvision, has_torchvision = optional_import(\"torchvision\")\n\nhas_hf_modules = \"huggingface_hub\" in sys.modules and \"huggingface_hub.utils._errors\" in sys.modules\n\n# from torchvision.models import ResNet50_Weights, resnet50\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_1 = [  # 3D, batch 3, 2 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 3,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": 7,\n        \"conv1_t_stride\": (2, 2, 2),\n    },\n    (3, 2, 32, 64, 48),\n    (3, 3),\n]\n\nTEST_CASE_2 = [  # 2D, batch 2, 1 input channel\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [7, 7],\n        \"conv1_t_stride\": [2, 2],\n    },\n    (2, 1, 32, 64),\n    (2, 3),\n]\n\nTEST_CASE_2_A = [  # 2D, batch 2, 1 input channel, shortcut type A\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 2,\n        \"n_input_channels\": 1,\n        \"num_classes\": 3,\n        \"shortcut_type\": \"A\",\n        \"conv1_t_size\": (7, 7),\n        \"conv1_t_stride\": 2,\n    },\n    (2, 1, 32, 64),\n    (2, 3),\n]\n\nTEST_CASE_3 = [  # 1D, batch 1, 2 input channels\n    {\n        \"pretrained\": False,\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n        \"act\": (\"relu\", {\"inplace\": False}),\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_3_A = [  # 1D, batch 1, 2 input channels\n    {\"pretrained\": False, \"spatial_dims\": 1, \"n_input_channels\": 2, \"num_classes\": 3, \"shortcut_type\": \"A\"},\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_4 = [  # 2D, batch 2, 1 input channel\n    {\"pretrained\": False, \"spatial_dims\": 2, \"n_input_channels\": 1, \"num_classes\": 3, \"feed_forward\": False},\n    (2, 1, 32, 64),\n    ((2, 512), (2, 2048)),\n]\n\nTEST_CASE_5 = [  # 1D, batch 1, 2 input channels\n    {\n        \"block\": \"basic\",\n        \"layers\": [1, 1, 1, 1],\n        \"block_inplanes\": [64, 128, 256, 512],\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_5_A = [  # 1D, batch 1, 2 input channels\n    {\n        \"block\": ResNetBlock,\n        \"layers\": [1, 1, 1, 1],\n        \"block_inplanes\": [64, 128, 256, 512],\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_6 = [  # 1D, batch 1, 2 input channels\n    {\n        \"block\": \"bottleneck\",\n        \"layers\": [3, 4, 6, 3],\n        \"block_inplanes\": [64, 128, 256, 512],\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_7 = [  # 1D, batch 1, 2 input channels, bias_downsample\n    {\n        \"block\": \"bottleneck\",\n        \"layers\": [3, 4, 6, 3],\n        \"block_inplanes\": [64, 128, 256, 512],\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n        \"bias_downsample\": False,  # set to False if pretrained=True (PR #5477)\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_8 = [\n    {\n        \"block\": \"bottleneck\",\n        \"layers\": [3, 4, 6, 3],\n        \"block_inplanes\": [64, 128, 256, 512],\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n        \"act\": (\"relu\", {\"inplace\": False}),\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASE_9 = [  # Layer norm\n    {\n        \"block\": ResNetBlock,\n        \"layers\": [3, 4, 6, 3],\n        \"block_inplanes\": [64, 128, 256, 512],\n        \"spatial_dims\": 1,\n        \"n_input_channels\": 2,\n        \"num_classes\": 3,\n        \"conv1_t_size\": [3],\n        \"conv1_t_stride\": 1,\n        \"act\": (\"relu\", {\"inplace\": False}),\n        \"norm\": (\"layer\", {\"normalized_shape\": (64, 32)}),\n    },\n    (1, 2, 32),\n    (1, 3),\n]\n\nTEST_CASES = []\nPRETRAINED_TEST_CASES = []\nfor case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:\n    for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:\n        TEST_CASES.append([model, *case])\n        PRETRAINED_TEST_CASES.append([model, *case])\nfor case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]:\n    TEST_CASES.append([ResNet, *case])\n\nTEST_SCRIPT_CASES = [\n    [model, *TEST_CASE_1] for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]\n]\n\nCASE_EXTRACT_FEATURES = [\n    (\n        {\"model_name\": \"resnet10\", \"pretrained\": True, \"spatial_dims\": 3, \"in_channels\": 1},\n        [1, 1, 64, 64, 64],\n        ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]),\n    )\n]\n\n\nclass TestResNet(unittest.TestCase):\n    def setUp(self):\n        self.tmp_ckpt_filename = os.path.join(\"tests\", \"monai_unittest_tmp_ckpt.pth\")\n\n    def tearDown(self):\n        if os.path.exists(self.tmp_ckpt_filename):\n            try:\n                os.remove(self.tmp_ckpt_filename)\n            except BaseException:\n                pass\n\n    @parameterized.expand(TEST_CASES)\n    def test_resnet_shape(self, model, input_param, input_shape, expected_shape):\n        net = model(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            if input_param.get(\"feed_forward\", True):\n                self.assertEqual(result.shape, expected_shape)\n            else:\n                self.assertIn(result.shape, expected_shape)\n\n    @parameterized.expand(PRETRAINED_TEST_CASES)\n    @skip_if_quick\n    @skip_if_no_cuda\n    def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape):\n        net = model(**input_param).to(device)\n        # Save ckpt\n        torch.save(net.state_dict(), self.tmp_ckpt_filename)\n\n        cp_input_param = copy.copy(input_param)\n        # Custom pretrained weights\n        cp_input_param[\"pretrained\"] = self.tmp_ckpt_filename\n        pretrained_net = model(**cp_input_param)\n        equal_state_dict(net.state_dict(), pretrained_net.state_dict())\n\n        if has_hf_modules:\n            # True flag\n            cp_input_param[\"pretrained\"] = True\n            resnet_depth = int(re.search(r\"resnet(\\d+)\", model.__name__).group(1))\n\n            bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)\n\n            # With orig. test cases\n            if (\n                input_param.get(\"spatial_dims\", 3) == 3\n                and input_param.get(\"n_input_channels\", 3) == 1\n                and input_param.get(\"feed_forward\", True) is False\n                and input_param.get(\"shortcut_type\", \"B\") == shortcut_type\n                and (input_param.get(\"bias_downsample\", True) == bias_downsample)\n            ):\n                model(**cp_input_param)\n            else:\n                with self.assertRaises(NotImplementedError):\n                    model(**cp_input_param)\n\n            # forcing MedicalNet pretrained download for 3D tests cases\n            cp_input_param[\"n_input_channels\"] = 1\n            cp_input_param[\"feed_forward\"] = False\n            cp_input_param[\"shortcut_type\"] = shortcut_type\n            cp_input_param[\"bias_downsample\"] = bias_downsample\n            if cp_input_param.get(\"spatial_dims\", 3) == 3:\n                with skip_if_downloading_fails():\n                    pretrained_net = model(**cp_input_param).to(device)\n                    medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device)\n                    medicalnet_state_dict = {\n                        key.replace(\"module.\", \"\"): value for key, value in medicalnet_state_dict.items()\n                    }\n                    equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict)\n\n    @parameterized.expand(TEST_SCRIPT_CASES)\n    def test_script(self, model, input_param, input_shape, expected_shape):\n        net = model(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\n@SkipIfNoModule(\"hf_hub_download\")\nclass TestExtractFeatures(unittest.TestCase):\n    @parameterized.expand(CASE_EXTRACT_FEATURES)\n    def test_shape(self, input_param, input_shape, expected_shapes):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        with skip_if_downloading_fails():\n            net = ResNetFeatures(**input_param).to(device)\n\n        # run inference with random tensor\n        with eval_mode(net):\n            features = net(torch.randn(input_shape).to(device))\n\n        # check output shape\n        self.assertEqual(len(features), len(expected_shapes))\n        for feature, expected_shape in zip(features, expected_shapes):\n            self.assertEqual(feature.shape, torch.Size(expected_shape))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_restormer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer\nfrom monai.utils import optional_import\n\neinops, has_einops = optional_import(\"einops\")\n\nTEST_CASES_TRANSFORMER = [\n    # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape]\n    [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)],\n    [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)],\n    [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)],\n    [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)],\n]\n\nTEST_CASES_PATCHEMBED = [\n    # spatial_dims, in_channels, embed_dim, input_shape, expected_shape\n    [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)],\n    [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)],\n    [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)],\n    [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)],\n]\n\nRESTORMER_CONFIGS = [\n    # 2-level architecture\n    {\"num_blocks\": [1, 1], \"heads\": [1, 1]},\n    {\"num_blocks\": [2, 1], \"heads\": [2, 1]},\n    # 3-level architecture\n    {\"num_blocks\": [1, 1, 1], \"heads\": [1, 1, 1]},\n    {\"num_blocks\": [2, 1, 1], \"heads\": [2, 1, 1]},\n]\n\nTEST_CASES_RESTORMER = []\nfor config in RESTORMER_CONFIGS:\n    # 2D cases\n    TEST_CASES_RESTORMER.extend(\n        [\n            [\n                {\n                    \"spatial_dims\": 2,\n                    \"in_channels\": 1,\n                    \"out_channels\": 1,\n                    \"dim\": 48,\n                    \"num_blocks\": config[\"num_blocks\"],\n                    \"heads\": config[\"heads\"],\n                    \"num_refinement_blocks\": 2,\n                    \"ffn_expansion_factor\": 1.5,\n                },\n                (2, 1, 64, 64),\n                (2, 1, 64, 64),\n            ],\n            # 3D cases\n            [\n                {\n                    \"spatial_dims\": 3,\n                    \"in_channels\": 1,\n                    \"out_channels\": 1,\n                    \"dim\": 16,\n                    \"num_blocks\": config[\"num_blocks\"],\n                    \"heads\": config[\"heads\"],\n                    \"num_refinement_blocks\": 2,\n                    \"ffn_expansion_factor\": 1.5,\n                },\n                (2, 1, 32, 32, 32),\n                (2, 1, 32, 32, 32),\n            ],\n        ]\n    )\n\n\nclass TestMDTATransformerBlock(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_TRANSFORMER)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):\n        if flash and not torch.cuda.is_available():\n            self.skipTest(\"Flash attention requires CUDA\")\n        block = MDTATransformerBlock(\n            spatial_dims=spatial_dims,\n            dim=dim,\n            num_heads=heads,\n            ffn_expansion_factor=ffn_factor,\n            bias=bias,\n            layer_norm_use_bias=layer_norm_use_bias,\n            flash_attention=flash,\n        )\n        with eval_mode(block):\n            x = torch.randn(shape)\n            output = block(x)\n            self.assertEqual(output.shape, x.shape)\n\n\nclass TestOverlapPatchEmbed(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_PATCHEMBED)\n    def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape):\n        net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestRestormer(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_RESTORMER)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        if input_param.get(\"flash_attention\", False) and not torch.cuda.is_available():\n            self.skipTest(\"Flash attention requires CUDA\")\n        net = Restormer(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_small_input_error_2d(self):\n        net = Restormer(spatial_dims=2, in_channels=1, out_channels=1)\n        with self.assertRaises(AssertionError):\n            net(torch.randn(1, 1, 8, 8))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_small_input_error_3d(self):\n        net = Restormer(spatial_dims=3, in_channels=1, out_channels=1)\n        with self.assertRaises(AssertionError):\n            net(torch.randn(1, 1, 8, 8, 8))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_segresnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import SegResNet, SegResNetVAE\nfrom monai.utils import UpsampleMode\nfrom tests.test_utils import dict_product, test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_SEGRESNET = [\n    [\n        {**params, \"use_conv_final\": False},\n        (2, 1, *([16] * params[\"spatial_dims\"])),\n        (2, params[\"init_filters\"], *([16] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(\n        spatial_dims=range(2, 4),\n        init_filters=[8, 16],\n        dropout_prob=[None, 0.2],\n        norm=[(\"GROUP\", {\"num_groups\": 8}), (\"batch\", {\"track_running_stats\": False}), \"instance\"],\n        upsample_mode=list(UpsampleMode),\n    )\n]\n\nTEST_CASE_SEGRESNET_2 = [\n    [params, (2, 1, *([16] * params[\"spatial_dims\"])), (2, params[\"out_channels\"], *([16] * params[\"spatial_dims\"]))]\n    for params in dict_product(\n        spatial_dims=range(2, 4), init_filters=[8, 16], out_channels=range(1, 3), upsample_mode=list(UpsampleMode)\n    )\n]\n\nTEST_CASE_SEGRESNET_VAE = [\n    [\n        {\n            **params,\n            \"act\": (\"leakyrelu\", {\"inplace\": True, \"negative_slope\": 0.01}),\n            \"input_image_size\": ([16] * params[\"spatial_dims\"]),\n        },\n        (2, 1, *([16] * params[\"spatial_dims\"])),\n        (2, params[\"out_channels\"], *([16] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(\n        spatial_dims=range(2, 4),\n        init_filters=[8, 16],\n        out_channels=range(1, 3),\n        upsample_mode=list(UpsampleMode),\n        vae_estimate_std=[True, False],\n    )\n]\n\n\nclass TestResNet(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_SEGRESNET + TEST_CASE_SEGRESNET_2)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SegResNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            SegResNet(spatial_dims=4)\n\n    def test_script(self):\n        input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET[0]\n        net = SegResNet(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nclass TestResNetVAE(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_SEGRESNET_VAE)\n    def test_vae_shape(self, input_param, input_shape, expected_shape):\n        net = SegResNetVAE(**input_param).to(device)\n        with eval_mode(net):\n            result, _ = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0]\n        net = SegResNetVAE(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_segresnet_ds.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import SegResNetDS, SegResNetDS2\nfrom tests.test_utils import dict_product, test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_SEGRESNET_DS = [\n    [params, (2, 1, *([16] * params[\"spatial_dims\"])), (2, 2, *([16] * params[\"spatial_dims\"]))]\n    for params in dict_product(\n        spatial_dims=range(2, 4),\n        init_filters=[8, 16],\n        act=[\"relu\", \"leakyrelu\"],\n        norm=[\"BATCH\", (\"instance\", {\"affine\": True})],\n        upsample_mode=[\"deconv\", \"nontrainable\"],\n    )\n]\n\nTEST_CASE_SEGRESNET_DS2 = [\n    [\n        {**params, \"init_filters\": 8},\n        (2, 1, *([16] * params[\"spatial_dims\"])),\n        (2, params[\"out_channels\"], *([16] * params[\"spatial_dims\"])),\n    ]\n    for params in dict_product(spatial_dims=range(2, 4), out_channels=[1, 2], dsdepth=[1, 2, 3])\n]\n\nTEST_CASE_SEGRESNET_DS3 = [\n    ({\"init_filters\": 8, \"dsdepth\": 2, \"resolution\": None}, (2, 1, 16, 16, 16), ((2, 2, 16, 16, 16), (2, 2, 8, 8, 8))),\n    (\n        {\"init_filters\": 8, \"dsdepth\": 3, \"resolution\": None},\n        (2, 1, 16, 16, 16),\n        ((2, 2, 16, 16, 16), (2, 2, 8, 8, 8), (2, 2, 4, 4, 4)),\n    ),\n    (\n        {\"init_filters\": 8, \"dsdepth\": 3, \"resolution\": [1, 1, 5]},\n        (2, 1, 16, 16, 16),\n        ((2, 2, 16, 16, 16), (2, 2, 8, 8, 16), (2, 2, 4, 4, 16)),\n    ),\n    (\n        {\"init_filters\": 8, \"dsdepth\": 3, \"resolution\": [1, 2, 5]},\n        (2, 1, 16, 16, 16),\n        ((2, 2, 16, 16, 16), (2, 2, 8, 8, 16), (2, 2, 4, 8, 16)),\n    ),\n]\n\n\nclass TestSegResNetDS(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_SEGRESNET_DS)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SegResNetDS(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n    @parameterized.expand(TEST_CASE_SEGRESNET_DS)\n    def test_shape_ds2(self, input_param, input_shape, expected_shape):\n        net = SegResNetDS2(**input_param).to(device)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device), with_label=False)\n            self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))\n            self.assertTrue(result[1] == [])\n\n            result = net(torch.randn(input_shape).to(device), with_point=False)\n            self.assertEqual(result[1].shape, expected_shape, msg=str(input_param))\n            self.assertTrue(result[0] == [])\n\n    @parameterized.expand(TEST_CASE_SEGRESNET_DS2)\n    def test_shape2(self, input_param, input_shape, expected_shape):\n        dsdepth = input_param.get(\"dsdepth\", 1)\n        for net in [SegResNetDS, SegResNetDS2]:\n            net = net(**input_param).to(device)\n            net.train()\n            if isinstance(net, SegResNetDS2):\n                result = net(torch.randn(input_shape).to(device), with_label=False)[0]\n            else:\n                result = net(torch.randn(input_shape).to(device))\n            if dsdepth > 1:\n                assert isinstance(result, list)\n                self.assertEqual(dsdepth, len(result))\n                for i in range(dsdepth):\n                    self.assertEqual(\n                        result[i].shape,\n                        expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]),\n                        msg=str(input_param),\n                    )\n            else:\n                assert isinstance(result, torch.Tensor)\n                self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n            if not isinstance(net, SegResNetDS2):\n                # eval mode of SegResNetDS2 has same output as training mode\n                # so only test eval mode for SegResNetDS\n                net.eval()\n                result = net(torch.randn(input_shape).to(device))\n                assert isinstance(result, torch.Tensor)\n                self.assertEqual(result.shape, expected_shape, msg=str(input_param))\n\n    @parameterized.expand(TEST_CASE_SEGRESNET_DS3)\n    def test_shape3(self, input_param, input_shape, expected_shapes):\n        dsdepth = input_param.get(\"dsdepth\", 1)\n        for net in [SegResNetDS, SegResNetDS2]:\n            net = net(**input_param).to(device)\n            net.train()\n            if isinstance(net, SegResNetDS2):\n                result = net(torch.randn(input_shape).to(device), with_point=False)[1]\n            else:\n                result = net(torch.randn(input_shape).to(device))\n            assert isinstance(result, list)\n            self.assertEqual(dsdepth, len(result))\n            for i in range(dsdepth):\n                self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param))\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            SegResNetDS(spatial_dims=4)\n\n        with self.assertRaises(ValueError):\n            SegResNetDS2(spatial_dims=4)\n\n    def test_script(self):\n        input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0]\n        net = SegResNetDS(**input_param)\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_senet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nimport monai.networks.nets.senet as se_mod\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import SENet, SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101\nfrom monai.utils import optional_import\nfrom tests.test_utils import test_is_quick, test_pretrained_networks, test_script_save, testing_data_config\n\nif TYPE_CHECKING:\n    import pretrainedmodels\n\n    has_cadene_pretrain = True\nelse:\n    pretrainedmodels, has_cadene_pretrain = optional_import(\"pretrainedmodels\")\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nNET_ARGS = {\"spatial_dims\": 3, \"in_channels\": 2, \"num_classes\": 2}\nTEST_CASE_1 = [SENet154, NET_ARGS]\nTEST_CASE_2 = [SEResNet50, NET_ARGS]\nTEST_CASE_3 = [SEResNet101, NET_ARGS]\nTEST_CASE_4 = [SEResNet152, NET_ARGS]\nTEST_CASE_5 = [SEResNext50, NET_ARGS]\nTEST_CASE_6 = [SEResNext101, NET_ARGS]\nTEST_CASE_7 = [\n    SENet,\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"num_classes\": 2,\n        \"block\": \"se_bottleneck\",\n        \"layers\": (3, 8, 36, 3),\n        \"groups\": 64,\n        \"reduction\": 16,\n    },\n]\n\nTEST_CASE_PRETRAINED_1 = [SEResNet50, {\"spatial_dims\": 2, \"in_channels\": 3, \"num_classes\": 2, \"pretrained\": True}]\n\n\nclass TestSENET(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])\n    def test_senet_shape(self, net, net_args):\n        input_data = torch.randn(2, 2, 64, 64, 64).to(device)\n        expected_shape = (2, 2)\n        net = net(**net_args).to(device)\n        with eval_mode(net):\n            result = net(input_data)\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])\n    def test_script(self, net, net_args):\n        net = net(**net_args)\n        input_data = torch.randn(2, 2, 64, 64, 64)\n        test_script_save(net, input_data)\n\n\nclass TestPretrainedSENET(unittest.TestCase):\n    def setUp(self):\n        self.original_urls = se_mod.SE_NET_MODELS.copy()\n        replace_url = test_is_quick()\n        if not replace_url:\n            try:\n                SEResNet50(pretrained=True, spatial_dims=2, in_channels=3, num_classes=2)\n            except OSError as rt_e:\n                print(rt_e)\n                if \"certificate\" in str(rt_e):  # [SSL: CERTIFICATE_VERIFY_FAILED]\n                    replace_url = True\n        if replace_url:\n            testing_dir = Path(__file__).parents[2] / \"testing_data\"\n            testing_data_urls = {\n                \"senet154\": {\n                    \"url\": testing_data_config(\"models\", \"senet154-c7b49a05\", \"url\"),\n                    \"filename\": \"senet154-c7b49a05.pth\",\n                },\n                \"se_resnet50\": {\n                    \"url\": testing_data_config(\"models\", \"se_resnet50-ce0d4300\", \"url\"),\n                    \"filename\": \"se_resnet50-ce0d4300.pth\",\n                },\n                \"se_resnet101\": {\n                    \"url\": testing_data_config(\"models\", \"se_resnet101-7e38fcc6\", \"url\"),\n                    \"filename\": \"se_resnet101-7e38fcc6.pth\",\n                },\n                \"se_resnet152\": {\n                    \"url\": testing_data_config(\"models\", \"se_resnet152-d17c99b7\", \"url\"),\n                    \"filename\": \"se_resnet152-d17c99b7.pth\",\n                },\n                \"se_resnext50_32x4d\": {\n                    \"url\": testing_data_config(\"models\", \"se_resnext50_32x4d-a260b3a4\", \"url\"),\n                    \"filename\": \"se_resnext50_32x4d-a260b3a4.pth\",\n                },\n                \"se_resnext101_32x4d\": {\n                    \"url\": testing_data_config(\"models\", \"se_resnext101_32x4d-3b2fe3d8\", \"url\"),\n                    \"filename\": \"se_resnext101_32x4d-3b2fe3d8.pth\",\n                },\n            }\n            for item in testing_data_urls:\n                testing_data_urls[item][\"filename\"] = os.path.join(testing_dir, testing_data_urls[item][\"filename\"])\n            se_mod.SE_NET_MODELS = testing_data_urls\n\n    def tearDown(self):\n        se_mod.SE_NET_MODELS = self.original_urls.copy()\n\n    @parameterized.expand([TEST_CASE_PRETRAINED_1])\n    def test_senet_shape(self, model, input_param):\n        net = test_pretrained_networks(model, input_param, device)\n        input_data = torch.randn(3, 3, 64, 64).to(device)\n        expected_shape = (3, 2)\n        net = net.to(device)\n        with eval_mode(net):\n            result = net(input_data)\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand([TEST_CASE_PRETRAINED_1])\n    @skipUnless(has_cadene_pretrain, \"Requires `pretrainedmodels` package.\")\n    def test_pretrain_consistency(self, model, input_param):\n        input_data = torch.randn(1, 3, 64, 64).to(device)\n        net = test_pretrained_networks(model, input_param, device)\n        with eval_mode(net):\n            result = net.features(input_data)\n        cadene_net = pretrainedmodels.se_resnet50().to(device)\n        with eval_mode(cadene_net):\n            expected_result = cadene_net.features(input_data)\n        # The difference between Cadene's senet and our version is that\n        # we use nn.Linear as the FC layer, but Cadene's version uses\n        # a conv layer with kernel size equals to 1. It may bring a little difference.\n        self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_spade_autoencoderkl.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import SPADEAutoencoderKL\nfrom monai.utils import optional_import\n\neinops, has_einops = optional_import(\"einops\")\n\nCASES_NO_ATTENTION = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n            \"with_decoder_nonlocal_attn\": False,\n        },\n        (1, 1, 16, 16, 16),\n        (1, 3, 16, 16, 16),\n        (1, 1, 16, 16, 16),\n        (1, 4, 4, 4, 4),\n    ],\n]\n\nCASES_ATTENTION = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": (1, 1, 2),\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, True),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, False),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"with_encoder_nonlocal_attn\": False,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, True),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n        },\n        (1, 1, 16, 16, 16),\n        (1, 3, 16, 16, 16),\n        (1, 1, 16, 16, 16),\n        (1, 4, 4, 4, 4),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"label_nc\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4, 4),\n            \"latent_channels\": 4,\n            \"attention_levels\": (False, False, True),\n            \"num_res_blocks\": 1,\n            \"norm_num_groups\": 4,\n            \"spade_intermediate_channels\": 32,\n        },\n        (1, 1, 16, 16),\n        (1, 3, 16, 16),\n        (1, 1, 16, 16),\n        (1, 4, 4, 4),\n    ],\n]\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nif has_einops:\n    CASES = CASES_ATTENTION + CASES_NO_ATTENTION\nelse:\n    CASES = CASES_NO_ATTENTION\n\n\nclass TestSPADEAutoEncoderKL(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape):\n        net = SPADEAutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device))\n            self.assertEqual(result[0].shape, expected_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_model_channels_not_multiple_of_norm_num_group(self):\n        with self.assertRaises(ValueError):\n            SPADEAutoencoderKL(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                channels=(24, 24, 24),\n                attention_levels=(False, False, False),\n                latent_channels=8,\n                num_res_blocks=1,\n                norm_num_groups=16,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_model_channels_not_same_size_of_attention_levels(self):\n        with self.assertRaises(ValueError):\n            SPADEAutoencoderKL(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                channels=(24, 24, 24),\n                attention_levels=(False, False),\n                latent_channels=8,\n                num_res_blocks=1,\n                norm_num_groups=16,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_model_channels_not_same_size_of_num_res_blocks(self):\n        with self.assertRaises(ValueError):\n            SPADEAutoencoderKL(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                channels=(24, 24, 24),\n                attention_levels=(False, False, False),\n                latent_channels=8,\n                num_res_blocks=(8, 8),\n                norm_num_groups=16,\n            )\n\n    def test_shape_encode(self):\n        input_param, input_shape, _, _, expected_latent_shape = CASES[0]\n        net = SPADEAutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.encode(torch.randn(input_shape).to(device))\n            self.assertEqual(result[0].shape, expected_latent_shape)\n            self.assertEqual(result[1].shape, expected_latent_shape)\n\n    def test_shape_sampling(self):\n        input_param, _, _, _, expected_latent_shape = CASES[0]\n        net = SPADEAutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.sampling(\n                torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)\n            )\n            self.assertEqual(result.shape, expected_latent_shape)\n\n    def test_shape_decode(self):\n        input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0]\n        net = SPADEAutoencoderKL(**input_param).to(device)\n        with eval_mode(net):\n            result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device))\n            self.assertEqual(result.shape, expected_input_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_wrong_shape_decode(self):\n        net = SPADEAutoencoderKL(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            channels=(4, 4, 4),\n            latent_channels=4,\n            attention_levels=(False, False, False),\n            num_res_blocks=1,\n            norm_num_groups=4,\n        )\n        with self.assertRaises(RuntimeError):\n            _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_spade_diffusion_model_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import SPADEDiffusionModelUNet\nfrom monai.utils import optional_import\n\neinops, has_einops = optional_import(\"einops\")\nUNCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": (1, 1, 2),\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, True, True),\n            \"num_head_channels\": (0, 2, 4),\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n]\n\nUNCOND_CASES_3D = [\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n            \"spade_intermediate_channels\": 256,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, False),\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 8,\n            \"norm_num_groups\": 8,\n            \"resblock_updown\": True,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": (0, 0, 4),\n            \"norm_num_groups\": 8,\n            \"label_nc\": 3,\n        }\n    ],\n]\n\nCOND_CASES_2D = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"resblock_updown\": True,\n            \"label_nc\": 3,\n        }\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"num_res_blocks\": 1,\n            \"channels\": (8, 8, 8),\n            \"attention_levels\": (False, False, True),\n            \"num_head_channels\": 4,\n            \"norm_num_groups\": 8,\n            \"with_conditioning\": True,\n            \"transformer_num_layers\": 1,\n            \"cross_attention_dim\": 3,\n            \"upcast_attention\": True,\n            \"label_nc\": 3,\n        }\n    ],\n]\n\n\nclass TestSPADEDiffusionModelUNet2D(unittest.TestCase):\n    @parameterized.expand(UNCOND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param):\n        net = SPADEDiffusionModelUNet(**input_param)\n        with eval_mode(net):\n            result = net.forward(\n                torch.rand((1, 1, 16, 16)),\n                torch.randint(0, 1000, (1,)).long(),\n                torch.rand((1, input_param[\"label_nc\"], 16, 16)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_timestep_with_wrong_shape(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net.forward(\n                    torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16))\n                )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_label_with_wrong_shape(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with self.assertRaises(RuntimeError):\n            with eval_mode(net):\n                net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16)))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_different_in_channel_out_channel(self):\n        in_channels = 6\n        out_channels = 3\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, False),\n            norm_num_groups=8,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16))\n            )\n            self.assertEqual(result.shape, (1, out_channels, 16, 16))\n\n    def test_model_channels_not_multiple_of_norm_num_group(self):\n        with self.assertRaises(ValueError):\n            SPADEDiffusionModelUNet(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 12),\n                attention_levels=(False, False, False),\n                norm_num_groups=8,\n            )\n\n    def test_attention_levels_with_different_length_num_head_channels(self):\n        with self.assertRaises(ValueError):\n            SPADEDiffusionModelUNet(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False, False),\n                num_head_channels=(0, 2),\n                norm_num_groups=8,\n            )\n\n    def test_num_res_blocks_with_different_length_channels(self):\n        with self.assertRaises(ValueError):\n            SPADEDiffusionModelUNet(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=(1, 1),\n                channels=(8, 8, 8),\n                attention_levels=(False, False, False),\n                norm_num_groups=8,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            with_conditioning=True,\n            transformer_num_layers=1,\n            cross_attention_dim=3,\n            norm_num_groups=8,\n            num_head_channels=8,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                seg=torch.rand((1, 3, 16, 32)),\n                context=torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 32))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_with_conditioning_cross_attention_dim_none(self):\n        with self.assertRaises(ValueError):\n            SPADEDiffusionModelUNet(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False, True),\n                with_conditioning=True,\n                transformer_num_layers=1,\n                cross_attention_dim=None,\n                norm_num_groups=8,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_context_with_conditioning_none(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            with_conditioning=False,\n            transformer_num_layers=1,\n            norm_num_groups=8,\n        )\n\n        with self.assertRaises(ValueError):\n            with eval_mode(net):\n                net.forward(\n                    x=torch.rand((1, 1, 16, 32)),\n                    timesteps=torch.randint(0, 1000, (1,)).long(),\n                    seg=torch.rand((1, 3, 16, 32)),\n                    context=torch.rand((1, 1, 3)),\n                )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models_class_conditioning(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=8,\n            num_head_channels=8,\n            num_class_embeds=2,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                seg=torch.rand((1, 3, 16, 32)),\n                class_labels=torch.randint(0, 2, (1,)).long(),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 32))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_models_no_class_labels(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=2,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=8,\n            num_head_channels=8,\n            num_class_embeds=2,\n        )\n\n        with self.assertRaises(ValueError):\n            net.forward(\n                x=torch.rand((1, 1, 16, 32)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                seg=torch.rand((1, 3, 16, 32)),\n            )\n\n    def test_model_channels_not_same_size_of_attention_levels(self):\n        with self.assertRaises(ValueError):\n            SPADEDiffusionModelUNet(\n                spatial_dims=2,\n                label_nc=3,\n                in_channels=1,\n                out_channels=1,\n                num_res_blocks=1,\n                channels=(8, 8, 8),\n                attention_levels=(False, False),\n                norm_num_groups=8,\n                num_head_channels=8,\n                num_class_embeds=2,\n            )\n\n    @parameterized.expand(COND_CASES_2D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_2d_models_shape(self, input_param):\n        net = SPADEDiffusionModelUNet(**input_param)\n        with eval_mode(net):\n            result = net.forward(\n                torch.rand((1, 1, 16, 16)),\n                torch.randint(0, 1000, (1,)).long(),\n                torch.rand((1, input_param[\"label_nc\"], 16, 16)),\n                torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16))\n\n\nclass TestDiffusionModelUNet3D(unittest.TestCase):\n    @parameterized.expand(UNCOND_CASES_3D)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_unconditioned_models(self, input_param):\n        net = SPADEDiffusionModelUNet(**input_param)\n        with eval_mode(net):\n            result = net.forward(\n                torch.rand((1, 1, 16, 16, 16)),\n                torch.randint(0, 1000, (1,)).long(),\n                torch.rand((1, input_param[\"label_nc\"], 16, 16, 16)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_with_different_in_channel_out_channel(self):\n        in_channels = 6\n        out_channels = 3\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=3,\n            label_nc=3,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            num_res_blocks=1,\n            channels=(8, 8, 8),\n            attention_levels=(False, False, True),\n            norm_num_groups=4,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                torch.rand((1, in_channels, 16, 16, 16)),\n                torch.randint(0, 1000, (1,)).long(),\n                torch.rand((1, 3, 16, 16, 16)),\n            )\n            self.assertEqual(result.shape, (1, out_channels, 16, 16, 16))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape_conditioned_models(self):\n        net = SPADEDiffusionModelUNet(\n            spatial_dims=3,\n            label_nc=3,\n            in_channels=1,\n            out_channels=1,\n            num_res_blocks=1,\n            channels=(16, 16, 16),\n            attention_levels=(False, False, True),\n            norm_num_groups=16,\n            with_conditioning=True,\n            transformer_num_layers=1,\n            cross_attention_dim=3,\n        )\n        with eval_mode(net):\n            result = net.forward(\n                x=torch.rand((1, 1, 16, 16, 16)),\n                timesteps=torch.randint(0, 1000, (1,)).long(),\n                seg=torch.rand((1, 3, 16, 16, 16)),\n                context=torch.rand((1, 1, 3)),\n            )\n            self.assertEqual(result.shape, (1, 1, 16, 16, 16))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_spade_vaegan.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import SPADENet\n\nCASE_2D = [\n    [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]],\n    [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], None, False]],\n]\nCASE_3D = [\n    [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]],\n    [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], None, False]],\n]\n\n\ndef create_semantic_data(shape: list, semantic_regions: int):\n    \"\"\"\n    To create semantic and image mock inputs for the network.\n    Args:\n        shape: input shape\n        semantic_regions: number of semantic region\n    Returns:\n    \"\"\"\n    out_label = torch.zeros(shape)\n    out_image = torch.zeros(shape) + torch.randn(shape) * 0.01\n    for i in range(1, semantic_regions):\n        shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape]\n        start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)]\n        if len(shape) == 2:\n            out_label[\n                start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1])\n            ] = i\n            base_intensity = torch.ones(shape_square) * np.random.randn()\n            out_image[\n                start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1])\n            ] = (base_intensity + torch.randn(shape_square) * 0.1)\n        elif len(shape) == 3:\n            out_label[\n                start_point[0] : (start_point[0] + shape_square[0]),\n                start_point[1] : (start_point[1] + shape_square[1]),\n                start_point[2] : (start_point[2] + shape_square[2]),\n            ] = i\n            base_intensity = torch.ones(shape_square) * np.random.randn()\n            out_image[\n                start_point[0] : (start_point[0] + shape_square[0]),\n                start_point[1] : (start_point[1] + shape_square[1]),\n                start_point[2] : (start_point[2] + shape_square[2]),\n            ] = (base_intensity + torch.randn(shape_square) * 0.1)\n        else:\n            ValueError(\"Supports only 2D and 3D tensors\")\n\n    # One hot encode label\n    out_label_ = torch.zeros([semantic_regions] + list(out_label.shape))\n    for ch in range(semantic_regions):\n        out_label_[ch, ...] = out_label == ch\n\n    return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0)\n\n\nclass TestSpadeNet(unittest.TestCase):\n    @parameterized.expand(CASE_2D)\n    def test_forward_2d(self, input_param):\n        \"\"\"\n        Check that forward method is called correctly and output shape matches.\n        \"\"\"\n        net = SPADENet(*input_param)\n        in_label, in_image = create_semantic_data(input_param[4], input_param[3])\n        with eval_mode(net):\n            if not net.is_vae:\n                out = net(in_label, in_image)\n                out = out[0]\n            else:\n                out, z_mu, z_logvar = net(in_label, in_image)\n                self.assertTrue(torch.all(torch.isfinite(z_mu)))\n                self.assertTrue(torch.all(torch.isfinite(z_logvar)))\n\n            self.assertTrue(torch.all(torch.isfinite(out)))\n            self.assertEqual(list(out.shape), [1, 1, 64, 64])\n\n    @parameterized.expand(CASE_2D)\n    def test_encoder_decoder(self, input_param):\n        \"\"\"\n        Check that forward method is called correctly and output shape matches.\n        \"\"\"\n        net = SPADENet(*input_param)\n        in_label, in_image = create_semantic_data(input_param[4], input_param[3])\n        with eval_mode(net):\n            out_z = net.encode(in_image)\n            if net.is_vae:\n                self.assertEqual(list(out_z.shape), [1, 16])\n            else:\n                self.assertEqual(out_z, None)\n            out_i = net.decode(in_label, out_z)\n            self.assertEqual(list(out_i.shape), [1, 1, 64, 64])\n\n    @parameterized.expand(CASE_3D)\n    def test_forward_3d(self, input_param):\n        \"\"\"\n        Check that forward method is called correctly and output shape matches.\n        \"\"\"\n        net = SPADENet(*input_param)\n        in_label, in_image = create_semantic_data(input_param[4], input_param[3])\n        with eval_mode(net):\n            if net.is_vae:\n                out, z_mu, z_logvar = net(in_label, in_image)\n                self.assertTrue(torch.all(torch.isfinite(z_mu)))\n                self.assertTrue(torch.all(torch.isfinite(z_logvar)))\n            else:\n                out = net(in_label, in_image)\n                out = out[0]\n            self.assertTrue(torch.all(torch.isfinite(out)))\n            self.assertEqual(list(out.shape), [1, 1, 64, 64, 64])\n\n    def test_shape_wrong(self):\n        \"\"\"\n        We input an input shape that isn't divisible by 2**(n downstream steps)\n        \"\"\"\n        with self.assertRaises(ValueError):\n            _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_swin_unetr.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps import download_url\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr\nfrom monai.networks.utils import copy_model_state\nfrom monai.utils import optional_import\nfrom tests.test_utils import (\n    assert_allclose,\n    dict_product,\n    skip_if_downloading_fails,\n    skip_if_no_cuda,\n    skip_if_quick,\n    testing_data_config,\n)\n\neinops, has_einops = optional_import(\"einops\")\n\ntest_merging_mode = [\"mergingv2\", \"merging\", PatchMerging, PatchMergingV2]\ncheckpoint_vals = [True, False]\n\nTEST_CASE_SWIN_UNETR = [\n    [\n        {\n            **{k: v for k, v in params.items() if k != \"img_size\"},\n            \"spatial_dims\": len(params[\"img_size\"]),\n            \"downsample\": test_merging_mode[i % len(test_merging_mode)],\n        },\n        (2, params[\"in_channels\"], *params[\"img_size\"]),\n        (2, params[\"out_channels\"], *params[\"img_size\"]),\n    ]\n    for i, params in enumerate(\n        dict_product(\n            attn_drop_rate=[0.4],\n            depths=[[2, 1, 1, 1], [1, 2, 1, 1]],\n            feature_size=[12],\n            img_size=((64, 32, 192), (96, 32)),\n            in_channels=[1],\n            norm_name=[\"instance\"],\n            out_channels=[2],\n            use_checkpoint=checkpoint_vals,\n        )\n    )\n]\n\nTEST_CASE_FILTER = [\n    [\n        {\"in_channels\": 1, \"out_channels\": 14, \"feature_size\": 48, \"use_checkpoint\": True},\n        \"swinViT.layers1.0.blocks.0.norm1.weight\",\n        torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]),\n    ]\n]\n\n\nclass TestSWINUNETR(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_SWIN_UNETR)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = SwinUNETR(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            SwinUNETR(spatial_dims=1, in_channels=1, out_channels=2, feature_size=48, norm_name=\"instance\")\n\n        with self.assertRaises(ValueError):\n            SwinUNETR(in_channels=1, out_channels=4, feature_size=50, norm_name=\"instance\")\n\n        with self.assertRaises(ValueError):\n            SwinUNETR(in_channels=1, out_channels=3, feature_size=24, norm_name=\"instance\", drop_rate=-1)\n\n    def test_patch_merging(self):\n        dim = 10\n        t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))\n        self.assertEqual(t.shape, torch.Size([1, 11, 10, 10, 20]))\n\n    @parameterized.expand(TEST_CASE_FILTER)\n    @skip_if_quick\n    @skip_if_no_cuda\n    def test_filter_swinunetr(self, input_param, key, value):\n        with skip_if_downloading_fails():\n            with tempfile.TemporaryDirectory() as tempdir:\n                file_name = \"ssl_pretrained_weights.pth\"\n                data_spec = testing_data_config(\"models\", f\"{file_name.split('.', 1)[0]}\")\n                weight_path = os.path.join(tempdir, file_name)\n                download_url(\n                    data_spec[\"url\"], weight_path, hash_val=data_spec[\"hash_val\"], hash_type=data_spec[\"hash_type\"]\n                )\n\n                ssl_weight = torch.load(weight_path, weights_only=True)[\"model\"]\n                net = SwinUNETR(**input_param)\n                dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr)\n                assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False)\n                self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_torchvision_fc_model.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import TorchVisionFCModel, UNet\nfrom monai.networks.utils import look_up_named_module, set_named_module\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import skip_if_downloading_fails\n\nInception_V3_Weights, has_enum = optional_import(\"torchvision.models.inception\", name=\"Inception_V3_Weights\")\n\n_, has_tv = optional_import(\"torchvision\", \"0.12\", min_version)\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": True, \"pretrained\": False},\n    (2, 3, 224, 224),\n    (2, 1, 1, 1),\n]\n\nTEST_CASE_1 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": True, \"pretrained\": False},\n    (2, 3, 256, 256),\n    (2, 1, 2, 2),\n]\n\nTEST_CASE_2 = [\n    {\"model_name\": \"resnet101\", \"num_classes\": 5, \"use_conv\": True, \"pretrained\": False},\n    (2, 3, 256, 256),\n    (2, 5, 2, 2),\n]\n\nTEST_CASE_3 = [\n    {\n        \"model_name\": \"resnet101\",\n        \"num_classes\": 5,\n        \"use_conv\": True,\n        \"pool\": (\"avg\", {\"kernel_size\": 6, \"stride\": 1}),\n        \"pretrained\": False,\n    },\n    (2, 3, 224, 224),\n    (2, 5, 2, 2),\n]\n\nTEST_CASE_4 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": False, \"pool\": None, \"pretrained\": False},\n    (2, 3, 224, 224),\n    (2, 1),\n]\n\nTEST_CASE_5 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": False, \"pool\": None, \"pretrained\": False},\n    (2, 3, 256, 256),\n    (2, 1),\n]\n\nTEST_CASE_6 = [\n    {\"model_name\": \"resnet101\", \"num_classes\": 5, \"use_conv\": False, \"pool\": None, \"pretrained\": False},\n    (2, 3, 256, 256),\n    (2, 5),\n]\n\nTEST_CASE_7 = [\n    {\n        \"model_name\": \"inception_v3\",\n        \"num_classes\": 5,\n        \"use_conv\": True,\n        \"pool\": \"\",\n        \"in_channels\": 2048,\n        \"node_name\": \"Mixed_7c.cat_2\",\n    },\n    (2, 3, 299, 299),\n    (2, 5, 8, 8),\n]\n\nTEST_CASE_8 = [\n    {\"model_name\": \"vit_b_16\", \"num_classes\": 5, \"in_channels\": 768, \"pool\": None, \"fc_name\": \"heads.head\"},\n    (2, 3, 224, 224),\n    (2, 5),\n]\n\nTEST_CASE_PRETRAINED_0 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": True, \"pretrained\": True},\n    (2, 3, 224, 224),\n    (2, 1, 1, 1),\n    -0.010419349186122417,\n]\n\nTEST_CASE_PRETRAINED_1 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": True, \"pretrained\": True},\n    (2, 3, 256, 256),\n    (2, 1, 2, 2),\n    -0.010419349186122417,\n]\n\nTEST_CASE_PRETRAINED_2 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 5, \"use_conv\": True, \"pretrained\": True},\n    (2, 3, 256, 256),\n    (2, 5, 2, 2),\n    -0.010419349186122417,\n]\n\nTEST_CASE_PRETRAINED_3 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": False, \"pool\": None, \"pretrained\": True},\n    (2, 3, 224, 224),\n    (2, 1),\n    -0.010419349186122417,\n]\n\nTEST_CASE_PRETRAINED_4 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 1, \"use_conv\": False, \"pool\": None, \"pretrained\": True},\n    (2, 3, 256, 256),\n    (2, 1),\n    -0.010419349186122417,\n]\n\nTEST_CASE_PRETRAINED_5 = [\n    {\"model_name\": \"resnet18\", \"num_classes\": 5, \"use_conv\": False, \"pool\": None, \"pretrained\": True},\n    (2, 3, 256, 256),\n    (2, 5),\n    -0.010419349186122417,\n]\n\nTEST_CASE_PRETRAINED_6 = [\n    {\n        \"model_name\": \"inception_v3\",\n        \"num_classes\": 5,\n        \"use_conv\": False,\n        \"pool\": None,\n        \"weights\": Inception_V3_Weights.IMAGENET1K_V1 if has_enum else None,\n    },\n    (2, 3, 299, 299),\n    (2, 5),\n    -0.21029122173786163,\n]\n\n\nclass TestTorchVisionFCModel(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]\n        + ([TEST_CASE_8] if has_enum else [])\n    )\n    @skipUnless(has_tv, \"Requires TorchVision.\")\n    def test_without_pretrained(self, input_param, input_shape, expected_shape):\n        net = TorchVisionFCModel(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_PRETRAINED_0,\n            TEST_CASE_PRETRAINED_1,\n            TEST_CASE_PRETRAINED_2,\n            TEST_CASE_PRETRAINED_3,\n            TEST_CASE_PRETRAINED_4,\n            TEST_CASE_PRETRAINED_5,\n        ]\n        + ([TEST_CASE_PRETRAINED_6] if has_enum else [])\n    )\n    @skipUnless(has_tv, \"Requires TorchVision.\")\n    def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value):\n        with skip_if_downloading_fails():\n            net = TorchVisionFCModel(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            value = next(net.features.parameters())[0, 0, 0, 0].item()\n            self.assertEqual(value, expected_value)\n            self.assertEqual(result.shape, expected_shape)\n\n\nclass TestLookup(unittest.TestCase):\n    def test_get_module(self):\n        net = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32, 64), strides=(2, 2, 2, 2))\n        self.assertEqual(look_up_named_module(\"\", net), net)\n        mod = look_up_named_module(\"model.1.submodule.1.submodule.1.submodule.0.conv\", net)\n        self.assertTrue(str(mod).startswith(\"Conv2d\"))\n        self.assertIsInstance(set_named_module(net, \"model\", torch.nn.Identity()).model, torch.nn.Identity)\n        self.assertIsNone(look_up_named_module(\"model.1.submodule.1.submodule.1.submodule.conv\", net))\n        self.assertIsNone(look_up_named_module(\"test attribute\", net))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_transchex.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.transchex import Transchex\nfrom tests.test_utils import dict_product, skip_if_downloading_fails, skip_if_quick\n\nTEST_CASE_TRANSCHEX = [\n    [\n        {\n            **{k: v for k, v in params.items() if k != \"img_size\"},\n            \"img_size\": (params[\"img_size\"],) * 2,\n            \"patch_size\": (params[\"patch_size\"],) * 2,\n        },\n        (2, params[\"num_classes\"]),\n    ]\n    for params in dict_product(\n        drop_out=[0.4],\n        img_size=[224],\n        in_channels=[3],\n        num_classes=[8],\n        num_language_layers=[2],\n        num_mixed_layers=[3],\n        num_vision_layers=[4],\n        patch_size=[16, 32],\n    )\n]\n\n\n@skip_if_quick\nclass TestTranschex(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_TRANSCHEX)\n    def test_shape(self, input_param, expected_shape):\n        with skip_if_downloading_fails():\n            net = Transchex(**input_param)\n        with eval_mode(net):\n            result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224)))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            Transchex(\n                in_channels=3,\n                img_size=(128, 128),\n                patch_size=(16, 16),\n                num_language_layers=2,\n                num_mixed_layers=4,\n                num_vision_layers=2,\n                num_classes=2,\n                drop_out=5.0,\n            )\n\n        with self.assertRaises(ValueError):\n            Transchex(\n                in_channels=1,\n                img_size=(97, 97),\n                patch_size=(16, 16),\n                num_language_layers=6,\n                num_mixed_layers=6,\n                num_vision_layers=8,\n                num_classes=8,\n                drop_out=0.4,\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_transformer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps import download_url\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import DecoderOnlyTransformer\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_downloading_fails, testing_data_config\n\n_, has_einops = optional_import(\"einops\")\nTEST_CASES = []\nfor dropout_rate in np.linspace(0, 1, 2):\n    for attention_layer_dim in [360, 480, 600, 768]:\n        for num_heads in [4, 6, 8, 12]:\n            TEST_CASES.append(\n                [\n                    {\n                        \"num_tokens\": 10,\n                        \"max_seq_len\": 16,\n                        \"attn_layers_dim\": attention_layer_dim,\n                        \"attn_layers_depth\": 2,\n                        \"attn_layers_heads\": num_heads,\n                        \"embedding_dropout_rate\": dropout_rate,\n                    }\n                ]\n            )\n\n\nclass TestDecoderOnlyTransformer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_unconditioned_models(self, input_param):\n        net = DecoderOnlyTransformer(**input_param)\n        with eval_mode(net):\n            net.forward(torch.randint(0, 10, (1, 16)))\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_conditioned_models(self, input_param):\n        net = DecoderOnlyTransformer(**input_param, with_cross_attention=True)\n        with eval_mode(net):\n            net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param[\"attn_layers_dim\"]))\n\n    def test_attention_dim_not_multiple_of_heads(self):\n        with self.assertRaises(ValueError):\n            DecoderOnlyTransformer(\n                num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_dropout_rate_negative(self):\n        with self.assertRaises(ValueError):\n            DecoderOnlyTransformer(\n                num_tokens=10,\n                max_seq_len=16,\n                attn_layers_dim=8,\n                attn_layers_depth=2,\n                attn_layers_heads=2,\n                embedding_dropout_rate=-1,\n            )\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_compatibility_with_monai_generative(self):\n        # test loading weights from a model saved in MONAI Generative, version 0.2.3\n        with skip_if_downloading_fails():\n            net = DecoderOnlyTransformer(\n                num_tokens=10,\n                max_seq_len=16,\n                attn_layers_dim=8,\n                attn_layers_depth=2,\n                attn_layers_heads=2,\n                with_cross_attention=True,\n                embedding_dropout_rate=0,\n            )\n\n            tmpdir = tempfile.mkdtemp()\n            key = \"decoder_only_transformer_monai_generative_weights\"\n            url = testing_data_config(\"models\", key, \"url\")\n            hash_type = testing_data_config(\"models\", key, \"hash_type\")\n            hash_val = testing_data_config(\"models\", key, \"hash_val\")\n            filename = \"decoder_only_transformer_monai_generative_weights.pt\"\n            weight_path = os.path.join(tmpdir, filename)\n            download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)\n\n            net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_unet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.layers import Act, Norm\nfrom monai.networks.nets import UNet\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [  # single channel 2D, batch 16, no residual\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 0,\n    },\n    (16, 1, 32, 32),\n    (16, 3, 32, 32),\n]\n\nTEST_CASE_1 = [  # single channel 2D, batch 16\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n    },\n    (16, 1, 32, 32),\n    (16, 3, 32, 32),\n]\n\nTEST_CASE_2 = [  # single channel 3D, batch 16\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 1,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n    },\n    (16, 1, 32, 24, 48),\n    (16, 3, 32, 24, 48),\n]\n\nTEST_CASE_3 = [  # 4-channel 3D, batch 16\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nTEST_CASE_4 = [  # 4-channel 3D, batch 16, batch normalization\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n        \"norm\": Norm.BATCH,\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nTEST_CASE_5 = [  # 4-channel 3D, batch 16, LeakyReLU activation\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n        \"act\": (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n        \"adn_ordering\": \"NA\",\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nTEST_CASE_6 = [  # 4-channel 3D, batch 16, LeakyReLU activation explicit\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"out_channels\": 3,\n        \"channels\": (16, 32, 64),\n        \"strides\": (2, 2),\n        \"num_res_units\": 1,\n        \"act\": (torch.nn.LeakyReLU, {\"negative_slope\": 0.2}),\n    },\n    (16, 4, 32, 64, 48),\n    (16, 3, 32, 64, 48),\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]\n\nILL_CASES = [\n    [\n        {  # len(channels) < 2\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 3,\n            \"channels\": (16,),\n            \"strides\": (2, 2),\n            \"num_res_units\": 0,\n        }\n    ],\n    [\n        {  # len(strides) < len(channels) - 1\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 3,\n            \"channels\": (8, 8, 8),\n            \"strides\": (2,),\n            \"num_res_units\": 0,\n        }\n    ],\n    [\n        {  # len(kernel_size) = 3, spatial_dims = 2\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 3,\n            \"channels\": (8, 8, 8),\n            \"strides\": (2, 2),\n            \"kernel_size\": (3, 3, 3),\n        }\n    ],\n    [\n        {  # len(up_kernel_size) = 2, spatial_dims = 3\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 3,\n            \"channels\": (8, 8, 8),\n            \"strides\": (2, 2),\n            \"up_kernel_size\": (3, 3),\n        }\n    ],\n]\n\n\nclass TestUNET(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = UNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = UNet(\n            spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0\n        )\n        test_data = torch.randn(16, 1, 32, 32)\n        test_script_save(net, test_data)\n\n    def test_script_without_running_stats(self):\n        net = UNet(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=3,\n            channels=(16, 32, 64),\n            strides=(2, 2),\n            num_res_units=0,\n            norm=(\"batch\", {\"track_running_stats\": False}),\n        )\n        test_data = torch.randn(16, 1, 16, 4)\n        test_script_save(net, test_data)\n\n    def test_ill_input_shape(self):\n        net = UNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2))\n        with eval_mode(net):\n            with self.assertRaisesRegex(RuntimeError, \"Sizes of tensors must match\"):\n                net.forward(torch.randn(2, 1, 16, 5))\n\n    @parameterized.expand(ILL_CASES)\n    def test_ill_input_hyper_params(self, input_param):\n        with self.assertRaises(ValueError):\n            _ = UNet(**input_param)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_unetr.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.unetr import UNETR\nfrom tests.test_utils import dict_product, skip_if_quick, test_script_save\n\nTEST_CASE_UNETR = [\n    [\n        {\n            **{k: v for k, v in params.items() if k not in [\"img_size\", \"nd\"]},\n            \"conv_block\": True,\n            \"res_block\": False,\n            \"img_size\": (params[\"img_size\"],) * params[\"nd\"],\n            **({\"spatial_dims\": 2} if params[\"nd\"] == 2 else {}),\n        },\n        (2, params[\"in_channels\"], *([params[\"img_size\"]] * params[\"nd\"])),\n        (2, params[\"out_channels\"], *([params[\"img_size\"]] * params[\"nd\"])),\n    ]\n    for params in dict_product(\n        dropout_rate=[0.4],\n        feature_size=[16],\n        hidden_size=[768],\n        img_size=[96, 128],\n        in_channels=[1],\n        mlp_dim=[3072],\n        nd=[2, 3],\n        norm_name=[\"instance\"],\n        num_heads=[8],\n        out_channels=[2],\n        proj_type=[\"perceptron\"],\n    )\n]\n\n\n@skip_if_quick\nclass TestUNETR(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_UNETR)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = UNETR(**input_param)\n        with eval_mode(net):\n            result = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_ill_arg(self):\n        with self.assertRaises(ValueError):\n            UNETR(\n                in_channels=1,\n                out_channels=3,\n                img_size=(128, 128, 128),\n                feature_size=16,\n                hidden_size=128,\n                mlp_dim=3072,\n                num_heads=12,\n                proj_type=\"conv\",\n                norm_name=\"instance\",\n                dropout_rate=5.0,\n            )\n\n        with self.assertRaises(ValueError):\n            UNETR(\n                in_channels=1,\n                out_channels=4,\n                img_size=(32, 32, 32),\n                feature_size=32,\n                hidden_size=512,\n                mlp_dim=3072,\n                num_heads=12,\n                proj_type=\"conv\",\n                norm_name=\"instance\",\n                dropout_rate=0.5,\n            )\n\n        with self.assertRaises(ValueError):\n            UNETR(\n                in_channels=1,\n                out_channels=3,\n                img_size=(96, 96, 96),\n                feature_size=16,\n                hidden_size=512,\n                mlp_dim=3072,\n                num_heads=14,\n                proj_type=\"conv\",\n                norm_name=\"batch\",\n                dropout_rate=0.4,\n            )\n\n        with self.assertRaises(ValueError):\n            UNETR(\n                in_channels=1,\n                out_channels=4,\n                img_size=(96, 96, 96),\n                feature_size=8,\n                hidden_size=768,\n                mlp_dim=3072,\n                num_heads=12,\n                proj_type=\"perc\",\n                norm_name=\"instance\",\n                dropout_rate=0.2,\n            )\n\n    @parameterized.expand(TEST_CASE_UNETR)\n    def test_script(self, input_param, input_shape, _):\n        net = UNETR(**(input_param))\n        net.eval()\n        with torch.no_grad():\n            torch.jit.script(net)\n\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_varautoencoder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.layers import Act\nfrom monai.networks.nets import VarAutoEncoder\nfrom tests.test_utils import test_script_save\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nTEST_CASE_0 = [  # single channel 2D, batch 4, no residual\n    {\n        \"spatial_dims\": 2,\n        \"in_shape\": (1, 128, 128),\n        \"out_channels\": 1,\n        \"latent_size\": 2,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n        \"num_res_units\": 0,\n    },\n    (1, 1, 128, 128),\n    (1, 1, 128, 128),\n]\n\nTEST_CASE_1 = [  # single channel 2D, batch 4\n    {\n        \"spatial_dims\": 2,\n        \"in_shape\": (1, 128, 128),\n        \"out_channels\": 1,\n        \"latent_size\": 2,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n    },\n    (1, 1, 128, 128),\n    (1, 1, 128, 128),\n]\n\nTEST_CASE_2 = [  # 3-channel 2D, batch 4, LeakyReLU activation\n    {\n        \"spatial_dims\": 2,\n        \"in_shape\": (3, 128, 128),\n        \"out_channels\": 3,\n        \"latent_size\": 2,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n        \"act\": (Act.LEAKYRELU, {\"negative_slope\": 0.2}),\n    },\n    (1, 3, 128, 128),\n    (1, 3, 128, 128),\n]\n\nTEST_CASE_3 = [  # 4-channel 3D, batch 4\n    {\n        \"spatial_dims\": 3,\n        \"in_shape\": (4, 128, 128, 128),\n        \"out_channels\": 3,\n        \"latent_size\": 2,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n    },\n    (1, 4, 128, 128, 128),\n    (1, 3, 128, 128, 128),\n]\n\nTEST_CASE_4 = [  # 4-channel 1D, batch 4\n    {\n        \"spatial_dims\": 1,\n        \"in_shape\": (4, 128),\n        \"out_channels\": 3,\n        \"latent_size\": 2,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n    },\n    (1, 4, 128),\n    (1, 3, 128),\n]\n\nTEST_CASE_5 = [  # 4-channel 1D, batch 4, use_sigmoid = False\n    {\n        \"spatial_dims\": 1,\n        \"in_shape\": (4, 128),\n        \"out_channels\": 3,\n        \"latent_size\": 2,\n        \"channels\": (4, 8, 16),\n        \"strides\": (2, 2, 2),\n        \"use_sigmoid\": False,\n    },\n    (1, 4, 128),\n    (1, 3, 128),\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]\n\n\nclass TestVarAutoEncoder(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = VarAutoEncoder(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))[0]\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = VarAutoEncoder(\n            spatial_dims=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2)\n        )\n        test_data = torch.randn(2, 1, 32, 32)\n        test_script_save(net, test_data, rtol=1e-3, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_vista3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import VISTA3D, SegResNetDS2\nfrom monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM\nfrom tests.test_utils import skip_if_quick\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASES = [\n    [{\"encoder_embed_dim\": 48, \"in_channels\": 1}, {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64)],\n    [{\"encoder_embed_dim\": 48, \"in_channels\": 2}, {}, (1, 2, 64, 64, 64), (1, 1, 64, 64, 64)],\n    [\n        {\"encoder_embed_dim\": 48, \"in_channels\": 1},\n        {\"class_vector\": torch.tensor([1, 2, 3], device=device)},\n        (1, 1, 64, 64, 64),\n        (3, 1, 64, 64, 64),\n    ],\n    [\n        {\"encoder_embed_dim\": 48, \"in_channels\": 1},\n        {\n            \"point_coords\": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),\n            \"point_labels\": torch.tensor([[1, 0]], device=device),\n        },\n        (1, 1, 64, 64, 64),\n        (1, 1, 64, 64, 64),\n    ],\n    [\n        {\"encoder_embed_dim\": 48, \"in_channels\": 1},\n        {\n            \"class_vector\": torch.tensor([1, 2], device=device),\n            \"point_coords\": torch.tensor([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], device=device),\n            \"point_labels\": torch.tensor([[1, 0], [1, 0]], device=device),\n        },\n        (1, 1, 64, 64, 64),\n        (2, 1, 64, 64, 64),\n    ],\n]\n\n\n@skip_if_quick\nclass TestVista3d(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_vista3d_shape(self, args, input_params, input_shape, expected_shape):\n        segresnet = SegResNetDS2(\n            in_channels=args[\"in_channels\"],\n            blocks_down=(1, 2, 2, 4, 4),\n            norm=\"instance\",\n            out_channels=args[\"encoder_embed_dim\"],\n            init_filters=args[\"encoder_embed_dim\"],\n            dsdepth=1,\n        )\n        point_head = PointMappingSAM(feature_size=args[\"encoder_embed_dim\"], n_classes=512, last_supported=132)\n        class_head = ClassMappingClassify(n_classes=512, feature_size=args[\"encoder_embed_dim\"], use_mlp=True)\n        net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device)\n        with eval_mode(net):\n            result = net.forward(\n                torch.randn(input_shape).to(device),\n                point_coords=input_params.get(\"point_coords\", None),\n                point_labels=input_params.get(\"point_labels\", None),\n                class_vector=input_params.get(\"class_vector\", None),\n            )\n            self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_vit.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.vit import ViT\nfrom tests.test_utils import dict_product, skip_if_quick, test_script_save\n\nTEST_CASE_Vit = [\n    (\n        [\n            {\n                **{k: v for k, v in params.items() if k not in [\"nd\"]},\n                **({\"spatial_dims\": 2} if params[\"nd\"] == 2 else {}),\n                **({\"post_activation\": False} if params[\"nd\"] == 2 and params[\"classification\"] else {}),\n            },\n            (2, params[\"in_channels\"], *([params[\"img_size\"]] * params[\"nd\"])),\n            (\n                (2, params[\"num_classes\"])\n                if params[\"classification\"]\n                else (2, (params[\"img_size\"] // params[\"patch_size\"]) ** params[\"nd\"], params[\"hidden_size\"])\n            ),\n        ]\n    )\n    for params in dict_product(\n        dropout_rate=[0.6],\n        in_channels=[4],\n        hidden_size=[768],\n        img_size=[96, 128],\n        patch_size=[16],\n        num_heads=[12],\n        mlp_dim=[3072],\n        num_layers=[4],\n        num_classes=[8],\n        proj_type=[\"conv\", \"perceptron\"],\n        classification=[False, True],\n        nd=[2, 3],\n    )\n]\n\n\n@skip_if_quick\nclass TestViT(unittest.TestCase):\n    @parameterized.expand(TEST_CASE_Vit)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = ViT(**input_param)\n        with eval_mode(net):\n            result, _ = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(\n        [\n            (1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, \"conv\", False, 5.0),\n            (1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, \"perceptron\", False, 0.3),\n            (1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, \"conv\", False, 0.3),\n            (1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, \"perceptron\", True, 0.3),\n            (4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, \"perc\", False, 0.3),\n        ]\n    )\n    def test_ill_arg(\n        self,\n        in_channels,\n        img_size,\n        patch_size,\n        hidden_size,\n        mlp_dim,\n        num_layers,\n        num_heads,\n        proj_type,\n        classification,\n        dropout_rate,\n    ):\n        with self.assertRaises(ValueError):\n            ViT(\n                in_channels=in_channels,\n                img_size=img_size,\n                patch_size=patch_size,\n                hidden_size=hidden_size,\n                mlp_dim=mlp_dim,\n                num_layers=num_layers,\n                num_heads=num_heads,\n                proj_type=proj_type,\n                classification=classification,\n                dropout_rate=dropout_rate,\n            )\n\n    @parameterized.expand(TEST_CASE_Vit[:1])\n    def test_script(self, input_param, input_shape, _):\n        net = ViT(**(input_param))\n        net.eval()\n        with torch.no_grad():\n            torch.jit.script(net)\n\n        test_data = torch.randn(input_shape)\n        test_script_save(net, test_data)\n\n    def test_access_attn_matrix(self):\n        # input format\n        in_channels = 1\n        img_size = (96, 96, 96)\n        patch_size = (16, 16, 16)\n        in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])\n\n        # no data in the matrix\n        no_matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)\n        no_matrix_acess_blk(torch.randn(in_shape))\n        assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)\n        # no of elements is zero\n        assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0\n\n        # be able to acess the attention matrix\n        matrix_acess_blk = ViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True)\n        matrix_acess_blk(torch.randn(in_shape))\n        assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 216, 216)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_vitautoenc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.vitautoenc import ViTAutoEnc\nfrom tests.test_utils import dict_product, skip_if_quick, skip_if_windows\n\nTEST_CASE_Vitautoenc = [\n    [\n        {\n            \"in_channels\": params[\"in_channels\"],\n            \"img_size\": (params[\"img_size\"],) * params[\"nd\"],\n            \"patch_size\": (params[\"patch_size\"],) * params[\"nd\"],\n            \"hidden_size\": 768,\n            \"mlp_dim\": 3072,\n            \"num_layers\": 4,\n            \"num_heads\": 12,\n            \"proj_type\": params[\"proj_type\"],\n            \"dropout_rate\": 0.6,\n            \"spatial_dims\": params[\"nd\"],\n        },\n        (2, params[\"in_channels\"], *([params[\"img_size\"]] * params[\"nd\"])),\n        (2, 1, *([params[\"img_size\"]] * params[\"nd\"])),\n    ]\n    for params in dict_product(\n        in_channels=[1, 4], img_size=[64, 96, 128], patch_size=[16], proj_type=[\"conv\", \"perceptron\"], nd=[2, 3]\n    )\n]\n\nTEST_CASE_Vitautoenc.append(\n    [\n        {\n            \"in_channels\": 1,\n            \"img_size\": (512, 512, 32),\n            \"patch_size\": (64, 64, 16),\n            \"hidden_size\": 768,\n            \"mlp_dim\": 3072,\n            \"num_layers\": 4,\n            \"num_heads\": 12,\n            \"proj_type\": \"conv\",\n            \"dropout_rate\": 0.6,\n            \"spatial_dims\": 3,\n        },\n        (2, 1, 512, 512, 32),\n        (2, 1, 512, 512, 32),\n    ]\n)\n\n\n@skip_if_quick\nclass TestVitAutoenc(unittest.TestCase):\n    def setUp(self):\n        self.threads = torch.get_num_threads()\n        torch.set_num_threads(4)\n\n    def tearDown(self):\n        torch.set_num_threads(self.threads)\n\n    @parameterized.expand(TEST_CASE_Vitautoenc)\n    @skip_if_windows\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = ViTAutoEnc(**input_param)\n        with eval_mode(net):\n            result, _ = net(torch.randn(input_shape))\n            self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(\n        [\n            (1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, \"perceptron\", 0.3),  # img_size_too_large_for_patch_size\n            (1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, \"conv\", 0.3),  # num_heads_out_of_bound\n            (1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, \"perceptron\", 0.3),  # img_size_not_divisible_by_patch_size\n            (4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, \"perc\", 0.3),  # invalid_pos_embed\n            (4, (96, 96, 96), (9, 9, 9), 768, 3072, 12, 12, \"perc\", 0.3),  # patch_size_not_divisible\n            # Add more test cases as needed\n        ]\n    )\n    def test_ill_arg(\n        self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, proj_type, dropout_rate\n    ):\n        with self.assertRaises(ValueError):\n            ViTAutoEnc(\n                in_channels=in_channels,\n                img_size=img_size,\n                patch_size=patch_size,\n                hidden_size=hidden_size,\n                mlp_dim=mlp_dim,\n                num_layers=num_layers,\n                num_heads=num_heads,\n                proj_type=proj_type,\n                dropout_rate=dropout_rate,\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_vnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import VNet\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_VNET_2D_1 = [\n    {\"spatial_dims\": 2, \"in_channels\": 4, \"out_channels\": 1, \"act\": \"elu\", \"dropout_dim\": 1},\n    (1, 4, 32, 32),\n    (1, 1, 32, 32),\n]\nTEST_CASE_VNET_2D_2 = [\n    {\"spatial_dims\": 2, \"in_channels\": 2, \"out_channels\": 2, \"act\": \"prelu\", \"dropout_dim\": 2},\n    (1, 2, 32, 32),\n    (1, 2, 32, 32),\n]\nTEST_CASE_VNET_2D_3 = [\n    {\"spatial_dims\": 2, \"in_channels\": 1, \"out_channels\": 3, \"dropout_dim\": 3},\n    (1, 1, 32, 32),\n    (1, 3, 32, 32),\n]\nTEST_CASE_VNET_3D_1 = [\n    {\"spatial_dims\": 3, \"in_channels\": 4, \"out_channels\": 1, \"act\": \"elu\", \"dropout_dim\": 1},\n    (1, 4, 32, 32, 32),\n    (1, 1, 32, 32, 32),\n]\nTEST_CASE_VNET_3D_2 = [\n    {\"spatial_dims\": 3, \"in_channels\": 2, \"out_channels\": 2, \"act\": \"prelu\", \"dropout_dim\": 2},\n    (1, 2, 32, 32, 32),\n    (1, 2, 32, 32, 32),\n]\nTEST_CASE_VNET_3D_3 = [\n    {\"spatial_dims\": 3, \"in_channels\": 1, \"out_channels\": 3, \"dropout_dim\": 3},\n    (1, 1, 32, 32, 32),\n    (1, 3, 32, 32, 32),\n]\n\n\nclass TestVNet(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_VNET_2D_1,\n            TEST_CASE_VNET_2D_2,\n            TEST_CASE_VNET_2D_3,\n            TEST_CASE_VNET_3D_1,\n            TEST_CASE_VNET_3D_2,\n            TEST_CASE_VNET_3D_3,\n        ]\n    )\n    def test_vnet_shape(self, input_param, input_shape, expected_shape):\n        net = VNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = VNet(spatial_dims=3, in_channels=1, out_channels=3, dropout_dim=3)\n        test_data = torch.randn(1, 1, 32, 32, 32)\n        test_script_save(net, test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_voxelmorph.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets import VoxelMorph, VoxelMorphUNet\nfrom tests.test_utils import test_script_save\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [  # single channel 3D, batch 1,\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_1 = [  # single channel 3D, batch 1,\n    # using strided convolution for downsampling instead of maxpooling\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n        \"use_maxpool\": False,\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_2 = [  # single channel 3D, batch 1,\n    # using strided convolution for downsampling instead of maxpooling,\n    # explicitly specify leakyrelu with a different negative slope for final convolutions\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n        \"final_conv_act\": (\"leakyrelu\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        \"use_maxpool\": False,\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_3 = [  # single channel 3D, batch 1,\n    # using strided convolution for downsampling instead of maxpooling,\n    # explicitly specify leakyrelu with a different negative slope for both unet and final convolutions.\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n        \"final_conv_act\": (\"leakyrelu\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        \"act\": (\"leakyrelu\", {\"negative_slope\": 0.1, \"inplace\": True}),\n        \"use_maxpool\": False,\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_4 = [  # 2-channel 3D, batch 1,\n    # i.e., possible use case where the input contains both modalities (e.g., T1 and T2)\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 4,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    },\n    (1, 4, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_5 = [  # single channel 3D, batch 2,\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    },\n    (2, 2, 96, 96, 48),\n    (2, 3, 96, 96, 48),\n]\n\nTEST_CASE_6 = [  # single channel 2D, batch 2,\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    },\n    (2, 2, 96, 96),\n    (2, 2, 96, 96),\n]\n\nTEST_CASE_7 = [  # single channel 3D, batch 1,\n    # one additional level in the UNet with 32 channels in both down and up branch.\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_8 = [  # single channel 3D, batch 1,\n    # one additional level in the UNet with 32 channels in both down and up branch.\n    # and removed one of the two final convolution blocks.\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16,),\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nTEST_CASE_9 = [  # single channel 3D, batch 1,\n    # only one level in the UNet\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32),\n        \"final_conv_channels\": (16, 16),\n    },\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nCASES = [\n    TEST_CASE_0,\n    TEST_CASE_1,\n    TEST_CASE_2,\n    TEST_CASE_3,\n    TEST_CASE_4,\n    TEST_CASE_5,\n    TEST_CASE_6,\n    TEST_CASE_7,\n    TEST_CASE_8,\n    TEST_CASE_9,\n]\n\nILL_CASE_0 = [  # spatial_dims = 1\n    {\n        \"spatial_dims\": 1,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    }\n]\n\nILL_CASE_1 = [  # in_channels = 3 (not divisible by 2)\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 3,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    }\n]\n\nILL_CASE_2 = [  # len(channels) = 0\n    {\"spatial_dims\": 3, \"in_channels\": 2, \"unet_out_channels\": 32, \"channels\": (), \"final_conv_channels\": (16, 16)}\n]\n\nILL_CASE_3 = [  # channels not in pairs\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n    }\n]\n\nILL_CASE_4 = [  # len(kernel_size) = 3, spatial_dims = 2\n    {\n        \"spatial_dims\": 2,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n        \"kernel_size\": (3, 3, 3),\n    }\n]\n\nILL_CASE_5 = [  # len(up_kernel_size) = 2, spatial_dims = 3\n    {\n        \"spatial_dims\": 3,\n        \"in_channels\": 2,\n        \"unet_out_channels\": 32,\n        \"channels\": (16, 32, 32, 32, 32, 32),\n        \"final_conv_channels\": (16, 16),\n        \"up_kernel_size\": (3, 3),\n    }\n]\n\nILL_CASES = [ILL_CASE_0, ILL_CASE_1, ILL_CASE_2, ILL_CASE_3, ILL_CASE_4, ILL_CASE_5]\n\nILL_CASES_IN_SHAPE_0 = [  # moving and fixed image shape not match\n    {\"spatial_dims\": 3},\n    (1, 2, 96, 96, 48),\n    (1, 3, 96, 96, 48),\n]\n\nILL_CASES_IN_SHAPE_1 = [  # spatial_dims = 2, ddf has 3 channels\n    {\"spatial_dims\": 2},\n    (1, 1, 96, 96, 96),\n    (1, 1, 96, 96, 96),\n]\n\nILL_CASES_IN_SHAPE = [ILL_CASES_IN_SHAPE_0, ILL_CASES_IN_SHAPE_1]\n\n\nclass TestVOXELMORPH(unittest.TestCase):\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        net = VoxelMorphUNet(**input_param).to(device)\n        with eval_mode(net):\n            result = net.forward(torch.randn(input_shape).to(device))\n            self.assertEqual(result.shape, expected_shape)\n\n    def test_script(self):\n        net = VoxelMorphUNet(\n            spatial_dims=2,\n            in_channels=2,\n            unet_out_channels=32,\n            channels=(16, 32, 32, 32, 32, 32),\n            final_conv_channels=(16, 16),\n        )\n        test_data = torch.randn(1, 2, 96, 96)\n        test_script_save(net, test_data)\n\n    @parameterized.expand(ILL_CASES)\n    def test_ill_input_hyper_params(self, input_param):\n        with self.assertRaises(ValueError):\n            _ = VoxelMorphUNet(**input_param)\n\n    @parameterized.expand(ILL_CASES_IN_SHAPE)\n    def test_ill_input_shape(self, input_param, moving_shape, fixed_shape):\n        with self.assertRaises((ValueError, RuntimeError)):\n            net = VoxelMorph(**input_param).to(device)\n            with eval_mode(net):\n                _ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_vqvae.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.networks.nets.vqvae import VQVAE\n\nTEST_CASES = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"num_res_layers\": 1,\n            \"num_res_channels\": (4, 4),\n            \"downsample_parameters\": ((2, 4, 1, 1),) * 2,\n            \"upsample_parameters\": ((2, 4, 1, 1, 0),) * 2,\n            \"num_embeddings\": 8,\n            \"embedding_dim\": 8,\n        },\n        (1, 1, 8, 8),\n        (1, 1, 8, 8),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"num_res_layers\": 1,\n            \"num_res_channels\": 4,\n            \"downsample_parameters\": ((2, 4, 1, 1),) * 2,\n            \"upsample_parameters\": ((2, 4, 1, 1, 0),) * 2,\n            \"num_embeddings\": 8,\n            \"embedding_dim\": 8,\n        },\n        (1, 1, 8, 8, 8),\n        (1, 1, 8, 8, 8),\n    ],\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"num_res_layers\": 1,\n            \"num_res_channels\": (4, 4),\n            \"downsample_parameters\": (2, 4, 1, 1),\n            \"upsample_parameters\": ((2, 4, 1, 1, 0),) * 2,\n            \"num_embeddings\": 8,\n            \"embedding_dim\": 8,\n        },\n        (1, 1, 8, 8),\n        (1, 1, 8, 8),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (4, 4),\n            \"num_res_layers\": 1,\n            \"num_res_channels\": (4, 4),\n            \"downsample_parameters\": ((2, 4, 1, 1),) * 2,\n            \"upsample_parameters\": (2, 4, 1, 1, 0),\n            \"num_embeddings\": 8,\n            \"embedding_dim\": 8,\n        },\n        (1, 1, 8, 8, 8),\n        (1, 1, 8, 8, 8),\n    ],\n]\n\nTEST_LATENT_SHAPE = {\n    \"spatial_dims\": 2,\n    \"in_channels\": 1,\n    \"out_channels\": 1,\n    \"downsample_parameters\": ((2, 4, 1, 1),) * 2,\n    \"upsample_parameters\": ((2, 4, 1, 1, 0),) * 2,\n    \"num_res_layers\": 1,\n    \"channels\": (8, 8),\n    \"num_res_channels\": (8, 8),\n    \"num_embeddings\": 16,\n    \"embedding_dim\": 8,\n}\n\n\nclass TestVQVAE(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        net = VQVAE(**input_param).to(device)\n\n        with eval_mode(net):\n            result, _ = net(torch.randn(input_shape).to(device))\n\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        input_param = input_param.copy()\n        input_param.update({\"use_checkpointing\": True})\n\n        net = VQVAE(**input_param).to(device)\n\n        with eval_mode(net):\n            result, _ = net(torch.randn(input_shape).to(device))\n\n        self.assertEqual(result.shape, expected_shape)\n\n    # Removed this test case since TorchScript currently does not support activation checkpoint.\n    # def test_script(self):\n    #     net = VQVAE(\n    #         spatial_dims=2,\n    #         in_channels=1,\n    #         out_channels=1,\n    #         downsample_parameters=((2, 4, 1, 1),) * 2,\n    #         upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n    #         num_res_layers=1,\n    #         channels=(8, 8),\n    #         num_res_channels=(8, 8),\n    #         num_embeddings=16,\n    #         embedding_dim=8,\n    #         ddp_sync=False,\n    #     )\n    #     test_data = torch.randn(1, 1, 16, 16)\n    #     test_script_save(net, test_data)\n\n    def test_channels_not_same_size_of_num_res_channels(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16, 16),\n                downsample_parameters=((2, 4, 1, 1),) * 2,\n                upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n            )\n\n    def test_channels_not_same_size_of_downsample_parameters(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16),\n                downsample_parameters=((2, 4, 1, 1),) * 3,\n                upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n            )\n\n    def test_channels_not_same_size_of_upsample_parameters(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16),\n                downsample_parameters=((2, 4, 1, 1),) * 2,\n                upsample_parameters=((2, 4, 1, 1, 0),) * 3,\n            )\n\n    def test_downsample_parameters_not_sequence_or_int(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16),\n                downsample_parameters=((\"test\", 4, 1, 1),) * 2,\n                upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n            )\n\n    def test_upsample_parameters_not_sequence_or_int(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16),\n                downsample_parameters=((2, 4, 1, 1),) * 2,\n                upsample_parameters=((\"test\", 4, 1, 1, 0),) * 2,\n            )\n\n    def test_downsample_parameter_length_different_4(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16),\n                downsample_parameters=((2, 4, 1),) * 3,\n                upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n            )\n\n    def test_upsample_parameter_length_different_5(self):\n        with self.assertRaises(ValueError):\n            VQVAE(\n                spatial_dims=2,\n                in_channels=1,\n                out_channels=1,\n                channels=(16, 16),\n                num_res_channels=(16, 16, 16),\n                downsample_parameters=((2, 4, 1, 1),) * 2,\n                upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3,\n            )\n\n    def test_encode_shape(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        net = VQVAE(**TEST_LATENT_SHAPE).to(device)\n\n        with eval_mode(net):\n            latent = net.encode(torch.randn(1, 1, 32, 32).to(device))\n\n        self.assertEqual(latent.shape, (1, 8, 8, 8))\n\n    def test_index_quantize_shape(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        net = VQVAE(**TEST_LATENT_SHAPE).to(device)\n\n        with eval_mode(net):\n            latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device))\n\n        self.assertEqual(latent.shape, (1, 8, 8))\n\n    def test_decode_shape(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        net = VQVAE(**TEST_LATENT_SHAPE).to(device)\n\n        with eval_mode(net):\n            latent = net.decode(torch.randn(1, 8, 8, 8).to(device))\n\n        self.assertEqual(latent.shape, (1, 1, 32, 32))\n\n    def test_decode_samples_shape(self):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        net = VQVAE(**TEST_LATENT_SHAPE).to(device)\n\n        with eval_mode(net):\n            latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device))\n\n        self.assertEqual(latent.shape, (1, 1, 32, 32))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/nets/test_vqvaetransformer_inferer.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.inferers import VQVAETransformerInferer\nfrom monai.networks.nets import VQVAE, DecoderOnlyTransformer\nfrom monai.utils import optional_import\nfrom monai.utils.ordering import Ordering, OrderingType\n\neinops, has_einops = optional_import(\"einops\")\nTEST_CASES = [\n    [\n        {\n            \"spatial_dims\": 2,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (8, 8),\n            \"num_res_channels\": (8, 8),\n            \"downsample_parameters\": ((2, 4, 1, 1),) * 2,\n            \"upsample_parameters\": ((2, 4, 1, 1, 0),) * 2,\n            \"num_res_layers\": 1,\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 8,\n        },\n        {\n            \"num_tokens\": 16 + 1,\n            \"max_seq_len\": 4,\n            \"attn_layers_dim\": 4,\n            \"attn_layers_depth\": 2,\n            \"attn_layers_heads\": 1,\n            \"with_cross_attention\": False,\n        },\n        {\"ordering_type\": OrderingType.RASTER_SCAN.value, \"spatial_dims\": 2, \"dimensions\": (2, 2, 2)},\n        (2, 1, 8, 8),\n        (2, 4, 17),\n        (2, 2, 2),\n    ],\n    [\n        {\n            \"spatial_dims\": 3,\n            \"in_channels\": 1,\n            \"out_channels\": 1,\n            \"channels\": (8, 8),\n            \"num_res_channels\": (8, 8),\n            \"downsample_parameters\": ((2, 4, 1, 1),) * 2,\n            \"upsample_parameters\": ((2, 4, 1, 1, 0),) * 2,\n            \"num_res_layers\": 1,\n            \"num_embeddings\": 16,\n            \"embedding_dim\": 8,\n        },\n        {\n            \"num_tokens\": 16 + 1,\n            \"max_seq_len\": 8,\n            \"attn_layers_dim\": 4,\n            \"attn_layers_depth\": 2,\n            \"attn_layers_heads\": 1,\n            \"with_cross_attention\": False,\n        },\n        {\"ordering_type\": OrderingType.RASTER_SCAN.value, \"spatial_dims\": 3, \"dimensions\": (2, 2, 2, 2)},\n        (2, 1, 8, 8, 8),\n        (2, 8, 17),\n        (2, 2, 2, 2),\n    ],\n]\n\n\nclass TestVQVAETransformerInferer(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_prediction_shape(\n        self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape\n    ):\n        stage_1 = VQVAE(**stage_1_params)\n        stage_2 = DecoderOnlyTransformer(**stage_2_params)\n        ordering = Ordering(**ordering_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n\n        inferer = VQVAETransformerInferer()\n        prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)\n        self.assertEqual(prediction.shape, logits_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_prediction_shape_shorter_sequence(\n        self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape\n    ):\n        stage_1 = VQVAE(**stage_1_params)\n        max_seq_len = 3\n        stage_2_params_shorter = dict(stage_2_params)\n        stage_2_params_shorter[\"max_seq_len\"] = max_seq_len\n        stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)\n        ordering = Ordering(**ordering_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n\n        inferer = VQVAETransformerInferer()\n        prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)\n        cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2])\n        self.assertEqual(prediction.shape, cropped_logits_shape)\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample(self):\n\n        stage_1 = VQVAE(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            channels=(8, 8),\n            num_res_channels=(8, 8),\n            downsample_parameters=((2, 4, 1, 1),) * 2,\n            upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n            num_res_layers=1,\n            num_embeddings=16,\n            embedding_dim=8,\n        )\n        stage_2 = DecoderOnlyTransformer(\n            num_tokens=16 + 1,\n            max_seq_len=4,\n            attn_layers_dim=4,\n            attn_layers_depth=2,\n            attn_layers_heads=1,\n            with_cross_attention=False,\n        )\n        ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        inferer = VQVAETransformerInferer()\n\n        starting_token = 16  # from stage_1 num_embeddings\n\n        sample = inferer.sample(\n            latent_spatial_dim=(2, 2),\n            starting_tokens=starting_token * torch.ones((2, 1), device=device),\n            vqvae_model=stage_1,\n            transformer_model=stage_2,\n            ordering=ordering,\n        )\n        self.assertEqual(sample.shape, (2, 1, 8, 8))\n\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_sample_shorter_sequence(self):\n        stage_1 = VQVAE(\n            spatial_dims=2,\n            in_channels=1,\n            out_channels=1,\n            channels=(8, 8),\n            num_res_channels=(8, 8),\n            downsample_parameters=((2, 4, 1, 1),) * 2,\n            upsample_parameters=((2, 4, 1, 1, 0),) * 2,\n            num_res_layers=1,\n            num_embeddings=16,\n            embedding_dim=8,\n        )\n        stage_2 = DecoderOnlyTransformer(\n            num_tokens=16 + 1,\n            max_seq_len=2,\n            attn_layers_dim=4,\n            attn_layers_depth=2,\n            attn_layers_heads=1,\n            with_cross_attention=False,\n        )\n        ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        inferer = VQVAETransformerInferer()\n\n        starting_token = 16  # from stage_1 num_embeddings\n\n        sample = inferer.sample(\n            latent_spatial_dim=(2, 2),\n            starting_tokens=starting_token * torch.ones((2, 1), device=device),\n            vqvae_model=stage_1,\n            transformer_model=stage_2,\n            ordering=ordering,\n        )\n        self.assertEqual(sample.shape, (2, 1, 8, 8))\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihood(\n        self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape\n    ):\n        stage_1 = VQVAE(**stage_1_params)\n        stage_2 = DecoderOnlyTransformer(**stage_2_params)\n        ordering = Ordering(**ordering_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n\n        inferer = VQVAETransformerInferer()\n        likelihood = inferer.get_likelihood(\n            inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering\n        )\n        self.assertEqual(likelihood.shape, latent_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihood_shorter_sequence(\n        self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape\n    ):\n        stage_1 = VQVAE(**stage_1_params)\n        max_seq_len = 3\n        stage_2_params_shorter = dict(stage_2_params)\n        stage_2_params_shorter[\"max_seq_len\"] = max_seq_len\n        stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)\n        ordering = Ordering(**ordering_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n\n        inferer = VQVAETransformerInferer()\n        likelihood = inferer.get_likelihood(\n            inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering\n        )\n        self.assertEqual(likelihood.shape, latent_shape)\n\n    @parameterized.expand(TEST_CASES)\n    @skipUnless(has_einops, \"Requires einops\")\n    def test_get_likelihood_resampling(\n        self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape\n    ):\n        stage_1 = VQVAE(**stage_1_params)\n        stage_2 = DecoderOnlyTransformer(**stage_2_params)\n        ordering = Ordering(**ordering_params)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        stage_1.to(device)\n        stage_2.to(device)\n        stage_1.eval()\n        stage_2.eval()\n\n        input = torch.randn(input_shape).to(device)\n\n        inferer = VQVAETransformerInferer()\n        likelihood = inferer.get_likelihood(\n            inputs=input,\n            vqvae_model=stage_1,\n            transformer_model=stage_2,\n            ordering=ordering,\n            resample_latent_likelihoods=True,\n            resample_interpolation_mode=\"nearest\",\n        )\n        self.assertEqual(likelihood.shape, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/schedulers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/schedulers/test_scheduler_ddim.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.schedulers import DDIMScheduler\nfrom tests.test_utils import assert_allclose\n\nTEST_2D_CASE = []\nfor beta_schedule in [\"linear_beta\", \"scaled_linear_beta\"]:\n    TEST_2D_CASE.append([{\"schedule\": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])\n\nTEST_3D_CASE = []\nfor beta_schedule in [\"linear_beta\", \"scaled_linear_beta\"]:\n    TEST_3D_CASE.append([{\"schedule\": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])\n\nTEST_CASES = TEST_2D_CASE + TEST_3D_CASE\n\nTEST_FULl_LOOP = [\n    [{\"schedule\": \"linear_beta\"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])]\n]\n\n\nclass TestDDPMScheduler(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_add_noise(self, input_param, input_shape, expected_shape):\n        scheduler = DDIMScheduler(**input_param)\n        scheduler.set_timesteps(num_inference_steps=100)\n        original_sample = torch.zeros(input_shape)\n        noise = torch.randn_like(original_sample)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()\n\n        noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)\n        self.assertEqual(noisy.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_step_shape(self, input_param, input_shape, expected_shape):\n        scheduler = DDIMScheduler(**input_param)\n        scheduler.set_timesteps(num_inference_steps=100)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)\n        self.assertEqual(output_step[0].shape, expected_shape)\n        self.assertEqual(output_step[1].shape, expected_shape)\n\n    @parameterized.expand(TEST_FULl_LOOP)\n    def test_full_timestep_loop(self, input_param, input_shape, expected_output):\n        scheduler = DDIMScheduler(**input_param)\n        scheduler.set_timesteps(50)\n        torch.manual_seed(42)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        for t in range(50):\n            sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)\n        assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)\n\n    def test_set_timesteps(self):\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        scheduler.set_timesteps(num_inference_steps=100)\n        self.assertEqual(scheduler.num_inference_steps, 100)\n        self.assertEqual(len(scheduler.timesteps), 100)\n\n    def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):\n        scheduler = DDIMScheduler(num_train_timesteps=1000)\n        with self.assertRaises(ValueError):\n            scheduler.set_timesteps(num_inference_steps=2000)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/schedulers/test_scheduler_ddpm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.schedulers import DDPMScheduler\nfrom tests.test_utils import assert_allclose\n\nTEST_2D_CASE = []\nfor beta_schedule in [\"linear_beta\", \"scaled_linear_beta\"]:\n    for variance_type in [\"fixed_small\", \"fixed_large\"]:\n        TEST_2D_CASE.append(\n            [{\"schedule\": beta_schedule, \"variance_type\": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)]\n        )\n\nTEST_3D_CASE = []\nfor beta_schedule in [\"linear_beta\", \"scaled_linear_beta\"]:\n    for variance_type in [\"fixed_small\", \"fixed_large\"]:\n        TEST_3D_CASE.append(\n            [{\"schedule\": beta_schedule, \"variance_type\": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]\n        )\n\nTEST_CASES = TEST_2D_CASE + TEST_3D_CASE\n\nTEST_FULl_LOOP = [\n    [{\"schedule\": \"linear_beta\"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])]\n]\n\n\nclass TestDDPMScheduler(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_add_noise(self, input_param, input_shape, expected_shape):\n        scheduler = DDPMScheduler(**input_param)\n        original_sample = torch.zeros(input_shape)\n        noise = torch.randn_like(original_sample)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()\n\n        noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)\n        self.assertEqual(noisy.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_step_shape(self, input_param, input_shape, expected_shape):\n        scheduler = DDPMScheduler(**input_param)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)\n        self.assertEqual(output_step[0].shape, expected_shape)\n        self.assertEqual(output_step[1].shape, expected_shape)\n\n    @parameterized.expand(TEST_FULl_LOOP)\n    def test_full_timestep_loop(self, input_param, input_shape, expected_output):\n        scheduler = DDPMScheduler(**input_param)\n        scheduler.set_timesteps(50)\n        torch.manual_seed(42)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        for t in range(50):\n            sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)\n        assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)\n\n    @parameterized.expand(TEST_CASES)\n    def test_get_velocity_shape(self, input_param, input_shape, expected_shape):\n        scheduler = DDPMScheduler(**input_param)\n        sample = torch.randn(input_shape)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long()\n        velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps)\n        self.assertEqual(velocity.shape, expected_shape)\n\n    def test_step_learned(self):\n        for variance_type in [\"learned\", \"learned_range\"]:\n            scheduler = DDPMScheduler(variance_type=variance_type)\n        model_output = torch.randn(2, 6, 16, 16)\n        sample = torch.randn(2, 3, 16, 16)\n        output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)\n        self.assertEqual(output_step[0].shape, sample.shape)\n        self.assertEqual(output_step[1].shape, sample.shape)\n\n    def test_set_timesteps(self):\n        scheduler = DDPMScheduler(num_train_timesteps=1000)\n        scheduler.set_timesteps(num_inference_steps=100)\n        self.assertEqual(scheduler.num_inference_steps, 100)\n        self.assertEqual(len(scheduler.timesteps), 100)\n\n    def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):\n        scheduler = DDPMScheduler(num_train_timesteps=1000)\n        with self.assertRaises(ValueError):\n            scheduler.set_timesteps(num_inference_steps=2000)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/schedulers/test_scheduler_pndm.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.schedulers import PNDMScheduler\nfrom tests.test_utils import assert_allclose\n\nTEST_2D_CASE = []\nfor beta_schedule in [\"linear_beta\", \"scaled_linear_beta\"]:\n    TEST_2D_CASE.append([{\"schedule\": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])\n\nTEST_3D_CASE = []\nfor beta_schedule in [\"linear_beta\", \"scaled_linear_beta\"]:\n    TEST_3D_CASE.append([{\"schedule\": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])\n\nTEST_CASES = TEST_2D_CASE + TEST_3D_CASE\n\nTEST_FULl_LOOP = [\n    [\n        {\"schedule\": \"linear_beta\"},\n        (1, 1, 2, 2),\n        torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]),\n    ]\n]\n\n\nclass TestDDPMScheduler(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_add_noise(self, input_param, input_shape, expected_shape):\n        scheduler = PNDMScheduler(**input_param)\n        original_sample = torch.zeros(input_shape)\n        noise = torch.randn_like(original_sample)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()\n        noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)\n        self.assertEqual(noisy.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_step_shape(self, input_param, input_shape, expected_shape):\n        scheduler = PNDMScheduler(**input_param)\n        scheduler.set_timesteps(600)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)\n        self.assertEqual(output_step[0].shape, expected_shape)\n        self.assertEqual(output_step[1], None)\n\n    @parameterized.expand(TEST_FULl_LOOP)\n    def test_full_timestep_loop(self, input_param, input_shape, expected_output):\n        scheduler = PNDMScheduler(**input_param)\n        scheduler.set_timesteps(50)\n        torch.manual_seed(42)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        for t in range(50):\n            sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)\n        assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)\n\n    @parameterized.expand(TEST_FULl_LOOP)\n    def test_timestep_two_loops(self, input_param, input_shape, expected_output):\n        scheduler = PNDMScheduler(**input_param)\n        scheduler.set_timesteps(50)\n        torch.manual_seed(42)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        for t in range(50):\n            sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)\n        torch.manual_seed(42)\n        model_output2 = torch.randn(input_shape)\n        sample2 = torch.randn(input_shape)\n        scheduler.set_timesteps(50)\n        for t in range(50):\n            sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2)\n        assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3)\n\n    def test_set_timesteps(self):\n        scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True)\n        scheduler.set_timesteps(num_inference_steps=100)\n        self.assertEqual(scheduler.num_inference_steps, 100)\n        self.assertEqual(len(scheduler.timesteps), 100)\n\n    def test_set_timesteps_prk(self):\n        scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False)\n        scheduler.set_timesteps(num_inference_steps=100)\n        self.assertEqual(scheduler.num_inference_steps, 109)\n        self.assertEqual(len(scheduler.timesteps), 109)\n\n    def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):\n        scheduler = PNDMScheduler(num_train_timesteps=1000)\n        with self.assertRaises(ValueError):\n            scheduler.set_timesteps(num_inference_steps=2000)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/schedulers/test_scheduler_rflow.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.schedulers import RFlowScheduler\nfrom tests.test_utils import assert_allclose\n\nTEST_2D_CASE = []\nfor sample_method in [\"uniform\", \"logit-normal\"]:\n    TEST_2D_CASE.append(\n        [{\"sample_method\": sample_method, \"use_timestep_transform\": False}, (2, 6, 16, 16), (2, 6, 16, 16)]\n    )\n\nfor sample_method in [\"uniform\", \"logit-normal\"]:\n    TEST_2D_CASE.append(\n        [\n            {\"sample_method\": sample_method, \"use_timestep_transform\": True, \"spatial_dim\": 2},\n            (2, 6, 16, 16),\n            (2, 6, 16, 16),\n        ]\n    )\n\n\nTEST_3D_CASE = []\nfor sample_method in [\"uniform\", \"logit-normal\"]:\n    TEST_3D_CASE.append(\n        [{\"sample_method\": sample_method, \"use_timestep_transform\": False}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]\n    )\n\nfor sample_method in [\"uniform\", \"logit-normal\"]:\n    TEST_3D_CASE.append(\n        [\n            {\"sample_method\": sample_method, \"use_timestep_transform\": True, \"spatial_dim\": 3},\n            (2, 6, 16, 16, 16),\n            (2, 6, 16, 16, 16),\n        ]\n    )\n\nTEST_CASES = TEST_2D_CASE + TEST_3D_CASE\n\nTEST_FULl_LOOP = [\n    [{\"sample_method\": \"uniform\"}, (1, 1, 2, 2), torch.Tensor([[[[-0.786166, -0.057519], [2.442662, -0.407664]]]])]\n]\n\n\nclass TestRFlowScheduler(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_add_noise(self, input_param, input_shape, expected_shape):\n        scheduler = RFlowScheduler(**input_param)\n        original_sample = torch.zeros(input_shape)\n        timesteps = scheduler.sample_timesteps(original_sample)\n        noise = torch.randn_like(original_sample)\n        timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()\n        noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)\n        self.assertEqual(noisy.shape, expected_shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_step_shape(self, input_param, input_shape, expected_shape):\n        scheduler = RFlowScheduler(**input_param)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=torch.numel(sample[0, 0, ...]))\n        output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)\n        self.assertEqual(output_step[0].shape, expected_shape)\n        self.assertEqual(output_step[1].shape, expected_shape)\n\n    @parameterized.expand(TEST_FULl_LOOP)\n    def test_full_timestep_loop(self, input_param, input_shape, expected_output):\n        scheduler = RFlowScheduler(**input_param)\n        torch.manual_seed(42)\n        model_output = torch.randn(input_shape)\n        sample = torch.randn(input_shape)\n        scheduler.set_timesteps(50, input_img_size_numel=torch.numel(sample[0, 0, ...]))\n        for t in range(50):\n            sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)\n        assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)\n\n    def test_set_timesteps(self):\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        scheduler.set_timesteps(num_inference_steps=100, input_img_size_numel=16 * 16 * 16)\n        self.assertEqual(scheduler.num_inference_steps, 100)\n        self.assertEqual(len(scheduler.timesteps), 100)\n\n    def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):\n        scheduler = RFlowScheduler(num_train_timesteps=1000)\n        with self.assertRaises(ValueError):\n            scheduler.set_timesteps(num_inference_steps=2000, input_img_size_numel=16 * 16 * 16)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_bundle_onnx_export.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nfrom parameterized import parameterized\n\nfrom monai.bundle import ConfigParser\nfrom monai.networks import save_state\nfrom tests.test_utils import SkipIfNoModule, command_line_tests, skip_if_windows\n\nTEST_CASE_1 = [\"True\"]\nTEST_CASE_2 = [\"False\"]\n\n\n@skip_if_windows\n@SkipIfNoModule(\"onnx\")\nclass TestONNXExport(unittest.TestCase):\n    def setUp(self):\n        self.device = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n        if not self.device:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"  # default\n\n    def tearDown(self):\n        if self.device is not None:\n            os.environ[\"CUDA_VISIBLE_DEVICES\"] = self.device\n        else:\n            del os.environ[\"CUDA_VISIBLE_DEVICES\"]  # previously unset\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_onnx_export(self, use_trace):\n        tests_path = Path(__file__).parents[1]\n        meta_file = os.path.join(tests_path, \"testing_data\", \"metadata.json\")\n        config_file = os.path.join(tests_path, \"testing_data\", \"inference.json\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            def_args = {\"meta_file\": \"will be replaced by `meta_file` arg\"}\n            def_args_file = os.path.join(tempdir, \"def_args.yaml\")\n\n            ckpt_file = os.path.join(tempdir, \"model.pt\")\n            onnx_file = os.path.join(tempdir, \"model.onnx\")\n\n            parser = ConfigParser()\n            parser.export_config_file(config=def_args, filepath=def_args_file)\n            parser.read_config(config_file)\n            net = parser.get_parsed_content(\"network_def\")\n            save_state(src=net, path=ckpt_file)\n\n            cmd = [\"python\", \"-m\", \"monai.bundle\", \"onnx_export\", \"network_def\", \"--filepath\", onnx_file]\n            cmd += [\"--meta_file\", meta_file, \"--config_file\", f\"['{config_file}','{def_args_file}']\"]\n            cmd += [\"--ckpt_file\", ckpt_file, \"--args_file\", def_args_file, \"--input_shape\", \"[1, 1, 96, 96, 96]\"]\n            cmd += [\"--use_trace\", use_trace]\n            command_line_tests(cmd)\n            self.assertTrue(os.path.exists(onnx_file))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_convert_to_onnx.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport platform\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import convert_to_onnx\nfrom monai.networks.nets import SegResNet, UNet\nfrom tests.test_utils import SkipIfNoModule, optional_import, skip_if_quick\n\nonnx, _ = optional_import(\"onnx\")\n\nTORCH_DEVICE_OPTIONS = [\"cpu\"]\n\n# FIXME: CUDA seems to produce different model outputs during testing vs. ONNX outputs, use CPU only for now\n# if torch.cuda.is_available():\n#     TORCH_DEVICE_OPTIONS.append(\"cuda\")\n\nTESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))\nTESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))\n\nON_AARCH64 = platform.machine() == \"aarch64\"\nif ON_AARCH64:\n    rtol, atol = 1e-1, 1e-2\nelse:\n    rtol, atol = 1e-2, 1e-2\n\n\n@SkipIfNoModule(\"onnx\")\n@skip_if_quick\nclass TestConvertToOnnx(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_unet(self, device, use_trace, use_ort):\n        \"\"\"Test converting UNet to ONNX.\"\"\"\n        if use_ort:\n            _, has_onnxruntime = optional_import(\"onnxruntime\")\n            if not has_onnxruntime:\n                self.skipTest(\"onnxruntime is not installed probably due to python version >= 3.11.\")\n        model = UNet(\n            spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0\n        )\n\n        onnx_model = convert_to_onnx(\n            model=model,\n            inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],\n            input_names=[\"x\"],\n            output_names=[\"y\"],\n            verify=True,\n            device=device,\n            use_ort=use_ort,\n            use_trace=use_trace,\n            rtol=rtol,\n            atol=atol,\n        )\n        self.assertTrue(isinstance(onnx_model, onnx.ModelProto))\n\n    @parameterized.expand(TESTS_ORT)\n    def test_seg_res_net(self, device, use_ort):\n        \"\"\"Test converting SetResNet to ONNX.\"\"\"\n        if use_ort:\n            _, has_onnxruntime = optional_import(\"onnxruntime\")\n            if not has_onnxruntime:\n                self.skipTest(\"onnxruntime is not installed probably due to python version >= 3.11.\")\n        model = SegResNet(\n            spatial_dims=3,\n            init_filters=32,\n            in_channels=1,\n            out_channels=105,\n            dropout_prob=0.2,\n            act=(\"RELU\", {\"inplace\": True}),\n            norm=(\"GROUP\", {\"num_groups\": 8}),\n            norm_name=\"\",\n            num_groups=8,\n            use_conv_final=True,\n            blocks_down=[1, 2, 2, 4],\n            blocks_up=[1, 1, 1],\n        )\n        onnx_model = convert_to_onnx(\n            model=model,\n            inputs=[torch.randn((1, 1, 24, 24, 24), requires_grad=False)],\n            input_names=[\"x\"],\n            output_names=[\"y\"],\n            verify=True,\n            device=device,\n            use_ort=use_ort,\n            use_trace=True,\n            rtol=rtol,\n            atol=atol,\n        )\n        self.assertTrue(isinstance(onnx_model, onnx.ModelProto))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_convert_to_torchscript.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\n\nfrom monai.networks import convert_to_torchscript\nfrom monai.networks.nets import UNet\n\n\nclass TestConvertToTorchScript(unittest.TestCase):\n\n    def test_value(self):\n        model = UNet(\n            spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0\n        )\n        with tempfile.TemporaryDirectory() as tempdir:\n            torchscript_model = convert_to_torchscript(\n                model=model,\n                filename_or_obj=os.path.join(tempdir, \"model.ts\"),\n                extra_files={\"foo.txt\": b\"bar\"},\n                verify=True,\n                inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],\n                device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n                rtol=1e-3,\n                atol=1e-4,\n                optimize=None,\n            )\n            self.assertTrue(isinstance(torchscript_model, torch.nn.Module))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_convert_to_trt.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport tempfile\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import convert_to_trt\nfrom monai.networks.nets import UNet\nfrom monai.utils import optional_import\nfrom tests.test_utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows\n\n_, has_torchtrt = optional_import(\n    \"torch_tensorrt\",\n    version=\"1.4.0\",\n    descriptor=\"Torch-TRT is not installed. Are you sure you have a Torch-TensorRT compilation?\",\n)\n_, has_tensorrt = optional_import(\n    \"tensorrt\", descriptor=\"TensorRT is not installed. Are you sure you have a TensorRT compilation?\"\n)\n\nTEST_CASE_1 = [\"fp32\"]\nTEST_CASE_2 = [\"fp16\"]\n\n\n@skip_if_windows\n@skip_if_no_cuda\n@skip_if_quick\n@SkipIfBeforeComputeCapabilityVersion((7, 5))\nclass TestConvertToTRT(unittest.TestCase):\n    def setUp(self):\n        self.gpu_device = torch.cuda.current_device()\n\n    def tearDown(self):\n        current_device = torch.cuda.current_device()\n        if current_device != self.gpu_device:\n            torch.cuda.set_device(self.gpu_device)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    @unittest.skipUnless(has_torchtrt and has_tensorrt, \"Torch-TensorRT is required for convert!\")\n    def test_value(self, precision):\n        model = UNet(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=2,\n            channels=(2, 2, 4, 8, 4),\n            strides=(2, 2, 2, 2),\n            num_res_units=2,\n            norm=\"batch\",\n        )\n        with tempfile.TemporaryDirectory() as _:\n            torchscript_model = convert_to_trt(\n                model=model,\n                precision=precision,\n                input_shape=[1, 1, 96, 96, 96],\n                dynamic_batchsize=[1, 4, 8],\n                use_trace=False,\n                verify=True,\n                device=0,\n                rtol=1e-2,\n                atol=1e-2,\n            )\n            self.assertTrue(isinstance(torchscript_model, torch.nn.Module))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_save_state.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\nimport torch.optim as optim\nfrom parameterized import parameterized\n\nfrom monai.networks import save_state\n\nTEST_CASE_1 = [torch.nn.PReLU(), [\"weight\"]]\n\nTEST_CASE_2 = [{\"net\": torch.nn.PReLU()}, [\"net\"]]\n\nTEST_CASE_3 = [{\"net\": torch.nn.PReLU(), \"opt\": optim.SGD(torch.nn.PReLU().parameters(), lr=0.02)}, [\"net\", \"opt\"]]\n\nTEST_CASE_4 = [torch.nn.DataParallel(torch.nn.PReLU()), [\"weight\"]]\n\nTEST_CASE_5 = [{\"net\": torch.nn.DataParallel(torch.nn.PReLU())}, [\"net\"]]\n\nTEST_CASE_6 = [torch.nn.PReLU(), [\"weight\"], True, True, None, {\"pickle_protocol\": 2}]\n\nTEST_CASE_7 = [torch.nn.PReLU().state_dict(), [\"weight\"]]\n\nTEST_CASE_8 = [torch.nn.PReLU(), [\"weight\"], False]\n\nTEST_CASE_9 = [torch.nn.PReLU(), [\"weight\"], True, False]\n\nTEST_CASE_10 = [torch.nn.PReLU(), [\"weight\"], True, True, torch.save]\n\n\nclass TestSaveState(unittest.TestCase):\n\n    @parameterized.expand(\n        [\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n            TEST_CASE_10,\n        ]\n    )\n    def test_file(self, src, expected_keys, create_dir=True, atomic=True, func=None, kwargs=None):\n        with tempfile.TemporaryDirectory() as tempdir:\n            path = os.path.join(tempdir, \"test_ckpt.pt\")\n            if kwargs is None:\n                kwargs = {}\n            save_state(src=src, path=path, create_dir=create_dir, atomic=atomic, func=func, **kwargs)\n            ckpt = dict(torch.load(path, weights_only=True))\n            for k in ckpt.keys():\n                self.assertIn(k, expected_keys)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_to_onehot.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import one_hot\n\nTEST_CASE_1 = [  # single channel 2D, batch 3, shape (2, 1, 2, 2)\n    {\"labels\": torch.tensor([[[[0, 1], [1, 2]]], [[[2, 1], [1, 0]]]]), \"num_classes\": 3},\n    (2, 3, 2, 2),\n]\n\nTEST_CASE_2 = [  # single channel 1D, batch 2, shape (2, 1, 4)\n    {\"labels\": torch.tensor([[[1, 2, 2, 0]], [[2, 1, 0, 1]]]), \"num_classes\": 3},\n    (2, 3, 4),\n    np.array([[[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], [[0, 0, 1, 0], [0, 1, 0, 1], [1, 0, 0, 0]]]),\n]\n\nTEST_CASE_3 = [  # single channel 0D, batch 2, shape (2, 1)\n    {\"labels\": torch.tensor([[1.0], [2.0]]), \"num_classes\": 3},\n    (2, 3),\n    np.array([[0, 1, 0], [0, 0, 1]]),\n]\n\nTEST_CASE_4 = [  # no channel 0D, batch 3, shape (3)\n    {\"labels\": torch.tensor([1, 2, 0]), \"num_classes\": 3, \"dtype\": torch.long},\n    (3, 3),\n    np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]),\n]\n\n\nclass TestToOneHot(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_shape(self, input_data, expected_shape, expected_result=None):\n        result = one_hot(**input_data)\n        self.assertEqual(result.shape, expected_shape)\n        if expected_result is not None:\n            self.assertTrue(np.allclose(expected_result, result.numpy()))\n\n        if \"dtype\" in input_data:\n            self.assertEqual(result.dtype, input_data[\"dtype\"])\n        else:\n            # by default, expecting float type\n            self.assertEqual(result.dtype, torch.float)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/test_varnet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.apps.reconstruction.networks.nets.coil_sensitivity_model import CoilSensitivityModel\nfrom monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet\nfrom monai.apps.reconstruction.networks.nets.varnet import VariationalNetworkModel\nfrom monai.networks import eval_mode\nfrom tests.test_utils import test_script_save\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\ncoil_sens_model = CoilSensitivityModel(spatial_dims=2, features=[8, 16, 32, 64, 128, 8])\nrefinement_model = ComplexUnet(spatial_dims=2, features=[8, 16, 32, 64, 128, 8])\nnum_cascades = 2\nTESTS = []\nTESTS.append([coil_sens_model, refinement_model, num_cascades, (1, 3, 50, 50, 2), (1, 50, 50)])  # batch=1\nTESTS.append([coil_sens_model, refinement_model, num_cascades, (2, 3, 50, 50, 2), (2, 50, 50)])  # batch=2\n\n\nclass TestVarNet(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):\n        net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades).to(device)\n        mask_shape = [1 for _ in input_shape]\n        mask_shape[-2] = input_shape[-2]\n        mask = torch.zeros(mask_shape)\n        mask[..., mask_shape[-2] // 2 - 5 : mask_shape[-2] // 2 + 5, :] = 1\n\n        with eval_mode(net):\n            result = net(torch.randn(input_shape).to(device), mask.bool().to(device))\n        self.assertEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TESTS)\n    def test_script(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):\n        net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades)\n\n        mask_shape = [1 for _ in input_shape]\n        mask_shape[-2] = input_shape[-2]\n        mask = torch.zeros(mask_shape)\n        mask[..., mask_shape[-2] // 2 - 5 : mask_shape[-2] // 2 + 5, :] = 1\n\n        test_data = torch.randn(input_shape)\n\n        test_script_save(net, test_data, mask.bool())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/networks/utils/test_copy_model_state.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.utils import copy_model_state\nfrom monai.utils import set_determinism\n\n\nclass _TestModelOne(torch.nn.Module):\n\n    def __init__(self, n_n, n_m, n_class):\n        super().__init__()\n        self.layer = torch.nn.Linear(n_n, n_m)\n        self.class_layer = torch.nn.Linear(n_m, n_class)\n\n    def forward(self, x):\n        x = self.layer(x)\n        x = self.class_layer(x)\n        return x\n\n\nclass _TestModelTwo(torch.nn.Module):\n\n    def __init__(self, n_n, n_m, n_d, n_class):\n        super().__init__()\n        self.layer = torch.nn.Linear(n_n, n_m)\n        self.layer_1 = torch.nn.Linear(n_m, n_d)\n        self.class_layer = torch.nn.Linear(n_d, n_class)\n\n    def forward(self, x):\n        x = self.layer(x)\n        x = self.layer_1(x)\n        x = self.class_layer(x)\n        return x\n\n\nTEST_CASES = []\n__devices = (\"cpu\", \"cuda\") if torch.cuda.is_available() else (\"cpu\",)\nfor _x in __devices:\n    for _y in __devices:\n        TEST_CASES.append((_x, _y))\n\n\nclass TestModuleState(unittest.TestCase):\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES)\n    def test_set_state(self, device_0, device_1):\n        set_determinism(0)\n        model_one = _TestModelOne(10, 20, 3)\n        model_two = _TestModelTwo(10, 20, 10, 4)\n        model_one.to(device_0)\n        model_two.to(device_1)\n        model_dict, ch, unch = copy_model_state(model_one, model_two)\n        x = np.random.randn(4, 10)\n        x = torch.tensor(x, device=device_0, dtype=torch.float32)\n        output = model_one(x).detach().cpu().numpy()\n        expected = np.array(\n            [\n                [-0.36076584, -0.03177825, -0.7702266],\n                [-0.0526831, -0.15855855, -0.01149344],\n                [-0.3760508, -0.22485238, -0.0634037],\n                [0.5977675, -0.67991066, 0.1919502],\n            ]\n        )\n        np.testing.assert_allclose(output, expected, atol=1e-3)\n        self.assertEqual(len(ch), 2)\n        self.assertEqual(len(unch), 2)\n\n    @parameterized.expand(TEST_CASES)\n    def test_set_full_state(self, device_0, device_1):\n        set_determinism(0)\n        model_one = _TestModelOne(10, 20, 3)\n        model_two = _TestModelOne(10, 20, 3)\n        model_one.to(device_0)\n        model_two.to(device_1)\n        # test module input\n        model_dict, ch, unch = copy_model_state(model_one, model_two)\n        # test dict input\n        model_dict, ch, unch = copy_model_state(model_dict, model_two)\n        x = np.random.randn(4, 10)\n        x = torch.tensor(x, device=device_0, dtype=torch.float32)\n        output = model_one(x).detach().cpu().numpy()\n        model_two.to(device_0)\n        output_1 = model_two(x).detach().cpu().numpy()\n        np.testing.assert_allclose(output, output_1, atol=1e-3)\n        self.assertEqual(len(ch), 4)\n        self.assertEqual(len(unch), 0)\n\n    @parameterized.expand(TEST_CASES)\n    def test_set_exclude_vars(self, device_0, device_1):\n        set_determinism(0)\n        model_one = _TestModelOne(10, 20, 3)\n        model_two = _TestModelTwo(10, 20, 10, 4)\n        model_one.to(device_0)\n        model_two.to(device_1)\n        # test skip layer.bias\n        model_dict, ch, unch = copy_model_state(model_one, model_two, exclude_vars=\"layer.bias\")\n        x = np.random.randn(4, 10)\n        x = torch.tensor(x, device=device_0, dtype=torch.float32)\n        output = model_one(x).detach().cpu().numpy()\n        expected = np.array(\n            [\n                [-0.34172416, 0.0375042, -0.98340976],\n                [-0.03364138, -0.08927619, -0.2246768],\n                [-0.35700908, -0.15556987, -0.27658707],\n                [0.61680925, -0.6106281, -0.02123314],\n            ]\n        )\n        np.testing.assert_allclose(output, expected, atol=1e-3)\n        self.assertEqual(len(ch), 1)\n        self.assertEqual(len(unch), 3)\n\n    @parameterized.expand(TEST_CASES)\n    def test_set_map_across(self, device_0, device_1):\n        set_determinism(0)\n        model_one = _TestModelOne(10, 10, 3)\n        model_two = _TestModelTwo(10, 10, 10, 4)\n        model_one.to(device_0)\n        model_two.to(device_1)\n        # test weight map\n        model_dict, ch, unch = copy_model_state(\n            model_one, model_two, mapping={\"layer_1.weight\": \"layer.weight\", \"layer_1.bias\": \"layer_1.weight\"}\n        )\n        model_one.load_state_dict(model_dict)\n        x = np.random.randn(4, 10)\n        x = torch.tensor(x, device=device_0, dtype=torch.float32)\n        output = model_one(x).detach().cpu().numpy()\n        expected = np.array(\n            [\n                [0.8244487, -0.19650555, 0.65723234],\n                [0.71239626, 0.25617486, 0.5247122],\n                [0.24168758, 1.0301148, 0.39089814],\n                [0.25791705, 0.8653245, 0.14833644],\n            ]\n        )\n        np.testing.assert_allclose(output, expected, atol=1e-3)\n        self.assertEqual(len(ch), 2)\n        self.assertEqual(len(unch), 2)\n\n    @parameterized.expand(TEST_CASES)\n    def test_set_prefix(self, device_0, device_1):\n        set_determinism(0)\n        model_one = torch.nn.Sequential(_TestModelOne(10, 20, 3))\n        model_two = _TestModelTwo(10, 20, 10, 4)\n        model_one.to(device_0)\n        model_two.to(device_1)\n        # test skip layer.bias\n        model_dict, ch, unch = copy_model_state(\n            model_one, model_two, dst_prefix=\"0.\", exclude_vars=\"layer.bias\", inplace=False\n        )\n        model_one.load_state_dict(model_dict)\n        x = np.random.randn(4, 10)\n        x = torch.tensor(x, device=device_0, dtype=torch.float32)\n        output = model_one(x).detach().cpu().numpy()\n        expected = np.array(\n            [\n                [-0.360766, -0.031778, -0.770227],\n                [-0.052683, -0.158559, -0.011493],\n                [-0.376051, -0.224852, -0.063404],\n                [0.597767, -0.679911, 0.19195],\n            ]\n        )\n        np.testing.assert_allclose(output, expected, atol=1e-3)\n        self.assertEqual(len(ch), 2)\n        self.assertEqual(len(unch), 2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/utils/test_eval_mode.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.networks.utils import eval_mode\n\n\nclass TestEvalMode(unittest.TestCase):\n\n    def test_eval_mode(self):\n        t = torch.rand(1, 1, 4, 4)\n        p = torch.nn.Conv2d(1, 1, 3)\n        self.assertTrue(p.training)  # True\n        with eval_mode(p):\n            self.assertFalse(p.training)  # False\n            with self.assertRaises(RuntimeError):\n                p(t).sum().backward()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/utils/test_freeze_layers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.utils import freeze_layers\nfrom monai.utils import set_determinism\nfrom tests.networks.utils.test_copy_model_state import _TestModelOne, _TestModelTwo\n\nTEST_CASES = []\n__devices = (\"cpu\", \"cuda\") if torch.cuda.is_available() else (\"cpu\",)\nfor _x in __devices:\n    TEST_CASES.append(_x)\n\n\nclass TestModuleState(unittest.TestCase):\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_CASES)\n    def test_freeze_vars(self, device):\n        set_determinism(0)\n        model = _TestModelOne(10, 20, 3)\n        model.to(device)\n        freeze_layers(model, \"class\")\n\n        for name, param in model.named_parameters():\n            if \"class_layer\" in name:\n                self.assertFalse(param.requires_grad)\n            else:\n                self.assertTrue(param.requires_grad)\n\n    @parameterized.expand(TEST_CASES)\n    def test_exclude_vars(self, device):\n        set_determinism(0)\n        model = _TestModelTwo(10, 20, 10, 4)\n        model.to(device)\n        freeze_layers(model, exclude_vars=\"class\")\n\n        for name, param in model.named_parameters():\n            if \"class_layer\" in name:\n                self.assertTrue(param.requires_grad)\n            else:\n                self.assertFalse(param.requires_grad)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/utils/test_pixelunshuffle.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.networks.utils import pixelshuffle, pixelunshuffle\n\n\nclass TestPixelUnshuffle(unittest.TestCase):\n\n    def test_2d_basic(self):\n        x = torch.randn(2, 4, 16, 16)\n        out = pixelunshuffle(x, spatial_dims=2, scale_factor=2)\n        self.assertEqual(out.shape, (2, 16, 8, 8))\n\n    def test_3d_basic(self):\n        x = torch.randn(2, 4, 16, 16, 16)\n        out = pixelunshuffle(x, spatial_dims=3, scale_factor=2)\n        self.assertEqual(out.shape, (2, 32, 8, 8, 8))\n\n    def test_non_square_input(self):\n        x = torch.arange(192).reshape(1, 2, 12, 8)\n        out = pixelunshuffle(x, spatial_dims=2, scale_factor=2)\n        torch.testing.assert_close(out, torch.pixel_unshuffle(x, 2))\n\n    def test_different_scale_factor(self):\n        x = torch.arange(360).reshape(1, 2, 12, 15)\n        out = pixelunshuffle(x, spatial_dims=2, scale_factor=3)\n        torch.testing.assert_close(out, torch.pixel_unshuffle(x, 3))\n\n    def test_inverse_operation(self):\n        x = torch.arange(4096).reshape(1, 8, 8, 8, 8)\n        shuffled = pixelshuffle(x, spatial_dims=3, scale_factor=2)\n        unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2)\n        torch.testing.assert_close(x, unshuffled)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/utils/test_replace_module.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import DenseNet121\nfrom monai.networks.utils import replace_modules, replace_modules_temp\nfrom tests.test_utils import TEST_DEVICES\n\nTESTS = []\nfor device in TEST_DEVICES:\n    for match_device in (True, False):\n        # replace 1\n        TESTS.append((\"features.denseblock1.denselayer1.layers.relu1\", True, match_device, *device))\n        # replace 1 (but not strict)\n        TESTS.append((\"features.denseblock1.denselayer1.layers.relu1\", False, match_device, *device))\n        # replace multiple\n        TESTS.append((\"relu\", False, match_device, *device))\n\n\nclass TestReplaceModule(unittest.TestCase):\n    def setUp(self):\n        self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        self.num_relus = self.get_num_modules(torch.nn.ReLU)\n        self.total = self.get_num_modules()\n        self.assertGreater(self.num_relus, 0)\n\n    def get_num_modules(self, mod: type[torch.nn.Module] | None = None) -> int:\n        m = [m for _, m in self.net.named_modules()]\n        if mod is not None:\n            m = [_m for _m in m if isinstance(_m, mod)]\n        return len(m)\n\n    def check_replaced_modules(self, name, match_device):\n        # total num modules should remain the same\n        self.assertEqual(self.total, self.get_num_modules())\n        num_relus_mod = self.get_num_modules(torch.nn.ReLU)\n        num_softmax = self.get_num_modules(torch.nn.Softmax)\n        # list of returned modules should be as long as number of softmax\n        self.assertEqual(self.num_relus, num_relus_mod + num_softmax)\n        if name == \"relu\":\n            # at least 2 softmaxes\n            self.assertGreaterEqual(num_softmax, 2)\n        else:\n            # one softmax\n            self.assertEqual(num_softmax, 1)\n        if match_device:\n            self.assertEqual(len(list({i.device for i in self.net.parameters()})), 1)\n\n    @parameterized.expand(TESTS)\n    def test_replace(self, name, strict_match, match_device, device):\n        self.net.to(device)\n        # replace module(s)\n        replaced = replace_modules(self.net, name, torch.nn.Softmax(), strict_match, match_device)\n        self.check_replaced_modules(name, match_device)\n        # number of returned modules should equal number of softmax modules\n        self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax))\n        # all replaced modules should be ReLU\n        for r in replaced:\n            self.assertIsInstance(r[1], torch.nn.ReLU)\n        # if a specfic module was named, check that the name matches exactly\n        if name == \"features.denseblock1.denselayer1.layers.relu1\":\n            self.assertEqual(replaced[0][0], name)\n\n    @parameterized.expand(TESTS)\n    def test_replace_context_manager(self, name, strict_match, match_device, device):\n        self.net.to(device)\n        with replace_modules_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device):\n            self.check_replaced_modules(name, match_device)\n        # Check that model was correctly reverted\n        self.assertEqual(self.get_num_modules(), self.total)\n        self.assertEqual(self.get_num_modules(torch.nn.ReLU), self.num_relus)\n        self.assertEqual(self.get_num_modules(torch.nn.Softmax), 0)\n\n    def test_raises(self):\n        # name doesn't exist in module\n        with self.assertRaises(AttributeError):\n            replace_modules(self.net, \"non_existent_module\", torch.nn.Softmax(), strict_match=True)\n        with self.assertRaises(AttributeError):\n            with replace_modules_temp(self.net, \"non_existent_module\", torch.nn.Softmax(), strict_match=True):\n                pass\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/networks/utils/test_train_mode.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.networks.utils import train_mode\n\n\nclass TestEvalMode(unittest.TestCase):\n\n    def test_eval_mode(self):\n        t = torch.rand(1, 1, 4, 4)\n        p = torch.nn.Conv2d(1, 1, 3)\n        p.eval()\n        self.assertFalse(p.training)  # False\n        with train_mode(p):\n            self.assertTrue(p.training)  # True\n            p(t).sum().backward()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/nonconfig_workflow.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.bundle import BundleWorkflow, PythonicWorkflow\nfrom monai.data import DataLoader, Dataset\nfrom monai.engines import SupervisedEvaluator\nfrom monai.inferers import SlidingWindowInferer\nfrom monai.networks.nets import UNet\nfrom monai.transforms import (\n    Activationsd,\n    AsDiscreted,\n    Compose,\n    EnsureChannelFirstd,\n    LoadImaged,\n    SaveImaged,\n    ScaleIntensityd,\n    ScaleIntensityRanged,\n)\nfrom monai.utils import BundleProperty, CommonKeys, set_determinism\n\n\nclass NonConfigWorkflow(BundleWorkflow):\n    \"\"\"\n    Test class simulates the bundle workflow defined by Python script directly.\n\n    \"\"\"\n\n    def __init__(self, filename, output_dir, meta_file=None, logging_file=None):\n        super().__init__(workflow_type=\"inference\", meta_file=meta_file, logging_file=logging_file)\n        self.filename = filename\n        self.output_dir = output_dir\n        self._bundle_root = \"will override\"\n        self._dataset_dir = \".\"\n        self._device = torch.device(\"cpu\")\n        self._data = [{\"image\": self.filename}]\n        self._dataset = None\n        self._network_def = None\n        self._inferer = None\n        self._preprocessing = None\n        self._postprocessing = None\n        self._evaluator = None\n        self._version = None\n        self._monai_version = None\n        self._pytorch_version = None\n        self._numpy_version = None\n\n    def initialize(self):\n        set_determinism(0)\n        if self._version is None:\n            self._version = \"0.1.0\"\n\n        if self._monai_version is None:\n            self._monai_version = \"1.1.0\"\n\n        if self._pytorch_version is None:\n            self._pytorch_version = \"2.3.0\"\n\n        if self._numpy_version is None:\n            self._numpy_version = \"1.22.2\"\n\n        if self._preprocessing is None:\n            self._preprocessing = Compose(\n                [LoadImaged(keys=\"image\"), EnsureChannelFirstd(keys=\"image\"), ScaleIntensityd(keys=\"image\")]\n            )\n        self._dataset = Dataset(data=self._data, transform=self._preprocessing)\n        dataloader = DataLoader(self._dataset, batch_size=1, num_workers=4)\n\n        if self._network_def is None:\n            self._network_def = UNet(\n                spatial_dims=3,\n                in_channels=1,\n                out_channels=2,\n                channels=[2, 2, 4, 8, 4],\n                strides=[2, 2, 2, 2],\n                num_res_units=2,\n                norm=\"batch\",\n            )\n        if self._inferer is None:\n            self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)\n\n        if self._postprocessing is None:\n            self._postprocessing = Compose(\n                [\n                    Activationsd(keys=\"pred\", softmax=True),\n                    AsDiscreted(keys=\"pred\", argmax=True),\n                    SaveImaged(keys=\"pred\", output_dir=self.output_dir, output_postfix=\"seg\"),\n                ]\n            )\n\n        self._evaluator = SupervisedEvaluator(\n            device=self._device,\n            val_data_loader=dataloader,\n            network=self._network_def.to(self._device),\n            inferer=self._inferer,\n            postprocessing=self._postprocessing,\n            amp=False,\n        )\n\n    def run(self):\n        self._evaluator.run()\n\n    def finalize(self):\n        return True\n\n    def _get_property(self, name, property):\n        if name == \"bundle_root\":\n            return self._bundle_root\n        if name == \"dataset_dir\":\n            return self._dataset_dir\n        if name == \"dataset_data\":\n            return self._data\n        if name == \"dataset\":\n            return self._dataset\n        if name == \"device\":\n            return self._device\n        if name == \"evaluator\":\n            return self._evaluator\n        if name == \"network_def\":\n            return self._network_def\n        if name == \"inferer\":\n            return self._inferer\n        if name == \"preprocessing\":\n            return self._preprocessing\n        if name == \"postprocessing\":\n            return self._postprocessing\n        if name == \"version\":\n            return self._version\n        if name == \"monai_version\":\n            return self._monai_version\n        if name == \"pytorch_version\":\n            return self._pytorch_version\n        if name == \"numpy_version\":\n            return self._numpy_version\n        if property[BundleProperty.REQUIRED]:\n            raise ValueError(f\"unsupported property '{name}' is required in the bundle properties.\")\n\n    def _set_property(self, name, property, value):\n        if name == \"bundle_root\":\n            self._bundle_root = value\n        elif name == \"device\":\n            self._device = value\n        elif name == \"dataset_dir\":\n            self._dataset_dir = value\n        elif name == \"dataset_data\":\n            self._data = value\n        elif name == \"dataset\":\n            self._dataset = value\n        elif name == \"evaluator\":\n            self._evaluator = value\n        elif name == \"network_def\":\n            self._network_def = value\n        elif name == \"inferer\":\n            self._inferer = value\n        elif name == \"preprocessing\":\n            self._preprocessing = value\n        elif name == \"postprocessing\":\n            self._postprocessing = value\n        elif name == \"version\":\n            self._version = value\n        elif name == \"monai_version\":\n            self._monai_version = value\n        elif name == \"pytorch_version\":\n            self._pytorch_version = value\n        elif name == \"numpy_version\":\n            self._numpy_version = value\n        elif property[BundleProperty.REQUIRED]:\n            raise ValueError(f\"unsupported property '{name}' is required in the bundle properties.\")\n\n\nclass PythonicWorkflowImpl(PythonicWorkflow):\n    \"\"\"\n    Test class simulates the bundle workflow defined by Python script directly.\n    \"\"\"\n\n    def __init__(\n        self,\n        workflow_type: str = \"inference\",\n        config_file: str | None = None,\n        properties_path: str | None = None,\n        meta_file: str | None = None,\n    ):\n        super().__init__(\n            workflow_type=workflow_type, properties_path=properties_path, config_file=config_file, meta_file=meta_file\n        )\n        self.dataflow: dict = {}\n\n    def initialize(self):\n        self._props_vals = {}\n        self._is_initialized = True\n        self.net = UNet(\n            spatial_dims=3,\n            in_channels=1,\n            out_channels=2,\n            channels=(16, 32, 64, 128),\n            strides=(2, 2, 2),\n            num_res_units=2,\n        ).to(self.device)\n        preprocessing = Compose(\n            [\n                EnsureChannelFirstd(keys=[\"image\"]),\n                ScaleIntensityd(keys=\"image\"),\n                ScaleIntensityRanged(keys=\"image\", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n            ]\n        )\n        self.dataset = Dataset(data=[self.dataflow], transform=preprocessing)\n        self.postprocessing = Compose([Activationsd(keys=\"pred\", softmax=True), AsDiscreted(keys=\"pred\", argmax=True)])\n\n    def run(self):\n        data = self.dataset[0]\n        inputs = data[CommonKeys.IMAGE].unsqueeze(0).to(self.device)\n        self.net.eval()\n        with torch.no_grad():\n            data[CommonKeys.PRED] = self.inferer(inputs, self.net)\n        self.dataflow.update({CommonKeys.PRED: self.postprocessing(data)[CommonKeys.PRED]})\n\n    def finalize(self):\n        pass\n\n    def get_bundle_root(self):\n        return \".\"\n\n    def get_device(self):\n        return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    def get_inferer(self):\n        return SlidingWindowInferer(roi_size=self.parser.roi_size, sw_batch_size=1, overlap=0)\n"
  },
  {
    "path": "tests/optimizers/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/optimizers/test_generate_param_groups.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import Unet\nfrom monai.optimizers import generate_param_groups\nfrom monai.utils import ensure_tuple\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [{\"layer_matches\": [lambda x: x.model[-1]], \"match_types\": \"select\", \"lr_values\": [1]}, (1, 100), [5, 21]]\n\nTEST_CASE_2 = [\n    {\n        \"layer_matches\": [lambda x: x.model[-1], lambda x: x.model[-2], lambda x: x.model[-3]],\n        \"match_types\": \"select\",\n        \"lr_values\": [1, 2, 3],\n    },\n    (1, 2, 3, 100),\n    [5, 16, 5, 0],\n]\n\nTEST_CASE_3 = [\n    {\"layer_matches\": [lambda x: x.model[2][1].conv[0].conv], \"match_types\": [\"select\"], \"lr_values\": [1]},\n    (1, 100),\n    [2, 24],\n]\n\nTEST_CASE_4 = [\n    {\n        \"layer_matches\": [lambda x: x.model[0], lambda x: \"2.0.conv\" in x[0]],\n        \"match_types\": [\"select\", \"filter\"],\n        \"lr_values\": [1, 2],\n    },\n    (1, 2, 100),\n    [5, 4, 17],\n]\n\nTEST_CASE_5 = [\n    {\"layer_matches\": [lambda x: x.model[-1]], \"match_types\": [\"select\"], \"lr_values\": [1], \"include_others\": False},\n    (1),\n    [5],\n]\n\nTEST_CASE_6 = [\n    {\n        \"layer_matches\": [lambda x: \"weight\" in x[0]],\n        \"match_types\": [\"filter\"],\n        \"lr_values\": [1],\n        \"include_others\": True,\n    },\n    (1),\n    [16, 10],\n]\n\n\nclass TestGenerateParamGroups(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])\n    def test_lr_values(self, input_param, expected_values, expected_groups):\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        net = Unet(\n            spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1\n        ).to(device)\n\n        params = generate_param_groups(network=net, **input_param)\n        optimizer = torch.optim.Adam(params, 100)\n\n        for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)):\n            assert_allclose(param_group[\"lr\"], value)\n\n        n = [len(p[\"params\"]) for p in params]\n        self.assertListEqual(n, expected_groups)\n\n    def test_wrong(self):\n        \"\"\"overlapped\"\"\"\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        net = Unet(\n            spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1\n        ).to(device)\n\n        params = generate_param_groups(\n            network=net,\n            layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]],\n            match_types=\"select\",\n            lr_values=0.1,\n        )\n        with self.assertRaises(ValueError):\n            torch.optim.Adam(params, 100)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/optimizers/test_lr_finder.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport pickle\nimport random\nimport sys\nimport unittest\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom monai.apps import MedNISTDataset\nfrom monai.networks.nets import DenseNet\nfrom monai.optimizers import LearningRateFinder\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd, ToTensord\nfrom monai.utils import optional_import, set_determinism\nfrom monai.utils.misc import MONAIEnvVars\nfrom tests.test_utils import skip_if_downloading_fails\n\nif TYPE_CHECKING:\n    import matplotlib.pyplot as plt\n\n    has_matplotlib = True\n    has_pil = True\nelse:\n    plt, has_matplotlib = optional_import(\"matplotlib.pyplot\")\n    _, has_pil = optional_import(\"PIL.Image\")\n\nRAND_SEED = 42\nrandom.seed(RAND_SEED)\nset_determinism(seed=RAND_SEED)\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n\n@unittest.skipUnless(sys.platform == \"linux\", \"requires linux\")\n@unittest.skipUnless(has_pil, \"requires PIL\")\nclass TestLRFinder(unittest.TestCase):\n    def setUp(self):\n        self.root_dir = MONAIEnvVars.data_dir()\n        if not self.root_dir:\n            self.root_dir = Path(__file__).parents[1] / \"testing_data\"\n\n        self.transforms = Compose(\n            [\n                LoadImaged(keys=\"image\"),\n                EnsureChannelFirstd(keys=\"image\", channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=\"image\"),\n                ToTensord(keys=\"image\"),\n            ]\n        )\n\n    def test_lr_finder(self):\n        # 0.001 gives 54 examples\n        with skip_if_downloading_fails():\n            train_ds = MedNISTDataset(\n                root_dir=self.root_dir,\n                transform=self.transforms,\n                section=\"validation\",\n                val_frac=0.001,\n                download=True,\n                num_workers=2,\n            )\n        train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=2)\n        num_classes = train_ds.get_num_classes()\n\n        model = DenseNet(\n            spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,)\n        )\n        loss_function = torch.nn.CrossEntropyLoss()\n        learning_rate = 1e-5\n        optimizer = torch.optim.Adam(model.parameters(), learning_rate)\n\n        lr_finder = LearningRateFinder(\n            model=model,\n            optimizer=optimizer,\n            criterion=loss_function,\n            device=device,\n            pickle_module=pickle,\n            pickle_protocol=4,\n        )\n        lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10.0, num_iter=5)\n        print(lr_finder.get_steepest_gradient(0, 0)[0])\n\n        if has_matplotlib:\n            ax = plt.subplot()\n            plt.show(block=False)\n            lr_finder.plot(0, 0, ax=ax)  # to inspect the loss-learning rate graph\n            plt.pause(3)\n            plt.close()\n\n        lr_finder.reset()  # to reset the model and optimizer to their initial state\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/optimizers/test_lr_scheduler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.optimizers.lr_scheduler import WarmupCosineSchedule\n\n\nclass SchedulerTestNet(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self.conv1 = torch.nn.Conv2d(1, 1, 1)\n        self.conv2 = torch.nn.Conv2d(1, 1, 1)\n\n    def forward(self, x):\n        return self.conv2(torch.nn.functional.relu(self.conv1(x)))\n\n\nTEST_CASE_LRSCHEDULER = [\n    [{\"warmup_steps\": 2, \"t_total\": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]],\n    [\n        {\"warmup_steps\": 2, \"t_total\": 10, \"warmup_multiplier\": 0.1},\n        [0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038],\n    ],\n    [\n        {\"warmup_steps\": 2, \"t_total\": 10, \"warmup_multiplier\": 0.1, \"end_lr\": 0.309},\n        [0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.309, 0.309],\n    ],\n]\n\n\nclass TestLRSCHEDULER(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASE_LRSCHEDULER)\n    def test_shape(self, input_param, expected_lr):\n        net = SchedulerTestNet()\n        optimizer = torch.optim.Adam(net.parameters(), lr=1.0)\n        scheduler = WarmupCosineSchedule(optimizer, **input_param)\n        self.assertEqual(len([scheduler.get_last_lr()[0]]), 1)\n        lrs_1 = []\n        for _ in range(input_param[\"t_total\"]):\n            lrs_1.append(float(f\"{scheduler.get_last_lr()[0]:.3f}\"))\n            optimizer.step()\n            scheduler.step()\n        for a, b in zip(lrs_1, expected_lr):\n            self.assertEqual(a, b, msg=f\"LR is wrong ! expected {b}, got {a}\")\n\n    def test_error(self):\n        \"\"\"Should fail because warmup_multiplier is outside 0..1\"\"\"\n        net = SchedulerTestNet()\n        optimizer = torch.optim.Adam(net.parameters(), lr=1.0)\n        with self.assertRaises(ValueError):\n            WarmupCosineSchedule(optimizer, warmup_steps=2, t_total=10, warmup_multiplier=-1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/optimizers/test_optim_novograd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\nfrom torch.autograd import Variable\n\nfrom monai.optimizers import Novograd\n\n\ndef build_test_cases(data):\n    [weight, bias, input] = data\n    weight = Variable(weight, requires_grad=True)\n    bias = Variable(bias, requires_grad=True)\n    input = Variable(input)\n\n    default_params = {\"lr\": 1e-3, \"amsgrad\": False, \"grad_averaging\": False, \"weight_decay\": 0}\n\n    test_case_same_param = [{\"params\": [weight, bias]}]\n    test_case_diff_param = [\n        {\"params\": [weight]},\n        {\"params\": [bias], \"lr\": 1e-2, \"amsgrad\": True, \"grad_averaging\": True, \"weight_decay\": 0.1},\n    ]\n\n    test_cases = [\n        [test_case_same_param, default_params, weight, bias, input],\n        [test_case_diff_param, default_params, weight, bias, input],\n    ]\n    return test_cases\n\n\nTEST_CASES_ALL = build_test_cases([torch.randn(10, 5), torch.randn(10), torch.randn(5)])  # normal parameters\n\nTEST_CASES_ALL += build_test_cases(  # non-contiguous parameters\n    [torch.randn(10, 5, 2)[..., 0], torch.randn(10, 2)[..., 0], torch.randn(5)]\n)\n\nif torch.cuda.is_available():\n    TEST_CASES_ALL += build_test_cases(  # gpu parameters\n        [torch.randn(10, 5).cuda(), torch.randn(10).cuda(), torch.randn(5).cuda()]\n    )\nif torch.cuda.device_count() > 1:\n    TEST_CASES_ALL += build_test_cases(  # multi-gpu parameters\n        [torch.randn(10, 5).cuda(0), torch.randn(10).cuda(1), torch.randn(5).cuda(0)]\n    )\n\n\nclass TestNovograd(unittest.TestCase):\n    \"\"\"\n    This class takes `Pytorch's test_optim function:\n    https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_optim.py for reference.\n\n    \"\"\"\n\n    @parameterized.expand(TEST_CASES_ALL)\n    def test_step(self, specify_param, default_param, weight, bias, input):\n        optimizer = Novograd(specify_param, **default_param)\n\n        def fn():\n            optimizer.zero_grad()\n            y = weight.mv(input)\n            if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():\n                y = y.cuda(bias.get_device())\n            loss = (y + bias).pow(2).sum()\n            loss.backward()\n            return loss\n\n        initial_value = fn().item()\n        for _ in range(100):\n            optimizer.step(fn)\n        self.assertLess(fn().item(), initial_value)\n\n    def test_ill_arg(self):\n        param = {\"params\": [Variable(torch.randn(10), requires_grad=True)]}\n        with self.assertRaisesRegex(ValueError, \"Invalid learning rate: -1\"):\n            Novograd(param, lr=-1)\n        with self.assertRaisesRegex(ValueError, \"Invalid epsilon value: -1\"):\n            Novograd(param, eps=-1)\n        with self.assertRaisesRegex(ValueError, \"Invalid beta parameter at index 0: 1.0\"):\n            Novograd(param, betas=(1.0, 0.98))\n        with self.assertRaisesRegex(ValueError, \"Invalid beta parameter at index 1: -1\"):\n            Novograd(param, betas=(0.9, -1))\n        with self.assertRaisesRegex(ValueError, \"Invalid weight_decay value: -1\"):\n            Novograd(param, weight_decay=-1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/padders.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Compose\nfrom monai.transforms.lazy.functional import apply_pending\nfrom monai.transforms.transform import MapTransform\nfrom monai.utils.enums import NumpyPadMode, PytorchPadMode\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nMODES = []\n# Test modes\nNP_MODES: list = [\n    \"constant\",\n    \"edge\",\n    # `reflect` mode is not supported in some PyTorch versions, skip the test\n    # \"reflect\",\n    \"wrap\",\n    \"median\",\n    \"mean\",\n]\nMODES += NP_MODES\nMODES += [NumpyPadMode(i) for i in NP_MODES]\n\nPT_MODES: list = [\n    \"constant\",\n    \"replicate\",\n    \"circular\",\n    # `reflect` mode is not supported in some PyTorch versions, skip the test\n    # \"reflect\",\n]\nMODES += PT_MODES\nMODES += [PytorchPadMode(i) for i in PT_MODES]\n\nTESTS_PENDING_MODE = [[\"constant\", \"zeros\"], [\"edge\", \"border\"]]\n\n\nclass PadTest(unittest.TestCase):\n    @staticmethod\n    def get_arr(shape):\n        return np.random.randint(100, size=shape).astype(float)\n\n    def pad_test(self, input_param, input_shape, expected_shape, modes=None):\n        # loop over each mode\n        for mode in modes or MODES:\n            with self.subTest(mode=mode):\n                base_comparison = None\n                im = self.get_arr(input_shape)\n                padder = self.Padder(mode=mode, **input_param)\n                is_map = isinstance(padder, MapTransform)\n                # check result is the same regardless of input type\n                for im_type in TEST_NDARRAYS_ALL:\n                    with self.subTest(im_type=im_type):\n                        input_image = im_type(im)\n                        input_data = {\"img\": im_type(im)} if is_map else im_type(im)\n                        # our array transforms can also take `mode` as an argument to `__call__`\n                        # Check this gives equivalent results\n                        for call_extra_args in [{}] if is_map else [{}, {\"mode\": mode}]:\n                            with self.subTest(call_extra_args=call_extra_args):\n                                r_out = padder(input_data, **call_extra_args)\n                                r_im = r_out[\"img\"] if is_map else r_out\n                                # check shape, type, etc.\n                                np.testing.assert_allclose(r_im.shape, expected_shape)\n                                self.assertIsInstance(r_im, MetaTensor)\n                                self.assertEqual(len(r_im.applied_operations), 1)\n                                # check results are same regardless of input type\n                                if base_comparison is None:\n                                    base_comparison = r_im\n                                else:\n                                    assert_allclose(r_im, base_comparison)\n                                # test inverse\n                                if isinstance(r_im, MetaTensor):\n                                    r_out = padder.inverse(r_out)\n                                    r_im = r_out[\"img\"] if is_map else r_out\n                                    self.assertIsInstance(r_im, MetaTensor)\n                                    assert_allclose(r_im, input_image, type_test=False)\n                                    self.assertEqual(r_im.applied_operations, [])\n\n    def pad_test_kwargs(self, unchanged_slices, **input_param):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                for kwargs in ({\"value\": 2}, {\"constant_values\": ((0, 0), (1, 1), (2, 2))}):\n                    with self.subTest(kwargs=kwargs):\n                        im = im_type(np.random.randint(-100, -10, size=(3, 8, 4)))\n                        padder = self.Padder(**input_param, **kwargs)\n                        result = padder(im)\n                        if isinstance(result, torch.Tensor):\n                            result = result.cpu()\n                        assert_allclose(result[unchanged_slices], im, type_test=False)\n                        # we should have the same as the input plus some 2s (if value) or 1s and 2s (if constant_values)\n                        if isinstance(im, torch.Tensor):\n                            im = im.detach().cpu().numpy()\n                        expected_vals = np.unique(im).tolist()\n                        expected_vals += [2] if \"value\" in kwargs else [1, 2]\n                        assert_allclose(np.unique(result), expected_vals, type_test=False)\n                        # check inverse\n                        if isinstance(result, MetaTensor):\n                            inv = padder.inverse(result)\n                            assert_allclose(im, inv, type_test=False)\n                            self.assertEqual(inv.applied_operations, [])\n\n    def pad_test_pending_ops(self, input_param, input_shape):\n        for mode in TESTS_PENDING_MODE:\n            # TODO: One of the dim in the input data contains 1 report error.\n            pad_fn = self.Padder(mode=mode[0], **input_param)\n            data = self.get_arr(input_shape)\n            is_map = isinstance(pad_fn, MapTransform)\n            im = MetaTensor(data, meta={\"a\": \"b\", \"affine\": np.eye(len(input_shape))})\n            input_data = {\"img\": im} if is_map else im\n            # non-lazy\n            result_non_lazy = pad_fn(input_data)\n            expected = result_non_lazy[\"img\"] if is_map else result_non_lazy\n            self.assertIsInstance(expected, MetaTensor)\n            # lazy\n            pad_fn.lazy = True\n            pending_result = pad_fn(input_data)\n            pending_result = pending_result[\"img\"] if is_map else pending_result\n            self.assertIsInstance(pending_result, MetaTensor)\n            assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n            assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n            # TODO: mode=\"bilinear\" may report error\n            overrides = {\"mode\": \"nearest\", \"padding_mode\": mode[1], \"align_corners\": False}\n            result = apply_pending(pending_result, overrides=overrides)[0]\n            # lazy in constructor\n            pad_fn_lazy = self.Padder(mode=mode[0], lazy=True, **input_param)\n            self.assertTrue(pad_fn_lazy.lazy)\n            # compare\n            assert_allclose(result, expected, rtol=1e-5)\n            if isinstance(result, MetaTensor) and not isinstance(pad_fn, MapTransform):\n                pad_fn.lazy = False\n                inverted = pad_fn.inverse(result)\n                self.assertTrue((not inverted.pending_operations) and (not inverted.applied_operations))\n                self.assertEqual(inverted.shape, im.shape)\n\n    def pad_test_combine_ops(self, funcs, input_shape, expected_shape):\n        for mode in TESTS_PENDING_MODE:\n            # non-lazy\n            _funcs = []\n            for func in funcs:\n                for _func, _params in func.items():\n                    _funcs.append(_func(mode=mode[0], **_params))\n            trans = Compose(_funcs)\n            data = self.get_arr(input_shape)\n            is_map = isinstance(_funcs[0], MapTransform)\n            im = MetaTensor(data, meta={\"a\": \"b\", \"affine\": np.eye(len(input_shape))})\n            input_data = {\"img\": im} if is_map else im\n            result_non_lazy = trans(input_data)\n            expected = result_non_lazy[\"img\"] if is_map else result_non_lazy\n            self.assertIsInstance(expected, MetaTensor)\n            # lazy\n            pending_result = input_data\n            for _func in _funcs:\n                _func.lazy = True\n                pending_result = _func(pending_result)\n            pending_result = pending_result[\"img\"] if is_map else pending_result\n            self.assertIsInstance(pending_result, MetaTensor)\n            assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n            assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n            # TODO: mode=\"bilinear\" may report error\n            overrides = {\"mode\": \"nearest\", \"padding_mode\": mode[1], \"align_corners\": False}\n            result = apply_pending(pending_result, overrides=overrides)[0]\n            # compare\n            assert_allclose(result, expected, rtol=1e-5)\n"
  },
  {
    "path": "tests/profile_subclass/README.md",
    "content": "# Profiling the performance of subclassing/`__torch_function__` in MONAI\n\n## Requirements\n```bash\npip install py-spy\npip install snakeviz  # for viewing the cProfile results\n```\n\n## Commands\n\n### Install MONAI\n```\n./runtests.sh --build   # from monai's root directory\n```\nor follow the installation guide (https://monai.readthedocs.io/en/latest/installation.html)\n\n### Profiling the task of adding two MetaTensors\n```bash\npython profiling.py\n```\n\n### Profiling using `py-spy`\n```bash\npy-spy record -o Tensor.svg -- python pyspy_profiling.py Tensor\npy-spy record -o SubTensor.svg -- python pyspy_profiling.py SubTensor\npy-spy record -o SubWithTorchFunc.svg -- python pyspy_profiling.py SubWithTorchFunc\npy-spy record -o MetaTensor.svg -- python pyspy_profiling.py MetaTensor\n```\n\n### Profiling using `cProfile` and `SNAKEVIZ`\n\n```bash\npython cprofile_profiling.py\nsnakeviz out_200.prof\n```\n\n---\nThese tests are based on the following code:\nhttps://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark\n\n- Overhead for torch functions when run on `torch.Tensor` objects is on the order of 2 microseconds.\n- `__torch_function__` should add zero overhead for `torch.Tensor` inputs, a small overhead for subclasses of `torch.Tensor`, and an order of microseconds for `MeatTensor`.\n- Changing the dispatching mechanism may result in changes that are on the order of 100 ns, which are hard to detect due to noise, but important.\n"
  },
  {
    "path": "tests/profile_subclass/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/profile_subclass/cprofile_profiling.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nProfiling MetaTensor\n\"\"\"\n\nfrom __future__ import annotations\n\nimport cProfile\n\nimport torch\n\nfrom monai.data.meta_tensor import MetaTensor\n\nif __name__ == \"__main__\":\n    n_chan = 3\n    for hwd in (10, 200):\n        shape = (n_chan, hwd, hwd, hwd)\n        a = MetaTensor(torch.rand(shape), meta={\"affine\": torch.eye(4) * 2, \"fname\": \"something1\"})\n        b = MetaTensor(torch.rand(shape), meta={\"affine\": torch.eye(4) * 3, \"fname\": \"something2\"})\n        cProfile.run(\"c = a + b\", filename=f\"out_{hwd}.prof\")\n"
  },
  {
    "path": "tests/profile_subclass/min_classes.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nMinimal subclassing as baselines\nAdapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark\n\"\"\"\n\nfrom __future__ import annotations\n\nimport torch\n\n__all__ = [\"SubTensor\", \"SubWithTorchFunc\"]\n\n\nclass SubTensor(torch.Tensor):\n    pass\n\n\nclass SubWithTorchFunc(torch.Tensor):\n\n    def __torch_function__(self, func, types, args=(), kwargs=None):\n        return super().__torch_function__(func, types, args, {} if kwargs is None else kwargs)\n"
  },
  {
    "path": "tests/profile_subclass/profiling.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nComparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor\nAdapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark\n\"\"\"\nfrom __future__ import annotations\n\nimport argparse\n\nimport torch\nfrom min_classes import SubTensor, SubWithTorchFunc\n\nfrom monai.data import MetaTensor\nfrom monai.utils.profiling import PerfContext\n\nNUM_REPEATS = 1000\nNUM_REPEAT_OF_REPEATS = 1000\n\n\ndef bench(t1, t2):\n    bench_times = []\n    for _ in range(NUM_REPEAT_OF_REPEATS):\n        with PerfContext() as pc:\n            for _ in range(NUM_REPEATS):\n                torch.add(t1, t2)\n        bench_times.append(pc.total_time)\n\n    bench_time_min = float(torch.min(torch.Tensor(bench_times))) / NUM_REPEATS\n    bench_time_avg = float(torch.sum(torch.Tensor(bench_times))) / (NUM_REPEATS * NUM_REPEAT_OF_REPEATS)\n    bench_time_med = float(torch.median(torch.Tensor(bench_times))) / NUM_REPEATS\n    bench_std = float(torch.std(torch.Tensor(bench_times))) / NUM_REPEATS\n    return bench_time_min, bench_time_avg, bench_time_med, bench_std\n\n\ndef main():\n    global NUM_REPEATS\n    global NUM_REPEAT_OF_REPEATS\n\n    parser = argparse.ArgumentParser(description=\"Run the __torch_function__ benchmarks.\")\n    parser.add_argument(\n        \"--nreps\", \"-n\", type=int, default=NUM_REPEATS, help=\"The number of repeats for one measurement.\"\n    )\n    parser.add_argument(\"--nrepreps\", \"-m\", type=int, default=NUM_REPEAT_OF_REPEATS, help=\"The number of measurements.\")\n    args = parser.parse_args()\n\n    NUM_REPEATS = args.nreps\n    NUM_REPEAT_OF_REPEATS = args.nrepreps\n\n    types = torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor\n\n    for t in types:\n        tensor_1 = t(1)\n        tensor_2 = t(2)\n\n        b_min, b_avg, b_med, b_std = bench(tensor_1, tensor_2)\n        print(\n            f\"Type {t.__name__} time (microseconds):\"\n            f\"  min: {10**6 * b_min}, avg: {(10**6) * b_avg}, median: {(10**6) * b_med}, and std {(10**6) * b_std}.\"\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/profile_subclass/pyspy_profiling.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTo be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor\nAdapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark\n\"\"\"\nfrom __future__ import annotations\n\nimport argparse\n\nimport torch\n\nTensor = torch.Tensor\n\nNUM_REPEATS = 1000000\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run the torch.add for a given class a given number of times.\")\n    parser.add_argument(\"tensor_class\", metavar=\"TensorClass\", type=str, help=\"The class to benchmark.\")\n    parser.add_argument(\"--nreps\", \"-n\", type=int, default=NUM_REPEATS, help=\"The number of repeats.\")\n    args = parser.parse_args()\n\n    TensorClass = globals()[args.tensor_class]\n    NUM_REPEATS = args.nreps\n\n    t1 = TensorClass(1)\n    t2 = TensorClass(2)\n\n    for _ in range(NUM_REPEATS):\n        torch.add(t1, t2)\n"
  },
  {
    "path": "tests/runner.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport argparse\nimport inspect\nimport os\nimport re\nimport sys\nimport time\nimport unittest\nfrom pathlib import Path\n\nfrom monai.utils import PerfContext\n\nresults: dict = {}\n\n\nclass TimeLoggingTestResult(unittest.TextTestResult):\n    \"\"\"Overload the default results so that we can store the results.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.timed_tests = {}\n\n    def startTest(self, test):  # noqa: N802\n        \"\"\"Start timer, print test name, do normal test.\"\"\"\n        self.start_time = time.time()\n        name = self.getDescription(test)\n        self.stream.write(f\"Starting test: {name}...\\n\")\n        super().startTest(test)\n\n    def stopTest(self, test):  # noqa: N802\n        \"\"\"On test end, get time, print, store and do normal behaviour.\"\"\"\n        elapsed = time.time() - self.start_time\n        name = self.getDescription(test)\n        self.stream.write(f\"Finished test: {name} ({elapsed:.03}s)\\n\")\n        if name in results:\n            raise AssertionError(f\"expected all keys to be unique, but {name} is duplicated\")\n        results[name] = elapsed\n        super().stopTest(test)\n\n\ndef print_results(results, discovery_time, thresh, status):\n    # only keep results >= threshold\n    results = dict(filter(lambda x: x[1] > thresh, results.items()))\n    if len(results) == 0:\n        return\n    print(f\"\\n\\n{status}, printing completed times >{thresh}s in ascending order...\\n\")\n    timings = dict(sorted(results.items(), key=lambda item: item[1]))\n\n    for r in timings:\n        if timings[r] >= thresh:\n            print(f\"{r} ({timings[r]:.03}s)\")\n    print(f\"test discovery time: {discovery_time:.03}s\")\n    print(f\"total testing time: {sum(results.values()):.03}s\")\n    print(\"Remember to check above times for any errors!\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Runner for MONAI unittests with timing.\")\n    parser.add_argument(\n        \"-s\", action=\"store\", dest=\"path\", default=\".\", help=\"Directory to start discovery (default: '%(default)s')\"\n    )\n    parser.add_argument(\n        \"-p\",\n        action=\"store\",\n        dest=\"pattern\",\n        default=\"test_*.py\",\n        help=\"Pattern to match tests (default: '%(default)s')\",\n    )\n    parser.add_argument(\n        \"-t\",\n        \"--thresh\",\n        dest=\"thresh\",\n        default=10.0,\n        type=float,\n        help=\"Display tests longer than given threshold (default: %(default)d)\",\n    )\n    parser.add_argument(\n        \"-v\",\n        \"--verbosity\",\n        action=\"store\",\n        dest=\"verbosity\",\n        type=int,\n        default=1,\n        help=\"Verbosity level (default: %(default)d)\",\n    )\n    parser.add_argument(\"-q\", \"--quick\", action=\"store_true\", dest=\"quick\", default=False, help=\"Only do quick tests\")\n    parser.add_argument(\n        \"-f\", \"--failfast\", action=\"store_true\", dest=\"failfast\", default=False, help=\"Stop testing on first failure\"\n    )\n    args = parser.parse_args()\n    print(f\"Running tests in folder: '{args.path}'\")\n    if args.pattern:\n        print(f\"With file pattern: '{args.pattern}'\")\n\n    return args\n\n\ndef get_default_pattern(loader):\n    signature = inspect.signature(loader.discover)\n    params = {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}\n    return params[\"pattern\"]\n\n\nif __name__ == \"__main__\":\n    # Parse input arguments\n    args = parse_args()\n\n    # If quick is desired, set environment variable\n    if args.quick:\n        os.environ[\"QUICKTEST\"] = \"True\"\n\n    # Get all test names (optionally from some path with some pattern)\n    with PerfContext() as pc:\n        # the files are searched from `tests/` folder, starting with `test_`\n        tests_path = Path(__file__).parent / args.path\n        files = {\n            file.relative_to(tests_path).as_posix()\n            for file in tests_path.rglob(\"test_*py\")\n            if re.search(args.pattern, file.name[:-3])\n        }\n        print(files)\n        cases = []\n        for test_module in tests_path.rglob(\"test_*py\"):\n            test_file = str(test_module.relative_to(tests_path).as_posix())\n            case_str = test_file.replace(\"/\", \".\")[:-3]\n            case_str = f\"tests.{case_str}\"\n            if test_file in files:\n                cases.append(case_str)\n            else:\n                print(f\"monai test runner: excluding {test_module.name}\")\n        print(cases)\n        tests = unittest.TestLoader().loadTestsFromNames(cases)\n    discovery_time = pc.total_time\n    print(f\"time to discover tests: {discovery_time}s, total cases: {tests.countTestCases()}.\")\n\n    test_runner = unittest.runner.TextTestRunner(\n        resultclass=TimeLoggingTestResult, verbosity=args.verbosity, failfast=args.failfast\n    )\n    # Use try catches to print the current results if encountering exception or keyboard interruption\n    try:\n        test_result = test_runner.run(tests)\n        print_results(results, discovery_time, args.thresh, \"tests finished\")\n        sys.exit(not test_result.wasSuccessful())\n    except KeyboardInterrupt:\n        print_results(results, discovery_time, args.thresh, \"tests cancelled\")\n        sys.exit(1)\n    except Exception:\n        print_results(results, discovery_time, args.thresh, \"exception reached\")\n        raise\n"
  },
  {
    "path": "tests/test_call_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedCallTest(DistTestCase):\n    def test_constructor(self):\n        with self.assertRaises(ValueError):\n            DistCall(nnodes=1, nproc_per_node=0)\n        with self.assertRaises(ValueError):\n            DistCall(nnodes=0, nproc_per_node=0)\n        with self.assertRaises(ValueError):\n            DistCall(nnodes=0, nproc_per_node=1)\n        _ = DistCall(nnodes=1, nproc_per_node=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_query_memory.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom tests.test_utils import query_memory\n\n\nclass TestQueryMemory(unittest.TestCase):\n    def test_output_str(self):\n        self.assertTrue(isinstance(query_memory(2), str))\n        all_device = query_memory(-1)\n        self.assertTrue(isinstance(all_device, str))\n        self.assertEqual(query_memory(\"test\"), \"\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_timedcall_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport multiprocessing\nimport sys\nimport time\nimport unittest\n\nfrom tests.test_utils import TimedCall\n\n\n@TimedCall(seconds=20 if sys.platform == \"linux\" else 60, force_quit=False)\ndef case_1_seconds(arg=None):\n    time.sleep(1)\n    return \"good\" if not arg else arg\n\n\n@TimedCall(seconds=0.1, skip_timing=True, force_quit=True)\ndef case_1_seconds_skip(arg=None):\n    time.sleep(1)\n    return \"good\" if not arg else arg\n\n\n@TimedCall(seconds=0.1, force_quit=True)\ndef case_1_seconds_timeout(arg=None):\n    time.sleep(1)\n    return \"good\" if not arg else arg\n\n\n@TimedCall(seconds=0.1, force_quit=False)\ndef case_1_seconds_timeout_warning(arg=None):\n    time.sleep(1)\n    return \"good\" if not arg else arg\n\n\n@TimedCall(seconds=0.1, force_quit=True)\ndef case_1_seconds_bad(arg=None):\n    time.sleep(1)\n    assert 0 == 1, \"wrong case\"\n\n\nclass TestTimedCall(unittest.TestCase):\n    def test_good_call(self):\n        output = case_1_seconds()\n        self.assertEqual(output, \"good\")\n\n    def test_skip_timing(self):\n        output = case_1_seconds_skip(\"testing\")\n        self.assertEqual(output, \"testing\")\n\n    def test_timeout(self):\n        with self.assertRaises(multiprocessing.TimeoutError):\n            case_1_seconds_timeout()\n\n    def test_timeout_not_force_quit(self):\n        with self.assertWarns(Warning):\n            with self.assertRaises(multiprocessing.TimeoutError):\n                case_1_seconds_timeout_warning()\n\n    def test_timeout_bad(self):\n        # timeout before the method's error\n        with self.assertRaises(multiprocessing.TimeoutError):\n            case_1_seconds_bad()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport argparse\nimport copy\nimport datetime\nimport functools\nimport importlib\nimport json\nimport operator\nimport os\nimport queue\nimport ssl\nimport subprocess\nimport sys\nimport tempfile\nimport time\nimport traceback\nimport unittest\nimport warnings\nfrom collections.abc import Iterable\nfrom contextlib import contextmanager\nfrom functools import partial, reduce\nfrom itertools import product\nfrom pathlib import Path\nfrom subprocess import PIPE, Popen\nfrom typing import Any, Callable\nfrom urllib.error import ContentTooShortError, HTTPError\n\nimport numpy as np\nimport torch\nimport torch.distributed as dist\n\nfrom monai.apps.utils import download_url\nfrom monai.config import NdarrayTensor\nfrom monai.config.deviceconfig import USE_COMPILED\nfrom monai.config.type_definitions import NdarrayOrTensor\nfrom monai.data import create_test_image_2d, create_test_image_3d\nfrom monai.data.meta_tensor import MetaTensor, get_track_meta\nfrom monai.networks import convert_to_onnx, convert_to_torchscript\nfrom monai.utils import optional_import\nfrom monai.utils.misc import MONAIEnvVars\nfrom monai.utils.module import compute_capabilities_after, pytorch_after\nfrom monai.utils.tf32 import detect_default_tf32\nfrom monai.utils.type_conversion import convert_data_type\n\nnib, _ = optional_import(\"nibabel\")\nhttp_error, has_req = optional_import(\"requests\", name=\"HTTPError\")\nfile_url_error, has_gdown = optional_import(\"gdown.exceptions\", name=\"FileURLRetrievalError\")\nhf_http_error, has_hf_hub = optional_import(\"huggingface_hub.errors\", name=\"HfHubHTTPError\")\nhf_local_entry_error, _has_hf_local = optional_import(\"huggingface_hub.errors\", name=\"LocalEntryNotFoundError\")\n\n\nquick_test_var = \"QUICKTEST\"\n_tf32_enabled = None\n_test_data_config: dict = {}\n\nMODULE_PATH = Path(__file__).resolve().parents[1]\n\nDOWNLOAD_EXCEPTS: tuple[type, ...] = (ContentTooShortError, HTTPError, ConnectionError)\nif has_req:\n    DOWNLOAD_EXCEPTS += (http_error,)\nif has_gdown:\n    DOWNLOAD_EXCEPTS += (file_url_error,)\nif has_hf_hub:\n    DOWNLOAD_EXCEPTS += (hf_http_error, hf_local_entry_error)\n\nDOWNLOAD_FAIL_MSGS = (\n    \"unexpected EOF\",  # incomplete download\n    \"network issue\",\n    \"gdown dependency\",  # gdown not installed\n    \"md5 check\",\n    \"limit\",  # HTTP Error 503: Egress is over the account limit\n    \"authenticate\",\n    \"timed out\",  # urlopen error [Errno 110] Connection timed out\n    \"HTTPError\",  # HTTPError: 429 Client Error: Too Many Requests for huggingface hub\n)\n\n\ndef testing_data_config(*keys):\n    \"\"\"get _test_data_config[keys0][keys1]...[keysN]\"\"\"\n    if not _test_data_config:\n        with open(f\"{MODULE_PATH}/tests/testing_data/data_config.json\") as c:\n            _config = json.load(c)\n            for k, v in _config.items():\n                _test_data_config[k] = v\n    return reduce(operator.getitem, keys, _test_data_config)\n\n\ndef get_testing_algo_template_path():\n    \"\"\"\n    a local folder to the testing algorithm template or a url to the compressed template file.\n    Default to None, which effectively uses bundle_gen's ``default_algo_zip`` path.\n\n    https://github.com/Project-MONAI/MONAI/blob/1.1.0/monai/apps/auto3dseg/bundle_gen.py#L380-L381\n    \"\"\"\n    return MONAIEnvVars.testing_algo_template()\n\n\ndef clone(data: NdarrayTensor) -> NdarrayTensor:\n    \"\"\"\n    Clone data independent of type.\n\n    Args:\n        data (NdarrayTensor): This can be a Pytorch Tensor or numpy array.\n\n    Returns:\n        Any: Cloned data object\n    \"\"\"\n    return copy.deepcopy(data)\n\n\ndef assert_allclose(\n    actual: NdarrayOrTensor,\n    desired: NdarrayOrTensor,\n    type_test: bool | str = True,\n    device_test: bool = False,\n    *args,\n    **kwargs,\n):\n    \"\"\"\n    Assert that types and all values of two data objects are close.\n\n    Args:\n        actual: Pytorch Tensor or numpy array for comparison.\n        desired: Pytorch Tensor or numpy array to compare against.\n        type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors.\n            if type_test == \"tensor\", it checks whether the `actual` is a torch.tensor or metatensor according to\n            `get_track_meta`.\n        device_test: whether to test the device property.\n        args: extra arguments to pass on to `np.testing.assert_allclose`.\n        kwargs: extra arguments to pass on to `np.testing.assert_allclose`.\n\n\n    \"\"\"\n    if isinstance(type_test, str) and type_test == \"tensor\":\n        if get_track_meta():\n            np.testing.assert_equal(isinstance(actual, MetaTensor), True, \"must be a MetaTensor\")\n        else:\n            np.testing.assert_equal(\n                isinstance(actual, torch.Tensor) and not isinstance(actual, MetaTensor), True, \"must be a torch.Tensor\"\n            )\n    elif type_test:\n        # check both actual and desired are of the same type\n        np.testing.assert_equal(isinstance(actual, np.ndarray), isinstance(desired, np.ndarray), \"numpy type\")\n        np.testing.assert_equal(isinstance(actual, torch.Tensor), isinstance(desired, torch.Tensor), \"torch type\")\n\n    if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor):\n        if device_test:\n            np.testing.assert_equal(str(actual.device), str(desired.device), \"torch device check\")  # type: ignore\n        actual = actual.detach().cpu().numpy() if isinstance(actual, torch.Tensor) else actual\n        desired = desired.detach().cpu().numpy() if isinstance(desired, torch.Tensor) else desired\n    np.testing.assert_allclose(actual, desired, *args, **kwargs)\n\n\n@contextmanager\ndef skip_if_downloading_fails():\n    \"\"\"\n    Skips a test if downloading something raises an exception recognised to indicate a download has failed.\n    \"\"\"\n\n    try:\n        yield\n    except DOWNLOAD_EXCEPTS as e:\n        raise unittest.SkipTest(f\"Error while downloading: {e}\") from e\n    except ssl.SSLError as ssl_e:\n        if \"decryption failed\" in str(ssl_e):\n            raise unittest.SkipTest(f\"SSL error while downloading: {ssl_e}\") from ssl_e\n    except (RuntimeError, OSError) as rt_e:\n        err_str = str(rt_e)\n        if any(k in err_str for k in DOWNLOAD_FAIL_MSGS):\n            raise unittest.SkipTest(f\"Error while downloading: {rt_e}\") from rt_e  # incomplete download\n\n        raise rt_e\n\n\ndef test_pretrained_networks(network, input_param, device):\n    with skip_if_downloading_fails():\n        return network(**input_param).to(device)\n\n\ndef test_is_quick():\n    return os.environ.get(quick_test_var, \"\").lower() == \"true\"\n\n\ndef is_tf32_env():\n    \"\"\"\n    When we may be using TF32 mode, check the precision of matrix operation.\n    If the checking result is greater than the threshold 0.001,\n    set _tf32_enabled=True (and relax _rtol for tests).\n    \"\"\"\n    global _tf32_enabled\n    if _tf32_enabled is None:\n        _tf32_enabled = False\n        if torch.cuda.is_available() and (detect_default_tf32() or torch.backends.cuda.matmul.allow_tf32):\n            try:\n                # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result\n                g_gpu = torch.Generator(device=\"cuda\")\n                g_gpu.manual_seed(2147483647)\n                a_full = torch.randn(1024, 1024, dtype=torch.double, device=\"cuda\", generator=g_gpu)\n                b_full = torch.randn(1024, 1024, dtype=torch.double, device=\"cuda\", generator=g_gpu)\n                _tf32_enabled = (a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001  # 0.1713\n            except BaseException:\n                pass\n        print(f\"tf32 enabled: {_tf32_enabled}\")\n    return _tf32_enabled\n\n\ndef skip_if_quick(obj):\n    \"\"\"\n    Skip the unit tests if environment variable `quick_test_var=true`.\n    For example, the user can skip the relevant tests by setting ``export QUICKTEST=true``.\n    \"\"\"\n    is_quick = test_is_quick()\n\n    return unittest.skipIf(is_quick, \"Skipping slow tests\")(obj)\n\n\nclass SkipIfNoModule:\n    \"\"\"Decorator to be used if test should be skipped\n    when optional module is not present.\"\"\"\n\n    def __init__(self, module_name):\n        self.module_name = module_name\n        self.module_missing = not optional_import(self.module_name)[1]\n\n    def __call__(self, obj):\n        return unittest.skipIf(self.module_missing, f\"optional module not present: {self.module_name}\")(obj)\n\n\nclass SkipIfModule:\n    \"\"\"Decorator to be used if test should be skipped\n    when optional module is present.\"\"\"\n\n    def __init__(self, module_name):\n        self.module_name = module_name\n        self.module_avail = optional_import(self.module_name)[1]\n\n    def __call__(self, obj):\n        return unittest.skipIf(self.module_avail, f\"Skipping because optional module present: {self.module_name}\")(obj)\n\n\ndef skip_if_no_cpp_extension(obj):\n    \"\"\"\n    Skip the unit tests if the cpp extension is not available.\n    \"\"\"\n    return unittest.skipUnless(USE_COMPILED, \"Skipping cpp extension tests\")(obj)\n\n\ndef skip_if_no_cuda(obj):\n    \"\"\"\n    Skip the unit tests if torch.cuda.is_available is False.\n    \"\"\"\n    return unittest.skipUnless(torch.cuda.is_available(), \"Skipping CUDA-based tests\")(obj)\n\n\ndef skip_if_windows(obj):\n    \"\"\"\n    Skip the unit tests if platform is win32.\n    \"\"\"\n    return unittest.skipIf(sys.platform == \"win32\", \"Skipping tests on Windows\")(obj)\n\n\ndef skip_if_darwin(obj):\n    \"\"\"\n    Skip the unit tests if platform is macOS (Darwin).\n    \"\"\"\n    return unittest.skipIf(sys.platform == \"darwin\", \"Skipping tests on macOS/Darwin\")(obj)\n\n\nclass SkipIfBeforePyTorchVersion:\n    \"\"\"Decorator to be used if test should be skipped\n    with PyTorch versions older than that given.\"\"\"\n\n    def __init__(self, pytorch_version_tuple):\n        self.min_version = pytorch_version_tuple\n        self.version_too_old = not pytorch_after(*pytorch_version_tuple)\n\n    def __call__(self, obj):\n        return unittest.skipIf(\n            self.version_too_old, f\"Skipping tests that fail on PyTorch versions before: {self.min_version}\"\n        )(obj)\n\n\nclass SkipIfAtLeastPyTorchVersion:\n    \"\"\"Decorator to be used if test should be skipped\n    with PyTorch versions newer than or equal to that given.\"\"\"\n\n    def __init__(self, pytorch_version_tuple):\n        self.max_version = pytorch_version_tuple\n        self.version_too_new = pytorch_after(*pytorch_version_tuple)\n\n    def __call__(self, obj):\n        return unittest.skipIf(\n            self.version_too_new, f\"Skipping tests that fail on PyTorch versions at least: {self.max_version}\"\n        )(obj)\n\n\nclass SkipIfBeforeComputeCapabilityVersion:\n    \"\"\"Decorator to be used if test should be skipped\n    with Compute Capability older than that given.\"\"\"\n\n    def __init__(self, compute_capability_tuple):\n        self.min_version = compute_capability_tuple\n        self.version_too_old = not compute_capabilities_after(*compute_capability_tuple)\n\n    def __call__(self, obj):\n        return unittest.skipIf(\n            self.version_too_old, f\"Skipping tests that fail on Compute Capability versions before: {self.min_version}\"\n        )(obj)\n\n\ndef is_main_test_process():\n    ps = torch.multiprocessing.current_process()\n    if not ps or not hasattr(ps, \"name\"):\n        return False\n    return ps.name.startswith(\"Main\")\n\n\ndef has_cupy():\n    \"\"\"\n    Returns True if the user has installed a version of cupy.\n    \"\"\"\n    cp, has_cp = optional_import(\"cupy\")\n    if not is_main_test_process():\n        return has_cp  # skip the check if we are running in subprocess\n    if not has_cp:\n        return False\n    try:  # test cupy installation with a basic example\n        x = cp.arange(6, dtype=\"f\").reshape(2, 3)\n        y = cp.arange(3, dtype=\"f\")\n        kernel = cp.ElementwiseKernel(\n            \"float32 x, float32 y\", \"float32 z\", \"\"\" if (x - 2 > y) { z = x * y; } else { z = x + y; } \"\"\", \"my_kernel\"\n        )\n        flag = kernel(x, y)[0, 0] == 0\n        del x, y, kernel\n        cp.get_default_memory_pool().free_all_blocks()\n        return flag\n    except Exception:\n        return False\n\n\nHAS_CUPY = has_cupy()\n\n\ndef make_nifti_image(\n    array: NdarrayOrTensor, affine=None, dir=None, fname=None, suffix=\".nii.gz\", verbose=False, dtype=float\n):\n    \"\"\"\n    Create a temporary nifti image on the disk and return the image name.\n    User is responsible for deleting the temporary file when done with it.\n    \"\"\"\n    if isinstance(array, torch.Tensor):\n        array, *_ = convert_data_type(array, np.ndarray)\n    if isinstance(affine, torch.Tensor):\n        affine, *_ = convert_data_type(affine, np.ndarray)\n    if affine is None:\n        affine = np.eye(4)\n    test_image = nib.Nifti1Image(array.astype(dtype), affine)  # type: ignore\n\n    # if dir not given, create random. Else, make sure it exists.\n    if dir is None:\n        dir = tempfile.mkdtemp()\n    else:\n        os.makedirs(dir, exist_ok=True)\n\n    # If fname not given, get random one. Else, concat dir, fname and suffix.\n    if fname is None:\n        temp_f, fname = tempfile.mkstemp(suffix=suffix, dir=dir)\n        os.close(temp_f)\n    else:\n        fname = os.path.join(dir, fname + suffix)\n\n    nib.save(test_image, fname)\n    if verbose:\n        print(f\"File written: {fname}.\")\n    return fname\n\n\ndef make_rand_affine(ndim: int = 3, random_state: np.random.RandomState | None = None):\n    \"\"\"Create random affine transformation (with values == -1, 0 or 1).\"\"\"\n    rs = np.random.random.__self__ if random_state is None else random_state  # type: ignore\n\n    vals = rs.choice([-1, 1], size=ndim)\n    positions = rs.choice(range(ndim), size=ndim, replace=False)\n    af = np.zeros([ndim + 1, ndim + 1])\n    af[ndim, ndim] = 1\n    for i, (v, p) in enumerate(zip(vals, positions)):\n        af[i, p] = v\n    return af\n\n\ndef get_arange_img(size, dtype=np.float32, offset=0):\n    \"\"\"\n    Returns an image as a numpy array (complete with channel as dim 0)\n    with contents that iterate like an arange.\n    \"\"\"\n    n_elem = np.prod(size)\n    img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size)\n    return np.expand_dims(img, 0)\n\n\nclass DistTestCase(unittest.TestCase):\n    \"\"\"\n    testcase without _outcome, so that it's picklable.\n    \"\"\"\n\n    def __getstate__(self):\n        self_dict = self.__dict__.copy()\n        del self_dict[\"_outcome\"]\n        return self_dict\n\n    def __setstate__(self, data_dict):\n        self.__dict__.update(data_dict)\n\n\nclass DistCall:\n    \"\"\"\n    Wrap a test case so that it will run in multiple processes on a single machine using `torch.distributed`.\n    It is designed to be used with `tests.utils.DistTestCase`.\n\n    Usage:\n\n        decorate a unittest testcase method with a `DistCall` instance::\n\n            class MyTests(unittest.TestCase):\n                @DistCall(nnodes=1, nproc_per_node=3, master_addr=\"localhost\")\n                def test_compute(self):\n                ...\n\n        the `test_compute` method should trigger different worker logic according to `dist.get_rank()`.\n\n    Multi-node tests require a fixed master_addr:master_port, with node_rank set manually in multiple scripts\n    or from environment variable \"NODE_RANK\".\n    \"\"\"\n\n    def __init__(\n        self,\n        nnodes: int = 1,\n        nproc_per_node: int = 1,\n        master_addr: str = \"localhost\",\n        master_port: int | None = None,\n        node_rank: int | None = None,\n        timeout=60,\n        init_method=None,\n        backend: str | None = None,\n        daemon: bool | None = None,\n        method: str | None = \"spawn\",\n        verbose: bool = False,\n    ):\n        \"\"\"\n\n        Args:\n            nnodes: The number of nodes to use for distributed call.\n            nproc_per_node: The number of processes to call on each node.\n            master_addr: Master node (rank 0)'s address, should be either the IP address or the hostname of node 0.\n            master_port: Master node (rank 0)'s free port.\n            node_rank: The rank of the node, this could be set via environment variable \"NODE_RANK\".\n            timeout: Timeout for operations executed against the process group.\n            init_method: URL specifying how to initialize the process group.\n                Default is \"env://\" or \"file:///d:/a_temp\" (windows) if unspecified.\n                If ``\"no_init\"``, the `dist.init_process_group` must be called within the code to be tested.\n            backend: The backend to use. Depending on build-time configurations,\n                valid values include ``mpi``, ``gloo``, and ``nccl``.\n            daemon: the process’s daemon flag.\n                When daemon=None, the initial value is inherited from the creating process.\n            method: set the method which should be used to start a child process.\n                method can be 'fork', 'spawn' or 'forkserver'.\n            verbose: whether to print NCCL debug info.\n        \"\"\"\n        self.nnodes = int(nnodes)\n        self.nproc_per_node = int(nproc_per_node)\n        if self.nnodes < 1 or self.nproc_per_node < 1:\n            raise ValueError(\n                f\"number of nodes and processes per node must be >= 1, got {self.nnodes} and {self.nproc_per_node}\"\n            )\n        self.node_rank = int(os.environ.get(\"NODE_RANK\", \"0\")) if node_rank is None else int(node_rank)\n        self.master_addr = master_addr\n        self.master_port = np.random.randint(10000, 20000) if master_port is None else master_port\n\n        if backend is None:\n            self.backend = \"nccl\" if torch.distributed.is_nccl_available() and torch.cuda.is_available() else \"gloo\"\n        else:\n            self.backend = backend\n        self.init_method = init_method\n        if self.init_method is None and sys.platform == \"win32\":\n            self.init_method = \"file:///d:/a_temp\"\n        self.timeout = datetime.timedelta(0, timeout)\n        self.daemon = daemon\n        self.method = method\n        self.verbose = verbose\n\n    def run_process(self, func, local_rank, args, kwargs, results):\n        _env = os.environ.copy()  # keep the original system env\n        try:\n            os.environ[\"MASTER_ADDR\"] = self.master_addr\n            os.environ[\"MASTER_PORT\"] = str(self.master_port)\n            os.environ[\"LOCAL_RANK\"] = str(local_rank)\n            if self.verbose:\n                os.environ[\"NCCL_DEBUG\"] = \"INFO\"\n                os.environ[\"NCCL_DEBUG_SUBSYS\"] = \"ALL\"\n            os.environ[\"TORCH_NCCL_BLOCKING_WAIT\"] = str(1)\n            os.environ[\"OMP_NUM_THREADS\"] = str(1)\n            os.environ[\"WORLD_SIZE\"] = str(self.nproc_per_node * self.nnodes)\n            os.environ[\"RANK\"] = str(self.nproc_per_node * self.node_rank + local_rank)\n\n            if torch.cuda.is_available():\n                torch.cuda.set_device(int(local_rank))  # using device ids from CUDA_VISIBILE_DEVICES\n\n            if self.init_method != \"no_init\":\n                dist.init_process_group(\n                    backend=self.backend,\n                    init_method=self.init_method,\n                    timeout=self.timeout,\n                    world_size=int(os.environ[\"WORLD_SIZE\"]),\n                    rank=int(os.environ[\"RANK\"]),\n                )\n            func(*args, **kwargs)\n            # the primary node lives longer to\n            # avoid _store_based_barrier, RuntimeError: Broken pipe\n            # as the TCP store daemon is on the rank 0\n            if int(os.environ[\"RANK\"]) == 0:\n                time.sleep(0.1)\n            results.put(True)\n        except Exception as e:\n            results.put(False)\n            raise e\n        finally:\n            os.environ.clear()\n            os.environ.update(_env)\n            try:\n                dist.destroy_process_group()\n            except RuntimeError as e:\n                warnings.warn(f\"While closing process group: {e}.\")\n\n    def __call__(self, obj):\n        if not torch.distributed.is_available():\n            return unittest.skipIf(True, \"Skipping distributed tests because not torch.distributed.is_available()\")(obj)\n        if torch.cuda.is_available() and torch.cuda.device_count() < self.nproc_per_node:\n            return unittest.skipIf(\n                True,\n                f\"Skipping distributed tests because it requires {self.nnodes} devices \"\n                f\"but got {torch.cuda.device_count()}\",\n            )(obj)\n\n        _cache_original_func(obj)\n\n        @functools.wraps(obj)\n        def _wrapper(*args, **kwargs):\n            tmp = torch.multiprocessing.get_context(self.method)\n            processes = []\n            results = tmp.Queue()\n            func = _call_original_func\n            args = [obj.__name__, obj.__module__] + list(args)\n            for proc_rank in range(self.nproc_per_node):\n                p = tmp.Process(\n                    target=self.run_process, args=(func, proc_rank, args, kwargs, results), daemon=self.daemon\n                )\n                p.start()\n                processes.append(p)\n            for p in processes:\n                p.join()\n                assert results.get(), \"Distributed call failed.\"\n            _del_original_func(obj)\n\n        return _wrapper\n\n\nclass TimedCall:\n    \"\"\"\n    Wrap a test case so that it will run in a new process, raises a TimeoutError if the decorated method takes\n    more than `seconds` to finish. It is designed to be used with `tests.utils.DistTestCase`.\n    \"\"\"\n\n    def __init__(\n        self,\n        seconds: float = 60.0,\n        daemon: bool | None = None,\n        method: str | None = \"spawn\",\n        force_quit: bool = True,\n        skip_timing=False,\n    ):\n        \"\"\"\n\n        Args:\n            seconds: timeout seconds.\n            daemon: the process’s daemon flag.\n                When daemon=None, the initial value is inherited from the creating process.\n            method: set the method which should be used to start a child process.\n                method can be 'fork', 'spawn' or 'forkserver'.\n            force_quit: whether to terminate the child process when `seconds` elapsed.\n            skip_timing: whether to skip the timing constraint.\n                this is useful to include some system conditions such as\n                `torch.cuda.is_available()`.\n        \"\"\"\n        self.timeout_seconds = seconds\n        self.daemon = daemon\n        self.force_quit = force_quit\n        self.skip_timing = skip_timing\n        self.method = method\n\n    @staticmethod\n    def run_process(func, args, kwargs, results):\n        try:\n            output = func(*args, **kwargs)\n            results.put(output)\n        except Exception as e:\n            e.traceback = traceback.format_exc()\n            results.put(e)\n\n    def __call__(self, obj):\n        if self.skip_timing:\n            return obj\n\n        _cache_original_func(obj)\n\n        @functools.wraps(obj)\n        def _wrapper(*args, **kwargs):\n            tmp = torch.multiprocessing.get_context(self.method)\n            func = _call_original_func\n            args = [obj.__name__, obj.__module__] + list(args)\n            results = tmp.Queue()\n            p = tmp.Process(target=TimedCall.run_process, args=(func, args, kwargs, results), daemon=self.daemon)\n            p.start()\n\n            p.join(timeout=self.timeout_seconds)\n\n            timeout_error = None\n            try:\n                if p.is_alive():\n                    # create an Exception\n                    timeout_error = torch.multiprocessing.TimeoutError(\n                        f\"'{obj.__name__}' in '{obj.__module__}' did not finish in {self.timeout_seconds}s.\"\n                    )\n                    if self.force_quit:\n                        p.terminate()\n                    else:\n                        warnings.warn(\n                            f\"TimedCall: deadline ({self.timeout_seconds}s) \"\n                            f\"reached but waiting for {obj.__name__} to finish.\"\n                        )\n            finally:\n                p.join()\n\n            _del_original_func(obj)\n            res = None\n            try:\n                res = results.get(block=False)\n            except queue.Empty:  # no result returned, took too long\n                pass\n            if isinstance(res, Exception):  # other errors from obj\n                if hasattr(res, \"traceback\"):\n                    raise RuntimeError(res.traceback) from res\n                raise res\n            if timeout_error:  # no force_quit finished\n                raise timeout_error\n            return res\n\n        return _wrapper\n\n\n_original_funcs = {}\n\n\ndef _cache_original_func(obj) -> None:\n    \"\"\"cache the original function by name, so that the decorator doesn't shadow it.\"\"\"\n    _original_funcs[obj.__name__] = obj\n\n\ndef _del_original_func(obj):\n    \"\"\"pop the original function from cache.\"\"\"\n    _original_funcs.pop(obj.__name__, None)\n    if torch.cuda.is_available():  # clean up the cached function\n        torch.cuda.synchronize()\n        torch.cuda.empty_cache()\n\n\ndef _call_original_func(name, module, *args, **kwargs):\n    if name not in _original_funcs:\n        _original_module = importlib.import_module(module)  # reimport, refresh _original_funcs\n        if not hasattr(_original_module, name):\n            # refresh module doesn't work\n            raise RuntimeError(f\"Could not recover the original {name} from {module}: {_original_funcs}.\")\n    f = _original_funcs[name]\n    return f(*args, **kwargs)\n\n\nclass NumpyImageTestCase2D(unittest.TestCase):\n    im_shape = (128, 64)\n    input_channels = 1\n    output_channels = 4\n    num_classes = 3\n\n    def setUp(self):\n        im, msk = create_test_image_2d(\n            self.im_shape[0], self.im_shape[1], num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=self.num_classes\n        )\n\n        self.imt = im[None, None]\n        self.seg1 = (msk[None, None] > 0).astype(np.float32)\n        self.segn = msk[None, None]\n\n\nclass TorchImageTestCase2D(NumpyImageTestCase2D):\n\n    def setUp(self):\n        NumpyImageTestCase2D.setUp(self)\n        self.imt = torch.tensor(self.imt)\n        self.seg1 = torch.tensor(self.seg1)\n        self.segn = torch.tensor(self.segn)\n\n\nclass NumpyImageTestCase3D(unittest.TestCase):\n    im_shape = (64, 48, 80)\n    input_channels = 1\n    output_channels = 4\n    num_classes = 3\n\n    def setUp(self):\n        im, msk = create_test_image_3d(\n            self.im_shape[0],\n            self.im_shape[1],\n            self.im_shape[2],\n            num_objs=4,\n            rad_max=20,\n            noise_max=0.0,\n            num_seg_classes=self.num_classes,\n        )\n\n        self.imt = im[None, None]\n        self.seg1 = (msk[None, None] > 0).astype(np.float32)\n        self.segn = msk[None, None]\n\n\nclass TorchImageTestCase3D(NumpyImageTestCase3D):\n\n    def setUp(self):\n        NumpyImageTestCase3D.setUp(self)\n        self.imt = torch.tensor(self.imt)\n        self.seg1 = torch.tensor(self.seg1)\n        self.segn = torch.tensor(self.segn)\n\n\ndef test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0):\n    \"\"\"\n    Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is\n    forward-passed through the original and loaded copy of the network and their results returned.\n    The forward pass for both is done without gradient accumulation.\n\n    The test will be performed with CUDA if available, else CPU.\n    \"\"\"\n    # TODO: would be nice to use GPU if available, but it currently causes CI failures.\n    device = \"cpu\"\n    with tempfile.TemporaryDirectory() as tempdir:\n        convert_to_torchscript(\n            model=net,\n            filename_or_obj=os.path.join(tempdir, \"model.ts\"),\n            verify=True,\n            inputs=inputs,\n            device=device,\n            rtol=rtol,\n            atol=atol,\n        )\n\n\ndef test_onnx_save(net, *inputs, device=None, rtol=1e-4, atol=0.0):\n    \"\"\"\n    Test the ability to save `net` in ONNX format, reload it and validate with runtime.\n    The value `inputs` is forward-passed through the `net` without gradient accumulation\n    to do onnx export and PyTorch inference.\n    PyTorch model inference is performed with CUDA if available, else CPU.\n    Saved ONNX model is validated with onnxruntime, if available, else ONNX native implementation.\n    \"\"\"\n    # TODO: would be nice to use GPU if available, but it currently causes CI failures.\n    device = \"cpu\"\n    _, has_onnxruntime = optional_import(\"onnxruntime\")\n    with tempfile.TemporaryDirectory() as tempdir:\n        convert_to_onnx(\n            model=net,\n            filename=os.path.join(tempdir, \"model.onnx\"),\n            verify=True,\n            inputs=inputs,\n            device=device,\n            use_ort=has_onnxruntime,\n            rtol=rtol,\n            atol=atol,\n        )\n\n\ndef download_url_or_skip_test(*args, **kwargs):\n    \"\"\"``download_url`` and skip the tests if any downloading error occurs.\"\"\"\n    with skip_if_downloading_fails():\n        download_url(*args, **kwargs)\n\n\ndef query_memory(n=2):\n    \"\"\"\n    Find best n idle devices and return a string of device ids using the `nvidia-smi` command.\n    \"\"\"\n    bash_string = \"nvidia-smi --query-gpu=power.draw,temperature.gpu,memory.used --format=csv,noheader,nounits\"\n\n    try:\n        print(f\"query memory with n={n}\")\n        p1 = Popen(bash_string.split(), stdout=PIPE)\n        output, error = p1.communicate()\n        free_memory = [x.split(\",\") for x in output.decode(\"utf-8\").split(\"\\n\")[:-1]]\n        free_memory = np.asarray(free_memory, dtype=float).T\n        free_memory[1] += free_memory[0]  # combine 0/1 column measures\n        ids = np.lexsort(free_memory)[:n]\n    except (TypeError, ValueError, IndexError, OSError):\n        ids = range(n) if isinstance(n, int) else []\n    return \",\".join(f\"{int(x)}\" for x in ids)\n\n\ndef test_local_inversion(invertible_xform, to_invert, im, dict_key=None):\n    \"\"\"test that invertible_xform can bring to_invert back to im\"\"\"\n    im_item = im if dict_key is None else im[dict_key]\n    if not isinstance(im_item, MetaTensor):\n        return\n    im_ref = copy.deepcopy(im)\n    im_inv = invertible_xform.inverse(to_invert)\n    if dict_key:\n        im_inv = im_inv[dict_key]\n        im_ref = im_ref[dict_key]\n    np.testing.assert_array_equal(im_inv.applied_operations, [])\n    assert_allclose(im_inv.shape, im_ref.shape)\n    assert_allclose(im_inv.affine, im_ref.affine, atol=1e-3, rtol=1e-3)\n\n\ndef command_line_tests(cmd, copy_env=True):\n    test_env = os.environ.copy() if copy_env else os.environ\n    print(f\"CUDA_VISIBLE_DEVICES in {__file__}\", test_env.get(\"CUDA_VISIBLE_DEVICES\"))\n    try:\n        normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True)\n        print(repr(normal_out).replace(\"\\\\n\", \"\\n\").replace(\"\\\\t\", \"\\t\"))\n        return repr(normal_out)\n    except subprocess.CalledProcessError as e:\n        output = repr(e.stdout).replace(\"\\\\n\", \"\\n\").replace(\"\\\\t\", \"\\t\")\n        errors = repr(e.stderr).replace(\"\\\\n\", \"\\n\").replace(\"\\\\t\", \"\\t\")\n        raise RuntimeError(f\"subprocess call error {e.returncode}: {errors}, {output}\") from e\n\n\ndef equal_state_dict(st_1, st_2):\n    \"\"\"\n    assert equal state_dict (for the shared keys between st_1 and st_2).\n    \"\"\"\n    for key_st_1, val_st_1 in st_1.items():\n        if key_st_1 in st_2:\n            val_st_2 = st_2.get(key_st_1)\n            assert_allclose(val_st_1, val_st_2)\n\n\nTEST_TORCH_TENSORS: tuple = (torch.as_tensor,)\nif torch.cuda.is_available():\n    gpu_tensor: Callable = partial(torch.as_tensor, device=\"cuda\")\n    TEST_TORCH_TENSORS = TEST_TORCH_TENSORS + (gpu_tensor,)\n\nDEFAULT_TEST_AFFINE = torch.tensor(\n    [[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]]\n)\n_metatensor_creator = partial(MetaTensor, meta={\"a\": \"b\", \"affine\": DEFAULT_TEST_AFFINE})\nTEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS\nTEST_NDARRAYS: tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,)  # type: ignore\nTEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,)\n# alias for branch tests\nTEST_NDARRAYS_ALL = TEST_NDARRAYS\n\nTEST_DEVICES = [[torch.device(\"cpu\")]]\nif torch.cuda.is_available():\n    TEST_DEVICES.append([torch.device(\"cuda\")])\n\n\ndef dict_product(**items: Iterable[Any]) -> list[dict]:\n    \"\"\"Create cartesian product, equivalent to a nested for-loop, combinations of the items dict.\n\n    Args:\n        items: dict of items to be combined.\n\n    Returns:\n        list: list of dictionaries with the combinations of the input items.\n\n    Example:\n        >>> dict_product(x=[1, 2], y=[3, 4])\n        [{'x': 1, 'y': 3}, {'x': 1, 'y': 4}, {'x': 2, 'y': 3}, {'x': 2, 'y': 4}]\n    \"\"\"\n    keys = items.keys()\n    values = items.values()\n    prod_values = product(*values)\n    prod_dict = [dict(zip(keys, v)) for v in prod_values]\n    return prod_dict\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(prog=\"util\")\n    parser.add_argument(\"-c\", \"--count\", default=2, help=\"max number of gpus\")\n    args = parser.parse_args()\n    print(\"\\n\", query_memory(int(args.count)), sep=\"\\n\")  # print to stdout\n    sys.exit(0)\n"
  },
  {
    "path": "tests/testing_data/1D_BP_bwd.txt",
    "content": "0., 1., 1., 1., 1., 1., 1., 1., 1.,12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.nearest BoundType.replicate\n0., 1., 1., 1., 1., 1., 1., 1., 1.,12., # InterpolationType.nearest BoundType.replicate\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.replicate\n0., # InterpolationType.nearest BoundType.replicate\n0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.linear BoundType.replicate\n0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, # InterpolationType.linear BoundType.replicate\n1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.linear BoundType.replicate\n0., # InterpolationType.linear BoundType.replicate\n0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.quadratic BoundType.replicate\n0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5, # InterpolationType.quadratic BoundType.replicate\n1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.quadratic BoundType.replicate\n0., # InterpolationType.quadratic BoundType.replicate\n0.5208333 , 0.9791666 , 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994,11.5 , 0.875 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875 , 0.125 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.cubic BoundType.replicate\n0.5208333 , 0.9791666 , 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994,11.5 , # InterpolationType.cubic BoundType.replicate\n0.875,1. ,1. ,1. ,1. ,1. ,1. ,1. ,0.875,0.125,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. , # InterpolationType.cubic BoundType.replicate\n0., # InterpolationType.cubic BoundType.replicate\n0.5416667 , 0.9583334 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5 , 0.8333334 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , 0.833333 , 0.16666651, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fourth BoundType.replicate\n0.5416667, 0.9583334, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,11.5 , # InterpolationType.fourth BoundType.replicate\n0.8333334 ,1. ,1. ,1. ,1. ,1. ,0.9999999 ,1. ,0.833333 ,0.16666651,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. , # InterpolationType.fourth BoundType.replicate\n0., # InterpolationType.fourth BoundType.replicate\n5.6223959e-01,9.3802083e-01,9.9973959e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.1499999e+01,7.9947913e-01,9.9739581e-01,1.0000000e+00,1.0000000e+00,9.9999994e-01,1.0000001e+00,9.9999976e-01,9.9739575e-01,7.9947948e-01,2.0052099e-01,2.6040077e-03,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.fifth BoundType.replicate\n0.5622396, 0.9380208, 0.9997396, 1. , 1. , 1. , 1. , 1. , 1. ,11.499999 , # InterpolationType.fifth BoundType.replicate\n0.7994791 ,0.9973958 ,1. ,1. ,0.99999994,1.0000001 ,0.99999976,0.99739575,0.7994795 ,0.20052099,0.00260401,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. ,0. , # InterpolationType.fifth BoundType.replicate\n0., # InterpolationType.fifth BoundType.replicate\n5.8194447e-01,9.1944444e-01,9.9861109e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.1499997e+01,7.7499998e-01,9.9166673e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,9.9999982e-01,1.0000004e+00,9.9166673e-01,7.7499980e-01,2.2500010e-01,8.3333999e-03,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07, # InterpolationType.sixth BoundType.replicate\n0.58194447, 0.91944444, 0.9986111 , 1. , 1. , 1. , 1. , 1. , 1. ,11.499997 , # InterpolationType.sixth BoundType.replicate\n7.7499998e-01,9.9166673e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,9.9999982e-01,1.0000004e+00,9.9166673e-01,7.7499980e-01,2.2500010e-01,8.3333999e-03,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07,1.9371510e-07, # InterpolationType.sixth BoundType.replicate\n0., # InterpolationType.sixth BoundType.replicate\n6.0078436e-01,9.0259641e-01,9.9662077e-01,9.9999845e-01,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.0000000e+00,1.1500004e+01,7.5551212e-01,9.8430985e-01,9.9997836e-01,9.9999994e-01,1.0000000e+00,1.0000001e+00,9.9997842e-01,9.8431003e-01,7.5551212e-01,2.4448761e-01,1.5690181e-02,2.1788481e-05,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07, # InterpolationType.seventh BoundType.replicate\n0.60078436, 0.9025964 , 0.9966208 , 0.99999845, 1. , 1. , 1. , 1. , 1. ,11.500004 , # InterpolationType.seventh BoundType.replicate\n7.5551212e-01,9.8430985e-01,9.9997836e-01,9.9999994e-01,1.0000000e+00,1.0000001e+00,9.9997842e-01,9.8431003e-01,7.5551212e-01,2.4448761e-01,1.5690181e-02,2.1788481e-05,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07,3.3080869e-07, # InterpolationType.seventh BoundType.replicate\n0., # InterpolationType.seventh BoundType.replicate\n1.,3.,3.,2.,2.,2.,2.,2.,2.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct1\n1.,3.,3.,2.,2.,2.,2.,2.,2.,1., # InterpolationType.nearest BoundType.dct1\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct1\n0., # InterpolationType.nearest BoundType.dct1\n1.5, 3. , 2.5, 2. , 2. , 2. , 2. , 2. , 2. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. , 1. , 1. , # InterpolationType.linear BoundType.dct1\n1.5,3. ,2.5,2. ,2. ,2. ,2. ,2. ,2. ,1. , # InterpolationType.linear BoundType.dct1\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 1., 1., # InterpolationType.linear BoundType.dct1\n0., # InterpolationType.linear BoundType.dct1\n1.5, 3. , 2.5, 2. , 2. , 2. , 2. , 2. , 2. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. , 1. , 1. , # InterpolationType.quadratic BoundType.dct1\n1.5,3. ,2.5,2. ,2. ,2. ,2. ,2. ,2. ,1. , # InterpolationType.quadratic BoundType.dct1\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 1., 1., # InterpolationType.quadratic BoundType.dct1\n0., # InterpolationType.quadratic BoundType.dct1\n1.5 , 2.9791667 , 2.5 , 2.0208333 , 1.9999999 , 1.9999999 , 1.9999999 , 1.9999999 , 1.9999999 , 0.99999994, 0.75 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.75 ,-0.75 ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.75 , 0.75 , 1. , # InterpolationType.cubic BoundType.dct1\n1.5 ,2.9791667 ,2.5 ,2.0208333 ,1.9999999 ,1.9999999 ,1.9999999 ,1.9999999 ,1.9999999 ,0.99999994, # InterpolationType.cubic BoundType.dct1\n0.75, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.75,-0.75,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.75, 0.75, 1. , # InterpolationType.cubic BoundType.dct1\n0., # InterpolationType.cubic BoundType.dct1\n1.5 , 2.9583333 , 2.5 , 2.0416667 , 2. , 2. , 2. , 2. , 2. , 1. , 0.6666666 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , 0.6666664 ,-0.66666675,-1. ,-1.0000001 ,-1.0000002 ,-1. ,-1.0000001 ,-1.0000001 ,-1. ,-0.6666667 , 0.6666666 , 1. , # InterpolationType.fourth BoundType.dct1\n1.5 ,2.9583333,2.5 ,2.0416667,2. ,2. ,2. ,2. ,2. ,1. , # InterpolationType.fourth BoundType.dct1\n0.6666666 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , 0.6666664 ,-0.66666675,-1. ,-1.0000001 ,-1.0000002 ,-1. ,-1.0000001 ,-1.0000001 ,-1. ,-0.6666667 , 0.6666666 , 1. , # InterpolationType.fourth BoundType.dct1\n0., # InterpolationType.fourth BoundType.dct1\n1.4997395 , 2.9380207 , 2.5 , 2.061979 , 2.0002604 , 2. , 2. , 2. , 2. , 1. , 0.5989583 , 0.9947917 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.99479157, 0.5989587 ,-0.59895825,-0.9947917 ,-0.9999998 ,-1.0000002 ,-1. ,-0.9999998 ,-1. ,-0.9947917 ,-0.5989583 , 0.5989583 , 0.9947917 , # InterpolationType.fifth BoundType.dct1\n1.4997395,2.9380207,2.5 ,2.061979 ,2.0002604,2. ,2. ,2. ,2. ,1. , # InterpolationType.fifth BoundType.dct1\n0.5989583 , 0.9947917 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.99479157, 0.5989587 ,-0.59895825,-0.9947917 ,-0.9999998 ,-1.0000002 ,-1. ,-0.9999998 ,-1. ,-0.9947917 ,-0.5989583 , 0.5989583 , 0.9947917 , # InterpolationType.fifth BoundType.dct1\n0., # InterpolationType.fifth BoundType.dct1\n1.498611 , 2.919444 , 2.5 , 2.0805554 , 2.0013888 , 2. , 2. , 2. , 2. , 1. , 0.54999995, 0.9833334 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.9833334 , 0.5499998 ,-0.5499999 ,-0.9833334 ,-1.0000004 ,-1.0000001 ,-1.0000001 ,-1. ,-0.99999994,-0.98333335,-0.55 , 0.54999995, 0.9833334 , # InterpolationType.sixth BoundType.dct1\n1.498611 ,2.919444 ,2.5 ,2.0805554,2.0013888,2. ,2. ,2. ,2. ,1. , # InterpolationType.sixth BoundType.dct1\n0.54999995, 0.9833334 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.9833334 , 0.5499998 ,-0.5499999 ,-0.9833334 ,-1.0000004 ,-1.0000001 ,-1.0000001 ,-1. ,-0.99999994,-0.98333335,-0.55 , 0.54999995, 0.9833334 , # InterpolationType.sixth BoundType.dct1\n0., # InterpolationType.sixth BoundType.dct1\n1.4966209 , 2.9025953 , 2.5000002 , 2.097404 , 2.0033796 , 2.000002 , 2.0000002 , 2.0000002 , 2.0000002 , 1. , 0.5110243 , 0.9686197 , 0.99995667, 0.99999994, 1. , 1.0000001 , 0.9999567 , 0.96861994, 0.51102436,-0.5110245 ,-0.9686197 ,-0.99995685,-1. ,-1. ,-1.0000001 ,-0.99995655,-0.9686198 ,-0.5110243 , 0.5110243 , 0.9686197 , # InterpolationType.seventh BoundType.dct1\n1.4966209,2.9025953,2.5000002,2.097404 ,2.0033796,2.000002 ,2.0000002,2.0000002,2.0000002,1. , # InterpolationType.seventh BoundType.dct1\n0.5110243 , 0.9686197 , 0.99995667, 0.99999994, 1. , 1.0000001 , 0.9999567 , 0.96861994, 0.51102436,-0.5110245 ,-0.9686197 ,-0.99995685,-1. ,-1. ,-1.0000001 ,-0.99995655,-0.9686198 ,-0.5110243 , 0.5110243 , 0.9686197 , # InterpolationType.seventh BoundType.dct1\n0., # InterpolationType.seventh BoundType.dct1\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct2\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.nearest BoundType.dct2\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dct2\n0., # InterpolationType.nearest BoundType.dct2\n2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.linear BoundType.dct2\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.linear BoundType.dct2\n1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.linear BoundType.dct2\n0., # InterpolationType.linear BoundType.dct2\n2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.quadratic BoundType.dct2\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.quadratic BoundType.dct2\n1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1.,-1., 0., # InterpolationType.quadratic BoundType.dct2\n0., # InterpolationType.quadratic BoundType.dct2\n1.9999999, 2. , 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 2. , 0.875 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875 , 0. ,-0.875 ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.875 , 0. , # InterpolationType.cubic BoundType.dct2\n1.9999999,2. ,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,2. , # InterpolationType.cubic BoundType.dct2\n0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875, 0. ,-0.875,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-1. ,-0.875, 0. , # InterpolationType.cubic BoundType.dct2\n0., # InterpolationType.cubic BoundType.dct2\n2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00, 8.3333302e-01,-1.1920929e-07,-8.3333325e-01,-1.0000000e+00,-1.0000001e+00,-1.0000002e+00,-1.0000000e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-8.3333337e-01, 0., # InterpolationType.fourth BoundType.dct2\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.fourth BoundType.dct2\n8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00, 8.3333302e-01,-1.1920929e-07,-8.3333325e-01,-1.0000000e+00,-1.0000001e+00,-1.0000002e+00,-1.0000000e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-8.3333337e-01, 0., # InterpolationType.fourth BoundType.dct2\n0., # InterpolationType.fourth BoundType.dct2\n2.0000000e+00, 1.9999999e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 7.9687500e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.9739575e-01, 7.9687530e-01, 1.6018748e-07,-7.9687524e-01,-9.9739569e-01,-9.9999982e-01,-1.0000002e+00,-1.0000000e+00,-9.9999982e-01,-1.0000000e+00,-9.9739587e-01,-7.9687494e-01, 5.1222742e-09, # InterpolationType.fifth BoundType.dct2\n2. ,1.9999999,2. ,2. ,2. ,2. ,2. ,2. ,2. ,2. , # InterpolationType.fifth BoundType.dct2\n7.9687500e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.9739575e-01, 7.9687530e-01, 1.6018748e-07,-7.9687524e-01,-9.9739569e-01,-9.9999982e-01,-1.0000002e+00,-1.0000000e+00,-9.9999982e-01,-1.0000000e+00,-9.9739587e-01,-7.9687494e-01, 5.1222742e-09, # InterpolationType.fifth BoundType.dct2\n0., # InterpolationType.fifth BoundType.dct2\n2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 2.0000000e+00, 7.6666665e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 9.9166673e-01, 7.6666647e-01, 5.9604645e-08,-7.6666659e-01,-9.9166662e-01,-1.0000004e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-9.9999994e-01,-9.9166667e-01,-7.6666665e-01, 1.8626451e-09, # InterpolationType.sixth BoundType.dct2\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.sixth BoundType.dct2\n7.6666665e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 9.9166673e-01, 7.6666647e-01, 5.9604645e-08,-7.6666659e-01,-9.9166662e-01,-1.0000004e+00,-1.0000001e+00,-1.0000001e+00,-1.0000000e+00,-9.9999994e-01,-9.9166667e-01,-7.6666665e-01, 1.8626451e-09, # InterpolationType.sixth BoundType.dct2\n0., # InterpolationType.sixth BoundType.dct2\n2.0000002e+00, 2.0000000e+00, 2.0000000e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 2.0000002e+00, 7.3982203e-01, 9.8428816e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9997842e-01, 9.8428833e-01, 7.3982203e-01,-1.6936974e-07,-7.3982191e-01,-9.8428810e-01,-9.9997830e-01,-1.0000000e+00,-1.0000000e+00,-1.0000001e+00,-9.9997824e-01,-9.8428822e-01,-7.3982203e-01,-2.7284841e-09, # InterpolationType.seventh BoundType.dct2\n2.0000002,2. ,2. ,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002, # InterpolationType.seventh BoundType.dct2\n7.3982203e-01, 9.8428816e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9997842e-01, 9.8428833e-01, 7.3982203e-01,-1.6936974e-07,-7.3982191e-01,-9.8428810e-01,-9.9997830e-01,-1.0000000e+00,-1.0000000e+00,-1.0000001e+00,-9.9997824e-01,-9.8428822e-01,-7.3982203e-01,-2.7284841e-09, # InterpolationType.seventh BoundType.dct2\n0., # InterpolationType.seventh BoundType.dct2\n-1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.nearest BoundType.dst1\n-1., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.nearest BoundType.dst1\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst1\n0., # InterpolationType.nearest BoundType.dst1\n0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.linear BoundType.dst1\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.linear BoundType.dst1\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.linear BoundType.dst1\n0., # InterpolationType.linear BoundType.dst1\n0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.quadratic BoundType.dst1\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.quadratic BoundType.dst1\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-9.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1., # InterpolationType.quadratic BoundType.dst1\n0., # InterpolationType.quadratic BoundType.dst1\n0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, 8.7500000e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,-2.5000000e-01,-7.7500000e+00,-7.7500000e+00,-2.5000000e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 8.7500000e-01, # InterpolationType.cubic BoundType.dst1\n0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, # InterpolationType.cubic BoundType.dst1\n0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-7.75 ,-7.75 ,-0.25 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875, # InterpolationType.cubic BoundType.dst1\n0., # InterpolationType.cubic BoundType.dst1\n0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, 8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00,-6.6666698e-01,-7.3333335e+00,-7.3333335e+00,-6.6666675e-01, 1.0000000e+00, 1.0000001e+00, 1.0000002e+00, 1.0000000e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 8.3333337e-01, # InterpolationType.fourth BoundType.dst1\n0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, # InterpolationType.fourth BoundType.dst1\n0.8333334 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. ,-0.666667 ,-7.3333335 ,-7.3333335 ,-0.66666675, 1. , 1.0000001 , 1.0000002 , 1. , 1.0000001 , 1.0000001 , 1. , 0.8333334 , # InterpolationType.fourth BoundType.dst1\n0., # InterpolationType.fourth BoundType.dst1\n3.9872248e-09, 0., 1.1175871e-08, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.9947913e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.7395825e-01,-1.0052080e+00,-6.9687500e+00,-6.9687500e+00,-1.0052083e+00, 9.7395819e-01, 9.9999982e-01, 1.0000002e+00, 1.0000000e+00, 9.9999982e-01, 1.0000000e+00, 9.9739587e-01, 7.9947913e-01, # InterpolationType.fifth BoundType.dst1\n3.9872248e-09,0.,1.1175871e-08,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09, # InterpolationType.fifth BoundType.dst1\n0.7994791 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-1.005208 ,-6.96875 ,-6.96875 ,-1.0052083 , 0.9739582 , 0.9999998 , 1.0000002 , 1. , 0.9999998 , 1. , 0.9973959 , 0.7994791 , # InterpolationType.fifth BoundType.dst1\n0., # InterpolationType.fifth BoundType.dst1\n4.1094609e-08, 0.,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-2.6193447e-08, 7.7499998e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 9.1666675e-01,-1.2500002e+00,-6.6666665e+00,-6.6666665e+00,-1.2500000e+00, 9.1666681e-01, 1.0000004e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 9.9999994e-01, 9.9166667e-01, 7.7499998e-01, # InterpolationType.sixth BoundType.dst1\n4.1094609e-08, 0.,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-2.6193447e-08, # InterpolationType.sixth BoundType.dst1\n0.775 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.2500002 ,-6.6666665 ,-6.6666665 ,-1.25 , 0.9166668 , 1.0000004 , 1.0000001 , 1.0000001 , 1. , 0.99999994, 0.9916667 , 0.775 , # InterpolationType.sixth BoundType.dst1\n0., # InterpolationType.sixth BoundType.dst1\n-9.7788870e-09, 3.7846348e-10,-7.4505806e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 7.5553381e-01, 9.8430985e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9978310e-01, 8.4309906e-01,-1.4446614e+00,-6.3982205e+00,-6.3982205e+00,-1.4446614e+00, 8.4309900e-01, 9.9978304e-01, 1.0000000e+00, 1.0000000e+00, 1.0000001e+00, 9.9997824e-01, 9.8430991e-01, 7.5553381e-01, # InterpolationType.seventh BoundType.dst1\n-9.7788870e-09, 3.7846348e-10,-7.4505806e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, # InterpolationType.seventh BoundType.dst1\n0.7555338 , 0.98430985, 0.99997836, 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.84309906,-1.4446614 ,-6.3982205 ,-6.3982205 ,-1.4446614 , 0.843099 , 0.99978304, 1. , 1. , 1.0000001 , 0.99997824, 0.9843099 , 0.7555338 , # InterpolationType.seventh BoundType.dst1\n0., # InterpolationType.seventh BoundType.dst1\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst2\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst2\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dst2\n0., # InterpolationType.nearest BoundType.dst2\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.linear BoundType.dst2\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.linear BoundType.dst2\n 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.linear BoundType.dst2\n0., # InterpolationType.linear BoundType.dst2\n 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.quadratic BoundType.dst2\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.quadratic BoundType.dst2\n 1., 1., 1., 1., 1., 1., 1., 1., 1.,-18., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., # InterpolationType.quadratic BoundType.dst2\n0., # InterpolationType.quadratic BoundType.dst2\n0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, 9.3132257e-09, 8.7500000e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,-1.3750000e+00,-1.3250000e+01,-1.3750000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 8.7500000e-01, 2.5000000e-01, # InterpolationType.cubic BoundType.dst2\n0., 0.,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08,-2.0489097e-08, 9.3132257e-09, # InterpolationType.cubic BoundType.dst2\n 0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. , -1.375,-13.25 , -1.375, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.875, 0.25 , # InterpolationType.cubic BoundType.dst2\n0., # InterpolationType.cubic BoundType.dst2\n0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, 8.3333337e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999988e-01, 1.0000000e+00,-2.1666670e+00,-1.1666667e+01,-2.1666667e+00, 1.0000000e+00, 1.0000001e+00, 1.0000002e+00, 1.0000000e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 8.3333337e-01, 3.3333334e-01, # InterpolationType.fourth BoundType.dst2\n0., 0.,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-4.0978193e-08,-1.1175871e-08, # InterpolationType.fourth BoundType.dst2\n 0.8333334 , 1. , 1. , 1. , 1. , 1. , 0.9999999 , 1. , -2.166667 ,-11.666667 , -2.1666667 , 1. , 1.0000001 , 1.0000002 , 1. , 1.0000001 , 1.0000001 , 1. , 0.8333334 , 0.33333334, # InterpolationType.fourth BoundType.dst2\n0., # InterpolationType.fourth BoundType.dst2\n0., 3.7252903e-09, 1.1175871e-08, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 7.1886461e-09, 1.0913936e-08, 8.0208331e-01, 9.9739581e-01, 1.0000000e+00, 1.0000000e+00, 9.9999994e-01, 1.0000001e+00, 9.9999976e-01, 9.5052075e-01,-2.7604163e+00,-1.0380208e+01,-2.7604165e+00, 9.5052069e-01, 9.9999982e-01, 1.0000002e+00, 1.0000000e+00, 9.9999982e-01, 1.0000000e+00, 9.9739587e-01, 8.0208331e-01, 4.0104166e-01, # InterpolationType.fifth BoundType.dst2\n0.,3.7252903e-09,1.1175871e-08,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,7.1886461e-09,1.0913936e-08, # InterpolationType.fifth BoundType.dst2\n 0.8020833 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.95052075, -2.7604163 ,-10.380208 , -2.7604165 , 0.9505207 , 0.9999998 , 1.0000002 , 1. , 0.9999998 , 1. , 0.9973959 , 0.8020833 , 0.40104166, # InterpolationType.fifth BoundType.dst2\n0., # InterpolationType.fifth BoundType.dst2\n5.9604645e-08,-1.4901161e-08,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-1.1292286e-08, 7.8333330e-01, 9.9166673e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999982e-01, 1.0000004e+00, 8.4166676e-01,-3.1166668e+00,-9.4499998e+00,-3.1166666e+00, 8.4166652e-01, 1.0000004e+00, 1.0000001e+00, 1.0000001e+00, 1.0000000e+00, 9.9999994e-01, 9.9166667e-01, 7.8333330e-01, 4.5000002e-01, # InterpolationType.sixth BoundType.dst2\n5.9604645e-08,-1.4901161e-08,-1.4901161e-08, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09, 3.6088750e-09,-1.1292286e-08, # InterpolationType.sixth BoundType.dst2\n0.7833333 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.84166676,-3.1166668 ,-9.45 ,-3.1166666 , 0.8416665 , 1.0000004 , 1.0000001 , 1.0000001 , 1. , 0.99999994, 0.9916667 , 0.7833333 , 0.45000002, # InterpolationType.sixth BoundType.dst2\n0., # InterpolationType.sixth BoundType.dst2\n0.,-7.4505806e-09,-6.9849193e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09,-5.0350764e-09, 7.7120221e-01, 9.8433155e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9958777e-01, 7.0230043e-01,-3.3471570e+00,-8.7094622e+00,-3.3471570e+00, 7.0230043e-01, 9.9958777e-01, 1.0000000e+00, 1.0000000e+00, 1.0000001e+00, 9.9997824e-01, 9.8433161e-01, 7.7120221e-01, 4.8897570e-01, # InterpolationType.seventh BoundType.dst2\n0.,-7.4505806e-09,-6.9849193e-09, 2.3283064e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09, 1.9498430e-09,-5.0350764e-09, # InterpolationType.seventh BoundType.dst2\n0.7712022 , 0.98433155, 0.99997836, 0.99999994, 1. , 1.0000001 , 0.9995878 , 0.7023004 ,-3.347157 ,-8.709462 ,-3.347157 , 0.7023004 , 0.9995878 , 1. , 1. , 1.0000001 , 0.99997824, 0.9843316 , 0.7712022 , 0.4889757 , # InterpolationType.seventh BoundType.dst2\n0., # InterpolationType.seventh BoundType.dst2\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dft\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.nearest BoundType.dft\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.dft\n0., # InterpolationType.nearest BoundType.dft\n2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.linear BoundType.dft\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.linear BoundType.dft\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.linear BoundType.dft\n0., # InterpolationType.linear BoundType.dft\n2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.quadratic BoundType.dft\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.quadratic BoundType.dft\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., # InterpolationType.quadratic BoundType.dft\n0., # InterpolationType.quadratic BoundType.dft\n2. , 2. , 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 1.9999999, 2. ,-0.25 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.5 ,-0.25 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.5 , # InterpolationType.cubic BoundType.dft\n2. ,2. ,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,1.9999999,2. , # InterpolationType.cubic BoundType.dft\n-0.25, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25,-6.5 ,-0.25, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25,-6.5 , # InterpolationType.cubic BoundType.dft\n0., # InterpolationType.cubic BoundType.dft\n2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. ,-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 ,-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 , # InterpolationType.fourth BoundType.dft\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.fourth BoundType.dft\n-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 ,-0.6666666, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.666667 , # InterpolationType.fourth BoundType.dft\n0., # InterpolationType.fourth BoundType.dft\n2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. ,-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 ,-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 , # InterpolationType.fifth BoundType.dft\n2.,2.,2.,2.,2.,2.,2.,2.,2.,2., # InterpolationType.fifth BoundType.dft\n-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 ,-0.97916675, 0.9739583 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9791663 ,-4.989583 , # InterpolationType.fifth BoundType.dft\n0., # InterpolationType.fifth BoundType.dft\n1.9999999 , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. , 2. ,-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 ,-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 , # InterpolationType.sixth BoundType.dft\n1.9999999,2. ,2. ,2. ,2. ,2. ,2. ,2. ,2. ,2. , # InterpolationType.sixth BoundType.dft\n-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 ,-1.1666667 , 0.9166667 , 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1666669 ,-4.4999995 , # InterpolationType.sixth BoundType.dft\n0., # InterpolationType.sixth BoundType.dft\n2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 , 2.0000002 ,-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 ,-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 , # InterpolationType.seventh BoundType.dft\n2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002,2.0000002, # InterpolationType.seventh BoundType.dft\n-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 ,-1.2879773 , 0.84331596, 0.999783 , 0.99999994, 1. , 1.0000001 , 0.9997831 , 0.8433161 ,-1.2879775 ,-4.110243 , # InterpolationType.seventh BoundType.dft\n0., # InterpolationType.seventh BoundType.dft\n0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.zero\n0.,1.,1.,1.,1.,1.,1.,1.,1.,1., # InterpolationType.nearest BoundType.zero\n0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0.,0., # InterpolationType.nearest BoundType.zero\n0., # InterpolationType.nearest BoundType.zero\n0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-9. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.linear BoundType.zero\n0.5,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.linear BoundType.zero\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.linear BoundType.zero\n0., # InterpolationType.linear BoundType.zero\n0.5, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-9. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.quadratic BoundType.zero\n0.5,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.quadratic BoundType.zero\n1., 1., 1., 1., 1., 1., 1., 1., 1.,-9., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., # InterpolationType.quadratic BoundType.zero\n0., # InterpolationType.quadratic BoundType.zero\n0.5 , 0.9791666 , 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.875 , 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.625 ,-1.125 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.cubic BoundType.zero\n0.5 ,0.9791666 ,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994,0.99999994, # InterpolationType.cubic BoundType.zero\n0.875, 1. , 1. , 1. , 1. , 1. , 1. , 1. ,-0.25 ,-6.625,-1.125, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.cubic BoundType.zero\n0., # InterpolationType.cubic BoundType.zero\n0.5 , 0.9583334, 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.8333334, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.8333335,-1.5 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fourth BoundType.zero\n0.5 ,0.9583334,1. ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.fourth BoundType.zero\n0.8333334, 1. , 1. , 1. , 1. , 1. , 0.9999999, 1. ,-0.666667 ,-5.8333335,-1.5 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fourth BoundType.zero\n0., # InterpolationType.fourth BoundType.zero\n0.5 , 0.9380208 , 0.9997396 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.7994791 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9817705 ,-5.190104 ,-1.7786459 ,-0.0234375 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fifth BoundType.zero\n0.5 ,0.9380208,0.9997396,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.fifth BoundType.zero\n0.7994791 , 0.9973958 , 1. , 1. , 0.99999994, 1.0000001 , 0.99999976, 0.97395825,-0.9817705 ,-5.190104 ,-1.7786459 ,-0.0234375 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.fifth BoundType.zero\n0., # InterpolationType.fifth BoundType.zero\n0.49999997, 0.91944444, 0.9986111 , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0.775 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1750002 ,-4.725 ,-1.9416667 ,-0.075 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.sixth BoundType.zero\n0.49999997,0.91944444,0.9986111 ,1. ,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.sixth BoundType.zero\n0.775 , 0.99166673, 1. , 1. , 1. , 0.9999998 , 1.0000004 , 0.91666675,-1.1750002 ,-4.725 ,-1.9416667 ,-0.075 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , # InterpolationType.sixth BoundType.zero\n0., # InterpolationType.sixth BoundType.zero\n5.0000000e-01, 9.0259641e-01, 9.9662077e-01, 9.9999845e-01, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 7.5551212e-01, 9.8430985e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9978310e-01, 8.4329438e-01,-1.3036675e+00,-4.3547311e+00,-2.0434895e+00,-1.4099392e-01,-1.9531250e-04, 0., 0., 0., 0., 0., 0., 0., # InterpolationType.seventh BoundType.zero\n0.5 ,0.9025964 ,0.9966208 ,0.99999845,1. ,1. ,1. ,1. ,1. ,1. , # InterpolationType.seventh BoundType.zero\n7.5551212e-01, 9.8430985e-01, 9.9997836e-01, 9.9999994e-01, 1.0000000e+00, 1.0000001e+00, 9.9978310e-01, 8.4329438e-01,-1.3036675e+00,-4.3547311e+00,-2.0434895e+00,-1.4099392e-01,-1.9531250e-04, 0., 0., 0., 0., 0., 0., 0., # InterpolationType.seventh BoundType.zero\n0., # InterpolationType.seventh BoundType.zero\n"
  },
  {
    "path": "tests/testing_data/1D_BP_fwd.txt",
    "content": "1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.nearest BoundType.replicate\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.linear BoundType.replicate\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.quadratic BoundType.replicate\n0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4792, 8.9792, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.cubic BoundType.replicate\n0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4583, 8.9583, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.fourth BoundType.replicate\n0.5622, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4997, 8.4378, 8.9378, 8.9997, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.fifth BoundType.replicate\n0.5819, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4986, 8.4181, 8.9181, 8.9986, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.sixth BoundType.replicate\n0.6008, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4966, 8.3992, 8.8992, 8.9966, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, # InterpolationType.seventh BoundType.replicate\n1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 1.0, 2.0, # InterpolationType.nearest BoundType.dct1\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.5, 1.5, # InterpolationType.linear BoundType.dct1\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.5, 1.5, # InterpolationType.quadratic BoundType.dct1\n0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4583, 8.4583, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5417, 0.5417, 1.5, # InterpolationType.cubic BoundType.dct1\n0.5833, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4167, 8.4167, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5833, 0.5833, 1.5, # InterpolationType.fourth BoundType.dct1\n0.6245, 1.5005, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4995, 8.3755, 8.3755, 7.4995, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5005, 0.6245, 0.6245, 1.5005, # InterpolationType.fifth BoundType.dct1\n0.6639, 1.5028, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4972, 8.3361, 8.3361, 7.4972, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5028, 0.6639, 0.6639, 1.5028, # InterpolationType.sixth BoundType.dct1\n0.7016, 1.5068, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4932, 8.2984, 8.2984, 7.4932, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5068, 0.7016, 0.7016, 1.5068, # InterpolationType.seventh BoundType.dct1\n1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 0.0, # InterpolationType.nearest BoundType.dct2\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.0, # InterpolationType.linear BoundType.dct2\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.0, 8.5, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5, 0.0, # InterpolationType.quadratic BoundType.dct2\n0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4792, 8.9583, 8.4792, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5208, 0.0417, # InterpolationType.cubic BoundType.dct2\n0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.4583, 8.9167, 8.4583, 7.5, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5, 0.5417, 0.0833, # InterpolationType.fourth BoundType.dct2\n0.5625, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4997, 8.4375, 8.8755, 8.4375, 7.4997, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5003, 0.5625, 0.1245, # InterpolationType.fifth BoundType.dct2\n0.5833, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4986, 8.4167, 8.8361, 8.4167, 7.4986, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5014, 0.5833, 0.1639, # InterpolationType.sixth BoundType.dct2\n0.6042, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4966, 8.3958, 8.7984, 8.3958, 7.4966, 6.5, 5.5, 4.5, 3.5, 2.5, 1.5034, 0.6042, 0.2016, # InterpolationType.seventh BoundType.dct2\n1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, -0.0, # InterpolationType.nearest BoundType.dst1\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, -4.5, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, # InterpolationType.linear BoundType.dst1\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, -4.5, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, # InterpolationType.quadratic BoundType.dst1\n0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.2917, -4.2917, -8.2917, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5208, # InterpolationType.cubic BoundType.dst1\n0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.0833, -4.0833, -8.0833, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5417, # InterpolationType.fourth BoundType.dst1\n0.5622, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8776, 3.8802, -3.8802, -7.8776, -7.4974, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5003, -0.5622, # InterpolationType.fifth BoundType.dst1\n0.5819, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6806, 3.6944, -3.6944, -7.6806, -7.4861, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5014, -0.5819, # InterpolationType.sixth BoundType.dst1\n0.6008, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.4922, 3.5260, -3.5260, -7.4922, -7.4662, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5034, -0.6008, # InterpolationType.seventh BoundType.dst1\n1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -2.0, -1.0, -0.0, 0.0, # InterpolationType.nearest BoundType.dst2\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 0.0, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, 0.0, # InterpolationType.linear BoundType.dst2\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 0.0, -8.5, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -0.5, 0.0, # InterpolationType.quadratic BoundType.dst2\n5.2083e-01, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.1042, -1.6391e-07, -8.1042, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -5.2083e-01, 0.0, # InterpolationType.cubic BoundType.dst2\n5.4167e-01, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 7.7083, 1.4901e-07, -7.7083, -7.5, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5, -5.4167e-01, 0.0, # InterpolationType.fourth BoundType.dst2\n5.6198e-01, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4951, 7.3224, 1.2107e-07, -7.3224, -7.4951, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5003, -5.6198e-01, 5.2387e-10, # InterpolationType.fifth BoundType.dst2\n5.8056e-01, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4736, 6.9694, -1.0896e-07, -6.9694, -7.4736, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5014, -5.8056e-01, 2.3283e-10, # InterpolationType.sixth BoundType.dst2\n0.59740, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4358, 6.6493, 0.0, -6.6493, -7.4358, -6.5, -5.5, -4.5, -3.5, -2.5, -1.5034, -0.59740, 0.0, # InterpolationType.seventh BoundType.dst2\n1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, # InterpolationType.nearest BoundType.dft\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, # InterpolationType.linear BoundType.dft\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, # InterpolationType.quadratic BoundType.dft\n0.7083, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.5, 0.7083, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.5, # InterpolationType.cubic BoundType.dft\n0.9167, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.5, 0.9167, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.5, # InterpolationType.fourth BoundType.dft\n1.1198, 1.5026, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8802, 4.5, 1.1198, 1.5026, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8802, 4.5, # InterpolationType.fifth BoundType.dft\n1.3056, 1.5139, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6944, 4.5, 1.3056, 1.5139, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6944, 4.5, # InterpolationType.sixth BoundType.dft\n1.4740, 1.5338, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.5260, 4.5, 1.4740, 1.5338, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.5260, 4.5, # InterpolationType.seventh BoundType.dft\n1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.nearest BoundType.zero\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.linear BoundType.zero\n0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 4.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.quadratic BoundType.zero\n0.5208, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.2917, 4.4792, 0.1875, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.cubic BoundType.zero\n0.5417, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.0833, 4.4583, 0.3750, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.fourth BoundType.zero\n5.6224e-01, 1.5003, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4974, 7.8799, 4.4378, 5.5755e-01, 2.3438e-03, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.fifth BoundType.zero\n0.5819, 1.5014, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4861, 7.6931, 4.4181, 0.7236, 0.0125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.sixth BoundType.zero\n6.0078e-01, 1.5034, 2.5, 3.5, 4.5, 5.5, 6.5, 7.4662, 7.5226, 4.3992, 8.7325e-01, 3.0411e-02, 1.3951e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, # InterpolationType.seventh BoundType.zero\n"
  },
  {
    "path": "tests/testing_data/bundle_test_network.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport torch\n\nfrom monai.networks.nets import UNet\n\n\nclass TestMultiInputUNet(UNet):\n    \"\"\"\n    This class is used for \"tests/test_bundle_verify_net.py\" to show that the monai.bundle.verify_net_in_out\n    function supports to verify networks that have multiple args as the input in the forward function.\n    \"\"\"\n\n    def forward(self, x: torch.Tensor, extra_arg1: int, extra_arg2: int) -> torch.Tensor:  # type: ignore\n        x = self.model(x)\n        x += extra_arg1\n        x += extra_arg2\n        return x\n"
  },
  {
    "path": "tests/testing_data/config_fl_evaluate.json",
    "content": "{\n    \"validate#handlers\": [\n        {\n            \"_target_\": \"CheckpointLoader\",\n            \"load_path\": \"$@bundle_root + '/models/model.pt'\",\n            \"load_dict\": {\n                \"model\": \"@network\"\n            }\n        },\n        {\n            \"_target_\": \"StatsHandler\",\n            \"iteration_log\": false\n        }\n    ],\n    \"run\": [\n        \"$@validate#evaluator.run()\"\n    ]\n}\n"
  },
  {
    "path": "tests/testing_data/config_fl_filters.json",
    "content": "{\n    \"pre_filters\": [\n        {\n            \"_target_\": \"monai.fl.utils.filters.SummaryFilter\"\n        }\n    ],\n    \"post_weight_filters\": [\n        {\n            \"_target_\": \"monai.fl.utils.filters.SummaryFilter\"\n        }\n    ],\n    \"post_evaluate_filters\": []\n}\n"
  },
  {
    "path": "tests/testing_data/config_fl_stats_1.json",
    "content": "{\n    \"imports\": [\n        \"$import os\"\n    ],\n    \"bundle_root\": \"\",\n    \"dataset_dir\": \"@bundle_root\",\n    \"train\": {\n        \"dataset\": {\n            \"_target_\": \"Dataset\",\n            \"data\": [\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'anatomical.nii')\",\n                    \"label\": \"$os.path.join(@dataset_dir, 'anatomical_label.nii.gz')\"\n                },\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'reoriented_anat_moved.nii')\",\n                    \"label\": \"$os.path.join(@dataset_dir, 'reoriented_anat_moved_label.nii.gz')\"\n                }\n            ],\n            \"transform\": \"@train#preprocessing\"\n        }\n    }\n}\n"
  },
  {
    "path": "tests/testing_data/config_fl_stats_2.json",
    "content": "{\n    \"imports\": [\n        \"$import os\"\n    ],\n    \"bundle_root\": \"\",\n    \"dataset_dir\": \"@bundle_root\",\n    \"train\": {\n        \"dataset\": {\n            \"_target_\": \"Dataset\",\n            \"data\": [\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'anatomical.nii')\",\n                    \"label\": \"$os.path.join(@dataset_dir, 'anatomical_label.nii.gz')\"\n                },\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'reoriented_anat_moved.nii')\",\n                    \"label\": \"$os.path.join(@dataset_dir, 'reoriented_anat_moved_label.nii.gz')\"\n                }\n            ],\n            \"transform\": \"@train#preprocessing\"\n        }\n    },\n    \"validate\": {\n        \"dataset\": {\n            \"_target_\": \"Dataset\",\n            \"data\": [\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'anatomical.nii')\",\n                    \"label\": \"$os.path.join(@dataset_dir, 'anatomical_label.nii.gz')\"\n                },\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'reoriented_anat_moved.nii')\",\n                    \"label\": \"$os.path.join(@dataset_dir, 'reoriented_anat_moved_label.nii.gz')\"\n                }\n            ],\n            \"transform\": \"@train#preprocessing\"\n        }\n    }\n}\n"
  },
  {
    "path": "tests/testing_data/config_fl_train.json",
    "content": "{\n    \"bundle_root\": \"tests/testing_data\",\n    \"dataset_dir\": \"@bundle_root\",\n    \"val_interval\": 1,\n    \"imports\": [\n        \"$import os\"\n    ],\n    \"device\": \"$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\",\n    \"network_def\": {\n        \"_target_\": \"DenseNet121\",\n        \"spatial_dims\": 2,\n        \"in_channels\": 1,\n        \"out_channels\": 6\n    },\n    \"network\": \"$@network_def.to(@device)\",\n    \"loss\": {\n        \"_target_\": \"torch.nn.CrossEntropyLoss\"\n    },\n    \"optimizer\": {\n        \"_target_\": \"torch.optim.Adam\",\n        \"params\": \"$@network.parameters()\",\n        \"lr\": 0.0001\n    },\n    \"train\": {\n        \"training_transforms\": [\n            {\n                \"_target_\": \"LoadImaged\",\n                \"keys\": [\n                    \"image\"\n                ],\n                \"image_only\": true\n            },\n            {\n                \"_target_\": \"EnsureChannelFirstD\",\n                \"keys\": [\n                    \"image\"\n                ]\n            },\n            {\n                \"_target_\": \"ScaleIntensityd\",\n                \"keys\": [\n                    \"image\"\n                ]\n            },\n            {\n                \"_target_\": \"RandRotated\",\n                \"keys\": [\n                    \"image\"\n                ],\n                \"range_x\": 15,\n                \"prob\": 0.5,\n                \"keep_size\": true\n            },\n            {\n                \"_target_\": \"RandFlipd\",\n                \"keys\": [\n                    \"image\"\n                ],\n                \"spatial_axis\": 0,\n                \"prob\": 0.5\n            },\n            {\n                \"_target_\": \"RandZoomd\",\n                \"keys\": [\n                    \"image\"\n                ],\n                \"min_zoom\": 0.9,\n                \"max_zoom\": 1.1,\n                \"prob\": 0.5\n            }\n        ],\n        \"preprocessing\": {\n            \"_target_\": \"Compose\",\n            \"transforms\": \"$@train#training_transforms\"\n        },\n        \"dataset\": {\n            \"_target_\": \"Dataset\",\n            \"data\": [\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'image0.jpeg')\",\n                    \"label\": 0\n                },\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'image1.jpeg')\",\n                    \"label\": 1\n                }\n            ],\n            \"transform\": \"@train#preprocessing\"\n        },\n        \"dataloader\": {\n            \"_target_\": \"DataLoader\",\n            \"dataset\": \"@train#dataset\",\n            \"batch_size\": 3,\n            \"shuffle\": true,\n            \"num_workers\": 2\n        },\n        \"inferer\": {\n            \"_target_\": \"SimpleInferer\"\n        },\n        \"handlers\": [\n            {\n                \"_target_\": \"ValidationHandler\",\n                \"validator\": \"@validate#evaluator\",\n                \"epoch_level\": true,\n                \"interval\": \"@val_interval\"\n            },\n            {\n                \"_target_\": \"StatsHandler\",\n                \"tag_name\": \"train_loss\",\n                \"output_transform\": \"$monai.handlers.from_engine(['loss'], first=True)\"\n            }\n        ],\n        \"trainer\": {\n            \"_target_\": \"SupervisedTrainer\",\n            \"max_epochs\": 2,\n            \"device\": \"@device\",\n            \"train_data_loader\": \"@train#dataloader\",\n            \"network\": \"@network\",\n            \"loss_function\": \"@loss\",\n            \"optimizer\": \"@optimizer\",\n            \"inferer\": \"@train#inferer\",\n            \"train_handlers\": \"@train#handlers\"\n        }\n    },\n    \"validate\": {\n        \"preprocessing\": {\n            \"_target_\": \"Compose\",\n            \"transforms\": [\n                {\n                    \"_target_\": \"LoadImaged\",\n                    \"keys\": [\n                        \"image\"\n                    ],\n                    \"image_only\": true\n                },\n                {\n                    \"_target_\": \"EnsureChannelFirstD\",\n                    \"keys\": [\n                        \"image\"\n                    ]\n                },\n                {\n                    \"_target_\": \"ScaleIntensityd\",\n                    \"keys\": [\n                        \"image\"\n                    ]\n                }\n            ]\n        },\n        \"dataset\": {\n            \"_target_\": \"Dataset\",\n            \"data\": [\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'image0.jpeg')\",\n                    \"label\": 0\n                },\n                {\n                    \"image\": \"$os.path.join(@dataset_dir, 'image1.jpeg')\",\n                    \"label\": 1\n                }\n            ],\n            \"transform\": \"@validate#preprocessing\"\n        },\n        \"dataloader\": {\n            \"_target_\": \"DataLoader\",\n            \"dataset\": \"@validate#dataset\",\n            \"batch_size\": 1,\n            \"shuffle\": false,\n            \"num_workers\": 2\n        },\n        \"inferer\": {\n            \"_target_\": \"SimpleInferer\"\n        },\n        \"postprocessing\": {\n            \"_target_\": \"Compose\",\n            \"transforms\": [\n                {\n                    \"_target_\": \"Activationsd\",\n                    \"keys\": \"pred\",\n                    \"softmax\": true\n                }\n            ]\n        },\n        \"key_metric\": {\n            \"accuracy\": {\n                \"_target_\": \"ignite.metrics.Accuracy\",\n                \"output_transform\": \"$monai.handlers.from_engine(['pred', 'label'])\"\n            }\n        },\n        \"handlers\": [\n            {\n                \"_target_\": \"StatsHandler\",\n                \"iteration_log\": false\n            }\n        ],\n        \"evaluator\": {\n            \"_target_\": \"SupervisedEvaluator\",\n            \"device\": \"@device\",\n            \"val_data_loader\": \"@validate#dataloader\",\n            \"network\": \"@network\",\n            \"inferer\": \"@validate#inferer\",\n            \"val_handlers\": \"@validate#handlers\",\n            \"postprocessing\": \"@validate#postprocessing\",\n            \"key_val_metric\": \"@validate#key_metric\"\n        }\n    },\n    \"initialize\": [\n        \"$monai.utils.set_determinism(seed=123)\"\n    ],\n    \"run\": [\n        \"$@train#trainer.run()\"\n    ],\n    \"finalize\": [\n        \"$monai.utils.set_determinism(seed=None)\"\n    ]\n}\n"
  },
  {
    "path": "tests/testing_data/cpp_resample_answers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport warnings\n\n\ndef _read_testing_data_answers(fname: str | None = None, delimiter=\",\") -> list:\n    answers: list = []\n    if not fname:\n        return answers\n    # read answers from directory of the current file\n    pwd = os.path.dirname(os.path.abspath(__file__))\n    filename = os.path.join(pwd, fname)\n    if not os.path.isfile(filename):\n        warnings.warn(f\"test data {filename} not found.\")\n        return answers\n    with open(filename) as f:\n        res_reader = csv.reader(f, delimiter=delimiter)\n        for r in res_reader:\n            res_row = []\n            for item in r:\n                if item.strip().startswith(\"#\"):\n                    continue  # allow for some simple comments in the file\n                res_row.append(float(item))\n            answers.append(res_row)\n    return answers\n\n\nExpected_1D_GP_fwd: list = _read_testing_data_answers(fname=\"1D_BP_fwd.txt\")\nExpected_1D_GP_bwd: list = _read_testing_data_answers(fname=\"1D_BP_bwd.txt\")\n"
  },
  {
    "path": "tests/testing_data/data_config.json",
    "content": "{\n    \"images\": {\n        \"wsi_generic_tiff\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CMU-1.tiff\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"73a7e89bc15576587c3d68e55d9bf92f09690280166240b48ff4b48230b13bcd\"\n        },\n        \"wsi_aperio_svs\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/Aperio-CMU-1.svs\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"00a3d54482cd707abf254fe69dccc8d06b8ff757a1663f1290c23418c480eb30\"\n        },\n        \"favicon\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/favicon.ico.zip\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"3a3635c8d8adb81feebc5926b4106e8eb643a24a4be2a69a9d35f9a578acadb5\"\n        },\n        \"icon\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/icon.tar.gz\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"90f24cd8f20f3932624da95190ce384302261acf0ea15b358f7832e3b6becac0\"\n        },\n        \"mednist\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"f2f4881ff8799a170b10a403495f0ce0ad7486491901cde67a647e6627e7f916\"\n        },\n        \"Prostate_T2W_AX_1\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/Prostate_T2W_AX_1.nii\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"a14231f539c0f365a5f83f2a046969a9b9870e56ffd126fd8e7242364d25938a\"\n        },\n        \"0000_t2_tse_tra_4\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ProstateX-0000_t2_tse_tra_4.nii.gz\",\n            \"hash_type\": \"md5\",\n            \"hash_val\": \"adb3f1c4db66a6481c3e4a2a3033c7d5\"\n        },\n        \"0000_ep2d_diff_tra_7\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ProstateX-0000_ep2d_diff_tra_7.nii.gz\",\n            \"hash_type\": \"md5\",\n            \"hash_val\": \"f12a11ad0ebb0b1876e9e010564745d2\"\n        },\n        \"ref_avg152T1_LR\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/avg152T1_LR_nifti.nii.gz\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"c01a50caa7a563158ecda43d93a1466bfc8aa939bc16b06452ac1089c54661c8\"\n        },\n        \"ref_avg152T1_RL\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/avg152T1_RL_nifti.nii.gz\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"8a731128dac4de46ccb2cc60d972b98f75a52f21fb63ddb040ca96f0aed8b51a\"\n        },\n        \"MNI152_T1_2mm\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm.nii.gz\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"0585cd056bf5ccfb8bf97a5f6a66082d4e7caad525718fc11e40d80a827fcb92\"\n        },\n        \"MNI152_T1_2mm_strucseg\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MNI152_T1_2mm_strucseg.nii.gz\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4\"\n        },\n        \"copd1_highres_INSP_STD_COPD_img\": {\n            \"url\": \"https://data.kitware.com/api/v1/file/62a0f067bddec9d0c4175c5a/download\",\n            \"hash_type\": \"sha512\",\n            \"hash_val\": \"60193cd6ef0cf055c623046446b74f969a2be838444801bd32ad5bedc8a7eeecb343e8a1208769c9c7a711e101c806a3133eccdda7790c551a69a64b9b3701e9\"\n        },\n        \"copd1_highres_EXP_STD_COPD_img\": {\n            \"url\": \"https://data.kitware.com/api/v1/item/62a0f045bddec9d0c4175c44/download\",\n            \"hash_type\": \"sha512\",\n            \"hash_val\": \"841ef303958541474e66c2d1ccdc8b7ed17ba2f2681101307766b979a07979f2ec818ddf13791c3f1ac5a8ec3258d6ea45b692b4b4a838de9188602618972b6d\"\n        },\n        \"CT_2D_head_fixed\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_fixed.mha\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"06f2ce6fbf6a59f0874c735555fcf71717f631156b1b0697c1752442f7fc1cc5\"\n        },\n        \"CT_2D_head_moving\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_2D_head_moving.mha\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"a37c5fe388c38b3f4ac564f456277d09d3982eda58c4da05ead8ee2332360f47\"\n        },\n        \"DICOM_single\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_DICOM_SINGLE.zip\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"a41f6e93d2e3d68956144f9a847273041d36441da12377d6a1d5ae610e0a7023\"\n        },\n        \"nrrd_example\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/CT_IMAGE_cropped.nrrd\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"66971ad17f0bac50e6082ed6a4dc1ae7093c30517137e53327b15a752327a1c0\"\n        }\n    },\n    \"videos\": {\n        \"endovis\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/d1_im.mp4\",\n            \"hash_type\": \"md5\",\n            \"hash_val\": \"9b103c07326439b0ea376018d7189384\"\n        },\n        \"ultrasound\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/example_data_Ultrasound_Q000_04_tu_segmented_ultrasound_256.avi\",\n            \"hash_type\": \"md5\",\n            \"hash_val\": \"f0755960cc4a08a958561cda9a79a157\"\n        }\n    },\n    \"models\": {\n        \"senet154-c7b49a05\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/senet154-c7b49a05.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"c7b49a056b98b0bed65b0237c27acdead655e599669215573d357ad337460413\"\n        },\n        \"se_resnet101-7e38fcc6\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet101-7e38fcc6.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"7e38fcc64eff3225a3ea4e6081efeb6087e8d5a61c204d94edc2ed1aab0b9d70\"\n        },\n        \"se_resnet152-d17c99b7\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet152-d17c99b7.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"d17c99b703dcca2d2507ddfb68f72625a2f7e23ee64396eb992f1b2cf7e6bdc1\"\n        },\n        \"se_resnet50-ce0d4300\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnet50-ce0d4300.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"ce0d430017d3f4aa6b5658c72209f3bfffb060207fd26a2ef0b203ce592eba01\"\n        },\n        \"se_resnext101_32x4d-3b2fe3d8\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext101_32x4d-3b2fe3d8.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"3b2fe3d8acb8de7d5976c4baf518f24a0237509272a69366e816682d3e57b989\"\n        },\n        \"se_resnext50_32x4d-a260b3a4\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e\"\n        },\n        \"ssl_pretrained_weights\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8\"\n        },\n        \"decoder_only_transformer_monai_generative_weights\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/decoder_only_transformer.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"f93de37d64d77cf91f3bde95cdf93d161aee800074c89a92aff9d5699120ec0d\"\n        },\n        \"diffusion_model_unet_monai_generative_weights\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/diffusion_model_unet.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"0d2171b386902f5b4fd3e967b4024f63e353694ca45091b114970019d045beee\"\n        },\n        \"autoencoderkl_monai_generative_weights\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184\"\n        },\n        \"controlnet_monai_generative_weights\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth\",\n            \"hash_type\": \"sha256\",\n            \"hash_val\": \"cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e\"\n        }\n    },\n    \"configs\": {\n        \"test_meta_file\": {\n            \"url\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json\",\n            \"hash_type\": \"md5\",\n            \"hash_val\": \"06954cad2cc5d3784e72077ac76f0fc8\"\n        }\n    }\n}\n"
  },
  {
    "path": "tests/testing_data/fl_infer_properties.json",
    "content": "{\n    \"infer\": {\n        \"bundle_root\": {\n            \"description\": \"root path of the bundle.\",\n            \"required\": true,\n            \"id\": \"bundle_root\"\n        },\n        \"device\": {\n            \"description\": \"target device to execute the bundle workflow.\",\n            \"required\": true,\n            \"id\": \"device\"\n        },\n        \"dataset_dir\": {\n            \"description\": \"directory path of the dataset.\",\n            \"required\": true,\n            \"id\": \"dataset_dir\"\n        },\n        \"dataset\": {\n            \"description\": \"PyTorch dataset object for the inference / evaluation logic.\",\n            \"required\": true,\n            \"id\": \"dataset\"\n        },\n        \"evaluator\": {\n            \"description\": \"inference / evaluation workflow engine.\",\n            \"required\": true,\n            \"id\": \"evaluator\"\n        },\n        \"network_def\": {\n            \"description\": \"network module for the inference.\",\n            \"required\": true,\n            \"id\": \"network_def\"\n        },\n        \"inferer\": {\n            \"description\": \"MONAI Inferer object to execute the model computation in inference.\",\n            \"required\": true,\n            \"id\": \"inferer\"\n        },\n        \"dataset_data\": {\n            \"description\": \"data source for the inference / evaluation dataset.\",\n            \"required\": false,\n            \"id\": \"dataset::data\",\n            \"refer_id\": null\n        },\n        \"handlers\": {\n            \"description\": \"event-handlers for the inference / evaluation logic.\",\n            \"required\": false,\n            \"id\": \"handlers\",\n            \"refer_id\": \"evaluator::val_handlers\"\n        },\n        \"preprocessing\": {\n            \"description\": \"preprocessing for the input data.\",\n            \"required\": false,\n            \"id\": \"preprocessing\",\n            \"refer_id\": \"dataset::transform\"\n        },\n        \"postprocessing\": {\n            \"description\": \"postprocessing for the model output data.\",\n            \"required\": false,\n            \"id\": \"postprocessing\",\n            \"refer_id\": \"evaluator::postprocessing\"\n        },\n        \"key_metric\": {\n            \"description\": \"the key metric during evaluation.\",\n            \"required\": false,\n            \"id\": \"key_metric\",\n            \"refer_id\": \"evaluator::key_val_metric\"\n        }\n    },\n    \"meta\": {\n        \"version\": {\n            \"description\": \"version of the inference configuration.\",\n            \"required\": true,\n            \"id\": \"_meta_::version\"\n        }\n    }\n}\n"
  },
  {
    "path": "tests/testing_data/inference.json",
    "content": "{\n    \"dataset_dir\": \"/workspace/data/Task09_Spleen\",\n    \"bundle_root\": \"will override\",\n    \"output_dir\": \"need override\",\n    \"prediction_shape\": \"prediction shape:\",\n    \"import_glob\": \"$import glob\",\n    \"device\": \"$torch.device('cpu')\",\n    \"print_test_name\": \"$print('json_test')\",\n    \"print_glob_file\": \"$print(glob.__file__)\",\n    \"network_def\": {\n        \"_target_\": \"UNet\",\n        \"spatial_dims\": 3,\n        \"in_channels\": 1,\n        \"out_channels\": 2,\n        \"channels\": [\n            2,\n            2,\n            4,\n            8,\n            4\n        ],\n        \"strides\": [\n            2,\n            2,\n            2,\n            2\n        ],\n        \"num_res_units\": 2,\n        \"norm\": \"batch\"\n    },\n    \"network\": \"need override\",\n    \"preprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n            {\n                \"_target_\": \"LoadImaged\",\n                \"keys\": \"image\"\n            },\n            {\n                \"_target_\": \"EnsureChannelFirstd\",\n                \"keys\": \"image\"\n            },\n            {\n                \"_target_\": \"ScaleIntensityd\",\n                \"keys\": \"image\"\n            },\n            {\n                \"_target_\": \"RandRotated\",\n                \"_disabled_\": true,\n                \"keys\": \"image\"\n            }\n        ]\n    },\n    \"dataset\": {\n        \"_target_\": \"need override\",\n        \"data\": \"@_meta_#datalist\",\n        \"transform\": \"@preprocessing\"\n    },\n    \"dataloader\": {\n        \"_target_\": \"DataLoader\",\n        \"dataset\": \"@dataset\",\n        \"batch_size\": 1,\n        \"shuffle\": false,\n        \"num_workers\": 4\n    },\n    \"inferer\": {\n        \"_target_\": \"SlidingWindowInferer\",\n        \"roi_size\": [\n            64,\n            64,\n            32\n        ],\n        \"sw_batch_size\": 4,\n        \"overlap\": 0.25\n    },\n    \"postprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n            {\n                \"_target_\": \"Activationsd\",\n                \"keys\": \"pred\",\n                \"softmax\": true\n            },\n            {\n                \"_target_\": \"AsDiscreted\",\n                \"keys\": \"pred\",\n                \"argmax\": true\n            },\n            {\n                \"_target_\": \"SaveImaged\",\n                \"keys\": \"pred\",\n                \"output_dir\": \"@output_dir\"\n            },\n            {\n                \"_target_\": \"Lambdad\",\n                \"keys\": \"pred\",\n                \"func\": \"$lambda x: print(@prediction_shape + str(x.shape))\",\n                \"overwrite\": false\n            }\n        ]\n    },\n    \"evaluator\": {\n        \"_target_\": \"SupervisedEvaluator\",\n        \"_requires_\": [\n            \"@print_test_name\",\n            \"@print_glob_file\",\n            \"$print('test_in_line_json')\"\n        ],\n        \"device\": \"@device\",\n        \"val_data_loader\": \"@dataloader\",\n        \"network\": \"@network\",\n        \"inferer\": \"@inferer\",\n        \"postprocessing\": \"@postprocessing\",\n        \"amp\": false\n    },\n    \"initialize\": [\n        \"$monai.utils.set_determinism(0)\"\n    ],\n    \"run\": [\n        \"$@evaluator.run()\"\n    ]\n}\n"
  },
  {
    "path": "tests/testing_data/inference.yaml",
    "content": "---\ndataset_dir: \"/workspace/data/Task09_Spleen\"\nbundle_root: \"will override\"\noutput_dir: \"need override\"\nprediction_shape: \"prediction shape:\"\ndevice: \"$torch.device('cpu')\"\nprint_test_name: \"$print('yaml_test')\"\nnetwork_def:\n  _target_: UNet\n  spatial_dims: 3\n  in_channels: 1\n  out_channels: 2\n  channels:\n  - 2\n  - 2\n  - 4\n  - 8\n  - 4\n  strides:\n  - 2\n  - 2\n  - 2\n  - 2\n  num_res_units: 2\n  norm: batch\nnetwork: need override\npreprocessing:\n  _target_: Compose\n  transforms:\n  - _target_: LoadImaged\n    keys: image\n  - _target_: EnsureChannelFirstd\n    keys: image\n  - _target_: ScaleIntensityd\n    keys: image\n  - _target_: RandRotated\n    _disabled_: true\n    keys: image\ndataset:\n  _target_: need override\n  data: \"@_meta_#datalist\"\n  transform: \"@preprocessing\"\ndataloader:\n  _target_: DataLoader\n  dataset: \"@dataset\"\n  batch_size: 1\n  shuffle: false\n  num_workers: 4\ninferer:\n  _target_: SlidingWindowInferer\n  roi_size:\n  - 64\n  - 64\n  - 32\n  sw_batch_size: 4\n  overlap: 0.25\npostprocessing:\n  _target_: Compose\n  transforms:\n  - _target_: Activationsd\n    keys: pred\n    softmax: true\n  - _target_: AsDiscreted\n    keys: pred\n    argmax: true\n  - _target_: SaveImaged\n    keys: pred\n    output_dir: \"@output_dir\"\n  - _target_: Lambdad\n    keys: pred\n    func: \"$lambda x: print(@prediction_shape + str(x.shape))\"\n    overwrite: false\nevaluator:\n  _target_: SupervisedEvaluator\n  _requires_:\n    - \"$print('test_in_line_yaml')\"\n    - \"@print_test_name\"\n  device: \"@device\"\n  val_data_loader: \"@dataloader\"\n  network: \"@network\"\n  inferer: \"@inferer\"\n  postprocessing: \"@postprocessing\"\n  amp: false\ninitialize:\n  - \"$monai.utils.set_determinism(0)\"\nrun:\n  - \"$@evaluator.run()\"\nfinalize:\n  - \"$print('test finalize section.')\"\n"
  },
  {
    "path": "tests/testing_data/integration_answers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport numpy as np\n\nEXPECTED_ANSWERS = [\n    {  # test answers for PyTorch 2.0\n        \"integration_segmentation_3d\": {\n            \"losses\": [\n                0.5430086106061935,\n                0.47010003924369814,\n                0.4453376233577728,\n                0.451901963353157,\n                0.4398456811904907,\n                0.43450237810611725,\n            ],\n            \"best_metric\": 0.9329540133476257,\n            \"infer_metric\": 0.9330471754074097,\n            \"output_sums\": [\n                0.14212507078546172,\n                0.15199039602949577,\n                0.15133471939291526,\n                0.13967984811021827,\n                0.18831614355832332,\n                0.1694076821827231,\n                0.14663931509271658,\n                0.16788710637623733,\n                0.1569452710008219,\n                0.17907130698392254,\n                0.16244092698688475,\n                0.1679350345855819,\n                0.14437674754879065,\n                0.11355098478396568,\n                0.161660275855964,\n                0.20082478187698194,\n                0.17575491677668853,\n                0.0974593860605401,\n                0.19366775441539907,\n                0.20293016863409002,\n                0.19610441127101647,\n                0.20812173772459808,\n                0.16184212006067655,\n                0.13185211452732482,\n                0.14824716961304257,\n                0.14229818359602905,\n                0.23141282114085215,\n                0.1609268635938338,\n                0.14825300029123678,\n                0.10286266811772046,\n                0.11873484714087054,\n                0.1296615212510262,\n                0.11386621034856693,\n                0.15203351148564773,\n                0.16300823766585265,\n                0.1936726544485426,\n                0.2227251185536394,\n                0.18067789917505797,\n                0.19005874127683337,\n                0.07462121515702229,\n            ],\n        }\n    },\n    {  # test answers for cuda 12\n        \"integration_segmentation_3d\": {\n            \"losses\": [\n                0.5362162500619888,\n                0.4704935997724533,\n                0.4335438072681427,\n                0.4507470965385437,\n                0.45187077224254607,\n                0.4363303750753403,\n            ],\n            \"best_metric\": 0.9334161877632141,\n            \"infer_metric\": 0.9335371851921082,\n            \"output_sums\": [\n                0.14210400101844414,\n                0.1521489829835625,\n                0.15127096315211278,\n                0.13992817339153868,\n                0.1884040828001848,\n                0.16929503899789516,\n                0.14662516818085808,\n                0.16803982264111883,\n                0.1570018930834878,\n                0.17916684191571494,\n                0.1626376090146162,\n                0.1680113549677271,\n                0.1446708736188978,\n                0.1140289628362559,\n                0.16191495673888556,\n                0.20066696225510708,\n                0.17581812459936835,\n                0.09836918048666465,\n                0.19355007524499268,\n                0.20291004237066343,\n                0.19606797329772976,\n                0.2082113232291515,\n                0.16189564397603906,\n                0.13203990336741953,\n                0.14849477534402156,\n                0.14250633066863938,\n                0.23139529505006795,\n                0.16079877619802546,\n                0.14821067071610583,\n                0.10302449386782145,\n                0.11876349315302756,\n                0.13006925219380802,\n                0.11431448379763984,\n                0.15254606148569302,\n                0.16317147221367873,\n                0.19376668030880526,\n                0.22260597124465822,\n                0.18085088544070227,\n                0.19010916899493174,\n                0.07748195410499427,\n            ],\n        }\n    },\n    {  # test answers for 23.02\n        \"integration_segmentation_3d\": {\n            \"losses\": [\n                0.5401686698198318,\n                0.4789864182472229,\n                0.4417317628860474,\n                0.44183324575424193,\n                0.4418945342302322,\n                0.44213996827602386,\n            ],\n            \"best_metric\": 0.9316274523735046,\n            \"infer_metric\": 0.9321609735488892,\n            \"output_sums\": [\n                0.14212507078546172,\n                0.15199039602949577,\n                0.15133471939291526,\n                0.13967984811021827,\n                0.18831614355832332,\n                0.1694076821827231,\n                0.14663931509271658,\n                0.16788710637623733,\n                0.1569452710008219,\n                0.17907130698392254,\n                0.16244092698688475,\n                0.1679350345855819,\n                0.14437674754879065,\n                0.11355098478396568,\n                0.161660275855964,\n                0.20082478187698194,\n                0.17575491677668853,\n                0.0974593860605401,\n                0.19366775441539907,\n                0.20293016863409002,\n                0.19610441127101647,\n                0.20812173772459808,\n                0.16184212006067655,\n                0.13185211452732482,\n                0.14824716961304257,\n                0.14229818359602905,\n                0.23141282114085215,\n                0.1609268635938338,\n                0.14825300029123678,\n                0.10286266811772046,\n                0.11873484714087054,\n                0.1296615212510262,\n                0.11386621034856693,\n                0.15203351148564773,\n                0.16300823766585265,\n                0.1936726544485426,\n                0.2227251185536394,\n                0.18067789917505797,\n                0.19005874127683337,\n                0.07462121515702229,\n            ],\n        }\n    },\n    {  # test answers for 24.03\n        \"integration_segmentation_3d\": {\n            \"losses\": [\n                0.5442982316017151,\n                0.4741817444562912,\n                0.4535954713821411,\n                0.44163046181201937,\n                0.4307525992393494,\n                0.428487154841423,\n            ],\n            \"best_metric\": 0.9314384460449219,\n            \"infer_metric\": 0.9315622448921204,\n            \"output_sums\": [\n                0.14268704426414708,\n                0.1528672845845743,\n                0.1521782248125706,\n                0.14028769128068194,\n                0.1889830671664784,\n                0.16999075690664475,\n                0.14736282992708227,\n                0.16877952654821815,\n                0.15779597155181269,\n                0.17987829927082263,\n                0.16320253928314676,\n                0.16854299322173155,\n                0.14497470986956967,\n                0.11437140546369519,\n                0.1624117412960871,\n                0.20156009294443875,\n                0.1764654154256958,\n                0.0982348259217418,\n                0.1942436068604293,\n                0.20359421536407518,\n                0.19661953116976483,\n                0.2088326101468625,\n                0.16273043545239807,\n                0.1326107887439663,\n                0.1489245275752285,\n                0.143107476635514,\n                0.23189027677929547,\n                0.1613818424566088,\n                0.14889532196775188,\n                0.10332622984492143,\n                0.11940054688302351,\n                0.13040496302762658,\n                0.11472123087193181,\n                0.15307044007394474,\n                0.16371989575844717,\n                0.1942898223272055,\n                0.2230120930471398,\n                0.1814679187634795,\n                0.19069496508164732,\n                0.07537197031940022,\n            ],\n        }\n    },\n    {  # test answers for 24.10\n        \"integration_classification_2d\": {\n            \"losses\": 0.7806512035761669,\n            \"best_metric\": 0.9977695200407783,\n            \"infer_prop\": [805, 727, 955, 1033, 321, 993],\n        },\n        \"integration_workflows\": {\n            \"best_metric\": 0.9207136034965515,\n            \"best_metric_2\": 0.9216295480728149,\n            \"infer_metric\": 0.920440673828125,\n            \"infer_metric_2\": 0.9203161001205444,\n            \"output_sums\": [\n                0.1423349380493164,\n                0.15172767639160156,\n                0.1382155418395996,\n                0.13398218154907227,\n                0.18552064895629883,\n                0.16435527801513672,\n                0.14128494262695312,\n                0.16725540161132812,\n                0.15690851211547852,\n                0.17731285095214844,\n                0.16189050674438477,\n                0.16543960571289062,\n                0.14431238174438477,\n                0.11064529418945312,\n                0.16129302978515625,\n                0.1970067024230957,\n                0.17503118515014648,\n                0.053476810455322266,\n                0.1914362907409668,\n                0.2001795768737793,\n                0.19636154174804688,\n                0.2040243148803711,\n                0.1606454849243164,\n                0.13213014602661133,\n                0.15132904052734375,\n                0.1370987892150879,\n                0.22805070877075195,\n                0.16170072555541992,\n                0.1477980613708496,\n                0.10428047180175781,\n                0.1195521354675293,\n                0.13089942932128906,\n                0.11238527297973633,\n                0.15204906463623047,\n                0.1603565216064453,\n                0.19054937362670898,\n                0.21789216995239258,\n                0.17824840545654297,\n                0.18654584884643555,\n                0.03622245788574219,\n            ],\n            \"output_sums_2\": [\n                0.1423349380493164,\n                0.15172767639160156,\n                0.1382155418395996,\n                0.13398218154907227,\n                0.18552064895629883,\n                0.16435527801513672,\n                0.14128494262695312,\n                0.16725540161132812,\n                0.15690851211547852,\n                0.17731285095214844,\n                0.16189050674438477,\n                0.16543960571289062,\n                0.14431238174438477,\n                0.11064529418945312,\n                0.16129302978515625,\n                0.1970067024230957,\n                0.17503118515014648,\n                0.053476810455322266,\n                0.1914362907409668,\n                0.2001795768737793,\n                0.19636154174804688,\n                0.2040243148803711,\n                0.1606454849243164,\n                0.13213014602661133,\n                0.15132904052734375,\n                0.1370987892150879,\n                0.22805070877075195,\n                0.16170072555541992,\n                0.1477980613708496,\n                0.10428047180175781,\n                0.1195521354675293,\n                0.13089942932128906,\n                0.11238527297973633,\n                0.15204906463623047,\n                0.1603565216064453,\n                0.19054937362670898,\n                0.21789216995239258,\n                0.17824840545654297,\n                0.18654584884643555,\n                0.03622245788574219,\n            ],\n        },\n    },\n]\n\n\ndef test_integration_value(test_name, key, data, rtol=1e-2):\n    for idx, expected in enumerate(EXPECTED_ANSWERS):\n        if test_name not in expected:\n            continue\n        if key not in expected[test_name]:\n            continue\n        value = expected[test_name][key]\n        if np.allclose(data, value, rtol=rtol):\n            print(f\"matched {idx} result of {test_name}, {key}, {rtol}.\")\n            return True\n    raise ValueError(f\"no matched results for {test_name}, {key}. {data}.\")\n"
  },
  {
    "path": "tests/testing_data/logging.conf",
    "content": "[loggers]\nkeys=root,ignite.engine.SupervisedEvaluator\n\n[handlers]\nkeys=consoleHandler\n\n[formatters]\nkeys=fullFormatter\n\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n\n[logger_ignite.engine.SupervisedEvaluator]\nlevel=INFO\nhandlers=consoleHandler\nqualname=ignite.engine.SupervisedEvaluator\npropagate=0\n\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=INFO\nformatter=fullFormatter\nargs=(sys.stdout,)\n\n[formatter_fullFormatter]\nformat=%(asctime)s - %(name)s - %(levelname)s - %(message)s\n"
  },
  {
    "path": "tests/testing_data/metadata.json",
    "content": "{\n    \"schema\": \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json\",\n    \"version\": \"0.1.0\",\n    \"changelog\": {\n        \"0.1.0\": \"complete the model package\",\n        \"0.0.1\": \"initialize the model package structure\"\n    },\n    \"monai_version\": \"0.9.0\",\n    \"pytorch_version\": \"1.10.0\",\n    \"numpy_version\": \"1.21.2\",\n    \"required_packages_version\": {\n        \"nibabel\": \"3.2.1\"\n    },\n    \"task\": \"Decathlon spleen segmentation\",\n    \"description\": \"A pre-trained model for volumetric (3D) segmentation of the spleen from CT image\",\n    \"authors\": \"MONAI team\",\n    \"copyright\": \"Copyright (c) MONAI Consortium\",\n    \"data_source\": \"Task09_Spleen.tar from http://medicaldecathlon.com/\",\n    \"data_type\": \"dicom\",\n    \"image_classes\": \"single channel data, intensity scaled to [0, 1]\",\n    \"label_classes\": \"single channel data, 1 is spleen, 0 is everything else\",\n    \"pred_classes\": \"2 channels OneHot data, channel 1 is spleen, channel 0 is background\",\n    \"eval_metrics\": {\n        \"mean_dice\": 0.96\n    },\n    \"intended_use\": \"This is an example, not to be used for diagnostic purposes\",\n    \"references\": [\n        \"Xia, Yingda, et al. '3D Semi-Supervised Learning with Uncertainty-Aware Multi-View Co-Training. arXiv preprint arXiv:1811.12506 (2018). https://arxiv.org/abs/1811.12506.\",\n        \"Kerfoot E., Clough J., Oksuz I., Lee J., King A.P., Schnabel J.A. (2019) Left-Ventricle Quantification Using Residual U-Net. In: Pop M. et al. (eds) Statistical Atlases and Computational Models of the Heart. Atrial Segmentation and LV Quantification Challenges. STACOM 2018. Lecture Notes in Computer Science, vol 11395. Springer, Cham. https://doi.org/10.1007/978-3-030-12029-0_40\"\n    ],\n    \"network_data_format\": {\n        \"inputs\": {\n            \"image\": {\n                \"type\": \"image\",\n                \"format\": \"magnitude\",\n                \"num_channels\": 1,\n                \"spatial_shape\": [\n                    160,\n                    160,\n                    160\n                ],\n                \"dtype\": \"float32\",\n                \"value_range\": [\n                    0,\n                    1\n                ],\n                \"is_patch_data\": false,\n                \"channel_def\": {\n                    \"0\": \"image\"\n                }\n            }\n        },\n        \"outputs\": {\n            \"pred\": {\n                \"type\": \"image\",\n                \"format\": \"segmentation\",\n                \"num_channels\": 2,\n                \"spatial_shape\": [\n                    160,\n                    160,\n                    160\n                ],\n                \"dtype\": \"float32\",\n                \"value_range\": [\n                    0,\n                    1\n                ],\n                \"is_patch_data\": false,\n                \"channel_def\": {\n                    \"0\": \"background\",\n                    \"1\": \"spleen\"\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "tests/testing_data/multi_gpu_evaluate.json",
    "content": "{\n    \"device\": \"$torch.device(f'cuda:{dist.get_rank()}')\",\n    \"network\": {\n        \"_target_\": \"torch.nn.parallel.DistributedDataParallel\",\n        \"module\": \"$@network_def.to(@device)\",\n        \"device_ids\": [\n            \"@device\"\n        ]\n    },\n    \"validate#sampler\": {\n        \"_target_\": \"DistributedSampler\",\n        \"dataset\": \"@validate#dataset\",\n        \"even_divisible\": false,\n        \"shuffle\": false\n    },\n    \"validate#dataloader#sampler\": \"@validate#sampler\",\n    \"initialize\": [\n        \"$import torch.distributed as dist\",\n        \"$dist.is_initialized() or dist.init_process_group(backend='nccl')\",\n        \"$torch.cuda.set_device(@device)\",\n        \"$setattr(torch.backends.cudnn, 'benchmark', True)\",\n        \"$import logging\",\n        \"$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)\"\n    ],\n    \"run\": [\n        \"$@validate#evaluator.run()\"\n    ],\n    \"finalize\": [\n        \"$dist.destroy_process_group()\"\n    ]\n}\n"
  },
  {
    "path": "tests/testing_data/multi_gpu_train.json",
    "content": "{\n    \"device\": \"$torch.device(f'cuda:{dist.get_rank()}')\",\n    \"network\": {\n        \"_target_\": \"torch.nn.parallel.DistributedDataParallel\",\n        \"module\": \"$@network_def.to(@device)\",\n        \"device_ids\": [\n            \"@device\"\n        ]\n    },\n    \"train#sampler\": {\n        \"_target_\": \"DistributedSampler\",\n        \"dataset\": \"@train#dataset\",\n        \"even_divisible\": true,\n        \"shuffle\": true\n    },\n    \"train#dataloader#sampler\": \"@train#sampler\",\n    \"train#dataloader#shuffle\": false,\n    \"initialize\": [\n        \"$import torch.distributed as dist\",\n        \"$dist.is_initialized() or dist.init_process_group(backend='nccl')\",\n        \"$torch.cuda.set_device(@device)\",\n        \"$monai.utils.set_determinism(seed=123)\",\n        \"$setattr(torch.backends.cudnn, 'benchmark', True)\",\n        \"$import logging\",\n        \"$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)\",\n        \"$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)\"\n    ],\n    \"run\": [\n        \"$@train#trainer.run()\"\n    ],\n    \"finalize\": [\n        \"$dist.destroy_process_group()\"\n    ]\n}\n"
  },
  {
    "path": "tests/testing_data/python_workflow_properties.json",
    "content": "{\n    \"infer\": {\n        \"bundle_root\": {\n            \"description\": \"root path of the bundle.\",\n            \"required\": true,\n            \"id\": \"bundle_root\"\n        },\n        \"device\": {\n            \"description\": \"target device to execute the bundle workflow.\",\n            \"required\": true,\n            \"id\": \"device\"\n        },\n        \"inferer\": {\n            \"description\": \"MONAI Inferer object to execute the model computation in inference.\",\n            \"required\": true,\n            \"id\": \"inferer\"\n        }\n    },\n    \"meta\": {\n        \"version\": {\n            \"description\": \"version of the inference configuration.\",\n            \"required\": true,\n            \"id\": \"_meta_::version\"\n        }\n    }\n}\n"
  },
  {
    "path": "tests/testing_data/responsive_inference.json",
    "content": "{\n    \"imports\": [\n        \"$from collections import defaultdict\"\n    ],\n    \"bundle_root\": \"will override\",\n    \"device\": \"$torch.device('cpu')\",\n    \"network_def\": {\n        \"_target_\": \"UNet\",\n        \"spatial_dims\": 3,\n        \"in_channels\": 1,\n        \"out_channels\": 2,\n        \"channels\": [\n            2,\n            2,\n            4,\n            8,\n            4\n        ],\n        \"strides\": [\n            2,\n            2,\n            2,\n            2\n        ],\n        \"num_res_units\": 2,\n        \"norm\": \"batch\"\n    },\n    \"network\": \"$@network_def.to(@device)\",\n    \"dataflow\": \"$defaultdict()\",\n    \"preprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n            {\n                \"_target_\": \"EnsureChannelFirstd\",\n                \"keys\": \"image\"\n            },\n            {\n                \"_target_\": \"ScaleIntensityd\",\n                \"keys\": \"image\"\n            },\n            {\n                \"_target_\": \"RandRotated\",\n                \"_disabled_\": true,\n                \"keys\": \"image\"\n            }\n        ]\n    },\n    \"dataset\": {\n        \"_target_\": \"Dataset\",\n        \"data\": [\n            \"@dataflow\"\n        ],\n        \"transform\": \"@preprocessing\"\n    },\n    \"dataloader\": {\n        \"_target_\": \"DataLoader\",\n        \"dataset\": \"@dataset\",\n        \"batch_size\": 1,\n        \"shuffle\": false,\n        \"num_workers\": 0\n    },\n    \"inferer\": {\n        \"_target_\": \"SlidingWindowInferer\",\n        \"roi_size\": [\n            64,\n            64,\n            32\n        ],\n        \"sw_batch_size\": 4,\n        \"overlap\": 0.25\n    },\n    \"postprocessing\": {\n        \"_target_\": \"Compose\",\n        \"transforms\": [\n            {\n                \"_target_\": \"Activationsd\",\n                \"keys\": \"pred\",\n                \"softmax\": true\n            },\n            {\n                \"_target_\": \"AsDiscreted\",\n                \"keys\": \"pred\",\n                \"argmax\": true\n            }\n        ]\n    },\n    \"evaluator\": {\n        \"_target_\": \"SupervisedEvaluator\",\n        \"device\": \"@device\",\n        \"val_data_loader\": \"@dataloader\",\n        \"network\": \"@network\",\n        \"inferer\": \"@inferer\",\n        \"postprocessing\": \"@postprocessing\",\n        \"amp\": false,\n        \"epoch_length\": 1\n    },\n    \"run\": [\n        \"$@evaluator.run()\",\n        \"$@dataflow.update(@evaluator.state.output[0])\"\n    ]\n}\n"
  },
  {
    "path": "tests/testing_data/transform_metatensor_cases.yaml",
    "content": "---\ninput_keys: [image, segs]\ntest_device: \"$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\"\ninit_affine: \"$np.array([[-2, 0, 0, 90], [0, 2, 0, -126], [0, 0, 2, -72], [0, 0, 0, 1]], dtype=np.float64)\"\ninit_shape: [1, 91, 109, 91]\nTEST_CASE_1:\n  _target_: Compose\n  transforms:\n  - _target_: LoadImageD\n    keys: \"@input_keys\"\n    ensure_channel_first: True\n    image_only: True\n  - _target_: ToDeviced\n    keys: \"@input_keys\"\n    device: \"@test_device\"\n  - _target_: CenterScaleCropD\n    keys: \"@input_keys\"\n    roi_scale: 0.98\n  - _target_: CropForegroundD\n    keys: \"@input_keys\"\n    source_key: seg\n    start_coord_key: null\n    end_coord_key: null\n    k_divisible: 5\n  - _target_: RandSpatialCropD\n    keys: \"@input_keys\"\n    roi_size: [76, 87, 73]\n    random_size: True\n  - _target_: RandScaleCropD\n    keys: \"@input_keys\"\n    roi_scale: 0.9\n    random_size: True\n  - _target_: ResizeWithPadOrCropD\n    keys: \"@input_keys\"\n    spatial_size: [32, 43, 54]\n  - _target_: DivisiblePadD\n    keys: \"@input_keys\"\n    k: 3\n\nTEST_CASE_2:\n  _target_: Compose\n  transforms:\n  - _target_: LoadImaged\n    keys: \"@input_keys\"\n    ensure_channel_first: False\n    image_only: True\n  - _target_: ToDeviced\n    keys: \"@input_keys\"\n    device: \"@test_device\"\n  - _target_: EnsureChannelFirstd\n    keys: \"@input_keys\"\n  - _target_: ScaleIntensityRangePercentilesd\n    keys: \"$@input_keys[0]\"\n    lower: 4\n    upper: 95\n    b_min: 1\n    b_max: 10\n  - _target_: RandScaleIntensityd\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n    factors: [5, 10]\n  - _target_: RandGaussianNoised\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n    mean: 10.0\n    std: 2.0\n  - _target_: RandCoarseShuffled\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n    holes: 2\n    spatial_size: [10, 13, 18]\n    max_spatial_size: [14, 30, 57]\n  - _target_: DataStatsd\n    keys: \"$@input_keys[0]\"\n  - _target_: RandBiasFieldd\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n  - _target_: RandGaussianSmoothd\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n  - _target_: RandGaussianSharpend\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n  - _target_: RandHistogramShiftd\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n  - _target_: RandAdjustContrastd\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n  - _target_: RandCoarseDropoutd\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n    holes: 3\n    spatial_size: [10, 13, 18]\n    max_spatial_size: [14, 30, 57]\n  - _target_: RandRicianNoised\n    keys: \"$@input_keys[0]\"\n    prob: 1.0\n\nTEST_CASE_3:\n  _target_: Compose\n  transforms:\n  - _target_: LoadImageD\n    keys: \"@input_keys\"\n    ensure_channel_first: True\n    image_only: True\n  - _target_: CenterScaleCropD\n    keys: \"@input_keys\"\n    roi_scale: 0.98\n  - _target_: CropForegroundD\n    keys: \"@input_keys\"\n    source_key: seg\n    start_coord_key: null\n    end_coord_key: null\n    k_divisible: 5\n  - _target_: ToDeviced\n    keys: \"@input_keys\"\n    device: \"@test_device\"\n  - _target_: RandRotate90d\n    keys: \"@input_keys\"\n    prob: 1.0\n    spatial_axes: [2, 1]\n  - _target_: Spacingd\n    keys: \"@input_keys\"\n    pixdim: [1.8, 2.1, 2.3]\n  - _target_: RandFlipd\n    keys: \"@input_keys\"\n    prob: 1.0\n    spatial_axis: 2\n  - _target_: RandAffined\n    keys: \"@input_keys\"\n    prob: 1.0\n    spatial_size: [80, 91, 92]\n    rotate_range: 1.0\n    scale_range: 0.1\n  - _target_: Flipd\n    keys: \"@input_keys\"\n    spatial_axis: 2\n  - _target_: Orientationd\n    keys: \"@input_keys\"\n    axcodes: \"RPI\"\n  - _target_: Affined\n    keys: \"@input_keys\"\n    shear_params: [0, 0.5, 0]\n  - _target_: Rotate90d\n    keys: \"@input_keys\"\n    spatial_axes: [1, 2]\n  - _target_: Zoomd\n    keys: \"@input_keys\"\n    zoom: 1.3\n  - _target_: ScaleIntensityd\n    keys: \"@input_keys\"\n    minv: 0\n    maxv: 10\n  - _target_: RandAxisFlipD\n    keys: \"@input_keys\"\n    prob: 1.0\n  - _target_: RandRotated\n    keys: \"@input_keys\"\n    prob: 1.0\n    range_y: \"$np.pi/3\"\n  - _target_: RandZoomD\n    keys: \"@input_keys\"\n    prob: 1.0\n    max_zoom: 1.2\n    keep_size: True\n  - _target_: RandGaussianNoised\n    keys: \"@input_keys\"\n    prob: 1.0\n  - _target_: ResizeWithPadOrCropD\n    keys: \"@input_keys\"\n    spatial_size: [71, 56, 80]\n  - _target_: Rand3DElasticd\n    keys: \"@input_keys\"\n    spatial_size: [71, 56, 80]\n    sigma_range: [5, 7]\n    magnitude_range: [50, 150]\n    prob: 1.0\n  - _target_: Resized\n    keys: \"@input_keys\"\n    spatial_size: [72, 57, 82]\n\nTEST_CASE_1_answer:\n  load_shape: [1, 1, 33, 45, 54]\n  affine: \"$np.array([[-2, 0, 0, 30], [0, 2, 0, -62], [0, 0, 2, -48], [0, 0, 0, 1]], dtype=np.float64)\"\n  inv_affine: \"@init_affine\"\n  inv_shape: \"@init_shape\"\n\nTEST_CASE_2_answer:\n  load_shape: [1, 1, 91, 109, 91]\n  affine: \"$np.array([[-2, 0, 0, 90], [0, 2, 0, -126], [0, 0, 2, -72], [0, 0, 0, 1]], dtype=np.float64)\"\n  inv_affine: \"@init_affine\"\n  inv_shape: \"@init_shape\"\n\nTEST_CASE_3_answer:\n  load_shape: [1, 1, 72, 57, 82]\n  affine: \"$np.array([[1.300558,  -0.700765,  -0.511861,  -3.739605], [0.479723,  -1.171149,   1.193079, -50.087933], [0.395736,   1.183532,   0.984201, -80.496605], [0, 0, 0, 1]], dtype=np.float64)\"\n  inv_affine: \"@init_affine\"\n  inv_shape: \"@init_shape\"\n"
  },
  {
    "path": "tests/transforms/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/compose/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/compose/test_compose.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport sys\nimport unittest\nfrom copy import deepcopy\nfrom io import StringIO\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai.transforms as mt\nfrom monai.data import DataLoader, Dataset\nfrom monai.transforms.compose import execute_compose\nfrom monai.transforms.transform import Randomizable\nfrom monai.utils import set_determinism\n\n\ndef data_from_keys(keys, h, w):\n    if keys is None:\n        data = torch.arange(h * w).reshape(1, h, w)\n    else:\n        data = {}\n        for i_k, k in enumerate(keys):\n            data[k] = torch.arange(h * w).reshape(1, h, w).mul_(i_k * h * w)\n    return data\n\n\nclass _RandXform(Randomizable):\n\n    def randomize(self):\n        self.val = self.R.random_sample()\n\n    def __call__(self, __unused):\n        self.randomize()\n        return self.val\n\n\nclass TestCompose(unittest.TestCase):\n\n    def test_empty_compose(self):\n        c = mt.Compose()\n        i = 1\n        self.assertEqual(c(i), 1)\n\n    def test_non_dict_compose(self):\n\n        def a(i):\n            return i + \"a\"\n\n        def b(i):\n            return i + \"b\"\n\n        c = mt.Compose([a, b, a, b])\n        self.assertEqual(c(\"\"), \"abab\")\n\n    def test_dict_compose(self):\n\n        def a(d):\n            d = dict(d)\n            d[\"a\"] += 1\n            return d\n\n        def b(d):\n            d = dict(d)\n            d[\"b\"] += 1\n            return d\n\n        transforms = [a, b, a, b, a]\n        data = {\"a\": 0, \"b\": 0}\n        expected = {\"a\": 3, \"b\": 2}\n\n        self.assertDictEqual(mt.Compose(transforms)(data), expected)\n        self.assertDictEqual(execute_compose(data, transforms), expected)\n\n    def test_list_dict_compose(self):\n\n        def a(d):  # transform to handle dict data\n            d = dict(d)\n            d[\"a\"] += 1\n            return d\n\n        def b(d):  # transform to generate a batch list of data\n            d = dict(d)\n            d[\"b\"] += 1\n            d = [d] * 5\n            return d\n\n        def c(d):  # transform to handle dict data\n            d = dict(d)\n            d[\"c\"] += 1\n            return d\n\n        transforms = [a, a, b, c, c]\n        data = {\"a\": 0, \"b\": 0, \"c\": 0}\n        expected = {\"a\": 2, \"b\": 1, \"c\": 2}\n        value = mt.Compose(transforms)(data)\n        for item in value:\n            self.assertDictEqual(item, expected)\n        value = execute_compose(data, transforms)\n        for item in value:\n            self.assertDictEqual(item, expected)\n\n    def test_non_dict_compose_with_unpack(self):\n\n        def a(i, i2):\n            return i + \"a\", i2 + \"a2\"\n\n        def b(i, i2):\n            return i + \"b\", i2 + \"b2\"\n\n        transforms = [a, b, a, b]\n        data = (\"\", \"\")\n        expected = (\"abab\", \"a2b2a2b2\")\n        self.assertEqual(mt.Compose(transforms, map_items=False, unpack_items=True)(data), expected)\n        self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected)\n\n    def test_list_non_dict_compose_with_unpack(self):\n\n        def a(i, i2):\n            return i + \"a\", i2 + \"a2\"\n\n        def b(i, i2):\n            return i + \"b\", i2 + \"b2\"\n\n        transforms = [a, b, a, b]\n        data = [(\"\", \"\"), (\"t\", \"t\")]\n        expected = [(\"abab\", \"a2b2a2b2\"), (\"tabab\", \"ta2b2a2b2\")]\n        self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected)\n        self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)\n\n    def test_list_non_dict_compose_with_unpack_map_2(self):\n\n        def a(i, i2):\n            return i + \"a\", i2 + \"a2\"\n\n        def b(i, i2):\n            return i + \"b\", i2 + \"b2\"\n\n        transforms = [a, b, a, b]\n        data = [[(\"\", \"\"), (\"\", \"\")], [(\"t\", \"t\"), (\"t\", \"t\")]]\n        expected = [[(\"abab\", \"a2b2a2b2\"), (\"abab\", \"a2b2a2b2\")], [(\"tabab\", \"ta2b2a2b2\"), (\"tabab\", \"ta2b2a2b2\")]]\n        self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected)\n        self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected)\n\n    def test_list_dict_compose_no_map(self):\n\n        def a(d):  # transform to handle dict data\n            d = dict(d)\n            d[\"a\"] += 1\n            return d\n\n        def b(d):  # transform to generate a batch list of data\n            d = dict(d)\n            d[\"b\"] += 1\n            d = [d] * 5\n            return d\n\n        def c(d):  # transform to handle dict data\n            d = [dict(di) for di in d]\n            for di in d:\n                di[\"c\"] += 1\n            return d\n\n        transforms = [a, a, b, c, c]\n        data = {\"a\": 0, \"b\": 0, \"c\": 0}\n        expected = {\"a\": 2, \"b\": 1, \"c\": 2}\n        value = mt.Compose(transforms, map_items=False)(data)\n        for item in value:\n            self.assertDictEqual(item, expected)\n        value = execute_compose(data, transforms, map_items=False)\n        for item in value:\n            self.assertDictEqual(item, expected)\n\n    def test_random_compose(self):\n\n        class _Acc(Randomizable):\n            self.rand = 0.0\n\n            def randomize(self, data=None):\n                self.rand = self.R.rand()\n\n            def __call__(self, data):\n                self.randomize()\n                return self.rand + data\n\n        c = mt.Compose([_Acc(), _Acc()])\n        self.assertNotAlmostEqual(c(0), c(0))\n        c.set_random_state(123)\n        self.assertAlmostEqual(c(1), 1.61381597)\n        c.set_random_state(223)\n        c.randomize()\n        self.assertAlmostEqual(c(1), 1.90734751)\n\n    def test_randomize_warn(self):\n\n        class _RandomClass(Randomizable):\n\n            def randomize(self, foo1, foo2):\n                pass\n\n            def __call__(self, data):\n                pass\n\n        c = mt.Compose([_RandomClass(), _RandomClass()])\n        with self.assertWarns(Warning):\n            c.randomize()\n\n    def test_err_msg(self):\n        transforms = mt.Compose([abs, mt.EnsureChannelFirst(), round])\n        with self.assertRaisesRegex(Exception, \"EnsureChannelFirst\"):\n            transforms(42.1)\n\n    def test_data_loader(self):\n        xform_1 = mt.Compose([_RandXform()])\n        train_ds = Dataset([1], transform=xform_1)\n\n        xform_1.set_random_state(123)\n        out_1 = train_ds[0]\n        self.assertAlmostEqual(out_1, 0.2045649)\n\n        set_determinism(seed=123)\n        train_loader = DataLoader(train_ds, num_workers=0)\n        out_1 = next(iter(train_loader))\n        self.assertAlmostEqual(out_1.cpu().item(), 0.0409280)\n\n        if sys.platform != \"win32\":  # skip multi-worker tests on win32\n            train_loader = DataLoader(train_ds, num_workers=1)\n            out_1 = next(iter(train_loader))\n            self.assertAlmostEqual(out_1.cpu().item(), 0.78663897075)\n\n            train_loader = DataLoader(train_ds, num_workers=2)\n            out_1 = next(iter(train_loader))\n            self.assertAlmostEqual(out_1.cpu().item(), 0.785907334)\n        set_determinism(None)\n\n    def test_data_loader_2(self):\n        set_determinism(seed=123)\n        xform_2 = mt.Compose([_RandXform(), _RandXform()])\n        train_ds = Dataset([1], transform=xform_2)\n\n        out_2 = train_ds[0]\n        self.assertAlmostEqual(out_2, 0.4092510)\n\n        train_loader = DataLoader(train_ds, num_workers=0)\n        out_2 = next(iter(train_loader))\n        self.assertAlmostEqual(out_2.cpu().item(), 0.98921915918)\n\n        if sys.platform != \"win32\":  # skip multi-worker tests on win32\n            train_loader = DataLoader(train_ds, num_workers=1)\n            out_2 = next(iter(train_loader))\n            self.assertAlmostEqual(out_2.cpu().item(), 0.32985207)\n\n            train_loader = DataLoader(train_ds, num_workers=2)\n            out_1 = next(iter(train_loader))\n            self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)\n        set_determinism(None)\n\n    def test_flatten_and_len(self):\n        x = mt.EnsureChannelFirst(channel_dim=\"no_channel\")\n        t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])])\n\n        t2 = t1.flatten()\n        for t in t2.transforms:\n            self.assertNotIsInstance(t, mt.Compose)\n\n        # test len\n        self.assertEqual(len(t1), 8)\n\n    def test_backwards_compatible_imports(self):\n        pass\n\n    def test_list_extend_multi_sample_trait(self):\n        center_crop = mt.CenterSpatialCrop([128, 128])\n        multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1)\n        flatten_sequence_transform = mt.FlattenSequence()\n\n        img = torch.zeros([1, 512, 512])\n\n        self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128]))\n        single_multi_sample_trait_result = execute_compose(\n            img, [multi_sample_transform, center_crop, flatten_sequence_transform]\n        )\n        self.assertIsInstance(single_multi_sample_trait_result, list)\n        self.assertEqual(len(single_multi_sample_trait_result), 1)\n        self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))\n\n        double_multi_sample_trait_result = execute_compose(\n            img, [multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop]\n        )\n        self.assertIsInstance(double_multi_sample_trait_result, list)\n        self.assertEqual(len(double_multi_sample_trait_result), 1)\n        self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))\n\n    def test_multi_sample_trait_cardinality(self):\n        img = torch.zeros([1, 128, 128])\n        t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2)\n        flatten_sequence_transform = mt.FlattenSequence()\n\n        # chaining should multiply counts: 2 x 2 = 4, flattened\n        res = execute_compose(img, [t2, t2, flatten_sequence_transform])\n        self.assertIsInstance(res, list)\n        self.assertEqual(len(res), 4)\n        for r in res:\n            self.assertEqual(r.shape, torch.Size([1, 32, 32]))\n\n\nTEST_COMPOSE_EXECUTE_TEST_CASES = [\n    [None, tuple()],\n    [None, (mt.Rotate(np.pi / 8),)],\n    [None, (mt.Flip(0), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity())],\n    [(\"a\",), (mt.Rotated((\"a\",), np.pi / 8),)],\n]\n\n\nclass TestComposeExecute(unittest.TestCase):\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)\n    def test_compose_execute_equivalence(self, keys, pipeline):\n        data = data_from_keys(keys, 12, 16)\n\n        expected = mt.Compose(deepcopy(pipeline))(data)\n\n        for cutoff in range(len(pipeline)):\n            c = mt.Compose(deepcopy(pipeline))\n            actual = c(c(data, end=cutoff), start=cutoff)\n            if isinstance(actual, dict):\n                for k in actual.keys():\n                    self.assertTrue(torch.allclose(expected[k], actual[k]))\n            else:\n                self.assertTrue(torch.allclose(expected, actual))\n\n            p = deepcopy(pipeline)\n            actual = execute_compose(execute_compose(data, p, start=0, end=cutoff), p, start=cutoff)\n            if isinstance(actual, dict):\n                for k in actual.keys():\n                    self.assertTrue(torch.allclose(expected[k], actual[k]))\n            else:\n                self.assertTrue(torch.allclose(expected, actual))\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)\n    def test_compose_execute_bad_start_param(self, keys, pipeline):\n        data = data_from_keys(keys, 12, 16)\n\n        c = mt.Compose(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, start=None)\n        with self.assertRaises(ValueError):\n            c(data, start=None)\n\n        with self.assertRaises(ValueError):\n            execute_compose(data, deepcopy(pipeline), start=None)\n\n        c = mt.Compose(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, start=-1)\n        with self.assertRaises(ValueError):\n            c(data, start=-1)\n\n        with self.assertRaises(ValueError):\n            execute_compose(data, deepcopy(pipeline), start=-1)\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)\n    def test_compose_execute_negative_range(self, keys, pipeline):\n        data = data_from_keys(keys, 12, 16)\n\n        with self.assertRaises(ValueError):\n            c = mt.Compose(deepcopy(pipeline))\n            c(data, start=2, end=1)\n\n        with self.assertRaises(ValueError):\n            execute_compose(data, deepcopy(pipeline), start=2, end=1)\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)\n    def test_compose_execute_bad_end_param(self, keys, pipeline):\n        data = data_from_keys(keys, 12, 16)\n\n        with self.assertRaises(ValueError):\n            c = mt.Compose(deepcopy(pipeline))\n            c(data, end=len(pipeline) + 1)\n\n        with self.assertRaises(ValueError):\n            execute_compose(data, deepcopy(pipeline), end=len(pipeline) + 1)\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)\n    def test_compose_execute_empty_range(self, keys, pipeline):\n        data = data_from_keys(keys, 12, 16)\n\n        c = mt.Compose(deepcopy(pipeline))\n        for i in range(len(pipeline)):\n            result = c(data, start=i, end=i)\n            self.assertIs(data, result)\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES)\n    def test_compose_with_logger(self, keys, pipeline):\n        data = data_from_keys(keys, 12, 16)\n\n        c = mt.Compose(deepcopy(pipeline), log_stats=\"a_logger_name\")\n        c(data)\n\n\nTEST_COMPOSE_EXECUTE_LOGGING_TEST_CASES = [\n    [\n        None,\n        (mt.Flip(0), mt.Spacing((1.2, 1.2)), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity()),\n        False,\n        (\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Spacing', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Rotate90', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Zoom', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'NormalizeIntensity', transform is not lazy\\n\"\n        ),\n    ],\n    [\n        None,\n        (\n            mt.Flip(0, lazy=True),\n            mt.Spacing((1.2, 1.2), lazy=True),\n            mt.Flip(1, lazy=True),\n            mt.Rotate90(1),\n            mt.Zoom(0.8, lazy=True),\n            mt.NormalizeIntensity(),\n        ),\n        None,\n        (\n            \"INFO - Accumulate pending transforms - lazy: None, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: True\\n\"\n            \"INFO - Accumulate pending transforms - lazy: None, pending: 1, \"\n            \"upcoming 'Spacing', transform.lazy: True\\n\"\n            \"INFO - Accumulate pending transforms - lazy: None, pending: 2, \"\n            \"upcoming 'Flip', transform.lazy: True\\n\"\n            \"INFO - Apply pending transforms - lazy: None, pending: 3, \"\n            \"upcoming 'Rotate90', transform.lazy: False\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 3\\n\"\n            \"INFO - Accumulate pending transforms - lazy: None, pending: 0, \"\n            \"upcoming 'Zoom', transform.lazy: True\\n\"\n            \"INFO - Apply pending transforms - lazy: None, pending: 1, \"\n            \"upcoming 'NormalizeIntensity', transform is not lazy\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 5\\n\"\n        ),\n    ],\n    [\n        None,\n        (mt.Flip(0), mt.Spacing((1.2, 1.2)), mt.Flip(1), mt.Rotate90(1), mt.Zoom(0.8), mt.NormalizeIntensity()),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 1, \"\n            \"upcoming 'Spacing', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 2, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 3, \"\n            \"upcoming 'Rotate90', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 4, \"\n            \"upcoming 'Zoom', transform.lazy: False (overridden)\\n\"\n            \"INFO - Apply pending transforms - lazy: True, pending: 5, \"\n            \"upcoming 'NormalizeIntensity', transform is not lazy\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 5\\n\"\n        ),\n    ],\n    [\n        (\"a\", \"b\"),\n        (\n            mt.Flipd((\"a\", \"b\"), 0),\n            mt.Spacingd((\"a\", \"b\"), 1.2),\n            mt.Rotate90d((\"a\", \"b\"), 1),\n            mt.NormalizeIntensityd((\"a\",)),\n        ),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, \"\n            \"upcoming 'Flipd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, \"\n            \"upcoming 'Flipd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 1, \"\n            \"upcoming 'Spacingd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 1, \"\n            \"upcoming 'Spacingd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 2, \"\n            \"upcoming 'Rotate90d', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 2, \"\n            \"upcoming 'Rotate90d', transform.lazy: False (overridden)\\n\"\n            \"INFO - Apply pending transforms - lazy: True, key: 'a', pending: 3, \"\n            \"upcoming 'NormalizeIntensityd', transform is not lazy\\n\"\n            \"INFO - Pending transforms applied: key: 'a', applied_operations: 3\\n\"\n            \"INFO - Pending transforms applied: key: 'b', applied_operations: 3\\n\"\n        ),\n    ],\n    [\n        (\"a\", \"b\"),\n        (\n            mt.Flipd(keys=\"a\", spatial_axis=0),\n            mt.Rotate90d(keys=\"b\", k=1, allow_missing_keys=True),\n            mt.Zoomd(keys=(\"a\", \"b\"), zoom=0.8, allow_missing_keys=True),\n            mt.Spacingd(keys=\"a\", pixdim=1.2),\n        ),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, \"\n            \"upcoming 'Flipd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, \"\n            \"upcoming 'Rotate90d', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 1, \"\n            \"upcoming 'Zoomd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 1, \"\n            \"upcoming 'Zoomd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 2, \"\n            \"upcoming 'Spacingd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Pending transforms applied: key: 'a', applied_operations: 3\\n\"\n            \"INFO - Pending transforms applied: key: 'b', applied_operations: 2\\n\"\n        ),\n    ],\n    [\n        None,\n        (\n            mt.Flip(0),\n            mt.Spacing((1.2, 1.2)),\n            mt.Flip(1),\n            mt.ApplyPending(),\n            mt.Rotate90(1),\n            mt.Zoom(0.8),\n            mt.NormalizeIntensity(),\n        ),\n        False,\n        (\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Spacing', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'ApplyPending', transform is not lazy\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Rotate90', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'Zoom', transform.lazy: False\\n\"\n            \"INFO - Apply pending transforms - lazy: False, pending: 0, \"\n            \"upcoming 'NormalizeIntensity', transform is not lazy\\n\"\n        ),\n    ],\n    [\n        None,\n        (\n            mt.Flip(0),\n            mt.Spacing((1.2, 1.2)),\n            mt.Flip(1),\n            mt.ApplyPending(),\n            mt.Rotate90(1),\n            mt.Zoom(0.8),\n            mt.NormalizeIntensity(),\n        ),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 1, \"\n            \"upcoming 'Spacing', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 2, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Apply pending transforms - lazy: True, pending: 3, \"\n            \"upcoming 'ApplyPending', transform is not lazy\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 3\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Rotate90', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 1, \"\n            \"upcoming 'Zoom', transform.lazy: False (overridden)\\n\"\n            \"INFO - Apply pending transforms - lazy: True, pending: 2, \"\n            \"upcoming 'NormalizeIntensity', transform is not lazy\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 5\\n\"\n        ),\n    ],\n    [\n        (\"a\", \"b\"),\n        (\n            mt.Flipd(keys=\"a\", spatial_axis=0),\n            mt.Rotate90d(keys=\"b\", k=1, allow_missing_keys=True),\n            mt.ApplyPendingd(keys=(\"a\", \"b\")),\n            mt.Zoomd(keys=(\"a\", \"b\"), zoom=0.8, allow_missing_keys=True),\n            mt.Spacingd(keys=\"a\", pixdim=1.2),\n        ),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, \"\n            \"upcoming 'Flipd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, \"\n            \"upcoming 'Rotate90d', transform.lazy: False (overridden)\\n\"\n            \"INFO - Apply pending transforms - lazy: True, key: 'a', pending: 1, \"\n            \"upcoming 'ApplyPendingd', transform is not lazy\\n\"\n            \"INFO - Apply pending transforms - lazy: True, key: 'b', pending: 1, \"\n            \"upcoming 'ApplyPendingd', transform is not lazy\\n\"\n            \"INFO - Pending transforms applied: key: 'a', applied_operations: 1\\n\"\n            \"INFO - Pending transforms applied: key: 'b', applied_operations: 1\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 0, \"\n            \"upcoming 'Zoomd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'b', pending: 0, \"\n            \"upcoming 'Zoomd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, key: 'a', pending: 1, \"\n            \"upcoming 'Spacingd', transform.lazy: False (overridden)\\n\"\n            \"INFO - Pending transforms applied: key: 'a', applied_operations: 3\\n\"\n            \"INFO - Pending transforms applied: key: 'b', applied_operations: 2\\n\"\n        ),\n    ],\n]\n\nTEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES = [\n    [\n        mt.Compose,\n        (mt.Flip(0), mt.Spacing((1.2, 1.2))),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 1, \"\n            \"upcoming 'Spacing', transform.lazy: False (overridden)\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 2\\n\"\n        ),\n    ],\n    [\n        mt.SomeOf,\n        (mt.Flip(0),),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 1\\n\"\n        ),\n    ],\n    [\n        mt.RandomOrder,\n        (mt.Flip(0),),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 1\\n\"\n        ),\n    ],\n    [\n        mt.OneOf,\n        (mt.Flip(0),),\n        True,\n        (\n            \"INFO - Accumulate pending transforms - lazy: True, pending: 0, \"\n            \"upcoming 'Flip', transform.lazy: False (overridden)\\n\"\n            \"INFO - Pending transforms applied: applied_operations: 1\\n\"\n        ),\n    ],\n    [\n        mt.OneOf,\n        (mt.Flip(0),),\n        False,\n        (\"INFO - Apply pending transforms - lazy: False, pending: 0, \" \"upcoming 'Flip', transform.lazy: False\\n\"),\n    ],\n]\n\n\nclass TestComposeExecuteWithLogging(unittest.TestCase):\n    LOGGER_NAME = \"a_logger_name\"\n\n    def init_logger(self, name=LOGGER_NAME):\n        stream = StringIO()\n        handler = logging.StreamHandler(stream)\n        formatter = logging.Formatter(\"%(levelname)s - %(message)s\")\n        handler.setFormatter(formatter)\n        logger = logging.getLogger(name)\n        logger.setLevel(logging.INFO)\n        while len(logger.handlers) > 0:\n            logger.removeHandler(logger.handlers[-1])\n        logger.addHandler(handler)\n        return handler, stream\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_LOGGING_TEST_CASES)\n    def test_compose_with_logging(self, keys, pipeline, lazy, expected):\n        handler, stream = self.init_logger(name=self.LOGGER_NAME)\n\n        data = data_from_keys(keys, 12, 16)\n        c = mt.Compose(deepcopy(pipeline), lazy=lazy, log_stats=self.LOGGER_NAME)\n        c(data)\n\n        handler.flush()\n        actual = stream.getvalue()\n        self.assertEqual(actual, expected)\n\n    @parameterized.expand(TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES)\n    def test_compose_lazy_on_call_with_logging(self, compose_type, pipeline, lazy_on_call, expected):\n        handler, stream = self.init_logger(name=self.LOGGER_NAME)\n\n        data = data_from_keys(None, 12, 16)\n        c = compose_type(deepcopy(pipeline), log_stats=self.LOGGER_NAME)\n        c(data, lazy=lazy_on_call)\n\n        handler.flush()\n        actual = stream.getvalue()\n        self.assertEqual(actual, expected)\n\n\nclass TestOps:\n\n    @staticmethod\n    def concat(value):\n\n        def _inner(data):\n            return data + value\n\n        return _inner\n\n    @staticmethod\n    def concatd(value):\n\n        def _inner(data):\n            return {k: v + value for k, v in data.items()}\n\n        return _inner\n\n    @staticmethod\n    def concata(value):\n\n        def _inner(data1, data2):\n            return data1 + value, data2 + value\n\n        return _inner\n\n\nTEST_COMPOSE_EXECUTE_FLAG_TEST_CASES = [\n    [{}, (\"\",), (TestOps.concat(\"a\"), TestOps.concat(\"b\"))],\n    [{\"unpack_items\": True}, (\"x\", \"y\"), (TestOps.concat(\"a\"), TestOps.concat(\"b\"))],\n    [{\"map_items\": False}, {\"x\": \"1\", \"y\": \"2\"}, (TestOps.concatd(\"a\"), TestOps.concatd(\"b\"))],\n    [{\"unpack_items\": True, \"map_items\": False}, (\"x\", \"y\"), (TestOps.concata(\"a\"), TestOps.concata(\"b\"))],\n]\n\n\nclass TestComposeExecuteWithFlags(unittest.TestCase):\n\n    @parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES)\n    def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):\n        expected = mt.Compose(pipeline, **flags)(data)\n\n        for cutoff in range(len(pipeline)):\n            c = mt.Compose(deepcopy(pipeline), **flags)\n            actual = c(c(data, end=cutoff), start=cutoff)\n            if isinstance(actual, dict):\n                for k in actual.keys():\n                    self.assertEqual(expected[k], actual[k])\n            else:\n                self.assertEqual(expected, actual)\n\n            p = deepcopy(pipeline)\n            actual = execute_compose(execute_compose(data, p, start=0, end=cutoff, **flags), p, start=cutoff, **flags)\n            if isinstance(actual, dict):\n                for k in actual.keys():\n                    self.assertEqual(expected[k], actual[k])\n            else:\n                self.assertEqual(expected, actual)\n\n\nclass TestComposeCallableInput(unittest.TestCase):\n\n    def test_value_error_when_not_sequence(self):\n        data = torch.tensor(np.random.randn(1, 5, 5))\n\n        xform = mt.Compose([mt.Flip(0), mt.Flip(0)])\n        res = xform(data)\n        np.testing.assert_allclose(data, res, atol=1e-3)\n\n        with self.assertRaises(ValueError):\n            mt.Compose(mt.Flip(0), mt.Flip(0))(data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/compose/test_some_of.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai.transforms.intensity.array as ia\nimport monai.transforms.spatial.array as sa\nimport monai.transforms.spatial.dictionary as sd\nfrom monai.data import MetaTensor\nfrom monai.transforms import TraceableTransform, Transform\nfrom monai.transforms.compose import Compose, SomeOf\nfrom monai.utils import set_determinism\nfrom monai.utils.enums import TraceKeys\nfrom tests.integration.test_one_of import NonInv\nfrom tests.transforms.test_random_order import InvC, InvD\n\n\nclass A(Transform):\n\n    def __call__(self, x):\n        return 2 * x\n\n\nclass B(Transform):\n\n    def __call__(self, x):\n        return 3 * x\n\n\nclass C(Transform):\n\n    def __call__(self, x):\n        return 5 * x\n\n\nclass D(Transform):\n\n    def __call__(self, x):\n        return 7 * x\n\n\nKEYS = [\"x\", \"y\"]\nTEST_COMPOUND = [\n    (SomeOf((A(), B(), C()), num_transforms=3), 2 * 3 * 5),\n    (Compose((SomeOf((A(), B(), C()), num_transforms=3), D())), 2 * 3 * 5 * 7),\n    (SomeOf((A(), B(), C(), Compose(D())), num_transforms=4), 2 * 3 * 5 * 7),\n    (SomeOf(()), 1),\n    (SomeOf(None), 1),\n]\n\n# Modified from RandomOrder\nTEST_INVERSES = [\n    (SomeOf((InvC(KEYS), InvD(KEYS))), True, True),\n    (Compose((SomeOf((InvC(KEYS), InvD(KEYS))), SomeOf((InvD(KEYS), InvC(KEYS))))), True, False),\n    (SomeOf((SomeOf((InvC(KEYS), InvD(KEYS))), SomeOf((InvD(KEYS), InvC(KEYS))))), True, False),\n    (SomeOf((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, False),\n    (SomeOf((NonInv(KEYS), NonInv(KEYS))), False, False),\n    (SomeOf(()), False, False),\n]\n\n\nclass TestSomeOf(unittest.TestCase):\n\n    def setUp(self):\n        set_determinism(seed=0)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    def update_transform_count(self, counts, output):\n        op_count = 0\n\n        if output % 2 == 0:\n            counts[0] += 1\n            op_count += 1\n        if output % 3 == 0:\n            counts[1] += 1\n            op_count += 1\n        if output % 5 == 0:\n            counts[2] += 1\n            op_count += 1\n\n        return op_count\n\n    def test_fixed(self):\n        iterations = 10000\n        num_transforms = 3\n        transform_counts = 3 * [0]\n        subset_size_counts = 4 * [0]\n\n        s = SomeOf((A(), B(), C()), num_transforms=num_transforms)\n\n        for _ in range(iterations):\n            output = s(1)\n            subset_size = self.update_transform_count(transform_counts, output)\n            subset_size_counts[subset_size] += 1\n\n        for i in range(3):\n            self.assertEqual(transform_counts[i], iterations)\n\n        for i in range(3):\n            self.assertEqual(subset_size_counts[i], 0)\n\n        self.assertEqual(subset_size_counts[3], iterations)\n\n    def test_unfixed(self):\n        iterations = 10000\n        num_transforms = (0, 3)\n        transform_counts = 3 * [0]\n        subset_size_counts = 4 * [0]\n\n        s = SomeOf((A(), B(), C()), num_transforms=num_transforms)\n\n        for _ in range(iterations):\n            output = s(1)\n            subset_size = self.update_transform_count(transform_counts, output)\n            subset_size_counts[subset_size] += 1\n\n        for i in range(3):\n            self.assertAlmostEqual(transform_counts[i] / iterations, 0.5, delta=0.01)\n\n        for i in range(4):\n            self.assertAlmostEqual(subset_size_counts[i] / iterations, 0.25, delta=0.01)\n\n    def test_non_dict_metatensor(self):\n        data = MetaTensor(1)\n        s = SomeOf([A()], num_transforms=1)\n        out = s(data)\n        self.assertEqual(out, 2)\n        inv = s.inverse(out)  # A() is not invertible, nothing happens\n        self.assertEqual(inv, 2)\n\n    @parameterized.expand(TEST_COMPOUND)\n    def test_compound_pipeline(self, transform, expected_value):\n        output = transform(1)\n        self.assertEqual(output, expected_value)\n\n    # Modified from RandomOrder\n    @parameterized.expand(TEST_INVERSES)\n    def test_inverse(self, transform, invertible, use_metatensor):\n        data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}\n        fwd_data1 = transform(data)\n        # test call twice won't affect inverse\n        fwd_data2 = transform(data)\n\n        if invertible:\n            for k in KEYS:\n                t = (\n                    fwd_data1[TraceableTransform.trace_key(k)][-1]\n                    if not use_metatensor\n                    else fwd_data1[k].applied_operations[-1]\n                )\n                # make sure the SomeOf applied_order was stored\n                self.assertEqual(t[TraceKeys.CLASS_NAME], SomeOf.__name__)\n\n        # call the inverse\n        fwd_inv_data1 = transform.inverse(fwd_data1)\n        fwd_inv_data2 = transform.inverse(fwd_data2)\n\n        fwd_data = [fwd_data1, fwd_data2]\n        fwd_inv_data = [fwd_inv_data1, fwd_inv_data2]\n        for i, _fwd_inv_data in enumerate(fwd_inv_data):\n            if invertible:\n                for k in KEYS:\n                    # check transform was removed\n                    if not use_metatensor:\n                        self.assertTrue(\n                            len(_fwd_inv_data[TraceableTransform.trace_key(k)])\n                            < len(fwd_data[i][TraceableTransform.trace_key(k)])\n                        )\n                    # check data is same as original (and different from forward)\n                    self.assertEqual(_fwd_inv_data[k], data[k])\n                    self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k])\n            else:\n                # if not invertible, should not change the data\n                self.assertDictEqual(fwd_data[i], _fwd_inv_data)\n\n    def test_bad_inverse_data(self):\n        tr = SomeOf((A(), B(), C()), num_transforms=1, weights=(1, 2, 1))\n        self.assertRaises(RuntimeError, tr.inverse, [])\n\n    def test_normalize_weights(self):\n        tr = SomeOf((A(), B(), C()), num_transforms=1, weights=(1, 2, 1))\n        self.assertTupleEqual(tr.weights, (0.25, 0.5, 0.25))\n\n        tr = SomeOf((), num_transforms=1, weights=(1, 2, 1))\n        self.assertIsNone(tr.weights)\n\n    def test_no_weights_arg(self):\n        tr = SomeOf((A(), B(), C(), D()), num_transforms=1)\n        self.assertIsNone(tr.weights)\n\n    def test_bad_weights(self):\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(1, 2))\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(0, 0, 0))\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=1, weights=(-1, 1, 1))\n\n    def test_bad_num_transforms(self):\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(-1, 2))\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=\"str\")\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(1, 2, 3))\n        self.assertRaises(ValueError, SomeOf, (A(), B(), C()), num_transforms=(\"a\", 1))\n\n\nTEST_SOMEOF_EXTENDED_TEST_CASES = [\n    [None, tuple()],\n    [None, (sa.Rotate(np.pi / 8),)],\n    [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())],\n    [(\"a\",), (sd.Rotated((\"a\",), np.pi / 8),)],\n]\n\n\nclass TestSomeOfAPITests(unittest.TestCase):\n\n    @staticmethod\n    def data_from_keys(keys):\n        if keys is None:\n            data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0)\n        else:\n            data = {}\n            for i_k, k in enumerate(keys):\n                data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0)\n        return data\n\n    @parameterized.expand(TEST_SOMEOF_EXTENDED_TEST_CASES)\n    def test_execute_change_start_end(self, keys, pipeline):\n        data = self.data_from_keys(keys)\n\n        c = SomeOf(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, start=1)\n        with self.assertRaises(ValueError):\n            c(data, start=1)\n\n        c = SomeOf(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, end=1)\n        with self.assertRaises(ValueError):\n            c(data, end=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/croppad/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/croppad/test_pad_nd_dtypes.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nTests for pad_nd dtype support and backend selection.\nValidates PyTorch padding preference and NumPy fallback behavior.\n\"\"\"\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest.mock import Mock, patch\n\nimport torch\nfrom parameterized.parameterized import parameterized\n\nimport monai.transforms.croppad.functional as F\nfrom monai.transforms.croppad.functional import pad_nd\n\nDTYPES = [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]\nMODES_DTYPES = [\n    (\"constant\", torch.bool),\n    (\"constant\", torch.int8),\n    (\"constant\", torch.float32),\n    (\"reflect\", torch.bool),\n    (\"reflect\", torch.int8),\n    (\"reflect\", torch.float32),\n    (\"replicate\", torch.bool),\n    (\"replicate\", torch.int8),\n    (\"replicate\", torch.float32),\n]\n\n\nclass TestPadNdDtypes(unittest.TestCase):\n    def test_pad_uses_pt_for_bool(self):\n        \"\"\"Test that pad_nd uses PyTorch backend for bool dtype in constant mode.\"\"\"\n        img = torch.ones((1, 4, 4), dtype=torch.bool)\n        to_pad = [(0, 0), (1, 1), (2, 2)]\n        with (\n            patch.object(F, \"_pt_pad\", wraps=F._pt_pad) as mock_pt,\n            patch.object(F, \"_np_pad\", wraps=F._np_pad) as mock_np,\n        ):\n            out = pad_nd(img, to_pad, mode=\"constant\", value=0)\n\n        self.assertTrue(mock_pt.called)\n        self.assertFalse(mock_np.called)\n        self.assertEqual(out.dtype, img.dtype)\n        self.assertEqual(out.shape, (1, 6, 8))\n\n    def test_pad_falls_back_to_np_if_pt_raises(self):\n        \"\"\"Test that pad_nd falls back to NumPy when PyTorch raises NotImplementedError.\"\"\"\n        img = torch.ones((1, 4, 4), dtype=torch.bool)\n        to_pad = [(0, 0), (1, 1), (2, 2)]\n        with (\n            patch.object(F, \"_pt_pad\", new=Mock(side_effect=NotImplementedError(\"no\"))) as mock_pt,\n            patch.object(F, \"_np_pad\", wraps=F._np_pad) as mock_np,\n        ):\n            out = pad_nd(img, to_pad, mode=\"constant\", value=0)\n\n        self.assertTrue(mock_pt.called)\n        self.assertTrue(mock_np.called)\n        self.assertEqual(out.dtype, img.dtype)\n        self.assertEqual(out.shape, (1, 6, 8))\n\n    @parameterized.expand(DTYPES)\n    def test_pad_dtype_no_error_and_dtype_preserved(self, dtype):\n        \"\"\"Test that pad_nd handles various dtypes without error and preserves dtype.\n        Args:\n            dtype: Input dtype under test.\n        \"\"\"\n        img = torch.ones((1, 4, 4), dtype=dtype)\n        to_pad = [(0, 0), (1, 1), (2, 2)]\n        out = pad_nd(img, to_pad, mode=\"constant\", value=0)\n\n        self.assertEqual(out.shape, (1, 6, 8))\n        self.assertEqual(out.dtype, img.dtype)\n\n    @parameterized.expand(MODES_DTYPES)\n    def test_pad_multiple_modes_dtype_preserved(self, mode, dtype):\n        \"\"\"Test that pad_nd preserves dtype across multiple padding modes.\n        Args:\n            mode: Padding mode under test.\n            dtype: Input dtype under test.\n        \"\"\"\n        img = torch.ones((1, 4, 4), dtype=dtype)\n        to_pad = [(0, 0), (1, 1), (2, 2)]\n\n        kwargs = {\"value\": 0} if mode == \"constant\" else {}\n        out = pad_nd(img, to_pad, mode=mode, **kwargs)\n\n        self.assertEqual(out.shape, (1, 6, 8))\n        self.assertEqual(out.dtype, img.dtype)\n\n    def test_value_with_non_constant_mode_raises(self):\n        \"\"\"Test that pad_nd raises ValueError when 'value' is provided with non-constant mode.\"\"\"\n        img = torch.ones((1, 4, 4))\n        to_pad = [(0, 0), (1, 1), (2, 2)]\n        with self.assertRaises(ValueError):\n            pad_nd(img, to_pad, mode=\"reflect\", value=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/croppad/test_rand_weighted_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized.parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.croppad.array import RandWeightedCrop\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.croppers import CropTest\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose\n\n\ndef get_data(ndim):\n    im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D()\n    im_gen.setUp()\n    return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0]\n\n\nIMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2)\nIMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3)\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for q in TEST_NDARRAYS_ALL:\n        im = SEG1_2D\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = 1.1\n        weight[0, 40, 31] = 1\n        weight[0, 80, 21] = 1\n        TESTS.append(\n            [\n                \"small roi 2d\",\n                dict(spatial_size=(10, 12), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 10, 12),\n                [[80, 21], [30, 17], [40, 31]],\n            ]\n        )\n        im = IMT_2D\n        TESTS.append(\n            [\n                \"default roi 2d\",\n                dict(spatial_size=(10, -1), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 10, 64),\n                [[14, 32], [105, 32], [20, 32]],\n            ]\n        )\n        im = SEGN_2D\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = 1.1\n        weight[0, 10, 1] = 1\n        TESTS.append(\n            [\n                \"large roi 2d\",\n                dict(spatial_size=(10000, 400), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 128, 64),\n                [[64, 32], [64, 32], [64, 32]],\n            ]\n        )\n        im = IMT_2D\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = np.inf\n        weight[0, 10, 1] = -np.inf\n        weight[0, 10, 20] = -np.nan\n        TESTS.append(\n            [\n                \"bad w 2d\",\n                dict(spatial_size=(20, 40), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 20, 40),\n                [[63, 37], [31, 43], [66, 20]],\n            ]\n        )\n        im = SEG1_2D\n        weight_map = np.zeros_like(im, dtype=np.int32)\n        weight_map[0, 30, 20] = 3\n        weight_map[0, 45, 44] = 1\n        weight_map[0, 60, 50] = 2\n        TESTS.append(\n            [\n                \"int w 2d\",\n                dict(spatial_size=(10, 12), num_samples=3),\n                p(im),\n                q(weight_map),\n                (1, 10, 12),\n                [[60, 50], [30, 20], [45, 44]],\n            ]\n        )\n        im = SEG1_3D\n        weight = np.zeros_like(im)\n        weight[0, 5, 30, 17] = 1.1\n        weight[0, 8, 40, 31] = 1\n        weight[0, 11, 23, 21] = 1\n        TESTS.append(\n            [\n                \"small roi 3d\",\n                dict(spatial_size=(8, 10, 12), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 8, 10, 12),\n                [[11, 23, 21], [5, 30, 17], [8, 40, 31]],\n            ]\n        )\n        im = IMT_3D\n        weight = np.zeros_like(im)\n        weight[0, 7, 17] = 1.1\n        weight[0, 13, 31] = 1.1\n        weight[0, 24, 21] = 1\n        TESTS.append(\n            [\n                \"default roi 3d\",\n                dict(spatial_size=(10, -1, -1), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 10, 48, 80),\n                [[14, 24, 40], [41, 24, 40], [20, 24, 40]],\n            ]\n        )\n        im = SEGN_3D\n        weight = np.zeros_like(im)\n        weight[0, 30, 17, 20] = 1.1\n        weight[0, 10, 1, 17] = 1\n        TESTS.append(\n            [\n                \"large roi 3d\",\n                dict(spatial_size=(10000, 400, 80), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 64, 48, 80),\n                [[32, 24, 40], [32, 24, 40], [32, 24, 40]],\n            ]\n        )\n        im = IMT_3D\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = np.inf\n        weight[0, 10, 1] = -np.inf\n        weight[0, 10, 20] = -np.nan\n        TESTS.append(\n            [\n                \"bad w 3d\",\n                dict(spatial_size=(64, 48, 80), num_samples=3),\n                p(im),\n                q(weight),\n                (1, 64, 48, 80),\n                [[32, 24, 40], [32, 24, 40], [32, 24, 40]],\n            ]\n        )\n        im = SEG1_3D\n        weight_map = np.zeros_like(im, dtype=np.int32)\n        weight_map[0, 6, 22, 19] = 4\n        weight_map[0, 8, 40, 31] = 2\n        weight_map[0, 13, 20, 24] = 3\n        TESTS.append(\n            [\n                \"int w 3d\",\n                dict(spatial_size=(8, 10, 12), num_samples=3),\n                p(im),\n                q(weight_map),\n                (1, 8, 10, 12),\n                [[13, 20, 24], [6, 22, 19], [8, 40, 31]],\n            ]\n        )\n\n\nclass TestRandWeightedCrop(CropTest):\n    Cropper = RandWeightedCrop\n\n    @parameterized.expand(TESTS)\n    def test_rand_weighted_crop(self, _, input_params, img, weight, expected_shape, expected_vals):\n        crop = RandWeightedCrop(**input_params)\n        crop.set_random_state(10)\n        result = crop(img, weight)\n        self.assertTrue(len(result) == input_params[\"num_samples\"])\n        assert_allclose(result[0].shape, expected_shape)\n        for c, e in zip(crop.centers, expected_vals):\n            assert_allclose(c, e, type_test=False)\n        # if desired ROI is larger than image, check image is unchanged\n        if all(s >= i for i, s in zip(img.shape[1:], input_params[\"spatial_size\"])):\n            for res in result:\n                assert_allclose(res, img, type_test=\"tensor\")\n                self.assertEqual(len(res.applied_operations), 1)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, _, input_param, img, weight, expected_shape, expected_vals):\n        crop = RandWeightedCrop(**input_param)\n        # non-lazy\n        crop.set_random_state(10)\n        expected = crop(img, weight)\n        self.assertIsInstance(expected[0], MetaTensor)\n        # lazy\n        crop.set_random_state(10)\n        crop.lazy = True\n        pending_result = crop(img, weight)\n        for i, _pending_result in enumerate(pending_result):\n            self.assertIsInstance(_pending_result, MetaTensor)\n            assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine)\n            assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:])\n            # only support nearest\n            result = apply_pending(_pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n            # compare\n            assert_allclose(result, expected[i], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/croppad/test_rand_weighted_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.croppad.dictionary import RandWeightedCropd\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose\n\n\ndef get_data(ndim):\n    im_gen = NumpyImageTestCase2D() if ndim == 2 else NumpyImageTestCase3D()\n    im_gen.setUp()\n    return im_gen.imt[0], im_gen.seg1[0], im_gen.segn[0]\n\n\nIMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2)\nIMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3)\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for q in TEST_NDARRAYS_ALL:\n        im = IMT_2D\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = 1.1\n        weight[0, 40, 31] = 1\n        weight[0, 80, 21] = 1\n        TESTS.append(\n            [\n                \"small roi 2d\",\n                dict(keys=\"img\", w_key=\"w\", spatial_size=(10, 12), num_samples=3),\n                {\"img\": p(im), \"w\": q(weight)},\n                (1, 10, 12),\n                [[80, 21], [30, 17], [40, 31]],\n            ]\n        )\n\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = 1.1\n        weight[0, 40, 31] = 1\n        weight[0, 80, 21] = 1\n        TESTS.append(\n            [\n                \"default roi 2d\",\n                dict(keys=\"img\", w_key=\"w\", spatial_size=(10, -1), num_samples=3),\n                {\"img\": p(im), \"w\": q(weight), \"others\": np.nan},\n                (1, 10, 64),\n                [[14, 32], [105, 32], [20, 32]],\n            ]\n        )\n\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = 1.1\n        weight[0, 10, 1] = 1\n        TESTS.append(\n            [\n                \"large roi 2d\",\n                dict(keys=(\"img\", \"seg\"), w_key=\"weight\", spatial_size=(10000, 400), num_samples=3),\n                {\"img\": p(im), \"seg\": p(SEGN_2D), \"weight\": q(weight)},\n                (1, 128, 64),\n                [[64, 32], [64, 32], [64, 32]],\n            ]\n        )\n\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = np.inf\n        weight[0, 10, 1] = -np.inf\n        weight[0, 10, 20] = -np.nan\n        TESTS.append(\n            [\n                \"bad w roi 2d\",\n                dict(keys=(\"img\", \"seg\"), w_key=\"w\", spatial_size=(20, 40), num_samples=3),\n                {\"img\": p(im), \"seg\": p(SEGN_2D), \"w\": q(weight)},\n                (1, 20, 40),\n                [[63, 37], [31, 43], [66, 20]],\n            ]\n        )\n\n        im = IMT_3D\n        weight = np.zeros_like(im)\n        weight[0, 5, 30, 17] = 1.1\n        weight[0, 8, 40, 31] = 1\n        weight[0, 11, 23, 21] = 1\n        TESTS.append(\n            [\n                \"small roi 3d\",\n                dict(keys=\"img\", w_key=\"w\", spatial_size=(8, 10, 12), num_samples=3),\n                {\"img\": p(im), \"w\": q(weight)},\n                (1, 8, 10, 12),\n                [[11, 23, 21], [5, 30, 17], [8, 40, 31]],\n            ]\n        )\n\n        weight = np.zeros_like(im)\n        weight[0, 5, 30, 17] = 1.1\n        weight[0, 8, 40, 31] = 1\n        weight[0, 11, 23, 21] = 1\n        TESTS.append(\n            [\n                \"default roi 3d\",\n                dict(keys=(\"img\", \"seg\"), w_key=\"w\", spatial_size=(10, -1, -1), num_samples=3),\n                {\"img\": p(im), \"seg\": p(SEGN_3D), \"w\": q(weight)},\n                (1, 10, 64, 80),\n                [[14, 32, 40], [41, 32, 40], [20, 32, 40]],\n            ]\n        )\n\n        weight = np.zeros_like(im)\n        weight[0, 30, 17, 20] = 1.1\n        weight[0, 10, 1, 17] = 1\n        TESTS.append(\n            [\n                \"large roi 3d\",\n                dict(keys=\"img\", w_key=\"w\", spatial_size=(10000, 400, 80), num_samples=3),\n                {\"img\": p(im), \"w\": q(weight)},\n                (1, 48, 64, 80),\n                [[24, 32, 40], [24, 32, 40], [24, 32, 40]],\n            ]\n        )\n\n        weight = np.zeros_like(im)\n        weight[0, 30, 17] = np.inf\n        weight[0, 10, 1] = -np.inf\n        weight[0, 10, 20] = -np.nan\n        TESTS.append(\n            [\n                \"bad w roi 3d\",\n                dict(keys=(\"img\", \"seg\"), w_key=\"w\", spatial_size=(48, 64, 80), num_samples=3),\n                {\"img\": p(im), \"seg\": p(SEGN_3D), \"w\": q(weight)},\n                (1, 48, 64, 80),\n                [[24, 32, 40], [24, 32, 40], [24, 32, 40]],\n            ]\n        )\n\n\nclass TestRandWeightedCrop(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, expected_centers):\n        crop = RandWeightedCropd(**init_params)\n        crop.set_random_state(10)\n        result = crop(input_data)\n        self.assertEqual(len(result), init_params[\"num_samples\"])\n        _len = len(tuple(input_data.keys()))\n        self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys()))\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, _, input_param, input_data, expected_shape, expected_centers):\n        crop = RandWeightedCropd(**input_param)\n        # non-lazy\n        crop.set_random_state(10)\n        expected = crop(input_data)\n        self.assertIsInstance(expected[0][\"img\"], MetaTensor)\n        # lazy\n        crop.set_random_state(10)\n        crop.lazy = True\n        pending_result = crop(input_data)\n        for i, _pending_result in enumerate(pending_result):\n            self.assertIsInstance(_pending_result[\"img\"], MetaTensor)\n            assert_allclose(_pending_result[\"img\"].peek_pending_affine(), expected[i][\"img\"].affine)\n            assert_allclose(_pending_result[\"img\"].peek_pending_shape(), expected[i][\"img\"].shape[1:])\n            # only support nearest\n            result = apply_pending(_pending_result[\"img\"], overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n            # compare\n            assert_allclose(result, expected[i][\"img\"], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/functional/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/functional/test_apply.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms.lazy.functional import apply_pending\nfrom monai.transforms.utils import create_rotate\nfrom monai.utils import LazyAttr, convert_to_tensor\nfrom tests.test_utils import get_arange_img\n\n\ndef single_2d_transform_cases():\n    return [\n        (\n            torch.as_tensor(get_arange_img((32, 32))),\n            [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}],\n            (1, 32, 32),\n        ),\n        (torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)),\n        (\n            torch.as_tensor(get_arange_img((16, 16))),\n            [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (45, 45)}],\n            (1, 45, 45),\n        ),\n    ]\n\n\nclass TestApply(unittest.TestCase):\n    def _test_apply_impl(self, tensor, pending_transforms, expected_shape):\n        result = apply_pending(tensor, pending_transforms)\n        self.assertListEqual(result[1], pending_transforms)\n        self.assertEqual(result[0].shape, expected_shape)\n\n    def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter):\n        tensor_ = convert_to_tensor(tensor, track_meta=True)\n        if pending_as_parameter:\n            result, transforms = apply_pending(tensor_, pending_transforms)\n        else:\n            for p in pending_transforms:\n                tensor_.push_pending_operation(p)\n                if not isinstance(p, dict):\n                    return\n            result, transforms = apply_pending(tensor_)\n        self.assertEqual(result.shape, expected_shape)\n\n    SINGLE_TRANSFORM_CASES = single_2d_transform_cases()\n\n    def test_apply_single_transform(self):\n        for case in self.SINGLE_TRANSFORM_CASES:\n            self._test_apply_impl(*case)\n\n    def test_apply_single_transform_metatensor(self):\n        for case in self.SINGLE_TRANSFORM_CASES:\n            self._test_apply_metatensor_impl(*case, False)\n\n    def test_apply_single_transform_metatensor_override(self):\n        for case in self.SINGLE_TRANSFORM_CASES:\n            self._test_apply_metatensor_impl(*case, True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/functional/test_resample.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.lazy.functional import resample\nfrom monai.utils import convert_to_tensor\nfrom tests.test_utils import assert_allclose, get_arange_img\n\n\ndef rotate_90_2d():\n    t = torch.eye(3)\n    t[:, 0] = torch.FloatTensor([0, -1, 0])\n    t[:, 1] = torch.FloatTensor([1, 0, 0])\n    return t\n\n\nRESAMPLE_FUNCTION_CASES = [\n    (get_arange_img((3, 3)), rotate_90_2d(), [[0, 3, 6], [0, 3, 6], [0, 3, 6]]),\n    (get_arange_img((3, 3)), torch.eye(3), get_arange_img((3, 3))[0]),\n]\n\n\nclass TestResampleFunction(unittest.TestCase):\n    @parameterized.expand(RESAMPLE_FUNCTION_CASES)\n    def test_resample_function_impl(self, img, matrix, expected):\n        out = resample(convert_to_tensor(img), matrix, {\"lazy_shape\": img.shape[1:], \"lazy_padding_mode\": \"border\"})\n        assert_allclose(out[0], expected, type_test=False)\n\n        img = convert_to_tensor(img, dtype=torch.uint8)\n        out = resample(img, matrix, {\"lazy_resample_mode\": \"auto\", \"lazy_dtype\": torch.float})\n        out_1 = resample(img, matrix, {\"lazy_resample_mode\": \"other value\", \"lazy_dtype\": torch.float})\n        self.assertIs(out.dtype, out_1.dtype)  # testing dtype in different lazy_resample_mode\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/intensity/test_compute_ho_ver_maps.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.intensity.array import ComputeHoVerMaps\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.0\", min_version)\n\nINSTANCE_MASK = np.zeros((1, 16, 16), dtype=\"int16\")\nINSTANCE_MASK[:, 5:8, 4:11] = 1\nINSTANCE_MASK[:, 3:5, 6:9] = 1\nINSTANCE_MASK[:, 8:10, 6:9] = 1\nINSTANCE_MASK[:, 13:, 13:] = 2\nH_MAP = torch.zeros((16, 16), dtype=torch.float32)\nH_MAP[5:8, 4] = -1.0\nH_MAP[5:8, 5] = -2.0 / 3.0\nH_MAP[3:10, 6] = -1.0 / 3.0\nH_MAP[3:10, 7] = 0.0\nH_MAP[3:10, 8] = 1.0 / 3.0\nH_MAP[5:8, 9] = 2.0 / 3.0\nH_MAP[5:8, 10] = 1.0\nH_MAP[13:, 13] = -1.0\nH_MAP[13:, 14] = 0.0\nH_MAP[13:, 15] = 1.0\nV_MAP = torch.zeros((16, 16), dtype=torch.float32)\nV_MAP[3, 6:9] = -1.0\nV_MAP[4, 6:9] = -2.0 / 3.0\nV_MAP[5, 4:11] = -1.0 / 3.0\nV_MAP[6, 4:11] = 0.0\nV_MAP[7, 4:11] = 1.0 / 3.0\nV_MAP[8, 6:9] = 2.0 / 3.0\nV_MAP[9, 6:9] = 1.0\nV_MAP[13, 13:] = -1.0\nV_MAP[14, 13:] = 0.0\nV_MAP[15, 13:] = 1.0\nHV_MAPS = torch.stack([H_MAP, V_MAP])\nTEST_CASE_0 = [{}, INSTANCE_MASK, HV_MAPS]\nTEST_CASE_1 = [{\"dtype\": \"float64\"}, INSTANCE_MASK, HV_MAPS]\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, *TEST_CASE_0])\n    TESTS.append([p, *TEST_CASE_1])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass ComputeHoVerMapsTests(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):\n        input_image = in_type(mask)\n        result = ComputeHoVerMaps(**arguments)(input_image)\n        self.assertIsInstance(result, torch.Tensor)\n        self.assertEqual(str(result.dtype).split(\".\")[1], arguments.get(\"dtype\", \"float32\"))\n        assert_allclose(result, hv_mask, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/test_compute_ho_ver_maps_d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.intensity.dictionary import ComputeHoVerMapsd\nfrom monai.utils import min_version, optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_skimage = optional_import(\"skimage\", \"0.19.0\", min_version)\n\nINSTANCE_MASK = np.zeros((1, 16, 16), dtype=\"int16\")\nINSTANCE_MASK[:, 5:8, 4:11] = 1\nINSTANCE_MASK[:, 3:5, 6:9] = 1\nINSTANCE_MASK[:, 8:10, 6:9] = 1\nINSTANCE_MASK[:, 13:, 13:] = 2\nH_MAP = torch.zeros((16, 16), dtype=torch.float32)\nH_MAP[5:8, 4] = -1.0\nH_MAP[5:8, 5] = -2.0 / 3.0\nH_MAP[3:10, 6] = -1.0 / 3.0\nH_MAP[3:10, 7] = 0.0\nH_MAP[3:10, 8] = 1.0 / 3.0\nH_MAP[5:8, 9] = 2.0 / 3.0\nH_MAP[5:8, 10] = 1.0\nH_MAP[13:, 13] = -1.0\nH_MAP[13:, 14] = 0.0\nH_MAP[13:, 15] = 1.0\nV_MAP = torch.zeros((16, 16), dtype=torch.float32)\nV_MAP[3, 6:9] = -1.0\nV_MAP[4, 6:9] = -2.0 / 3.0\nV_MAP[5, 4:11] = -1.0 / 3.0\nV_MAP[6, 4:11] = 0.0\nV_MAP[7, 4:11] = 1.0 / 3.0\nV_MAP[8, 6:9] = 2.0 / 3.0\nV_MAP[9, 6:9] = 1.0\nV_MAP[13, 13:] = -1.0\nV_MAP[14, 13:] = 0.0\nV_MAP[15, 13:] = 1.0\nHV_MAPS = torch.stack([H_MAP, V_MAP])\nTEST_CASE_0 = [{}, {\"mask\": INSTANCE_MASK}, {\"hover_mask\": HV_MAPS}]\nTEST_CASE_1 = [{\"dtype\": \"float64\"}, {\"mask\": INSTANCE_MASK}, {\"hover_mask\": HV_MAPS}]\nTEST_CASE_1 = [{\"dtype\": \"float64\", \"new_key_prefix\": \"\"}, {\"mask\": INSTANCE_MASK}, {\"mask\": HV_MAPS}]\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, *TEST_CASE_0])\n    TESTS.append([p, *TEST_CASE_1])\n\n\n@unittest.skipUnless(has_skimage, \"Requires scikit-image library.\")\nclass ComputeHoVerMapsDictTests(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):\n        hv_key = list(hv_mask.keys())[0]\n        input_image = {}\n        for k in mask.keys():\n            input_image[k] = in_type(mask[k])\n        result = ComputeHoVerMapsd(keys=\"mask\", **arguments)(input_image)[hv_key]\n        self.assertIsInstance(result, torch.Tensor)\n        self.assertEqual(str(result.dtype).split(\".\")[1], arguments.get(\"dtype\", \"float32\"))\n        assert_allclose(result, hv_mask[hv_key], type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/test_foreground_mask.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.intensity.array import ForegroundMask\nfrom monai.utils import min_version, optional_import, set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nskimage, has_skimage = optional_import(\"skimage\", \"0.19.0\", min_version)\nset_determinism(1234)\n\nA = np.random.randint(64, 128, (3, 3, 2)).astype(np.uint8)\nA3D = np.random.randint(64, 128, (3, 3, 2, 2)).astype(np.uint8)\nB = np.ones_like(A[:1])\nB3D = np.ones_like(A3D[:1])\nMASK = np.pad(B, ((0, 0), (2, 2), (2, 2)), constant_values=0)\nMASK3D = np.pad(B3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=0)\nIMAGE1 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=255)\nIMAGE3D = np.pad(A3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=255)\nIMAGE2 = np.copy(IMAGE1)\nIMAGE2[0] = 0\nIMAGE3 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=0)\nTEST_CASE_0 = [{}, IMAGE1, MASK]\nTEST_CASE_1 = [{\"threshold\": \"otsu\"}, IMAGE1, MASK]\nTEST_CASE_2 = [{\"threshold\": \"otsu\"}, IMAGE2, MASK]\nTEST_CASE_3 = [{\"threshold\": 140}, IMAGE1, MASK]\nTEST_CASE_4 = [{\"threshold\": \"otsu\", \"invert\": True}, IMAGE3, MASK]\nTEST_CASE_5 = [{\"threshold\": 0.5}, MASK, np.logical_not(MASK)]\nTEST_CASE_6 = [{\"threshold\": 140}, IMAGE2, np.ones_like(MASK)]\nTEST_CASE_7 = [{\"threshold\": {\"R\": \"otsu\", \"G\": \"otsu\", \"B\": \"otsu\"}}, IMAGE2, MASK]\nTEST_CASE_8 = [{\"threshold\": {\"R\": 140, \"G\": \"otsu\", \"B\": \"otsu\"}}, IMAGE2, np.ones_like(MASK)]\nTEST_CASE_9 = [{\"threshold\": {\"R\": 140, \"G\": skimage.filters.threshold_otsu, \"B\": \"otsu\"}}, IMAGE2, np.ones_like(MASK)]\nTEST_CASE_10 = [{\"threshold\": skimage.filters.threshold_mean}, IMAGE1, MASK]\nTEST_CASE_11 = [{\"threshold\": None, \"hsv_threshold\": \"otsu\"}, IMAGE1, np.ones_like(MASK)]\nTEST_CASE_12 = [{\"threshold\": None, \"hsv_threshold\": {\"S\": \"otsu\"}}, IMAGE1, MASK]\nTEST_CASE_13 = [{\"threshold\": 100, \"invert\": True}, IMAGE1, np.logical_not(MASK)]\nTEST_CASE_14 = [{}, IMAGE3D, MASK3D]\nTEST_CASE_15 = [{\"hsv_threshold\": {\"S\": 0.1}}, IMAGE3D, MASK3D]\n\nTEST_CASE_ERROR_1 = [{\"threshold\": None}, IMAGE1]\nTEST_CASE_ERROR_2 = [{\"threshold\": {\"K\": 1}}, IMAGE1]\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, *TEST_CASE_0])\n    TESTS.append([p, *TEST_CASE_1])\n    TESTS.append([p, *TEST_CASE_2])\n    TESTS.append([p, *TEST_CASE_3])\n    TESTS.append([p, *TEST_CASE_4])\n    TESTS.append([p, *TEST_CASE_5])\n    TESTS.append([p, *TEST_CASE_6])\n    TESTS.append([p, *TEST_CASE_7])\n    TESTS.append([p, *TEST_CASE_8])\n    TESTS.append([p, *TEST_CASE_9])\n    TESTS.append([p, *TEST_CASE_10])\n    TESTS.append([p, *TEST_CASE_11])\n    TESTS.append([p, *TEST_CASE_12])\n    TESTS.append([p, *TEST_CASE_13])\n    TESTS.append([p, *TEST_CASE_14])\n    TESTS.append([p, *TEST_CASE_15])\n\nTESTS_ERROR = []\nfor p in TEST_NDARRAYS:\n    TESTS_ERROR.append([p, *TEST_CASE_ERROR_1])\n    TESTS_ERROR.append([p, *TEST_CASE_ERROR_2])\n\n\n@unittest.skipUnless(has_skimage, \"Requires sci-kit image\")\nclass TestForegroundMask(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_foreground_mask(self, in_type, arguments, image, mask):\n        input_image = in_type(image)\n        result = ForegroundMask(**arguments)(input_image)\n        assert_allclose(result, mask, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS_ERROR)\n    def test_foreground_mask_error(self, in_type, arguments, image):\n        input_image = in_type(image)\n        with self.assertRaises(ValueError):\n            ForegroundMask(**arguments)(input_image)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/test_foreground_maskd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.intensity.dictionary import ForegroundMaskd\nfrom monai.utils import min_version, optional_import, set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nskimage, has_skimage = optional_import(\"skimage\", \"0.19.0\", min_version)\nset_determinism(1234)\n\nA = np.random.randint(64, 128, (3, 3, 2)).astype(np.uint8)\nA3D = np.random.randint(64, 128, (3, 3, 2, 2)).astype(np.uint8)\nB = np.ones_like(A[:1])\nB3D = np.ones_like(A3D[:1])\nMASK = np.pad(B, ((0, 0), (2, 2), (2, 2)), constant_values=0)\nMASK3D = np.pad(B3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=0)\nIMAGE1 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=255)\nIMAGE3D = np.pad(A3D, ((0, 0), (2, 2), (2, 2), (2, 2)), constant_values=255)\nIMAGE2 = np.copy(IMAGE1)\nIMAGE2[0] = 0\nIMAGE3 = np.pad(A, ((0, 0), (2, 2), (2, 2)), constant_values=0)\nTEST_CASE_0 = [{\"keys\": \"image\"}, {\"image\": IMAGE1}, MASK]\nTEST_CASE_1 = [{\"keys\": \"image\", \"threshold\": \"otsu\"}, {\"image\": IMAGE1}, MASK]\nTEST_CASE_2 = [{\"keys\": \"image\", \"threshold\": \"otsu\"}, {\"image\": IMAGE2}, MASK]\nTEST_CASE_3 = [{\"keys\": \"image\", \"threshold\": 140}, {\"image\": IMAGE1}, MASK]\nTEST_CASE_4 = [{\"keys\": \"image\", \"threshold\": \"otsu\", \"invert\": True}, {\"image\": IMAGE3}, MASK]\nTEST_CASE_5 = [{\"keys\": \"image\", \"threshold\": 0.5}, {\"image\": MASK}, np.logical_not(MASK)]\nTEST_CASE_6 = [{\"keys\": \"image\", \"threshold\": 140}, {\"image\": IMAGE2}, np.ones_like(MASK)]\nTEST_CASE_7 = [{\"keys\": \"image\", \"threshold\": {\"R\": \"otsu\", \"G\": \"otsu\", \"B\": \"otsu\"}}, {\"image\": IMAGE2}, MASK]\nTEST_CASE_8 = [\n    {\"keys\": \"image\", \"threshold\": {\"R\": 140, \"G\": \"otsu\", \"B\": \"otsu\"}},\n    {\"image\": IMAGE2},\n    np.ones_like(MASK),\n]\nTEST_CASE_9 = [\n    {\"keys\": \"image\", \"threshold\": {\"R\": 140, \"G\": skimage.filters.threshold_otsu, \"B\": \"otsu\"}},\n    {\"image\": IMAGE2},\n    np.ones_like(MASK),\n]\nTEST_CASE_10 = [{\"keys\": \"image\", \"threshold\": skimage.filters.threshold_mean}, {\"image\": IMAGE1}, MASK]\nTEST_CASE_11 = [{\"keys\": \"image\", \"threshold\": None, \"hsv_threshold\": \"otsu\"}, {\"image\": IMAGE1}, np.ones_like(MASK)]\nTEST_CASE_12 = [{\"keys\": \"image\", \"threshold\": None, \"hsv_threshold\": {\"S\": \"otsu\"}}, {\"image\": IMAGE1}, MASK]\nTEST_CASE_13 = [{\"keys\": \"image\", \"threshold\": 100, \"invert\": True}, {\"image\": IMAGE1}, np.logical_not(MASK)]\nTEST_CASE_14 = [{\"keys\": \"image\"}, {\"image\": IMAGE3D}, MASK3D]\nTEST_CASE_15 = [{\"keys\": \"image\", \"hsv_threshold\": {\"S\": 0.1}}, {\"image\": IMAGE3D}, MASK3D]\n\nTEST_CASE_ERROR_1 = [{\"keys\": \"image\", \"threshold\": None}, {\"image\": IMAGE1}]\nTEST_CASE_ERROR_2 = [{\"keys\": \"image\", \"threshold\": {\"K\": 1}}, {\"image\": IMAGE1}]\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, *TEST_CASE_0])\n    TESTS.append([p, *TEST_CASE_1])\n    TESTS.append([p, *TEST_CASE_2])\n    TESTS.append([p, *TEST_CASE_3])\n    TESTS.append([p, *TEST_CASE_4])\n    TESTS.append([p, *TEST_CASE_5])\n    TESTS.append([p, *TEST_CASE_6])\n    TESTS.append([p, *TEST_CASE_7])\n    TESTS.append([p, *TEST_CASE_8])\n    TESTS.append([p, *TEST_CASE_9])\n    TESTS.append([p, *TEST_CASE_10])\n    TESTS.append([p, *TEST_CASE_11])\n    TESTS.append([p, *TEST_CASE_12])\n    TESTS.append([p, *TEST_CASE_13])\n    TESTS.append([p, *TEST_CASE_14])\n    TESTS.append([p, *TEST_CASE_15])\n\nTESTS_ERROR = []\nfor p in TEST_NDARRAYS:\n    TESTS_ERROR.append([p, *TEST_CASE_ERROR_1])\n    TESTS_ERROR.append([p, *TEST_CASE_ERROR_2])\n\n\n@unittest.skipUnless(has_skimage, \"Requires sci-kit image\")\nclass TestForegroundMaskd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_foreground_mask(self, in_type, arguments, data_dict, mask):\n        data_dict[arguments[\"keys\"]] = in_type(data_dict[arguments[\"keys\"]])\n        result = ForegroundMaskd(**arguments)(data_dict)[arguments[\"keys\"]]\n        assert_allclose(result, mask, type_test=False)\n\n    @parameterized.expand(TESTS_ERROR)\n    def test_foreground_mask_error(self, in_type, arguments, data_dict):\n        data_dict[arguments[\"keys\"]] = in_type(data_dict[arguments[\"keys\"]])\n        with self.assertRaises(ValueError):\n            ForegroundMaskd(**arguments)(data_dict)[arguments[\"keys\"]]\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/test_rand_histogram_shiftd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.intensity.dictionary import RandHistogramShiftd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": (\"img\",), \"num_control_points\": 5, \"prob\": 0.0},\n            {\"img\": p(np.arange(8).reshape((1, 2, 2, 2))), \"seg\": p(np.ones(8).reshape((1, 2, 2, 2)))},\n            {\"img\": np.arange(8).reshape((1, 2, 2, 2)), \"seg\": np.ones(8).reshape((1, 2, 2, 2))},\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": (\"img\",), \"num_control_points\": 5, \"prob\": 0.9},\n            {\n                \"img\": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)),\n                \"seg\": p(np.ones(8).reshape((1, 2, 2, 2))),\n            },\n            {\n                \"img\": np.array(\n                    [[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]\n                ),\n                \"seg\": np.ones(8).reshape((1, 2, 2, 2)),\n            },\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": (\"img\",), \"num_control_points\": (5, 20), \"prob\": 0.9},\n            {\n                \"img\": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32)),\n                \"seg\": p(np.ones(8).reshape((1, 2, 2, 2))),\n            },\n            {\n                \"img\": np.array(\n                    [[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]\n                ),\n                \"seg\": np.ones(8).reshape((1, 2, 2, 2)),\n            },\n        ]\n    )\n\n\nclass TestRandHistogramShiftD(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_histogram_shiftd(self, input_param, input_data, expected_val):\n        g = RandHistogramShiftd(**input_param)\n        g.set_random_state(123)\n        res = g(input_data)\n        for key in (\"img\",):\n            result = res[key]\n            expected = expected_val[key] if isinstance(expected_val, dict) else expected_val\n            assert_allclose(result, expected, rtol=1e-4, atol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/test_scale_intensity_range_percentiles.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms.intensity.array import ScaleIntensityRangePercentiles\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestScaleIntensityRangePercentiles(NumpyImageTestCase2D):\n    def test_scaling(self):\n        img = self.imt[0]\n        lower = 10\n        upper = 99\n        b_min = 0\n        b_max = 255\n\n        a_min = np.percentile(img, lower)\n        a_max = np.percentile(img, upper)\n        expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)\n        scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8)\n        for p in TEST_NDARRAYS:\n            result = scaler(p(img))\n            self.assertEqual(result.dtype, torch.uint8)\n            assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4)\n\n    def test_relative_scaling(self):\n        img = self.imt[0]\n        lower = 10\n        upper = 99\n        b_min = 100\n        b_max = 300\n        scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, relative=True)\n\n        expected_a_min = np.percentile(img, lower)\n        expected_a_max = np.percentile(img, upper)\n        expected_b_min = ((b_max - b_min) * (lower / 100.0)) + b_min\n        expected_b_max = ((b_max - b_min) * (upper / 100.0)) + b_min\n        expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min)\n        expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min\n\n        for p in TEST_NDARRAYS:\n            result = scaler(p(img))\n            assert_allclose(result, p(expected_img), type_test=\"tensor\", rtol=1e-3)\n\n        scaler = ScaleIntensityRangePercentiles(\n            lower=lower, upper=upper, b_min=b_min, b_max=b_max, relative=True, clip=True\n        )\n        for p in TEST_NDARRAYS:\n            result = scaler(p(img))\n            assert_allclose(\n                result, p(np.clip(expected_img, expected_b_min, expected_b_max)), type_test=\"tensor\", rtol=0.1\n            )\n\n    def test_invalid_instantiation(self):\n        self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255)\n        self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=101, upper=99, b_min=0, b_max=255)\n        self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=-20, b_min=0, b_max=255)\n        self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=30, upper=900, b_min=0, b_max=255)\n\n    def test_channel_wise(self):\n        img = np.tile(self.imt, (3, 1, 1, 1))\n        lower = 10\n        upper = 99\n        b_min = 0\n        b_max = 255\n        scaler = ScaleIntensityRangePercentiles(\n            lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8\n        )\n        expected = []\n        for c in img:\n            a_min = np.percentile(c, lower)\n            a_max = np.percentile(c, upper)\n            expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8))\n        expected = np.stack(expected)\n\n        for p in TEST_NDARRAYS:\n            result = scaler(p(img))\n            assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/intensity/test_scale_intensity_range_percentilesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms.intensity.dictionary import ScaleIntensityRangePercentilesd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestScaleIntensityRangePercentilesd(NumpyImageTestCase2D):\n    def test_scaling(self):\n        img = self.imt\n        lower = 10\n        upper = 99\n        b_min = 0\n        b_max = 255\n\n        a_min = np.percentile(img, lower)\n        a_max = np.percentile(img, upper)\n        expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8)\n\n        for p in TEST_NDARRAYS:\n            data = {\"img\": p(img)}\n            scaler = ScaleIntensityRangePercentilesd(\n                keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8\n            )\n            assert_allclose(scaler(data)[\"img\"], p(expected), type_test=\"tensor\", rtol=1e-4)\n\n    def test_relative_scaling(self):\n        img = self.imt\n        data = {\"img\": img}\n        lower = 10\n        upper = 99\n        b_min = 100\n        b_max = 300\n        scaler = ScaleIntensityRangePercentilesd(\n            keys=data.keys(), lower=lower, upper=upper, b_min=b_min, b_max=b_max, relative=True\n        )\n\n        expected_a_min = np.percentile(img, lower)\n        expected_a_max = np.percentile(img, upper)\n        expected_b_min = ((b_max - b_min) * (lower / 100.0)) + b_min\n        expected_b_max = ((b_max - b_min) * (upper / 100.0)) + b_min\n        expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min)\n        expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min\n\n        np.testing.assert_allclose(expected_img, scaler(data)[\"img\"], rtol=1e-3, atol=0.1)\n\n    def test_invalid_instantiation(self):\n        self.assertRaises(\n            ValueError, ScaleIntensityRangePercentilesd, keys=[\"img\"], lower=-1, upper=99, b_min=0, b_max=255\n        )\n        self.assertRaises(\n            ValueError, ScaleIntensityRangePercentilesd, keys=[\"img\"], lower=101, upper=99, b_min=0, b_max=255\n        )\n        self.assertRaises(\n            ValueError, ScaleIntensityRangePercentilesd, keys=[\"img\"], lower=30, upper=-2, b_min=0, b_max=255\n        )\n        self.assertRaises(\n            ValueError, ScaleIntensityRangePercentilesd, keys=[\"img\"], lower=30, upper=1000, b_min=0, b_max=255\n        )\n        with self.assertRaises(ValueError):\n            s = ScaleIntensityRangePercentilesd(keys=[\"img\"], lower=30, upper=90, b_min=None, b_max=20, relative=True)\n            s(self.imt)\n\n    def test_channel_wise(self):\n        img = np.tile(self.imt, (3, 1, 1, 1))\n        lower = 10\n        upper = 99\n        b_min = 0\n        b_max = 255\n        scaler = ScaleIntensityRangePercentilesd(\n            keys=\"img\", lower=lower, upper=upper, b_min=b_min, b_max=b_max, channel_wise=True, dtype=np.uint8\n        )\n        expected = []\n        for c in img:\n            a_min = np.percentile(c, lower)\n            a_max = np.percentile(c, upper)\n            expected.append((((c - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8))\n        expected = np.stack(expected)\n\n        for p in TEST_NDARRAYS:\n            data = {\"img\": p(img)}\n            assert_allclose(scaler(data)[\"img\"], p(expected), type_test=\"tensor\", rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/inverse/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/inverse/test_inverse.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nimport sys\nimport unittest\nfrom copy import deepcopy\nfrom functools import partial\nfrom typing import TYPE_CHECKING\nfrom unittest.case import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import CacheDataset, DataLoader, MetaTensor, create_test_image_2d, create_test_image_3d, decollate_batch\nfrom monai.networks.nets import UNet\nfrom monai.transforms import (\n    Affined,\n    BorderPadd,\n    CenterScaleCropd,\n    CenterSpatialCropd,\n    Compose,\n    CropForegroundd,\n    DivisiblePadd,\n    EnsureChannelFirstd,\n    Flipd,\n    FromMetaTensord,\n    InvertibleTransform,\n    Lambdad,\n    LoadImaged,\n    Orientationd,\n    RandAffined,\n    RandAxisFlipd,\n    RandCropByLabelClassesd,\n    RandCropByPosNegLabeld,\n    RandFlipd,\n    RandLambdad,\n    Randomizable,\n    RandRotate90d,\n    RandRotated,\n    RandSpatialCropd,\n    RandSpatialCropSamplesd,\n    RandWeightedCropd,\n    RandZoomd,\n    Resized,\n    ResizeWithPadOrCrop,\n    ResizeWithPadOrCropd,\n    Rotate90d,\n    Rotated,\n    Spacingd,\n    SpatialCropd,\n    SpatialPadd,\n    ToMetaTensord,\n    Transposed,\n    Zoomd,\n    allow_missing_keys_mode,\n    convert_applied_interp_mode,\n    reset_ops_id,\n)\nfrom monai.utils import first, get_seed, optional_import, set_determinism\nfrom tests.test_utils import make_nifti_image, make_rand_affine\n\nif TYPE_CHECKING:\n    has_nib = True\nelse:\n    _, has_nib = optional_import(\"nibabel\")\n\nKEYS = [\"image\", \"label\"]\n\nTESTS: list[tuple] = []\n\n# For pad, start with odd/even images and add odd/even amounts\nfor name in (\"1D even\", \"1D odd\"):\n    for val in (3, 4):\n        for t in (\n            partial(SpatialPadd, spatial_size=val, method=\"symmetric\"),\n            partial(SpatialPadd, spatial_size=val, method=\"end\"),\n            partial(BorderPadd, spatial_border=[val, val + 1]),\n            partial(DivisiblePadd, k=val),\n            partial(ResizeWithPadOrCropd, spatial_size=20 + val),\n            partial(CenterSpatialCropd, roi_size=10 + val),\n            partial(CenterScaleCropd, roi_scale=0.8),\n            partial(CropForegroundd, source_key=\"label\"),\n            partial(SpatialCropd, roi_center=10, roi_size=10 + val),\n            partial(SpatialCropd, roi_center=11, roi_size=10 + val),\n            partial(SpatialCropd, roi_start=val, roi_end=17),\n            partial(SpatialCropd, roi_start=val, roi_end=16),\n            partial(RandSpatialCropd, roi_size=12 + val),\n            partial(ResizeWithPadOrCropd, spatial_size=21 - val),\n        ):\n            TESTS.append((t.func.__name__ + name, name, 0, True, t(KEYS)))  # type: ignore\n\n# non-sensical tests: crop bigger or pad smaller or -ve values\nfor t in (\n    partial(DivisiblePadd, k=-3),\n    partial(CenterSpatialCropd, roi_size=-3),\n    partial(RandSpatialCropd, roi_size=-3),\n    partial(SpatialPadd, spatial_size=15),\n    partial(BorderPadd, spatial_border=[15, 16]),\n    partial(CenterSpatialCropd, roi_size=30),\n    partial(SpatialCropd, roi_center=10, roi_size=100),\n    partial(SpatialCropd, roi_start=3, roi_end=100),\n):\n    TESTS.append((t.func.__name__ + \"bad 1D even\", \"1D even\", 0, True, t(KEYS)))  # type: ignore\n\nTESTS.append(\n    (\n        \"SpatialPadd (x2) 2d\",\n        \"2D\",\n        0,\n        True,\n        SpatialPadd(KEYS, spatial_size=[111, 113], method=\"end\"),\n        SpatialPadd(KEYS, spatial_size=[118, 117]),\n    )\n)\n\nTESTS.append((\"SpatialPadd 3d\", \"3D\", 0, True, SpatialPadd(KEYS, spatial_size=[112, 113, 116])))\n\nTESTS.append((\"SpatialCropd 2d\", \"2D\", 0, True, SpatialCropd(KEYS, [49, 51], [90, 89])))\n\nTESTS.append(\n    (\n        \"SpatialCropd 3d\",\n        \"3D\",\n        0,\n        True,\n        SpatialCropd(KEYS, roi_slices=[slice(s, e) for s, e in zip([None, None, -99], [None, -2, None])]),\n    )\n)\n\nTESTS.append((\"SpatialCropd 2d\", \"2D\", 0, True, SpatialCropd(KEYS, [49, 51], [390, 89])))\n\nTESTS.append((\"SpatialCropd 3d\", \"3D\", 0, True, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93])))\n\nTESTS.append((\"RandSpatialCropd 2d\", \"2D\", 0, True, RandSpatialCropd(KEYS, [96, 93], None, True, False)))\n\nTESTS.append((\"RandSpatialCropd 3d\", \"3D\", 0, True, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False)))\n\nTESTS.append((\"BorderPadd 2d\", \"2D\", 0, True, BorderPadd(KEYS, [3, 7, 2, 5])))\n\nTESTS.append((\"BorderPadd 2d\", \"2D\", 0, True, BorderPadd(KEYS, [3, 7])))\n\nTESTS.append((\"BorderPadd 3d\", \"3D\", 0, True, BorderPadd(KEYS, [4])))\n\nTESTS.append((\"DivisiblePadd 2d\", \"2D\", 0, True, DivisiblePadd(KEYS, k=4)))\n\nTESTS.append((\"DivisiblePadd 3d\", \"3D\", 0, True, DivisiblePadd(KEYS, k=[4, 8, 11])))\n\nTESTS.append((\"CenterSpatialCropd 2d\", \"2D\", 0, True, CenterSpatialCropd(KEYS, roi_size=95)))\n\nTESTS.append((\"CenterSpatialCropd 3d\", \"3D\", 0, True, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98])))\n\nTESTS.append((\"CropForegroundd 2d\", \"2D\", 0, True, CropForegroundd(KEYS, source_key=\"label\", margin=2)))\n\nTESTS.append((\"CropForegroundd 3d\", \"3D\", 0, True, CropForegroundd(KEYS, source_key=\"label\", k_divisible=[5, 101, 2])))\n\nTESTS.append((\"ResizeWithPadOrCropd 3d\", \"3D\", 0, True, ResizeWithPadOrCropd(KEYS, [201, 150, 105])))\n\nTESTS.append((\"Flipd 3d\", \"3D\", 0, True, Flipd(KEYS, [1, 2])))\nTESTS.append((\"Flipd 3d\", \"3D\", 0, True, Flipd(KEYS, [1, 2])))\n\nTESTS.append((\"RandFlipd 3d\", \"3D\", 0, True, RandFlipd(KEYS, 1, [1, 2])))\n\nTESTS.append((\"RandAxisFlipd 3d\", \"3D\", 0, True, RandAxisFlipd(KEYS, 1)))\nTESTS.append((\"RandAxisFlipd 3d\", \"3D\", 0, True, RandAxisFlipd(KEYS, 1)))\n\nfor acc in [True, False]:\n    TESTS.append((\"Orientationd 3d\", \"3D\", 0, True, Orientationd(KEYS, \"RAS\", as_closest_canonical=acc)))\n\nTESTS.append((\"Rotate90d 2d\", \"2D\", 0, True, Rotate90d(KEYS)))\n\nTESTS.append((\"Rotate90d 3d\", \"3D\", 0, True, Rotate90d(KEYS, k=2, spatial_axes=(1, 2))))\n\nTESTS.append((\"RandRotate90d 3d\", \"3D\", 0, True, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2))))\n\nTESTS.append((\"Spacingd 3d\", \"3D\", 3e-2, True, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False)))\n\nTESTS.append((\"Resized 2d\", \"2D\", 2e-1, True, Resized(KEYS, [50, 47])))\n\nTESTS.append((\"Resized 3d\", \"3D\", 5e-2, True, Resized(KEYS, [201, 150, 78])))\n\nTESTS.append((\"Resized longest 2d\", \"2D\", 2e-1, True, Resized(KEYS, 47, \"longest\", \"area\")))\n\nTESTS.append((\"Resized longest 3d\", \"3D\", 5e-2, True, Resized(KEYS, 201, \"longest\", \"trilinear\", True)))\n\nTESTS.append(\n    (\"Lambdad 2d\", \"2D\", 5e-2, False, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True))\n)\n\nTESTS.append(\n    (\n        \"RandLambdad 3d\",\n        \"3D\",\n        5e-2,\n        False,\n        RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5),\n    )\n)\n\nTESTS.append((\"Zoomd 1d\", \"1D odd\", 0, True, Zoomd(KEYS, zoom=2, keep_size=False)))\n\nTESTS.append((\"Zoomd 2d\", \"2D\", 2e-1, True, Zoomd(KEYS, zoom=0.9)))\n\nTESTS.append((\"Zoomd 3d\", \"3D\", 3e-2, True, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False)))\n\nTESTS.append((\"RandZoom 3d\", \"3D\", 9e-2, True, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True)))\n\nTESTS.append((\"RandRotated, prob 0\", \"2D\", 0, True, RandRotated(KEYS, prob=0, dtype=np.float64)))\n\nTESTS.append(\n    (\n        \"Rotated 2d\",\n        \"2D\",\n        8e-2,\n        True,\n        Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64),\n    )\n)\n\nTESTS.append(\n    (\n        \"Rotated 3d\",\n        \"3D\",\n        1e-1,\n        True,\n        Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True, dtype=np.float64),\n    )\n)\n\nTESTS.append(\n    (\n        \"RandRotated 3d\",\n        \"3D\",\n        1e-1,\n        True,\n        RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1, dtype=np.float64),  # type: ignore\n    )\n)\n\nTESTS.append((\"Transposed 2d\", \"2D\", 0, False, Transposed(KEYS, [0, 2, 1])))  # channel=0\n\nTESTS.append((\"Transposed 3d\", \"3D\", 0, False, Transposed(KEYS, [0, 3, 1, 2])))  # channel=0\n\nTESTS.append(\n    (\n        \"Affine 3d\",\n        \"3D\",\n        1e-1,\n        True,\n        Affined(\n            KEYS,\n            spatial_size=[155, 179, 192],\n            rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7],\n            shear_params=[0.5, 0.5],\n            translate_params=[10, 5, -4],\n            scale_params=[0.8, 1.3],\n        ),\n    )\n)\n\nTESTS.append(\n    (\n        \"RandAffine 3d\",\n        \"3D\",\n        1e-1,\n        True,\n        RandAffined(\n            KEYS,\n            [155, 179, 192],\n            prob=1,\n            padding_mode=\"zeros\",\n            rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7],\n            shear_range=[(0.5, 0.5)],\n            translate_range=[10, 5, -4],\n            scale_range=[(0.8, 1.2), (0.9, 1.3)],\n        ),\n    )\n)\n\nTESTS.append((\"RandAffine 3d\", \"3D\", 0, True, RandAffined(KEYS, spatial_size=None, prob=0)))\n\nTESTS.append(\n    (\n        \"RandCropByLabelClassesd 2d\",\n        \"2D\",\n        1e-7,\n        True,\n        RandCropByLabelClassesd(KEYS, \"label\", (99, 96), ratios=[1, 2, 3, 4, 5], num_classes=5, num_samples=10),\n    )\n)\n\nTESTS.append(\n    (\"RandCropByPosNegLabeld 2d\", \"2D\", 1e-7, True, RandCropByPosNegLabeld(KEYS, \"label\", (99, 96), num_samples=10))\n)\n\nTESTS.append((\"RandSpatialCropSamplesd 2d\", \"2D\", 1e-7, True, RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10)))\n\nTESTS.append((\"RandWeightedCropd 2d\", \"2D\", 1e-7, True, RandWeightedCropd(KEYS, \"label\", (90, 91), num_samples=10)))\n\nTESTS_COMPOSE_X2 = [(t[0] + \" Compose\", t[1], t[2], t[3], Compose(Compose(t[4:]))) for t in TESTS]\n\nTESTS = TESTS + TESTS_COMPOSE_X2\n\nNUM_SAMPLES = 5\nN_SAMPLES_TESTS = [\n    [RandCropByLabelClassesd(KEYS, \"label\", (110, 99), [1, 2, 3, 4, 5], num_classes=5, num_samples=NUM_SAMPLES)],\n    [RandCropByPosNegLabeld(KEYS, \"label\", (110, 99), num_samples=NUM_SAMPLES)],\n    [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)],\n    [RandWeightedCropd(KEYS, \"label\", (90, 91), num_samples=NUM_SAMPLES)],\n]\n\n\ndef no_collation(x):\n    return x\n\n\nclass TestInverse(unittest.TestCase):\n    \"\"\"Test inverse methods.\n\n    If tests are failing, the following function might be useful for displaying\n    `x`, `fx`, `f⁻¹fx` and `x - f⁻¹fx`.\n\n    .. code-block:: python\n\n        def plot_im(orig, fwd_bck, fwd):\n            import matplotlib.pyplot as plt\n            diff_orig_fwd_bck = orig - fwd_bck\n            ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck]\n            titles = [\"x\", \"fx\", \"f⁻¹fx\", \"x - f⁻¹fx\"]\n            fig, axes = plt.subplots(1, 4, gridspec_kw={\"width_ratios\": [i.shape[1] for i in ims_to_show]})\n            vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd])\n            vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd])\n            for im, title, ax in zip(ims_to_show, titles, axes):\n                _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None)\n                im = np.squeeze(np.array(im))\n                while im.ndim > 2:\n                    im = im[..., im.shape[-1] // 2]\n                im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax)\n                ax.set_title(title, fontsize=25)\n                ax.axis(\"off\")\n                fig.colorbar(im_show, ax=ax)\n            plt.show()\n\n    This can then be added to the exception:\n\n    .. code-block:: python\n\n        except AssertionError:\n            print(\n                f\"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}\"\n            )\n            if orig[0].ndim > 1:\n                plot_im(orig, fwd_bck, unmodified)\n    \"\"\"\n\n    def setUp(self):\n        if not has_nib:\n            self.skipTest(\"nibabel required for test_inverse\")\n\n        set_determinism(seed=0)\n\n        self.all_data = {}\n\n        affine = make_rand_affine()\n        affine[0] *= 2\n\n        for size in [10, 11]:\n            # pad 5 onto both ends so that cropping can be lossless\n            im_1d = np.pad(np.arange(size), 5)[None]\n            name = \"1D even\" if size % 2 == 0 else \"1D odd\"\n            self.all_data[name] = {\n                \"image\": torch.as_tensor(np.array(im_1d, copy=True)),\n                \"label\": torch.as_tensor(np.array(im_1d, copy=True)),\n                \"other\": torch.as_tensor(np.array(im_1d, copy=True)),\n            }\n\n        im_2d_fname, seg_2d_fname = (make_nifti_image(i) for i in create_test_image_2d(101, 100))\n        im_3d_fname, seg_3d_fname = (make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107))\n\n        load_ims = Compose(\n            [LoadImaged(KEYS), EnsureChannelFirstd(KEYS, channel_dim=\"no_channel\"), FromMetaTensord(KEYS)]\n        )\n        self.all_data[\"2D\"] = load_ims({\"image\": im_2d_fname, \"label\": seg_2d_fname})\n        self.all_data[\"3D\"] = load_ims({\"image\": im_3d_fname, \"label\": seg_3d_fname})\n\n    def tearDown(self):\n        set_determinism(seed=None)\n\n    def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff):\n        for key in keys:\n            orig = orig_d[key]\n            fwd_bck = fwd_bck_d[key]\n            if isinstance(fwd_bck, torch.Tensor):\n                fwd_bck = fwd_bck.cpu().numpy()\n            unmodified = unmodified_d[key]\n            if isinstance(orig, np.ndarray):\n                mean_diff = np.mean(np.abs(orig - fwd_bck))\n                resized = ResizeWithPadOrCrop(orig.shape[1:])(unmodified)\n                if isinstance(resized, torch.Tensor):\n                    resized = resized.detach().cpu().numpy()\n                unmodded_diff = np.mean(np.abs(orig - resized))\n                try:\n                    self.assertLessEqual(mean_diff, acceptable_diff)\n                except AssertionError:\n                    print(\n                        f\"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}\"\n                    )\n                    if orig[0].ndim == 1:\n                        print(\"orig\", orig[0])\n                        print(\"fwd_bck\", fwd_bck[0])\n                        print(\"unmod\", unmodified[0])\n                    raise\n\n    @parameterized.expand(TESTS)\n    def test_inverse(self, _, data_name, acceptable_diff, is_meta, *transforms):\n        name = _\n\n        data = self.all_data[data_name]\n        if is_meta:\n            data = ToMetaTensord(KEYS)(data)\n\n        forwards = [data.copy()]\n\n        # Apply forwards\n        for t in transforms:\n            if isinstance(t, Randomizable):\n                t.set_random_state(seed=get_seed())\n            forwards.append(t(forwards[-1]))\n\n        # Apply inverses\n        fwd_bck = forwards[-1].copy()\n        for i, t in enumerate(reversed(transforms)):\n            if isinstance(t, InvertibleTransform):\n                if isinstance(fwd_bck, list):\n                    for j, _fwd_bck in enumerate(fwd_bck):\n                        fwd_bck = t.inverse(_fwd_bck)\n                        self.check_inverse(\n                            name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1][j], acceptable_diff\n                        )\n                else:\n                    fwd_bck = t.inverse(fwd_bck)\n                    self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff)\n\n    # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway\n    @skipUnless(torch.multiprocessing.get_start_method() == \"spawn\", \"requires spawn\")\n    def test_fail(self):\n        t1 = SpatialPadd(\"image\", [10, 5])\n        data = t1(self.all_data[\"2D\"])\n\n        # Check that error is thrown when inverse are used out of order.\n        t2 = ResizeWithPadOrCropd(\"image\", [10, 5])\n        with self.assertRaises(RuntimeError):\n            t2.inverse(data)\n\n    @parameterized.expand(N_SAMPLES_TESTS)\n    def test_inverse_inferred_seg(self, extra_transform):\n        test_data = []\n        for _ in range(20):\n            image, label = create_test_image_2d(100, 101)\n            test_data.append({\"image\": image, \"label\": label.astype(np.float32)})\n\n        batch_size = 10\n        # num workers = 0 for mac\n        num_workers = 2 if sys.platform == \"linux\" else 0\n        transforms = Compose(\n            [EnsureChannelFirstd(KEYS, channel_dim=\"no_channel\"), SpatialPadd(KEYS, (150, 153)), extra_transform]\n        )\n\n        dataset = CacheDataset(test_data, transform=transforms, progress=False)\n        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1,)).to(device)\n\n        data = first(loader)\n        self.assertEqual(data[\"image\"].shape[0], batch_size * NUM_SAMPLES)\n\n        labels = data[\"label\"].to(device)\n        self.assertIsInstance(labels, MetaTensor)\n        segs = model(labels).detach().cpu()\n        segs_decollated = decollate_batch(segs)\n        self.assertIsInstance(segs_decollated[0], MetaTensor)\n        # inverse of individual segmentation\n        seg_metatensor = first(segs_decollated)\n        # test to convert interpolation mode for 1 data of model output batch\n        convert_applied_interp_mode(seg_metatensor.applied_operations, mode=\"nearest\", align_corners=None)\n\n        # manually invert the last crop samples\n        xform = seg_metatensor.applied_operations.pop(-1)\n        shape_before_extra_xform = xform[\"orig_size\"]\n        resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform)\n        with resizer.trace_transform(False):\n            seg_metatensor = resizer(seg_metatensor)\n        no_ops_id_tensor = reset_ops_id(deepcopy(seg_metatensor))\n\n        with allow_missing_keys_mode(transforms):\n            inv_seg = transforms.inverse({\"label\": seg_metatensor})[\"label\"]\n            inv_seg_1 = transforms.inverse({\"label\": no_ops_id_tensor})[\"label\"]\n        self.assertEqual(inv_seg.shape[1:], test_data[0][\"label\"].shape)\n        self.assertEqual(inv_seg_1.shape[1:], test_data[0][\"label\"].shape)\n\n        # # Inverse of batch\n        # batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True)\n        # with allow_missing_keys_mode(transforms):\n        #     inv_batch = batch_inverter(first(loader))\n        # self.assertEqual(inv_batch[0][\"label\"].shape[1:], test_data[0][\"label\"].shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/inverse/test_inverse_array.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.transforms import Compose, EnsureChannelFirst, Flip, Orientation, Spacing\nfrom monai.transforms.inverse import InvertibleTransform\nfrom monai.utils import optional_import\nfrom tests.test_utils import TEST_DEVICES\n\n_, has_nib = optional_import(\"nibabel\")\n\nTESTS = []\nfor use_compose in (False, True):\n    for dtype in (torch.float32, torch.float64):\n        for device in TEST_DEVICES:\n            TESTS.append([use_compose, dtype, *device])\n\n\n@unittest.skipUnless(has_nib, \"Requires nibabel\")\nclass TestInverseArray(unittest.TestCase):\n    @staticmethod\n    def get_image(dtype, device) -> MetaTensor:\n        affine = torch.tensor([[0, 0, 1, 0], [-1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 0, 1]]).to(dtype).to(device)\n        img = torch.rand((15, 16, 17)).to(dtype).to(device)\n        return MetaTensor(img, affine=affine)\n\n    @parameterized.expand(TESTS)\n    def test_inverse_array(self, use_compose: bool, dtype: torch.dtype, device: torch.device):\n        img: MetaTensor\n        tr = Compose(\n            [\n                EnsureChannelFirst(channel_dim=\"no_channel\"),\n                Orientation(\"RAS\"),\n                Flip(1),\n                Spacing([1.0, 1.2, 0.9], align_corners=False),\n            ]\n        )\n        num_invertible = len([i for i in tr.transforms if isinstance(i, InvertibleTransform)])\n\n        # forward\n        img = tr(self.get_image(dtype, device))\n        self.assertEqual(len(img.applied_operations), num_invertible)\n\n        # inverse with Compose\n        if use_compose:\n            img = tr.inverse(img)\n            self.assertEqual(len(img.applied_operations), 0)\n\n        # inverse individually\n        else:\n            _tr: InvertibleTransform\n            num_to_inverse = num_invertible\n            for _tr in tr.transforms[::-1]:\n                if isinstance(_tr, InvertibleTransform):\n                    img = _tr.inverse(img)\n                    num_to_inverse -= 1\n                    self.assertEqual(len(img.applied_operations), num_to_inverse)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/inverse/test_inverse_dict.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom itertools import product\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import DataLoader, Dataset, MetaTensor, ThreadDataLoader, create_test_image_2d\nfrom monai.engines.evaluator import SupervisedEvaluator\nfrom monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd\nfrom monai.utils.enums import CommonKeys\nfrom tests.test_utils import TEST_DEVICES, SkipIfNoModule\n\n\nclass TestInvertDict(unittest.TestCase):\n\n    def setUp(self):\n        self.orig_size = (60, 60)\n        img, _ = create_test_image_2d(*self.orig_size, 2, 10, num_seg_classes=2)\n        self.img = MetaTensor(img, meta={\"original_channel_dim\": float(\"nan\"), \"pixdim\": [1.0, 1.0]})\n        self.key = CommonKeys.IMAGE\n        self.pred = CommonKeys.PRED\n        self.new_pixdim = 2.0\n\n        self.preprocessing = Compose([EnsureChannelFirstd(self.key), Spacingd(self.key, pixdim=[self.new_pixdim] * 2)])\n\n        self.postprocessing = Compose([Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key)])\n\n    @parameterized.expand(TEST_DEVICES)\n    def test_simple_processing(self, device):\n        \"\"\"\n        Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly.\n\n        This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which\n        returns it to the original shape using Invertd. This tests that the shape of the output is the same as the\n        original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing\n        sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in\n        `self.postprocessing`.\n        \"\"\"\n\n        item = {self.key: self.img.to(device)}\n        pre = self.preprocessing(item)\n\n        nw = int(self.orig_size[0] / self.new_pixdim)\n        nh = int(self.orig_size[1] / self.new_pixdim)\n\n        self.assertTupleEqual(pre[self.key].shape, (1, nh, nw), \"Pre-processing did not reshape input correctly\")\n        self.assertTrue(len(pre[self.key].applied_operations) > 0, \"Pre-processing transforms did not trace correctly\")\n\n        pre[self.pred] = pre[self.key]  # the inputs are the prediction for this test\n\n        post = self.postprocessing(pre)\n\n        self.assertTupleEqual(\n            post[self.pred].shape, (1, *self.orig_size), \"Result does not have same shape as original input\"\n        )\n\n    @parameterized.expand(product(sum(TEST_DEVICES, []), [True, False]))\n    @SkipIfNoModule(\"ignite\")\n    def test_workflow(self, device, use_threads):\n        \"\"\"\n        This tests the interaction between pre and postprocesing transform sequences being executed in parallel.\n\n        When the `ThreadDataLoader` is used to load batches, this is done in parallel at times with the execution of\n        the post-process transform sequence. Previously this encountered a race condition at times because the\n        `TraceableTransform.tracing` variables of transforms was being toggled in different threads, so at times a\n        pre-process transform wouldn't trace correctly and so confuse `Invertd`. Using a `SupervisedEvaluator` is\n        the best way to induce this race condition, other methods didn't get the timing right..\n        \"\"\"\n        batch_size = 2\n        ds_size = 4\n        test_data = [{self.key: self.img.clone().to(device)} for _ in range(ds_size)]\n        ds = Dataset(test_data, transform=self.preprocessing)\n        dl_type = ThreadDataLoader if use_threads else DataLoader\n        dl = dl_type(ds, num_workers=0, batch_size=batch_size)\n\n        class AssertAppliedOps(torch.nn.Module):\n            def forward(self, x):\n                assert len(x.applied_operations) == x.shape[0]\n                assert all(len(a) > 0 for a in x.applied_operations)\n                return x\n\n        evaluator = SupervisedEvaluator(\n            device=device, network=AssertAppliedOps(), postprocessing=self.postprocessing, val_data_loader=dl\n        )\n\n        evaluator.run()\n\n        self.assertTupleEqual(evaluator.state.output[0][self.pred].shape, (1, *self.orig_size))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/inverse/test_invert.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\n\nfrom monai.data import DataLoader, Dataset, MetaTensor, create_test_image_3d, decollate_batch\nfrom monai.transforms import (\n    CastToType,\n    Compose,\n    EnsureChannelFirst,\n    Invert,\n    Lambda,\n    LoadImage,\n    Orientation,\n    RandAffine,\n    RandAxisFlip,\n    RandFlip,\n    RandRotate,\n    RandRotate90,\n    RandZoom,\n    ResizeWithPadOrCrop,\n    Spacing,\n)\nfrom monai.utils import set_determinism\nfrom tests.test_utils import assert_allclose, make_nifti_image\n\n\nclass TestInvert(unittest.TestCase):\n    def test_invert(self):\n        set_determinism(seed=0)\n        im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete\n        data = [im_fname for _ in range(12)]\n        transform = Compose(\n            [\n                LoadImage(image_only=True),\n                EnsureChannelFirst(),\n                Orientation(\"RPS\"),\n                Spacing(pixdim=(1.2, 1.01, 0.9), mode=1, dtype=np.float32),\n                RandFlip(prob=0.5, spatial_axis=[1, 2]),\n                RandAxisFlip(prob=0.5),\n                RandRotate90(prob=0, spatial_axes=(1, 2)),\n                RandZoom(prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),\n                RandRotate(prob=0.5, range_x=np.pi, mode=\"bilinear\", align_corners=True, dtype=np.float64),\n                RandAffine(prob=0.5, rotate_range=np.pi, mode=\"nearest\"),\n                ResizeWithPadOrCrop(100),\n                CastToType(dtype=torch.uint8),\n            ]\n        )\n\n        # num workers = 0 for mac or gpu transforms\n        num_workers = 0 if sys.platform != \"linux\" or torch.cuda.is_available() else 2\n        dataset = Dataset(data, transform=transform)\n        self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)\n        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)\n        inverter = Invert(transform=transform, nearest_interp=True, device=\"cpu\", post_func=torch.as_tensor)\n\n        for d in loader:\n            d = decollate_batch(d)\n            for item in d:\n                orig = deepcopy(item)\n                i = inverter(item)\n                self.assertTupleEqual(orig.shape[1:], (100, 100, 100))\n                # check the nearest interpolation mode\n                assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))\n                self.assertTupleEqual(i.shape[1:], (101, 100, 107))\n        # check labels match\n        reverted = i.detach().cpu().numpy().astype(np.int32)\n        original = LoadImage(image_only=True)(data[-1])\n        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))\n        reverted_name = i.meta[\"filename_or_obj\"]\n        original_name = original.meta[\"filename_or_obj\"]\n        self.assertEqual(reverted_name, original_name)\n        print(\"invert diff\", reverted.size - n_good)\n        self.assertTrue((reverted.size - n_good) < 300000, f\"diff. {reverted.size - n_good}\")\n        set_determinism(seed=None)\n\n    def test_invert_warn_pending(self):\n        # this test shouldn't raise a warning or error any more as that issue was fixed\n        # by https://github.com/Project-MONAI/MONAI/pull/6257\n        set_determinism(seed=0)\n        im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete\n        transform = Compose(\n            [LoadImage(image_only=True), EnsureChannelFirst(), Orientation(\"RPS\"), Lambda(func=lambda x: x)], lazy=True\n        )\n        output = transform([im_fname for _ in range(2)])\n        transform.inverse(output)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/inverse/test_invertd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import DataLoader, Dataset, create_test_image_3d, decollate_batch\nfrom monai.transforms import (\n    CastToTyped,\n    Compose,\n    CopyItemsd,\n    EnsureChannelFirstd,\n    Invertd,\n    LoadImaged,\n    Orientationd,\n    RandAffined,\n    RandAxisFlipd,\n    RandFlipd,\n    RandRotate90d,\n    RandRotated,\n    RandZoomd,\n    ResizeWithPadOrCropd,\n    ScaleIntensityd,\n    Spacingd,\n)\nfrom monai.utils import set_determinism\nfrom tests.test_utils import assert_allclose, make_nifti_image\n\nKEYS = [\"image\", \"label\"]\n\n\nclass TestInvertd(unittest.TestCase):\n    def test_invert(self):\n        set_determinism(seed=0)\n        im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100))\n        transform = Compose(\n            [\n                LoadImaged(KEYS, image_only=True),\n                EnsureChannelFirstd(KEYS),\n                Orientationd(KEYS, \"RPS\"),\n                Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=[\"bilinear\", \"nearest\"], dtype=np.float32),\n                ScaleIntensityd(\"image\", minv=1, maxv=10),\n                RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),\n                RandAxisFlipd(KEYS, prob=0.5),\n                RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)),\n                RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),\n                RandRotated(KEYS, prob=0.5, range_x=np.pi, mode=\"bilinear\", align_corners=True, dtype=np.float64),\n                RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode=[\"nearest\", 0]),\n                ResizeWithPadOrCropd(KEYS, 100),\n                CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),\n                CopyItemsd(\"label\", times=2, names=[\"label_inverted\", \"label_inverted1\"]),\n                CopyItemsd(\"image\", times=2, names=[\"image_inverted\", \"image_inverted1\"]),\n            ]\n        )\n        data = [{\"image\": im_fname, \"label\": seg_fname} for _ in range(12)]\n\n        # num workers = 0 for mac or gpu transforms\n        num_workers = 0 if sys.platform != \"linux\" or torch.cuda.is_available() else 2\n\n        dataset = Dataset(data, transform=transform)\n        transform.inverse(dataset[0])\n        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)\n        inverter = Invertd(\n            # `image` was not copied, invert the original value directly\n            keys=[\"image_inverted\", \"label_inverted\"],\n            transform=transform,\n            orig_keys=[\"label\", \"label\"],\n            nearest_interp=True,\n            device=None,\n            post_func=torch.as_tensor,\n        )\n\n        inverter_1 = Invertd(\n            # `image` was not copied, invert the original value directly\n            keys=[\"image_inverted1\", \"label_inverted1\"],\n            transform=transform,\n            orig_keys=[\"image\", \"image\"],\n            nearest_interp=[True, False],\n            device=\"cpu\",\n        )\n\n        expected_keys = [\"image\", \"image_inverted\", \"image_inverted1\", \"label\", \"label_inverted\", \"label_inverted1\"]\n        # execute 1 epoch\n        for d in loader:\n            d = decollate_batch(d)\n            for item in d:\n                item = inverter(item)\n                item = inverter_1(item)\n\n                self.assertListEqual(sorted(item), expected_keys)\n                self.assertTupleEqual(item[\"image\"].shape[1:], (100, 100, 100))\n                self.assertTupleEqual(item[\"label\"].shape[1:], (100, 100, 100))\n                # check the nearest interpolation mode\n                i = item[\"image_inverted\"]\n                assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))\n                self.assertTupleEqual(i.shape[1:], (101, 100, 107))\n                i = item[\"label_inverted\"]\n                assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))\n                self.assertTupleEqual(i.shape[1:], (101, 100, 107))\n\n                # check the case that different items use different interpolation mode to invert transforms\n                j = item[\"image_inverted1\"]\n                # if the interpolation mode is nearest, accumulated diff should be smaller than 1\n                self.assertLess(torch.sum(j.to(torch.float) - j.to(torch.uint8).to(torch.float)).item(), 1.0)\n                self.assertTupleEqual(j.shape, (1, 101, 100, 107))\n\n                k = item[\"label_inverted1\"]\n                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000\n                self.assertGreater(torch.sum(k.to(torch.float) - k.to(torch.uint8).to(torch.float)).item(), 10000.0)\n                self.assertTupleEqual(k.shape, (1, 101, 100, 107))\n\n        # check labels match\n        reverted = item[\"label_inverted\"].detach().cpu().numpy().astype(np.int32)\n        original = LoadImaged(KEYS, image_only=True)(data[-1])[\"label\"]\n        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))\n        reverted_name = item[\"label_inverted\"].meta[\"filename_or_obj\"]\n        original_name = data[-1][\"label\"]\n        self.assertEqual(reverted_name, original_name)\n        print(\"invert diff\", reverted.size - n_good)\n        # 25300: 2 workers (cpu, non-macos)\n        # 1812: 0 workers (gpu or macos)\n        # 1821: windows torch 1.10.0\n        self.assertLess((reverted.size - n_good), 40000, f\"diff.  {reverted.size - n_good}\")\n\n        set_determinism(seed=None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/inverse/test_traceable_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.transforms.inverse import TraceableTransform\n\n\nclass _TraceTest(TraceableTransform):\n\n    def __call__(self, data):\n        self.push_transform(data)\n        return data\n\n    def pop(self, data):\n        self.pop_transform(data)\n        return data\n\n\nclass TestTraceable(unittest.TestCase):\n\n    def test_default(self):\n        expected_key = \"_transforms\"\n        a = _TraceTest()\n        for x in a.transform_info_keys():\n            self.assertIn(x, a.get_transform_info())\n        self.assertEqual(a.trace_key(), expected_key)\n\n        data = {\"image\": \"test\"}\n        data = a(data)  # adds to the stack\n        self.assertIsInstance(data[expected_key], list)\n        self.assertEqual(data[expected_key][0][\"class\"], \"_TraceTest\")\n\n        data = a(data)  # adds to the stack\n        self.assertEqual(len(data[expected_key]), 2)\n        self.assertEqual(data[expected_key][-1][\"class\"], \"_TraceTest\")\n\n        with self.assertRaises(ValueError):\n            a.pop({\"test\": \"test\"})  # no stack in the data\n        data = a.pop(data)\n        data = a.pop(data)\n        self.assertEqual(data[expected_key], [])\n\n        with self.assertRaises(ValueError):  # no more items\n            a.pop(data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/post/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/post/test_label_filterd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.post.dictionary import LabelFilterd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\ngrid_1 = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]])\n\nVALID_TESTS = []\nfor p in TEST_NDARRAYS:\n    VALID_TESTS.append(\n        [\n            \"filter_single_label\",\n            {\"applied_labels\": 3},\n            p(grid_1),\n            p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])),\n        ]\n    )\n\n    VALID_TESTS.append(\n        [\n            \"filter_single_label_list\",\n            {\"applied_labels\": [3]},\n            p(grid_1),\n            p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])),\n        ]\n    )\n\n    VALID_TESTS.append(\n        [\n            \"filter_multi_label\",\n            {\"applied_labels\": [3, 5, 8]},\n            p(grid_1),\n            p(torch.tensor([[[[0, 0, 3], [0, 5, 0], [0, 8, 0]]]])),\n        ]\n    )\n\n    VALID_TESTS.append([\"filter_all\", {\"applied_labels\": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)])\n\nITEST_CASE_1 = [\"invalid_image_data_type\", {\"applied_labels\": 1}, [[[[1, 1, 1]]]], NotImplementedError]\n\nINVALID_CASES = [ITEST_CASE_1]\n\n\nclass TestLabelFilter(unittest.TestCase):\n    @parameterized.expand(VALID_TESTS)\n    def test_correct_results(self, _, args, input_image, expected):\n        converter = LabelFilterd(keys=\"image\", **args)\n        result = converter({\"image\": input_image})[\"image\"]\n        assert_allclose(result, expected)\n\n    @parameterized.expand(INVALID_CASES)\n    def test_raise_exception(self, _, args, input_image, expected_error):\n        with self.assertRaises(expected_error):\n            converter = LabelFilterd(keys=\"image\", **args)\n            if isinstance(input_image, torch.Tensor) and torch.cuda.is_available():\n                _ = converter({\"image\": input_image.cuda()})\n            else:\n                _ = converter({\"image\": input_image})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/post/test_probnms.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.post.array import ProbNMS\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5))\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": 10}, probs_map_1, []])\n\n    probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_2[33, 33] = 0.7\n    probs_map_2[66, 66] = 0.9\n    expected_2 = [[0.9, 66, 66], [0.7, 33, 33]]\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": [10, 10]}, probs_map_2, expected_2])\n\n    probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_3[56, 58] = 0.7\n    probs_map_3[60, 66] = 0.8\n    probs_map_3[66, 66] = 0.9\n    expected_3 = [[0.9, 66, 66], [0.8, 60, 66]]\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": (10, 20)}, probs_map_3, expected_3])\n\n    probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_4[33, 33] = 0.7\n    probs_map_4[66, 66] = 0.9\n    expected_4 = [[0.9, 66, 66]]\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.8, \"box_size\": 10}, probs_map_4, expected_4])\n\n    probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5))\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"sigma\": 0.1}, probs_map_5, []])\n\n    probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_6[33, 33] = 0.7\n    probs_map_6[66, 66] = 0.9\n    expected_6 = [[0.9, 66, 66], [0.7, 33, 33]]\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"sigma\": 0.1}, probs_map_6, expected_6])\n\n    probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5))\n    probs_map_3d[25, 25, 25] = 0.7\n    probs_map_3d[45, 45, 45] = 0.9\n    expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]]\n    TESTS.append([{\"spatial_dims\": 3, \"prob_threshold\": 0.5, \"box_size\": (10, 10, 10)}, probs_map_3d, expected_3d])\n\n\nclass TestProbNMS(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_output(self, class_args, probs_map, expected):\n        nms = ProbNMS(**class_args)\n        output = nms(probs_map)\n        assert_allclose(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/post/test_probnmsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.post.dictionary import ProbNMSD\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS: list[Any] = []\nfor p in TEST_NDARRAYS:\n    probs_map_1 = p(np.random.rand(100, 100).clip(0, 0.5))\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": 10}, {\"prob_map\": probs_map_1}, []])\n\n    probs_map_2 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_2[33, 33] = 0.7\n    probs_map_2[66, 66] = 0.9\n    expected_2 = [[0.9, 66, 66], [0.7, 33, 33]]\n    TESTS.append(\n        [{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": [10, 10]}, {\"prob_map\": probs_map_2}, expected_2]\n    )\n\n    probs_map_3 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_3[56, 58] = 0.7\n    probs_map_3[60, 66] = 0.8\n    probs_map_3[66, 66] = 0.9\n    expected_3 = [[0.9, 66, 66], [0.8, 60, 66]]\n    TESTS.append(\n        [{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"box_size\": (10, 20)}, {\"prob_map\": probs_map_3}, expected_3]\n    )\n\n    probs_map_4 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_4[33, 33] = 0.7\n    probs_map_4[66, 66] = 0.9\n    expected_4 = [[0.9, 66, 66]]\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.8, \"box_size\": 10}, {\"prob_map\": probs_map_4}, expected_4])\n\n    probs_map_5 = p(np.random.rand(100, 100).clip(0, 0.5))\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"sigma\": 0.1}, {\"prob_map\": probs_map_5}, []])\n\n    probs_map_6 = p(np.random.rand(100, 100).clip(0, 0.5))\n    probs_map_6[33, 33] = 0.7\n    probs_map_6[66, 66] = 0.9\n    expected_6 = [[0.9, 66, 66], [0.7, 33, 33]]\n    TESTS.append([{\"spatial_dims\": 2, \"prob_threshold\": 0.5, \"sigma\": 0.1}, {\"prob_map\": probs_map_6}, expected_6])\n\n    probs_map_3d = p(torch.rand([50, 50, 50]).uniform_(0, 0.5))\n    probs_map_3d[25, 25, 25] = 0.7\n    probs_map_3d[45, 45, 45] = 0.9\n    expected_3d = [[0.9, 45, 45, 45], [0.7, 25, 25, 25]]\n    TESTS.append(\n        [{\"spatial_dims\": 3, \"prob_threshold\": 0.5, \"box_size\": (10, 10, 10)}, {\"prob_map\": probs_map_3d}, expected_3d]\n    )\n\n\nclass TestProbNMS(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_output(self, class_args, probs_map, expected):\n        nms = ProbNMSD(keys=\"prob_map\", **class_args)\n        output = nms(probs_map)\n        np.testing.assert_allclose(output[\"prob_map\"], expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/post/test_remove_small_objects.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.post.array import RemoveSmallObjects\nfrom monai.transforms.post.dictionary import RemoveSmallObjectsd\nfrom monai.utils import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, SkipIfNoModule, assert_allclose\n\nmorphology, has_morphology = optional_import(\"skimage.morphology\")\n\nTEST_ZEROS = np.zeros((1, 9, 8, 7))\nTEST_ONES = np.ones((3, 7, 8, 9))\n\nTEST_INPUT1 = np.array([[[0, 0, 2, 1, 0], [1, 1, 1, 2, 0], [1, 1, 1, 0, 1]]])\n\nTEST_OUTPUT1 = np.array([[[0, 0, 2, 1, 0], [1, 1, 1, 2, 0], [1, 1, 1, 0, 0]]])\n\nTEST_INPUT2 = np.array([[[1, 1, 1, 0, 0], [1, 1, 1, 0, 0], [0, 0, 0, 0, 0], [0, 1, 1, 0, 1], [0, 0, 0, 1, 1]]])\naffine = torch.eye(4, dtype=torch.float64)\naffine[0, 0] = 2.0\nTEST_INPUT3 = MetaTensor(TEST_INPUT2, affine=affine)\n\nTESTS: list[tuple] = []\nfor dtype in (int, float):\n    for p in TEST_NDARRAYS:\n        TESTS.append((dtype, p, TEST_ZEROS, None))\n        TESTS.append((dtype, p, TEST_ONES, None))\n        TESTS.append((dtype, p, TEST_INPUT1, None, {\"min_size\": 6}))\n        TESTS.append((dtype, p, TEST_INPUT1, None, {\"min_size\": 7, \"connectivity\": 2}))\n        # for non-independent channels, the twos should stay\n        TESTS.append((dtype, p, TEST_INPUT1, TEST_OUTPUT1, {\"min_size\": 2, \"independent_channels\": False}))\n\nTESTS_PHYSICAL: list[tuple] = []\nfor dtype in (int, float):\n    TESTS_PHYSICAL.append((dtype, np.array, TEST_INPUT2, None, {\"min_size\": 3, \"by_measure\": True, \"pixdim\": (2, 1)}))\n    TESTS_PHYSICAL.append((dtype, MetaTensor, TEST_INPUT3, None, {\"min_size\": 3, \"by_measure\": True}))\n\n\n@SkipIfNoModule(\"skimage.morphology\")\nclass TestRemoveSmallObjects(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_remove_small_objects(self, dtype, im_type, lbl, expected, params=None):\n        params = params or {}\n        if expected is None:\n            dtype = bool if len(np.unique(lbl)) == 1 else int\n            expected = morphology.remove_small_objects(lbl.astype(dtype), **params)\n        expected = im_type(expected, dtype=dtype)\n        lbl = im_type(lbl, dtype=dtype)\n        lbl_clean = RemoveSmallObjects(**params)(lbl)\n        assert_allclose(lbl_clean, expected, device_test=True)\n        if isinstance(lbl, MetaTensor):\n            assert_allclose(lbl.affine, lbl_clean.affine)\n\n    @parameterized.expand(TESTS_PHYSICAL)\n    def test_remove_small_objects_physical(self, dtype, im_type, lbl, expected, params):\n        params = params or {}\n        min_size = np.ceil(params[\"min_size\"] / 2)\n\n        if expected is None:\n            dtype = bool if lbl.max() <= 1 else int\n            expected = morphology.remove_small_objects(lbl.astype(dtype), min_size=min_size)\n        expected = im_type(expected, dtype=dtype)\n        lbl = im_type(lbl, dtype=dtype)\n        lbl_clean = RemoveSmallObjects(**params)(lbl)\n        assert_allclose(lbl_clean, expected, device_test=True)\n\n        lbl_clean = RemoveSmallObjectsd(\"lbl\", **params)({\"lbl\": lbl})[\"lbl\"]\n        assert_allclose(lbl_clean, expected, device_test=True)\n\n    @parameterized.expand(TESTS)\n    def test_remove_small_objects_dict(self, dtype, im_type, lbl, expected, params=None):\n        params = params or {}\n        if expected is None:\n            dtype = bool if len(np.unique(lbl)) == 1 else int\n            expected = morphology.remove_small_objects(lbl.astype(dtype), **params)\n        expected = im_type(expected, dtype=dtype)\n        lbl = im_type(lbl, dtype=dtype)\n        lbl_clean = RemoveSmallObjectsd(\"lbl\", **params)({\"lbl\": lbl})[\"lbl\"]\n        assert_allclose(lbl_clean, expected, device_test=True)\n        if isinstance(lbl, MetaTensor):\n            assert_allclose(lbl.affine, lbl_clean.affine)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/spatial/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/spatial/test_convert_box_points.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.box_utils import convert_box_to_standard_mode\nfrom monai.transforms.spatial.array import ConvertBoxToPoints, ConvertPointsToBoxes\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_POINTS_2D = [\n    [\n        torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]),\n        \"xyxy\",\n        torch.tensor([[[10, 20], [30, 20], [30, 40], [10, 40]], [[50, 60], [70, 60], [70, 80], [50, 80]]]),\n    ],\n    [torch.tensor([[10, 20, 20, 20]]), \"ccwh\", torch.tensor([[[0, 10], [20, 10], [20, 30], [0, 30]]])],\n]\nTEST_CASE_POINTS_3D = [\n    [\n        torch.tensor([[10, 20, 30, 40, 50, 60], [70, 80, 90, 100, 110, 120]]),\n        \"xyzxyz\",\n        torch.tensor(\n            [\n                [\n                    [10, 20, 30],\n                    [40, 20, 30],\n                    [40, 50, 30],\n                    [10, 50, 30],\n                    [10, 20, 60],\n                    [40, 20, 60],\n                    [40, 50, 60],\n                    [10, 50, 60],\n                ],\n                [\n                    [70, 80, 90],\n                    [100, 80, 90],\n                    [100, 110, 90],\n                    [70, 110, 90],\n                    [70, 80, 120],\n                    [100, 80, 120],\n                    [100, 110, 120],\n                    [70, 110, 120],\n                ],\n            ]\n        ),\n    ],\n    [\n        torch.tensor([[10, 20, 30, 10, 10, 10]]),\n        \"cccwhd\",\n        torch.tensor(\n            [\n                [\n                    [5, 15, 25],\n                    [15, 15, 25],\n                    [15, 25, 25],\n                    [5, 25, 25],\n                    [5, 15, 35],\n                    [15, 15, 35],\n                    [15, 25, 35],\n                    [5, 25, 35],\n                ]\n            ]\n        ),\n    ],\n    [\n        torch.tensor([[10, 20, 30, 40, 50, 60]]),\n        \"xxyyzz\",\n        torch.tensor(\n            [\n                [\n                    [10, 30, 50],\n                    [20, 30, 50],\n                    [20, 40, 50],\n                    [10, 40, 50],\n                    [10, 30, 60],\n                    [20, 30, 60],\n                    [20, 40, 60],\n                    [10, 40, 60],\n                ]\n            ]\n        ),\n    ],\n]\n\nTEST_CASES = TEST_CASE_POINTS_2D + TEST_CASE_POINTS_3D\n\n\nclass TestConvertBoxToPoints(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_convert_box_to_points(self, boxes, mode, expected_points):\n        transform = ConvertBoxToPoints(mode=mode)\n        converted_points = transform(boxes)\n        assert_allclose(converted_points, expected_points, type_test=False)\n\n\nclass TestConvertPointsToBoxes(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_convert_box_to_points(self, boxes, mode, points):\n        transform = ConvertPointsToBoxes()\n        converted_boxes = transform(points)\n        expected_boxes = convert_box_to_standard_mode(boxes, mode)\n        assert_allclose(converted_boxes, expected_boxes, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/spatial/test_grid_patch.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms.spatial.array import GridPatch\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nA = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)\nA11 = A[:, :2, :2]\nA12 = A[:, :2, 2:]\nA21 = A[:, 2:, :2]\nA22 = A[:, 2:, 2:]\n\nTEST_CASE_0 = [{\"patch_size\": (2, 2)}, A, [A11, A12, A21, A22]]\nTEST_CASE_1 = [{\"patch_size\": (2, 2), \"num_patches\": 3}, A, [A11, A12, A21]]\nTEST_CASE_2 = [{\"patch_size\": (2, 2), \"num_patches\": 5}, A, [A11, A12, A21, A22, np.zeros((3, 2, 2))]]\nTEST_CASE_3 = [{\"patch_size\": (2, 2), \"offset\": (0, 0)}, A, [A11, A12, A21, A22]]\nTEST_CASE_4 = [{\"patch_size\": (2, 2), \"offset\": (0, 0)}, A, [A11, A12, A21, A22]]\nTEST_CASE_5 = [{\"patch_size\": (2, 2), \"offset\": (2, 2)}, A, [A22]]\nTEST_CASE_6 = [{\"patch_size\": (2, 2), \"offset\": (0, 2)}, A, [A12, A22]]\nTEST_CASE_7 = [{\"patch_size\": (2, 2), \"offset\": (2, 0)}, A, [A21, A22]]\nTEST_CASE_8 = [{\"patch_size\": (2, 2), \"num_patches\": 3, \"sort_fn\": \"max\"}, A, [A22, A21, A12]]\nTEST_CASE_9 = [{\"patch_size\": (2, 2), \"num_patches\": 4, \"sort_fn\": \"min\"}, A, [A11, A12, A21, A22]]\nTEST_CASE_10 = [{\"patch_size\": (2, 2), \"overlap\": 0.5, \"num_patches\": 3}, A, [A11, A[:, :2, 1:3], A12]]\nTEST_CASE_11 = [\n    {\"patch_size\": (3, 3), \"num_patches\": 2, \"constant_values\": 255, \"pad_mode\": \"constant\"},\n    A,\n    [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode=\"constant\", constant_values=255)],\n]\nTEST_CASE_12 = [\n    {\"patch_size\": (3, 3), \"offset\": (-2, -2), \"num_patches\": 2, \"pad_mode\": \"constant\"},\n    A,\n    [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode=\"constant\")],\n]\n# Only threshold filtering\nTEST_CASE_13 = [{\"patch_size\": (2, 2), \"threshold\": 50.0}, A, [A11]]\nTEST_CASE_14 = [{\"patch_size\": (2, 2), \"threshold\": 150.0}, A, [A11, A12, A21]]\n# threshold filtering with num_patches more than available patches (no effect)\nTEST_CASE_15 = [{\"patch_size\": (2, 2), \"num_patches\": 3, \"threshold\": 50.0}, A, [A11]]\n# threshold filtering with num_patches less than available patches (count filtering)\nTEST_CASE_16 = [{\"patch_size\": (2, 2), \"num_patches\": 2, \"threshold\": 150.0}, A, [A11, A12]]\n\nTEST_CASE_META_0 = [\n    {\"patch_size\": (2, 2)},\n    A,\n    [A11, A12, A21, A22],\n    [{\"location\": [0, 0]}, {\"location\": [0, 2]}, {\"location\": [2, 0]}, {\"location\": [2, 2]}],\n]\n\nTEST_CASE_META_1 = [\n    {\"patch_size\": (2, 2)},\n    MetaTensor(x=A, meta={\"path\": \"path/to/file\"}),\n    [A11, A12, A21, A22],\n    [\n        {\"location\": [0, 0], \"path\": \"path/to/file\"},\n        {\"location\": [0, 2], \"path\": \"path/to/file\"},\n        {\"location\": [2, 0], \"path\": \"path/to/file\"},\n        {\"location\": [2, 2], \"path\": \"path/to/file\"},\n    ],\n]\n\nTEST_CASES = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES.append([p, *TEST_CASE_0])\n    TEST_CASES.append([p, *TEST_CASE_1])\n    TEST_CASES.append([p, *TEST_CASE_2])\n    TEST_CASES.append([p, *TEST_CASE_3])\n    TEST_CASES.append([p, *TEST_CASE_4])\n    TEST_CASES.append([p, *TEST_CASE_5])\n    TEST_CASES.append([p, *TEST_CASE_6])\n    TEST_CASES.append([p, *TEST_CASE_7])\n    TEST_CASES.append([p, *TEST_CASE_8])\n    TEST_CASES.append([p, *TEST_CASE_9])\n    TEST_CASES.append([p, *TEST_CASE_10])\n    TEST_CASES.append([p, *TEST_CASE_11])\n    TEST_CASES.append([p, *TEST_CASE_12])\n    TEST_CASES.append([p, *TEST_CASE_13])\n    TEST_CASES.append([p, *TEST_CASE_14])\n    TEST_CASES.append([p, *TEST_CASE_15])\n    TEST_CASES.append([p, *TEST_CASE_16])\n\n\nclass TestGridPatch(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_grid_patch(self, in_type, input_parameters, image, expected):\n        input_image = in_type(image)\n        splitter = GridPatch(**input_parameters)\n        output = splitter(input_image)\n        self.assertEqual(len(output), len(expected))\n        for output_patch, expected_patch in zip(output, expected):\n            assert_allclose(\n                output_patch,\n                in_type(expected_patch),\n                type_test=False,\n                device_test=bool(isinstance(in_type(expected_patch), torch.Tensor)),\n            )\n\n    @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1])\n    def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta):\n        set_track_meta(True)\n        splitter = GridPatch(**input_parameters)\n        output = splitter(image)\n        self.assertEqual(len(output), len(expected))\n        if \"path\" in expected_meta[0]:\n            self.assertTrue(output.meta[\"path\"] == expected_meta[0][\"path\"])\n        for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta):\n            assert_allclose(output_patch, expected_patch, type_test=False)\n            self.assertIsInstance(output_patch, MetaTensor)\n            self.assertEqual(output_patch.meta[\"location\"], expected_patch_meta[\"location\"])\n            self.assertTrue(output_patch.meta[\"spatial_shape\"], list(output_patch.shape[1:]))\n            if \"path\" in expected_meta[0]:\n                self.assertEqual(output_patch.meta[\"path\"], expected_patch_meta[\"path\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/spatial/test_grid_patchd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.spatial.dictionary import GridPatchd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nA = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)\nA11 = A[:, :2, :2]\nA12 = A[:, :2, 2:]\nA21 = A[:, 2:, :2]\nA22 = A[:, 2:, 2:]\n\nTEST_CASE_0 = [{\"patch_size\": (2, 2)}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_1 = [{\"patch_size\": (2, 2), \"num_patches\": 3}, {\"image\": A}, [A11, A12, A21]]\nTEST_CASE_2 = [{\"patch_size\": (2, 2), \"num_patches\": 5}, {\"image\": A}, [A11, A12, A21, A22, np.zeros((3, 2, 2))]]\nTEST_CASE_3 = [{\"patch_size\": (2, 2), \"offset\": (0, 0)}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_4 = [{\"patch_size\": (2, 2), \"offset\": (0, 0)}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_5 = [{\"patch_size\": (2, 2), \"offset\": (2, 2)}, {\"image\": A}, [A22]]\nTEST_CASE_6 = [{\"patch_size\": (2, 2), \"offset\": (0, 2)}, {\"image\": A}, [A12, A22]]\nTEST_CASE_7 = [{\"patch_size\": (2, 2), \"offset\": (2, 0)}, {\"image\": A}, [A21, A22]]\nTEST_CASE_8 = [{\"patch_size\": (2, 2), \"num_patches\": 3, \"sort_fn\": \"max\"}, {\"image\": A}, [A22, A21, A12]]\nTEST_CASE_9 = [{\"patch_size\": (2, 2), \"num_patches\": 4, \"sort_fn\": \"min\"}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_10 = [{\"patch_size\": (2, 2), \"overlap\": 0.5, \"num_patches\": 3}, {\"image\": A}, [A11, A[:, :2, 1:3], A12]]\nTEST_CASE_11 = [\n    {\"patch_size\": (3, 3), \"num_patches\": 2, \"constant_values\": 255, \"pad_mode\": \"constant\"},\n    {\"image\": A},\n    [A[:, :3, :3], np.pad(A[:, :3, 3:], ((0, 0), (0, 0), (0, 2)), mode=\"constant\", constant_values=255)],\n]\nTEST_CASE_12 = [\n    {\"patch_size\": (3, 3), \"offset\": (-2, -2), \"num_patches\": 2, \"pad_mode\": \"constant\"},\n    {\"image\": A},\n    [np.zeros((3, 3, 3)), np.pad(A[:, :1, 1:4], ((0, 0), (2, 0), (0, 0)), mode=\"constant\")],\n]\n# Only threshold filtering\nTEST_CASE_13 = [{\"patch_size\": (2, 2), \"threshold\": 50.0}, {\"image\": A}, [A11]]\nTEST_CASE_14 = [{\"patch_size\": (2, 2), \"threshold\": 150.0}, {\"image\": A}, [A11, A12, A21]]\n# threshold filtering with num_patches more than available patches (no effect)\nTEST_CASE_15 = [{\"patch_size\": (2, 2), \"threshold\": 50.0, \"num_patches\": 3}, {\"image\": A}, [A11]]\n# threshold filtering with num_patches less than available patches (count filtering)\nTEST_CASE_16 = [{\"patch_size\": (2, 2), \"threshold\": 150.0, \"num_patches\": 2}, {\"image\": A}, [A11, A12]]\n\nTEST_SINGLE = []\nfor p in TEST_NDARRAYS:\n    TEST_SINGLE.append([p, *TEST_CASE_0])\n    TEST_SINGLE.append([p, *TEST_CASE_1])\n    TEST_SINGLE.append([p, *TEST_CASE_2])\n    TEST_SINGLE.append([p, *TEST_CASE_3])\n    TEST_SINGLE.append([p, *TEST_CASE_4])\n    TEST_SINGLE.append([p, *TEST_CASE_5])\n    TEST_SINGLE.append([p, *TEST_CASE_6])\n    TEST_SINGLE.append([p, *TEST_CASE_7])\n    TEST_SINGLE.append([p, *TEST_CASE_8])\n    TEST_SINGLE.append([p, *TEST_CASE_9])\n    TEST_SINGLE.append([p, *TEST_CASE_10])\n    TEST_SINGLE.append([p, *TEST_CASE_11])\n    TEST_SINGLE.append([p, *TEST_CASE_12])\n    TEST_SINGLE.append([p, *TEST_CASE_13])\n    TEST_SINGLE.append([p, *TEST_CASE_14])\n    TEST_SINGLE.append([p, *TEST_CASE_15])\n    TEST_SINGLE.append([p, *TEST_CASE_16])\n\n\nclass TestGridPatchd(unittest.TestCase):\n    @parameterized.expand(TEST_SINGLE)\n    def test_grid_patchd(self, in_type, input_parameters, image_dict, expected):\n        image_key = \"image\"\n        input_dict = {}\n        for k, v in image_dict.items():\n            input_dict[k] = v\n            if k == image_key:\n                input_dict[k] = in_type(v)\n        splitter = GridPatchd(keys=image_key, **input_parameters)\n        output = splitter(input_dict)\n        self.assertEqual(len(output[image_key]), len(expected))\n        for output_patch, expected_patch in zip(output[image_key], expected):\n            assert_allclose(\n                output_patch,\n                in_type(expected_patch),\n                type_test=False,\n                device_test=bool(isinstance(in_type(expected_patch), torch.Tensor)),\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/spatial/test_rand_grid_patch.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms.spatial.array import RandGridPatch\nfrom monai.utils import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nA = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)\nA11 = A[:, :2, :2]\nA12 = A[:, :2, 2:]\nA21 = A[:, 2:, :2]\nA22 = A[:, 2:, 2:]\n\nTEST_CASE_0 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0}, A, [A11, A12, A21, A22]]\nTEST_CASE_1 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"num_patches\": 3}, A, [A11, A12, A21]]\nTEST_CASE_2 = [\n    {\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0, \"num_patches\": 5},\n    A,\n    [A11, A12, A21, A22, np.zeros((3, 2, 2))],\n]\nTEST_CASE_3 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0}, A, [A11, A12, A21, A22]]\nTEST_CASE_4 = [{\"patch_size\": (2, 2)}, A, [A11, A12, A21, A22]]\nTEST_CASE_5 = [{\"patch_size\": (2, 2), \"min_offset\": 2, \"max_offset\": 2}, A, [A22]]\nTEST_CASE_6 = [{\"patch_size\": (2, 2), \"min_offset\": (0, 2), \"max_offset\": (0, 2)}, A, [A12, A22]]\nTEST_CASE_7 = [{\"patch_size\": (2, 2), \"min_offset\": 1, \"max_offset\": 2}, A, [A22]]\nTEST_CASE_8 = [\n    {\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 1, \"num_patches\": 1, \"sort_fn\": \"max\"},\n    A,\n    [A[:, 1:3, 1:3]],\n]\nTEST_CASE_9 = [\n    {\n        \"patch_size\": (3, 3),\n        \"min_offset\": -3,\n        \"max_offset\": -1,\n        \"sort_fn\": \"min\",\n        \"num_patches\": 1,\n        \"pad_mode\": \"constant\",\n        \"constant_values\": 255,\n    },\n    A,\n    [np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode=\"constant\", constant_values=255)],\n]\nTEST_CASE_10 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0, \"threshold\": 50.0}, A, [A11]]\nTEST_CASE_11 = [{\"patch_size\": (2, 2), \"sort_fn\": \"random\", \"num_patches\": 2}, A, [A11, A12]]\nTEST_CASE_12 = [{\"patch_size\": (2, 2), \"sort_fn\": \"random\", \"num_patches\": 4}, A, [A11, A12, A21, A22]]\nTEST_CASE_13 = [\n    {\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 1, \"num_patches\": 1, \"sort_fn\": \"random\"},\n    A,\n    [A[:, 1:3, 1:3]],\n]\n\nTEST_CASE_META_0 = [\n    {\"patch_size\": (2, 2)},\n    A,\n    [A11, A12, A21, A22],\n    [{\"location\": [0, 0]}, {\"location\": [0, 2]}, {\"location\": [2, 0]}, {\"location\": [2, 2]}],\n]\n\nTEST_CASE_META_1 = [\n    {\"patch_size\": (2, 2)},\n    MetaTensor(x=A, meta={\"path\": \"path/to/file\"}),\n    [A11, A12, A21, A22],\n    [\n        {\"location\": [0, 0], \"path\": \"path/to/file\"},\n        {\"location\": [0, 2], \"path\": \"path/to/file\"},\n        {\"location\": [2, 0], \"path\": \"path/to/file\"},\n        {\"location\": [2, 2], \"path\": \"path/to/file\"},\n    ],\n]\n\nTEST_SINGLE = []\nfor p in TEST_NDARRAYS:\n    TEST_SINGLE.append([p, *TEST_CASE_0])\n    TEST_SINGLE.append([p, *TEST_CASE_1])\n    TEST_SINGLE.append([p, *TEST_CASE_2])\n    TEST_SINGLE.append([p, *TEST_CASE_3])\n    TEST_SINGLE.append([p, *TEST_CASE_4])\n    TEST_SINGLE.append([p, *TEST_CASE_5])\n    TEST_SINGLE.append([p, *TEST_CASE_6])\n    TEST_SINGLE.append([p, *TEST_CASE_7])\n    TEST_SINGLE.append([p, *TEST_CASE_8])\n    TEST_SINGLE.append([p, *TEST_CASE_9])\n    TEST_SINGLE.append([p, *TEST_CASE_10])\n    TEST_SINGLE.append([p, *TEST_CASE_11])\n    TEST_SINGLE.append([p, *TEST_CASE_12])\n    TEST_SINGLE.append([p, *TEST_CASE_13])\n\n\nclass TestRandGridPatch(unittest.TestCase):\n    def setUp(self):\n        set_determinism(seed=1234)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_SINGLE)\n    def test_rand_grid_patch(self, in_type, input_parameters, image, expected):\n        input_image = in_type(image)\n        splitter = RandGridPatch(**input_parameters)\n        splitter.set_random_state(1234)\n        output = splitter(input_image)\n        self.assertEqual(len(output), len(expected))\n        for output_patch, expected_patch in zip(output, expected):\n            assert_allclose(\n                output_patch,\n                in_type(expected_patch),\n                type_test=False,\n                device_test=bool(isinstance(in_type(expected_patch), torch.Tensor)),\n            )\n\n    @parameterized.expand([TEST_CASE_META_0, TEST_CASE_META_1])\n    def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta):\n        set_track_meta(True)\n        splitter = RandGridPatch(**input_parameters)\n        splitter.set_random_state(1234)\n        output = splitter(image)\n        self.assertEqual(len(output), len(expected))\n        if \"path\" in expected_meta[0]:\n            self.assertTrue(output.meta[\"path\"] == expected_meta[0][\"path\"])\n        for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta):\n            assert_allclose(output_patch, expected_patch, type_test=False)\n            self.assertTrue(isinstance(output_patch, MetaTensor))\n            self.assertTrue(output_patch.meta[\"location\"] == expected_patch_meta[\"location\"])\n            self.assertTrue(output_patch.meta[\"spatial_shape\"], list(output_patch.shape[1:]))\n            if \"path\" in expected_meta[0]:\n                self.assertTrue(output_patch.meta[\"path\"] == expected_patch_meta[\"path\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/spatial/test_rand_grid_patchd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.spatial.dictionary import RandGridPatchd\nfrom monai.utils import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nA = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)\nA11 = A[:, :2, :2]\nA12 = A[:, :2, 2:]\nA21 = A[:, 2:, :2]\nA22 = A[:, 2:, 2:]\n\nTEST_CASE_0 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_1 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"num_patches\": 3}, {\"image\": A}, [A11, A12, A21]]\nTEST_CASE_2 = [\n    {\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0, \"num_patches\": 5},\n    {\"image\": A},\n    [A11, A12, A21, A22, np.zeros((3, 2, 2))],\n]\nTEST_CASE_3 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_4 = [{\"patch_size\": (2, 2)}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_5 = [{\"patch_size\": (2, 2), \"min_offset\": 2, \"max_offset\": 2}, {\"image\": A}, [A22]]\nTEST_CASE_6 = [{\"patch_size\": (2, 2), \"min_offset\": (0, 2), \"max_offset\": (0, 2)}, {\"image\": A}, [A12, A22]]\nTEST_CASE_7 = [{\"patch_size\": (2, 2), \"min_offset\": 1, \"max_offset\": 2}, {\"image\": A}, [A22]]\nTEST_CASE_8 = [\n    {\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 1, \"num_patches\": 1, \"sort_fn\": \"max\"},\n    {\"image\": A},\n    [A[:, 1:3, 1:3]],\n]\nTEST_CASE_9 = [\n    {\n        \"patch_size\": (3, 3),\n        \"min_offset\": -3,\n        \"max_offset\": -1,\n        \"sort_fn\": \"min\",\n        \"num_patches\": 1,\n        \"pad_mode\": \"constant\",\n        \"constant_values\": 255,\n    },\n    {\"image\": A},\n    [np.pad(A[:, :2, 1:], ((0, 0), (1, 0), (0, 0)), mode=\"constant\", constant_values=255)],\n]\nTEST_CASE_10 = [{\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 0, \"threshold\": 50.0}, {\"image\": A}, [A11]]\nTEST_CASE_11 = [{\"patch_size\": (2, 2), \"sort_fn\": \"random\", \"num_patches\": 2}, {\"image\": A}, [A11, A12]]\nTEST_CASE_12 = [{\"patch_size\": (2, 2), \"sort_fn\": \"random\", \"num_patches\": 4}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_13 = [\n    {\"patch_size\": (2, 2), \"min_offset\": 0, \"max_offset\": 1, \"num_patches\": 1, \"sort_fn\": \"random\"},\n    {\"image\": A},\n    [A[:, 1:3, 1:3]],\n]\n\nTEST_SINGLE = []\nfor p in TEST_NDARRAYS:\n    TEST_SINGLE.append([p, *TEST_CASE_0])\n    TEST_SINGLE.append([p, *TEST_CASE_1])\n    TEST_SINGLE.append([p, *TEST_CASE_2])\n    TEST_SINGLE.append([p, *TEST_CASE_3])\n    TEST_SINGLE.append([p, *TEST_CASE_4])\n    TEST_SINGLE.append([p, *TEST_CASE_5])\n    TEST_SINGLE.append([p, *TEST_CASE_6])\n    TEST_SINGLE.append([p, *TEST_CASE_7])\n    TEST_SINGLE.append([p, *TEST_CASE_8])\n    TEST_SINGLE.append([p, *TEST_CASE_9])\n    TEST_SINGLE.append([p, *TEST_CASE_10])\n    TEST_SINGLE.append([p, *TEST_CASE_11])\n    TEST_SINGLE.append([p, *TEST_CASE_12])\n    TEST_SINGLE.append([p, *TEST_CASE_13])\n\n\nclass TestRandGridPatchd(unittest.TestCase):\n    def setUp(self):\n        set_determinism(seed=1234)\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @parameterized.expand(TEST_SINGLE)\n    def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected):\n        image_key = \"image\"\n        input_dict = {}\n        for k, v in image_dict.items():\n            input_dict[k] = v\n            if k == image_key:\n                input_dict[k] = in_type(v)\n        splitter = RandGridPatchd(keys=image_key, **input_parameters)\n        splitter.set_random_state(1234)\n        output = splitter(input_dict)\n        self.assertEqual(len(output[image_key]), len(expected))\n        for output_patch, expected_patch in zip(output[image_key], expected):\n            assert_allclose(\n                output_patch,\n                in_type(expected_patch),\n                type_test=False,\n                device_test=bool(isinstance(in_type(expected_patch), torch.Tensor)),\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/spatial/test_spatial_resampled.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport platform\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import to_affine_nd\nfrom monai.transforms.spatial.dictionary import SpatialResampled\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_DEVICES, assert_allclose, dict_product\n\nON_AARCH64 = platform.machine() == \"aarch64\"\nif ON_AARCH64:\n    rtol, atol = 1e-1, 1e-2\nelse:\n    rtol, atol = 1e-3, 1e-4\n\nTESTS = []\n\ndestinations_3d = [\n    torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n    torch.tensor([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n]\nexpected_3d = [\n    torch.tensor([[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]]),\n    torch.tensor([[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]]),\n]\n\nfor dst, expct in zip(destinations_3d, expected_3d):\n    TESTS.extend(\n        [\n            [\n                np.arange(12).reshape((1, 2, 2, 3)) + 1.0,  # data\n                *params[\"device\"],\n                dst,\n                {\n                    **{k: v for k, v in params.items() if k not in [\"device\", \"interp_mode\"]},\n                    \"dst_keys\": \"dst_affine\",\n                    \"padding_mode\": \"zeros\",\n                },\n                expct,\n            ]\n            for params in dict_product(\n                device=TEST_DEVICES,\n                align_corners=[True, False],\n                dtype=[torch.float32, torch.float64],\n                interp_mode=[\"nearest\", \"bilinear\"],\n                padding_mode=[\"zeros\", \"border\", \"reflection\"],\n            )\n        ]\n    )\n\ndestinations_2d = [\n    torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]),  # flip the second\n    torch.tensor([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),  # flip the first\n]\n\nexpected_2d = [torch.tensor([[[2.0, 1.0], [4.0, 3.0]]]), torch.tensor([[[3.0, 4.0], [1.0, 2.0]]])]\n\nfor dst, expct in zip(destinations_2d, expected_2d):\n    TESTS += [\n        [\n            np.arange(4).reshape((1, 2, 2)) + 1.0,  # data\n            *params.pop(\"device\"),\n            dst,\n            {\n                **{k: v for k, v in params.items() if k not in [\"align\", \"interp_mode\"]},\n                \"dst_keys\": \"dst_affine\",\n                \"padding_mode\": \"zeros\",\n            },\n            expct,\n        ]\n        for params in dict_product(\n            device=TEST_DEVICES,\n            align=[False, True],\n            dtype=[torch.float32, torch.float64],\n            interp_mode=[\"nearest\", \"bilinear\"],\n        )\n    ]\n\n\nclass TestSpatialResample(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):\n        img = MetaTensor(img, affine=torch.eye(4)).to(device)\n        data = {\"img\": img, \"dst_affine\": dst_affine}\n        init_param = kwargs.copy()\n        init_param[\"keys\"] = \"img\"\n        call_param = {\"data\": data}\n        xform = SpatialResampled(**init_param)\n        output_data = xform(**call_param)\n        out = output_data[\"img\"]\n\n        assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)\n        assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), dst_affine, rtol=1e-2, atol=1e-2)\n\n        # check lazy\n        lazy_xform = SpatialResampled(**init_param)\n        test_resampler_lazy(lazy_xform, output_data, init_param, call_param, output_key=\"img\", rtol=rtol, atol=atol)\n\n        # check inverse\n        inverted = xform.inverse(output_data)[\"img\"]\n        self.assertEqual(inverted.applied_operations, [])  # no further invert after inverting\n        expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4))\n        assert_allclose(inverted.affine, expected_affine, rtol=1e-2, atol=1e-2)\n        assert_allclose(inverted, img, rtol=1e-2, atol=1e-2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_activations.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.layers.factories import Act\nfrom monai.transforms import Activations\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASES = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES.append(\n        [\n            {\"sigmoid\": True, \"softmax\": False, \"other\": None},\n            p([[[0.0, 1.0], [2.0, 3.0]]]),\n            p([[[0.5000, 0.7311], [0.8808, 0.9526]]]),\n            (1, 2, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"sigmoid\": False, \"softmax\": True, \"other\": None},\n            p([[[0.0, 1.0]], [[2.0, 3.0]]]),\n            p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]),\n            (2, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"sigmoid\": False, \"softmax\": True, \"other\": None, \"unused\": True, \"dim\": 1},\n            p([[[0.0, 1.0]], [[2.0, 3.0]]]),\n            p([[[1.0, 1.0]], [[1.0, 1.0]]]),\n            (2, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"sigmoid\": False, \"softmax\": False, \"other\": torch.tanh},\n            p([[[0.0, 1.0], [2.0, 3.0]]]),\n            p([[[0.0000, 0.7616], [0.9640, 0.9951]]]),\n            (1, 2, 2),\n        ]\n    )\n\nTEST_CASE_4 = [\n    \"swish\",\n    torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32),\n    torch.tensor(\n        [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]]\n    ),\n    (1, 2, 5),\n]\n\nTEST_CASE_5 = [\n    \"memswish\",\n    torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32),\n    torch.tensor(\n        [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]]\n    ),\n    (1, 2, 5),\n]\n\nTEST_CASE_6 = [\n    \"mish\",\n    torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32),\n    torch.tensor(\n        [[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]]\n    ),\n    (1, 2, 5),\n]\n\nTEST_CASE_7 = [\n    \"geglu\",\n    torch.tensor([[[-10, -8, -6, -4, -2, 0], [0, 2, 4, 6, 8, 10]]], dtype=torch.float32),\n    torch.tensor([[[1.27e-03, 3.64e-01, 0.00e00], [0.00e00, 1.60e01, 4.00e01]]]),\n    (1, 2, 3),\n]\n\n\nclass TestActivations(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_value_shape(self, input_param, img, out, expected_shape):\n        result = Activations(**input_param)(img)\n\n        def _compare(ret, out, shape):\n            assert_allclose(ret, out, rtol=1e-3, type_test=False)\n            self.assertTupleEqual(ret.shape, shape)\n\n        if isinstance(result, (list, tuple)):\n            for r, e in zip(result, out):\n                _compare(r, e, expected_shape)\n        else:\n            _compare(result, out, expected_shape)\n\n    @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])\n    def test_monai_activations_value_shape(self, input_param, img, out, expected_shape):\n        act = Act[input_param]()\n        result = act(img)\n        assert_allclose(result, out, rtol=1e-2, atol=1e-5)\n        self.assertTupleEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_activationsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Activationsd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASES = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES.append(\n        [\n            {\"keys\": [\"pred\", \"label\"], \"sigmoid\": False, \"softmax\": [True, False], \"other\": None, \"dim\": 0},\n            {\"pred\": p([[[0.0, 1.0]], [[2.0, 3.0]]]), \"label\": p([[[0.0, 1.0]], [[2.0, 3.0]]])},\n            {\"pred\": p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), \"label\": p([[[0.0, 1.0]], [[2.0, 3.0]]])},\n            (2, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": [\"pred\", \"label\"], \"sigmoid\": False, \"softmax\": False, \"other\": [torch.tanh, None]},\n            {\"pred\": p([[[0.0, 1.0], [2.0, 3.0]]]), \"label\": p([[[0.0, 1.0], [2.0, 3.0]]])},\n            {\"pred\": p([[[0.0000, 0.7616], [0.9640, 0.9951]]]), \"label\": p([[[0.0, 1.0], [2.0, 3.0]]])},\n            (1, 2, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": \"pred\", \"sigmoid\": False, \"softmax\": False, \"other\": torch.tanh},\n            {\"pred\": p([[[0.0, 1.0], [2.0, 3.0]]])},\n            {\"pred\": p([[[0.0000, 0.7616], [0.9640, 0.9951]]])},\n            (1, 2, 2),\n        ]\n    )\n\n\nclass TestActivationsd(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_value_shape(self, input_param, test_input, output, expected_shape):\n        result = Activationsd(**input_param)(test_input)\n        assert_allclose(result[\"pred\"], output[\"pred\"], rtol=1e-3, type_test=\"tensor\")\n        self.assertTupleEqual(result[\"pred\"].shape, expected_shape)\n        if \"label\" in result:\n            assert_allclose(result[\"label\"], output[\"label\"], rtol=1e-3, type_test=\"tensor\")\n            self.assertTupleEqual(result[\"label\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_adaptors.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport unittest\n\nfrom monai.transforms.adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs\n\n\nclass TestAdaptors(unittest.TestCase):\n\n    def test_function_signature(self):\n\n        def foo(image, label=None, *a, **kw):\n            pass\n\n        _ = FunctionSignature(foo)\n\n    def test_single_in_single_out(self):\n\n        def foo(image):\n            return image * 2\n\n        it = itertools.product([\"image\", [\"image\"]], [None, \"image\", [\"image\"], {\"image\": \"image\"}])\n        for i in it:\n            d = {\"image\": 2}\n            dres = adaptor(foo, i[0], i[1])(d)\n            self.assertEqual(dres[\"image\"], 4)\n\n        d = {\"image\": 2}\n        dres = adaptor(foo, \"image\")(d)\n        self.assertEqual(dres[\"image\"], 4)\n\n        d = {\"image\": 2}\n        dres = adaptor(foo, \"image\", \"image\")(d)\n        self.assertEqual(dres[\"image\"], 4)\n\n        d = {\"image\": 2}\n        dres = adaptor(foo, \"image\", {\"image\": \"image\"})(d)\n        self.assertEqual(dres[\"image\"], 4)\n\n        d = {\"img\": 2}\n        dres = adaptor(foo, \"img\", {\"img\": \"image\"})(d)\n        self.assertEqual(dres[\"img\"], 4)\n\n        d = {\"img\": 2}\n        dres = adaptor(foo, [\"img\"], {\"img\": \"image\"})(d)\n        self.assertEqual(dres[\"img\"], 4)\n\n    def test_multi_in_single_out(self):\n\n        def foo(image, label):\n            return image * label\n\n        it = itertools.product([\"image\", [\"image\"]], [None, [\"image\", \"label\"], {\"image\": \"image\", \"label\": \"label\"}])\n\n        for i in it:\n            d = {\"image\": 2, \"label\": 3}\n            dres = adaptor(foo, i[0], i[1])(d)\n            self.assertEqual(dres[\"image\"], 6)\n            self.assertEqual(dres[\"label\"], 3)\n\n        it = itertools.product(\n            [\"newimage\", [\"newimage\"]], [None, [\"image\", \"label\"], {\"image\": \"image\", \"label\": \"label\"}]\n        )\n\n        for i in it:\n            d = {\"image\": 2, \"label\": 3}\n            dres = adaptor(foo, i[0], i[1])(d)\n            self.assertEqual(dres[\"image\"], 2)\n            self.assertEqual(dres[\"label\"], 3)\n            self.assertEqual(dres[\"newimage\"], 6)\n\n        it = itertools.product([\"img\", [\"img\"]], [{\"img\": \"image\", \"lbl\": \"label\"}])\n\n        for i in it:\n            d = {\"img\": 2, \"lbl\": 3}\n            dres = adaptor(foo, i[0], i[1])(d)\n            self.assertEqual(dres[\"img\"], 6)\n            self.assertEqual(dres[\"lbl\"], 3)\n\n    def test_default_arg_single_out(self):\n\n        def foo(a, b=2):\n            return a * b\n\n        d = {\"a\": 5}\n        dres = adaptor(foo, \"c\")(d)\n        self.assertEqual(dres[\"c\"], 10)\n\n        d = {\"b\": 5}\n        with self.assertRaises(TypeError):\n            dres = adaptor(foo, \"c\")(d)\n\n    def test_multi_out(self):\n\n        def foo(a, b):\n            return a * b, a / b\n\n        d = {\"a\": 3, \"b\": 4}\n        dres = adaptor(foo, [\"c\", \"d\"])(d)\n        self.assertEqual(dres[\"c\"], 12)\n        self.assertEqual(dres[\"d\"], 3 / 4)\n\n    def test_dict_out(self):\n\n        def foo(a):\n            return {\"a\": a * 2}\n\n        d = {\"a\": 2}\n        dres = adaptor(foo, {\"a\": \"a\"})(d)\n        self.assertEqual(dres[\"a\"], 4)\n\n        d = {\"b\": 2}\n        dres = adaptor(foo, {\"a\": \"b\"}, {\"b\": \"a\"})(d)\n        self.assertEqual(dres[\"b\"], 4)\n\n\nclass TestApplyAlias(unittest.TestCase):\n\n    def test_apply_alias(self):\n\n        def foo(d):\n            d[\"x\"] *= 2\n            return d\n\n        d = {\"a\": 1, \"b\": 3}\n        result = apply_alias(foo, {\"b\": \"x\"})(d)\n        self.assertDictEqual({\"a\": 1, \"b\": 6}, result)\n\n\nclass TestToKwargs(unittest.TestCase):\n\n    def test_to_kwargs(self):\n\n        def foo(**kwargs):\n            results = {k: v * 2 for k, v in kwargs.items()}\n            return results\n\n        def compose_like(fn, data):\n            data = fn(data)\n            return data\n\n        d = {\"a\": 1, \"b\": 2}\n\n        actual = compose_like(to_kwargs(foo), d)\n        self.assertDictEqual(actual, {\"a\": 2, \"b\": 4})\n\n        with self.assertRaises(TypeError):\n            actual = compose_like(foo, d)\n"
  },
  {
    "path": "tests/transforms/test_add_coordinate_channels.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import AddCoordinateChannels\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], []\nfor p in TEST_NDARRAYS:\n    TESTS.append([{\"spatial_dims\": (0, 1, 2)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (4, 3, 3, 3)])\n    TESTS.append([{\"spatial_dims\": (0,)}, p(np.random.randint(0, 2, size=(1, 3, 3, 3))), (2, 3, 3, 3)])\n    TEST_CASES_ERROR_1.append([{\"spatial_dims\": (2,)}, p(np.random.randint(0, 2, size=(1, 3, 3)))])\n    TEST_CASES_ERROR_2.append([{\"spatial_dims\": (-1, 0, 1)}, p(np.random.randint(0, 2, size=(1, 3, 3)))])\n\n\nclass TestAddCoordinateChannels(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input, expected_shape):\n        result = AddCoordinateChannels(**input_param)(input)\n        self.assertEqual(type(result), type(input))\n        if isinstance(result, torch.Tensor):\n            self.assertEqual(result.device, input.device)\n        self.assertEqual(list(result.shape), list(expected_shape))\n        assert_allclose(input[0, ...], result[0, ...])\n\n    @parameterized.expand(TEST_CASES_ERROR_1)\n    def test_max_channel(self, input_param, input):\n        with self.assertRaises(ValueError):\n            AddCoordinateChannels(**input_param)(input)\n\n    @parameterized.expand(TEST_CASES_ERROR_2)\n    def test_channel_dim(self, input_param, input):\n        with self.assertRaises(ValueError):\n            AddCoordinateChannels(**input_param)(input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_add_coordinate_channelsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import AddCoordinateChannelsd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS, TEST_CASES_ERROR_1, TEST_CASES_ERROR_2 = [], [], []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"spatial_dims\": (0, 1, 2), \"keys\": [\"img\"]},\n            {\"img\": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))},\n            (4, 3, 3, 3),\n        ]\n    )\n    TESTS.append(\n        [{\"spatial_dims\": (0,), \"keys\": [\"img\"]}, {\"img\": p(np.random.randint(0, 2, size=(1, 3, 3, 3)))}, (2, 3, 3, 3)]\n    )\n\n    TEST_CASES_ERROR_1.append(\n        [{\"spatial_dims\": (2,), \"keys\": [\"img\"]}, {\"img\": p(np.random.randint(0, 2, size=(1, 3, 3)))}]\n    )\n    TEST_CASES_ERROR_2.append(\n        [{\"spatial_dims\": (-1, 0, 1), \"keys\": [\"img\"]}, {\"img\": p(np.random.randint(0, 2, size=(1, 3, 3)))}]\n    )\n\n\nclass TestAddCoordinateChannels(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input, expected_shape):\n        result = AddCoordinateChannelsd(**input_param)(input)[\"img\"]\n        input = input[\"img\"]\n        self.assertEqual(type(result), type(input))\n        if isinstance(result, torch.Tensor):\n            self.assertEqual(result.device, input.device)\n        self.assertEqual(result.shape, expected_shape)\n        assert_allclose(input[0, ...], result[0, ...])\n\n    @parameterized.expand(TEST_CASES_ERROR_1)\n    def test_max_channel(self, input_param, input):\n        with self.assertRaises(ValueError):\n            AddCoordinateChannelsd(**input_param)(input)\n\n    @parameterized.expand(TEST_CASES_ERROR_2)\n    def test_channel_dim(self, input_param, input):\n        with self.assertRaises(ValueError):\n            AddCoordinateChannelsd(**input_param)(input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_add_extreme_points_channel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import AddExtremePointsChannel\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nIMG_CHANNEL = 3\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for q in TEST_NDARRAYS:\n        TESTS.append(\n            [\n                {\n                    \"img\": p(np.zeros((IMG_CHANNEL, 4, 3))),\n                    \"label\": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])),\n                    \"sigma\": 1.0,\n                    \"rescale_min\": 0.0,\n                    \"rescale_max\": 1.0,\n                },\n                p(\n                    np.array(\n                        [\n                            [0.38318458, 0.98615628, 0.85551184],\n                            [0.35422316, 0.94430935, 1.0],\n                            [0.46000731, 0.57319659, 0.46000722],\n                            [0.64577687, 0.38318464, 0.0],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n        TESTS.append(\n            [\n                {\n                    \"img\": p(np.zeros((IMG_CHANNEL, 4, 3))),\n                    \"label\": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])),\n                    \"sigma\": 1.0,\n                    \"rescale_min\": 0.0,\n                    \"rescale_max\": 1.0,\n                },\n                p(\n                    np.array(\n                        [\n                            [0.44628328, 0.80495411, 0.44628328],\n                            [0.6779086, 1.0, 0.67790854],\n                            [0.33002687, 0.62079221, 0.33002687],\n                            [0.0, 0.31848389, 0.0],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestAddExtremePointsChannel(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, input_data, expected):\n        add_extreme_points_channel = AddExtremePointsChannel()\n        result = add_extreme_points_channel(**input_data)\n        assert_allclose(result[IMG_CHANNEL], expected, rtol=1e-4, atol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_add_extreme_points_channeld.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import AddExtremePointsChanneld\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nIMG_CHANNEL = 3\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for q in TEST_NDARRAYS:\n        TESTS.append(\n            [\n                {\n                    \"img\": p(np.zeros((IMG_CHANNEL, 4, 3))),\n                    \"label\": q(np.array([[[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]]])),\n                },\n                p(\n                    np.array(\n                        [\n                            [0.38318458, 0.98615628, 0.85551184],\n                            [0.35422316, 0.94430935, 1.0],\n                            [0.46000731, 0.57319659, 0.46000722],\n                            [0.64577687, 0.38318464, 0.0],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n        TESTS.append(\n            [\n                {\n                    \"img\": p(np.zeros((IMG_CHANNEL, 4, 3))),\n                    \"label\": q(np.array([[[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]]])),\n                },\n                p(\n                    np.array(\n                        [\n                            [0.44628328, 0.80495411, 0.44628328],\n                            [0.6779086, 1.0, 0.67790854],\n                            [0.33002687, 0.62079221, 0.33002687],\n                            [0.0, 0.31848389, 0.0],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestAddExtremePointsChanneld(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, input_data, expected):\n        add_extreme_points_channel = AddExtremePointsChanneld(\n            keys=\"img\", label_key=\"label\", sigma=1.0, rescale_min=0.0, rescale_max=1.0\n        )\n        result = add_extreme_points_channel(input_data)\n        assert_allclose(result[\"img\"][IMG_CHANNEL], expected, rtol=1e-4, atol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_adjust_contrast.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import AdjustContrast\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\nTESTS = []\nfor invert_image in (True, False):\n    for retain_stats in (True, False):\n        TEST_CASE_1 = [1.0, invert_image, retain_stats]\n        TEST_CASE_2 = [0.5, invert_image, retain_stats]\n        TEST_CASE_3 = [4.5, invert_image, retain_stats]\n\n        TESTS.extend([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n\n\nclass TestAdjustContrast(NumpyImageTestCase2D):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, gamma, invert_image, retain_stats):\n        adjuster = AdjustContrast(gamma=gamma, invert_image=invert_image, retain_stats=retain_stats)\n        for p in TEST_NDARRAYS:\n            im = p(self.imt)\n            result = adjuster(im)\n            self.assertTrue(type(im), type(result))\n            if False:  # gamma == 1.0:\n                expected = self.imt\n            else:\n                if invert_image:\n                    self.imt = -self.imt\n\n                if retain_stats:\n                    mn = self.imt.mean()\n                    sd = self.imt.std()\n\n                epsilon = 1e-7\n                img_min = self.imt.min()\n                img_range = self.imt.max() - img_min\n\n                expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min\n\n                if retain_stats:\n                    # zero mean and normalize\n                    expected = expected - expected.mean()\n                    expected = expected / (expected.std() + 1e-8)\n                    # restore old mean and standard deviation\n                    expected = sd * expected + mn\n\n                if invert_image:\n                    expected = -expected\n\n            assert_allclose(result, expected, atol=1e-05, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_adjust_contrastd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import AdjustContrastd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\nTESTS = []\nfor invert_image in (True, False):\n    for retain_stats in (True, False):\n        TEST_CASE_1 = [1.0, invert_image, retain_stats]\n        TEST_CASE_2 = [0.5, invert_image, retain_stats]\n        TEST_CASE_3 = [4.5, invert_image, retain_stats]\n\n        TESTS.extend([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n\n\nclass TestAdjustContrastd(NumpyImageTestCase2D):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, gamma, invert_image, retain_stats):\n        adjuster = AdjustContrastd(\"img\", gamma=gamma, invert_image=invert_image, retain_stats=retain_stats)\n        for p in TEST_NDARRAYS:\n            result = adjuster({\"img\": p(self.imt)})\n            if invert_image:\n                self.imt = -self.imt\n\n            if retain_stats:\n                mn = self.imt.mean()\n                sd = self.imt.std()\n\n            epsilon = 1e-7\n            img_min = self.imt.min()\n            img_range = self.imt.max() - img_min\n\n            expected = np.power(((self.imt - img_min) / float(img_range + epsilon)), gamma) * img_range + img_min\n\n            if retain_stats:\n                # zero mean and normalize\n                expected = expected - expected.mean()\n                expected = expected / (expected.std() + 1e-8)\n                # restore old mean and standard deviation\n                expected = sd * expected + mn\n\n            if invert_image:\n                expected = -expected\n            assert_allclose(result[\"img\"], expected, atol=1e-05, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_affine.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Affine, Resize\nfrom monai.transforms.lazy.functional import apply_pending\nfrom monai.utils import optional_import\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, test_local_inversion\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                dict(padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(9).reshape((1, 3, 3))), \"spatial_size\": (-1, 0)},\n                p(np.arange(9).reshape(1, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(padding_mode=\"zeros\", device=device, image_only=True),\n                {\"img\": p(np.arange(9).reshape((1, 3, 3))), \"spatial_size\": (-1, 0)},\n                p(np.arange(9).reshape(1, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2)))},\n                p(np.arange(4).reshape(1, 2, 2)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2))), \"spatial_size\": (4, 4)},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(rotate_params=[np.pi / 2], padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2))), \"spatial_size\": (4, 4)},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(rotate_params=[np.pi / 2], padding_mode=\"zeros\", device=device, align_corners=False),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2))), \"spatial_size\": (4, 4)},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])),\n                    padding_mode=\"zeros\",\n                    device=device,\n                ),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2))), \"spatial_size\": (4, 4)},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(27).reshape((1, 3, 3, 3))), \"spatial_size\": (-1, 0, 0)},\n                p(np.arange(27).reshape(1, 3, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(8).reshape((1, 2, 2, 2))), \"spatial_size\": (4, 4, 4)},\n                p(\n                    np.array(\n                        [\n                            [\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 1.0, 0.0],\n                                    [0.0, 2.0, 3.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 4.0, 5.0, 0.0],\n                                    [0.0, 6.0, 7.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(rotate_params=[np.pi / 2], padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(8).reshape((1, 2, 2, 2))), \"spatial_size\": (4, 4, 4)},\n                p(\n                    np.array(\n                        [\n                            [\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 2.0, 0.0, 0.0],\n                                    [0.0, 3.0, 1.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 6.0, 4.0, 0.0],\n                                    [0.0, 7.0, 5.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestAffine(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_affine(self, input_param, input_data, expected_val):\n        input_copy = deepcopy(input_data[\"img\"])\n        g = Affine(**input_param)\n        result = g(**input_data)\n        output_idx = None\n        if isinstance(result, tuple):\n            output_idx = 0\n            result = result[output_idx]\n\n        test_local_inversion(g, result, input_copy)\n        assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False)\n\n        set_track_meta(False)\n        result = g(**input_data)\n        if isinstance(result, tuple):\n            result = result[0]\n        self.assertNotIsInstance(result, MetaTensor)\n        self.assertIsInstance(result, torch.Tensor)\n        set_track_meta(True)\n\n        # test lazy\n        lazy_input_param = input_param.copy()\n        for align_corners in [True, False]:\n            lazy_input_param[\"align_corners\"] = align_corners\n            resampler = Affine(**lazy_input_param)\n            non_lazy_result = resampler(**input_data)\n            test_resampler_lazy(\n                resampler, non_lazy_result, lazy_input_param, input_data, output_idx=output_idx, rtol=1e-3, atol=1e-3\n            )\n\n\n@unittest.skipUnless(optional_import(\"scipy\")[1], \"Requires scipy library.\")\nclass TestAffineConsistency(unittest.TestCase):\n    @parameterized.expand([[7], [8], [9]])\n    def test_affine_resize(self, s):\n        \"\"\"s\"\"\"\n        im = np.arange(4).reshape(1, 2, 2).astype(float)\n        mat = np.array([[1 / s, 0, 0], [0, 1 / s, 0], [0, 0, 1]])\n        sp_size = 2 * s\n\n        def method_0(im, ac):\n            xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size)\n            xform.lazy = True\n            out = xform(im)\n            overrides = {\"padding_mode\": \"border\", \"align_corners\": ac}\n            out = apply_pending(out, overrides=overrides)[0]\n            return out\n\n        def method_1(im, ac):\n            xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=sp_size)\n            xform.lazy = True\n            out = xform(im)\n            overrides = {\"mode\": 1, \"padding_mode\": \"nearest\", \"align_corners\": ac}\n            out = apply_pending(out, overrides=overrides)[0]\n            return out\n\n        def method_2(im, ac):\n            xform = Affine(align_corners=ac, affine=mat, padding_mode=\"border\", image_only=True, spatial_size=sp_size)\n            out = xform(im)\n            return out\n\n        def method_3(im, ac):\n            xform = Affine(\n                align_corners=ac, affine=mat, mode=1, padding_mode=\"nearest\", image_only=True, spatial_size=sp_size\n            )\n            out = xform(im)\n            return out\n\n        for call in (method_0, method_1, method_2, method_3):\n            for ac in (False, True):\n                out = call(im, ac)\n                ref = Resize(align_corners=ac, spatial_size=(sp_size, sp_size), mode=\"bilinear\")(im)\n                assert_allclose(out, ref, rtol=1e-4, atol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_affine_grid.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import AffineGrid\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, is_tf32_env\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                {\"device\": device},\n                {\"spatial_size\": (2, 2)},\n                np.array([[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]], [[1.0, 1.0], [1.0, 1.0]]]),\n            ]\n        )\n\n        TESTS.append([{\"device\": device}, {\"grid\": p(np.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))])\n        TESTS.append([{\"device\": device}, {\"grid\": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))])\n        TESTS.append(\n            [\n                {\"rotate_params\": (1.0, 1.0), \"scale_params\": (-20, 10), \"device\": device},\n                {\"grid\": p(torch.ones((3, 3, 3)))},\n                p(\n                    torch.tensor(\n                        [\n                            [\n                                [-19.2208, -19.2208, -19.2208],\n                                [-19.2208, -19.2208, -19.2208],\n                                [-19.2208, -19.2208, -19.2208],\n                            ],\n                            [\n                                [-11.4264, -11.4264, -11.4264],\n                                [-11.4264, -11.4264, -11.4264],\n                                [-11.4264, -11.4264, -11.4264],\n                            ],\n                            [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"affine\": p(\n                        torch.tensor(\n                            [[-10.8060, -8.4147, 0.0000], [-16.8294, 5.4030, 0.0000], [0.0000, 0.0000, 1.0000]]\n                        )\n                    )\n                },\n                {\"grid\": p(torch.ones((3, 3, 3)))},\n                p(\n                    torch.tensor(\n                        [\n                            [\n                                [-19.2208, -19.2208, -19.2208],\n                                [-19.2208, -19.2208, -19.2208],\n                                [-19.2208, -19.2208, -19.2208],\n                            ],\n                            [\n                                [-11.4264, -11.4264, -11.4264],\n                                [-11.4264, -11.4264, -11.4264],\n                                [-11.4264, -11.4264, -11.4264],\n                            ],\n                            [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"rotate_params\": (1.0, 1.0, 1.0), \"scale_params\": (-20, 10), \"device\": device},\n                {\"grid\": p(torch.ones((4, 3, 3, 3)))},\n                p(\n                    torch.tensor(\n                        [\n                            [\n                                [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]],\n                                [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]],\n                                [[-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435], [-9.5435, -9.5435, -9.5435]],\n                            ],\n                            [\n                                [\n                                    [-20.2381, -20.2381, -20.2381],\n                                    [-20.2381, -20.2381, -20.2381],\n                                    [-20.2381, -20.2381, -20.2381],\n                                ],\n                                [\n                                    [-20.2381, -20.2381, -20.2381],\n                                    [-20.2381, -20.2381, -20.2381],\n                                    [-20.2381, -20.2381, -20.2381],\n                                ],\n                                [\n                                    [-20.2381, -20.2381, -20.2381],\n                                    [-20.2381, -20.2381, -20.2381],\n                                    [-20.2381, -20.2381, -20.2381],\n                                ],\n                            ],\n                            [\n                                [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]],\n                                [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]],\n                                [[-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844], [-0.5844, -0.5844, -0.5844]],\n                            ],\n                            [\n                                [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]],\n                                [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]],\n                                [[1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000]],\n                            ],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n_rtol = 5e-2 if is_tf32_env() else 1e-4\n\n\nclass TestAffineGrid(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_affine_grid(self, input_param, input_data, expected_val):\n        g = AffineGrid(**input_param)\n        set_track_meta(False)\n        result, _ = g(**input_data)\n        self.assertNotIsInstance(result, MetaTensor)\n        self.assertIsInstance(result, torch.Tensor)\n        set_track_meta(True)\n        if \"device\" in input_data:\n            self.assertEqual(result.device, input_data[device])\n        assert_allclose(result, expected_val, type_test=False, rtol=_rtol)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_affined.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Affined\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, test_local_inversion\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", spatial_size=(-1, 0), device=device),\n                {\"img\": p(np.arange(9).reshape((1, 3, 3)))},\n                p(np.arange(9).reshape(1, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", spatial_size=(-1, 0), device=device, dtype=None),\n                {\"img\": p(np.arange(9, dtype=float).reshape((1, 3, 3)))},\n                p(np.arange(9).reshape(1, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", device=device),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2)))},\n                p(np.arange(4).reshape(1, 2, 2)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", spatial_size=(4, 4), device=device),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2)))},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", rotate_params=[np.pi / 2], padding_mode=\"zeros\", spatial_size=(4, 4), device=device),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2)))},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    keys=\"img\",\n                    affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])),\n                    padding_mode=\"zeros\",\n                    spatial_size=(4, 4),\n                    device=device,\n                ),\n                {\"img\": p(np.arange(4).reshape((1, 2, 2)))},\n                p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", spatial_size=(-1, 0, 0), device=device),\n                {\"img\": p(np.arange(27).reshape((1, 3, 3, 3)))},\n                p(np.arange(27).reshape(1, 3, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", spatial_size=(-1, 0, 0), device=device, align_corners=False),\n                {\"img\": p(np.arange(27).reshape((1, 3, 3, 3)))},\n                p(np.arange(27).reshape(1, 3, 3, 3)),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(keys=\"img\", padding_mode=\"zeros\", spatial_size=(4, 4, 4), device=device),\n                {\"img\": p(np.arange(8).reshape((1, 2, 2, 2)))},\n                p(\n                    np.array(\n                        [\n                            [\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 1.0, 0.0],\n                                    [0.0, 2.0, 3.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 4.0, 5.0, 0.0],\n                                    [0.0, 6.0, 7.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    keys=\"img\", rotate_params=[np.pi / 2], padding_mode=\"zeros\", spatial_size=(4, 4, 4), device=device\n                ),\n                {\"img\": p(np.arange(8).reshape((1, 2, 2, 2)))},\n                p(\n                    np.array(\n                        [\n                            [\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 2.0, 0.0, 0.0],\n                                    [0.0, 3.0, 1.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 6.0, 4.0, 0.0],\n                                    [0.0, 7.0, 5.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                                [\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                    [0.0, 0.0, 0.0, 0.0],\n                                ],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestAffined(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_affine(self, input_param, input_data, expected_val):\n        input_copy = deepcopy(input_data)\n        g = Affined(**input_param)\n        result = g(input_data)\n        test_local_inversion(g, result, input_copy, dict_key=\"img\")\n        assert_allclose(result[\"img\"], expected_val, rtol=1e-4, atol=1e-4, type_test=\"tensor\")\n\n        # test lazy\n        lazy_input_param = input_param.copy()\n        for align_corners in [True, False]:\n            lazy_input_param[\"align_corners\"] = align_corners\n            resampler = Affined(**lazy_input_param)\n            call_param = {\"data\": input_data}\n            non_lazy_result = resampler(**call_param)\n            test_resampler_lazy(\n                resampler, non_lazy_result, lazy_input_param, call_param, output_key=\"img\", rtol=1e-3, atol=1e-3\n            )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_as_channel_last.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import AsChannelLast\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, {\"channel_dim\": 0}, (2, 3, 4, 1)])\n    TESTS.append([p, {\"channel_dim\": 1}, (1, 3, 4, 2)])\n    TESTS.append([p, {\"channel_dim\": 3}, (1, 2, 3, 4)])\n\n\nclass TestAsChannelLast(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, in_type, input_param, expected_shape):\n        test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4]))\n        result = AsChannelLast(**input_param)(test_data)\n        self.assertTupleEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_as_channel_lastd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import AsChannelLastd\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, {\"keys\": [\"image\", \"label\", \"extra\"], \"channel_dim\": 0}, (2, 3, 4, 1)])\n    TESTS.append([p, {\"keys\": [\"image\", \"label\", \"extra\"], \"channel_dim\": 1}, (1, 3, 4, 2)])\n    TESTS.append([p, {\"keys\": [\"image\", \"label\", \"extra\"], \"channel_dim\": 3}, (1, 2, 3, 4)])\n\n\nclass TestAsChannelLastd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, in_type, input_param, expected_shape):\n        test_data = {\n            \"image\": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),\n            \"label\": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),\n            \"extra\": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),\n        }\n        result = AsChannelLastd(**input_param)(test_data)\n        self.assertTupleEqual(result[\"image\"].shape, expected_shape)\n        self.assertTupleEqual(result[\"label\"].shape, expected_shape)\n        self.assertTupleEqual(result[\"extra\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_as_discrete.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import AsDiscrete\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASES = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES.append(\n        [\n            {\"argmax\": True, \"to_onehot\": None, \"threshold\": 0.5},\n            p([[[0.0, 1.0]], [[2.0, 3.0]]]),\n            p([[[1.0, 1.0]]]),\n            (1, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"argmax\": True, \"to_onehot\": 2, \"threshold\": 0.5, \"dim\": 0},\n            p([[[0.0, 1.0]], [[2.0, 3.0]]]),\n            p([[[0.0, 0.0]], [[1.0, 1.0]]]),\n            (2, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"argmax\": False, \"to_onehot\": None, \"threshold\": 0.6},\n            p([[[0.0, 1.0], [2.0, 3.0]]]),\n            p([[[0.0, 1.0], [1.0, 1.0]]]),\n            (1, 2, 2),\n        ]\n    )\n\n    # test threshold = 0.0\n    TEST_CASES.append(\n        [\n            {\"argmax\": False, \"to_onehot\": None, \"threshold\": 0.0},\n            p([[[0.0, -1.0], [-2.0, 3.0]]]),\n            p([[[1.0, 0.0], [0.0, 1.0]]]),\n            (1, 2, 2),\n        ]\n    )\n\n    TEST_CASES.append([{\"argmax\": False, \"to_onehot\": 3}, p(1), p([0.0, 1.0, 0.0]), (3,)])\n\n    TEST_CASES.append(\n        [{\"rounding\": \"torchrounding\"}, p([[[0.123, 1.345], [2.567, 3.789]]]), p([[[0.0, 1.0], [3.0, 4.0]]]), (1, 2, 2)]\n    )\n\n\nclass TestAsDiscrete(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_value_shape(self, input_param, img, out, expected_shape):\n        result = AsDiscrete(**input_param)(img)\n        assert_allclose(result, out, rtol=1e-3, type_test=\"tensor\")\n        self.assertTupleEqual(result.shape, expected_shape)\n\n    def test_additional(self):\n        for p in TEST_NDARRAYS:\n            out = AsDiscrete(argmax=True, dim=1, keepdim=False)(p([[[0.0, 1.0]], [[2.0, 3.0]]]))\n            assert_allclose(out, p([[0.0, 0.0], [0.0, 0.0]]), type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_as_discreted.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import AsDiscreted\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASES = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES.append(\n        [\n            {\"keys\": [\"pred\", \"label\"], \"argmax\": [True, False], \"to_onehot\": 2, \"threshold\": 0.5},\n            {\"pred\": p([[[0.0, 1.0]], [[2.0, 3.0]]]), \"label\": p([[[0, 1]]])},\n            {\"pred\": p([[[0.0, 0.0]], [[1.0, 1.0]]]), \"label\": p([[[1.0, 0.0]], [[0.0, 1.0]]])},\n            (2, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": [\"pred\", \"label\"], \"argmax\": False, \"to_onehot\": None, \"threshold\": [0.6, None]},\n            {\"pred\": p([[[0.0, 1.0], [2.0, 3.0]]]), \"label\": p([[[0, 1], [1, 1]]])},\n            {\"pred\": p([[[0.0, 1.0], [1.0, 1.0]]]), \"label\": p([[[0.0, 1.0], [1.0, 1.0]]])},\n            (1, 2, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": [\"pred\"], \"argmax\": True, \"to_onehot\": 2, \"threshold\": 0.5, \"dim\": 0, \"keepdim\": True},\n            {\"pred\": p([[[0.0, 1.0]], [[2.0, 3.0]]])},\n            {\"pred\": p([[[0.0, 0.0]], [[1.0, 1.0]]])},\n            (2, 1, 2),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": \"pred\", \"rounding\": \"torchrounding\"},\n            {\"pred\": p([[[0.123, 1.345], [2.567, 3.789]]])},\n            {\"pred\": p([[[0.0, 1.0], [3.0, 4.0]]])},\n            (1, 2, 2),\n        ]\n    )\n\n    # test threshold = 0.0\n    TEST_CASES.append(\n        [\n            {\"keys\": [\"pred\", \"label\"], \"argmax\": False, \"to_onehot\": None, \"threshold\": [0.0, None]},\n            {\"pred\": p([[[0.0, -1.0], [-2.0, 3.0]]]), \"label\": p([[[0, 1], [1, 1]]])},\n            {\"pred\": p([[[1.0, 0.0], [0.0, 1.0]]]), \"label\": p([[[0.0, 1.0], [1.0, 1.0]]])},\n            (1, 2, 2),\n        ]\n    )\n\n\nclass TestAsDiscreted(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_value_shape(self, input_param, test_input, output, expected_shape):\n        result = AsDiscreted(**input_param)(test_input)\n        assert_allclose(result[\"pred\"], output[\"pred\"], rtol=1e-3, type_test=\"tensor\")\n        self.assertTupleEqual(result[\"pred\"].shape, expected_shape)\n        if \"label\" in result:\n            assert_allclose(result[\"label\"], output[\"label\"], rtol=1e-3, type_test=\"tensor\")\n            self.assertTupleEqual(result[\"label\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_border_pad.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import BorderPad\nfrom monai.utils.enums import NumpyPadMode, PytorchPadMode\nfrom tests.padders import PadTest\n\nTESTS = [\n    [{\"spatial_border\": 2}, (3, 8, 8, 4), (3, 12, 12, 8)],\n    [{\"spatial_border\": [1, 2, 3]}, (3, 8, 8, 4), (3, 10, 12, 10)],\n    [{\"spatial_border\": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)],\n    [{\"spatial_border\": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)],\n]\n\n\nclass TestBorderPad(PadTest):\n    Padder = BorderPad\n\n    @parameterized.expand(TESTS)\n    def test_pad(self, input_param, input_shape, expected_shape):\n        modes = [\"constant\", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT]\n        self.pad_test(input_param, input_shape, expected_shape, modes)\n\n    def test_pad_kwargs(self):\n        kwargs = {\"spatial_border\": 2, \"mode\": \"constant\"}\n        unchanged_slices = [slice(None), slice(2, -2), slice(2, -2)]\n        self.pad_test_kwargs(unchanged_slices, **kwargs)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.pad_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_border_padd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import BorderPadd\nfrom monai.utils.enums import NumpyPadMode, PytorchPadMode\nfrom tests.padders import PadTest\n\nTESTS = [\n    [{\"keys\": \"img\", \"spatial_border\": 2}, (3, 8, 8, 4), (3, 12, 12, 8)],\n    [{\"keys\": \"img\", \"spatial_border\": [1, 2, 3]}, (3, 8, 8, 4), (3, 10, 12, 10)],\n    [{\"keys\": \"img\", \"spatial_border\": [1, 2, 3, 4, 5, 6]}, (3, 8, 8, 4), (3, 11, 15, 15)],\n    [{\"keys\": \"img\", \"spatial_border\": 2}, (3, 8, 8, 4), (3, 12, 12, 8)],\n    [{\"keys\": \"img\", \"spatial_border\": 2}, (3, 8, 8, 4), (3, 12, 12, 8)],\n]\n\n\nclass TestBorderPadd(PadTest):\n    Padder = BorderPadd\n\n    @parameterized.expand(TESTS)\n    def test_pad(self, input_param, input_shape, expected_shape):\n        modes = [\"constant\", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, \"edge\", NumpyPadMode.EDGE]\n        self.pad_test(input_param, input_shape, expected_shape, modes)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.pad_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_bounding_rect.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nimport monai\nfrom monai.transforms import BoundingRect\nfrom tests.test_utils import TEST_NDARRAYS\n\nTEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]]\n\nTEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]]\n\nTEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]\n\n\nclass TestBoundingRect(unittest.TestCase):\n    def setUp(self):\n        monai.utils.set_determinism(1)\n\n    def tearDown(self):\n        monai.utils.set_determinism(None)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, input_shape, expected):\n        test_data = np.random.randint(0, 8, size=input_shape)\n        test_data = test_data == 7\n        for p in TEST_NDARRAYS:\n            result = BoundingRect()(p(test_data))\n            np.testing.assert_allclose(result, expected)\n\n    def test_select_fn(self):\n        test_data = np.random.randint(0, 8, size=(2, 3))\n        test_data = test_data == 7\n        for p in TEST_NDARRAYS:\n            bbox = BoundingRect(select_fn=lambda x: x < 1)(p(test_data))\n            np.testing.assert_allclose(bbox, [[0, 3], [0, 3]])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_bounding_rectd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nimport monai\nfrom monai.transforms import BoundingRectD\nfrom tests.test_utils import TEST_NDARRAYS\n\nTEST_CASE_1 = [(2, 3), [[0, 0], [1, 2]]]\n\nTEST_CASE_2 = [(1, 8, 10), [[0, 7, 1, 9]]]\n\nTEST_CASE_3 = [(2, 16, 20, 18), [[0, 16, 0, 20, 0, 18], [0, 16, 0, 20, 0, 18]]]\n\n\nclass TestBoundingRectD(unittest.TestCase):\n    def setUp(self):\n        monai.utils.set_determinism(1)\n\n    def tearDown(self):\n        monai.utils.set_determinism(None)\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, input_shape, expected):\n        test_data = np.random.randint(0, 8, size=input_shape)\n        test_data = test_data == 7\n        for p in TEST_NDARRAYS:\n            result = BoundingRectD(\"image\")({\"image\": p(test_data)})\n            np.testing.assert_allclose(result[\"image_bbox\"], expected)\n\n            result = BoundingRectD(\"image\", \"cc\")({\"image\": p(test_data)})\n            np.testing.assert_allclose(result[\"image_cc\"], expected)\n\n            with self.assertRaises(KeyError):\n                BoundingRectD(\"image\", \"cc\")({\"image\": p(test_data), \"image_cc\": None})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_cast_to_type.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import CastToType\nfrom monai.utils import optional_import\nfrom monai.utils.type_conversion import get_equivalent_dtype\nfrom tests.test_utils import HAS_CUPY, TEST_NDARRAYS\n\ncp, _ = optional_import(\"cupy\")\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for out_dtype in (np.float64, torch.float64):\n        TESTS.append([out_dtype, p(np.array([[0, 1], [1, 2]], dtype=np.float32)), out_dtype])\n\nTESTS_CUPY = [\n    [np.float32, np.array([[0, 1], [1, 2]], dtype=np.float32), np.float32],\n    [np.float32, np.array([[0, 1], [1, 2]], dtype=np.uint8), np.float32],\n    [np.uint8, np.array([[0, 1], [1, 2]], dtype=np.float32), np.uint8],\n]\n\n\nclass TestCastToType(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_type(self, out_dtype, input_data, expected_type):\n        result = CastToType(dtype=out_dtype)(input_data)\n        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))\n\n        result = CastToType()(input_data, out_dtype)\n        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))\n\n    @parameterized.expand(TESTS_CUPY)\n    @unittest.skipUnless(HAS_CUPY, \"Requires CuPy\")\n    def test_type_cupy(self, out_dtype, input_data, expected_type):\n        input_data = cp.asarray(input_data)\n\n        result = CastToType(dtype=out_dtype)(input_data)\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))\n\n        result = CastToType()(input_data, out_dtype)\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_cast_to_typed.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import CastToTyped\nfrom monai.utils import optional_import\nfrom tests.test_utils import HAS_CUPY\n\ncp, _ = optional_import(\"cupy\")\n\nTEST_CASE_1 = [\n    {\"keys\": [\"img\"], \"dtype\": np.float64},\n    {\"img\": np.array([[0, 1], [1, 2]], dtype=np.float32), \"seg\": np.array([[0, 1], [1, 2]], dtype=np.int8)},\n    {\"img\": np.float64, \"seg\": np.int8},\n]\n\nTEST_CASE_2 = [\n    {\"keys\": [\"img\"], \"dtype\": torch.float64},\n    {\n        \"img\": torch.tensor([[0, 1], [1, 2]], dtype=torch.float32),\n        \"seg\": torch.tensor([[0, 1], [1, 2]], dtype=torch.int8),\n    },\n    {\"img\": torch.float64, \"seg\": torch.int8},\n]\n\nTESTS_CUPY = [\n    [\n        {\"keys\": \"image\", \"dtype\": np.uint8},\n        {\"image\": np.array([[0, 1], [1, 2]], dtype=np.float32), \"label\": np.array([[0, 1], [1, 1]], dtype=np.float32)},\n        {\"image\": np.uint8, \"label\": np.float32},\n    ],\n    [\n        {\"keys\": [\"image\", \"label\"], \"dtype\": np.float32},\n        {\"image\": np.array([[0, 1], [1, 2]], dtype=np.uint8), \"label\": np.array([[0, 1], [1, 1]], dtype=np.uint8)},\n        {\"image\": np.float32, \"label\": np.float32},\n    ],\n]\n\n\nclass TestCastToTyped(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_type(self, input_param, input_data, expected_type):\n        result = CastToTyped(**input_param)(input_data)\n        for k, v in result.items():\n            self.assertEqual(v.dtype, expected_type[k])\n\n    @parameterized.expand(TESTS_CUPY)\n    @unittest.skipUnless(HAS_CUPY, \"Requires CuPy\")\n    def test_type_cupy(self, input_param, input_data, expected_type):\n        input_data = {k: cp.asarray(v) for k, v in input_data.items()}\n\n        result = CastToTyped(**input_param)(input_data)\n        for k, v in result.items():\n            self.assertTrue(isinstance(v, cp.ndarray))\n            self.assertEqual(v.dtype, expected_type[k])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_center_scale_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import CenterScaleCrop\nfrom tests.croppers import CropTest\n\nTEST_SHAPES = [\n    [{\"roi_scale\": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3), False],\n    [{\"roi_scale\": [0.6, 0.3, -1]}, (3, 3, 4, 3), (3, 2, 2, 3), True],\n    [{\"roi_scale\": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2), True],\n    [{\"roi_scale\": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2), True],\n]\n\nTEST_VALUES = [\n    [\n        {\"roi_scale\": [0.4, 0.4]},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n        np.array([[[1, 2], [2, 3]]]),\n    ]\n]\n\n\nclass TestCenterScaleCrop(CropTest):\n    Cropper = CenterScaleCrop\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape, _):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_arr, expected_arr):\n        self.crop_test_value(input_param, input_arr, expected_arr)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _, align_corners):\n        self.crop_test_pending_ops(input_param, input_shape, align_corners)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_center_scale_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import CenterScaleCropd\nfrom tests.croppers import CropTest\n\nTESTS = [\n    [{\"keys\": \"img\", \"roi_scale\": [0.6, 0.3, -1]}, (3, 3, 3, 3), (3, 2, 1, 3), False],\n    [{\"keys\": \"img\", \"roi_scale\": [0.6, 0.3, -1]}, (3, 3, 1, 3), (3, 2, 1, 3), True],\n    [{\"keys\": \"img\", \"roi_scale\": [0.6, 0.3, -1]}, (3, 3, 4, 3), (3, 2, 2, 3), True],\n    [{\"keys\": \"img\", \"roi_scale\": 0.6}, (3, 3, 3, 3), (3, 2, 2, 2), True],\n    [{\"keys\": \"img\", \"roi_scale\": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2), True],\n]\n\nTEST_VALUES = [\n    [\n        {\"keys\": \"img\", \"roi_scale\": [0.4, 0.4]},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n        np.array([[[1, 2], [2, 3]]]),\n    ]\n]\n\n\nclass TestCenterScaleCropd(CropTest):\n    Cropper = CenterScaleCropd\n\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_shape, expected_shape, _):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_arr, expected_arr):\n        self.crop_test_value(input_param, input_arr, expected_arr)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _, align_corners):\n        self.crop_test_pending_ops(input_param, input_shape, align_corners)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_center_spatial_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import CenterSpatialCrop\nfrom tests.croppers import CropTest\n\nTEST_SHAPES = [\n    [{\"roi_size\": [2, 2, -1]}, (3, 3, 3, 3), (3, 2, 2, 3), True],\n    [{\"roi_size\": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2), True],\n    [{\"roi_size\": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2), True],\n    [{\"roi_size\": [2, 1, 2]}, (3, 3, 3, 3), (3, 2, 1, 2), False],\n    [{\"roi_size\": [2, 1, 3]}, (3, 3, 1, 3), (3, 2, 1, 3), True],\n]\n\nTEST_VALUES = [\n    [\n        {\"roi_size\": [2, 2]},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n        np.array([[[1, 2], [2, 3]]]),\n    ]\n]\n\n\nclass TestCenterSpatialCrop(CropTest):\n    Cropper = CenterSpatialCrop\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape, _):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_arr, expected_arr):\n        self.crop_test_value(input_param, input_arr, expected_arr)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _, align_corners):\n        self.crop_test_pending_ops(input_param, input_shape, align_corners)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_center_spatial_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import CenterSpatialCropd\nfrom tests.croppers import CropTest\n\nTEST_SHAPES = [\n    [\n        {\"keys\": \"img\", \"roi_size\": [2, -1, -1]},\n        (3, 3, 3, 3),\n        (3, 2, 3, 3),\n        (slice(None), slice(None, -1), slice(None), slice(None)),\n    ],\n    [\n        {\"keys\": \"img\", \"roi_size\": [2, 2, 2]},\n        (3, 3, 3, 3),\n        (3, 2, 2, 2),\n        (slice(None), slice(None, -1), slice(None, -1), slice(None, -1)),\n    ],\n]\n\nTEST_CASES = [\n    [\n        {\"keys\": \"img\", \"roi_size\": [2, 2]},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n        np.array([[[1, 2], [2, 3]]]),\n    ]\n]\n\n\nclass TestCenterSpatialCropd(CropTest):\n    Cropper = CenterSpatialCropd\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape, same_area):\n        self.crop_test(input_param, input_shape, expected_shape, same_area)\n\n    @parameterized.expand(TEST_CASES)\n    def test_value(self, input_param, input_data, expected_value):\n        self.crop_test_value(input_param, input_data, expected_value)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _expected_shape, _same_area):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_classes_to_indices.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import ClassesToIndices\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS_CASES = []\nfor p in TEST_NDARRAYS:\n    TESTS_CASES.append(\n        [\n            # test Argmax data\n            {\"num_classes\": 3, \"image_threshold\": 0.0},\n            p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]),\n            None,\n            [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"num_classes\": 3, \"image_threshold\": 60},\n            p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]),\n            p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]),\n            [p([0, 8]), p([1, 5, 6]), p([3])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            # test One-Hot data\n            {\"image_threshold\": 0.0},\n            p(\n                [\n                    [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                    [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                    [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                ]\n            ),\n            None,\n            [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"num_classes\": None, \"image_threshold\": 60},\n            p(\n                [\n                    [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                    [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                    [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                ]\n            ),\n            p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]),\n            [p([0, 8]), p([1, 5, 6]), p([3])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            # test output_shape\n            {\"num_classes\": 3, \"image_threshold\": 0.0, \"output_shape\": [3, 3]},\n            p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]),\n            None,\n            [p([[0, 0], [1, 1], [2, 2]]), p([[0, 1], [1, 2], [2, 0]]), p([[0, 2], [1, 0], [2, 1]])],\n        ]\n    )\n\n\nclass TestClassesToIndices(unittest.TestCase):\n    @parameterized.expand(TESTS_CASES)\n    def test_value(self, input_args, label, image, expected_indices):\n        indices = ClassesToIndices(**input_args)(label, image)\n        for i, e in zip(indices, expected_indices):\n            assert_allclose(i, e)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_classes_to_indicesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import ClassesToIndicesd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS_CASES = []\nfor p in TEST_NDARRAYS:\n    TESTS_CASES.append(\n        [\n            # test Argmax data\n            {\"keys\": \"label\", \"num_classes\": 3, \"image_threshold\": 0.0},\n            {\"label\": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])},\n            [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"keys\": \"label\", \"image_key\": \"image\", \"num_classes\": 3, \"image_threshold\": 60},\n            {\n                \"label\": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]),\n                \"image\": p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]),\n            },\n            [p([0, 8]), p([1, 5, 6]), p([3])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            # test One-Hot data\n            {\"keys\": \"label\", \"image_threshold\": 0.0},\n            {\n                \"label\": p(\n                    [\n                        [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                        [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                        [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                    ]\n                )\n            },\n            [p([0, 4, 8]), p([1, 5, 6]), p([2, 3, 7])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\n                \"keys\": \"label\",\n                \"image_key\": \"image\",\n                \"num_classes\": None,\n                \"image_threshold\": 60,\n                \"max_samples_per_class\": 2,\n            },\n            {\n                \"label\": p(\n                    [\n                        [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                        [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                        [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                    ]\n                ),\n                \"image\": p([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]),\n            },\n            [p([0, 8]), p([1, 6]), p([3])],\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            # test output_shape\n            {\n                \"keys\": \"label\",\n                \"indices_postfix\": \"cls\",\n                \"num_classes\": 3,\n                \"image_threshold\": 0.0,\n                \"output_shape\": [3, 3],\n            },\n            {\"label\": p([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])},\n            [p([[0, 0], [1, 1], [2, 2]]), p([[0, 1], [1, 2], [2, 0]]), p([[0, 2], [1, 0], [2, 1]])],\n        ]\n    )\n\n\nclass TestClassesToIndicesd(unittest.TestCase):\n    @parameterized.expand(TESTS_CASES)\n    def test_value(self, input_args, input_data, expected_indices):\n        result = ClassesToIndicesd(**input_args)(input_data)\n        key_postfix = input_args.get(\"indices_postfix\")\n        key_postfix = \"_cls_indices\" if key_postfix is None else key_postfix\n        for i, e in zip(result[\"label\" + key_postfix], expected_indices):\n            assert_allclose(i, e)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_clip_intensity_percentiles.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import ClipIntensityPercentiles\nfrom monai.transforms.utils import soft_clip\nfrom monai.transforms.utils_pytorch_numpy_unification import clip, percentile\nfrom monai.utils.type_conversion import convert_to_tensor\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose\n\n\ndef test_hard_clip_func(im, lower, upper):\n    im_t = convert_to_tensor(im)\n    if lower is None:\n        upper = percentile(im_t, upper)\n    elif upper is None:\n        lower = percentile(im_t, lower)\n    else:\n        lower, upper = percentile(im_t, (lower, upper))\n    return clip(im_t, lower, upper)\n\n\ndef test_soft_clip_func(im, lower, upper):\n    im_t = convert_to_tensor(im)\n    if lower is None:\n        upper = percentile(im_t, upper)\n    elif upper is None:\n        lower = percentile(im_t, lower)\n    else:\n        lower, upper = percentile(im_t, (lower, upper))\n    return soft_clip(im_t, minv=lower, maxv=upper, sharpness_factor=1.0, dtype=torch.float32)\n\n\nclass TestClipIntensityPercentiles2D(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_two_sided(self, p):\n        hard_clipper = ClipIntensityPercentiles(upper=95, lower=5)\n        im = p(self.imt)\n        result = hard_clipper(im)\n        expected = test_hard_clip_func(im, 5, 95)\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_high(self, p):\n        hard_clipper = ClipIntensityPercentiles(upper=95, lower=None)\n        im = p(self.imt)\n        result = hard_clipper(im)\n        expected = test_hard_clip_func(im, 0, 95)\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_low(self, p):\n        hard_clipper = ClipIntensityPercentiles(upper=None, lower=5)\n        im = p(self.imt)\n        result = hard_clipper(im)\n        expected = test_hard_clip_func(im, 5, 100)\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_two_sided(self, p):\n        soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper(im)\n        expected = test_soft_clip_func(im, 5, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_high(self, p):\n        soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper(im)\n        expected = test_soft_clip_func(im, None, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_low(self, p):\n        soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper(im)\n        expected = test_soft_clip_func(im, 5, None)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, p):\n        clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True)\n        im = p(self.imt)\n        result = clipper(im)\n        im_t = convert_to_tensor(self.imt)\n        for i, c in enumerate(im_t):\n            lower, upper = percentile(c, (5, 95))\n            expected = clip(c, lower, upper)\n            assert_allclose(result[i], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    def test_ill_sharpness_factor(self):\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=0.0)\n\n    def test_ill_lower_percentile(self):\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentiles(upper=None, lower=-1)\n\n    def test_ill_upper_percentile(self):\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentiles(upper=101, lower=None)\n\n    def test_ill_percentiles(self):\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentiles(upper=95, lower=96)\n\n    def test_ill_both_none(self):\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentiles(upper=None, lower=None)\n\n\nclass TestClipIntensityPercentiles3D(NumpyImageTestCase3D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_two_sided(self, p):\n        hard_clipper = ClipIntensityPercentiles(upper=95, lower=5)\n        im = p(self.imt)\n        result = hard_clipper(im)\n        expected = test_hard_clip_func(im, 5, 95)\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_high(self, p):\n        hard_clipper = ClipIntensityPercentiles(upper=95, lower=None)\n        im = p(self.imt)\n        result = hard_clipper(im)\n        expected = test_hard_clip_func(im, 0, 95)\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_low(self, p):\n        hard_clipper = ClipIntensityPercentiles(upper=None, lower=5)\n        im = p(self.imt)\n        result = hard_clipper(im)\n        expected = test_hard_clip_func(im, 5, 100)\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_two_sided(self, p):\n        soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper(im)\n        expected = test_soft_clip_func(im, 5, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_high(self, p):\n        soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper(im)\n        expected = test_soft_clip_func(im, None, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_low(self, p):\n        soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper(im)\n        expected = test_soft_clip_func(im, 5, None)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, p):\n        clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True)\n        im = p(self.imt)\n        result = clipper(im)\n        im_t = convert_to_tensor(self.imt)\n        for i, c in enumerate(im_t):\n            lower, upper = percentile(c, (5, 95))\n            expected = clip(c, lower, upper)\n            assert_allclose(result[i], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_clip_intensity_percentilesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import ClipIntensityPercentilesd\nfrom monai.transforms.utils_pytorch_numpy_unification import clip, percentile\nfrom monai.utils.type_conversion import convert_to_tensor\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose\nfrom tests.transforms.test_clip_intensity_percentiles import test_hard_clip_func, test_soft_clip_func\n\n\nclass TestClipIntensityPercentilesd2D(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_two_sided(self, p):\n        key = \"img\"\n        hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5)\n        im = p(self.imt)\n        result = hard_clipper({key: im})\n        expected = test_hard_clip_func(im, 5, 95)\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_high(self, p):\n        key = \"img\"\n        hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None)\n        im = p(self.imt)\n        result = hard_clipper({key: im})\n        expected = test_hard_clip_func(im, 0, 95)\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_low(self, p):\n        key = \"img\"\n        hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5)\n        im = p(self.imt)\n        result = hard_clipper({key: im})\n        expected = test_hard_clip_func(im, 5, 100)\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_two_sided(self, p):\n        key = \"img\"\n        soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper({key: im})\n        expected = test_soft_clip_func(im, 5, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_high(self, p):\n        key = \"img\"\n        soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper({key: im})\n        expected = test_soft_clip_func(im, None, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_low(self, p):\n        key = \"img\"\n        soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper({key: im})\n        expected = test_soft_clip_func(im, 5, None)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, p):\n        key = \"img\"\n        clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True)\n        im = p(self.imt)\n        result = clipper({key: im})\n        im_t = convert_to_tensor(self.imt)\n        for i, c in enumerate(im_t):\n            lower, upper = percentile(c, (5, 95))\n            expected = clip(c, lower, upper)\n            assert_allclose(result[key][i], p(expected), type_test=\"tensor\", rtol=1e-3, atol=0)\n\n    def test_ill_sharpness_factor(self):\n        key = \"img\"\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=0.0)\n\n    def test_ill_lower_percentile(self):\n        key = \"img\"\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentilesd(keys=[key], upper=None, lower=-1)\n\n    def test_ill_upper_percentile(self):\n        key = \"img\"\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentilesd(keys=[key], upper=101, lower=None)\n\n    def test_ill_percentiles(self):\n        key = \"img\"\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentilesd(keys=[key], upper=95, lower=96)\n\n    def test_ill_both_none(self):\n        key = \"img\"\n        with self.assertRaises(ValueError):\n            ClipIntensityPercentilesd(keys=[key], upper=None, lower=None)\n\n\nclass TestClipIntensityPercentilesd3D(NumpyImageTestCase3D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_two_sided(self, p):\n        key = \"img\"\n        hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5)\n        im = p(self.imt)\n        result = hard_clipper({key: im})\n        expected = test_hard_clip_func(im, 5, 95)\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_high(self, p):\n        key = \"img\"\n        hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None)\n        im = p(self.imt)\n        result = hard_clipper({key: im})\n        expected = test_hard_clip_func(im, 0, 95)\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_hard_clipping_one_sided_low(self, p):\n        key = \"img\"\n        hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5)\n        im = p(self.imt)\n        result = hard_clipper({key: im})\n        expected = test_hard_clip_func(im, 5, 100)\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_two_sided(self, p):\n        key = \"img\"\n        soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper({key: im})\n        expected = test_soft_clip_func(im, 5, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_high(self, p):\n        key = \"img\"\n        soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper({key: im})\n        expected = test_soft_clip_func(im, None, 95)\n        # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_soft_clipping_one_sided_low(self, p):\n        key = \"img\"\n        soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0)\n        im = p(self.imt)\n        result = soft_clipper({key: im})\n        expected = test_soft_clip_func(im, 5, None)\n        # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy\n        assert_allclose(result[key], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, p):\n        key = \"img\"\n        clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True)\n        im = p(self.imt)\n        result = clipper({key: im})\n        im_t = convert_to_tensor(im)\n        for i, c in enumerate(im_t):\n            lower, upper = percentile(c, (5, 95))\n            expected = clip(c, lower, upper)\n            assert_allclose(result[key][i], p(expected), type_test=\"tensor\", rtol=1e-4, atol=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_compose_get_number_conversions.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Compose\nfrom monai.transforms.compose import OneOf\nfrom monai.transforms.transform import Transform\nfrom monai.transforms.utils import get_number_image_type_conversions\nfrom monai.utils import convert_to_numpy, convert_to_tensor\n\nNP_ARR = np.ones((10, 10, 10))\nPT_ARR = torch.as_tensor(NP_ARR)\nKEY = \"IMAGE\"\n\n\ndef _apply(x, fn):\n    if isinstance(x, dict):\n        d = deepcopy(x)\n        d[KEY] = fn(d[KEY])\n        return d\n    return fn(x)\n\n\nclass Load(Transform):\n\n    def __init__(self, as_tensor):\n        self.fn = lambda _: PT_ARR if as_tensor else NP_ARR\n\n    def __call__(self, x):\n        return _apply(x, self.fn)\n\n\nclass N(Transform):\n\n    def __call__(self, x):\n        return _apply(x, convert_to_numpy)\n\n\nclass T(Transform):\n\n    def __call__(self, x):\n        return _apply(x, convert_to_tensor)\n\n\nclass NT(Transform):\n\n    def __call__(self, x):\n        return _apply(x, lambda x: x)\n\n\nclass TCPU(Transform):\n\n    def __call__(self, x):\n        return _apply(x, lambda x: convert_to_tensor(x).cpu())\n\n\nclass TGPU(Transform):\n\n    def __call__(self, x):\n        return _apply(x, lambda x: convert_to_tensor(x).cuda())\n\n\nTESTS: list[tuple] = []\nfor is_dict in (False, True):\n    # same type depends on input\n    TESTS.append(((N(), N()), is_dict, NP_ARR, 0))\n    TESTS.append(((N(), N()), is_dict, PT_ARR, 1))\n    TESTS.append(((T(), T()), is_dict, NP_ARR, 1))\n    TESTS.append(((T(), T()), is_dict, PT_ARR, 0))\n\n    # loading depends on loader's output type and following transform\n    TESTS.append(((Load(as_tensor=False), N()), is_dict, \"fname.nii\", 0))\n    TESTS.append(((Load(as_tensor=True), N()), is_dict, \"fname.nii\", 1))\n    TESTS.append(((Load(as_tensor=False), T()), is_dict, \"fname.nii\", 1))\n    TESTS.append(((Load(as_tensor=True), T()), is_dict, \"fname.nii\", 0))\n    TESTS.append(((Load(as_tensor=True), NT()), is_dict, \"fname.nii\", 0))\n    TESTS.append(((Load(as_tensor=True), NT()), is_dict, \"fname.nii\", 0))\n\n    # no changes for ambivalent transforms\n    TESTS.append(((NT(), NT()), is_dict, NP_ARR, 0))\n    TESTS.append(((NT(), NT()), is_dict, PT_ARR, 0))\n\n    # multiple conversions\n    TESTS.append(((N(), T(), N()), is_dict, PT_ARR, 3))\n    TESTS.append(((N(), NT(), T(), T(), NT(), NT(), N()), is_dict, PT_ARR, 3))\n\n    # shouldn't matter that there are nested composes\n    TESTS.append(((N(), NT(), T(), Compose([T(), NT(), NT(), N()])), is_dict, PT_ARR, 3))\n\n    # changing device also counts\n    if torch.cuda.is_available():\n        TESTS.append(((TCPU(), TGPU(), TCPU()), is_dict, PT_ARR, 2))\n\n\nclass TestComposeNumConversions(unittest.TestCase):\n\n    @parameterized.expand(TESTS)\n    def test_get_number_of_conversions(self, transforms, is_dict, input, expected):\n        input = input if not is_dict else {KEY: input, \"Other\": NP_ARR}\n        tr = Compose(transforms)\n        n = get_number_image_type_conversions(tr, input, key=KEY if is_dict else None)\n        self.assertEqual(n, expected)\n\n    def test_raises(self):\n        tr = Compose([N(), OneOf([T(), T()])])\n        with self.assertRaises(RuntimeError):\n            get_number_image_type_conversions(tr, NP_ARR)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_concat_itemsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor\nfrom monai.transforms import ConcatItemsd\nfrom tests.test_utils import assert_allclose\n\n\nclass TestConcatItemsd(unittest.TestCase):\n    def test_tensor_values(self):\n        device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu:0\")\n        input_data = {\n            \"img1\": torch.tensor([[0, 1], [1, 2]], device=device),\n            \"img2\": torch.tensor([[0, 1], [1, 2]], device=device),\n        }\n        result = ConcatItemsd(keys=[\"img1\", \"img2\"], name=\"cat_img\")(input_data)\n        self.assertIn(\"cat_img\", result)\n        result[\"cat_img\"] += 1\n        assert_allclose(result[\"img1\"], torch.tensor([[0, 1], [1, 2]], device=device))\n        assert_allclose(result[\"cat_img\"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))\n\n    def test_metatensor_values(self):\n        device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu:0\")\n        input_data = {\n            \"img1\": MetaTensor([[0, 1], [1, 2]], device=device),\n            \"img2\": MetaTensor([[0, 1], [1, 2]], device=device),\n        }\n        result = ConcatItemsd(keys=[\"img1\", \"img2\"], name=\"cat_img\")(input_data)\n        self.assertIn(\"cat_img\", result)\n        self.assertIsInstance(result[\"cat_img\"], MetaTensor)\n        self.assertEqual(result[\"img1\"].meta, result[\"cat_img\"].meta)\n        result[\"cat_img\"] += 1\n        assert_allclose(result[\"img1\"], torch.tensor([[0, 1], [1, 2]], device=device))\n        assert_allclose(result[\"cat_img\"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))\n\n    def test_numpy_values(self):\n        input_data = {\"img1\": np.array([[0, 1], [1, 2]]), \"img2\": np.array([[0, 1], [1, 2]])}\n        result = ConcatItemsd(keys=[\"img1\", \"img2\"], name=\"cat_img\")(input_data)\n        self.assertIn(\"cat_img\", result)\n        result[\"cat_img\"] += 1\n        np.testing.assert_allclose(result[\"img1\"], np.array([[0, 1], [1, 2]]))\n        np.testing.assert_allclose(result[\"cat_img\"], np.array([[1, 2], [2, 3], [1, 2], [2, 3]]))\n\n    def test_single_numpy(self):\n        input_data = {\"img\": np.array([[0, 1], [1, 2]])}\n        result = ConcatItemsd(keys=\"img\", name=\"cat_img\")(input_data)\n        result[\"cat_img\"] += 1\n        np.testing.assert_allclose(result[\"img\"], np.array([[0, 1], [1, 2]]))\n        np.testing.assert_allclose(result[\"cat_img\"], np.array([[1, 2], [2, 3]]))\n\n    def test_single_tensor(self):\n        input_data = {\"img\": torch.tensor([[0, 1], [1, 2]])}\n        result = ConcatItemsd(keys=\"img\", name=\"cat_img\")(input_data)\n        result[\"cat_img\"] += 1\n        assert_allclose(result[\"img\"], torch.tensor([[0, 1], [1, 2]]))\n        assert_allclose(result[\"cat_img\"], torch.tensor([[1, 2], [2, 3]]))\n\n    def test_single_metatensor(self):\n        input_data = {\"img\": MetaTensor([[0, 1], [1, 2]])}\n        result = ConcatItemsd(keys=\"img\", name=\"cat_img\")(input_data)\n        result[\"cat_img\"] += 1\n        assert_allclose(result[\"img\"], torch.tensor([[0, 1], [1, 2]]))\n        assert_allclose(result[\"cat_img\"], torch.tensor([[1, 2], [2, 3]]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_convert_to_multi_channel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import ConvertToMultiChannelBasedOnBratsClasses\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.extend(\n        [\n            [\n                p([[0, 1, 2], [1, 2, 4], [0, 1, 4]]),\n                p(\n                    [\n                        [[0, 1, 0], [1, 0, 1], [0, 1, 1]],\n                        [[0, 1, 1], [1, 1, 1], [0, 1, 1]],\n                        [[0, 0, 0], [0, 0, 1], [0, 0, 1]],\n                    ]\n                ),\n            ],\n            [\n                p([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]),\n                p(\n                    [\n                        [[[0, 1], [1, 0]], [[0, 1], [1, 1]]],\n                        [[[0, 1], [1, 1]], [[1, 1], [1, 1]]],\n                        [[[0, 0], [0, 0]], [[0, 1], [1, 1]]],\n                    ]\n                ),\n            ],\n        ]\n    )\n\n\nclass TestConvertToMultiChannel(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, data, expected_result):\n        result = ConvertToMultiChannelBasedOnBratsClasses()(data)\n        assert_allclose(result, expected_result)\n        self.assertTrue(result.dtype in (bool, torch.bool))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_convert_to_multi_channeld.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ConvertToMultiChannelBasedOnBratsClassesd\n\nTEST_CASE = [\n    {\"keys\": \"label\"},\n    {\"label\": np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]])},\n    np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),\n]\n\n\nclass TestConvertToMultiChanneld(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE])\n    def test_type_shape(self, keys, data, expected_result):\n        result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)\n        np.testing.assert_equal(result[\"label\"], expected_result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_copy_itemsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks import eval_mode\nfrom monai.transforms import CopyItemsd\nfrom monai.utils import ensure_tuple\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [\"img\", 1, \"img_1\"]\n\nTEST_CASE_2 = [[\"img\", \"seg\"], 1, [\"img_1\", \"seg_1\"]]\n\nTEST_CASE_3 = [\"img\", 2, [\"img_1\", \"img_2\"]]\n\nTEST_CASE_4 = [[\"img\", \"seg\"], 2, [\"img_1\", \"seg_1\", \"img_2\", \"seg_2\"]]\n\n\nclass TestCopyItemsd(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_numpy_values(self, keys, times, names):\n        input_data = {\"img\": np.array([[0, 1], [1, 2]]), \"seg\": np.array([[3, 4], [4, 5]])}\n        result = CopyItemsd(keys=keys, times=times, names=names)(input_data)\n        for name in ensure_tuple(names):\n            self.assertTrue(name in result)\n        result[\"img_1\"] += 1\n        np.testing.assert_allclose(result[\"img_1\"], np.array([[1, 2], [2, 3]]))\n        np.testing.assert_allclose(result[\"img\"], np.array([[0, 1], [1, 2]]))\n\n    def test_default_names(self):\n        input_data = {\"img\": np.array([[0, 1], [1, 2]]), \"seg\": np.array([[3, 4], [4, 5]])}\n        result = CopyItemsd(keys=[\"img\", \"seg\"], times=2, names=None)(input_data)\n        for name in [\"img_0\", \"seg_0\", \"img_1\", \"seg_1\"]:\n            self.assertTrue(name in result)\n\n    def test_tensor_values(self):\n        device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu:0\")\n        input_data = {\n            \"img\": torch.tensor([[0, 1], [1, 2]], device=device),\n            \"seg\": torch.tensor([[0, 1], [1, 2]], device=device),\n        }\n        # test default `times=1`\n        result = CopyItemsd(keys=\"img\", names=\"img_1\")(input_data)\n        self.assertTrue(\"img_1\" in result)\n        result[\"img_1\"] += 1\n        assert_allclose(result[\"img\"], torch.tensor([[0, 1], [1, 2]], device=device))\n        assert_allclose(result[\"img_1\"], torch.tensor([[1, 2], [2, 3]], device=device))\n\n    def test_array_values(self):\n        input_data = {\"img\": [[0, 1], [1, 2]], \"seg\": [[0, 1], [1, 2]]}\n        result = CopyItemsd(keys=\"img\", times=1, names=\"img_1\")(input_data)\n        self.assertTrue(\"img_1\" in result)\n        result[\"img_1\"][0][0] += 1\n        np.testing.assert_allclose(result[\"img\"], [[0, 1], [1, 2]])\n        np.testing.assert_allclose(result[\"img_1\"], [[1, 1], [1, 2]])\n\n    def test_graph_tensor_values(self):\n        device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu:0\")\n        net = torch.nn.PReLU().to(device)\n        with eval_mode(net):\n            pred = net(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device))\n        input_data = {\"pred\": pred, \"seg\": torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device)}\n        result = CopyItemsd(keys=\"pred\", times=1, names=\"pred_1\")(input_data)\n        self.assertTrue(\"pred_1\" in result)\n        result[\"pred_1\"] += 1.0\n        assert_allclose(result[\"pred\"], torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device))\n        assert_allclose(result[\"pred_1\"], torch.tensor([[1.0, 2.0], [2.0, 3.0]], device=device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_create_grid_and_affine.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import (\n    create_control_grid,\n    create_grid,\n    create_rotate,\n    create_scale,\n    create_shear,\n    create_translate,\n)\nfrom tests.test_utils import assert_allclose, is_tf32_env\n\n\nclass TestCreateGrid(unittest.TestCase):\n    def test_create_grid(self):\n        with self.assertRaisesRegex(TypeError, \"\"):\n            create_grid(None)\n        with self.assertRaisesRegex(TypeError, \"\"):\n            create_grid((1, 1), spacing=2.0)\n        with self.assertRaisesRegex(TypeError, \"\"):\n            create_grid((1, 1), spacing=2.0)\n\n        test_assert(create_grid, ((1, 1),), np.array([[[0.0]], [[0.0]], [[1.0]]]))\n\n        test_assert(create_grid, ((1, 1), None, False), np.array([[[0.0]], [[0.0]]]))\n\n        test_assert(create_grid, ((1, 1), (1.2, 1.3)), np.array([[[0.0]], [[0.0]], [[1.0]]]))\n\n        test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0)), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[1.0]]]]))\n\n        test_assert(create_grid, ((1, 1, 1), (1.2, 1.3, 1.0), False), np.array([[[[0.0]]], [[[0.0]]], [[[0.0]]]]))\n\n        g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=np.int32)\n        np.testing.assert_equal(g.dtype, np.int32)\n\n        g = create_grid((1, 1, 1), spacing=(1.2, 1.3, 1.0), dtype=torch.float64, backend=\"torch\")\n        np.testing.assert_equal(g.dtype, torch.float64)\n\n        test_assert(\n            create_grid,\n            ((2, 2, 2),),\n            np.array(\n                [\n                    [[[-0.5, -0.5], [-0.5, -0.5]], [[0.5, 0.5], [0.5, 0.5]]],\n                    [[[-0.5, -0.5], [0.5, 0.5]], [[-0.5, -0.5], [0.5, 0.5]]],\n                    [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]],\n                    [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]],\n                ]\n            ),\n        )\n\n        test_assert(\n            create_grid,\n            ((2, 2, 2), (1.2, 1.3, 1.0)),\n            np.array(\n                [\n                    [[[-0.6, -0.6], [-0.6, -0.6]], [[0.6, 0.6], [0.6, 0.6]]],\n                    [[[-0.65, -0.65], [0.65, 0.65]], [[-0.65, -0.65], [0.65, 0.65]]],\n                    [[[-0.5, 0.5], [-0.5, 0.5]], [[-0.5, 0.5], [-0.5, 0.5]]],\n                    [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]],\n                ]\n            ),\n        )\n\n    def test_create_control_grid(self):\n        with self.assertRaisesRegex(TypeError, \"\"):\n            create_control_grid(None, None)\n        with self.assertRaisesRegex(TypeError, \"\"):\n            create_control_grid((1, 1), 2.0)\n\n        test_assert(\n            create_control_grid,\n            ((1.0, 1.0), (1.0, 1.0)),\n            np.array(\n                [\n                    [[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]],\n                    [[-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0], [-1.0, 0.0, 1.0]],\n                    [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                ]\n            ),\n        )\n\n        test_assert(\n            create_control_grid,\n            ((1.0, 1.0), (2.0, 2.0)),\n            np.array(\n                [\n                    [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]],\n                    [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]],\n                    [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                ]\n            ),\n        )\n\n        test_assert(\n            create_control_grid,\n            ((2.0, 2.0), (1.0, 1.0)),\n            np.array(\n                [\n                    [[-1.5, -1.5, -1.5, -1.5], [-0.5, -0.5, -0.5, -0.5], [0.5, 0.5, 0.5, 0.5], [1.5, 1.5, 1.5, 1.5]],\n                    [[-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5], [-1.5, -0.5, 0.5, 1.5]],\n                    [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]],\n                ]\n            ),\n        )\n\n        test_assert(\n            create_control_grid,\n            ((2.0, 2.0), (2.0, 2.0)),\n            np.array(\n                [\n                    [[-3.0, -3.0, -3.0, -3.0], [-1.0, -1.0, -1.0, -1.0], [1.0, 1.0, 1.0, 1.0], [3.0, 3.0, 3.0, 3.0]],\n                    [[-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0], [-3.0, -1.0, 1.0, 3.0]],\n                    [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]],\n                ]\n            ),\n        )\n\n        test_assert(\n            create_control_grid,\n            ((1.0, 1.0, 1.0), (2.0, 2.0, 2.0), False),\n            np.array(\n                [\n                    [\n                        [[-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0], [-2.0, -2.0, -2.0]],\n                        [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],\n                        [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],\n                    ],\n                    [\n                        [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]],\n                        [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]],\n                        [[-2.0, -2.0, -2.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]],\n                    ],\n                    [\n                        [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]],\n                        [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]],\n                        [[-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0], [-2.0, 0.0, 2.0]],\n                    ],\n                ]\n            ),\n        )\n\n\ndef test_assert(func, params, expected):\n    gpu_test = (\"torch_gpu\",) if torch.cuda.is_available() else ()\n    for b in (\"torch\", \"numpy\") + gpu_test:\n        if b == \"torch_gpu\":\n            m = func(*params, device=\"cuda:0\", backend=\"torch\")\n        else:\n            m = func(*params, backend=b)\n        assert_allclose(m, expected, type_test=False, rtol=1e-2 if is_tf32_env() else 1e-5, atol=1e-5)\n\n\nclass TestCreateAffine(unittest.TestCase):\n    def test_create_rotate(self):\n        with self.assertRaisesRegex(TypeError, \"\"):\n            create_rotate(2, None)\n\n        with self.assertRaisesRegex(ValueError, \"\"):\n            create_rotate(5, 1)\n\n        test_assert(\n            create_rotate,\n            (2, 1.1),\n            np.array([[0.45359612, -0.89120736, 0.0], [0.89120736, 0.45359612, 0.0], [0.0, 0.0, 1.0]]),\n        )\n        test_assert(\n            create_rotate,\n            (3, 1.1),\n            np.array(\n                [\n                    [1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.45359612, -0.89120736, 0.0],\n                    [0.0, 0.89120736, 0.45359612, 0.0],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n        test_assert(\n            create_rotate,\n            (3, (1.1, 1)),\n            np.array(\n                [\n                    [0.54030231, 0.0, 0.84147098, 0.0],\n                    [0.74992513, 0.45359612, -0.48152139, 0.0],\n                    [-0.38168798, 0.89120736, 0.24507903, 0.0],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n        test_assert(\n            create_rotate,\n            (3, (1, 1, 1.1)),\n            np.array(\n                [\n                    [0.24507903, -0.48152139, 0.84147098, 0.0],\n                    [0.80270075, -0.38596121, -0.45464871, 0.0],\n                    [0.54369824, 0.78687425, 0.29192658, 0.0],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n        test_assert(\n            create_rotate,\n            (3, (0, 0, np.pi / 2)),\n            np.array([[0.0, -1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    def test_create_shear(self):\n        test_assert(create_shear, (2, 1.0), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))\n        test_assert(create_shear, (2, (2.0, 3.0)), np.array([[1.0, 2.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))\n        test_assert(\n            create_shear,\n            (3, 1.0),\n            np.array([[1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    def test_create_scale(self):\n        test_assert(create_scale, (2, 2), np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))\n        test_assert(create_scale, (2, [2, 2, 2]), np.array([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]]))\n        test_assert(\n            create_scale,\n            (3, [1.5, 2.4]),\n            np.array([[1.5, 0.0, 0.0, 0.0], [0.0, 2.4, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n        test_assert(\n            create_scale,\n            (3, 1.5),\n            np.array([[1.5, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n        test_assert(\n            create_scale,\n            (3, [1, 2, 3, 4, 5]),\n            np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 3.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    def test_create_translate(self):\n        test_assert(create_translate, (2, 2), np.array([[1.0, 0.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))\n        test_assert(create_translate, (2, [2, 2, 2]), np.array([[1.0, 0.0, 2.0], [0.0, 1.0, 2.0], [0.0, 0.0, 1.0]]))\n        test_assert(\n            create_translate,\n            (3, [1.5, 2.4]),\n            np.array([[1.0, 0.0, 0.0, 1.5], [0.0, 1.0, 0.0, 2.4], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n        test_assert(\n            create_translate,\n            (3, 1.5),\n            np.array([[1.0, 0.0, 0.0, 1.5], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n        test_assert(\n            create_translate,\n            (3, [1, 2, 3, 4, 5]),\n            np.array([[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 3.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_crop_foreground.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import CropForeground\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_COORDS, TESTS, TEST_LAZY_ERROR = [], [], []\n\nfor p in TEST_NDARRAYS_ALL:\n    TEST_COORDS.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": 0},\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n            p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]),\n            True,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"select_fn\": lambda x: x > 1, \"channel_indices\": None, \"margin\": 0},\n            p([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]]),\n            p([[[3]]]),\n            False,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": 0, \"margin\": 0},\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n            p([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]]),\n            True,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": 1},\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]),\n            True,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": [2, 1], \"allow_smaller\": True},\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),\n            True,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": [2, 1], \"allow_smaller\": False},\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),\n            p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),\n            True,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": 0, \"k_divisible\": 4},\n            p([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n            p([[[1, 2, 1, 0], [2, 3, 2, 0], [1, 2, 1, 0], [0, 0, 0, 0]]]),\n            True,\n        ]\n    )\n\n    TEST_LAZY_ERROR.append(\n        [\n            {\"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": 0, \"k_divisible\": 10},\n            p([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]),\n            p(np.zeros((1, 0, 0), dtype=np.int64)),\n            True,\n        ]\n    )\n\n\nclass TestCropForeground(unittest.TestCase):\n    @parameterized.expand(TEST_COORDS + TESTS)\n    def test_value(self, arguments, image, expected_data, _):\n        cropper = CropForeground(**arguments)\n        result = cropper(image)\n        assert_allclose(result, expected_data, type_test=False)\n        self.assertIsInstance(result, MetaTensor)\n        self.assertEqual(len(result.applied_operations), 1)\n        inv = cropper.inverse(result)\n        self.assertIsInstance(inv, MetaTensor)\n        self.assertEqual(inv.applied_operations, [])\n        self.assertTupleEqual(inv.shape, image.shape)\n\n    @parameterized.expand(TEST_COORDS)\n    def test_return_coords(self, arguments, image, _expected_data, _align_corners):\n        arguments[\"return_coords\"] = True\n        _, start_coord, end_coord = CropForeground(**arguments)(image)\n        arguments[\"return_coords\"] = False\n        np.testing.assert_allclose(start_coord, np.asarray([1, 1]))\n        np.testing.assert_allclose(end_coord, np.asarray([4, 4]))\n\n    @parameterized.expand(TEST_COORDS + TESTS)\n    def test_pending_ops(self, input_param, image, _expected_data, align_corners):\n        crop_fn = CropForeground(**input_param)\n        # non-lazy\n        expected = crop_fn(image)\n        self.assertIsInstance(expected, MetaTensor)\n        # lazy\n        crop_fn.lazy = True\n        pending_result = crop_fn(image)\n        self.assertIsInstance(pending_result, MetaTensor)\n        assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n        assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n        # only support nearest\n        overrides = {\"mode\": \"nearest\", \"align_corners\": align_corners}\n        result = apply_pending(pending_result, overrides=overrides)[0]\n        # compare\n        assert_allclose(result, expected, rtol=1e-5)\n\n    @parameterized.expand(TEST_LAZY_ERROR)\n    @unittest.skipIf(USE_COMPILED, \"skip errors whe use compiled\")\n    def test_lazy_error(self, input_param, image, _expected_data, align_corners):\n        with self.assertRaises(ValueError):\n            crop_fn = CropForeground(**input_param)\n            # lazy\n            crop_fn.lazy = True\n            pending_result = crop_fn(image)\n            overrides = {\"mode\": \"nearest\", \"align_corners\": align_corners}\n            return apply_pending(pending_result, overrides=overrides)[0]\n\n    @parameterized.expand(TEST_COORDS + TESTS)\n    def test_inverse_pending_ops(self, input_param, image, _expected_data, align_corners):\n        crop_fn = CropForeground(**input_param)\n        crop_fn.lazy = True\n        pending_result = crop_fn(image)\n        self.assertIsInstance(pending_result, MetaTensor)\n        result = apply_pending(pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": align_corners})[0]\n        inverted = crop_fn.inverse(result)\n        self.assertEqual(image.shape, inverted.shape)\n        self.assertTrue((not inverted.applied_operations) and (not inverted.pending_operations))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_crop_foregroundd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import CropForegroundd\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_POSITION, TESTS = [], []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_POSITION.append(\n        [\n            {\n                \"keys\": [\"img\", \"label\"],\n                \"source_key\": \"label\",\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": 0,\n            },\n            {\n                \"img\": p(\n                    np.array([[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]])\n                ),\n                \"label\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]])\n                ),\n            },\n            p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])),\n            True,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": [\"img\"], \"source_key\": \"img\", \"select_fn\": lambda x: x > 1, \"channel_indices\": None, \"margin\": 0},\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])\n                )\n            },\n            p(np.array([[[3]]])),\n            False,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": [\"img\"], \"source_key\": \"img\", \"select_fn\": lambda x: x > 0, \"channel_indices\": 0, \"margin\": 0},\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                )\n            },\n            p(np.array([[[1, 2, 1], [2, 3, 2], [1, 2, 1]]])),\n            True,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": [\"img\"], \"source_key\": \"img\", \"select_fn\": lambda x: x > 0, \"channel_indices\": None, \"margin\": 1},\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])\n                )\n            },\n            p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]])),\n            True,\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"keys\": [\"img\"],\n                \"source_key\": \"img\",\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": [2, 1],\n                \"allow_smaller\": True,\n            },\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                )\n            },\n            p(np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])),\n            True,\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"keys\": [\"img\"],\n                \"source_key\": \"img\",\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": [2, 1],\n                \"allow_smaller\": False,\n            },\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                )\n            },\n            p(\n                np.array(\n                    [\n                        [\n                            [0, 0, 0, 0, 0],\n                            [0, 0, 0, 0, 0],\n                            [0, 1, 2, 1, 0],\n                            [0, 2, 3, 2, 0],\n                            [0, 1, 2, 1, 0],\n                            [0, 0, 0, 0, 0],\n                            [0, 0, 0, 0, 0],\n                        ]\n                    ]\n                )\n            ),\n            True,\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"keys\": [\"img\"],\n                \"source_key\": \"img\",\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": 0,\n                \"margin\": 0,\n                \"k_divisible\": [4, 6],\n                \"mode\": \"constant\",\n            },\n            {\n                \"img\": p(\n                    np.array(\n                        [[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]],\n                        dtype=np.float32,\n                    )\n                )\n            },\n            p(np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 0], [2, 2, 3, 2, 2, 0], [1, 1, 2, 1, 1, 0]]])),\n            False,\n        ]\n    )\n\n\nclass TestCropForegroundd(unittest.TestCase):\n    @parameterized.expand(TEST_POSITION + TESTS)\n    def test_value(self, arguments, input_data, expected_data, _):\n        cropper = CropForegroundd(**arguments)\n        result = cropper(input_data)\n        assert_allclose(result[\"img\"], expected_data, type_test=\"tensor\")\n        if \"label\" in input_data and \"img\" in input_data:\n            self.assertTupleEqual(result[\"img\"].shape, result[\"label\"].shape)\n        inv = cropper.inverse(result)\n        self.assertTupleEqual(inv[\"img\"].shape, input_data[\"img\"].shape)\n        if \"label\" in input_data:\n            self.assertTupleEqual(inv[\"label\"].shape, input_data[\"label\"].shape)\n\n    @parameterized.expand(TEST_POSITION)\n    def test_foreground_position(self, arguments, input_data, _expected_data, _align_corners):\n        result = CropForegroundd(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"foreground_start_coord\"], np.array([1, 1]))\n        np.testing.assert_allclose(result[\"foreground_end_coord\"], np.array([4, 4]))\n\n        arguments[\"start_coord_key\"] = \"test_start_coord\"\n        arguments[\"end_coord_key\"] = \"test_end_coord\"\n        result = CropForegroundd(**arguments)(input_data)\n        np.testing.assert_allclose(result[\"test_start_coord\"], np.array([1, 1]))\n        np.testing.assert_allclose(result[\"test_end_coord\"], np.array([4, 4]))\n\n    @parameterized.expand(TEST_POSITION + TESTS)\n    def test_pending_ops(self, input_param, image, _expected_data, align_corners):\n        crop_fn = CropForegroundd(**input_param)\n        # non-lazy\n        expected = crop_fn(image)[\"img\"]\n        self.assertIsInstance(expected, MetaTensor)\n        # lazy\n        crop_fn.lazy = True\n        pending_result = crop_fn(image)[\"img\"]\n        self.assertIsInstance(pending_result, MetaTensor)\n        assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n        assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n        # only support nearest\n        overrides = {\"mode\": \"nearest\", \"align_corners\": align_corners}\n        result = apply_pending(pending_result, overrides=overrides)[0]\n        # compare\n        assert_allclose(result, expected, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_cucim_dict_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import CuCIMd\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import HAS_CUPY, skip_if_no_cuda\n\n_, has_cut = optional_import(\"cucim.core.operations.expose.transform\")\ncp, _ = optional_import(\"cupy\")\n\nset_determinism(seed=0)\n\nTEST_CASE_COLOR_JITTER_1 = [\n    {\"name\": \"color_jitter\", \"brightness\": 0.0, \"contrast\": 0.0, \"saturation\": 0.0, \"hue\": 0.0},\n    np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32),\n    np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_COLOR_JITTER_2 = [\n    {\"name\": \"color_jitter\", \"brightness\": 0.0, \"contrast\": 0.0, \"saturation\": 0.0, \"hue\": 0.0},\n    np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8),\n    np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8),\n]\n\nTEST_CASE_FLIP_1 = [\n    {\"name\": \"image_flip\", \"spatial_axis\": -1},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_ROTATE_1 = [\n    {\"name\": \"image_rotate_90\", \"k\": 1, \"spatial_axis\": (-2, -1)},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_SCALE_INTENSITY_1 = [\n    {\"name\": \"scale_intensity_range\", \"a_min\": 0.0, \"a_max\": 4.0, \"b_min\": 0.0, \"b_max\": 1.0, \"clip\": False},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32),\n]\n\nTEST_CASE_ZOOM_1 = [\n    {\"name\": \"zoom\", \"zoom_factor\": (0.5, 0.5)},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]),\n]\n\n\n@skip_if_no_cuda\n@unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n@unittest.skipUnless(has_cut, \"cuCIM transforms are required.\")\nclass TestCuCIMDict(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_numpy_single(self, params, input, expected):\n        input = {\"image\": input}\n        output = CuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, np.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_numpy_batch(self, params, input, expected):\n        input = {\"image\": input[cp.newaxis, ...]}\n        expected = expected[cp.newaxis, ...]\n        output = CuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, np.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_cupy_single(self, params, input, expected):\n        input = {\"image\": cp.asarray(input)}\n        expected = cp.asarray(expected)\n        output = CuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, cp.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_cupy_batch(self, params, input, expected):\n        input = {\"image\": cp.asarray(input)[cp.newaxis, ...]}\n        expected = cp.asarray(expected)[cp.newaxis, ...]\n        output = CuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, cp.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_cucim_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import CuCIM\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import HAS_CUPY, skip_if_no_cuda\n\n_, has_cut = optional_import(\"cucim.core.operations.expose.transform\")\ncp, _ = optional_import(\"cupy\")\n\nset_determinism(seed=0)\n\nTEST_CASE_COLOR_JITTER_1 = [\n    {\"name\": \"color_jitter\", \"brightness\": 0.0, \"contrast\": 0.0, \"saturation\": 0.0, \"hue\": 0.0},\n    np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32),\n    np.array([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_COLOR_JITTER_2 = [\n    {\"name\": \"color_jitter\", \"brightness\": 0.0, \"contrast\": 0.0, \"saturation\": 0.0, \"hue\": 0.0},\n    np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8),\n    np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]], [[0, 1], [2, 3]]], dtype=np.uint8),\n]\n\nTEST_CASE_FLIP_1 = [\n    {\"name\": \"image_flip\", \"spatial_axis\": -1},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_ROTATE_1 = [\n    {\"name\": \"image_rotate_90\", \"k\": 1, \"spatial_axis\": (-2, -1)},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_SCALE_INTENSITY_1 = [\n    {\"name\": \"scale_intensity_range\", \"a_min\": 0.0, \"a_max\": 4.0, \"b_min\": 0.0, \"b_max\": 1.0, \"clip\": False},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32),\n]\n\nTEST_CASE_ZOOM_1 = [\n    {\"name\": \"zoom\", \"zoom_factor\": (0.5, 0.5)},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]),\n]\n\n\n@skip_if_no_cuda\n@unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n@unittest.skipUnless(has_cut, \"cuCIM transforms are required.\")\nclass TestCuCIM(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_numpy_single(self, params, input, expected):\n        output = CuCIM(**params)(input)\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, np.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_numpy_batch(self, params, input, expected):\n        input = input[cp.newaxis, ...]\n        expected = expected[cp.newaxis, ...]\n        output = CuCIM(**params)(input)\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, np.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_cupy_single(self, params, input, expected):\n        input = cp.asarray(input)\n        expected = cp.asarray(expected)\n        output = CuCIM(**params)(input)\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, cp.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_COLOR_JITTER_2,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_ROTATE_1,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n        ]\n    )\n    def test_tramsforms_cupy_batch(self, params, input, expected):\n        input = cp.asarray(input)[cp.newaxis, ...]\n        expected = cp.asarray(expected)[cp.newaxis, ...]\n        output = CuCIM(**params)(input)\n        self.assertEqual(output.dtype, expected.dtype)\n        self.assertIsInstance(output, cp.ndarray)\n        cp.testing.assert_allclose(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_data_stats.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport sys\nimport tempfile\nimport unittest\nfrom io import StringIO\nfrom unittest.mock import patch\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import DataStats\n\nTEST_CASE_1 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": False,\n        \"data_shape\": False,\n        \"value_range\": False,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\",\n]\n\nTEST_CASE_2 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": False,\n        \"value_range\": False,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\",\n]\n\nTEST_CASE_3 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": False,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\",\n]\n\nTEST_CASE_4 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\nValue range: (0, 2)\",\n]\n\nTEST_CASE_5 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\nValue range: (0, 2)\\nValue: [[0 1]\\n [1 2]]\",\n]\n\nTEST_CASE_6 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": np.mean,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    (\n        \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\n\"\n        \"Value range: (0, 2)\\nValue: [[0 1]\\n [1 2]]\\nAdditional info: 1.0\"\n    ),\n]\n\nTEST_CASE_7 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": lambda x: torch.mean(x.float()),\n        \"name\": \"DataStats\",\n    },\n    torch.tensor([[0, 1], [1, 2]]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n    (\n        \"test data statistics:\\nType: <class 'torch.Tensor'>\\nShape: torch.Size([2, 2])\\nValue range: (0, 2)\\n\"\n        \"Value: tensor([[0, 1],\\n        [1, 2]])\\nAdditional info: 1.0\"\n    ),\n]\n\nTEST_CASE_8 = [\n    {\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": np.mean,\n        \"name\": \"DataStats\",\n    },\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\\nType: <class 'numpy.ndarray'> int64\\nShape: (2, 2)\\nValue range: (0, 2)\\n\"\n    \"Value: [[0 1]\\n [1 2]]\\nAdditional info: 1.0\\n\",\n]\n\nTEST_CASE_9 = [\n    np.array([[0, 1], [1, 2]]),\n    \"test data statistics:\\nType: <class 'numpy.ndarray'> int64\\nShape: (2, 2)\\nValue range: (0, 2)\\n\"\n    \"Value: [[0 1]\\n [1 2]]\\n\"\n    \"Meta info: '(input is not a MetaTensor)'\\n\"\n    \"Additional info: 1.0\\n\",\n]\n\nTEST_CASE_10 = [\n    MetaTensor(\n        torch.tensor([[0, 1], [1, 2]]),\n        affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),\n        meta={\"some\": \"info\"},\n    ),\n    \"test data statistics:\\nType: <class 'monai.data.meta_tensor.MetaTensor'> torch.int64\\n\"\n    \"Shape: torch.Size([2, 2])\\nValue range: (0, 2)\\n\"\n    \"Value: tensor([[0, 1],\\n        [1, 2]])\\n\"\n    \"Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\\n\"\n    \"        [0., 2., 0., 0.],\\n\"\n    \"        [0., 0., 2., 0.],\\n\"\n    \"        [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\\n\"\n    \"Additional info: 1.0\\n\",\n]\n\n\nclass TestDataStats(unittest.TestCase):\n\n    @parameterized.expand(\n        [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]\n    )\n    def test_value(self, input_param, input_data, expected_print):\n        transform = DataStats(**input_param)\n        _ = transform(input_data)\n\n    @parameterized.expand([TEST_CASE_9, TEST_CASE_10])\n    def test_file(self, input_data, expected_print):\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_data_stats.log\")\n            handler = logging.FileHandler(filename, mode=\"w\")\n            handler.setLevel(logging.INFO)\n            name = \"DataStats\"\n            logger = logging.getLogger(name)\n            logger.addHandler(handler)\n            input_param = {\n                \"prefix\": \"test data\",\n                \"data_type\": True,\n                \"data_shape\": True,\n                \"value_range\": True,\n                \"data_value\": True,\n                \"meta_info\": True,\n                \"additional_info\": np.mean,\n                \"name\": name,\n            }\n            transform = DataStats(**input_param)\n            _ = transform(input_data)\n            for h in logger.handlers[:]:\n                h.close()\n                logger.removeHandler(h)\n            with open(filename) as f:\n                content = f.read()\n            if sys.platform != \"win32\":\n                self.assertEqual(content, expected_print)\n\n    def test_multiple_data_stats(self):\n        with patch(\"sys.stdout\", new=StringIO()) as out:\n            input_data = np.array([[0, 1], [1, 2]])\n            transform = DataStats()\n            _ = DataStats()\n            _ = transform(input_data)\n            print(out.getvalue().strip())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_data_statsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import DataStatsd\n\nTEST_CASE_1 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": False,\n        \"data_shape\": False,\n        \"value_range\": False,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\",\n]\n\nTEST_CASE_2 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": False,\n        \"value_range\": False,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\",\n]\n\nTEST_CASE_3 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": False,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\",\n]\n\nTEST_CASE_4 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": False,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\nValue range: (0, 2)\",\n]\n\nTEST_CASE_5 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": None,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\nValue range: (0, 2)\\nValue: [[0 1]\\n [1 2]]\",\n]\n\nTEST_CASE_6 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": np.mean,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    (\n        \"test data statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\n\"\n        \"Value range: (0, 2)\\nValue: [[0 1]\\n [1 2]]\\nAdditional info: 1.0\"\n    ),\n]\n\nTEST_CASE_7 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"additional_info\": lambda x: torch.mean(x.float()),\n        \"name\": \"DataStats\",\n    },\n    {\"img\": torch.tensor([[0, 1], [1, 2]]).to(\"cuda\" if torch.cuda.is_available() else \"cpu\")},\n    (\n        \"test data statistics:\\nType: <class 'torch.Tensor'>\\nShape: torch.Size([2, 2])\\nValue range: (0, 2)\\n\"\n        \"Value: tensor([[0, 1],\\n        [1, 2]])\\nAdditional info: 1.0\"\n    ),\n]\n\nTEST_CASE_8 = [\n    {\n        \"keys\": (\"img\", \"affine\"),\n        \"prefix\": (\"image\", \"affine\"),\n        \"data_type\": True,\n        \"data_shape\": True,\n        \"value_range\": (True, False),\n        \"data_value\": (False, True),\n        \"additional_info\": (np.mean, None),\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]]), \"affine\": np.eye(2, 2)},\n    \"affine statistics:\\nType: <class 'numpy.ndarray'>\\nShape: (2, 2)\\nValue: [[1. 0.]\\n [0. 1.]]\",\n]\n\nTEST_CASE_9 = [\n    {\n        \"keys\": \"img\",\n        \"prefix\": \"test data\",\n        \"data_shape\": True,\n        \"value_range\": True,\n        \"data_value\": True,\n        \"meta_info\": False,\n        \"additional_info\": np.mean,\n        \"name\": \"DataStats\",\n    },\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\\nType: <class 'numpy.ndarray'> int64\\nShape: (2, 2)\\nValue range: (0, 2)\\n\"\n    \"Value: [[0 1]\\n [1 2]]\\nAdditional info: 1.0\\n\",\n]\n\nTEST_CASE_10 = [\n    {\"img\": np.array([[0, 1], [1, 2]])},\n    \"test data statistics:\\nType: <class 'numpy.ndarray'> int64\\nShape: (2, 2)\\nValue range: (0, 2)\\n\"\n    \"Value: [[0 1]\\n [1 2]]\\n\"\n    \"Meta info: '(input is not a MetaTensor)'\\n\"\n    \"Additional info: 1.0\\n\",\n]\n\nTEST_CASE_11 = [\n    {\n        \"img\": (\n            MetaTensor(\n                torch.tensor([[0, 1], [1, 2]]),\n                affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),\n                meta={\"some\": \"info\"},\n            )\n        )\n    },\n    \"test data statistics:\\nType: <class 'monai.data.meta_tensor.MetaTensor'> torch.int64\\n\"\n    \"Shape: torch.Size([2, 2])\\nValue range: (0, 2)\\n\"\n    \"Value: tensor([[0, 1],\\n        [1, 2]])\\n\"\n    \"Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\\n\"\n    \"        [0., 2., 0., 0.],\\n\"\n    \"        [0., 0., 2., 0.],\\n\"\n    \"        [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\\n\"\n    \"Additional info: 1.0\\n\",\n]\n\n\nclass TestDataStatsd(unittest.TestCase):\n\n    @parameterized.expand(\n        [\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n        ]\n    )\n    def test_value(self, input_param, input_data, expected_print):\n        transform = DataStatsd(**input_param)\n        _ = transform(input_data)\n\n    @parameterized.expand([TEST_CASE_10, TEST_CASE_11])\n    def test_file(self, input_data, expected_print):\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_stats.log\")\n            handler = logging.FileHandler(filename, mode=\"w\")\n            handler.setLevel(logging.INFO)\n            name = \"DataStats\"\n            logger = logging.getLogger(name)\n            logger.addHandler(handler)\n            input_param = {\n                \"keys\": \"img\",\n                \"prefix\": \"test data\",\n                \"data_shape\": True,\n                \"value_range\": True,\n                \"data_value\": True,\n                \"meta_info\": True,\n                \"additional_info\": np.mean,\n                \"name\": name,\n            }\n            transform = DataStatsd(**input_param)\n            _ = transform(input_data)\n            for h in logger.handlers[:]:\n                h.close()\n                logger.removeHandler(h)\n            del handler\n            with open(filename) as f:\n                content = f.read()\n            if sys.platform != \"win32\":\n                self.assertEqual(content, expected_print)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_delete_itemsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport time\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import DeleteItemsd\nfrom monai.utils.enums import PostFix\n\nTEST_CASE_1 = [{\"keys\": [str(i) for i in range(30)]}, 20]\n\nTEST_CASE_2 = [{\"keys\": [\"image/\" + str(i) for i in range(30)], \"sep\": \"/\"}, 20]\n\nTEST_CASE_3 = [{\"keys\": \"meta_dict%0008\\\\|[0-9]\", \"sep\": \"%\", \"use_re\": True}]\n\n\nclass TestDeleteItemsd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_memory(self, input_param, expected_key_size):\n        input_data = {\"image\": {}} if \"sep\" in input_param else {}\n        for i in range(50):\n            if \"sep\" in input_param:\n                input_data[\"image\"][str(i)] = [time.time()] * 100000\n            else:\n                input_data[str(i)] = [time.time()] * 100000\n        result = DeleteItemsd(**input_param)(input_data)\n        if \"sep\" in input_param:\n            self.assertEqual(len(result[\"image\"].keys()), expected_key_size)\n        else:\n            self.assertEqual(len(result.keys()), expected_key_size)\n        self.assertGreaterEqual(\n            sys.getsizeof(input_data) * float(expected_key_size) / len(input_data), sys.getsizeof(result)\n        )\n\n    @parameterized.expand([TEST_CASE_3])\n    def test_re(self, input_param):\n        input_data = {\"image\": [1, 2, 3], PostFix.meta(): {\"0008|0005\": 1, \"0008|1050\": 2, \"0008test\": 3}}\n        result = DeleteItemsd(**input_param)(input_data)\n        self.assertEqual(result[PostFix.meta()][\"0008test\"], 3)\n        self.assertEqual(len(result[PostFix.meta()]), 1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_detect_envelope.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import DetectEnvelope\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nn_samples = 500\nhann_windowed_sine = np.sin(2 * np.pi * 10 * np.linspace(0, 1, n_samples)) * np.hanning(n_samples)\n\n# SINGLE-CHANNEL VALUE TESTS\n# using np.expand_dims() to add length 1 channel dimension at dimension 0\n\nTEST_CASE_1D_SINE = [\n    {},  # args (empty, so use default)\n    np.expand_dims(hann_windowed_sine, 0),  # Input data: Hann windowed sine wave\n    np.expand_dims(np.hanning(n_samples), 0),  # Expected output: the Hann window\n    1e-4,  # absolute tolerance\n]\n\nTEST_CASE_2D_SINE = [\n    {},  # args (empty, so use default (i.e. process along first spatial dimension, axis=1)\n    # Create 10 identical windowed sine waves as a 2D numpy array\n    np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),\n    # Expected output: Set of 10 identical Hann windows\n    np.expand_dims(np.stack([np.hanning(n_samples)] * 10, axis=1), 0),\n    1e-4,  # absolute tolerance\n]\n\nTEST_CASE_3D_SINE = [\n    {},  # args (empty, so use default (i.e. process along first spatial dimension, axis=1)\n    # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array\n    np.expand_dims(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), 0),\n    # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array\n    np.expand_dims(np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2), 0),\n    1e-4,  # absolute tolerance\n]\n\nTEST_CASE_2D_SINE_AXIS_1 = [\n    {\"axis\": 2},  # set axis argument to 2\n    # Create 10 identical windowed sine waves as a 2D numpy array\n    np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),\n    # Expected output: absolute value of each sample of the waveform, repeated (i.e. flat envelopes)\n    np.expand_dims(np.abs(np.repeat(hann_windowed_sine, 10).reshape((n_samples, 10))), 0),\n    1e-4,  # absolute tolerance\n]\n\nTEST_CASE_1D_SINE_PADDING_N = [\n    {\"n\": 512},  # args (empty, so use default)\n    np.expand_dims(hann_windowed_sine, 0),  # Input data: Hann windowed sine wave\n    np.expand_dims(np.concatenate([np.hanning(500), np.zeros(12)]), 0),  # Expected output: the Hann window\n    1e-3,  # absolute tolerance\n]\n\n# MULTI-CHANNEL VALUE TEST\n\nTEST_CASE_2_CHAN_3D_SINE = [\n    {},  # args (empty, so use default (i.e. process along first spatial dimension, axis=1)\n    # Create 100 identical windowed sine waves as a (n_samples x 10 x 10) 3D numpy array, twice (2 channels)\n    torch.as_tensor(np.stack([np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2)] * 2, axis=0)),\n    # Expected output: Set of 100 identical Hann windows in (n_samples x 10 x 10) 3D numpy array, twice (2 channels)\n    torch.as_tensor(np.stack([np.stack([np.stack([np.hanning(n_samples)] * 10, axis=1)] * 10, axis=2)] * 2, axis=0)),\n    1e-4,  # absolute tolerance\n]\n\n# EXCEPTION TESTS\n\nTEST_CASE_INVALID_AXIS_1 = [\n    {\"axis\": 3},  # set axis argument to 3 when only 3 dimensions (1 channel + 2 spatial)\n    np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),  # Create 2D dataset\n    \"__call__\",  # method expected to raise exception\n]\n\nTEST_CASE_INVALID_AXIS_2 = [\n    {\"axis\": -1},  # set axis argument negative\n    np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),  # Create 2D dataset\n    \"__init__\",  # method expected to raise exception\n]\n\nTEST_CASE_INVALID_N = [\n    {\"n\": 0},  # set FFT length to zero\n    np.expand_dims(np.stack([hann_windowed_sine] * 10, axis=1), 0),  # Create 2D dataset\n    \"__call__\",  # method expected to raise exception\n]\n\nTEST_CASE_INVALID_DTYPE = [\n    {},\n    np.expand_dims(np.array(hann_windowed_sine, dtype=complex), 0),  # complex numbers are invalid\n    \"__call__\",  # method expected to raise exception\n]\n\nTEST_CASE_INVALID_IMG_LEN = [\n    {},\n    np.expand_dims(np.array([]), 0),  # empty array is invalid\n    \"__call__\",  # method expected to raise exception\n]\n\nTEST_CASE_INVALID_OBJ = [{}, \"a string\", \"__call__\"]  # method expected to raise exception\n\n\nclass TestDetectEnvelope(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_1D_SINE,\n            TEST_CASE_2D_SINE,\n            TEST_CASE_3D_SINE,\n            TEST_CASE_2D_SINE_AXIS_1,\n            TEST_CASE_1D_SINE_PADDING_N,\n            TEST_CASE_2_CHAN_3D_SINE,\n        ]\n    )\n    def test_value(self, arguments, image, expected_data, atol):\n        for p in TEST_NDARRAYS:\n            result = DetectEnvelope(**arguments)(p(image))\n            assert_allclose(result, p(expected_data), atol=atol, type_test=\"tensor\")\n\n    @parameterized.expand(\n        [\n            TEST_CASE_INVALID_AXIS_1,\n            TEST_CASE_INVALID_AXIS_2,\n            TEST_CASE_INVALID_N,\n            TEST_CASE_INVALID_DTYPE,\n            TEST_CASE_INVALID_IMG_LEN,\n        ]\n    )\n    def test_value_error(self, arguments, image, method):\n        if method == \"__init__\":\n            self.assertRaises(ValueError, DetectEnvelope, **arguments)\n        elif method == \"__call__\":\n            self.assertRaises(ValueError, DetectEnvelope(**arguments), image)\n        else:\n            self.fail(\"Expected raising method invalid. Should be __init__ or __call__.\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_distance_transform_edt.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import DistanceTransformEDT, DistanceTransformEDTd\nfrom tests.test_utils import HAS_CUPY, assert_allclose, optional_import, skip_if_no_cuda\n\nmomorphology, has_cucim = optional_import(\"cucim.core.operations.morphology\")\nndimage, has_ndimage = optional_import(\"scipy.ndimage\")\ncp, _ = optional_import(\"cupy\")\n\nTEST_CASES = [\n    [\n        np.array(\n            ([[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]],), dtype=np.float32\n        ),\n        np.array(\n            [\n                [\n                    [0.0, 1.0, 1.4142, 2.2361, 3.0],\n                    [0.0, 0.0, 1.0, 2.0, 2.0],\n                    [0.0, 1.0, 1.4142, 1.4142, 1.0],\n                    [0.0, 1.0, 1.4142, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ]\n            ]\n        ),\n    ],\n    [  # Example 4D input to test channel-wise CuPy\n        np.array(\n            [[[[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]]]], dtype=np.float32\n        ),\n        np.array(\n            [\n                [\n                    [\n                        [0.0, 1.0, 1.4142, 2.2361, 3.0],\n                        [0.0, 0.0, 1.0, 2.0, 2.0],\n                        [0.0, 1.0, 1.4142, 1.4142, 1.0],\n                        [0.0, 1.0, 1.4142, 1.0, 0.0],\n                        [0.0, 1.0, 1.0, 0.0, 0.0],\n                    ]\n                ]\n            ]\n        ),\n    ],\n    [\n        np.array(\n            [\n                [\n                    [0.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 0.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 0.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 0.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 1.0],\n                    [0.0, 1.0, 1.0, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ],\n            ],\n            dtype=np.float32,\n        ),\n        np.array(\n            [\n                [\n                    [0.0, 1.0, 1.4142135, 2.236068, 3.0],\n                    [0.0, 0.0, 1.0, 2.0, 2.0],\n                    [0.0, 1.0, 1.4142135, 1.4142135, 1.0],\n                    [0.0, 1.0, 1.4142135, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 1.4142135, 2.236068, 3.0],\n                    [0.0, 0.0, 1.0, 2.0, 2.0],\n                    [0.0, 1.0, 1.4142135, 1.4142135, 1.0],\n                    [0.0, 1.0, 1.4142135, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 1.4142135, 2.236068, 3.0],\n                    [0.0, 0.0, 1.0, 2.0, 2.0],\n                    [0.0, 1.0, 1.4142135, 1.4142135, 1.0],\n                    [0.0, 1.0, 1.4142135, 1.0, 0.0],\n                    [0.0, 1.0, 1.0, 0.0, 0.0],\n                ],\n            ],\n            dtype=np.float32,\n        ),\n    ],\n]\n\nSAMPLING_TEST_CASES = [\n    [\n        2,\n        np.array(\n            ([[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]],), dtype=np.float32\n        ),\n        np.array(\n            [\n                [\n                    [0.0, 2.0, 2.828427, 4.472136, 6.0],\n                    [0.0, 0.0, 2.0, 4.0, 4.0],\n                    [0.0, 2.0, 2.828427, 2.828427, 2.0],\n                    [0.0, 2.0, 2.828427, 2.0, 0.0],\n                    [0.0, 2.0, 2.0, 0.0, 0.0],\n                ]\n            ]\n        ),\n    ]\n]\n\nRAISES_TEST_CASES = (\n    [  # Example 4D input. Should raise under CuPy\n        np.array(\n            [[[[[0, 1, 1, 1, 1], [0, 0, 1, 1, 1], [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 0]]]]],\n            dtype=np.float32,\n        )\n    ],\n)\n\n\nclass TestDistanceTransformEDT(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_scipy_transform(self, input, expected_output):\n        transform = DistanceTransformEDT()\n        output = transform(input)\n        assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @parameterized.expand(TEST_CASES)\n    def test_scipy_transformd(self, input, expected_output):\n        transform = DistanceTransformEDTd(keys=(\"to_transform\",))\n        data = {\"to_transform\": input}\n        data_ = transform(data)\n        output = data_[\"to_transform\"]\n        assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @parameterized.expand(SAMPLING_TEST_CASES)\n    def test_scipy_sampling(self, sampling, input, expected_output):\n        transform = DistanceTransformEDT(sampling=sampling)\n        output = transform(input)\n        assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)\n\n    @parameterized.expand(TEST_CASES)\n    @skip_if_no_cuda\n    @unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n    @unittest.skipUnless(momorphology, \"cuCIM transforms are required.\")\n    def test_cucim_transform(self, input, expected_output):\n        input_ = torch.tensor(input, device=\"cuda\")\n        transform = DistanceTransformEDT()\n        output = transform(input_)\n        assert_allclose(cp.asnumpy(output), cp.asnumpy(expected_output), atol=1e-4, rtol=1e-4, type_test=False)\n\n    @parameterized.expand(SAMPLING_TEST_CASES)\n    @skip_if_no_cuda\n    @unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n    @unittest.skipUnless(momorphology, \"cuCIM transforms are required.\")\n    def test_cucim_sampling(self, sampling, input, expected_output):\n        input_ = torch.tensor(input, device=\"cuda\")\n        transform = DistanceTransformEDT(sampling=sampling)\n        output = transform(input_)\n        assert_allclose(cp.asnumpy(output), cp.asnumpy(expected_output), atol=1e-4, rtol=1e-4, type_test=False)\n\n    @parameterized.expand(RAISES_TEST_CASES)\n    @skip_if_no_cuda\n    @unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n    @unittest.skipUnless(momorphology, \"cuCIM transforms are required.\")\n    def test_cucim_raises(self, raises):\n        \"\"\"Currently images of shape a certain shape are supported. This test checks for the according error message\"\"\"\n        input_ = torch.tensor(raises, device=\"cuda\")\n        transform = DistanceTransformEDT()\n        with self.assertRaises(RuntimeError):\n            transform(input_)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_divisible_pad.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import DivisiblePad\nfrom monai.utils.enums import NumpyPadMode, PytorchPadMode\nfrom tests.padders import PadTest\n\nTESTS = []\n\n# pad first dim to be divisible by 7, the second unchanged.\nTESTS.append([{\"k\": (7, -1)}, (3, 8, 7), (3, 14, 7)])\n# pad all dimensions to be divisible by 5\nTESTS.append([{\"k\": 5, \"method\": \"end\"}, (3, 10, 5, 17), (3, 10, 5, 20)])\n\n\nclass TestDivisiblePad(PadTest):\n    Padder = DivisiblePad\n\n    @parameterized.expand(TESTS)\n    def test_pad(self, input_param, input_shape, expected_shape):\n        modes = [\"constant\", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT]\n        self.pad_test(input_param, input_shape, expected_shape, modes)\n\n    def test_pad_kwargs(self):\n        kwargs = {\"k\": 5, \"method\": \"end\"}\n        unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)]\n        self.pad_test_kwargs(unchanged_slices, **kwargs)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.pad_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_divisible_padd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import DivisiblePadd\nfrom monai.utils.enums import NumpyPadMode, PytorchPadMode\nfrom tests.padders import PadTest\n\nTESTS = [\n    [{\"keys\": \"img\", \"k\": [4, 3, 2]}, (3, 8, 8, 4), (3, 8, 9, 4)],\n    [{\"keys\": \"img\", \"k\": 7, \"method\": \"end\"}, (3, 8, 7), (3, 14, 7)],\n]\n\n\nclass TestDivisiblePadd(PadTest):\n    Padder = DivisiblePadd\n\n    @parameterized.expand(TESTS)\n    def test_pad(self, input_param, input_shape, expected_shape):\n        modes = [\"constant\", NumpyPadMode.CONSTANT, PytorchPadMode.CONSTANT, \"edge\", NumpyPadMode.EDGE]\n        self.pad_test(input_param, input_shape, expected_shape, modes)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.pad_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_ensure_channel_first.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom PIL import Image\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import EnsureChannelFirst, LoadImage\nfrom monai.utils import optional_import\n\nitk, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\nITKReader, _ = optional_import(\"monai.data\", name=\"ITKReader\", as_type=\"decorator\")\n\nTEST_CASE_1 = [{}, [\"test_image.nii.gz\"], None]\n\nTEST_CASE_2 = [{}, [\"test_image.nii.gz\"], -1]\n\nTEST_CASE_3 = [{}, [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"], None]\n\nTEST_CASE_4 = [{\"reader\": ITKReader() if has_itk else \"itkreader\"}, [\"test_image.nii.gz\"], None]\n\nTEST_CASE_5 = [{\"reader\": ITKReader() if has_itk else \"itkreader\"}, [\"test_image.nii.gz\"], -1]\n\nTEST_CASE_6 = [\n    {\"reader\": ITKReader() if has_itk else \"itkreader\"},\n    [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"],\n    None,\n]\n\n\nclass TestEnsureChannelFirst(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])\n    @unittest.skipUnless(has_itk, \"itk not installed\")\n    def test_load_nifti(self, input_param, filenames, original_channel_dim):\n        # if original_channel_dim is None\n        test_image = np.random.rand(8, 8, 8)\n\n        if original_channel_dim == -1:\n            test_image = np.random.rand(8, 8, 8, 1)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])\n\n            result = LoadImage(image_only=True, **input_param)(filenames)\n            result = EnsureChannelFirst()(result)\n            self.assertEqual(result.shape[0], len(filenames))\n\n    @unittest.skipUnless(has_itk, \"itk not installed\")\n    def test_itk_dicom_series_reader(self):\n        filenames = \"tests/testing_data/CT_DICOM\"\n        itk.ProcessObject.SetGlobalWarningDisplay(False)\n        result = LoadImage(image_only=True, reader=ITKReader(pixel_type=itk.UC))(filenames)\n        result = EnsureChannelFirst()(result)\n        self.assertEqual(result.shape[0], 1)\n\n    def test_load_png(self):\n        spatial_size = (6, 6, 3)\n        test_image = np.random.randint(0, 6, size=spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.png\")\n            Image.fromarray(test_image.astype(\"uint8\")).save(filename)\n            result = LoadImage(image_only=True)(filename)\n            result = EnsureChannelFirst()(result)\n            self.assertEqual(result.shape[0], 3)\n            result = EnsureChannelFirst(channel_dim=-1)(result)\n            self.assertEqual(result.shape, (6, 3, 6))\n\n    def test_check(self):\n        im = torch.zeros(1, 2, 3)\n        im_nodim = MetaTensor(im, meta={\"original_channel_dim\": None})\n\n        with self.assertRaises(ValueError):  # not MetaTensor\n            EnsureChannelFirst(channel_dim=None)(im)\n        with self.assertRaises(ValueError):  # no meta\n            EnsureChannelFirst(channel_dim=None)(MetaTensor(im))\n        with self.assertRaises(ValueError):  # no meta channel\n            EnsureChannelFirst()(im_nodim)\n\n        with self.assertWarns(Warning):\n            EnsureChannelFirst(strict_check=False, channel_dim=None)(im)\n\n        with self.assertWarns(Warning):\n            EnsureChannelFirst(strict_check=False, channel_dim=None)(im_nodim)\n\n    def test_default_channel_first(self):\n        im = torch.rand(4, 4)\n        result = EnsureChannelFirst(channel_dim=\"no_channel\")(im)\n\n        self.assertEqual(result.shape, (1, 4, 4))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_ensure_channel_firstd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom PIL import Image\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import EnsureChannelFirstd, LoadImaged\n\nTEST_CASE_1 = [{\"keys\": \"img\"}, [\"test_image.nii.gz\"], None]\n\nTEST_CASE_2 = [{\"keys\": \"img\"}, [\"test_image.nii.gz\"], -1]\n\nTEST_CASE_3 = [{\"keys\": \"img\"}, [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"], None]\n\n\nclass TestEnsureChannelFirstd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_load_nifti(self, input_param, filenames, original_channel_dim):\n        # if original_channel_dim is None:\n        test_image = np.random.rand(8, 8, 8)\n\n        if original_channel_dim == -1:\n            test_image = np.random.rand(8, 8, 8, 1)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])\n            result = LoadImaged(**input_param)({\"img\": filenames})\n            result = EnsureChannelFirstd(**input_param)(result)\n            self.assertEqual(result[\"img\"].shape[0], len(filenames))\n\n    def test_load_png(self):\n        spatial_size = (6, 6, 3)\n        test_image = np.random.randint(0, 256, size=spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.png\")\n            Image.fromarray(test_image.astype(\"uint8\")).save(filename)\n            result = LoadImaged(keys=\"img\")({\"img\": filename})\n            result = EnsureChannelFirstd(keys=\"img\")(result)\n            self.assertEqual(result[\"img\"].shape[0], 3)\n\n    def test_exceptions(self):\n        im = torch.zeros((1, 2, 3))\n        im_nodim = MetaTensor(im, meta={\"original_channel_dim\": None})\n\n        with self.assertRaises(ValueError):  # no meta\n            EnsureChannelFirstd(\"img\", channel_dim=None)({\"img\": im})\n        with self.assertRaises(ValueError):  # no meta channel\n            EnsureChannelFirstd(\"img\", channel_dim=None)({\"img\": im_nodim})\n\n        with self.assertWarns(Warning):\n            EnsureChannelFirstd(\"img\", strict_check=False, channel_dim=None)({\"img\": im})\n\n        with self.assertWarns(Warning):\n            EnsureChannelFirstd(\"img\", strict_check=False, channel_dim=None)({\"img\": im_nodim})\n\n    def test_default_channel_first(self):\n        im = torch.rand(4, 4)\n        result = EnsureChannelFirstd(\"img\", channel_dim=\"no_channel\")({\"img\": im})\n\n        self.assertEqual(result[\"img\"].shape, (1, 4, 4))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_ensure_type.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor\nfrom monai.transforms import EnsureType\nfrom tests.test_utils import assert_allclose\n\n\nclass TestEnsureType(unittest.TestCase):\n    def test_array_input(self):\n        test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])]\n        if torch.cuda.is_available():\n            test_datas.append(test_datas[-1].cuda())\n        for test_data in test_datas:\n            for dtype in (\"tensor\", \"NUMPY\"):\n                result = EnsureType(dtype, dtype=np.float32 if dtype == \"NUMPY\" else None, device=\"cpu\")(test_data)\n                if dtype == \"NUMPY\":\n                    self.assertTrue(result.dtype == np.float32)\n                self.assertTrue(isinstance(result, torch.Tensor if dtype == \"tensor\" else np.ndarray))\n                assert_allclose(result, test_data, type_test=False)\n                self.assertTupleEqual(result.shape, (2, 2))\n\n    def test_single_input(self):\n        test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)]\n        if torch.cuda.is_available():\n            test_datas.append(test_datas[-1].cuda())\n        for test_data in test_datas:\n            for dtype in (\"tensor\", \"numpy\"):\n                result = EnsureType(data_type=dtype, device=\"cpu\")(test_data)\n                self.assertTrue(isinstance(result, torch.Tensor if dtype == \"tensor\" else np.ndarray))\n                if isinstance(test_data, bool):\n                    self.assertFalse(result)\n                else:\n                    assert_allclose(result, test_data, type_test=False)\n                self.assertEqual(result.ndim, 0)\n\n    def test_string(self):\n        for dtype in (\"tensor\", \"numpy\"):\n            # string input\n            result = EnsureType(data_type=dtype)(\"test_string\")\n            self.assertTrue(isinstance(result, str))\n            self.assertEqual(result, \"test_string\")\n            # numpy array of string\n            result = EnsureType(data_type=dtype)(np.array([\"test_string0\", \"test_string1\"]))\n            self.assertTrue(isinstance(result, np.ndarray))\n            self.assertEqual(result[1], \"test_string1\")\n\n    def test_list_tuple(self):\n        for dtype in (\"tensor\", \"numpy\"):\n            result = EnsureType(data_type=dtype, wrap_sequence=False, track_meta=True)([[1, 2], [3, 4]])\n            self.assertTrue(isinstance(result, list))\n            self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == \"tensor\" else np.ndarray))\n            assert_allclose(result[1][0], torch.as_tensor(3), type_test=False)\n            # tuple of numpy arrays\n            result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4])))\n            self.assertTrue(isinstance(result, tuple))\n            self.assertTrue(isinstance(result[0], torch.Tensor if dtype == \"tensor\" else np.ndarray))\n            assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False)\n\n    def test_dict(self):\n        # simulate complicated input data\n        test_data = {\n            \"img\": np.array([1.0, 2.0], dtype=np.float32),\n            \"meta\": {\"dims\": 3, \"size\": np.array([1, 2, 3]), \"path\": \"temp/test\"},\n            \"extra\": None,\n        }\n        for dtype in (\"tensor\", \"numpy\"):\n            result = EnsureType(data_type=dtype, track_meta=False)(test_data)\n            self.assertTrue(isinstance(result, dict))\n            self.assertTrue(isinstance(result[\"img\"], torch.Tensor if dtype == \"tensor\" else np.ndarray))\n            assert_allclose(result[\"img\"], torch.as_tensor([1.0, 2.0]), type_test=False)\n            self.assertTrue(isinstance(result[\"meta\"][\"size\"], torch.Tensor if dtype == \"tensor\" else np.ndarray))\n            assert_allclose(result[\"meta\"][\"size\"], torch.as_tensor([1, 2, 3]), type_test=False)\n            self.assertEqual(result[\"meta\"][\"path\"], \"temp/test\")\n            self.assertEqual(result[\"extra\"], None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_ensure_typed.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor\nfrom monai.transforms import EnsureTyped\nfrom tests.test_utils import assert_allclose\n\n\nclass TestEnsureTyped(unittest.TestCase):\n    def test_array_input(self):\n        test_datas = [np.array([[1, 2], [3, 4]]), torch.as_tensor([[1, 2], [3, 4]])]\n        if torch.cuda.is_available():\n            test_datas.append(test_datas[-1].cuda())\n        for test_data in test_datas:\n            for dtype in (\"tensor\", \"NUMPY\"):\n                result = EnsureTyped(\n                    keys=\"data\", data_type=dtype, dtype=np.float32 if dtype == \"NUMPY\" else None, device=\"cpu\"\n                )({\"data\": test_data})[\"data\"]\n                if dtype == \"NUMPY\":\n                    self.assertEqual(result.dtype, np.float32)\n                self.assertIsInstance(result, torch.Tensor if dtype == \"tensor\" else np.ndarray)\n                assert_allclose(result, test_data, type_test=False)\n                self.assertTupleEqual(result.shape, (2, 2))\n\n    def test_single_input(self):\n        test_datas = [5, 5.0, False, np.asarray(5), torch.tensor(5)]\n        if torch.cuda.is_available():\n            test_datas.append(test_datas[-1].cuda())\n        for test_data in test_datas:\n            for dtype in (\"tensor\", \"numpy\"):\n                result = EnsureTyped(keys=\"data\", data_type=dtype)({\"data\": test_data})[\"data\"]\n                self.assertIsInstance(result, torch.Tensor if dtype == \"tensor\" else np.ndarray)\n                if isinstance(test_data, bool):\n                    self.assertFalse(result)\n                else:\n                    assert_allclose(result, test_data, type_test=False)\n                self.assertEqual(result.ndim, 0)\n\n    def test_string(self):\n        for dtype in (\"tensor\", \"numpy\"):\n            # string input\n            result = EnsureTyped(keys=\"data\", data_type=dtype)({\"data\": \"test_string\"})[\"data\"]\n            self.assertIsInstance(result, str)\n            self.assertEqual(result, \"test_string\")\n            # numpy array of string\n            result = EnsureTyped(keys=\"data\", data_type=dtype)({\"data\": np.array([\"test_string\"])})[\"data\"]\n            self.assertIsInstance(result, np.ndarray)\n            self.assertEqual(result[0], \"test_string\")\n\n    def test_list_tuple(self):\n        for dtype in (\"tensor\", \"numpy\"):\n            result = EnsureTyped(keys=\"data\", data_type=dtype, wrap_sequence=False, track_meta=True)(\n                {\"data\": [[1, 2], [3, 4]]}\n            )[\"data\"]\n            self.assertIsInstance(result, list)\n            self.assertIsInstance(result[0][1], MetaTensor if dtype == \"tensor\" else np.ndarray)\n            assert_allclose(result[1][0], torch.as_tensor(3), type_test=False)\n            # tuple of numpy arrays\n            result = EnsureTyped(keys=\"data\", data_type=dtype, wrap_sequence=False)(\n                {\"data\": (np.array([1, 2]), np.array([3, 4]))}\n            )[\"data\"]\n            self.assertIsInstance(result, tuple)\n            self.assertIsInstance(result[0], torch.Tensor if dtype == \"tensor\" else np.ndarray)\n            assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False)\n\n    def test_dict(self):\n        # simulate complicated input data\n        test_data = {\n            \"img\": np.array([1.0, 2.0], dtype=np.float32),\n            \"meta\": {\"dims\": 3, \"size\": np.array([1, 2, 3]), \"path\": \"temp/test\"},\n            \"extra\": None,\n        }\n        for dtype in (\"tensor\", \"numpy\"):\n            trans = EnsureTyped(keys=[\"data\", \"label\"], data_type=dtype, dtype=[np.float32, np.int8], device=\"cpu\")(\n                {\"data\": test_data, \"label\": test_data}\n            )\n            for key in (\"data\", \"label\"):\n                result = trans[key]\n                self.assertIsInstance(result, dict)\n                self.assertIsInstance(result[\"img\"], torch.Tensor if dtype == \"tensor\" else np.ndarray)\n                self.assertIsInstance(result[\"meta\"][\"size\"], torch.Tensor if dtype == \"tensor\" else np.ndarray)\n                self.assertEqual(result[\"meta\"][\"path\"], \"temp/test\")\n                self.assertEqual(result[\"extra\"], None)\n                assert_allclose(result[\"img\"], torch.as_tensor([1.0, 2.0]), type_test=False)\n                assert_allclose(result[\"meta\"][\"size\"], torch.as_tensor([1, 2, 3]), type_test=False)\n            if dtype == \"numpy\":\n                self.assertEqual(trans[\"data\"][\"img\"].dtype, np.float32)\n                self.assertEqual(trans[\"label\"][\"img\"].dtype, np.int8)\n            else:\n                self.assertEqual(trans[\"data\"][\"img\"].dtype, torch.float32)\n                self.assertEqual(trans[\"label\"][\"img\"].dtype, torch.int8)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_fg_bg_to_indices.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import FgBgToIndices\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS_CASES = []\nfor p in TEST_NDARRAYS:\n    TESTS_CASES.append(\n        [\n            {\"image_threshold\": 0.0, \"output_shape\": None},\n            p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),\n            None,\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 4, 8]),\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"image_threshold\": 0.0, \"output_shape\": None},\n            p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),\n            p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]),\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 8]),\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"image_threshold\": 1.0, \"output_shape\": None},\n            p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]),\n            p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 8]),\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"image_threshold\": 1.0, \"output_shape\": None},\n            p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]),\n            p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 8]),\n        ]\n    )\n\n    TESTS_CASES.append(\n        [\n            {\"image_threshold\": 1.0, \"output_shape\": [3, 3]},\n            p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]),\n            p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]),\n            p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]),\n            p([[0, 0], [2, 2]]),\n        ]\n    )\n\n\nclass TestFgBgToIndices(unittest.TestCase):\n    @parameterized.expand(TESTS_CASES)\n    def test_type_shape(self, input_data, label, image, expected_fg, expected_bg):\n        fg_indices, bg_indices = FgBgToIndices(**input_data)(label, image)\n        assert_allclose(fg_indices, expected_fg)\n        assert_allclose(bg_indices, expected_bg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_fg_bg_to_indicesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import FgBgToIndicesd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASES = []\nfor p in TEST_NDARRAYS:\n    TEST_CASES.append(\n        [\n            {\"keys\": \"label\", \"image_key\": None, \"image_threshold\": 0.0, \"output_shape\": None},\n            {\"label\": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])},\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 4, 8]),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": \"label\", \"image_key\": \"image\", \"image_threshold\": 0.0, \"output_shape\": None},\n            {\"label\": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), \"image\": p([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])},\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 8]),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": \"label\", \"image_key\": \"image\", \"image_threshold\": 1.0, \"output_shape\": None},\n            {\"label\": p([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), \"image\": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 8]),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": \"label\", \"image_key\": \"image\", \"image_threshold\": 1.0, \"output_shape\": None},\n            {\"label\": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), \"image\": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},\n            p([1, 2, 3, 5, 6, 7]),\n            p([0, 8]),\n        ]\n    )\n\n    TEST_CASES.append(\n        [\n            {\"keys\": \"label\", \"image_key\": \"image\", \"image_threshold\": 1.0, \"output_shape\": [3, 3]},\n            {\"label\": p([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), \"image\": p([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])},\n            p([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]),\n            p([[0, 0], [2, 2]]),\n        ]\n    )\n\n\nclass TestFgBgToIndicesd(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_type_shape(self, input_data, data, expected_fg, expected_bg):\n        result = FgBgToIndicesd(**input_data)(data)\n        assert_allclose(result[\"label_fg_indices\"], expected_fg)\n        assert_allclose(result[\"label_bg_indices\"], expected_bg)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_fill_holes.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import FillHoles\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, clone\n\ngrid_1_raw = [[1, 1, 1], [1, 0, 1], [1, 1, 1]]\n\ngrid_2_raw = [[0, 1, 0], [1, 0, 1], [0, 1, 0]]\n\ngrid_3_raw = [[1, 1, 1], [1, 1, 1], [1, 1, 1]]\n\ngrid_4_raw = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]\n\ngrid_1 = torch.tensor([grid_1_raw])\n\ngrid_2 = torch.tensor([grid_2_raw])\n\ngrid_3 = torch.tensor([grid_3_raw])\n\ngrid_4 = torch.tensor([grid_4_raw])\n\ngrid_5 = torch.tensor([[[1, 1, 1], [1, 0, 0], [1, 1, 1]]])\n\ngrid_6 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 0, 2], [1, 1, 2, 2, 2]]])\n\ngrid_7 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 2, 2], [1, 1, 2, 2, 2]]])\n\nTEST_CASE_0 = [\"enclosed_default_full_connectivity_default_applied_labels\", {}, grid_1, grid_3]\n\nTEST_CASE_1 = [\"enclosed_full_connectivity_default_applied_labels\", {\"connectivity\": 2}, grid_1, grid_3]\n\nTEST_CASE_2 = [\n    \"enclosed_full_connectivity_applied_labels_same_single\",\n    {\"connectivity\": 2, \"applied_labels\": 1},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_3 = [\n    \"enclosed_full_connectivity_applied_labels_same_list\",\n    {\"connectivity\": 2, \"applied_labels\": [1]},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_4 = [\n    \"enclosed_full_connectivity_applied_labels_other_single\",\n    {\"connectivity\": 2, \"applied_labels\": 2},\n    grid_1,\n    grid_1,\n]\n\nTEST_CASE_5 = [\n    \"enclosed_full_connectivity_applied_labels_other_list\",\n    {\"connectivity\": 2, \"applied_labels\": [2]},\n    grid_1,\n    grid_1,\n]\n\nTEST_CASE_6 = [\n    \"enclosed_full_connectivity_applied_labels_same_and_other\",\n    {\"connectivity\": 2, \"applied_labels\": [1, 2]},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_7 = [\"enclosed_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_1, grid_3]\n\nTEST_CASE_8 = [\"enclosed_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_2, grid_4]\n\nTEST_CASE_9 = [\"open_full_connectivity_default_applied_labels\", {\"connectivity\": 2}, grid_2, grid_2]\n\nTEST_CASE_10 = [\"open_to_edge_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_5, grid_5]\n\nTEST_CASE_11 = [\"open_to_other_label_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_6, grid_7]\n\nTEST_CASE_12 = [\n    \"open_to_other_label_connectivity_1_applied_labels_other\",\n    {\"connectivity\": 1, \"applied_labels\": 1},\n    grid_6,\n    grid_6,\n]\n\nTEST_CASE_13 = [\n    \"numpy_enclosed_default_full_connectivity_default_applied_labels\",\n    {},\n    grid_1.cpu().numpy(),\n    grid_3.cpu().numpy(),\n]\n\nTEST_CASE_14 = [\n    \"3D_enclosed_full_connectivity_default_applied_labels\",\n    {\"connectivity\": 3},\n    torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]),\n    torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]),\n]\n\nTEST_CASE_15 = [\n    \"3D_enclosed_connectivity_1_default_applied_labels\",\n    {\"connectivity\": 1},\n    torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),\n    torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]),\n]\n\nTEST_CASE_16 = [\n    \"3D_open_full_connectivity_default_applied_labels\",\n    {\"connectivity\": 3},\n    torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),\n    torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),\n]\n\nTEST_CASE_17 = [\n    \"3D_open_to_edge_connectivity_1_default_applied_labels\",\n    {\"connectivity\": 1},\n    torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]),\n    torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]),\n]\n\nTEST_CASE_18 = [\n    \"enclosed_full_connectivity_applied_labels_with_background\",\n    {\"connectivity\": 2, \"applied_labels\": [0, 1]},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_19 = [\n    \"enclosed_full_connectivity_applied_labels_only_background\",\n    {\"connectivity\": 2, \"applied_labels\": [0]},\n    grid_1,\n    grid_1,\n]\n\nTEST_CASE_20 = [\n    \"one-hot_enclosed_connectivity_1_default_applied_labels\",\n    {\"connectivity\": 1},\n    torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),\n    torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]),\n]\n\nTEST_CASE_21 = [\n    \"one-hot_enclosed_connectivity_1_applied_labels_2\",\n    {\"connectivity\": 1, \"applied_labels\": [2]},\n    torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),\n    torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]),\n]\n\nTEST_CASE_22 = [\n    \"one-hot_full_connectivity_applied_labels_2\",\n    {\"connectivity\": 2},\n    torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),\n    torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]),\n]\n\nVALID_CASES = [\n    TEST_CASE_0,\n    TEST_CASE_1,\n    TEST_CASE_2,\n    TEST_CASE_3,\n    TEST_CASE_4,\n    TEST_CASE_5,\n    TEST_CASE_6,\n    TEST_CASE_7,\n    TEST_CASE_8,\n    TEST_CASE_9,\n    TEST_CASE_10,\n    TEST_CASE_11,\n    TEST_CASE_12,\n    TEST_CASE_13,\n    TEST_CASE_14,\n    TEST_CASE_15,\n    TEST_CASE_16,\n    TEST_CASE_17,\n    TEST_CASE_18,\n    TEST_CASE_19,\n    TEST_CASE_20,\n    TEST_CASE_21,\n    TEST_CASE_22,\n]\n\n\nclass TestFillHoles(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, args, input_image, expected):\n        converter = FillHoles(**args)\n        for p in TEST_NDARRAYS:\n            result = converter(p(clone(input_image)))\n            assert_allclose(result, p(expected), type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_fill_holesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import FillHolesd\nfrom monai.utils.enums import CommonKeys\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, clone\n\ngrid_1_raw = [[1, 1, 1], [1, 0, 1], [1, 1, 1]]\n\ngrid_2_raw = [[0, 1, 0], [1, 0, 1], [0, 1, 0]]\n\ngrid_3_raw = [[1, 1, 1], [1, 1, 1], [1, 1, 1]]\n\ngrid_4_raw = [[0, 1, 0], [1, 1, 1], [0, 1, 0]]\n\ngrid_1 = torch.tensor([grid_1_raw])\n\ngrid_2 = torch.tensor([grid_2_raw])\n\ngrid_3 = torch.tensor([grid_3_raw])\n\ngrid_4 = torch.tensor([grid_4_raw])\n\ngrid_5 = torch.tensor([[[1, 1, 1], [1, 0, 0], [1, 1, 1]]])\n\ngrid_6 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 0, 2], [1, 1, 2, 2, 2]]])\n\ngrid_7 = torch.tensor([[[1, 1, 2, 2, 2], [1, 0, 2, 2, 2], [1, 1, 2, 2, 2]]])\n\nTEST_CASE_0 = [\"enclosed_default_full_connectivity_default_applied_labels\", {}, grid_1, grid_3]\n\nTEST_CASE_1 = [\"enclosed_full_connectivity_default_applied_labels\", {\"connectivity\": 2}, grid_1, grid_3]\n\nTEST_CASE_2 = [\n    \"enclosed_full_connectivity_applied_labels_same_single\",\n    {\"connectivity\": 2, \"applied_labels\": 1},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_3 = [\n    \"enclosed_full_connectivity_applied_labels_same_list\",\n    {\"connectivity\": 2, \"applied_labels\": [1]},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_4 = [\n    \"enclosed_full_connectivity_applied_labels_other_single\",\n    {\"connectivity\": 2, \"applied_labels\": 2},\n    grid_1,\n    grid_1,\n]\n\nTEST_CASE_5 = [\n    \"enclosed_full_connectivity_applied_labels_other_list\",\n    {\"connectivity\": 2, \"applied_labels\": [2]},\n    grid_1,\n    grid_1,\n]\n\nTEST_CASE_6 = [\n    \"enclosed_full_connectivity_applied_labels_same_and_other\",\n    {\"connectivity\": 2, \"applied_labels\": [1, 2]},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_7 = [\"enclosed_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_1, grid_3]\n\nTEST_CASE_8 = [\"enclosed_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_2, grid_4]\n\nTEST_CASE_9 = [\"open_full_connectivity_default_applied_labels\", {\"connectivity\": 2}, grid_2, grid_2]\n\nTEST_CASE_10 = [\"open_to_edge_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_5, grid_5]\n\nTEST_CASE_11 = [\"open_to_other_label_connectivity_1_default_applied_labels\", {\"connectivity\": 1}, grid_6, grid_7]\n\nTEST_CASE_12 = [\n    \"open_to_other_label_connectivity_1_applied_labels_other\",\n    {\"connectivity\": 1, \"applied_labels\": 1},\n    grid_6,\n    grid_6,\n]\n\nTEST_CASE_13 = [\n    \"numpy_enclosed_default_full_connectivity_default_applied_labels\",\n    {},\n    grid_1.cpu().numpy(),\n    grid_3.cpu().numpy(),\n]\n\nTEST_CASE_14 = [\n    \"3D_enclosed_full_connectivity_default_applied_labels\",\n    {\"connectivity\": 3},\n    torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]),\n    torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]),\n]\n\nTEST_CASE_15 = [\n    \"3D_enclosed_connectivity_1_default_applied_labels\",\n    {\"connectivity\": 1},\n    torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),\n    torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]),\n]\n\nTEST_CASE_16 = [\n    \"3D_open_full_connectivity_default_applied_labels\",\n    {\"connectivity\": 3},\n    torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),\n    torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),\n]\n\nTEST_CASE_17 = [\n    \"3D_open_to_edge_connectivity_1_default_applied_labels\",\n    {\"connectivity\": 1},\n    torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]),\n    torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]),\n]\n\nTEST_CASE_18 = [\n    \"enclosed_full_connectivity_applied_labels_with_background\",\n    {\"connectivity\": 2, \"applied_labels\": [0, 1]},\n    grid_1,\n    grid_3,\n]\n\nTEST_CASE_19 = [\n    \"enclosed_full_connectivity_applied_labels_only_background\",\n    {\"connectivity\": 2, \"applied_labels\": [0]},\n    grid_1,\n    grid_1,\n]\n\nTEST_CASE_20 = [\n    \"one-hot_enclosed_connectivity_1_default_applied_labels\",\n    {\"connectivity\": 1},\n    torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),\n    torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]),\n]\n\nTEST_CASE_21 = [\n    \"one-hot_enclosed_connectivity_1_applied_labels_2\",\n    {\"connectivity\": 1, \"applied_labels\": [2]},\n    torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),\n    torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]),\n]\n\nTEST_CASE_22 = [\n    \"one-hot_full_connectivity_applied_labels_2\",\n    {\"connectivity\": 2},\n    torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),\n    torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]),\n]\n\nVALID_CASES = [\n    TEST_CASE_0,\n    TEST_CASE_1,\n    TEST_CASE_2,\n    TEST_CASE_3,\n    TEST_CASE_4,\n    TEST_CASE_5,\n    TEST_CASE_6,\n    TEST_CASE_7,\n    TEST_CASE_8,\n    TEST_CASE_9,\n    TEST_CASE_10,\n    TEST_CASE_11,\n    TEST_CASE_12,\n    TEST_CASE_13,\n    TEST_CASE_14,\n    TEST_CASE_15,\n    TEST_CASE_16,\n    TEST_CASE_17,\n    TEST_CASE_18,\n    TEST_CASE_19,\n    TEST_CASE_20,\n    TEST_CASE_21,\n    TEST_CASE_22,\n]\n\n\nclass TestFillHoles(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, args, input_image, expected):\n        key = CommonKeys.IMAGE\n        converter = FillHolesd(keys=key, **args)\n        for p in TEST_NDARRAYS:\n            result = converter({key: p(clone(input_image))})[key]\n            assert_allclose(result, p(expected), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_flatten_sub_keysd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import FlattenSubKeysd\n\nA = torch.randn(2, 2)\nB = torch.randn(3, 3)\nC = torch.randn(1, 3)\nI = torch.randn(2, 3)\nD1 = {\"a\": A, \"b\": B}\nD2 = {\"a\": A, \"b\": B, \"c\": C}\n\nTEST_CASE_0 = [{\"keys\": \"pred\"}, {\"image\": I, \"pred\": D1}, {\"a\": A, \"b\": B, \"image\": I}]\nTEST_CASE_1 = [{\"keys\": \"pred\"}, {\"image\": I, \"pred\": D2}, {\"a\": A, \"b\": B, \"c\": C, \"image\": I}]\nTEST_CASE_2 = [{\"keys\": \"pred\", \"sub_keys\": [\"a\", \"b\"]}, {\"image\": I, \"pred\": D1}, {\"a\": A, \"b\": B, \"image\": I}]\nTEST_CASE_3 = [{\"keys\": \"pred\", \"sub_keys\": [\"a\", \"b\"]}, {\"image\": I, \"pred\": D2}, {\"a\": A, \"b\": B, \"image\": I}]\nTEST_CASE_4 = [\n    {\"keys\": \"pred\", \"sub_keys\": [\"a\", \"b\"], \"delete_keys\": False},\n    {\"image\": I, \"pred\": D1},\n    {\"a\": A, \"b\": B, \"image\": I, \"pred\": D1},\n]\nTEST_CASE_5 = [\n    {\"keys\": \"pred\", \"sub_keys\": [\"a\", \"b\"], \"prefix\": \"new\"},\n    {\"image\": I, \"pred\": D2},\n    {\"new_a\": A, \"new_b\": B, \"image\": I},\n]\nTEST_CASE_ERROR_1 = [  # error for duplicate key\n    {\"keys\": \"pred\", \"sub_keys\": [\"a\", \"b\"]},\n    {\"image\": I, \"pred\": D2, \"a\": None},\n]\n\n\nclass TestFlattenSubKeysd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_dict(self, params, input_data, expected):\n        result = FlattenSubKeysd(**params)(input_data)\n        self.assertSetEqual(set(result.keys()), set(expected.keys()))\n        for k in expected:\n            self.assertEqual(id(result[k]), id(expected[k]))\n\n    @parameterized.expand([TEST_CASE_ERROR_1])\n    def test_error(self, params, input_data):\n        with self.assertRaises(ValueError):\n            FlattenSubKeysd(**params)(input_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_flip.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Flip\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import (\n    TEST_DEVICES,\n    TEST_NDARRAYS_ALL,\n    NumpyImageTestCase2D,\n    assert_allclose,\n    test_local_inversion,\n)\n\nINVALID_CASES = [(\"wrong_axis\", [\"s\", 1], TypeError), (\"not_numbers\", \"s\", TypeError)]\n\nVALID_CASES = [(\"no_axis\", None), (\"one_axis\", 1), (\"many_axis\", [0, 1]), (\"negative_axis\", [0, -1])]\n\nTORCH_CASES = []\nfor track_meta in (False, True):\n    for device in TEST_DEVICES:\n        TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device])\n\n\nclass TestFlip(NumpyImageTestCase2D):\n    @parameterized.expand(INVALID_CASES)\n    def test_invalid_inputs(self, _, spatial_axis, raises):\n        with self.assertRaises(raises):\n            flip = Flip(spatial_axis)\n            flip(self.imt[0])\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, spatial_axis):\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            init_param = {\"spatial_axis\": spatial_axis}\n            flip = Flip(**init_param)\n            expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            call_param = {\"img\": im}\n            result = flip(**call_param)\n            test_resampler_lazy(flip, result, init_param, call_param)\n            assert_allclose(result, p(expected), type_test=\"tensor\")\n            test_local_inversion(flip, result, im)\n\n    @parameterized.expand(TORCH_CASES)\n    def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device):\n        set_track_meta(track_meta)\n        img = img.to(device)\n        init_param = {\"spatial_axis\": spatial_axis}\n        xform = Flip(**init_param)\n        call_param = {\"img\": img}\n        res = xform(**call_param)  # type: ignore[arg-type]\n        self.assertEqual(img.shape, res.shape)\n        if track_meta:\n            test_resampler_lazy(xform, res, init_param, call_param)\n            self.assertIsInstance(res, MetaTensor)\n        else:\n            self.assertNotIsInstance(res, MetaTensor)\n            self.assertIsInstance(res, torch.Tensor)\n            with self.assertRaisesRegex(ValueError, \"MetaTensor\"):\n                xform.inverse(res)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_flipd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai import config\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Flipd\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import (\n    TEST_DEVICES,\n    TEST_NDARRAYS_ALL,\n    NumpyImageTestCase2D,\n    assert_allclose,\n    test_local_inversion,\n)\n\nINVALID_CASES = [(\"wrong_axis\", [\"s\", 1], TypeError), (\"not_numbers\", \"s\", TypeError)]\n\nVALID_CASES = [(\"no_axis\", None), (\"one_axis\", 1), (\"many_axis\", [0, 1])]\n\nTORCH_CASES = []\nfor track_meta in (False, True):\n    for device in TEST_DEVICES:\n        TORCH_CASES.append([[0, 1], torch.zeros((1, 3, 2)), track_meta, *device])\n\n\nclass TestFlipd(NumpyImageTestCase2D):\n    @parameterized.expand(INVALID_CASES)\n    def test_invalid_cases(self, _, spatial_axis, raises):\n        with self.assertRaises(raises):\n            flip = Flipd(keys=\"img\", spatial_axis=spatial_axis)\n            flip({\"img\": self.imt[0]})\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, spatial_axis):\n        for p in TEST_NDARRAYS_ALL:\n            init_param = {\"keys\": \"img\", \"spatial_axis\": spatial_axis}\n            flip = Flipd(**init_param)\n            expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            im = p(self.imt[0])\n            call_param = {\"data\": {\"img\": im}}\n            result = flip(**call_param)\n            test_resampler_lazy(flip, result, init_param, call_param, output_key=\"img\")\n            assert_allclose(result[\"img\"], p(expected), type_test=\"tensor\")\n            test_local_inversion(flip, {\"img\": result[\"img\"]}, {\"img\": im}, \"img\")\n\n    @parameterized.expand(TORCH_CASES)\n    def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device):\n        set_track_meta(track_meta)\n        img = img.to(device)\n        init_param = {\"keys\": \"image\", \"spatial_axis\": spatial_axis}\n        xform = Flipd(**init_param)\n        call_param = {\"data\": {\"image\": img}}\n        res = xform(**call_param)  # type: ignore\n        self.assertEqual(img.shape, res[\"image\"].shape)\n        if track_meta:\n            test_resampler_lazy(xform, res, init_param, call_param, output_key=\"image\")\n            self.assertIsInstance(res[\"image\"], MetaTensor)\n        else:\n            self.assertNotIsInstance(res[\"image\"], MetaTensor)\n            self.assertIsInstance(res[\"image\"], torch.Tensor)\n            with self.assertRaisesRegex(ValueError, \"MetaTensor\"):\n                xform.inverse(res)\n\n    @unittest.skipIf(not config.USE_META_DICT, \"not using meta dict\")\n    def test_meta_dict(self):\n        xform = Flipd(\"image\", [0, 1])\n        res = xform({\"image\": torch.zeros(1, 3, 4)})\n        self.assertEqual(res[\"image\"].applied_operations, res[\"image_transforms\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_fourier.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import Fourier\nfrom monai.utils.misc import set_determinism\n\nTEST_CASES = [((128, 64),), ((64, 48, 80),)]\n\n\nclass TestFourier(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(img_shape):\n        create_test_image = create_test_image_2d if len(img_shape) == 2 else create_test_image_3d\n        im = create_test_image(*img_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None]\n        return torch.Tensor(im)\n\n    @parameterized.expand(TEST_CASES)\n    def test_forward(self, img_shape):\n        n_dims = len(img_shape[1:])\n        x = self.get_data(img_shape)\n        t = Fourier()\n        out = t.shift_fourier(x, n_dims)\n\n        expect = torch.fft.fftshift(torch.fft.fftn(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0)))\n\n        np.testing.assert_allclose(out, expect)\n\n    @parameterized.expand(TEST_CASES)\n    def test_backward(self, img_shape):\n        n_dims = len(img_shape[1:])\n        x = self.get_data(img_shape)\n        t = Fourier()\n        out = t.inv_shift_fourier(x, n_dims)\n\n        expect = torch.fft.ifftn(\n            torch.fft.ifftshift(x, dim=tuple(range(-n_dims, 0))), dim=tuple(range(-n_dims, 0))\n        ).real\n\n        np.testing.assert_allclose(out, expect)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_gaussian_sharpen.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import GaussianSharpen\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\n\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [4.1081963, 3.4950666, 4.1081963],\n                        [3.7239995, 2.8491793, 3.7239995],\n                        [4.569839, 3.9529324, 4.569839],\n                    ],\n                    [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"sigma1\": 1.0, \"sigma2\": 0.75, \"alpha\": 20},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [4.513644, 4.869134, 4.513644],\n                        [8.467242, 9.4004135, 8.467242],\n                        [10.416813, 12.0653515, 10.416813],\n                    ],\n                    [\n                        [15.711488, 17.569994, 15.711488],\n                        [21.16811, 23.501041, 21.16811],\n                        [21.614658, 24.766209, 21.614658],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"sigma1\": (0.5, 1.0), \"sigma2\": (0.5, 0.75), \"alpha\": 20},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [3.3324685, 3.335536, 3.3324673],\n                        [7.7666636, 8.16056, 7.7666636],\n                        [12.662973, 14.317837, 12.6629715],\n                    ],\n                    [\n                        [15.329051, 16.57557, 15.329051],\n                        [19.41665, 20.40139, 19.416655],\n                        [24.659554, 27.557873, 24.659554],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestGaussianSharpen(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = GaussianSharpen(**arguments)(image)\n        assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_gaussian_sharpend.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import GaussianSharpend\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": \"img\"},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [4.1081963, 3.4950666, 4.1081963],\n                        [3.7239995, 2.8491793, 3.7239995],\n                        [4.569839, 3.9529324, 4.569839],\n                    ],\n                    [[10.616725, 9.081067, 10.616725], [9.309998, 7.12295, 9.309998], [11.078365, 9.538931, 11.078365]],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma1\": 1.0, \"sigma2\": 0.75, \"alpha\": 20},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [4.513644, 4.869134, 4.513644],\n                        [8.467242, 9.4004135, 8.467242],\n                        [10.416813, 12.0653515, 10.416813],\n                    ],\n                    [\n                        [15.711488, 17.569994, 15.711488],\n                        [21.16811, 23.501041, 21.16811],\n                        [21.614658, 24.766209, 21.614658],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma1\": (0.5, 1.0), \"sigma2\": (0.5, 0.75), \"alpha\": 20},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [3.3324685, 3.335536, 3.3324673],\n                        [7.7666636, 8.16056, 7.7666636],\n                        [12.662973, 14.317837, 12.6629715],\n                    ],\n                    [\n                        [15.329051, 16.57557, 15.329051],\n                        [19.41665, 20.40139, 19.416655],\n                        [24.659554, 27.557873, 24.659554],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestGaussianSharpend(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = GaussianSharpend(**arguments)(image)\n        assert_allclose(result[\"img\"], expected_data, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_gaussian_smooth.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import GaussianSmooth\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\n\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"sigma\": 1.5},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [0.59167546, 0.69312394, 0.59167546],\n                        [0.7956997, 0.93213004, 0.7956997],\n                        [0.7668002, 0.8982755, 0.7668002],\n                    ],\n                    [\n                        [1.6105323, 1.8866735, 1.6105323],\n                        [1.9892492, 2.3303251, 1.9892492],\n                        [1.7856569, 2.091825, 1.7856569],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"sigma\": 0.5},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [0.8424794, 0.99864554, 0.8424794],\n                        [1.678146, 1.9892154, 1.678146],\n                        [1.9889624, 2.3576462, 1.9889624],\n                    ],\n                    [\n                        [2.966061, 3.5158648, 2.966061],\n                        [4.1953645, 4.973038, 4.1953645],\n                        [4.112544, 4.8748655, 4.1125436],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"sigma\": [1.5, 0.5]},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [0.8542037, 1.0125432, 0.8542037],\n                        [1.1487541, 1.3616928, 1.1487541],\n                        [1.1070318, 1.3122368, 1.1070318],\n                    ],\n                    [\n                        [2.3251305, 2.756128, 2.3251305],\n                        [2.8718853, 3.4042323, 2.8718853],\n                        [2.5779586, 3.0558217, 2.5779586],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestGaussianSmooth(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = GaussianSmooth(**arguments)(image)\n        assert_allclose(result, expected_data, atol=1e-4, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_gaussian_smoothd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import GaussianSmoothd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma\": 1.5},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [0.59167546, 0.69312394, 0.59167546],\n                        [0.7956997, 0.93213004, 0.7956997],\n                        [0.7668002, 0.8982755, 0.7668002],\n                    ],\n                    [\n                        [1.6105323, 1.8866735, 1.6105323],\n                        [1.9892492, 2.3303251, 1.9892492],\n                        [1.7856569, 2.091825, 1.7856569],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma\": 0.5},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [0.8424794, 0.99864554, 0.8424794],\n                        [1.678146, 1.9892154, 1.678146],\n                        [1.9889624, 2.3576462, 1.9889624],\n                    ],\n                    [\n                        [2.966061, 3.5158648, 2.966061],\n                        [4.1953645, 4.973038, 4.1953645],\n                        [4.112544, 4.8748655, 4.1125436],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma\": [1.5, 0.5]},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [0.8542037, 1.0125432, 0.8542037],\n                        [1.1487541, 1.3616928, 1.1487541],\n                        [1.1070318, 1.3122368, 1.1070318],\n                    ],\n                    [\n                        [2.3251305, 2.756128, 2.3251305],\n                        [2.8718853, 3.4042323, 2.8718853],\n                        [2.5779586, 3.0558217, 2.5779586],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestGaussianSmoothd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = GaussianSmoothd(**arguments)(image)\n        assert_allclose(result[\"img\"], expected_data, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_generate_heatmap.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.post.array import GenerateHeatmap\nfrom tests.test_utils import TEST_NDARRAYS\n\n\ndef _argmax_nd(x) -> np.ndarray:\n    \"\"\"argmax for N-D array → returns coordinate vector (z,y,x) or (y,x).\"\"\"\n    if isinstance(x, torch.Tensor):\n        x = x.cpu().numpy()\n    return np.asarray(np.unravel_index(np.argmax(x), x.shape))\n\n\n# Test cases for 2D array inputs with different data types\nTEST_CASES_2D = [\n    [\n        f\"2d_basic_type{idx}\",\n        p(np.array([[4.2, 7.8], [12.3, 3.6]], dtype=np.float32)),\n        {\"sigma\": 1.5, \"spatial_shape\": (16, 16)},\n        (2, 16, 16),\n    ]\n    for idx, p in enumerate(TEST_NDARRAYS)\n]\n\n# Test cases for 3D torch outputs with explicit dtype\nTEST_CASES_3D_TORCH = [\n    [\n        f\"3d_torch_{str(dtype).replace('torch.', '')}\",\n        torch.tensor([[1.5, 2.5, 3.5]], dtype=torch.float32),\n        {\"sigma\": 1.0, \"spatial_shape\": (8, 8, 8), \"dtype\": dtype},\n        (1, 8, 8, 8),\n        dtype,\n    ]\n    for dtype in [torch.float32, torch.float64]\n]\n\n# Test cases for 3D numpy outputs with explicit dtype\nTEST_CASES_3D_NUMPY = [\n    [\n        f\"3d_numpy_{dtype_obj.__name__}\",\n        np.array([[1.5, 2.5, 3.5]], dtype=np.float32),\n        {\"sigma\": 1.0, \"spatial_shape\": (8, 8, 8), \"dtype\": dtype_obj},\n        (1, 8, 8, 8),\n        dtype_obj,\n    ]\n    for dtype_obj in [np.float32, np.float64]\n]\n\n# Test cases for different sigma values\nTEST_CASES_SIGMA = [\n    [\n        f\"sigma_{sigma}\",\n        np.array([[8.0, 8.0]], dtype=np.float32),\n        {\"sigma\": sigma, \"spatial_shape\": (16, 16)},\n        (1, 16, 16),\n    ]\n    for sigma in [0.5, 1.0, 2.0, 3.0]\n]\n\n# Test cases for truncated parameter\nTEST_CASES_TRUNCATED = [\n    [\n        f\"truncated_{truncated}\",\n        np.array([[8.0, 8.0]], dtype=np.float32),\n        {\"sigma\": 2.0, \"spatial_shape\": (32, 32), \"truncated\": truncated},\n        (1, 32, 32),\n    ]\n    for truncated in [2.0, 4.0, 6.0]\n]\n\n# Test cases for device and dtype propagation (torch only)\ntest_device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\ntest_dtypes = [torch.float32, torch.float64]\nif torch.cuda.is_available():\n    test_dtypes.append(torch.float16)\n\nTEST_CASES_DEVICE_DTYPE = [\n    [\n        f\"{test_device.split(':')[0]}_{str(dtype).replace('torch.', '')}\",\n        torch.tensor([[3.0, 4.0, 5.0]], dtype=torch.float32, device=test_device),\n        {\"sigma\": 1.2, \"spatial_shape\": (10, 10, 10), \"dtype\": dtype},\n        (1, 10, 10, 10),\n        dtype,\n        test_device,\n    ]\n    for dtype in test_dtypes\n]\n\n\nclass TestGenerateHeatmap(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_array_2d(self, _, points, params, expected_shape):\n        transform = GenerateHeatmap(**params)\n        heatmap = transform(points)\n\n        # Check output type matches input type\n        if isinstance(points, torch.Tensor):\n            self.assertIsInstance(heatmap, torch.Tensor)\n            self.assertEqual(heatmap.dtype, torch.float32)  # Default dtype for torch\n            heatmap_np = heatmap.cpu().numpy()\n            points_np = points.cpu().numpy()\n        else:\n            self.assertIsInstance(heatmap, np.ndarray)\n            self.assertEqual(heatmap.dtype, np.float32)  # Default dtype for numpy\n            heatmap_np = heatmap\n            points_np = points\n\n        self.assertEqual(heatmap.shape, expected_shape)\n        np.testing.assert_allclose(heatmap_np.max(axis=(1, 2)), np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5)\n\n        # peak should be close to original point location (<= 1px tolerance due to discretization)\n        for idx in range(expected_shape[0]):\n            peak = _argmax_nd(heatmap_np[idx])\n            self.assertTrue(np.all(np.abs(peak - points_np[idx]) <= 1.0), msg=f\"peak={peak}, point={points_np[idx]}\")\n            self.assertLess(heatmap_np[idx, 0, 0], 1e-3)\n\n    @parameterized.expand(TEST_CASES_3D_TORCH)\n    def test_array_3d_torch_output(self, _, points, params, expected_shape, expected_dtype):\n        transform = GenerateHeatmap(**params)\n        heatmap = transform(points)\n\n        self.assertIsInstance(heatmap, torch.Tensor)\n        self.assertEqual(heatmap.device, points.device)\n        self.assertEqual(tuple(heatmap.shape), expected_shape)\n        self.assertEqual(heatmap.dtype, expected_dtype)\n        self.assertTrue(torch.isclose(heatmap.max(), torch.tensor(1.0, dtype=heatmap.dtype, device=heatmap.device)))\n\n    @parameterized.expand(TEST_CASES_3D_NUMPY)\n    def test_array_3d_numpy_output(self, _, points, params, expected_shape, expected_dtype):\n        transform = GenerateHeatmap(**params)\n        heatmap = transform(points)\n\n        self.assertIsInstance(heatmap, np.ndarray)\n        self.assertEqual(heatmap.shape, expected_shape)\n        self.assertEqual(heatmap.dtype, expected_dtype)\n        np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5)\n\n    @parameterized.expand(TEST_CASES_DEVICE_DTYPE)\n    def test_array_torch_device_and_dtype_propagation(\n        self, _, pts, params, expected_shape, expected_dtype, expected_device\n    ):\n        tr = GenerateHeatmap(**params)\n        hm = tr(pts)\n\n        self.assertIsInstance(hm, torch.Tensor)\n        self.assertEqual(str(hm.device).split(\":\")[0], expected_device.split(\":\")[0])\n        self.assertEqual(hm.dtype, expected_dtype)\n        self.assertEqual(tuple(hm.shape), expected_shape)\n        self.assertTrue(torch.all(hm >= 0))\n\n    def test_array_channel_order_identity(self):\n        # ensure the order of channels follows the order of input points\n        pts = np.array([[2.0, 2.0], [12.0, 2.0], [2.0, 12.0]], dtype=np.float32)  # point A  # point B  # point C\n        hm = GenerateHeatmap(sigma=1.2, spatial_shape=(16, 16))(pts)\n\n        self.assertIsInstance(hm, np.ndarray)\n        self.assertEqual(hm.shape, (3, 16, 16))\n\n        peaks = np.vstack([_argmax_nd(hm[i]) for i in range(3)])\n        # y,x close to points\n        np.testing.assert_allclose(peaks, pts, atol=1.0)\n\n    def test_array_points_out_of_bounds(self):\n        # points outside spatial domain: heatmap should still be valid (no NaN/Inf) and not all-zeros\n        pts = np.array(\n            [[-5.0, -5.0], [100.0, 100.0], [8.0, 8.0]],  # outside top-left  # outside bottom-right  # inside\n            dtype=np.float32,\n        )\n        hm = GenerateHeatmap(sigma=2.0, spatial_shape=(16, 16))(pts)\n\n        self.assertIsInstance(hm, np.ndarray)\n        self.assertEqual(hm.shape, (3, 16, 16))\n        self.assertFalse(np.isnan(hm).any() or np.isinf(hm).any())\n\n        # inside point channel should have max≈1; others may clip at border (≤1)\n        self.assertGreater(hm[2].max(), 0.9)\n\n    @parameterized.expand(TEST_CASES_SIGMA)\n    def test_array_sigma_scaling_effect(self, _, pt, params, expected_shape):\n        heatmap = GenerateHeatmap(**params)(pt)[0]\n        self.assertEqual(heatmap.shape, expected_shape[1:])\n\n        # All should have peak normalized to 1.0\n        np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5)\n\n        # Verify heatmap is valid\n        self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any())\n\n    def test_invalid_points_shape_raises(self):\n        # points must be (N, D) with D in {2,3}\n        tr = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))\n        with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)):\n            tr(np.zeros((2,), dtype=np.float32))  # wrong rank\n\n        with self.assertRaises((ValueError, AssertionError, IndexError, RuntimeError)):\n            tr(np.zeros((2, 4), dtype=np.float32))  # D=4 unsupported\n\n    @parameterized.expand(TEST_CASES_TRUNCATED)\n    def test_truncated_parameter(self, _, pt, params, expected_shape):\n        heatmap = GenerateHeatmap(**params)(pt)[0]\n\n        # All should have same peak value (normalized to 1.0)\n        np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5)\n\n        # Verify shape and no NaN/Inf\n        self.assertEqual(heatmap.shape, expected_shape[1:])\n        self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any())\n\n    def test_torch_to_torch_type_preservation(self):\n        \"\"\"Test that torch input produces torch output\"\"\"\n        pts = torch.tensor([[4.0, 4.0]], dtype=torch.float32)\n        hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts)\n\n        self.assertIsInstance(hm, torch.Tensor)\n        self.assertEqual(hm.dtype, torch.float32)\n        self.assertEqual(hm.device, pts.device)\n\n    def test_numpy_to_numpy_type_preservation(self):\n        \"\"\"Test that numpy input produces numpy output\"\"\"\n        pts = np.array([[4.0, 4.0]], dtype=np.float32)\n        hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8))(pts)\n\n        self.assertIsInstance(hm, np.ndarray)\n        self.assertEqual(hm.dtype, np.float32)\n\n    def test_dtype_override_torch(self):\n        \"\"\"Test dtype parameter works with torch tensors\"\"\"\n        pts = torch.tensor([[4.0, 4.0, 4.0]], dtype=torch.float32)\n        hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=torch.float64)(pts)\n\n        self.assertIsInstance(hm, torch.Tensor)\n        self.assertEqual(hm.dtype, torch.float64)\n\n    def test_dtype_override_numpy(self):\n        \"\"\"Test dtype parameter works with numpy arrays\"\"\"\n        pts = np.array([[4.0, 4.0, 4.0]], dtype=np.float32)\n        hm = GenerateHeatmap(sigma=1.0, spatial_shape=(8, 8, 8), dtype=np.float64)(pts)\n\n        self.assertIsInstance(hm, np.ndarray)\n        self.assertEqual(hm.dtype, np.float64)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_generate_heatmapd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.transforms.post.dictionary import GenerateHeatmapd\nfrom tests.test_utils import assert_allclose\n\n# Test cases for dictionary transforms with reference image\n# Only test with non-MetaTensor types to avoid affine conflicts\nTEST_CASES_WITH_REF = [\n    [\n        \"dict_with_ref_3d_numpy\",\n        np.array([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=np.float32),\n        {\"sigma\": 2.0},\n        (2, 8, 8, 8),\n        torch.float32,\n        True,  # uses reference image\n    ],\n    [\n        \"dict_with_ref_3d_torch\",\n        torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32),\n        {\"sigma\": 2.0},\n        (2, 8, 8, 8),\n        torch.float32,\n        True,  # uses reference image\n    ],\n]\n\n# Test cases for dictionary transforms with static spatial shape\nTEST_CASES_STATIC_SHAPE = [\n    [\n        f\"dict_static_shape_{len(shape)}d\",\n        np.array([[1.0] * len(shape)], dtype=np.float32),\n        {\"spatial_shape\": shape},\n        (1, *shape),\n        np.float32,\n    ]\n    for shape in [(6, 6), (8, 8, 8), (10, 10, 10)]\n]\n\n# Test cases for dtype control\nTEST_CASES_DTYPE = [\n    [\n        f\"dict_dtype_{str(dtype).replace('torch.', '')}\",\n        np.array([[2.0, 3.0, 4.0]], dtype=np.float32),\n        {\"sigma\": 1.4, \"dtype\": dtype},\n        (1, 10, 10, 10),\n        dtype,\n    ]\n    for dtype in [torch.float16, torch.float32, torch.float64]\n]\n\n# Test cases for various sigma values\nTEST_CASES_SIGMA_VALUES = [\n    [\n        f\"dict_sigma_{sigma}\",\n        np.array([[4.0, 4.0, 4.0]], dtype=np.float32),\n        {\"sigma\": sigma, \"spatial_shape\": (8, 8, 8)},\n        (1, 8, 8, 8),\n    ]\n    for sigma in [0.5, 1.0, 2.0, 3.0]\n]\n\n\nclass TestGenerateHeatmapd(unittest.TestCase):\n    @parameterized.expand(TEST_CASES_WITH_REF)\n    def test_dict_with_reference_meta(self, _, points, params, expected_shape, *_unused):\n        affine = torch.eye(4)\n        image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=affine)\n        image.meta[\"spatial_shape\"] = (8, 8, 8)\n        data = {\"points\": points, \"image\": image}\n\n        transform = GenerateHeatmapd(keys=\"points\", heatmap_keys=\"heatmap\", ref_image_keys=\"image\", **params)\n        result = transform(data)\n        heatmap = result[\"heatmap\"]\n\n        self.assertIsInstance(heatmap, MetaTensor)\n        self.assertEqual(tuple(heatmap.shape), expected_shape)\n        self.assertEqual(heatmap.meta[\"spatial_shape\"], (8, 8, 8))\n        # The heatmap should inherit the reference image's affine\n        assert_allclose(heatmap.affine, image.affine, type_test=False)\n\n        # Check max values are normalized to 1.0\n        max_vals = heatmap.cpu().numpy().max(axis=tuple(range(1, len(expected_shape))))\n        np.testing.assert_allclose(max_vals, np.ones(expected_shape[0]), rtol=1e-5, atol=1e-5)\n\n    @parameterized.expand(TEST_CASES_STATIC_SHAPE)\n    def test_dict_static_shape(self, _, points, params, expected_shape, expected_dtype):\n        transform = GenerateHeatmapd(keys=\"points\", heatmap_keys=\"heatmap\", **params)\n        result = transform({\"points\": points})\n        heatmap = result[\"heatmap\"]\n\n        self.assertIsInstance(heatmap, np.ndarray)\n        self.assertEqual(heatmap.shape, expected_shape)\n        self.assertEqual(heatmap.dtype, expected_dtype)\n\n        # Verify no NaN or Inf values\n        self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any())\n\n        # Verify max value is 1.0 for normalized heatmaps\n        np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5)\n\n    def test_dict_missing_shape_raises(self):\n        # Without ref image or explicit spatial_shape, must raise\n        transform = GenerateHeatmapd(keys=\"points\", heatmap_keys=\"heatmap\")\n        with self.assertRaisesRegex(ValueError, \"spatial_shape|ref_image_keys\"):\n            transform({\"points\": np.zeros((1, 2), dtype=np.float32)})\n\n    @parameterized.expand(TEST_CASES_DTYPE)\n    def test_dict_dtype_control(self, _, points, params, expected_shape, expected_dtype):\n        ref = MetaTensor(torch.zeros((1, 10, 10, 10), dtype=torch.float32), affine=torch.eye(4))\n        d = {\"pts\": points, \"img\": ref}\n\n        tr = GenerateHeatmapd(keys=\"pts\", heatmap_keys=\"hm\", ref_image_keys=\"img\", **params)\n        out = tr(d)\n        hm = out[\"hm\"]\n\n        self.assertIsInstance(hm, MetaTensor)\n        self.assertEqual(tuple(hm.shape), expected_shape)\n        self.assertEqual(hm.dtype, expected_dtype)\n\n    @parameterized.expand(TEST_CASES_SIGMA_VALUES)\n    def test_dict_various_sigma(self, _, points, params, expected_shape):\n        transform = GenerateHeatmapd(keys=\"points\", heatmap_keys=\"heatmap\", **params)\n        result = transform({\"points\": points})\n        heatmap = result[\"heatmap\"]\n\n        self.assertEqual(heatmap.shape, expected_shape)\n        # Verify heatmap is normalized\n        np.testing.assert_allclose(heatmap.max(), 1.0, rtol=1e-5)\n        # Verify no NaN or Inf\n        self.assertFalse(np.isnan(heatmap).any() or np.isinf(heatmap).any())\n\n    def test_dict_multiple_keys(self):\n        \"\"\"Test dictionary transform with multiple input/output keys\"\"\"\n        points1 = np.array([[2.0, 2.0]], dtype=np.float32)\n        points2 = np.array([[4.0, 4.0]], dtype=np.float32)\n\n        data = {\"pts1\": points1, \"pts2\": points2}\n        transform = GenerateHeatmapd(\n            keys=[\"pts1\", \"pts2\"], heatmap_keys=[\"hm1\", \"hm2\"], spatial_shape=(8, 8), sigma=1.0\n        )\n\n        result = transform(data)\n\n        self.assertIn(\"hm1\", result)\n        self.assertIn(\"hm2\", result)\n        self.assertEqual(result[\"hm1\"].shape, (1, 8, 8))\n        self.assertEqual(result[\"hm2\"].shape, (1, 8, 8))\n\n        # Verify peaks are at different locations\n        self.assertNotEqual(np.argmax(result[\"hm1\"]), np.argmax(result[\"hm2\"]))\n\n    def test_dict_mismatched_heatmap_keys_length(self):\n        \"\"\"Test ValueError when heatmap_keys length doesn't match keys\"\"\"\n        with self.assertRaises(ValueError):\n            GenerateHeatmapd(\n                keys=[\"pts1\", \"pts2\"],\n                heatmap_keys=[\"hm1\", \"hm2\", \"hm3\"],  # Mismatch: 3 heatmap keys for 2 input keys\n                spatial_shape=(8, 8),\n            )\n\n    def test_dict_mismatched_ref_image_keys_length(self):\n        \"\"\"Test ValueError when ref_image_keys length doesn't match keys\"\"\"\n        with self.assertRaises(ValueError):\n            GenerateHeatmapd(\n                keys=[\"pts1\", \"pts2\"],\n                heatmap_keys=[\"hm1\", \"hm2\"],\n                ref_image_keys=[\"img1\", \"img2\", \"img3\"],  # Mismatch: 3 ref keys for 2 input keys\n                spatial_shape=(8, 8),\n            )\n\n    def test_dict_per_key_spatial_shape_mismatch(self):\n        \"\"\"Test ValueError when per-key spatial_shape length doesn't match keys\"\"\"\n        with self.assertRaises(ValueError):\n            GenerateHeatmapd(\n                keys=[\"pts1\", \"pts2\"],\n                heatmap_keys=[\"hm1\", \"hm2\"],\n                spatial_shape=[(8, 8), (8, 8), (8, 8)],  # Mismatch: 3 shapes for 2 keys\n                sigma=1.0,\n            )\n\n    def test_metatensor_points_with_ref(self):\n        \"\"\"Test MetaTensor points with reference image - documents current behavior\"\"\"\n        from monai.data import MetaTensor\n\n        # Create MetaTensor points with non-identity affine\n        points_affine = torch.tensor([[2.0, 0, 0, 0], [0, 2.0, 0, 0], [0, 0, 2.0, 0], [0, 0, 0, 1.0]])\n        points = MetaTensor(torch.tensor([[2.5, 2.5, 3.0], [5.0, 5.0, 4.0]], dtype=torch.float32), affine=points_affine)\n\n        # Reference image with identity affine\n        ref_affine = torch.eye(4)\n        image = MetaTensor(torch.zeros((1, 8, 8, 8), dtype=torch.float32), affine=ref_affine)\n        image.meta[\"spatial_shape\"] = (8, 8, 8)\n\n        data = {\"points\": points, \"image\": image}\n        transform = GenerateHeatmapd(keys=\"points\", heatmap_keys=\"heatmap\", ref_image_keys=\"image\", sigma=2.0)\n        result = transform(data)\n        heatmap = result[\"heatmap\"]\n\n        self.assertIsInstance(heatmap, MetaTensor)\n        self.assertEqual(tuple(heatmap.shape), (2, 8, 8, 8))\n\n        # Heatmap should inherit affine from the reference image\n        assert_allclose(heatmap.affine, image.affine, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_generate_label_classes_crop_centers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import generate_label_classes_crop_centers\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASE_1 = [\n    {\n        \"spatial_size\": [2, 2, 2],\n        \"num_samples\": 2,\n        \"ratios\": [1, 2],\n        \"label_spatial_shape\": [3, 3, 3],\n        \"indices\": [[3, 12, 21], [1, 9, 18]],\n    },\n    tuple,\n    2,\n    3,\n]\n\nTEST_CASE_2 = [\n    {\n        \"spatial_size\": [2, 2, 2],\n        \"num_samples\": 1,\n        \"ratios\": None,\n        \"label_spatial_shape\": [3, 3, 3],\n        \"indices\": [[3, 12, 21], [1, 9, 18]],\n    },\n    tuple,\n    1,\n    3,\n]\n\n\nclass TestGenerateLabelClassesCropCenters(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_type_shape(self, input_data, expected_type, expected_count, expected_shape):\n        results = []\n        for p in TEST_NDARRAYS + (None,):\n            input_data = deepcopy(input_data)\n            if p is not None:\n                input_data[\"indices\"] = p(input_data[\"indices\"])\n            set_determinism(0)\n            result = generate_label_classes_crop_centers(**input_data)\n            self.assertIsInstance(result, expected_type)\n            self.assertEqual(len(result), expected_count)\n            self.assertEqual(len(result[0]), expected_shape)\n            # check for consistency between numpy, torch and torch.cuda\n            results.append(result)\n            if len(results) > 1:\n                for x, y in zip(result[0], result[-1]):\n                    assert_allclose(x, y, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_generate_pos_neg_label_crop_centers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import generate_pos_neg_label_crop_centers\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = [\n    [\n        {\n            \"spatial_size\": [2, 2, 2],\n            \"num_samples\": 2,\n            \"pos_ratio\": 1.0,\n            \"label_spatial_shape\": [3, 3, 3],\n            \"fg_indices\": [1, 9, 18],\n            \"bg_indices\": [3, 12, 21],\n        },\n        tuple,\n        2,\n        3,\n    ],\n    [\n        {\n            \"spatial_size\": [2, 2, 2],\n            \"num_samples\": 2,\n            \"pos_ratio\": 0.0,\n            \"label_spatial_shape\": [3, 3, 3],\n            \"fg_indices\": [],\n            \"bg_indices\": [3, 12, 21],\n        },\n        tuple,\n        2,\n        3,\n    ],\n]\n\n\nclass TestGeneratePosNegLabelCropCenters(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, input_data, expected_type, expected_count, expected_shape):\n        results = []\n        for p in TEST_NDARRAYS + (None,):\n            input_data = deepcopy(input_data)\n            if p is not None:\n                for k in [\"fg_indices\", \"bg_indices\"]:\n                    input_data[k] = p(input_data[k])\n            set_determinism(0)\n            result = generate_pos_neg_label_crop_centers(**input_data)\n            self.assertIsInstance(result, expected_type)\n            self.assertEqual(len(result), expected_count)\n            self.assertEqual(len(result[0]), expected_shape)\n            # check for consistency between numpy, torch and torch.cuda\n            results.append(result)\n            if len(results) > 1:\n                # compare every crop center\n                for x, y in zip(results[0], results[-1]):\n                    assert_allclose(x, y, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_generate_spatial_bounding_box.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import generate_spatial_bounding_box\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                ),\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": 0,\n            },\n            ([1, 1], [4, 4]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 3, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]])\n                ),\n                \"select_fn\": lambda x: x > 1,\n                \"channel_indices\": None,\n                \"margin\": 0,\n            },\n            ([2, 2], [3, 3]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                ),\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": 0,\n                \"margin\": 0,\n            },\n            ([1, 1], [4, 4]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])\n                ),\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": 1,\n            },\n            ([0, 0], [4, 5]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                ),\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": [2, 1],\n                \"allow_smaller\": False,\n            },\n            ([-1, 0], [6, 5]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"img\": p(\n                    np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]])\n                ),\n                \"select_fn\": lambda x: x > 0,\n                \"channel_indices\": None,\n                \"margin\": [2, 1],\n                \"allow_smaller\": True,\n            },\n            ([0, 0], [5, 5]),\n        ]\n    )\n\n\nclass TestGenerateSpatialBoundingBox(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_data, expected_box):\n        result = generate_spatial_bounding_box(**input_data)\n        self.assertTupleEqual(result, expected_box)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_get_extreme_points.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import get_extreme_points\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\n                \"img\": p(np.array([[0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])),\n                \"rand_state\": np.random,\n                \"background\": 0,\n                \"pert\": 0.0,\n            },\n            [(0, 1), (3, 0), (3, 0), (1, 2)],\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"img\": p(np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0], [0, 1, 0]])),\n                \"rand_state\": np.random,\n                \"background\": 0,\n                \"pert\": 0.0,\n            },\n            [(0, 1), (3, 1), (1, 0), (1, 2)],\n        ]\n    )\n\n\nclass TestGetExtremePoints(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, input_data, expected):\n        result = get_extreme_points(**input_data)\n        self.assertEqual(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_gibbs_noise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import GibbsNoise\nfrom monai.utils.misc import set_determinism\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product\n\n_, has_torch_fft = optional_import(\"torch.fft\", name=\"fftshift\")\n\nshapes = ((128, 64), (64, 48, 80))\ninput_types = TEST_NDARRAYS if has_torch_fft else [np.array]\nTEST_CASES = [[p_dict[\"shape\"], p_dict[\"input_type\"]] for p_dict in dict_product(shape=shapes, input_type=input_types)]\n\n\nclass TestGibbsNoise(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, input_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None]\n        return input_type(im)\n\n    @parameterized.expand(TEST_CASES)\n    def test_same_result(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = 0.8\n        t = GibbsNoise(alpha)\n        out1 = t(deepcopy(im))\n        out2 = t(deepcopy(im))\n        assert_allclose(out1, out2, rtol=1e-7, atol=0, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_identity(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = 0.0\n        t = GibbsNoise(alpha)\n        out = t(deepcopy(im))\n        assert_allclose(out, im, atol=1e-2, rtol=1e-7, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha_1(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = 1.0\n        t = GibbsNoise(alpha)\n        out = t(deepcopy(im))\n        assert_allclose(out, 0 * im, rtol=1e-7, atol=0, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_gibbs_noised.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import GibbsNoised\nfrom monai.utils.misc import set_determinism\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_torch_fft = optional_import(\"torch.fft\", name=\"fftshift\")\n\nTEST_CASES = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:\n        TEST_CASES.append((shape, input_type))\nKEYS = [\"im\", \"label\"]\n\n\nclass TestGibbsNoised(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, input_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)\n        return {k: input_type(deepcopy(v)) for k, v in zip(KEYS, ims)}\n\n    @parameterized.expand(TEST_CASES)\n    def test_same_result(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = 0.8\n        t = GibbsNoised(KEYS, alpha)\n        out1 = t(deepcopy(data))\n        out2 = t(deepcopy(data))\n        for k in KEYS:\n            assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_identity(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = 0.0\n        t = GibbsNoised(KEYS, alpha)\n        out = t(deepcopy(data))\n        for k in KEYS:\n            assert_allclose(out[k], data[k], atol=1e-2, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha_1(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = 1.0\n        t = GibbsNoised(KEYS, alpha)\n        out = t(deepcopy(data))\n        for k in KEYS:\n            assert_allclose(out[k], 0.0 * data[k], atol=1e-2, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_dict_matches(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])}\n        alpha = 1.0\n        t = GibbsNoised(KEYS, alpha)\n        out = t(deepcopy(data))\n        assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_grid_distortion.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import GridDistortion\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    TESTS.append(\n        [\n            dict(num_cells=3, distort_steps=[(1.5,) * 4] * 2, mode=\"nearest\", padding_mode=\"zeros\"),\n            p(np.indices([6, 6]).astype(np.float32)),\n            p(\n                np.array(\n                    [\n                        [\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [3.0, 3.0, 3.0, 0.0, 0.0, 0.0],\n                            [3.0, 3.0, 3.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ],\n                        [\n                            [0.0, 3.0, 3.0, 0.0, 0.0, 0.0],\n                            [0.0, 3.0, 3.0, 0.0, 0.0, 0.0],\n                            [0.0, 3.0, 3.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ],\n                    ]\n                ).astype(np.float32)\n            ),\n        ]\n    )\n    num_cells = (2, 2)\n    distort_steps = [(1.5,) * (1 + num_cells[0]), (1.0,) * (1 + num_cells[1])]\n    TESTS.append(\n        [\n            dict(num_cells=num_cells, distort_steps=distort_steps, mode=\"bilinear\", padding_mode=\"reflection\"),\n            p(np.indices([6, 6]).astype(np.float32)),\n            p(\n                np.array(\n                    [\n                        [\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [2.25, 2.25, 2.25, 2.25, 2.25, 2.25],\n                            [4.5, 4.5, 4.5, 4.5, 4.5, 4.5],\n                            [4.5, 4.5, 4.5, 4.5, 4.5, 4.5],\n                            [4.2500, 4.2500, 4.2500, 4.2500, 4.2500, 4.2500],\n                            [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],\n                        ],\n                        [\n                            [0.0, 1.5, 3.0, 3.0, 4.5, 5.0],\n                            [0.0, 1.5, 3.0, 3.0, 4.5, 5.0],\n                            [0.0, 1.5, 3.0, 3.0, 4.5, 5.0],\n                            [0.0, 1.5, 3.0, 3.0, 4.5, 5.0],\n                            [0.0, 1.5, 3.0, 3.0, 4.5, 5.0],\n                            [0.0, 1.5, 3.0, 3.0, 4.5, 5.0],\n                        ],\n                    ]\n                ).astype(np.float32)\n            ),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(num_cells=2, distort_steps=[(1.26,) * 3] * 3, mode=\"nearest\", padding_mode=\"zeros\"),\n            p(np.indices([3, 3, 3])[:1].astype(np.float32)),\n            p(\n                np.array(\n                    [\n                        [\n                            [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],\n                            [[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                            [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],\n                        ]\n                    ]\n                ).astype(np.float32)\n            ),\n        ]\n    )\n\n\nclass TestGridDistortion(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_grid_distortion(self, input_param, input_data, expected_val):\n        g = GridDistortion(**input_param)\n        result = g(input_data)\n        if input_param[\"padding_mode\"] != \"reflection\":\n            assert_allclose(result, expected_val, type_test=False, rtol=1e-4, atol=1e-4)\n        else:\n            assert_allclose(result.shape, expected_val.shape, type_test=False, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_grid_distortiond.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import GridDistortiond\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nnum_cells = (2, 2)\ndistort_steps = [(1.5,) * (1 + n_c) for n_c in num_cells]\nfor p in TEST_NDARRAYS_ALL:\n    img = np.indices([6, 6]).astype(np.float32)\n    TESTS.append(\n        [\n            dict(\n                keys=[\"img\", \"mask\"],\n                num_cells=num_cells,\n                distort_steps=distort_steps,\n                mode=[\"bilinear\", \"nearest\"],\n                padding_mode=[\"reflection\", \"zeros\"],\n            ),\n            {\"img\": p(img), \"mask\": p(np.ones_like(img[:1]))},\n            p(\n                np.array(\n                    [\n                        [\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [2.25, 2.25, 2.25, 2.25, 2.25, 2.25],\n                            [4.5, 4.5, 4.5, 4.5, 4.5, 4.5],\n                            [4.5, 4.5, 4.5, 4.5, 4.5, 4.5],\n                            [4.2500, 4.2500, 4.2500, 4.2500, 4.2500, 4.2500],\n                            [2.0000, 2.0000, 2.0000, 2.0000, 2.0000, 2.0000],\n                        ],\n                        [\n                            [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000],\n                            [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000],\n                            [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000],\n                            [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000],\n                            [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000],\n                            [0.0000, 2.2500, 4.5000, 4.5000, 4.2500, 2.0000],\n                        ],\n                    ]\n                ).astype(np.float32)\n            ),\n            p(\n                np.array(\n                    [\n                        [\n                            [1.0, 1.0, 1.0, 1.0, 0.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 0.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 0.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ]\n                    ]\n                ).astype(np.float32)\n            ),\n        ]\n    )\n\n\nclass TestGridDistortiond(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_grid_distortiond(self, input_param, input_data, expected_val_img, expected_val_mask):\n        g = GridDistortiond(**input_param)\n        result = g(input_data)\n        assert_allclose(result[\"mask\"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4)\n        assert_allclose(result[\"img\"].shape, expected_val_img.shape, type_test=False, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_grid_split.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import GridSplit\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nA11 = torch.randn(3, 2, 2)\nA12 = torch.randn(3, 2, 2)\nA21 = torch.randn(3, 2, 2)\nA22 = torch.randn(3, 2, 2)\n\nA1 = torch.cat([A11, A12], 2)\nA2 = torch.cat([A21, A22], 2)\nA = torch.cat([A1, A2], 1)\n\nTEST_CASE_0 = [{\"grid\": (2, 2)}, A, [A11, A12, A21, A22]]\nTEST_CASE_1 = [{\"grid\": (2, 1)}, A, [A1, A2]]\nTEST_CASE_2 = [{\"grid\": (1, 2)}, A1, [A11, A12]]\nTEST_CASE_3 = [{\"grid\": (1, 2)}, A2, [A21, A22]]\nTEST_CASE_4 = [{\"grid\": (1, 1), \"size\": (2, 2)}, A, [A11]]\nTEST_CASE_5 = [{\"grid\": (1, 1), \"size\": 4}, A, [A]]\nTEST_CASE_6 = [{\"grid\": (2, 2), \"size\": 2}, A, [A11, A12, A21, A22]]\nTEST_CASE_7 = [{\"grid\": (1, 1)}, A, [A]]\nTEST_CASE_8 = [\n    {\"grid\": (2, 2), \"size\": 2},\n    torch.arange(12).reshape(1, 3, 4).to(torch.float32),\n    torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32),\n]\n\nTEST_SINGLE = []\nfor p in TEST_NDARRAYS:\n    TEST_SINGLE.append([p, *TEST_CASE_0])\n    TEST_SINGLE.append([p, *TEST_CASE_1])\n    TEST_SINGLE.append([p, *TEST_CASE_2])\n    TEST_SINGLE.append([p, *TEST_CASE_3])\n    TEST_SINGLE.append([p, *TEST_CASE_4])\n    TEST_SINGLE.append([p, *TEST_CASE_5])\n    TEST_SINGLE.append([p, *TEST_CASE_6])\n    TEST_SINGLE.append([p, *TEST_CASE_7])\n    TEST_SINGLE.append([p, *TEST_CASE_8])\n\nTEST_CASE_MC_0 = [{\"grid\": (2, 2)}, [A, A], [[A11, A12, A21, A22], [A11, A12, A21, A22]]]\nTEST_CASE_MC_1 = [{\"grid\": (2, 1)}, [A] * 5, [[A1, A2]] * 5]\nTEST_CASE_MC_2 = [{\"grid\": (1, 2)}, [A1, A2], [[A11, A12], [A21, A22]]]\n\nTEST_MULTIPLE = []\nfor p in TEST_NDARRAYS:\n    TEST_MULTIPLE.append([p, *TEST_CASE_MC_0])\n    TEST_MULTIPLE.append([p, *TEST_CASE_MC_1])\n    TEST_MULTIPLE.append([p, *TEST_CASE_MC_2])\n\n\nclass TestGridSplit(unittest.TestCase):\n    @parameterized.expand(TEST_SINGLE)\n    def test_split_patch_single_call(self, in_type, input_parameters, image, expected):\n        input_image = in_type(image)\n        splitter = GridSplit(**input_parameters)\n        output = splitter(input_image)\n        for output_patch, expected_patch in zip(output, expected):\n            assert_allclose(output_patch, expected_patch, type_test=False)\n\n    @parameterized.expand(TEST_MULTIPLE)\n    def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):\n        splitter = GridSplit(**input_parameters)\n        for image, expected in zip(img_list, expected_list):\n            input_image = in_type(image)\n            output = splitter(input_image)\n            for output_patch, expected_patch in zip(output, expected):\n                assert_allclose(output_patch, expected_patch, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_grid_splitd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import GridSplitd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nA11 = torch.randn(3, 2, 2)\nA12 = torch.randn(3, 2, 2)\nA21 = torch.randn(3, 2, 2)\nA22 = torch.randn(3, 2, 2)\n\nA1 = torch.cat([A11, A12], 2)\nA2 = torch.cat([A21, A22], 2)\nA = torch.cat([A1, A2], 1)\n\nTEST_CASE_0 = [{\"keys\": \"image\", \"grid\": (2, 2)}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_1 = [{\"keys\": \"image\", \"grid\": (2, 1)}, {\"image\": A}, [A1, A2]]\nTEST_CASE_2 = [{\"keys\": \"image\", \"grid\": (1, 2)}, {\"image\": A1}, [A11, A12]]\nTEST_CASE_3 = [{\"keys\": \"image\", \"grid\": (1, 2)}, {\"image\": A2}, [A21, A22]]\nTEST_CASE_4 = [{\"keys\": \"image\", \"grid\": (1, 1), \"size\": {\"image\": (2, 2)}}, {\"image\": A}, [A11]]\nTEST_CASE_5 = [{\"keys\": \"image\", \"grid\": (1, 1), \"size\": {\"image\": 4}}, {\"image\": A}, [A]]\nTEST_CASE_6 = [{\"keys\": \"image\", \"grid\": (2, 2), \"size\": {\"image\": 2}}, {\"image\": A}, [A11, A12, A21, A22]]\nTEST_CASE_7 = [{\"keys\": \"image\", \"grid\": (1, 1)}, {\"image\": A}, [A]]\nTEST_CASE_8 = [\n    {\"keys\": \"image\", \"grid\": (2, 2), \"size\": {\"image\": 2}},\n    {\"image\": torch.arange(12).reshape(1, 3, 4).to(torch.float32)},\n    torch.Tensor([[[[0, 1], [4, 5]]], [[[2, 3], [6, 7]]], [[[4, 5], [8, 9]]], [[[6, 7], [10, 11]]]]).to(torch.float32),\n]\n\nTEST_SINGLE = []\nfor p in TEST_NDARRAYS:\n    TEST_SINGLE.append([p, *TEST_CASE_0])\n    TEST_SINGLE.append([p, *TEST_CASE_1])\n    TEST_SINGLE.append([p, *TEST_CASE_2])\n    TEST_SINGLE.append([p, *TEST_CASE_3])\n    TEST_SINGLE.append([p, *TEST_CASE_4])\n    TEST_SINGLE.append([p, *TEST_CASE_5])\n    TEST_SINGLE.append([p, *TEST_CASE_6])\n    TEST_SINGLE.append([p, *TEST_CASE_7])\n    TEST_SINGLE.append([p, *TEST_CASE_8])\n\nTEST_CASE_MC_0 = [\n    {\"keys\": \"image\", \"grid\": (2, 2)},\n    [{\"image\": A}, {\"image\": A}],\n    [[A11, A12, A21, A22], [A11, A12, A21, A22]],\n]\nTEST_CASE_MC_1 = [{\"keys\": \"image\", \"grid\": (2, 1)}, [{\"image\": A}, {\"image\": A}, {\"image\": A}], [[A1, A2]] * 3]\nTEST_CASE_MC_2 = [{\"keys\": \"image\", \"grid\": (1, 2)}, [{\"image\": A1}, {\"image\": A2}], [[A11, A12], [A21, A22]]]\n\nTEST_MULTIPLE = []\nfor p in TEST_NDARRAYS:\n    TEST_MULTIPLE.append([p, *TEST_CASE_MC_0])\n    TEST_MULTIPLE.append([p, *TEST_CASE_MC_1])\n    TEST_MULTIPLE.append([p, *TEST_CASE_MC_2])\n\n\nclass TestGridSplitd(unittest.TestCase):\n    @parameterized.expand(TEST_SINGLE)\n    def test_split_patch_single_call(self, in_type, input_parameters, img_dict, expected):\n        input_dict = {}\n        for k, v in img_dict.items():\n            input_dict[k] = in_type(v)\n        splitter = GridSplitd(**input_parameters)\n        output = splitter(input_dict)\n        for output_patch, expected_patch in zip(output, expected):\n            assert_allclose(output_patch[input_parameters[\"keys\"]], expected_patch, type_test=False)\n\n    @parameterized.expand(TEST_MULTIPLE)\n    def test_split_patch_multiple_call(self, in_type, input_parameters, img_list, expected_list):\n        splitter = GridSplitd(**input_parameters)\n        for img_dict, expected in zip(img_list, expected_list):\n            input_dict = {}\n            for k, v in img_dict.items():\n                input_dict[k] = in_type(v)\n            output = splitter(input_dict)\n            for output_patch, expected_patch in zip(output, expected):\n                assert_allclose(output_patch[input_parameters[\"keys\"]], expected_patch, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_histogram_normalize.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import HistogramNormalize\nfrom monai.utils import get_equivalent_dtype\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"num_bins\": 4, \"min\": 1, \"max\": 5, \"mask\": np.array([1, 1, 1, 1, 1, 0])},\n            p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])),\n            p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"num_bins\": 4, \"max\": 4, \"dtype\": np.uint8},\n            p(np.array([0.0, 1.0, 2.0, 3.0, 4.0])),\n            p(np.array([0, 0, 1, 3, 4])),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"num_bins\": 256, \"max\": 255, \"dtype\": np.uint8},\n            p(np.array([[[100.0, 200.0], [150.0, 250.0]]])),\n            p(np.array([[[0, 170], [70, 255]]])),\n        ]\n    )\n\n\nclass TestHistogramNormalize(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = HistogramNormalize(**arguments)(image)\n        assert_allclose(result, expected_data, type_test=\"tensor\")\n        self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), arguments.get(\"dtype\", np.float32))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_histogram_normalized.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import HistogramNormalized\nfrom monai.utils import get_equivalent_dtype\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"num_bins\": 4, \"min\": 1, \"max\": 5, \"mask_key\": \"mask\"},\n            {\"img\": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), \"mask\": p(np.array([1, 1, 1, 1, 1, 0]))},\n            p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"num_bins\": 4, \"max\": 4, \"dtype\": np.uint8},\n            {\"img\": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0]))},\n            p(np.array([0, 0, 1, 3, 4])),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"num_bins\": 256, \"max\": 255, \"dtype\": np.uint8},\n            {\"img\": p(np.array([[[100.0, 200.0], [150.0, 250.0]]]))},\n            p(np.array([[[0, 170], [70, 255]]])),\n        ]\n    )\n\n\nclass TestHistogramNormalized(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = HistogramNormalized(**arguments)(image)[\"img\"]\n        assert_allclose(result, expected_data, type_test=\"tensor\")\n        self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), arguments.get(\"dtype\", np.float32))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_image_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.networks.layers.simplelayers import GaussianFilter\nfrom monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd\n\nEXPECTED_FILTERS = {\n    \"mean\": torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]).float(),\n    \"laplace\": torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float(),\n    \"elliptical\": torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]).float(),\n    \"sharpen\": torch.tensor([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]).float(),\n}\n\nSUPPORTED_FILTERS = [\"mean\", \"laplace\", \"elliptical\", \"sobel\", \"sharpen\", \"median\", \"gauss\", \"savitzky_golay\"]\nSAMPLE_IMAGE_2D = torch.randn(1, 10, 10)\nSAMPLE_IMAGE_3D = torch.randn(1, 10, 10, 10)\nSAMPLE_DICT = {\"image_2d\": SAMPLE_IMAGE_2D, \"image_3d\": SAMPLE_IMAGE_3D}\n\n# Sobel filter uses reflect pad as default which is not implemented for 3d in torch 1.8.1 or 1.9.1\nADDITIONAL_ARGUMENTS = {\"order\": 1, \"sigma\": 1, \"padding_mode\": \"zeros\"}\n\n\nclass TestModule(torch.nn.Module):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        return x + 1\n\n\nclass TestNotAModuleOrTransform:\n    pass\n\n\nclass TestImageFilter(unittest.TestCase):\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_init_from_string(self, filter_name):\n        \"Test init from string\"\n        _ = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS)\n\n    def test_init_raises(self):\n        with self.assertRaises(Exception) as context:\n            _ = ImageFilter(\"mean\")\n            self.assertTrue(\"`filter_size` must be specified when specifying filters by string.\" in str(context.output))\n        with self.assertRaises(Exception) as context:\n            _ = ImageFilter(\"mean\")\n            self.assertTrue(\"`filter_size` should be a single uneven integer.\" in str(context.output))\n        with self.assertRaises(Exception) as context:\n            _ = ImageFilter(\"gauss\", 3)\n            self.assertTrue(\"`filter='gauss', requires the additonal keyword argument `sigma`\" in str(context.output))\n        with self.assertRaises(Exception) as context:\n            _ = ImageFilter(\"savitzky_golay\", 3)\n            self.assertTrue(\n                \"`filter='savitzky_golay', requires the additonal keyword argument `order`\" in str(context.output)\n            )\n\n    def test_init_from_array(self):\n        \"Test init with custom filter and assert wrong filter shape throws an error\"\n        _ = ImageFilter(torch.ones(3, 3))\n        _ = ImageFilter(torch.ones(3, 3, 3))\n        _ = ImageFilter(np.ones((3, 3)))\n        _ = ImageFilter(np.ones((3, 3, 3)))\n\n        with self.assertRaises(Exception) as context:\n            _ = ImageFilter(torch.ones(3, 3, 3, 3))\n            self.assertTrue(\"Only 1D, 2D, and 3D filters are supported.\" in str(context.output))\n\n    def test_init_from_module(self):\n        filter = ImageFilter(TestModule())\n        out = filter(torch.zeros(1, 3, 3, 3))\n        torch.testing.assert_allclose(torch.ones(1, 3, 3, 3), out)\n\n    def test_init_from_transform(self):\n        _ = ImageFilter(GaussianFilter(3, sigma=2))\n\n    def test_init_from_wrong_type_fails(self):\n        with self.assertRaises(Exception) as context:\n            _ = ImageFilter(TestNotAModuleOrTransform())\n            self.assertTrue(\"<class 'type'> is not supported.\" in str(context.output))\n\n    @parameterized.expand(EXPECTED_FILTERS.keys())\n    def test_2d_filter_correctness(self, filter_name):\n        \"Test correctness of filters (2d only)\"\n        tfm = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        filter = tfm._get_filter_from_string(filter_name, size=3, ndim=2).filter.squeeze()\n        torch.testing.assert_allclose(filter, EXPECTED_FILTERS[filter_name])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_2d(self, filter_name):\n        \"Text function `__call__` for 2d images\"\n        filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_IMAGE_2D)\n        self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_3d(self, filter_name):\n        \"Text function `__call__` for 3d images\"\n        filter = ImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_IMAGE_3D)\n        self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:])\n\n    def test_pass_applied_operations(self):\n        \"Test that applied operations are passed through\"\n        applied_operations = [\"op1\", \"op2\"]\n        image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations)\n        filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(image)\n        self.assertEqual(out_tensor.applied_operations, applied_operations)\n\n    def test_pass_empty_metadata_dict(self):\n        \"Test that applied operations are passed through\"\n        image = MetaTensor(SAMPLE_IMAGE_2D, meta={})\n        filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(image)\n        self.assertTrue(isinstance(out_tensor, MetaTensor))\n\n    def test_gaussian_filter_without_filter_size(self):\n        \"Test Gaussian filter without specifying filter_size\"\n        filter = ImageFilter(\"gauss\", sigma=2)\n        out_tensor = filter(SAMPLE_IMAGE_2D)\n        self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:])\n\n\nclass TestImageFilterDict(unittest.TestCase):\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_init_from_string_dict(self, filter_name):\n        \"Test init from string and assert an error is thrown if no size is passed\"\n        _ = ImageFilterd(\"image\", filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        with self.assertRaises(Exception) as _:\n            _ = ImageFilterd(self.image_key, filter_name)\n\n    def test_init_from_array_dict(self):\n        \"Test init with custom filter and assert wrong filter shape throws an error\"\n        _ = ImageFilterd(\"image\", torch.ones(3, 3))\n        with self.assertRaises(Exception) as _:\n            _ = ImageFilterd(self.image_key, torch.ones(3, 3, 3, 3))\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_2d(self, filter_name):\n        \"Text function `__call__` for 2d images\"\n        filter = ImageFilterd(\"image_2d\", filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_DICT)\n        self.assertEqual(out_tensor[\"image_2d\"].shape[1:], SAMPLE_IMAGE_2D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_3d(self, filter_name):\n        \"Text function `__call__` for 3d images\"\n        filter = ImageFilterd(\"image_3d\", filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_DICT)\n        self.assertEqual(out_tensor[\"image_3d\"].shape[1:], SAMPLE_IMAGE_3D.shape[1:])\n\n\nclass TestRandImageFilter(unittest.TestCase):\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_init_from_string(self, filter_name):\n        \"Test init from string and assert an error is thrown if no size is passed\"\n        _ = RandImageFilter(filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        with self.assertRaises(Exception) as _:\n            _ = RandImageFilter(filter_name)\n\n    def test_init_from_array(self):\n        \"Test init with custom filter and assert wrong filter shape throws an error\"\n        _ = RandImageFilter(torch.ones(3, 3))\n        with self.assertRaises(Exception) as _:\n            _ = RandImageFilter(torch.ones(3, 3, 3, 3))\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_2d_prob_1(self, filter_name):\n        \"Text function `__call__` for 2d images\"\n        filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_IMAGE_2D)\n        self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_3d_prob_1(self, filter_name):\n        \"Text function `__call__` for 3d images\"\n        filter = RandImageFilter(filter_name, 3, 1, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_IMAGE_3D)\n        self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_2d_prob_0(self, filter_name):\n        \"Text function `__call__` for 2d images\"\n        filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_IMAGE_2D)\n        torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_2D)\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_3d_prob_0(self, filter_name):\n        \"Text function `__call__` for 3d images\"\n        filter = RandImageFilter(filter_name, 3, 0, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_IMAGE_3D)\n        torch.testing.assert_allclose(out_tensor, SAMPLE_IMAGE_3D)\n\n\nclass TestRandImageFilterDict(unittest.TestCase):\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_init_from_string_dict(self, filter_name):\n        \"Test init from string and assert an error is thrown if no size is passed\"\n        _ = RandImageFilterd(\"image\", filter_name, 3, **ADDITIONAL_ARGUMENTS)\n        with self.assertRaises(Exception) as _:\n            _ = RandImageFilterd(\"image\", filter_name)\n\n    def test_init_from_array_dict(self):\n        \"Test init with custom filter and assert wrong filter shape throws an error\"\n        _ = RandImageFilterd(\"image\", torch.ones(3, 3))\n        with self.assertRaises(Exception) as _:\n            _ = RandImageFilterd(\"image\", torch.ones(3, 3, 3, 3))\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_2d_prob_1(self, filter_name):\n        filter = RandImageFilterd(\"image_2d\", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_DICT)\n        self.assertEqual(out_tensor[\"image_2d\"].shape[1:], SAMPLE_IMAGE_2D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_3d_prob_1(self, filter_name):\n        filter = RandImageFilterd(\"image_3d\", filter_name, 3, 1.0, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_DICT)\n        self.assertEqual(out_tensor[\"image_3d\"].shape[1:], SAMPLE_IMAGE_3D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_2d_prob_0(self, filter_name):\n        filter = RandImageFilterd(\"image_2d\", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_DICT)\n        torch.testing.assert_allclose(out_tensor[\"image_2d\"].shape[1:], SAMPLE_IMAGE_2D.shape[1:])\n\n    @parameterized.expand(SUPPORTED_FILTERS)\n    def test_call_3d_prob_0(self, filter_name):\n        filter = RandImageFilterd(\"image_3d\", filter_name, 3, 0.0, **ADDITIONAL_ARGUMENTS)\n        out_tensor = filter(SAMPLE_DICT)\n        torch.testing.assert_allclose(out_tensor[\"image_3d\"].shape[1:], SAMPLE_IMAGE_3D.shape[1:])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_intensity_stats.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import IntensityStats\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.extend(\n        [\n            [\n                {\"ops\": [\"max\", \"mean\"], \"key_prefix\": \"orig\"},\n                p([[[0.0, 1.0], [2.0, 3.0]]]),\n                {\"affine\": None},\n                {\"orig_max\": 3.0, \"orig_mean\": 1.5},\n            ],\n            [{\"ops\": \"std\", \"key_prefix\": \"orig\"}, p([[[0.0, 1.0], [2.0, 3.0]]]), None, {\"orig_std\": 1.118034}],\n            [\n                {\"ops\": [np.mean, \"max\", np.min], \"key_prefix\": \"orig\"},\n                p([[[0.0, 1.0], [2.0, 3.0]]]),\n                None,\n                {\"orig_custom_0\": 1.5, \"orig_max\": 3.0, \"orig_custom_1\": 0.0},\n            ],\n            [\n                {\"ops\": [\"max\", \"mean\"], \"key_prefix\": \"orig\", \"channel_wise\": True},\n                p([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]),\n                {\"affine\": None},\n                {\"orig_max\": [3.0, 7.0], \"orig_mean\": [1.5, 5.5]},\n            ],\n            [\n                {\"ops\": [\"max\", \"mean\"], \"key_prefix\": \"orig\"},\n                p([[[0.0, 1.0], [2.0, 3.0]]]),\n                {\"affine\": None},\n                {\"orig_max\": 3.0, \"orig_mean\": 1.5},\n            ],\n        ]\n    )\n\n\nclass TestIntensityStats(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, img, meta_dict, expected):\n        _, meta_dict = IntensityStats(**input_param)(img, meta_dict)\n        for k, v in expected.items():\n            self.assertTrue(k in meta_dict)\n            np.testing.assert_allclose(v, meta_dict[k], atol=1e-3)\n\n    def test_mask(self):\n        for p in TEST_NDARRAYS:\n            img = p([[[0.0, 1.0], [2.0, 3.0]]])\n            mask = np.array([[[1, 0], [1, 0]]], dtype=bool)\n            img, meta_dict = IntensityStats(ops=[\"max\", \"mean\"], key_prefix=\"orig\")(img, mask=mask)\n            np.testing.assert_allclose(meta_dict[\"orig_max\"], 2.0, atol=1e-3)\n            np.testing.assert_allclose(meta_dict[\"orig_mean\"], 1.0, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_intensity_statsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\n\nimport numpy as np\nimport torch.multiprocessing as mp\nfrom parameterized import parameterized\n\nfrom monai.data import DataLoader, Dataset\nfrom monai.transforms import IntensityStatsd\nfrom monai.utils.enums import PostFix\n\nTEST_CASE_1 = [\n    {\"keys\": \"img\", \"ops\": [\"max\", \"mean\"], \"key_prefix\": \"orig\", \"meta_keys\": \"test_meta\"},\n    {\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]]]), \"test_meta\": {\"affine\": None}},\n    \"test_meta\",\n    {\"orig_max\": 3.0, \"orig_mean\": 1.5},\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"img\", \"ops\": \"std\", \"key_prefix\": \"orig\"},\n    {\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]]])},\n    PostFix.meta(\"img\"),\n    {\"orig_std\": 1.118034},\n]\n\nTEST_CASE_3 = [\n    {\"keys\": \"img\", \"ops\": [np.mean, \"max\", np.min], \"key_prefix\": \"orig\"},\n    {\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]]])},\n    PostFix.meta(\"img\"),\n    {\"orig_custom_0\": 1.5, \"orig_max\": 3.0, \"orig_custom_1\": 0.0},\n]\n\nTEST_CASE_4 = [\n    {\"keys\": \"img\", \"ops\": [\"max\", \"mean\"], \"key_prefix\": \"orig\", \"channel_wise\": True, \"meta_key_postfix\": \"meta\"},\n    {\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]]), \"img_meta\": {\"affine\": None}},\n    \"img_meta\",\n    {\"orig_max\": [3.0, 7.0], \"orig_mean\": [1.5, 5.5]},\n]\n\n\nclass TestIntensityStatsd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_value(self, input_param, data, meta_key, expected):\n        meta = IntensityStatsd(**input_param)(data)[meta_key]\n        for k, v in expected.items():\n            self.assertTrue(k in meta)\n            np.testing.assert_allclose(v, meta[k], atol=1e-3)\n\n    def test_dataloader(self):\n        dataset = Dataset(\n            data=[{\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]]])}, {\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]]])}],\n            transform=IntensityStatsd(keys=\"img\", ops=[\"max\", \"mean\"], key_prefix=\"orig\"),\n        )\n        # set num workers = 0 for mac / win\n        num_workers = 2 if sys.platform == \"linux\" else 0\n        dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2)\n        orig_method = mp.get_start_method()\n        mp.set_start_method(\"spawn\", force=True)\n\n        for d in dataloader:\n            meta = d[PostFix.meta(\"img\")]\n            np.testing.assert_allclose(meta[\"orig_max\"], [3.0, 3.0], atol=1e-3)\n            np.testing.assert_allclose(meta[\"orig_mean\"], [1.5, 1.5], atol=1e-3)\n        # restore the mp method\n        mp.set_start_method(orig_method, force=True)\n\n    def test_mask(self):\n        data = {\"img\": np.array([[[0.0, 1.0], [2.0, 3.0]]]), \"img_mask\": np.array([[[1, 0], [1, 0]]], dtype=bool)}\n        stats = IntensityStatsd(keys=\"img\", ops=[\"max\", \"mean\"], mask_keys=\"img_mask\", key_prefix=\"orig\")\n        meta = stats(data)[PostFix.meta(\"img\")]\n        np.testing.assert_allclose(meta[\"orig_max\"], 2.0, atol=1e-3)\n        np.testing.assert_allclose(meta[\"orig_mean\"], 1.0, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_inverse_collation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport unittest\nfrom typing import TYPE_CHECKING\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import (\n    CacheDataset,\n    DataLoader,\n    MetaTensor,\n    create_test_image_2d,\n    create_test_image_3d,\n    decollate_batch,\n    pad_list_data_collate,\n)\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirstd,\n    Flipd,\n    LoadImaged,\n    RandAffined,\n    RandAxisFlipd,\n    RandFlipd,\n    RandRotate90d,\n    RandRotated,\n    RandZoomd,\n    ResizeWithPadOrCropd,\n    Rotated,\n)\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import make_nifti_image\n\nif TYPE_CHECKING:\n    has_nib = True\nelse:\n    _, has_nib = optional_import(\"nibabel\")\n\nKEYS = [\"image\", \"label\"]\n\nTESTS_3D = [\n    (t.__class__.__name__ + (\" pad_list_data_collate\" if collate_fn else \" default_collate\"), t, collate_fn, 3)\n    for collate_fn in [None, pad_list_data_collate]\n    for t in [\n        Flipd(KEYS, spatial_axis=1),\n        RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]),\n        RandAxisFlipd(keys=KEYS, prob=0.5),\n        Compose([RandRotate90d(keys=KEYS, spatial_axes=(1, 2))]),\n        RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),\n        Rotated(keys=KEYS, angle=np.pi, dtype=np.float64),\n        RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64),\n        RandAffined(\n            keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        ),\n    ]\n]\n\nTESTS_2D = [\n    (t.__class__.__name__ + (\" pad_list_data_collate\" if collate_fn else \" default_collate\"), t, collate_fn, 2)\n    for collate_fn in [None, pad_list_data_collate]\n    for t in [\n        Flipd(KEYS, spatial_axis=1),\n        RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]),\n        RandAxisFlipd(keys=KEYS, prob=0.5),\n        Compose([RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1))]),\n        RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),\n        Rotated(keys=KEYS, angle=np.pi / 2, dtype=np.float64),\n        RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64),\n        RandAffined(\n            keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        ),\n    ]\n]\n\n\nclass TestInverseCollation(unittest.TestCase):\n    \"\"\"Test collation for of random transformations with prob == 0 and 1.\"\"\"\n\n    def setUp(self):\n        if not has_nib:\n            self.skipTest(\"nibabel required for test_inverse\")\n\n        set_determinism(seed=0)\n\n        b_size = 11\n        im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107))\n        load_ims = Compose([LoadImaged(KEYS), EnsureChannelFirstd(KEYS, channel_dim=\"no_channel\")])\n        self.data_3d = [load_ims({\"image\": im_fname, \"label\": seg_fname}) for _ in range(b_size)]\n\n        b_size = 8\n        im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_2d(62, 37, rad_max=10))\n        load_ims = Compose([LoadImaged(KEYS), EnsureChannelFirstd(KEYS, channel_dim=\"no_channel\")])\n        self.data_2d = [load_ims({\"image\": im_fname, \"label\": seg_fname}) for _ in range(b_size)]\n\n        self.batch_size = 7\n\n    def tearDown(self):\n        set_determinism(seed=None)\n\n    @parameterized.expand(TESTS_2D + TESTS_3D)\n    def test_collation(self, _, transform, collate_fn, ndim):\n        \"\"\"transform, collate_fn, ndim\"\"\"\n        data = self.data_3d if ndim == 3 else self.data_2d\n        if collate_fn:\n            modified_transform = transform\n        else:\n            modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100)])\n\n        # num workers = 0 for mac or gpu transforms\n        num_workers = 0 if sys.platform != \"linux\" or torch.cuda.is_available() else 2\n\n        dataset = CacheDataset(data, transform=modified_transform, progress=False)\n        loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn)\n\n        for item in loader:\n            if isinstance(item, dict):\n                np.testing.assert_array_equal(item[\"image\"].shape, item[\"label\"].shape)\n                continue\n            d = decollate_batch(item)\n            self.assertTrue(len(d) <= self.batch_size)\n            for b in d:\n                self.assertIsInstance(b[\"image\"], MetaTensor)\n                np.testing.assert_array_equal(\n                    b[\"image\"].applied_operations[-1][\"orig_size\"], b[\"label\"].applied_operations[-1][\"orig_size\"]\n                )\n                np.testing.assert_array_equal(\n                    b[\"image\"].applied_operations[-1].get(\"_do_transform\"),\n                    b[\"label\"].applied_operations[-1].get(\"_do_transform\"),\n                )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_k_space_spike_noise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom numpy.fft import fftn, fftshift\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import KSpaceSpikeNoise\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for p in TEST_NDARRAYS:\n        for intensity in [10, None]:\n            TESTS.append((shape, p, intensity))\n\n\nclass TestKSpaceSpikeNoise(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, im_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        im, _ = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)\n        return im_type(im[None])\n\n    @parameterized.expand(TESTS)\n    def test_same_result(self, im_shape, im_type, k_intensity):\n        im = self.get_data(im_shape, im_type)\n        loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0]\n        t = KSpaceSpikeNoise(loc, k_intensity)\n\n        out1 = t(deepcopy(im))\n        out2 = t(deepcopy(im))\n\n        if isinstance(out1, torch.Tensor):\n            out1 = out1.cpu()\n            out2 = out2.cpu()\n\n        np.testing.assert_allclose(out1, out2)\n\n    @parameterized.expand(TESTS)\n    def test_highlighted_kspace_pixel(self, im_shape, as_tensor_input, k_intensity):\n        im = self.get_data(im_shape, as_tensor_input)\n        loc = [0, int(im.shape[1] / 2), 0] if len(im_shape) == 2 else [0, int(im.shape[1] / 2), 0, 0]\n        t = KSpaceSpikeNoise(loc, k_intensity)\n        out = t(im)\n\n        if isinstance(out, torch.Tensor):\n            out = out.cpu()\n\n        if k_intensity is not None:\n            n_dims = len(im_shape)\n            out_k = fftshift(fftn(out, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))\n            log_mag = np.log(np.absolute(out_k))\n            np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_k_space_spike_noised.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom numpy.fft import fftn, fftshift\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import KSpaceSpikeNoised\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for p in TEST_NDARRAYS:\n        TESTS.append((shape, p))\n\nKEYS = [\"image\", \"label\"]\n\n\nclass TestKSpaceSpikeNoised(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, im_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)\n        ims = [im_type(im[None]) for im in ims]\n        return dict(zip(KEYS, ims))\n\n    @parameterized.expand(TESTS)\n    def test_same_result(self, im_shape, im_type):\n        data = self.get_data(im_shape, im_type)\n        loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))]\n        k_intensity = 10\n\n        t = KSpaceSpikeNoised(KEYS, loc, k_intensity)\n        out1 = t(deepcopy(data))\n        out2 = t(deepcopy(data))\n\n        for k in KEYS:\n            if isinstance(out1[k], torch.Tensor):\n                out1[k] = out1[k].cpu()\n                out2[k] = out2[k].cpu()\n            np.testing.assert_allclose(out1[k], out2[k])\n\n    @parameterized.expand(TESTS)\n    def test_highlighted_kspace_pixel(self, im_shape, im_type):\n        data = self.get_data(im_shape, im_type)\n        loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))]\n        k_intensity = 10\n\n        t = KSpaceSpikeNoised(KEYS, loc, k_intensity)\n        out = t(data)\n\n        for k in KEYS:\n            if isinstance(out[k], torch.Tensor):\n                out[k] = out[k].cpu()\n\n            n_dims = len(im_shape)\n            out_k = fftshift(fftn(out[k], axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))\n            log_mag = np.log(np.absolute(out_k))\n            np.testing.assert_allclose(k_intensity, log_mag[tuple(loc)], 1e-1)\n\n    @parameterized.expand(TESTS)\n    def test_dict_matches(self, im_shape, im_type):\n        data = self.get_data(im_shape, im_type)\n        # use same image for both dictionary entries to check same trans is applied to them\n        data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])}\n        loc = [0] + [int(im_shape[i] / 2) for i in range(len(im_shape))]\n        k_intensity = 10\n\n        t = KSpaceSpikeNoised(KEYS, loc, k_intensity)\n        out = t(deepcopy(data))\n        assert_allclose(out[KEYS[0]], out[KEYS[1]], type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_keep_largest_connected_component.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport torch\nimport torch.nn.functional as F\nfrom parameterized import parameterized\n\nfrom monai.transforms import KeepLargestConnectedComponent\nfrom monai.transforms.utils_pytorch_numpy_unification import moveaxis\nfrom monai.utils.type_conversion import convert_to_dst_type\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n\ndef to_onehot(x):\n    out = moveaxis(F.one_hot(torch.as_tensor(x).long())[0], -1, 0)\n    out, *_ = convert_to_dst_type(out, x)\n    return out\n\n\ngrid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]\ngrid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]\ngrid_3 = [\n    [\n        [1.0, 1.0, 0.0, 1.0, 1.0],\n        [1.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 1.0, 1.0],\n        [0.0, 0.0, 1.0, 0.0, 1.0],\n        [0.0, 0.0, 1.0, 1.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 1.0, 0.0, 0.0],\n        [0.0, 0.0, 1.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n        [1.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [1.0, 1.0, 0.0, 0.0, 1.0],\n    ],\n]\ngrid_4 = [\n    [\n        [1.0, 1.0, 1.0, 1.0, 0.0],\n        [1.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0, 1.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n        [0.0, 0.0, 1.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0, 1.0, 0.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n        [0.0, 0.0, 0.0, 1.0, 1.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n]\ngrid_5 = [[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 1, 0], [1, 1, 0, 0, 1]]]\n\ngrid_6 = [[[0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1], [1, 1, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 0, 1]]]\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            \"value_1\",\n            {\"independent\": False, \"applied_labels\": 1, \"is_onehot\": False},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"value_2\",\n            {\"independent\": False, \"applied_labels\": [2], \"is_onehot\": False},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"independent_value_1_2\",\n            {\"independent\": True, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"dependent_value_1_2\",\n            {\"independent\": False, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"value_1\",\n            {\"independent\": True, \"applied_labels\": [1], \"is_onehot\": False},\n            p(grid_2),\n            torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"independent_value_1_2\",\n            {\"independent\": True, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            p(grid_2),\n            torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"dependent_value_1_2\",\n            {\"independent\": False, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            p(grid_2),\n            torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"value_1_connect_1\",\n            {\"independent\": False, \"applied_labels\": [1], \"connectivity\": 1, \"is_onehot\": False},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"independent_value_1_2_connect_1\",\n            {\"independent\": True, \"applied_labels\": [1, 2], \"connectivity\": 1, \"is_onehot\": False},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"onehot_none_dependent_value_1_2_connect_1\",\n            {\"independent\": False, \"applied_labels\": [1, 2], \"connectivity\": 1},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"onehot_independent_batch_2_apply_label_1_connect_1\",\n            {\"independent\": True, \"applied_labels\": [1], \"connectivity\": 1, \"is_onehot\": True},\n            p(grid_3),\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 1.0, 1.0],\n                        [1.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 1.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"onehot_independent_batch_2_apply_label_1_connect_2\",\n            {\"independent\": True, \"applied_labels\": [1], \"connectivity\": 2, \"is_onehot\": True},\n            p(grid_3),\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 1.0, 1.0],\n                        [1.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 1.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"onehot_independent_batch_2_apply_label_1_2_connect_2\",\n            {\"independent\": True, \"applied_labels\": [1, 2], \"connectivity\": 2, \"is_onehot\": True},\n            p(grid_3),\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 1.0, 1.0],\n                        [1.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"onehot_dependent_batch_2_apply_label_1_2_connect_2\",\n            {\"independent\": False, \"applied_labels\": [1, 2], \"connectivity\": 2, \"is_onehot\": True},\n            p(grid_4),\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"onehot_none_dependent_batch_2_apply_label_1_2_connect_1\",\n            {\"independent\": False, \"applied_labels\": [1, 2], \"connectivity\": 1},\n            p(grid_4),\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            \"all_non_zero_labels\",\n            {\"independent\": True},\n            p(grid_1),\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n    # no connected regions\n    TESTS.append([\"0 regions\", {\"num_components\": 0}, p(grid_6), p(torch.zeros(1, 4, 7))])\n    # 1 connected region\n    TESTS.append(\n        [\n            \"1 region\",\n            {\"num_components\": 1},\n            p(grid_6),\n            p(\n                torch.tensor(\n                    [[[0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0, 0]]]\n                )\n            ),\n        ]\n    )\n    # 2 connected regions\n    TESTS.append(\n        [\n            \"2 regions\",\n            {\"num_components\": 2},\n            p(grid_6),\n            p(\n                torch.tensor(\n                    [[[0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 0, 1]]]\n                )\n            ),\n        ]\n    )\n    # 3+ connected regions unchanged (as input has 3)\n    for num_connected in (3, 4):\n        TESTS.append([f\"{num_connected} regions\", {\"num_components\": num_connected}, p(grid_6), p(grid_6)])\n\n\nclass TestKeepLargestConnectedComponent(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, _, args, input_image, expected):\n        converter = KeepLargestConnectedComponent(**args)\n        result = converter(input_image)\n        assert_allclose(result, expected, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS)\n    def test_correct_results_before_after_onehot(self, _, args, input_image, expected):\n        \"\"\"\n        From torch==1.7, torch.argmax changes its mechanism that if there are multiple maximal values then the\n        indices of the first maximal value are returned (before this version, the indices of the last maximal value\n        are returned).\n        Therefore, we can may use of this changes to convert the onehotted labels into un-onehot format directly\n        and then check if the result stays the same.\n\n        \"\"\"\n        converter = KeepLargestConnectedComponent(**args)\n        result = converter(deepcopy(input_image))\n\n        if \"is_onehot\" in args:\n            args[\"is_onehot\"] = not args[\"is_onehot\"]\n        # if not onehotted, onehot it and make sure result stays the same\n        if input_image.shape[0] == 1:\n            img = to_onehot(input_image)\n            result2 = KeepLargestConnectedComponent(**args)(img)\n            result2 = result2.argmax(0)[None]\n            assert_allclose(result, result2, type_test=\"tensor\")\n        # if onehotted, un-onehot and check result stays the same\n        else:\n            img = input_image.argmax(0)[None]\n            result2 = KeepLargestConnectedComponent(**args)(img)\n            assert_allclose(result.argmax(0)[None], result2, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_keep_largest_connected_componentd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import KeepLargestConnectedComponentd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\ngrid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]\ngrid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]\ngrid_3 = [\n    [\n        [1.0, 1.0, 0.0, 1.0, 1.0],\n        [1.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 1.0, 1.0],\n        [0.0, 0.0, 1.0, 0.0, 1.0],\n        [0.0, 0.0, 1.0, 1.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 1.0, 0.0, 0.0],\n        [0.0, 0.0, 1.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n        [1.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [1.0, 1.0, 0.0, 0.0, 1.0],\n    ],\n]\ngrid_4 = [\n    [\n        [1.0, 1.0, 1.0, 1.0, 0.0],\n        [1.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0, 1.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n        [0.0, 0.0, 1.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0, 1.0, 0.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 1.0],\n        [0.0, 0.0, 0.0, 1.0, 1.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n]\ngrid_5 = [[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 1, 0], [1, 1, 0, 0, 1]]]\n\nVALID_CASES = []\nfor p in TEST_NDARRAYS:\n    VALID_CASES.append(\n        [\n            \"value_1\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": 1, \"is_onehot\": False},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"value_2\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [2], \"is_onehot\": False},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"independent_value_1_2\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"dependent_value_1_2\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"value_1\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1], \"is_onehot\": False},\n            {\"img\": p(grid_2)},\n            torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"independent_value_1_2\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            {\"img\": p(grid_2)},\n            torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"dependent_value_1_2\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [1, 2], \"is_onehot\": False},\n            {\"img\": p(grid_2)},\n            torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"value_1_connect_1\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [1], \"connectivity\": 1, \"is_onehot\": False},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"independent_value_1_2_connect_1\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1, 2], \"connectivity\": 1, \"is_onehot\": False},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"onehot_none_dependent_value_1_2_connect_1\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [1, 2], \"connectivity\": 1},\n            {\"img\": p(grid_1)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"onehot_independent_batch_2_apply_label_1_connect_1\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1], \"connectivity\": 1, \"is_onehot\": True},\n            {\"img\": p(grid_3)},\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 1.0, 1.0],\n                        [1.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 1.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"onehot_independent_batch_2_apply_label_1_connect_2\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1], \"connectivity\": 2, \"is_onehot\": True},\n            {\"img\": p(grid_3)},\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 1.0, 1.0],\n                        [1.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 1.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"onehot_independent_batch_2_apply_label_1_2_connect_2\",\n            {\"keys\": [\"img\"], \"independent\": True, \"applied_labels\": [1, 2], \"connectivity\": 2, \"is_onehot\": True},\n            {\"img\": p(grid_3)},\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 1.0, 1.0],\n                        [1.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 1.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"onehot_dependent_batch_2_apply_label_1_2_connect_2\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [1, 2], \"connectivity\": 2, \"is_onehot\": True},\n            {\"img\": p(grid_4)},\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"onehot_none_dependent_batch_2_apply_label_1_2_connect_1\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": [1, 2], \"connectivity\": 1},\n            {\"img\": p(grid_4)},\n            torch.tensor(\n                [\n                    [\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0, 0.0, 0.0],\n                        [1.0, 1.0, 1.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 1.0],\n                        [0.0, 0.0, 1.0, 1.0, 0.0],\n                        [0.0, 0.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                    ],\n                    [\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0, 1.0],\n                        [0.0, 0.0, 0.0, 1.0, 1.0],\n                        [0.0, 0.0, 0.0, 0.0, 0.0],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    VALID_CASES.append(\n        [\n            \"single_channel_onehot\",\n            {\"keys\": [\"img\"], \"independent\": False, \"applied_labels\": 0, \"connectivity\": 1, \"is_onehot\": True},\n            {\"img\": p(grid_5)},\n            torch.tensor([[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]),\n        ]\n    )\n\n\nclass TestKeepLargestConnectedComponentd(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, args, input_dict, expected):\n        converter = KeepLargestConnectedComponentd(**args)\n        result = converter(input_dict)\n        assert_allclose(result[\"img\"], expected, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_label_filter.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import LabelFilter\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\ngrid_1 = torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]])\n\nVALID_TESTS = []\nfor p in TEST_NDARRAYS:\n    VALID_TESTS.append(\n        [\n            \"filter_single_label\",\n            {\"applied_labels\": 3},\n            p(grid_1),\n            p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])),\n        ]\n    )\n\n    VALID_TESTS.append(\n        [\n            \"filter_single_label_list\",\n            {\"applied_labels\": [3]},\n            p(grid_1),\n            p(torch.tensor([[[[0, 0, 3], [0, 0, 0], [0, 0, 0]]]])),\n        ]\n    )\n\n    VALID_TESTS.append(\n        [\n            \"filter_multi_label\",\n            {\"applied_labels\": [3, 5, 8]},\n            p(grid_1),\n            p(torch.tensor([[[[0, 0, 3], [0, 5, 0], [0, 8, 0]]]])),\n        ]\n    )\n\n    VALID_TESTS.append([\"filter_all\", {\"applied_labels\": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)])\n\nITEST_CASE_1 = [\"invalid_image_data_type\", {\"applied_labels\": 1}, [[[[1, 1, 1]]]], NotImplementedError]\n\nINVALID_CASES = [ITEST_CASE_1]\n\n\nclass TestLabelFilter(unittest.TestCase):\n    @parameterized.expand(VALID_TESTS)\n    def test_correct_results(self, _, args, input_image, expected):\n        converter = LabelFilter(**args)\n        result = converter(input_image)\n        assert_allclose(result, expected)\n\n    @parameterized.expand(INVALID_CASES)\n    def test_raise_exception(self, _, args, input_image, expected_error):\n        with self.assertRaises(expected_error):\n            converter = LabelFilter(**args)\n            if isinstance(input_image, torch.Tensor) and torch.cuda.is_available():\n                _ = converter(input_image.cuda())\n            else:\n                _ = converter(input_image)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_label_to_contour.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import LabelToContour\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nexpected_output_for_cube = [\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n]\n\n\ndef gen_fixed_cube(array_type):\n    scale, core_start, core_end = 8, 1, 7\n    cube = np.zeros((scale, scale, scale))\n    cube[core_start:core_end, core_start:core_end, core_start:core_end] = torch.ones(\n        core_end - core_start, core_end - core_start, core_end - core_start\n    )\n    cube = cube[None]\n\n    batch_size, channels = 10, 6\n    cube = np.tile(cube, (batch_size, channels, 1, 1, 1))\n    return array_type(cube), array_type(expected_output_for_cube)\n\n\ndef gen_fixed_img(array_type):\n    img = np.array(\n        [\n            [0, 0, 0, 1, 1, 1, 1],\n            [0, 0, 0, 1, 1, 1, 1],\n            [0, 0, 1, 1, 1, 1, 1],\n            [0, 1, 1, 1, 1, 1, 1],\n            [1, 1, 1, 1, 1, 1, 1],\n        ],\n        dtype=np.float32,\n    )\n    batch_size, channels = 10, 6\n    img = array_type(np.tile(img, (batch_size, channels, 1, 1)))\n    expected_output_for_img = array_type(\n        [\n            [0, 0, 0, 1, 1, 1, 1],\n            [0, 0, 0, 1, 0, 0, 1],\n            [0, 0, 1, 1, 0, 0, 1],\n            [0, 1, 1, 0, 0, 0, 1],\n            [1, 1, 1, 1, 1, 1, 1],\n        ]\n    )\n    return img, expected_output_for_img\n\n\nclass TestContour(unittest.TestCase):\n    def test_contour(self):\n        input_param = {\"kernel_type\": \"Laplace\"}\n\n        for p in TEST_NDARRAYS:\n            # check 5-dim input data\n            test_cube, expected_output = gen_fixed_cube(p)\n            for cube in test_cube:\n                test_result_cube = LabelToContour(**input_param)(cube)\n                self.assertEqual(test_result_cube.shape, cube.shape)\n\n                channels = cube.shape[0]\n                for channel in range(channels):\n                    assert_allclose(test_result_cube[channel, ...], expected_output, type_test=\"tensor\")\n\n            # check 4-dim input data\n            test_img, expected_output = gen_fixed_img(p)\n            for img in test_img:\n                channels = img.shape[0]\n                test_result_img = LabelToContour(**input_param)(img)\n                self.assertEqual(test_result_img.shape, img.shape)\n\n                for channel in range(channels):\n                    assert_allclose(test_result_img[channel, ...], expected_output, type_test=\"tensor\")\n\n        # check invalid input data\n        error_input = torch.rand(1, 2)\n        self.assertRaises(ValueError, LabelToContour(**input_param), error_input)\n        error_input = torch.rand(1, 2, 3, 4, 5)\n        self.assertRaises(ValueError, LabelToContour(**input_param), error_input)\n        error_input = np.random.rand(1, 2, 3, 4, 5)\n        self.assertRaises(ValueError, LabelToContour(**input_param), error_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_label_to_contourd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import LabelToContourd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nexpected_output_for_cube = [\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    [\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n]\n\n\ndef gen_fixed_cube(array_type):\n    scale, core_start, core_end = 8, 1, 7\n    cube = np.zeros((scale, scale, scale))\n    cube[core_start:core_end, core_start:core_end, core_start:core_end] = torch.ones(\n        core_end - core_start, core_end - core_start, core_end - core_start\n    )\n    cube = cube[None]\n\n    batch_size, channels = 10, 6\n    cube = np.tile(cube, (batch_size, channels, 1, 1, 1))\n    return array_type(cube), array_type(expected_output_for_cube)\n\n\ndef gen_fixed_img(array_type):\n    img = np.array(\n        [\n            [0, 0, 0, 1, 1, 1, 1],\n            [0, 0, 0, 1, 1, 1, 1],\n            [0, 0, 1, 1, 1, 1, 1],\n            [0, 1, 1, 1, 1, 1, 1],\n            [1, 1, 1, 1, 1, 1, 1],\n        ],\n        dtype=np.float32,\n    )\n    batch_size, channels = 10, 6\n    img = np.tile(img, (batch_size, channels, 1, 1))\n    img = array_type(img)\n    expected_output_for_img = array_type(\n        [\n            [0, 0, 0, 1, 1, 1, 1],\n            [0, 0, 0, 1, 0, 0, 1],\n            [0, 0, 1, 1, 0, 0, 1],\n            [0, 1, 1, 0, 0, 0, 1],\n            [1, 1, 1, 1, 1, 1, 1],\n        ]\n    )\n    return img, expected_output_for_img\n\n\nclass TestContourd(unittest.TestCase):\n    def test_contour(self):\n        input_param = {\"keys\": \"img\", \"kernel_type\": \"Laplace\"}\n\n        for p in TEST_NDARRAYS:\n            # check 5-dim input data\n            test_cube, expected_output = gen_fixed_cube(p)\n            for cube in test_cube:\n                test_result_cube = LabelToContourd(**input_param)({\"img\": cube})\n                self.assertEqual(test_result_cube[\"img\"].shape, cube.shape)\n\n                test_result_np = test_result_cube[\"img\"]\n                channels = cube.shape[0]\n                for channel in range(channels):\n                    assert_allclose(test_result_np[channel, ...], expected_output, type_test=\"tensor\")\n\n            # check 4-dim input data\n            test_img, expected_output = gen_fixed_img(p)\n            for img in test_img:\n                channels = img.shape[0]\n                test_result_img = LabelToContourd(**input_param)({\"img\": img})\n                self.assertEqual(test_result_img[\"img\"].shape, img.shape)\n\n                test_result_np = test_result_img[\"img\"]\n                for channel in range(channels):\n                    assert_allclose(test_result_np[channel, ...], expected_output, type_test=\"tensor\")\n\n        # check invalid input data\n        error_input = {\"img\": torch.rand(1, 2)}\n        self.assertRaises(ValueError, LabelToContourd(**input_param), error_input)\n        error_input = {\"img\": np.random.rand(1, 2)}\n        self.assertRaises(ValueError, LabelToContourd(**input_param), error_input)\n        error_input = {\"img\": torch.rand(1, 2, 3, 4, 5)}\n        self.assertRaises(ValueError, LabelToContourd(**input_param), error_input)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_label_to_mask.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import LabelToMask\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"select_labels\": [2, 3], \"merge_channels\": False},\n            p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])),\n            np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"select_labels\": 2, \"merge_channels\": False},\n            p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]])),\n            np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"select_labels\": [1, 2], \"merge_channels\": False},\n            p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])),\n            np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"select_labels\": 2, \"merge_channels\": False},\n            p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])),\n            np.array([[[1, 0, 1], [1, 1, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"select_labels\": [1, 2], \"merge_channels\": True},\n            p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]])),\n            np.array([[[1, 0, 1], [1, 1, 1]]]),\n        ]\n    )\n\n\nclass TestLabelToMask(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = LabelToMask(**arguments)(image)\n        assert_allclose(result, expected_data, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_label_to_maskd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import LabelToMaskd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"select_labels\": [2, 3], \"merge_channels\": False},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array([[[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"select_labels\": 2, \"merge_channels\": False},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array([[[0, 0, 0], [1, 1, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"select_labels\": [1, 2], \"merge_channels\": False},\n            {\"img\": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))},\n            np.array([[[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"select_labels\": 2, \"merge_channels\": False},\n            {\"img\": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))},\n            np.array([[[1, 0, 1], [1, 1, 0]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"select_labels\": [1, 2], \"merge_channels\": True},\n            {\"img\": p(np.array([[[0, 0, 1], [0, 1, 0]], [[1, 0, 0], [0, 1, 1]], [[1, 0, 1], [1, 1, 0]]]))},\n            np.array([[[1, 0, 1], [1, 1, 1]]]),\n        ]\n    )\n\n\nclass TestLabelToMaskd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, input_data, expected_data):\n        result = LabelToMaskd(**arguments)(input_data)\n        r = result[\"img\"]\n        assert_allclose(r, expected_data, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_load_image.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom PIL import Image\n\nfrom monai.apps import download_and_extract\nfrom monai.data import NibabelReader, PydicomReader\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import LoadImage\nfrom monai.utils import optional_import\nfrom tests.test_utils import SkipIfNoModule, assert_allclose, skip_if_downloading_fails, testing_data_config\n\nitk, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\nITKReader, _ = optional_import(\"monai.data\", name=\"ITKReader\", as_type=\"decorator\")\nitk_uc, _ = optional_import(\"itk\", name=\"UC\", allow_namespace_pkg=True)\n\n\nclass _MiniReader:\n    \"\"\"a test case customised reader\"\"\"\n\n    def __init__(self, is_compatible=False):\n        self.is_compatible = is_compatible\n\n    def verify_suffix(self, _name):\n        return self.is_compatible\n\n    def read(self, name):\n        return name\n\n    def get_data(self, _obj):\n        return np.zeros((1, 1, 1)), {\"name\": \"my test\"}\n\n\nTEST_CASE_1 = [{}, [\"test_image.nii.gz\"], (128, 128, 128)]\n\nTEST_CASE_2 = [{}, [\"test_image.nii.gz\"], (128, 128, 128)]\n\nTEST_CASE_3 = [{}, [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"], (3, 128, 128, 128)]\n\nTEST_CASE_3_1 = [  # .mgz format\n    {\"reader\": \"nibabelreader\"},\n    [\"test_image.mgz\", \"test_image2.mgz\", \"test_image3.mgz\"],\n    (3, 128, 128, 128),\n]\n\nTEST_CASE_4 = [{}, [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"], (3, 128, 128, 128)]\n\nTEST_CASE_4_1 = [  # additional parameter\n    {\"mmap\": False},\n    [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"],\n    (3, 128, 128, 128),\n]\n\nTEST_CASE_5 = [{\"reader\": NibabelReader(mmap=False)}, [\"test_image.nii.gz\"], (128, 128, 128)]\n\nTEST_CASE_GPU_1 = [{\"reader\": \"nibabelreader\", \"to_gpu\": True}, [\"test_image.nii.gz\"], (128, 128, 128)]\n\nTEST_CASE_GPU_2 = [{\"reader\": \"nibabelreader\", \"to_gpu\": True}, [\"test_image.nii\"], (128, 128, 128)]\n\nTEST_CASE_GPU_3 = [\n    {\"reader\": \"nibabelreader\", \"to_gpu\": True},\n    [\"test_image.nii\", \"test_image2.nii\", \"test_image3.nii\"],\n    (3, 128, 128, 128),\n]\n\nTEST_CASE_GPU_4 = [\n    {\"reader\": \"nibabelreader\", \"to_gpu\": True},\n    [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"],\n    (3, 128, 128, 128),\n]\n\nTEST_CASE_6 = [{\"reader\": ITKReader() if has_itk else \"itkreader\"}, [\"test_image.nii.gz\"], (128, 128, 128)]\n\nTEST_CASE_7 = [{\"reader\": ITKReader() if has_itk else \"itkreader\"}, [\"test_image.nii.gz\"], (128, 128, 128)]\n\nTEST_CASE_8 = [\n    {\"reader\": ITKReader() if has_itk else \"itkreader\"},\n    [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"],\n    (3, 128, 128, 128),\n]\n\nTEST_CASE_8_1 = [\n    {\"reader\": ITKReader(channel_dim=0) if has_itk else \"itkreader\"},\n    [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"],\n    (384, 128, 128),\n]\n\nTEST_CASE_9 = [\n    {\"reader\": ITKReader() if has_itk else \"itkreader\"},\n    [\"test_image.nii.gz\", \"test_image2.nii.gz\", \"test_image3.nii.gz\"],\n    (3, 128, 128, 128),\n]\n\nTEST_CASE_10 = [\n    {\"reader\": ITKReader(pixel_type=itk_uc) if has_itk else \"itkreader\"},\n    \"tests/testing_data/CT_DICOM\",\n    (16, 16, 4),\n    (16, 16, 4),\n]\n\nTEST_CASE_11 = [{\"reader\": \"ITKReader\", \"pixel_type\": itk_uc}, \"tests/testing_data/CT_DICOM\", (16, 16, 4), (16, 16, 4)]\n\nTEST_CASE_12 = [\n    {\"reader\": \"ITKReader\", \"pixel_type\": itk_uc, \"reverse_indexing\": True},\n    \"tests/testing_data/CT_DICOM\",\n    (16, 16, 4),\n    (4, 16, 16),\n]\n\nTEST_CASE_13 = [{\"reader\": \"nibabelreader\", \"channel_dim\": 0}, \"test_image.nii.gz\", (3, 128, 128, 128)]\n\nTEST_CASE_14 = [\n    {\"reader\": \"nibabelreader\", \"channel_dim\": -1, \"ensure_channel_first\": True},\n    \"test_image.nii.gz\",\n    (128, 128, 128, 3),\n]\n\nTEST_CASE_15 = [{\"reader\": \"nibabelreader\", \"channel_dim\": 2}, \"test_image.nii.gz\", (128, 128, 3, 128)]\n\nTEST_CASE_16 = [{\"reader\": \"itkreader\", \"channel_dim\": 0}, \"test_image.nii.gz\", (3, 128, 128, 128)]\n\nTEST_CASE_17 = [{\"reader\": \"monai.data.ITKReader\", \"channel_dim\": -1}, \"test_image.nii.gz\", (128, 128, 128, 3)]\n\nTEST_CASE_18 = [\n    {\"reader\": \"ITKReader\", \"channel_dim\": 2, \"ensure_channel_first\": True},\n    \"test_image.nii.gz\",\n    (128, 128, 3, 128),\n]\n\n# test same dicom data with PydicomReader\nTEST_CASE_19 = [{\"reader\": PydicomReader()}, \"tests/testing_data/CT_DICOM\", (16, 16, 4), (16, 16, 4)]\n\nTEST_CASE_20 = [\n    {\"reader\": \"PydicomReader\", \"ensure_channel_first\": True, \"force\": True},\n    \"tests/testing_data/CT_DICOM\",\n    (16, 16, 4),\n    (1, 16, 16, 4),\n]\n\nTEST_CASE_21 = [\n    {\"reader\": \"PydicomReader\", \"affine_lps_to_ras\": True, \"defer_size\": \"2 MB\", \"force\": True},\n    \"tests/testing_data/CT_DICOM\",\n    (16, 16, 4),\n    (16, 16, 4),\n]\n\n# test reader consistency between PydicomReader and ITKReader on dicom data\nTEST_CASE_22 = [\"tests/testing_data/CT_DICOM\"]\n\n# test pydicom gpu reader\nTEST_CASE_GPU_5 = [{\"reader\": \"PydicomReader\", \"to_gpu\": True}, \"tests/testing_data/CT_DICOM\", (16, 16, 4), (16, 16, 4)]\n\nTEST_CASE_GPU_6 = [\n    {\"reader\": \"PydicomReader\", \"ensure_channel_first\": True, \"force\": True, \"to_gpu\": True},\n    \"tests/testing_data/CT_DICOM\",\n    (16, 16, 4),\n    (1, 16, 16, 4),\n]\n\nTESTS_META = []\nfor track_meta in (False, True):\n    TESTS_META.append([{}, (128, 128, 128), track_meta])\n    TESTS_META.append([{\"reader\": \"ITKReader\", \"fallback_only\": False}, (128, 128, 128), track_meta])\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadImage(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        with skip_if_downloading_fails():\n            cls.tmpdir = tempfile.mkdtemp()\n            key = \"DICOM_single\"\n            url = testing_data_config(\"images\", key, \"url\")\n            hash_type = testing_data_config(\"images\", key, \"hash_type\")\n            hash_val = testing_data_config(\"images\", key, \"hash_val\")\n            download_and_extract(\n                url=url, output_dir=cls.tmpdir, hash_val=hash_val, hash_type=hash_type, file_type=\"zip\"\n            )\n            cls.data_dir = os.path.join(cls.tmpdir, \"CT_DICOM_SINGLE\")\n\n    @classmethod\n    def tearDownClass(cls):\n        shutil.rmtree(cls.tmpdir)\n        super().tearDownClass()\n\n    @parameterized.expand(\n        [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_3_1, TEST_CASE_4, TEST_CASE_4_1, TEST_CASE_5]\n    )\n    def test_nibabel_reader(self, input_param, filenames, expected_shape):\n        test_image = np.random.rand(128, 128, 128)\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])\n            result = LoadImage(image_only=True, **input_param)(filenames)\n            ext = \"\".join(Path(name).suffixes)\n            self.assertEqual(result.meta[\"filename_or_obj\"], os.path.join(tempdir, \"test_image\" + ext))\n            self.assertEqual(result.meta[\"space\"], \"RAS\")\n            assert_allclose(result.affine, torch.eye(4))\n            self.assertTupleEqual(result.shape, expected_shape)\n\n    @SkipIfNoModule(\"nibabel\")\n    @SkipIfNoModule(\"cupy\")\n    @SkipIfNoModule(\"kvikio\")\n    @parameterized.expand([TEST_CASE_GPU_1, TEST_CASE_GPU_2, TEST_CASE_GPU_3, TEST_CASE_GPU_4])\n    def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):\n        if torch.__version__.endswith(\"nv24.8\"):\n            # related issue: https://github.com/Project-MONAI/MONAI/issues/8274\n            # for this version, use randint test case to avoid the issue\n            test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy()\n        else:\n            test_image = np.random.rand(128, 128, 128)\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])\n            result = LoadImage(image_only=True, **input_param)(filenames)\n            ext = \"\".join(Path(name).suffixes)\n            self.assertEqual(result.meta[\"filename_or_obj\"], os.path.join(tempdir, \"test_image\" + ext))\n            self.assertEqual(result.meta[\"space\"], \"RAS\")\n            assert_allclose(result.affine, torch.eye(4))\n            self.assertTupleEqual(result.shape, expected_shape)\n\n            # verify gpu and cpu loaded data are the same\n            input_param_cpu = input_param.copy()\n            input_param_cpu[\"to_gpu\"] = False\n            result_cpu = LoadImage(image_only=True, **input_param_cpu)(filenames)\n            assert_allclose(result_cpu, result.cpu(), atol=1e-6)\n\n    @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])\n    def test_itk_reader(self, input_param, filenames, expected_shape):\n        test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy()\n        print(\"Test image value range:\", test_image.min(), test_image.max())\n        with tempfile.TemporaryDirectory() as tempdir:\n            for i, name in enumerate(filenames):\n                filenames[i] = os.path.join(tempdir, name)\n                nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])\n            result = LoadImage(image_only=True, **input_param)(filenames)\n            ext = \"\".join(Path(name).suffixes)\n            self.assertEqual(result.meta[\"filename_or_obj\"], os.path.join(tempdir, \"test_image\" + ext))\n            self.assertEqual(result.meta[\"space\"], \"RAS\")\n            assert_allclose(result.affine, torch.eye(4))\n            self.assertTupleEqual(result.shape, expected_shape)\n\n    @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_19, TEST_CASE_20, TEST_CASE_21])\n    def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, expected_np_shape):\n        result = LoadImage(image_only=True, **input_param)(filenames)\n        self.assertEqual(result.meta[\"filename_or_obj\"], f\"{Path(filenames)}\")\n        assert_allclose(\n            result.affine,\n            torch.tensor(\n                [\n                    [-0.488281, 0.0, 0.0, 125.0],\n                    [0.0, -0.488281, 0.0, 128.100006],\n                    [0.0, 0.0, 68.33333333, -99.480003],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n        self.assertTupleEqual(result.shape, expected_np_shape)\n\n    @SkipIfNoModule(\"pydicom\")\n    @SkipIfNoModule(\"cupy\")\n    @SkipIfNoModule(\"kvikio\")\n    @parameterized.expand([TEST_CASE_GPU_5, TEST_CASE_GPU_6])\n    def test_pydicom_gpu_reader(self, input_param, filenames, expected_shape, expected_np_shape):\n        result = LoadImage(image_only=True, **input_param)(filenames)\n        self.assertEqual(result.meta[\"filename_or_obj\"], f\"{Path(filenames)}\")\n        assert_allclose(\n            result.affine,\n            torch.tensor(\n                [\n                    [-0.488281, 0.0, 0.0, 125.0],\n                    [0.0, -0.488281, 0.0, 128.100006],\n                    [0.0, 0.0, 68.33333333, -99.480003],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n        self.assertTupleEqual(result.shape, expected_np_shape)\n\n    def test_no_files(self):\n        with self.assertRaisesRegex(RuntimeError, \"list index out of range\"):  # fname_regex excludes everything\n            LoadImage(image_only=True, reader=\"PydicomReader\", fname_regex=r\"^(?!.*).*\")(\"tests/testing_data/CT_DICOM\")\n        LoadImage(image_only=True, reader=\"PydicomReader\", fname_regex=None)(\"tests/testing_data/CT_DICOM\")\n\n    def test_itk_dicom_series_reader_single(self):\n        result = LoadImage(image_only=True, reader=\"ITKReader\")(self.data_dir)\n        self.assertEqual(result.meta[\"filename_or_obj\"], f\"{Path(self.data_dir)}\")\n        assert_allclose(\n            result.affine,\n            torch.tensor(\n                [\n                    [-0.488281, 0.0, 0.0, 125.0],\n                    [0.0, -0.488281, 0.0, 128.100006],\n                    [0.0, 0.0, 1.0, -99.480003],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n        self.assertTupleEqual(result.shape, (16, 16, 1))\n\n    def test_itk_reader_multichannel(self):\n        test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype(\"uint8\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.png\")\n            itk_np_view = itk.image_view_from_array(test_image, is_vector=True)\n            itk.imwrite(itk_np_view, filename)\n            for flag in (False, True):\n                result = LoadImage(image_only=True, reader=ITKReader(reverse_indexing=flag))(Path(filename))\n                test_image = test_image.transpose(1, 0, 2)\n                np.testing.assert_allclose(result[:, :, 0], test_image[:, :, 0])\n                np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1])\n                np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2])\n\n    @parameterized.expand([TEST_CASE_22])\n    def test_dicom_reader_consistency(self, filenames):\n        itk_param = {\"reader\": \"ITKReader\"}\n        pydicom_param = {\"reader\": \"PydicomReader\"}\n        for affine_flag in [True, False]:\n            itk_param[\"affine_lps_to_ras\"] = affine_flag\n            pydicom_param[\"affine_lps_to_ras\"] = affine_flag\n            itk_result = LoadImage(image_only=True, **itk_param)(filenames)\n            pydicom_result = LoadImage(image_only=True, **pydicom_param)(filenames)\n            np.testing.assert_allclose(pydicom_result, itk_result)\n            np.testing.assert_allclose(pydicom_result.affine, itk_result.affine)\n\n    @SkipIfNoModule(\"pydicom\")\n    @SkipIfNoModule(\"cupy\")\n    @SkipIfNoModule(\"kvikio\")\n    @parameterized.expand([TEST_CASE_22])\n    def test_pydicom_reader_gpu_cpu_consistency(self, filenames):\n        gpu_param = {\"reader\": \"PydicomReader\", \"to_gpu\": True}\n        cpu_param = {\"reader\": \"PydicomReader\", \"to_gpu\": False}\n        for affine_flag in [True, False]:\n            gpu_param[\"affine_lps_to_ras\"] = affine_flag\n            cpu_param[\"affine_lps_to_ras\"] = affine_flag\n            gpu_result = LoadImage(image_only=True, **gpu_param)(filenames)\n            cpu_result = LoadImage(image_only=True, **cpu_param)(filenames)\n            np.testing.assert_allclose(gpu_result.cpu(), cpu_result)\n            np.testing.assert_allclose(gpu_result.affine.cpu(), cpu_result.affine)\n\n    def test_dicom_reader_consistency_single(self):\n        itk_param = {\"reader\": \"ITKReader\"}\n        pydicom_param = {\"reader\": \"PydicomReader\"}\n        for affine_flag in [True, False]:\n            itk_param[\"affine_lps_to_ras\"] = affine_flag\n            pydicom_param[\"affine_lps_to_ras\"] = affine_flag\n            itk_result = LoadImage(image_only=True, **itk_param)(self.data_dir)\n            pydicom_result = LoadImage(image_only=True, **pydicom_param)(self.data_dir)\n            np.testing.assert_allclose(pydicom_result, itk_result.squeeze())\n            np.testing.assert_allclose(pydicom_result.affine, itk_result.affine)\n\n    def test_load_nifti_multichannel(self):\n        test_image = np.random.randint(0, 256, size=(31, 64, 16, 2)).astype(np.float32)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.nii.gz\")\n            itk_np_view = itk.image_view_from_array(test_image, is_vector=True)\n            itk.imwrite(itk_np_view, filename)\n\n            itk_img = LoadImage(image_only=True, reader=ITKReader())(Path(filename))\n            self.assertTupleEqual(tuple(itk_img.shape), (16, 64, 31, 2))\n\n            nib_image = LoadImage(image_only=True, reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename))\n            self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2))\n\n            np.testing.assert_allclose(itk_img, nib_image, atol=1e-3, rtol=1e-3)\n\n    def test_load_png(self):\n        spatial_size = (256, 224)\n        test_image = np.random.randint(0, 256, size=spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.png\")\n            Image.fromarray(test_image.astype(\"uint8\")).save(filename)\n            result = LoadImage(image_only=True)(filename)\n            self.assertTupleEqual(result.shape, spatial_size[::-1])\n            np.testing.assert_allclose(result.T, test_image)\n\n    def test_register(self):\n        spatial_size = (32, 64, 128)\n        test_image = np.random.rand(*spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.nii.gz\")\n            itk_np_view = itk.image_view_from_array(test_image)\n            itk.imwrite(itk_np_view, filename)\n\n            loader = LoadImage(image_only=True)\n            loader.register(ITKReader())\n            result = loader(filename)\n            self.assertTupleEqual(result.shape, spatial_size[::-1])\n\n    def test_kwargs(self):\n        spatial_size = (32, 64, 128)\n        test_image = np.random.rand(*spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.nii.gz\")\n            itk_np_view = itk.image_view_from_array(test_image)\n            itk.imwrite(itk_np_view, filename)\n\n            loader = LoadImage(image_only=True)\n            reader = ITKReader(fallback_only=False)\n            loader.register(reader)\n            result = loader(filename)\n\n            reader = ITKReader()\n            img = reader.read(filename, fallback_only=False)\n            result_raw = reader.get_data(img)\n            result_raw = MetaTensor.ensure_torch_and_prune_meta(*result_raw)\n            self.assertTupleEqual(result.shape, result_raw.shape)\n\n    def test_my_reader(self):\n        \"\"\"test customised readers\"\"\"\n        out = LoadImage(image_only=True, reader=_MiniReader, is_compatible=True)(\"test\")\n        self.assertEqual(out.meta[\"name\"], \"my test\")\n        out = LoadImage(image_only=True, reader=_MiniReader, is_compatible=False)(\"test\")\n        self.assertEqual(out.meta[\"name\"], \"my test\")\n        for item in (_MiniReader, _MiniReader(is_compatible=False)):\n            out = LoadImage(image_only=True, reader=item)(\"test\")\n            self.assertEqual(out.meta[\"name\"], \"my test\")\n        out = LoadImage(image_only=True)(\"test\", reader=_MiniReader(is_compatible=False))\n        self.assertEqual(out.meta[\"name\"], \"my test\")\n\n    def test_itk_meta(self):\n        \"\"\"test metadata from a directory\"\"\"\n        out = LoadImage(image_only=True, reader=\"ITKReader\", pixel_type=itk_uc, series_meta=True)(\n            \"tests/testing_data/CT_DICOM\"\n        )\n        idx = \"0008|103e\"\n        label = itk.GDCMImageIO.GetLabelFromTag(idx, \"\")[1]\n        val = out.meta[idx]\n        expected = \"Series Description=Routine Brain \"\n        self.assertEqual(f\"{label}={val}\", expected)\n\n    @parameterized.expand([TEST_CASE_13, TEST_CASE_14, TEST_CASE_15, TEST_CASE_16, TEST_CASE_17, TEST_CASE_18])\n    def test_channel_dim(self, input_param, filename, expected_shape):\n        test_image = np.random.rand(*expected_shape)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, filename)\n            nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)\n            result = LoadImage(image_only=True, **input_param)(filename)  # with itk, meta has 'qto_xyz': itkMatrixF44\n\n        self.assertTupleEqual(\n            result.shape, (3, 128, 128, 128) if input_param.get(\"ensure_channel_first\", False) else expected_shape\n        )\n        self.assertEqual(result.meta[\"original_channel_dim\"], input_param[\"channel_dim\"])\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadImageMeta(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls.tmpdir = tempfile.mkdtemp()\n        test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4))\n        nib.save(test_image, os.path.join(cls.tmpdir, \"im.nii.gz\"))\n        cls.test_data = os.path.join(cls.tmpdir, \"im.nii.gz\")\n\n    @classmethod\n    def tearDownClass(cls):\n        shutil.rmtree(cls.tmpdir)\n        super().tearDownClass()\n\n    @parameterized.expand(TESTS_META)\n    def test_correct(self, input_param, expected_shape, track_meta):\n        set_track_meta(track_meta)\n        r = LoadImage(image_only=True, prune_meta_pattern=\"glmax\", prune_meta_sep=\"%\", **input_param)(self.test_data)\n        self.assertTupleEqual(r.shape, expected_shape)\n        if track_meta:\n            self.assertIsInstance(r, MetaTensor)\n            self.assertTrue(hasattr(r, \"affine\"))\n            self.assertIsInstance(r.affine, torch.Tensor)\n            self.assertTrue(\"glmax\" not in r.meta)\n        else:\n            self.assertIsInstance(r, torch.Tensor)\n            self.assertNotIsInstance(r, MetaTensor)\n            self.assertFalse(hasattr(r, \"affine\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_load_imaged.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import ITKReader\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Compose, EnsureChannelFirstD, FromMetaTensord, LoadImaged, SaveImageD\nfrom monai.transforms.meta_utility.dictionary import ToMetaTensord\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose\n\nitk, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\n\nKEYS = [\"image\", \"label\", \"extra\"]\n\nTEST_CASE_1 = [{\"keys\": KEYS}, (128, 128, 128)]\n\nTEST_CASE_2 = [{\"keys\": KEYS, \"reader\": \"ITKReader\", \"fallback_only\": False}, (128, 128, 128)]\n\nTESTS_META = []\nfor track_meta in (False, True):\n    TESTS_META.append([{\"keys\": KEYS}, (128, 128, 128), track_meta])\n    TESTS_META.append([{\"keys\": KEYS, \"reader\": \"ITKReader\", \"fallback_only\": False}, (128, 128, 128), track_meta])\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadImaged(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape(self, input_param, expected_shape):\n        test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4))\n        test_data = {}\n        with tempfile.TemporaryDirectory() as tempdir:\n            for key in KEYS:\n                nib.save(test_image, os.path.join(tempdir, key + \".nii.gz\"))\n                test_data.update({key: os.path.join(tempdir, key + \".nii.gz\")})\n            result = LoadImaged(image_only=True, **input_param)(test_data)\n\n        for key in KEYS:\n            self.assertTupleEqual(result[key].shape, expected_shape)\n\n    def test_register(self):\n        spatial_size = (32, 64, 128)\n        test_image = np.random.rand(*spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.nii.gz\")\n            itk_np_view = itk.image_view_from_array(test_image)\n            itk.imwrite(itk_np_view, filename)\n\n            loader = LoadImaged(keys=\"img\", image_only=True)\n            loader.register(ITKReader())\n            result = loader({\"img\": Path(filename)})\n            self.assertTupleEqual(result[\"img\"].shape, spatial_size[::-1])\n\n    def test_channel_dim(self):\n        spatial_size = (32, 64, 3, 128)\n        test_image = np.random.rand(*spatial_size)\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.nii.gz\")\n            nib.save(nib.Nifti1Image(test_image, affine=np.eye(4)), filename)\n\n            loader = LoadImaged(keys=\"img\", image_only=True)\n            loader.register(ITKReader(channel_dim=2))\n            t = Compose([EnsureChannelFirstD(\"img\"), FromMetaTensord(\"img\")])\n            result = t(loader({\"img\": filename}))\n            self.assertTupleEqual(result[\"img\"].shape, (3, 32, 64, 128))\n\n    def test_no_file(self):\n        with self.assertRaises(RuntimeError):\n            LoadImaged(keys=\"img\", image_only=True)({\"img\": \"unknown\"})\n        with self.assertRaises(RuntimeError):\n            LoadImaged(keys=\"img\", reader=\"nibabelreader\", image_only=True)({\"img\": \"unknown\"})\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestConsistency(unittest.TestCase):\n    def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext):\n        data_dict = {\"img\": filename}\n        keys = data_dict.keys()\n        xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True, image_only=True)])\n        img_dict = xforms(data_dict)  # load dicom with itk\n        self.assertTupleEqual(img_dict[\"img\"].shape, ch_shape)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_xform = SaveImageD(keys, output_dir=tempdir, squeeze_end_dims=False, output_ext=ext)\n            save_xform(img_dict)  # save to nifti\n\n            new_xforms = Compose(\n                [\n                    LoadImaged(keys, reader=reader_2, image_only=True),\n                    EnsureChannelFirstD(keys),\n                    FromMetaTensord(keys),\n                    ToMetaTensord(keys),\n                ]\n            )\n            out = new_xforms({\"img\": os.path.join(tempdir, outname)})  # load nifti with itk\n            self.assertTupleEqual(out[\"img\"].shape, ch_shape)\n\n            def is_identity(x):\n                return (x == torch.eye(x.shape[0])).all()\n\n            if not is_identity(img_dict[\"img\"].affine) and not is_identity(out[\"img\"].affine):\n                assert_allclose(img_dict[\"img\"].affine, out[\"img\"].affine, rtol=1e-3)\n            assert_allclose(out[\"img\"], img_dict[\"img\"], rtol=1e-3)\n\n    def test_dicom(self):\n        img_dir = \"tests/testing_data/CT_DICOM\"\n        self._cmp(img_dir, (1, 16, 16, 4), \"itkreader\", \"itkreader\", \"CT_DICOM/CT_DICOM_trans.nii.gz\", \".nii.gz\")\n        output_name = \"CT_DICOM/CT_DICOM_trans.nii.gz\"\n        self._cmp(img_dir, (1, 16, 16, 4), \"nibabelreader\", \"itkreader\", output_name, \".nii.gz\")\n        self._cmp(img_dir, (1, 16, 16, 4), \"itkreader\", \"nibabelreader\", output_name, \".nii.gz\")\n\n    def test_multi_dicom(self):\n        \"\"\"multichannel dicom reading, saving to nifti, then load with itk or nibabel\"\"\"\n\n        img_dir = [\"tests/testing_data/CT_DICOM\", \"tests/testing_data/CT_DICOM\"]\n        self._cmp(img_dir, (2, 16, 16, 4), \"itkreader\", \"itkreader\", \"CT_DICOM/CT_DICOM_trans.nii.gz\", \".nii.gz\")\n        output_name = \"CT_DICOM/CT_DICOM_trans.nii.gz\"\n        self._cmp(img_dir, (2, 16, 16, 4), \"nibabelreader\", \"itkreader\", output_name, \".nii.gz\")\n        self._cmp(img_dir, (2, 16, 16, 4), \"itkreader\", \"nibabelreader\", output_name, \".nii.gz\")\n\n    def test_png(self):\n        \"\"\"png reading with itk, saving to nifti, then load with itk or nibabel or PIL\"\"\"\n\n        test_image = np.random.randint(0, 256, size=(256, 224, 3)).astype(\"uint8\")\n        with tempfile.TemporaryDirectory() as tempdir:\n            filename = os.path.join(tempdir, \"test_image.png\")\n            itk_np_view = itk.image_view_from_array(test_image, is_vector=True)\n            itk.imwrite(itk_np_view, filename)\n            output_name = \"test_image/test_image_trans.png\"\n            self._cmp(filename, (3, 224, 256), \"itkreader\", \"itkreader\", output_name, \".png\")\n            self._cmp(filename, (3, 224, 256), \"itkreader\", \"PILReader\", output_name, \".png\")\n            self._cmp(filename, (3, 224, 256), \"itkreader\", \"nibabelreader\", output_name, \".png\")\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestLoadImagedMeta(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls.tmpdir = tempfile.mkdtemp()\n        test_image = nib.Nifti1Image(np.random.rand(128, 128, 128), np.eye(4))\n        cls.test_data = {}\n        for key in KEYS:\n            nib.save(test_image, os.path.join(cls.tmpdir, key + \".nii.gz\"))\n            cls.test_data.update({key: os.path.join(cls.tmpdir, key + \".nii.gz\")})\n\n    @classmethod\n    def tearDownClass(cls):\n        shutil.rmtree(cls.tmpdir)\n        super().tearDownClass()\n\n    @parameterized.expand(TESTS_META)\n    def test_correct(self, input_p, expected_shape, track_meta):\n        set_track_meta(track_meta)\n        result = LoadImaged(image_only=True, prune_meta_pattern=\".*_code$\", prune_meta_sep=\" \", **input_p)(\n            self.test_data\n        )\n\n        # shouldn't have any extra meta data keys\n        self.assertEqual(len(result), len(KEYS))\n        for key in KEYS:\n            r = result[key]\n            self.assertTupleEqual(r.shape, expected_shape)\n            if track_meta:\n                self.assertIsInstance(r, MetaTensor)\n                self.assertTrue(hasattr(r, \"affine\"))\n                self.assertIsInstance(r.affine, torch.Tensor)\n                self.assertEqual(r.meta[\"space\"], \"RAS\")\n                self.assertNotIn(\"qform_code\", r.meta)\n            else:\n                self.assertIsInstance(r, torch.Tensor)\n                self.assertNotIsInstance(r, MetaTensor)\n                self.assertFalse(hasattr(r, \"affine\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_load_spacing_orientation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport time\nimport unittest\nfrom pathlib import Path\n\nimport nibabel\nimport numpy as np\nimport torch\nfrom nibabel.processing import resample_to_output\nfrom parameterized import parameterized\n\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, Spacingd\n\nTESTS_PATH = Path(__file__).parents[1]\nFILES = tuple(\n    os.path.join(TESTS_PATH, \"testing_data\", filename) for filename in (\"anatomical.nii\", \"reoriented_anat_moved.nii\")\n)\n\n\nclass TestLoadSpacingOrientation(unittest.TestCase):\n    @staticmethod\n    def load_image(filename):\n        data = {\"image\": filename}\n        t = Compose([LoadImaged(keys=\"image\"), EnsureChannelFirstd(keys=\"image\", channel_dim=\"no_channel\")])\n        return t(data)\n\n    @parameterized.expand(FILES)\n    def test_load_spacingd(self, filename):\n        data_dict = self.load_image(filename)\n        t = time.time()\n        res_dict = Spacingd(keys=\"image\", pixdim=(1, 0.2, 1), diagonal=True, padding_mode=\"zeros\")(data_dict)\n        t1 = time.time()\n        print(f\"time monai: {t1 - t}\")\n        anat = nibabel.Nifti1Image(np.asarray(data_dict[\"image\"][0]), data_dict[\"image\"].meta[\"original_affine\"])\n        ref = resample_to_output(anat, (1, 0.2, 1), order=1)\n        t2 = time.time()\n        print(f\"time scipy: {t2 - t1}\")\n        self.assertGreaterEqual(t2, t1)\n        np.testing.assert_allclose(res_dict[\"image\"].affine, ref.affine)\n        np.testing.assert_allclose(res_dict[\"image\"].shape[1:], ref.shape)\n        np.testing.assert_allclose(ref.get_fdata(), res_dict[\"image\"][0], atol=0.05)\n\n    @parameterized.expand(FILES)\n    def test_load_spacingd_rotate(self, filename):\n        data_dict = self.load_image(filename)\n        affine = data_dict[\"image\"].affine\n        data_dict[\"image\"].meta[\"original_affine\"] = data_dict[\"image\"].affine = (\n            torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]], dtype=torch.float64) @ affine\n        )\n        t = time.time()\n        res_dict = Spacingd(keys=\"image\", pixdim=(1, 2, 3), diagonal=True, padding_mode=\"zeros\")(data_dict)\n        t1 = time.time()\n        print(f\"time monai: {t1 - t}\")\n        anat = nibabel.Nifti1Image(np.asarray(data_dict[\"image\"][0]), data_dict[\"image\"].meta[\"original_affine\"])\n        ref = resample_to_output(anat, (1, 2, 3), order=1)\n        t2 = time.time()\n        print(f\"time scipy: {t2 - t1}\")\n        self.assertGreaterEqual(t2, t1)\n        np.testing.assert_allclose(res_dict[\"image\"].affine, ref.affine)\n        if \"anatomical\" not in filename:\n            np.testing.assert_allclose(res_dict[\"image\"].shape[1:], ref.shape)\n            np.testing.assert_allclose(ref.get_fdata(), res_dict[\"image\"][0], atol=0.05)\n        else:\n            # different from the ref implementation (shape computed by round\n            # instead of ceil)\n            np.testing.assert_allclose(ref.get_fdata()[..., :-1], res_dict[\"image\"][0], atol=0.05)\n\n    def test_load_spacingd_non_diag(self):\n        data_dict = self.load_image(FILES[1])\n        affine = data_dict[\"image\"].affine\n        data_dict[\"image\"].meta[\"original_affine\"] = data_dict[\"image\"].affine = (\n            torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]], dtype=torch.float64) @ affine\n        )\n        res_dict = Spacingd(keys=\"image\", pixdim=(1, 2, 3), diagonal=False, padding_mode=\"zeros\")(data_dict)\n        np.testing.assert_allclose(\n            res_dict[\"image\"].affine,\n            np.array(\n                [\n                    [0.0, 0.0, 3.0, -27.599409],\n                    [0.0, 2.0, 0.0, -47.977585],\n                    [-1.0, 0.0, 0.0, 35.297897],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n\n    def test_load_spacingd_rotate_non_diag(self):\n        data_dict = self.load_image(FILES[0])\n        res_dict = Spacingd(keys=\"image\", pixdim=(1, 2, 3), diagonal=False, padding_mode=\"border\")(data_dict)\n        np.testing.assert_allclose(\n            res_dict[\"image\"].affine,\n            np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, 2.0, 0.0, -40.0], [0.0, 0.0, 3.0, -16.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    def test_load_spacingd_rotate_non_diag_ornt(self):\n        data_dict = self.load_image(FILES[0])\n        t = Compose(\n            [\n                Spacingd(keys=\"image\", pixdim=(1, 2, 3), diagonal=False, padding_mode=\"border\"),\n                Orientationd(keys=\"image\", axcodes=\"LPI\"),\n            ]\n        )\n        res_dict = t(data_dict)\n        np.testing.assert_allclose(\n            res_dict[\"image\"].affine,\n            np.array([[-1.0, 0.0, 0.0, 32.0], [0.0, -2.0, 0.0, 40.0], [0.0, 0.0, -3.0, 32.0], [0.0, 0.0, 0.0, 1.0]]),\n        )\n\n    def test_load_spacingd_non_diag_ornt(self):\n        data_dict = self.load_image(FILES[1])\n        affine = data_dict[\"image\"].affine\n        data_dict[\"image\"].meta[\"original_affine\"] = data_dict[\"image\"].affine = (\n            torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]], dtype=torch.float64) @ affine\n        )\n        t = Compose(\n            [\n                Spacingd(keys=\"image\", pixdim=(1, 2, 3), diagonal=False, padding_mode=\"border\"),\n                Orientationd(keys=\"image\", axcodes=\"LPI\"),\n            ]\n        )\n        res_dict = t(data_dict)\n        np.testing.assert_allclose(\n            res_dict[\"image\"].affine,\n            np.array(\n                [\n                    [-3.0, 0.0, 0.0, 56.4005909],\n                    [0.0, -2.0, 0.0, 52.02241516],\n                    [0.0, 0.0, -1.0, 35.29789734],\n                    [0.0, 0.0, 0.0, 1.0],\n                ]\n            ),\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_map_and_generate_sampling_centers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import map_and_generate_sampling_centers\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASE_1 = [\n    # test Argmax data\n    {\n        \"label\": (np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])),\n        \"spatial_size\": [2, 2, 2],\n        \"num_samples\": 2,\n        \"label_spatial_shape\": [3, 3, 3],\n        \"num_classes\": 3,\n        \"image\": None,\n        \"ratios\": [0, 1, 2],\n        \"image_threshold\": 0.0,\n    },\n    tuple,\n    2,\n    3,\n]\n\nTEST_CASE_2 = [\n    {\n        \"label\": (\n            np.array(\n                [\n                    [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                    [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                    [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                ]\n            )\n        ),\n        \"spatial_size\": [2, 2, 2],\n        \"num_samples\": 1,\n        \"ratios\": None,\n        \"label_spatial_shape\": [3, 3, 3],\n        \"image\": None,\n        \"image_threshold\": 0.0,\n    },\n    tuple,\n    1,\n    3,\n]\n\n\nclass TestMapAndGenerateSamplingCenters(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_map_and_generate_sampling_centers(self, input_data, expected_type, expected_count, expected_shape):\n        results = []\n        for p in TEST_NDARRAYS + (None,):\n            input_data = deepcopy(input_data)\n            if p is not None:\n                input_data[\"label\"] = p(input_data[\"label\"])\n            set_determinism(0)\n            result = map_and_generate_sampling_centers(**input_data)\n            self.assertIsInstance(result, expected_type)\n            self.assertEqual(len(result), expected_count)\n            self.assertEqual(len(result[0]), expected_shape)\n            # check for consistency between numpy, torch and torch.cuda\n            results.append(result)\n            if len(results) > 1:\n                for x, y in zip(result[0], result[-1]):\n                    assert_allclose(x, y, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_map_binary_to_indices.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import map_binary_to_indices\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"label\": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), \"image\": None, \"image_threshold\": 0.0},\n            np.array([1, 2, 3, 5, 6, 7]),\n            np.array([0, 4, 8]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"label\": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])),\n                \"image\": p(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])),\n                \"image_threshold\": 0.0,\n            },\n            np.array([1, 2, 3, 5, 6, 7]),\n            np.array([0, 8]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"label\": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])),\n                \"image\": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])),\n                \"image_threshold\": 1.0,\n            },\n            np.array([1, 2, 3, 5, 6, 7]),\n            np.array([0, 8]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\n                \"label\": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])),\n                \"image\": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])),\n                \"image_threshold\": 1.0,\n            },\n            np.array([1, 2, 3, 5, 6, 7]),\n            np.array([0, 8]),\n        ]\n    )\n\n\nclass TestMapBinaryToIndices(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, input_data, expected_fg, expected_bg):\n        fg_indices, bg_indices = map_binary_to_indices(**input_data)\n        assert_allclose(fg_indices, expected_fg, type_test=False)\n        assert_allclose(bg_indices, expected_bg, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_map_classes_to_indices.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import map_classes_to_indices\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            # test Argmax data\n            {\n                \"label\": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])),\n                \"num_classes\": 3,\n                \"image\": None,\n                \"image_threshold\": 0.0,\n            },\n            [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])],\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"label\": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])),\n                \"num_classes\": 3,\n                \"image\": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])),\n                \"image_threshold\": 60,\n            },\n            [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])],\n        ]\n    )\n\n    TESTS.append(\n        [\n            # test One-Hot data\n            {\n                \"label\": p(\n                    np.array(\n                        [\n                            [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                            [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                            [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                        ]\n                    )\n                ),\n                \"image\": None,\n                \"image_threshold\": 0.0,\n            },\n            [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])],\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"label\": p(\n                    np.array(\n                        [\n                            [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                            [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                            [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                        ]\n                    )\n                ),\n                \"num_classes\": None,\n                \"image\": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])),\n                \"image_threshold\": 60,\n            },\n            [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])],\n        ]\n    )\n\n    TESTS.append(\n        [\n            # test empty class\n            {\n                \"label\": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])),\n                \"num_classes\": 5,\n                \"image\": None,\n                \"image_threshold\": 0.0,\n            },\n            [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])],\n        ]\n    )\n\n    TESTS.append(\n        [\n            # test empty class\n            {\n                \"label\": p(\n                    np.array(\n                        [\n                            [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n                            [[0, 1, 0], [0, 0, 1], [1, 0, 0]],\n                            [[0, 0, 1], [1, 0, 0], [0, 1, 0]],\n                            [[0, 0, 0], [0, 0, 0], [0, 0, 0]],\n                            [[0, 0, 0], [0, 0, 0], [0, 0, 0]],\n                        ]\n                    )\n                ),\n                \"image\": None,\n                \"image_threshold\": 0.0,\n                \"max_samples_per_class\": 2,\n            },\n            [np.array([0, 8]), np.array([1, 6]), np.array([2, 7]), np.array([]), np.array([])],\n        ]\n    )\n\n\nclass TestMapClassesToIndices(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_data, expected_indices):\n        indices = map_classes_to_indices(**input_data)\n        for i, e in zip(indices, expected_indices):\n            assert_allclose(i, e, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_map_label_value.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import MapLabelValue\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.extend(\n        [\n            [{\"orig_labels\": [3, 2, 1], \"target_labels\": [0, 1, 2]}, p([[3, 1], [1, 2]]), p([[0.0, 2.0], [2.0, 1.0]])],\n            [\n                {\"orig_labels\": [3, 5, 8], \"target_labels\": [0, 1, 2]},\n                p([[[3], [5], [5], [8]]]),\n                p([[[0.0], [1.0], [1.0], [2.0]]]),\n            ],\n            [{\"orig_labels\": [1, 2, 3], \"target_labels\": [0, 1, 2]}, p([3, 1, 1, 2]), p([2.0, 0.0, 0.0, 1.0])],\n            [{\"orig_labels\": [1, 2, 3], \"target_labels\": [0.5, 1.5, 2.5]}, p([3, 1, 1, 2]), p([2.5, 0.5, 0.5, 1.5])],\n        ]\n    )\n    # note: PyTorch 1.5.1 doesn't support rich dtypes\n    TESTS.append(\n        [\n            {\"orig_labels\": [1.5, 2.5, 3.5], \"target_labels\": [0, 1, 2], \"dtype\": np.int8},\n            p([3.5, 1.5, 1.5, 2.5]),\n            p([2, 0, 0, 1]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"orig_labels\": [1.5, 2.5, 3.5], \"target_labels\": [0, 1, 2], \"dtype\": torch.int8},\n            p([3.5, 1.5, 1.5, 2.5]),\n            p([2, 0, 0, 1]),\n        ]\n    )\nTESTS.extend(\n    [\n        [\n            {\"orig_labels\": [\"label3\", \"label2\", \"label1\"], \"target_labels\": [0, 1, 2]},\n            np.array([[\"label3\", \"label1\"], [\"label1\", \"label2\"]]),\n            np.array([[0, 2], [2, 1]]),\n        ],\n        [\n            {\"orig_labels\": [3.5, 2.5, 1.5], \"target_labels\": [\"label0\", \"label1\", \"label2\"], \"dtype\": \"str\"},\n            np.array([[3.5, 1.5], [1.5, 2.5]]),\n            np.array([[\"label0\", \"label2\"], [\"label2\", \"label1\"]]),\n        ],\n        [\n            {\n                \"orig_labels\": [\"label3\", \"label2\", \"label1\"],\n                \"target_labels\": [\"label1\", \"label2\", \"label3\"],\n                \"dtype\": \"str\",\n            },\n            np.array([[\"label3\", \"label1\"], [\"label1\", \"label2\"]]),\n            np.array([[\"label1\", \"label3\"], [\"label3\", \"label2\"]]),\n        ],\n    ]\n)\n\n\nclass TestMapLabelValue(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_data, expected_value):\n        result = MapLabelValue(**input_param)(input_data)\n        if isinstance(expected_value, torch.Tensor):\n            assert_allclose(result, expected_value)\n        else:\n            np.testing.assert_equal(result, expected_value)\n        self.assertTupleEqual(result.shape, expected_value.shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_map_label_valued.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import MapLabelValued\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [\n    {\"keys\": \"seg\", \"orig_labels\": [3, 2, 1], \"target_labels\": [0, 1, 2]},\n    {\"seg\": np.array([[3, 1], [1, 2]])},\n    np.array([[0, 2], [2, 1]]),\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"seg\", \"orig_labels\": [3, 5, 8], \"target_labels\": [0, 1, 2]},\n    {\"seg\": np.array([[[3], [5], [5], [8]]])},\n    np.array([[[0], [1], [1], [2]]]),\n]\n\nTEST_CASE_3 = [\n    {\"keys\": \"seg\", \"orig_labels\": [1, 2, 3], \"target_labels\": [0, 1, 2]},\n    {\"seg\": np.array([3, 1, 1, 2])},\n    np.array([2, 0, 0, 1]),\n]\n\nTEST_CASE_4 = [\n    {\"keys\": \"seg\", \"orig_labels\": [1, 2, 3], \"target_labels\": [0.5, 1.5, 2.5]},\n    {\"seg\": np.array([3, 1, 1, 2])},\n    np.array([2.5, 0.5, 0.5, 1.5]),\n]\n\nTEST_CASE_5 = [\n    {\"keys\": \"seg\", \"orig_labels\": [1.5, 2.5, 3.5], \"target_labels\": [0, 1, 2], \"dtype\": np.int8},\n    {\"seg\": np.array([3.5, 1.5, 1.5, 2.5])},\n    np.array([2, 0, 0, 1]),\n]\nTEST_CASE_5_1 = [\n    {\"keys\": \"seg\", \"orig_labels\": [1.5, 2.5, 3.5], \"target_labels\": [0, 1, 2], \"dtype\": torch.int8},\n    {\"seg\": torch.as_tensor([3.5, 1.5, 1.5, 2.5])},\n    torch.as_tensor([2.0, 0.0, 0.0, 1.0]),\n]\n\nTEST_CASE_6 = [\n    {\"keys\": \"seg\", \"orig_labels\": [\"label3\", \"label2\", \"label1\"], \"target_labels\": [0, 1, 2]},\n    {\"seg\": np.array([[\"label3\", \"label1\"], [\"label1\", \"label2\"]])},\n    np.array([[0, 2], [2, 1]]),\n]\n\nTEST_CASE_7 = [\n    {\"keys\": \"seg\", \"orig_labels\": [3.5, 2.5, 1.5], \"target_labels\": [\"label0\", \"label1\", \"label2\"], \"dtype\": \"str\"},\n    {\"seg\": np.array([[3.5, 1.5], [1.5, 2.5]])},\n    np.array([[\"label0\", \"label2\"], [\"label2\", \"label1\"]]),\n]\n\n\nclass TestMapLabelValued(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_5_1, TEST_CASE_6, TEST_CASE_7]\n    )\n    def test_shape(self, input_param, input_data, expected_value):\n        result = MapLabelValued(**input_param)(input_data)\n        if isinstance(expected_value, torch.Tensor):\n            assert_allclose(result[\"seg\"], expected_value)\n        else:\n            np.testing.assert_equal(result[\"seg\"], expected_value)\n        self.assertTupleEqual(result[\"seg\"].shape, expected_value.shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_map_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import MapTransform\n\nTEST_CASES = [[\"item\", (\"item\",)], [None, (None,)], [[\"item1\", \"item2\"], (\"item1\", \"item2\")]]\n\nTEST_ILL_CASES = [[ValueError, []], [ValueError, ()], [TypeError, [[]]]]\n\n\nclass MapTest(MapTransform):\n\n    def __call__(self, data):\n        pass\n\n\nclass TestRandomizable(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_keys(self, keys, expected):\n        transform = MapTest(keys=keys)\n        self.assertEqual(transform.keys, expected)\n\n    @parameterized.expand(TEST_ILL_CASES)\n    def test_wrong_keys(self, exception, keys):\n        with self.assertRaisesRegex(exception, \"\"):\n            MapTest(keys=keys)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_mask_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import MaskIntensity\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASE_1 = [\n    {\"mask_data\": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])},\n    np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]),\n]\n\nTEST_CASE_2 = [\n    {\"mask_data\": np.array([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])},\n    np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]),\n]\n\nTEST_CASE_3 = [\n    {\"mask_data\": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])},\n    np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]),\n]\n\nTEST_CASE_4 = [\n    {\n        \"mask_data\": np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),\n        \"select_fn\": lambda x: np.where((x > 3) & (x < 7), True, False),\n    },\n    np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n    np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]),\n]\n\nTEST_CASE_5 = [\n    {\"mask_data\": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])},\n    torch.as_tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n    torch.as_tensor([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]),\n]\n\n\nclass TestMaskIntensity(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_value(self, arguments, image, expected_data):\n        for p in TEST_NDARRAYS:\n            result = MaskIntensity(**arguments)(p(image))\n            assert_allclose(result, p(expected_data), type_test=\"tensor\")\n\n    def test_runtime_mask(self):\n        mask_data = np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])\n        img = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])\n        expected = np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]])\n\n        result = MaskIntensity()(img=img, mask_data=mask_data)\n        assert_allclose(result, expected, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_mask_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import MaskIntensityd\n\nTEST_CASE_1 = [\n    {\"keys\": \"img\", \"mask_data\": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]])},\n    {\"img\": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]),\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"img\", \"mask_data\": np.array([[[0, 0, 0], [0, 5, 0], [0, 0, 0]]])},\n    {\"img\": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 0, 0], [0, 5, 0], [0, 0, 0]]]),\n]\n\nTEST_CASE_3 = [\n    {\"keys\": \"img\", \"mask_data\": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]])},\n    {\"img\": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]),\n]\n\nTEST_CASE_4 = [\n    {\"keys\": \"img\", \"mask_key\": \"mask\"},\n    {\n        \"img\": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n        \"mask\": np.array([[[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 1, 0], [0, 1, 0], [0, 1, 0]]]),\n    },\n    np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]),\n]\n\nTEST_CASE_5 = [\n    {\n        \"keys\": \"img\",\n        \"mask_data\": np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),\n        \"select_fn\": lambda x: np.where((x > 3) & (x < 7), True, False),\n    },\n    {\"img\": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},\n    np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]),\n]\n\n\nclass TestMaskIntensityd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_value(self, arguments, image, expected_data):\n        result = MaskIntensityd(**arguments)(image)\n        np.testing.assert_allclose(result[\"img\"], expected_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_mean_ensemble.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import MeanEnsemble\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([{\"weights\": None}, [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], p(torch.ones(2, 2, 2)) + 1])\n\n    TESTS.append(\n        [{\"weights\": None}, p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])), p(torch.ones(2, 2, 2)) + 1]\n    )\n\n    TESTS.append(\n        [{\"weights\": [1, 3]}, [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2], p(torch.ones(2, 2, 2)) * 2.5]\n    )\n\n    TESTS.append(\n        [\n            {\"weights\": [[1, 3], [3, 1]]},\n            [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2],\n            p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"weights\": np.array([[1, 3], [3, 1]])},\n            [p(torch.ones(2, 2, 2)), p(torch.ones(2, 2, 2)) + 2],\n            p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"weights\": torch.tensor([[[1, 3]], [[3, 1]]])},\n            [p(torch.ones(2, 2, 2, 2)), p(torch.ones(2, 2, 2, 2)) + 2],\n            p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)),\n        ]\n    )\n\n\nclass TestMeanEnsemble(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, img, expected_value):\n        result = MeanEnsemble(**input_param)(img)\n        assert_allclose(result, expected_value)\n\n    def test_cuda_value(self):\n        img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2])\n        expected_value = torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)\n        if torch.cuda.is_available():\n            img = img.to(torch.device(\"cuda:0\"))\n            expected_value = expected_value.to(torch.device(\"cuda:0\"))\n        result = MeanEnsemble(torch.tensor([[[1, 3]], [[3, 1]]]))(img)\n        assert_allclose(result, expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_mean_ensembled.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import MeanEnsembled\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\"], \"output_key\": \"output\", \"weights\": None},\n            {\"pred0\": p(torch.ones(2, 2, 2)), \"pred1\": p(torch.ones(2, 2, 2)) + 2},\n            p(torch.ones(2, 2, 2)) + 1,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"output\", \"weights\": None},\n            {\"output\": p(torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]))},\n            p(torch.ones(2, 2, 2)) + 1,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\"], \"output_key\": \"output\", \"weights\": [1, 3]},\n            {\"pred0\": p(torch.ones(2, 2, 2, 2)), \"pred1\": p(torch.ones(2, 2, 2, 2)) + 2},\n            p(torch.ones(2, 2, 2, 2)) * 2.5,\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\"], \"output_key\": \"output\", \"weights\": [[1, 3], [3, 1]]},\n            {\"pred0\": p(torch.ones(2, 2, 2)), \"pred1\": p(torch.ones(2, 2, 2)) + 2},\n            p(torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1)),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\"], \"output_key\": \"output\", \"weights\": np.array([[[1, 3]], [[3, 1]]])},\n            {\"pred0\": p(torch.ones(2, 2, 2, 2)), \"pred1\": p(torch.ones(2, 2, 2, 2)) + 2},\n            p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\"], \"output_key\": \"output\", \"weights\": torch.tensor([[[1, 3]], [[3, 1]]])},\n            {\"pred0\": p(torch.ones(2, 2, 2, 2)), \"pred1\": p(torch.ones(2, 2, 2, 2)) + 2},\n            p(torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)),\n        ]\n    )\n\n\nclass TestMeanEnsembled(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, data, expected_value):\n        result = MeanEnsembled(**input_param)(data)\n        assert_allclose(result[\"output\"], expected_value)\n\n    def test_cuda_value(self):\n        img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2])\n        expected_value = torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1)\n        if torch.cuda.is_available():\n            img = img.to(torch.device(\"cuda:0\"))\n            expected_value = expected_value.to(torch.device(\"cuda:0\"))\n        result = MeanEnsembled(keys=\"output\", weights=torch.tensor([[[1, 3]], [[3, 1]]]))({\"output\": img})\n        assert_allclose(result[\"output\"], expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_median_smooth.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import MedianSmooth\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\n\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"radius\": 1},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n        ]\n    )\n\n\nclass TestMedianSmooth(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = MedianSmooth(**arguments)(image)\n        assert_allclose(result, expected_data, atol=1e-4, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_median_smoothd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import MedianSmoothd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS[0:1]:\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"radius\": [0, 1]},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]))},\n            np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"radius\": 1},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"radius\": [1, 1]},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"radius\": [1, 1, 1]},\n            {\"img\": p(np.array([[[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]]))},\n            np.array([[[[2, 2, 2], [3, 3, 3], [3, 3, 3]], [[4, 4, 4], [4, 4, 4], [5, 5, 5]]]]),\n        ]\n    )\n\n\nclass TestMedianSmoothd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        result = MedianSmoothd(**arguments)(image)\n        assert_allclose(result[\"img\"], expected_data, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_morphological_ops.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS_SHAPE = []\nfor p in TEST_NDARRAYS:\n    mask = torch.zeros(1, 1, 5, 5, 5)\n    filter_size = 3\n    TESTS_SHAPE.append([{\"mask\": p(mask), \"filter_size\": filter_size}, [1, 1, 5, 5, 5]])\n    mask = torch.zeros(3, 2, 5, 5, 5)\n    filter_size = 5\n    TESTS_SHAPE.append([{\"mask\": p(mask), \"filter_size\": filter_size}, [3, 2, 5, 5, 5]])\n    mask = torch.zeros(1, 1, 1, 1, 1)\n    filter_size = 5\n    TESTS_SHAPE.append([{\"mask\": p(mask), \"filter_size\": filter_size}, [1, 1, 1, 1, 1]])\n    mask = torch.zeros(1, 1, 1, 1)\n    filter_size = 5\n    TESTS_SHAPE.append([{\"mask\": p(mask), \"filter_size\": filter_size}, [1, 1, 1, 1]])\n\nTESTS_VALUE_T = []\nfilter_size = 3\nmask = torch.ones(3, 2, 3, 3, 3)\nTESTS_VALUE_T.append([{\"mask\": mask, \"filter_size\": filter_size, \"pad_value\": 1.0}, torch.ones(3, 2, 3, 3, 3)])\nmask = torch.zeros(3, 2, 3, 3, 3)\nTESTS_VALUE_T.append([{\"mask\": mask, \"filter_size\": filter_size, \"pad_value\": 0.0}, torch.zeros(3, 2, 3, 3, 3)])\nmask = torch.ones(3, 2, 3, 3)\nTESTS_VALUE_T.append([{\"mask\": mask, \"filter_size\": filter_size, \"pad_value\": 1.0}, torch.ones(3, 2, 3, 3)])\nmask = torch.zeros(3, 2, 3, 3)\nTESTS_VALUE_T.append([{\"mask\": mask, \"filter_size\": filter_size, \"pad_value\": 0.0}, torch.zeros(3, 2, 3, 3)])\n\nTESTS_VALUE = []\nfor p in TEST_NDARRAYS:\n    mask = torch.zeros(3, 2, 5, 5, 5)\n    filter_size = 3\n    TESTS_VALUE.append(\n        [{\"mask\": p(mask), \"filter_size\": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))]\n    )\n    mask = torch.ones(1, 1, 3, 3, 3)\n    filter_size = 3\n    TESTS_VALUE.append(\n        [{\"mask\": p(mask), \"filter_size\": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))]\n    )\n    mask = torch.ones(1, 2, 3, 3, 3)\n    filter_size = 3\n    TESTS_VALUE.append(\n        [{\"mask\": p(mask), \"filter_size\": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))]\n    )\n    mask = torch.zeros(3, 2, 3, 3, 3)\n    mask[:, :, 1, 1, 1] = 1.0\n    filter_size = 3\n    TESTS_VALUE.append(\n        [{\"mask\": p(mask), \"filter_size\": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))]\n    )\n    mask = torch.zeros(3, 2, 3, 3)\n    mask[:, :, 1, 1] = 1.0\n    filter_size = 3\n    TESTS_VALUE.append(\n        [{\"mask\": p(mask), \"filter_size\": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))]\n    )\n\n\nclass TestMorph(unittest.TestCase):\n    @parameterized.expand(TESTS_SHAPE)\n    def test_shape(self, input_data, expected_result):\n        result1 = erode(input_data[\"mask\"], input_data[\"filter_size\"])\n        assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0)\n\n    @parameterized.expand(TESTS_VALUE_T)\n    def test_value_t(self, input_data, expected_result):\n        result1 = get_morphological_filter_result_t(\n            input_data[\"mask\"], input_data[\"filter_size\"], input_data[\"pad_value\"]\n        )\n        assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0)\n\n    @parameterized.expand(TESTS_VALUE)\n    def test_value(self, input_data, expected_erode_result, expected_dilate_result):\n        result1 = erode(input_data[\"mask\"], input_data[\"filter_size\"])\n        assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)\n        result2 = dilate(input_data[\"mask\"], input_data[\"filter_size\"])\n        assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_nifti_endianness.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\nfrom unittest.case import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import DataLoader, Dataset, create_test_image_2d\nfrom monai.data.image_reader import PILReader\nfrom monai.transforms import LoadImage, LoadImaged\nfrom monai.transforms.io.array import switch_endianness\nfrom monai.utils.enums import PostFix\nfrom monai.utils.module import optional_import\n\nif TYPE_CHECKING:\n    import nibabel as nib\n    from PIL import Image as PILImage\n\n    has_nib = True\n    has_pil = True\nelse:\n    nib, has_nib = optional_import(\"nibabel\")\n    PILImage, has_pil = optional_import(\"PIL.Image\")\n\nTESTS: list[tuple] = []\nfor endianness in [\"<\", \">\"]:\n    for use_array in [True, False]:\n        for image_only in [True, False]:\n            TESTS.append((endianness, use_array, image_only))\n\n\nclass TestNiftiEndianness(unittest.TestCase):\n\n    def setUp(self):\n        self.im, _ = create_test_image_2d(100, 100)\n        self.fname = tempfile.NamedTemporaryFile(suffix=\".nii.gz\").name\n\n    @parameterized.expand(TESTS)\n    @skipUnless(has_nib, \"Requires NiBabel\")\n    def test_endianness(self, endianness, use_array, image_only):\n        hdr = nib.Nifti1Header(endianness=endianness)\n        nii = nib.Nifti1Image(self.im, np.eye(4), header=hdr)\n        nib.save(nii, self.fname)\n\n        data = [self.fname] if use_array else [{\"image\": self.fname}]\n        tr = LoadImage(image_only=image_only) if use_array else LoadImaged(\"image\", image_only=image_only)\n        check_ds = Dataset(data, tr)\n        check_loader = DataLoader(check_ds, batch_size=1)\n        ret = next(iter(check_loader))\n        if isinstance(ret, dict) and PostFix.meta(\"image\") in ret:\n            np.testing.assert_allclose(ret[PostFix.meta(\"image\")][\"spatial_shape\"], [[100, 100]])\n\n    def test_switch(self):  # verify data types\n        for data in (np.zeros((2, 1)), (\"test\",), [24, 42], {\"foo\": \"bar\"}, True, 42):\n            output = switch_endianness(data, \"<\")\n            self.assertEqual(type(data), type(output))\n\n        before = np.array((20, 20), dtype=\">i2\")\n        expected_float = before.astype(float)\n        after = switch_endianness(before)\n        np.testing.assert_allclose(after.astype(float), expected_float)\n        self.assertEqual(after.dtype.byteorder, \"<\")\n\n        before = np.array((20, 20), dtype=\"<i2\")\n        expected_float = before.astype(float)\n        after = switch_endianness(before)\n        np.testing.assert_allclose(after.astype(float), expected_float)\n\n        before = np.array([\"1.12\", \"-9.2\", \"42\"], dtype=np.bytes_)\n        after = switch_endianness(before)\n        np.testing.assert_array_equal(before, after)\n\n        with self.assertRaises(NotImplementedError):\n            switch_endianness(np.zeros((2, 1)), \"=\")\n\n        with self.assertRaises(RuntimeError):\n            switch_endianness(Path(\"test\"), \"<\")\n\n    @skipUnless(has_pil, \"Requires PIL\")\n    def test_pil(self):\n        tempdir = tempfile.mkdtemp()\n        test_image = np.random.randint(0, 256, size=[128, 256])\n        filename = os.path.join(tempdir, \"test_image.png\")\n        PILImage.fromarray(test_image.astype(\"uint8\")).save(filename)\n\n        loader = LoadImage(PILReader(converter=lambda image: image.convert(\"LA\")))\n        _ = loader(filename)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_normalize_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import NormalizeIntensity\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, {\"nonzero\": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])])\n    for q in TEST_NDARRAYS:\n        for u in TEST_NDARRAYS:\n            TESTS.append(\n                [\n                    p,\n                    {\n                        \"subtrahend\": q(np.array([3.5, 3.5, 3.5, 3.5])),\n                        \"divisor\": u(np.array([0.5, 0.5, 0.5, 0.5])),\n                        \"nonzero\": True,\n                    },\n                    p(np.array([0.0, 3.0, 0.0, 4.0])),\n                    p(np.array([0.0, -1.0, 0.0, 1.0])),\n                ]\n            )\n    TESTS.append([p, {\"nonzero\": True}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))])\n    TESTS.append([p, {\"nonzero\": False}, p(np.array([0.0, 0.0, 0.0, 0.0])), p(np.array([0.0, 0.0, 0.0, 0.0]))])\n    TESTS.append([p, {\"nonzero\": False}, p(np.array([1, 1, 1, 1])), p(np.array([0.0, 0.0, 0.0, 0.0]))])\n    TESTS.append(\n        [\n            p,\n            {\"nonzero\": False, \"channel_wise\": True, \"subtrahend\": [1, 2, 3], \"dtype\": np.float32},\n            p(np.ones((3, 2, 2))),\n            p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]])),\n        ]\n    )\n    TESTS.append(\n        [\n            p,\n            {\"nonzero\": True, \"channel_wise\": True, \"subtrahend\": [1, 2, 3], \"divisor\": [0, 0, 2], \"dtype\": \"float32\"},\n            p(np.ones((3, 2, 2))),\n            p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]])),\n        ]\n    )\n    TESTS.append(\n        [\n            p,\n            {\"nonzero\": True, \"channel_wise\": False, \"subtrahend\": 2, \"divisor\": 0, \"dtype\": torch.float32},\n            p(np.ones((3, 2, 2))),\n            p(np.ones((3, 2, 2)) * -1.0),\n        ]\n    )\n    TESTS.append(\n        [\n            p,\n            {\"nonzero\": True, \"channel_wise\": False, \"subtrahend\": np.ones((3, 2, 2)) * 0.5, \"divisor\": 0},\n            p(np.ones((3, 2, 2))),\n            p(np.ones((3, 2, 2)) * 0.5),\n        ]\n    )\n    TESTS.append(\n        [\n            p,\n            {\"nonzero\": True, \"channel_wise\": True, \"subtrahend\": np.ones((3, 2, 2)) * 0.5, \"divisor\": [0, 1, 0]},\n            p(np.ones((3, 2, 2))),\n            p(np.ones((3, 2, 2)) * 0.5),\n        ]\n    )\n\n\nclass TestNormalizeIntensity(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_default(self, im_type):\n        im = im_type(self.imt.copy())\n        normalizer = NormalizeIntensity()\n        normalized = normalizer(im)\n        self.assertTrue(normalized.dtype in (np.float32, torch.float32))\n        expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)\n        assert_allclose(normalized, expected, type_test=\"tensor\", rtol=1e-3)\n\n    @parameterized.expand(TESTS)\n    def test_nonzero(self, in_type, input_param, input_data, expected_data):\n        normalizer = NormalizeIntensity(**input_param)\n        im = in_type(input_data)\n        normalized = normalizer(im)\n        assert_allclose(normalized, in_type(expected_data), type_test=\"tensor\")\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, im_type):\n        normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)\n        input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))\n        expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])\n        normalized = normalizer(input_data)\n        assert_allclose(normalized, im_type(expected), type_test=\"tensor\")\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise_int(self, im_type):\n        normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)\n        input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4))\n        expected = np.array(\n            [\n                [\n                    [-1.593255, -1.3035723, -1.0138896, -0.7242068],\n                    [-0.4345241, -0.1448414, 0.1448414, 0.4345241],\n                    [0.7242068, 1.0138896, 1.3035723, 1.593255],\n                ],\n                [\n                    [-1.593255, -1.3035723, -1.0138896, -0.7242068],\n                    [-0.4345241, -0.1448414, 0.1448414, 0.4345241],\n                    [0.7242068, 1.0138896, 1.3035723, 1.593255],\n                ],\n            ]\n        )\n        normalized = normalizer(input_data)\n        assert_allclose(normalized, im_type(expected), type_test=\"tensor\", rtol=1e-7, atol=1e-7)  # tolerance\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_value_errors(self, im_type):\n        input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))\n        normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1])\n        with self.assertRaises(ValueError):\n            normalizer(input_data)\n        normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1, 2], divisor=[1])\n        with self.assertRaises(ValueError):\n            normalizer(input_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_normalize_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import NormalizeIntensityd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for q in TEST_NDARRAYS:\n        TESTS.append(\n            [\n                {\"keys\": [\"img\"], \"nonzero\": True},\n                {\"img\": p(np.array([0.0, 3.0, 0.0, 4.0]))},\n                p(np.array([0.0, -1.0, 0.0, 1.0])),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": [\"img\"],\n                    \"subtrahend\": q(np.array([3.5, 3.5, 3.5, 3.5])),\n                    \"divisor\": q(np.array([0.5, 0.5, 0.5, 0.5])),\n                    \"nonzero\": True,\n                },\n                {\"img\": p(np.array([0.0, 3.0, 0.0, 4.0]))},\n                p(np.array([0.0, -1.0, 0.0, 1.0])),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"keys\": [\"img\"], \"nonzero\": True},\n                {\"img\": p(np.array([0.0, 0.0, 0.0, 0.0]))},\n                p(np.array([0.0, 0.0, 0.0, 0.0])),\n            ]\n        )\n\n\nclass TestNormalizeIntensityd(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_image_normalize_intensityd(self, im_type):\n        key = \"img\"\n        im = im_type(self.imt)\n        normalizer = NormalizeIntensityd(keys=[key])\n        normalized = normalizer({key: im})[key]\n        expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)\n        assert_allclose(normalized, im_type(expected), rtol=1e-3, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS)\n    def test_nonzero(self, input_param, input_data, expected_data):\n        key = \"img\"\n        normalizer = NormalizeIntensityd(**input_param)\n        normalized = normalizer(input_data)[key]\n        assert_allclose(normalized, expected_data, type_test=\"tensor\")\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, im_type):\n        key = \"img\"\n        normalizer = NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True)\n        input_data = {key: im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))}\n        normalized = normalizer(input_data)[key]\n        expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])\n        assert_allclose(normalized, im_type(expected), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_nvtx_decorator.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import (\n    Compose,\n    CuCIM,\n    Flip,\n    Flipd,\n    OneOf,\n    RandAdjustContrast,\n    RandCuCIM,\n    RandFlip,\n    RandomizableTrait,\n    Rotate90,\n    ToCupy,\n    ToNumpy,\n    TorchVision,\n    ToTensor,\n    ToTensord,\n)\nfrom monai.utils import Range, optional_import\nfrom tests.test_utils import HAS_CUPY\n\n_, has_nvtx = optional_import(\"torch._C._nvtx\", descriptor=\"NVTX is not installed. Are you sure you have a CUDA build?\")\n_, has_tvt = optional_import(\"torchvision.transforms\")\n_, has_cut = optional_import(\"cucim.core.operations.expose.transform\")\n\nTEST_CASE_ARRAY_0 = [np.random.randn(3, 3)]\nTEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)]\n\nTEST_CASE_DICT_0 = [{\"image\": np.random.randn(3, 3)}]\nTEST_CASE_DICT_1 = [{\"image\": np.random.randn(3, 10, 10)}]\n\nTEST_CASE_TORCH_0 = [torch.randn(3, 3)]\nTEST_CASE_TORCH_1 = [torch.randn(3, 10, 10)]\n\nTEST_CASE_WRAPPER = [np.random.randn(3, 10, 10)]\n\nTEST_CASE_RECURSIVE_0 = [\n    torch.randn(3, 3),\n    Compose([ToNumpy(), Flip(), RandAdjustContrast(prob=0.0), RandFlip(prob=1.0), ToTensor()]),\n]\nTEST_CASE_RECURSIVE_1 = [\n    torch.randn(3, 3),\n    Compose([ToNumpy(), Flip(), Compose([RandAdjustContrast(prob=0.0), RandFlip(prob=1.0)]), ToTensor()]),\n]\nTEST_CASE_RECURSIVE_2 = [\n    torch.randn(3, 3),\n    Compose([ToNumpy(), Flip(), OneOf([RandAdjustContrast(prob=0.0), RandFlip(prob=1.0)], weights=[0, 1]), ToTensor()]),\n]\nTEST_CASE_RECURSIVE_LIST = [\n    torch.randn(3, 3),\n    [ToNumpy(), Flip(), RandAdjustContrast(prob=0.0), RandFlip(prob=1.0), ToTensor()],\n]\n\n\n@unittest.skipUnless(has_nvtx, \"Required torch._C._nvtx for NVTX Range!\")\nclass TestNVTXRangeDecorator(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1])\n    def test_transform_array(self, input):\n        transforms = Compose([Range(\"random flip\")(Flip()), Range()(ToTensor())])\n        # Apply transforms\n        output = transforms(input)\n\n        # Decorate with NVTX Range\n        transforms1 = Range()(transforms)\n        transforms2 = Range(\"Transforms2\")(transforms)\n        transforms3 = Range(name=\"Transforms3\", methods=\"__call__\")(transforms)\n\n        # Apply transforms with Range\n        output1 = transforms1(input)\n        output2 = transforms2(input)\n        output3 = transforms3(input)\n\n        # Check the outputs\n        self.assertIsInstance(output, torch.Tensor)\n        self.assertIsInstance(output1, torch.Tensor)\n        self.assertIsInstance(output2, torch.Tensor)\n        self.assertIsInstance(output3, torch.Tensor)\n        np.testing.assert_equal(output.numpy(), output1.numpy())\n        np.testing.assert_equal(output.numpy(), output2.numpy())\n        np.testing.assert_equal(output.numpy(), output3.numpy())\n\n    @parameterized.expand([TEST_CASE_DICT_0, TEST_CASE_DICT_1])\n    def test_tranform_dict(self, input):\n        transforms = Compose([Range(\"random flip dict\")(Flipd(keys=\"image\")), Range()(ToTensord(\"image\"))])\n        # Apply transforms\n        output = transforms(input)[\"image\"]\n\n        # Decorate with NVTX Range\n        transforms1 = Range()(transforms)\n        transforms2 = Range(\"Transforms2\")(transforms)\n        transforms3 = Range(name=\"Transforms3\", methods=\"__call__\")(transforms)\n\n        # Apply transforms with Range\n        output1 = transforms1(input)[\"image\"]\n        output2 = transforms2(input)[\"image\"]\n        output3 = transforms3(input)[\"image\"]\n\n        # Check the outputs\n        self.assertIsInstance(output, torch.Tensor)\n        self.assertIsInstance(output1, torch.Tensor)\n        self.assertIsInstance(output2, torch.Tensor)\n        self.assertIsInstance(output3, torch.Tensor)\n        np.testing.assert_equal(output.numpy(), output1.numpy())\n        np.testing.assert_equal(output.numpy(), output2.numpy())\n        np.testing.assert_equal(output.numpy(), output3.numpy())\n\n    @parameterized.expand([TEST_CASE_WRAPPER])\n    @unittest.skipUnless(HAS_CUPY, \"Requires CuPy.\")\n    @unittest.skipUnless(has_cut, \"Requires cuCIM transforms.\")\n    @unittest.skipUnless(has_tvt, \"Requires torchvision transforms.\")\n    def test_wrapper_tranforms(self, input):\n        transform_list = [\n            ToTensor(),\n            TorchVision(name=\"RandomHorizontalFlip\", p=1.0),\n            ToCupy(),\n            CuCIM(name=\"image_flip\", spatial_axis=-1),\n            RandCuCIM(name=\"rand_image_rotate_90\", prob=1.0, max_k=1, spatial_axis=(-2, -1)),\n        ]\n\n        transforms = Compose(transform_list)\n        transforms_range = Compose([Range()(t) for t in transform_list])\n\n        # Apply transforms\n        output = transforms(input)\n\n        # Apply transforms with Range\n        output_r = transforms_range(input)\n\n        # Check the outputs\n        np.testing.assert_equal(output.get(), output_r.get())\n\n    @parameterized.expand([TEST_CASE_RECURSIVE_0, TEST_CASE_RECURSIVE_1, TEST_CASE_RECURSIVE_2])\n    def test_recursive_tranforms(self, input, transforms):\n        transforms_range = Range(name=\"Recursive Compose\", recursive=True)(transforms)\n\n        # Apply transforms\n        output = transforms(input)\n\n        # Apply transforms with Range\n        output_r = transforms_range(input)\n\n        # Check the outputs\n        self.assertEqual(transforms.map_items, transforms_range.map_items)\n        self.assertEqual(transforms.unpack_items, transforms_range.unpack_items)\n        np.testing.assert_equal(output.numpy(), output_r.numpy())\n\n    @parameterized.expand([TEST_CASE_RECURSIVE_LIST])\n    def test_recursive_list_tranforms(self, input, transform_list):\n        transforms_list_range = Range(recursive=True)(transform_list)\n\n        # Apply transforms\n        output = Compose(transform_list)(input)\n\n        # Apply transforms with Range\n        output_r = Compose(transforms_list_range)(input)\n\n        # Check the outputs\n        np.testing.assert_equal(output.numpy(), output_r.numpy())\n\n    @parameterized.expand([TEST_CASE_ARRAY_1])\n    def test_tranform_randomized(self, input):\n        # Compose deterministic and randomized transforms\n        transforms = Compose(\n            [\n                Range(\"flip\")(Flip()),\n                Rotate90(),\n                Range()(RandAdjustContrast(prob=0.0)),\n                Range(\"random flip\")(RandFlip(prob=1.0)),\n                ToTensor(),\n            ]\n        )\n        # Apply transforms\n        output = transforms(input)\n\n        # Decorate with NVTX Range\n        transforms1 = Range()(transforms)\n        transforms2 = Range(\"Transforms2\")(transforms)\n        transforms3 = Range(name=\"Transforms3\", methods=\"__call__\")(transforms)\n\n        # Apply transforms with Range\n        output1 = transforms1(input)\n        output2 = transforms2(input)\n        output3 = transforms3(input)\n\n        # Check if the outputs are equal\n        self.assertIsInstance(output, torch.Tensor)\n        self.assertIsInstance(output1, torch.Tensor)\n        self.assertIsInstance(output2, torch.Tensor)\n        self.assertIsInstance(output3, torch.Tensor)\n        np.testing.assert_equal(output.numpy(), output1.numpy())\n        np.testing.assert_equal(output.numpy(), output2.numpy())\n        np.testing.assert_equal(output.numpy(), output3.numpy())\n\n        # Check if the first randomized is RandAdjustContrast\n        for tran in transforms.transforms:\n            if isinstance(tran, RandomizableTrait):\n                self.assertIsInstance(tran, RandAdjustContrast)\n                break\n\n    @parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1])\n    def test_network(self, input):\n        # Create a network\n        model = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Sigmoid())\n\n        # Forward\n        output = model(input)\n\n        # Decorate with NVTX Range\n        model1 = Range()(model)\n        model2 = Range(\"Model2\")(model)\n        model3 = Range(name=\"Model3\", methods=\"forward\")(model)\n\n        # Forward with Range\n        output1 = model1(input)\n        output2 = model2(input)\n        output3 = model3(input)\n\n        # Check the outputs\n        self.assertIsInstance(output, torch.Tensor)\n        self.assertIsInstance(output1, torch.Tensor)\n        self.assertIsInstance(output2, torch.Tensor)\n        self.assertIsInstance(output3, torch.Tensor)\n        np.testing.assert_equal(output.numpy(), output1.numpy())\n        np.testing.assert_equal(output.numpy(), output2.numpy())\n        np.testing.assert_equal(output.numpy(), output3.numpy())\n\n    @parameterized.expand([TEST_CASE_TORCH_0, TEST_CASE_TORCH_1])\n    def test_loss(self, input):\n        # Create a network and loss\n        model = torch.nn.Sigmoid()\n        loss = torch.nn.BCELoss()\n        pred = model(input)\n        target = torch.empty_like(input).random_(2)\n\n        # Loss evaluation\n        output = loss(pred, target)\n\n        # Decorate with NVTX Range\n        loss1 = Range()(loss)\n        loss2 = Range(\"Loss2\")(loss)\n        loss3 = Range(name=\"Loss3\", methods=\"forward\")(loss)\n\n        # Loss evaluation with Range\n        output1 = loss1(pred, target)\n        output2 = loss2(pred, target)\n        output3 = loss3(pred, target)\n\n        # Check the outputs\n        self.assertIsInstance(output, torch.Tensor)\n        self.assertIsInstance(output1, torch.Tensor)\n        self.assertIsInstance(output2, torch.Tensor)\n        self.assertIsInstance(output3, torch.Tensor)\n        np.testing.assert_equal(output.numpy(), output1.numpy())\n        np.testing.assert_equal(output.numpy(), output2.numpy())\n        np.testing.assert_equal(output.numpy(), output3.numpy())\n\n    def test_context_manager(self):\n        model = torch.nn.Sigmoid()\n        loss = torch.nn.BCELoss()\n\n        with Range():\n            input = torch.randn(3, requires_grad=True)\n            target = torch.empty(3).random_(2)\n\n        with Range(\"Model\"):\n            output = loss(model(input), target)\n            output.backward()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_nvtx_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Compose, Flip, RandFlip, RandFlipD, RandomizableTrait, ToTensor, ToTensorD\nfrom monai.transforms.nvtx import (\n    Mark,\n    MarkD,\n    RandMark,\n    RandMarkD,\n    RandRangePop,\n    RandRangePopD,\n    RandRangePush,\n    RandRangePushD,\n    RangePop,\n    RangePopD,\n    RangePush,\n    RangePushD,\n)\nfrom monai.utils import optional_import\n\n_, has_nvtx = optional_import(\"torch._C._nvtx\", descriptor=\"NVTX is not installed. Are you sure you have a CUDA build?\")\n\nTEST_CASE_ARRAY_0 = [np.random.randn(3, 3)]\nTEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)]\nTEST_CASE_DICT_0 = [{\"image\": np.random.randn(3, 3)}]\nTEST_CASE_DICT_1 = [{\"image\": np.random.randn(3, 10, 10)}]\n\n\nclass TestNVTXTransforms(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1, TEST_CASE_DICT_0, TEST_CASE_DICT_1])\n    @unittest.skipUnless(has_nvtx, \"CUDA is required for NVTX!\")\n    def test_nvtx_transforms_alone(self, input):\n        transforms = Compose(\n            [\n                Mark(\"Mark: Transforms Start!\"),\n                RangePush(\"Range: RandFlipD\"),\n                RangePop(),\n                RandRangePush(\"Range: ToTensorD\"),\n                RandRangePop(),\n                RandMark(\"Mark: Transforms End!\"),\n            ]\n        )\n        output = transforms(input)\n        self.assertEqual(id(input), id(output))\n\n        # Check if chain of randomizable/non-randomizable transforms is not broken\n        for tran in transforms.transforms:\n            if isinstance(tran, RandomizableTrait):\n                self.assertIsInstance(tran, RangePush)\n                break\n\n    @parameterized.expand([TEST_CASE_ARRAY_0, TEST_CASE_ARRAY_1])\n    @unittest.skipUnless(has_nvtx, \"CUDA is required for NVTX!\")\n    def test_nvtx_transfroms_array(self, input):\n        # with prob == 0.0\n        transforms = Compose(\n            [\n                RandMark(\"Mark: Transforms Start!\"),\n                RandRangePush(\"Range: RandFlip\"),\n                RandFlip(prob=0.0),\n                RandRangePop(),\n                RangePush(\"Range: ToTensor\"),\n                ToTensor(),\n                RangePop(),\n                Mark(\"Mark: Transforms End!\"),\n            ]\n        )\n        output = transforms(input)\n        self.assertIsInstance(output, torch.Tensor)\n        np.testing.assert_array_equal(input, output)\n        # with prob == 1.0\n        transforms = Compose(\n            [\n                RandMark(\"Mark: Transforms Start!\"),\n                RandRangePush(\"Range: RandFlip\"),\n                RandFlip(prob=1.0),\n                RandRangePop(),\n                RangePush(\"Range: ToTensor\"),\n                ToTensor(),\n                RangePop(),\n                Mark(\"Mark: Transforms End!\"),\n            ]\n        )\n        output = transforms(input)\n        self.assertIsInstance(output, torch.Tensor)\n        np.testing.assert_array_equal(input, Flip()(output.numpy()))\n\n    @parameterized.expand([TEST_CASE_DICT_0, TEST_CASE_DICT_1])\n    @unittest.skipUnless(has_nvtx, \"CUDA is required for NVTX!\")\n    def test_nvtx_transfroms_dict(self, input):\n        # with prob == 0.0\n        transforms = Compose(\n            [\n                RandMarkD(\"Mark: Transforms (p=0) Start!\"),\n                RandRangePushD(\"Range: RandFlipD\"),\n                RandFlipD(keys=\"image\", prob=0.0),\n                RandRangePopD(),\n                RangePushD(\"Range: ToTensorD\"),\n                ToTensorD(keys=(\"image\")),\n                RangePopD(),\n                MarkD(\"Mark: Transforms (p=0) End!\"),\n            ]\n        )\n        output = transforms(input)\n        self.assertIsInstance(output[\"image\"], torch.Tensor)\n        np.testing.assert_array_equal(input[\"image\"], output[\"image\"])\n        # with prob == 1.0\n        transforms = Compose(\n            [\n                RandMarkD(\"Mark: Transforms (p=1) Start!\"),\n                RandRangePushD(\"Range: RandFlipD\"),\n                RandFlipD(keys=\"image\", prob=1.0),\n                RandRangePopD(),\n                RangePushD(\"Range: ToTensorD\"),\n                ToTensorD(keys=(\"image\")),\n                RangePopD(),\n                MarkD(\"Mark: Transforms (p=1) End!\"),\n            ]\n        )\n        output = transforms(input)\n        self.assertIsInstance(output[\"image\"], torch.Tensor)\n        np.testing.assert_array_equal(input[\"image\"], Flip()(output[\"image\"].numpy()))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_orientation.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import cast\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Orientation, create_rotate, create_translate\nfrom monai.utils import SpaceKeys\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_DEVICES, assert_allclose\n\nTESTS = []\nfor device in TEST_DEVICES:\n    TESTS.append(\n        [\n            {\"axcodes\": \"RAS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.eye(4),\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            \"RAS\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"LPS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.eye(4),\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            \"LPS\",\n            True,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"ALS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.as_tensor(np.diag([-1, -1, 1, 1])),\n            torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),\n            \"ALS\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"PRS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.as_tensor(np.diag([-1, -1, 1, 1])),\n            torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),\n            \"PRS\",\n            True,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"RAS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.as_tensor(np.diag([-1, -1, 1, 1])),\n            torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),\n            \"RAS\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"LPS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.as_tensor(np.diag([-1, -1, 1, 1])),\n            torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),\n            \"LPS\",\n            True,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"AL\"},\n            torch.arange(6).reshape((2, 1, 3)),\n            torch.eye(3),\n            torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),\n            \"AL\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"PR\"},\n            torch.arange(6).reshape((2, 1, 3)),\n            torch.eye(3),\n            torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),\n            \"PR\",\n            True,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"L\"},\n            torch.arange(6).reshape((2, 3)),\n            torch.eye(2),\n            torch.tensor([[2, 1, 0], [5, 4, 3]]),\n            \"L\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"R\"},\n            torch.arange(6).reshape((2, 3)),\n            torch.eye(2),\n            torch.tensor([[2, 1, 0], [5, 4, 3]]),\n            \"R\",\n            True,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"L\"},\n            torch.arange(6).reshape((2, 3)),\n            torch.eye(2),\n            torch.tensor([[2, 1, 0], [5, 4, 3]]),\n            \"L\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"L\"},\n            torch.arange(6).reshape((2, 3)),\n            torch.as_tensor(np.diag([-1, 1])),\n            torch.arange(6).reshape((2, 3)),\n            \"L\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"LPS\"},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.as_tensor(\n                create_translate(3, (10, 20, 30))\n                @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4))\n                @ np.diag([-1, 1, 1, 1])\n            ),\n            torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]),\n            \"LPS\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"as_closest_canonical\": True},\n            torch.arange(12).reshape((2, 1, 2, 3)),\n            torch.as_tensor(\n                create_translate(3, (10, 20, 30))\n                @ create_rotate(3, (np.pi / 2, np.pi / 2, np.pi / 4))\n                @ np.diag([-1, 1, 1, 1])\n            ),\n            torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]),\n            \"RAS\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"as_closest_canonical\": True},\n            torch.arange(6).reshape((1, 2, 3)),\n            torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),\n            torch.tensor([[[3, 0], [4, 1], [5, 2]]]),\n            \"RA\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"LP\"},\n            torch.arange(6).reshape((1, 2, 3)),\n            torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),\n            torch.tensor([[[2, 5], [1, 4], [0, 3]]]),\n            \"LP\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"axcodes\": \"LPID\", \"labels\": tuple(zip(\"LPIC\", \"RASD\"))},\n            torch.zeros((1, 2, 3, 4, 5)),\n            torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),\n            torch.zeros((1, 2, 3, 4, 5)),\n            \"LPID\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"as_closest_canonical\": True, \"labels\": tuple(zip(\"LPIC\", \"RASD\"))},\n            torch.zeros((1, 2, 3, 4, 5)),\n            torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),\n            torch.zeros((1, 2, 3, 4, 5)),\n            \"RASD\",\n            False,\n            *device,\n        ]\n    )\n\nTESTS_TORCH = []\nfor track_meta in (False, True):\n    for device in TEST_DEVICES:\n        TESTS_TORCH.append([{\"axcodes\": \"LPS\"}, torch.zeros((1, 3, 4, 5)), track_meta, *device])\n\nILL_CASES = [\n    # too short axcodes\n    [{\"axcodes\": \"RA\"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]\n]\n\nTESTS_INVERSE = []\nfor device in TEST_DEVICES:\n    TESTS_INVERSE.append([True, *device])\n    TESTS_INVERSE.append([False, *device])\n\n\nclass TestOrientationCase(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_ornt_meta(\n        self,\n        init_param,\n        img: torch.Tensor,\n        affine: torch.Tensor,\n        expected_data: torch.Tensor,\n        expected_code: str,\n        lps_convention: bool,\n        device,\n    ):\n        meta = {\"space\": SpaceKeys.LPS} if lps_convention else None\n        img = MetaTensor(img, affine=affine, meta=meta).to(device)\n        ornt = Orientation(**init_param)\n        call_param = {\"data_array\": img}\n        res = ornt(**call_param)  # type: ignore[arg-type]\n        if img.ndim in (3, 4):\n            test_resampler_lazy(ornt, res, init_param, call_param)\n\n        assert_allclose(res, expected_data.to(device))\n        labels = ((\"R\", \"L\"), (\"A\", \"P\"), (\"I\", \"S\")) if lps_convention else ornt.labels\n        new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels)  # type: ignore\n        self.assertEqual(\"\".join(new_code), expected_code)\n\n    @parameterized.expand(TESTS_TORCH)\n    def test_ornt_torch(self, init_param, img: torch.Tensor, track_meta: bool, device):\n        set_track_meta(track_meta)\n        ornt = Orientation(**init_param)\n\n        img = img.to(device)\n        expected_data = img.clone()\n        expected_code = ornt.axcodes\n\n        res = ornt(img)\n        assert_allclose(res, expected_data)\n        if track_meta:\n            self.assertIsInstance(res, MetaTensor)\n            assert isinstance(res, MetaTensor)  # for mypy type narrowing\n            new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels)\n            self.assertEqual(\"\".join(new_code), expected_code)\n        else:\n            self.assertIsInstance(res, torch.Tensor)\n            self.assertNotIsInstance(res, MetaTensor)\n\n    @parameterized.expand(ILL_CASES)\n    def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):\n        img = MetaTensor(img, affine=affine)\n        with self.assertRaises(ValueError):\n            Orientation(**init_param)(img)\n\n    @parameterized.expand(TESTS_INVERSE)\n    def test_inverse(self, lps_convention: bool, device):\n        img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)\n        affine = torch.tensor(\n            [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device=\"cpu\"\n        )\n        meta = {\"fname\": \"somewhere\", \"space\": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS}\n        img = MetaTensor(img_t, affine=affine, meta=meta)\n        tr = Orientation(\"LPS\")\n        # check that image and affine have changed\n        img = cast(MetaTensor, tr(img))\n        self.assertNotEqual(img.shape, img_t.shape)\n        self.assertGreater(float((affine - img.affine).max()), 0.5)\n        # check that with inverse, image affine are back to how they were\n        img = cast(MetaTensor, tr.inverse(img))\n        self.assertEqual(img.shape, img_t.shape)\n        self.assertLess(float((affine - img.affine).max()), 1e-2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_orientationd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Orientationd\nfrom monai.utils import SpaceKeys\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_DEVICES\n\nTESTS = []\nfor device in TEST_DEVICES:\n    TESTS.append(\n        [{\"keys\": \"seg\", \"axcodes\": \"RAS\"}, torch.ones((2, 1, 2, 3)), torch.eye(4), (2, 1, 2, 3), \"RAS\", False, *device]\n    )\n    TESTS.append(\n        [{\"keys\": \"seg\", \"axcodes\": \"RAS\"}, torch.ones((2, 1, 2, 3)), torch.eye(4), (2, 1, 2, 3), \"RAS\", True, *device]\n    )\n    # 3d\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"axcodes\": \"PLI\"},\n            torch.ones((2, 1, 2, 3)),\n            torch.eye(4),\n            (2, 2, 1, 3),\n            \"PLI\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"axcodes\": \"PLI\"},\n            torch.ones((2, 1, 2, 3)),\n            torch.eye(4),\n            (2, 2, 1, 3),\n            \"PLI\",\n            True,\n            *device,\n        ]\n    )\n    # 2d\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"axcodes\": \"PLI\"},\n            torch.ones((2, 1, 3)),\n            torch.eye(4),\n            (2, 3, 1),\n            \"PLS\",\n            False,\n            *device,\n        ]\n    )\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"axcodes\": \"PLI\"},\n            torch.ones((2, 1, 3)),\n            torch.eye(4),\n            (2, 3, 1),\n            \"PLS\",\n            True,\n            *device,\n        ]\n    )\n    # 1d\n    TESTS.append(\n        [{\"keys\": [\"img\", \"seg\"], \"axcodes\": \"L\"}, torch.ones((2, 3)), torch.eye(4), (2, 3), \"LAS\", False, *device]\n    )\n    TESTS.append(\n        [{\"keys\": [\"img\", \"seg\"], \"axcodes\": \"L\"}, torch.ones((2, 3)), torch.eye(4), (2, 3), \"LPS\", True, *device]\n    )\n    # canonical\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"as_closest_canonical\": True},\n            torch.ones((2, 1, 2, 3)),\n            torch.eye(4),\n            (2, 1, 2, 3),\n            \"RAS\",\n            False,\n            *device,\n        ]\n    )\n\nTESTS_TORCH = []\nfor track_meta in (False, True):\n    for device in TEST_DEVICES:\n        TESTS_TORCH.append([{\"keys\": \"seg\", \"axcodes\": \"RAS\"}, torch.ones(2, 1, 2, 3), track_meta, *device])\n\n\nclass TestOrientationdCase(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_orntd(\n        self,\n        init_param,\n        img: torch.Tensor,\n        affine: torch.Tensor | None,\n        expected_shape,\n        expected_code,\n        lps_convention: bool,\n        device,\n    ):\n        ornt = Orientationd(**init_param)\n        if affine is not None:\n            meta = {\"space\": SpaceKeys.LPS} if lps_convention else None\n            img = MetaTensor(img, affine=affine, meta=meta)\n        img = img.to(device)\n        call_param = {\"data\": {k: img.clone() for k in ornt.keys}}\n        res = ornt(**call_param)  # type: ignore[arg-type]\n        for k in ornt.keys:\n            if img.ndim in (3, 4):\n                test_resampler_lazy(ornt, res, init_param, call_param, output_key=k)\n            _im = res[k]\n            self.assertIsInstance(_im, MetaTensor)\n            np.testing.assert_allclose(_im.shape, expected_shape)\n            labels = ((\"R\", \"L\"), (\"A\", \"P\"), (\"I\", \"S\")) if lps_convention else ornt.ornt_transform.labels\n            code = nib.aff2axcodes(_im.affine.cpu(), labels)  # type: ignore\n            self.assertEqual(\"\".join(code), expected_code)\n\n    @parameterized.expand(TESTS_TORCH)\n    def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device):\n        set_track_meta(track_meta)\n        ornt = Orientationd(**init_param)\n        img = img.to(device)\n        expected_shape = img.shape\n        expected_code = ornt.ornt_transform.axcodes\n        call_param = {\"data\": {k: img.clone() for k in ornt.keys}}\n        res = ornt(**call_param)  # type: ignore[arg-type]\n        for k in ornt.keys:\n            _im = res[k]\n            np.testing.assert_allclose(_im.shape, expected_shape)\n            if track_meta:\n                if img.ndim in (3, 4):\n                    test_resampler_lazy(ornt, res, init_param, call_param, output_key=k)\n                self.assertIsInstance(_im, MetaTensor)\n                assert isinstance(_im, MetaTensor)  # for mypy type narrowing\n                code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels)\n                self.assertEqual(\"\".join(code), expected_code)\n            else:\n                self.assertIsInstance(_im, torch.Tensor)\n                self.assertNotIsInstance(_im, MetaTensor)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_adjust_contrast.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandAdjustContrast\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\nTEST_CASE_1 = [(0.5, 4.5)]\n\nTEST_CASE_2 = [1.5]\n\n\nclass TestRandAdjustContrast(NumpyImageTestCase2D):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_correct_results(self, gamma):\n        adjuster = RandAdjustContrast(prob=1.0, gamma=gamma)\n        for p in TEST_NDARRAYS:\n            im = p(self.imt)\n            result = adjuster(im)\n            epsilon = 1e-7\n            img_min = self.imt.min()\n            img_range = self.imt.max() - img_min\n            expected = (\n                np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.gamma_value) * img_range\n                + img_min\n            )\n            assert_allclose(result, expected, rtol=1e-05, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_adjust_contrastd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandAdjustContrastd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\nTEST_CASE_1 = [(0.5, 4.5)]\n\nTEST_CASE_2 = [1.5]\n\n\nclass TestRandAdjustContrastd(NumpyImageTestCase2D):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_correct_results(self, gamma):\n        adjuster = RandAdjustContrastd(\"img\", prob=1.0, gamma=gamma)\n        for p in TEST_NDARRAYS:\n            result = adjuster({\"img\": p(self.imt)})\n            epsilon = 1e-7\n            img_min = self.imt.min()\n            img_range = self.imt.max() - img_min\n            expected = (\n                np.power(((self.imt - img_min) / float(img_range + epsilon)), adjuster.adjuster.gamma_value) * img_range\n                + img_min\n            )\n            assert_allclose(result[\"img\"], expected, rtol=1e-05, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_affine.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandAffine\nfrom monai.utils.type_conversion import convert_data_type\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, is_tf32_env\n\n_rtol = 1e-3 if is_tf32_env() else 1e-4\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [dict(device=device), {\"img\": p(torch.arange(27).reshape((3, 3, 3)))}, p(np.arange(27).reshape((3, 3, 3)))]\n        )\n        TESTS.append(\n            [\n                dict(device=device, spatial_size=-1),\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3)))},\n                p(np.arange(27).reshape((3, 3, 3))),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(device=device),\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3))), \"spatial_size\": (2, 2)},\n                p(np.array([[[2.0, 3.0], [5.0, 6.0]], [[11.0, 12.0], [14.0, 15.0]], [[20.0, 21.0], [23.0, 24.0]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(device=device),\n                {\"img\": p(torch.ones((1, 3, 3, 3))), \"spatial_size\": (2, 2, 2)},\n                p(torch.ones((1, 2, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(device=device, spatial_size=(2, 2, 2), cache_grid=True),\n                {\"img\": p(torch.ones((1, 3, 3, 3)))},\n                p(torch.ones((1, 2, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    prob=0.9,\n                    rotate_range=(np.pi / 2,),\n                    shear_range=[1, 2],\n                    translate_range=[2, 1],\n                    padding_mode=\"zeros\",\n                    spatial_size=(2, 2, 2),\n                    device=device,\n                ),\n                {\"img\": p(torch.ones((1, 3, 3, 3))), \"mode\": \"bilinear\"},\n                p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    prob=0.9,\n                    rotate_range=(np.pi / 2,),\n                    shear_range=[1, 2],\n                    translate_range=[2, 1],\n                    padding_mode=\"zeros\",\n                    spatial_size=(2, 2, 2),\n                    cache_grid=True,\n                    device=device,\n                ),\n                {\"img\": p(torch.ones((1, 3, 3, 3))), \"mode\": \"bilinear\"},\n                p(torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]])),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    prob=0.9,\n                    rotate_range=(np.pi / 2,),\n                    shear_range=[1, 2],\n                    translate_range=[2, 1],\n                    scale_range=[0.1, 0.2],\n                    device=device,\n                ),\n                {\"img\": p(torch.arange(64).reshape((1, 8, 8))), \"spatial_size\": (3, 3)},\n                p(\n                    torch.tensor(\n                        [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                dict(\n                    prob=0.9,\n                    rotate_range=(np.pi / 2,),\n                    shear_range=[1, 2],\n                    translate_range=[2, 1],\n                    scale_range=[0.1, 0.2],\n                    spatial_size=(3, 3),\n                    cache_grid=True,\n                    device=device,\n                ),\n                {\"img\": p(torch.arange(64).reshape((1, 8, 8)))},\n                p(\n                    torch.tensor(\n                        [[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]\n                    )\n                ),\n            ]\n        )\n\nTEST_CASES_SKIPPED_CONSISTENCY = []\nfor p in TEST_NDARRAYS_ALL:\n    for in_dtype in (np.int32, np.float32):\n        TEST_CASES_SKIPPED_CONSISTENCY.append((p(np.arange(9 * 10).reshape(1, 9, 10)), in_dtype))\n\nTEST_RANDOMIZE = []\nfor cache_grid in (False, True):\n    for initial_randomize in (False, True):\n        TEST_RANDOMIZE.append((initial_randomize, cache_grid))\n\n\nclass TestRandAffine(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_affine(self, input_param, input_data, expected_val):\n        g = RandAffine(**input_param)\n        g.set_random_state(123)\n        result = g(**input_data)\n        g.rand_affine_grid.affine = torch.eye(4, dtype=torch.float64)  # reset affine\n        test_resampler_lazy(g, result, input_param, input_data, seed=123, rtol=_rtol)\n        if input_param.get(\"cache_grid\", False):\n            self.assertTrue(g._cached_grid is not None)\n        assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test=\"tensor\")\n\n    @parameterized.expand([(None,), ((1, 1, -1),)])\n    def test_ill_cache(self, spatial_size):\n        with self.assertWarns(UserWarning):\n            RandAffine(cache_grid=True, spatial_size=spatial_size)\n\n    @parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY)\n    def test_skipped_transform_consistency(self, im, in_dtype):\n        t1 = RandAffine(prob=0)\n        t2 = RandAffine(prob=1, spatial_size=(10, 11))\n\n        im, *_ = convert_data_type(im, dtype=in_dtype)\n\n        out1 = t1(im)\n        out2 = t2(im)\n\n        # check same type\n        self.assertEqual(type(out1), type(out2))\n        # check matching dtype\n        self.assertEqual(out1.dtype, out2.dtype)\n\n    @parameterized.expand(TEST_RANDOMIZE)\n    def test_no_randomize(self, initial_randomize, cache_grid):\n        rand_affine = RandAffine(\n            prob=1,\n            rotate_range=(np.pi / 6, 0, 0),\n            translate_range=((-2, 2), (-2, 2), (-2, 2)),\n            scale_range=((-0.1, 0.1), (-0.1, 0.1), (-0.1, 0.1)),\n            spatial_size=(16, 16, 16),\n            cache_grid=cache_grid,\n            padding_mode=\"zeros\",\n        )\n        if initial_randomize:\n            rand_affine.randomize(None)\n\n        arr = torch.randn((1, 16, 16, 16)) * 100\n\n        arr1 = rand_affine(arr, randomize=False)\n        m1 = rand_affine.rand_affine_grid.get_transformation_matrix()\n\n        arr2 = rand_affine(arr, randomize=False)\n        m2 = rand_affine.rand_affine_grid.get_transformation_matrix()\n\n        assert_allclose(m1, m2)\n        assert_allclose(arr1, arr2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_affine_grid.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandAffineGrid\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, is_tf32_env\n\n_rtol = 1e-1 if is_tf32_env() else 1e-4\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append([{\"device\": device}, {\"grid\": p(torch.ones((3, 3, 3)))}, p(np.ones((3, 3, 3)))])\n        TESTS.append(\n            [\n                {\"rotate_range\": (1, 2), \"translate_range\": (3, 3, 3)},\n                {\"grid\": p(torch.arange(0, 27).reshape((3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [\n                                [-32.81998, -33.910976, -35.001972],\n                                [-36.092968, -37.183964, -38.27496],\n                                [-39.36596, -40.456955, -41.54795],\n                            ],\n                            [\n                                [2.1380205, 3.1015975, 4.0651755],\n                                [5.028752, 5.9923296, 6.955907],\n                                [7.919484, 8.883063, 9.84664],\n                            ],\n                            [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0], [24.0, 25.0, 26.0]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"translate_range\": (3, 3, 3), \"device\": device},\n                {\"spatial_size\": (3, 3, 3)},\n                np.array(\n                    [\n                        [\n                            [\n                                [0.17881513, 0.17881513, 0.17881513],\n                                [0.17881513, 0.17881513, 0.17881513],\n                                [0.17881513, 0.17881513, 0.17881513],\n                            ],\n                            [\n                                [1.1788151, 1.1788151, 1.1788151],\n                                [1.1788151, 1.1788151, 1.1788151],\n                                [1.1788151, 1.1788151, 1.1788151],\n                            ],\n                            [\n                                [2.1788151, 2.1788151, 2.1788151],\n                                [2.1788151, 2.1788151, 2.1788151],\n                                [2.1788151, 2.1788151, 2.1788151],\n                            ],\n                        ],\n                        [\n                            [\n                                [-2.283164, -2.283164, -2.283164],\n                                [-1.283164, -1.283164, -1.283164],\n                                [-0.28316402, -0.28316402, -0.28316402],\n                            ],\n                            [\n                                [-2.283164, -2.283164, -2.283164],\n                                [-1.283164, -1.283164, -1.283164],\n                                [-0.28316402, -0.28316402, -0.28316402],\n                            ],\n                            [\n                                [-2.283164, -2.283164, -2.283164],\n                                [-1.283164, -1.283164, -1.283164],\n                                [-0.28316402, -0.28316402, -0.28316402],\n                            ],\n                        ],\n                        [\n                            [\n                                [-2.6388912, -1.6388912, -0.6388912],\n                                [-2.6388912, -1.6388912, -0.6388912],\n                                [-2.6388912, -1.6388912, -0.6388912],\n                            ],\n                            [\n                                [-2.6388912, -1.6388912, -0.6388912],\n                                [-2.6388912, -1.6388912, -0.6388912],\n                                [-2.6388912, -1.6388912, -0.6388912],\n                            ],\n                            [\n                                [-2.6388912, -1.6388912, -0.6388912],\n                                [-2.6388912, -1.6388912, -0.6388912],\n                                [-2.6388912, -1.6388912, -0.6388912],\n                            ],\n                        ],\n                        [\n                            [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                            [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                            [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],\n                        ],\n                    ]\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"device\": device, \"rotate_range\": (1.0, 1.0, 1.0), \"shear_range\": (0.1,), \"scale_range\": (1.2,)},\n                {\"grid\": p(torch.arange(0, 108).reshape((4, 3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [\n                                [\n                                    [-9.4201e00, -8.1672e00, -6.9143e00],\n                                    [-5.6614e00, -4.4085e00, -3.1556e00],\n                                    [-1.9027e00, -6.4980e-01, 6.0310e-01],\n                                ],\n                                [\n                                    [1.8560e00, 3.1089e00, 4.3618e00],\n                                    [5.6147e00, 6.8676e00, 8.1205e00],\n                                    [9.3734e00, 1.0626e01, 1.1879e01],\n                                ],\n                                [\n                                    [1.3132e01, 1.4385e01, 1.5638e01],\n                                    [1.6891e01, 1.8144e01, 1.9397e01],\n                                    [2.0650e01, 2.1902e01, 2.3155e01],\n                                ],\n                            ],\n                            [\n                                [\n                                    [9.9383e-02, -4.8845e-01, -1.0763e00],\n                                    [-1.6641e00, -2.2519e00, -2.8398e00],\n                                    [-3.4276e00, -4.0154e00, -4.6032e00],\n                                ],\n                                [\n                                    [-5.1911e00, -5.7789e00, -6.3667e00],\n                                    [-6.9546e00, -7.5424e00, -8.1302e00],\n                                    [-8.7180e00, -9.3059e00, -9.8937e00],\n                                ],\n                                [\n                                    [-1.0482e01, -1.1069e01, -1.1657e01],\n                                    [-1.2245e01, -1.2833e01, -1.3421e01],\n                                    [-1.4009e01, -1.4596e01, -1.5184e01],\n                                ],\n                            ],\n                            [\n                                [\n                                    [5.9635e01, 6.1199e01, 6.2764e01],\n                                    [6.4328e01, 6.5892e01, 6.7456e01],\n                                    [6.9021e01, 7.0585e01, 7.2149e01],\n                                ],\n                                [\n                                    [7.3714e01, 7.5278e01, 7.6842e01],\n                                    [7.8407e01, 7.9971e01, 8.1535e01],\n                                    [8.3099e01, 8.4664e01, 8.6228e01],\n                                ],\n                                [\n                                    [8.7792e01, 8.9357e01, 9.0921e01],\n                                    [9.2485e01, 9.4049e01, 9.5614e01],\n                                    [9.7178e01, 9.8742e01, 1.0031e02],\n                                ],\n                            ],\n                            [\n                                [\n                                    [8.1000e01, 8.2000e01, 8.3000e01],\n                                    [8.4000e01, 8.5000e01, 8.6000e01],\n                                    [8.7000e01, 8.8000e01, 8.9000e01],\n                                ],\n                                [\n                                    [9.0000e01, 9.1000e01, 9.2000e01],\n                                    [9.3000e01, 9.4000e01, 9.5000e01],\n                                    [9.6000e01, 9.7000e01, 9.8000e01],\n                                ],\n                                [\n                                    [9.9000e01, 1.0000e02, 1.0100e02],\n                                    [1.0200e02, 1.0300e02, 1.0400e02],\n                                    [1.0500e02, 1.0600e02, 1.0700e02],\n                                ],\n                            ],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestRandAffineGrid(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_affine_grid(self, input_param, input_data, expected_val):\n        g = RandAffineGrid(**input_param)\n        g.set_random_state(123)\n        result = g(**input_data)\n        if \"device\" in input_data:\n            self.assertEqual(result.device, input_data[device])\n        assert_allclose(result, expected_val, type_test=False, rtol=_rtol, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_affined.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandAffined\nfrom monai.utils import GridSampleMode, ensure_tuple_rep\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import assert_allclose, is_tf32_env\n\n_rtol = 1e-3 if is_tf32_env() else 1e-4\n\nTESTS = []\n\nfor device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n    TESTS.append(\n        [\n            dict(device=device, spatial_size=None, keys=(\"img\", \"seg\")),\n            {\n                \"img\": MetaTensor(torch.arange(27).reshape((3, 3, 3))),\n                \"seg\": MetaTensor(torch.arange(27).reshape((3, 3, 3))),\n            },\n            torch.arange(27).reshape((3, 3, 3)),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(device=device, spatial_size=(2, 2), keys=(\"img\", \"seg\")),\n            {\"img\": MetaTensor(torch.ones((3, 3, 3))), \"seg\": MetaTensor(torch.ones((3, 3, 3)))},\n            torch.ones((3, 2, 2)),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(device=device, spatial_size=(2, 2), cache_grid=True, keys=(\"img\", \"seg\")),\n            {\"img\": MetaTensor(torch.ones((3, 3, 3))), \"seg\": MetaTensor(torch.ones((3, 3, 3)))},\n            torch.ones((3, 2, 2)),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(device=device, spatial_size=(2, 2, 2), keys=(\"img\", \"seg\")),\n            {\"img\": MetaTensor(torch.ones((1, 3, 3, 3))), \"seg\": MetaTensor(torch.ones((1, 3, 3, 3)))},\n            torch.ones((1, 2, 2, 2)),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(\n                prob=0.9,\n                rotate_range=(np.pi / 2,),\n                shear_range=[1, 2],\n                translate_range=[2, 1],\n                spatial_size=(2, 2, 2),\n                padding_mode=\"zeros\",\n                device=device,\n                keys=(\"img\", \"seg\"),\n                mode=\"bilinear\",\n            ),\n            {\"img\": MetaTensor(torch.ones((1, 3, 3, 3))), \"seg\": MetaTensor(torch.ones((1, 3, 3, 3)))},\n            torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(\n                prob=0.9,\n                rotate_range=(np.pi / 2,),\n                shear_range=[1, 2],\n                translate_range=[2, 1],\n                scale_range=[0.1, 0.2],\n                spatial_size=(3, 3),\n                keys=(\"img\", \"seg\"),\n                device=device,\n            ),\n            {\n                \"img\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n                \"seg\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n            },\n            torch.tensor([[[18.7362, 15.5820, 12.4278], [27.3988, 24.2446, 21.0904], [36.0614, 32.9072, 29.7530]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(\n                prob=0.9,\n                mode=(\"bilinear\", \"nearest\"),\n                rotate_range=(np.pi / 2,),\n                shear_range=[1, 2],\n                translate_range=[2, 1],\n                scale_range=[0.1, 0.2],\n                spatial_size=(3, 3),\n                keys=(\"img\", \"seg\"),\n                device=device,\n            ),\n            {\n                \"img\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n                \"seg\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n            },\n            {\n                \"img\": MetaTensor(\n                    torch.tensor(\n                        [\n                            [\n                                [18.736153, 15.581954, 12.4277525],\n                                [27.398798, 24.244598, 21.090399],\n                                [36.061443, 32.90724, 29.753046],\n                            ]\n                        ]\n                    )\n                ),\n                \"seg\": MetaTensor(torch.tensor([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),\n            },\n        ]\n    )\n    TESTS.append(\n        [\n            dict(\n                prob=0.9,\n                rotate_range=(np.pi / 2,),\n                shear_range=[1, 2],\n                translate_range=[2, 1],\n                spatial_size=(2, 2, 2),\n                padding_mode=\"zeros\",\n                device=device,\n                keys=(\"img\", \"seg\"),\n                mode=GridSampleMode.BILINEAR,\n            ),\n            {\"img\": MetaTensor(torch.ones((1, 3, 3, 3))), \"seg\": MetaTensor(torch.ones((1, 3, 3, 3)))},\n            torch.tensor([[[[0.3658, 1.0000], [1.0000, 1.0000]], [[1.0000, 1.0000], [1.0000, 0.9333]]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            dict(\n                prob=0.9,\n                mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST),\n                rotate_range=(np.pi / 2,),\n                shear_range=[1, 2],\n                translate_range=[2, 1],\n                scale_range=[0.1, 0.2],\n                spatial_size=(3, 3),\n                keys=(\"img\", \"seg\"),\n                device=device,\n            ),\n            {\n                \"img\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n                \"seg\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n            },\n            {\n                \"img\": MetaTensor(\n                    np.array(\n                        [\n                            [\n                                [18.736153, 15.581954, 12.4277525],\n                                [27.398798, 24.244598, 21.090399],\n                                [36.061443, 32.90724, 29.753046],\n                            ]\n                        ]\n                    )\n                ),\n                \"seg\": MetaTensor(np.array([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),\n            },\n        ]\n    )\n    TESTS.append(\n        [\n            dict(\n                prob=0.9,\n                mode=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST),\n                rotate_range=(np.pi / 2,),\n                shear_range=[1, 2],\n                translate_range=[2, 1],\n                scale_range=[0.1, 0.2],\n                spatial_size=(3, 3),\n                cache_grid=True,\n                keys=(\"img\", \"seg\"),\n                device=device,\n            ),\n            {\n                \"img\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n                \"seg\": MetaTensor(torch.arange(64).reshape((1, 8, 8))),\n            },\n            {\n                \"img\": MetaTensor(\n                    torch.tensor(\n                        [\n                            [\n                                [18.736153, 15.581954, 12.4277525],\n                                [27.398798, 24.244598, 21.090399],\n                                [36.061443, 32.90724, 29.753046],\n                            ]\n                        ]\n                    )\n                ),\n                \"seg\": MetaTensor(torch.tensor([[[19.0, 20.0, 12.0], [27.0, 28.0, 20.0], [35.0, 36.0, 29.0]]])),\n            },\n        ]\n    )\n\n\nclass TestRandAffined(unittest.TestCase):\n    @parameterized.expand(x + [y] for x, y in itertools.product(TESTS, (False, True)))\n    def test_rand_affined(self, input_param, input_data, expected_val, track_meta):\n        set_track_meta(track_meta)\n        g = RandAffined(**input_param).set_random_state(123)\n        call_param = {\"data\": input_data}\n        res = g(**call_param)\n        # test lazy\n        if track_meta and input_data[\"img\"].ndim in (3, 4):\n            if \"mode\" not in input_param.keys():\n                input_param[\"mode\"] = \"bilinear\"\n            if not isinstance(input_param[\"keys\"], str):\n                input_param[\"mode\"] = ensure_tuple_rep(input_param[\"mode\"], len(input_param[\"keys\"]))\n            lazy_init_param = input_param.copy()\n            for key, mode in zip(input_param[\"keys\"], input_param[\"mode\"]):\n                lazy_init_param[\"keys\"], lazy_init_param[\"mode\"] = key, mode\n                resampler = RandAffined(**lazy_init_param).set_random_state(123)\n                expected_output = resampler(**call_param)\n                test_resampler_lazy(\n                    resampler, expected_output, lazy_init_param, call_param, seed=123, output_key=key, rtol=_rtol\n                )\n            resampler.lazy = False\n\n        if input_param.get(\"cache_grid\", False):\n            self.assertIsNotNone(g.rand_affine._cached_grid)\n        for key in res:\n            if isinstance(key, str) and key.endswith(\"_transforms\"):\n                continue\n            result = res[key]\n            if track_meta:\n                self.assertIsInstance(result, MetaTensor)\n                self.assertEqual(len(result.applied_operations), 1)\n            expected = expected_val[key] if isinstance(expected_val, dict) else expected_val\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-3, type_test=False)\n\n        g.set_random_state(4)\n        res = g(**call_param)\n        if not track_meta:\n            return\n\n        # affine should be tensor because the resampler only supports pytorch backend\n        if isinstance(res[\"img\"], MetaTensor) and \"extra_info\" in res[\"img\"].applied_operations[0]:\n            if not res[\"img\"].applied_operations[-1][\"extra_info\"]:\n                return\n            if not res[\"img\"].applied_operations[-1][\"extra_info\"][\"extra_info\"][\"do_resampling\"]:\n                return\n            affine_img = res[\"img\"].applied_operations[0][\"extra_info\"][\"extra_info\"][\"affine\"]\n            affine_seg = res[\"seg\"].applied_operations[0][\"extra_info\"][\"extra_info\"][\"affine\"]\n            assert_allclose(affine_img, affine_seg, rtol=_rtol, atol=1e-3)\n\n        res_inv = g.inverse(res)\n        for k, v in res_inv.items():\n            self.assertIsInstance(v, MetaTensor)\n            self.assertEqual(len(v.applied_operations), 0)\n            self.assertTupleEqual(v.shape, input_data[k].shape)\n\n    @parameterized.expand([(None,), ((2, -1),)])  # spatial size is None  # spatial size is dynamic\n    def test_ill_cache(self, spatial_size):\n        with self.assertWarns(UserWarning):\n            RandAffined(device=device, spatial_size=spatial_size, prob=1.0, cache_grid=True, keys=(\"img\", \"seg\"))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_axis_flip.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandAxisFlip\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\n\nclass TestRandAxisFlip(NumpyImageTestCase2D):\n    def test_correct_results(self):\n        for p in TEST_NDARRAYS_ALL:\n            flip = RandAxisFlip(prob=1.0)\n            flip.set_random_state(seed=321)\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            result = flip(**call_param)\n\n            # test lazy\n            test_resampler_lazy(flip, result, call_param=call_param, seed=321)\n            flip.lazy = False\n\n            expected = [np.flip(channel, flip._axis) for channel in self.imt[0]]\n            assert_allclose(result, p(np.stack(expected)), type_test=\"tensor\")\n            test_local_inversion(flip, result, im)\n\n            set_track_meta(False)\n            result = flip(im)\n            self.assertNotIsInstance(result, MetaTensor)\n            self.assertIsInstance(result, torch.Tensor)\n            set_track_meta(True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_axis_flipd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandAxisFlipd\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase3D, assert_allclose, test_local_inversion\n\n\nclass TestRandAxisFlip(NumpyImageTestCase3D):\n    def test_correct_results(self):\n        for p in TEST_NDARRAYS_ALL:\n            flip = RandAxisFlipd(keys=\"img\", prob=1.0)\n            flip.set_random_state(seed=1234)\n            im = p(self.imt[0])\n            call_param = {\"data\": {\"img\": im}}\n            result = flip(**call_param)\n\n            # test lazy\n            test_resampler_lazy(flip, result, call_param=call_param, output_key=\"img\", seed=1234)\n            flip.lazy = False\n\n            test_local_inversion(flip, result, {\"img\": im}, \"img\")\n            expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]]\n            assert_allclose(result[\"img\"], p(np.stack(expected)), type_test=\"tensor\")\n\n            set_track_meta(False)\n            result = flip({\"img\": im})[\"img\"]\n            self.assertNotIsInstance(result, MetaTensor)\n            self.assertIsInstance(result, torch.Tensor)\n            set_track_meta(True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_bias_field.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandBiasField\nfrom tests.test_utils import TEST_NDARRAYS\n\nTEST_CASES_2D = [{\"prob\": 1.0}, (3, 32, 32)]\nTEST_CASES_3D = [{\"prob\": 1.0}, (3, 32, 32, 32)]\nTEST_CASES_2D_ZERO_RANGE = [{\"prob\": 1.0, \"coeff_range\": (0.0, 0.0)}, (2, 3, 3)]\nTEST_CASES_2D_ONES = [\n    {\"prob\": 1.0, \"coeff_range\": (1.0, 1.0)},\n    np.asarray([[[7.389056, 0.1353353], [7.389056, 22026.46]]]),\n]\n\n\nclass TestRandBiasField(unittest.TestCase):\n    @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D])\n    def test_output_shape(self, class_args, img_shape):\n        for p in TEST_NDARRAYS:\n            for degree in [1, 2, 3]:\n                bias_field = RandBiasField(degree=degree, **class_args)\n                img = p(np.random.rand(*img_shape))\n                output = bias_field(img)\n                np.testing.assert_equal(output.shape, img_shape)\n                self.assertIn(output.dtype, (np.float32, torch.float32))\n\n                img_zero = np.zeros([*img_shape])\n                output_zero = bias_field(img_zero)\n                np.testing.assert_equal(output_zero, img_zero)\n\n    @parameterized.expand([TEST_CASES_2D_ZERO_RANGE])\n    def test_zero_range(self, class_args, img_shape):\n        bias_field = RandBiasField(**class_args)\n        img = np.ones(img_shape)\n        output = bias_field(img)\n        np.testing.assert_allclose(output, np.ones(img_shape), rtol=1e-3)\n\n    @parameterized.expand([TEST_CASES_2D_ONES])\n    def test_one_range_input(self, class_args, expected):\n        bias_field = RandBiasField(**class_args)\n        img = np.ones([1, 2, 2])\n        output = bias_field(img)\n        np.testing.assert_allclose(output, expected.astype(bias_field.dtype), rtol=1e-3)\n\n    def test_zero_prob(self):\n        bias_field = RandBiasField(prob=0.0)\n        img = np.random.rand(3, 32, 32)\n        output = bias_field(img)\n        np.testing.assert_equal(output, img)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_bias_fieldd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandBiasFieldd\n\nTEST_CASES_2D = [{\"prob\": 1.0}, (3, 32, 32)]\nTEST_CASES_3D = [{\"prob\": 1.0}, (3, 32, 32, 32)]\nTEST_CASES_2D_ZERO_RANGE = [{\"prob\": 1.0, \"coeff_range\": (0.0, 0.0)}, (3, 32, 32)]\nTEST_CASES_2D_ONES = [\n    {\"prob\": 1.0, \"coeff_range\": (1.0, 1.0)},\n    np.asarray([[[7.3890562e00, 1.3533528e-01], [7.3890562e00, 2.2026465e04]]]),\n]\n\n\nclass TestRandBiasFieldd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASES_2D, TEST_CASES_3D])\n    def test_output_shape(self, class_args, img_shape):\n        key = \"img\"\n        bias_field = RandBiasFieldd(keys=[key], **class_args)\n        img = np.random.rand(*img_shape)\n        output = bias_field({key: img})\n        np.testing.assert_equal(output[key].shape, img_shape)\n\n    @parameterized.expand([TEST_CASES_2D_ZERO_RANGE])\n    def test_zero_range(self, class_args, img_shape):\n        key = \"img\"\n        bias_field = RandBiasFieldd(keys=[key], **class_args)\n        img = np.ones(img_shape)\n        output = bias_field({key: img})\n        np.testing.assert_allclose(output[key], np.ones(img_shape))\n\n    @parameterized.expand([TEST_CASES_2D_ONES])\n    def test_one_range_input(self, class_args, expected):\n        key = \"img\"\n        bias_field = RandBiasFieldd(keys=[key], **class_args)\n        img = np.ones([1, 2, 2])\n        output = bias_field({key: img})\n        np.testing.assert_allclose(output[key], expected.astype(bias_field.rand_bias_field.dtype), rtol=1e-3)\n\n    def test_zero_prob(self):\n        key = \"img\"\n        bias_field = RandBiasFieldd(keys=[key], prob=0.0)\n        img = np.random.rand(3, 32, 32)\n        output = bias_field({key: img})\n        np.testing.assert_equal(output[key], img)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_coarse_dropout.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandCoarseDropout\nfrom monai.utils import fall_back_tuple\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTEST_CASE_0 = [\n    {\"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": 5, \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_1 = [\n    {\"holes\": 1, \"spatial_size\": [1, 2, 3], \"fill_value\": 5, \"max_holes\": 5, \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_2 = [\n    {\"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": 5, \"max_spatial_size\": [4, 4, 3], \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_3 = [\n    {\"holes\": 2, \"spatial_size\": [2, -1, 2], \"fill_value\": 5, \"max_spatial_size\": [4, 4, -1], \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_4 = [\n    {\"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": (3, 6), \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_5 = [\n    {\"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": None, \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_6 = [\n    {\"holes\": 2, \"spatial_size\": [2, 2, 2], \"dropout_holes\": False, \"fill_value\": (3, 6), \"prob\": 1.0},\n    np.random.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\nTEST_CASE_7 = [\n    {\"holes\": 2, \"spatial_size\": [2, 2, 2], \"dropout_holes\": False, \"fill_value\": (3, 6), \"prob\": 1.0},\n    torch.randint(0, 2, size=[3, 3, 3, 4]),\n]\n\n\nclass TestRandCoarseDropout(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]\n    )\n    def test_value(self, input_param, input_data):\n        for p in TEST_NDARRAYS:\n            dropout = RandCoarseDropout(**input_param)\n            im = p(input_data)\n            result = dropout(im)\n        holes = input_param.get(\"holes\")\n        max_holes = input_param.get(\"max_holes\")\n        spatial_size = fall_back_tuple(input_param.get(\"spatial_size\"), input_data.shape[1:])\n        max_spatial_size = fall_back_tuple(input_param.get(\"max_spatial_size\"), input_data.shape[1:])\n\n        if max_holes is None:\n            self.assertEqual(len(dropout.hole_coords), holes)\n        else:\n            self.assertGreaterEqual(len(dropout.hole_coords), holes)\n            self.assertLessEqual(len(dropout.hole_coords), max_holes)\n\n        for h in dropout.hole_coords:\n            data = result[h]\n            # test hole value\n            if input_param.get(\"dropout_holes\", True):\n                fill_value = input_param.get(\"fill_value\", None)\n                if isinstance(fill_value, (int, float)):\n                    assert_allclose(data, fill_value, type_test=False)\n                elif fill_value is not None:\n                    min_value = data.min()\n                    max_value = data.max()\n                    self.assertGreaterEqual(max_value, min_value)\n                    self.assertGreaterEqual(min_value, fill_value[0])\n                    self.assertLess(max_value, fill_value[1])\n            else:\n                assert_allclose(data, input_data[h], type_test=False)\n\n            if max_spatial_size is None:\n                self.assertTupleEqual(data.shape[1:], tuple(spatial_size))\n            else:\n                for d, s, m in zip(data.shape[1:], spatial_size, max_spatial_size):\n                    self.assertGreaterEqual(d, s)\n                    self.assertLessEqual(d, m)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_coarse_dropoutd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandCoarseDropoutd\nfrom monai.utils import fall_back_tuple\n\nTEST_CASE_0 = [\n    {\"keys\": \"img\", \"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": 5, \"prob\": 1.0},\n    {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 4])},\n]\n\nTEST_CASE_1 = [\n    {\"keys\": \"img\", \"holes\": 1, \"spatial_size\": [1, 2, 3], \"fill_value\": 5, \"max_holes\": 5, \"prob\": 1.0},\n    {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 4])},\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"img\", \"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": 5, \"max_spatial_size\": [4, 4, 3], \"prob\": 1.0},\n    {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 4])},\n]\n\nTEST_CASE_3 = [\n    {\n        \"keys\": \"img\",\n        \"holes\": 2,\n        \"spatial_size\": [2, -1, 2],\n        \"fill_value\": 5,\n        \"max_spatial_size\": [4, 4, -1],\n        \"prob\": 1.0,\n    },\n    {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 4])},\n]\n\nTEST_CASE_4 = [\n    {\"keys\": \"img\", \"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": (0.2, 0.6), \"prob\": 1.0},\n    {\"img\": np.random.rand(3, 3, 3, 4)},\n]\n\nTEST_CASE_5 = [\n    {\"keys\": \"img\", \"holes\": 2, \"spatial_size\": [2, 2, 2], \"fill_value\": None, \"prob\": 1.0},\n    {\"img\": np.random.rand(3, 3, 3, 4)},\n]\n\nTEST_CASE_6 = [\n    {\"keys\": \"img\", \"holes\": 2, \"spatial_size\": [2, 2, 2], \"dropout_holes\": False, \"fill_value\": 0.5, \"prob\": 1.0},\n    {\"img\": np.random.rand(3, 3, 3, 4)},\n]\n\n\nclass TestRandCoarseDropoutd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])\n    def test_value(self, input_param, input_data):\n        dropout = RandCoarseDropoutd(**input_param)\n        result = dropout(input_data)[\"img\"]\n        holes = input_param.get(\"holes\")\n        max_holes = input_param.get(\"max_holes\")\n        spatial_size = fall_back_tuple(input_param.get(\"spatial_size\"), input_data[\"img\"].shape[1:])\n        max_spatial_size = fall_back_tuple(input_param.get(\"max_spatial_size\"), input_data[\"img\"].shape[1:])\n\n        if max_holes is None:\n            self.assertEqual(len(dropout.dropper.hole_coords), holes)\n        else:\n            self.assertGreaterEqual(len(dropout.dropper.hole_coords), holes)\n            self.assertLessEqual(len(dropout.dropper.hole_coords), max_holes)\n\n        for h in dropout.dropper.hole_coords:\n            data = result[h]\n            # test hole value\n            if input_param.get(\"dropout_holes\", True):\n                fill_value = input_param.get(\"fill_value\", 0)\n                if isinstance(fill_value, (int, float)):\n                    np.testing.assert_allclose(data, fill_value)\n                elif fill_value is not None:\n                    min_value = data.min()\n                    max_value = data.max()\n                    self.assertGreaterEqual(max_value, min_value)\n                    self.assertGreaterEqual(min_value, fill_value[0])\n                    self.assertLess(max_value, fill_value[1])\n            else:\n                np.testing.assert_allclose(data, input_data[\"img\"][h])\n\n            if max_spatial_size is None:\n                self.assertTupleEqual(data.shape[1:], tuple(spatial_size))\n            else:\n                for d, s, m in zip(data.shape[1:], spatial_size, max_spatial_size):\n                    self.assertGreaterEqual(d, s)\n                    self.assertLessEqual(d, m)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_coarse_shuffle.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandCoarseShuffle\n\nTEST_CASES = [\n    [\n        {\"holes\": 5, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 0.0},\n        {\"img\": np.arange(8).reshape((1, 2, 2, 2))},\n        np.arange(8).reshape((1, 2, 2, 2)),\n    ],\n    [\n        {\"holes\": 10, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 1.0},\n        {\"img\": np.arange(27).reshape((1, 3, 3, 3))},\n        np.asarray(\n            [\n                [\n                    [[8, 19, 26], [24, 6, 15], [0, 13, 25]],\n                    [[17, 3, 5], [10, 1, 12], [22, 4, 11]],\n                    [[21, 20, 23], [14, 2, 16], [18, 9, 7]],\n                ]\n            ]\n        ),\n    ],\n    [\n        {\"holes\": 2, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 1.0},\n        {\"img\": np.arange(16).reshape((2, 2, 2, 2))},\n        np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]),\n    ],\n    [\n        {\"holes\": 2, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 1.0},\n        {\"img\": torch.arange(16).reshape((2, 2, 2, 2))},\n        torch.as_tensor([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]),\n    ],\n]\n\n\nclass TestRandCoarseShuffle(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shuffle(self, input_param, input_data, expected_val):\n        g = RandCoarseShuffle(**input_param)\n        g.set_random_state(seed=12)\n        result = g(**input_data)\n        np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_coarse_shuffled.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandCoarseShuffled\n\nTEST_CASES = [\n    [\n        {\"keys\": \"img\", \"holes\": 5, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 0.0},\n        {\"img\": np.arange(8).reshape((1, 2, 2, 2))},\n        np.arange(8).reshape((1, 2, 2, 2)),\n    ],\n    [\n        {\"keys\": \"img\", \"holes\": 10, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 1.0},\n        {\"img\": np.arange(27).reshape((1, 3, 3, 3))},\n        np.asarray(\n            [\n                [\n                    [[8, 19, 26], [24, 6, 15], [0, 13, 25]],\n                    [[17, 3, 5], [10, 1, 12], [22, 4, 11]],\n                    [[21, 20, 23], [14, 2, 16], [18, 9, 7]],\n                ]\n            ]\n        ),\n    ],\n    [\n        {\"keys\": \"img\", \"holes\": 2, \"spatial_size\": 1, \"max_spatial_size\": -1, \"prob\": 1.0},\n        {\"img\": np.arange(16).reshape((2, 2, 2, 2))},\n        np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]),\n    ],\n]\n\n\nclass TestRandCoarseShuffled(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_shuffle(self, input_param, input_data, expected_val):\n        g = RandCoarseShuffled(**input_param)\n        g.set_random_state(seed=12)\n        result = g(input_data)\n        np.testing.assert_allclose(result[\"img\"], expected_val, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_crop_by_label_classes.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import ClassesToIndices, RandCropByLabelClasses\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS_INDICES, TESTS_SHAPE = [], []\nfor p in TEST_NDARRAYS_ALL:\n    # One-Hot label\n    TESTS_INDICES.append(\n        [\n            {\n                \"label\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"num_classes\": None,\n                \"spatial_size\": [2, 2, -1],\n                \"ratios\": [1, 1, 1],\n                \"num_samples\": 2,\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image_threshold\": 0,\n            },\n            {\"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))},\n            list,\n            (3, 2, 2, 3),\n        ]\n    )\n\n    TESTS_INDICES.append(\n        [\n            # Argmax label\n            {\n                \"label\": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),\n                \"num_classes\": 2,\n                \"spatial_size\": [2, 2, 2],\n                \"ratios\": [1, 1],\n                \"num_samples\": 2,\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image_threshold\": 0,\n            },\n            {\"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3]))},\n            list,\n            (3, 2, 2, 2),\n        ]\n    )\n\n    TESTS_SHAPE.append(\n        [\n            # provide label at runtime\n            {\n                \"label\": None,\n                \"num_classes\": 2,\n                \"spatial_size\": [2, 2, 2],\n                \"ratios\": [1, 1],\n                \"num_samples\": 2,\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image_threshold\": 0,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n            },\n            list,\n            (3, 2, 2, 2),\n        ]\n    )\n    TESTS_SHAPE.append(\n        [\n            # provide label at runtime\n            {\n                \"label\": None,\n                \"num_classes\": 2,\n                \"spatial_size\": [4, 4, 2],\n                \"ratios\": [1, 1],\n                \"num_samples\": 2,\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image_threshold\": 0,\n                \"allow_smaller\": True,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n            },\n            list,\n            (3, 3, 3, 2),\n        ]\n    )\n    TESTS_SHAPE.append(\n        [\n            # provide label at runtime\n            {\n                \"label\": None,\n                \"num_classes\": 2,\n                \"spatial_size\": [4, 4, 4],\n                \"ratios\": (1, 1),  # test no assignment\n                \"num_samples\": 2,\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image_threshold\": 0,\n                \"allow_smaller\": True,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 1, size=[1, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n            },\n            list,\n            (3, 3, 3, 3),\n        ]\n    )\n\n\nclass TestRandCropByLabelClasses(unittest.TestCase):\n    @parameterized.expand(TESTS_INDICES + TESTS_SHAPE)\n    def test_type_shape(self, input_param, input_data, expected_type, expected_shape):\n        result = RandCropByLabelClasses(**input_param)(**input_data)\n        self.assertIsInstance(result, expected_type)\n        self.assertTupleEqual(result[0].shape, expected_shape)\n\n    @parameterized.expand(TESTS_INDICES)\n    def test_indices(self, input_param, input_data, expected_type, expected_shape):\n        input_param[\"indices\"] = ClassesToIndices(num_classes=input_param[\"num_classes\"])(input_param[\"label\"])\n        result = RandCropByLabelClasses(**input_param)(**input_data)\n        self.assertIsInstance(result, expected_type)\n        self.assertTupleEqual(result[0].shape, expected_shape)\n        # test set indices at runtime\n        input_data[\"indices\"] = input_param[\"indices\"]\n        result = RandCropByLabelClasses(**input_param)(**input_data)\n        self.assertIsInstance(result, expected_type)\n        self.assertTupleEqual(result[0].shape, expected_shape)\n\n    @parameterized.expand(TESTS_INDICES + TESTS_SHAPE)\n    def test_pending_ops(self, input_param, input_data, _expected_type, _expected_shape):\n        cropper = RandCropByLabelClasses(**input_param)\n        # non-lazy\n        cropper.set_random_state(0)\n        expected = cropper(**input_data)\n        self.assertIsInstance(expected[0], MetaTensor)\n        # lazy\n        cropper.set_random_state(0)\n        cropper.lazy = True\n        pending_result = cropper(**input_data)\n        for i, _pending_result in enumerate(pending_result):\n            self.assertIsInstance(_pending_result, MetaTensor)\n            assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine)\n            assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:])\n            # only support nearest\n            result = apply_pending(_pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n            # compare\n            assert_allclose(result, expected[i], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_crop_by_label_classesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import ClassesToIndicesd, RandCropByLabelClassesd\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    TESTS.append(\n        [\n            # One-Hot label\n            {\n                \"keys\": \"img\",\n                \"label_key\": \"label\",\n                \"num_classes\": None,\n                \"spatial_size\": [2, 2, -1],\n                \"ratios\": [1, 1, 1],\n                \"num_samples\": 2,\n                \"image_key\": \"image\",\n                \"image_threshold\": 0,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n            },\n            list,\n            (3, 2, 2, 3),\n        ]\n    )\n\n    TESTS.append(\n        [\n            # Argmax label\n            {\n                \"keys\": \"img\",\n                \"label_key\": \"label\",\n                \"num_classes\": 2,\n                \"spatial_size\": [2, 2, 2],\n                \"ratios\": [1, 1],\n                \"num_samples\": 2,\n                \"image_key\": \"image\",\n                \"image_threshold\": 0,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),\n            },\n            list,\n            (3, 2, 2, 2),\n        ]\n    )\n\n    TESTS.append(\n        [\n            # Argmax label\n            {\n                \"keys\": \"img\",\n                \"label_key\": \"label\",\n                \"num_classes\": 2,\n                \"spatial_size\": [4, 4, 2],\n                \"ratios\": (1, 1),  # test no assignment\n                \"num_samples\": 2,\n                \"image_key\": \"image\",\n                \"image_threshold\": 0,\n                \"allow_smaller\": True,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 1, size=[1, 3, 3, 3])),\n            },\n            list,\n            (3, 3, 3, 2),\n        ]\n    )\n\n    TESTS.append(\n        [\n            # Argmax label\n            {\n                \"keys\": \"img\",\n                \"label_key\": \"label\",\n                \"num_classes\": 2,\n                \"spatial_size\": [4, 4, 4],\n                \"ratios\": [1, 1],\n                \"num_samples\": 2,\n                \"image_key\": \"image\",\n                \"image_threshold\": 0,\n                \"allow_smaller\": True,\n                \"max_samples_per_class\": 10,\n            },\n            {\n                \"img\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"image\": p(np.random.randint(0, 2, size=[3, 3, 3, 3])),\n                \"label\": p(np.random.randint(0, 2, size=[1, 3, 3, 3])),\n            },\n            list,\n            (3, 3, 3, 3),\n        ]\n    )\n\n\nclass TestRandCropByLabelClassesd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, input_param, input_data, expected_type, expected_shape):\n        result = RandCropByLabelClassesd(**input_param)(input_data)\n        self.assertIsInstance(result, expected_type)\n        self.assertTupleEqual(result[0][\"img\"].shape, expected_shape)\n        # test with pre-computed indices\n        input_data = ClassesToIndicesd(keys=\"label\", num_classes=input_param[\"num_classes\"])(input_data)\n        input_param[\"indices_key\"] = \"label_cls_indices\"\n        result = RandCropByLabelClassesd(**input_param)(input_data)\n        self.assertIsInstance(result, expected_type)\n        self.assertTupleEqual(result[0][\"img\"].shape, expected_shape)\n        _len = len(tuple(input_data.keys())) - 1  # except for the indices_key\n        self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())[:-1])\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_data, _expected_type, _expected_shape):\n        cropper = RandCropByLabelClassesd(**input_param)\n        # non-lazy\n        cropper.set_random_state(0)\n        expected = cropper(input_data)\n        self.assertIsInstance(expected[0][\"img\"], MetaTensor)\n        # lazy\n        cropper.set_random_state(0)\n        cropper.lazy = True\n        pending_result = cropper(input_data)\n        for i, _pending_result in enumerate(pending_result):\n            self.assertIsInstance(_pending_result[\"img\"], MetaTensor)\n            assert_allclose(_pending_result[\"img\"].peek_pending_affine(), expected[i][\"img\"].affine)\n            assert_allclose(_pending_result[\"img\"].peek_pending_shape(), expected[i][\"img\"].shape[1:])\n            # only support nearest\n            result = apply_pending(_pending_result[\"img\"], overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n            # compare\n            assert_allclose(result, expected[i][\"img\"], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_crop_by_pos_neg_label.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import RandCropByPosNegLabel\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = [\n    [\n        {\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"spatial_size\": [2, 2, -1],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"image_threshold\": 0,\n        },\n        {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 3])},\n        (3, 2, 2, 3),\n    ],\n    [\n        {\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"spatial_size\": [2, 2, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"image_threshold\": 0,\n        },\n        {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 3])},\n        (3, 2, 2, 2),\n    ],\n    [\n        {\n            \"label\": None,\n            \"spatial_size\": [2, 2, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"image_threshold\": 0,\n        },\n        {\n            \"img\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n        },\n        (3, 2, 2, 2),\n    ],\n    [\n        {\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"spatial_size\": [4, 4, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"allow_smaller\": True,\n        },\n        {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 3])},\n        (3, 3, 3, 2),\n    ],\n    [\n        {\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"spatial_size\": [4, 4, 4],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"allow_smaller\": True,\n        },\n        {\"img\": np.random.randint(0, 2, size=[3, 3, 3, 3])},\n        (3, 3, 3, 3),\n    ],\n]\n\n\nclass TestRandCropByPosNegLabel(unittest.TestCase):\n    @staticmethod\n    def convert_data_type(im_type, d, keys=(\"img\", \"image\", \"label\")):\n        out = deepcopy(d)\n        for k, v in out.items():\n            if k in keys and isinstance(v, np.ndarray):\n                out[k] = im_type(v)\n        return out\n\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, input_param, input_data, expected_shape):\n        results = []\n        for p in TEST_NDARRAYS_ALL:\n            input_param_mod = self.convert_data_type(p, input_param)\n            input_data_mod = self.convert_data_type(p, input_data)\n            cropper = RandCropByPosNegLabel(**input_param_mod)\n            cropper.set_random_state(0)\n            result = cropper(**input_data_mod)\n            self.assertListEqual(cropper.spatial_size, input_param[\"spatial_size\"])\n\n            self.assertIsInstance(result, list)\n            self.assertTupleEqual(result[0].shape, expected_shape)\n\n            # check for same results across numpy, torch.Tensor and torch.cuda.Tensor\n            result = np.asarray([i if isinstance(i, np.ndarray) else i.cpu().numpy() for i in result])\n            results.append(np.asarray(result))\n            if len(results) > 1:\n                np.testing.assert_allclose(results[0], results[-1])\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_data, _expected_shape):\n        for p in TEST_NDARRAYS_ALL:\n            input_param_mod = self.convert_data_type(p, input_param)\n            input_data_mod = self.convert_data_type(p, input_data)\n            cropper = RandCropByPosNegLabel(**input_param_mod)\n            # non-lazy\n            cropper.set_random_state(0)\n            expected = cropper(**input_data_mod)\n            self.assertIsInstance(expected[0], MetaTensor)\n            # lazy\n            cropper.set_random_state(0)\n            cropper.lazy = True\n            pending_result = cropper(**input_data_mod)\n            for i, _pending_result in enumerate(pending_result):\n                self.assertIsInstance(_pending_result, MetaTensor)\n                assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine)\n                assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:])\n                # only support nearest\n                result = apply_pending(_pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n                # compare\n                assert_allclose(result, expected[i], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_crop_by_pos_neg_labeld.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import RandCropByPosNegLabeld\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = [\n    [\n        {\n            \"keys\": [\"image\", \"extra\", \"label\"],\n            \"label_key\": \"label\",\n            \"spatial_size\": [-1, 2, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image_key\": None,\n            \"image_threshold\": 0,\n        },\n        {\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"extra\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n        },\n        (3, 3, 2, 2),\n    ],\n    [\n        {\n            \"keys\": [\"image\", \"extra\", \"label\"],\n            \"label_key\": \"label\",\n            \"spatial_size\": [2, 2, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image_key\": None,\n            \"image_threshold\": 0,\n        },\n        {\n            \"image\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"extra\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n            \"label\": np.random.randint(0, 2, size=[3, 3, 3, 3]),\n        },\n        (3, 2, 2, 2),\n    ],\n    [\n        {\n            \"keys\": [\"image\", \"extra\", \"label\"],\n            \"label_key\": \"label\",\n            \"spatial_size\": [2, 2, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image_key\": None,\n            \"image_threshold\": 0,\n        },\n        {\"image\": np.zeros([3, 3, 3, 3]) - 1, \"extra\": np.zeros([3, 3, 3, 3]), \"label\": np.ones([3, 3, 3, 3])},\n        (3, 2, 2, 2),\n    ],\n    [\n        {\n            \"keys\": [\"image\", \"extra\", \"label\"],\n            \"label_key\": \"label\",\n            \"spatial_size\": [4, 4, 2],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image_key\": None,\n            \"image_threshold\": 0,\n            \"allow_smaller\": True,\n        },\n        {\"image\": np.zeros([3, 3, 3, 3]) - 1, \"extra\": np.zeros([3, 3, 3, 3]), \"label\": np.ones([3, 3, 3, 3])},\n        (3, 3, 3, 2),\n    ],\n    [\n        {\n            \"keys\": [\"image\", \"extra\", \"label\"],\n            \"label_key\": \"label\",\n            \"spatial_size\": [4, 4, 4],\n            \"pos\": 1,\n            \"neg\": 1,\n            \"num_samples\": 2,\n            \"image_key\": None,\n            \"image_threshold\": 0,\n            \"allow_smaller\": True,\n        },\n        {\"image\": np.zeros([3, 3, 3, 3]) - 1, \"extra\": np.zeros([3, 3, 3, 3]), \"label\": np.ones([3, 3, 3, 3])},\n        (3, 3, 3, 3),\n    ],\n]\n\n\nclass TestRandCropByPosNegLabeld(unittest.TestCase):\n    @staticmethod\n    def convert_data_type(im_type, d, keys=(\"img\", \"image\", \"label\")):\n        out = deepcopy(d)\n        for k, v in out.items():\n            if k in keys and isinstance(v, np.ndarray):\n                out[k] = im_type(v)\n        return out\n\n    @parameterized.expand(TESTS)\n    def test_type_shape(self, input_param, input_data, expected_shape):\n        for p in TEST_NDARRAYS_ALL:\n            input_param_mod = self.convert_data_type(p, input_param)\n            input_data_mod = self.convert_data_type(p, input_data)\n            cropper = RandCropByPosNegLabeld(**input_param_mod)\n            cropper.set_random_state(0)\n            result = cropper(input_data_mod)\n            self.assertListEqual(cropper.cropper.spatial_size, input_param[\"spatial_size\"])\n\n            self.assertIsInstance(result, list)\n\n            _len = len(tuple(input_data.keys()))\n            self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys()))\n            for k in (\"image\", \"extra\", \"label\"):\n                self.assertTupleEqual(result[0][k].shape, expected_shape)\n                for i, item in enumerate(result):\n                    self.assertEqual(item[k].meta[\"patch_index\"], i)\n\n    def test_correct_center(self):\n        cropper = RandCropByPosNegLabeld(keys=\"label\", label_key=\"label\", spatial_size=[3, 3])\n        cropper.set_random_state(0)\n        test_image = {\"label\": np.asarray([[[1, 0, 0, 1], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 1]]])}\n        result = cropper(test_image)\n        np.testing.assert_allclose(result[0][\"label\"], np.asarray([[[0, 0, 1], [0, 0, 0], [0, 0, 0]]]))\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_data, _expected_shape):\n        for p in TEST_NDARRAYS_ALL:\n            input_param_mod = self.convert_data_type(p, input_param)\n            input_data_mod = self.convert_data_type(p, input_data)\n            cropper = RandCropByPosNegLabeld(**input_param_mod)\n            # non-lazy\n            cropper.set_random_state(0)\n            expected = cropper(input_data_mod)\n            self.assertIsInstance(expected[0][\"image\"], MetaTensor)\n            # lazy\n            cropper.set_random_state(0)\n            cropper.lazy = True\n            pending_result = cropper(input_data_mod)\n            for i, _pending_result in enumerate(pending_result):\n                self.assertIsInstance(_pending_result[\"image\"], MetaTensor)\n                assert_allclose(_pending_result[\"image\"].peek_pending_affine(), expected[i][\"image\"].affine)\n                assert_allclose(_pending_result[\"image\"].peek_pending_shape(), expected[i][\"image\"].shape[1:])\n                # only support nearest\n                overrides = {\"mode\": \"nearest\", \"align_corners\": False}\n                result_image = apply_pending(_pending_result[\"image\"], overrides=overrides)[0]\n                result_extra = apply_pending(_pending_result[\"extra\"], overrides=overrides)[0]\n                # compare\n                assert_allclose(result_image, expected[i][\"image\"], rtol=1e-5)\n                assert_allclose(result_extra, expected[i][\"extra\"], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_cucim_dict_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandCuCIMd\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import HAS_CUPY, skip_if_no_cuda\n\n_, has_cut = optional_import(\"cucim.core.operations.expose.transform\")\ncp, _ = optional_import(\"cupy\")\n\nset_determinism(seed=0)\n\nTEST_CASE_COLOR_JITTER_1 = [\n    {\"name\": \"color_jitter\", \"brightness\": 0.0, \"contrast\": 0.0, \"saturation\": 0.0, \"hue\": 0.0},\n    np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8),\n    np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8),\n]\n\nTEST_CASE_FLIP_1 = [\n    {\"name\": \"image_flip\", \"spatial_axis\": -1},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_RAND_ROTATE_1 = [\n    {\"name\": \"rand_image_rotate_90\", \"prob\": 1.0, \"max_k\": 1, \"spatial_axis\": (-2, -1)},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_RAND_ROTATE_2 = [\n    {\"name\": \"rand_image_rotate_90\", \"prob\": 0.0, \"max_k\": 1, \"spatial_axis\": (-2, -1)},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n]\n\nTEST_CASE_SCALE_INTENSITY_1 = [\n    {\"name\": \"scale_intensity_range\", \"a_min\": 0.0, \"a_max\": 4.0, \"b_min\": 0.0, \"b_max\": 1.0, \"clip\": False},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32),\n]\n\nTEST_CASE_ZOOM_1 = [\n    {\"name\": \"zoom\", \"zoom_factor\": (0.5, 0.5)},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]),\n]\n\nTEST_CASE_RAND_ZOOM_1 = [\n    {\"name\": \"rand_zoom\", \"prob\": 1.0, \"min_zoom\": 0.5, \"max_zoom\": 0.5},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]),\n]\n\nTEST_CASE_RAND_ZOOM_2 = [\n    {\"name\": \"rand_zoom\", \"prob\": 0.0, \"min_zoom\": 0.5, \"max_zoom\": 0.5},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n]\n\n\n@skip_if_no_cuda\n@unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n@unittest.skipUnless(has_cut, \"cuCIM transforms are required.\")\nclass TestRandCuCIMDict(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_numpy_single(self, params, input, expected):\n        input = {\"image\": np.copy(input)}\n        output = RandCuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, np.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_numpy_batch(self, params, input, expected):\n        input = {\"image\": np.copy(input[cp.newaxis, ...])}\n        expected = expected[cp.newaxis, ...]\n        output = RandCuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, np.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_cupy_single(self, params, input, expected):\n        input = {\"image\": cp.asarray(input)}\n        expected = cp.asarray(expected)\n        output = RandCuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, cp.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_cupy_batch(self, params, input, expected):\n        input = {\"image\": cp.asarray(input)[cp.newaxis, ...]}\n        expected = cp.asarray(expected)[cp.newaxis, ...]\n        output = RandCuCIMd(keys=\"image\", **params)(input)[\"image\"]\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, cp.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_cucim_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandCuCIM\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import HAS_CUPY, skip_if_no_cuda\n\n_, has_cut = optional_import(\"cucim.core.operations.expose.transform\")\ncp, _ = optional_import(\"cupy\")\n\nset_determinism(seed=0)\n\nTEST_CASE_COLOR_JITTER_1 = [\n    {\"name\": \"color_jitter\", \"brightness\": 0.0, \"contrast\": 0.0, \"saturation\": 0.0, \"hue\": 0.0},\n    np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8),\n    np.array([[[0, 1], [2, 3]], [[0, 10], [20, 30]], [[0, 50], [100, 150]]], dtype=np.uint8),\n]\n\nTEST_CASE_FLIP_1 = [\n    {\"name\": \"image_flip\", \"spatial_axis\": -1},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_RAND_ROTATE_1 = [\n    {\"name\": \"rand_image_rotate_90\", \"prob\": 1.0, \"max_k\": 1, \"spatial_axis\": (-2, -1)},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32),\n]\n\nTEST_CASE_RAND_ROTATE_2 = [\n    {\"name\": \"rand_image_rotate_90\", \"prob\": 0.0, \"max_k\": 1, \"spatial_axis\": (-2, -1)},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n]\n\nTEST_CASE_SCALE_INTENSITY_1 = [\n    {\"name\": \"scale_intensity_range\", \"a_min\": 0.0, \"a_max\": 4.0, \"b_min\": 0.0, \"b_max\": 1.0, \"clip\": False},\n    np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),\n    np.array([[[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]], [[0.0, 0.25], [0.5, 0.75]]], dtype=np.float32),\n]\n\nTEST_CASE_ZOOM_1 = [\n    {\"name\": \"zoom\", \"zoom_factor\": (0.5, 0.5)},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]),\n]\n\nTEST_CASE_RAND_ZOOM_1 = [\n    {\"name\": \"rand_zoom\", \"prob\": 1.0, \"min_zoom\": 0.5, \"max_zoom\": 0.5},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.concatenate([np.ones((1, 3, 3), dtype=np.float32) * 1.0, np.ones((1, 3, 3), dtype=np.float32) * 2.0]),\n]\n\nTEST_CASE_RAND_ZOOM_2 = [\n    {\"name\": \"rand_zoom\", \"prob\": 0.0, \"min_zoom\": 0.5, \"max_zoom\": 0.5},\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n    np.mgrid[:3, 1:4].astype(dtype=np.float32),\n]\n\n\n@skip_if_no_cuda\n@unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n@unittest.skipUnless(has_cut, \"cuCIM transforms are required.\")\nclass TestRandCuCIM(unittest.TestCase):\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_numpy_single(self, params, input, expected):\n        input = np.copy(input)\n        output = RandCuCIM(**params)(input)\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, np.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_numpy_batch(self, params, input, expected):\n        input = np.copy(input[cp.newaxis, ...])\n        expected = expected[cp.newaxis, ...]\n        output = RandCuCIM(**params)(input)\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, np.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_cupy_single(self, params, input, expected):\n        input = cp.asarray(input)\n        expected = cp.asarray(expected)\n        output = RandCuCIM(**params)(input)\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, cp.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_COLOR_JITTER_1,\n            TEST_CASE_FLIP_1,\n            TEST_CASE_RAND_ROTATE_1,\n            TEST_CASE_RAND_ROTATE_2,\n            TEST_CASE_SCALE_INTENSITY_1,\n            TEST_CASE_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_1,\n            TEST_CASE_RAND_ZOOM_2,\n        ]\n    )\n    def test_tramsforms_cupy_batch(self, params, input, expected):\n        input = cp.asarray(input)[cp.newaxis, ...]\n        expected = cp.asarray(expected)[cp.newaxis, ...]\n        output = RandCuCIM(**params)(input)\n        self.assertTrue(output.dtype == expected.dtype)\n        self.assertTrue(isinstance(output, cp.ndarray))\n        cp.testing.assert_allclose(output, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_deform_grid.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandDeformGrid\nfrom tests.test_utils import assert_allclose\n\nTEST_CASES = [\n    [\n        dict(spacing=(1, 2), magnitude_range=(1.0, 2.0), device=None),\n        {\"spatial_size\": (3, 3)},\n        np.array(\n            [\n                [\n                    [-3.45774551, -0.6608006, -1.62002671, -4.02259806, -2.77692349],\n                    [1.21748926, -4.25845712, -1.57592837, 0.69985342, -2.16382767],\n                    [-0.91158377, -0.12717178, 2.00258405, -0.85789449, -0.59616292],\n                    [0.41676882, 3.96204313, 3.93633727, 2.34820726, 1.51855713],\n                    [2.99011186, 4.00170105, 0.74339613, 3.57886072, 0.31633439],\n                ],\n                [\n                    [-4.85634965, -0.78197195, -1.91838077, 1.81192079, 2.84286669],\n                    [-4.34323645, -5.75784424, -2.37875058, 1.06023016, 5.24536301],\n                    [-4.23315172, -1.99617861, 0.92412057, 0.81899041, 4.38084451],\n                    [-5.08141703, -4.31985211, -0.52488611, 2.77048576, 4.45464513],\n                    [-4.01588556, 1.21238156, 0.55444352, 3.31421131, 7.00529793],\n                ],\n                [\n                    [1.0, 1.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0, 1.0],\n                    [1.0, 1.0, 1.0, 1.0, 1.0],\n                ],\n            ]\n        ),\n    ],\n    [\n        dict(spacing=(1, 2, 2), magnitude_range=(1.0, 3.0), device=None),\n        {\"spatial_size\": (1, 2, 2)},\n        np.array(\n            [\n                [\n                    [\n                        [-2.81748977, 0.66968869, -0.52625642, -3.52173734],\n                        [-1.96865364, 1.76472402, -5.06258324, -1.71805669],\n                        [1.11934537, -2.45103851, -2.13654555, -1.15855539],\n                        [1.49678424, -2.06960677, -1.74328475, -1.7271617],\n                    ],\n                    [\n                        [3.69301983, 3.66097025, 1.68091953, 0.6465273],\n                        [1.23445289, 2.49568333, -1.56671014, 1.96849393],\n                        [-2.09916271, -1.06768069, 1.51861453, -2.39180117],\n                        [-0.23449363, -1.44269211, -0.42794076, -4.68520972],\n                    ],\n                    [\n                        [-1.96578162, -0.17168741, 2.55269525, 0.70931081],\n                        [1.00476444, 2.15217619, -0.47246061, 1.4748298],\n                        [-0.34829048, -1.89234811, 0.34558185, 1.9606272],\n                        [1.56684302, 0.98019418, 5.00513708, 1.69126978],\n                    ],\n                ],\n                [\n                    [\n                        [-1.36146598, 0.7469491, -5.16647064, -4.73906938],\n                        [1.91920577, -2.33606298, -0.95030633, 0.7901769],\n                        [2.49116076, 3.93791246, 3.50390686, 2.79030531],\n                        [1.70638302, 4.33070564, 3.52613304, 0.77965554],\n                    ],\n                    [\n                        [-0.62725323, -1.64857887, -2.92384357, -3.39022706],\n                        [-3.00611521, -0.66597021, -0.21577072, -2.39146379],\n                        [2.94568388, -0.83686357, -2.55435186, 2.74064119],\n                        [2.3247117, 2.78900974, 1.59788581, 0.31140512],\n                    ],\n                    [\n                        [-0.89856598, -4.15325814, -0.21934502, -1.64845891],\n                        [-1.52694693, -2.81794479, -2.22623861, -3.0299247],\n                        [4.49410486, 1.27529645, 2.92559679, -1.12171559],\n                        [3.30307684, 4.97189727, 2.43914751, 4.7262225],\n                    ],\n                ],\n                [\n                    [\n                        [-4.81571068, -3.28263239, 1.635167, 2.36520831],\n                        [-1.92511521, -4.311247, 2.19242556, 7.34990574],\n                        [-3.04122716, -0.94284154, 1.30058968, -0.11719455],\n                        [-2.28657395, -3.68766906, 0.28400757, 5.08072864],\n                    ],\n                    [\n                        [-4.2308508, -0.16084264, 2.69545963, 3.4666492],\n                        [-5.29514976, -1.55660775, 4.28031473, -0.39019547],\n                        [-3.4617024, -1.92430221, 1.20214712, 4.25261228],\n                        [-0.30683774, -1.4524049, 2.35996724, 3.83663135],\n                    ],\n                    [\n                        [-2.20587965, -1.94408353, -0.66964855, 1.15838178],\n                        [-4.26637632, -0.46145396, 2.27393031, 3.5415298],\n                        [-3.91902371, 2.02343374, 3.54278271, 2.40735681],\n                        [-4.3785335, -0.78200288, 3.12162619, 3.55709275],\n                    ],\n                ],\n                [\n                    [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]],\n                    [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]],\n                    [[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]],\n                ],\n            ]\n        ),\n    ],\n]\n\n\nclass TestRandDeformGrid(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_rand_deform_grid(self, input_param, input_data, expected_val):\n        g = RandDeformGrid(**input_param)\n        g.set_random_state(123)\n        result = g(**input_data)\n        assert_allclose(result, expected_val, type_test=False, rtol=1e-3, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_elastic_2d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Rand2DElastic\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, is_tf32_env\n\n_rtol = 5e-3 if is_tf32_env() else 1e-4\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                {\"spacing\": (0.3, 0.3), \"magnitude_range\": (1.0, 2.0), \"prob\": 0.0, \"device\": device},\n                {\"img\": p(torch.ones((3, 3, 3))), \"spatial_size\": (2, 2)},\n                p(np.ones((3, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"spacing\": (0.3, 0.3), \"magnitude_range\": (1.0, 2.0), \"prob\": 0.0, \"device\": device},\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3)))},\n                p(np.arange(27).reshape((3, 3, 3))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (1.0, 2.0),\n                    \"prob\": 0.9,\n                    \"device\": device,\n                    \"padding_mode\": \"zeros\",\n                },\n                {\"img\": p(torch.ones((3, 3, 3))), \"spatial_size\": (2, 2), \"mode\": \"bilinear\"},\n                p(\n                    np.array(\n                        [\n                            [[0.45531988, 0.0], [0.0, 0.71558857]],\n                            [[0.45531988, 0.0], [0.0, 0.71558857]],\n                            [[0.45531988, 0.0], [0.0, 0.71558857]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"spacing\": (1.0, 1.0),\n                    \"magnitude_range\": (1.0, 1.0),\n                    \"scale_range\": [1.2, 2.2],\n                    \"prob\": 0.9,\n                    \"padding_mode\": \"border\",\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3)))},\n                p(\n                    torch.tensor(\n                        [\n                            [[3.0793, 2.6141], [4.0568, 5.9978]],\n                            [[12.0793, 11.6141], [13.0568, 14.9978]],\n                            [[21.0793, 20.6141], [22.0568, 23.9978]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (0.1, 0.2),\n                    \"translate_range\": [-0.01, 0.01],\n                    \"scale_range\": [0.01, 0.02],\n                    \"prob\": 0.9,\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [[1.3584113, 1.9251312], [5.626623, 6.642721]],\n                            [[10.358411, 10.925131], [14.626623, 15.642721]],\n                            [[19.358412, 19.92513], [23.626623, 24.642721]],\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestRand2DElastic(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_2d_elastic(self, input_param, input_data, expected_val):\n        g = Rand2DElastic(**input_param)\n        set_track_meta(False)\n        result = g(**input_data)\n        self.assertNotIsInstance(result, MetaTensor)\n        self.assertIsInstance(result, torch.Tensor)\n        set_track_meta(True)\n        g.set_random_state(123)\n        result = g(**input_data)\n        assert_allclose(result, expected_val, type_test=False, rtol=_rtol, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_elastic_3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Rand3DElastic\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                {\n                    \"magnitude_range\": (0.3, 2.3),\n                    \"sigma_range\": (1.0, 20.0),\n                    \"prob\": 0.0,\n                    \"device\": device,\n                    \"spatial_size\": -1,\n                },\n                {\"img\": p(torch.arange(72).reshape((2, 3, 3, 4)))},\n                p(np.arange(72).reshape((2, 3, 3, 4))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"magnitude_range\": (0.3, 2.3), \"sigma_range\": (1.0, 20.0), \"prob\": 0.0, \"device\": device},\n                {\"img\": p(torch.ones((2, 3, 3, 3))), \"spatial_size\": (2, 2, 2)},\n                p(np.ones((2, 2, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\"magnitude_range\": (0.3, 0.3), \"sigma_range\": (1.0, 2.0), \"prob\": 0.9, \"device\": device},\n                {\"img\": p(torch.arange(27).reshape((1, 3, 3, 3))), \"spatial_size\": (2, 2, 2)},\n                p(\n                    np.array(\n                        [\n                            [\n                                [[6.4939356, 7.50289], [9.518351, 10.522849]],\n                                [[15.512375, 16.523542], [18.531467, 19.53646]],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"magnitude_range\": (0.3, 0.3),\n                    \"sigma_range\": (1.0, 2.0),\n                    \"prob\": 0.9,\n                    \"rotate_range\": [1, 1, 1],\n                    \"device\": device,\n                    \"spatial_size\": (2, 2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((1, 3, 3, 3))), \"mode\": \"bilinear\"},\n                p(\n                    np.array(\n                        [\n                            [\n                                [[5.0069294, 9.463932], [9.287769, 13.739735]],\n                                [[12.319424, 16.777205], [16.594296, 21.045748]],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n\n\nclass TestRand3DElastic(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_3d_elastic(self, input_param, input_data, expected_val):\n        g = Rand3DElastic(**input_param)\n        set_track_meta(False)\n        g.set_random_state(123)\n        result = g(**input_data)\n        self.assertNotIsInstance(result, MetaTensor)\n        self.assertIsInstance(result, torch.Tensor)\n        set_track_meta(True)\n        g.set_random_state(123)\n        result = g(**input_data)\n        assert_allclose(result, expected_val, type_test=False, rtol=1e-1, atol=1e-1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_elasticd_2d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Rand2DElasticd\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose, is_tf32_env\n\n_rtol = 5e-3 if is_tf32_env() else 1e-4\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (1.0, 2.0),\n                    \"prob\": 0.0,\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                },\n                {\"img\": p(torch.ones((3, 3, 3))), \"seg\": p(torch.ones((3, 3, 3)))},\n                p(np.ones((3, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (0.3, 0.3),\n                    \"prob\": 0.0,\n                    \"device\": device,\n                    \"spatial_size\": -1,\n                },\n                {\"img\": p(torch.arange(4).reshape((1, 2, 2))), \"seg\": p(torch.arange(4).reshape((1, 2, 2)))},\n                p(np.arange(4).reshape((1, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (1.0, 2.0),\n                    \"prob\": 0.9,\n                    \"padding_mode\": \"zeros\",\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                    \"mode\": \"bilinear\",\n                },\n                {\"img\": p(torch.ones((3, 3, 3))), \"seg\": p(torch.ones((3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [[0.45531988, 0.0], [0.0, 0.71558857]],\n                            [[0.45531988, 0.0], [0.0, 0.71558857]],\n                            [[0.45531988, 0.0], [0.0, 0.71558857]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"spacing\": (1.0, 1.0),\n                    \"magnitude_range\": (1.0, 1.0),\n                    \"scale_range\": [1.2, 2.2],\n                    \"prob\": 0.9,\n                    \"padding_mode\": \"border\",\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3))), \"seg\": p(torch.arange(27).reshape((3, 3, 3)))},\n                p(\n                    torch.tensor(\n                        [\n                            [[3.0793, 2.6141], [4.0568, 5.9978]],\n                            [[12.0793, 11.6141], [13.0568, 14.9978]],\n                            [[21.0793, 20.6141], [22.0568, 23.9978]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (0.1, 0.2),\n                    \"translate_range\": [-0.01, 0.01],\n                    \"scale_range\": [0.01, 0.02],\n                    \"prob\": 0.9,\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3))), \"seg\": p(torch.arange(27).reshape((3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [[1.3584113, 1.9251312], [5.626623, 6.642721]],\n                            [[10.358411, 10.925131], [14.626623, 15.642721]],\n                            [[19.358412, 19.92513], [23.626623, 24.642721]],\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"mode\": (\"bilinear\", \"nearest\"),\n                    \"spacing\": (0.3, 0.3),\n                    \"magnitude_range\": (0.1, 0.2),\n                    \"translate_range\": [-0.01, 0.01],\n                    \"scale_range\": [0.01, 0.02],\n                    \"prob\": 0.9,\n                    \"device\": device,\n                    \"spatial_size\": (2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((3, 3, 3))), \"seg\": p(torch.arange(27).reshape((3, 3, 3)))},\n                {\n                    \"img\": p(\n                        torch.tensor(\n                            [\n                                [[1.3584, 1.9251], [5.6266, 6.6427]],\n                                [[10.3584, 10.9251], [14.6266, 15.6427]],\n                                [[19.3584, 19.9251], [23.6266, 24.6427]],\n                            ]\n                        )\n                    ),\n                    \"seg\": p(\n                        torch.tensor(\n                            [[[0.0, 2.0], [6.0, 8.0]], [[9.0, 11.0], [15.0, 17.0]], [[18.0, 20.0], [24.0, 26.0]]]\n                        )\n                    ),\n                },\n            ]\n        )\n\n\nclass TestRand2DElasticd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_2d_elasticd(self, input_param, input_data, expected_val):\n        g = Rand2DElasticd(**input_param)\n        if input_param.get(\"device\", None) is None and isinstance(input_data[\"img\"], torch.Tensor):\n            input_data[\"img\"].to(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n        g.set_random_state(123)\n        res = g(input_data)\n        for key in res:\n            result = res[key]\n            expected = expected_val[key] if isinstance(expected_val, dict) else expected_val\n            assert_allclose(result, expected, rtol=_rtol, atol=5e-3, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_elasticd_3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Rand3DElasticd\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"magnitude_range\": (0.3, 2.3),\n                    \"sigma_range\": (1.0, 20.0),\n                    \"prob\": 0.0,\n                    \"device\": device,\n                    \"spatial_size\": (2, 2, 2),\n                },\n                {\"img\": p(torch.ones((2, 3, 3, 3))), \"seg\": p(torch.ones((2, 3, 3, 3)))},\n                p(np.ones((2, 2, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"magnitude_range\": (0.3, 2.3),\n                    \"sigma_range\": (1.0, 20.0),\n                    \"prob\": 0.0,\n                    \"device\": device,\n                    \"spatial_size\": (2, -1, -1),\n                },\n                {\"img\": p(torch.ones((2, 3, 3, 3))), \"seg\": p(torch.ones((2, 3, 3, 3)))},\n                p(np.ones((2, 2, 3, 3))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"magnitude_range\": (0.3, 2.3),\n                    \"sigma_range\": (1.0, 20.0),\n                    \"prob\": 0.0,\n                    \"device\": device,\n                    \"spatial_size\": -1,\n                },\n                {\"img\": p(torch.arange(8).reshape((1, 2, 2, 2))), \"seg\": p(torch.arange(8).reshape((1, 2, 2, 2)))},\n                p(np.arange(8).reshape((1, 2, 2, 2))),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"magnitude_range\": (0.3, 0.3),\n                    \"sigma_range\": (1.0, 2.0),\n                    \"prob\": 0.9,\n                    \"device\": device,\n                    \"spatial_size\": (2, 2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((1, 3, 3, 3))), \"seg\": p(torch.arange(27).reshape((1, 3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [\n                                [[6.4939356, 7.50289], [9.518351, 10.522849]],\n                                [[15.512375, 16.523542], [18.531467, 19.53646]],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"magnitude_range\": (0.3, 0.3),\n                    \"sigma_range\": (1.0, 2.0),\n                    \"prob\": 0.9,\n                    \"rotate_range\": [1, 1, 1],\n                    \"device\": device,\n                    \"spatial_size\": (2, 2, 2),\n                    \"mode\": \"bilinear\",\n                },\n                {\"img\": p(torch.arange(27).reshape((1, 3, 3, 3))), \"seg\": p(torch.arange(27).reshape((1, 3, 3, 3)))},\n                p(\n                    np.array(\n                        [\n                            [\n                                [[5.0069294, 9.463932], [9.287769, 13.739735]],\n                                [[12.319424, 16.777205], [16.594296, 21.045748]],\n                            ]\n                        ]\n                    )\n                ),\n            ]\n        )\n        TESTS.append(\n            [\n                {\n                    \"keys\": (\"img\", \"seg\"),\n                    \"mode\": (\"bilinear\", \"nearest\"),\n                    \"magnitude_range\": (0.3, 0.3),\n                    \"sigma_range\": (1.0, 2.0),\n                    \"prob\": 0.9,\n                    \"rotate_range\": [1, 1, 1],\n                    \"device\": device,\n                    \"spatial_size\": (2, 2, 2),\n                },\n                {\"img\": p(torch.arange(27).reshape((1, 3, 3, 3))), \"seg\": p(torch.arange(27).reshape((1, 3, 3, 3)))},\n                {\n                    \"img\": p(\n                        torch.tensor(\n                            [[[[5.0069, 9.4639], [9.2878, 13.7397]], [[12.3194, 16.7772], [16.5943, 21.0457]]]]\n                        )\n                    ),\n                    \"seg\": p(torch.tensor([[[[4.0, 14.0], [7.0, 14.0]], [[9.0, 19.0], [12.0, 22.0]]]])),\n                },\n            ]\n        )\n\n\nclass TestRand3DElasticd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_3d_elasticd(self, input_param, input_data, expected_val):\n        g = Rand3DElasticd(**input_param)\n        g.set_random_state(123)\n        if input_param.get(\"device\", None) is None and isinstance(input_data[\"img\"], torch.Tensor):\n            input_data[\"img\"].to(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n        res = g(input_data)\n        for key in res:\n            result = res[key]\n            expected = expected_val[key] if isinstance(expected_val, dict) else expected_val\n            assert_allclose(result, expected, type_test=False, rtol=1e-2, atol=1e-2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_flip.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandFlip\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\nINVALID_CASES = [(\"wrong_axis\", [\"s\", 1], TypeError), (\"not_numbers\", \"s\", TypeError)]\n\nVALID_CASES = [(\"no_axis\", None), (\"one_axis\", 1), (\"many_axis\", [0, 1])]\n\n\nclass TestRandFlip(NumpyImageTestCase2D):\n    @parameterized.expand(INVALID_CASES)\n    def test_invalid_inputs(self, _, spatial_axis, raises):\n        with self.assertRaises(raises):\n            flip = RandFlip(prob=1.0, spatial_axis=spatial_axis)\n            flip(self.imt[0])\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, spatial_axis):\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            init_param = {\"prob\": 1.0, \"spatial_axis\": spatial_axis}\n            flip = RandFlip(**init_param)\n            set_track_meta(False)\n            result = flip(im)\n            self.assertNotIsInstance(result, MetaTensor)\n            self.assertIsInstance(result, torch.Tensor)\n            set_track_meta(True)\n            expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            call_param = {\"img\": im}\n            result = flip(**call_param)\n            assert_allclose(result, p(expected), type_test=\"tensor\")\n            test_local_inversion(flip, result, im)\n\n            # test lazy\n            test_resampler_lazy(flip, result, init_param, call_param)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_flipd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandFlipd\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\nVALID_CASES = [(\"no_axis\", None), (\"one_axis\", 1), (\"many_axis\", [0, 1])]\n\n\nclass TestRandFlipd(NumpyImageTestCase2D):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, _, spatial_axis):\n        for p in TEST_NDARRAYS_ALL:\n            init_param = {\"keys\": \"img\", \"prob\": 1.0, \"spatial_axis\": spatial_axis}\n            flip = RandFlipd(**init_param)\n            im = p(self.imt[0])\n            call_param = {\"data\": {\"img\": im}}\n            result = flip(**call_param)\n\n            # test lazy\n            test_resampler_lazy(flip, result, init_param, call_param, output_key=\"img\")\n            flip.lazy = False\n\n            expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(result[\"img\"], p(expected), type_test=\"tensor\")\n            test_local_inversion(flip, {\"img\": result[\"img\"]}, {\"img\": im}, \"img\")\n\n            set_track_meta(False)\n            result = flip({\"img\": im})[\"img\"]\n            self.assertNotIsInstance(result, MetaTensor)\n            self.assertIsInstance(result, torch.Tensor)\n            set_track_meta(True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gaussian_noise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGaussianNoise\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append((\"test_zero_mean\", p, 0, 0.1, True))\n    TESTS.append((\"test_non_zero_mean\", p, 1, 0.5, True))\n    TESTS.append((\"test_no_sample_std\", p, 1, 0.5, False))\n\n\nclass TestRandGaussianNoise(NumpyImageTestCase2D):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, _, im_type, mean, std, sample_std):\n        seed = 0\n        gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std, sample_std=sample_std)\n        gaussian_fn.set_random_state(seed)\n        im = im_type(self.imt)\n        noised = gaussian_fn(im)\n        np.random.seed(seed)\n        np.random.random()\n        _std = np.random.uniform(0, std) if sample_std else std\n        expected = self.imt + np.random.normal(mean, _std, size=self.imt.shape)\n        if isinstance(noised, torch.Tensor):\n            noised = noised.cpu()\n        np.testing.assert_allclose(expected, noised, atol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gaussian_noised.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGaussianNoised\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([\"test_zero_mean\", p, [\"img1\", \"img2\"], 0, 0.1, True])\n    TESTS.append([\"test_non_zero_mean\", p, [\"img1\", \"img2\"], 1, 0.5, True])\n    TESTS.append([\"test_no_sample_std\", p, [\"img1\", \"img2\"], 1, 0.5, False])\n\nseed = 0\n\n\nclass TestRandGaussianNoised(NumpyImageTestCase2D):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, _, im_type, keys, mean, std, sample_std):\n        gaussian_fn = RandGaussianNoised(\n            keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64, sample_std=sample_std\n        )\n        gaussian_fn.set_random_state(seed)\n        im = im_type(self.imt)\n        noised = gaussian_fn({k: im for k in keys})\n        np.random.seed(seed)\n        # simulate the randomize() of transform\n        np.random.random()\n        _std = np.random.uniform(0, std) if sample_std else std\n        noise = np.random.normal(mean, _std, size=self.imt.shape)\n        for k in keys:\n            expected = self.imt + noise\n            if isinstance(noised[k], torch.Tensor):\n                noised[k] = noised[k].cpu()\n            np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gaussian_sharpen.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGaussianSharpen\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\n\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"prob\": 1.0},\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [5.2919216, 5.5854445, 5.29192],\n                        [11.3982, 12.62332, 11.398202],\n                        [14.870525, 17.323769, 14.870527],\n                    ],\n                    [\n                        [20.413757, 22.767355, 20.413757],\n                        [28.495504, 31.558315, 28.495499],\n                        [29.99236, 34.505676, 29.992361],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"sigma1_x\": (0.5, 0.75),\n                \"sigma1_y\": (0.5, 0.75),\n                \"sigma1_z\": (0.5, 0.75),\n                \"sigma2_x\": 0.4,\n                \"sigma2_y\": 0.4,\n                \"sigma2_z\": 0.4,\n                \"prob\": 1.0,\n            },\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [4.1071496, 3.597953, 4.1071477],\n                        [10.062014, 9.825114, 10.0620165],\n                        [14.698058, 15.818766, 14.698058],\n                    ],\n                    [\n                        [18.211048, 18.16049, 18.211048],\n                        [25.155039, 24.56279, 25.155039],\n                        [28.801964, 30.381308, 28.801964],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"sigma1_x\": (0.5, 0.75),\n                \"sigma1_y\": (0.5, 0.75),\n                \"sigma1_z\": (0.5, 0.75),\n                \"sigma2_x\": (0.5, 0.75),\n                \"sigma2_y\": (0.5, 0.75),\n                \"sigma2_z\": (0.5, 0.75),\n                \"prob\": 1.0,\n            },\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [4.81077, 4.4237204, 4.81077],\n                        [12.061236, 12.298177, 12.061236],\n                        [17.362553, 19.201174, 17.362553],\n                    ],\n                    [\n                        [21.440754, 22.142393, 21.440754],\n                        [30.15308, 30.745445, 30.153086],\n                        [33.99255, 36.919838, 33.99255],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"sigma1_x\": (0.5, 0.75),\n                \"sigma1_y\": (0.5, 0.75),\n                \"sigma1_z\": (0.5, 0.75),\n                \"sigma2_x\": (0.5, 0.75),\n                \"sigma2_y\": (0.5, 0.75),\n                \"sigma2_z\": (0.5, 0.75),\n                \"approx\": \"scalespace\",\n                \"prob\": 1.0,\n            },\n            p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),\n            p(\n                [\n                    [\n                        [4.430213, 3.2278745, 4.4302144],\n                        [10.325399, 8.507457, 10.325399],\n                        [17.494898, 16.5609, 17.494894],\n                    ],\n                    [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestRandGaussianSharpen(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        converter = RandGaussianSharpen(**arguments)\n        converter.set_random_state(seed=0)\n        result = converter(image)\n        assert_allclose(result, expected_data, atol=0, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gaussian_sharpend.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGaussianSharpend\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"prob\": 1.0},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [5.2919216, 5.5854445, 5.29192],\n                        [11.3982, 12.62332, 11.398202],\n                        [14.870525, 17.323769, 14.870527],\n                    ],\n                    [\n                        [20.413757, 22.767355, 20.413757],\n                        [28.495504, 31.558315, 28.495499],\n                        [29.99236, 34.505676, 29.992361],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"keys\": \"img\",\n                \"sigma1_x\": (0.5, 0.75),\n                \"sigma1_y\": (0.5, 0.75),\n                \"sigma1_z\": (0.5, 0.75),\n                \"sigma2_x\": 0.4,\n                \"sigma2_y\": 0.4,\n                \"sigma2_z\": 0.4,\n                \"prob\": 1.0,\n            },\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [4.1071496, 3.597953, 4.1071477],\n                        [10.062014, 9.825114, 10.0620165],\n                        [14.698058, 15.818766, 14.698058],\n                    ],\n                    [\n                        [18.211048, 18.16049, 18.211048],\n                        [25.155039, 24.56279, 25.155039],\n                        [28.801964, 30.381308, 28.801964],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"keys\": \"img\",\n                \"sigma1_x\": (0.5, 0.75),\n                \"sigma1_y\": (0.5, 0.75),\n                \"sigma1_z\": (0.5, 0.75),\n                \"sigma2_x\": (0.5, 0.75),\n                \"sigma2_y\": (0.5, 0.75),\n                \"sigma2_z\": (0.5, 0.75),\n                \"prob\": 1.0,\n            },\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [4.81077, 4.4237204, 4.81077],\n                        [12.061236, 12.298177, 12.061236],\n                        [17.362553, 19.201174, 17.362553],\n                    ],\n                    [\n                        [21.440754, 22.142393, 21.440754],\n                        [30.15308, 30.745445, 30.153086],\n                        [33.99255, 36.919838, 33.99255],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\n                \"keys\": \"img\",\n                \"sigma1_x\": (0.5, 0.75),\n                \"sigma1_y\": (0.5, 0.75),\n                \"sigma1_z\": (0.5, 0.75),\n                \"sigma2_x\": (0.5, 0.75),\n                \"sigma2_y\": (0.5, 0.75),\n                \"sigma2_z\": (0.5, 0.75),\n                \"approx\": \"scalespace\",\n                \"prob\": 1.0,\n            },\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [4.430213, 3.2278745, 4.4302144],\n                        [10.325399, 8.507457, 10.325399],\n                        [17.494898, 16.5609, 17.494894],\n                    ],\n                    [[20.87405, 18.06946, 20.87405], [25.813503, 21.268656, 25.8135], [33.93874, 31.402481, 33.938725]],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestRandGaussianSharpend(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        converter = RandGaussianSharpend(**arguments)\n        converter.set_random_state(seed=0)\n        result = converter(image)\n        assert_allclose(result[\"img\"], expected_data, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gaussian_smooth.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGaussianSmooth\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"sigma_x\": (0.5, 1.5), \"prob\": 1.0},\n            p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])),\n            np.array(\n                [\n                    [\n                        [0.71806467, 0.9074683, 0.71806467],\n                        [1.0718315, 1.3545481, 1.0718315],\n                        [1.0337002, 1.306359, 1.0337002],\n                    ],\n                    [\n                        [2.0318885, 2.5678391, 2.0318885],\n                        [2.6795788, 3.3863702, 2.6795788],\n                        [2.3475242, 2.9667296, 2.3475242],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"sigma_x\": (0.5, 1.5), \"sigma_y\": (0.5, 1.0), \"prob\": 1.0},\n            p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])),\n            np.array(\n                [\n                    [\n                        [0.7686928, 0.9848021, 0.7686928],\n                        [1.1474025, 1.4699818, 1.1474024],\n                        [1.1065826, 1.4176859, 1.1065826],\n                    ],\n                    [\n                        [2.1751494, 2.7866683, 2.1751497],\n                        [2.8685062, 3.6749542, 2.8685062],\n                        [2.5130394, 3.219552, 2.5130394],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"sigma_x\": (0.5, 1.5), \"sigma_y\": (0.5, 1.0), \"approx\": \"scalespace\", \"prob\": 1.0},\n            p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])),\n            np.array(\n                [\n                    [\n                        [0.8128456, 0.96736777, 0.8128456],\n                        [1.2742369, 1.5164697, 1.2742369],\n                        [1.2800367, 1.5233722, 1.2800368],\n                    ],\n                    [\n                        [2.3825073, 2.8354228, 2.3825073],\n                        [3.1855922, 3.7911744, 3.1855922],\n                        [2.8496985, 3.391427, 2.8496985],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestRandGaussianSmooth(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        converter = RandGaussianSmooth(**arguments)\n        converter.set_random_state(seed=0)\n        result = converter(image)\n        assert_allclose(result, expected_data, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gaussian_smoothd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGaussianSmoothd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma_x\": (0.5, 1.5), \"prob\": 1.0},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [0.71806467, 0.9074683, 0.71806467],\n                        [1.0718315, 1.3545481, 1.0718315],\n                        [1.0337002, 1.306359, 1.0337002],\n                    ],\n                    [\n                        [2.0318885, 2.5678391, 2.0318885],\n                        [2.6795788, 3.3863702, 2.6795788],\n                        [2.3475242, 2.9667296, 2.3475242],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma_x\": (0.5, 1.5), \"sigma_y\": (0.5, 1.0), \"prob\": 1.0},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [0.7686928, 0.9848021, 0.7686928],\n                        [1.1474025, 1.4699818, 1.1474024],\n                        [1.1065826, 1.4176859, 1.1065826],\n                    ],\n                    [\n                        [2.1751494, 2.7866683, 2.1751497],\n                        [2.8685062, 3.6749542, 2.8685062],\n                        [2.5130394, 3.219552, 2.5130394],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": \"img\", \"sigma_x\": (0.5, 1.5), \"sigma_y\": (0.5, 1.0), \"approx\": \"scalespace\", \"prob\": 1.0},\n            {\"img\": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},\n            np.array(\n                [\n                    [\n                        [0.8128456, 0.96736777, 0.8128456],\n                        [1.2742369, 1.5164697, 1.2742369],\n                        [1.2800367, 1.5233722, 1.2800368],\n                    ],\n                    [\n                        [2.3825073, 2.8354228, 2.3825073],\n                        [3.1855922, 3.7911744, 3.1855922],\n                        [2.8496985, 3.391427, 2.8496985],\n                    ],\n                ]\n            ),\n        ]\n    )\n\n\nclass TestRandGaussianSmoothd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        converter = RandGaussianSmoothd(**arguments)\n        converter.set_random_state(seed=0)\n        result = converter(image)\n        assert_allclose(result[\"img\"], expected_data, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gibbs_noise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import RandGibbsNoise\nfrom monai.utils.misc import set_determinism\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_torch_fft = optional_import(\"torch.fft\", name=\"fftshift\")\n\nTEST_CASES = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:\n        TEST_CASES.append((shape, input_type))\n\n\nclass TestRandGibbsNoise(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, input_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None]\n        return input_type(im)\n\n    @parameterized.expand(TEST_CASES)\n    def test_0_prob(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = [0.5, 1.0]\n        t = RandGibbsNoise(0.0, alpha)\n        out = t(im)\n        assert_allclose(out, im, rtol=1e-7, atol=0, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_same_result(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = [0.5, 0.8]\n        t = RandGibbsNoise(1.0, alpha)\n        t.set_random_state(42)\n        out1 = t(deepcopy(im))\n        t.set_random_state(42)\n        out2 = t(deepcopy(im))\n        assert_allclose(out1, out2, rtol=1e-7, atol=1e-2, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_identity(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = [0.0, 0.0]\n        t = RandGibbsNoise(1.0, alpha)\n        out = t(deepcopy(im))\n        assert_allclose(out, im, atol=1e-2, rtol=1e-7, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha_1(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = [1.0, 1.0]\n        t = RandGibbsNoise(1.0, alpha)\n        out = t(deepcopy(im))\n        assert_allclose(out, 0 * im, rtol=1e-7, atol=1e-2, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = [0.5, 0.51]\n        t = RandGibbsNoise(1.0, alpha)\n        _ = t(deepcopy(im))\n        self.assertGreaterEqual(t.sampled_alpha, 0.5)\n        self.assertLessEqual(t.sampled_alpha, 0.51)\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha_single_value(self, im_shape, input_type):\n        im = self.get_data(im_shape, input_type)\n        alpha = 0.01\n        t = RandGibbsNoise(1.0, alpha)\n        _ = t(deepcopy(im))\n        self.assertGreaterEqual(t.sampled_alpha, 0)\n        self.assertLessEqual(t.sampled_alpha, 0.01)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_gibbs_noised.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import RandGibbsNoised\nfrom monai.utils.misc import set_determinism\nfrom monai.utils.module import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n_, has_torch_fft = optional_import(\"torch.fft\", name=\"fftshift\")\n\nTEST_CASES = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:\n        TEST_CASES.append((shape, input_type))\n\nKEYS = [\"im\", \"label\"]\n\n\nclass TestRandGibbsNoised(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, input_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)\n        return {k: input_type(v) for k, v in zip(KEYS, ims)}\n\n    @parameterized.expand(TEST_CASES)\n    def test_0_prob(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = [0.5, 1.0]\n        t = RandGibbsNoised(KEYS, 0.0, alpha)\n        out = t(data)\n        for k in KEYS:\n            assert_allclose(data[k], out[k], rtol=1e-7, atol=0, type_test=False)\n\n    @parameterized.expand(TEST_CASES)\n    def test_same_result(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = [0.5, 0.8]\n        t = RandGibbsNoised(KEYS, 1.0, alpha)\n        t.set_random_state(42)\n        out1 = t(deepcopy(data))\n        t.set_random_state(42)\n        out2 = t(deepcopy(data))\n        for k in KEYS:\n            assert_allclose(out1[k], out2[k], rtol=1e-7, atol=0, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_identity(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = [0.0, 0.0]\n        t = RandGibbsNoised(KEYS, 1.0, alpha)\n        out = t(deepcopy(data))\n        for k in KEYS:\n            assert_allclose(out[k], data[k], atol=1e-2, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha_1(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = [1.0, 1.0]\n        t = RandGibbsNoised(KEYS, 1.0, alpha)\n        out = t(deepcopy(data))\n        for k in KEYS:\n            assert_allclose(out[k], 0.0 * data[k], atol=1e-2, type_test=\"tensor\")\n\n    @parameterized.expand(TEST_CASES)\n    def test_dict_matches(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        # use same image for both dictionary entries to check same trans is applied to them\n        data = {KEYS[0]: deepcopy(data[KEYS[0]]), KEYS[1]: deepcopy(data[KEYS[0]])}\n        alpha = [0.5, 1.0]\n        t = RandGibbsNoised(KEYS, 1.0, alpha)\n        out = t(deepcopy(data))\n        assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0, type_test=False)\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = [0.5, 0.51]\n        t = RandGibbsNoised(KEYS, 1.0, alpha)\n        _ = t(deepcopy(data))\n        self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51)\n\n    @parameterized.expand(TEST_CASES)\n    def test_alpha_single_value(self, im_shape, input_type):\n        data = self.get_data(im_shape, input_type)\n        alpha = 0.01\n        t = RandGibbsNoised(KEYS, 1.0, alpha)\n        _ = t(deepcopy(data))\n        self.assertTrue(0 <= t.rand_gibbs_noise.sampled_alpha <= 0.01)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_grid_distortion.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGridDistortion\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    seed = 0\n    TESTS.append(\n        [\n            dict(num_cells=2, prob=1.0, distort_limit=0.5, mode=\"nearest\", padding_mode=\"zeros\"),\n            seed,\n            p(np.indices([6, 6]).astype(np.float32)),\n            p(\n                np.array(\n                    [\n                        [\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [2.0, 2.0, 2.0, 2.0, 2.0, 0.0],\n                            [4.0, 4.0, 4.0, 4.0, 4.0, 0.0],\n                            [4.0, 4.0, 4.0, 4.0, 4.0, 0.0],\n                            [5.0, 5.0, 5.0, 5.0, 5.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ],\n                        [\n                            [0.0, 1.0, 3.0, 3.0, 4.0, 0.0],\n                            [0.0, 1.0, 3.0, 3.0, 4.0, 0.0],\n                            [0.0, 1.0, 3.0, 3.0, 4.0, 0.0],\n                            [0.0, 1.0, 3.0, 3.0, 4.0, 0.0],\n                            [0.0, 1.0, 3.0, 3.0, 4.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ],\n                    ]\n                ).astype(np.float32)\n            ),\n        ]\n    )\n    seed = 1\n    TESTS.append(\n        [\n            dict(num_cells=(2, 2), prob=1.0, distort_limit=0.1, mode=\"bilinear\", padding_mode=\"reflection\"),\n            seed,\n            p(np.indices([6, 6]).astype(np.float32)),\n            p(\n                np.array(\n                    [\n                        [\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [1.5660975, 1.5660975, 1.5660975, 1.5660975, 1.5660974, 1.5660975],\n                            [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195],\n                            [3.132195, 3.132195, 3.132195, 3.132195, 3.132195, 3.132195],\n                            [4.482229, 4.482229, 4.482229, 4.482229, 4.482229, 4.482229],\n                            [5.0, 5.0, 5.0, 5.0, 5.0, 5.0],\n                        ],\n                        [\n                            [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0],\n                            [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0],\n                            [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0],\n                            [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0],\n                            [0.0, 1.3940268, 2.7880535, 2.7880535, 4.1657553, 5.0],\n                            [0.0, 1.3940266, 2.7880538, 2.7880538, 4.1657557, 5.0],\n                        ],\n                    ]\n                ).astype(np.float32)\n            ),\n        ]\n    )\n\n\nclass TestRandGridDistortion(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_grid_distortion(self, input_param, seed, input_data, expected_val):\n        g = RandGridDistortion(**input_param)\n        g.set_random_state(seed=seed)\n        result = g(input_data)\n        if input_param[\"padding_mode\"] != \"reflection\":\n            assert_allclose(result, expected_val, type_test=\"tensor\", rtol=1e-4, atol=1e-4)\n        else:\n            assert_allclose(result.shape, expected_val.shape, type_test=False, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_grid_distortiond.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandGridDistortiond\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nnum_cells = 2\nseed = 0\nfor p in TEST_NDARRAYS_ALL:\n    img = np.indices([6, 6]).astype(np.float32)\n    TESTS.append(\n        [\n            dict(\n                keys=[\"img\", \"mask\"],\n                num_cells=num_cells,\n                prob=1.0,\n                distort_limit=(-0.1, 0.1),\n                mode=[\"bilinear\", \"nearest\"],\n                padding_mode=\"zeros\",\n            ),\n            seed,\n            {\"img\": p(img), \"mask\": p(np.ones_like(img[:1]))},\n            p(\n                np.array(\n                    [\n                        [\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                            [1.5645568, 1.5645568, 1.5645568, 1.5645568, 1.5645568, 0.0],\n                            [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0],\n                            [3.1291137, 3.1291137, 3.1291137, 3.1291137, 3.1291137, 0.0],\n                            [4.6599426, 4.6599426, 4.6599426, 4.6599426, 4.6599426, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ],\n                        [\n                            [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],\n                            [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],\n                            [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],\n                            [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],\n                            [0.0, 1.4770963, 2.9541926, 2.9541926, 4.497961, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ],\n                    ]\n                ).astype(np.float32)\n            ),\n            p(\n                np.array(\n                    [\n                        [\n                            [1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n                            [1.0, 1.0, 1.0, 1.0, 1.0, 0.0],\n                            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n                        ]\n                    ]\n                )\n            ),\n        ]\n    )\n\n\nclass TestRandGridDistortiond(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_grid_distortiond(self, input_param, seed, input_data, expected_val_img, expected_val_mask):\n        g = RandGridDistortiond(**input_param)\n        g.set_random_state(seed=seed)\n        result = g(input_data)\n        assert_allclose(result[\"img\"], expected_val_img, type_test=False, rtol=1e-4, atol=1e-4)\n        assert_allclose(result[\"mask\"], expected_val_mask, type_test=False, rtol=1e-4, atol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_histogram_shift.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandHistogramShift\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"num_control_points\": 5, \"prob\": 0.0},\n            {\"img\": p(np.arange(8).reshape((1, 2, 2, 2)))},\n            np.arange(8).reshape((1, 2, 2, 2)),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"num_control_points\": 5, \"prob\": 0.9},\n            {\"img\": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))},\n            np.array([[[[0.0, 0.57227867], [1.1391707, 1.68990281]], [[2.75833219, 4.34445884], [5.70913743, 7.0]]]]),\n        ]\n    )\n    TESTS.append(\n        [\n            {\"num_control_points\": (5, 20), \"prob\": 0.9},\n            {\"img\": p(np.arange(8).reshape((1, 2, 2, 2)).astype(np.float32))},\n            np.array([[[[0.0, 1.17472492], [2.21553091, 2.88292011]], [[3.98407301, 5.01302123], [6.09275004, 7.0]]]]),\n        ]\n    )\n\nWARN_TESTS = []\nfor p in TEST_NDARRAYS:\n    WARN_TESTS.append(\n        [\n            {\"num_control_points\": 5, \"prob\": 1.0},\n            {\"img\": p(np.zeros(8).reshape((1, 2, 2, 2)))},\n            np.zeros(8).reshape((1, 2, 2, 2)),\n        ]\n    )\n\n\nclass TestRandHistogramShift(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_rand_histogram_shift(self, input_param, input_data, expected_val):\n        g = RandHistogramShift(**input_param)\n        g.set_random_state(123)\n        result = g(**input_data)\n        assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=\"tensor\")\n\n    def test_interp(self):\n        tr = RandHistogramShift()\n        for array_type in (torch.tensor, np.array):\n            x = array_type([0.0, 4.0, 6.0, 10.0])\n            y = array_type([1.0, -1.0, 3.0, 5.0])\n\n            yi = tr.interp(array_type([0, 2, 4, 8, 10]), x, y)\n            self.assertEqual(yi.shape, (5,))\n            assert_allclose(yi, array_type([1.0, 0.0, -1.0, 4.0, 5.0]))\n\n            yi = tr.interp(array_type([-1, 11, 10.001, -0.001]), x, y)\n            self.assertEqual(yi.shape, (4,))\n            assert_allclose(yi, array_type([1.0, 5.0, 5.0, 1.0]))\n\n            yi = tr.interp(array_type([[-2, 11], [1, 3], [8, 10]]), x, y)\n            self.assertEqual(yi.shape, (3, 2))\n            assert_allclose(yi, array_type([[1.0, 5.0], [0.5, -0.5], [4.0, 5.0]]))\n\n    @parameterized.expand(WARN_TESTS)\n    def test_warn(self, input_param, input_data, expected_val):\n        with self.assertWarns(Warning):\n            result = RandHistogramShift(**input_param)(**input_data)\n            assert_allclose(result, expected_val, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_k_space_spike_noise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import KSpaceSpikeNoise, RandKSpaceSpikeNoise\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for p in TEST_NDARRAYS:\n        for channel_wise in (True, False):\n            TESTS.append((shape, p, channel_wise))\n\n\nclass TestRandKSpaceSpikeNoise(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, im_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        im = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None]\n        return im_type(im)\n\n    @parameterized.expand(TESTS)\n    def test_0_prob(self, im_shape, im_type, channel_wise):\n        im = self.get_data(im_shape, im_type)\n        intensity_range = [14, 15]\n        t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise)\n        out = t(im)\n        assert_allclose(out, im, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS)\n    def test_1_prob(self, im_shape, im_type, channel_wise):\n        im = self.get_data(im_shape, im_type)\n        intensity_range = [14, 14]\n        t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise)\n        out = t(im)\n        base_t = KSpaceSpikeNoise(t.sampled_locs, [14])\n        out = out - base_t(im)\n        assert_allclose(out, im * 0, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS)\n    def test_same_result(self, im_shape, im_type, channel_wise):\n        im = self.get_data(im_shape, im_type)\n        intensity_range = [14, 15]\n        t = RandKSpaceSpikeNoise(0.0, intensity_range, channel_wise)\n        t.set_random_state(42)\n        out1 = t(deepcopy(im))\n        t.set_random_state(42)\n        out2 = t(deepcopy(im))\n        assert_allclose(out1, out2, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS)\n    def test_intensity(self, im_shape, im_type, channel_wise):\n        im = self.get_data(im_shape, im_type)\n        intensity_range = [14, 14.1]\n        t = RandKSpaceSpikeNoise(1.0, intensity_range, channel_wise)\n        _ = t(deepcopy(im))\n        self.assertGreaterEqual(t.sampled_k_intensity[0], 14)\n        self.assertLessEqual(t.sampled_k_intensity[0], 14.1)\n\n    @parameterized.expand(TESTS)\n    def test_default_intensity(self, im_shape, im_type, channel_wise):\n        im = self.get_data(im_shape, im_type)\n        t = RandKSpaceSpikeNoise(1.0, intensity_range=None, channel_wise=channel_wise)\n        out = t(deepcopy(im))\n        self.assertEqual(out.shape, im.shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_k_space_spike_noised.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms import RandKSpaceSpikeNoised\nfrom monai.utils.misc import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor shape in ((128, 64), (64, 48, 80)):\n    for p in TEST_NDARRAYS:\n        TESTS.append((shape, p))\n\nKEYS = [\"image\", \"label\"]\n\n\nclass TestKSpaceSpikeNoised(unittest.TestCase):\n    def setUp(self):\n        set_determinism(0)\n        super().setUp()\n\n    def tearDown(self):\n        set_determinism(None)\n\n    @staticmethod\n    def get_data(im_shape, im_type):\n        create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d\n        ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)\n        ims = [im_type(im[None]) for im in ims]\n        return dict(zip(KEYS, ims))\n\n    @parameterized.expand(TESTS)\n    def test_same_result(self, im_shape, im_type):\n        data = self.get_data(im_shape, im_type)\n\n        t = RandKSpaceSpikeNoised(KEYS, prob=1.0, intensity_range=(13, 15), channel_wise=True)\n        t.set_random_state(42)\n        out1 = t(deepcopy(data))\n\n        t.set_random_state(42)\n        out2 = t(deepcopy(data))\n\n        for k in KEYS:\n            assert_allclose(out1[k], out2[k], atol=1e-10, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS)\n    def test_0_prob(self, im_shape, im_type):\n        data = self.get_data(im_shape, im_type)\n\n        t1 = RandKSpaceSpikeNoised(KEYS, prob=0.0, intensity_range=(13, 15), channel_wise=True)\n        t2 = RandKSpaceSpikeNoised(KEYS, prob=0.0, intensity_range=(13, 15), channel_wise=True)\n        out1 = t1(data)\n        out2 = t2(data)\n\n        for k in KEYS:\n            assert_allclose(out1[k], data[k], type_test=\"tensor\")\n            assert_allclose(out2[k], data[k], type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_rician_noise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandRicianNoise\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append((\"test_zero_mean\", p, 0, 0.1))\n    TESTS.append((\"test_non_zero_mean\", p, 1, 0.5))\n\n\nclass TestRandRicianNoise(NumpyImageTestCase2D):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, _, in_type, mean, std):\n        seed = 0\n        rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std)\n        rician_fn.set_random_state(seed)\n        im = in_type(self.imt)\n        noised = rician_fn(im)\n        if isinstance(im, torch.Tensor):\n            self.assertEqual(im.dtype, noised.dtype)\n        np.random.seed(seed)\n        np.random.random()\n        _std = np.random.uniform(0, std)\n        expected = np.sqrt(\n            (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2\n            + np.random.normal(mean, _std, size=self.imt.shape) ** 2\n        )\n        if isinstance(noised, torch.Tensor):\n            noised = noised.cpu()\n        np.testing.assert_allclose(expected, noised, atol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_rician_noised.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandRicianNoised\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([\"test_zero_mean\", p, [\"img1\", \"img2\"], 0, 0.1])\n    TESTS.append([\"test_non_zero_mean\", p, [\"img1\", \"img2\"], 1, 0.5])\n\nseed = 0\n\n\nclass TestRandRicianNoisedNumpy(NumpyImageTestCase2D):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, _, in_type, keys, mean, std):\n        rician_fn = RandRicianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64)\n        rician_fn.set_random_state(seed)\n        noised = rician_fn({k: in_type(self.imt) for k in keys})\n        np.random.seed(seed)\n        for k in keys:\n            # simulate the `randomize` function of transform\n            np.random.random()\n            _std = np.random.uniform(0, std)\n            expected = np.sqrt(\n                (self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2\n                + np.random.normal(mean, _std, size=self.imt.shape) ** 2\n            )\n            if isinstance(noised[k], torch.Tensor):\n                noised[k] = noised[k].cpu()\n            np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_rotate.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport scipy.ndimage\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandRotate\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import (\n    TEST_NDARRAYS_ALL,\n    NumpyImageTestCase2D,\n    NumpyImageTestCase3D,\n    assert_allclose,\n    test_local_inversion,\n)\n\nTEST_CASES_2D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_2D.append((p, np.pi / 2, True, \"bilinear\", \"border\", False))\n    TEST_CASES_2D.append((p, np.pi / 4, True, \"nearest\", \"border\", False))\n    TEST_CASES_2D.append((p, np.pi, False, \"nearest\", \"zeros\", True))\n    if not USE_COMPILED:\n        TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, \"nearest\", \"zeros\", True))\n\nTEST_CASES_3D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_3D.append(\n        (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, \"bilinear\", \"border\", False, (1, 81, 110, 112))\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            np.pi / 4,\n            (-np.pi / 9, np.pi / 4.5),\n            (np.pi / 9, np.pi / 6),\n            False,\n            \"nearest\",\n            \"border\",\n            True,\n            (1, 97, 100, 97),\n        )\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            0.0,\n            (2 * np.pi, 2.06 * np.pi),\n            (-np.pi / 180, np.pi / 180),\n            True,\n            \"nearest\",\n            \"zeros\",\n            True,\n            (1, 64, 48, 80),\n        )\n    )\n    TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, \"nearest\", \"zeros\", False, (1, 64, 61, 87)))\n\n\nclass TestRandRotate2D(NumpyImageTestCase2D):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):\n        init_param = {\n            \"range_x\": degrees,\n            \"prob\": 1.0,\n            \"keep_size\": keep_size,\n            \"mode\": mode,\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = RandRotate(**init_param)\n        rotate_fn.set_random_state(243)\n        call_param = {\"img\": im_type(self.imt[0])}\n        rotated = rotate_fn(**call_param)\n\n        # test lazy\n        test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243)\n        rotate_fn.lazy = False\n\n        _order = 0 if mode == \"nearest\" else 1\n        if mode == \"border\":\n            _mode = \"nearest\"\n        elif mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n        angle = rotate_fn.x\n        expected = scipy.ndimage.rotate(\n            self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False\n        )\n        expected = np.stack(expected).astype(np.float32)\n        rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated\n        good = np.sum(np.isclose(expected, rotated[0], atol=1e-3))\n        self.assertLessEqual(np.abs(good - expected.size), 40, \"diff at most 40 pixels\")\n\n\n@unittest.skipIf(USE_COMPILED, \"unit tests not for compiled version.\")\nclass TestRandRotate3D(NumpyImageTestCase3D):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):\n        init_param = {\n            \"range_x\": x,\n            \"range_y\": y,\n            \"range_z\": z,\n            \"prob\": 1.0,\n            \"keep_size\": keep_size,\n            \"mode\": mode,\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = RandRotate(**init_param)\n        rotate_fn.set_random_state(243)\n        im = im_type(self.imt[0])\n        call_param = {\"img\": im}\n        rotated = rotate_fn(**call_param)\n\n        # test lazy\n        test_resampler_lazy(rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243)\n        rotate_fn.lazy = False\n\n        assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0)\n        test_local_inversion(rotate_fn, rotated, im)\n\n        set_track_meta(False)\n        rotated = rotate_fn(im)\n        self.assertNotIsInstance(rotated, MetaTensor)\n        self.assertIsInstance(rotated, torch.Tensor)\n        set_track_meta(True)\n\n\nclass TestRandRotateDtype(NumpyImageTestCase2D):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):\n        rotate_fn = RandRotate(\n            range_x=1.0,\n            prob=0.5,\n            keep_size=keep_size,\n            mode=mode,\n            padding_mode=padding_mode,\n            align_corners=align_corners,\n            dtype=np.float64,\n        )\n        im = im_type(self.imt[0])\n        rotated = rotate_fn(im)\n        self.assertEqual(rotated.dtype, torch.float32)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_rotate90.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandRotate90\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\n\nclass TestRandRotate90(NumpyImageTestCase2D):\n    def test_default(self):\n        rotate = RandRotate90()\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(123)\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123)\n            rotate.lazy = False\n\n    def test_k(self):\n        init_param = {\"max_k\": 2}\n        rotate = RandRotate90(**init_param)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            set_track_meta(False)\n            rotated = rotate(im)\n            self.assertNotIsInstance(rotated, MetaTensor)\n            self.assertIsInstance(rotated, torch.Tensor)\n\n            set_track_meta(True)\n            rotate.set_random_state(123)\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=123)\n            rotate.lazy = False\n\n    def test_spatial_axes(self):\n        rotate = RandRotate90(spatial_axes=(0, 1), prob=1.0)\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(1234)\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1234)\n            rotate.lazy = False\n\n            self.assertEqual(len(rotated.applied_operations), 1)\n            expected = [np.rot90(channel, rotate._rand_k, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n            test_local_inversion(rotate, rotated, im)\n\n    def test_prob_k_spatial_axes(self):\n        rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(234)\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_rotate90d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import RandRotate90d\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\n\nclass TestRandRotate90d(NumpyImageTestCase2D):\n    def test_default(self):\n        key = \"test\"\n        rotate = RandRotate90d(keys=key)\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(1323)\n            im = {key: p(self.imt[0])}\n            call_param = {\"data\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=1323, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im, key)\n            expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n\n            set_track_meta(False)\n            rotated = rotate(im)[key]\n            self.assertNotIsInstance(rotated, MetaTensor)\n            self.assertIsInstance(rotated, torch.Tensor)\n            set_track_meta(True)\n\n    def test_k(self):\n        key = \"test\"\n        rotate = RandRotate90d(keys=key, max_k=2)\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(234)\n            im = {key: p(self.imt[0])}\n            call_param = {\"data\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im, key)\n            expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n\n    def test_spatial_axes(self):\n        key = \"test\"\n        rotate = RandRotate90d(keys=key, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(234)\n            im = {key: p(self.imt[0])}\n            call_param = {\"data\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im, key)\n            expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n\n    def test_prob_k_spatial_axes(self):\n        key = \"test\"\n        rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            rotate.set_random_state(234)\n            im = {key: p(self.imt[0])}\n            call_param = {\"data\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, seed=234, output_key=key)\n            rotate.lazy = False\n\n            expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n            test_local_inversion(rotate, rotated, im, key)\n\n    def test_no_key(self):\n        key = \"unknown\"\n        rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1))\n        with self.assertRaisesRegex(KeyError, \"\"):\n            rotate({\"test\": self.imt[0]})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_rotated.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport scipy.ndimage\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.transforms import RandRotated\nfrom monai.utils import GridSampleMode, GridSamplePadMode\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion\n\nTEST_CASES_2D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_2D.append((p, np.pi / 2, True, \"bilinear\", \"border\", False))\n    TEST_CASES_2D.append((p, np.pi / 4, True, \"nearest\", \"border\", False))\n    TEST_CASES_2D.append((p, np.pi, False, \"nearest\", \"zeros\", True))\n    if not USE_COMPILED:\n        TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, \"nearest\", \"zeros\", True))\n\nTEST_CASES_3D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_3D.append(\n        (p, np.pi / 2, -np.pi / 6, (0.0, np.pi), False, \"bilinear\", \"border\", False, (1, 81, 110, 112))\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            np.pi / 2,\n            -np.pi / 6,\n            (0.0, np.pi),\n            False,\n            GridSampleMode.NEAREST,\n            GridSamplePadMode.BORDER,\n            False,\n            (1, 81, 110, 112),\n        )\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            np.pi / 4,\n            (-np.pi / 9, np.pi / 4.5),\n            (np.pi / 9, np.pi / 6),\n            False,\n            \"nearest\",\n            \"border\",\n            True,\n            (1, 97, 100, 97),\n        )\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            np.pi / 4,\n            (-np.pi / 9, np.pi / 4.5),\n            (np.pi / 9, np.pi / 6),\n            False,\n            GridSampleMode.NEAREST,\n            GridSamplePadMode.BORDER,\n            True,\n            (1, 97, 100, 97),\n        )\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            0.0,\n            (2 * np.pi, 2.06 * np.pi),\n            (-np.pi / 180, np.pi / 180),\n            True,\n            \"nearest\",\n            \"zeros\",\n            True,\n            (1, 64, 48, 80),\n        )\n    )\n    TEST_CASES_3D.append(\n        (\n            p,\n            0.0,\n            (2 * np.pi, 2.06 * np.pi),\n            (-np.pi / 180, np.pi / 180),\n            True,\n            GridSampleMode.NEAREST,\n            GridSamplePadMode.ZEROS,\n            True,\n            (1, 64, 48, 80),\n        )\n    )\n    TEST_CASES_3D.append((p, (-np.pi / 4, 0), 0, 0, False, \"nearest\", \"zeros\", False, (1, 64, 61, 87)))\n    TEST_CASES_3D.append(\n        (p, (-np.pi / 4, 0), 0, 0, False, GridSampleMode.NEAREST, GridSamplePadMode.ZEROS, False, (1, 64, 61, 87))\n    )\n\n\nclass TestRandRotated2D(NumpyImageTestCase2D):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):\n        init_param = {\n            \"keys\": \"img\",\n            \"range_x\": degrees,\n            \"prob\": 1.0,\n            \"keep_size\": keep_size,\n            \"mode\": mode,\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = RandRotated(**init_param)\n        im = im_type(self.imt[0])\n        rotate_fn.set_random_state(243)\n        call_param = {\"data\": {\"img\": im, \"seg\": im_type(self.segn[0])}}\n        rotated = rotate_fn(**call_param)\n\n        # test lazy\n        test_resampler_lazy(\n            rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243, output_key=\"img\"\n        )\n\n        _order = 0 if mode == \"nearest\" else 1\n        if padding_mode == \"border\":\n            _mode = \"nearest\"\n        elif padding_mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n        angle = rotate_fn.rand_rotate.x\n        expected = scipy.ndimage.rotate(\n            self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False\n        )\n        test_local_inversion(rotate_fn, rotated, {\"img\": im}, \"img\")\n        for k, v in rotated.items():\n            rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v\n        expected = np.stack(expected).astype(np.float32)\n        good = np.sum(np.isclose(expected, rotated[\"img\"][0], atol=1e-3))\n        self.assertLessEqual(np.abs(good - expected.size), 40, \"diff at most 40 pixels\")\n\n\n@unittest.skipIf(USE_COMPILED, \"unit tests not for compiled version.\")\nclass TestRandRotated3D(NumpyImageTestCase3D):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):\n        init_param = {\n            \"keys\": (\"img\", \"seg\"),\n            \"range_x\": x,\n            \"range_y\": y,\n            \"range_z\": z,\n            \"prob\": 1.0,\n            \"keep_size\": keep_size,\n            \"mode\": mode,\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = RandRotated(**init_param)\n        rotate_fn.set_random_state(243)\n        call_param = {\"data\": {\"img\": im_type(self.imt[0]), \"seg\": im_type(self.segn[0])}}\n        rotated = rotate_fn(**call_param)\n\n        # test lazy\n        test_resampler_lazy(\n            rotate_fn, rotated, init_param=init_param, call_param=call_param, seed=243, output_key=\"img\"\n        )\n        np.testing.assert_allclose(rotated[\"img\"].shape, expected)\n\n        rotate_fn.prob = 0.0\n        rotated = rotate_fn({\"img\": im_type(self.imt[0]), \"seg\": im_type(self.segn[0])})\n        self.assertEqual(rotated[\"seg\"].dtype, torch.float32)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_scale_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandScaleCrop\nfrom tests.croppers import CropTest\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_SHAPES = [\n    [{\"roi_scale\": [1.0, 1.0, -1.0], \"random_center\": True}, (3, 3, 3, 4), (3, 3, 3, 4)],\n    [{\"roi_scale\": [1.0, 1.0, 1.0], \"random_center\": False}, (3, 3, 3, 3), (3, 3, 3, 3)],\n]\n\nTEST_VALUES = [\n    [\n        {\"roi_scale\": [0.6, 0.6], \"random_center\": False},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n    ]\n]\n\nTEST_RANDOM_SHAPES = [\n    [\n        {\"roi_scale\": [0.75, 0.6, 0.5], \"max_roi_scale\": [1.0, -1.0, 0.6], \"random_center\": True, \"random_size\": True},\n        (1, 4, 5, 6),\n        (1, 3, 4, 3),\n    ],\n    [{\"roi_scale\": 0.6, \"max_roi_scale\": 0.8, \"random_center\": True, \"random_size\": True}, (1, 4, 5, 6), (1, 3, 4, 4)],\n    [{\"roi_scale\": 0.2, \"max_roi_scale\": 0.8, \"random_center\": True, \"random_size\": True}, (1, 4, 5, 6), (1, 3, 2, 4)],\n]\n\n\nclass TestRandScaleCrop(CropTest):\n    Cropper = RandScaleCrop\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_data):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = RandScaleCrop(**input_param)\n                result = cropper(im_type(input_data))\n                roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size]\n                assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=\"tensor\")\n\n    @parameterized.expand(TEST_RANDOM_SHAPES)\n    def test_random_shape(self, input_param, input_shape, expected_shape):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = RandScaleCrop(**input_param)\n                cropper.set_random_state(seed=123)\n                input_data = im_type(np.random.randint(0, 2, input_shape))\n                result = cropper(input_data)\n                self.assertTupleEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_scale_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandScaleCropd\nfrom tests.croppers import CropTest\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_SHAPES = [\n    [{\"keys\": \"img\", \"roi_scale\": [1.0, 1.0, -1.0], \"random_center\": True}, (3, 3, 3, 4), (3, 3, 3, 4)],\n    [\n        # test `allow_missing_keys` with key \"label\"\n        {\"keys\": [\"label\", \"img\"], \"roi_scale\": [1.0, 1.0, 1.0], \"random_center\": False, \"allow_missing_keys\": True},\n        (3, 3, 3, 3),\n        (3, 3, 3, 3),\n    ],\n]\n\nTEST_VALUES = [\n    [\n        {\"keys\": \"img\", \"roi_scale\": [0.6, 0.6], \"random_center\": False},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n    ]\n]\n\nTEST_RANDOM_SHAPES = [\n    [\n        {\n            \"keys\": \"img\",\n            \"roi_scale\": [0.75, 0.6, 0.5],\n            \"max_roi_scale\": [1.0, -1.0, 0.6],\n            \"random_center\": True,\n            \"random_size\": True,\n        },\n        (1, 4, 5, 6),\n        (1, 3, 4, 3),\n    ],\n    [\n        {\"keys\": \"img\", \"roi_scale\": 0.6, \"max_roi_scale\": 0.8, \"random_center\": True, \"random_size\": True},\n        (1, 4, 5, 6),\n        (1, 3, 4, 4),\n    ],\n    [\n        {\"keys\": \"img\", \"roi_scale\": 0.2, \"max_roi_scale\": 0.8, \"random_center\": True, \"random_size\": True},\n        (1, 4, 5, 6),\n        (1, 3, 2, 4),\n    ],\n]\n\n\nclass TestRandScaleCropd(CropTest):\n    Cropper = RandScaleCropd\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_im):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = self.Cropper(**input_param)\n                input_data = {\"img\": im_type(input_im)}\n                result = cropper(input_data)[\"img\"]\n                roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size]\n                assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=\"tensor\")\n\n    @parameterized.expand(TEST_RANDOM_SHAPES)\n    def test_random_shape(self, input_param, input_shape, expected_shape):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = self.Cropper(**input_param)\n                cropper.set_random_state(seed=123)\n                input_data = {\"img\": im_type(np.random.randint(0, 2, input_shape))}\n                result = cropper(input_data)[\"img\"]\n                self.assertTupleEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_scale_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandScaleIntensity\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandScaleIntensity(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_value(self, p):\n        scaler = RandScaleIntensity(factors=0.5, prob=1.0)\n        scaler.set_random_state(seed=0)\n        im = p(self.imt)\n        result = scaler(im)\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32))\n        assert_allclose(result, p(expected), rtol=1e-7, atol=0, type_test=\"tensor\")\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, p):\n        scaler = RandScaleIntensity(factors=0.5, channel_wise=True, prob=1.0)\n        scaler.set_random_state(seed=0)\n        im = p(self.imt)\n        result = scaler(im)\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        channel_num = self.imt.shape[0]\n        factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]\n        expected = p(\n            np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32)\n        )\n        assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_scale_intensity_fixed_mean.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandScaleIntensityFixedMean\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandScaleIntensity(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_value(self, p):\n        scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5)\n        scaler.set_random_state(seed=0)\n        im = p(self.imt)\n        result = scaler(im)\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        mn = im.mean()\n        im = im - mn\n        expected = (1 + np.random.uniform(low=-0.5, high=0.5)) * im\n        expected = expected + mn\n        assert_allclose(result, expected, type_test=\"tensor\", atol=1e-7)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_scale_intensity_fixed_meand.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import RandScaleIntensityFixedMeand\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandScaleIntensityFixedMeand(NumpyImageTestCase2D):\n    def test_value(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = RandScaleIntensityFixedMeand(keys=[key], factors=0.5, prob=1.0)\n            scaler.set_random_state(seed=0)\n            result = scaler({key: p(self.imt)})\n            np.random.seed(0)\n            # simulate the randomize function of transform\n            np.random.random()\n            im = self.imt\n            mn = im.mean()\n            im = im - mn\n            expected = (1 + np.random.uniform(low=-0.5, high=0.5)) * im\n            expected = expected + mn\n            assert_allclose(result[key], p(expected), type_test=\"tensor\", atol=1e-6)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_scale_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import RandScaleIntensityd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandScaleIntensityd(NumpyImageTestCase2D):\n    def test_value(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0)\n            scaler.set_random_state(seed=0)\n            result = scaler({key: p(self.imt)})\n            np.random.seed(0)\n            # simulate the randomize function of transform\n            np.random.random()\n            expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n    def test_channel_wise(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0, channel_wise=True)\n            scaler.set_random_state(seed=0)\n            result = scaler({key: p(self.imt)})\n            np.random.seed(0)\n            # simulate the randomize function of transform\n            np.random.random()\n            channel_num = self.imt.shape[0]\n            factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]\n            expected = p(\n                np.stack([np.asarray((self.imt[i]) * (1 + factor[i])) for i in range(channel_num)]).astype(np.float32)\n            )\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_shift_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandShiftIntensity\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandShiftIntensity(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_value(self, p):\n        shifter = RandShiftIntensity(offsets=1.0, prob=1.0)\n        shifter.set_random_state(seed=0)\n        im = p(self.imt)\n        result = shifter(im, factor=1.0)\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        expected = self.imt + np.random.uniform(low=-1.0, high=1.0)\n        assert_allclose(result, expected, type_test=\"tensor\")\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_channel_wise(self, p):\n        scaler = RandShiftIntensity(offsets=3.0, channel_wise=True, prob=1.0)\n        scaler.set_random_state(seed=0)\n        im = p(self.imt)\n        result = scaler(im)\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        channel_num = self.imt.shape[0]\n        factor = [np.random.uniform(low=-3.0, high=3.0) for _ in range(channel_num)]\n        expected = p(np.stack([np.asarray((self.imt[i]) + factor[i]) for i in range(channel_num)]).astype(np.float32))\n        assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_shift_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import IntensityStatsd, RandShiftIntensityd\nfrom monai.utils.enums import PostFix\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandShiftIntensityd(NumpyImageTestCase2D):\n    def test_value(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            shifter = RandShiftIntensityd(keys=[key], offsets=1.0, prob=1.0)\n            shifter.set_random_state(seed=0)\n            result = shifter({key: p(self.imt)})\n            np.random.seed(0)\n            # simulate the randomize() of transform\n            np.random.random()\n            expected = self.imt + np.random.uniform(low=-1.0, high=1.0)\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n    def test_factor(self):\n        key = \"img\"\n        stats = IntensityStatsd(keys=key, ops=\"max\", key_prefix=\"orig\")\n        shifter = RandShiftIntensityd(keys=[key], offsets=1.0, factor_key=[\"orig_max\"], prob=1.0)\n        data = {key: self.imt, PostFix.meta(key): {\"affine\": None}}\n        shifter.set_random_state(seed=0)\n        result = shifter(stats(data))\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        expected = self.imt + np.random.uniform(low=-1.0, high=1.0) * np.nanmax(self.imt)\n        np.testing.assert_allclose(result[key], expected)\n\n    def test_channel_wise(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = RandShiftIntensityd(keys=[key], offsets=3.0, prob=1.0, channel_wise=True)\n            scaler.set_random_state(seed=0)\n            result = scaler({key: p(self.imt)})\n            np.random.seed(0)\n            # simulate the randomize function of transform\n            np.random.random()\n            channel_num = self.imt.shape[0]\n            factor = [np.random.uniform(low=-3.0, high=3.0) for _ in range(channel_num)]\n            expected = p(\n                np.stack([np.asarray((self.imt[i]) + factor[i]) for i in range(channel_num)]).astype(np.float32)\n            )\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_simulate_low_resolution.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandSimulateLowResolution\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            dict(prob=1.0, zoom_range=(0.8, 0.81)),\n            p(\n                np.array(\n                    [\n                        [\n                            [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],\n                            [[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]],\n                            [[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]],\n                            [[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]],\n                        ]\n                    ]\n                )\n            ),\n            np.array(\n                [\n                    [\n                        [\n                            [0.0000, 0.6250, 1.3750, 2.0000],\n                            [2.5000, 3.1250, 3.8750, 4.5000],\n                            [5.5000, 6.1250, 6.8750, 7.5000],\n                            [8.0000, 8.6250, 9.3750, 10.0000],\n                        ],\n                        [\n                            [10.0000, 10.6250, 11.3750, 12.0000],\n                            [12.5000, 13.1250, 13.8750, 14.5000],\n                            [15.5000, 16.1250, 16.8750, 17.5000],\n                            [18.0000, 18.6250, 19.3750, 20.0000],\n                        ],\n                        [\n                            [22.0000, 22.6250, 23.3750, 24.0000],\n                            [24.5000, 25.1250, 25.8750, 26.5000],\n                            [27.5000, 28.1250, 28.8750, 29.5000],\n                            [30.0000, 30.6250, 31.3750, 32.0000],\n                        ],\n                        [\n                            [32.0000, 32.6250, 33.3750, 34.0000],\n                            [34.5000, 35.1250, 35.8750, 36.5000],\n                            [37.5000, 38.1250, 38.8750, 39.5000],\n                            [40.0000, 40.6250, 41.3750, 42.0000],\n                        ],\n                    ]\n                ]\n            ),\n        ]\n    )\n\n\nclass TestRandGaussianSmooth(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        randsimlowres = RandSimulateLowResolution(**arguments)\n        randsimlowres.set_random_state(seed=0)\n        result = randsimlowres(image)\n        assert_allclose(result, expected_data, rtol=1e-4, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_simulate_low_resolutiond.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandSimulateLowResolutiond\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            dict(keys=[\"img\", \"seg\"], prob=1.0, zoom_range=(0.8, 0.81)),\n            {\"img\": p(np.arange(64).reshape(1, 4, 4, 4)), \"seg\": p(np.arange(64).reshape(1, 4, 4, 4))},\n            np.array(\n                [\n                    [\n                        [\n                            [0.0000, 0.6250, 1.3750, 2.0000],\n                            [2.5000, 3.1250, 3.8750, 4.5000],\n                            [5.5000, 6.1250, 6.8750, 7.5000],\n                            [8.0000, 8.6250, 9.3750, 10.0000],\n                        ],\n                        [\n                            [10.0000, 10.6250, 11.3750, 12.0000],\n                            [12.5000, 13.1250, 13.8750, 14.5000],\n                            [15.5000, 16.1250, 16.8750, 17.5000],\n                            [18.0000, 18.6250, 19.3750, 20.0000],\n                        ],\n                        [\n                            [22.0000, 22.6250, 23.3750, 24.0000],\n                            [24.5000, 25.1250, 25.8750, 26.5000],\n                            [27.5000, 28.1250, 28.8750, 29.5000],\n                            [30.0000, 30.6250, 31.3750, 32.0000],\n                        ],\n                        [\n                            [32.0000, 32.6250, 33.3750, 34.0000],\n                            [34.5000, 35.1250, 35.8750, 36.5000],\n                            [37.5000, 38.1250, 38.8750, 39.5000],\n                            [40.0000, 40.6250, 41.3750, 42.0000],\n                        ],\n                    ]\n                ]\n            ),\n        ]\n    )\n\n\nclass TestRandGaussianSmoothd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, arguments, image, expected_data):\n        converter = RandSimulateLowResolutiond(**arguments)\n        converter.set_random_state(seed=0)\n        result = converter(image)\n        assert_allclose(result[\"img\"], expected_data, rtol=1e-4, type_test=False)\n        assert_allclose(result[\"seg\"], expected_data, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_spatial_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import RandScaleCrop, RandSpatialCrop\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.croppers import CropTest\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_SHAPES = [\n    [{\"roi_size\": [3, 3, -1], \"random_center\": True}, (3, 3, 3, 4), (3, 3, 3, 4)],\n    [{\"roi_size\": [3, 3, 3], \"random_center\": True}, (3, 3, 3, 3), (3, 3, 3, 3)],\n    [{\"roi_size\": [3, 3, 3], \"random_center\": False}, (3, 3, 3, 3), (3, 3, 3, 3)],\n    [{\"roi_size\": [3, 3, 2], \"random_center\": False, \"random_size\": False}, (3, 3, 3, 3), (3, 3, 3, 2)],\n]\n\nTEST_VALUES = [\n    [\n        {\"roi_size\": [3, 3], \"random_center\": False},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n    ]\n]\n\nTEST_RANDOM_SHAPES = [\n    [\n        {\"roi_size\": [3, 3, 3], \"max_roi_size\": [5, -1, 4], \"random_center\": True, \"random_size\": True},\n        (1, 4, 5, 6),\n        (1, 4, 4, 3),\n    ],\n    [{\"roi_size\": 3, \"max_roi_size\": 4, \"random_center\": True, \"random_size\": True}, (1, 4, 5, 6), (1, 3, 4, 3)],\n]\n\nfunc1 = {RandSpatialCrop: {\"roi_size\": [8, 7, -1], \"random_center\": True, \"random_size\": False}}\nfunc2 = {RandScaleCrop: {\"roi_scale\": [0.5, 0.6, -1.0], \"random_center\": True, \"random_size\": True}}\nfunc3 = {RandScaleCrop: {\"roi_scale\": [1.0, 0.5, -1.0], \"random_center\": False, \"random_size\": False}}\n\nTESTS_COMBINE = []\nTESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)])\nTESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)])\nTESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4)])\n\n\nclass TestRandSpatialCrop(CropTest):\n    Cropper = RandSpatialCrop\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_data):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = RandSpatialCrop(**input_param)\n                result = cropper(im_type(input_data))\n                roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper._size]\n                assert_allclose(result, input_data[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=\"tensor\")\n\n    @parameterized.expand(TEST_RANDOM_SHAPES)\n    def test_random_shape(self, input_param, input_shape, expected_shape):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = RandSpatialCrop(**input_param)\n                cropper.set_random_state(seed=123)\n                input_data = im_type(np.random.randint(0, 2, input_shape))\n                expected = cropper(input_data)\n                self.assertTupleEqual(expected.shape, expected_shape)\n\n                # lazy\n                # reset random seed to ensure the same results\n                cropper.set_random_state(seed=123)\n                cropper.lazy = True\n                pending_result = cropper(input_data)\n                self.assertIsInstance(pending_result, MetaTensor)\n                assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n                assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n                # only support nearest\n                result = apply_pending(pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n                # compare\n                assert_allclose(result, expected, rtol=1e-5)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n    @parameterized.expand(TESTS_COMBINE)\n    def test_combine_ops(self, funcs, input_shape):\n        self.crop_test_combine_ops(funcs, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_spatial_crop_samples.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import RandSpatialCropSamples\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.croppers import CropTest\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_CASE_1 = [\n    {\"roi_size\": [3, 3, 3], \"num_samples\": 4, \"random_center\": True, \"random_size\": False},\n    (3, 4, 4, 4),\n    [(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)],\n    np.array(\n        [\n            [\n                [[21, 22, 23], [25, 26, 27], [29, 30, 31]],\n                [[37, 38, 39], [41, 42, 43], [45, 46, 47]],\n                [[53, 54, 55], [57, 58, 59], [61, 62, 63]],\n            ],\n            [\n                [[85, 86, 87], [89, 90, 91], [93, 94, 95]],\n                [[101, 102, 103], [105, 106, 107], [109, 110, 111]],\n                [[117, 118, 119], [121, 122, 123], [125, 126, 127]],\n            ],\n            [\n                [[149, 150, 151], [153, 154, 155], [157, 158, 159]],\n                [[165, 166, 167], [169, 170, 171], [173, 174, 175]],\n                [[181, 182, 183], [185, 186, 187], [189, 190, 191]],\n            ],\n        ]\n    ),\n]\n\nTEST_CASE_2 = [\n    {\"roi_size\": [3, 3, 3], \"num_samples\": 8, \"random_center\": False, \"random_size\": True},\n    (3, 4, 4, 4),\n    [(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)],\n    np.array(\n        [\n            [\n                [[21, 22, 23], [25, 26, 27], [29, 30, 31]],\n                [[37, 38, 39], [41, 42, 43], [45, 46, 47]],\n                [[53, 54, 55], [57, 58, 59], [61, 62, 63]],\n            ],\n            [\n                [[85, 86, 87], [89, 90, 91], [93, 94, 95]],\n                [[101, 102, 103], [105, 106, 107], [109, 110, 111]],\n                [[117, 118, 119], [121, 122, 123], [125, 126, 127]],\n            ],\n            [\n                [[149, 150, 151], [153, 154, 155], [157, 158, 159]],\n                [[165, 166, 167], [169, 170, 171], [173, 174, 175]],\n                [[181, 182, 183], [185, 186, 187], [189, 190, 191]],\n            ],\n        ]\n    ),\n]\n\nTEST_INVERSE_LIST = [\n    [(1, 2, 2), {\"roi_size\": (1, 1), \"num_samples\": 4, \"random_size\": False}],\n    [(1, 3, 2), {\"roi_size\": (1, 1), \"num_samples\": 100, \"random_size\": False}],\n    [(3, 10, 11, 12), {\"roi_size\": (3, 5, 4), \"num_samples\": 7, \"random_size\": False}],\n    [(3, 10, 11, 12), {\"roi_size\": (10, 11, 12), \"num_samples\": 3, \"random_size\": False}],\n    [(3, 10, 11, 12), {\"roi_size\": (3, 4, 5), \"num_samples\": 100, \"random_size\": False}],\n]\n\n\nclass TestRandSpatialCropSamples(CropTest):\n    Cropper = RandSpatialCropSamples\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_shape(self, input_param, input_shape, expected_shape, expected_last_item):\n        input_data = np.arange(192).reshape(*input_shape)\n\n        for p in TEST_NDARRAYS_ALL:\n            xform = RandSpatialCropSamples(**input_param)\n            xform.set_random_state(1234)\n            result = xform(p(input_data))\n\n            np.testing.assert_equal(len(result), input_param[\"num_samples\"])\n            for i, (item, expected) in enumerate(zip(result, expected_shape)):\n                self.assertTupleEqual(item.shape, expected)\n                self.assertEqual(item.meta[\"patch_index\"], i)\n            assert_allclose(result[-1], expected_last_item, type_test=\"tensor\")\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2])\n    def test_pending_ops(self, input_param, input_shape, _expected_shape, _expected_last_item):\n        input_data = np.arange(192).reshape(*input_shape)\n\n        for p in TEST_NDARRAYS_ALL:\n            xform = RandSpatialCropSamples(**input_param)\n            image = p(input_data)\n            # non-lazy\n            xform.set_random_state(1234)\n            expected = xform(image)\n            self.assertIsInstance(expected[0], MetaTensor)\n            # lazy\n            xform.set_random_state(1234)\n            xform.lazy = True\n            pending_result = xform(image)\n            for i, _pending_result in enumerate(pending_result):\n                self.assertIsInstance(_pending_result, MetaTensor)\n                assert_allclose(_pending_result.peek_pending_affine(), expected[i].affine)\n                assert_allclose(_pending_result.peek_pending_shape(), expected[i].shape[1:])\n                # only support nearest\n                result = apply_pending(_pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n                # compare\n                assert_allclose(result, expected[i], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_spatial_crop_samplesd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import Compose, DivisiblePadd, RandSpatialCropSamplesd\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_CASE_1 = [\n    {\"keys\": [\"img\", \"seg\"], \"num_samples\": 4, \"roi_size\": [2, 2, 2], \"random_center\": True, \"random_size\": True},\n    {\"img\": np.arange(81).reshape(3, 3, 3, 3), \"seg\": np.arange(81, 0, -1).reshape(3, 3, 3, 3)},\n    [(3, 2, 2, 2), (3, 2, 3, 3), (3, 2, 3, 2), (3, 2, 3, 2)],\n    {\n        \"img\": np.array(\n            [\n                [[[1, 2], [4, 5], [7, 8]], [[10, 11], [13, 14], [16, 17]]],\n                [[[28, 29], [31, 32], [34, 35]], [[37, 38], [40, 41], [43, 44]]],\n                [[[55, 56], [58, 59], [61, 62]], [[64, 65], [67, 68], [70, 71]]],\n            ]\n        ),\n        \"seg\": np.array(\n            [\n                [[[80, 79], [77, 76], [74, 73]], [[71, 70], [68, 67], [65, 64]]],\n                [[[53, 52], [50, 49], [47, 46]], [[44, 43], [41, 40], [38, 37]]],\n                [[[26, 25], [23, 22], [20, 19]], [[17, 16], [14, 13], [11, 10]]],\n            ]\n        ),\n    },\n]\n\nTEST_CASE_2 = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASE_2.append(\n        [\n            {\n                \"keys\": [\"img\", \"seg\"],\n                \"num_samples\": 8,\n                \"roi_size\": [2, 2, 3],\n                \"random_center\": False,\n                \"random_size\": True,\n            },\n            {\"img\": p(np.arange(81).reshape(3, 3, 3, 3)), \"seg\": p(np.arange(81, 0, -1).reshape(3, 3, 3, 3))},\n            [\n                (3, 2, 2, 3),\n                (3, 2, 2, 3),\n                (3, 3, 3, 3),\n                (3, 2, 3, 3),\n                (3, 3, 3, 3),\n                (3, 2, 3, 3),\n                (3, 2, 3, 3),\n                (3, 3, 2, 3),\n            ],\n            {\n                \"img\": p(\n                    np.array(\n                        [\n                            [[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]],\n                            [[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]],\n                            [[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]],\n                        ]\n                    )\n                ),\n                \"seg\": p(\n                    np.array(\n                        [\n                            [[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]],\n                            [[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]],\n                            [[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]],\n                        ]\n                    )\n                ),\n            },\n        ]\n    )\n\n\nclass TestRandSpatialCropSamplesd(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, *TEST_CASE_2])\n    def test_shape(self, input_param, input_data, expected_shape, expected_last):\n        xform = RandSpatialCropSamplesd(**input_param)\n        xform.set_random_state(1234)\n        result = xform(input_data)\n        _len = len(tuple(input_data.keys()))\n        self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys()))\n        for item, expected in zip(result, expected_shape):\n            self.assertTupleEqual(item[\"img\"].shape, expected)\n            self.assertTupleEqual(item[\"seg\"].shape, expected)\n        for i, item in enumerate(result):\n            self.assertEqual(item[\"img\"].meta[\"patch_index\"], i)\n            self.assertEqual(item[\"seg\"].meta[\"patch_index\"], i)\n        assert_allclose(item[\"img\"], expected_last[\"img\"], type_test=False)\n        assert_allclose(item[\"seg\"], expected_last[\"seg\"], type_test=False)\n\n    def test_deep_copy(self):\n        data = {\"img\": np.ones((1, 10, 11, 12))}\n        num_samples = 3\n        sampler = RandSpatialCropSamplesd(\n            keys=[\"img\"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False\n        )\n        transform = Compose([DivisiblePadd(keys=\"img\", k=5), sampler])\n        samples = transform(data)\n        self.assertEqual(len(samples), num_samples)\n        for sample in samples:\n            self.assertEqual(len(sample[\"img\"].applied_operations), len(transform))\n\n    @parameterized.expand([TEST_CASE_1, *TEST_CASE_2])\n    def test_pending_ops(self, input_param, input_data, _expected_shape, _expected_last):\n        xform = RandSpatialCropSamplesd(**input_param)\n        # non-lazy\n        xform.set_random_state(1234)\n        expected = xform(input_data)\n        self.assertIsInstance(expected[0][\"img\"], MetaTensor)\n\n        # lazy\n        xform.set_random_state(1234)\n        xform.lazy = True\n        pending_result = xform(input_data)\n        for i, _pending_result in enumerate(pending_result):\n            self.assertIsInstance(_pending_result[\"img\"], MetaTensor)\n            assert_allclose(_pending_result[\"img\"].peek_pending_affine(), expected[i][\"img\"].affine)\n            assert_allclose(_pending_result[\"img\"].peek_pending_shape(), expected[i][\"img\"].shape[1:])\n            # only support nearest\n            overrides = {\"mode\": \"nearest\", \"align_corners\": False}\n            result_img = apply_pending(_pending_result[\"img\"], overrides=overrides)[0]\n            result_seg = apply_pending(_pending_result[\"seg\"], overrides=overrides)[0]\n            # compare\n            assert_allclose(result_img, expected[i][\"img\"], rtol=1e-5)\n            assert_allclose(result_seg, expected[i][\"seg\"], rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_spatial_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import RandScaleCropd, RandSpatialCropd\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.croppers import CropTest\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_SHAPES = [\n    [{\"keys\": \"img\", \"roi_size\": [3, 3, -1], \"random_center\": True}, (3, 3, 3, 5), (3, 3, 3, 5)],\n    [{\"keys\": \"img\", \"roi_size\": [3, 3, 3], \"random_center\": True}, (3, 3, 3, 3), (3, 3, 3, 3)],\n    [{\"keys\": \"img\", \"roi_size\": [3, 3, 3], \"random_center\": False}, (3, 3, 3, 3), (3, 3, 3, 3)],\n    [{\"keys\": \"img\", \"roi_size\": [3, 2, 3], \"random_center\": False, \"random_size\": False}, (3, 3, 3, 3), (3, 3, 2, 3)],\n]\n\nTEST_VALUES = [\n    [\n        {\"keys\": \"img\", \"roi_size\": [3, 3], \"random_center\": False},\n        np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 1, 2, 1, 0], [0, 0, 0, 0, 0]]]),\n    ]\n]\n\nTEST_RANDOM_SHAPES = [\n    [\n        {\"keys\": \"img\", \"roi_size\": [3, 3, 3], \"max_roi_size\": [5, -1, 4], \"random_center\": True, \"random_size\": True},\n        (1, 4, 5, 6),\n        (1, 4, 4, 3),\n    ],\n    [\n        {\"keys\": \"img\", \"roi_size\": 3, \"max_roi_size\": 4, \"random_center\": True, \"random_size\": True},\n        (1, 4, 5, 6),\n        (1, 3, 4, 3),\n    ],\n]\n\nfunc1 = {RandSpatialCropd: {\"keys\": \"img\", \"roi_size\": [8, 7, -1], \"random_center\": True, \"random_size\": False}}\nfunc2 = {RandScaleCropd: {\"keys\": \"img\", \"roi_scale\": [0.5, 0.6, -1.0], \"random_center\": True, \"random_size\": True}}\nfunc3 = {RandScaleCropd: {\"keys\": \"img\", \"roi_scale\": [1.0, 0.5, -1.0], \"random_center\": False, \"random_size\": False}}\n\nTESTS_COMBINE = []\nTESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)])\nTESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)])\nTESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4)])\n\n\nclass TestRandSpatialCropd(CropTest):\n    Cropper = RandSpatialCropd\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_VALUES)\n    def test_value(self, input_param, input_im):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = self.Cropper(**input_param)\n                input_data = {\"img\": im_type(input_im)}\n                result = cropper(input_data)[\"img\"]\n                roi = [(2 - i // 2, 2 + i - i // 2) for i in cropper.cropper._size]\n                assert_allclose(result, input_im[:, roi[0][0] : roi[0][1], roi[1][0] : roi[1][1]], type_test=\"tensor\")\n\n    @parameterized.expand(TEST_RANDOM_SHAPES)\n    def test_random_shape(self, input_param, input_shape, expected_shape):\n        for im_type in TEST_NDARRAYS_ALL:\n            with self.subTest(im_type=im_type):\n                cropper = self.Cropper(**input_param)\n                cropper.set_random_state(seed=123)\n                input_data = {\"img\": im_type(np.random.randint(0, 2, input_shape))}\n                expected = cropper(input_data)[\"img\"]\n                self.assertTupleEqual(expected.shape, expected_shape)\n\n                # lazy\n                # reset random seed to ensure the same results\n                cropper.set_random_state(seed=123)\n                cropper.lazy = True\n                pending_result = cropper(input_data)[\"img\"]\n                self.assertIsInstance(pending_result, MetaTensor)\n                assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n                assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n                # only support nearest\n                result = apply_pending(pending_result, overrides={\"mode\": \"nearest\", \"align_corners\": False})[0]\n                # compare\n                assert_allclose(result, expected, rtol=1e-5)\n\n    @parameterized.expand(TEST_SHAPES)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n    @parameterized.expand(TESTS_COMBINE)\n    def test_combine_ops(self, funcs, input_shape):\n        self.crop_test_combine_ops(funcs, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_std_shift_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandStdShiftIntensity\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandStdShiftIntensity(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_value(self, p):\n        np.random.seed(0)\n        # simulate the randomize() of transform\n        np.random.random()\n        factor = np.random.uniform(low=-1.0, high=1.0)\n        offset = factor * np.std(self.imt)\n        expected = p(self.imt + offset)\n        shifter = RandStdShiftIntensity(factors=1.0, prob=1.0)\n        shifter.set_random_state(seed=0)\n        _imt = p(self.imt)\n        result = shifter(_imt)\n        if isinstance(_imt, torch.Tensor):\n            self.assertEqual(result.dtype, _imt.dtype)\n        assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_std_shift_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import RandStdShiftIntensityd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestRandStdShiftIntensityd(NumpyImageTestCase2D):\n    def test_value(self):\n        for p in TEST_NDARRAYS:\n            key = \"img\"\n            np.random.seed(0)\n            # simulate the randomize() of transform\n            np.random.random()\n            factor = np.random.uniform(low=-1.0, high=1.0)\n            expected = self.imt + factor * np.std(self.imt)\n            shifter = RandStdShiftIntensityd(keys=[key], factors=1.0, prob=1.0)\n            shifter.set_random_state(seed=0)\n            result = shifter({key: p(self.imt)})[key]\n            assert_allclose(result, expected, rtol=1e-5, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_torchio.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandTorchIO\nfrom monai.utils import optional_import, set_determinism\n\n_, has_torchio = optional_import(\"torchio\")\n\nTEST_DIMS = [3, 128, 160, 160]\nTESTS = [\n    [{\"name\": \"RandomAffine\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomElasticDeformation\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomAnisotropy\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomMotion\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomGhosting\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomSpike\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomBiasField\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomBlur\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomNoise\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomSwap\"}, torch.rand(TEST_DIMS)],\n    [{\"name\": \"RandomGamma\"}, torch.rand(TEST_DIMS)],\n]\n\n\n@skipUnless(has_torchio, \"Requires torchio\")\nclass TestRandTorchIO(unittest.TestCase):\n\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, input_data):\n        set_determinism(seed=0)\n        result = RandTorchIO(**input_param)(input_data)\n        self.assertIsNotNone(result)\n        self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f\"{input_param} failed\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_torchiod.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandTorchIOd\nfrom monai.utils import optional_import, set_determinism\nfrom tests.test_utils import assert_allclose\n\n_, has_torchio = optional_import(\"torchio\")\n\nTEST_DIMS = [3, 128, 160, 160]\nTEST_TENSOR = torch.rand(TEST_DIMS)\nTEST_PARAMS = [[{\"keys\": [\"img1\", \"img2\"], \"name\": \"RandomAffine\"}, {\"img1\": TEST_TENSOR, \"img2\": TEST_TENSOR}]]\n\n\n@skipUnless(has_torchio, \"Requires torchio\")\nclass TestRandTorchIOd(unittest.TestCase):\n    @parameterized.expand(TEST_PARAMS)\n    def test_random_transform(self, input_param, input_data):\n        set_determinism(seed=0)\n        result = RandTorchIOd(**input_param)(input_data)\n        self.assertFalse(np.allclose(input_data[\"img1\"], result[\"img1\"], atol=1e-6, rtol=1e-6))\n        assert_allclose(result[\"img1\"], result[\"img2\"], atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_zoom.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom scipy.ndimage import zoom as zoom_scipy\n\nfrom monai.config import USE_COMPILED\nfrom monai.transforms import RandZoom\nfrom monai.utils import InterpolateMode\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\nVALID_CASES = [\n    (0.8, 1.2, \"nearest\", False),\n    (0.8, 1.2, InterpolateMode.NEAREST, False),\n    (0.8, 1.2, InterpolateMode.BILINEAR, False, True),\n    (0.8, 1.2, InterpolateMode.BILINEAR, False, False),\n]\n\n\nclass TestRandZoom(NumpyImageTestCase2D):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, min_zoom, max_zoom, mode, keep_size, align_corners=None):\n        for p in TEST_NDARRAYS_ALL:\n            init_param = {\n                \"prob\": 1.0,\n                \"min_zoom\": min_zoom,\n                \"max_zoom\": max_zoom,\n                \"mode\": mode,\n                \"keep_size\": keep_size,\n                \"dtype\": torch.float64,\n                \"align_corners\": align_corners,\n            }\n            random_zoom = RandZoom(**init_param)\n            random_zoom.set_random_state(1234)\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            zoomed = random_zoom(**call_param)\n\n            # test lazy\n            # TODO: temporarily skip \"nearest\" test\n            if mode == InterpolateMode.BILINEAR:\n                test_resampler_lazy(\n                    random_zoom, zoomed, init_param, call_param, seed=1234, atol=1e-4 if USE_COMPILED else 1e-6\n                )\n\n            test_local_inversion(random_zoom, zoomed, im)\n            expected = [\n                zoom_scipy(channel, zoom=random_zoom._zoom, mode=\"nearest\", order=0, prefilter=False)\n                for channel in self.imt[0]\n            ]\n\n            expected = np.stack(expected).astype(np.float32)\n            assert_allclose(zoomed, p(expected), atol=1.0, type_test=False)\n\n    def test_keep_size(self):\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True)\n            random_zoom.set_random_state(12)\n            zoomed = random_zoom(im)\n            test_local_inversion(random_zoom, zoomed, im)\n            self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))\n            zoomed = random_zoom(im)\n            self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))\n            zoomed = random_zoom(im)\n            self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))\n            random_zoom.prob = 0.0\n            self.assertEqual(random_zoom(im).dtype, torch.float32)\n\n    @parameterized.expand(\n        [(\"no_min_zoom\", None, 1.1, \"bilinear\", TypeError), (\"invalid_mode\", 0.9, 1.1, \"s\", ValueError)]\n    )\n    def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):\n        for p in TEST_NDARRAYS_ALL:\n            with self.assertRaises(raises):\n                random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode)\n                random_zoom(p(self.imt[0]))\n\n    def test_auto_expand_3d(self):\n        for p in TEST_NDARRAYS_ALL:\n            random_zoom = RandZoom(prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode=\"nearest\", keep_size=False)\n            random_zoom.set_random_state(1234)\n            test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4]))\n            zoomed = random_zoom(test_data)\n            assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2, type_test=False)\n            assert_allclose(zoomed.shape, (2, 2, 3, 3), type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rand_zoomd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom scipy.ndimage import zoom as zoom_scipy\n\nfrom monai.config import USE_COMPILED\nfrom monai.transforms import RandZoomd\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\nVALID_CASES = [\n    (0.8, 1.2, \"nearest\", None, False),\n    (0.8, 1.2, \"bilinear\", None, False),\n    (0.8, 1.2, \"bilinear\", False, False),\n]\n\n\nclass TestRandZoomd(NumpyImageTestCase2D):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_size):\n        key = \"img\"\n        init_param = {\n            \"keys\": key,\n            \"prob\": 1.0,\n            \"min_zoom\": min_zoom,\n            \"max_zoom\": max_zoom,\n            \"mode\": mode,\n            \"align_corners\": align_corners,\n            \"keep_size\": keep_size,\n            \"dtype\": torch.float64,\n        }\n        random_zoom = RandZoomd(**init_param)\n        for p in TEST_NDARRAYS_ALL:\n            random_zoom.set_random_state(1234)\n\n            im = p(self.imt[0])\n            call_param = {\"data\": {key: im}}\n            zoomed = random_zoom(**call_param)\n\n            # test lazy\n            # TODO: temporarily skip \"nearest\" test\n            if mode == \"bilinear\":\n                test_resampler_lazy(\n                    random_zoom, zoomed, init_param, call_param, key, seed=1234, atol=1e-4 if USE_COMPILED else 1e-6\n                )\n                random_zoom.lazy = False\n\n            test_local_inversion(random_zoom, zoomed, {key: im}, key)\n            expected = [\n                zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode=\"nearest\", order=0, prefilter=False)\n                for channel in self.imt[0]\n            ]\n\n            expected = np.stack(expected).astype(np.float32)\n            assert_allclose(zoomed[key], p(expected), atol=1.0, type_test=False)\n\n    def test_keep_size(self):\n        key = \"img\"\n        random_zoom = RandZoomd(\n            keys=key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True, padding_mode=\"constant\", constant_values=2\n        )\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            zoomed = random_zoom({key: im})\n            test_local_inversion(random_zoom, zoomed, {key: im}, key)\n            np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:])\n        random_zoom.prob = 0.0\n        self.assertEqual(random_zoom({key: p(self.imt[0])})[key].dtype, torch.float32)\n\n    @parameterized.expand(\n        [(\"no_min_zoom\", None, 1.1, \"bilinear\", TypeError), (\"invalid_order\", 0.9, 1.1, \"s\", ValueError)]\n    )\n    def test_invalid_inputs(self, _, min_zoom, max_zoom, mode, raises):\n        key = \"img\"\n        for p in TEST_NDARRAYS_ALL:\n            with self.assertRaises(raises):\n                random_zoom = RandZoomd(key, prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode)\n                random_zoom({key: p(self.imt[0])})\n\n    def test_auto_expand_3d(self):\n        random_zoom = RandZoomd(\n            keys=\"img\", prob=1.0, min_zoom=[0.8, 0.7], max_zoom=[1.2, 1.3], mode=\"nearest\", keep_size=False\n        )\n        for p in TEST_NDARRAYS_ALL:\n            random_zoom.set_random_state(1234)\n            test_data = {\"img\": p(np.random.randint(0, 2, size=[2, 2, 3, 4]))}\n            zoomed = random_zoom(test_data)\n            assert_allclose(random_zoom.rand_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2)\n            assert_allclose(zoomed[\"img\"].shape, (2, 2, 3, 3))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_randidentity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport monai.transforms as mt\nfrom monai.data import CacheDataset\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass T(mt.Transform):\n    def __call__(self, x):\n        return x * 2\n\n\nclass TestIdentity(NumpyImageTestCase2D):\n    def test_identity(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n            identity = mt.RandIdentity()\n            assert_allclose(img, identity(img))\n\n    def test_caching(self, init=1, expect=4, expect_pre_cache=2):\n        # check that we get the correct result (two lots of T so should get 4)\n        x = init\n        transforms = mt.Compose([T(), mt.RandIdentity(), T()])\n        self.assertEqual(transforms(x), expect)\n\n        # check we get correct result with CacheDataset\n        x = [init]\n        ds = CacheDataset(x, transforms)\n        self.assertEqual(ds[0], expect)\n\n        # check that the cached value is correct\n        self.assertEqual(ds._cache[0], expect_pre_cache)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_random_order.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nimport monai.transforms.intensity.array as ia\nimport monai.transforms.spatial.array as sa\nimport monai.transforms.spatial.dictionary as sd\nfrom monai.data import MetaTensor\nfrom monai.transforms import RandomOrder\nfrom monai.transforms.compose import Compose\nfrom monai.utils import set_determinism\nfrom monai.utils.enums import TraceKeys\nfrom tests.integration.test_one_of import A, B, C, Inv, NonInv, X, Y\n\n\nclass InvC(Inv):\n    def __init__(self, keys):\n        super().__init__(keys)\n        self.fwd_fn = lambda x: x + 1\n        self.inv_fn = lambda x: x - 1\n\n\nclass InvD(Inv):\n    def __init__(self, keys):\n        super().__init__(keys)\n        self.fwd_fn = lambda x: x * 100\n        self.inv_fn = lambda x: x / 100\n\n\nset_determinism(seed=123)\nKEYS = [\"x\", \"y\"]\nTEST_INVERSES = [\n    (RandomOrder((InvC(KEYS), InvD(KEYS))), True, True),\n    (Compose((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, False),\n    (RandomOrder((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, False),\n    (RandomOrder((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, False),\n    (RandomOrder((NonInv(KEYS), NonInv(KEYS))), False, False),\n]\n\n\nclass TestRandomOrder(unittest.TestCase):\n    def test_empty_compose(self):\n        c = RandomOrder()\n        i = 1\n        self.assertEqual(c(i), 1)\n\n    def test_compose_flatten_does_not_affect_random_order(self):\n        p = Compose([A(), B(), RandomOrder([C(), Inv(KEYS), Compose([X(), Y()])])])\n        f = p.flatten()\n\n        # in this case the flattened transform should be the same.\n        def _match(a, b):\n            self.assertEqual(type(a), type(b))\n            for a_, b_ in zip(a.transforms, b.transforms):\n                self.assertEqual(type(a_), type(b_))\n                if isinstance(a_, (Compose, RandomOrder)):\n                    _match(a_, b_)\n\n        _match(p, f)\n\n    @parameterized.expand(TEST_INVERSES)\n    def test_inverse(self, transform, invertible, use_metatensor):\n        data = {k: MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}\n        fwd_data1 = transform(data)\n        # test call twice won't affect inverse\n        fwd_data2 = transform(data)\n\n        if invertible:\n            for k in KEYS:\n                t = fwd_data1[k].applied_operations[-1]\n                # make sure the RandomOrder applied_order was stored\n                self.assertEqual(t[TraceKeys.CLASS_NAME], RandomOrder.__name__)\n\n        # call the inverse\n        fwd_inv_data1 = transform.inverse(fwd_data1)\n        fwd_inv_data2 = transform.inverse(fwd_data2)\n\n        fwd_data = [fwd_data1, fwd_data2]\n        fwd_inv_data = [fwd_inv_data1, fwd_inv_data2]\n        for i, _fwd_inv_data in enumerate(fwd_inv_data):\n            if invertible:\n                for k in KEYS:\n                    # check data is same as original (and different from forward)\n                    self.assertEqual(_fwd_inv_data[k], data[k])\n                    self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k])\n            else:\n                # if not invertible, should not change the data\n                self.assertDictEqual(fwd_data[i], _fwd_inv_data)\n\n\nTEST_RANDOM_ORDER_EXTENDED_TEST_CASES = [\n    [None, tuple()],\n    [None, (sa.Rotate(np.pi / 8),)],\n    [None, (sa.Flip(0), sa.Flip(1), sa.Rotate90(1), sa.Zoom(0.8), ia.NormalizeIntensity())],\n    [(\"a\",), (sd.Rotated((\"a\",), np.pi / 8),)],\n]\n\n\nclass TestRandomOrderAPITests(unittest.TestCase):\n    @staticmethod\n    def data_from_keys(keys):\n        if keys is None:\n            data = torch.unsqueeze(torch.tensor(np.arange(12 * 16).reshape(12, 16)), dim=0)\n        else:\n            data = {}\n            for i_k, k in enumerate(keys):\n                data[k] = torch.unsqueeze(torch.tensor(np.arange(12 * 16)).reshape(12, 16) + i_k * 192, dim=0)\n        return data\n\n    @parameterized.expand(TEST_RANDOM_ORDER_EXTENDED_TEST_CASES)\n    def test_execute_change_start_end(self, keys, pipeline):\n        data = self.data_from_keys(keys)\n\n        c = RandomOrder(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, start=1)\n        with self.assertRaises(ValueError):\n            c(data, start=1)\n\n        c = RandomOrder(deepcopy(pipeline))\n        with self.assertRaises(ValueError):\n            c(data, end=1)\n        with self.assertRaises(ValueError):\n            c(data, end=1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_randtorchvisiond.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import RandomizableTrait, RandTorchVisiond\nfrom monai.utils import set_determinism\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [\n    {\"keys\": \"img\", \"name\": \"ColorJitter\"},\n    {\"img\": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},\n    torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"img\", \"name\": \"ColorJitter\", \"brightness\": 0.5, \"contrast\": 0.5, \"saturation\": [0.1, 0.8], \"hue\": 0.5},\n    {\"img\": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},\n    torch.tensor(\n        [\n            [[0.1090, 0.6193], [0.6193, 0.9164]],\n            [[0.1090, 0.6193], [0.6193, 0.9164]],\n            [[0.1090, 0.6193], [0.6193, 0.9164]],\n        ]\n    ),\n]\n\nTEST_CASE_3 = [\n    {\"keys\": \"img\", \"name\": \"Pad\", \"padding\": [1, 1, 1, 1]},\n    {\"img\": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},\n    torch.tensor(\n        [\n            [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n            [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n            [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n        ]\n    ),\n]\n\n\nclass TestRandTorchVisiond(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_value(self, input_param, input_data, expected_value):\n        set_determinism(seed=0)\n        transform = RandTorchVisiond(**input_param)\n        result = transform(input_data)\n        self.assertTrue(isinstance(transform, RandomizableTrait))\n        assert_allclose(result[\"img\"], expected_value, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_regularization.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd\nfrom tests.test_utils import assert_allclose\n\n\nclass TestMixup(unittest.TestCase):\n    def test_mixup(self):\n        for dims in [2, 3]:\n            shape = (6, 3) + (32,) * dims\n            sample = torch.rand(*shape, dtype=torch.float32)\n            mixup = MixUp(6, 1.0)\n            mixup.set_random_state(seed=0)\n            output = mixup(sample)\n            np.random.seed(0)\n            # simulate the randomize() of transform\n            np.random.random()\n            weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)\n            perm = np.random.permutation(6)\n            self.assertEqual(output.shape, sample.shape)\n            mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]\n            expected = mixweight * sample + (1 - mixweight) * sample[perm, ...]\n            assert_allclose(output, expected, type_test=False, atol=1e-7)\n\n        with self.assertRaises(ValueError):\n            MixUp(6, -0.5)\n\n        mixup = MixUp(6, 0.5)\n        for dims in [2, 3]:\n            with self.assertRaises(ValueError):\n                shape = (5, 3) + (32,) * dims\n                sample = torch.rand(*shape, dtype=torch.float32)\n                mixup(sample)\n\n    def test_mixupd(self):\n        for dims in [2, 3]:\n            shape = (6, 3) + (32,) * dims\n            t = torch.rand(*shape, dtype=torch.float32)\n            sample = {\"a\": t, \"b\": t}\n            mixup = MixUpd([\"a\", \"b\"], 6)\n            mixup.set_random_state(seed=0)\n            output = mixup(sample)\n            np.random.seed(0)\n            # simulate the randomize() of transform\n            np.random.random()\n            weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)\n            perm = np.random.permutation(6)\n            self.assertEqual(output[\"a\"].shape, sample[\"a\"].shape)\n            mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]\n            expected = mixweight * sample[\"a\"] + (1 - mixweight) * sample[\"a\"][perm, ...]\n            assert_allclose(output[\"a\"], expected, type_test=False, atol=1e-7)\n            assert_allclose(output[\"a\"], output[\"b\"], type_test=False, atol=1e-7)\n            # self.assertTrue(torch.allclose(output[\"a\"], output[\"b\"]))\n\n        with self.assertRaises(ValueError):\n            MixUpd([\"k1\", \"k2\"], 6, -0.5)\n\n\nclass TestCutMix(unittest.TestCase):\n    def test_cutmix(self):\n        for dims in [2, 3]:\n            shape = (6, 3) + (32,) * dims\n            sample = torch.rand(*shape, dtype=torch.float32)\n            cutmix = CutMix(6, 1.0)\n            cutmix.set_random_state(seed=0)\n            output = cutmix(sample)\n            self.assertEqual(output.shape, sample.shape)\n            self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))\n\n    def test_cutmixd(self):\n        for dims in [2, 3]:\n            shape = (6, 3) + (32,) * dims\n            t = torch.rand(*shape, dtype=torch.float32)\n            label = torch.randint(0, 1, shape)\n            sample = {\"a\": t, \"b\": t, \"lbl1\": label, \"lbl2\": label}\n            cutmix = CutMixd([\"a\", \"b\"], 6, label_keys=(\"lbl1\", \"lbl2\"))\n            cutmix.set_random_state(seed=123)\n            output = cutmix(sample)\n            # but mixing of labels is not affected by it\n            self.assertTrue(torch.allclose(output[\"lbl1\"], output[\"lbl2\"]))\n\n\nclass TestCutOut(unittest.TestCase):\n    def test_cutout(self):\n        for dims in [2, 3]:\n            shape = (6, 3) + (32,) * dims\n            sample = torch.rand(*shape, dtype=torch.float32)\n            cutout = CutOut(6, 1.0)\n            cutout.set_random_state(seed=123)\n            output = cutout(sample)\n            np.random.seed(123)\n            # simulate the randomize() of transform\n            np.random.random()\n            weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)\n            perm = np.random.permutation(6)\n            coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]]\n            assert_allclose(weight, cutout._params[0])\n            assert_allclose(perm, cutout._params[1])\n            self.assertSequenceEqual(coords, cutout._params[2])\n            self.assertEqual(output.shape, sample.shape)\n\n    def test_cutoutd(self):\n        for dims in [2, 3]:\n            shape = (6, 3) + (32,) * dims\n            t = torch.rand(*shape, dtype=torch.float32)\n            sample = {\"a\": t, \"b\": t}\n            cutout = CutOutd([\"a\", \"b\"], 6, 1.0)\n            cutout.set_random_state(seed=123)\n            output = cutout(sample)\n            np.random.seed(123)\n            # simulate the randomize() of transform\n            np.random.random()\n            weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)\n            perm = np.random.permutation(6)\n            coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]]\n            assert_allclose(weight, cutout.cutout._params[0])\n            assert_allclose(perm, cutout.cutout._params[1])\n            self.assertSequenceEqual(coords, cutout.cutout._params[2])\n            self.assertEqual(output[\"a\"].shape, sample[\"a\"].shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_remove_repeated_channel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import RemoveRepeatedChannel\nfrom tests.test_utils import TEST_NDARRAYS\n\nTEST_CASES = []\nfor q in TEST_NDARRAYS:\n    TEST_CASES.append([{\"repeats\": 2}, q([[1, 2], [1, 2], [3, 4], [3, 4]]), (2, 2)])\n\n\nclass TestRemoveRepeatedChannel(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_shape(self, input_param, input_data, expected_shape):\n        result = RemoveRepeatedChannel(**input_param)(input_data)\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_remove_repeated_channeld.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RemoveRepeatedChanneld\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": [\"img\"], \"repeats\": 2},\n            {\n                \"img\": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])),\n                \"seg\": p(np.array([[1, 2], [1, 2], [3, 4], [3, 4]])),\n            },\n            (2, 2),\n        ]\n    )\n\n\nclass TestRemoveRepeatedChanneld(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_data, expected_shape):\n        result = RemoveRepeatedChanneld(**input_param)(input_data)\n        self.assertEqual(result[\"img\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_repeat_channel.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import RepeatChannel\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([{\"repeats\": 3}, p([[[0, 1], [1, 2]]]), (3, 2, 2)])\n\n\nclass TestRepeatChannel(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_data, expected_shape):\n        result = RepeatChannel(**input_param)(input_data)\n        self.assertEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_repeat_channeld.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import RepeatChanneld\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": [\"img\"], \"repeats\": 3},\n            {\"img\": p(np.array([[[0, 1], [1, 2]]])), \"seg\": p(np.array([[[0, 1], [1, 2]]]))},\n            (3, 2, 2),\n        ]\n    )\n\n\nclass TestRepeatChanneld(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_data, expected_shape):\n        result = RepeatChanneld(**input_param)(input_data)\n        self.assertEqual(result[\"img\"].shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resample_backends.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.data import MetaTensor\nfrom monai.transforms import Resample\nfrom monai.transforms.utils import create_grid\nfrom monai.utils import GridSampleMode, GridSamplePadMode, NdimageMode, SplineMode, convert_to_numpy\nfrom tests.test_utils import assert_allclose, is_tf32_env\n\n_rtol = 1e-3 if is_tf32_env() else 1e-4\n\nTEST_IDENTITY = []\nfor interp in GridSampleMode if not USE_COMPILED else (\"nearest\", \"bilinear\"):  # type: ignore\n    for pad in GridSamplePadMode:\n        for p in (np.float32, np.float64):\n            for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n                TEST_IDENTITY.append([dict(device=device), p, interp, pad, (1, 3, 4)])\n                if interp != \"bicubic\":\n                    TEST_IDENTITY.append([dict(device=device), p, interp, pad, (1, 3, 5, 8)])\nfor interp_s in SplineMode if not USE_COMPILED else []:  # type: ignore\n    for pad_s in NdimageMode:\n        for p_s in (int, float, np.float32, np.float64):\n            for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n                TEST_IDENTITY.append([dict(device=device), p_s, interp_s, pad_s, (1, 20, 21)])\n                TEST_IDENTITY.append([dict(device=device), p_s, interp_s, pad_s, (1, 21, 23, 24)])\n\n\nclass TestResampleBackends(unittest.TestCase):\n    @parameterized.expand(TEST_IDENTITY)\n    def test_resample_identity(self, input_param, im_type, interp, padding, input_shape):\n        \"\"\"test resampling of an identity grid with padding 2, im_type, interp, padding, input_shape\"\"\"\n        xform = Resample(dtype=im_type, **input_param)\n        n_elem = np.prod(input_shape)\n        img = convert_to_numpy(np.arange(n_elem).reshape(input_shape), dtype=im_type)\n        grid = create_grid(input_shape[1:], homogeneous=True, backend=\"numpy\")\n        grid_p = np.stack([np.pad(g, 2, \"constant\") for g in grid])  # testing pad\n        output = xform(img=img, grid=grid_p, mode=interp, padding_mode=padding)\n        self.assertTrue(not torch.any(torch.isinf(output) | torch.isnan(output)))\n        self.assertIsInstance(output, MetaTensor)\n        slices = [slice(None)]\n        slices.extend([slice(2, -2) for _ in img.shape[1:]])\n        output_c = output[slices]\n        assert_allclose(output_c, img, rtol=_rtol, atol=1e-3, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resample_to_match.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport os\nimport random\nimport shutil\nimport string\nimport tempfile\nimport unittest\n\nimport nibabel as nib\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.data.image_reader import ITKReader, NibabelReader\nfrom monai.data.image_writer import ITKWriter\nfrom monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImage, SaveImaged\nfrom monai.utils import optional_import\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config\n\n_, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\n\nTEST_CASES = [\"itkreader\", \"nibabelreader\"]\n\n\ndef get_rand_fname(len=10, suffix=\".nii.gz\"):\n    letters = string.ascii_letters\n    out = \"\".join(random.choice(letters) for _ in range(len))\n    out += suffix\n    return out\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestResampleToMatch(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls.fnames = []\n        cls.tmpdir = tempfile.mkdtemp()\n        for key in (\"0000_t2_tse_tra_4\", \"0000_ep2d_diff_tra_7\"):\n            fname = os.path.join(cls.tmpdir, f\"test_{key}.nii.gz\")\n            url = testing_data_config(\"images\", key, \"url\")\n            hash_type = testing_data_config(\"images\", key, \"hash_type\")\n            hash_val = testing_data_config(\"images\", key, \"hash_val\")\n            download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val)\n            cls.fnames.append(fname)\n\n    @classmethod\n    def tearDownClass(cls):\n        shutil.rmtree(cls.tmpdir)\n        super().tearDownClass()\n\n    @parameterized.expand(itertools.product([NibabelReader, ITKReader], [\"monai.data.NibabelWriter\", ITKWriter]))\n    def test_correct(self, reader, writer):\n        loader = Compose([LoadImaged((\"im1\", \"im2\"), reader=reader), EnsureChannelFirstd((\"im1\", \"im2\"))])\n        data = loader({\"im1\": self.fnames[0], \"im2\": self.fnames[1]})\n        tr = ResampleToMatch()\n        im_mod = tr(data[\"im2\"], data[\"im1\"])\n\n        # check lazy resample\n        call_param = {\"img\": data[\"im2\"], \"img_dst\": data[\"im1\"]}\n        test_resampler_lazy(tr, im_mod, init_param={}, call_param=call_param)\n\n        saver = SaveImaged(\n            \"im3\", output_dir=self.tmpdir, output_postfix=\"\", separate_folder=False, writer=writer, resample=False\n        )\n        im_mod.meta[\"filename_or_obj\"] = get_rand_fname()\n        saver({\"im3\": im_mod})\n\n        saved = nib.load(os.path.join(self.tmpdir, im_mod.meta[\"filename_or_obj\"]))\n        assert_allclose(data[\"im1\"].shape[1:], saved.shape)\n        assert_allclose(saved.header[\"dim\"][:4], np.array([3, 384, 384, 19]))\n\n    def test_inverse(self):\n        loader = Compose([LoadImaged((\"im1\", \"im2\"), image_only=True), EnsureChannelFirstd((\"im1\", \"im2\"))])\n        data = loader({\"im1\": self.fnames[0], \"im2\": self.fnames[1]})\n        tr = ResampleToMatch()\n        im_mod = tr(data[\"im2\"], data[\"im1\"])\n        self.assertNotEqual(im_mod.shape, data[\"im2\"].shape)\n        self.assertGreater(((im_mod.affine - data[\"im2\"].affine) ** 2).sum() ** 0.5, 1e-2)\n        # inverse\n        im_mod2 = tr.inverse(im_mod)\n        self.assertEqual(im_mod2.shape, data[\"im2\"].shape)\n        self.assertLess(((im_mod2.affine - data[\"im2\"].affine) ** 2).sum() ** 0.5, 1e-2)\n        self.assertEqual(im_mod2.applied_operations, [])\n\n    def test_no_name(self):\n        img_1 = MetaTensor(torch.zeros(1, 2, 2, 2))\n        img_2 = MetaTensor(torch.zeros(1, 3, 3, 3))\n        im_mod = ResampleToMatch()(img_1, img_2)\n        self.assertEqual(im_mod.meta[\"filename_or_obj\"], \"resample_to_match_source\")\n        SaveImage(output_dir=self.tmpdir, output_postfix=\"\", separate_folder=False, resample=False)(im_mod)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resample_to_matchd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport shutil\nimport tempfile\nimport unittest\n\nfrom monai.transforms import (\n    Compose,\n    CopyItemsd,\n    EnsureChannelFirstd,\n    Invertd,\n    Lambda,\n    LoadImaged,\n    ResampleToMatchd,\n    SaveImaged,\n)\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import assert_allclose, download_url_or_skip_test, testing_data_config\n\n\ndef update_fname(d):\n    d[\"im3\"].meta[\"filename_or_obj\"] = \"file3.nii.gz\"\n    return d\n\n\nclass TestResampleToMatchd(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls.fnames = []\n        cls.tmpdir = tempfile.mkdtemp()\n        for key in (\"0000_t2_tse_tra_4\", \"0000_ep2d_diff_tra_7\"):\n            fname = os.path.join(cls.tmpdir, f\"test_{key}.nii.gz\")\n            url = testing_data_config(\"images\", key, \"url\")\n            hash_type = testing_data_config(\"images\", key, \"hash_type\")\n            hash_val = testing_data_config(\"images\", key, \"hash_val\")\n            download_url_or_skip_test(url=url, filepath=fname, hash_type=hash_type, hash_val=hash_val)\n            cls.fnames.append(fname)\n\n    @classmethod\n    def tearDownClass(cls):\n        shutil.rmtree(cls.tmpdir)\n        super().tearDownClass()\n\n    def test_correct(self):\n        transforms = Compose(\n            [\n                LoadImaged((\"im1\", \"im2\")),\n                EnsureChannelFirstd((\"im1\", \"im2\")),\n                CopyItemsd((\"im2\"), names=(\"im3\")),\n                ResampleToMatchd(\"im3\", \"im1\"),\n                Lambda(update_fname),\n                SaveImaged(\"im3\", output_dir=self.tmpdir, output_postfix=\"\", separate_folder=False, resample=False),\n            ]\n        )\n        data = transforms({\"im1\": self.fnames[0], \"im2\": self.fnames[1]})\n        # check that output sizes match\n        assert_allclose(data[\"im1\"].shape, data[\"im3\"].shape)\n        # and that the meta data has been updated accordingly\n        assert_allclose(data[\"im3\"].affine, data[\"im1\"].affine)\n        # check we're different from the original\n        self.assertTrue(any(i != j for i, j in zip(data[\"im3\"].shape, data[\"im2\"].shape)))\n        self.assertTrue(any(i != j for i, j in zip(data[\"im3\"].affine.flatten(), data[\"im2\"].affine.flatten())))\n        # test the inverse\n        data = Invertd(\"im3\", transforms)(data)\n        assert_allclose(data[\"im2\"].shape, data[\"im3\"].shape)\n\n    def test_lazy(self):\n        pre_transforms = Compose(\n            [LoadImaged((\"im1\", \"im2\")), EnsureChannelFirstd((\"im1\", \"im2\")), CopyItemsd((\"im2\"), names=(\"im3\"))]\n        )\n        data = pre_transforms({\"im1\": self.fnames[0], \"im2\": self.fnames[1]})\n        init_param = {\"keys\": \"im3\", \"key_dst\": \"im1\"}\n        resampler = ResampleToMatchd(**init_param)\n        call_param = {\"data\": data}\n        non_lazy_out = resampler(**call_param)\n        test_resampler_lazy(resampler, non_lazy_out, init_param, call_param, output_key=\"im3\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resampler.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Resample\nfrom monai.transforms.utils import create_grid\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS_ALL:\n    for q in TEST_NDARRAYS_ALL:\n        for device in [None, \"cpu\", \"cuda\"] if torch.cuda.is_available() else [None, \"cpu\"]:\n            TESTS.append(\n                [\n                    dict(padding_mode=\"zeros\", device=device),\n                    {\"grid\": p(create_grid((2, 2))), \"img\": q(np.arange(4).reshape((1, 2, 2)))},\n                    q(np.array([[[0.0, 1.0], [2.0, 3.0]]])),\n                ]\n            )\n            TESTS.append(\n                [\n                    dict(padding_mode=\"zeros\", device=device),\n                    {\"grid\": p(create_grid((4, 4))), \"img\": q(np.arange(4).reshape((1, 2, 2)))},\n                    q(\n                        np.array(\n                            [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 2.0, 3.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]\n                        )\n                    ),\n                ]\n            )\n            TESTS.append(\n                [\n                    dict(padding_mode=\"border\", device=device),\n                    {\"grid\": p(create_grid((4, 4))), \"img\": q(np.arange(4).reshape((1, 2, 2)))},\n                    q(\n                        np.array(\n                            [[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0], [2.0, 2.0, 3, 3.0], [2.0, 2.0, 3.0, 3.0]]]\n                        )\n                    ),\n                ]\n            )\n            # TESTS.append(  # not well defined nearest + reflection resampling\n            #     [\n            #         dict(padding_mode=\"reflection\", device=device),\n            #         {\"grid\": p(create_grid((4, 4))), \"img\": q(np.arange(4).reshape((1, 2, 2))), \"mode\": \"nearest\"},\n            #         q(\n            #             np.array(\n            #                 [[[3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0], [3.0, 2.0, 3.0, 2.0], [1.0, 0.0, 1.0, 0.0]]]\n            #             )\n            #         ),\n            #     ]\n            # )\n            TESTS.append(\n                [\n                    dict(padding_mode=\"zeros\", device=device),\n                    {\n                        \"grid\": p(create_grid((4, 4, 4))),\n                        \"img\": q(np.arange(8).reshape((1, 2, 2, 2))),\n                        \"mode\": \"bilinear\",\n                    },\n                    q(\n                        np.array(\n                            [\n                                [\n                                    [\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                    ],\n                                    [\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 1.0, 0.0],\n                                        [0.0, 2.0, 3.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                    ],\n                                    [\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 4.0, 5.0, 0.0],\n                                        [0.0, 6.0, 7.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                    ],\n                                    [\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                        [0.0, 0.0, 0.0, 0.0],\n                                    ],\n                                ]\n                            ]\n                        )\n                    ),\n                ]\n            )\n            TESTS.append(\n                [\n                    dict(padding_mode=\"border\", device=device),\n                    {\n                        \"grid\": p(create_grid((4, 4, 4))),\n                        \"img\": q(np.arange(8).reshape((1, 2, 2, 2))),\n                        \"mode\": \"bilinear\",\n                    },\n                    q(\n                        np.array(\n                            [\n                                [\n                                    [\n                                        [0.0, 0.0, 1.0, 1.0],\n                                        [0.0, 0.0, 1.0, 1.0],\n                                        [2.0, 2.0, 3.0, 3.0],\n                                        [2.0, 2.0, 3.0, 3.0],\n                                    ],\n                                    [\n                                        [0.0, 0.0, 1.0, 1.0],\n                                        [0.0, 0.0, 1.0, 1.0],\n                                        [2.0, 2.0, 3.0, 3.0],\n                                        [2.0, 2.0, 3.0, 3.0],\n                                    ],\n                                    [\n                                        [4.0, 4.0, 5.0, 5.0],\n                                        [4.0, 4.0, 5.0, 5.0],\n                                        [6.0, 6.0, 7.0, 7.0],\n                                        [6.0, 6.0, 7.0, 7.0],\n                                    ],\n                                    [\n                                        [4.0, 4.0, 5.0, 5.0],\n                                        [4.0, 4.0, 5.0, 5.0],\n                                        [6.0, 6.0, 7.0, 7.0],\n                                        [6.0, 6.0, 7.0, 7.0],\n                                    ],\n                                ]\n                            ]\n                        )\n                    ),\n                ]\n            )\n\n\nclass TestResample(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_resample(self, input_param, input_data, expected_val):\n        g = Resample(**input_param)\n        result = g(**input_data)\n        if \"device\" in input_data:\n            self.assertEqual(result.device, input_data[\"device\"])\n        assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resize.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport skimage.transform\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Resize\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, is_tf32_env\n\nTEST_CASE_0 = [{\"spatial_size\": 15}, (6, 10, 15)]\n\nTEST_CASE_1 = [{\"spatial_size\": 15, \"mode\": \"area\"}, (6, 10, 15)]\n\nTEST_CASE_2 = [{\"spatial_size\": 6, \"mode\": \"trilinear\", \"align_corners\": True}, (2, 4, 6)]\n\nTEST_CASE_2_1 = [{\"spatial_size\": 6, \"mode\": 1, \"align_corners\": True}, (2, 4, 6)]\n\nTEST_CASE_3 = [{\"spatial_size\": 15, \"anti_aliasing\": True}, (6, 10, 15)]\n\nTEST_CASE_4 = [{\"spatial_size\": 6, \"anti_aliasing\": True, \"anti_aliasing_sigma\": 2.0}, (2, 4, 6)]\n\ndiff_t = 0.3 if is_tf32_env() else 0.2\n\n\nclass TestResize(NumpyImageTestCase2D):\n    def test_invalid_inputs(self):\n        with self.assertRaises(ValueError):\n            resize = Resize(spatial_size=(128, 128, 3), mode=\"order\")\n            resize(self.imt[0])\n\n        with self.assertRaises(ValueError):\n            resize = Resize(spatial_size=(128,), mode=\"order\")\n            resize(self.imt[0])\n\n    def test_unchange(self):\n        resize = Resize(spatial_size=(128, 64), mode=\"bilinear\")\n        set_track_meta(False)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            result = resize(im)\n            assert_allclose(im, result, type_test=False)\n        set_track_meta(True)\n\n    @parameterized.expand(\n        [\n            ((32, -1), \"area\", True),\n            ((32, 32), \"area\", False),\n            ((32, 32, 32), \"trilinear\", True),\n            ((256, 256), \"bilinear\", False),\n            ((256, 256), \"nearest-exact\", False),\n            ((128, 128), \"nearest\", False),\n            ((128, 64), \"area\", True),  # already in a good shape\n        ]\n    )\n    def test_correct_results(self, spatial_size, mode, anti_aliasing):\n        \"\"\"resize 'spatial_size' and 'mode'\"\"\"\n        init_param = {\"spatial_size\": spatial_size, \"mode\": mode, \"anti_aliasing\": anti_aliasing, \"dtype\": np.float64}\n        resize = Resize(**init_param)\n        _order = 0\n        if mode.endswith(\"linear\"):\n            _order = 1\n        if spatial_size == (32, -1):\n            spatial_size = (32, 64)\n\n        expected = [\n            skimage.transform.resize(\n                channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing\n            )\n            for channel in self.imt[0]\n        ]\n\n        expected = np.stack(expected).astype(np.float32)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            out = resize(**call_param)\n            if init_param[\"mode\"] in (\"bilinear\", \"nearest\") and anti_aliasing is False:\n                test_resampler_lazy(resize, out, init_param, call_param)\n            if isinstance(im, MetaTensor):\n                im_inv = resize.inverse(out)\n                self.assertTrue(not im_inv.applied_operations)\n                assert_allclose(im_inv.shape, im.shape)\n                assert_allclose(im_inv.affine, im.affine, atol=1e-3, rtol=1e-3)\n            if not anti_aliasing:\n                assert_allclose(out, expected, type_test=False, atol=0.9)\n                return\n            # skimage uses reflect padding for anti-aliasing filter.\n            # Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead.\n            # Thus their results near the image boundary will be different.\n            if isinstance(out, torch.Tensor):\n                out = out.cpu().detach().numpy()\n            good = np.sum(np.isclose(expected, out, atol=0.9))\n            self.assertLessEqual(\n                np.abs(good - expected.size) / float(expected.size), diff_t, f\"at most {diff_t} percent mismatch \"\n            )\n\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_2_1, TEST_CASE_3, TEST_CASE_4])\n    def test_longest_shape(self, input_param, expected_shape):\n        input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])\n        input_param[\"size_mode\"] = \"longest\"\n        result = Resize(**input_param)(input_data)\n        np.testing.assert_allclose(result.shape[1:], expected_shape)\n\n        set_track_meta(False)\n        result = Resize(**input_param)(input_data)\n        self.assertNotIsInstance(result, MetaTensor)\n        np.testing.assert_allclose(result.shape[1:], expected_shape)\n        set_track_meta(True)\n\n    def test_longest_infinite_decimals(self):\n        resize = Resize(spatial_size=1008, size_mode=\"longest\", mode=\"bilinear\", align_corners=False)\n        ret = resize(np.random.randint(0, 2, size=[1, 2544, 3032]))\n        self.assertTupleEqual(ret.shape, (1, 846, 1008))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resize_with_pad_or_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import ResizeWithPadOrCrop\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTEST_CASES = [\n    [{\"spatial_size\": [15, 8, 8], \"mode\": \"constant\"}, (3, 8, 8, 4), (3, 15, 8, 8), True],\n    [{\"spatial_size\": [15, 4, -1], \"mode\": \"constant\"}, (3, 8, 8, 4), (3, 15, 4, 4), True],\n    [{\"spatial_size\": [15, 4, -1], \"mode\": \"reflect\"}, (3, 8, 8, 4), (3, 15, 4, 4), True],\n    [{\"spatial_size\": [-1, -1, -1], \"mode\": \"reflect\"}, (3, 8, 8, 4), (3, 8, 8, 4), True],\n    [\n        {\"spatial_size\": [15, 4, 8], \"mode\": \"constant\", \"method\": \"end\", \"constant_values\": 1},\n        (3, 8, 8, 4),\n        (3, 15, 4, 8),\n        True,\n    ],\n]\nTESTS_PENDING_MODE = {\"constant\": \"zeros\", \"edge\": \"border\", \"reflect\": \"reflection\"}\n\n\nclass TestResizeWithPadOrCrop(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_pad_shape(self, input_param, input_shape, expected_shape, _):\n        for p in TEST_NDARRAYS_ALL:\n            if isinstance(p(0), torch.Tensor) and (\n                \"constant_values\" in input_param or input_param[\"mode\"] == \"reflect\"\n            ):\n                continue\n            padcropper = ResizeWithPadOrCrop(**input_param)\n            result = padcropper(p(np.zeros(input_shape)))\n            np.testing.assert_allclose(result.shape, expected_shape)\n            result = padcropper(p(np.zeros(input_shape)), mode=\"constant\")\n            np.testing.assert_allclose(result.shape, expected_shape)\n            self.assertIsInstance(result, MetaTensor)\n            self.assertEqual(len(result.applied_operations), 1)\n            inv = padcropper.inverse(result)\n            self.assertTupleEqual(inv.shape, input_shape)\n            self.assertIsInstance(inv, MetaTensor)\n            self.assertEqual(inv.applied_operations, [])\n\n    @parameterized.expand(TEST_CASES)\n    def test_pending_ops(self, input_param, input_shape, _expected_data, align_corners):\n        for p in TEST_NDARRAYS_ALL:\n            # grid sample only support constant value to be zero\n            if \"constant_values\" in input_param and input_param[\"constant_values\"] != 0:\n                continue\n            padcropper = ResizeWithPadOrCrop(**input_param)\n            image = p(np.zeros(input_shape))\n            # non-lazy\n            expected = padcropper(image)\n            self.assertIsInstance(expected, MetaTensor)\n            # lazy\n            padcropper.lazy = True\n            pending_result = padcropper(image)\n            self.assertIsInstance(pending_result, MetaTensor)\n            assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n            assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n            # only support nearest\n            overrides = {\n                \"mode\": \"nearest\",\n                \"padding_mode\": TESTS_PENDING_MODE[input_param[\"mode\"]],\n                \"align_corners\": align_corners,\n            }\n            result = apply_pending(pending_result, overrides=overrides)[0]\n            # compare\n            assert_allclose(result, expected, rtol=1e-5)\n            inverted = padcropper.inverse(result)\n            self.assertEqual(inverted.shape, image.shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resize_with_pad_or_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import ResizeWithPadOrCropd\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\nfrom tests.transforms.test_resize_with_pad_or_crop import TESTS_PENDING_MODE\n\nTEST_CASES = [\n    [{\"keys\": \"img\", \"spatial_size\": [15, 8, 8], \"mode\": \"constant\"}, {\"img\": np.zeros((3, 8, 8, 4))}, (3, 15, 8, 8)],\n    [{\"keys\": \"img\", \"spatial_size\": [15, 4, -1], \"mode\": \"constant\"}, {\"img\": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)],\n    [{\"keys\": \"img\", \"spatial_size\": [15, 4, -1], \"mode\": \"reflect\"}, {\"img\": np.zeros((3, 8, 8, 4))}, (3, 15, 4, 4)],\n    [{\"keys\": \"img\", \"spatial_size\": [-1, -1, -1], \"mode\": \"reflect\"}, {\"img\": np.zeros((3, 8, 8, 4))}, (3, 8, 8, 4)],\n    [\n        {\"keys\": \"img\", \"spatial_size\": [15, 4, 8], \"mode\": \"constant\", \"method\": \"end\", \"constant_values\": 1},\n        {\"img\": np.zeros((3, 8, 8, 4))},\n        (3, 15, 4, 8),\n    ],\n]\n\n\nclass TestResizeWithPadOrCropd(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_pad_shape(self, input_param, input_data, expected_val):\n        for p in TEST_NDARRAYS_ALL:\n            if isinstance(p(0), torch.Tensor) and (\n                \"constant_values\" in input_param or input_param[\"mode\"] == \"reflect\"\n            ):\n                continue\n            padcropper = ResizeWithPadOrCropd(**input_param)\n            input_data_ = deepcopy(input_data)\n            input_data_[\"img\"] = p(input_data[\"img\"])\n            result = padcropper(input_data_)\n            np.testing.assert_allclose(result[\"img\"].shape, expected_val)\n            inv = padcropper.inverse(result)\n            for k in input_data_:\n                self.assertTupleEqual(inv[k].shape, input_data_[k].shape)\n\n    @parameterized.expand(TEST_CASES)\n    def test_pending_ops(self, input_param, input_data, _expected_data):\n        for p in TEST_NDARRAYS_ALL:\n            # grid sample only support constant value to be zero\n            if \"constant_values\" in input_param and input_param[\"constant_values\"] != 0:\n                continue\n            padcropper = ResizeWithPadOrCropd(**input_param)\n            input_data[\"img\"] = p(input_data[\"img\"])\n            # non-lazy\n            expected = padcropper(input_data)[\"img\"]\n            self.assertIsInstance(expected, MetaTensor)\n            # lazy\n            padcropper.lazy = True\n            pending_result = padcropper(input_data)[\"img\"]\n            self.assertIsInstance(pending_result, MetaTensor)\n            assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n            assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n            # only support nearest\n            overrides = {\n                \"mode\": \"nearest\",\n                \"padding_mode\": TESTS_PENDING_MODE[input_param[\"mode\"]],\n                \"align_corners\": True,\n            }\n            result = apply_pending(pending_result, overrides=overrides)[0]\n            # compare\n            assert_allclose(result, expected, rtol=1e-5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_resized.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport skimage.transform\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Invertd, Resize, Resized\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\nTEST_CASE_0 = [{\"keys\": \"img\", \"spatial_size\": 15}, (6, 10, 15)]\n\nTEST_CASE_1 = [\n    {\"keys\": \"img\", \"spatial_size\": 15, \"mode\": \"area\", \"anti_aliasing\": True, \"anti_aliasing_sigma\": None},\n    (6, 10, 15),\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"img\", \"spatial_size\": 6, \"mode\": \"trilinear\", \"align_corners\": True, \"anti_aliasing_sigma\": 2.0},\n    (2, 4, 6),\n]\n\nTEST_CASE_3 = [\n    {\n        \"keys\": [\"img\", \"label\"],\n        \"spatial_size\": 6,\n        \"mode\": [\"trilinear\", \"nearest\"],\n        \"align_corners\": [True, None],\n        \"anti_aliasing\": [False, True],\n        \"anti_aliasing_sigma\": (None, 2.0),\n    },\n    (2, 4, 6),\n]\n\nTEST_CORRECT_CASES = [\n    ((32, -1), \"area\", False),\n    ((64, 64), \"area\", True),\n    ((32, 32, 32), \"area\", True),\n    ((256, 256), \"bilinear\", False),\n    ((256, 256), \"bilinear\", True),\n    ((128, 128), \"nearest\", False),\n    ((128, 128), \"nearest\", True),\n]\n\n\nclass TestResized(NumpyImageTestCase2D):\n    def test_invalid_inputs(self):\n        with self.assertRaises(ValueError):\n            resize = Resized(keys=\"img\", spatial_size=(128, 128, 3), mode=\"order\")\n            resize({\"img\": self.imt[0]})\n\n        with self.assertRaises(ValueError):\n            resize = Resized(keys=\"img\", spatial_size=(128,), mode=\"order\")\n            resize({\"img\": self.imt[0]})\n\n    def test_unchange(self):\n        resize = Resized(keys=\"img\", spatial_size=(128, 64), mode=\"bilinear\")\n        set_track_meta(False)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            result = resize({\"img\": im})[\"img\"]\n            assert_allclose(im, result, type_test=False)\n        set_track_meta(True)\n\n    @parameterized.expand(TEST_CORRECT_CASES)\n    def test_correct_results(self, spatial_size, mode, anti_aliasing):\n        init_param = {\n            \"keys\": \"img\",\n            \"spatial_size\": spatial_size,\n            \"mode\": mode,\n            \"anti_aliasing\": anti_aliasing,\n            \"dtype\": np.float32,\n        }\n        resize = Resized(**init_param)\n        _order = 0\n        if mode.endswith(\"linear\"):\n            _order = 1\n        if spatial_size == (32, -1):\n            spatial_size = (32, 64)\n        expected = [\n            skimage.transform.resize(\n                channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing\n            )\n            for channel in self.imt[0]\n        ]\n\n        expected = np.stack(expected).astype(np.float32)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"data\": {\"img\": im}}\n            out = resize(**call_param)\n            lazy_resize = Resized(**init_param)\n            if init_param[\"mode\"] in (\"bilinear\", \"nearest\"):\n                test_resampler_lazy(lazy_resize, out, init_param, call_param, output_key=\"img\", atol=1e-5)\n            test_local_inversion(resize, out, {\"img\": im}, \"img\")\n            assert_allclose(out[\"img\"], expected, type_test=False, atol=1.0)\n\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_longest_shape(self, input_param, expected_shape):\n        input_data = {\n            \"img\": np.random.randint(0, 2, size=[3, 4, 7, 10]),\n            \"label\": np.random.randint(0, 2, size=[3, 4, 7, 10]),\n        }\n        input_param[\"size_mode\"] = \"longest\"\n        rescaler = Resized(**input_param)\n        result = rescaler(input_data)\n        for k in rescaler.keys:\n            np.testing.assert_allclose(result[k].shape[1:], expected_shape)\n        set_track_meta(False)\n        result = Resized(**input_param)(input_data)\n        self.assertNotIsInstance(result[\"img\"], MetaTensor)\n        np.testing.assert_allclose(result[\"img\"].shape[1:], expected_shape)\n        set_track_meta(True)\n\n    def test_identical_spatial(self):\n        test_input = {\"X\": np.ones((1, 10, 16, 17))}\n        xform = Resized(\"X\", (-1, 16, 17))\n        out = xform(test_input)\n        out[\"Y\"] = 2 * out[\"X\"]\n        transform_inverse = Invertd(keys=\"Y\", transform=xform, orig_keys=\"X\")\n        assert_allclose(transform_inverse(out)[\"Y\"].array, np.ones((1, 10, 16, 17)) * 2)\n\n    def test_consistent_resize(self):\n        spatial_size = (16, 16, 16)\n        rescaler_1 = Resize(spatial_size=spatial_size, anti_aliasing=True, anti_aliasing_sigma=(0.5, 1.0, 2.0))\n        rescaler_2 = Resize(spatial_size=spatial_size, anti_aliasing=True, anti_aliasing_sigma=None)\n        rescaler_dict = Resized(\n            keys=[\"img1\", \"img2\"],\n            spatial_size=spatial_size,\n            anti_aliasing=(True, True),\n            anti_aliasing_sigma=[(0.5, 1.0, 2.0), None],\n        )\n        test_input_1 = torch.randn([3, 32, 32, 32])\n        test_input_2 = torch.randn([3, 32, 32, 32])\n        test_input_dict = {\"img1\": test_input_1, \"img2\": test_input_2}\n        assert_allclose(rescaler_1(test_input_1), rescaler_dict(test_input_dict)[\"img1\"])\n        assert_allclose(rescaler_2(test_input_2), rescaler_dict(test_input_dict)[\"img2\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rotate.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport scipy.ndimage\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Rotate\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import (\n    HAS_CUPY,\n    TEST_NDARRAYS_ALL,\n    NumpyImageTestCase2D,\n    NumpyImageTestCase3D,\n    test_local_inversion,\n)\n\nTEST_CASES_2D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_2D.append((p, np.pi / 6, False, \"bilinear\", \"border\", False))\n    TEST_CASES_2D.append((p, np.pi / 4, True, \"bilinear\", \"border\", False))\n    TEST_CASES_2D.append((p, -np.pi / 4.5, True, \"nearest\", \"border\" if USE_COMPILED else \"reflection\", False))\n    TEST_CASES_2D.append((p, np.pi, False, \"nearest\", \"zeros\", False))\n    TEST_CASES_2D.append((p, -np.pi / 2, False, \"bilinear\", \"zeros\", True))\n    if HAS_CUPY:  # 1 and cuda image requires cupy\n        TEST_CASES_2D.append((p, -np.pi / 2, False, 1, \"constant\", True))\n\nTEST_CASES_3D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_3D.append((p, -np.pi / 2, True, \"nearest\", \"border\", False))\n    TEST_CASES_3D.append((p, np.pi / 4, True, \"bilinear\", \"border\", False))\n    TEST_CASES_3D.append((p, -np.pi / 4.5, True, \"nearest\", \"border\" if USE_COMPILED else \"reflection\", False))\n    TEST_CASES_3D.append((p, np.pi, False, \"nearest\", \"zeros\", False))\n    TEST_CASES_3D.append((p, -np.pi / 2, False, \"bilinear\", \"zeros\", False))\n    if HAS_CUPY:\n        TEST_CASES_3D.append((p, -np.pi / 2, False, 1, \"zeros\", False))\n\nTEST_CASES_SHAPE_3D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_SHAPE_3D.append((p, [-np.pi / 2, 1.0, 2.0], \"nearest\", \"border\", False))\n    TEST_CASES_SHAPE_3D.append((p, [np.pi / 4, 0, 0], \"bilinear\", \"border\", False))\n    TEST_CASES_SHAPE_3D.append((p, [-np.pi / 4.5, -20, 20], \"nearest\", \"reflection\", False))\n\n\nclass TestRotate2D(NumpyImageTestCase2D):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):\n        init_param = {\n            \"angle\": angle,\n            \"keep_size\": keep_size,\n            \"mode\": mode,\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = Rotate(**init_param)\n        call_param = {\"img\": im_type(self.imt[0])}\n        rotated = rotate_fn(**call_param)\n        test_resampler_lazy(rotate_fn, rotated, init_param, call_param, atol=1e-4 if USE_COMPILED else 1e-6)\n        if keep_size:\n            np.testing.assert_allclose(self.imt[0].shape, rotated.shape)\n        _order = 0 if mode == \"nearest\" else 1\n        if padding_mode == \"border\":\n            _mode = \"nearest\"\n        elif padding_mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n\n        expected = []\n        for channel in self.imt[0]:\n            expected.append(\n                scipy.ndimage.rotate(\n                    channel, -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False\n                )\n            )\n        expected = np.stack(expected).astype(np.float32)\n        rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated\n        good = np.sum(np.isclose(expected, rotated, atol=1e-3))\n        self.assertLessEqual(np.abs(good - expected.size), 5, \"diff at most 5 pixels\")\n\n\nclass TestRotate3D(NumpyImageTestCase3D):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):\n        init_param = {\n            \"angle\": [angle, 0, 0],\n            \"keep_size\": keep_size,\n            \"mode\": mode,\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = Rotate(**init_param)\n        call_param = {\"img\": im_type(self.imt[0])}\n        rotated = rotate_fn(**call_param)\n        test_resampler_lazy(rotate_fn, rotated, init_param, call_param)\n        if keep_size:\n            np.testing.assert_allclose(self.imt[0].shape, rotated.shape)\n        _order = 0 if mode == \"nearest\" else 1\n        if padding_mode == \"border\":\n            _mode = \"nearest\"\n        elif padding_mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n\n        expected = []\n        for channel in self.imt[0]:\n            expected.append(\n                scipy.ndimage.rotate(\n                    channel, -np.rad2deg(angle), (1, 2), not keep_size, order=_order, mode=_mode, prefilter=False\n                )\n            )\n        expected = np.stack(expected).astype(np.float32)\n        rotated = rotated.cpu() if isinstance(rotated, torch.Tensor) else rotated\n        n_good = np.sum(np.isclose(expected, rotated, atol=1e-3))\n        self.assertLessEqual(expected.size - n_good, 5, \"diff at most 5 pixels\")\n\n    @parameterized.expand(TEST_CASES_SHAPE_3D)\n    def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners):\n        rotate_fn = Rotate(angle, True, align_corners=align_corners, dtype=np.float64)\n        im = im_type(self.imt[0])\n        set_track_meta(False)\n        rotated = rotate_fn(im, mode=mode, padding_mode=padding_mode)\n        self.assertNotIsInstance(rotated, MetaTensor)\n        np.testing.assert_allclose(self.imt[0].shape, rotated.shape)\n        set_track_meta(True)\n        rotated = rotate_fn(im, mode=mode, padding_mode=padding_mode)\n        np.testing.assert_allclose(self.imt[0].shape, rotated.shape)\n        test_local_inversion(rotate_fn, rotated, im)\n\n    def test_ill_case(self):\n        for p in TEST_NDARRAYS_ALL:\n            rotate_fn = Rotate(10, True)\n            with self.assertRaises(ValueError):  # wrong shape\n                rotate_fn(p(self.imt))\n\n            rotate_fn = Rotate(10, keep_size=False)\n            with self.assertRaises(ValueError):  # wrong mode\n                rotate_fn(p(self.imt[0]), mode=\"trilinear_spell_error\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rotate90.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Affine, Rotate90\nfrom monai.transforms.lazy.functional import apply_pending\nfrom monai.utils import optional_import\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import (\n    TEST_NDARRAYS_ALL,\n    NumpyImageTestCase2D,\n    NumpyImageTestCase3D,\n    assert_allclose,\n    test_local_inversion,\n)\n\n\nclass TestRotate90(NumpyImageTestCase2D):\n    def test_rotate90_default(self):\n        rotate = Rotate90()\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            set_track_meta(True)\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n            set_track_meta(False)\n            rotated = rotate(im)\n            self.assertNotIsInstance(rotated, MetaTensor)\n            set_track_meta(True)\n\n    def test_k(self):\n        rotate = Rotate90(k=2)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n    def test_spatial_axes(self):\n        rotate = Rotate90(spatial_axes=(0, -1))\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n    def test_prob_k_spatial_axes(self):\n        rotate = Rotate90(k=2, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n\nclass TestRotate903d(NumpyImageTestCase3D):\n    def test_rotate90_default(self):\n        rotate = Rotate90()\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n    def test_k(self):\n        rotate = Rotate90(k=2)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n    def test_spatial_axes(self):\n        rotate = Rotate90(spatial_axes=(0, -1))\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n    def test_prob_k_spatial_axes(self):\n        rotate = Rotate90(k=2, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"img\": im}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, im)\n            expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=\"tensor\")\n\n\n@unittest.skipUnless(optional_import(\"scipy\")[1], \"Requires scipy library.\")\nclass TestRot90Consistency(unittest.TestCase):\n    @parameterized.expand([[2], [3], [4]])\n    def test_affine_rot90(self, s):\n        \"\"\"s\"\"\"\n        im = np.arange(int(s * s)).reshape(1, s, s).astype(float)\n        mat = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])\n\n        def method_0(im, ac):\n            xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s)\n            xform.lazy = True\n            out = xform(im)\n            out = apply_pending(out, overrides={\"padding_mode\": \"border\", \"align_corners\": ac})[0]\n            return out\n\n        def method_1(im, ac):\n            xform = Affine(align_corners=ac, affine=mat, image_only=True, spatial_size=s)\n            xform.lazy = True\n            out = xform(im)\n            out = apply_pending(out, overrides={\"mode\": 1, \"padding_mode\": \"nearest\", \"align_corners\": ac})[0]\n            return out\n\n        def method_2(im, ac):\n            xform = Affine(align_corners=ac, affine=mat, padding_mode=\"border\", image_only=True, spatial_size=s)\n            out = xform(im)\n            return out\n\n        def method_3(im, ac):\n            xform = Affine(\n                align_corners=ac, affine=mat, mode=1, padding_mode=\"nearest\", image_only=True, spatial_size=s\n            )\n            out = xform(im)\n            return out\n\n        for call in (method_0, method_1, method_2, method_3):\n            for ac in (False, True):\n                out = call(im, ac)\n                ref = Rotate90()(im)\n                assert_allclose(out, ref, rtol=1e-4, atol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rotate90d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Rotate90d\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\n\nclass TestRotate90d(NumpyImageTestCase2D):\n    def test_rotate90_default(self):\n        key = \"test\"\n        rotate = Rotate90d(keys=key)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            set_track_meta(True)\n            call_param = {\"data\": {key: im}}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, {key: im}, key)\n            expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n            set_track_meta(False)\n            rotated = rotate({key: im})\n            self.assertNotIsInstance(rotated[key], MetaTensor)\n            set_track_meta(True)\n\n    def test_k(self):\n        key = \"test\"\n        rotate = Rotate90d(keys=key, k=2)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"data\": {key: im}}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, {key: im}, key)\n            expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n\n    def test_spatial_axes(self):\n        key = \"test\"\n        rotate = Rotate90d(keys=key, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"data\": {key: im}}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, {key: im}, key)\n            expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n\n    def test_prob_k_spatial_axes(self):\n        key = \"test\"\n        rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1))\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"data\": {key: im}}\n            rotated = rotate(**call_param)\n\n            # test lazy\n            test_resampler_lazy(rotate, rotated, call_param=call_param, output_key=key)\n            rotate.lazy = False\n\n            test_local_inversion(rotate, rotated, {key: im}, key)\n            expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]]\n            expected = np.stack(expected)\n            assert_allclose(rotated[key], p(expected), type_test=\"tensor\")\n\n    def test_no_key(self):\n        key = \"unknown\"\n        rotate = Rotate90d(keys=key)\n        with self.assertRaisesRegex(KeyError, \"\"):\n            rotate({\"test\": self.imt[0]})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_rotated.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport scipy.ndimage\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.data import MetaTensor\nfrom monai.transforms import Rotated\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion\n\nTEST_CASES_2D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_2D.append((p, -np.pi / 6, False, \"bilinear\", \"border\", False))\n    TEST_CASES_2D.append((p, -np.pi / 4, True, \"bilinear\", \"border\", False))\n    TEST_CASES_2D.append((p, np.pi / 4.5, True, \"nearest\", \"reflection\", False))\n    TEST_CASES_2D.append((p, -np.pi, False, \"nearest\", \"zeros\", False))\n    TEST_CASES_2D.append((p, np.pi / 2, False, \"bilinear\", \"zeros\", True))\n\nTEST_CASES_3D: list[tuple] = []\nfor p in TEST_NDARRAYS_ALL:\n    TEST_CASES_3D.append((p, -np.pi / 6, False, \"bilinear\", \"border\", False))\n    TEST_CASES_3D.append((p, -np.pi / 4, True, \"bilinear\", \"border\", False))\n    TEST_CASES_3D.append((p, np.pi / 4.5, True, \"nearest\", \"reflection\", False))\n    TEST_CASES_3D.append((p, -np.pi, False, \"nearest\", \"zeros\", False))\n    TEST_CASES_3D.append((p, np.pi / 2, False, \"bilinear\", \"zeros\", True))\n\n\n@unittest.skipIf(USE_COMPILED, \"unittests are not designed for both USE_COMPILED=True/False\")\nclass TestRotated2D(NumpyImageTestCase2D):\n    @parameterized.expand(TEST_CASES_2D)\n    def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):\n        init_param = {\n            \"keys\": (\"img\", \"seg\"),\n            \"angle\": angle,\n            \"keep_size\": keep_size,\n            \"mode\": (mode, \"nearest\"),\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = Rotated(**init_param)\n        im = im_type(self.imt[0])\n        call_param = {\"data\": {\"img\": im, \"seg\": im_type(self.segn[0])}}\n        rotated = rotate_fn(**call_param)\n        # test lazy\n        lazy_init_param = init_param.copy()\n        for k, m in zip(init_param[\"keys\"], init_param[\"mode\"]):\n            lazy_init_param[\"keys\"], lazy_init_param[\"mode\"] = k, m\n            test_resampler_lazy(\n                rotate_fn, rotated, lazy_init_param, call_param, output_key=k, atol=1e-4 if USE_COMPILED else 1e-6\n            )\n        if keep_size:\n            np.testing.assert_allclose(self.imt[0].shape, rotated[\"img\"].shape)\n        _order = 0 if mode == \"nearest\" else 1\n        if padding_mode == \"border\":\n            _mode = \"nearest\"\n        elif padding_mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n        expected = scipy.ndimage.rotate(\n            self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False\n        )\n        for k, v in rotated.items():\n            rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v\n        good = np.sum(np.isclose(expected, rotated[\"img\"][0], atol=1e-3))\n        self.assertLessEqual(np.abs(good - expected.size), 5, \"diff at most 5 pixels\")\n        test_local_inversion(rotate_fn, rotated, {\"img\": im}, \"img\")\n\n        expected = scipy.ndimage.rotate(\n            self.segn[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=0, mode=_mode, prefilter=False\n        )\n        expected = np.stack(expected).astype(int)\n        if isinstance(rotated[\"seg\"], MetaTensor):\n            rotated[\"seg\"] = rotated[\"seg\"].as_tensor()  # pytorch 1.7 compatible\n        self.assertLessEqual(np.count_nonzero(expected != rotated[\"seg\"][0]), 30)\n\n\n@unittest.skipIf(USE_COMPILED, \"unittests are not designed for both USE_COMPILED=True/False\")\nclass TestRotated3D(NumpyImageTestCase3D):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):\n        init_param = {\n            \"keys\": (\"img\", \"seg\"),\n            \"angle\": [0, angle, 0],\n            \"keep_size\": keep_size,\n            \"mode\": (mode, \"nearest\"),\n            \"padding_mode\": padding_mode,\n            \"align_corners\": align_corners,\n            \"dtype\": np.float64,\n        }\n        rotate_fn = Rotated(**init_param)\n        call_param = {\"data\": {\"img\": im_type(self.imt[0]), \"seg\": im_type(self.segn[0])}}\n        rotated = rotate_fn(**call_param)\n        # test lazy\n        lazy_init_param = init_param.copy()\n        for k, m in zip(init_param[\"keys\"], init_param[\"mode\"]):\n            lazy_init_param[\"keys\"], lazy_init_param[\"mode\"] = k, m\n            test_resampler_lazy(\n                rotate_fn, rotated, lazy_init_param, call_param, output_key=k, atol=1e-4 if USE_COMPILED else 1e-6\n            )\n        if keep_size:\n            np.testing.assert_allclose(self.imt[0].shape, rotated[\"img\"].shape)\n        _order = 0 if mode == \"nearest\" else 1\n        if padding_mode == \"border\":\n            _mode = \"nearest\"\n        elif padding_mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n        expected = scipy.ndimage.rotate(\n            self.imt[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=_order, mode=_mode, prefilter=False\n        )\n        for k, v in rotated.items():\n            rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v\n        good = np.sum(np.isclose(expected.astype(np.float32), rotated[\"img\"][0], atol=1e-3))\n        self.assertLessEqual(np.abs(good - expected.size), 5, \"diff at most 5 voxels.\")\n\n        expected = scipy.ndimage.rotate(\n            self.segn[0, 0], np.rad2deg(angle), (0, 2), not keep_size, order=0, mode=_mode, prefilter=False\n        )\n        expected = np.stack(expected).astype(int)\n        if isinstance(rotated[\"seg\"], MetaTensor):\n            rotated[\"seg\"] = rotated[\"seg\"].as_tensor()  # pytorch 1.7 compatible\n        self.assertLessEqual(np.count_nonzero(expected != rotated[\"seg\"][0]), 160)\n\n\n@unittest.skipIf(USE_COMPILED, \"unittests are not designed for both USE_COMPILED=True/False\")\nclass TestRotated3DXY(NumpyImageTestCase3D):\n    @parameterized.expand(TEST_CASES_3D)\n    def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, align_corners):\n        rotate_fn = Rotated(\n            (\"img\", \"seg\"), [0, 0, angle], keep_size, (mode, \"nearest\"), padding_mode, align_corners, dtype=np.float64\n        )\n        rotated = rotate_fn({\"img\": im_type(self.imt[0]), \"seg\": im_type(self.segn[0])})\n        if keep_size:\n            np.testing.assert_allclose(self.imt[0].shape, rotated[\"img\"].shape)\n        _order = 0 if mode == \"nearest\" else 1\n        if padding_mode == \"border\":\n            _mode = \"nearest\"\n        elif padding_mode == \"reflection\":\n            _mode = \"reflect\"\n        else:\n            _mode = \"constant\"\n        expected = scipy.ndimage.rotate(\n            self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False\n        )\n        for k, v in rotated.items():\n            rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v\n        good = np.sum(np.isclose(expected, rotated[\"img\"][0], atol=1e-3))\n        self.assertLessEqual(np.abs(good - expected.size), 5, \"diff at most 5 voxels\")\n\n        expected = scipy.ndimage.rotate(\n            self.segn[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=0, mode=_mode, prefilter=False\n        )\n        expected = np.stack(expected).astype(int)\n        if isinstance(rotated[\"seg\"], MetaTensor):\n            rotated[\"seg\"] = rotated[\"seg\"].as_tensor()  # pytorch 1.7 compatible\n        self.assertLessEqual(np.count_nonzero(expected != rotated[\"seg\"][0]), 160)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_save_classificationd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport csv\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom monai.data import CSVSaver, decollate_batch\nfrom monai.transforms import Compose, CopyItemsd, SaveClassificationd\nfrom monai.utils.enums import PostFix\n\n\nclass TestSaveClassificationd(unittest.TestCase):\n\n    def test_saved_content(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            data = [\n                {\n                    \"pred\": torch.zeros(8),\n                    PostFix.meta(\"image\"): {\"filename_or_obj\": [\"testfile\" + str(i) for i in range(8)]},\n                },\n                {\n                    \"pred\": torch.zeros(8),\n                    PostFix.meta(\"image\"): {\"filename_or_obj\": [\"testfile\" + str(i) for i in range(8, 16)]},\n                },\n                {\n                    \"pred\": torch.zeros(8),\n                    PostFix.meta(\"image\"): {\"filename_or_obj\": [\"testfile\" + str(i) for i in range(16, 24)]},\n                },\n            ]\n\n            saver = CSVSaver(\n                output_dir=Path(tempdir), filename=\"predictions2.csv\", overwrite=False, flush=False, delimiter=\"\\t\"\n            )\n            # set up test transforms\n            post_trans = Compose(\n                [\n                    CopyItemsd(keys=PostFix.meta(\"image\"), times=1, names=PostFix.meta(\"pred\")),\n                    # 1st saver saves data into CSV file\n                    SaveClassificationd(\n                        keys=\"pred\",\n                        saver=None,\n                        meta_keys=None,\n                        output_dir=Path(tempdir),\n                        filename=\"predictions1.csv\",\n                        delimiter=\"\\t\",\n                        overwrite=True,\n                    ),\n                    # 2rd saver only saves data into the cache, manually finalize later\n                    SaveClassificationd(keys=\"pred\", saver=saver, meta_key_postfix=PostFix.meta()),\n                ]\n            )\n            # simulate inference 2 iterations\n            d = decollate_batch(data[0])\n            for i in d:\n                post_trans(i)\n            d = decollate_batch(data[1])\n            for i in d:\n                post_trans(i)\n            # write into CSV file\n            saver.finalize()\n\n            # 3rd saver will not delete previous data due to `overwrite=False`\n            trans2 = SaveClassificationd(\n                keys=\"pred\",\n                saver=None,\n                meta_keys=PostFix.meta(\"image\"),  # specify meta key, so no need to copy anymore\n                output_dir=tempdir,\n                filename=\"predictions1.csv\",\n                delimiter=\"\\t\",\n                overwrite=False,\n            )\n            d = decollate_batch(data[2])\n            for i in d:\n                trans2(i)\n\n            def _test_file(filename, count):\n                filepath = os.path.join(tempdir, filename)\n                self.assertTrue(os.path.exists(filepath))\n                with open(filepath) as f:\n                    reader = csv.reader(f, delimiter=\"\\t\")\n                    i = 0\n                    for row in reader:\n                        self.assertEqual(row[0], \"testfile\" + str(i))\n                        self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0)\n                        i += 1\n                    self.assertEqual(i, count)\n\n            _test_file(\"predictions1.csv\", 24)\n            _test_file(\"predictions2.csv\", 16)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_save_image.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import SaveImage\nfrom monai.utils import optional_import\n\n_, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\n\nTEST_CASE_1 = [torch.randint(0, 255, (1, 2, 3, 4)), {\"filename_or_obj\": \"testfile0.nii.gz\"}, \".nii.gz\", False]\n\nTEST_CASE_2 = [torch.randint(0, 255, (1, 2, 3, 4)), None, \".nii.gz\", False]\n\nTEST_CASE_3 = [torch.randint(0, 255, (1, 2, 3, 4)), {\"filename_or_obj\": \"testfile0.nrrd\"}, \".nrrd\", False]\n\nTEST_CASE_4 = [\n    torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8),\n    {\"filename_or_obj\": \"testfile0.dcm\"},\n    \".dcm\",\n    False,\n]\n\nTEST_CASE_5 = [torch.randint(0, 255, (3, 2, 4, 5), dtype=torch.uint8), \".dcm\", False]\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestSaveImage(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_saved_content(self, test_data, meta_data, output_ext, resample):\n        if meta_data is not None:\n            test_data = MetaTensor(test_data, meta=meta_data)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            trans = SaveImage(\n                output_dir=tempdir,\n                output_ext=output_ext,\n                resample=resample,\n                separate_folder=False,  # test saving into the same folder\n                output_name_formatter=lambda x, xform: dict(subject=x[\"filename_or_obj\"] if x else \"0\"),\n            )\n            trans(test_data)\n\n            filepath = \"testfile0\" if meta_data is not None else \"0\"\n            self.assertTrue(os.path.exists(os.path.join(tempdir, filepath + \"_trans\" + output_ext)))\n\n    @parameterized.expand([TEST_CASE_5])\n    def test_saved_content_with_filename(self, test_data, output_ext, resample):\n        with tempfile.TemporaryDirectory() as tempdir:\n            trans = SaveImage(\n                output_dir=tempdir,\n                output_ext=output_ext,\n                resample=resample,\n                separate_folder=False,  # test saving into the same folder\n            )\n            filename = str(os.path.join(tempdir, \"test\"))\n            trans(test_data, filename=filename)\n\n            self.assertTrue(os.path.exists(filename + output_ext))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_save_imaged.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import PathLike\nfrom monai.data.folder_layout import FolderLayoutBase\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import SaveImaged\nfrom monai.utils import optional_import\n\n_, has_itk = optional_import(\"itk\", allow_namespace_pkg=True)\n\nTEST_CASE_1 = [\n    {\"img\": MetaTensor(torch.randint(0, 255, (1, 2, 3, 4)), meta={\"filename_or_obj\": \"testfile0.nii.gz\"})},\n    \".nii.gz\",\n    False,\n]\n\nTEST_CASE_2 = [\n    {\n        \"img\": MetaTensor(torch.randint(0, 255, (1, 2, 3, 4)), meta={\"filename_or_obj\": \"testfile0.nii.gz\"}),\n        \"patch_index\": 6,\n    },\n    \".nii.gz\",\n    False,\n]\n\nTEST_CASE_3 = [\n    {\n        \"img\": MetaTensor(torch.randint(0, 255, (1, 2, 3, 4)), meta={\"filename_or_obj\": \"testfile0.nrrd\"}),\n        \"patch_index\": 6,\n    },\n    \".nrrd\",\n    False,\n]\n\n\n@unittest.skipUnless(has_itk, \"itk not installed\")\nclass TestSaveImaged(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_saved_content(self, test_data, output_ext, resample):\n        with tempfile.TemporaryDirectory() as tempdir:\n            trans = SaveImaged(\n                keys=[\"img\", \"pred\"],\n                output_dir=tempdir,\n                output_ext=output_ext,\n                resample=resample,\n                allow_missing_keys=True,\n            )\n            trans(test_data)\n\n            patch_index = test_data[\"img\"].meta.get(\"patch_index\", None)\n            patch_index = f\"_{patch_index}\" if patch_index is not None else \"\"\n            filepath = os.path.join(\"testfile0\", \"testfile0\" + \"_trans\" + patch_index + output_ext)\n            self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_custom_folderlayout(self, test_data, output_ext, resample):\n\n        class TestFolderLayout(FolderLayoutBase):\n\n            def __init__(self, basepath: Path, extension: str, makedirs: bool):\n                self.basepath = basepath\n                self.ext = extension\n                self.makedirs = makedirs\n\n            def filename(self, **kwargs) -> PathLike:\n                p = self.basepath / str(kwargs[\"subdirectory\"])\n                if not p.exists() and self.makedirs:\n                    p.mkdir()\n\n                return p / (str(kwargs[\"filename\"]) + self.ext)\n\n        def name_formatter(metadict: dict, _) -> dict:\n            # \"[filename].[ext]\"\n            # quick and dirty split on .\n            base_filename = metadict[\"filename_or_obj\"].split(\".\")[0]\n\n            return {\"subdirectory\": base_filename, \"filename\": \"image\"}\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            trans = SaveImaged(\n                keys=[\"img\", \"pred\"],\n                resample=resample,\n                allow_missing_keys=True,\n                output_name_formatter=name_formatter,\n                folder_layout=TestFolderLayout(basepath=Path(tempdir), extension=output_ext, makedirs=True),\n            )\n            trans(test_data)\n\n            filepath = os.path.join(\"testfile0\", \"image\" + output_ext)\n            self.assertTrue(os.path.exists(os.path.join(tempdir, filepath)))\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_includes_metadata(self, test_data, output_ext, resample):\n        with tempfile.TemporaryDirectory() as tempdir:\n            trans = SaveImaged(\n                keys=[\"img\", \"pred\"],\n                output_dir=tempdir,\n                output_ext=output_ext,\n                resample=resample,\n                allow_missing_keys=True,\n                savepath_in_metadict=True,\n            )\n            trans(test_data)\n\n            self.assertTrue(\"saved_to\" in test_data[\"img\"].meta.keys())\n            self.assertTrue(os.path.exists(test_data[\"img\"].meta[\"saved_to\"]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_savitzky_golay_smooth.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SavitzkyGolaySmooth\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n# Zero-padding trivial tests\n\nTEST_CASE_SINGLE_VALUE = [\n    {\"window_length\": 3, \"order\": 1},\n    np.expand_dims(np.array([1.0]), 0),  # Input data: Single value\n    np.expand_dims(np.array([1 / 3]), 0),  # Expected output: With a window length of 3 and polyorder 1\n    # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)\n    1e-5,  # absolute tolerance\n]\n\nTEST_CASE_2D_AXIS_2 = [\n    {\"window_length\": 3, \"order\": 1, \"axis\": 2},  # along axis 2 (second spatial dim)\n    np.expand_dims(np.ones((2, 3)), 0),\n    np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0),\n    1e-5,  # absolute tolerance\n]\n\n# Replicated-padding trivial tests\n\nTEST_CASE_SINGLE_VALUE_REP = [\n    {\"window_length\": 3, \"order\": 1, \"mode\": \"replicate\"},\n    np.expand_dims(np.array([1.0]), 0),  # Input data: Single value\n    np.expand_dims(np.array([1.0]), 0),  # Expected output: With a window length of 3 and polyorder 1\n    # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)\n    1e-5,  # absolute tolerance\n]\n\n# Sine smoothing\n\nTEST_CASE_SINE_SMOOTH = [\n    {\"window_length\": 3, \"order\": 1},\n    # Sine wave with period equal to savgol window length (windowed to reduce edge effects).\n    np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0),\n    # Should be smoothed out to zeros\n    np.expand_dims(np.zeros(100), 0),\n    # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input\n    2e-2,  # absolute tolerance\n]\n\n\nclass TestSavitzkyGolaySmooth(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP]\n    )\n    def test_value(self, arguments, image, expected_data, atol):\n        for p in TEST_NDARRAYS:\n            result = SavitzkyGolaySmooth(**arguments)(p(image.astype(np.float32)))\n            assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_savitzky_golay_smoothd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SavitzkyGolaySmoothd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n# Zero-padding trivial tests\n\nTEST_CASE_SINGLE_VALUE = [\n    {\"keys\": \"img\", \"window_length\": 3, \"order\": 1},\n    np.expand_dims(np.array([1.0]), 0),  # Input data: Single value\n    np.expand_dims(np.array([1 / 3]), 0),  # Expected output: With a window length of 3 and polyorder 1\n    # output should be equal to mean of 0, 1 and 0 = 1/3 (because input will be zero-padded and a linear fit performed)\n    1e-5,  # absolute tolerance\n]\n\nTEST_CASE_2D_AXIS_2 = [\n    {\"keys\": \"img\", \"window_length\": 3, \"order\": 1, \"axis\": 2},  # along axis 2 (second spatial dim)\n    np.expand_dims(np.ones((2, 3)), 0),\n    np.expand_dims(np.array([[2 / 3, 1.0, 2 / 3], [2 / 3, 1.0, 2 / 3]]), 0),\n    1e-5,  # absolute tolerance\n]\n\n# Replicated-padding trivial tests\n\nTEST_CASE_SINGLE_VALUE_REP = [\n    {\"keys\": \"img\", \"window_length\": 3, \"order\": 1, \"mode\": \"replicate\"},\n    np.expand_dims(np.array([1.0]), 0),  # Input data: Single value\n    np.expand_dims(np.array([1.0]), 0),  # Expected output: With a window length of 3 and polyorder 1\n    # output will be equal to mean of [1, 1, 1] = 1 (input will be nearest-neighbour-padded and a linear fit performed)\n    1e-5,  # absolute tolerance\n]\n\n# Sine smoothing\n\nTEST_CASE_SINE_SMOOTH = [\n    {\"keys\": \"img\", \"window_length\": 3, \"order\": 1},\n    # Sine wave with period equal to savgol window length (windowed to reduce edge effects).\n    np.expand_dims(np.sin(2 * np.pi * 1 / 3 * np.arange(100)) * np.hanning(100), 0),\n    # Should be smoothed out to zeros\n    np.expand_dims(np.zeros(100), 0),\n    # tolerance chosen by examining output of SciPy.signal.savgol_filter() when provided the above input\n    2e-2,  # absolute tolerance\n]\n\n\nclass TestSavitzkyGolaySmoothd(unittest.TestCase):\n    @parameterized.expand(\n        [TEST_CASE_SINGLE_VALUE, TEST_CASE_2D_AXIS_2, TEST_CASE_SINE_SMOOTH, TEST_CASE_SINGLE_VALUE_REP]\n    )\n    def test_value(self, arguments, image, expected_data, atol):\n        for p in TEST_NDARRAYS:\n            result = SavitzkyGolaySmoothd(**arguments)({\"img\": p(image.astype(np.float32))})[\"img\"]\n            assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_scale_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ScaleIntensity\nfrom monai.transforms.utils import rescale_array\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestScaleIntensity(NumpyImageTestCase2D):\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_range_scale(self, p):\n        scaler = ScaleIntensity(minv=1.0, maxv=2.0)\n        im = p(self.imt)\n        result = scaler(im)\n        mina = self.imt.min()\n        maxa = self.imt.max()\n        norm = (self.imt - mina) / (maxa - mina)\n        expected = p((norm * (2.0 - 1.0)) + 1.0)\n        assert_allclose(result, expected, type_test=\"tensor\", rtol=1e-7, atol=0)\n\n    def test_factor_scale(self):\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1)\n            result = scaler(p(self.imt))\n            expected = p((self.imt * (1 + 0.1)).astype(np.float32))\n            assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-7, atol=0)\n\n    def test_max_none(self):\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensity(minv=0.0, maxv=None, factor=0.1)\n            result = scaler(p(self.imt))\n            expected = rescale_array(p(self.imt), minv=0.0, maxv=None)\n            assert_allclose(result, expected, type_test=\"tensor\", rtol=1e-3, atol=1e-3)\n\n    def test_int(self):\n        \"\"\"integers should be handled by converting them to floats first.\"\"\"\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensity(minv=1.0, maxv=2.0)\n            result = scaler(p(self.imt.astype(int)))\n            _imt = self.imt.astype(int).astype(np.float32)\n            mina = _imt.min()\n            maxa = _imt.max()\n            norm = (_imt - mina) / (maxa - mina)\n            expected = p((norm * (2.0 - 1.0)) + 1.0)\n            assert_allclose(result, expected, type_test=\"tensor\", rtol=1e-7, atol=0)\n\n    def test_channel_wise(self):\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensity(minv=1.0, maxv=2.0, channel_wise=True)\n            data = p(np.tile(self.imt, (3, 1, 1, 1)))\n            result = scaler(data)\n            mina = self.imt.min()\n            maxa = self.imt.max()\n            for i, c in enumerate(data):\n                norm = (c - mina) / (maxa - mina)\n                expected = p((norm * (2.0 - 1.0)) + 1.0)\n                assert_allclose(result[i], expected, type_test=\"tensor\", rtol=1e-7, atol=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_scale_intensity_fixed_mean.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ScaleIntensityFixedMean\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestScaleIntensityFixedMean(NumpyImageTestCase2D):\n    def test_factor_scale(self):\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensityFixedMean(factor=0.1, fixed_mean=False)\n            result = scaler(p(self.imt))\n            expected = p((self.imt * (1 + 0.1)).astype(np.float32))\n            assert_allclose(result, p(expected), type_test=\"tensor\", rtol=1e-7, atol=0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_preserve_range(self, p):\n        for channel_wise in [False, True]:\n            factor = 0.9\n            scaler = ScaleIntensityFixedMean(\n                factor=factor, preserve_range=True, channel_wise=channel_wise, fixed_mean=False\n            )\n            im = p(self.imt)\n            result = scaler(im)\n\n            if False:  # channel_wise:\n                out = []\n                for d in im:\n                    clip_min = d.min()\n                    clip_max = d.max()\n                    d = (1 + factor) * d\n                    d[d < clip_min] = clip_min\n                    d[d > clip_max] = clip_max\n                    out.append(d)\n                expected = p(out)\n            else:\n                clip_min = im.min()\n                clip_max = im.max()\n                im = (1 + factor) * im\n                im[im < clip_min] = clip_min\n                im[im > clip_max] = clip_max\n                expected = im\n            assert_allclose(result, expected, type_test=\"tensor\", atol=1e-7)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_fixed_mean(self, p):\n        for channel_wise in [False, True]:\n            factor = 0.9\n            scaler = ScaleIntensityFixedMean(factor=factor, fixed_mean=True, channel_wise=channel_wise)\n            im = p(self.imt)\n            result = scaler(im)\n            mn = im.mean()\n            im = im - mn\n            expected = (1 + factor) * im\n            expected = expected + mn\n            assert_allclose(result, expected, type_test=\"tensor\", atol=1e-7)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_fixed_mean_preserve_range(self, p):\n        for channel_wise in [False, True]:\n            factor = 0.9\n            scaler = ScaleIntensityFixedMean(\n                factor=factor, preserve_range=True, fixed_mean=True, channel_wise=channel_wise\n            )\n            im = p(self.imt)\n            clip_min = im.min()\n            clip_max = im.max()\n            result = scaler(im)\n            mn = im.mean()\n            im = im - mn\n            expected = (1 + factor) * im\n            expected = expected + mn\n            expected[expected < clip_min] = clip_min\n            expected[expected > clip_max] = clip_max\n            assert_allclose(result, expected, type_test=\"tensor\", atol=1e-7)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_scale_intensity_range.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import ScaleIntensityRange\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass IntensityScaleIntensityRange(NumpyImageTestCase2D):\n    def test_image_scale_intensity_range(self):\n        scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80, dtype=np.uint8)\n        for p in TEST_NDARRAYS:\n            scaled = scaler(p(self.imt))\n            self.assertTrue(scaled.dtype, np.uint8)\n            expected = (((self.imt - 20) / 88) * 30 + 50).astype(np.uint8)\n            assert_allclose(scaled, p(expected), type_test=\"tensor\")\n\n    def test_image_scale_intensity_range_none_clip(self):\n        scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=None, b_max=80, clip=True, dtype=np.uint8)\n        for p in TEST_NDARRAYS:\n            scaled = scaler(p(self.imt))\n            self.assertTrue(scaled.dtype, np.uint8)\n            expected = (np.clip((self.imt - 20) / 88, None, 80)).astype(np.uint8)\n            assert_allclose(scaled, p(expected), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_scale_intensity_ranged.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.transforms import ScaleIntensityRanged\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass IntensityScaleIntensityRanged(NumpyImageTestCase2D):\n    def test_image_scale_intensity_ranged(self):\n        key = \"img\"\n        scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80)\n        for p in TEST_NDARRAYS:\n            scaled = scaler({key: p(self.imt)})\n            expected = (self.imt - 20) / 88\n            expected = expected * 30 + 50\n            assert_allclose(scaled[key], p(expected), type_test=\"tensor\")\n\n    def test_image_scale_intensity_ranged_none(self):\n        key = \"img\"\n        scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=None, b_max=None)\n        for p in TEST_NDARRAYS:\n            scaled = scaler({key: p(self.imt)})\n            expected = (self.imt - 20) / 88\n            assert_allclose(scaled[key], p(expected), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_scale_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import ScaleIntensityd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestScaleIntensityd(NumpyImageTestCase2D):\n    def test_range_scale(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)\n            result = scaler({key: p(self.imt)})\n            mina = np.min(self.imt)\n            maxa = np.max(self.imt)\n            norm = (self.imt - mina) / (maxa - mina)\n            expected = (norm * (2.0 - 1.0)) + 1.0\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n    def test_factor_scale(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1)\n            result = scaler({key: p(self.imt)})\n            expected = (self.imt * (1 + 0.1)).astype(np.float32)\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n    def test_channel_wise(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0, channel_wise=True)\n            data = p(self.imt)\n            result = scaler({key: data})\n            mina = self.imt.min()\n            maxa = self.imt.max()\n            for i, c in enumerate(data):\n                norm = (c - mina) / (maxa - mina)\n                expected = p((norm * (2.0 - 1.0)) + 1.0)\n                assert_allclose(result[key][i], expected, type_test=\"tensor\", rtol=1e-7, atol=0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_select_itemsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport sys\nimport time\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import SelectItemsd\n\nTEST_CASE_1 = [{\"keys\": [str(i) for i in range(30)]}, 30]\n\n\nclass TestSelectItemsd(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1])\n    def test_memory(self, input_param, expected_key_size):\n        input_data = {}\n        for i in range(50):\n            input_data[str(i)] = [time.time()] * 100000\n        result = SelectItemsd(**input_param)(input_data)\n        self.assertEqual(len(result.keys()), expected_key_size)\n        self.assertSetEqual(set(result.keys()), set(input_param[\"keys\"]))\n        self.assertGreaterEqual(\n            sys.getsizeof(input_data) * float(expected_key_size) / len(input_data), sys.getsizeof(result)\n        )\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_shift_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import ShiftIntensity\nfrom tests.test_utils import NumpyImageTestCase2D\n\n\nclass TestShiftIntensity(NumpyImageTestCase2D):\n    def test_value(self):\n        shifter = ShiftIntensity(offset=1.0)\n        result = shifter(self.imt)\n        expected = self.imt + 1.0\n        np.testing.assert_allclose(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_shift_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import IntensityStatsd, ShiftIntensityd\nfrom monai.utils.enums import PostFix\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestShiftIntensityd(NumpyImageTestCase2D):\n    def test_value(self):\n        key = \"img\"\n        for p in TEST_NDARRAYS:\n            shifter = ShiftIntensityd(keys=[key], offset=1.0)\n            result = shifter({key: p(self.imt)})\n            expected = self.imt + 1.0\n            assert_allclose(result[key], p(expected), type_test=\"tensor\")\n\n    def test_factor(self):\n        key = \"img\"\n        stats = IntensityStatsd(keys=key, ops=\"max\", key_prefix=\"orig\")\n        shifter = ShiftIntensityd(keys=[key], offset=1.0, factor_key=[\"orig_max\"])\n        data = {key: self.imt, PostFix.meta(key): {\"affine\": None}}\n\n        result = shifter(stats(data))\n        expected = self.imt + 1.0 * np.nanmax(self.imt)\n        np.testing.assert_allclose(result[key], expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_continuouswavelet.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalContinuousWavelet\nfrom monai.utils import optional_import\n\n_, has_pywt = optional_import(\"pywt\")\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [(\"mexh\", 150, 500)]\nEXPECTED_RESULTS = [(6, 150, 2000)]\n\n\n@skipUnless(has_pywt, \"pywt required\")\nclass TestSignalContinousWavelet(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, type, length, frequency):\n        self.assertIsInstance(SignalContinuousWavelet(type, length, frequency), SignalContinuousWavelet)\n        sig = np.load(TEST_SIGNAL)\n        cwt = SignalContinuousWavelet(type, length, frequency)\n        cwtsignal = cwt(sig)\n        self.assertEqual(cwtsignal.shape, EXPECTED_RESULTS[0])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_fillempty.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import SignalFillEmpty\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\n\n\nclass TestSignalFillEmptyNumpy(unittest.TestCase):\n    def test_correct_parameters_multi_channels(self):\n        self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)\n        sig = np.load(TEST_SIGNAL)\n        sig[:, 123] = np.nan\n        fillempty = SignalFillEmpty(replacement=0.0)\n        fillemptysignal = fillempty(sig)\n        self.assertTrue(not np.isnan(fillemptysignal).any())\n\n\nclass TestSignalFillEmptyTorch(unittest.TestCase):\n    def test_correct_parameters_multi_channels(self):\n        self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        sig[:, 123] = convert_to_tensor(np.nan)\n        fillempty = SignalFillEmpty(replacement=0.0)\n        fillemptysignal = fillempty(sig)\n        self.assertTrue(not torch.isnan(fillemptysignal).any())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_fillemptyd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import SignalFillEmptyd\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\n\n\nclass TestSignalFillEmptyNumpy(unittest.TestCase):\n    def test_correct_parameters_multi_channels(self):\n        self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd)\n        sig = np.load(TEST_SIGNAL)\n        sig[:, 123] = np.nan\n        data = {}\n        data[\"signal\"] = sig\n        fillempty = SignalFillEmptyd(keys=(\"signal\",), replacement=0.0)\n        data_ = fillempty(data)\n\n        self.assertTrue(np.isnan(sig).any())\n        self.assertTrue(not np.isnan(data_[\"signal\"]).any())\n\n\nclass TestSignalFillEmptyTorch(unittest.TestCase):\n    def test_correct_parameters_multi_channels(self):\n        self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        sig[:, 123] = convert_to_tensor(np.nan)\n        data = {}\n        data[\"signal\"] = sig\n        fillempty = SignalFillEmptyd(keys=(\"signal\",), replacement=0.0)\n        data_ = fillempty(data)\n\n        self.assertTrue(np.isnan(sig).any())\n        self.assertTrue(not torch.isnan(data_[\"signal\"]).any())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_add_gaussiannoise.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandAddGaussianNoise\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([0.0, 0.02],)]\n\n\nclass TestSignalRandAddGaussianNoiseNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries):\n        self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise)\n        sig = np.load(TEST_SIGNAL)\n        gaussian = SignalRandAddGaussianNoise(boundaries)\n        gaussiansignal = gaussian(sig)\n        self.assertEqual(gaussiansignal.shape[1], sig.shape[1])\n\n\nclass TestSignalRandAddGaussianNoiseTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries):\n        self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        gaussian = SignalRandAddGaussianNoise(boundaries)\n        gaussiansignal = gaussian(sig)\n        self.assertEqual(gaussiansignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_add_sine.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandAddSine\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([0.0, 1.0], [0.0, 0.5]), ([0.0, 1.0], [0.01, 0.1])]\n\n\nclass TestSignalRandAddSineNumpy(unittest.TestCase):\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, freqs):\n        self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine)\n        sig = np.load(TEST_SIGNAL)\n        sine = SignalRandAddSine(boundaries, freqs)\n        sinesignal = sine(sig)\n        self.assertEqual(sinesignal.shape[1], sig.shape[1])\n\n\nclass TestSignalRandAddSineTorch(unittest.TestCase):\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, freqs):\n        self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        sine = SignalRandAddSine(boundaries, freqs)\n        sinesignal = sine(sig)\n        self.assertEqual(sinesignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_add_sine_partial.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandAddSinePartial\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([0.0, 1.0], [0.1, 0.6], [0.0, 0.4])]\n\n\nclass TestSignalRandAddSinePartialNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):\n        self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial)\n        sig = np.load(TEST_SIGNAL)\n        partialsine = SignalRandAddSinePartial(boundaries, frequencies, fraction)\n        partialsinesignal = partialsine(sig)\n        self.assertEqual(partialsinesignal.shape[1], sig.shape[1])\n\n\nclass TestSignalRandAddSinePartialTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):\n        self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        partialsine = SignalRandAddSinePartial(boundaries, frequencies, fraction)\n        partialsinesignal = partialsine(sig)\n        self.assertEqual(partialsinesignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_add_squarepulse.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandAddSquarePulse\nfrom monai.utils import optional_import\nfrom monai.utils.type_conversion import convert_to_tensor\n\n_, has_scipy = optional_import(\"scipy\")\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([0.0, 1.0], [0.001, 0.2])]\n\n\n@skipUnless(has_scipy, \"scipy required\")\nclass TestSignalRandAddSquarePulseNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, frequencies):\n        self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse)\n        sig = np.load(TEST_SIGNAL)\n        squared = SignalRandAddSquarePulse(boundaries, frequencies)\n        squaredsignal = squared(sig)\n        self.assertEqual(squaredsignal.shape[1], sig.shape[1])\n\n\n@skipUnless(has_scipy, \"scipy required\")\nclass TestSignalRandAddSquarePulseTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, frequencies):\n        self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        squared = SignalRandAddSquarePulse(boundaries, frequencies)\n        squaredsignal = squared(sig)\n        self.assertEqual(squaredsignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_add_squarepulse_partial.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandAddSquarePulsePartial\nfrom monai.utils import optional_import\nfrom monai.utils.type_conversion import convert_to_tensor\n\n_, has_scipy = optional_import(\"scipy\")\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([0.0, 1.0], [0.001, 0.2], [0.0, 0.4])]\n\n\n@skipUnless(has_scipy, \"scipy required\")\nclass TestSignalRandAddSquarePulsePartialNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):\n        self.assertIsInstance(\n            SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction), SignalRandAddSquarePulsePartial\n        )\n        sig = np.load(TEST_SIGNAL)\n        partialsquare = SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction)\n        partialsquaresignal = partialsquare(sig)\n        self.assertEqual(partialsquaresignal.shape[1], sig.shape[1])\n\n\n@skipUnless(has_scipy, \"scipy required\")\nclass TestSignalRandAddSquarePulsePartialTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):\n        self.assertIsInstance(\n            SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction), SignalRandAddSquarePulsePartial\n        )\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        partialsquare = SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction)\n        partialsquaresignal = partialsquare(sig)\n        self.assertEqual(partialsquaresignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_drop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandDrop\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([0.0, 1.0],), ([0.01, 0.1],)]\n\n\nclass TestSignalRandDropNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries):\n        self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop)\n        sig = np.load(TEST_SIGNAL)\n        droped = SignalRandDrop(boundaries)\n        dropedsignal = droped(sig)\n        self.assertEqual(dropedsignal.shape[1], sig.shape[1])\n\n\nclass TestSignalRandDropTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries):\n        self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        droped = SignalRandDrop(boundaries)\n        dropedsignal = droped(sig)\n        self.assertEqual(dropedsignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_scale.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRandScale\nfrom monai.utils.type_conversion import convert_to_tensor\n\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [([-1.0, 1.0],), ([0.01, 0.1],)]\n\n\nclass TestSignalRandScaleNumpy(unittest.TestCase):\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries):\n        self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale)\n        sig = np.load(TEST_SIGNAL)\n        scaled = SignalRandScale(boundaries)\n        scaledsignal = scaled(sig)\n        self.assertEqual(scaledsignal.shape[1], sig.shape[1])\n\n\nclass TestSignalRandScaleTorch(unittest.TestCase):\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, boundaries):\n        self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        scaled = SignalRandScale(boundaries)\n        scaledsignal = scaled(sig)\n        self.assertEqual(scaledsignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_rand_shift.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.signal.array import SignalRandShift\nfrom monai.utils import optional_import\nfrom monai.utils.type_conversion import convert_to_tensor\n\n_, has_scipy = optional_import(\"scipy\")\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [(\"wrap\", 0.0, [-1.0, 1.0])]\n\n\n@skipUnless(has_scipy, \"scipy required\")\nclass TestSignalRandShiftNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, mode, filling, boundaries):\n        self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift)\n        sig = np.load(TEST_SIGNAL)\n        shifted = SignalRandShift(mode, filling, boundaries)\n        shiftedsignal = shifted(sig)\n        self.assertEqual(shiftedsignal.shape[1], sig.shape[1])\n\n\n@skipUnless(has_scipy, \"scipy required\")\nclass TestSignalRandShiftTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, mode, filling, boundaries):\n        self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        shifted = SignalRandShift(mode, filling, boundaries)\n        shiftedsignal = shifted(sig)\n        self.assertEqual(shiftedsignal.shape[1], sig.shape[1])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_signal_remove_frequency.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# you may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import SignalRemoveFrequency\nfrom monai.utils import optional_import\nfrom monai.utils.type_conversion import convert_to_tensor\n\n_, has_scipy = optional_import(\"scipy\")\n_, has_torchaudio = optional_import(\"torchaudio\")\nTESTS_PATH = Path(__file__).parents[1]\nTEST_SIGNAL = os.path.join(TESTS_PATH, \"testing_data\", \"signal.npy\")\nVALID_CASES = [(60, 1, 500)]\n\n\n@skipUnless(has_scipy and has_torchaudio, \"scipy and torchaudio are required\")\nclass TestSignalRemoveFrequencyNumpy(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq):\n        self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency)\n        sig = np.load(TEST_SIGNAL)\n        t = sig.shape[1] / sampling_freq\n        composite_sig = sig + np.sin(2 * np.pi * frequency * t)\n        freqremove = SignalRemoveFrequency(frequency, quality_factor, sampling_freq)\n        freqremovesignal = freqremove(composite_sig)\n        y = np.fft.fft(composite_sig) / composite_sig.shape[1]\n        y = y[: composite_sig.shape[1] // 2]\n        y2 = np.fft.fft(freqremovesignal) / freqremovesignal.shape[1]\n        y2 = y2[: freqremovesignal.shape[1] // 2]\n        self.assertEqual(composite_sig.shape[1], sig.shape[1])\n        self.assertAlmostEqual(y.all(), y2.all())\n\n\n@skipUnless(has_scipy and has_torchaudio, \"scipy and torchaudio are required\")\nclass TestSignalRemoveFrequencyTorch(unittest.TestCase):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq):\n        self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency)\n        sig = convert_to_tensor(np.load(TEST_SIGNAL))\n        t = sig.shape[1] / sampling_freq\n        composite_sig = convert_to_tensor(sig + np.sin(2 * np.pi * frequency * t))\n        freqremove = SignalRemoveFrequency(frequency, quality_factor, sampling_freq)\n        freqremovesignal = freqremove(composite_sig)\n        y = torch.fft.fft(composite_sig) / composite_sig.shape[1]\n        y = y[: composite_sig.shape[1] // 2]\n        y2 = torch.fft.fft(freqremovesignal) / freqremovesignal.shape[1]\n        y2 = y2[: freqremovesignal.shape[1] // 2]\n        self.assertEqual(composite_sig.shape[1], sig.shape[1])\n        self.assertAlmostEqual(y.all(), y2.all())\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_smooth_field.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom itertools import product\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.utils import meshgrid_xy\nfrom monai.transforms import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, is_tf32_env\n\n_rtol = 5e-3 if is_tf32_env() else 1e-4\n\nx, y = meshgrid_xy(torch.linspace(-1, 2, 11), torch.linspace(-2.1, 1.2, 8))\npattern2d = x.pow(2).add_(y.pow(2)).sqrt_()\n\nx, y, z = meshgrid_xy(torch.linspace(-1, 2, 11), torch.linspace(-2.1, 1.2, 8), torch.linspace(-0.1, 10.2, 6))\npattern3d = x.pow(2).add_(y.pow(2)).add_(z.pow(2)).sqrt_()\n\nINPUT_SHAPES = ((1, 8, 8), (1, 12, 7), (2, 8, 8), (2, 13, 8), (1, 8, 8, 8), (3, 7, 4, 5))\n\nTESTS_CONTRAST = []\nTESTS_INTENSITY = []\nTESTS_DEFORM = []\n\nKEY = \"test\"\n\nfor arr_type, shape in product(TEST_NDARRAYS, INPUT_SHAPES):\n    in_arr = arr_type(np.ones(shape, np.float32))\n    exp_arr = arr_type(np.ones(shape, np.float32))\n    rand_size = (4,) * (len(shape) - 1)\n\n    device = torch.device(\"cpu\")\n\n    if isinstance(in_arr, torch.Tensor) and in_arr.get_device() >= 0:\n        device = torch.device(in_arr.get_device())\n\n    TESTS_CONTRAST.append(\n        (\n            {\"keys\": (KEY,), \"spatial_size\": shape[1:], \"rand_size\": rand_size, \"prob\": 1.0, \"device\": device},\n            {KEY: in_arr},\n            {KEY: exp_arr},\n        )\n    )\n\n    TESTS_INTENSITY.append(\n        (\n            {\n                \"keys\": (KEY,),\n                \"spatial_size\": shape[1:],\n                \"rand_size\": rand_size,\n                \"prob\": 1.0,\n                \"device\": device,\n                \"gamma\": (0.9, 1),\n            },\n            {KEY: in_arr},\n            {KEY: exp_arr},\n        )\n    )\n\n    TESTS_DEFORM.append(\n        (\n            {\n                \"keys\": (KEY,),\n                \"spatial_size\": shape[1:],\n                \"rand_size\": rand_size,\n                \"prob\": 1.0,\n                \"device\": device,\n                \"def_range\": 0.1,\n            },\n            {KEY: in_arr},\n            {KEY: exp_arr},\n        )\n    )\n\n\nclass TestSmoothField(unittest.TestCase):\n    @parameterized.expand(TESTS_CONTRAST)\n    def test_rand_smooth_field_adjust_contrastd(self, input_param, input_data, expected_val):\n        g = RandSmoothFieldAdjustContrastd(**input_param)\n        g.set_random_state(123)\n\n        res = g(input_data)\n        for key, result in res.items():\n            expected = expected_val[key]\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n    def test_rand_smooth_field_adjust_contrastd_pad(self):\n        input_param, input_data, expected_val = TESTS_CONTRAST[0]\n\n        g = RandSmoothFieldAdjustContrastd(pad=1, **input_param)\n        g.set_random_state(123)\n\n        res = g(input_data)\n        for key, result in res.items():\n            expected = expected_val[key]\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS_INTENSITY)\n    def test_rand_smooth_field_adjust_intensityd(self, input_param, input_data, expected_val):\n        g = RandSmoothFieldAdjustIntensityd(**input_param)\n        g.set_random_state(123)\n\n        res = g(input_data)\n        for key, result in res.items():\n            expected = expected_val[key]\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n    def test_rand_smooth_field_adjust_intensityd_pad(self):\n        input_param, input_data, expected_val = TESTS_INTENSITY[0]\n\n        g = RandSmoothFieldAdjustIntensityd(pad=1, **input_param)\n        g.set_random_state(123)\n\n        res = g(input_data)\n        for key, result in res.items():\n            expected = expected_val[key]\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n    @parameterized.expand(TESTS_DEFORM)\n    def test_rand_smooth_deformd(self, input_param, input_data, expected_val):\n        g = RandSmoothDeformd(**input_param)\n        g.set_random_state(123)\n\n        res = g(input_data)\n        for key, result in res.items():\n            expected = expected_val[key]\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n    def test_rand_smooth_nodeformd(self):\n        \"\"\"Test input is very close to output when deformation is very low, verifies there's no transposition.\"\"\"\n\n        for label, im in zip((\"2D\", \"3D\"), (pattern2d, pattern3d)):\n            with self.subTest(f\"Testing {label} case with shape {im.shape}\"):\n                rsize = (3,) * len(im.shape)\n                g = RandSmoothDeformd(\n                    keys=(KEY,), spatial_size=im.shape, rand_size=rsize, prob=1.0, device=device, def_range=1e-20\n                )\n                g.set_random_state(123)\n\n                expected_val = {KEY: im[None]}\n\n                res = g(expected_val)\n                for key, result in res.items():\n                    expected = expected_val[key]\n\n                    self.assertSequenceEqual(tuple(result.shape), tuple(expected.shape))\n\n                    assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n    def test_rand_smooth_deformd_pad(self):\n        input_param, input_data, expected_val = TESTS_DEFORM[0]\n\n        g = RandSmoothDeformd(pad=1, **input_param)\n        g.set_random_state(123)\n\n        res = g(input_data)\n        for key, result in res.items():\n            expected = expected_val[key]\n            assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_sobel_gradient.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import SobelGradients\nfrom tests.test_utils import assert_allclose\n\nIMAGE = torch.zeros(1, 16, 16, dtype=torch.float32)\nIMAGE[0, 8, :] = 1\n\n# Output with reflect padding\nOUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)\nOUTPUT_3x3[0, 7, :] = 0.5\nOUTPUT_3x3[0, 9, :] = -0.5\n\n# Output with zero padding\nOUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone()\nOUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, 0] = 0.125\nOUTPUT_3x3_ZERO_PAD[1, 8, 0] = 0.25\nOUTPUT_3x3_ZERO_PAD[1, 7, -1] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -0.125\nOUTPUT_3x3_ZERO_PAD[1, 8, -1] = -0.25\nOUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 7, -1] = 3.0 / 8.0\nOUTPUT_3x3_ZERO_PAD[0, 9, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -3.0 / 8.0\n\nTEST_CASE_0 = [IMAGE, {\"kernel_size\": 3, \"dtype\": torch.float32}, OUTPUT_3x3]\nTEST_CASE_1 = [IMAGE, {\"kernel_size\": 3, \"dtype\": torch.float64}, OUTPUT_3x3]\nTEST_CASE_2 = [IMAGE, {\"kernel_size\": 3, \"spatial_axes\": 0, \"dtype\": torch.float64}, OUTPUT_3x3[0:1]]\nTEST_CASE_3 = [IMAGE, {\"kernel_size\": 3, \"spatial_axes\": 1, \"dtype\": torch.float64}, OUTPUT_3x3[1:2]]\nTEST_CASE_4 = [IMAGE, {\"kernel_size\": 3, \"spatial_axes\": [1], \"dtype\": torch.float64}, OUTPUT_3x3[1:2]]\nTEST_CASE_5 = [\n    IMAGE,\n    {\"kernel_size\": 3, \"spatial_axes\": [0, 1], \"normalize_kernels\": True, \"dtype\": torch.float64},\n    OUTPUT_3x3,\n]\nTEST_CASE_6 = [\n    IMAGE,\n    {\"kernel_size\": 3, \"spatial_axes\": (0, 1), \"padding_mode\": \"reflect\", \"dtype\": torch.float64},\n    OUTPUT_3x3,\n]\nTEST_CASE_7 = [\n    IMAGE,\n    {\"kernel_size\": 3, \"spatial_axes\": (0, 1), \"padding_mode\": \"zeros\", \"dtype\": torch.float64},\n    OUTPUT_3x3_ZERO_PAD,\n]\nTEST_CASE_8 = [  # Non-normalized kernels\n    IMAGE,\n    {\"kernel_size\": 3, \"normalize_kernels\": False, \"dtype\": torch.float32},\n    OUTPUT_3x3 * 8.0,\n]\nTEST_CASE_9 = [  # Normalized gradients and normalized kernels\n    IMAGE,\n    {\n        \"kernel_size\": 3,\n        \"normalize_kernels\": True,\n        \"normalize_gradients\": True,\n        \"spatial_axes\": (0, 1),\n        \"dtype\": torch.float64,\n    },\n    torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]]),\n]\nTEST_CASE_10 = [  # Normalized gradients but non-normalized kernels\n    IMAGE,\n    {\n        \"kernel_size\": 3,\n        \"normalize_kernels\": False,\n        \"normalize_gradients\": True,\n        \"spatial_axes\": (0, 1),\n        \"dtype\": torch.float64,\n    },\n    torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]]),\n]\n\nTEST_CASE_KERNEL_0 = [\n    {\"kernel_size\": 3, \"dtype\": torch.float64},\n    (torch.tensor([-0.5, 0.0, 0.5], dtype=torch.float64), torch.tensor([0.25, 0.5, 0.25], dtype=torch.float64)),\n]\nTEST_CASE_KERNEL_1 = [\n    {\"kernel_size\": 5, \"dtype\": torch.float64},\n    (\n        torch.tensor([-0.1250, -0.2500, 0.0000, 0.2500, 0.1250], dtype=torch.float64),\n        torch.tensor([0.0625, 0.2500, 0.3750, 0.2500, 0.0625], dtype=torch.float64),\n    ),\n]\nTEST_CASE_KERNEL_2 = [\n    {\"kernel_size\": 7, \"dtype\": torch.float64},\n    (\n        torch.tensor([-0.03125, -0.125, -0.15625, 0.0, 0.15625, 0.125, 0.03125], dtype=torch.float64),\n        torch.tensor([0.015625, 0.09375, 0.234375, 0.3125, 0.234375, 0.09375, 0.015625], dtype=torch.float64),\n    ),\n]\nTEST_CASE_KERNEL_NON_NORMALIZED_0 = [\n    {\"kernel_size\": 3, \"normalize_kernels\": False, \"dtype\": torch.float64},\n    (torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float64), torch.tensor([1.0, 2.0, 1.0], dtype=torch.float64)),\n]\nTEST_CASE_KERNEL_NON_NORMALIZED_1 = [\n    {\"kernel_size\": 5, \"normalize_kernels\": False, \"dtype\": torch.float64},\n    (\n        torch.tensor([-1.0, -2.0, 0.0, 2.0, 1.0], dtype=torch.float64),\n        torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0], dtype=torch.float64),\n    ),\n]\nTEST_CASE_KERNEL_NON_NORMALIZED_2 = [\n    {\"kernel_size\": 7, \"normalize_kernels\": False, \"dtype\": torch.float64},\n    (\n        torch.tensor([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0], dtype=torch.float64),\n        torch.tensor([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0], dtype=torch.float64),\n    ),\n]\n\nTEST_CASE_ERROR_0 = [IMAGE, {\"kernel_size\": 1}]  # kernel size less than 3\nTEST_CASE_ERROR_1 = [IMAGE, {\"kernel_size\": 4}]  # even kernel size\nTEST_CASE_ERROR_2 = [IMAGE, {\"spatial_axes\": \"horizontal\"}]  # wrong type direction\nTEST_CASE_ERROR_3 = [IMAGE, {\"spatial_axes\": 3}]  # wrong direction\nTEST_CASE_ERROR_4 = [IMAGE, {\"spatial_axes\": [3]}]  # wrong direction in a list\nTEST_CASE_ERROR_5 = [IMAGE, {\"spatial_axes\": [0, 4]}]  # correct and wrong direction in a list\n\n\nclass SobelGradientTests(unittest.TestCase):\n    backend = None\n\n    @parameterized.expand(\n        [\n            TEST_CASE_0,\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n            TEST_CASE_10,\n        ]\n    )\n    def test_sobel_gradients(self, image, arguments, expected_grad):\n        sobel = SobelGradients(**arguments)\n        grad = sobel(image)\n        assert_allclose(grad, expected_grad)\n\n    @parameterized.expand(\n        [\n            TEST_CASE_KERNEL_0,\n            TEST_CASE_KERNEL_1,\n            TEST_CASE_KERNEL_2,\n            TEST_CASE_KERNEL_NON_NORMALIZED_0,\n            TEST_CASE_KERNEL_NON_NORMALIZED_1,\n            TEST_CASE_KERNEL_NON_NORMALIZED_2,\n        ]\n    )\n    def test_sobel_kernels(self, arguments, expected_kernels):\n        sobel = SobelGradients(**arguments)\n        self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype)\n        self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype)\n        assert_allclose(sobel.kernel_diff, expected_kernels[0])\n        assert_allclose(sobel.kernel_smooth, expected_kernels[1])\n\n    @parameterized.expand(\n        [\n            TEST_CASE_ERROR_0,\n            TEST_CASE_ERROR_1,\n            TEST_CASE_ERROR_2,\n            TEST_CASE_ERROR_3,\n            TEST_CASE_ERROR_4,\n            TEST_CASE_ERROR_5,\n        ]\n    )\n    def test_sobel_gradients_error(self, image, arguments):\n        with self.assertRaises(ValueError):\n            sobel = SobelGradients(**arguments)\n            sobel(image)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_sobel_gradientd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import SobelGradientsd\nfrom tests.test_utils import assert_allclose\n\nIMAGE = torch.zeros(1, 16, 16, dtype=torch.float32)\nIMAGE[0, 8, :] = 1\n\n# Output with reflect padding\nOUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)\nOUTPUT_3x3[0, 7, :] = 0.5\nOUTPUT_3x3[0, 9, :] = -0.5\n\n# Output with zero padding\nOUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone()\nOUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, 0] = 0.125\nOUTPUT_3x3_ZERO_PAD[1, 8, 0] = 0.25\nOUTPUT_3x3_ZERO_PAD[1, 7, -1] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -0.125\nOUTPUT_3x3_ZERO_PAD[1, 8, -1] = -0.25\nOUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 7, -1] = 3.0 / 8.0\nOUTPUT_3x3_ZERO_PAD[0, 9, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -3.0 / 8.0\n\nTEST_CASE_0 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"kernel_size\": 3, \"dtype\": torch.float32}, {\"image\": OUTPUT_3x3}]\nTEST_CASE_1 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"kernel_size\": 3, \"dtype\": torch.float64}, {\"image\": OUTPUT_3x3}]\nTEST_CASE_2 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"dtype\": torch.float32, \"new_key_prefix\": \"sobel_\"},\n    {\"sobel_image\": OUTPUT_3x3},\n]\nTEST_CASE_3 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": 0, \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3[0][None, ...]},\n]\nTEST_CASE_4 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": 1, \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3[1][None, ...]},\n]\nTEST_CASE_5 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": [1], \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3[1][None, ...]},\n]\nTEST_CASE_6 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": [0, 1], \"normalize_kernels\": True, \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3},\n]\nTEST_CASE_7 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": (0, 1), \"padding_mode\": \"reflect\", \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3},\n]\nTEST_CASE_8 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": (0, 1), \"padding_mode\": \"zeros\", \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3_ZERO_PAD},\n]\nTEST_CASE_9 = [  # Non-normalized kernels\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"kernel_size\": 3, \"spatial_axes\": (0, 1), \"normalize_kernels\": False, \"dtype\": torch.float32},\n    {\"image\": OUTPUT_3x3 * 8.0},\n]\nTEST_CASE_10 = [  # Normalized gradients and normalized kernels\n    {\"image\": IMAGE},\n    {\n        \"keys\": \"image\",\n        \"kernel_size\": 3,\n        \"spatial_axes\": (0, 1),\n        \"normalize_kernels\": True,\n        \"normalize_gradients\": True,\n        \"dtype\": torch.float32,\n    },\n    {\"image\": torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]])},\n]\nTEST_CASE_11 = [  # Normalized gradients but non-normalized kernels\n    {\"image\": IMAGE},\n    {\n        \"keys\": \"image\",\n        \"kernel_size\": 3,\n        \"spatial_axes\": (0, 1),\n        \"normalize_kernels\": False,\n        \"normalize_gradients\": True,\n        \"dtype\": torch.float32,\n    },\n    {\"image\": torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]])},\n]\n\nTEST_CASE_KERNEL_0 = [\n    {\"keys\": \"image\", \"kernel_size\": 3, \"dtype\": torch.float64},\n    (torch.tensor([-0.5, 0.0, 0.5], dtype=torch.float64), torch.tensor([0.25, 0.5, 0.25], dtype=torch.float64)),\n]\nTEST_CASE_KERNEL_1 = [\n    {\"keys\": \"image\", \"kernel_size\": 5, \"dtype\": torch.float64},\n    (\n        torch.tensor([-0.1250, -0.2500, 0.0000, 0.2500, 0.1250], dtype=torch.float64),\n        torch.tensor([0.0625, 0.2500, 0.3750, 0.2500, 0.0625], dtype=torch.float64),\n    ),\n]\nTEST_CASE_KERNEL_2 = [\n    {\"keys\": \"image\", \"kernel_size\": 7, \"dtype\": torch.float64},\n    (\n        torch.tensor([-0.03125, -0.125, -0.15625, 0.0, 0.15625, 0.125, 0.03125], dtype=torch.float64),\n        torch.tensor([0.015625, 0.09375, 0.234375, 0.3125, 0.234375, 0.09375, 0.015625], dtype=torch.float64),\n    ),\n]\nTEST_CASE_KERNEL_NON_NORMALIZED_0 = [\n    {\"keys\": \"image\", \"kernel_size\": 3, \"normalize_kernels\": False, \"dtype\": torch.float64},\n    (torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float64), torch.tensor([1.0, 2.0, 1.0], dtype=torch.float64)),\n]\nTEST_CASE_KERNEL_NON_NORMALIZED_1 = [\n    {\"keys\": \"image\", \"kernel_size\": 5, \"normalize_kernels\": False, \"dtype\": torch.float64},\n    (\n        torch.tensor([-1.0, -2.0, 0.0, 2.0, 1.0], dtype=torch.float64),\n        torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0], dtype=torch.float64),\n    ),\n]\nTEST_CASE_KERNEL_NON_NORMALIZED_2 = [\n    {\"keys\": \"image\", \"kernel_size\": 7, \"normalize_kernels\": False, \"dtype\": torch.float64},\n    (\n        torch.tensor([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0], dtype=torch.float64),\n        torch.tensor([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0], dtype=torch.float64),\n    ),\n]\nTEST_CASE_ERROR_0 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"kernel_size\": 1}]  # kernel size less than 3\nTEST_CASE_ERROR_1 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"kernel_size\": 4}]  # even kernel size\nTEST_CASE_ERROR_2 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"spatial_axes\": \"horizontal\"}]  # wrong type direction\nTEST_CASE_ERROR_3 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"spatial_axes\": 3}]  # wrong direction\nTEST_CASE_ERROR_4 = [{\"image\": IMAGE}, {\"keys\": \"image\", \"spatial_axes\": [3]}]  # wrong direction in a list\nTEST_CASE_ERROR_5 = [\n    {\"image\": IMAGE},\n    {\"keys\": \"image\", \"spatial_axes\": [0, 4]},\n]  # correct and wrong direction in a list\n\n\nclass SobelGradientTests(unittest.TestCase):\n    backend = None\n\n    @parameterized.expand(\n        [\n            TEST_CASE_0,\n            TEST_CASE_1,\n            TEST_CASE_2,\n            TEST_CASE_3,\n            TEST_CASE_4,\n            TEST_CASE_5,\n            TEST_CASE_6,\n            TEST_CASE_7,\n            TEST_CASE_8,\n            TEST_CASE_9,\n            TEST_CASE_10,\n            TEST_CASE_11,\n        ]\n    )\n    def test_sobel_gradients(self, image_dict, arguments, expected_grad):\n        sobel = SobelGradientsd(**arguments)\n        grad = sobel(image_dict)\n        key = \"image\" if \"new_key_prefix\" not in arguments else arguments[\"new_key_prefix\"] + arguments[\"keys\"]\n        assert_allclose(grad[key], expected_grad[key])\n\n    @parameterized.expand(\n        [\n            TEST_CASE_KERNEL_0,\n            TEST_CASE_KERNEL_1,\n            TEST_CASE_KERNEL_2,\n            TEST_CASE_KERNEL_NON_NORMALIZED_0,\n            TEST_CASE_KERNEL_NON_NORMALIZED_1,\n            TEST_CASE_KERNEL_NON_NORMALIZED_2,\n        ]\n    )\n    def test_sobel_kernels(self, arguments, expected_kernels):\n        sobel = SobelGradientsd(**arguments)\n        self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype)\n        self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype)\n        assert_allclose(sobel.kernel_diff, expected_kernels[0])\n        assert_allclose(sobel.kernel_smooth, expected_kernels[1])\n\n    @parameterized.expand(\n        [\n            TEST_CASE_ERROR_0,\n            TEST_CASE_ERROR_1,\n            TEST_CASE_ERROR_2,\n            TEST_CASE_ERROR_3,\n            TEST_CASE_ERROR_4,\n            TEST_CASE_ERROR_5,\n        ]\n    )\n    def test_sobel_gradients_error(self, image_dict, arguments):\n        with self.assertRaises(ValueError):\n            sobel = SobelGradientsd(**arguments)\n            sobel(image_dict)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spacing.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import USE_COMPILED\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import affine_to_spacing\nfrom monai.transforms import Spacing\nfrom monai.utils import fall_back_tuple\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, dict_product, skip_if_quick\n\n# Define the static parts of each test case\n_template_5_expected_output = (\n    torch.tensor([[[[0.75, 0.75]], [[0.75, 0.75]], [[0.75, 0.75]]]])\n    if USE_COMPILED\n    else torch.tensor([[[[0.95527864, 0.95527864]], [[1.0, 1.0]], [[1.0, 1.0]]]])\n)\n\nall_template_parts = [\n    [\n        {\"pixdim\": (1.0, 1.5), \"padding_mode\": \"zeros\", \"dtype\": float},\n        torch.arange(4).reshape((1, 2, 2)) + 1.0,\n        torch.eye(4),\n        {},\n        torch.tensor([[[1.0, 1.0], [3.0, 2.0]]]),\n    ],\n    [\n        {\"pixdim\": 1.0, \"padding_mode\": \"zeros\", \"dtype\": float},\n        torch.ones((1, 2, 1, 2)),\n        torch.eye(4),\n        {},\n        torch.tensor([[[[1.0, 1.0]], [[1.0, 1.0]]]]),\n    ],\n    [\n        {\"pixdim\": 2.0, \"padding_mode\": \"zeros\", \"dtype\": float},\n        torch.arange(4).reshape((1, 2, 2)) + 1.0,\n        torch.eye(4),\n        {},\n        torch.tensor([[[1.0, 0.0], [0.0, 0.0]]]),\n    ],\n    [\n        {\"pixdim\": (1.0, 1.0, 1.0), \"padding_mode\": \"zeros\", \"dtype\": float},\n        torch.ones((1, 2, 1, 2)),\n        torch.eye(4),\n        {},\n        torch.tensor([[[[1.0, 1.0]], [[1.0, 1.0]]]]),\n    ],\n    [\n        {\"pixdim\": (1.0, 0.2, 1.5), \"diagonal\": False, \"padding_mode\": \"zeros\", \"align_corners\": True},\n        torch.ones((1, 2, 1, 2)),\n        torch.tensor([[2, 1, 0, 4], [-1, -3, 0, 5], [0, 0, 2.0, 5], [0, 0, 0, 1]]),\n        {},\n        _template_5_expected_output,\n    ],\n    [\n        {\"pixdim\": (3.0, 1.0), \"padding_mode\": \"zeros\"},\n        torch.arange(24).reshape((2, 3, 4)),\n        torch.as_tensor(np.diag([-3.0, 0.2, 1.5, 1])),\n        {},\n        torch.tensor([[[0, 0], [4, 0], [8, 0]], [[12, 0], [16, 0], [20, 0]]]),\n    ],\n    [\n        {\"pixdim\": (3.0, 1.0), \"padding_mode\": \"zeros\"},\n        torch.arange(24).reshape((2, 3, 4)),\n        torch.eye(4),\n        {},\n        torch.tensor([[[0, 1, 2, 3], [0, 0, 0, 0]], [[12, 13, 14, 15], [0, 0, 0, 0]]]),\n    ],\n    [\n        {\"pixdim\": (1.0, 1.0), \"align_corners\": True},\n        torch.arange(24).reshape((2, 3, 4)),\n        torch.eye(4),\n        {},\n        torch.tensor(\n            [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]\n        ),\n    ],\n    [\n        {\"pixdim\": (4.0, 5.0, 6.0)},\n        torch.arange(24).reshape((1, 2, 3, 4)),\n        torch.tensor([[-4, 0, 0, 4], [0, 5, 0, -5], [0, 0, 6, -6], [0, 0, 0, 1]]),\n        {},\n        torch.arange(24).reshape((1, 2, 3, 4)),\n    ],\n    [\n        {\"pixdim\": (4.0, 5.0, 6.0), \"diagonal\": True},\n        torch.arange(24).reshape((1, 2, 3, 4)),\n        torch.tensor([[-4, 0, 0, 4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]),\n        {},\n        torch.tensor(\n            [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]]\n        ),\n    ],\n    [\n        {\"pixdim\": (4.0, 5.0, 6.0), \"padding_mode\": \"border\", \"diagonal\": True},\n        torch.arange(24).reshape((1, 2, 3, 4)),\n        torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]),\n        {},\n        torch.tensor(\n            [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]]\n        ),\n    ],\n    [\n        {\"pixdim\": (4.0, 5.0, 6.0), \"padding_mode\": \"border\", \"diagonal\": True},\n        torch.arange(24).reshape((1, 2, 3, 4)),\n        torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]),\n        {\"mode\": \"nearest\"},\n        torch.tensor(\n            [[[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]]\n        ),\n    ],\n    [\n        {\"pixdim\": (1.9, 4.0), \"padding_mode\": \"zeros\", \"diagonal\": True},\n        torch.arange(24).reshape((1, 4, 6)),\n        torch.tensor([[-4, 0, 0, -4], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]),\n        {\"mode\": \"nearest\"},\n        torch.tensor(\n            [\n                [\n                    [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0],\n                    [18.0, 19.0, 20.0, 20.0, 21.0, 22.0, 23.0],\n                    [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0],\n                    [12.0, 13.0, 14.0, 14.0, 15.0, 16.0, 17.0],\n                    [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0],\n                    [6.0, 7.0, 8.0, 8.0, 9.0, 10.0, 11.0],\n                    [0.0, 1.0, 2.0, 2.0, 3.0, 4.0, 5.0],\n                ]\n            ]\n        ),\n    ],\n    [\n        {\"pixdim\": (5.0, 3.0), \"padding_mode\": \"border\", \"diagonal\": True, \"dtype\": torch.float32},\n        torch.arange(24).reshape((1, 4, 6)),\n        torch.tensor([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]),\n        {\"mode\": \"bilinear\"},\n        torch.tensor(\n            [\n                [\n                    [18.0, 18.6, 19.2, 19.8, 20.400002, 21.0, 21.6, 22.2, 22.8],\n                    [10.5, 11.1, 11.700001, 12.299999, 12.900001, 13.5, 14.1, 14.700001, 15.3],\n                    [3.0, 3.6000001, 4.2000003, 4.8, 5.4000006, 6.0, 6.6000004, 7.200001, 7.8],\n                ]\n            ]\n        ),\n    ],\n    [\n        {\"pixdim\": (5.0, 3.0), \"padding_mode\": \"zeros\", \"diagonal\": True, \"dtype\": torch.float32},\n        torch.arange(24).reshape((1, 4, 6)),\n        torch.tensor([[-4, 0, 0, 0], [0, 5, 0, 0], [0, 0, 6, 0], [0, 0, 0, 1]]),\n        {\"mode\": \"bilinear\"},\n        torch.tensor(\n            [\n                [\n                    [18.0000, 18.6000, 19.2000, 19.8000, 20.4000, 21.0000, 21.6000, 22.2000, 22.8000],\n                    [10.5000, 11.1000, 11.7000, 12.3000, 12.9000, 13.5000, 14.1000, 14.7000, 15.3000],\n                    [3.0000, 3.6000, 4.2000, 4.8000, 5.4000, 6.0000, 6.6000, 7.2000, 7.8000],\n                ]\n            ]\n        ),\n    ],\n    [\n        {\"pixdim\": [-1, -1, 0.5], \"padding_mode\": \"zeros\", \"dtype\": float},\n        torch.ones((1, 2, 1, 2)),\n        torch.eye(4),\n        {},\n        torch.tensor([[[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]]]),\n    ],\n    [  # 5D input\n        {\"pixdim\": [-1, -1, 0.5], \"padding_mode\": \"zeros\", \"dtype\": float, \"align_corners\": True},\n        torch.ones((1, 2, 2, 2, 1)),\n        torch.eye(4),\n        {},\n        torch.ones((1, 2, 2, 3, 1)),\n    ],\n    [  # 5D input\n        {\"pixdim\": 0.5, \"padding_mode\": \"constant\", \"mode\": \"nearest\", \"scale_extent\": True},\n        torch.ones((1, 368, 336, 368)),\n        torch.tensor(\n            [\n                [0.41, 0.005, 0.008, -79.7],\n                [-0.0049, 0.592, 0.0664, -57.4],\n                [-0.0073, -0.0972, 0.404, -32.1],\n                [0.0, 0.0, 0.0, 1.0],\n            ]\n        ),\n        {},\n        torch.ones((1, 302, 403, 301)),\n    ],\n]\nTESTS: list[list] = [\n    params[\"template\"] + [*params[\"device_val\"]]\n    for params in dict_product(template=all_template_parts, device_val=TEST_DEVICES)\n]\n\nTESTS_TORCH = [\n    [[1.2, 1.3, 0.9], params[\"p\"](torch.zeros((1, 3, 4, 5))), params[\"track_meta\"]]\n    for params in dict_product(track_meta=[False, True], p=TEST_NDARRAYS_ALL)\n]\n\nTEST_INVERSE = [\n    [*params[\"d\"], params[\"recompute\"], params[\"align\"], params[\"scale_extent\"]]\n    for params in dict_product(d=TEST_DEVICES, recompute=[False, True], align=[False, True], scale_extent=[False, True])\n]\n\n\n@skip_if_quick\nclass TestSpacingCase(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_spacing(\n        self,\n        init_param: dict,\n        img: torch.Tensor,\n        affine: torch.Tensor,\n        data_param: dict,\n        expected_output: torch.Tensor,\n        device: torch.device,\n    ):\n        img = MetaTensor(img, affine=affine).to(device)\n        tr = Spacing(**init_param)\n        call_param = data_param.copy()\n        call_param[\"data_array\"] = img\n        res: MetaTensor = tr(**call_param)  # type: ignore\n        self.assertEqual(img.device, res.device)\n\n        test_resampler_lazy(tr, res, init_param=init_param, call_param=call_param)\n\n        assert_allclose(res, expected_output, atol=1e-1, rtol=1e-1)\n        sr = min(len(res.shape) - 1, 3)\n        if isinstance(init_param[\"pixdim\"], float):\n            init_param[\"pixdim\"] = [init_param[\"pixdim\"]] * sr\n        init_pixdim = init_param[\"pixdim\"][:sr]\n        norm = affine_to_spacing(res.affine, sr).cpu().numpy()\n        assert_allclose(fall_back_tuple(init_pixdim, norm), norm, type_test=False)  # type: ignore\n\n    @parameterized.expand(TESTS_TORCH)\n    def test_spacing_torch(self, pixdim, img, track_meta: bool):\n        set_track_meta(track_meta)\n        init_param = {\"pixdim\": pixdim}\n        tr = Spacing(**init_param)\n        call_param = {\"data_array\": img}\n        res = tr(**call_param)\n\n        if track_meta:\n            self.assertIsInstance(res, MetaTensor)\n            new_spacing = affine_to_spacing(res.affine, 3)  # type: ignore\n            assert_allclose(new_spacing, pixdim, type_test=False)\n            self.assertNotEqual(img.shape, res.shape)\n            test_resampler_lazy(tr, res, init_param=init_param, call_param=call_param)\n        else:\n            self.assertIsInstance(res, torch.Tensor)\n            self.assertNotIsInstance(res, MetaTensor)\n            self.assertNotEqual(img.shape, res.shape)\n        set_track_meta(True)\n\n    @parameterized.expand(TEST_INVERSE)\n    def test_inverse(self, device, recompute, align, scale_extent):\n        img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)\n        affine = torch.tensor(\n            [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device=\"cpu\"\n        )\n        meta = {\"fname\": \"somewhere\"}\n        img = MetaTensor(img_t, affine=affine, meta=meta)\n        tr = Spacing(pixdim=[1.1, 1.2, 0.9], recompute_affine=recompute, align_corners=align, scale_extent=scale_extent)\n        # check that image and affine have changed\n        img = tr(img)\n        self.assertNotEqual(img.shape, img_t.shape)\n        l2_norm_affine = ((affine - img.affine) ** 2).sum() ** 0.5\n        self.assertGreater(l2_norm_affine, 5e-2)\n        # check that with inverse, image affine are back to how they were\n        img = tr.inverse(img)\n        self.assertEqual(img.applied_operations, [])\n        self.assertEqual(img.shape, img_t.shape)\n        l2_norm_affine = ((affine - img.affine) ** 2).sum() ** 0.5\n        self.assertLess(l2_norm_affine, 5e-2)\n\n    @parameterized.expand(TEST_INVERSE)\n    def test_inverse_mn_mx(self, device, recompute, align, scale_extent):\n        img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)\n        affine = torch.tensor(\n            [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device=\"cpu\"\n        )\n        img = MetaTensor(img_t, affine=affine, meta={\"fname\": \"somewhere\"})\n        choices = [(None, None), [1.2, None], [None, 0.7], [0.7, 0.9]]\n        idx = np.random.choice(range(len(choices)), size=1)[0]\n        tr = Spacing(\n            pixdim=[1.1, 1.2, 0.9],\n            recompute_affine=recompute,\n            align_corners=align,\n            scale_extent=scale_extent,\n            min_pixdim=[0.9, None, choices[idx][0]],\n            max_pixdim=[1.1, 1.1, choices[idx][1]],\n        )\n        img_out = tr(img)\n        if isinstance(img_out, MetaTensor):\n            assert_allclose(\n                img_out.pixdim, [1.0, 1.125, 0.888889] if recompute else [1.0, 1.2, 0.9], type_test=False, rtol=1e-4\n            )\n        img_out = tr.inverse(img_out)\n        self.assertEqual(img_out.applied_operations, [])\n        self.assertEqual(img_out.shape, img_t.shape)\n        self.assertLess(((affine - img_out.affine) ** 2).sum() ** 0.5, 5e-2)\n\n    def test_property_no_change(self):\n        affine = torch.tensor(\n            [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float32, device=\"cpu\"\n        )\n        affine[:3] *= 1 - 1e-4  # make sure it's not exactly target but close to the target\n        img = MetaTensor(torch.rand((1, 10, 9, 8), dtype=torch.float32), affine=affine, meta={\"fname\": \"somewhere\"})\n        tr = Spacing(pixdim=[1.0, 1.0, 1.0])\n        tr(img)\n        assert_allclose(tr.pixdim, [1.0, 1.0, 1.0], type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spacingd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import affine_to_spacing\nfrom monai.transforms import Spacingd\nfrom monai.utils import ensure_tuple_rep\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_DEVICES, assert_allclose, skip_if_quick\n\nTESTS: list[tuple] = []\nfor device in TEST_DEVICES:\n    TESTS.append(\n        (\n            \"spacing 3d\",\n            {\"image\": MetaTensor(torch.ones((2, 10, 15, 20)), affine=torch.eye(4))},\n            dict(keys=\"image\", pixdim=(1, 2, 1.4)),\n            (2, 10, 8, 15),\n            torch.as_tensor(np.diag([1, 2, 1.4, 1.0])),\n            *device,\n        )\n    )\n    TESTS.append(\n        (\n            \"spacing 2d\",\n            {\"image\": MetaTensor(torch.ones((2, 10, 20)), affine=torch.eye(3))},\n            dict(keys=\"image\", pixdim=(1, 2)),\n            (2, 10, 10),\n            torch.as_tensor(np.diag((1, 2, 1))),\n            *device,\n        )\n    )\n    TESTS.append(\n        (\n            \"spacing 2d no metadata\",\n            {\"image\": MetaTensor(torch.ones((2, 10, 20)))},\n            dict(keys=\"image\", pixdim=(1, 2)),\n            (2, 10, 10),\n            torch.as_tensor(np.diag((1, 2, 1, 1))),\n            *device,\n        )\n    )\n    TESTS.append(\n        (\n            \"interp all\",\n            {\n                \"image\": MetaTensor(np.arange(20).reshape((2, 1, 10)), affine=torch.eye(4)),\n                \"seg\": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),\n            },\n            dict(keys=(\"image\", \"seg\"), mode=\"nearest\", pixdim=(1, 0.2)),\n            (2, 1, 46),\n            torch.as_tensor(np.diag((1, 0.2, 1, 1))),\n            *device,\n        )\n    )\n    TESTS.append(\n        (\n            \"interp sep\",\n            {\n                \"image\": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),\n                \"seg\": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),\n            },\n            dict(keys=(\"image\", \"seg\"), mode=(\"bilinear\", \"nearest\"), pixdim=(1, 0.2)),\n            (2, 1, 46),\n            torch.as_tensor(np.diag((1, 0.2, 1, 1))),\n            *device,\n        )\n    )\n    TESTS.append(\n        (\n            \"interp sep\",\n            {\n                \"image\": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),\n                \"seg1\": MetaTensor(torch.ones((2, 1, 10)), affine=torch.diag(torch.tensor([2, 2, 2, 1]))),\n                \"seg2\": MetaTensor(torch.ones((2, 1, 10)), affine=torch.eye(4)),\n            },\n            dict(keys=(\"image\", \"seg1\", \"seg2\"), mode=(\"bilinear\", \"nearest\", \"nearest\"), pixdim=(1, 1, 1)),\n            (2, 1, 10),\n            torch.as_tensor(np.diag((1, 1, 1, 1))),\n            *device,\n        )\n    )\n\nTESTS_TORCH = []\nfor track_meta in (False, True):\n    for device in TEST_DEVICES:\n        TESTS_TORCH.append([{\"keys\": \"seg\", \"pixdim\": [0.2, 0.3, 1]}, torch.ones(2, 1, 2, 3), track_meta, *device])\n\n\nclass TestSpacingDCase(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_spacingd(self, _, data, kw_args, expected_shape, expected_affine, device):\n        data = {k: v.to(device) for k, v in data.items()}\n        tr = Spacingd(**kw_args)\n        call_param = {\"data\": data}\n        res = tr(**call_param)\n        # test lazy\n        if not isinstance(kw_args[\"keys\"], str):  # multiple keys\n            kw_args[\"mode\"] = ensure_tuple_rep(kw_args[\"mode\"], len(kw_args[\"keys\"]))\n            init_param = kw_args.copy()\n            for key, mode in zip(kw_args[\"keys\"], kw_args[\"mode\"]):\n                init_param[\"keys\"], init_param[\"mode\"] = key, mode\n                test_resampler_lazy(tr, res, init_param, call_param, output_key=key)\n        else:\n            test_resampler_lazy(tr, res, kw_args, call_param, output_key=kw_args[\"keys\"])\n        in_img = data[\"image\"]\n        out_img = res[\"image\"]\n        self.assertEqual(in_img.device, out_img.device)\n        # no change in number of keys\n        self.assertEqual(tuple(sorted(data)), tuple(sorted(res)))\n        np.testing.assert_allclose(out_img.shape, expected_shape)\n        assert_allclose(out_img.affine, expected_affine)\n\n    @parameterized.expand(TESTS_TORCH)\n    def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, device):\n        set_track_meta(track_meta)\n        tr = Spacingd(**init_param)\n        call_param = {\"data\": {\"seg\": img.to(device)}}\n        res_data = tr(**call_param)  # type: ignore\n        res = res_data[\"seg\"]\n\n        if track_meta:\n            test_resampler_lazy(tr, res_data, init_param, call_param, output_key=\"seg\")\n            self.assertIsInstance(res, MetaTensor)\n            assert isinstance(res, MetaTensor)  # for mypy type narrowing\n            new_spacing = affine_to_spacing(res.affine, 3)\n            assert_allclose(new_spacing, init_param[\"pixdim\"], type_test=False)\n            self.assertNotEqual(img.shape, res.shape)\n        else:\n            self.assertIsInstance(res, torch.Tensor)\n            self.assertNotIsInstance(res, MetaTensor)\n            self.assertNotEqual(img.shape, res.shape)\n\n    @skip_if_quick\n    def test_space_same_shape(self):\n        affine_1 = np.array(\n            [\n                [1.499277e00, 2.699563e-02, 3.805804e-02, -1.948635e02],\n                [-2.685805e-02, 1.499757e00, -2.635604e-12, 4.438188e01],\n                [-3.805194e-02, -5.999028e-04, 1.499517e00, 4.036536e01],\n                [0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00],\n            ]\n        )\n        affine_2 = np.array(\n            [\n                [1.499275e00, 2.692252e-02, 3.805728e-02, -1.948635e02],\n                [-2.693010e-02, 1.499758e00, -4.260525e-05, 4.438188e01],\n                [-3.805190e-02, -6.406730e-04, 1.499517e00, 4.036536e01],\n                [0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00],\n            ]\n        )\n        img_1 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_1)\n        img_2 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_2)\n        out = Spacingd((\"img_1\", \"img_2\"), pixdim=1)({\"img_1\": img_1, \"img_2\": img_2})\n        self.assertEqual(out[\"img_1\"].shape, out[\"img_2\"].shape)  # ensure_same_shape True\n        out = Spacingd((\"img_1\", \"img_2\"), pixdim=1, ensure_same_shape=False)({\"img_1\": img_1, \"img_2\": img_2})\n        self.assertNotEqual(out[\"img_1\"].shape, out[\"img_2\"].shape)  # ensure_same_shape False\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spatial_crop.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import CenterScaleCrop, CenterSpatialCrop, SpatialCrop\nfrom tests.croppers import CropTest\n\nTESTS = [\n    [{\"roi_center\": [1, 1, 1], \"roi_size\": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)],\n    [{\"roi_center\": [1, 1, 1], \"roi_size\": [2, 2, 2]}, (3, 1, 1, 1), (3, 1, 1, 1)],\n    [{\"roi_start\": [0, 0, 0], \"roi_end\": [2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)],\n    [{\"roi_start\": [0, 0], \"roi_end\": [2, 2]}, (3, 3, 3, 3), (3, 2, 2, 3)],\n    [{\"roi_start\": [0, 0, 0, 0, 0], \"roi_end\": [2, 2, 2, 2, 2]}, (3, 3, 3, 3), (3, 2, 2, 2)],\n    [{\"roi_start\": [0, 0, 0, 0, 0], \"roi_end\": [8, 8, 8, 2, 2]}, (3, 3, 3, 3), (3, 3, 3, 3)],\n    [\n        {\"roi_slices\": [slice(s, e) for s, e in zip([None, None, None], [None, None, None])]},\n        (3, 11, 12, 15),\n        (3, 11, 12, 15),\n    ],\n    [{\"roi_slices\": [slice(s, e) for s, e in zip([1, None, 0], [None, None, None])]}, (3, 7, 9, 11), (3, 6, 9, 11)],\n    [{\"roi_slices\": [slice(s, e) for s, e in zip([0, None, None], [-1, None, None])]}, (3, 7, 9, 11), (3, 6, 9, 11)],\n    [{\"roi_slices\": [slice(s, e) for s, e in zip([1, None, None], [None, None, None])]}, (3, 10, 8, 6), (3, 9, 8, 6)],\n    [{\"roi_slices\": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]}, (3, 15, 17, 8), (3, 1, 2, 2)],\n    [{\"roi_slices\": [slice(s, e) for s, e in zip([None, None, None], [-2, -1, 2])]}, (3, 13, 8, 6), (3, 11, 7, 2)],\n    [{\"roi_start\": [-1, 0], \"roi_end\": [5, 5]}, (1, 5, 5), (1, 5, 5)],\n]\n\nTEST_ERRORS = [[{\"roi_slices\": [slice(s, e, 2) for s, e in zip([-1, -2, 0], [None, None, 2])]}]]\n\nTEST_LAZY_ERRORS = [[{\"roi_start\": [1, 0, 0], \"roi_end\": [1, 8, 8]}, (3, 3, 3, 3), (3, 0, 3, 3)]]\n\nfunc1 = {CenterSpatialCrop: {\"roi_size\": [8, 8, 6]}}\nfunc2 = {SpatialCrop: {\"roi_center\": [1, 1, 1], \"roi_size\": [3, 4, 3]}}\nfunc3 = {CenterScaleCrop: {\"roi_scale\": [0.6, 0.3, -1]}}\n\nTESTS_COMBINE = []\nTESTS_COMBINE.append([[func1, func2, func3], (3, 10, 10, 8)])\nTESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4)])\nTESTS_COMBINE.append([[func1, func3], (3, 8, 8, 4)])\n\n\nclass TestSpatialCrop(CropTest):\n    Cropper = SpatialCrop\n\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_shape, expected_shape):\n        self.crop_test(input_param, input_shape, expected_shape)\n\n    @parameterized.expand(TEST_ERRORS)\n    def test_error(self, input_param):\n        with self.assertRaises(ValueError):\n            SpatialCrop(**input_param)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n    @parameterized.expand(TEST_LAZY_ERRORS)\n    def test_lazy_error(self, input_param, input_shape, _):\n        with self.assertRaises(ValueError):\n            return self.crop_test_pending_ops(input_param, input_shape)\n\n    @parameterized.expand(TESTS_COMBINE)\n    def test_combine_ops(self, funcs, input_shape):\n        self.crop_test_combine_ops(funcs, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spatial_cropd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import SpatialCropd\nfrom tests.croppers import CropTest\n\nTESTS = [\n    [\n        {\"keys\": [\"img\"], \"roi_center\": [1, 1], \"roi_size\": [2, 2]},\n        (1, 3, 3),\n        (1, 2, 2),\n        (slice(None), slice(None, 2), slice(None, 2)),\n    ],\n    [\n        {\"keys\": [\"img\"], \"roi_center\": [1, 1, 1], \"roi_size\": [2, 2, 2]},\n        (3, 3, 3, 3),\n        (3, 2, 2, 2),\n        (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)),\n    ],\n    [\n        {\"keys\": [\"img\"], \"roi_start\": [0, 0, 0], \"roi_end\": [2, 2, 2]},\n        (3, 3, 3, 3),\n        (3, 2, 2, 2),\n        (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)),\n    ],\n    [\n        {\"keys\": [\"img\"], \"roi_start\": [0, 0], \"roi_end\": [2, 2]},\n        (3, 3, 3, 3),\n        (3, 2, 2, 3),\n        (slice(None), slice(None, 2), slice(None, 2), slice(None)),\n    ],\n    [\n        {\"keys\": [\"img\"], \"roi_start\": [0, 0, 0, 0, 0], \"roi_end\": [2, 2, 2, 2, 2]},\n        (3, 3, 3, 3),\n        (3, 2, 2, 2),\n        (slice(None), slice(None, 2), slice(None, 2), slice(None, 2)),\n    ],\n    [\n        {\"keys\": [\"img\"], \"roi_slices\": [slice(s, e) for s, e in zip([-1, -2, 0], [None, None, 2])]},\n        (3, 3, 3, 3),\n        (3, 1, 2, 2),\n        (slice(None), slice(-1, None), slice(-2, None), slice(0, 2)),\n    ],\n]\n\n\nclass TestSpatialCropd(CropTest):\n    Cropper = SpatialCropd\n\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, input_shape, expected_shape, same_area):\n        self.crop_test(input_param, input_shape, expected_shape, same_area)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _expected_shape, _same_area):\n        self.crop_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spatial_pad.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import BorderPad, DivisiblePad, SpatialPad\nfrom tests.padders import PadTest\n\nTESTS = []\nTESTS.append([{\"spatial_size\": [3, 4], \"method\": \"end\"}, (1, 2, 3), (1, 3, 4)])\nTESTS.append([{\"spatial_size\": [15, 4, -1], \"method\": \"symmetric\"}, (3, 8, 8, 4), (3, 15, 8, 4)])\n\nfunc1 = {SpatialPad: {\"spatial_size\": [15, 4, -1], \"method\": \"symmetric\"}}\nfunc2 = {BorderPad: {\"spatial_border\": 2}}\nfunc3 = {DivisiblePad: {\"k\": 5, \"method\": \"end\"}}\n\nTESTS_COMBINE = []\nTESTS_COMBINE.append([[func1, func2, func3], (3, 8, 8, 4), (3, 20, 15, 10)])\nTESTS_COMBINE.append([[func1, func2], (3, 8, 8, 4), (3, 19, 12, 8)])\nTESTS_COMBINE.append([[func2, func2], (3, 8, 8, 4), (3, 16, 16, 12)])\n\n\nclass TestSpatialPad(PadTest):\n    Padder = SpatialPad\n\n    @parameterized.expand(TESTS)\n    def test_pad(self, input_param, input_shape, expected_shape):\n        self.pad_test(input_param, input_shape, expected_shape)\n\n    def test_pad_kwargs(self):\n        kwargs = {\"spatial_size\": [15, 8], \"method\": \"end\", \"mode\": \"constant\"}\n        unchanged_slices = [slice(None), slice(None, 8), slice(None, 4)]\n        self.pad_test_kwargs(unchanged_slices, **kwargs)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.pad_test_pending_ops(input_param, input_shape)\n\n    @parameterized.expand(TESTS_COMBINE)\n    def test_combine_ops(self, funcs, input_shape, expected_shape):\n        self.pad_test_combine_ops(funcs, input_shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spatial_padd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import SpatialPadd\nfrom tests.padders import PadTest\n\nTESTS = [\n    [{\"keys\": [\"img\"], \"spatial_size\": [15, 8, 8], \"method\": \"symmetric\"}, (3, 8, 8, 5), (3, 15, 8, 8)],\n    [{\"keys\": [\"img\"], \"spatial_size\": [15, 8, 8], \"method\": \"end\"}, (3, 8, 8, 5), (3, 15, 8, 8)],\n    [{\"keys\": [\"img\"], \"spatial_size\": [15, 8, 8], \"method\": \"end\"}, (3, 8, 8, 5), (3, 15, 8, 8)],\n    [{\"keys\": [\"img\"], \"spatial_size\": [15, 8, -1], \"method\": \"end\"}, (3, 8, 5, 4), (3, 15, 8, 4)],\n]\n\n\nclass TestSpatialPadd(PadTest):\n    Padder = SpatialPadd\n\n    @parameterized.expand(TESTS)\n    def test_pad(self, input_param, input_shape, expected_shape):\n        modes = [\"constant\", {\"constant\"}]\n        self.pad_test(input_param, input_shape, expected_shape, modes)\n\n    @parameterized.expand(TESTS)\n    def test_pending_ops(self, input_param, input_shape, _):\n        self.pad_test_pending_ops(input_param, input_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_spatial_resample.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_obj import set_track_meta\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.data.utils import to_affine_nd\nfrom monai.transforms import SpatialResample\nfrom monai.utils import optional_import\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, dict_product\n\nTESTS = []\n\ndestinations_3d = [\n    torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n    torch.tensor([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),\n]\nexpected_3d = [\n    torch.tensor([[[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]], [[10.0, 11.0, 12.0], [7.0, 8.0, 9.0]]]]),\n    torch.tensor([[[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]]),\n]\n\nfor dst, expct in zip(destinations_3d, expected_3d):\n    for device in TEST_DEVICES:\n        for align in (False, True):\n            interp = (\"nearest\", \"bilinear\")\n            for interp_mode in interp:\n                for padding_mode in (\"zeros\", \"border\", \"reflection\"):\n                    TESTS.append(\n                        [\n                            torch.arange(12).reshape((1, 2, 2, 3)) + 1.0,  # data\n                            *device,\n                            {\n                                \"dst_affine\": dst,\n                                \"dtype\": torch.float64,\n                                \"align_corners\": align,\n                                \"mode\": interp_mode,\n                                \"padding_mode\": padding_mode,\n                            },\n                            expct,\n                        ]\n                    )\nif optional_import(\"cupy\")[1] and optional_import(\"scipy.ndimage\")[1]:\n    TESTS.append(deepcopy(TESTS[-1]))\n    TESTS[-1][2].update({\"align_corners\": True, \"mode\": 1, \"padding_mode\": \"reflect\"})  # type: ignore\n\ndestinations_2d = [\n    torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]),  # flip the second\n    torch.tensor([[-1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),  # flip the first\n]\nexpected_2d = [torch.tensor([[[2.0, 1.0], [4.0, 3.0]]]), torch.tensor([[[3.0, 4.0], [1.0, 2.0]]])]\n\nfor dst, expct in zip(destinations_2d, expected_2d):\n    TESTS.extend(\n        [\n            [\n                torch.arange(4).reshape((1, 2, 2)) + 1.0,\n                *params[\"device\"],\n                {\n                    \"dst_affine\": dst,\n                    \"dtype\": torch.float32,\n                    \"align_corners\": params[\"align\"],\n                    \"mode\": params[\"interp_mode\"],\n                    \"padding_mode\": \"zeros\",\n                },\n                expct,\n            ]\n            for params in dict_product(device=TEST_DEVICES, align=[False, True], interp_mode=[\"nearest\", \"bilinear\"])\n        ]\n    )\n\nTEST_4_5_D = []\nfor device in TEST_DEVICES:\n    for dtype in (torch.float32, torch.float64):\n        # 4D\n        TEST_4_5_D.append(\n            [\n                (1, 2, 2, 3, 1),\n                (1, 1, 1, 1, 2),\n                *device,\n                dtype,\n                torch.tensor(\n                    [\n                        [[[0.5, 0.0], [0.0, 2.0], [1.5, 1.0]], [[3.5, 3.0], [3.0, 5.0], [4.5, 4.0]]],\n                        [[[6.5, 6.0], [6.0, 8.0], [7.5, 7.0]], [[9.5, 9.0], [9.0, 11.0], [10.5, 10.0]]],\n                    ]\n                ),\n            ]\n        )\n        # 5D\n        TEST_4_5_D.append(\n            [\n                (1, 2, 2, 3, 1, 1),\n                (1, 1, 1, 1, 2, 2),\n                *device,\n                dtype,\n                torch.tensor(\n                    [\n                        [\n                            [[[0.0, 0.0], [0.0, 1.0]], [[0.5, 0.0], [1.5, 1.0]], [[1.0, 2.0], [2.0, 2.0]]],\n                            [[[3.0, 3.0], [3.0, 4.0]], [[3.5, 3.0], [4.5, 4.0]], [[4.0, 5.0], [5.0, 5.0]]],\n                        ],\n                        [\n                            [[[6.0, 6.0], [6.0, 7.0]], [[6.5, 6.0], [7.5, 7.0]], [[7.0, 8.0], [8.0, 8.0]]],\n                            [[[9.0, 9.0], [9.0, 10.0]], [[9.5, 9.0], [10.5, 10.0]], [[10.0, 11.0], [11.0, 11.0]]],\n                        ],\n                    ]\n                ),\n            ]\n        )\n\nTEST_TORCH_INPUT = []\nfor track_meta in (True,):\n    for t in TEST_4_5_D:\n        TEST_TORCH_INPUT.append(t + [track_meta])\n\n\nclass TestSpatialResample(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_flips(self, img, device, data_param, expected_output):\n        for p in TEST_NDARRAYS_ALL:\n            img = p(img)\n            if isinstance(img, MetaTensor):\n                img.affine = torch.eye(4)\n            if hasattr(img, \"to\"):\n                img = img.to(device)\n            resampler = SpatialResample()\n            call_param = data_param.copy()\n            call_param[\"img\"] = img\n            out = resampler(**call_param)\n            assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)\n            assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), call_param[\"dst_affine\"])\n\n            test_resampler_lazy(resampler, out, init_param=None, call_param=call_param)\n\n    @parameterized.expand(TEST_4_5_D)\n    def test_4d_5d(self, new_shape, tile, device, dtype, expected_data):\n        img = np.arange(12).reshape(new_shape)\n        img = np.tile(img, tile)\n        img = MetaTensor(img).to(device)\n\n        dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])\n        dst = dst.to(dtype)\n        init_param = {\"dtype\": dtype, \"align_corners\": True}\n        call_param = {\"img\": img, \"dst_affine\": dst, \"align_corners\": False}\n        resampler = SpatialResample(**init_param)\n        out = resampler(**call_param)\n        assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)\n        assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)\n\n        test_resampler_lazy(resampler, out, init_param, call_param)\n\n    @parameterized.expand(TEST_DEVICES)\n    def test_ill_affine(self, device):\n        img = MetaTensor(torch.arange(12).reshape(1, 2, 2, 3)).to(device)\n        ill_affine = torch.tensor([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, -1, 1.5], [0, 0, 0, 1]])\n        with self.assertRaises(ValueError):\n            img.affine = torch.eye(4)\n            dst_affine = ill_affine\n            SpatialResample()(img=img, dst_affine=dst_affine)\n        with self.assertRaises(ValueError):\n            img.affine = ill_affine\n            dst_affine = torch.eye(4)\n            SpatialResample()(img=img, dst_affine=dst_affine)\n        if not (optional_import(\"scipy\")[1] and optional_import(\"cupy\")[1]):\n            return\n        with self.assertRaises(ValueError):  # requires scipy\n            SpatialResample(mode=1, align_corners=True)(img=img, dst_affine=dst_affine)\n        with self.assertRaises(ValueError):\n            SpatialResample(mode=1, align_corners=False)(img=img, dst_affine=dst_affine)\n\n    @parameterized.expand(TEST_TORCH_INPUT)\n    def test_input_torch(self, new_shape, tile, device, dtype, expected_data, track_meta):\n        set_track_meta(track_meta)\n        img = np.arange(12).reshape(new_shape)\n        img = torch.as_tensor(np.tile(img, tile)).to(device)\n        dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])\n        dst = dst.to(dtype).to(device)\n        init_param = {\"dtype\": dtype}\n        call_param = {\"img\": img, \"dst_affine\": dst}\n        resampler = SpatialResample(**init_param)\n        out = resampler(**call_param)\n        assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)\n\n        test_resampler_lazy(resampler, out, init_param, call_param)\n\n        if track_meta:\n            self.assertIsInstance(out, MetaTensor)\n            assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)\n        else:\n            self.assertIsInstance(out, torch.Tensor)\n            self.assertNotIsInstance(out, MetaTensor)\n\n    @parameterized.expand(TESTS)\n    def test_inverse(self, img, device, data_param, expected_output):\n        img = MetaTensor(img, affine=torch.eye(4)).to(device)\n        tr = SpatialResample()\n        out = tr(img=img, **data_param)\n        assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)\n        assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), data_param[\"dst_affine\"])\n\n        # inverse\n        out = tr.inverse(out)\n        assert_allclose(img, out)\n        expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4))\n        assert_allclose(out.affine, expected_affine)\n\n    def test_unchange(self):\n        for i, p in enumerate(TEST_NDARRAYS_ALL):\n            set_track_meta(i % 2)\n            img = p(np.arange(12).reshape(1, 3, 4))\n            result = SpatialResample()(img)\n            assert_allclose(result, img, type_test=False)\n        set_track_meta(True)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_squeezedim.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.transforms import SqueezeDim\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS, TESTS_FAIL = [], []\nfor p in TEST_NDARRAYS:\n    TESTS.append([{\"dim\": None}, p(np.random.rand(1, 2, 1, 3)), (2, 3)])\n    TESTS.append([{\"dim\": 2}, p(np.random.rand(1, 2, 1, 8, 16)), (1, 2, 8, 16)])\n    TESTS.append([{\"dim\": -1}, p(np.random.rand(1, 1, 16, 8, 1)), (1, 1, 16, 8)])\n    TESTS.append([{}, p(np.random.rand(1, 2, 1, 3)), (2, 1, 3)])\n\n    TESTS_FAIL.append([ValueError, {\"dim\": -2}, p(np.random.rand(1, 1, 16, 8, 1))])\n    TESTS_FAIL.append([TypeError, {\"dim\": 0.5}, p(np.random.rand(1, 1, 16, 8, 1))])\n\n\nclass TestSqueezeDim(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, test_data, expected_shape):\n        result = SqueezeDim(**input_param)(test_data)\n        self.assertTupleEqual(result.shape, expected_shape)\n        if \"dim\" in input_param and input_param[\"dim\"] == 2 and isinstance(result, MetaTensor):\n            assert_allclose(result.affine.shape, [3, 3])\n\n    @parameterized.expand(TESTS_FAIL)\n    def test_invalid_inputs(self, exception, input_param, test_data):\n        with self.assertRaises(exception):\n            SqueezeDim(**input_param)(test_data)\n\n    def test_affine_ill_inputs(self):\n        img = MetaTensor(\n            np.random.rand(1, 2, 1, 3),\n            affine=[\n                [-0.7422, 0.0, 0.0, 186.3210],\n                [0.0, 0.0, -3.0, 70.6580],\n                [0.0, -0.7422, 0.0, 189.4130],\n                [0.0, 0.0, 0.0, 1.0],\n            ],\n        )\n        with self.assertWarns(UserWarning):\n            SqueezeDim(dim=2)(img)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_squeezedimd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.transforms import SqueezeDimd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS, TESTS_FAIL = [], []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"dim\": None},\n            {\"img\": p(np.random.rand(1, 2, 1, 3)), \"seg\": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))},\n            (2, 3),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"dim\": 2},\n            {\"img\": p(np.random.rand(1, 2, 1, 8, 16)), \"seg\": p(np.random.randint(0, 2, size=[1, 2, 1, 8, 16]))},\n            (1, 2, 8, 16),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"dim\": -1},\n            {\"img\": p(np.random.rand(1, 1, 16, 8, 1)), \"seg\": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))},\n            (1, 1, 16, 8),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"]},\n            {\"img\": p(np.random.rand(1, 2, 1, 3)), \"seg\": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))},\n            (2, 1, 3),\n        ]\n    )\n\n    TESTS.append(\n        [\n            {\"keys\": [\"img\", \"seg\"], \"dim\": 0},\n            {\"img\": p(np.random.rand(1, 2, 1, 3)), \"seg\": p(np.random.randint(0, 2, size=[1, 2, 1, 3]))},\n            (2, 1, 3),\n        ]\n    )\n\n    TESTS_FAIL.append(\n        [\n            ValueError,\n            {\"keys\": [\"img\", \"seg\"], \"dim\": -2},\n            {\"img\": p(np.random.rand(1, 1, 16, 8, 1)), \"seg\": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))},\n        ]\n    )\n\n    TESTS_FAIL.append(\n        [\n            TypeError,\n            {\"keys\": [\"img\", \"seg\"], \"dim\": 0.5},\n            {\"img\": p(np.random.rand(1, 1, 16, 8, 1)), \"seg\": p(np.random.randint(0, 2, size=[1, 1, 16, 8, 1]))},\n        ]\n    )\n\n\nclass TestSqueezeDim(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, input_param, test_data, expected_shape):\n        result = SqueezeDimd(**input_param)(test_data)\n        self.assertTupleEqual(result[\"img\"].shape, expected_shape)\n        self.assertTupleEqual(result[\"seg\"].shape, expected_shape)\n        if \"dim\" in input_param and isinstance(result[\"img\"], MetaTensor) and input_param[\"dim\"] == 2:\n            assert_allclose(result[\"img\"].affine.shape, [3, 3])\n\n    @parameterized.expand(TESTS_FAIL)\n    def test_invalid_inputs(self, exception, input_param, test_data):\n        with self.assertRaises(exception):\n            SqueezeDimd(**input_param)(test_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_std_shift_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import ShiftIntensity, StdShiftIntensity\nfrom monai.utils import dtype_numpy_to_torch\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestStdShiftIntensity(NumpyImageTestCase2D):\n    def test_value(self):\n        for p in TEST_NDARRAYS:\n            imt = p(self.imt)\n            factor = np.random.rand()\n            offset = np.std(self.imt) * factor\n            shifter = ShiftIntensity(offset=offset)\n            expected = shifter(imt)\n            std_shifter = StdShiftIntensity(factor=factor)\n            result = std_shifter(imt)\n            assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)\n\n    def test_zerostd(self):\n        for p in TEST_NDARRAYS:\n            image = p(np.ones([2, 3, 3], dtype=np.float32))\n            for nonzero in [True, False]:\n                for channel_wise in [True, False]:\n                    factor = np.random.rand()\n                    std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise)\n                    result = std_shifter(image)\n                    assert_allclose(result, image, atol=0, rtol=1e-5, type_test=False)\n\n    def test_nonzero(self):\n        for p in TEST_NDARRAYS:\n            image = p(np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]]))  # std = 1\n            factor = np.random.rand()\n            std_shifter = StdShiftIntensity(factor=factor, nonzero=True)\n            result = std_shifter(image)\n            expected = p(np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]], dtype=np.float32))\n            assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)\n\n    def test_channel_wise(self):\n        for p in TEST_NDARRAYS:\n            image = p(np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0]))))  # std: 0.5, 0\n            factor = np.random.rand()\n            std_shifter = StdShiftIntensity(factor=factor, channel_wise=True)\n            result = std_shifter(image)\n            expected = p(\n                np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))).astype(np.float32)\n            )\n            assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)\n\n    def test_dtype(self):\n        trans_dtype = np.float32\n        for dtype in [int, np.float32, np.float64]:\n            image = np.random.rand(2, 2, 2).astype(dtype)\n            factor = np.random.rand()\n            std_shifter = StdShiftIntensity(factor=factor, dtype=trans_dtype)\n            result = std_shifter(image)\n            np.testing.assert_equal(result.dtype, dtype_numpy_to_torch(trans_dtype))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_std_shift_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import ShiftIntensityd, StdShiftIntensityd\nfrom monai.utils import dtype_numpy_to_torch\nfrom tests.test_utils import NumpyImageTestCase2D\n\n\nclass TestStdShiftIntensityd(NumpyImageTestCase2D):\n    def test_value(self):\n        key = \"img\"\n        factor = np.random.rand()\n        offset = np.std(self.imt) * factor\n        shifter = ShiftIntensityd(keys=[key], offset=offset)\n        expected = shifter({key: self.imt})\n        std_shifter = StdShiftIntensityd(keys=[key], factor=factor)\n        result = std_shifter({key: self.imt})\n        np.testing.assert_allclose(result[key], expected[key], rtol=1e-5)\n\n    def test_zerostd(self):\n        key = \"img\"\n        image = np.ones([2, 3, 3])\n        for nonzero in [True, False]:\n            for channel_wise in [True, False]:\n                factor = np.random.rand()\n                std_shifter = StdShiftIntensityd(keys=[key], factor=factor, nonzero=nonzero, channel_wise=channel_wise)\n                result = std_shifter({key: image})\n                np.testing.assert_allclose(result[key], image, rtol=1e-5)\n\n    def test_nonzero(self):\n        key = \"img\"\n        image = np.asarray([[4.0, 0.0, 2.0], [0, 2, 4]])  # std = 1\n        factor = np.random.rand()\n        std_shifter = StdShiftIntensityd(keys=[key], factor=factor, nonzero=True)\n        result = std_shifter({key: image})\n        expected = np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]])\n        np.testing.assert_allclose(result[key], expected, rtol=1e-5)\n\n    def test_channel_wise(self):\n        key = \"img\"\n        image = np.stack((np.asarray([1.0, 2.0]), np.asarray([1.0, 1.0])))  # std: 0.5, 0\n        factor = np.random.rand()\n        std_shifter = StdShiftIntensityd(keys=[key], factor=factor, channel_wise=True)\n        result = std_shifter({key: image})\n        expected = np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1])))\n        np.testing.assert_allclose(result[key], expected, rtol=1e-5)\n\n    def test_dtype(self):\n        key = \"img\"\n        trans_dtype = np.float32\n        for dtype in [int, np.float32, np.float64]:\n            image = np.random.rand(2, 2, 2).astype(dtype)\n            factor = np.random.rand()\n            std_shifter = StdShiftIntensityd(keys=[key], factor=factor, dtype=trans_dtype)\n            result = std_shifter({key: image})\n            np.testing.assert_equal(result[key].dtype, dtype_numpy_to_torch(trans_dtype))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_threshold_intensity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ThresholdIntensity\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p, {\"threshold\": 5, \"above\": True, \"cval\": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)])\n    TESTS.append([p, {\"threshold\": 5, \"above\": False, \"cval\": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)])\n    TESTS.append([p, {\"threshold\": 5, \"above\": True, \"cval\": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)])\n\n\nclass TestThresholdIntensity(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, in_type, input_param, expected_value):\n        test_data = in_type(np.arange(10))\n        result = ThresholdIntensity(**input_param)(test_data)\n        assert_allclose(result, in_type(expected_value), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_threshold_intensityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ThresholdIntensityd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append(\n        [\n            p,\n            {\"keys\": [\"image\", \"label\", \"extra\"], \"threshold\": 5, \"above\": True, \"cval\": 0},\n            (0, 0, 0, 0, 0, 0, 6, 7, 8, 9),\n        ]\n    )\n    TESTS.append(\n        [\n            p,\n            {\"keys\": [\"image\", \"label\", \"extra\"], \"threshold\": 5, \"above\": False, \"cval\": 0},\n            (0, 1, 2, 3, 4, 0, 0, 0, 0, 0),\n        ]\n    )\n    TESTS.append(\n        [\n            p,\n            {\"keys\": [\"image\", \"label\", \"extra\"], \"threshold\": 5, \"above\": True, \"cval\": 5},\n            (5, 5, 5, 5, 5, 5, 6, 7, 8, 9),\n        ]\n    )\n\n\nclass TestThresholdIntensityd(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, in_type, input_param, expected_value):\n        test_data = {\"image\": in_type(np.arange(10)), \"label\": in_type(np.arange(10)), \"extra\": in_type(np.arange(10))}\n        result = ThresholdIntensityd(**input_param)(test_data)\n        assert_allclose(result[\"image\"], in_type(expected_value), type_test=\"tensor\")\n        assert_allclose(result[\"label\"], in_type(expected_value), type_test=\"tensor\")\n        assert_allclose(result[\"extra\"], in_type(expected_value), type_test=\"tensor\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_contiguous.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import convert_to_contiguous\nfrom tests.test_utils import assert_allclose\n\n\nclass TestToContiguous(unittest.TestCase):\n    def test_contiguous_dict(self):\n        tochange = np.moveaxis(np.zeros((2, 3, 4)), 0, -1)\n        test_dict = {\"test_key\": [[1]], 0: np.array(0), 1: np.array([0]), \"nested\": {\"nested\": [tochange]}}\n        output = convert_to_contiguous(test_dict)\n        self.assertEqual(output[\"test_key\"], [[1]])\n        assert_allclose(output[0], np.array(0))\n        assert_allclose(output[1], np.array([0]))\n        self.assertTrue(output[\"nested\"][\"nested\"][0].flags.c_contiguous)\n\n    def test_contiguous_seq(self):\n        tochange = torch.zeros(2, 3, 4).transpose(0, 1)\n        test_seq = [[[1]], np.array(0), np.array([0]), torch.tensor(1.0), [[tochange]], \"test_string\", (1, 2, 3)]\n        output = convert_to_contiguous(test_seq)\n        self.assertEqual(output[0], [[1]])\n        assert_allclose(output[1], np.array(0))\n        assert_allclose(output[2], np.array([0]))\n        assert_allclose(output[3], torch.tensor(1.0))\n        self.assertTrue(output[4][0][0].is_contiguous())\n        self.assertEqual(output[5], \"test_string\")\n        self.assertEqual(output[6], (1, 2, 3))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_cupy.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import ToCupy\nfrom monai.utils import optional_import\nfrom tests.test_utils import HAS_CUPY, skip_if_no_cuda\n\ncp, _ = optional_import(\"cupy\")\n\n\n@skipUnless(HAS_CUPY, \"CuPy is required.\")\nclass TestToCupy(unittest.TestCase):\n    def test_cupy_input(self):\n        test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32)\n        test_data = cp.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToCupy()(test_data)\n        self.assertTrue(result.dtype == cp.float32)\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_cupy_input_dtype(self):\n        test_data = cp.array([[1, 2], [3, 4]], dtype=cp.float32)\n        test_data = cp.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToCupy(cp.uint8)(test_data)\n        self.assertTrue(result.dtype == cp.uint8)\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_numpy_input(self):\n        test_data = np.array([[1, 2], [3, 4]], dtype=np.float32)\n        test_data = np.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToCupy()(test_data)\n        self.assertTrue(result.dtype == cp.float32)\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_numpy_input_dtype(self):\n        test_data = np.array([[1, 2], [3, 4]], dtype=np.float32)\n        test_data = np.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToCupy(np.uint8)(test_data)\n        self.assertEqual(result.dtype, cp.uint8)\n        self.assertIsInstance(result, cp.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_tensor_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToCupy()(test_data)\n        self.assertEqual(result.dtype, cp.float32)\n        self.assertIsInstance(result, cp.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    @skip_if_no_cuda\n    def test_tensor_cuda_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32).cuda()\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToCupy()(test_data)\n        self.assertEqual(result.dtype, cp.float32)\n        self.assertIsInstance(result, cp.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    @skip_if_no_cuda\n    def test_tensor_cuda_input_dtype(self):\n        test_data = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint8).cuda()\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n\n        result = ToCupy(dtype=\"float32\")(test_data)\n        self.assertEqual(result.dtype, cp.float32)\n        self.assertIsInstance(result, cp.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_list_tuple(self):\n        test_data = [[1, 2], [3, 4]]\n        result = ToCupy(wrap_sequence=True)(test_data)\n        cp.testing.assert_allclose(result, cp.asarray(test_data))\n        test_data = ((1, 2), (3, 4))\n        result = ToCupy(wrap_sequence=True)(test_data)\n        cp.testing.assert_allclose(result, cp.asarray(test_data))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_cupyd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import ToCupyd\nfrom monai.utils import optional_import\nfrom tests.test_utils import HAS_CUPY, skip_if_no_cuda\n\ncp, _ = optional_import(\"cupy\")\n\n\n@skipUnless(HAS_CUPY, \"CuPy is required.\")\nclass TestToCupyd(unittest.TestCase):\n    def test_cupy_input(self):\n        test_data = cp.array([[1, 2], [3, 4]])\n        test_data = cp.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToCupyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_numpy_input(self):\n        test_data = np.array([[1, 2], [3, 4]])\n        test_data = np.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToCupyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data)\n\n    def test_tensor_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]])\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToCupyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data.numpy())\n\n    @skip_if_no_cuda\n    def test_tensor_cuda_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]]).cuda()\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToCupyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, cp.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        cp.testing.assert_allclose(result, test_data.cpu().numpy())\n\n    def test_list_tuple(self):\n        test_data = [[1, 2], [3, 4]]\n        result = ToCupyd(keys=\"img\", wrap_sequence=True)({\"img\": test_data})[\"img\"]\n        cp.testing.assert_allclose(result, cp.asarray(test_data))\n        test_data = ((1, 2), (3, 4))\n        result = ToCupyd(keys=\"img\", wrap_sequence=True)({\"img\": test_data})[\"img\"]\n        cp.testing.assert_allclose(result, cp.asarray(test_data))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_device.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import ToDevice\nfrom tests.test_utils import assert_allclose, skip_if_no_cuda\n\nTEST_CASE_1 = [\"cuda:0\"]\n\nTEST_CASE_2 = [\"cuda\"]\n\nTEST_CASE_3 = [torch.device(\"cpu:0\")]\n\nTEST_CASE_4 = [\"cpu\"]\n\n\n@skip_if_no_cuda\nclass TestToDevice(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n    def test_value(self, device):\n        converter = ToDevice(device=device, non_blocking=True)\n        data = torch.tensor([1, 2, 3, 4])\n        ret = converter(data)\n        assert_allclose(ret, data.to(device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_deviced.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.data import CacheDataset, ThreadDataLoader\nfrom monai.transforms import ToDeviced\nfrom tests.test_utils import assert_allclose, skip_if_no_cuda\n\n\n@skip_if_no_cuda\nclass TestToDeviced(unittest.TestCase):\n    def test_value(self):\n        device = \"cuda:0\"\n        data = [{\"img\": torch.tensor(i)} for i in range(4)]\n        dataset = CacheDataset(\n            data=data, transform=ToDeviced(keys=\"img\", device=device, non_blocking=True), cache_rate=1.0\n        )\n        dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1)\n        for i, d in enumerate(dataloader):\n            assert_allclose(d[\"img\"], torch.tensor([i], device=device))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_numpy.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import ToNumpy\nfrom monai.utils import optional_import\nfrom tests.test_utils import HAS_CUPY, assert_allclose, skip_if_no_cuda\n\ncp, _ = optional_import(\"cupy\")\n\n\nclass TestToNumpy(unittest.TestCase):\n    @skipUnless(HAS_CUPY, \"CuPy is required.\")\n    def test_cupy_input(self):\n        test_data = cp.array([[1, 2], [3, 4]])\n        test_data = cp.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToNumpy()(test_data)\n        self.assertIsInstance(result, np.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data.get(), type_test=False)\n\n    def test_numpy_input(self):\n        test_data = np.array([[1, 2], [3, 4]])\n        test_data = np.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToNumpy(dtype=\"float32\")(test_data)\n        self.assertIsInstance(result, np.ndarray)\n        self.assertEqual(result.dtype, np.float32)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data, type_test=False)\n\n    def test_tensor_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]])\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToNumpy(dtype=torch.uint8)(test_data)\n        self.assertIsInstance(result, np.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data, type_test=False)\n\n    @skip_if_no_cuda\n    def test_tensor_cuda_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]]).cuda()\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToNumpy()(test_data)\n        self.assertIsInstance(result, np.ndarray)\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data, type_test=False)\n\n    def test_list_tuple(self):\n        test_data = [[1, 2], [3, 4]]\n        result = ToNumpy()(test_data)\n        assert_allclose(result, np.asarray(test_data), type_test=False)\n        test_data = ((1, 2), (3, 4))\n        result = ToNumpy(wrap_sequence=False)(test_data)\n        self.assertIsInstance(result, tuple)\n        assert_allclose(result, ((np.asarray(1), np.asarray(2)), (np.asarray(3), np.asarray(4))))\n\n    def test_single_value(self):\n        for test_data in [5, np.array(5), torch.tensor(5)]:\n            result = ToNumpy(dtype=np.uint8)(test_data)\n            self.assertIsInstance(result, np.ndarray)\n            assert_allclose(result, np.asarray(test_data), type_test=False)\n            self.assertEqual(result.ndim, 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_numpyd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import ToNumpyd\nfrom monai.utils import optional_import\nfrom tests.test_utils import HAS_CUPY, assert_allclose, skip_if_no_cuda\n\ncp, _ = optional_import(\"cupy\")\n\n\nclass TestToNumpyd(unittest.TestCase):\n    @skipUnless(HAS_CUPY, \"CuPy is required.\")\n    def test_cupy_input(self):\n        test_data = cp.array([[1, 2], [3, 4]])\n        test_data = cp.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToNumpyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, np.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data.get(), type_test=False)\n\n    def test_numpy_input(self):\n        test_data = np.array([[1, 2], [3, 4]])\n        test_data = np.rot90(test_data)\n        self.assertFalse(test_data.flags[\"C_CONTIGUOUS\"])\n        result = ToNumpyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, np.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data, type_test=False)\n\n    def test_tensor_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]])\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToNumpyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, np.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data, type_test=False)\n\n    @skip_if_no_cuda\n    def test_tensor_cuda_input(self):\n        test_data = torch.tensor([[1, 2], [3, 4]]).cuda()\n        test_data = test_data.rot90()\n        self.assertFalse(test_data.is_contiguous())\n        result = ToNumpyd(keys=\"img\")({\"img\": test_data})[\"img\"]\n        self.assertTrue(isinstance(result, np.ndarray))\n        self.assertTrue(result.flags[\"C_CONTIGUOUS\"])\n        assert_allclose(result, test_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_pil.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import TYPE_CHECKING\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ToPIL\nfrom monai.utils import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nif TYPE_CHECKING:\n    from PIL.Image import Image as PILImageImage\n    from PIL.Image import fromarray as pil_image_fromarray\n\n    has_pil = True\nelse:\n    pil_image_fromarray, has_pil = optional_import(\"PIL.Image\", name=\"fromarray\")\n    PILImageImage, _ = optional_import(\"PIL.Image\", name=\"Image\")\n\nim = [[1.0, 2.0], [3.0, 4.0]]\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p(im)])\nif has_pil:\n    TESTS.append([pil_image_fromarray(np.array(im))])\n\n\nclass TestToPIL(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    @skipUnless(has_pil, \"Requires `pillow` package.\")\n    def test_value(self, test_data):\n        result = ToPIL()(test_data)\n        self.assertTrue(isinstance(result, PILImageImage))\n        assert_allclose(np.array(result), test_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_pild.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import TYPE_CHECKING\nfrom unittest import skipUnless\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms import ToPILd\nfrom monai.utils import optional_import\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nif TYPE_CHECKING:\n    from PIL.Image import Image as PILImageImage\n    from PIL.Image import fromarray as pil_image_fromarray\n\n    has_pil = True\nelse:\n    pil_image_fromarray, has_pil = optional_import(\"PIL.Image\", name=\"fromarray\")\n    PILImageImage, _ = optional_import(\"PIL.Image\", name=\"Image\")\n\nim = [[1.0, 2.0], [3.0, 4.0]]\nTESTS = [[{\"keys\": \"image\"}, {\"image\": p(im)}] for p in TEST_NDARRAYS]\nif has_pil:\n    TESTS.append([{\"keys\": \"image\"}, {\"image\": pil_image_fromarray(np.array(im))}])\n\n\nclass TestToPIL(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    @skipUnless(has_pil, \"Requires `pillow` package.\")\n    def test_values(self, input_param, test_data):\n        result = ToPILd(**input_param)(test_data)[input_param[\"keys\"]]\n        self.assertTrue(isinstance(result, PILImageImage))\n        assert_allclose(np.array(result), test_data[input_param[\"keys\"]], type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_tensor.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import ToTensor\nfrom tests.test_utils import HAS_CUPY, TEST_NDARRAYS, assert_allclose, optional_import\n\ncp, _ = optional_import(\"cupy\")\n\nim = [[1, 2], [3, 4]]\n\nTESTS = [(im, (2, 2))]\nfor p in TEST_NDARRAYS:\n    TESTS.append((p(im), (2, 2)))\n\nTESTS_SINGLE = [[5]]\nfor p in TEST_NDARRAYS:\n    TESTS_SINGLE.append([p(5)])\n\n\nclass TestToTensor(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_array_input(self, test_data, expected_shape):\n        result = ToTensor(dtype=torch.float32, device=\"cpu\", wrap_sequence=True)(test_data)\n        self.assertTrue(isinstance(result, torch.Tensor))\n        assert_allclose(result, test_data, type_test=False)\n        self.assertTupleEqual(result.shape, expected_shape)\n\n    @parameterized.expand(TESTS_SINGLE)\n    def test_single_input(self, test_data):\n        result = ToTensor(track_meta=True)(test_data)\n        self.assertTrue(isinstance(result, torch.Tensor))\n        assert_allclose(result, test_data, type_test=False)\n        self.assertEqual(result.ndim, 0)\n\n    @unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n    def test_cupy(self):\n        test_data = [[1, 2], [3, 4]]\n        cupy_array = cp.ascontiguousarray(cp.asarray(test_data))\n        result = ToTensor()(cupy_array)\n        self.assertTrue(isinstance(result, torch.Tensor))\n        assert_allclose(result, test_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_to_tensord.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import ToTensord\nfrom tests.test_utils import HAS_CUPY, TEST_NDARRAYS, assert_allclose, optional_import\n\ncp, _ = optional_import(\"cupy\")\n\nim = [[1, 2], [3, 4]]\n\nTESTS = [(im, (2, 2))]\nfor p in TEST_NDARRAYS:\n    TESTS.append((p(im), (2, 2)))\n\nTESTS_SINGLE = [[5]]\nfor p in TEST_NDARRAYS:\n    TESTS_SINGLE.append([p(5)])\n\n\nclass TestToTensord(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_array_input(self, test_data, expected_shape):\n        test_data = {\"img\": test_data}\n        to_tensord = ToTensord(keys=\"img\", dtype=torch.float32, device=\"cpu\", wrap_sequence=True)\n        result = to_tensord(test_data)\n        out_img = result[\"img\"]\n        self.assertTrue(isinstance(out_img, torch.Tensor))\n        assert_allclose(out_img, test_data[\"img\"], type_test=False)\n        self.assertTupleEqual(out_img.shape, expected_shape)\n\n        # test inverse\n        inv_data = to_tensord.inverse(result)\n        self.assertTrue(isinstance(inv_data[\"img\"], np.ndarray))\n        assert_allclose(test_data[\"img\"], inv_data[\"img\"], type_test=False)\n\n    @parameterized.expand(TESTS_SINGLE)\n    def test_single_input(self, test_data):\n        test_data = {\"img\": test_data}\n        result = ToTensord(keys=\"img\", track_meta=True)(test_data)\n        out_img = result[\"img\"]\n        self.assertTrue(isinstance(out_img, torch.Tensor))\n        assert_allclose(out_img, test_data[\"img\"], type_test=False)\n        self.assertEqual(out_img.ndim, 0)\n\n    @unittest.skipUnless(HAS_CUPY, \"CuPy is required.\")\n    def test_cupy(self):\n        test_data = [[1, 2], [3, 4]]\n        cupy_array = cp.ascontiguousarray(cp.asarray(test_data))\n        result = ToTensord(keys=\"img\")({\"img\": cupy_array})\n        self.assertTrue(isinstance(result[\"img\"], torch.Tensor))\n        assert_allclose(result[\"img\"], test_data, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_torchio.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import TorchIO\nfrom monai.utils import optional_import\n\n_, has_torchio = optional_import(\"torchio\")\n\nTEST_DIMS = [3, 128, 160, 160]\nTESTS = [[{\"name\": \"RescaleIntensity\"}, torch.rand(TEST_DIMS)], [{\"name\": \"ZNormalization\"}, torch.rand(TEST_DIMS)]]\n\n\n@skipUnless(has_torchio, \"Requires torchio\")\nclass TestTorchIO(unittest.TestCase):\n\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, input_data):\n        result = TorchIO(**input_param)(input_data)\n        self.assertIsNotNone(result)\n        self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f\"{input_param} failed\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_torchiod.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest import skipUnless\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import TorchIOd\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose\n\n_, has_torchio = optional_import(\"torchio\")\n\nTEST_DIMS = [3, 128, 160, 160]\nTEST_TENSOR = torch.rand(TEST_DIMS)\nTEST_PARAMS = [\n    [\n        {\"keys\": \"img\", \"name\": \"RescaleIntensity\", \"out_min_max\": (0, 42)},\n        {\"img\": TEST_TENSOR},\n        ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42,\n    ]\n]\n\n\n@skipUnless(has_torchio, \"Requires torchio\")\nclass TestTorchIOd(unittest.TestCase):\n    @parameterized.expand(TEST_PARAMS)\n    def test_value(self, input_param, input_data, expected_value):\n        result = TorchIOd(**input_param)(input_data)\n        assert_allclose(result[\"img\"], expected_value, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_torchvision.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.transforms import TorchVision\nfrom monai.utils import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.extend(\n        [\n            [\n                {\"name\": \"ColorJitter\"},\n                p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),\n                p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),\n            ],\n            [\n                {\"name\": \"ColorJitter\", \"brightness\": 0.5, \"contrast\": 0.5, \"saturation\": [0.1, 0.8], \"hue\": 0.5},\n                p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),\n                p(\n                    [\n                        [[0.1090, 0.6193], [0.6193, 0.9164]],\n                        [[0.1090, 0.6193], [0.6193, 0.9164]],\n                        [[0.1090, 0.6193], [0.6193, 0.9164]],\n                    ]\n                ),\n            ],\n            [\n                {\"name\": \"Pad\", \"padding\": [1, 1, 1, 1]},\n                p([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),\n                p(\n                    [\n                        [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n                        [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n                        [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n                    ]\n                ),\n            ],\n        ]\n    )\n\n\nclass TestTorchVision(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, input_data, expected_value):\n        set_determinism(seed=0)\n        result = TorchVision(**input_param)(input_data)\n        assert_allclose(result, expected_value, rtol=1e-3)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_torchvisiond.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import TorchVisiond\nfrom monai.utils import set_determinism\nfrom tests.test_utils import assert_allclose\n\nTEST_CASE_1 = [\n    {\"keys\": \"img\", \"name\": \"ColorJitter\"},\n    {\"img\": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},\n    torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]]),\n]\n\nTEST_CASE_2 = [\n    {\"keys\": \"img\", \"name\": \"ColorJitter\", \"brightness\": 0.5, \"contrast\": 0.5, \"saturation\": [0.1, 0.8], \"hue\": 0.5},\n    {\"img\": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},\n    torch.tensor(\n        [\n            [[0.1090, 0.6193], [0.6193, 0.9164]],\n            [[0.1090, 0.6193], [0.6193, 0.9164]],\n            [[0.1090, 0.6193], [0.6193, 0.9164]],\n        ]\n    ),\n]\n\nTEST_CASE_3 = [\n    {\"keys\": \"img\", \"name\": \"Pad\", \"padding\": [1, 1, 1, 1]},\n    {\"img\": torch.tensor([[[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]], [[0.0, 1.0], [1.0, 2.0]]])},\n    torch.tensor(\n        [\n            [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n            [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n            [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n        ]\n    ),\n]\n\n\nclass TestTorchVisiond(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_value(self, input_param, input_data, expected_value):\n        set_determinism(seed=0)\n        result = TorchVisiond(**input_param)(input_data)\n        assert_allclose(result[\"img\"], expected_value, atol=1e-4, rtol=1e-4, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\n\nimport monai.transforms as mt\nfrom monai.data import Dataset\nfrom monai.utils.misc import MONAIEnvVars\n\n\nclass FaultyTransform(mt.Transform):\n\n    def __call__(self, _):\n        raise RuntimeError\n\n\ndef faulty_lambda(_):\n    raise RuntimeError\n\n\nclass TestTransform(unittest.TestCase):\n\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls.orig_value = str(MONAIEnvVars.debug())\n\n    @classmethod\n    def tearDownClass(cls):\n        if cls.orig_value is not None:\n            os.environ[\"MONAI_DEBUG\"] = cls.orig_value\n        else:\n            os.environ.pop(\"MONAI_DEBUG\")\n        super().tearDownClass()\n\n    def test_raise(self):\n        for transform in (FaultyTransform(), mt.Lambda(faulty_lambda)):\n            ds = Dataset([None] * 10, transform)\n            for debug in (\"False\", \"True\"):\n                os.environ[\"MONAI_DEBUG\"] = debug\n                try:\n                    ds[0]\n                except RuntimeError as re:\n                    if debug == \"False\":\n                        self.assertTrue(\"applying transform\" in str(re))\n                    else:\n                        self.assertFalse(\"applying transform\" in str(re))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_transpose.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Transpose\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None])\n    TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]])\n\n\nclass TestTranspose(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_transpose(self, im, indices):\n        tr = Transpose(indices)\n        out1 = tr(im)\n        if isinstance(im, torch.Tensor):\n            im = im.cpu().numpy()\n        out2 = np.transpose(im, indices)\n        assert_allclose(out1, out2, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_transposed.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import Transposed\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), [1, 0]])\n    TESTS.append([p(np.arange(5 * 4).reshape(5, 4)), None])\n    TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), [2, 0, 1]])\n    TESTS.append([p(np.arange(5 * 4 * 3).reshape(5, 4, 3)), None])\n\n\nclass TestTranspose(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_transpose(self, im, indices):\n        data = {\"i\": deepcopy(im), \"j\": deepcopy(im)}\n        tr = Transposed([\"i\", \"j\"], indices)\n        out_data = tr(data)\n        out_im1, out_im2 = out_data[\"i\"], out_data[\"j\"]\n        if isinstance(im, torch.Tensor):\n            im = im.cpu().numpy()\n        out_gt = np.transpose(im, indices)\n        assert_allclose(out_im1, out_gt, type_test=False)\n        assert_allclose(out_im2, out_gt, type_test=False)\n\n        # test inverse\n        fwd_inv_data = tr.inverse(out_data)\n        for i, j in zip(data.values(), fwd_inv_data.values()):\n            assert_allclose(i, j, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_ultrasound_confidence_map_transform.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom PIL import Image\n\nfrom monai.transforms import UltrasoundConfidenceMapTransform\nfrom monai.utils import optional_import\nfrom tests.test_utils import assert_allclose\n\n_, has_scipy = optional_import(\"scipy\")\n\nTESTS_PATH = Path(__file__).parents[1]\n\nTEST_INPUT = np.array(\n    [\n        [1, 2, 3, 23, 13, 22, 5, 1, 2, 3],\n        [1, 2, 3, 12, 4, 6, 9, 1, 2, 3],\n        [1, 2, 3, 8, 7, 10, 11, 1, 2, 3],\n        [1, 2, 3, 14, 15, 16, 17, 1, 2, 3],\n        [1, 2, 3, 18, 19, 20, 21, 1, 2, 3],\n        [1, 2, 3, 24, 25, 26, 27, 1, 2, 3],\n        [1, 2, 3, 28, 29, 30, 31, 1, 2, 3],\n        [1, 2, 3, 32, 33, 34, 35, 1, 2, 3],\n        [1, 2, 3, 36, 37, 38, 39, 1, 2, 3],\n        [1, 2, 3, 40, 41, 42, 43, 1, 2, 3],\n    ],\n    dtype=np.float32,\n)\n\nTEST_MASK = np.array(\n    [\n        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [1, 1, 1, 0, 0, 0, 1, 1, 1, 0],\n        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n    ],\n    dtype=np.float32,\n)\n\nSINK_ALL_OUTPUT = np.array(\n    [\n        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n        [\n            0.8884930952884654,\n            0.8626656901726876,\n            0.8301161870669913,\n            0.9757179300830185,\n            0.9989819637626414,\n            0.9994717624885747,\n            0.9954377526794013,\n            0.8898638133944221,\n            0.862604343021387,\n            0.8277862494812598,\n        ],\n        [\n            0.7765718877433174,\n            0.7363731552518268,\n            0.6871875923653379,\n            0.9753673327387775,\n            0.9893175316399789,\n            0.9944181334242039,\n            0.9936979128319371,\n            0.7778001700035326,\n            0.7362622619974832,\n            0.6848377775329241,\n        ],\n        [\n            0.6648416226360719,\n            0.6178079903692397,\n            0.5630152545966568,\n            0.8278402502498404,\n            0.82790391019578,\n            0.8289702087149963,\n            0.8286730258710652,\n            0.6658773633169731,\n            0.6176836507071695,\n            0.5609165245633834,\n        ],\n        [\n            0.5534420483956817,\n            0.5055401989946189,\n            0.451865872383879,\n            0.7541423053657541,\n            0.7544115886347456,\n            0.7536884376055174,\n            0.7524927915364896,\n            0.5542943466824017,\n            0.505422678400297,\n            0.4502051549732117,\n        ],\n        [\n            0.4423657561928356,\n            0.398221575954319,\n            0.35030055029978124,\n            0.4793202144786371,\n            0.48057175662074125,\n            0.4812057229564038,\n            0.48111949176149327,\n            0.44304092606050766,\n            0.39812149713417405,\n            0.34902458531143377,\n        ],\n        [\n            0.3315561576450342,\n            0.29476346732036784,\n            0.2558303772864961,\n            0.35090405668257535,\n            0.3515225984307705,\n            0.35176548159366317,\n            0.3516979775419521,\n            0.33205839061494885,\n            0.2946859567272435,\n            0.2549042599220772,\n        ],\n        [\n            0.22094175240967673,\n            0.19431840633358133,\n            0.16672448058324435,\n            0.22716195845848167,\n            0.22761996456848282,\n            0.22782525614780919,\n            0.22781876632199002,\n            0.22127471252104777,\n            0.19426593309729956,\n            0.16612306610996525,\n        ],\n        [\n            0.11044782531624744,\n            0.09623229814933323,\n            0.08174664901235043,\n            0.11081911718888311,\n            0.11102310514207447,\n            0.1111041051969924,\n            0.11108329076967229,\n            0.11061376973431204,\n            0.09620592927336903,\n            0.08145227209865454,\n        ],\n        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n    ],\n    dtype=np.float32,\n)\n\nSINK_MID_OUTPUT = np.array(\n    [\n        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n        [\n            0.9999957448889315,\n            0.9999781044114231,\n            0.9999142422442185,\n            0.999853253199584,\n            0.9999918403054282,\n            0.9999874855193227,\n            0.9999513619364747,\n            0.9999589247003497,\n            0.9999861765528631,\n            0.9999939213967494,\n        ],\n        [\n            0.9999918011366045,\n            0.9999588498417253,\n            0.9998388659316617,\n            0.9998496524281603,\n            0.9999154673258592,\n            0.9997827845182361,\n            0.9998160234579786,\n            0.9999163964511287,\n            0.9999743435786168,\n            0.9999894752861168,\n        ],\n        [\n            0.9999883847481621,\n            0.9999427334014465,\n            0.9997703972600652,\n            0.9853967608835997,\n            0.9852517829915376,\n            0.9853308520519438,\n            0.9854102394414211,\n            0.9998728503298413,\n            0.9999642585978225,\n            0.999986204909933,\n        ],\n        [\n            0.999985544721449,\n            0.9999296195017368,\n            0.9997066149628903,\n            0.9753803016111353,\n            0.9750688049429371,\n            0.9749211929217173,\n            0.9750052047129354,\n            0.9998284130289159,\n            0.9999558481338295,\n            0.9999837966320273,\n        ],\n        [\n            0.9999832723447848,\n            0.9999192263814408,\n            0.9996472692076177,\n            0.90541293509353,\n            0.9049945536526819,\n            0.9051142437853055,\n            0.9057005861296792,\n            0.9997839348839027,\n            0.9999490318922627,\n            0.9999820419085812,\n        ],\n        [\n            0.9999815409510937,\n            0.9999113168889934,\n            0.9995930143319085,\n            0.8370025145062345,\n            0.8358345435164332,\n            0.8358231468627223,\n            0.8369430449157075,\n            0.9997408260265034,\n            0.9999437526409107,\n            0.9999808010740554,\n        ],\n        [\n            0.9999803198262347,\n            0.9999057164296593,\n            0.9995461103528891,\n            0.7047260555380003,\n            0.7023346743490383,\n            0.7022946969603594,\n            0.7045662738042475,\n            0.9997017258131392,\n            0.9999399744001316,\n            0.9999799785302944,\n        ],\n        [\n            0.9999795785255197,\n            0.9999022923125928,\n            0.999510772973329,\n            0.46283993237260707,\n            0.4577365087549323,\n            0.4571888733219068,\n            0.4614967878524538,\n            0.9996710272733927,\n            0.9999376682163403,\n            0.9999795067125865,\n        ],\n        [\n            0.9999792877553907,\n            0.9999009179811408,\n            0.9994950057121632,\n            0.05049460567213739,\n            0.030946131978013824,\n            0.0,\n            0.019224121648385283,\n            0.9996568912408903,\n            0.9999367861122628,\n            0.9999793358521326,\n        ],\n    ],\n    dtype=np.float32,\n)\n\nSINK_MIN_OUTPUT = np.array(\n    [\n        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n        [\n            0.9999961997987318,\n            0.9999801752476248,\n            0.9999185667341594,\n            0.9993115972922259,\n            0.9999536433504382,\n            0.9997590064584757,\n            0.9963282396026231,\n            0.9020645423682648,\n            0.965641014946897,\n            0.9847003633599846,\n        ],\n        [\n            0.9999926824858815,\n            0.9999628275604145,\n            0.9998472915971415,\n            0.9992953054409239,\n            0.9995550237000549,\n            0.9972853256638443,\n            0.9958871482234863,\n            0.8006505271617617,\n            0.9360757301263053,\n            0.9734843475613124,\n        ],\n        [\n            0.9999896427490426,\n            0.9999484707116104,\n            0.9997841142091455,\n            0.9321779021295554,\n            0.9308591506422442,\n            0.9299937642438358,\n            0.9286536283468563,\n            0.6964658886602826,\n            0.9106656689679997,\n            0.9652109119709528,\n        ],\n        [\n            0.9999871227708508,\n            0.9999369646510842,\n            0.9997276125796202,\n            0.9006206490361908,\n            0.8987968702587018,\n            0.8965696900664386,\n            0.8941507574801211,\n            0.5892568658180841,\n            0.8892240419729905,\n            0.9590996257620853,\n        ],\n        [\n            0.9999851119906539,\n            0.9999280075234918,\n            0.9996788394671484,\n            0.778755271203017,\n            0.7763917808258874,\n            0.7737517385551721,\n            0.7707980517990098,\n            0.4788014936236403,\n            0.8715671104783401,\n            0.954632732759503,\n        ],\n        [\n            0.9999835837292402,\n            0.999921323618806,\n            0.9996389455307461,\n            0.7222961578407286,\n            0.7186158832946955,\n            0.7146983167265393,\n            0.7105768254632475,\n            0.3648911004360315,\n            0.8575943501305144,\n            0.9514642802768379,\n        ],\n        [\n            0.9999825081019064,\n            0.999916683268467,\n            0.9996093996776352,\n            0.6713490686473397,\n            0.6664914636518112,\n            0.6613110504728309,\n            0.6558325489984669,\n            0.247299682539502,\n            0.8473037957967624,\n            0.9493580587294981,\n        ],\n        [\n            0.999981856118739,\n            0.9999138938063622,\n            0.9995907248497593,\n            0.6331535096751639,\n            0.6271637176135582,\n            0.6206687804556549,\n            0.6136262027168252,\n            0.12576864809108962,\n            0.8407892431959736,\n            0.9481472656653798,\n        ],\n        [\n            0.9999816006081851,\n            0.9999127861527936,\n            0.9995832399159849,\n            0.6133274396648696,\n            0.6086364734302403,\n            0.6034602717119345,\n            0.5978473214165134,\n            0.0,\n            0.8382338778894218,\n            0.9477082231321966,\n        ],\n    ],\n    dtype=np.float32,\n)\n\nSINK_MASK_OUTPUT = np.array(\n    [\n        [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n        [0.0, 0.0, 0.0, 0.9047934405899283, 0.9936046284605553, 0.9448690902377527, 0.0, 0.0, 0.0, 0.8363773255131761],\n        [0.0, 0.0, 0.0, 0.90375200446097, 0.9434594475474036, 0.4716831449516178, 0.0, 0.0, 0.0, 0.7364197333910302],\n        [\n            0.0,\n            0.0,\n            0.0,\n            0.09080438801405301,\n            0.06774182873204163,\n            0.038207095016625024,\n            0.0,\n            0.0,\n            0.0,\n            0.6745641479264269,\n        ],\n        [\n            0.0,\n            0.0,\n            0.0,\n            0.01731082802870267,\n            0.013540929458217351,\n            0.007321202161532623,\n            0.0,\n            0.0,\n            0.0,\n            0.6341231654271253,\n        ],\n        [\n            0.0,\n            0.0,\n            0.0,\n            0.0006444251665178544,\n            0.0005397129128756325,\n            0.0003048384803626333,\n            0.0,\n            0.0,\n            0.0,\n            0.6070178708536365,\n        ],\n        [\n            0.0,\n            0.0,\n            0.0,\n            5.406078586212675e-05,\n            4.416783924970537e-05,\n            2.4597362039020103e-05,\n            0.0,\n            0.0,\n            0.0,\n            0.5889413683184284,\n        ],\n        [\n            0.0,\n            0.0,\n            0.0,\n            4.39259327223233e-06,\n            3.6050656774754658e-06,\n            2.0127120155893425e-06,\n            0.0,\n            0.0,\n            0.0,\n            0.5774279920364456,\n        ],\n        [\n            0.0,\n            0.0,\n            0.0,\n            4.0740501726718113e-07,\n            3.374875487404489e-07,\n            1.9113630985667455e-07,\n            0.0,\n            0.0,\n            0.0,\n            0.5709897726747111,\n        ],\n        [\n            3.2266922388030425e-17,\n            1.801110982679718e-14,\n            9.325899448306927e-12,\n            3.913608442133728e-07,\n            3.9581822403393465e-07,\n            4.02383505118481e-07,\n            4.14820241328287e-07,\n            4.281640797396309e-06,\n            0.0023900192231620593,\n            0.5686882523793125,\n        ],\n    ],\n    dtype=np.float32,\n)\n\n\n@unittest.skipUnless(has_scipy, \"Requires scipy\")\nclass TestUltrasoundConfidenceMapTransform(unittest.TestCase):\n    def setUp(self):\n        self.input_img_np = np.expand_dims(TEST_INPUT, axis=0)  # mock image (numpy array)\n        self.input_mask_np = np.expand_dims(TEST_MASK, axis=0)  # mock mask (numpy array)\n\n        self.input_img_torch = torch.from_numpy(TEST_INPUT).unsqueeze(0)  # mock image (torch tensor)\n        self.input_mask_torch = torch.from_numpy(TEST_MASK).unsqueeze(0)  # mock mask (torch tensor)\n\n        self.real_input_img_paths = [\n            os.path.join(TESTS_PATH, \"testing_data\", \"ultrasound_confidence_map\", \"neck_input.png\"),\n            os.path.join(TESTS_PATH, \"testing_data\", \"ultrasound_confidence_map\", \"femur_input.png\"),\n        ]\n\n        self.real_result_npy_paths = [\n            os.path.join(TESTS_PATH, \"testing_data\", \"ultrasound_confidence_map\", \"neck_result.npy\"),\n            os.path.join(TESTS_PATH, \"testing_data\", \"ultrasound_confidence_map\", \"femur_result.npy\"),\n        ]\n\n        self.real_input_paramaters = [\n            {\"alpha\": 2.0, \"beta\": 90, \"gamma\": 0.03},\n            {\"alpha\": 2.0, \"beta\": 90, \"gamma\": 0.06},\n        ]\n\n    def test_parameters(self):\n        # Unknown mode\n        with self.assertRaises(ValueError):\n            UltrasoundConfidenceMapTransform(mode=\"unknown\")\n\n        # Unknown sink_mode\n        with self.assertRaises(ValueError):\n            UltrasoundConfidenceMapTransform(sink_mode=\"unknown\")\n\n    @parameterized.expand(\n        [(\"all\", SINK_ALL_OUTPUT), (\"mid\", SINK_MID_OUTPUT), (\"min\", SINK_MIN_OUTPUT), (\"mask\", SINK_MASK_OUTPUT, True)]\n    )\n    def test_ultrasound_confidence_map_transform(self, sink_mode, expected_output, use_mask=False):\n        # RGB image\n        input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 3, axis=0), axis=0)\n        input_img_rgb_torch = torch.from_numpy(input_img_rgb)\n\n        transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)\n\n        if use_mask:\n            result_torch = transform(input_img_rgb_torch, self.input_mask_torch)\n            result_np = transform(input_img_rgb, self.input_mask_np)\n        else:\n            result_torch = transform(input_img_rgb_torch)\n            result_np = transform(input_img_rgb)\n\n        self.assertIsInstance(result_torch, torch.Tensor)\n        assert_allclose(result_torch, torch.tensor(expected_output), rtol=1e-4, atol=1e-4)\n        self.assertIsInstance(result_np, np.ndarray)\n        assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4)\n\n    @parameterized.expand(\n        [\n            (\"all\", SINK_ALL_OUTPUT),\n            (\"mid\", SINK_MID_OUTPUT),\n            (\"min\", SINK_MIN_OUTPUT),\n            (\"mask\", SINK_MASK_OUTPUT, True),  # Adding a flag for mask cases\n        ]\n    )\n    def test_multi_channel_2d(self, sink_mode, expected_output, use_mask=False):\n        input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 17, axis=0), axis=0)\n        input_img_rgb_torch = torch.from_numpy(input_img_rgb)\n\n        transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)\n\n        if use_mask:\n            result_torch = transform(input_img_rgb_torch, self.input_mask_torch)\n            result_np = transform(input_img_rgb, self.input_mask_np)\n        else:\n            result_torch = transform(input_img_rgb_torch)\n            result_np = transform(input_img_rgb)\n\n        self.assertIsInstance(result_torch, torch.Tensor)\n        assert_allclose(result_torch, torch.tensor(expected_output), rtol=1e-4, atol=1e-4)\n        self.assertIsInstance(result_np, np.ndarray)\n        assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4)\n\n    @parameterized.expand([(\"all\",), (\"mid\",), (\"min\",), (\"mask\",)])\n    def test_non_one_first_dim(self, sink_mode):\n        transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)\n        input_img_rgb = np.repeat(self.input_img_np, 3, axis=0)\n        input_img_rgb_torch = torch.from_numpy(input_img_rgb)\n\n        if sink_mode == \"mask\":\n            with self.assertRaises(ValueError):\n                transform(input_img_rgb_torch, self.input_mask_torch)\n            with self.assertRaises(ValueError):\n                transform(input_img_rgb, self.input_mask_np)\n        else:\n            with self.assertRaises(ValueError):\n                transform(input_img_rgb_torch)\n            with self.assertRaises(ValueError):\n                transform(input_img_rgb)\n\n    @parameterized.expand([(\"all\",), (\"mid\",), (\"min\",), (\"mask\",)])\n    def test_no_first_dim(self, sink_mode):\n        input_img_rgb = self.input_img_np[0]\n        input_img_rgb_torch = torch.from_numpy(input_img_rgb)\n\n        transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)\n\n        with self.assertRaises(ValueError):\n            transform(input_img_rgb_torch)\n        with self.assertRaises(ValueError):\n            transform(input_img_rgb)\n\n        if sink_mode == \"mask\":\n            with self.assertRaises(ValueError):\n                transform(input_img_rgb_torch, self.input_mask_torch)\n            with self.assertRaises(ValueError):\n                transform(input_img_rgb, self.input_mask_np)\n\n    @parameterized.expand([(\"all\",), (\"mid\",), (\"min\",)])\n    def test_sink_mode(self, mode):\n        transform = UltrasoundConfidenceMapTransform(sink_mode=mode)\n\n        # This should not raise an exception for torch tensor\n        result_torch = transform(self.input_img_torch)\n        self.assertIsInstance(result_torch, torch.Tensor)\n\n        # This should not raise an exception for numpy array\n        result_np = transform(self.input_img_np)\n        self.assertIsInstance(result_np, np.ndarray)\n\n    def test_sink_mask(self):\n        transform = UltrasoundConfidenceMapTransform(sink_mode=\"mask\")\n\n        # This should not raise an exception for torch tensor with mask\n        result_torch = transform(self.input_img_torch, self.input_mask_torch)\n        self.assertIsInstance(result_torch, torch.Tensor)\n\n        # This should not raise an exception for numpy array with mask\n        result_np = transform(self.input_img_np, self.input_mask_np)\n        self.assertIsInstance(result_np, np.ndarray)\n\n        # This should raise an exception for torch tensor without mask\n        with self.assertRaises(ValueError):\n            transform(self.input_img_torch)\n\n        # This should raise an exception for numpy array without mask\n        with self.assertRaises(ValueError):\n            transform(self.input_img_np)\n\n    def test_func(self):\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"all\")\n        output = transform(self.input_img_np)\n        assert_allclose(output, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"mid\")\n        output = transform(self.input_img_np)\n        assert_allclose(output, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"min\")\n        output = transform(self.input_img_np)\n        assert_allclose(output, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"mask\")\n        output = transform(self.input_img_np, self.input_mask_np)\n        assert_allclose(output, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"all\")\n        output = transform(self.input_img_torch)\n        assert_allclose(output, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"mid\")\n        output = transform(self.input_img_torch)\n        assert_allclose(output, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"min\")\n        output = transform(self.input_img_torch)\n        assert_allclose(output, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4)\n\n        transform = UltrasoundConfidenceMapTransform(alpha=2.0, beta=90.0, gamma=0.05, mode=\"B\", sink_mode=\"mask\")\n        output = transform(self.input_img_torch, self.input_mask_torch)\n        assert_allclose(output, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4)\n\n    def test_against_official_code(self):\n        # This test is to compare the output of the transform with the official code\n        # The official code is available at:\n        # https://campar.in.tum.de/Main/AthanasiosKaramalisCode\n\n        for input_img_path, result_npy_path, params in zip(\n            self.real_input_img_paths, self.real_result_npy_paths, self.real_input_paramaters\n        ):\n            input_img = np.array(Image.open(input_img_path))\n            input_img = np.expand_dims(input_img, axis=0)\n\n            result_img = np.load(result_npy_path)\n\n            transform = UltrasoundConfidenceMapTransform(sink_mode=\"all\", **params)\n            output = transform(input_img)\n\n            assert_allclose(output, result_img, rtol=1e-4, atol=1e-4)\n\n    def test_against_official_code_using_cg(self):\n        # This test is to compare the output of the transform with the official code\n        # The official code is available at:\n        # https://campar.in.tum.de/Main/AthanasiosKaramalisCode\n\n        for input_img_path, result_npy_path, params in zip(\n            self.real_input_img_paths, self.real_result_npy_paths, self.real_input_paramaters\n        ):\n            input_img = np.array(Image.open(input_img_path))\n            input_img = np.expand_dims(input_img, axis=0)\n\n            result_img = np.load(result_npy_path)\n\n            transform = UltrasoundConfidenceMapTransform(\n                sink_mode=\"all\", use_cg=True, cg_tol=1.0e-6, cg_maxiter=300, **params\n            )\n            output = transform(input_img)\n\n            assert_allclose(output, result_img, rtol=1e-2, atol=1e-2)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_utils_pytorch_numpy_unification.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile\nfrom monai.utils import set_determinism\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, skip_if_quick\n\nTEST_MODE = []\nfor p in TEST_NDARRAYS:\n    TEST_MODE.append([p(np.array([1, 2, 3, 4, 4, 5])), p(4), False])\n    TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False])\n    TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True])\n\nTEST_MIN_MAX = []\nfor p in TEST_NDARRAYS:\n    TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)])\n    TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {\"dim\": 1}, min, p([3.1, 3])])\n    TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)])\n    TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {\"dim\": 1}, max, p([5.1, 5])])\n\n\nclass TestPytorchNumpyUnification(unittest.TestCase):\n    def setUp(self) -> None:\n        set_determinism(0)\n\n    def test_percentile(self):\n        for size in (1, 100):\n            q = np.random.randint(0, 100, size=size)\n            results = []\n            for idx, p in enumerate(TEST_NDARRAYS):\n                dtype = [np.float32, float][idx % 2]\n                arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(dtype))\n                results.append(percentile(arr, q))\n                assert_allclose(results[0], results[-1], type_test=False, atol=1e-4, rtol=1e-4)\n\n    @skip_if_quick\n    def test_many_elements_quantile(self):  # pytorch#64947\n        for p in TEST_NDARRAYS:\n            for elements in (1000, 17_000_000):\n                for t in [*TEST_NDARRAYS, list]:\n                    x = p(np.random.randn(elements))\n                    q = percentile(x, t([10, 50]))\n                    if isinstance(x, torch.Tensor):\n                        self.assertIsInstance(q, torch.Tensor)\n                    assert_allclose(q.shape, [2], type_test=False)\n\n    def test_fails(self):\n        for p in TEST_NDARRAYS:\n            for q in (-1, 101):\n                arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32))\n                with self.assertRaises(ValueError):\n                    percentile(arr, q)\n\n    def test_dim(self):\n        q = np.random.randint(0, 100, size=50)\n        results = []\n        for p in TEST_NDARRAYS:\n            arr = p(np.arange(6).reshape(1, 2, 3).astype(np.float32))\n            results.append(percentile(arr, q, dim=1))\n            assert_allclose(results[0], results[-1], type_test=False, atol=1e-4)\n\n    @parameterized.expand(TEST_MODE)\n    def test_mode(self, array, expected, to_long):\n        res = mode(array, to_long=to_long)\n        assert_allclose(res, expected)\n\n    @parameterized.expand(TEST_MIN_MAX)\n    def test_min_max(self, array, input_params, func, expected):\n        res = func(array, **input_params)\n        assert_allclose(res, expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_vote_ensemble.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import VoteEnsemble\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    # shape: [2, 1, 1]\n    TESTS.append(\n        [\n            {\"num_classes\": None},\n            [p(torch.tensor([[[1]], [[0]]])), p(torch.tensor([[[1]], [[0]]])), p(torch.tensor([[[0]], [[1]]]))],\n            p(torch.tensor([[[1.0]], [[0.0]]])),\n        ]\n    )\n\n    # shape: [1, 2, 1, 1]\n    TESTS.append(\n        [\n            {\"num_classes\": None},\n            p(\n                torch.stack(\n                    [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]\n                )\n            ),\n            p(torch.tensor([[[[1.0]], [[0.0]]]])),\n        ]\n    )\n\n    # shape: [1, 2, 1]\n    TESTS.append(\n        [\n            {\"num_classes\": 3},\n            [p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[1], [1]]]))],\n            p(torch.tensor([[[0], [2]]])),\n        ]\n    )\n\n    # shape: [1, 2, 1]\n    TESTS.append(\n        [\n            {\"num_classes\": 5},\n            [p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[0], [2]]])), p(torch.tensor([[[1], [1]]]))],\n            p(torch.tensor([[[0], [2]]])),\n        ]\n    )\n\n    # shape: [1]\n    TESTS.append(\n        [{\"num_classes\": 3}, [p(torch.tensor([2])), p(torch.tensor([2])), p(torch.tensor([1]))], p(torch.tensor([2]))]\n    )\n\n    # shape: 1\n    TESTS.append([{\"num_classes\": 3}, [p(torch.tensor(2)), p(torch.tensor(2)), p(torch.tensor(1))], p(torch.tensor(2))])\n\n\nclass TestVoteEnsemble(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, img, expected_value):\n        result = VoteEnsemble(**input_param)(img)\n        if isinstance(img, torch.Tensor):\n            self.assertIsInstance(result, torch.Tensor)\n            self.assertEqual(result.device, img.device)\n        assert_allclose(result, expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_vote_ensembled.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms import VoteEnsembled\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    # shape: [1, 2, 1, 1]\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\", \"pred2\"], \"output_key\": \"output\", \"num_classes\": None},\n            {\n                \"pred0\": p(torch.tensor([[[[1]], [[0]]]])),\n                \"pred1\": p(torch.tensor([[[[1]], [[0]]]])),\n                \"pred2\": p(torch.tensor([[[[0]], [[1]]]])),\n            },\n            p(torch.tensor([[[[1.0]], [[0.0]]]])),\n        ]\n    )\n\n    # shape: [1, 2, 1, 1]\n    TESTS.append(\n        [\n            {\"keys\": \"output\", \"output_key\": \"output\", \"num_classes\": None},\n            {\n                \"output\": p(\n                    torch.stack(\n                        [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]\n                    )\n                )\n            },\n            p(torch.tensor([[[[1.0]], [[0.0]]]])),\n        ]\n    )\n\n    # shape: [1, 2, 1]\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\", \"pred2\"], \"output_key\": \"output\", \"num_classes\": 3},\n            {\n                \"pred0\": p(torch.tensor([[[0], [2]]])),\n                \"pred1\": p(torch.tensor([[[0], [2]]])),\n                \"pred2\": p(torch.tensor([[[1], [1]]])),\n            },\n            p(torch.tensor([[[0], [2]]])),\n        ]\n    )\n\n    # shape: [1, 2, 1]\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\", \"pred2\"], \"output_key\": \"output\", \"num_classes\": 5},\n            {\n                \"pred0\": p(torch.tensor([[[0], [2]]])),\n                \"pred1\": p(torch.tensor([[[0], [2]]])),\n                \"pred2\": p(torch.tensor([[[1], [1]]])),\n            },\n            p(torch.tensor([[[0], [2]]])),\n        ]\n    )\n\n    # shape: [1]\n    TESTS.append(\n        [\n            {\"keys\": [\"pred0\", \"pred1\", \"pred2\"], \"output_key\": \"output\", \"num_classes\": 3},\n            {\"pred0\": p(torch.tensor([2])), \"pred1\": p(torch.tensor([2])), \"pred2\": p(torch.tensor([1]))},\n            p(torch.tensor([2])),\n        ]\n    )\n\n\nclass TestVoteEnsembled(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input_param, img, expected_value):\n        result = VoteEnsembled(**input_param)(img)\n        assert_allclose(result[\"output\"], expected_value)\n\n    def test_cuda_value(self):\n        img = torch.stack(\n            [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])]\n        )\n        expected_value = torch.tensor([[[[1.0]], [[0.0]]]])\n        if torch.cuda.is_available():\n            img = img.to(torch.device(\"cuda:0\"))\n            expected_value = expected_value.to(torch.device(\"cuda:0\"))\n        result = VoteEnsembled(keys=\"output\", num_classes=None)({\"output\": img})\n        assert_allclose(result[\"output\"], expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_with_allow_missing_keys.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms import Compose, SpatialPad, SpatialPadd, allow_missing_keys_mode\n\n\nclass TestWithAllowMissingKeysMode(unittest.TestCase):\n\n    def setUp(self):\n        self.data = {\"image\": np.arange(16, dtype=float).reshape(1, 4, 4)}\n\n    def test_map_transform(self):\n        for amk in [True, False]:\n            t = SpatialPadd([\"image\", \"label\"], 10, allow_missing_keys=amk)\n            with allow_missing_keys_mode(t):\n                # check state is True\n                self.assertTrue(t.allow_missing_keys)\n                # and that transform works even though key is missing\n                _ = t(self.data)\n            # check it has returned to original state\n            self.assertEqual(t.allow_missing_keys, amk)\n            if not amk:\n                # should fail because amks==False and key is missing\n                with self.assertRaises(KeyError):\n                    _ = t(self.data)\n\n    def test_compose(self):\n        amks = [True, False, True]\n        t = Compose([SpatialPadd([\"image\", \"label\"], 10, allow_missing_keys=amk) for amk in amks])\n        with allow_missing_keys_mode(t):\n            # check states are all True\n            for _t in t.transforms:\n                self.assertTrue(_t.allow_missing_keys)\n            # and that transform works even though key is missing\n            _ = t(self.data)\n        # check they've returned to original state\n        for _t, amk in zip(t.transforms, amks):\n            self.assertEqual(_t.allow_missing_keys, amk)\n        # should fail because not all amks==True and key is missing\n        with self.assertRaises((KeyError, RuntimeError)):\n            _ = t(self.data)\n\n    def test_array_transform(self):\n        for t in [SpatialPad(10), Compose([SpatialPad(10)])]:\n            with self.assertRaises(TypeError):\n                with allow_missing_keys_mode(t):\n                    pass\n\n    def test_multiple(self):\n        orig_states = [True, False]\n        ts = [SpatialPadd([\"image\", \"label\"], 10, allow_missing_keys=i) for i in orig_states]\n        with allow_missing_keys_mode(ts):\n            for t in ts:\n                self.assertTrue(t.allow_missing_keys)\n                # and that transform works even though key is missing\n                _ = t(self.data)\n        for t, o_s in zip(ts, orig_states):\n            self.assertEqual(t.allow_missing_keys, o_s)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_zoom.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom scipy.ndimage import zoom as zoom_scipy\n\nfrom monai.data import MetaTensor, set_track_meta\nfrom monai.transforms import Zoom\nfrom monai.transforms.lazy.functional import apply_pending\nfrom tests.test_utils import (\n    DEFAULT_TEST_AFFINE,\n    TEST_NDARRAYS_ALL,\n    NumpyImageTestCase2D,\n    assert_allclose,\n    test_local_inversion,\n)\n\nVALID_CASES = [\n    (1.5, \"nearest\", True),\n    (1.5, \"nearest\", False),\n    (0.8, \"bilinear\"),\n    (0.8, 1),\n    (0.8, \"area\"),\n    (1.5, \"nearest\", False, True),\n    (0.8, \"area\", False, True),\n]\n\nINVALID_CASES = [((None, None), \"bilinear\", TypeError), ((0.9, 0.9), \"s\", ValueError)]\n\n\nclass TestZoom(NumpyImageTestCase2D):\n    @parameterized.expand(VALID_CASES)\n    def test_pending_ops(self, zoom, mode, align_corners=False, keep_size=False):\n        im = MetaTensor(self.imt[0], meta={\"a\": \"b\", \"affine\": DEFAULT_TEST_AFFINE})\n        zoom_fn = Zoom(\n            zoom=zoom, mode=\"bilinear\", keep_size=keep_size, dtype=torch.float64, align_corners=align_corners\n        )\n        # non-lazy\n        expected = zoom_fn(im)\n        self.assertIsInstance(expected, MetaTensor)\n        # lazy\n        zoom_fn.lazy = True\n        pending_result = zoom_fn(im)\n        self.assertIsInstance(pending_result, MetaTensor)\n        assert_allclose(pending_result.peek_pending_affine(), expected.affine)\n        assert_allclose(pending_result.peek_pending_shape(), expected.shape[1:])\n        overrides = {\"mode\": \"bilinear\", \"dtype\": np.float64, \"align_corners\": align_corners}\n        result = apply_pending(pending_result, overrides=overrides)[0]\n        # compare\n        match_ratio = np.sum(np.isclose(result, expected)) / np.prod(result.shape)\n        self.assertGreater(match_ratio, 0.95)\n\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, zoom, mode, *_):\n        for p in TEST_NDARRAYS_ALL:\n            zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False)\n            im = p(self.imt[0])\n            zoomed = zoom_fn(im)\n            test_local_inversion(zoom_fn, zoomed, im)\n            _order = 0\n            if mode == 1 or mode.endswith(\"linear\"):\n                _order = 1\n            expected = []\n            for channel in self.imt[0]:\n                expected.append(zoom_scipy(channel, zoom=zoom, mode=\"nearest\", order=_order, prefilter=False))\n            expected = np.stack(expected).astype(np.float32)\n            assert_allclose(zoomed, p(expected), atol=1.0, type_test=False)\n\n    def test_keep_size(self):\n        for p in TEST_NDARRAYS_ALL:\n            zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True)\n            im = p(self.imt[0])\n            zoomed = zoom_fn(im, mode=\"bilinear\")\n            assert_allclose(zoomed.shape, self.imt.shape[1:], type_test=False)\n            test_local_inversion(zoom_fn, zoomed, im)\n\n            zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True)\n            im = p(self.imt[0])\n            zoomed = zoom_fn(im)\n            assert_allclose(zoomed.shape, self.imt.shape[1:], type_test=False)\n            test_local_inversion(zoom_fn, zoomed, p(self.imt[0]))\n\n            set_track_meta(False)\n            rotated = zoom_fn(im)\n            self.assertNotIsInstance(rotated, MetaTensor)\n            np.testing.assert_allclose(zoomed.shape, self.imt.shape[1:])\n            set_track_meta(True)\n\n    @parameterized.expand(INVALID_CASES)\n    def test_invalid_inputs(self, zoom, mode, raises):\n        for p in TEST_NDARRAYS_ALL:\n            with self.assertRaises(raises):\n                zoom_fn = Zoom(zoom=zoom, mode=mode)\n                zoom_fn(p(self.imt[0]))\n\n    def test_padding_mode(self):\n        for p in TEST_NDARRAYS_ALL:\n            zoom_fn = Zoom(zoom=0.5, mode=\"nearest\", padding_mode=\"constant\", keep_size=True)\n            test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]])\n            zoomed = zoom_fn(test_data)\n            expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])\n            assert_allclose(zoomed, expected, type_test=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/test_zoomd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom scipy.ndimage import zoom as zoom_scipy\n\nfrom monai.config import USE_COMPILED\nfrom monai.transforms import Zoomd\nfrom tests.lazy_transforms_utils import test_resampler_lazy\nfrom tests.test_utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion\n\nVALID_CASES = [\n    (1.5, \"nearest\", False),\n    (0.3, \"bilinear\", False, True),\n    (0.8, \"bilinear\", False, False),\n    (1.3, \"bilinear\", False),\n]\n\nINVALID_CASES = [(\"no_zoom\", None, \"bilinear\", TypeError), (\"invalid_order\", 0.9, \"s\", ValueError)]\n\n\nclass TestZoomd(NumpyImageTestCase2D):\n    @parameterized.expand(VALID_CASES)\n    def test_correct_results(self, zoom, mode, keep_size, align_corners=None):\n        key = \"img\"\n        init_param = {\n            \"keys\": key,\n            \"zoom\": zoom,\n            \"mode\": mode,\n            \"keep_size\": keep_size,\n            \"dtype\": torch.float64,\n            \"align_corners\": align_corners,\n        }\n        zoom_fn = Zoomd(**init_param)\n        for p in TEST_NDARRAYS_ALL:\n            im = p(self.imt[0])\n            call_param = {\"data\": {key: im}}\n            zoomed = zoom_fn(**call_param)\n\n            # test lazy\n            # TODO: temporarily skip \"nearest\" test\n            if mode == \"bilinear\":\n                test_resampler_lazy(\n                    zoom_fn, zoomed, init_param, call_param, output_key=key, atol=1e-4 if USE_COMPILED else 1e-6\n                )\n                zoom_fn.lazy = False\n\n            test_local_inversion(zoom_fn, zoomed, {key: im}, key)\n            _order = 0\n            if mode.endswith(\"linear\"):\n                _order = 1\n            expected = [\n                zoom_scipy(channel, zoom=zoom, mode=\"nearest\", order=_order, prefilter=False) for channel in self.imt[0]\n            ]\n\n            expected = np.stack(expected).astype(np.float32)\n            assert_allclose(zoomed[key], p(expected), atol=1.0, type_test=False)\n\n    def test_keep_size(self):\n        key = \"img\"\n        zoom_fn = Zoomd(key, zoom=0.6, keep_size=True, padding_mode=\"constant\", constant_values=2)\n        for p in TEST_NDARRAYS_ALL:\n            zoomed = zoom_fn({key: p(self.imt[0])})\n            np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:])\n\n            zoom_fn = Zoomd(key, zoom=1.3, keep_size=True)\n            zoomed = zoom_fn({key: self.imt[0]})\n            self.assertTrue(np.array_equal(zoomed[key].shape, self.imt.shape[1:]))\n\n    @parameterized.expand(INVALID_CASES)\n    def test_invalid_inputs(self, _, zoom, mode, raises):\n        key = \"img\"\n        for p in TEST_NDARRAYS_ALL:\n            with self.assertRaises(raises):\n                zoom_fn = Zoomd(key, zoom=zoom, mode=mode)\n                zoom_fn({key: p(self.imt[0])})\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/transform/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/transform/test_randomizable.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\n\nfrom monai.transforms.transform import Randomizable\n\n\nclass RandTest(Randomizable):\n\n    def randomize(self, data=None):\n        pass\n\n\nclass TestRandomizable(unittest.TestCase):\n\n    def test_default(self):\n        inst = RandTest()\n        r1 = inst.R.rand()\n        self.assertTrue(isinstance(inst.R, np.random.RandomState))\n        inst.set_random_state()\n        r2 = inst.R.rand()\n        self.assertNotAlmostEqual(r1, r2)\n\n    def test_seed(self):\n        inst = RandTest()\n        inst.set_random_state(seed=123)\n        self.assertAlmostEqual(inst.R.rand(), 0.69646918)\n        inst.set_random_state(123)\n        self.assertAlmostEqual(inst.R.rand(), 0.69646918)\n\n    def test_state(self):\n        inst = RandTest()\n        inst_r = np.random.RandomState(123)\n        inst.set_random_state(state=inst_r)\n        self.assertAlmostEqual(inst.R.rand(), 0.69646918)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/transform/test_randomizable_transform_type.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.transforms.transform import RandomizableTrait, RandomizableTransform\n\n\nclass InheritsInterface(RandomizableTrait):\n    pass\n\n\nclass InheritsImplementation(RandomizableTransform):\n\n    def __call__(self, data):\n        return data\n\n\nclass TestRandomizableTransformType(unittest.TestCase):\n\n    def test_is_randomizable_transform_type(self):\n        inst = InheritsInterface()\n        self.assertIsInstance(inst, RandomizableTrait)\n\n    def test_set_random_state_randomizable_transform(self):\n        inst = InheritsImplementation()\n        inst.set_random_state(0)\n"
  },
  {
    "path": "tests/transforms/utility/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/utility/test_apply_transform_to_points.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.transforms.utility.array import ApplyTransformToPoints\nfrom monai.utils import set_determinism\n\nset_determinism(seed=0)\n\nDATA_2D = torch.rand(1, 64, 64)\nDATA_3D = torch.rand(1, 64, 64, 64)\nPOINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]])\nPOINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]])\nPOINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]])\nPOINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])\nPOINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])\nPOINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])\nAFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])\nAFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])\n\nTEST_CASES = [\n    [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE],\n    [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD],\n    [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD],\n    [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS],\n    [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],\n    [\n        MetaTensor(DATA_3D, affine=AFFINE_2),\n        MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),\n        None,\n        False,\n        False,\n        POINT_3D_WORLD,\n    ],\n    [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],\n]\n\nTEST_CASES_WRONG = [\n    [POINT_2D_WORLD, True, None],\n    [POINT_2D_WORLD.unsqueeze(0), False, None],\n    [POINT_3D_WORLD[..., 0:1], False, None],\n    [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])],\n]\n\n\nclass TestCoordinateTransform(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output):\n        transform = ApplyTransformToPoints(\n            dtype=torch.int64, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras\n        )\n        affine = image.affine if image is not None else None\n        output = transform(points, affine)\n        self.assertTrue(torch.allclose(output, expected_output))\n        invert_out = transform.inverse(output)\n        self.assertTrue(torch.allclose(invert_out, points))\n\n    @parameterized.expand(TEST_CASES_WRONG)\n    def test_wrong_input(self, input, invert_affine, affine):\n        transform = ApplyTransformToPoints(dtype=torch.int64, invert_affine=invert_affine)\n        with self.assertRaises(ValueError):\n            transform(input, affine)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_apply_transform_to_pointsd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.transforms.utility.dictionary import ApplyTransformToPointsd\nfrom monai.utils import set_determinism\n\nset_determinism(seed=0)\n\nDATA_2D = torch.rand(1, 64, 64)\nDATA_3D = torch.rand(1, 64, 64, 64)\nPOINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]])\nPOINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]])\nPOINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]])\nPOINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])\nPOINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])\nPOINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])\nAFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])\nAFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])\n\nTEST_CASES = [\n    [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE],  # use image affine\n    [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD],  # use point affine\n    [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD],  # use input affine\n    [None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE],  # use input affine\n    [\n        MetaTensor(DATA_2D, affine=AFFINE_1),\n        POINT_2D_WORLD,\n        None,\n        True,\n        True,\n        POINT_2D_IMAGE_RAS,\n    ],  # test affine_lps_to_ras\n    [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],\n    [\"affine\", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],  # use refer_data itself\n    [\n        MetaTensor(DATA_3D, affine=AFFINE_2),\n        MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),\n        None,\n        False,\n        False,\n        POINT_3D_WORLD,\n    ],\n    [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],\n    [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],\n]\nTEST_CASES_SEQUENCE = [\n    [\n        (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),\n        [POINT_2D_WORLD, POINT_3D_WORLD],\n        None,\n        True,\n        False,\n        [\"image_1\", \"image_2\"],\n        [POINT_2D_IMAGE, POINT_3D_IMAGE],\n    ],  # use image affine\n    [\n        (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),\n        [POINT_2D_WORLD, POINT_3D_WORLD],\n        None,\n        True,\n        True,\n        [\"image_1\", \"image_2\"],\n        [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS],\n    ],  # test affine_lps_to_ras\n    [\n        (None, None),\n        [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],\n        None,\n        False,\n        False,\n        None,\n        [POINT_2D_WORLD, POINT_3D_WORLD],\n    ],  # use point affine\n    [\n        (None, None),\n        [POINT_2D_WORLD, POINT_2D_WORLD],\n        AFFINE_1,\n        True,\n        False,\n        None,\n        [POINT_2D_IMAGE, POINT_2D_IMAGE],\n    ],  # use input affine\n    [\n        (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),\n        [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],\n        None,\n        False,\n        False,\n        [\"image_1\", \"image_2\"],\n        [POINT_2D_WORLD, POINT_3D_WORLD],\n    ],\n]\n\nTEST_CASES_WRONG = [\n    [POINT_2D_WORLD, True, None, None],\n    [POINT_2D_WORLD.unsqueeze(0), False, None, None],\n    [POINT_3D_WORLD[..., 0:1], False, None, None],\n    [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None],\n    [POINT_3D_WORLD, False, None, \"image\"],\n    [POINT_3D_WORLD, False, None, []],\n]\n\n\nclass TestCoordinateTransform(unittest.TestCase):\n    @parameterized.expand(TEST_CASES)\n    def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output):\n        data = {\n            \"image\": image,\n            \"point\": points,\n            \"affine\": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),\n        }\n        refer_keys = \"image\" if (image is not None and image != \"affine\") else image\n        transform = ApplyTransformToPointsd(\n            keys=\"point\",\n            refer_keys=refer_keys,\n            dtype=torch.int64,\n            affine=affine,\n            invert_affine=invert_affine,\n            affine_lps_to_ras=affine_lps_to_ras,\n        )\n        output = transform(data)\n\n        self.assertTrue(torch.allclose(output[\"point\"], expected_output))\n        invert_out = transform.inverse(output)\n        self.assertTrue(torch.allclose(invert_out[\"point\"], points))\n\n    @parameterized.expand(TEST_CASES_SEQUENCE)\n    def test_transform_coordinates_sequences(\n        self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output\n    ):\n        data = {\"image_1\": image[0], \"image_2\": image[1], \"point_1\": points[0], \"point_2\": points[1]}\n        keys = [\"point_1\", \"point_2\"]\n        transform = ApplyTransformToPointsd(\n            keys=keys,\n            refer_keys=refer_keys,\n            dtype=torch.int64,\n            affine=affine,\n            invert_affine=invert_affine,\n            affine_lps_to_ras=affine_lps_to_ras,\n        )\n        output = transform(data)\n\n        self.assertTrue(torch.allclose(output[\"point_1\"], expected_output[0]))\n        self.assertTrue(torch.allclose(output[\"point_2\"], expected_output[1]))\n        invert_out = transform.inverse(output)\n        self.assertTrue(torch.allclose(invert_out[\"point_1\"], points[0]))\n\n    @parameterized.expand(TEST_CASES_WRONG)\n    def test_wrong_input(self, input, invert_affine, affine, refer_keys):\n        if refer_keys == []:\n            with self.assertRaises(ValueError):\n                ApplyTransformToPointsd(\n                    keys=\"point\", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys\n                )\n        else:\n            transform = ApplyTransformToPointsd(\n                keys=\"point\", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys\n            )\n            data = {\"point\": input}\n            if refer_keys == \"image\":\n                with self.assertRaises(KeyError):\n                    transform(data)\n            else:\n                with self.assertRaises(ValueError):\n                    transform(data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_identity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.transforms.utility.array import Identity\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestIdentity(NumpyImageTestCase2D):\n    def test_identity(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n            identity = Identity()\n            assert_allclose(img, identity(img))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_identityd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.transforms.utility.dictionary import Identityd\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestIdentityd(NumpyImageTestCase2D):\n    def test_identityd(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n            data = {\"img\": img}\n            identity = Identityd(keys=data.keys())\n            assert_allclose(img, identity(data)[\"img\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_lambda.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom numpy import ndarray\nfrom torch import Tensor\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.utility.array import Lambda\nfrom monai.utils.type_conversion import convert_to_numpy, convert_to_tensor\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestLambda(NumpyImageTestCase2D):\n    def test_lambda_identity(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n\n            def identity_func(x):\n                return x\n\n            lambd = Lambda(func=identity_func)\n            assert_allclose(identity_func(img), lambd(img), type_test=False)\n\n    def test_lambda_slicing(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n\n            def slice_func(x):\n                return x[:, :, :6, ::2]\n\n            lambd = Lambda(func=slice_func)\n            out = lambd(img)\n            assert_allclose(slice_func(img), out, type_test=False)\n            self.assertIsInstance(out, MetaTensor)\n            self.assertEqual(len(out.applied_operations), 1)\n            out = lambd.inverse(out)\n            self.assertEqual(len(out.applied_operations), 0)\n\n    def test_lambda_track_meta_false(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n\n            def to_numpy(x):\n                return convert_to_numpy(x)\n\n            lambd = Lambda(func=to_numpy, track_meta=False)\n            out = lambd(img)\n            self.assertIsInstance(out, ndarray)\n\n            def to_tensor(x):\n                return convert_to_tensor(x)\n\n            lambd = Lambda(func=to_tensor, track_meta=False)\n            out = lambd(img)\n            self.assertIsInstance(out, Tensor)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_lambdad.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom numpy import ndarray\nfrom torch import Tensor\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.utility.dictionary import Lambdad\nfrom monai.utils.type_conversion import convert_to_numpy, convert_to_tensor\nfrom tests.test_utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose\n\n\nclass TestLambdad(NumpyImageTestCase2D):\n    def test_lambdad_identity(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n            data = {\"img\": img, \"prop\": 1.0, \"label\": 1.0}\n\n            def noise_func(x):\n                return x + 1.0\n\n            expected = {\"img\": noise_func(data[\"img\"]), \"prop\": 1.0, \"new_label\": 2.0}\n            ret = Lambdad(keys=[\"img\", \"prop\", \"label\"], func=noise_func, overwrite=[True, False, \"new_label\"])(data)\n            assert_allclose(expected[\"img\"], ret[\"img\"], type_test=False)\n            assert_allclose(expected[\"prop\"], ret[\"prop\"], type_test=False)\n            assert_allclose(expected[\"new_label\"], ret[\"new_label\"], type_test=False)\n\n    def test_lambdad_slicing(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n            data = {\"img\": img}\n\n            def slice_func(x):\n                return x[:, :, :6, ::2]\n\n            lambd = Lambdad(keys=data.keys(), func=slice_func)\n            expected = {}\n            expected = slice_func(data[\"img\"])\n            out = lambd(data)\n            out_img = out[\"img\"]\n            assert_allclose(expected, out_img, type_test=False)\n            self.assertIsInstance(out_img, MetaTensor)\n            self.assertEqual(len(out_img.applied_operations), 1)\n            inv_img = lambd.inverse(out)[\"img\"]\n            self.assertIsInstance(inv_img, MetaTensor)\n            self.assertEqual(len(inv_img.applied_operations), 0)\n\n    def test_lambdad_track_meta_false(self):\n        for p in TEST_NDARRAYS:\n            img = p(self.imt)\n            data = {\"img\": img}\n\n            def to_numpy(x):\n                return convert_to_numpy(x)\n\n            lambd = Lambdad(keys=data.keys(), func=to_numpy, track_meta=False)\n            out = lambd(data)\n            out_img = out[\"img\"]\n            self.assertIsInstance(out_img, ndarray)\n\n            def to_tensor(x):\n                return convert_to_tensor(x)\n\n            lambd = Lambdad(keys=data.keys(), func=to_tensor, track_meta=False)\n            out = lambd(data)\n            out_img = out[\"img\"]\n            self.assertIsInstance(out_img, Tensor)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_rand_lambda.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.transform import Randomizable\nfrom monai.transforms.utility.array import RandLambda\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n\nclass RandTest(Randomizable):\n    \"\"\"\n    randomisable transform for testing.\n    \"\"\"\n\n    def randomize(self, data=None):\n        self._a = self.R.random()\n\n    def __call__(self, data):\n        self.randomize()\n        return deepcopy(data) + self._a\n\n\nclass TestRandLambda(unittest.TestCase):\n    def check(self, tr: RandLambda, img, img_orig_type, out, expected=None):\n        # input shouldn't change\n        self.assertIsInstance(img, img_orig_type)\n        if isinstance(img, MetaTensor):\n            self.assertEqual(len(img.applied_operations), 0)\n        # output data matches expected\n        assert_allclose(expected, out, type_test=False)\n        # output type is MetaTensor with 1 appended operation\n        self.assertIsInstance(out, MetaTensor)\n        self.assertEqual(len(out.applied_operations), 1)\n\n        # inverse\n        inv = tr.inverse(out)\n        # after inverse, input image remains unchanged\n        self.assertIsInstance(img, img_orig_type)\n        if isinstance(img, MetaTensor):\n            self.assertEqual(len(img.applied_operations), 0)\n        # after inverse, output is MetaTensor with 0 applied operations\n        self.assertIsInstance(inv, MetaTensor)\n        self.assertEqual(len(inv.applied_operations), 0)\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_rand_lambdad_identity(self, t):\n        img = t(np.zeros((10, 10)))\n        img_t = type(img)\n\n        test_func = RandTest()\n        test_func.set_random_state(seed=134)\n        expected = test_func(img)\n        test_func.set_random_state(seed=134)\n\n        # default prob\n        tr = RandLambda(func=test_func)\n        ret = tr(img)\n        self.check(tr, img, img_t, ret, expected)\n\n        tr = RandLambda(func=test_func, prob=0.0)\n        ret = tr(img)\n        self.check(tr, img, img_t, ret, expected=img)\n\n        trans = RandLambda(func=test_func, prob=0.5)\n        trans.set_random_state(seed=123)\n        ret = trans(img)\n        self.check(trans, img, img_t, ret, expected=img)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_rand_lambdad.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms.transform import Randomizable\nfrom monai.transforms.utility.dictionary import RandLambdad\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n\nclass RandTest(Randomizable):\n    \"\"\"\n    randomisable transform for testing.\n    \"\"\"\n\n    def randomize(self, data=None):\n        self._a = self.R.random()\n\n    def __call__(self, data):\n        self.randomize()\n        return data + self._a\n\n\nclass TestRandLambdad(unittest.TestCase):\n    def check(self, tr: RandLambdad, input: dict, out: dict, expected: dict):\n        if isinstance(input[\"img\"], MetaTensor):\n            self.assertEqual(len(input[\"img\"].applied_operations), 0)\n        self.assertIsInstance(out[\"img\"], MetaTensor)\n        self.assertEqual(len(out[\"img\"].applied_operations), 1)\n        assert_allclose(expected[\"img\"], out[\"img\"], type_test=False)\n        assert_allclose(expected[\"prop\"], out[\"prop\"], type_test=False)\n        inv = tr.inverse(out)\n        self.assertIsInstance(inv[\"img\"], MetaTensor)\n        self.assertEqual(len(inv[\"img\"].applied_operations), 0)  # type: ignore\n\n    @parameterized.expand([[p] for p in TEST_NDARRAYS])\n    def test_rand_lambdad_identity(self, t):\n        img = t(np.zeros((10, 10)))\n        data = {\"img\": img, \"prop\": 1.0}\n\n        test_func = RandTest()\n        test_func.set_random_state(seed=134)\n        expected = {\"img\": test_func(data[\"img\"]), \"prop\": 1.0}\n        test_func.set_random_state(seed=134)\n\n        # default prob\n        tr = RandLambdad(keys=[\"img\", \"prop\"], func=test_func, overwrite=[True, False])\n        ret = tr(deepcopy(data))\n        self.check(tr, data, ret, expected)\n\n        tr = RandLambdad(keys=[\"img\", \"prop\"], func=test_func, prob=0.0)\n        ret = tr(deepcopy(data))\n        self.check(tr, data, ret, expected=data)\n\n        trans = RandLambdad(keys=[\"img\", \"prop\"], func=test_func, prob=0.5)\n        trans.set_random_state(seed=123)\n        ret = trans(deepcopy(data))\n        self.check(trans, data, ret, expected=data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_simulatedelay.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport time\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.utility.array import SimulateDelay\nfrom tests.test_utils import NumpyImageTestCase2D\n\n\nclass TestSimulateDelay(NumpyImageTestCase2D):\n    @parameterized.expand([(0.45,), (1,)])\n    def test_value(self, delay_test_time: float):\n        resize = SimulateDelay(delay_time=delay_test_time)\n        start: float = time.time()\n        _ = resize(self.imt[0])\n        stop: float = time.time()\n        measured_approximate: float = stop - start\n        np.testing.assert_allclose(delay_test_time, measured_approximate, rtol=0.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_simulatedelayd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport time\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.utility.dictionary import SimulateDelayd\nfrom tests.test_utils import NumpyImageTestCase2D\n\n\nclass TestSimulateDelay(NumpyImageTestCase2D):\n    @parameterized.expand([(0.45,), (1,)])\n    def test_value(self, delay_test_time: float):\n        resize = SimulateDelayd(keys=\"imgd\", delay_time=delay_test_time)\n        start: float = time.time()\n        _ = resize({\"imgd\": self.imt[0]})\n        stop: float = time.time()\n        measured_approximate: float = stop - start\n        np.testing.assert_allclose(delay_test_time, measured_approximate, rtol=0.5)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_splitdim.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.transforms.utility.array import SplitDim\nfrom tests.test_utils import TEST_NDARRAYS\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for keepdim in (True, False):\n        TESTS.append(((2, 10, 8, 7), keepdim, p))\n\n\nclass TestSplitDim(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_correct_shape(self, shape, keepdim, im_type):\n        arr = im_type(np.random.rand(*shape))\n        for dim in range(arr.ndim):\n            out = SplitDim(dim, keepdim)(arr)\n            self.assertIsInstance(out, (list, tuple))\n            self.assertEqual(type(out[0]), type(arr))\n            self.assertEqual(len(out), arr.shape[dim])\n            expected_ndim = arr.ndim if keepdim else arr.ndim - 1\n            self.assertEqual(out[0].ndim, expected_ndim)\n            # assert is a shallow copy\n            arr[0, 0, 0, 0] *= 2\n            self.assertEqual(arr.flatten()[0], out[0].flatten()[0])\n\n    def test_singleton(self):\n        shape = (2, 1, 8, 7)\n        for p in TEST_NDARRAYS:\n            arr = p(np.random.rand(*shape))\n            out = SplitDim(dim=1)(arr)\n            self.assertEqual(out[0].shape, shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utility/test_splitdimd.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.meta_tensor import MetaTensor\nfrom monai.transforms import LoadImaged\nfrom monai.transforms.utility.dictionary import SplitDimd\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product, make_nifti_image, make_rand_affine\n\nTESTS = list(dict_product(keepdim=[True, False], p=TEST_NDARRAYS, update_meta=[True, False], list_output=[True, False]))\n\n\nclass TestSplitDimd(unittest.TestCase):\n    data: MetaTensor\n\n    @classmethod\n    def setUpClass(cls) -> None:\n        arr = np.random.rand(2, 10, 8, 7)\n        affine = make_rand_affine()\n        data = {\"i\": make_nifti_image(arr, affine)}\n\n        loader = LoadImaged(\"i\", image_only=True)\n        cls.data = loader(data)\n\n    @parameterized.expand(TESTS)\n    def test_correct(self, keepdim, _, update_meta, list_output):\n        data = deepcopy(self.data)\n        arr = data[\"i\"]\n        for dim in range(arr.ndim):\n            out = SplitDimd(\"i\", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data)\n            if list_output:\n                self.assertIsInstance(out, list)\n                self.assertEqual(len(out), arr.shape[dim])\n            else:\n                self.assertIsInstance(out, dict)\n                self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim])\n            # if updating metadata, pick some random points and\n            # check same world coordinates between input and output\n            if update_meta:\n                for _ in range(10):\n                    idx = [np.random.choice(i) for i in arr.shape]\n                    split_im_idx = idx[dim]\n                    split_idx = deepcopy(idx)\n                    split_idx[dim] = 0\n                    if list_output:\n                        split_im = out[split_im_idx][\"i\"]\n                    else:\n                        split_im = out[f\"i_{split_im_idx}\"]\n                    if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor):\n                        # idx[1:] to remove channel and then add 1 for 4th element\n                        real_world = data.affine @ torch.tensor(idx[1:] + [1]).double()\n                        real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double()\n                        assert_allclose(real_world, real_world2)\n\n            if list_output:\n                out = out[0][\"i\"]\n            else:\n                out = out[\"i_0\"]\n            expected_ndim = arr.ndim if keepdim else arr.ndim - 1\n            self.assertEqual(out.ndim, expected_ndim)\n            # assert is a shallow copy\n            arr[0, 0, 0, 0] *= 2\n            self.assertEqual(arr.flatten()[0], out.flatten()[0])\n\n    def test_singleton(self):\n        shape = (2, 1, 8, 7)\n        for p in TEST_NDARRAYS:\n            arr = p(np.random.rand(*shape))\n            out = SplitDimd(\"i\", dim=1)({\"i\": arr})\n            self.assertEqual(out[\"i\"].shape, shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/transforms/utils/test_correct_crop_centers.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.utils import correct_crop_centers\nfrom tests.test_utils import assert_allclose\n\nTESTS = [[[1, 5, 0], [2, 2, 2], [10, 10, 10]], [[4, 4, 4], [2, 2, 1], [10, 10, 10]]]\n\n\nclass TestCorrectCropCenters(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_torch(self, spatial_size, centers, label_spatial_shape):\n        result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape)\n        centers = [torch.tensor(i) for i in centers]\n        result2 = correct_crop_centers(centers, spatial_size, label_spatial_shape)\n        assert_allclose(result1, result2)\n        self.assertEqual(type(result1[0]), type(result2[0]))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utils/test_get_unique_labels.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nimport torch.nn.functional as F\nfrom parameterized import parameterized\n\nfrom monai.transforms.utils import get_unique_labels\nfrom monai.transforms.utils_pytorch_numpy_unification import moveaxis\nfrom tests.test_utils import TEST_NDARRAYS\n\ngrid_raw = [[0, 0, 0], [0, 0, 1], [2, 2, 3], [5, 5, 6], [3, 6, 2], [5, 6, 6]]\ngrid = torch.Tensor(grid_raw).unsqueeze(0).to(torch.int64)\ngrid_onehot = moveaxis(F.one_hot(grid)[0], -1, 0)\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for o_h in (False, True):\n        im = grid_onehot if o_h else grid\n        TESTS.append([dict(img=p(im), is_onehot=o_h), {0, 1, 2, 3, 5, 6}])\n        TESTS.append([dict(img=p(im), is_onehot=o_h, discard=0), {1, 2, 3, 5, 6}])\n        TESTS.append([dict(img=p(im), is_onehot=o_h, discard=[1, 2]), {0, 3, 5, 6}])\n\n\nclass TestGetUniqueLabels(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_correct_results(self, args, expected):\n        result = get_unique_labels(**args)\n        self.assertEqual(result, expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/transforms/utils/test_print_transform_backends.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.transforms.utils import get_transform_backends, print_transform_backends\n\n\nclass TestPrintTransformBackends(unittest.TestCase):\n\n    def test_get_number_of_conversions(self):\n        tr_t_or_np, *_ = get_transform_backends()\n        self.assertGreater(len(tr_t_or_np), 0)\n        print_transform_backends()\n\n\nif __name__ == \"__main__\":\n    a = TestPrintTransformBackends()\n    a.test_get_number_of_conversions()\n"
  },
  {
    "path": "tests/transforms/utils/test_soft_clip.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.transforms.utils import soft_clip\n\nTEST_CASES = [\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 10},\n        {\n            \"input\": torch.arange(10).float(),\n            \"clipped\": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": None, \"sharpness_factor\": 10},\n        {\n            \"input\": torch.arange(10).float(),\n            \"clipped\": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]),\n        },\n    ],\n    [\n        {\"minv\": None, \"maxv\": 7, \"sharpness_factor\": 10},\n        {\n            \"input\": torch.arange(10).float(),\n            \"clipped\": torch.tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 1.0},\n        {\n            \"input\": torch.arange(10).float(),\n            \"clipped\": torch.tensor([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 3.0},\n        {\n            \"input\": torch.arange(10).float(),\n            \"clipped\": torch.tensor([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 5.0},\n        {\n            \"input\": torch.arange(10).float(),\n            \"clipped\": torch.tensor([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 10},\n        {\n            \"input\": np.arange(10).astype(np.float32),\n            \"clipped\": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": None, \"sharpness_factor\": 10},\n        {\n            \"input\": np.arange(10).astype(float),\n            \"clipped\": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]),\n        },\n    ],\n    [\n        {\"minv\": None, \"maxv\": 7, \"sharpness_factor\": 10},\n        {\n            \"input\": np.arange(10).astype(float),\n            \"clipped\": np.array([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 1.0},\n        {\n            \"input\": np.arange(10).astype(float),\n            \"clipped\": np.array([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 3.0},\n        {\n            \"input\": np.arange(10).astype(float),\n            \"clipped\": np.array([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]),\n        },\n    ],\n    [\n        {\"minv\": 2, \"maxv\": 8, \"sharpness_factor\": 5.0},\n        {\n            \"input\": np.arange(10).astype(float),\n            \"clipped\": np.array([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]),\n        },\n    ],\n]\n\n\nclass TestSoftClip(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_result(self, input_param, input_data):\n        outputs = soft_clip(input_data[\"input\"], **input_param)\n        expected_val = input_data[\"clipped\"]\n        if isinstance(outputs, torch.Tensor):\n            np.testing.assert_allclose(\n                outputs.detach().cpu().numpy(), expected_val.detach().cpu().numpy(), atol=1e-4, rtol=1e-4\n            )\n        else:\n            np.testing.assert_allclose(outputs, expected_val, atol=1e-4, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/utils/enums/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/utils/enums/test_hovernet_loss.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport random\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\nfrom torch.nn import functional as F\n\nfrom monai.apps.pathology.losses import HoVerNetLoss\nfrom monai.transforms import GaussianSmooth, Rotate\nfrom monai.transforms.intensity.array import ComputeHoVerMaps\nfrom monai.utils.enums import HoVerNetBranch\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\ns = 10e-8\nt = 1.0 - s\nH = 40\nW = 40\nN = 5\nB = 2\n\n\nclass PrepareTestInputs:\n\n    def __init__(self, inputs):\n        self.inputs = {HoVerNetBranch.NP: inputs[1], HoVerNetBranch.HV: inputs[3]}\n        self.targets = {HoVerNetBranch.NP: inputs[0], HoVerNetBranch.HV: inputs[2]}\n\n        if len(inputs) > 4:\n            self.targets[HoVerNetBranch.NC] = inputs[4]\n            self.inputs[HoVerNetBranch.NC] = inputs[5]\n\n\ndef test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, width=5, rotation=0.0, smoothing=False):\n    t_g = torch.zeros((batch_size, height, width), dtype=torch.int64)\n    t_p = None\n    hv_g = torch.zeros((batch_size, 2, height, width))\n    hv_p = torch.zeros((batch_size, 2, height, width))\n\n    rad_min = 2\n    rad_max = min(max(height // 3, width // 3, rad_min), 5)\n\n    for b in range(batch_size):\n        random.seed(10 + b)\n        inst_map = torch.zeros((height, width), dtype=torch.int64)\n        for inst_id in range(1, num_objects + 1):\n            x = random.randint(rad_max, width - rad_max)\n            y = random.randint(rad_max, height - rad_max)\n            rad = random.randint(rad_min, rad_max)\n            spy, spx = np.ogrid[-x : height - x, -y : width - y]\n            circle = torch.tensor((spx * spx + spy * spy) <= rad * rad)\n\n            if num_classes > 1:\n                t_g[b, circle] = np.ceil(random.random() * num_classes)\n            else:\n                t_g[b, circle] = 1\n\n            inst_map[circle] = inst_id\n\n        hv_g[b] = ComputeHoVerMaps()(inst_map[None])\n        hv_g[b] = hv_g[b].squeeze(0)\n        if rotation > 0.0:\n            hv_p[b] = Rotate(angle=rotation, keep_size=True, mode=\"bilinear\")(hv_g[b])\n\n    n_g = t_g > 0\n    if rotation == 0.0:\n        hv_p = hv_g * 0.99\n\n    # rotation of prediction needs to happen before one-hot encoding\n    if rotation > 0.0:\n        n_p = Rotate(angle=rotation, keep_size=True, mode=\"nearest\")(n_g)\n        n_p = F.one_hot(n_p.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n        if num_classes > 1:\n            t_p = Rotate(angle=rotation, keep_size=True, mode=\"nearest\")(t_g)\n            t_p = F.one_hot(t_p.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n            t_g = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n        else:\n            t_g = None\n    else:\n        n_p = F.one_hot(n_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n        if num_classes > 1:\n            t_p = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n            t_g = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n        else:\n            t_g = None\n\n    n_g = F.one_hot(n_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)\n\n    if smoothing:\n        n_p = GaussianSmooth()(n_p)\n        if num_classes > 1:\n            t_p = GaussianSmooth()(t_p)\n        hv_p = hv_p * 0.1\n    else:\n        n_p = torch.clamp(n_p, s, t)\n        if num_classes > 1:\n            t_p = torch.clamp(t_p, s, t)\n\n    # Apply log to emulate logits\n    if t_p is not None:\n        return n_g, n_p.log(), hv_g, hv_p, t_g, t_p.log()\n    else:\n        return n_g, n_p.log(), hv_g, hv_p\n\n\ninputs_test = [\n    PrepareTestInputs(test_shape_generator(height=H, width=W)),\n    PrepareTestInputs(test_shape_generator(num_classes=N, height=H, width=W)),\n    PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W)),\n    PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.15)),\n    PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.2)),\n    PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.25)),\n]\n\nTEST_CASE_0 = [  # batch size of 1, no type prediction\n    {\"prediction\": inputs_test[0].inputs, \"target\": inputs_test[0].targets},\n    0.003,\n]\n\nTEST_CASE_1 = [  # batch size of 1, 2 classes with type prediction\n    {\"prediction\": inputs_test[1].inputs, \"target\": inputs_test[1].targets},\n    0.2762,\n]\n\nTEST_CASE_2 = [  # batch size of 2, 2 classes with type prediction\n    {\"prediction\": inputs_test[2].inputs, \"target\": inputs_test[2].targets},\n    0.4852,\n]\n\nTEST_CASE_3 = [  # batch size of 2, 3 classes with minor rotation of nuclear prediction\n    {\"prediction\": inputs_test[3].inputs, \"target\": inputs_test[3].targets},\n    3.6348,\n]\n\nTEST_CASE_4 = [  # batch size of 2, 3 classes with medium rotation of nuclear prediction\n    {\"prediction\": inputs_test[4].inputs, \"target\": inputs_test[4].targets},\n    4.5312,\n]\n\nTEST_CASE_5 = [  # batch size of 2, 3 classes with medium rotation of nuclear prediction\n    {\"prediction\": inputs_test[5].inputs, \"target\": inputs_test[5].targets},\n    5.4929,\n]\n\nCASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]\n\nILL_CASES = [\n    [\n        {\n            \"prediction\": {\"np\": inputs_test[0].inputs[HoVerNetBranch.NP]},\n            \"target\": {\n                \"np\": inputs_test[0].targets[HoVerNetBranch.NP],\n                HoVerNetBranch.HV: inputs_test[0].targets[HoVerNetBranch.HV],\n            },\n        }\n    ]\n]\n\n\nclass TestHoverNetLoss(unittest.TestCase):\n\n    @parameterized.expand(CASES)\n    def test_shape(self, input_param, expected_loss):\n        loss = HoVerNetLoss()\n        result = loss(**input_param).to(device)\n        self.assertAlmostEqual(float(result), expected_loss, places=2)\n\n    @parameterized.expand(ILL_CASES)\n    def test_ill_input_hyper_params(self, input_param):\n        with self.assertRaises(ValueError):\n            loss = HoVerNetLoss()\n            _ = loss(**input_param).to(device)\n\n\nif __name__ == \"__main__\":\n    unittest.main(argv=[\"first-arg-is-ignored\"], exit=False)\n"
  },
  {
    "path": "tests/utils/enums/test_ordering.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nfrom parameterized import parameterized\n\nfrom monai.utils.enums import OrderingTransformations, OrderingType\nfrom monai.utils.ordering import Ordering\n\nTEST_2D_NON_RANDOM = [\n    [\n        {\n            \"ordering_type\": OrderingType.RASTER_SCAN,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 1, 2, 3],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 1, 3, 2],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.RASTER_SCAN,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [2, 3, 0, 1],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [2, 3, 1, 0],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.RASTER_SCAN,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": ((1, 0),),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 2, 1, 3],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": ((1, 0),),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 2, 3, 1],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.RASTER_SCAN,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": ((0, 1),),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [1, 3, 0, 2],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": ((0, 1),),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [1, 3, 2, 0],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.RASTER_SCAN,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": ((1, 0),),\n            \"rot90_axes\": ((0, 1),),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 1, 2, 3],\n    ],\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": ((1, 0),),\n            \"rot90_axes\": ((0, 1),),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 1, 3, 2],\n    ],\n]\n\n\nTEST_3D = [\n    [\n        {\n            \"ordering_type\": OrderingType.RASTER_SCAN,\n            \"spatial_dims\": 3,\n            \"dimensions\": (1, 2, 2, 2),\n            \"reflected_spatial_dims\": (),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        },\n        [0, 1, 2, 3, 4, 5, 6, 7],\n    ]\n]\n\nTEST_ORDERING_TYPE_FAILURE = [\n    [\n        {\n            \"ordering_type\": \"hilbert\",\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": ((1, 0),),\n            \"rot90_axes\": ((0, 1),),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        }\n    ]\n]\n\nTEST_ORDERING_TRANSFORMATION_FAILURE = [\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": ((1, 0),),\n            \"rot90_axes\": ((0, 1),),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                \"flip\",\n            ),\n        }\n    ]\n]\n\nTEST_REVERT = [\n    [\n        {\n            \"ordering_type\": OrderingType.S_CURVE,\n            \"spatial_dims\": 2,\n            \"dimensions\": (1, 2, 2),\n            \"reflected_spatial_dims\": (True, False),\n            \"transpositions_axes\": (),\n            \"rot90_axes\": (),\n            \"transformation_order\": (\n                OrderingTransformations.TRANSPOSE.value,\n                OrderingTransformations.ROTATE_90.value,\n                OrderingTransformations.REFLECT.value,\n            ),\n        }\n    ]\n]\n\n\nclass TestOrdering(unittest.TestCase):\n    @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D)\n    def test_ordering(self, input_param, expected_sequence_ordering):\n        ordering = Ordering(**input_param)\n        self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True))\n\n    @parameterized.expand(TEST_ORDERING_TYPE_FAILURE)\n    def test_ordering_type_failure(self, input_param):\n        with self.assertRaises(ValueError):\n            Ordering(**input_param)\n\n    @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE)\n    def test_ordering_transformation_failure(self, input_param):\n        with self.assertRaises(ValueError):\n            Ordering(**input_param)\n\n    @parameterized.expand(TEST_REVERT)\n    def test_revert(self, input_param):\n        sequence = np.random.randint(0, 100, size=input_param[\"dimensions\"]).flatten()\n\n        ordering = Ordering(**input_param)\n\n        reverted_sequence = sequence[ordering.get_sequence_ordering()]\n        reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()]\n\n        self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/enums/test_wsireader.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\nfrom pathlib import Path\nfrom typing import Any\nfrom unittest import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.config import PathLike\nfrom monai.data import DataLoader, Dataset\nfrom monai.data.wsi_reader import WSIReader\nfrom monai.transforms import Compose, LoadImaged, ToTensord\nfrom monai.utils import first, optional_import\nfrom monai.utils.enums import PostFix, WSIPatchKeys\nfrom tests.test_utils import assert_allclose, download_url_or_skip_test, skip_if_no_cuda, testing_data_config\n\ncucim, has_cucim = optional_import(\"cucim\")\nhas_cucim = has_cucim and hasattr(cucim, \"CuImage\")\nopenslide, has_osl = optional_import(\"openslide\")\nimwrite, has_tiff = optional_import(\"tifffile\", name=\"imwrite\")\n_, has_codec = optional_import(\"imagecodecs\")\nhas_tiff = has_tiff and has_codec\n\nTESTS_PATH = Path(__file__).parents[2]\nWSI_GENERIC_TIFF_KEY = \"wsi_generic_tiff\"\nWSI_GENERIC_TIFF_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"temp_{WSI_GENERIC_TIFF_KEY}.tiff\")\n\nWSI_APERIO_SVS_KEY = \"wsi_aperio_svs\"\nWSI_APERIO_SVS_PATH = os.path.join(TESTS_PATH, \"testing_data\", f\"temp_{WSI_APERIO_SVS_KEY}.svs\")\n\nWSI_GENERIC_TIFF_HEIGHT = 32914\nWSI_GENERIC_TIFF_WIDTH = 46000\n\nTEST_CASE_WHOLE_0 = [WSI_GENERIC_TIFF_PATH, 2, (3, WSI_GENERIC_TIFF_HEIGHT // 4, WSI_GENERIC_TIFF_WIDTH // 4)]\n\nTEST_CASE_TRANSFORM_0 = [\n    WSI_GENERIC_TIFF_PATH,\n    4,\n    (WSI_GENERIC_TIFF_HEIGHT // 16, WSI_GENERIC_TIFF_WIDTH // 16),\n    (1, 3, WSI_GENERIC_TIFF_HEIGHT // 16, WSI_GENERIC_TIFF_WIDTH // 16),\n]\n\n# ----------------------------------------------------------------------------\n# Test cases for reading patches\n# ----------------------------------------------------------------------------\n\nTEST_CASE_0 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": None},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float64),\n]\n\nTEST_CASE_1 = [\n    WSI_GENERIC_TIFF_PATH,\n    {},\n    {\"location\": (WSI_GENERIC_TIFF_HEIGHT // 2, WSI_GENERIC_TIFF_WIDTH // 2), \"size\": (2, 1), \"level\": 0},\n    np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8),\n]\n\nTEST_CASE_2 = [\n    WSI_GENERIC_TIFF_PATH,\n    {},\n    {\"location\": (0, 0), \"size\": (2, 1), \"level\": 8},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8),\n]\n\nTEST_CASE_3 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"channel_dim\": -1},\n    {\"location\": (WSI_GENERIC_TIFF_HEIGHT // 2, WSI_GENERIC_TIFF_WIDTH // 2), \"size\": (4, 1), \"level\": 0},\n    np.moveaxis(\n        np.array(\n            [[[246], [246], [246], [246]], [[246], [246], [246], [246]], [[246], [246], [246], [246]]], dtype=np.uint8\n        ),\n        0,\n        -1,\n    ),\n]\n\nTEST_CASE_4 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"channel_dim\": 2},\n    {\"location\": (0, 0), \"size\": (4, 1), \"level\": 8},\n    np.moveaxis(\n        np.array(\n            [[[242], [242], [242], [242]], [[242], [242], [242], [242]], [[242], [242], [242], [242]]], dtype=np.uint8\n        ),\n        0,\n        -1,\n    ),\n]\n\nTEST_CASE_5 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8),\n]\n\nTEST_CASE_6 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": np.int32},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.int32),\n]\n\nTEST_CASE_7 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": np.float32},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float32),\n]\n\nTEST_CASE_8 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": torch.uint8},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.uint8),\n]\n\nTEST_CASE_9 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": torch.float32},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32),\n]\n\n# exact mpp in get_data\nTEST_CASE_10_MPP = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"mpp_atol\": 0.0, \"mpp_rtol\": 0.0},\n    {\"location\": (WSI_GENERIC_TIFF_HEIGHT // 2, WSI_GENERIC_TIFF_WIDTH // 2), \"size\": (2, 1), \"mpp\": 1000},\n    np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8),\n    {\"level\": 0},\n]\n\n# exact mpp as default\nTEST_CASE_11_MPP = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"mpp_atol\": 0.0, \"mpp_rtol\": 0.0, \"mpp\": 1000},\n    {\"location\": (WSI_GENERIC_TIFF_HEIGHT // 2, WSI_GENERIC_TIFF_WIDTH // 2), \"size\": (2, 1)},\n    np.array([[[246], [246]], [[246], [246]], [[246], [246]]], dtype=np.uint8),\n    {\"level\": 0},\n]\n\n# exact mpp as default (Aperio SVS)\nTEST_CASE_12_MPP = [\n    WSI_APERIO_SVS_PATH,\n    {\"mpp_atol\": 0.0, \"mpp_rtol\": 0.0, \"mpp\": 0.499},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8),\n    {\"level\": 0},\n]\n# acceptable mpp within default tolerances\nTEST_CASE_13_MPP = [\n    WSI_GENERIC_TIFF_PATH,\n    {},\n    {\"location\": (0, 0), \"size\": (2, 1), \"mpp\": 256000},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8),\n    {\"level\": 8},\n]\n\n# acceptable mpp within default tolerances (Aperio SVS)\nTEST_CASE_14_MPP = [\n    WSI_APERIO_SVS_PATH,\n    {\"mpp\": 8.0},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[238], [240]], [[239], [241]], [[240], [241]]], dtype=np.uint8),\n    {\"level\": 2},\n]\n\n# acceptable mpp within absolute tolerance (Aperio SVS)\nTEST_CASE_15_MPP = [\n    WSI_APERIO_SVS_PATH,\n    {\"mpp\": 7.0, \"mpp_atol\": 1.0, \"mpp_rtol\": 0.0},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[238], [240]], [[239], [241]], [[240], [241]]], dtype=np.uint8),\n    {\"level\": 2},\n]\n\n# acceptable mpp within relative tolerance (Aperio SVS)\nTEST_CASE_16_MPP = [\n    WSI_APERIO_SVS_PATH,\n    {\"mpp\": 7.8, \"mpp_atol\": 0.0, \"mpp_rtol\": 0.1},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[238], [240]], [[239], [241]], [[240], [241]]], dtype=np.uint8),\n    {\"level\": 2},\n]\n\n# exact power\nTEST_CASE_17_POWER = [\n    WSI_APERIO_SVS_PATH,\n    {\"power_atol\": 0.0, \"power_rtol\": 0.0},\n    {\"location\": (0, 0), \"size\": (2, 1), \"power\": 20},\n    np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8),\n    {\"level\": 0},\n]\n\n# exact power\nTEST_CASE_18_POWER = [\n    WSI_APERIO_SVS_PATH,\n    {\"power\": 20, \"power_atol\": 0.0, \"power_rtol\": 0.0},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[239], [239]], [[239], [239]], [[239], [239]]], dtype=np.uint8),\n    {\"level\": 0},\n]\n\n# acceptable power within default tolerances (Aperio SVS)\nTEST_CASE_19_POWER = [\n    WSI_APERIO_SVS_PATH,\n    {},\n    {\"location\": (0, 0), \"size\": (2, 1), \"power\": 1.25},\n    np.array([[[238], [240]], [[239], [241]], [[240], [241]]], dtype=np.uint8),\n    {\"level\": 2},\n]\n\n# acceptable power within absolute tolerance (Aperio SVS)\nTEST_CASE_20_POWER = [\n    WSI_APERIO_SVS_PATH,\n    {\"power_atol\": 0.3, \"power_rtol\": 0.0},\n    {\"location\": (0, 0), \"size\": (2, 1), \"power\": 1.0},\n    np.array([[[238], [240]], [[239], [241]], [[240], [241]]], dtype=np.uint8),\n    {\"level\": 2},\n]\n\n# acceptable power within relative tolerance (Aperio SVS)\nTEST_CASE_21_POWER = [\n    WSI_APERIO_SVS_PATH,\n    {\"power_atol\": 0.0, \"power_rtol\": 0.3},\n    {\"location\": (0, 0), \"size\": (2, 1), \"power\": 1.0},\n    np.array([[[238], [240]], [[239], [241]], [[240], [241]]], dtype=np.uint8),\n    {\"level\": 2},\n]\n# device tests\nTEST_CASE_DEVICE_1 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": torch.float32, \"device\": \"cpu\"},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32),\n    \"cpu\",\n]\n\nTEST_CASE_DEVICE_2 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": torch.float32, \"device\": \"cuda\"},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32),\n    \"cuda\",\n]\n\nTEST_CASE_DEVICE_3 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": np.float32, \"device\": \"cpu\"},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.float32),\n    \"cpu\",\n]\n\nTEST_CASE_DEVICE_4 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"dtype\": np.float32, \"device\": \"cuda\"},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.float32),\n    \"cuda\",\n]\n\nTEST_CASE_DEVICE_5 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"device\": \"cuda\"},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    torch.tensor([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=torch.uint8),\n    \"cuda\",\n]\n\nTEST_CASE_DEVICE_6 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8),\n    \"cpu\",\n]\n\nTEST_CASE_DEVICE_7 = [\n    WSI_GENERIC_TIFF_PATH,\n    {\"level\": 8, \"device\": None},\n    {\"location\": (0, 0), \"size\": (2, 1)},\n    np.array([[[242], [242]], [[242], [242]], [[242], [242]]], dtype=np.uint8),\n    \"cpu\",\n]\n\nTEST_CASE_MULTI_WSI = [\n    [WSI_GENERIC_TIFF_PATH, WSI_GENERIC_TIFF_PATH],\n    {\"location\": (0, 0), \"size\": (2, 1), \"level\": 8},\n    np.concatenate(\n        [\n            np.array([[[242], [242]], [[242], [242]], [[242], [242]]]),\n            np.array([[[242], [242]], [[242], [242]], [[242], [242]]]),\n        ],\n        axis=0,\n    ),\n]\n\nTEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)]  # CHW\n\nTEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)]  # CHW\n\nTEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)]  # no color channel\nTEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)]  # one color channel\nTEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)]  # two color channels\nTEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)]  # 3D + color\n\n# mpp not within default\nTEST_CASE_ERROR_0_MPP = [\n    WSI_GENERIC_TIFF_PATH,\n    {},\n    {\"location\": (WSI_GENERIC_TIFF_HEIGHT // 2, WSI_GENERIC_TIFF_WIDTH // 2), \"size\": (2, 1), \"mpp\": 1200},\n    ValueError,\n]\n\n# mpp is not exact (no tolerance)\nTEST_CASE_ERROR_1_MPP = [\n    WSI_APERIO_SVS_PATH,\n    {\"mpp_atol\": 0.0, \"mpp_rtol\": 0.0},\n    {\"location\": (0, 0), \"size\": (2, 1), \"mpp\": 8.0},\n    ValueError,\n]\n\n# power not within default\nTEST_CASE_ERROR_2_POWER = [WSI_APERIO_SVS_PATH, {}, {\"location\": (0, 0), \"size\": (2, 1), \"power\": 40}, ValueError]\n\n# power is not exact (no tolerance)\nTEST_CASE_ERROR_3_POWER = [\n    WSI_APERIO_SVS_PATH,\n    {\"power_atol\": 0.0, \"power_rtol\": 0.0},\n    {\"location\": (0, 0), \"size\": (2, 1), \"power\": 1.25},\n    ValueError,\n]\n\nTEST_CASE_MPP_0 = [WSI_GENERIC_TIFF_PATH, 0, (1000.0, 1000.0)]\n\n\ndef save_rgba_tiff(array: np.ndarray, filename: str, mode: str):\n    \"\"\"\n    Save numpy array into a TIFF RGB/RGBA file\n\n    Args:\n        array: numpy ndarray with the shape of CxHxW and C==3 representing a RGB image\n        filename: the filename to be used for the tiff file. '_RGB.tiff' or '_RGBA.tiff' will be appended to this filename.\n        mode: RGB or RGBA\n    \"\"\"\n    if mode == \"RGBA\":\n        array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8)\n\n    img_rgb = array.transpose(1, 2, 0)\n    imwrite(filename, img_rgb, shape=img_rgb.shape, tile=(16, 16))\n\n    return filename\n\n\ndef save_gray_tiff(array: np.ndarray, filename: str):\n    \"\"\"\n    Save numpy array into a TIFF file\n\n    Args:\n        array: numpy ndarray with any shape\n        filename: the filename to be used for the tiff file.\n    \"\"\"\n    img_gray = array\n    imwrite(filename, img_gray, shape=img_gray.shape)\n\n    return filename\n\n\n@skipUnless(has_cucim or has_osl or has_tiff, \"Requires cucim, openslide, or tifffile!\")\ndef setUpModule():\n    download_url_or_skip_test(\n        testing_data_config(\"images\", WSI_GENERIC_TIFF_KEY, \"url\"),\n        WSI_GENERIC_TIFF_PATH,\n        hash_type=testing_data_config(\"images\", WSI_GENERIC_TIFF_KEY, \"hash_type\"),\n        hash_val=testing_data_config(\"images\", WSI_GENERIC_TIFF_KEY, \"hash_val\"),\n    )\n    download_url_or_skip_test(\n        testing_data_config(\"images\", WSI_APERIO_SVS_KEY, \"url\"),\n        WSI_APERIO_SVS_PATH,\n        hash_type=testing_data_config(\"images\", WSI_APERIO_SVS_KEY, \"hash_type\"),\n        hash_val=testing_data_config(\"images\", WSI_APERIO_SVS_KEY, \"hash_val\"),\n    )\n\n\nclass WSIReaderTests:\n    class Tests(unittest.TestCase):\n        backend = None\n\n        @parameterized.expand([TEST_CASE_WHOLE_0])\n        def test_read_whole_image(self, file_path, level, expected_shape):\n            reader = WSIReader(self.backend, level=level)\n            with reader.read(file_path) as img_obj:\n                img, meta = reader.get_data(img_obj)\n            self.assertTupleEqual(img.shape, expected_shape)\n            self.assertEqual(meta[\"backend\"], self.backend)\n            self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())\n            self.assertEqual(meta[WSIPatchKeys.LEVEL], level)\n            assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False)\n            assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False)\n\n        @parameterized.expand(\n            [\n                TEST_CASE_0,\n                TEST_CASE_1,\n                TEST_CASE_2,\n                TEST_CASE_3,\n                TEST_CASE_4,\n                TEST_CASE_5,\n                TEST_CASE_6,\n                TEST_CASE_7,\n                TEST_CASE_8,\n                TEST_CASE_9,\n                TEST_CASE_10_MPP,\n                TEST_CASE_11_MPP,\n                TEST_CASE_12_MPP,\n                TEST_CASE_13_MPP,\n                TEST_CASE_14_MPP,\n                TEST_CASE_15_MPP,\n                TEST_CASE_16_MPP,\n                TEST_CASE_17_POWER,\n                TEST_CASE_18_POWER,\n                TEST_CASE_19_POWER,\n                TEST_CASE_20_POWER,\n                TEST_CASE_21_POWER,\n            ]\n        )\n        def test_read_region(self, file_path, reader_kwargs, patch_info, expected_img, *args):\n            reader = WSIReader(self.backend, **reader_kwargs)\n            level = patch_info.get(\"level\", reader_kwargs.get(\"level\"))\n            # Skip mpp, power tests for TiffFile backend\n            if self.backend == \"tifffile\" and (level is None or level < 2 or file_path == WSI_APERIO_SVS_PATH):\n                return\n            if level is None:\n                level = args[0].get(\"level\")\n            with reader.read(file_path) as img_obj:\n                # Read twice to check multiple calls\n                img, meta = reader.get_data(img_obj, **patch_info)\n                img2 = reader.get_data(img_obj, **patch_info)[0]\n            self.assertTupleEqual(img.shape, img2.shape)\n            assert_allclose(img, img2)\n            self.assertTupleEqual(img.shape, expected_img.shape)\n            assert_allclose(img, expected_img)\n            self.assertEqual(img.dtype, expected_img.dtype)\n            self.assertEqual(meta[\"backend\"], self.backend)\n            self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())\n            self.assertEqual(meta[WSIPatchKeys.LEVEL], level)\n            assert_allclose(meta[WSIPatchKeys.SIZE], patch_info[\"size\"], type_test=False)\n            assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info[\"location\"], type_test=False)\n\n        @parameterized.expand([TEST_CASE_MULTI_WSI])\n        def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img):\n            kwargs = {\"name\": None, \"offset\": None} if self.backend == \"tifffile\" else {}\n            reader = WSIReader(self.backend, **kwargs)\n            img_obj_list = reader.read(file_path_list, **kwargs)\n            # Read twice to check multiple calls\n            img, meta = reader.get_data(img_obj_list, **patch_info)\n            img2 = reader.get_data(img_obj_list, **patch_info)[0]\n            for img_obj in img_obj_list:\n                img_obj.close()\n            self.assertTupleEqual(img.shape, img2.shape)\n            assert_allclose(img, img2)\n            self.assertTupleEqual(img.shape, expected_img.shape)\n            assert_allclose(img, expected_img)\n            self.assertEqual(meta[\"backend\"], self.backend)\n            self.assertEqual(meta[WSIPatchKeys.PATH][0].lower(), str(os.path.abspath(file_path_list[0])).lower())\n            self.assertEqual(meta[WSIPatchKeys.LEVEL][0], patch_info[\"level\"])\n            assert_allclose(meta[WSIPatchKeys.SIZE][0], expected_img.shape[1:], type_test=False)\n            assert_allclose(meta[WSIPatchKeys.LOCATION][0], patch_info[\"location\"], type_test=False)\n\n        @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])\n        @skipUnless(has_tiff, \"Requires tifffile.\")\n        def test_read_rgba(self, img_expected):\n            # skip for OpenSlide since not working with images without tiles\n            if self.backend == \"openslide\":\n                return\n            image = {}\n            reader = WSIReader(self.backend)\n            for mode in [\"RGB\", \"RGBA\"]:\n                file_path = save_rgba_tiff(\n                    img_expected, os.path.join(TESTS_PATH, \"testing_data\", f\"temp_tiff_image_{mode}.tiff\"), mode=mode\n                )\n                with reader.read(file_path) as img_obj:\n                    image[mode], _ = reader.get_data(img_obj)\n\n            assert_allclose(image[\"RGB\"], img_expected)\n            assert_allclose(image[\"RGBA\"], img_expected)\n\n        @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D])\n        @skipUnless(has_tiff, \"Requires tifffile.\")\n        def test_read_malformats(self, img_expected):\n            if self.backend == \"cucim\" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1):\n                # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230\n                return\n            reader = WSIReader(self.backend)\n            file_path = os.path.join(TESTS_PATH, \"testing_data\", \"temp_tiff_image_gray.tiff\")\n            imwrite(file_path, img_expected, shape=img_expected.shape)\n            with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)):\n                with reader.read(file_path) as img_obj:\n                    reader.get_data(img_obj)\n\n        @parameterized.expand([TEST_CASE_TRANSFORM_0])\n        def test_with_dataloader(\n            self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: tuple[int, ...]\n        ):\n            train_transform = Compose(\n                [\n                    LoadImaged(keys=[\"image\"], reader=WSIReader, backend=self.backend, level=level, image_only=False),\n                    ToTensord(keys=[\"image\"]),\n                ]\n            )\n            dataset = Dataset([{\"image\": file_path}], transform=train_transform)\n            data_loader = DataLoader(dataset)\n            data: dict = first(data_loader, {})\n            for s in data[PostFix.meta(\"image\")][\"spatial_shape\"]:\n                assert_allclose(s, expected_spatial_shape, type_test=False)\n            self.assertTupleEqual(data[\"image\"].shape, expected_shape)\n\n        @parameterized.expand([TEST_CASE_TRANSFORM_0])\n        def test_with_dataloader_batch(\n            self, file_path: PathLike, level: int, expected_spatial_shape: Any, expected_shape: tuple[int, ...]\n        ):\n            train_transform = Compose(\n                [\n                    LoadImaged(keys=[\"image\"], reader=WSIReader, backend=self.backend, level=level, image_only=False),\n                    ToTensord(keys=[\"image\"]),\n                ]\n            )\n            dataset = Dataset([{\"image\": file_path}, {\"image\": file_path}], transform=train_transform)\n            batch_size = 2\n            data_loader = DataLoader(dataset, batch_size=batch_size)\n            data: dict = first(data_loader, {})\n            for s in data[PostFix.meta(\"image\")][\"spatial_shape\"]:\n                assert_allclose(s, expected_spatial_shape, type_test=False)\n            self.assertTupleEqual(data[\"image\"].shape, (batch_size, *expected_shape[1:]))\n\n        @parameterized.expand([TEST_CASE_WHOLE_0])\n        def test_read_whole_image_multi_thread(self, file_path, level, expected_shape):\n            if self.backend == \"cucim\":\n                reader = WSIReader(self.backend, level=level, num_workers=4)\n                with reader.read(file_path) as img_obj:\n                    img, meta = reader.get_data(img_obj)\n                self.assertTupleEqual(img.shape, expected_shape)\n                self.assertEqual(meta[\"backend\"], self.backend)\n                self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())\n                self.assertEqual(meta[WSIPatchKeys.LEVEL], level)\n                assert_allclose(meta[WSIPatchKeys.SIZE], expected_shape[1:], type_test=False)\n                assert_allclose(meta[WSIPatchKeys.LOCATION], (0, 0), type_test=False)\n\n        @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])\n        def test_read_region_multi_thread(self, file_path, kwargs, patch_info, expected_img):\n            if self.backend == \"cucim\":\n                reader = WSIReader(backend=self.backend, num_workers=2, **kwargs)\n                with reader.read(file_path) as img_obj:\n                    # Read twice to check multiple calls\n                    img, meta = reader.get_data(img_obj, **patch_info)\n                    img2 = reader.get_data(img_obj, **patch_info)[0]\n                    self.assertTupleEqual(img.shape, img2.shape)\n                    assert_allclose(img, img2)\n                    self.assertTupleEqual(img.shape, expected_img.shape)\n                    assert_allclose(img, expected_img)\n                    self.assertEqual(meta[\"backend\"], self.backend)\n                    self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())\n                    self.assertEqual(meta[WSIPatchKeys.LEVEL], patch_info[\"level\"])\n                    assert_allclose(meta[WSIPatchKeys.SIZE], patch_info[\"size\"], type_test=False)\n                    assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info[\"location\"], type_test=False)\n\n        @parameterized.expand([TEST_CASE_MPP_0])\n        def test_resolution_mpp(self, file_path, level, expected_mpp):\n            reader = WSIReader(self.backend, level=level)\n            with reader.read(file_path) as img_obj:\n                mpp = reader.get_mpp(img_obj, level)\n            self.assertTupleEqual(mpp, expected_mpp)\n\n        @parameterized.expand(\n            [\n                TEST_CASE_DEVICE_1,\n                TEST_CASE_DEVICE_2,\n                TEST_CASE_DEVICE_3,\n                TEST_CASE_DEVICE_4,\n                TEST_CASE_DEVICE_5,\n                TEST_CASE_DEVICE_6,\n                TEST_CASE_DEVICE_7,\n            ]\n        )\n        @skip_if_no_cuda\n        def test_read_region_device(self, file_path, kwargs, patch_info, expected_img, device):\n            reader = WSIReader(self.backend, **kwargs)\n            level = patch_info.get(\"level\", kwargs.get(\"level\"))\n            if self.backend == \"tifffile\" and level < 2:\n                return\n            with reader.read(file_path) as img_obj:\n                # Read twice to check multiple calls\n                img, meta = reader.get_data(img_obj, **patch_info)\n                img2 = reader.get_data(img_obj, **patch_info)[0]\n            self.assertTupleEqual(img.shape, img2.shape)\n            assert_allclose(img, img2)\n            self.assertTupleEqual(img.shape, expected_img.shape)\n            assert_allclose(img, expected_img)\n            self.assertEqual(img.dtype, expected_img.dtype)\n            if isinstance(img, torch.Tensor):\n                self.assertEqual(img.device.type, device)\n            else:\n                self.assertEqual(\"cpu\", device)\n            self.assertEqual(meta[\"backend\"], self.backend)\n            self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())\n            self.assertEqual(meta[WSIPatchKeys.LEVEL], level)\n            assert_allclose(meta[WSIPatchKeys.SIZE], patch_info[\"size\"], type_test=False)\n            assert_allclose(meta[WSIPatchKeys.LOCATION], patch_info[\"location\"], type_test=False)\n\n        @parameterized.expand(\n            [TEST_CASE_ERROR_0_MPP, TEST_CASE_ERROR_1_MPP, TEST_CASE_ERROR_2_POWER, TEST_CASE_ERROR_3_POWER]\n        )\n        def test_errors(self, file_path, reader_kwargs, patch_info, exception):\n            with self.assertRaises(exception):\n                reader = WSIReader(self.backend, **reader_kwargs)\n                with reader.read(file_path) as img_obj:\n                    reader.get_data(img_obj, **patch_info)\n\n\n@skipUnless(has_cucim, \"Requires cucim\")\nclass TestCuCIM(WSIReaderTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"cucim\"\n\n\n@skipUnless(has_osl, \"Requires openslide\")\nclass TestOpenSlide(WSIReaderTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"openslide\"\n\n\n@skipUnless(has_tiff, \"Requires tifffile\")\nclass TestTiffFile(WSIReaderTests.Tests):\n    @classmethod\n    def setUpClass(cls):\n        cls.backend = \"tifffile\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/misc/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/utils/misc/test_ensure_tuple.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils.misc import ensure_tuple\nfrom tests.test_utils import assert_allclose\n\nTESTS = [\n    [\"test\", (\"test\",)],\n    [[\"test1\", \"test2\"], (\"test1\", \"test2\")],\n    [123, (123,)],\n    [(1, [2], 3), (1, [2], 3)],\n    [(1, 2, 3), (1, 2, 3), True],\n    [np.array([1, 2]), (np.array([1, 2]),), True],\n    [np.array([1, 2]), (1, 2), False],\n    [torch.tensor([1, 2]), (torch.tensor([1, 2]),), True],\n    [np.array([]), (np.array([]),)],\n    [torch.tensor([]), (torch.tensor([]),)],\n    [np.array(123), (np.array(123),), True],\n    [torch.tensor(123), (torch.tensor(123),)],\n]\n\n\nclass TestEnsureTuple(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_value(self, input, expected_value, wrap_array=False):\n        result = ensure_tuple(input, wrap_array)\n\n        self.assertTrue(isinstance(result, tuple))\n        if isinstance(input, (np.ndarray, torch.Tensor)):\n            for i, j in zip(result, expected_value):\n                assert_allclose(i, j)\n        else:\n            self.assertTupleEqual(result, expected_value)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/misc/test_monai_env_vars.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\n\nfrom monai.utils.misc import MONAIEnvVars\n\n\nclass TestMONAIEnvVars(unittest.TestCase):\n\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls.orig_value = str(MONAIEnvVars.debug())\n\n    @classmethod\n    def tearDownClass(cls):\n        if cls.orig_value is not None:\n            os.environ[\"MONAI_DEBUG\"] = cls.orig_value\n        else:\n            os.environ.pop(\"MONAI_DEBUG\")\n        print(\"MONAI debug value:\", str(MONAIEnvVars.debug()))\n        super().tearDownClass()\n\n    def test_monai_env_vars(self):\n        for debug in (False, True):\n            os.environ[\"MONAI_DEBUG\"] = str(debug)\n            self.assertEqual(os.environ.get(\"MONAI_DEBUG\"), str(debug))\n            self.assertEqual(MONAIEnvVars.debug(), debug)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/misc/test_monai_utils_misc.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.utils.misc import MONAIEnvVars, check_kwargs_exist_in_class_init, run_cmd, to_tuple_of_dictionaries\n\nTO_TUPLE_OF_DICTIONARIES_TEST_CASES = [\n    ({}, tuple(), tuple()),\n    ({}, (\"x\",), ({},)),\n    ({}, (\"x\", \"y\"), ({}, {})),\n    ({\"a\": 1}, tuple(), tuple()),\n    ({\"a\": 1}, (\"x\",), ({\"a\": 1},)),\n    ({\"a\": (1,)}, (\"x\",), ({\"a\": 1},)),\n    ({\"a\": (1,)}, (\"x\", \"y\"), ValueError()),\n    ({\"a\": 1}, (\"x\", \"y\"), ({\"a\": 1}, {\"a\": 1})),\n    ({\"a\": (1, 2)}, tuple(), tuple()),\n    ({\"a\": (1, 2)}, (\"x\", \"y\"), ({\"a\": 1}, {\"a\": 2})),\n    ({\"a\": (1, 2, 3)}, (\"x\", \"y\"), ValueError()),\n    ({\"b\": (2,), \"a\": 1}, tuple(), tuple()),\n    ({\"b\": (2,), \"a\": 1}, (\"x\",), ({\"b\": 2, \"a\": 1},)),\n    ({\"b\": (2,), \"a\": 1}, (\"x\", \"y\"), ValueError()),\n    ({\"b\": (3, 2), \"a\": 1}, tuple(), tuple()),\n    ({\"b\": (3, 2), \"a\": 1}, (\"x\",), ValueError()),\n    ({\"b\": (3, 2), \"a\": 1}, (\"x\", \"y\"), ({\"b\": 3, \"a\": 1}, {\"b\": 2, \"a\": 1})),\n]\n\n\nclass MiscClass:\n\n    def __init__(self, arg1, arg2, kwargs1=None, kwargs2=None):\n        pass\n\n\nclass TestToTupleOfDictionaries(unittest.TestCase):\n\n    @parameterized.expand(TO_TUPLE_OF_DICTIONARIES_TEST_CASES)\n    def test_to_tuple_of_dictionaries(self, dictionary, keys, expected):\n        self._test_to_tuple_of_dictionaries(dictionary, keys, expected)\n\n    def _test_to_tuple_of_dictionaries(self, dictionary, keys, expected):\n        if isinstance(expected, Exception):\n            with self.assertRaises(type(expected)):\n                to_tuple_of_dictionaries(dictionary, keys)\n            print(type(expected))\n        else:\n            actual = to_tuple_of_dictionaries(dictionary, keys)\n            print(actual, expected)\n            self.assertTupleEqual(actual, expected)\n\n\nclass TestMiscKwargs(unittest.TestCase):\n\n    def test_kwargs(self):\n        present, extra_args = self._custom_user_function(MiscClass, 1, kwargs1=\"value1\", kwargs2=\"value2\")\n        self.assertEqual(present, True)\n        self.assertEqual(extra_args, set())\n        present, extra_args = self._custom_user_function(MiscClass, 1, kwargs1=\"value1\", kwargs3=\"value3\")\n        self.assertEqual(present, False)\n        self.assertEqual(extra_args, {\"kwargs3\"})\n\n    def _custom_user_function(self, cls, *args, **kwargs):\n        return check_kwargs_exist_in_class_init(cls, kwargs)\n\n\nclass TestCommandRunner(unittest.TestCase):\n\n    def setUp(self):\n        self.orig_flag = str(MONAIEnvVars.debug())\n\n    def tearDown(self):\n        if self.orig_flag is not None:\n            os.environ[\"MONAI_DEBUG\"] = self.orig_flag\n        else:\n            os.environ.pop(\"MONAI_DEBUG\")\n\n    def test_run_cmd(self):\n        cmd1 = \"python\"\n        cmd2 = \"-c\"\n        cmd3 = 'import sys; print(\"\\\\tThis is on stderr\\\\n\", file=sys.stderr); sys.exit(1)'\n        os.environ[\"MONAI_DEBUG\"] = str(True)\n        with self.assertRaises(RuntimeError) as cm:\n            run_cmd([cmd1, cmd2, cmd3], check=True)\n        self.assertIn(\"This is on stderr\", str(cm.exception))\n        self.assertNotIn(\"\\\\n\", str(cm.exception))\n        self.assertNotIn(\"\\\\t\", str(cm.exception))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/misc/test_str2bool.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.utils.misc import str2bool\n\n\nclass TestStr2Bool(unittest.TestCase):\n\n    def test_str_2_bool(self):\n        for i in (\"yes\", \"true\", \"t\", \"y\", \"1\", True):\n            self.assertTrue(str2bool(i))\n        for i in (\"no\", \"false\", \"f\", \"n\", \"0\", False):\n            self.assertFalse(str2bool(i))\n        for bad_value in (\"test\", 0, 1, 2, None):\n            self.assertFalse(str2bool(bad_value, default=False, raise_exc=False))\n            self.assertTrue(str2bool(bad_value, default=True, raise_exc=False))\n            with self.assertRaises(ValueError):\n                self.assertTrue(str2bool(bad_value))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/misc/test_str2list.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.utils.misc import str2list\n\n\nclass TestStr2List(unittest.TestCase):\n\n    def test_str_2_list(self):\n        for i in (\"1,2,3\", \"1, 2, 3\", \"1,2e-0,3.0\", [1, 2, 3]):\n            self.assertEqual(str2list(i), [1, 2, 3])\n        for i in (\"1,2,3\", \"1,2,3,4.3\", [1, 2, 3, 4.001]):\n            self.assertNotEqual(str2list(i), [1, 2, 3, 4])\n        for bad_value in ((1, 3), int):\n            self.assertIsNone(str2list(bad_value, raise_exc=False))\n            with self.assertRaises(ValueError):\n                self.assertIsNone(str2list(bad_value))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_alias.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport inspect\nimport os\nimport unittest\nfrom pathlib import Path\n\nfrom monai.utils import optional_import\n\nTESTS_PATH = Path(__file__).parents[1]\n\n\nclass TestModuleAlias(unittest.TestCase):\n    \"\"\"check that 'import monai.xx.file_name' returns a module\"\"\"\n\n    def test_files(self):\n        src_dir = os.path.dirname(TESTS_PATH)\n        monai_dir = os.path.join(src_dir, \"monai\")\n        py_files = glob.glob(os.path.join(monai_dir, \"**\", \"*.py\"), recursive=True)\n        for x in py_files:\n            if os.path.basename(x).startswith(\"_\"):\n                continue\n            mod_name = x[len(src_dir) : -3]  # create relative path\n            mod_name = mod_name[1:].replace(mod_name[0], \".\")\n            mod, cls = mod_name.rsplit(\".\", 1)\n            obj, exist = optional_import(mod, name=cls)\n            if exist:\n                self.assertTrue(inspect.ismodule(obj), msg=mod_name)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_component_store.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.utils import ComponentStore\n\n\nclass TestComponentStore(unittest.TestCase):\n\n    def setUp(self):\n        self.cs = ComponentStore(\"TestStore\", \"I am a test store, please ignore\")\n\n    def test_empty(self):\n        self.assertEqual(len(self.cs), 0)\n        self.assertEqual(list(self.cs), [])\n\n    def test_add(self):\n        test_obj = object()\n\n        self.assertFalse(\"test_obj\" in self.cs)\n\n        self.cs.add(\"test_obj\", \"Test object\", test_obj)\n\n        self.assertTrue(\"test_obj\" in self.cs)\n\n        self.assertEqual(len(self.cs), 1)\n        self.assertEqual(list(self.cs), [(\"test_obj\", test_obj)])\n\n        self.assertEqual(self.cs.test_obj, test_obj)\n        self.assertEqual(self.cs[\"test_obj\"], test_obj)\n\n    def test_add2(self):\n        test_obj1 = object()\n        test_obj2 = object()\n\n        self.cs.add(\"test_obj1\", \"Test object\", test_obj1)\n        self.cs.add(\"test_obj2\", \"Test object\", test_obj2)\n\n        self.assertEqual(len(self.cs), 2)\n        self.assertIn(\"test_obj1\", self.cs)\n        self.assertIn(\"test_obj2\", self.cs)\n\n    def test_add_def(self):\n        self.assertNotIn(\"test_func\", self.cs)\n\n        @self.cs.add_def(\"test_func\", \"Test function\")\n        def test_func():\n            return 123\n\n        self.assertIn(\"test_func\", self.cs)\n\n        self.assertEqual(len(self.cs), 1)\n        self.assertEqual(list(self.cs), [(\"test_func\", test_func)])\n\n        self.assertEqual(self.cs.test_func, test_func)\n        self.assertEqual(self.cs[\"test_func\"], test_func)\n\n        # try adding the same function again\n        self.cs.add_def(\"test_func\", \"Test function but with new description\")(test_func)\n\n        self.assertEqual(len(self.cs), 1)\n        self.assertEqual(self.cs.test_func, test_func)\n"
  },
  {
    "path": "tests/utils/test_deprecated.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nimport warnings\n\nfrom monai.utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default\n\n\nclass TestDeprecatedRC(unittest.TestCase):\n\n    def setUp(self):\n        self.test_version_rc = \"0.6.0rc1\"\n        self.test_version = \"0.6.0\"\n        self.next_version = \"0.7.0\"\n\n    def test_warning(self):\n        \"\"\"Test deprecated decorator with `since` and `removed` set for an RC version\"\"\"\n\n        @deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version_rc)\n        def foo2():\n            pass\n\n        foo2()  # should not raise any warnings\n\n    def test_warning_milestone(self):\n        \"\"\"Test deprecated decorator with `since` and `removed` set for a milestone version\"\"\"\n\n        @deprecated(since=self.test_version, removed=self.next_version, version_val=self.test_version)\n        def foo2():\n            pass\n\n        self.assertWarns(FutureWarning, foo2)\n\n    def test_warning_last(self):\n        \"\"\"Test deprecated decorator with `since` and `removed` set, for the last version\"\"\"\n\n        @deprecated(since=self.test_version, removed=self.next_version, version_val=self.next_version)\n        def foo3():\n            pass\n\n        self.assertRaises(DeprecatedError, foo3)\n\n    def test_warning_beyond(self):\n        \"\"\"Test deprecated decorator with `since` and `removed` set, beyond the last version\"\"\"\n\n        @deprecated(since=self.test_version_rc, removed=self.test_version, version_val=self.next_version)\n        def foo3():\n            pass\n\n        self.assertRaises(DeprecatedError, foo3)\n\n\nclass TestDeprecated(unittest.TestCase):\n\n    def setUp(self):\n        self.test_version = \"0.5.3+96.g1fa03c2.dirty\"\n        self.prev_version = \"0.4.3+96.g1fa03c2.dirty\"\n        self.next_version = \"0.6.3+96.g1fa03c2.dirty\"\n\n    def test_warning1(self):\n        \"\"\"Test deprecated decorator with just `since` set.\"\"\"\n\n        @deprecated(since=self.prev_version, version_val=self.test_version)\n        def foo1():\n            pass\n\n        self.assertWarns(FutureWarning, foo1)\n\n    def test_warning2(self):\n        \"\"\"Test deprecated decorator with `since` and `removed` set.\"\"\"\n\n        @deprecated(since=self.prev_version, removed=self.next_version, version_val=self.test_version)\n        def foo2():\n            pass\n\n        self.assertWarns(FutureWarning, foo2)\n\n    def test_except1(self):\n        \"\"\"Test deprecated decorator raises exception with no versions set.\"\"\"\n\n        @deprecated(version_val=self.test_version)\n        def foo3():\n            pass\n\n        self.assertRaises(DeprecatedError, foo3)\n\n    def test_except2(self):\n        \"\"\"Test deprecated decorator raises exception with `removed` set in the past.\"\"\"\n\n        @deprecated(removed=self.prev_version, version_val=self.test_version)\n        def foo4():\n            pass\n\n        self.assertRaises(DeprecatedError, foo4)\n\n    def test_class_warning1(self):\n        \"\"\"Test deprecated decorator with just `since` set.\"\"\"\n\n        @deprecated(since=self.prev_version, version_val=self.test_version)\n        class Foo1:\n            pass\n\n        self.assertWarns(FutureWarning, Foo1)\n\n    def test_class_warning2(self):\n        \"\"\"Test deprecated decorator with `since` and `removed` set.\"\"\"\n\n        @deprecated(since=self.prev_version, removed=self.next_version, version_val=self.test_version)\n        class Foo2:\n            pass\n\n        self.assertWarns(FutureWarning, Foo2)\n\n    def test_class_except1(self):\n        \"\"\"Test deprecated decorator raises exception with no versions set.\"\"\"\n\n        @deprecated(version_val=self.test_version)\n        class Foo3:\n            pass\n\n        self.assertRaises(DeprecatedError, Foo3)\n\n    def test_class_except2(self):\n        \"\"\"Test deprecated decorator raises exception with `removed` set in the past.\"\"\"\n\n        @deprecated(removed=self.prev_version, version_val=self.test_version)\n        class Foo4:\n            pass\n\n        self.assertRaises(DeprecatedError, Foo4)\n\n    def test_meth_warning1(self):\n        \"\"\"Test deprecated decorator with just `since` set.\"\"\"\n\n        class Foo5:\n\n            @deprecated(since=self.prev_version, version_val=self.test_version)\n            def meth1(self):\n                pass\n\n        self.assertWarns(FutureWarning, lambda: Foo5().meth1())\n\n    def test_meth_except1(self):\n        \"\"\"Test deprecated decorator with just `since` set.\"\"\"\n\n        class Foo6:\n\n            @deprecated(version_val=self.test_version)\n            def meth1(self):\n                pass\n\n        self.assertRaises(DeprecatedError, lambda: Foo6().meth1())\n\n    def test_arg_warn1(self):\n        \"\"\"Test deprecated_arg decorator with just `since` set.\"\"\"\n\n        @deprecated_arg(\"b\", since=self.prev_version, version_val=self.test_version)\n        def afoo1(a, b=None):\n            pass\n\n        afoo1(1)  # ok when no b provided\n\n        self.assertWarns(FutureWarning, lambda: afoo1(1, 2))\n\n    def test_arg_warn2(self):\n        \"\"\"Test deprecated_arg decorator with just `since` set.\"\"\"\n\n        @deprecated_arg(\"b\", since=self.prev_version, version_val=self.test_version)\n        def afoo2(a, **kw):\n            pass\n\n        afoo2(1)  # ok when no b provided\n\n        self.assertWarns(FutureWarning, lambda: afoo2(1, b=2))\n\n    def test_arg_except1(self):\n        \"\"\"Test deprecated_arg decorator raises exception with no versions set.\"\"\"\n\n        @deprecated_arg(\"b\", version_val=self.test_version)\n        def afoo3(a, b=None):\n            pass\n\n        self.assertRaises(DeprecatedError, lambda: afoo3(1, b=2))\n\n    def test_arg_except2(self):\n        \"\"\"Test deprecated_arg decorator raises exception with `removed` set in the past.\"\"\"\n\n        @deprecated_arg(\"b\", removed=self.prev_version, version_val=self.test_version)\n        def afoo4(a, b=None):\n            pass\n\n        self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))\n\n    def test_2arg_warn1(self):\n        \"\"\"Test deprecated_arg decorator applied twice with just `since` set.\"\"\"\n\n        @deprecated_arg(\"b\", since=self.prev_version, version_val=self.test_version)\n        @deprecated_arg(\"c\", since=self.prev_version, version_val=self.test_version)\n        def afoo5(a, b=None, c=None):\n            pass\n\n        afoo5(1)  # ok when no b or c provided\n\n        self.assertWarns(FutureWarning, lambda: afoo5(1, 2))\n        self.assertWarns(FutureWarning, lambda: afoo5(1, 2, 3))\n\n    def test_future(self):\n        \"\"\"Test deprecated decorator with `since` set to a future version.\"\"\"\n\n        @deprecated(since=self.next_version, version_val=self.test_version)\n        def future1():\n            pass\n\n        with self.assertWarns(FutureWarning) as aw:\n            future1()\n            warnings.warn(\"fake warning\", FutureWarning)\n\n        self.assertEqual(aw.warning.args[0], \"fake warning\")\n\n    def test_arg_except2_unknown(self):\n        \"\"\"\n        Test deprecated_arg decorator raises exception with `removed` set in the past.\n        with unknown version\n        \"\"\"\n\n        @deprecated_arg(\"b\", removed=self.prev_version, version_val=\"0+untagged.1.g3131155\")\n        def afoo4(a, b=None):\n            pass\n\n        afoo4(1, b=2)\n\n    def test_arg_except3_unknown(self):\n        \"\"\"\n        Test deprecated_arg decorator raises exception with `removed` set in the past.\n        with unknown version and kwargs\n        \"\"\"\n\n        @deprecated_arg(\"b\", removed=self.prev_version, version_val=\"0+untagged.1.g3131155\")\n        def afoo4(a, b=None, **kwargs):\n            pass\n\n        afoo4(1, b=2)\n        afoo4(1, b=2, c=3)\n\n    def test_replacement_arg(self):\n        \"\"\"\n        Test deprecated arg being replaced.\n        \"\"\"\n\n        @deprecated_arg(\"b\", new_name=\"a\", since=self.prev_version, version_val=self.test_version)\n        def afoo4(a, b=None):\n            return a\n\n        self.assertEqual(afoo4(b=2), 2)\n        self.assertEqual(afoo4(1, b=2), 1)  # new name is in use\n        self.assertEqual(afoo4(a=1, b=2), 1)  # prefers the new arg\n\n    def test_replacement_arg1(self):\n        \"\"\"\n        Test deprecated arg being replaced with kwargs.\n        \"\"\"\n\n        @deprecated_arg(\"b\", new_name=\"a\", since=self.prev_version, version_val=self.test_version)\n        def afoo4(a, *args, **kwargs):\n            return a\n\n        self.assertEqual(afoo4(b=2), 2)\n        self.assertEqual(afoo4(1, b=2, c=3), 1)  # new name is in use\n        self.assertEqual(afoo4(a=1, b=2, c=3), 1)  # prefers the new arg\n\n    def test_replacement_arg2(self):\n        \"\"\"\n        Test deprecated arg (with a default value) being replaced.\n        \"\"\"\n\n        @deprecated_arg(\"b\", new_name=\"a\", since=self.prev_version, version_val=self.test_version)\n        def afoo4(a, b=None, **kwargs):\n            return a, kwargs\n\n        self.assertEqual(afoo4(b=2, c=3), (2, {\"c\": 3}))\n        self.assertEqual(afoo4(1, b=2, c=3), (1, {\"c\": 3}))  # new name is in use\n        self.assertEqual(afoo4(a=1, b=2, c=3), (1, {\"c\": 3}))  # prefers the new arg\n        self.assertEqual(afoo4(1, 2, c=3), (1, {\"c\": 3}))  # prefers the new positional arg\n\n    def test_deprecated_arg_default_explicit_default(self):\n        \"\"\"\n        Test deprecated arg default, where the default is explicitly set (no warning).\n        \"\"\"\n\n        @deprecated_arg_default(\n            \"b\", old_default=\"a\", new_default=\"b\", since=self.prev_version, version_val=self.test_version\n        )\n        def foo(a, b=\"a\"):\n            return a, b\n\n        with self.assertWarns(FutureWarning) as aw:\n            self.assertEqual(foo(\"a\", \"a\"), (\"a\", \"a\"))\n            self.assertEqual(foo(\"a\", \"b\"), (\"a\", \"b\"))\n            self.assertEqual(foo(\"a\", \"c\"), (\"a\", \"c\"))\n            warnings.warn(\"fake warning\", FutureWarning)\n\n        self.assertEqual(aw.warning.args[0], \"fake warning\")\n\n    def test_deprecated_arg_default_version_less_than_since(self):\n        \"\"\"\n        Test deprecated arg default, where the current version is less than `since` (no warning).\n        \"\"\"\n\n        @deprecated_arg_default(\n            \"b\", old_default=\"a\", new_default=\"b\", since=self.test_version, version_val=self.prev_version\n        )\n        def foo(a, b=\"a\"):\n            return a, b\n\n        with self.assertWarns(FutureWarning) as aw:\n            self.assertEqual(foo(\"a\"), (\"a\", \"a\"))\n            self.assertEqual(foo(\"a\", \"a\"), (\"a\", \"a\"))\n            warnings.warn(\"fake warning\", FutureWarning)\n\n        self.assertEqual(aw.warning.args[0], \"fake warning\")\n\n    def test_deprecated_arg_default_warning_deprecated(self):\n        \"\"\"\n        Test deprecated arg default, where the default is used.\n        \"\"\"\n\n        @deprecated_arg_default(\n            \"b\", old_default=\"a\", new_default=\"b\", since=self.prev_version, version_val=self.test_version\n        )\n        def foo(a, b=\"a\"):\n            return a, b\n\n        self.assertWarns(FutureWarning, lambda: foo(\"a\"))\n\n    def test_deprecated_arg_default_warning_replaced(self):\n        \"\"\"\n        Test deprecated arg default, where the default is used.\n        \"\"\"\n\n        @deprecated_arg_default(\n            \"b\",\n            old_default=\"a\",\n            new_default=\"b\",\n            since=self.prev_version,\n            replaced=self.prev_version,\n            version_val=self.test_version,\n        )\n        def foo(a, b=\"a\"):\n            return a, b\n\n        self.assertWarns(FutureWarning, lambda: foo(\"a\"))\n\n    def test_deprecated_arg_default_warning_with_none_as_placeholder(self):\n        \"\"\"\n        Test deprecated arg default, where the default is used.\n        \"\"\"\n\n        @deprecated_arg_default(\n            \"b\", old_default=\"a\", new_default=\"b\", since=self.prev_version, version_val=self.test_version\n        )\n        def foo(a, b=None):\n            if b is None:\n                b = \"a\"\n            return a, b\n\n        self.assertWarns(FutureWarning, lambda: foo(\"a\"))\n\n        @deprecated_arg_default(\n            \"b\", old_default=\"a\", new_default=\"b\", since=self.prev_version, version_val=self.test_version\n        )\n        def foo2(a, b=None):\n            if b is None:\n                b = \"b\"\n            return a, b\n\n        self.assertWarns(FutureWarning, lambda: foo2(\"a\"))\n\n    def test_deprecated_arg_default_errors(self):\n        \"\"\"\n        Test deprecated arg default, where the decorator is wrongly used.\n        \"\"\"\n\n        # since > replaced\n        def since_grater_than_replaced():\n\n            @deprecated_arg_default(\n                \"b\",\n                old_default=\"a\",\n                new_default=\"b\",\n                since=self.test_version,\n                replaced=self.prev_version,\n                version_val=self.test_version,\n            )\n            def foo(a, b=None):\n                return a, b\n\n        self.assertRaises(ValueError, since_grater_than_replaced)\n\n        # argname doesnt exist\n        def argname_doesnt_exist():\n\n            @deprecated_arg_default(\n                \"other\", old_default=\"a\", new_default=\"b\", since=self.test_version, version_val=self.test_version\n            )\n            def foo(a, b=None):\n                return a, b\n\n        self.assertRaises(ValueError, argname_doesnt_exist)\n\n        # argname has no default\n        def argname_has_no_default():\n\n            @deprecated_arg_default(\n                \"a\",\n                old_default=\"a\",\n                new_default=\"b\",\n                since=self.prev_version,\n                replaced=self.test_version,\n                version_val=self.test_version,\n            )\n            def foo(a):\n                return a\n\n        self.assertRaises(ValueError, argname_has_no_default)\n\n        # new default is used but version < replaced\n        def argname_was_replaced_before_specified_version():\n\n            @deprecated_arg_default(\n                \"a\",\n                old_default=\"a\",\n                new_default=\"b\",\n                since=self.prev_version,\n                replaced=self.next_version,\n                version_val=self.test_version,\n            )\n            def foo(a, b=\"b\"):\n                return a, b\n\n        self.assertRaises(ValueError, argname_was_replaced_before_specified_version)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_enum_bound_interp.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.utils import optional_import\nfrom tests.test_utils import skip_if_no_cpp_extension\n\nb, _ = optional_import(\"monai._C\", name=\"BoundType\")\np, _ = optional_import(\"monai._C\", name=\"InterpolationType\")\n\n\n@skip_if_no_cpp_extension\nclass TestEnumBoundInterp(unittest.TestCase):\n    def test_bound(self):\n        self.assertEqual(str(b.replicate), \"BoundType.replicate\")\n        self.assertEqual(str(b.nearest), \"BoundType.replicate\")\n        self.assertEqual(str(b.dct1), \"BoundType.dct1\")\n        self.assertEqual(str(b.mirror), \"BoundType.dct1\")\n        self.assertEqual(str(b.dct2), \"BoundType.dct2\")\n        self.assertEqual(str(b.reflect), \"BoundType.dct2\")\n        self.assertEqual(str(b.dst1), \"BoundType.dst1\")\n        self.assertEqual(str(b.antimirror), \"BoundType.dst1\")\n        self.assertEqual(str(b.dst2), \"BoundType.dst2\")\n        self.assertEqual(str(b.antireflect), \"BoundType.dst2\")\n        self.assertEqual(str(b.dft), \"BoundType.dft\")\n        self.assertEqual(str(b.wrap), \"BoundType.dft\")\n        self.assertEqual(str(b.zero), \"BoundType.zero\")\n\n        self.assertEqual(int(b.replicate), 0)\n        self.assertEqual(int(b.nearest), 0)\n        self.assertEqual(int(b.dct1), 1)\n        self.assertEqual(int(b.mirror), 1)\n        self.assertEqual(int(b.dct2), 2)\n        self.assertEqual(int(b.reflect), 2)\n        self.assertEqual(int(b.dst1), 3)\n        self.assertEqual(int(b.antimirror), 3)\n        self.assertEqual(int(b.dst2), 4)\n        self.assertEqual(int(b.antireflect), 4)\n        self.assertEqual(int(b.dft), 5)\n        self.assertEqual(int(b.wrap), 5)\n        self.assertEqual(int(b.zero), 7)\n\n    def test_interp(self):\n        self.assertEqual(str(p.nearest), \"InterpolationType.nearest\")\n        self.assertEqual(str(p.linear), \"InterpolationType.linear\")\n        self.assertEqual(str(p.quadratic), \"InterpolationType.quadratic\")\n        self.assertEqual(str(p.cubic), \"InterpolationType.cubic\")\n        self.assertEqual(str(p.fourth), \"InterpolationType.fourth\")\n        self.assertEqual(str(p.fifth), \"InterpolationType.fifth\")\n        self.assertEqual(str(p.sixth), \"InterpolationType.sixth\")\n        self.assertEqual(str(p.seventh), \"InterpolationType.seventh\")\n\n        self.assertEqual(int(p.nearest), 0)\n        self.assertEqual(int(p.linear), 1)\n        self.assertEqual(int(p.quadratic), 2)\n        self.assertEqual(int(p.cubic), 3)\n        self.assertEqual(int(p.fourth), 4)\n        self.assertEqual(int(p.fifth), 5)\n        self.assertEqual(int(p.sixth), 6)\n        self.assertEqual(int(p.seventh), 7)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_evenly_divisible_all_gather_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nimport torch.distributed as dist\n\nfrom monai.utils import evenly_divisible_all_gather\nfrom tests.test_utils import DistCall, DistTestCase, assert_allclose, skip_if_windows\n\n\n@skip_if_windows\nclass DistributedEvenlyDivisibleAllGather(DistTestCase):\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_data(self):\n        self._run()\n\n    def _run(self):\n        # if dist.get_rank() == 0\n        data1 = torch.tensor([[1, 2], [3, 4]])\n        data2 = torch.tensor([[1.0, 2.0]])\n        data3 = torch.tensor(7)\n\n        if dist.get_rank() == 1:\n            data1 = torch.tensor([[5, 6]])\n            data2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]])\n            data3 = torch.tensor(8)\n\n        result1 = evenly_divisible_all_gather(data=data1, concat=True)\n        assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))\n        result2 = evenly_divisible_all_gather(data=data2, concat=False)\n        for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]):\n            assert_allclose(r, e)\n        result3 = evenly_divisible_all_gather(data=data3, concat=False)\n        for r in result3:\n            self.assertEqual(r.ndimension(), 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_get_package_version.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.utils.module import get_package_version\n\n\nclass TestGetVersion(unittest.TestCase):\n\n    def test_default(self):\n        output = get_package_version(\"42foobarnoexist\")\n        self.assertIn(\"UNKNOWN\", output)\n\n        output = get_package_version(\"numpy\")\n        self.assertNotIn(\"UNKNOWN\", output)\n\n    def test_msg(self):\n        output = get_package_version(\"42foobarnoexist\", \"test\")\n        self.assertIn(\"test\", output)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_handler_logfile.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\n\nfrom monai.utils import optional_import\nfrom tests.test_utils import SkipIfNoModule\n\ntry:\n    _, has_ignite = optional_import(\"ignite\")\n    from ignite.engine import Engine\n\n    from monai.handlers import LogfileHandler\nexcept ImportError:\n    has_ignite = False\n\n\nclass TestHandlerLogfile(unittest.TestCase):\n    def setUp(self):\n        if has_ignite:\n            # set up engine\n            def _train_func(engine, batch):\n                return torch.tensor(0.0)\n\n            self.engine = Engine(_train_func)\n\n            logger = self.engine.logger\n\n            # remove all other handlers to prevent output\n            while logger is not None:\n                del logger.handlers[:]\n                logger = logger.parent\n\n    @SkipIfNoModule(\"ignite\")\n    def test_logfile(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            handler = LogfileHandler(output_dir=tempdir)\n            handler.attach(self.engine)\n\n            self.engine.run(range(3))\n\n            self.assertTrue(os.path.isfile(os.path.join(tempdir, \"log.txt\")))\n\n    @SkipIfNoModule(\"ignite\")\n    def test_filename(self):\n        filename = \"something_else.txt\"\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            handler = LogfileHandler(output_dir=tempdir, filename=filename)\n            handler.attach(self.engine)\n\n            self.engine.run(range(3))\n\n            self.assertTrue(os.path.isfile(os.path.join(tempdir, filename)))\n\n    @SkipIfNoModule(\"ignite\")\n    def test_createdir(self):\n        with tempfile.TemporaryDirectory() as tempdir:\n            output_dir = os.path.join(tempdir, \"new_dir\")\n\n            handler = LogfileHandler(output_dir=output_dir)\n            handler.attach(self.engine)\n\n            self.engine.run(range(3))\n\n            self.assertTrue(os.path.isfile(os.path.join(output_dir, \"log.txt\")))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_handler_metric_logger.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\n\nfrom monai.utils import optional_import\nfrom tests.test_utils import SkipIfNoModule\n\ntry:\n    _, has_ignite = optional_import(\"ignite\")\n    from ignite.engine import Engine, Events\n\n    from monai.handlers import MetricLogger\nexcept ImportError:\n    has_ignite = False\n\n\nclass TestHandlerMetricLogger(unittest.TestCase):\n    @SkipIfNoModule(\"ignite\")\n    def test_metric_logging(self):\n        dummy_name = \"dummy\"\n\n        # set up engine\n        def _train_func(engine, batch):\n            return torch.tensor(0.0)\n\n        engine = Engine(_train_func)\n\n        # set up dummy metric\n        @engine.on(Events.EPOCH_COMPLETED)\n        def _update_metric(engine):\n            engine.state.metrics[dummy_name] = 1\n\n        # set up testing handler\n        handler = MetricLogger(loss_transform=lambda output: output.item())\n        handler.attach(engine)\n\n        engine.run(range(3), max_epochs=2)\n\n        expected_loss = [(1, 0.0), (2, 0.0), (3, 0.0), (4, 0.0), (5, 0.0), (6, 0.0)]\n        expected_metric = [(4, 1), (5, 1), (6, 1)]\n\n        self.assertSetEqual({dummy_name}, set(handler.metrics))\n\n        self.assertListEqual(expected_loss, handler.loss)\n        self.assertListEqual(expected_metric, handler.metrics[dummy_name])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_list_to_dict.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.utils import list_to_dict\n\nTEST_CASE_1 = [[\"a=1\", \"b=2\", \"c=3\", \"d=4\"], {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": 4}]\n\nTEST_CASE_2 = [[\"a=a\", \"b=b\", \"c=c\", \"d=d\"], {\"a\": \"a\", \"b\": \"b\", \"c\": \"c\", \"d\": \"d\"}]\n\nTEST_CASE_3 = [[\"a=0.1\", \"b=0.2\", \"c=0.3\", \"d=0.4\"], {\"a\": 0.1, \"b\": 0.2, \"c\": 0.3, \"d\": 0.4}]\n\nTEST_CASE_4 = [[\"a=True\", \"b=TRUE\", \"c=false\", \"d=FALSE\"], {\"a\": True, \"b\": True, \"c\": False, \"d\": False}]\n\nTEST_CASE_5 = [\n    [\"a='1'\", \"b=2 \", \" c = 3\", \"d='test'\", \"'e'=0\", \"f\", \"g=None\"],\n    {\"a\": 1, \"b\": 2, \"c\": 3, \"d\": \"test\", \"e\": 0, \"f\": None, \"g\": None},\n]\n\n\nclass TestListToDict(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_value_shape(self, input, output):\n        result = list_to_dict(input)\n        self.assertDictEqual(result, output)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_look_up_option.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom enum import Enum\n\nfrom parameterized import parameterized\n\nfrom monai.utils import StrEnum, look_up_option\n\n\nclass _CaseEnum(Enum):\n    CONST = \"constant\"\n    EMPTY = \"empty\"\n\n\nclass _CaseEnum1(Enum):\n    CONST = \"constant\"\n    EMPTY = \"empty\"\n\n\nclass _CaseStrEnum(StrEnum):\n    MODE_A = \"A\"\n    MODE_B = \"B\"\n\n\nTEST_CASES = (\n    (\"test\", (\"test\", \"test1\"), \"test\"),\n    (\"test1\", {\"test1\", \"test\"}, \"test1\"),\n    (2, {1: \"test\", 2: \"valid\"}, \"valid\"),\n    (_CaseEnum.EMPTY, _CaseEnum, _CaseEnum.EMPTY),\n    (\"empty\", _CaseEnum, _CaseEnum.EMPTY),\n)\n\n\nclass TestLookUpOption(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_look_up(self, input_str, supported, expected):\n        output = look_up_option(input_str, supported)\n        self.assertEqual(output, expected)\n\n    def test_default(self):\n        output = look_up_option(\"not here\", {\"a\", \"b\"}, default=None)\n        self.assertEqual(output, None)\n\n    def test_str_enum(self):\n        output = look_up_option(\"C\", {\"A\", \"B\"}, default=None)\n        self.assertIsNone(output)\n        self.assertEqual(list(_CaseStrEnum), [\"A\", \"B\"])\n        self.assertEqual(_CaseStrEnum.MODE_A, \"A\")\n        self.assertEqual(str(_CaseStrEnum.MODE_A), \"A\")\n        self.assertEqual(look_up_option(\"A\", _CaseStrEnum), \"A\")\n\n    def test_no_found(self):\n        with self.assertRaisesRegex(ValueError, \"Unsupported\"):\n            look_up_option(\"not here\", {\"a\", \"b\"})\n        with self.assertRaisesRegex(ValueError, \"Unsupported\"):\n            look_up_option(\"not here\", [\"a\", \"b\"])\n        with self.assertRaisesRegex(ValueError, \"Unsupported\"):\n            look_up_option(\"not here\", {\"a\": 1, \"b\": 2})\n        with self.assertRaisesRegex(ValueError, \"did you mean\"):\n            look_up_option(3, {1: \"a\", 2: \"b\", \"c\": 3})\n        with self.assertRaisesRegex(ValueError, \"did.*empty\"):\n            look_up_option(\"empy\", _CaseEnum)\n        with self.assertRaisesRegex(ValueError, \"Unsupported\"):\n            look_up_option(_CaseEnum1.EMPTY, _CaseEnum)\n        with self.assertRaisesRegex(ValueError, \"Unsupported\"):\n            look_up_option(None, _CaseEnum)\n        with self.assertRaisesRegex(ValueError, \"No\"):\n            look_up_option(None, None)\n        with self.assertRaisesRegex(ValueError, \"No\"):\n            look_up_option(\"test\", None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_optional_import.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.utils import OptionalImportError, exact_version, optional_import\n\n\nclass TestOptionalImport(unittest.TestCase):\n\n    @parameterized.expand([\"not_a_module\", \"torch.randint\"])\n    def test_default(self, import_module):\n        my_module, flag = optional_import(import_module)\n        self.assertFalse(flag)\n        with self.assertRaises(OptionalImportError):\n            my_module.test\n\n    def test_import_valid(self):\n        my_module, flag = optional_import(\"torch\")\n        self.assertTrue(flag)\n        print(my_module.randint(1, 2, (1, 2)))\n\n    def test_import_wrong_number(self):\n        my_module, flag = optional_import(\"torch\", \"42\")\n        with self.assertRaisesRegex(OptionalImportError, \"version\"):\n            my_module.nn\n        self.assertFalse(flag)\n        with self.assertRaisesRegex(OptionalImportError, \"version\"):\n            my_module.randint(1, 2, (1, 2))\n        with self.assertRaisesRegex(ValueError, \"invalid literal\"):\n            my_module, flag = optional_import(\"torch\", \"test\")  # version should be number.number\n            my_module.nn\n            self.assertTrue(flag)\n            print(my_module.randint(1, 2, (1, 2)))\n\n    @parameterized.expand([\"0\", \"0.0.0.1\", \"1.1.0\"])\n    def test_import_good_number(self, version_number):\n        my_module, flag = optional_import(\"torch\", version_number)\n        my_module.nn\n        self.assertTrue(flag)\n        print(my_module.randint(1, 2, (1, 2)))\n\n    def test_import_exact(self):\n        my_module, flag = optional_import(\"torch\", \"0\", exact_version)\n        with self.assertRaisesRegex(OptionalImportError, \"exact_version\"):\n            my_module.nn\n        self.assertFalse(flag)\n        with self.assertRaisesRegex(OptionalImportError, \"exact_version\"):\n            my_module.randint(1, 2, (1, 2))\n\n    def test_import_method(self):\n        nn, flag = optional_import(\"torch\", \"1.1\", name=\"nn\")\n        self.assertTrue(flag)\n        print(nn.functional)\n\n    def test_additional(self):\n        test_args = {\"a\": \"test\", \"b\": \"test\"}\n\n        def versioning(module, ver, a):\n            self.assertEqual(a, test_args)\n            return True\n\n        nn, flag = optional_import(\"torch\", \"1.1\", version_checker=versioning, name=\"nn\", version_args=test_args)\n        self.assertTrue(flag)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_pad_mode.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.transforms import CastToType, Pad\nfrom monai.utils import NumpyPadMode, PytorchPadMode\nfrom tests.test_utils import dict_product\n\n\nclass TestPadMode(unittest.TestCase):\n    def test_pad(self):\n        expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)}\n        devices = (\"cuda:0\", \"cpu\") if torch.cuda.is_available() else (\"cpu\",)\n        shapes = ((1, 10, 10), (1, 5, 6, 7))\n        types = (float, int, np.uint8, np.int16, np.float32, bool)\n        modes = list(PytorchPadMode) + list(NumpyPadMode)\n\n        for params in dict_product(t=types, d=devices, s=shapes, m=modes):\n            t = params[\"t\"]\n            d = params[\"d\"]\n            s = params[\"s\"]\n            m = params[\"m\"]\n            a = torch.rand(s)\n            to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)]\n            out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d))\n            self.assertEqual(out.shape, expected_shapes[len(s)])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_profiling.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport datetime\nimport os\nimport unittest\nfrom io import StringIO\n\nimport torch\n\nimport monai.transforms as mt\nfrom monai.data import Dataset, ThreadDataLoader\nfrom monai.utils import first, optional_import\nfrom monai.utils.enums import CommonKeys\nfrom monai.utils.profiling import ProfileHandler, ProfileResult, WorkflowProfiler\nfrom tests.test_utils import SkipIfNoModule\n\npd, _ = optional_import(\"pandas\")\n\n\nclass TestWorkflowProfiler(unittest.TestCase):\n    def setUp(self):\n        super().setUp()\n\n        self.scale = mt.ScaleIntensity()\n        self.scale_call_name = \"ScaleIntensity.__call__\"\n        self.compose_call_name = \"Compose.__call__\"\n        self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)])\n        self.test_image = torch.rand(1, 16, 16, 16)\n        self.pid = os.getpid()\n\n    def test_empty(self):\n        \"\"\"Test that the profiler correctly produces an empty result when nothing happens in a context.\"\"\"\n        wp = WorkflowProfiler()\n\n        with wp:\n            pass\n\n        self.assertEqual(wp.get_results(), {})\n\n    def test_profile_transforms(self):\n        \"\"\"Test basic reporting when invoking a single transform directly.\"\"\"\n        with WorkflowProfiler() as wp:\n            self.scale(self.test_image)\n\n        results = wp.get_results()\n        self.assertSequenceEqual(list(results), [self.scale_call_name])\n\n        prs = results[self.scale_call_name]\n\n        self.assertEqual(len(prs), 1)\n\n        pr = prs[0]\n\n        self.assertIsInstance(pr, ProfileResult)\n        self.assertEqual(pr.name, self.scale_call_name)\n        self.assertEqual(pr.pid, self.pid)\n        self.assertGreater(pr.time, 0)\n\n        dt = datetime.datetime.fromisoformat(pr.timestamp)\n\n        self.assertIsInstance(dt, datetime.datetime)\n\n    def test_profile_multithread(self):\n        \"\"\"Test resulst are gathered from multiple threads using ThreadDataLoader.\"\"\"\n        ds = Dataset([self.test_image] * 4, self.scale)\n        dl = ThreadDataLoader(ds, batch_size=4, num_workers=4, use_thread_workers=True)\n\n        with WorkflowProfiler() as wp:\n            batch = first(dl)\n\n        self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16))\n\n        results = wp.get_results()\n        self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name])\n\n        prs = results[self.scale_call_name]\n\n        self.assertEqual(len(prs), 4)\n\n    def test_profile_context(self):\n        \"\"\"Test results from profiling contexts with the same name accumulate correctly.\"\"\"\n        with WorkflowProfiler() as wp:\n            with wp.profile_ctx(\"context\"):\n                self.scale(self.test_image)\n\n            with wp.profile_ctx(\"context\"):\n                self.scale(self.test_image)\n\n        results = wp.get_results()\n\n        self.assertSequenceEqual(set(results), {\"ScaleIntensity.__call__\", \"context\"})\n\n        prs = results[\"context\"]\n\n        self.assertEqual(len(prs), 2)\n\n    def test_profile_callable(self):\n        \"\"\"Test profiling functions with default or set names.\"\"\"\n\n        def funca():\n            pass\n\n        with WorkflowProfiler() as wp:\n            funca = wp.profile_callable()(funca)\n\n            funca()\n\n            @wp.profile_callable(\"funcb\")\n            def _func():\n                pass\n\n            _func()\n            _func()\n\n        results = wp.get_results()\n        self.assertSequenceEqual(set(results), {\"funca\", \"funcb\"})\n\n        self.assertEqual(len(results[\"funca\"]), 1)\n        self.assertEqual(len(results[\"funcb\"]), 2)\n\n    def test_profile_iteration(self):\n        \"\"\"Test iterables are profiled correctly, producing the right output and number of results.\"\"\"\n        with WorkflowProfiler() as wp:\n            range_vals = []\n\n            for i in wp.profile_iter(\"range5\", range(5)):\n                range_vals.append(i)\n\n            self.assertSequenceEqual(range_vals, list(range(5)))\n\n        results = wp.get_results()\n        self.assertSequenceEqual(set(results), {\"range5\"})\n\n        self.assertEqual(len(results[\"range5\"]), 5)\n\n    def test_times_summary(self):\n        \"\"\"Test generating the summary report dictionary.\"\"\"\n        with WorkflowProfiler() as wp:\n            self.scale(self.test_image)\n\n        tsum = wp.get_times_summary()\n\n        self.assertSequenceEqual(list(tsum), [self.scale_call_name])\n\n        times = tsum[self.scale_call_name]\n\n        self.assertEqual(len(times), 6)\n        self.assertEqual(times[0], 1)\n\n    @SkipIfNoModule(\"pandas\")\n    def test_times_summary_pd(self):\n        \"\"\"Test generating the Pandas result works if Pandas is present.\"\"\"\n        with WorkflowProfiler() as wp:\n            self.scale(self.test_image)\n\n        df = wp.get_times_summary_pd()\n\n        self.assertIsInstance(df, pd.DataFrame)\n\n    def test_csv_dump(self):\n        \"\"\"Test dumping the results to csv file in a local StringIO object.\"\"\"\n        with WorkflowProfiler() as wp:\n            self.scale(self.test_image)\n\n        sio = StringIO()\n        wp.dump_csv(sio)\n        self.assertGreater(sio.tell(), 0)\n\n    @SkipIfNoModule(\"ignite\")\n    def test_handler(self):\n        \"\"\"Test profiling Engine objects works if Ignite is present.\"\"\"\n        from ignite.engine import Events\n\n        from monai.engines import SupervisedTrainer\n\n        net = torch.nn.Conv2d(1, 1, 3, padding=1)\n        im = torch.rand(1, 1, 16, 16)\n\n        with WorkflowProfiler(None) as wp:\n            trainer = SupervisedTrainer(\n                device=torch.device(\"cpu\"),\n                max_epochs=2,\n                train_data_loader=[{CommonKeys.IMAGE: im, CommonKeys.LABEL: im}] * 3,\n                epoch_length=3,\n                network=net,\n                optimizer=torch.optim.Adam(net.parameters()),\n                loss_function=torch.nn.L1Loss(),\n            )\n\n            _ = ProfileHandler(\"Epoch\", wp, Events.EPOCH_STARTED, Events.EPOCH_COMPLETED).attach(trainer)\n\n            trainer.run()\n\n        results = wp.get_results()\n\n        self.assertSequenceEqual(set(results), {\"Epoch\"})\n        self.assertEqual(len(results[\"Epoch\"]), 2)\n"
  },
  {
    "path": "tests/utils/test_rankfilter_dist.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport logging\nimport os\nimport tempfile\nimport unittest\n\nimport torch.distributed as dist\n\nfrom monai.utils import RankFilter\nfrom tests.test_utils import DistCall, DistTestCase\n\n\nclass DistributedRankFilterTest(DistTestCase):\n    def setUp(self):\n        self.log_dir = tempfile.TemporaryDirectory()\n\n    @DistCall(nnodes=1, nproc_per_node=2)\n    def test_rankfilter(self):\n        logger = logging.getLogger(__name__)\n        log_filename = os.path.join(self.log_dir.name, \"records.log\")\n        h1 = logging.FileHandler(filename=log_filename)\n        h1.setLevel(logging.WARNING)\n\n        logger.addHandler(h1)\n\n        logger.addFilter(RankFilter())\n        logger.warning(\"test_warnings\")\n\n        dist.barrier()\n        if dist.get_rank() == 0:\n            with open(log_filename) as file:\n                lines = [line.rstrip() for line in file]\n            log_message = \" \".join(lines)\n            self.assertEqual(log_message.count(\"test_warnings\"), 1)\n\n    def tearDown(self) -> None:\n        self.log_dir.cleanup()\n\n\nclass SingleRankFilterTest(unittest.TestCase):\n    def tearDown(self) -> None:\n        self.log_dir.cleanup()\n\n    def setUp(self):\n        self.log_dir = tempfile.TemporaryDirectory()\n\n    def test_rankfilter_single_proc(self):\n        logger = logging.getLogger(__name__)\n        log_filename = os.path.join(self.log_dir.name, \"records_sp.log\")\n        h1 = logging.FileHandler(filename=log_filename)\n        h1.setLevel(logging.WARNING)\n        logger.addHandler(h1)\n        logger.addFilter(RankFilter())\n        logger.warning(\"test_warnings\")\n\n        with open(log_filename) as file:\n            lines = [line.rstrip() for line in file]\n        logger.removeHandler(h1)\n        h1.close()\n        log_message = \" \".join(lines)\n        self.assertEqual(log_message.count(\"test_warnings\"), 1)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_require_pkg.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom monai.utils import OptionalImportError, min_version, require_pkg\n\n\nclass TestRequirePkg(unittest.TestCase):\n\n    def test_class(self):\n\n        @require_pkg(pkg_name=\"torch\", version=\"1.4\", version_checker=min_version)\n        class TestClass:\n            pass\n\n        TestClass()\n\n    def test_function(self):\n\n        @require_pkg(pkg_name=\"torch\", version=\"1.4\", version_checker=min_version)\n        def test_func(x):\n            return x\n\n        test_func(x=None)\n\n    def test_warning(self):\n\n        @require_pkg(pkg_name=\"test123\", raise_error=False)\n        def test_func(x):\n            return x\n\n        test_func(x=None)\n\n    def test_class_exception(self):\n        with self.assertRaises(OptionalImportError):\n\n            @require_pkg(pkg_name=\"test123\")\n            class TestClass:\n                pass\n\n            TestClass()\n\n    def test_class_version_exception(self):\n        with self.assertRaises(OptionalImportError):\n\n            @require_pkg(pkg_name=\"torch\", version=\"10000\", version_checker=min_version)\n            class TestClass:\n                pass\n\n            TestClass()\n\n    def test_func_exception(self):\n        with self.assertRaises(OptionalImportError):\n\n            @require_pkg(pkg_name=\"test123\")\n            def test_func(x):\n                return x\n\n            test_func(x=None)\n\n    def test_func_versions_exception(self):\n        with self.assertRaises(OptionalImportError):\n\n            @require_pkg(pkg_name=\"torch\", version=\"10000\", version_checker=min_version)\n            def test_func(x):\n                return x\n\n            test_func(x=None)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_sample_slices.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils import sample_slices\nfrom tests.test_utils import TEST_NDARRAYS, assert_allclose\n\n# test data[:, [1, ], ...]\nTEST_CASE_1 = [torch.tensor([[[0, 2], [1, 0]]]), 1, True, (1,), torch.tensor([[[1, 0]]])]\n# test data[:, [0, 2], ...]\nTEST_CASE_2 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, True, (0, 2), torch.tensor([[[0, 2], [4, 5]]])]\n# test data[:, [0: 2], ...]\nTEST_CASE_3 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 2), torch.tensor([[[0, 2], [1, 0]]])]\n# test data[:, [1: ], ...]\nTEST_CASE_4 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (1, None), torch.tensor([[[1, 0], [4, 5]]])]\n# test data[:, [0: 3: 2], ...]\nTEST_CASE_5 = [torch.tensor([[[0, 2], [1, 0], [4, 5]]]), 1, False, (0, 3, 2), torch.tensor([[[0, 2], [4, 5]]])]\n\n\nclass TestSampleSlices(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_shape(self, input_data, dim, as_indices, vals, expected_result):\n        for p in TEST_NDARRAYS:\n            result = sample_slices(p(input_data), dim, as_indices, *vals)\n            assert_allclose(p(expected_result), result)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_set_determinism.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\n\nfrom monai.utils import get_seed, set_determinism\nfrom tests.test_utils import skip_if_no_cuda\n\n\nclass TestSetDeterminism(unittest.TestCase):\n    def test_values(self):\n        # check system default flags\n        set_determinism(None)\n        self.assertTrue(not torch.backends.cudnn.deterministic)\n        self.assertTrue(get_seed() is None)\n        # set default seed\n        set_determinism()\n        self.assertTrue(get_seed() is not None)\n        self.assertTrue(torch.backends.cudnn.deterministic)\n        self.assertTrue(not torch.backends.cudnn.benchmark)\n        # resume default\n        set_determinism(None)\n        self.assertTrue(not torch.backends.cudnn.deterministic)\n        self.assertTrue(not torch.backends.cudnn.benchmark)\n        self.assertTrue(get_seed() is None)\n        # test seeds\n        seed = 255\n        set_determinism(seed=seed)\n        self.assertEqual(seed, get_seed())\n        a = np.random.randint(seed)\n        b = torch.randint(seed, (1,))\n\n        # test when global flag support is disabled\n        torch.backends.disable_global_flags()\n        set_determinism(seed=seed)\n        c = np.random.randint(seed)\n        d = torch.randint(seed, (1,))\n        self.assertEqual(a, c)\n        self.assertEqual(b, d)\n        self.assertTrue(torch.backends.cudnn.deterministic)\n        self.assertTrue(not torch.backends.cudnn.benchmark)\n        set_determinism(seed=None)\n\n\nclass TestSetFlag(unittest.TestCase):\n    def setUp(self):\n        set_determinism(1, use_deterministic_algorithms=True)\n\n    @skip_if_no_cuda\n    def test_algo_not_deterministic(self):\n        \"\"\"\n        Test `avg_pool3d_backward_cuda` correctly raises an exception since it lacks a deterministic implementation.\n        \"\"\"\n        with self.assertRaises(RuntimeError):\n            x = torch.randn(20, 16, 50, 44, 31, requires_grad=True, device=\"cuda:0\")\n            y = torch.nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))(x)\n            y.sum().backward()\n\n    @skip_if_no_cuda\n    def test_algo_cublas_env(self):\n        \"\"\"\n        Test `torch.mm` does not raise an exception with the CUBLAS_WORKSPACE_CONFIG environment variable correctly set.\n        \"\"\"\n        x = torch.rand(5, 5, device=\"cuda:0\")\n        _ = torch.mm(x, x)\n\n    def tearDown(self):\n        set_determinism(None, use_deterministic_algorithms=False)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_squeeze_unsqueeze.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils import unsqueeze_left, unsqueeze_right\n\nRIGHT_CASES = [\n    (np.random.rand(3, 4).astype(np.float32), 5, (3, 4, 1, 1, 1)),\n    (torch.rand(3, 4).type(torch.float32), 5, (3, 4, 1, 1, 1)),\n    (np.random.rand(3, 4).astype(np.float64), 5, (3, 4, 1, 1, 1)),\n    (torch.rand(3, 4).type(torch.float64), 5, (3, 4, 1, 1, 1)),\n    (np.random.rand(3, 4).astype(np.int32), 5, (3, 4, 1, 1, 1)),\n    (torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)),\n]\n\nLEFT_CASES = [\n    (np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)),\n    (torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)),\n    (np.random.rand(3, 4).astype(np.float64), 5, (1, 1, 1, 3, 4)),\n    (torch.rand(3, 4).type(torch.float64), 5, (1, 1, 1, 3, 4)),\n    (np.random.rand(3, 4).astype(np.int32), 5, (1, 1, 1, 3, 4)),\n    (torch.rand(3, 4).type(torch.int32), 5, (1, 1, 1, 3, 4)),\n]\nALL_CASES = [\n    (np.random.rand(3, 4), 2, (3, 4)),\n    (np.random.rand(3, 4), 0, (3, 4)),\n    (np.random.rand(3, 4), -1, (3, 4)),\n    (np.array(3), 4, (1, 1, 1, 1)),\n    (np.array(3), 0, ()),\n    (np.random.rand(3, 4).astype(np.int32), 2, (3, 4)),\n    (np.random.rand(3, 4).astype(np.int32), 0, (3, 4)),\n    (np.random.rand(3, 4).astype(np.int32), -1, (3, 4)),\n    (np.array(3).astype(np.int32), 4, (1, 1, 1, 1)),\n    (np.array(3).astype(np.int32), 0, ()),\n    (torch.rand(3, 4), 2, (3, 4)),\n    (torch.rand(3, 4), 0, (3, 4)),\n    (torch.rand(3, 4), -1, (3, 4)),\n    (torch.tensor(3), 4, (1, 1, 1, 1)),\n    (torch.tensor(3), 0, ()),\n    (torch.rand(3, 4).type(torch.int32), 2, (3, 4)),\n    (torch.rand(3, 4).type(torch.int32), 0, (3, 4)),\n    (torch.rand(3, 4).type(torch.int32), -1, (3, 4)),\n    (torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)),\n    (torch.tensor(3).type(torch.int32), 0, ()),\n]\n\n\nclass TestUnsqueeze(unittest.TestCase):\n\n    @parameterized.expand(RIGHT_CASES + ALL_CASES)\n    def test_unsqueeze_right(self, arr, ndim, shape):\n        self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)\n\n    @parameterized.expand(LEFT_CASES + ALL_CASES)\n    def test_unsqueeze_left(self, arr, ndim, shape):\n        self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)\n"
  },
  {
    "path": "tests/utils/test_state_cacher.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport pickle\nimport unittest\nfrom os.path import exists, join\nfrom pathlib import Path\nfrom tempfile import gettempdir\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils import StateCacher\n\nDEVICE = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n\nTEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {\"in_memory\": True}]\nTEST_CASE_1 = [\n    torch.Tensor([1]).to(DEVICE),\n    {\n        \"in_memory\": False,\n        \"cache_dir\": gettempdir(),\n        \"pickle_module\": None,\n        # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility\n        \"pickle_protocol\": torch.serialization.DEFAULT_PROTOCOL,\n    },\n]\nTEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {\"in_memory\": False, \"allow_overwrite\": False}]\nTEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {\"in_memory\": False, \"cache_dir\": Path(gettempdir())}]\n\nTEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]\n\n\nclass TestStateCacher(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_state_cacher(self, data_obj, params):\n        key = \"data_obj\"\n\n        state_cacher = StateCacher(**params)\n        # store it\n        state_cacher.store(key, data_obj, pickle_module=pickle)\n        # create clone then modify original\n        data_obj_orig = data_obj.clone()\n        data_obj += 1\n        # Restore and check nothing has changed\n        data_obj_restored = state_cacher.retrieve(key)\n        self.assertEqual(data_obj_orig, data_obj_restored)\n\n        # If not allow overwrite, check an attempt would raise exception\n        if \"allow_overwrite\" in params and params[\"allow_overwrite\"]:\n            with self.assertRaises(RuntimeError):\n                state_cacher.store(key, data_obj)\n\n        # If using a cache dir, check file has been deleted et end\n        if \"cache_dir\" in params:\n            i = id(state_cacher)\n            del state_cacher\n            self.assertFalse(exists(join(params[\"cache_dir\"], f\"state_{key}_{i}.pt\")))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_torchscript_utils.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\n\nimport torch\n\nfrom monai.config import get_config_values\nfrom monai.data import load_net_with_metadata, save_net_with_metadata\nfrom monai.utils import JITMetadataKeys\n\n\nclass TestModule(torch.nn.Module):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n    def forward(self, x):\n        return x + 10\n\n\nclass TestTorchscript(unittest.TestCase):\n\n    def test_save_net_with_metadata(self):\n        \"\"\"Save a network without metadata to a file.\"\"\"\n        m = torch.jit.script(TestModule())\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_net_with_metadata(m, f\"{tempdir}/test\")\n\n            self.assertTrue(os.path.isfile(f\"{tempdir}/test.ts\"))\n\n    def test_save_net_with_metadata_ext(self):\n        \"\"\"Save a network without metadata to a file.\"\"\"\n        m = torch.jit.script(TestModule())\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_net_with_metadata(m, f\"{tempdir}/test.zip\")\n\n            self.assertTrue(os.path.isfile(f\"{tempdir}/test.zip\"))\n\n    def test_save_net_with_metadata_with_extra(self):\n        \"\"\"Save a network with simple metadata to a file.\"\"\"\n        m = torch.jit.script(TestModule())\n\n        test_metadata = {\"foo\": [1, 2], \"bar\": \"string\"}\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_net_with_metadata(m, f\"{tempdir}/test\", meta_values=test_metadata)\n\n            self.assertTrue(os.path.isfile(f\"{tempdir}/test.ts\"))\n\n    def test_load_net_with_metadata(self):\n        \"\"\"Save then load a network with no metadata or other extra files.\"\"\"\n        m = torch.jit.script(TestModule())\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_net_with_metadata(m, f\"{tempdir}/test\")\n            _, meta, extra_files = load_net_with_metadata(f\"{tempdir}/test.ts\")\n\n        del meta[JITMetadataKeys.TIMESTAMP.value]  # no way of knowing precisely what this value would be\n\n        self.assertEqual(meta, get_config_values())\n        self.assertEqual(extra_files, {})\n\n    def test_load_net_with_metadata_with_extra(self):\n        \"\"\"Save then load a network with basic metadata.\"\"\"\n        m = torch.jit.script(TestModule())\n\n        test_metadata = {\"foo\": [1, 2], \"bar\": \"string\"}\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_net_with_metadata(m, f\"{tempdir}/test\", meta_values=test_metadata)\n            _, meta, extra_files = load_net_with_metadata(f\"{tempdir}/test.ts\")\n\n        del meta[JITMetadataKeys.TIMESTAMP.value]  # no way of knowing precisely what this value would be\n\n        test_compare = get_config_values()\n        test_compare.update(test_metadata)\n\n        self.assertEqual(meta, test_compare)\n        self.assertEqual(extra_files, {})\n\n    def test_save_load_more_extra_files(self):\n        \"\"\"Save then load extra file data from a torchscript file.\"\"\"\n        m = torch.jit.script(TestModule())\n\n        test_metadata = {\"foo\": [1, 2], \"bar\": \"string\"}\n\n        more_extra_files = {\"test.txt\": b\"This is test data\"}\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            save_net_with_metadata(m, f\"{tempdir}/test\", meta_values=test_metadata, more_extra_files=more_extra_files)\n\n            self.assertTrue(os.path.isfile(f\"{tempdir}/test.ts\"))\n\n            _, _, loaded_extra_files = load_net_with_metadata(f\"{tempdir}/test.ts\", more_extra_files=(\"test.txt\",))\n\n            self.assertEqual(more_extra_files[\"test.txt\"], loaded_extra_files[\"test.txt\"])\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_version.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport itertools\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.utils import version_geq, version_leq\n\n\n# from pkg_resources\ndef _pairwise(iterable):\n    \"s -> (s0,s1), (s1,s2), (s2, s3), ...\"\n    a, b = itertools.tee(iterable)\n    next(b, None)\n    return zip(a, b)\n\n\n# from pkg_resources\ntorture = \"\"\"\n    0.80.1-3 0.80.1-2 0.80.1-1 0.79.9999+0.80.0pre4-1\n    0.79.9999+0.80.0pre2-3 0.79.9999+0.80.0pre2-2\n    0.77.2-1 0.77.1-1 0.77.0-1\n    \"\"\"\n\nTEST_CASES = (\n    (\"1.6.0\", \"1.6.0\"),\n    (\"1.6.0a0+9907a3e\", \"1.6.0\"),\n    (\"0+unknown\", \"0.6\"),\n    (\"ab\", \"abc\"),\n    (\"0.6rc1\", \"0.6\"),\n    (\"0.6\", \"0.7\"),\n    (\"1.2.a\", \"1.2a\"),\n    (\"1.2-rc1\", \"1.2rc1\"),\n    (\"0.4\", \"0.4.0\"),\n    (\"0.4.0.0\", \"0.4.0\"),\n    (\"0.4.0-0\", \"0.4-0\"),\n    (\"0post1\", \"0.0post1\"),\n    (\"0pre1\", \"0.0c1\"),\n    (\"0.0.0preview1\", \"0c1\"),\n    (\"0.0c1\", \"0-rc1\"),\n    (\"1.2a1\", \"1.2.a.1\"),\n    (\"1.2.a\", \"1.2a\"),\n    (\"2.1\", \"2.1.1\"),\n    (\"2a1\", \"2b0\"),\n    (\"2a1\", \"2.1\"),\n    (\"2.3a1\", \"2.3\"),\n    (\"2.1-1\", \"2.1-2\"),\n    (\"2.1-1\", \"2.1.1\"),\n    (\"2.1\", \"2.1post4\"),\n    (\"2.1a0-20040501\", \"2.1\"),\n    (\"1.1\", \"02.1\"),\n    (\"3.2\", \"3.2.post0\"),\n    (\"3.2post1\", \"3.2post2\"),\n    (\"0.4\", \"4.0\"),\n    (\"0.0.4\", \"0.4.0\"),\n    (\"0post1\", \"0.4post1\"),\n    (\"2.1.0-rc1\", \"2.1.0\"),\n    (\"2.1dev\", \"2.1a0\"),\n    (1.6, \"1.6.0\"),\n    (\"1.6.0\", 1.6),\n    (1.6, 1.7),\n) + tuple(_pairwise(reversed(torture.split())))\n\n\nclass TestVersionCompare(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES)\n    def test_compare_leq(self, a, b, expected=True):\n        \"\"\"Test version_leq with `a` and `b`\"\"\"\n        self.assertEqual(version_leq(a, b), expected)\n\n    @parameterized.expand(TEST_CASES)\n    def test_compare_geq(self, a, b, expected=True):\n        \"\"\"Test version_geq with `b` and `a`\"\"\"\n        self.assertEqual(version_geq(b, a), expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_version_after.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nfrom parameterized import parameterized\n\nfrom monai.utils import compute_capabilities_after, pytorch_after\n\nTEST_CASES_PT = (\n    (1, 5, 9, \"1.6.0\"),\n    (1, 6, 0, \"1.6.0\"),\n    (1, 6, 1, \"1.6.0\", False),\n    (1, 7, 0, \"1.6.0\", False),\n    (2, 6, 0, \"1.6.0\", False),\n    (0, 6, 0, \"1.6.0a0+3fd9dcf\"),\n    (1, 5, 9, \"1.6.0a0+3fd9dcf\"),\n    (1, 6, 0, \"1.6.0a0+3fd9dcf\", False),\n    (1, 6, 1, \"1.6.0a0+3fd9dcf\", False),\n    (2, 6, 0, \"1.6.0a0+3fd9dcf\", False),\n    (1, 6, 0, \"1.6.0-rc0+3fd9dcf\", False),  # defaults to prerelease\n    (1, 6, 0, \"1.6.0rc0\", False),\n    (1, 6, 0, \"1.6\", True),\n    (1, 6, 0, \"1\", False),\n    (1, 6, 0, \"1.6.0+cpu\", True),\n    (1, 6, 1, \"1.6.0+cpu\", False),\n)\n\nTEST_CASES_SM = [\n    # (major, minor, sm, expected)\n    (6, 1, \"6.1\", True),\n    (6, 1, \"6.0\", False),\n    (6, 0, \"8.6\", True),\n    (7, 0, \"8\", True),\n    (8, 6, \"8\", False),\n]\n\n\nclass TestPytorchVersionCompare(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_PT)\n    def test_compare(self, a, b, p, current, expected=True):\n        \"\"\"Test pytorch_after with a and b\"\"\"\n        self.assertEqual(pytorch_after(a, b, p, current), expected)\n\n\nclass TestComputeCapabilitiesAfter(unittest.TestCase):\n\n    @parameterized.expand(TEST_CASES_SM)\n    def test_compute_capabilities_after(self, major, minor, sm, expected):\n        self.assertEqual(compute_capabilities_after(major, minor, sm), expected)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/type_conversion/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/utils/type_conversion/test_convert_data_type.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data import MetaTensor\nfrom monai.utils.type_conversion import convert_data_type, convert_to_dst_type, get_equivalent_dtype\nfrom tests.test_utils import TEST_NDARRAYS_ALL, assert_allclose\n\nTESTS: list[tuple] = []\nfor in_type in TEST_NDARRAYS_ALL + (int, float):\n    for out_type in TEST_NDARRAYS_ALL:\n        TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)), None, False))  # type: ignore\n        if in_type is not float:\n            TESTS.append((in_type(np.array(256)), out_type(np.array(255)), np.uint8, True))  # type: ignore\n\nTESTS_LIST: list[tuple] = []\nfor in_type in TEST_NDARRAYS_ALL + (int, float):\n    for out_type in TEST_NDARRAYS_ALL:\n        TESTS_LIST.append(\n            (\n                [in_type(np.array(1.0)), in_type(np.array(1.0))],  # type: ignore\n                out_type(np.array([1.0, 1.0])),\n                True,\n                None,\n                False,\n            )\n        )\n        TESTS_LIST.append(\n            (\n                [in_type(np.array(1.0)), in_type(np.array(1.0))],  # type: ignore\n                [out_type(np.array(1.0)), out_type(np.array(1.0))],\n                False,\n                None,\n                False,\n            )\n        )\n        if in_type is not float:\n            TESTS_LIST.append(\n                (\n                    [in_type(np.array(257)), in_type(np.array(1))],  # type: ignore\n                    out_type(np.array([255, 1])),\n                    True,\n                    np.uint8,\n                    True,\n                )\n            )\n            TESTS_LIST.append(\n                (\n                    [in_type(np.array(257)), in_type(np.array(-12))],  # type: ignore\n                    [out_type(np.array(255)), out_type(np.array(0))],\n                    False,\n                    np.uint8,\n                    True,\n                )\n            )\n\nUNSUPPORTED_TYPES = {np.dtype(\"uint16\"): torch.int32, np.dtype(\"uint32\"): torch.int64, np.dtype(\"uint64\"): torch.int64}\n\n\nclass TestTensor(torch.Tensor):\n    __test__ = False  # indicate to pytest that this class is not intended for collection\n\n\nclass TestConvertDataType(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_convert_data_type(self, in_image, im_out, out_dtype, safe):\n        converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out), dtype=out_dtype, safe=safe)\n        # check input is unchanged\n        self.assertEqual(type(in_image), orig_type)\n        if isinstance(in_image, torch.Tensor):\n            self.assertEqual(in_image.device, orig_device)\n        # check output is desired type\n        self.assertEqual(type(converted_im), type(im_out))\n        # check data has been clipped\n        assert_allclose(converted_im, im_out)\n        # check dtype is unchanged\n        if out_dtype is None:\n            if isinstance(in_image, (np.ndarray, torch.Tensor)):\n                self.assertEqual(converted_im.dtype, im_out.dtype)\n\n    def test_neg_stride(self):\n        _ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor)\n\n    @parameterized.expand(list(UNSUPPORTED_TYPES.items()))\n    def test_unsupported_np_types(self, np_type, pt_type):\n        in_image = np.ones(13, dtype=np_type)  # choose a prime size so as to be indivisible by the size of any dtype\n        converted_im, orig_type, orig_device = convert_data_type(in_image, torch.Tensor)\n\n        self.assertEqual(converted_im.dtype, pt_type)\n\n    @parameterized.expand(TESTS_LIST)\n    def test_convert_list(self, in_image, im_out, wrap, out_dtype, safe):\n        output_type = type(im_out) if wrap else type(im_out[0])\n        converted_im, *_ = convert_data_type(in_image, output_type, wrap_sequence=wrap, dtype=out_dtype, safe=safe)\n        # check output is desired type\n        if not wrap:\n            converted_im = converted_im[0]\n            im_out = im_out[0]\n        self.assertEqual(type(converted_im), type(im_out))\n        assert_allclose(converted_im, im_out)\n        # check dtype is unchanged\n        if isinstance(in_image[0], (np.ndarray, torch.Tensor)):\n            if out_dtype is None:\n                self.assertEqual(converted_im.dtype, im_out.dtype)\n            else:\n                _out_dtype = get_equivalent_dtype(out_dtype, output_type)\n                self.assertEqual(converted_im.dtype, _out_dtype)\n\n\nclass TestConvertDataSame(unittest.TestCase):\n    # add test for subclass of Tensor\n    @parameterized.expand(TESTS + [(np.array(256), TestTensor(np.array([255])), torch.uint8, True)])\n    def test_convert_data_type(self, in_image, im_out, out_dtype, safe):\n        converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out, dtype=out_dtype, safe=safe)\n        # check input is unchanged\n        self.assertEqual(type(in_image), orig_type)\n        assert_allclose(converted_im, im_out)\n        if isinstance(in_image, torch.Tensor):\n            self.assertEqual(in_image.device, orig_device)\n\n        # check output is desired type\n        if isinstance(im_out, MetaTensor):\n            output_type = MetaTensor\n        elif isinstance(im_out, torch.Tensor):\n            output_type = torch.Tensor\n        else:\n            output_type = np.ndarray\n        self.assertEqual(type(converted_im), output_type)\n        # check dtype is unchanged\n        if out_dtype is None:\n            if isinstance(in_image, (np.ndarray, torch.Tensor, MetaTensor)):\n                self.assertEqual(converted_im.dtype, im_out.dtype)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/type_conversion/test_get_equivalent_dtype.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils.type_conversion import get_equivalent_dtype, get_numpy_dtype_from_string, get_torch_dtype_from_string\nfrom tests.test_utils import TEST_NDARRAYS\n\nDTYPES = [torch.float32, np.float32, np.dtype(np.float32)]\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    for im_dtype in DTYPES:\n        TESTS.append((p(np.array(1.0, dtype=np.float32)), im_dtype))\n\n\nclass TestGetEquivalentDtype(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_get_equivalent_dtype(self, im, input_dtype):\n        out_dtype = get_equivalent_dtype(input_dtype, type(im))\n        self.assertEqual(out_dtype, im.dtype)\n\n    def test_native_type(self):\n        \"\"\"the get_equivalent_dtype currently doesn't change the build-in type\"\"\"\n        n_type = [float, int, bool]\n        for n in n_type:\n            for im_dtype in DTYPES:\n                out_dtype = get_equivalent_dtype(n, type(im_dtype))\n                self.assertEqual(out_dtype, n)\n\n    @parameterized.expand(\n        [\n            [\"float\", np.float64],\n            [\"float32\", np.float32],\n            [\"np.float32\", np.float32],\n            [\"float64\", np.float64],\n            [\"torch.float64\", np.float64],\n        ]\n    )\n    def test_from_string(self, dtype_str, expected_np):\n        expected_pt = get_equivalent_dtype(expected_np, torch.Tensor)\n        # numpy\n        dtype = get_numpy_dtype_from_string(dtype_str)\n        self.assertEqual(dtype, expected_np)\n        # torch\n        dtype = get_torch_dtype_from_string(dtype_str)\n        self.assertEqual(dtype, expected_pt)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/type_conversion/test_safe_dtype_range.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils import optional_import\nfrom monai.utils.type_conversion import get_equivalent_dtype, safe_dtype_range\nfrom tests.test_utils import HAS_CUPY, TEST_NDARRAYS_ALL, assert_allclose\n\ncp, _ = optional_import(\"cupy\")\n\nTESTS: list[tuple] = []\nfor in_type in TEST_NDARRAYS_ALL + (int, float):\n    TESTS.append((in_type(np.array(1.0)), in_type(np.array(1.0)), None))  # type: ignore\n    if in_type is not float:\n        TESTS.append((in_type(np.array(256)), in_type(np.array(255)), np.uint8))  # type: ignore\n        TESTS.append((in_type(np.array(-12)), in_type(np.array(0)), np.uint8))  # type: ignore\nfor in_type in TEST_NDARRAYS_ALL:\n    TESTS.append((in_type(np.array([[256, 255], [-12, 0]])), in_type(np.array([[255, 255], [0, 0]])), np.uint8))\n\nTESTS_LIST: list[tuple] = []\nfor in_type in TEST_NDARRAYS_ALL + (int, float):\n    TESTS_LIST.append(\n        (\n            [in_type(np.array(1.0)), in_type(np.array(1.0))],  # type: ignore\n            [in_type(np.array(1.0)), in_type(np.array(1.0))],  # type: ignore\n            None,\n        )\n    )\n    if in_type is not float:\n        TESTS_LIST.append(\n            (\n                [in_type(np.array(257)), in_type(np.array(-12))],  # type: ignore\n                [in_type(np.array(255)), in_type(np.array(0))],  # type: ignore\n                np.uint8,\n            )\n        )\n\nTESTS_CUPY = [[np.array(1.0), np.array(1.0), None], [np.array([-12]), np.array([0]), np.uint8]]\n\n\nclass TesSafeDtypeRange(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_safe_dtype_range(self, in_image, im_out, out_dtype):\n        result = safe_dtype_range(in_image, out_dtype)\n        # check type is unchanged\n        self.assertEqual(type(in_image), type(result))\n        # check dtype is unchanged\n        if isinstance(in_image, (np.ndarray, torch.Tensor)):\n            self.assertEqual(in_image.dtype, result.dtype)\n        # check output\n        assert_allclose(result, im_out)\n\n    @parameterized.expand(TESTS_LIST)\n    def test_safe_dtype_range_list(self, in_image, im_out, out_dtype):\n        output_type = type(im_out[0])\n        result = safe_dtype_range(in_image, dtype=out_dtype)\n        # check type is unchanged\n        self.assertEqual(type(result), type(im_out))\n        # check output\n        for i, _result in enumerate(result):\n            assert_allclose(_result, im_out[i])\n        # check dtype is unchanged\n        if isinstance(in_image, (np.ndarray, torch.Tensor)):\n            if out_dtype is None:\n                self.assertEqual(result[0].dtype, im_out[0].dtype)\n            else:\n                _out_dtype = get_equivalent_dtype(out_dtype, output_type)\n                self.assertEqual(result[0].dtype, _out_dtype)\n\n    @parameterized.expand(TESTS_CUPY)\n    @unittest.skipUnless(HAS_CUPY, \"Requires CuPy\")\n    def test_type_cupy(self, in_image, im_out, out_dtype):\n        in_image = cp.asarray(in_image)\n        result = safe_dtype_range(in_image, dtype=out_dtype)\n        # check type is unchanged\n        self.assertEqual(type(in_image), type(result))\n        # check dtype is unchanged\n        self.assertEqual(result.dtype, in_image.dtype)\n        # check output\n        self.assertEqual(result, cp.asarray(im_out))\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/visualize/test_img2tensorboard.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport numpy as np\nimport tensorboard\nimport torch\n\nfrom monai.visualize import make_animated_gif_summary\n\n\nclass TestImg2Tensorboard(unittest.TestCase):\n\n    def test_write_gray(self):\n        nparr = np.ones(shape=(1, 32, 32, 32), dtype=np.float32)\n        summary_object_np = make_animated_gif_summary(\n            tag=\"test_summary_nparr.png\", image=nparr, max_out=1, scale_factor=253.0\n        )\n        for s in summary_object_np:\n            assert isinstance(\n                s, tensorboard.compat.proto.summary_pb2.Summary\n            ), \"make_animated_gif_summary must return a tensorboard.summary object from numpy array\"\n\n        tensorarr = torch.tensor(nparr)\n        summary_object_tensor = make_animated_gif_summary(\n            tag=\"test_summary_tensorarr.png\", image=tensorarr, max_out=1, frame_dim=-1, scale_factor=253.0\n        )\n        for s in summary_object_tensor:\n            assert isinstance(\n                s, tensorboard.compat.proto.summary_pb2.Summary\n            ), \"make_animated_gif_summary must return a tensorboard.summary object from tensor input\"\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/test_occlusion_sensitivity.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import Any\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import DenseNet, DenseNet121\nfrom monai.visualize import OcclusionSensitivity\n\n\nclass DenseNetAdjoint(DenseNet121):\n\n    def __call__(self, x, adjoint_info):\n        if adjoint_info != 42:\n            raise ValueError\n        return super().__call__(x)\n\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nout_channels_2d = 4\nout_channels_3d = 3\nmodel_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device)\nmodel_2d_2c = DenseNet121(spatial_dims=2, in_channels=2, out_channels=out_channels_2d).to(device)\nmodel_3d = DenseNet(\n    spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,)\n).to(device)\nmodel_2d_adjoint = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device)\nmodel_2d.eval()\nmodel_2d_2c.eval()\nmodel_3d.eval()\nmodel_2d_adjoint.eval()\n\nTESTS: list[Any] = []\nTESTS_FAIL: list[Any] = []\n\n# 2D w/ bounding box with all modes\nfor mode in (\"gaussian\", \"mean_patch\", \"mean_img\"):\n    TESTS.append(\n        [\n            {\"nn_module\": model_2d, \"mode\": mode},\n            {\"x\": torch.rand(1, 1, 48, 64).to(device), \"b_box\": [2, 40, 1, 62]},\n            (1, out_channels_2d, 38, 61),\n            (1, 1, 38, 61),\n        ]\n    )\n# 3D w/ bounding box\nTESTS.append(\n    [\n        {\"nn_module\": model_3d, \"n_batch\": 10, \"mask_size\": (16, 15, 14)},\n        {\"x\": torch.rand(1, 1, 64, 32, 16).to(device), \"b_box\": [2, 43, -1, -1, -1, -1]},\n        (1, out_channels_3d, 41, 32, 16),\n        (1, 1, 41, 32, 16),\n    ]\n)\nTESTS.append(\n    [\n        {\"nn_module\": model_3d, \"n_batch\": 10},\n        {\"x\": torch.rand(1, 1, 6, 7, 8).to(device), \"b_box\": [1, 3, -1, -1, -1, -1]},\n        (1, out_channels_3d, 2, 7, 8),\n        (1, 1, 2, 7, 8),\n    ]\n)\nTESTS.append(\n    [\n        {\"nn_module\": model_2d_2c},\n        {\"x\": torch.rand(1, 2, 48, 64).to(device)},\n        (1, out_channels_2d, 48, 64),\n        (1, 1, 48, 64),\n    ]\n)\n# 2D w/ bounding box and adjoint\nTESTS.append(\n    [\n        {\"nn_module\": model_2d_adjoint},\n        {\"x\": torch.rand(1, 1, 48, 64).to(device), \"b_box\": [2, 40, 1, 62], \"adjoint_info\": 42},\n        (1, out_channels_2d, 38, 61),\n        (1, 1, 38, 61),\n    ]\n)\n# 2D should fail: bbox makes image too small\nTESTS_FAIL.append(\n    [{\"nn_module\": model_2d, \"n_batch\": 10, \"mask_size\": 200}, {\"x\": torch.rand(1, 1, 48, 64).to(device)}, ValueError]\n)\n# 2D should fail: batch > 1\nTESTS_FAIL.append(\n    [{\"nn_module\": model_2d, \"n_batch\": 10, \"mask_size\": 100}, {\"x\": torch.rand(2, 1, 48, 64).to(device)}, ValueError]\n)\n# 2D should fail: unknown mode\nTESTS_FAIL.append(\n    [{\"nn_module\": model_2d, \"mode\": \"test\"}, {\"x\": torch.rand(1, 1, 48, 64).to(device)}, NotImplementedError]\n)\n\n\nclass TestComputeOcclusionSensitivity(unittest.TestCase):\n\n    @parameterized.expand(TESTS)\n    def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape):\n        occ_sens = OcclusionSensitivity(**init_data)\n        m, most_prob = occ_sens(**call_data)\n        self.assertTupleEqual(m.shape, map_expected_shape)\n        self.assertTupleEqual(most_prob.shape, most_prob_expected_shape)\n        # most probable class should be of type int, and should have min>=0, max<num_classes\n        self.assertEqual(most_prob.dtype, torch.int64)\n        self.assertGreaterEqual(most_prob.min(), 0)\n        self.assertLess(most_prob.max(), m.shape[-1])\n\n    @parameterized.expand(TESTS_FAIL)\n    def test_fail(self, init_data, call_data, error_type):\n        with self.assertRaises(error_type):\n            occ_sens = OcclusionSensitivity(**init_data)\n            occ_sens(**call_data)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/test_plot_2d_or_3d_image.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport glob\nimport tempfile\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.utils import optional_import\nfrom monai.visualize import plot_2d_or_3d_image\nfrom tests.test_utils import SkipIfNoModule\n\nSummaryWriter, has_tb = optional_import(\"torch.utils.tensorboard\", name=\"SummaryWriter\")\n\nSummaryWriterX, _ = optional_import(\"tensorboardX\", name=\"SummaryWriter\")\n\nTEST_CASE_1 = [(1, 1, 10, 10)]\n\nTEST_CASE_2 = [(1, 3, 10, 10)]\n\nTEST_CASE_3 = [(1, 4, 10, 10)]\n\nTEST_CASE_4 = [(1, 1, 10, 10, 10)]\n\nTEST_CASE_5 = [(1, 3, 10, 10, 10)]\n\n\n@unittest.skipUnless(has_tb, \"Requires SummaryWriter installation\")\nclass TestPlot2dOr3dImage(unittest.TestCase):\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_tb_image(self, shape):\n        with tempfile.TemporaryDirectory() as tempdir:\n            writer = SummaryWriter(log_dir=tempdir)\n            plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=3, frame_dim=-1)\n            writer.flush()\n            writer.close()\n            self.assertTrue(len(glob.glob(tempdir)) > 0)\n\n    @SkipIfNoModule(\"tensorboardX\")\n    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])\n    def test_tbx_image(self, shape):\n        with tempfile.TemporaryDirectory() as tempdir:\n            writer = SummaryWriterX(log_dir=tempdir)\n            plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=2)\n            writer.flush()\n            writer.close()\n            self.assertTrue(len(glob.glob(tempdir)) > 0)\n\n    @SkipIfNoModule(\"tensorboardX\")\n    @parameterized.expand([TEST_CASE_5])\n    def test_tbx_video(self, shape):\n        with tempfile.TemporaryDirectory() as tempdir:\n            writer = SummaryWriterX(log_dir=tempdir)\n            plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3)\n            writer.flush()\n            writer.close()\n            self.assertTrue(len(glob.glob(tempdir)) > 0)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/test_vis_cam.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\n\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import DenseNet, DenseNet121, SEResNet50\nfrom monai.visualize import CAM\n\n# 2D\nTEST_CASE_0 = [\n    {\n        \"model\": \"densenet2d\",\n        \"shape\": (2, 1, 48, 64),\n        \"feature_shape\": (2, 1, 1, 2),\n        \"target_layers\": \"class_layers.relu\",\n        \"fc_layers\": \"class_layers.out\",\n    },\n    (2, 1, 48, 64),\n]\n# 3D\nTEST_CASE_1 = [\n    {\n        \"model\": \"densenet3d\",\n        \"shape\": (2, 1, 6, 6, 6),\n        \"feature_shape\": (2, 1, 2, 2, 2),\n        \"target_layers\": \"class_layers.relu\",\n        \"fc_layers\": \"class_layers.out\",\n    },\n    (2, 1, 6, 6, 6),\n]\n# 2D\nTEST_CASE_2 = [\n    {\n        \"model\": \"senet2d\",\n        \"shape\": (2, 3, 64, 64),\n        \"feature_shape\": (2, 1, 2, 2),\n        \"target_layers\": \"layer4\",\n        \"fc_layers\": \"last_linear\",\n    },\n    (2, 1, 64, 64),\n]\n\n# 3D\nTEST_CASE_3 = [\n    {\n        \"model\": \"senet3d\",\n        \"shape\": (2, 3, 8, 8, 48),\n        \"feature_shape\": (2, 1, 1, 1, 2),\n        \"target_layers\": \"layer4\",\n        \"fc_layers\": \"last_linear\",\n    },\n    (2, 1, 8, 8, 48),\n]\n\n\nclass TestClassActivationMap(unittest.TestCase):\n\n    @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])\n    def test_shape(self, input_data, expected_shape):\n        model = None\n\n        if input_data[\"model\"] == \"densenet2d\":\n            model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        if input_data[\"model\"] == \"densenet3d\":\n            model = DenseNet(\n                spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)\n            )\n        if input_data[\"model\"] == \"senet2d\":\n            model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)\n        if input_data[\"model\"] == \"senet3d\":\n            model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        cam = CAM(nn_module=model, target_layers=input_data[\"target_layers\"], fc_layers=input_data[\"fc_layers\"])\n        image = torch.rand(input_data[\"shape\"], device=device)\n        result = cam(x=image, layer_idx=-1)\n        fea_shape = cam.feature_map_size(input_data[\"shape\"], device=device)\n        self.assertTupleEqual(fea_shape, input_data[\"feature_shape\"])\n        self.assertTupleEqual(result.shape, expected_shape)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/test_vis_gradcam.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.networks.nets import DenseNet, DenseNet121, SEResNet50\nfrom monai.visualize import GradCAM, GradCAMpp\nfrom tests.test_utils import assert_allclose, skip_if_quick\n\n\nclass DenseNetAdjoint(DenseNet121):\n    def __call__(self, x, adjoint_info):\n        if adjoint_info != 42:\n            raise ValueError\n        return super().__call__(x)\n\n\nTESTS: list[Any] = []\nTESTS_ILL: list[Any] = []\n\nfor cam in (GradCAM, GradCAMpp):\n    # 2D\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"densenet2d\",\n                \"shape\": (2, 1, 48, 64),\n                \"feature_shape\": (2, 1, 1, 2),\n                \"target_layers\": \"class_layers.relu\",\n            },\n            (2, 1, 48, 64),\n        ]\n    )\n    # 2D binary classification (out_channels=1)\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"densenet2d_bin\",\n                \"shape\": (2, 1, 48, 64),\n                \"feature_shape\": (2, 1, 1, 2),\n                \"target_layers\": \"class_layers.relu\",\n            },\n            (2, 1, 48, 64),\n        ]\n    )\n    # 3D\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"densenet3d\",\n                \"shape\": (2, 1, 6, 6, 6),\n                \"feature_shape\": (2, 1, 2, 2, 2),\n                \"target_layers\": \"class_layers.relu\",\n            },\n            (2, 1, 6, 6, 6),\n        ]\n    )\n    # 3D binary classification (out_channels=1)\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"densenet3d_bin\",\n                \"shape\": (2, 1, 6, 6, 6),\n                \"feature_shape\": (2, 1, 2, 2, 2),\n                \"target_layers\": \"class_layers.relu\",\n            },\n            (2, 1, 6, 6, 6),\n        ]\n    )\n    # 2D\n    TESTS.append(\n        [\n            cam,\n            {\"model\": \"senet2d\", \"shape\": (2, 3, 64, 64), \"feature_shape\": (2, 1, 2, 2), \"target_layers\": \"layer4\"},\n            (2, 1, 64, 64),\n        ]\n    )\n    # 2D binary classification (num_classes=1)\n    TESTS.append(\n        [\n            cam,\n            {\"model\": \"senet2d_bin\", \"shape\": (2, 3, 64, 64), \"feature_shape\": (2, 1, 2, 2), \"target_layers\": \"layer4\"},\n            (2, 1, 64, 64),\n        ]\n    )\n\n    # 3D\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"senet3d\",\n                \"shape\": (2, 3, 8, 8, 48),\n                \"feature_shape\": (2, 1, 1, 1, 2),\n                \"target_layers\": \"layer4\",\n            },\n            (2, 1, 8, 8, 48),\n        ]\n    )\n    # 3D binary classification (num_classes=1)\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"senet3d_bin\",\n                \"shape\": (2, 3, 8, 8, 48),\n                \"feature_shape\": (2, 1, 1, 1, 2),\n                \"target_layers\": \"layer4\",\n            },\n            (2, 1, 8, 8, 48),\n        ]\n    )\n\n    # adjoint info\n    TESTS.append(\n        [\n            cam,\n            {\n                \"model\": \"adjoint\",\n                \"shape\": (2, 1, 48, 64),\n                \"feature_shape\": (2, 1, 1, 2),\n                \"target_layers\": \"class_layers.relu\",\n            },\n            (2, 1, 48, 64),\n        ]\n    )\n\n    TESTS_ILL.append([cam])\n\n\n@skip_if_quick\nclass TestGradientClassActivationMap(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_shape(self, cam_class, input_data, expected_shape):\n        model = None\n\n        if input_data[\"model\"] == \"densenet2d\":\n            model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        elif input_data[\"model\"] == \"densenet2d_bin\":\n            model = DenseNet(spatial_dims=2, in_channels=1, out_channels=1)\n        elif input_data[\"model\"] == \"densenet3d\":\n            model = DenseNet(\n                spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)\n            )\n        elif input_data[\"model\"] == \"densenet3d_bin\":\n            model = DenseNet(\n                spatial_dims=3, in_channels=1, out_channels=1, init_features=2, growth_rate=2, block_config=(6,)\n            )\n        elif input_data[\"model\"] == \"senet2d\":\n            model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)\n        elif input_data[\"model\"] == \"senet2d_bin\":\n            model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=1)\n        elif input_data[\"model\"] == \"senet3d\":\n            model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)\n        elif input_data[\"model\"] == \"senet3d_bin\":\n            model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=1)\n        elif input_data[\"model\"] == \"adjoint\":\n            model = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3)\n\n        # optionally test for adjoint info\n        kwargs = {\"adjoint_info\": 42} if input_data[\"model\"] == \"adjoint\" else {}\n\n        device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        model.to(device)\n        model.eval()\n        cam = cam_class(nn_module=model, target_layers=input_data[\"target_layers\"])\n        image = torch.rand(input_data[\"shape\"], device=device)\n        inferred = model(image, **kwargs).max(1)[-1].cpu()\n        result = cam(x=image, layer_idx=-1, **kwargs)\n        np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), inferred)\n\n        fea_shape = cam.feature_map_size(input_data[\"shape\"], device=device, **kwargs)\n        self.assertTupleEqual(fea_shape, input_data[\"feature_shape\"])\n        self.assertTupleEqual(result.shape, expected_shape)\n        # check result is same whether class_idx=None is used or not\n        result2 = cam(x=image, layer_idx=-1, class_idx=inferred, **kwargs)\n        assert_allclose(result, result2)\n\n    @parameterized.expand(TESTS_ILL)\n    def test_ill(self, cam_class):\n        model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)\n        for name, x in model.named_parameters():\n            if \"features\" in name:\n                x.requires_grad = False\n        cam = cam_class(nn_module=model, target_layers=\"class_layers.relu\")\n        image = torch.rand((2, 1, 48, 64))\n        with self.assertRaises(IndexError):\n            cam(x=image)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/utils/__init__.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "tests/visualize/utils/test_blend_images.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport unittest\nfrom unittest.case import skipUnless\n\nimport numpy as np\nimport torch\nfrom parameterized import parameterized\n\nfrom monai.data.synthetic import create_test_image_2d, create_test_image_3d\nfrom monai.transforms.utils_pytorch_numpy_unification import moveaxis\nfrom monai.utils.module import optional_import\nfrom monai.visualize.utils import blend_images\nfrom tests.test_utils import TEST_NDARRAYS\n\nplt, has_matplotlib = optional_import(\"matplotlib.pyplot\")\n\n\ndef get_alpha(img):\n    return 0.5 * np.arange(img.size).reshape(img.shape) / img.size\n\n\nTESTS = []\nfor p in TEST_NDARRAYS:\n    image, label = create_test_image_2d(100, 101, channel_dim=0)\n    TESTS.append((p(image), p(label), 0.5))\n    TESTS.append((p(image), p(label), p(get_alpha(image))))\n\n    image, label = create_test_image_3d(100, 101, 102, channel_dim=0)\n    TESTS.append((p(image), p(label), 0.5))\n    TESTS.append((p(image), p(label), p(get_alpha(image))))\n\n\n@skipUnless(has_matplotlib, \"Matplotlib required\")\nclass TestBlendImages(unittest.TestCase):\n    @parameterized.expand(TESTS)\n    def test_blend(self, image, label, alpha):\n        blended = blend_images(image, label, alpha)\n        self.assertEqual(type(image), type(blended))\n        if isinstance(blended, torch.Tensor):\n            self.assertEqual(blended.device, image.device)\n            blended = blended.cpu().numpy()\n        self.assertEqual((3,) + image[0].shape, blended.shape)\n\n        blended = moveaxis(blended, 0, -1)  # move RGB component to end\n        if blended.ndim > 3:\n            blended = blended[blended.shape[0] // 2]\n        plt.imshow(blended)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/visualize/utils/test_matshow3d.py",
    "content": "# Copyright (c) MONAI Consortium\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#     http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\n\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirstd,\n    LoadImaged,\n    RandSpatialCropSamplesd,\n    RepeatChanneld,\n    ScaleIntensityd,\n)\nfrom monai.utils import optional_import\nfrom monai.visualize.utils import matshow3d\nfrom tests.test_utils import SkipIfNoModule\n\ncompare_images, _ = optional_import(\"matplotlib.testing.compare\", name=\"compare_images\")\npyplot, has_pyplot = optional_import(\"matplotlib\", name=\"pyplot\")\n\n\n@SkipIfNoModule(\"matplotlib\")\nclass TestMatshow3d(unittest.TestCase):\n    def test_3d(self):\n        test_root = Path(__file__).parents[2]\n        testing_dir = os.path.join(test_root, \"testing_data\")\n        print(\"test_root: \", testing_dir)\n        keys = \"image\"\n        xforms = Compose(\n            [\n                LoadImaged(keys=keys),\n                EnsureChannelFirstd(keys=keys, channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=keys),\n            ]\n        )\n        image_path = os.path.join(testing_dir, \"anatomical.nii\")\n        ims = xforms({keys: image_path})\n\n        fig = pyplot.figure()  # external figure\n        fig, _ = matshow3d(ims[keys], fig=fig, figsize=(2, 2), frames_per_row=5, every_n=2, frame_dim=-1, show=False)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            tempimg = f\"{tempdir}/matshow3d_test.png\"\n            fig.savefig(tempimg)\n            comp = compare_images(f\"{testing_dir}/matshow3d_test.png\", tempimg, 5e-2)\n            self.assertIsNone(comp, f\"value of comp={comp}\")  # None indicates test passed\n\n        _, axes = pyplot.subplots()\n        matshow3d(ims[keys], fig=axes, figsize=(2, 2), frames_per_row=5, every_n=2, frame_dim=-1, show=False)\n\n    def test_samples(self):\n        test_root = Path(__file__).parents[2]\n        testing_dir = os.path.join(test_root, \"testing_data\")\n        keys = \"image\"\n        xforms = Compose(\n            [\n                LoadImaged(keys=keys),\n                EnsureChannelFirstd(keys=keys),\n                ScaleIntensityd(keys=keys),\n                RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10),\n            ]\n        )\n        image_path = os.path.join(testing_dir, \"anatomical.nii\")\n        xforms.set_random_state(0)\n        ims = xforms({keys: image_path})\n        fig, mat = matshow3d(\n            [im[keys] for im in ims], title=f\"testing {keys}\", figsize=(2, 2), frames_per_row=5, every_n=2, show=False\n        )\n        self.assertEqual(mat.dtype, np.float32)\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            tempimg = f\"{tempdir}/matshow3d_patch_test.png\"\n            fig.savefig(tempimg)\n            comp = compare_images(f\"{testing_dir}/matshow3d_patch_test.png\", tempimg, 5e-2, in_decorator=True)\n            if comp:\n                print(\"not none comp: \", comp)  # matplotlib 3.2.2\n                np.testing.assert_allclose(comp[\"rms\"], 30.786983, atol=1e-3, rtol=1e-3)\n            else:\n                self.assertIsNone(comp, f\"value of comp={comp}\")  # None indicates test passed\n\n    def test_3d_rgb(self):\n        test_dir = Path(__file__).parents[2].as_posix()\n        testing_dir = os.path.join(test_dir, \"testing_data\")\n        keys = \"image\"\n        xforms = Compose(\n            [\n                LoadImaged(keys=keys),\n                EnsureChannelFirstd(keys=keys, channel_dim=\"no_channel\"),\n                ScaleIntensityd(keys=keys),\n                # change to RGB color image\n                RepeatChanneld(keys=keys, repeats=3),\n            ]\n        )\n        image_path = os.path.join(testing_dir, \"anatomical.nii\")\n        ims = xforms({keys: image_path})\n\n        fig = pyplot.figure()  # external figure\n        fig, _ = matshow3d(\n            volume=ims[keys],\n            fig=fig,\n            figsize=(2, 2),\n            frames_per_row=5,\n            every_n=2,\n            frame_dim=-1,\n            channel_dim=0,\n            fill_value=0,\n            show=False,\n        )\n\n        with tempfile.TemporaryDirectory() as tempdir:\n            tempimg = f\"{tempdir}/matshow3d_rgb_test.png\"\n            fig.savefig(tempimg)\n            comp = compare_images(f\"{testing_dir}/matshow3d_rgb_test.png\", tempimg, 5e-2)\n            self.assertIsNone(comp, f\"value of comp={comp}\")  # None indicates test passed\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "versioneer.py",
    "content": "# Version: 0.23\n\n\"\"\"The Versioneer - like a rocketeer, but for versions.\n\nThe Versioneer\n==============\n\n* like a rocketeer, but for versions!\n* https://github.com/python-versioneer/python-versioneer\n* Brian Warner\n* License: Public Domain (CC0-1.0)\n* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3\n* [![Latest Version][pypi-image]][pypi-url]\n* [![Build Status][travis-image]][travis-url]\n\nThis is a tool for managing a recorded version number in distutils/setuptools-based\npython projects. The goal is to remove the tedious and error-prone \"update\nthe embedded version string\" step from your release process. Making a new\nrelease should be as easy as recording a new tag in your version-control\nsystem, and maybe making new tarballs.\n\n\n## Quick Install\n\n* `pip install versioneer` to somewhere in your $PATH\n* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md))\n* run `versioneer install` in your source tree, commit the results\n* Verify version information with `python setup.py version`\n\n## Version Identifiers\n\nSource trees come from a variety of places:\n\n* a version-control system checkout (mostly used by developers)\n* a nightly tarball, produced by build automation\n* a snapshot tarball, produced by a web-based VCS browser, like github's\n  \"tarball from tag\" feature\n* a release tarball, produced by \"setup.py sdist\", distributed through PyPI\n\nWithin each source tree, the version identifier (either a string or a number,\nthis tool is format-agnostic) can come from a variety of places:\n\n* ask the VCS tool itself, e.g. \"git describe\" (for checkouts), which knows\n  about recent \"tags\" and an absolute revision-id\n* the name of the directory into which the tarball was unpacked\n* an expanded VCS keyword ($Id$, etc)\n* a `_version.py` created by some earlier build step\n\nFor released software, the version identifier is closely related to a VCS\ntag. Some projects use tag names that include more than just the version\nstring (e.g. \"myproject-1.2\" instead of just \"1.2\"), in which case the tool\nneeds to strip the tag prefix to extract the version identifier. For\nunreleased software (between tags), the version identifier should provide\nenough information to help developers recreate the same tree, while also\ngiving them an idea of roughly how old the tree is (after version 1.2, before\nversion 1.3). Many VCS systems can report a description that captures this,\nfor example `git describe --tags --dirty --always` reports things like\n\"0.7-1-g574ab98-dirty\" to indicate that the checkout is one revision past the\n0.7 tag, has a unique revision id of \"574ab98\", and is \"dirty\" (it has\nuncommitted changes).\n\nThe version identifier is used for multiple purposes:\n\n* to allow the module to self-identify its version: `myproject.__version__`\n* to choose a name and prefix for a 'setup.py sdist' tarball\n\n## Theory of Operation\n\nVersioneer works by adding a special `_version.py` file into your source\ntree, where your `__init__.py` can import it. This `_version.py` knows how to\ndynamically ask the VCS tool for version information at import time.\n\n`_version.py` also contains `$Revision$` markers, and the installation\nprocess marks `_version.py` to have this marker rewritten with a tag name\nduring the `git archive` command. As a result, generated tarballs will\ncontain enough information to get the proper version.\n\nTo allow `setup.py` to compute a version too, a `versioneer.py` is added to\nthe top level of your source tree, next to `setup.py` and the `setup.cfg`\nthat configures it. This overrides several distutils/setuptools commands to\ncompute the version when invoked, and changes `setup.py build` and `setup.py\nsdist` to replace `_version.py` with a small static file that contains just\nthe generated version data.\n\n## Installation\n\nSee [INSTALL.md](./INSTALL.md) for detailed installation instructions.\n\n## Version-String Flavors\n\nCode which uses Versioneer can learn about its version string at runtime by\nimporting `_version` from your main `__init__.py` file and running the\n`get_versions()` function. From the \"outside\" (e.g. in `setup.py`), you can\nimport the top-level `versioneer.py` and run `get_versions()`.\n\nBoth functions return a dictionary with different flavors of version\ninformation:\n\n* `['version']`: A condensed version string, rendered using the selected\n  style. This is the most commonly used value for the project's version\n  string. The default \"pep440\" style yields strings like `0.11`,\n  `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the \"Styles\" section\n  below for alternative styles.\n\n* `['full-revisionid']`: detailed revision identifier. For Git, this is the\n  full SHA1 commit id, e.g. \"1076c978a8d3cfc70f408fe5974aa6c092c949ac\".\n\n* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the\n  commit date in ISO 8601 format. This will be None if the date is not\n  available.\n\n* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that\n  this is only accurate if run in a VCS checkout, otherwise it is likely to\n  be False or None\n\n* `['error']`: if the version string could not be computed, this will be set\n  to a string describing the problem, otherwise it will be None. It may be\n  useful to throw an exception in setup.py if this is set, to avoid e.g.\n  creating tarballs with a version string of \"unknown\".\n\nSome variants are more useful than others. Including `full-revisionid` in a\nbug report should allow developers to reconstruct the exact code being tested\n(or indicate the presence of local changes that should be shared with the\ndevelopers). `version` is suitable for display in an \"about\" box or a CLI\n`--version` output: it can be easily compared against release notes and lists\nof bugs fixed in various releases.\n\nThe installer adds the following text to your `__init__.py` to place a basic\nversion in `YOURPROJECT.__version__`:\n\n    from ._version import get_versions\n    __version__ = get_versions()['version']\n    del get_versions\n\n## Styles\n\nThe setup.cfg `style=` configuration controls how the VCS information is\nrendered into a version string.\n\nThe default style, \"pep440\", produces a PEP440-compliant string, equal to the\nun-prefixed tag name for actual releases, and containing an additional \"local\nversion\" section with more detail for in-between builds. For Git, this is\nTAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags\n--dirty --always`. For example \"0.11+2.g1076c97.dirty\" indicates that the\ntree is like the \"1076c97\" commit but has uncommitted changes (\".dirty\"), and\nthat this commit is two revisions (\"+2\") beyond the \"0.11\" tag. For released\nsoftware (exactly equal to a known tag), the identifier will only contain the\nstripped tag, e.g. \"0.11\".\n\nOther styles are available. See [details.md](details.md) in the Versioneer\nsource tree for descriptions.\n\n## Debugging\n\nVersioneer tries to avoid fatal errors: if something goes wrong, it will tend\nto return a version of \"0+unknown\". To investigate the problem, run `setup.py\nversion`, which will run the version-lookup code in a verbose mode, and will\ndisplay the full contents of `get_versions()` (including the `error` string,\nwhich may help identify what went wrong).\n\n## Known Limitations\n\nSome situations are known to cause problems for Versioneer. This details the\nmost significant ones. More can be found on Github\n[issues page](https://github.com/python-versioneer/python-versioneer/issues).\n\n### Subprojects\n\nVersioneer has limited support for source trees in which `setup.py` is not in\nthe root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are\ntwo common reasons why `setup.py` might not be in the root:\n\n* Source trees which contain multiple subprojects, such as\n  [Buildbot](https://github.com/buildbot/buildbot), which contains both\n  \"master\" and \"slave\" subprojects, each with their own `setup.py`,\n  `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI\n  distributions (and upload multiple independently-installable tarballs).\n* Source trees whose main purpose is to contain a C library, but which also\n  provide bindings to Python (and perhaps other languages) in subdirectories.\n\nVersioneer will look for `.git` in parent directories, and most operations\nshould get the right version string. However `pip` and `setuptools` have bugs\nand implementation details which frequently cause `pip install .` from a\nsubproject directory to fail to find a correct version string (so it usually\ndefaults to `0+unknown`).\n\n`pip install --editable .` should work correctly. `setup.py install` might\nwork too.\n\nPip-8.1.1 is known to have this problem, but hopefully it will get fixed in\nsome later version.\n\n[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking\nthis issue. The discussion in\n[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the\nissue from the Versioneer side in more detail.\n[pip PR#3176](https://github.com/pypa/pip/pull/3176) and\n[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve\npip to let Versioneer work correctly.\n\nVersioneer-0.16 and earlier only looked for a `.git` directory next to the\n`setup.cfg`, so subprojects were completely unsupported with those releases.\n\n### Editable installs with setuptools <= 18.5\n\n`setup.py develop` and `pip install --editable .` allow you to install a\nproject into a virtualenv once, then continue editing the source code (and\ntest) without re-installing after every change.\n\n\"Entry-point scripts\" (`setup(entry_points={\"console_scripts\": ..})`) are a\nconvenient way to specify executable scripts that should be installed along\nwith the python package.\n\nThese both work as expected when using modern setuptools. When using\nsetuptools-18.5 or earlier, however, certain operations will cause\n`pkg_resources.DistributionNotFound` errors when running the entrypoint\nscript, which must be resolved by re-installing the package. This happens\nwhen the install happens with one version, then the egg_info data is\nregenerated while a different version is checked out. Many setup.py commands\ncause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into\na different virtualenv), so this can be surprising.\n\n[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes\nthis one, but upgrading to a newer version of setuptools should probably\nresolve it.\n\n\n## Updating Versioneer\n\nTo upgrade your project to a new release of Versioneer, do the following:\n\n* install the new Versioneer (`pip install -U versioneer` or equivalent)\n* edit `setup.cfg`, if necessary, to include any new configuration settings\n  indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details.\n* re-run `versioneer install` in your source tree, to replace\n  `SRC/_version.py`\n* commit any changed files\n\n## Future Directions\n\nThis tool is designed to make it easily extended to other version-control\nsystems: all VCS-specific components are in separate directories like\nsrc/git/ . The top-level `versioneer.py` script is assembled from these\ncomponents by running make-versioneer.py . In the future, make-versioneer.py\nwill take a VCS name as an argument, and will construct a version of\n`versioneer.py` that is specific to the given VCS. It might also take the\nconfiguration arguments that are currently provided manually during\ninstallation by editing setup.py . Alternatively, it might go the other\ndirection and include code from all supported VCS systems, reducing the\nnumber of intermediate scripts.\n\n## Similar projects\n\n* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time\n  dependency\n* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of\n  versioneer\n* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools\n  plugin\n\n## License\n\nTo make Versioneer easier to embed, all its code is dedicated to the public\ndomain. The `_version.py` that it creates is also in the public domain.\nSpecifically, both are released under the Creative Commons \"Public Domain\nDedication\" license (CC0-1.0), as described in\nhttps://creativecommons.org/publicdomain/zero/1.0/ .\n\n[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg\n[pypi-url]: https://pypi.python.org/pypi/versioneer/\n[travis-image]:\nhttps://img.shields.io/travis/com/python-versioneer/python-versioneer.svg\n[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer\n\n\"\"\"\n# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring\n# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements\n# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error\n# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with\n# pylint:disable=attribute-defined-outside-init,too-many-arguments\n\nimport configparser\nimport errno\nimport json\nimport os\nimport re\nimport subprocess\nimport sys\nfrom typing import Callable, Dict\nimport functools\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n\ndef get_root():\n    \"\"\"Get the project root directory.\n\n    We require that all commands are run from the project root, i.e. the\n    directory that contains setup.py, setup.cfg, and versioneer.py .\n    \"\"\"\n    root = os.path.realpath(os.path.abspath(os.getcwd()))\n    setup_py = os.path.join(root, \"setup.py\")\n    versioneer_py = os.path.join(root, \"versioneer.py\")\n    if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):\n        # allow 'python path/to/setup.py COMMAND'\n        root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0])))\n        setup_py = os.path.join(root, \"setup.py\")\n        versioneer_py = os.path.join(root, \"versioneer.py\")\n    if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):\n        err = (\n            \"Versioneer was unable to run the project root directory. \"\n            \"Versioneer requires setup.py to be executed from \"\n            \"its immediate directory (like 'python setup.py COMMAND'), \"\n            \"or in a way that lets it use sys.argv[0] to find the root \"\n            \"(like 'python path/to/setup.py COMMAND').\"\n        )\n        raise VersioneerBadRootError(err)\n    try:\n        # Certain runtime workflows (setup.py install/develop in a setuptools\n        # tree) execute all dependencies in a single python process, so\n        # \"versioneer\" may be imported multiple times, and python's shared\n        # module-import table will cache the first one. So we can't use\n        # os.path.dirname(__file__), as that will find whichever\n        # versioneer.py was first imported, even in later projects.\n        my_path = os.path.realpath(os.path.abspath(__file__))\n        me_dir = os.path.normcase(os.path.splitext(my_path)[0])\n        vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])\n        if me_dir != vsr_dir:\n            print(\"Warning: build in %s is using versioneer.py from %s\" % (os.path.dirname(my_path), versioneer_py))\n    except NameError:\n        pass\n    return root\n\n\ndef get_config_from_root(root):\n    \"\"\"Read the project setup.cfg file to determine Versioneer config.\"\"\"\n    # This might raise OSError (if setup.cfg is missing), or\n    # configparser.NoSectionError (if it lacks a [versioneer] section), or\n    # configparser.NoOptionError (if it lacks \"VCS=\"). See the docstring at\n    # the top of versioneer.py for instructions on writing your setup.cfg .\n    setup_cfg = os.path.join(root, \"setup.cfg\")\n    parser = configparser.ConfigParser()\n    with open(setup_cfg, \"r\") as cfg_file:\n        parser.read_file(cfg_file)\n    VCS = parser.get(\"versioneer\", \"VCS\")  # mandatory\n\n    # Dict-like interface for non-mandatory entries\n    section = parser[\"versioneer\"]\n\n    cfg = VersioneerConfig()\n    cfg.VCS = VCS\n    cfg.style = section.get(\"style\", \"\")\n    cfg.versionfile_source = section.get(\"versionfile_source\")\n    cfg.versionfile_build = section.get(\"versionfile_build\")\n    cfg.tag_prefix = section.get(\"tag_prefix\")\n    if cfg.tag_prefix in (\"''\", '\"\"', None):\n        cfg.tag_prefix = \"\"\n    cfg.parentdir_prefix = section.get(\"parentdir_prefix\")\n    cfg.verbose = section.get(\"verbose\")\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\n# these dictionaries contain VCS-specific tools\nLONG_VERSION_PY: Dict[str, str] = {}\nHANDLERS: Dict[str, Dict[str, Callable]] = {}\n\n\ndef register_vcs_handler(vcs, method):  # decorator\n    \"\"\"Create decorator to mark a method as the handler of a VCS.\"\"\"\n\n    def decorate(f):\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        HANDLERS.setdefault(vcs, {})[method] = f\n        return f\n\n    return decorate\n\n\ndef run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    process = None\n\n    popen_kwargs = {}\n    if sys.platform == \"win32\":\n        # This hides the console window if pythonw.exe is used\n        startupinfo = subprocess.STARTUPINFO()\n        startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW\n        popen_kwargs[\"startupinfo\"] = startupinfo\n\n    for command in commands:\n        try:\n            dispcmd = str([command] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            process = subprocess.Popen(\n                [command] + args,\n                cwd=cwd,\n                env=env,\n                stdout=subprocess.PIPE,\n                stderr=(subprocess.PIPE if hide_stderr else None),\n                **popen_kwargs,\n            )\n            break\n        except OSError:\n            e = sys.exc_info()[1]\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %s\" % dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %s\" % (commands,))\n        return None, None\n    stdout = process.communicate()[0].strip().decode()\n    if process.returncode != 0:\n        if verbose:\n            print(\"unable to run %s (error)\" % dispcmd)\n            print(\"stdout was %s\" % stdout)\n        return None, process.returncode\n    return stdout, process.returncode\n\n\nLONG_VERSION_PY[\n    \"git\"\n] = r'''\n# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provided by githubs download-from-tag\n# feature). Distribution tarballs (built by setup.py sdist) and build\n# directories (produced by setup.py build) will contain a much shorter file\n# that just contains the computed version number.\n\n# This file is released into the public domain. Generated by\n# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer)\n\n\"\"\"Git implementation of _version.py.\"\"\"\n\nimport errno\nimport os\nimport re\nimport subprocess\nimport sys\nfrom typing import Callable, Dict\nimport functools\n\n\ndef get_keywords():\n    \"\"\"Get the keywords needed to look up the version information.\"\"\"\n    # these strings will be replaced by git during git-archive.\n    # setup.py/versioneer.py will grep for the variable names, so they must\n    # each be defined on a line of their own. _version.py will just call\n    # get_keywords().\n    git_refnames = \"%(DOLLAR)sFormat:%%d%(DOLLAR)s\"\n    git_full = \"%(DOLLAR)sFormat:%%H%(DOLLAR)s\"\n    git_date = \"%(DOLLAR)sFormat:%%ci%(DOLLAR)s\"\n    keywords = {\"refnames\": git_refnames, \"full\": git_full, \"date\": git_date}\n    return keywords\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n\ndef get_config():\n    \"\"\"Create, populate and return the VersioneerConfig() object.\"\"\"\n    # these strings are filled in when 'setup.py versioneer' creates\n    # _version.py\n    cfg = VersioneerConfig()\n    cfg.VCS = \"git\"\n    cfg.style = \"%(STYLE)s\"\n    cfg.tag_prefix = \"%(TAG_PREFIX)s\"\n    cfg.parentdir_prefix = \"%(PARENTDIR_PREFIX)s\"\n    cfg.versionfile_source = \"%(VERSIONFILE_SOURCE)s\"\n    cfg.verbose = False\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\nLONG_VERSION_PY: Dict[str, str] = {}\nHANDLERS: Dict[str, Dict[str, Callable]] = {}\n\n\ndef register_vcs_handler(vcs, method):  # decorator\n    \"\"\"Create decorator to mark a method as the handler of a VCS.\"\"\"\n    def decorate(f):\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,\n                env=None):\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    process = None\n\n    popen_kwargs = {}\n    if sys.platform == \"win32\":\n        # This hides the console window if pythonw.exe is used\n        startupinfo = subprocess.STARTUPINFO()\n        startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW\n        popen_kwargs[\"startupinfo\"] = startupinfo\n\n    for command in commands:\n        try:\n            dispcmd = str([command] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            process = subprocess.Popen([command] + args, cwd=cwd, env=env,\n                                       stdout=subprocess.PIPE,\n                                       stderr=(subprocess.PIPE if hide_stderr\n                                               else None), **popen_kwargs)\n            break\n        except OSError:\n            e = sys.exc_info()[1]\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %%s\" %% dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %%s\" %% (commands,))\n        return None, None\n    stdout = process.communicate()[0].strip().decode()\n    if process.returncode != 0:\n        if verbose:\n            print(\"unable to run %%s (error)\" %% dispcmd)\n            print(\"stdout was %%s\" %% stdout)\n        return None, process.returncode\n    return stdout, process.returncode\n\n\ndef versions_from_parentdir(parentdir_prefix, root, verbose):\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for _ in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        rootdirs.append(root)\n        root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %%s but none started with prefix %%s\" %%\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs):\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords = {}\n    try:\n        with open(versionfile_abs, \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(\"git_refnames =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"refnames\"] = mo.group(1)\n                if line.strip().startswith(\"git_full =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"full\"] = mo.group(1)\n                if line.strip().startswith(\"git_date =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"date\"] = mo.group(1)\n    except OSError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(keywords, tag_prefix, verbose):\n    \"\"\"Get version information from git keywords.\"\"\"\n    if \"refnames\" not in keywords:\n        raise NotThisMethod(\"Short version file found\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # Use only the last line.  Previous lines may contain GPG signature\n        # information.\n        date = date.splitlines()[-1]\n\n        # git-2.2.0 added \"%%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = {r.strip() for r in refnames.strip(\"()\").split(\",\")}\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %%d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = {r for r in refs if re.search(r'\\d', r)}\n        if verbose:\n            print(\"discarding '%%s', no digits\" %% \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %%s\" %% \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            # Filter out refs that exactly match prefix or that don't start\n            # with a number once the prefix is stripped (mostly a concern\n            # when prefix is '')\n            if not re.match(r'\\d', r):\n                continue\n            if verbose:\n                print(\"picking %%s\" %% r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    # GIT_DIR can interfere with correct operation of Versioneer.\n    # It may be intended to be passed to the Versioneer-versioned project,\n    # but that should not change where we get our version from.\n    env = os.environ.copy()\n    env.pop(\"GIT_DIR\", None)\n    runner = functools.partial(runner, env=env)\n\n    _, rc = runner(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                   hide_stderr=True)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %%s not under git control\" %% root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = runner(GITS, [\n        \"describe\", \"--tags\", \"--dirty\", \"--always\", \"--long\",\n        \"--match\", f\"{tag_prefix}[[:digit:]]*\"\n        ], cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = runner(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    branch_name, rc = runner(GITS, [\"rev-parse\", \"--abbrev-ref\", \"HEAD\"],\n                             cwd=root)\n    # --abbrev-ref was added in git-1.6.3\n    if rc != 0 or branch_name is None:\n        raise NotThisMethod(\"'git rev-parse --abbrev-ref' returned error\")\n    branch_name = branch_name.strip()\n\n    if branch_name == \"HEAD\":\n        # If we aren't exactly on a branch, pick a branch which represents\n        # the current commit. If all else fails, we are on a branchless\n        # commit.\n        branches, rc = runner(GITS, [\"branch\", \"--contains\"], cwd=root)\n        # --contains was added in git-1.5.4\n        if rc != 0 or branches is None:\n            raise NotThisMethod(\"'git branch --contains' returned error\")\n        branches = branches.split(\"\\n\")\n\n        # Remove the first line if we're running detached\n        if \"(\" in branches[0]:\n            branches.pop(0)\n\n        # Strip off the leading \"* \" from the list of branches.\n        branches = [branch[2:] for branch in branches]\n        if \"master\" in branches:\n            branch_name = \"master\"\n        elif not branches:\n            branch_name = None\n        else:\n            # Pick the first branch that is returned. Good or bad.\n            branch_name = branches[0]\n\n    pieces[\"branch\"] = branch_name\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparsable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%%s'\"\n                               %% describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%%s' doesn't start with prefix '%%s'\"\n                print(fmt %% (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%%s' doesn't start with prefix '%%s'\"\n                               %% (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        out, rc = runner(GITS, [\"rev-list\", \"HEAD\", \"--left-right\"], cwd=root)\n        pieces[\"distance\"] = len(out.split())  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = runner(GITS, [\"show\", \"-s\", \"--format=%%ci\", \"HEAD\"], cwd=root)[0].strip()\n    # Use only the last line.  Previous lines may contain GPG signature\n    # information.\n    date = date.splitlines()[-1]\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef plus_or_dot(pieces):\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces):\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%%d.g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%%d.g%%s\" %% (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_branch(pieces):\n    \"\"\"TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch. Note that .dev0 sorts backwards\n    (a feature branch will appear \"older\" than the master branch).\n\n    Exceptions:\n    1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"%%d.g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0\"\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+untagged.%%d.g%%s\" %% (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef pep440_split_post(ver):\n    \"\"\"Split pep440 version string at the post-release segment.\n\n    Returns the release segments before the post-release and the\n    post-release version number (or -1 if no post-release segment is present).\n    \"\"\"\n    vc = str.split(ver, \".post\")\n    return vc[0], int(vc[1] or 0) if len(vc) == 2 else None\n\n\ndef render_pep440_pre(pieces):\n    \"\"\"TAG[.postN.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post0.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        if pieces[\"distance\"]:\n            # update the post release segment\n            tag_version, post_version = pep440_split_post(pieces[\"closest-tag\"])\n            rendered = tag_version\n            if post_version is not None:\n                rendered += \".post%%d.dev%%d\" %% (post_version + 1, pieces[\"distance\"])\n            else:\n                rendered += \".post0.dev%%d\" %% (pieces[\"distance\"])\n        else:\n            # no commits, use the tag as the version\n            rendered = pieces[\"closest-tag\"]\n    else:\n        # exception #1\n        rendered = \"0.post0.dev%%d\" %% pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%%s\" %% pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%%s\" %% pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_post_branch(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%%s\" %% pieces[\"short\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+g%%s\" %% pieces[\"short\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_old(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces):\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%%d-g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces):\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%%d-g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces, style):\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-branch\":\n        rendered = render_pep440_branch(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-post-branch\":\n        rendered = render_pep440_post_branch(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%%s'\" %% style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\ndef get_versions():\n    \"\"\"Get version information or return default if unable to do so.\"\"\"\n    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have\n    # __file__, we can work backwards from there to the root. Some\n    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which\n    # case we can only use expanded keywords.\n\n    cfg = get_config()\n    verbose = cfg.verbose\n\n    try:\n        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,\n                                          verbose)\n    except NotThisMethod:\n        pass\n\n    try:\n        root = os.path.realpath(__file__)\n        # versionfile_source is the relative path from the top of the source\n        # tree (where the .git directory might live) to this file. Invert\n        # this to find the root from __file__.\n        for _ in cfg.versionfile_source.split('/'):\n            root = os.path.dirname(root)\n    except NameError:\n        return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n                \"dirty\": None,\n                \"error\": \"unable to find root of source tree\",\n                \"date\": None}\n\n    try:\n        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)\n        return render(pieces, cfg.style)\n    except NotThisMethod:\n        pass\n\n    try:\n        if cfg.parentdir_prefix:\n            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n    except NotThisMethod:\n        pass\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None,\n            \"error\": \"unable to compute version\", \"date\": None}\n'''\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs):\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords = {}\n    try:\n        with open(versionfile_abs, \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(\"git_refnames =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"refnames\"] = mo.group(1)\n                if line.strip().startswith(\"git_full =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"full\"] = mo.group(1)\n                if line.strip().startswith(\"git_date =\"):\n                    mo = re.search(r'=\\s*\"(.*)\"', line)\n                    if mo:\n                        keywords[\"date\"] = mo.group(1)\n    except OSError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(keywords, tag_prefix, verbose):\n    \"\"\"Get version information from git keywords.\"\"\"\n    if \"refnames\" not in keywords:\n        raise NotThisMethod(\"Short version file found\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # Use only the last line.  Previous lines may contain GPG signature\n        # information.\n        date = date.splitlines()[-1]\n\n        # git-2.2.0 added \"%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = {r.strip() for r in refnames.strip(\"()\").split(\",\")}\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = {r for r in refs if re.search(r\"\\d\", r)}\n        if verbose:\n            print(\"discarding '%s', no digits\" % \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %s\" % \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix) :]\n            # Filter out refs that exactly match prefix or that don't start\n            # with a number once the prefix is stripped (mostly a concern\n            # when prefix is '')\n            if not re.match(r\"\\d\", r):\n                continue\n            if verbose:\n                print(\"picking %s\" % r)\n            return {\n                \"version\": r,\n                \"full-revisionid\": keywords[\"full\"].strip(),\n                \"dirty\": False,\n                \"error\": None,\n                \"date\": date,\n            }\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\n        \"version\": \"0+unknown\",\n        \"full-revisionid\": keywords[\"full\"].strip(),\n        \"dirty\": False,\n        \"error\": \"no suitable tags\",\n        \"date\": None,\n    }\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    # GIT_DIR can interfere with correct operation of Versioneer.\n    # It may be intended to be passed to the Versioneer-versioned project,\n    # but that should not change where we get our version from.\n    env = os.environ.copy()\n    env.pop(\"GIT_DIR\", None)\n    runner = functools.partial(runner, env=env)\n\n    _, rc = runner(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root, hide_stderr=True)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %s not under git control\" % root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = runner(\n        GITS, [\"describe\", \"--tags\", \"--dirty\", \"--always\", \"--long\", \"--match\", f\"{tag_prefix}[[:digit:]]*\"], cwd=root\n    )\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = runner(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    branch_name, rc = runner(GITS, [\"rev-parse\", \"--abbrev-ref\", \"HEAD\"], cwd=root)\n    # --abbrev-ref was added in git-1.6.3\n    if rc != 0 or branch_name is None:\n        raise NotThisMethod(\"'git rev-parse --abbrev-ref' returned error\")\n    branch_name = branch_name.strip()\n\n    if branch_name == \"HEAD\":\n        # If we aren't exactly on a branch, pick a branch which represents\n        # the current commit. If all else fails, we are on a branchless\n        # commit.\n        branches, rc = runner(GITS, [\"branch\", \"--contains\"], cwd=root)\n        # --contains was added in git-1.5.4\n        if rc != 0 or branches is None:\n            raise NotThisMethod(\"'git branch --contains' returned error\")\n        branches = branches.split(\"\\n\")\n\n        # Remove the first line if we're running detached\n        if \"(\" in branches[0]:\n            branches.pop(0)\n\n        # Strip off the leading \"* \" from the list of branches.\n        branches = [branch[2:] for branch in branches]\n        if \"master\" in branches:\n            branch_name = \"master\"\n        elif not branches:\n            branch_name = None\n        else:\n            # Pick the first branch that is returned. Good or bad.\n            branch_name = branches[0]\n\n    pieces[\"branch\"] = branch_name\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[: git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r\"^(.+)-(\\d+)-g([0-9a-f]+)$\", git_describe)\n        if not mo:\n            # unparsable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = \"unable to parse git-describe output: '%s'\" % describe_out\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%s' doesn't start with prefix '%s'\"\n                print(fmt % (full_tag, tag_prefix))\n            pieces[\"error\"] = \"tag '%s' doesn't start with prefix '%s'\" % (full_tag, tag_prefix)\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix) :]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        out, rc = runner(GITS, [\"rev-list\", \"HEAD\", \"--left-right\"], cwd=root)\n        pieces[\"distance\"] = len(out.split())  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = runner(GITS, [\"show\", \"-s\", \"--format=%ci\", \"HEAD\"], cwd=root)[0].strip()\n    # Use only the last line.  Previous lines may contain GPG signature\n    # information.\n    date = date.splitlines()[-1]\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef do_vcs_install(versionfile_source, ipy):\n    \"\"\"Git-specific installation logic for Versioneer.\n\n    For Git, this means creating/changing .gitattributes to mark _version.py\n    for export-subst keyword substitution.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n    files = [versionfile_source]\n    if ipy:\n        files.append(ipy)\n    try:\n        my_path = __file__\n        if my_path.endswith(\".pyc\") or my_path.endswith(\".pyo\"):\n            my_path = os.path.splitext(my_path)[0] + \".py\"\n        versioneer_file = os.path.relpath(my_path)\n    except NameError:\n        versioneer_file = \"versioneer.py\"\n    files.append(versioneer_file)\n    present = False\n    try:\n        with open(\".gitattributes\", \"r\") as fobj:\n            for line in fobj:\n                if line.strip().startswith(versionfile_source):\n                    if \"export-subst\" in line.strip().split()[1:]:\n                        present = True\n                        break\n    except OSError:\n        pass\n    if not present:\n        with open(\".gitattributes\", \"a+\") as fobj:\n            fobj.write(f\"{versionfile_source} export-subst\\n\")\n        files.append(\".gitattributes\")\n    run_command(GITS, [\"add\", \"--\"] + files)\n\n\ndef versions_from_parentdir(parentdir_prefix, root, verbose):\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for _ in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\n                \"version\": dirname[len(parentdir_prefix) :],\n                \"full-revisionid\": None,\n                \"dirty\": False,\n                \"error\": None,\n                \"date\": None,\n            }\n        rootdirs.append(root)\n        root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %s but none started with prefix %s\" % (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\nSHORT_VERSION_PY = \"\"\"\n# This file was generated by 'versioneer.py' (0.23) from\n# revision-control system data, or from the parent directory name of an\n# unpacked source archive. Distribution tarballs contain a pre-generated copy\n# of this file.\n\nimport json\n\nversion_json = '''\n%s\n'''  # END VERSION_JSON\n\n\ndef get_versions():\n    return json.loads(version_json)\n\"\"\"\n\n\ndef versions_from_file(filename):\n    \"\"\"Try to determine the version from _version.py if present.\"\"\"\n    try:\n        with open(filename) as f:\n            contents = f.read()\n    except OSError:\n        raise NotThisMethod(\"unable to read _version.py\")\n    mo = re.search(r\"version_json = '''\\n(.*)'''  # END VERSION_JSON\", contents, re.M | re.S)\n    if not mo:\n        mo = re.search(r\"version_json = '''\\r\\n(.*)'''  # END VERSION_JSON\", contents, re.M | re.S)\n    if not mo:\n        raise NotThisMethod(\"no version_json in _version.py\")\n    return json.loads(mo.group(1))\n\n\ndef write_to_version_file(filename, versions):\n    \"\"\"Write the given version number to the given _version.py file.\"\"\"\n    os.unlink(filename)\n    contents = json.dumps(versions, sort_keys=True, indent=1, separators=(\",\", \": \"))\n    with open(filename, \"w\") as f:\n        f.write(SHORT_VERSION_PY % contents)\n\n    print(\"set %s to '%s'\" % (filename, versions[\"version\"]))\n\n\ndef plus_or_dot(pieces):\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces):\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_branch(pieces):\n    \"\"\"TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch. Note that .dev0 sorts backwards\n    (a feature branch will appear \"older\" than the master branch).\n\n    Exceptions:\n    1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0\"\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+untagged.%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef pep440_split_post(ver):\n    \"\"\"Split pep440 version string at the post-release segment.\n\n    Returns the release segments before the post-release and the\n    post-release version number (or -1 if no post-release segment is present).\n    \"\"\"\n    vc = str.split(ver, \".post\")\n    return vc[0], int(vc[1] or 0) if len(vc) == 2 else None\n\n\ndef render_pep440_pre(pieces):\n    \"\"\"TAG[.postN.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post0.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        if pieces[\"distance\"]:\n            # update the post release segment\n            tag_version, post_version = pep440_split_post(pieces[\"closest-tag\"])\n            rendered = tag_version\n            if post_version is not None:\n                rendered += \".post%d.dev%d\" % (post_version + 1, pieces[\"distance\"])\n            else:\n                rendered += \".post0.dev%d\" % (pieces[\"distance\"])\n        else:\n            # no commits, use the tag as the version\n            rendered = pieces[\"closest-tag\"]\n    else:\n        # exception #1\n        rendered = \"0.post0.dev%d\" % pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_post_branch(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .\n\n    The \".dev0\" means not master branch.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"branch\"] != \"master\":\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"branch\"] != \"master\":\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_old(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces):\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces):\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces, style):\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\n            \"version\": \"unknown\",\n            \"full-revisionid\": pieces.get(\"long\"),\n            \"dirty\": None,\n            \"error\": pieces[\"error\"],\n            \"date\": None,\n        }\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-branch\":\n        rendered = render_pep440_branch(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-post-branch\":\n        rendered = render_pep440_post_branch(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%s'\" % style)\n\n    return {\n        \"version\": rendered,\n        \"full-revisionid\": pieces[\"long\"],\n        \"dirty\": pieces[\"dirty\"],\n        \"error\": None,\n        \"date\": pieces.get(\"date\"),\n    }\n\n\nclass VersioneerBadRootError(Exception):\n    \"\"\"The project root directory is unknown or missing key files.\"\"\"\n\n\ndef get_versions(verbose=False):\n    \"\"\"Get the project version from whatever source is available.\n\n    Returns dict with two keys: 'version' and 'full'.\n    \"\"\"\n    if \"versioneer\" in sys.modules:\n        # see the discussion in cmdclass.py:get_cmdclass()\n        del sys.modules[\"versioneer\"]\n\n    root = get_root()\n    cfg = get_config_from_root(root)\n\n    assert cfg.VCS is not None, \"please set [versioneer]VCS= in setup.cfg\"\n    handlers = HANDLERS.get(cfg.VCS)\n    assert handlers, \"unrecognized VCS '%s'\" % cfg.VCS\n    verbose = verbose or cfg.verbose\n    assert cfg.versionfile_source is not None, \"please set versioneer.versionfile_source\"\n    assert cfg.tag_prefix is not None, \"please set versioneer.tag_prefix\"\n\n    versionfile_abs = os.path.join(root, cfg.versionfile_source)\n\n    # extract version from first of: _version.py, VCS command (e.g. 'git\n    # describe'), parentdir. This is meant to work for developers using a\n    # source checkout, for users of a tarball created by 'setup.py sdist',\n    # and for users of a tarball/zipball created by 'git archive' or github's\n    # download-from-tag feature or the equivalent in other VCSes.\n\n    get_keywords_f = handlers.get(\"get_keywords\")\n    from_keywords_f = handlers.get(\"keywords\")\n    if get_keywords_f and from_keywords_f:\n        try:\n            keywords = get_keywords_f(versionfile_abs)\n            ver = from_keywords_f(keywords, cfg.tag_prefix, verbose)\n            if verbose:\n                print(\"got version from expanded keyword %s\" % ver)\n            return ver\n        except NotThisMethod:\n            pass\n\n    try:\n        ver = versions_from_file(versionfile_abs)\n        if verbose:\n            print(\"got version from file %s %s\" % (versionfile_abs, ver))\n        return ver\n    except NotThisMethod:\n        pass\n\n    from_vcs_f = handlers.get(\"pieces_from_vcs\")\n    if from_vcs_f:\n        try:\n            pieces = from_vcs_f(cfg.tag_prefix, root, verbose)\n            ver = render(pieces, cfg.style)\n            if verbose:\n                print(\"got version from VCS %s\" % ver)\n            return ver\n        except NotThisMethod:\n            pass\n\n    try:\n        if cfg.parentdir_prefix:\n            ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n            if verbose:\n                print(\"got version from parentdir %s\" % ver)\n            return ver\n    except NotThisMethod:\n        pass\n\n    if verbose:\n        print(\"unable to compute version\")\n\n    return {\n        \"version\": \"0+unknown\",\n        \"full-revisionid\": None,\n        \"dirty\": None,\n        \"error\": \"unable to compute version\",\n        \"date\": None,\n    }\n\n\ndef get_version():\n    \"\"\"Get the short version string for this project.\"\"\"\n    return get_versions()[\"version\"]\n\n\ndef get_cmdclass(cmdclass=None):\n    \"\"\"Get the custom setuptools subclasses used by Versioneer.\n\n    If the package uses a different cmdclass (e.g. one from numpy), it\n    should be provide as an argument.\n    \"\"\"\n    if \"versioneer\" in sys.modules:\n        del sys.modules[\"versioneer\"]\n        # this fixes the \"python setup.py develop\" case (also 'install' and\n        # 'easy_install .'), in which subdependencies of the main project are\n        # built (using setup.py bdist_egg) in the same python process. Assume\n        # a main project A and a dependency B, which use different versions\n        # of Versioneer. A's setup.py imports A's Versioneer, leaving it in\n        # sys.modules by the time B's setup.py is executed, causing B to run\n        # with the wrong versioneer. Setuptools wraps the sub-dep builds in a\n        # sandbox that restores sys.modules to it's pre-build state, so the\n        # parent is protected against the child's \"import versioneer\". By\n        # removing ourselves from sys.modules here, before the child build\n        # happens, we protect the child from the parent's versioneer too.\n        # Also see https://github.com/python-versioneer/python-versioneer/issues/52\n\n    cmds = {} if cmdclass is None else cmdclass.copy()\n\n    # we add \"version\" to setuptools\n    from setuptools import Command\n\n    class cmd_version(Command):\n        description = \"report generated version string\"\n        user_options = []\n        boolean_options = []\n\n        def initialize_options(self):\n            pass\n\n        def finalize_options(self):\n            pass\n\n        def run(self):\n            vers = get_versions(verbose=True)\n            print(\"Version: %s\" % vers[\"version\"])\n            print(\" full-revisionid: %s\" % vers.get(\"full-revisionid\"))\n            print(\" dirty: %s\" % vers.get(\"dirty\"))\n            print(\" date: %s\" % vers.get(\"date\"))\n            if vers[\"error\"]:\n                print(\" error: %s\" % vers[\"error\"])\n\n    cmds[\"version\"] = cmd_version\n\n    # we override \"build_py\" in setuptools\n    #\n    # most invocation pathways end up running build_py:\n    #  distutils/build -> build_py\n    #  distutils/install -> distutils/build ->..\n    #  setuptools/bdist_wheel -> distutils/install ->..\n    #  setuptools/bdist_egg -> distutils/install_lib -> build_py\n    #  setuptools/install -> bdist_egg ->..\n    #  setuptools/develop -> ?\n    #  pip install:\n    #   copies source tree to a tempdir before running egg_info/etc\n    #   if .git isn't copied too, 'git describe' will fail\n    #   then does setup.py bdist_wheel, or sometimes setup.py install\n    #  setup.py egg_info -> ?\n\n    # pip install -e . and setuptool/editable_wheel will invoke build_py\n    # but the build_py command is not expected to copy any files.\n\n    # we override different \"build_py\" commands for both environments\n    if \"build_py\" in cmds:\n        _build_py = cmds[\"build_py\"]\n    else:\n        from setuptools.command.build_py import build_py as _build_py\n\n    class cmd_build_py(_build_py):\n        def run(self):\n            root = get_root()\n            cfg = get_config_from_root(root)\n            versions = get_versions()\n            _build_py.run(self)\n            if getattr(self, \"editable_mode\", False):\n                # During editable installs `.py` and data files are\n                # not copied to build_lib\n                return\n            # now locate _version.py in the new build/ directory and replace\n            # it with an updated value\n            if cfg.versionfile_build:\n                target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build)\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n    cmds[\"build_py\"] = cmd_build_py\n\n    if \"build_ext\" in cmds:\n        _build_ext = cmds[\"build_ext\"]\n    else:\n        from setuptools.command.build_ext import build_ext as _build_ext\n\n    class cmd_build_ext(_build_ext):\n        def run(self):\n            root = get_root()\n            cfg = get_config_from_root(root)\n            versions = get_versions()\n            _build_ext.run(self)\n            if self.inplace:\n                # build_ext --inplace will only build extensions in\n                # build/lib<..> dir with no _version.py to write to.\n                # As in place builds will already have a _version.py\n                # in the module dir, we do not need to write one.\n                return\n            # now locate _version.py in the new build/ directory and replace\n            # it with an updated value\n            target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build)\n            if not os.path.exists(target_versionfile):\n                print(\n                    f\"Warning: {target_versionfile} does not exist, skipping \"\n                    \"version update. This can happen if you are running build_ext \"\n                    \"without first running build_py.\"\n                )\n                return\n            print(\"UPDATING %s\" % target_versionfile)\n            write_to_version_file(target_versionfile, versions)\n\n    cmds[\"build_ext\"] = cmd_build_ext\n\n    if \"cx_Freeze\" in sys.modules:  # cx_freeze enabled?\n        from cx_Freeze.dist import build_exe as _build_exe\n\n        # nczeczulin reports that py2exe won't like the pep440-style string\n        # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g.\n        # setup(console=[{\n        #   \"version\": versioneer.get_version().split(\"+\", 1)[0], # FILEVERSION\n        #   \"product_version\": versioneer.get_version(),\n        #   ...\n\n        class cmd_build_exe(_build_exe):\n            def run(self):\n                root = get_root()\n                cfg = get_config_from_root(root)\n                versions = get_versions()\n                target_versionfile = cfg.versionfile_source\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n                _build_exe.run(self)\n                os.unlink(target_versionfile)\n                with open(cfg.versionfile_source, \"w\") as f:\n                    LONG = LONG_VERSION_PY[cfg.VCS]\n                    f.write(\n                        LONG\n                        % {\n                            \"DOLLAR\": \"$\",\n                            \"STYLE\": cfg.style,\n                            \"TAG_PREFIX\": cfg.tag_prefix,\n                            \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                            \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                        }\n                    )\n\n        cmds[\"build_exe\"] = cmd_build_exe\n        del cmds[\"build_py\"]\n\n    if \"py2exe\" in sys.modules:  # py2exe enabled?\n        from py2exe.distutils_buildexe import py2exe as _py2exe\n\n        class cmd_py2exe(_py2exe):\n            def run(self):\n                root = get_root()\n                cfg = get_config_from_root(root)\n                versions = get_versions()\n                target_versionfile = cfg.versionfile_source\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n                _py2exe.run(self)\n                os.unlink(target_versionfile)\n                with open(cfg.versionfile_source, \"w\") as f:\n                    LONG = LONG_VERSION_PY[cfg.VCS]\n                    f.write(\n                        LONG\n                        % {\n                            \"DOLLAR\": \"$\",\n                            \"STYLE\": cfg.style,\n                            \"TAG_PREFIX\": cfg.tag_prefix,\n                            \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                            \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                        }\n                    )\n\n        cmds[\"py2exe\"] = cmd_py2exe\n\n    # sdist farms its file list building out to egg_info\n    if \"egg_info\" in cmds:\n        _sdist = cmds[\"egg_info\"]\n    else:\n        from setuptools.command.egg_info import egg_info as _egg_info\n\n    class cmd_egg_info(_egg_info):\n        def find_sources(self):\n            # egg_info.find_sources builds the manifest list and writes it\n            # in one shot\n            super().find_sources()\n\n            # Modify the filelist and normalize it\n            root = get_root()\n            cfg = get_config_from_root(root)\n            self.filelist.append(\"versioneer.py\")\n            if cfg.versionfile_source:\n                # There are rare cases where versionfile_source might not be\n                # included by default, so we must be explicit\n                self.filelist.append(cfg.versionfile_source)\n            self.filelist.sort()\n            self.filelist.remove_duplicates()\n\n            # The write method is hidden in the manifest_maker instance that\n            # generated the filelist and was thrown away\n            # We will instead replicate their final normalization (to unicode,\n            # and POSIX-style paths)\n            from setuptools import unicode_utils\n\n            normalized = [unicode_utils.filesys_decode(f).replace(os.sep, \"/\") for f in self.filelist.files]\n\n            manifest_filename = os.path.join(self.egg_info, \"SOURCES.txt\")\n            with open(manifest_filename, \"w\") as fobj:\n                fobj.write(\"\\n\".join(normalized))\n\n    cmds[\"egg_info\"] = cmd_egg_info\n\n    # we override different \"sdist\" commands for both environments\n    if \"sdist\" in cmds:\n        _sdist = cmds[\"sdist\"]\n    else:\n        from setuptools.command.sdist import sdist as _sdist\n\n    class cmd_sdist(_sdist):\n        def run(self):\n            versions = get_versions()\n            self._versioneer_generated_versions = versions\n            # unless we update this, the command will keep using the old\n            # version\n            self.distribution.metadata.version = versions[\"version\"]\n            return _sdist.run(self)\n\n        def make_release_tree(self, base_dir, files):\n            root = get_root()\n            cfg = get_config_from_root(root)\n            _sdist.make_release_tree(self, base_dir, files)\n            # now locate _version.py in the new base_dir directory\n            # (remembering that it may be a hardlink) and replace it with an\n            # updated value\n            target_versionfile = os.path.join(base_dir, cfg.versionfile_source)\n            print(\"UPDATING %s\" % target_versionfile)\n            write_to_version_file(target_versionfile, self._versioneer_generated_versions)\n\n    cmds[\"sdist\"] = cmd_sdist\n\n    return cmds\n\n\nCONFIG_ERROR = \"\"\"\nsetup.cfg is missing the necessary Versioneer configuration. You need\na section like:\n\n [versioneer]\n VCS = git\n style = pep440\n versionfile_source = src/myproject/_version.py\n versionfile_build = myproject/_version.py\n tag_prefix =\n parentdir_prefix = myproject-\n\nYou will also need to edit your setup.py to use the results:\n\n import versioneer\n setup(version=versioneer.get_version(),\n       cmdclass=versioneer.get_cmdclass(), ...)\n\nPlease read the docstring in ./versioneer.py for configuration instructions,\nedit setup.cfg, and re-run the installer or 'python versioneer.py setup'.\n\"\"\"\n\nSAMPLE_CONFIG = \"\"\"\n# See the docstring in versioneer.py for instructions. Note that you must\n# re-run 'versioneer.py setup' after changing this section, and commit the\n# resulting files.\n\n[versioneer]\n#VCS = git\n#style = pep440\n#versionfile_source =\n#versionfile_build =\n#tag_prefix =\n#parentdir_prefix =\n\n\"\"\"\n\nOLD_SNIPPET = \"\"\"\nfrom ._version import get_versions\n__version__ = get_versions()['version']\ndel get_versions\n\"\"\"\n\nINIT_PY_SNIPPET = \"\"\"\nfrom . import {0}\n__version__ = {0}.get_versions()['version']\n\"\"\"\n\n\ndef do_setup():\n    \"\"\"Do main VCS-independent setup function for installing Versioneer.\"\"\"\n    root = get_root()\n    try:\n        cfg = get_config_from_root(root)\n    except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e:\n        if isinstance(e, (OSError, configparser.NoSectionError)):\n            print(\"Adding sample versioneer config to setup.cfg\", file=sys.stderr)\n            with open(os.path.join(root, \"setup.cfg\"), \"a\") as f:\n                f.write(SAMPLE_CONFIG)\n        print(CONFIG_ERROR, file=sys.stderr)\n        return 1\n\n    print(\" creating %s\" % cfg.versionfile_source)\n    with open(cfg.versionfile_source, \"w\") as f:\n        LONG = LONG_VERSION_PY[cfg.VCS]\n        f.write(\n            LONG\n            % {\n                \"DOLLAR\": \"$\",\n                \"STYLE\": cfg.style,\n                \"TAG_PREFIX\": cfg.tag_prefix,\n                \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n            }\n        )\n\n    ipy = os.path.join(os.path.dirname(cfg.versionfile_source), \"__init__.py\")\n    if os.path.exists(ipy):\n        try:\n            with open(ipy, \"r\") as f:\n                old = f.read()\n        except OSError:\n            old = \"\"\n        module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0]\n        snippet = INIT_PY_SNIPPET.format(module)\n        if OLD_SNIPPET in old:\n            print(\" replacing boilerplate in %s\" % ipy)\n            with open(ipy, \"w\") as f:\n                f.write(old.replace(OLD_SNIPPET, snippet))\n        elif snippet not in old:\n            print(\" appending to %s\" % ipy)\n            with open(ipy, \"a\") as f:\n                f.write(snippet)\n        else:\n            print(\" %s unmodified\" % ipy)\n    else:\n        print(\" %s doesn't exist, ok\" % ipy)\n        ipy = None\n\n    # Make VCS-specific changes. For git, this means creating/changing\n    # .gitattributes to mark _version.py for export-subst keyword\n    # substitution.\n    do_vcs_install(cfg.versionfile_source, ipy)\n    return 0\n\n\ndef scan_setup_py():\n    \"\"\"Validate the contents of setup.py against Versioneer's expectations.\"\"\"\n    found = set()\n    setters = False\n    errors = 0\n    with open(\"setup.py\", \"r\") as f:\n        for line in f.readlines():\n            if \"import versioneer\" in line:\n                found.add(\"import\")\n            if \"versioneer.get_cmdclass()\" in line:\n                found.add(\"cmdclass\")\n            if \"versioneer.get_version()\" in line:\n                found.add(\"get_version\")\n            if \"versioneer.VCS\" in line:\n                setters = True\n            if \"versioneer.versionfile_source\" in line:\n                setters = True\n    if len(found) != 3:\n        print(\"\")\n        print(\"Your setup.py appears to be missing some important items\")\n        print(\"(but I might be wrong). Please make sure it has something\")\n        print(\"roughly like the following:\")\n        print(\"\")\n        print(\" import versioneer\")\n        print(\" setup( version=versioneer.get_version(),\")\n        print(\"        cmdclass=versioneer.get_cmdclass(),  ...)\")\n        print(\"\")\n        errors += 1\n    if setters:\n        print(\"You should remove lines like 'versioneer.VCS = ' and\")\n        print(\"'versioneer.versionfile_source = ' . This configuration\")\n        print(\"now lives in setup.cfg, and should be removed from setup.py\")\n        print(\"\")\n        errors += 1\n    return errors\n\n\nif __name__ == \"__main__\":\n    cmd = sys.argv[1]\n    if cmd == \"setup\":\n        errors = do_setup()\n        errors += scan_setup_py()\n        if errors:\n            sys.exit(1)\n"
  }
]